From d67f8cd6b0a00110dfa1ca4de8548cbe41405c00 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 18 Mar 2026 01:16:42 -0400 Subject: [PATCH 001/304] ql: bootstrap v2 protocol and session groundwork --- Cargo.lock | 104 +++++ Cargo.toml | 3 +- api/src/api/quantum_link.rs | 50 +++ ql-protocol/Cargo.toml | 19 + ql-protocol/src/executor.rs | 718 +++++++++++++++++++++++++++++++ ql-protocol/src/lib.rs | 19 + ql-protocol/src/test_identity.rs | 28 ++ ql-protocol/src/typed/handle.rs | 202 +++++++++ ql-protocol/src/typed/mod.rs | 123 ++++++ ql-protocol/src/typed/router.rs | 297 +++++++++++++ ql-protocol/src/typed/test.rs | 287 ++++++++++++ ql-protocol/src/wire.rs | 434 +++++++++++++++++++ 12 files changed, 2283 insertions(+), 1 deletion(-) create mode 100644 ql-protocol/Cargo.toml create mode 100644 ql-protocol/src/executor.rs create mode 100644 ql-protocol/src/lib.rs create mode 100644 ql-protocol/src/test_identity.rs create mode 100644 ql-protocol/src/typed/handle.rs create mode 100644 ql-protocol/src/typed/mod.rs create mode 100644 ql-protocol/src/typed/router.rs create mode 100644 ql-protocol/src/typed/test.rs create mode 100644 ql-protocol/src/wire.rs diff --git a/Cargo.lock b/Cargo.lock index f144305f..c953474e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -109,6 +109,18 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "async-channel" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "924ed96dd52d1b75e9c1a3e6275715fd320f5f9439fb5a4a11fa51f4221158d2" +dependencies = [ + "concurrent-queue", + "event-listener-strategy", + "futures-core", + "pin-project-lite", +] + [[package]] name = "atomic" version = "0.5.3" @@ -493,6 +505,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "console" version = "0.15.11" @@ -565,6 +586,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crunchy" version = "0.2.4" @@ -809,6 +836,33 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "event-listener" +version = "5.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13b66accf52311f30a0db42147dadea9850cb48cd070028831ae5f5d4b856ab" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" +dependencies = [ + "event-listener", + "pin-project-lite", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "ff" version = "0.13.1" @@ -941,6 +995,19 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-lite" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad" +dependencies = [ + "fastrand", + "futures-core", + "futures-io", + "parking", + "pin-project-lite", +] + [[package]] name = "futures-macro" version = "0.3.31" @@ -1540,6 +1607,12 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "oneshot" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea" + [[package]] name = "opaque-debug" version = "0.3.1" @@ -1595,6 +1668,12 @@ dependencies = [ "sha2", ] +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" + [[package]] name = "parking_lot_core" version = "0.9.11" @@ -1863,6 +1942,19 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "ql-protocol" +version = "0.1.0" +dependencies = [ + "async-channel", + "bc-components", + "dcbor", + "futures-lite", + "oneshot", + "thiserror", + "tokio", +] + [[package]] name = "quantum-link-macros" version = "0.1.0" @@ -2448,6 +2540,18 @@ dependencies = [ "mio", "pin-project-lite", "slab", + "tokio-macros", +] + +[[package]] +name = "tokio-macros" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 0fd0e755..540aa7ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["api", "backup-shard", "btp", "quantum-link-macros"] +members = ["api", "backup-shard", "btp", "ql-protocol", "quantum-link-macros"] [workspace.package] homepage = "https://github.com/Foundation-Devices/foundation-api" @@ -24,6 +24,7 @@ backup-shard = { path = "backup-shard" } btp = { path = "btp" } foundation-api = { path = "api" } quantum-link-macros = { path = "quantum-link-macros" } +ql-protocol = { path = "ql-protocol" } [patch.crates-io] pqcrypto-traits = { git = "https://github.com/Foundation-Devices/pqcrypto", rev = "ebadf71214f67cb970242fa1053b4acb65767737" } diff --git a/api/src/api/quantum_link.rs b/api/src/api/quantum_link.rs index a9033ed3..2b67c10e 100644 --- a/api/src/api/quantum_link.rs +++ b/api/src/api/quantum_link.rs @@ -239,6 +239,8 @@ impl QuantumLinkIdentity { #[cfg(test)] mod tests { + use dcbor::CBOREncodable; + use crate::{ api::{ message::{QuantumLinkMessage, PROTOCOL_VERSION}, @@ -247,6 +249,7 @@ mod tests { fx::ExchangeRate, message::EnvoyMessage, quantum_link::{ARIDCache, QlError, QuantumLinkIdentity}, + status::Heartbeat, }; #[test] @@ -309,6 +312,53 @@ mod tests { assert_eq!(fx_rate.rate, fx_rate_decoded.rate); } + #[test] + fn test_sealed_message_size() { + let envoy = QuantumLinkIdentity::generate(); + let passport = QuantumLinkIdentity::generate(); + + let fx_rate = ExchangeRate { + currency_code: String::from("USD"), + rate: 0.85, + timestamp: 0, + }; + let message = EnvoyMessage { + message: QuantumLinkMessage::ExchangeRate(fx_rate), + timestamp: 123456, + }; + + let envelope = QuantumLink::seal( + message, + (envoy.private_keys.as_ref().unwrap(), &envoy.xid_document), + &passport.xid_document, + ); + let bytes = envelope.to_cbor_data(); + + println!("sealed message size: {} bytes", bytes.len()); + assert!(!bytes.is_empty()); + } + + #[test] + fn test_sealed_heartbeat_size() { + let envoy = QuantumLinkIdentity::generate(); + let passport = QuantumLinkIdentity::generate(); + + let message = EnvoyMessage { + message: QuantumLinkMessage::Heartbeat(Heartbeat {}), + timestamp: 123456, + }; + + let envelope = QuantumLink::seal( + message, + (envoy.private_keys.as_ref().unwrap(), &envoy.xid_document), + &passport.xid_document, + ); + let bytes = envelope.to_cbor_data(); + + println!("sealed heartbeat size: {} bytes", bytes.len()); + assert!(!bytes.is_empty()); + } + #[test] fn test_serialize_ql_identity() { let identity = QuantumLinkIdentity::generate(); diff --git a/ql-protocol/Cargo.toml b/ql-protocol/Cargo.toml new file mode 100644 index 00000000..e741fce4 --- /dev/null +++ b/ql-protocol/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "ql-protocol" +version = "0.1.0" +edition = "2021" +description = "Quantum Link protocol primitives." +license = "Proprietary" + +[dependencies] +async-channel = { version = "2.5" } +bc-components = { version = "0.28.0", default-features = false, features = [ + "pqcrypto", +] } +dcbor = { version = "0.23.3" } +futures-lite = { version = "2.5" } +oneshot = { version = "0.1.11" } +thiserror = "2" + +[dev-dependencies] +tokio = { version = "1", features = ["rt", "time", "macros"] } diff --git a/ql-protocol/src/executor.rs b/ql-protocol/src/executor.rs new file mode 100644 index 00000000..65b79ccc --- /dev/null +++ b/ql-protocol/src/executor.rs @@ -0,0 +1,718 @@ +use std::{ + cmp::{Ordering, Reverse}, + collections::{BinaryHeap, HashMap, VecDeque}, + future::Future, + pin::{pin, Pin}, + task::{Context, Poll}, + time::{Duration, Instant}, +}; + +use async_channel::{Receiver, Sender, WeakSender}; +use bc_components::{EncryptedMessage, Signer, ARID, XID}; + +use super::wire::{ + decode_ql_message, encode_ql_message, DecodeErrContext, EncodeQlConfig, MessageKind, QlMessage, +}; + +pub type PlatformFuture<'a, T> = Pin + 'a>>; + +pub trait QlPlatform { + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>>; + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()>; +} + +#[derive(Debug)] +pub enum QlError { + Cancelled, + Protocol, + SendFailed, + Timeout, + Decode(super::wire::DecodeError), +} + +#[derive(Debug, Clone, Copy)] +pub struct RequestConfig { + pub timeout: Option, +} + +impl Default for RequestConfig { + fn default() -> Self { + Self { timeout: None } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct ExecutorConfig { + pub default_timeout: Duration, +} + +#[derive(Debug)] +pub struct InboundRequest { + pub message: QlMessage, + pub respond_to: Responder, +} + +#[derive(Debug)] +pub struct InboundEvent { + pub message: QlMessage, +} + +#[derive(Debug)] +pub enum HandlerEvent { + Request(InboundRequest), + Event(InboundEvent), +} + +#[derive(Debug, Clone)] +pub struct Responder { + id: ARID, + recipient: XID, + tx: Sender, +} + +impl Responder { + pub fn id(&self) -> ARID { + self.id + } + + pub fn recipient(&self) -> XID { + self.recipient + } + + pub fn respond( + self, + payload: EncryptedMessage, + encode_config: EncodeQlConfig, + signer: &dyn Signer, + ) -> Result<(), QlError> { + let bytes = encode_ql_message( + MessageKind::Response, + self.id, + encode_config, + payload, + signer, + ); + self.tx + .send_blocking(ExecutorEvent::SendResponse { bytes }) + .map_err(|_| QlError::Cancelled) + } +} + +#[derive(Debug)] +pub struct HandlerStream { + rx: Receiver, +} + +impl HandlerStream { + pub async fn next(&mut self) -> Result { + self.rx.recv().await.map_err(|_| QlError::Cancelled) + } +} + +impl futures_lite::Stream for HandlerStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> std::task::Poll> { + let rx = unsafe { self.as_mut().map_unchecked_mut(|s| &mut s.rx) }; + match rx.poll_next(cx) { + Poll::Ready(Some(event)) => Poll::Ready(Some(Ok(event))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +#[derive(Debug)] +enum ExecutorEvent { + SendRequest { + id: ARID, + bytes: Vec, + respond_to: oneshot::Sender>, + config: RequestConfig, + }, + SendEvent { + bytes: Vec, + }, + SendResponse { + bytes: Vec, + }, + Incoming { + message: QlMessage, + }, + IncomingDecodeError { + context: DecodeErrContext, + }, +} + +#[derive(Debug, Clone)] +pub struct ExecutorHandle { + tx: Sender, +} + +pub struct ExecutorResponse { + rx: oneshot::Receiver>, +} + +impl std::future::Future for ExecutorResponse { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + pin!(&mut self.rx) + .poll(cx) + .map(|result| result.unwrap_or(Err(QlError::Cancelled))) + } +} + +impl ExecutorHandle { + pub fn request( + &self, + id: ARID, + payload: EncryptedMessage, + encode_config: EncodeQlConfig, + request_config: RequestConfig, + signer: &dyn Signer, + ) -> ExecutorResponse { + let bytes = encode_ql_message(MessageKind::Request, id, encode_config, payload, signer); + let (tx, rx) = oneshot::channel(); + self.tx + .send_blocking(ExecutorEvent::SendRequest { + id, + bytes, + respond_to: tx, + config: request_config, + }) + .unwrap(); + ExecutorResponse { rx } + } + + pub fn send_event( + &self, + id: ARID, + payload: EncryptedMessage, + encode_config: EncodeQlConfig, + signer: &dyn Signer, + ) { + let tx = self.tx.clone(); + let bytes = encode_ql_message(MessageKind::Event, id, encode_config, payload, signer); + tx.send_blocking(ExecutorEvent::SendEvent { bytes }) + .unwrap(); + } + + pub fn send_incoming(&self, bytes: Vec) -> Result<(), QlError> { + match decode_ql_message(&bytes) { + Ok(message) => self + .tx + .send_blocking(ExecutorEvent::Incoming { message }) + .map_err(|_| QlError::Cancelled), + Err(context) => { + let _ = self + .tx + .send_blocking(ExecutorEvent::IncomingDecodeError { context }); + Ok(()) + } + } + } +} + +pub struct Executor

{ + platform: P, + rx: Receiver, + tx: WeakSender, + config: ExecutorConfig, + incoming: Sender, +} + +struct ExecutorState<'a> { + pending: HashMap, + timeouts: BinaryHeap>, + outbound: VecDeque, + in_flight: Option>, +} + +struct OutboundBytes { + id: Option, + bytes: Vec, +} + +struct InFlightWrite<'a> { + id: Option, + future: PlatformFuture<'a, Result<(), QlError>>, +} + +struct PendingEntry { + tx: oneshot::Sender>, +} + +#[derive(Debug, Clone)] +struct TimeoutEntry { + deadline: Instant, + id: ARID, +} + +impl PartialEq for TimeoutEntry { + fn eq(&self, other: &Self) -> bool { + self.deadline == other.deadline + } +} + +impl Eq for TimeoutEntry {} + +impl PartialOrd for TimeoutEntry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for TimeoutEntry { + fn cmp(&self, other: &Self) -> Ordering { + self.deadline.cmp(&other.deadline) + } +} + +enum LoopStep { + Event(Result), + WriteDone { + id: Option, + result: Result<(), QlError>, + }, + Timeout, +} + +impl

Executor

+where + P: QlPlatform, +{ + pub fn new(platform: P, config: ExecutorConfig) -> (Self, ExecutorHandle, HandlerStream) { + let (tx, rx) = async_channel::unbounded(); + let (incoming_tx, incoming_rx) = async_channel::unbounded(); + ( + Self { + rx, + tx: tx.downgrade(), + platform, + config, + incoming: incoming_tx, + }, + ExecutorHandle { tx }, + HandlerStream { rx: incoming_rx }, + ) + } + + pub async fn run<'a>(&'a mut self) { + let mut state = ExecutorState { + pending: HashMap::new(), + timeouts: BinaryHeap::new(), + outbound: VecDeque::new(), + in_flight: None, + }; + + loop { + Self::process_timeouts(&mut state); + + if state.in_flight.is_none() { + if let Some(message) = state.outbound.pop_front() { + state.in_flight = Some(InFlightWrite { + id: message.id, + future: self.platform.write_message(message.bytes), + }); + } + } + + let step = { + let recv_future = self.rx.recv(); + futures_lite::pin!(recv_future); + + let mut sleep_future = + Self::next_timeout_sleep(&state).map(|duration| self.platform.sleep(duration)); + + futures_lite::future::poll_fn(|cx| { + if let Some(in_flight) = state.in_flight.as_mut() { + if let Poll::Ready(result) = in_flight.future.as_mut().poll(cx) { + return Poll::Ready(LoopStep::WriteDone { + id: in_flight.id, + result, + }); + } + } + + if let Some(sleep_future) = sleep_future.as_mut() { + if let Poll::Ready(_result) = sleep_future.as_mut().poll(cx) { + return Poll::Ready(LoopStep::Timeout); + } + } + + match recv_future.as_mut().poll(cx) { + Poll::Ready(event) => Poll::Ready(LoopStep::Event(event)), + Poll::Pending => Poll::Pending, + } + }) + .await + }; + + match step { + LoopStep::Event(Ok(event)) => match event { + ExecutorEvent::SendRequest { + id, + bytes, + respond_to, + config, + } => { + let effective_timeout = + config.timeout.unwrap_or(self.config.default_timeout); + if effective_timeout.is_zero() { + let _ = respond_to.send(Err(QlError::Timeout)); + continue; + } + let deadline = Instant::now() + effective_timeout; + state.pending.insert(id, PendingEntry { tx: respond_to }); + state.timeouts.push(Reverse(TimeoutEntry { deadline, id })); + state.outbound.push_back(OutboundBytes { + id: Some(id), + bytes, + }); + } + ExecutorEvent::SendEvent { bytes } => { + state.outbound.push_back(OutboundBytes { id: None, bytes }); + } + ExecutorEvent::SendResponse { bytes } => { + state.outbound.push_back(OutboundBytes { id: None, bytes }); + } + ExecutorEvent::Incoming { message } => match message.header.kind { + MessageKind::Response => { + if let Some(entry) = state.pending.remove(&message.header.id) { + let _ = entry.tx.send(Ok(message)); + } + } + MessageKind::Request => { + let Some(tx) = self.tx.upgrade() else { return }; + let responder = Responder { + id: message.header.id, + recipient: message.header.sender, + tx, + }; + let _ = self + .incoming + .send(HandlerEvent::Request(InboundRequest { + message, + respond_to: responder, + })) + .await; + } + MessageKind::Event => { + let _ = self + .incoming + .send(HandlerEvent::Event(InboundEvent { message })) + .await; + } + }, + ExecutorEvent::IncomingDecodeError { context } => { + let Some(header) = context.header else { + continue; + }; + if header.kind == MessageKind::Response { + if let Some(entry) = state.pending.remove(&header.id) { + let _ = entry.tx.send(Err(QlError::Decode(context.error))); + } + } + } + }, + LoopStep::Event(Err(_)) => break, + LoopStep::WriteDone { id, result } => { + state.in_flight = None; + if let Err(e) = result { + if let Some(id) = id { + if let Some(entry) = state.pending.remove(&id) { + let _ = entry.tx.send(Err(e)); + } + } + } + } + LoopStep::Timeout => { + Self::process_timeouts(&mut state); + } + } + } + } + + fn process_timeouts(state: &mut ExecutorState<'_>) { + let now = Instant::now(); + while let Some(Reverse(entry)) = state.timeouts.peek().cloned() { + if entry.deadline > now { + break; + } + state.timeouts.pop(); + if let Some(pending) = state.pending.remove(&entry.id) { + let _ = pending.tx.send(Err(QlError::Timeout)); + } + } + } + + fn next_timeout_sleep(state: &ExecutorState<'_>) -> Option { + let Reverse(entry) = state.timeouts.peek()?; + let now = Instant::now(); + Some(entry.deadline.saturating_duration_since(now)) + } +} + +#[cfg(test)] +mod test { + use super::*; + use bc_components::{Nonce, SymmetricKey}; + use std::time::{SystemTime, UNIX_EPOCH}; + use crate::test_identity::TestIdentity; + + struct TestPlatform { + tx: Sender>, + } + + impl TestPlatform { + fn new() -> (Self, Receiver>) { + let (tx, rx) = async_channel::unbounded(); + (Self { tx }, rx) + } + } + + impl QlPlatform for TestPlatform { + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { + let tx = self.tx.clone(); + Box::pin(async move { tx.send(message).await.map_err(|_| QlError::Cancelled) }) + } + + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { + Box::pin(async move { + tokio::time::sleep(duration).await; + }) + } + } + + fn now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_secs()) + .unwrap_or(0) + } + + fn encrypt_payload(data: &str) -> EncryptedMessage { + let key = SymmetricKey::new(); + key.encrypt( + data.as_bytes(), + None::>, + None::, + ) + } + + fn encode_config(sender: XID, recipient: XID, valid_until: u64) -> EncodeQlConfig { + EncodeQlConfig { + sender, + recipient, + valid_until, + kem_ct: None, + sign_header: false, + } + } + + #[tokio::test(flavor = "current_thread")] + async fn request_response_round_trip() { + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let (platform, outbound_rx) = TestPlatform::new(); + let config = ExecutorConfig { + default_timeout: Duration::from_millis(50), + }; + let (mut core, handle, _incoming) = Executor::new(platform, config); + tokio::task::spawn_local(async move { core.run().await }); + + let requester = TestIdentity::generate(); + let responder = TestIdentity::generate(); + let recipient_xid = responder.xid; + let valid_until = now_secs().saturating_add(60); + let payload = encrypt_payload("ping"); + let request_id = ARID::new(); + + let response_task = tokio::task::spawn_local({ + let handle = handle.clone(); + let signer = requester.private_keys.clone(); + let config = encode_config(requester.xid, recipient_xid, valid_until); + async move { + handle + .request( + request_id, + payload, + config, + RequestConfig::default(), + &signer, + ) + .await + } + }); + + let outbound = outbound_rx.recv().await.expect("no outbound request"); + let outbound_message = decode_ql_message(&outbound).expect("decode outbound"); + assert_eq!(outbound_message.header.kind, MessageKind::Request); + let request_id = outbound_message.header.id; + + let response_payload = encrypt_payload("pong"); + let response_bytes = encode_ql_message( + MessageKind::Response, + request_id, + encode_config( + responder.xid, + outbound_message.header.sender, + now_secs().saturating_add(60), + ), + response_payload, + &responder.private_keys, + ); + handle.send_incoming(response_bytes).unwrap(); + + let response = response_task.await.unwrap().unwrap(); + assert_eq!(response.header.kind, MessageKind::Response); + assert_eq!(response.header.id, request_id); + }) + .await; + } + + #[tokio::test(flavor = "current_thread")] + async fn request_timeout_returns_error() { + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let (platform, _outbound_rx) = TestPlatform::new(); + let config = ExecutorConfig { + default_timeout: Duration::from_millis(5), + }; + let (mut core, handle, _incoming) = Executor::new(platform, config); + tokio::task::spawn_local(async move { core.run().await }); + + let requester = TestIdentity::generate(); + let recipient_xid = requester.xid; + let valid_until = now_secs().saturating_add(60); + let payload = encrypt_payload("timeout"); + let request_id = ARID::new(); + let result = handle + .request( + request_id, + payload, + encode_config(requester.xid, recipient_xid, valid_until), + RequestConfig { + timeout: Some(Duration::from_millis(1)), + }, + &requester.private_keys, + ) + .await; + + assert!(matches!(result, Err(QlError::Timeout))); + }) + .await; + } + + #[tokio::test(flavor = "current_thread")] + async fn event_is_forwarded() { + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let (platform, _outbound_rx) = TestPlatform::new(); + let config = ExecutorConfig { + default_timeout: Duration::from_secs(1), + }; + let (mut core, handle, mut handler_stream) = Executor::new(platform, config); + tokio::task::spawn_local(async move { core.run().await }); + + let sender = TestIdentity::generate(); + let recipient = TestIdentity::generate(); + let recipient_xid = recipient.xid; + let event_id = ARID::new(); + let payload = encrypt_payload("event"); + let event_bytes = encode_ql_message( + MessageKind::Event, + event_id, + encode_config( + sender.xid, + recipient_xid, + now_secs().saturating_add(60), + ), + payload, + &sender.private_keys, + ); + + handle.send_incoming(event_bytes).unwrap(); + + let event = handler_stream.next().await.unwrap(); + match event { + HandlerEvent::Event(event) => { + assert_eq!(event.message.header.kind, MessageKind::Event); + assert_eq!(event.message.header.id, event_id); + } + HandlerEvent::Request(_) => panic!("unexpected request"), + } + }) + .await; + } + + #[tokio::test(flavor = "current_thread")] + async fn expired_response_returns_error() { + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let (platform, outbound_rx) = TestPlatform::new(); + let config = ExecutorConfig { + default_timeout: Duration::from_secs(2), + }; + let (mut core, handle, _incoming) = Executor::new(platform, config); + tokio::task::spawn_local(async move { core.run().await }); + + let requester = TestIdentity::generate(); + let responder = TestIdentity::generate(); + let recipient_xid = responder.xid; + let valid_until = now_secs().saturating_add(60); + let payload = encrypt_payload("ping"); + let request_id = ARID::new(); + + let response_task = tokio::task::spawn_local({ + let handle = handle.clone(); + let signer = requester.private_keys.clone(); + let config = encode_config(requester.xid, recipient_xid, valid_until); + async move { + handle + .request( + request_id, + payload, + config, + RequestConfig { + timeout: Some(Duration::from_secs(3)), + }, + &signer, + ) + .await + } + }); + + let outbound = outbound_rx.recv().await.expect("no outbound request"); + let outbound_message = decode_ql_message(&outbound).expect("decode outbound"); + let request_id = outbound_message.header.id; + + let response_payload = encrypt_payload("pong"); + let response_bytes = encode_ql_message( + MessageKind::Response, + request_id, + encode_config( + responder.xid, + outbound_message.header.sender, + 0, + ), + response_payload, + &responder.private_keys, + ); + tokio::time::sleep(Duration::from_secs(1)).await; + handle.send_incoming(response_bytes).unwrap(); + + let response = response_task.await.unwrap(); + assert!(matches!(response, Err(QlError::Decode(_)))); + }) + .await; + } +} diff --git a/ql-protocol/src/lib.rs b/ql-protocol/src/lib.rs new file mode 100644 index 00000000..67280849 --- /dev/null +++ b/ql-protocol/src/lib.rs @@ -0,0 +1,19 @@ +pub mod executor; +pub mod typed; +pub mod wire; + +#[cfg(test)] +mod test_identity; + +pub use executor::{ + Executor, ExecutorConfig, ExecutorHandle, HandlerEvent, HandlerStream, InboundEvent, + InboundRequest, PlatformFuture, QlError, QlPlatform, RequestConfig, Responder, +}; +pub use typed::{ + Event, EventHandler, QlCodec, RequestHandler, RequestResponse, Router, RouterBuilder, + RouterError, TypedExecutorHandle, TypedPayload, TypedRequest, TypedResponder, +}; +pub use wire::{ + decode_ql_message, encode_ql_message, DecodeErrContext, DecodeError, EncodeQlConfig, + MessageKind, QlHeader, QlHeaderUnsigned, QlMessage, +}; diff --git a/ql-protocol/src/test_identity.rs b/ql-protocol/src/test_identity.rs new file mode 100644 index 00000000..478d294e --- /dev/null +++ b/ql-protocol/src/test_identity.rs @@ -0,0 +1,28 @@ +use bc_components::{ + EncapsulationPublicKey, EncapsulationScheme, PrivateKeys, SignatureScheme, SigningPublicKey, + XID, +}; + +#[derive(Debug, Clone)] +pub(crate) struct TestIdentity { + pub(crate) private_keys: PrivateKeys, + pub(crate) signing_public_key: SigningPublicKey, + pub(crate) encapsulation_public_key: EncapsulationPublicKey, + pub(crate) xid: XID, +} + +impl TestIdentity { + pub(crate) fn generate() -> Self { + let (signing_private_key, signing_public_key) = SignatureScheme::MLDSA44.keypair(); + let (encapsulation_private_key, encapsulation_public_key) = + EncapsulationScheme::MLKEM512.keypair(); + let private_keys = PrivateKeys::with_keys(signing_private_key, encapsulation_private_key); + let xid = XID::new(&signing_public_key); + Self { + private_keys, + signing_public_key, + encapsulation_public_key, + xid, + } + } +} diff --git a/ql-protocol/src/typed/handle.rs b/ql-protocol/src/typed/handle.rs new file mode 100644 index 00000000..0c41c648 --- /dev/null +++ b/ql-protocol/src/typed/handle.rs @@ -0,0 +1,202 @@ +use std::{ + future::Future, + marker::PhantomData, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; + +use bc_components::{EncapsulationCiphertext, ARID, XID}; + +use super::{Event, RequestResponse, RouterError, RouterPlatform, TypedPayload}; +use crate::{ + executor::ExecutorResponse, EncodeQlConfig, ExecutorHandle, MessageKind, QlCodec, + QlHeaderUnsigned, RequestConfig, +}; + +#[derive(Clone)] +pub struct TypedExecutorHandle { + handle: ExecutorHandle, + platform: Arc, +} + +pub struct Response { + inner: ResponseInner, + _type: PhantomData T>, +} + +enum ResponseInner { + Err(Option), + Ok { + response: ExecutorResponse, + platform: Arc, + }, +} + +impl Future for Response +where + T: QlCodec, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match &mut self.inner { + ResponseInner::Err(e) => { + let e = e.take(); + let e = e.unwrap_or(RouterError::Send(crate::QlError::Cancelled)); + Poll::Ready(Err(e)) + } + ResponseInner::Ok { response, platform } => { + Pin::new(response).poll(cx).map(|response| { + let response = response?; + let session_key = platform + .session_for_peer(response.header.sender) + .ok_or(RouterError::MissingSession(response.header.sender))?; + let decrypted = platform.decrypt_message( + &session_key, + &response.header.aad_data(), + &response.payload, + )?; + let message = T::try_from(decrypted)?; + Ok(message) + }) + } + } + } +} + +impl TypedExecutorHandle { + pub fn new(handle: ExecutorHandle, platform: Arc) -> Self { + Self { handle, platform } + } + + pub fn request( + &self, + message: M, + recipient: XID, + request_config: RequestConfig, + ) -> Response + where + M: RequestResponse, + { + let platform = self.platform.clone(); + let payload = TypedPayload { + message_id: M::ID, + payload: message.into(), + }; + let message_id = ARID::new(); + let inner = match self.encrypt_payload_for_recipient( + recipient, + MessageKind::Request, + message_id, + payload.into(), + ) { + Ok((encrypted, config)) => { + let response = self.handle.request( + message_id, + encrypted, + config, + request_config, + platform.signer(), + ); + + ResponseInner::Ok { + response, + platform: self.platform.clone(), + } + } + Err(e) => ResponseInner::Err(Some(e)), + }; + Response { + inner, + _type: Default::default(), + } + } + + pub fn send_event( + &self, + message: M, + recipient: XID, + _valid_for: Duration, + ) -> Result<(), RouterError> + where + M: Event, + { + let payload = TypedPayload { + message_id: M::ID, + payload: message.into(), + }; + let message_id = ARID::new(); + let (encrypted, config) = self.encrypt_payload_for_recipient( + recipient, + MessageKind::Event, + message_id, + payload.into(), + )?; + self.handle + .send_event(message_id, encrypted, config, self.platform.signer()); + Ok(()) + } + + fn encrypt_payload_for_recipient( + &self, + recipient: XID, + kind: MessageKind, + message_id: ARID, + payload: dcbor::CBOR, + ) -> Result<(bc_components::EncryptedMessage, EncodeQlConfig), RouterError> { + let platform = self.platform.as_ref(); + let (session_key, kem_ct, sign_header) = match platform.session_for_peer(recipient) { + Some(session_key) => (session_key, None, false), + None => self.create_session(recipient)?, + }; + let valid_until = now_secs().saturating_add(platform.message_expiration().as_secs()); + let header_unsigned = QlHeaderUnsigned { + kind, + id: message_id, + sender: platform.sender_xid(), + recipient, + valid_until, + kem_ct: kem_ct.clone(), + }; + let aad = header_unsigned.aad_data(); + let payload_bytes = payload.to_cbor_data(); + let encrypted = session_key.encrypt(payload_bytes, Some(aad), None::); + let config = EncodeQlConfig { + sender: platform.sender_xid(), + recipient, + valid_until, + kem_ct, + sign_header, + }; + Ok((encrypted, config)) + } + + fn create_session( + &self, + recipient: XID, + ) -> Result< + ( + bc_components::SymmetricKey, + Option, + bool, + ), + RouterError, + > { + let platform = self.platform.as_ref(); + let recipient_key = platform + .lookup_recipient(recipient) + .ok_or(RouterError::UnknownRecipient(recipient))?; + let (session_key, kem_ct) = recipient_key.encapsulate_new_shared_secret(); + platform.store_session(recipient, session_key.clone()); + Ok((session_key, Some(kem_ct), true)) + } +} + +fn now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_secs()) + .unwrap_or(0) +} diff --git a/ql-protocol/src/typed/mod.rs b/ql-protocol/src/typed/mod.rs new file mode 100644 index 00000000..1c0efefc --- /dev/null +++ b/ql-protocol/src/typed/mod.rs @@ -0,0 +1,123 @@ +use std::time::Duration; + +use bc_components::{ + EncapsulationCiphertext, EncapsulationPrivateKey, EncapsulationPublicKey, + EncryptedMessage, Signer, SigningPublicKey, SymmetricKey, XID, +}; +use dcbor::CBOR; + +use crate::QlError; + +pub mod handle; +pub mod router; + +pub trait QlCodec: Into + TryFrom + Sized {} + +impl QlCodec for T where T: Into + TryFrom + Sized {} + +pub trait RequestResponse: QlCodec { + const ID: u64; + type Response: QlCodec; +} + +pub trait Event: QlCodec { + const ID: u64; +} + +#[derive(Debug, Clone)] +pub struct TypedPayload { + pub message_id: u64, + pub payload: CBOR, +} + +impl From for CBOR { + fn from(value: TypedPayload) -> Self { + CBOR::from(vec![CBOR::from(value.message_id), value.payload]) + } +} + +impl TryFrom for TypedPayload { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let array = value.try_into_array()?; + if array.len() != 2 { + return Err(dcbor::Error::msg("invalid typed payload length")); + } + let message_id: u64 = array[0].clone().try_into()?; + Ok(Self { + message_id, + payload: array[1].clone(), + }) + } +} + +#[derive(Debug)] +pub enum RouterError { + Decode(dcbor::Error), + InvalidPayload, + InvalidSignature, + MissingHandler(u64), + MissingSession(XID), + Send(QlError), + UnknownRecipient(XID), + UnknownSender(XID), +} + +impl From for RouterError { + fn from(error: dcbor::Error) -> Self { + Self::Decode(error) + } +} + +impl From for RouterError { + fn from(error: QlError) -> Self { + Self::Send(error) + } +} + +pub trait RouterPlatform { + fn lookup_recipient(&self, recipient: XID) -> Option<&EncapsulationPublicKey>; + fn lookup_signing_key(&self, sender: XID) -> Option<&SigningPublicKey>; + fn session_for_peer(&self, peer: XID) -> Option; + fn store_session(&self, peer: XID, key: SymmetricKey); + fn encapsulation_private_key(&self) -> EncapsulationPrivateKey; + fn signing_key(&self) -> &SigningPublicKey; + fn message_expiration(&self) -> Duration; + fn signer(&self) -> &dyn Signer; + fn handle_error(&self, e: RouterError); + + fn sender_xid(&self) -> XID { + XID::new(self.signing_key()) + } + + fn decapsulate_shared_secret( + &self, + ciphertext: &EncapsulationCiphertext, + ) -> Result { + self.encapsulation_private_key() + .decapsulate_shared_secret(ciphertext) + .map_err(|_| RouterError::InvalidPayload) + } + + fn decrypt_message( + &self, + key: &SymmetricKey, + header_aad: &[u8], + payload: &EncryptedMessage, + ) -> Result { + if payload.aad() != header_aad { + return Err(RouterError::InvalidPayload); + } + let plaintext = key.decrypt(payload).map_err(|_| RouterError::InvalidPayload)?; + CBOR::try_from_data(plaintext).map_err(RouterError::Decode) + } +} + +pub use handle::TypedExecutorHandle; +pub use router::{ + EventHandler, RequestHandler, Router, RouterBuilder, TypedRequest, TypedResponder, +}; + +#[cfg(test)] +mod test; diff --git a/ql-protocol/src/typed/router.rs b/ql-protocol/src/typed/router.rs new file mode 100644 index 00000000..31784b83 --- /dev/null +++ b/ql-protocol/src/typed/router.rs @@ -0,0 +1,297 @@ +use std::{collections::HashMap, sync::Arc}; + +use bc_components::{Verifier, XID}; + +use super::{Event, QlCodec, RequestResponse, RouterError, RouterPlatform, TypedPayload}; +use crate::{EncodeQlConfig, HandlerEvent, MessageKind, QlHeader, QlHeaderUnsigned, Responder}; + +pub trait RequestHandler +where + M: RequestResponse, +{ + fn handle(&mut self, request: TypedRequest); + fn default_response() -> M::Response; +} + +pub trait EventHandler +where + M: Event, +{ + fn handle(&mut self, event: M); +} + +pub struct TypedRequest +where + M: RequestResponse, +{ + pub message: M, + pub responder: TypedResponder, +} + +pub struct TypedResponder +where + R: QlCodec, +{ + responder: Option, + platform: Arc, + recipient: XID, + default: fn() -> R, +} + +impl TypedResponder +where + R: QlCodec, +{ + pub fn respond(mut self, response: R) -> Result<(), RouterError> { + self.respond_inner(response) + } + + fn respond_inner(&mut self, response: R) -> Result<(), RouterError> { + let responder = self.responder.take().unwrap(); + let payload = response.into(); + let session_key = self + .platform + .session_for_peer(self.recipient) + .ok_or(RouterError::MissingSession(self.recipient))?; + let now = now_secs(); + let valid_until = now.saturating_add(self.platform.message_expiration().as_secs()); + let header_unsigned = QlHeaderUnsigned { + kind: MessageKind::Response, + id: responder.id(), + sender: self.platform.sender_xid(), + recipient: self.recipient, + valid_until, + kem_ct: None, + }; + let aad = header_unsigned.aad_data(); + let payload_bytes = dcbor::CBOR::from(payload).to_cbor_data(); + let encrypted = session_key.encrypt(payload_bytes, Some(aad), None::); + let config = EncodeQlConfig { + sender: self.platform.sender_xid(), + recipient: self.recipient, + valid_until, + kem_ct: None, + sign_header: false, + }; + responder.respond(encrypted, config, self.platform.signer())?; + Ok(()) + } +} + +impl Drop for TypedResponder +where + R: QlCodec, +{ + fn drop(&mut self) { + if self.responder.is_some() { + let default = (self.default)(); + let _ = self.respond_inner(default); + } + } +} + +type RouterHandler = fn(&mut S, RouterEvent, Arc) -> Result<(), RouterError>; + +enum RouterEvent { + Event { + #[allow(unused)] + header: QlHeader, + payload: TypedPayload, + }, + Request { + header: QlHeader, + payload: TypedPayload, + responder: Responder, + }, +} + +impl RouterEvent { + fn message_id(&self) -> u64 { + match self { + RouterEvent::Event { payload, .. } => payload.message_id, + RouterEvent::Request { payload, .. } => payload.message_id, + } + } +} + +pub struct RouterBuilder { + handlers: HashMap>, +} + +impl Default for RouterBuilder { + fn default() -> Self { + Self { + handlers: HashMap::new(), + } + } +} + +impl RouterBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn add_request_handler(mut self) -> Self + where + M: RequestResponse, + S: RequestHandler, + { + self.handlers.insert(M::ID, handle_request::); + self + } + + pub fn add_event_handler(mut self) -> Self + where + M: Event, + S: EventHandler, + { + self.handlers.insert(M::ID, handle_event::); + self + } + + pub fn build(self, platform: Arc) -> Router { + Router { + platform, + handlers: self.handlers, + } + } +} + +pub struct Router { + platform: Arc, + handlers: HashMap>, +} + +impl Router { + pub fn builder() -> RouterBuilder { + RouterBuilder::new() + } + + pub fn handle(&self, state: &mut S, event: HandlerEvent) -> Result<(), RouterError> { + let event = decrypt_event(event, self.platform.as_ref())?; + let message_id = event.message_id(); + let handler = self + .handlers + .get(&message_id) + .ok_or(RouterError::MissingHandler(message_id))?; + handler(state, event, self.platform.clone()) + } +} + +fn handle_request( + state: &mut S, + event: RouterEvent, + platform: Arc, +) -> Result<(), RouterError> +where + M: RequestResponse, + S: RequestHandler, +{ + let (header, payload, responder) = match event { + RouterEvent::Request { + header, + payload, + responder, + } => (header, payload, responder), + RouterEvent::Event { .. } => unreachable!("expected request event"), + }; + let message = M::try_from(payload.payload)?; + let responder = TypedResponder { + responder: Some(responder), + platform, + recipient: header.sender, + default: S::default_response, + }; + state.handle(TypedRequest { message, responder }); + Ok(()) +} + +fn handle_event( + state: &mut S, + event: RouterEvent, + _platform: Arc, +) -> Result<(), RouterError> +where + M: Event, + S: EventHandler, +{ + let payload = match event { + RouterEvent::Event { payload, .. } => payload, + RouterEvent::Request { .. } => unreachable!("expected event"), + }; + let message = M::try_from(payload.payload)?; + state.handle(message); + Ok(()) +} + +fn decrypt_event( + event: HandlerEvent, + platform: &dyn RouterPlatform, +) -> Result { + match event { + HandlerEvent::Request(request) => { + verify_header(platform, &request.message.header)?; + let payload = + extract_typed_payload(platform, &request.message.header, request.message.payload)?; + Ok(RouterEvent::Request { + header: request.message.header, + payload, + responder: request.respond_to, + }) + } + HandlerEvent::Event(event) => { + verify_header(platform, &event.message.header)?; + let payload = + extract_typed_payload(platform, &event.message.header, event.message.payload)?; + Ok(RouterEvent::Event { + header: event.message.header, + payload, + }) + } + } +} + +fn verify_header(platform: &dyn RouterPlatform, header: &QlHeader) -> Result<(), RouterError> { + if header.kem_ct.is_none() { + return Ok(()); + } + let signature = header + .signature + .as_ref() + .ok_or(RouterError::InvalidSignature)?; + let signing_key = platform + .lookup_signing_key(header.sender) + .ok_or(RouterError::UnknownSender(header.sender))?; + if signing_key.verify(signature, &header.unsigned().aad_data()) { + Ok(()) + } else { + Err(RouterError::InvalidSignature) + } +} + +fn extract_typed_payload( + platform: &dyn RouterPlatform, + header: &QlHeader, + payload: bc_components::EncryptedMessage, +) -> Result { + let session_key = if let Some(kem_ct) = &header.kem_ct { + let key = platform.decapsulate_shared_secret(kem_ct)?; + platform.store_session(header.sender, key.clone()); + key + } else { + platform + .session_for_peer(header.sender) + .ok_or(RouterError::MissingSession(header.sender))? + }; + let decrypted = platform.decrypt_message(&session_key, &header.aad_data(), &payload)?; + TypedPayload::try_from(decrypted).map_err(RouterError::Decode) +} + +fn now_secs() -> u64 { + use std::time::{SystemTime, UNIX_EPOCH}; + + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_secs()) + .unwrap_or(0) +} diff --git a/ql-protocol/src/typed/test.rs b/ql-protocol/src/typed/test.rs new file mode 100644 index 00000000..c21e5104 --- /dev/null +++ b/ql-protocol/src/typed/test.rs @@ -0,0 +1,287 @@ +use std::{collections::HashMap, sync::{Arc, Mutex}, time::Duration}; + +use async_channel::{Receiver, Sender}; +use oneshot; +use bc_components::{ + Decrypter, EncapsulationPrivateKey, EncapsulationPublicKey, Signer, SigningPublicKey, + SymmetricKey, XID, +}; +use dcbor::CBOR; + +use super::{ + Event, EventHandler, RequestHandler, RequestResponse, Router, RouterPlatform, + TypedExecutorHandle, TypedRequest, +}; +use crate::{ + test_identity::TestIdentity, Executor, ExecutorConfig, PlatformFuture, QlError, QlPlatform, + RequestConfig, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct Ping(u64); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct Pong(u64); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct Notice(u64); + +impl From for CBOR { + fn from(value: Ping) -> Self { + CBOR::from(value.0) + } +} + +impl TryFrom for Ping { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let value: u64 = value.try_into()?; + Ok(Self(value)) + } +} + +impl From for CBOR { + fn from(value: Pong) -> Self { + CBOR::from(value.0) + } +} + +impl TryFrom for Pong { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let value: u64 = value.try_into()?; + Ok(Self(value)) + } +} + +impl From for CBOR { + fn from(value: Notice) -> Self { + CBOR::from(value.0) + } +} + +impl TryFrom for Notice { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let value: u64 = value.try_into()?; + Ok(Self(value)) + } +} + +impl RequestResponse for Ping { + const ID: u64 = 100; + type Response = Pong; +} + +impl Event for Notice { + const ID: u64 = 200; +} + +struct TestPlatform { + tx: Sender>, +} + +impl TestPlatform { + fn new() -> (Self, Receiver>) { + let (tx, rx) = async_channel::unbounded(); + (Self { tx }, rx) + } +} + +impl QlPlatform for TestPlatform { + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { + let tx = self.tx.clone(); + Box::pin(async move { tx.send(message).await.map_err(|_| QlError::Cancelled) }) + } + + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { + Box::pin(async move { tokio::time::sleep(duration).await }) + } +} + +struct TestRouterPlatform { + identity: TestIdentity, + peer: EncapsulationPublicKey, + peer_signing_key: SigningPublicKey, + sessions: Mutex>, +} + +impl TestRouterPlatform { + fn new( + identity: TestIdentity, + peer: EncapsulationPublicKey, + peer_signing_key: SigningPublicKey, + ) -> Self { + Self { + identity, + peer, + peer_signing_key, + sessions: Mutex::new(HashMap::new()), + } + } + + fn xid(&self) -> XID { + self.identity.xid + } +} + +impl RouterPlatform for TestRouterPlatform { + fn lookup_recipient(&self, _recipient: XID) -> Option<&EncapsulationPublicKey> { + Some(&self.peer) + } + + fn lookup_signing_key(&self, sender: XID) -> Option<&SigningPublicKey> { + if sender == XID::new(&self.peer_signing_key) { + Some(&self.peer_signing_key) + } else { + None + } + } + + fn session_for_peer(&self, peer: XID) -> Option { + self.sessions.lock().ok()?.get(&peer).cloned() + } + + fn store_session(&self, peer: XID, key: SymmetricKey) { + if let Ok(mut sessions) = self.sessions.lock() { + sessions.insert(peer, key); + } + } + + fn encapsulation_private_key(&self) -> EncapsulationPrivateKey { + self.identity.private_keys.encapsulation_private_key() + } + + fn signing_key(&self) -> &SigningPublicKey { + &self.identity.signing_public_key + } + + fn message_expiration(&self) -> Duration { + Duration::from_secs(60) + } + + fn signer(&self) -> &dyn Signer { + &self.identity.private_keys + } + + fn handle_error(&self, _e: super::RouterError) {} +} + +struct TestState { + event_tx: Option>, +} + +impl RequestHandler for TestState { + fn handle(&mut self, request: TypedRequest) { + let response = Pong(request.message.0 + 1); + let _ = request.responder.respond(response); + } + + fn default_response() -> Pong { + Pong(0) + } +} + +impl EventHandler for TestState { + fn handle(&mut self, event: Notice) { + if let Some(tx) = self.event_tx.take() { + let _ = tx.send(event.0); + } + } +} + +#[tokio::test(flavor = "current_thread")] +async fn typed_round_trip() { + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let (client_platform, client_outbound) = TestPlatform::new(); + let (server_platform, server_outbound) = TestPlatform::new(); + let config = ExecutorConfig { + default_timeout: Duration::from_secs(1), + }; + + let (mut client_core, client_handle, _client_incoming) = + Executor::new(client_platform, config); + let (mut server_core, server_handle, mut server_incoming) = + Executor::new(server_platform, config); + + tokio::task::spawn_local(async move { client_core.run().await }); + tokio::task::spawn_local(async move { server_core.run().await }); + + tokio::task::spawn_local({ + let server_handle = server_handle.clone(); + async move { + while let Ok(bytes) = client_outbound.recv().await { + server_handle.send_incoming(bytes).unwrap(); + } + } + }); + + tokio::task::spawn_local({ + let client_handle = client_handle.clone(); + async move { + while let Ok(bytes) = server_outbound.recv().await { + client_handle.send_incoming(bytes).unwrap(); + } + } + }); + + let client_identity = TestIdentity::generate(); + let server_identity = TestIdentity::generate(); + let client_platform = Arc::new(TestRouterPlatform::new( + client_identity.clone(), + server_identity.encapsulation_public_key.clone(), + server_identity.signing_public_key.clone(), + )); + let server_platform = Arc::new(TestRouterPlatform::new( + server_identity.clone(), + client_identity.encapsulation_public_key.clone(), + client_identity.signing_public_key.clone(), + )); + let recipient = server_platform.xid(); + + let router = Router::builder() + .add_request_handler::() + .add_event_handler::() + .build(server_platform.clone()); + + let (event_tx, event_rx) = oneshot::channel(); + let mut state = TestState { + event_tx: Some(event_tx), + }; + + tokio::task::spawn_local({ + let server_platform = server_platform.clone(); + async move { + loop { + let event = match server_incoming.next().await { + Ok(event) => event, + Err(_) => break, + }; + if let Err(err) = router.handle(&mut state, event) { + server_platform.handle_error(err); + } + } + } + }); + + let client_typed = TypedExecutorHandle::new(client_handle, client_platform); + + client_typed + .send_event(Notice(7), recipient, Duration::from_secs(60)) + .unwrap(); + let event_value = event_rx.await.expect("event handled"); + assert_eq!(event_value, 7); + + let response = client_typed + .request(Ping(41), recipient, RequestConfig::default()) + .await + .expect("response"); + assert_eq!(response, Pong(42)); + }) + .await; +} diff --git a/ql-protocol/src/wire.rs b/ql-protocol/src/wire.rs new file mode 100644 index 00000000..8220ee86 --- /dev/null +++ b/ql-protocol/src/wire.rs @@ -0,0 +1,434 @@ +use std::time::{SystemTime, UNIX_EPOCH}; + +use bc_components::{EncapsulationCiphertext, EncryptedMessage, Signature, Signer, ARID, XID}; +use dcbor::CBOR; +use thiserror::Error; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MessageKind { + Request, + Response, + Event, +} + +#[derive(Debug, Clone)] +pub struct QlHeader { + pub kind: MessageKind, + pub id: ARID, + pub sender: XID, + pub recipient: XID, + pub valid_until: u64, + pub kem_ct: Option, + pub signature: Option, +} + +#[derive(Debug, Clone)] +pub struct QlHeaderUnsigned { + pub kind: MessageKind, + pub id: ARID, + pub sender: XID, + pub recipient: XID, + pub valid_until: u64, + pub kem_ct: Option, +} + +#[derive(Debug, Clone)] +pub struct EncodeQlConfig { + pub sender: XID, + pub recipient: XID, + pub valid_until: u64, + pub kem_ct: Option, + pub sign_header: bool, +} + +impl QlHeader { + pub fn unsigned(&self) -> QlHeaderUnsigned { + QlHeaderUnsigned { + kind: self.kind, + id: self.id, + sender: self.sender, + recipient: self.recipient, + valid_until: self.valid_until, + kem_ct: self.kem_ct.clone(), + } + } + + pub fn aad_data(&self) -> Vec { + CBOR::from(self.unsigned()).to_cbor_data() + } +} + +impl QlHeaderUnsigned { + pub fn aad_data(&self) -> Vec { + CBOR::from(self.clone()).to_cbor_data() + } +} + +impl From for dcbor::CBOR { + fn from(value: QlHeader) -> Self { + dcbor::CBOR::from(vec![ + dcbor::CBOR::from(value.kind), + dcbor::CBOR::from(value.id), + dcbor::CBOR::from(value.sender), + dcbor::CBOR::from(value.recipient), + dcbor::CBOR::from(value.valid_until), + option_to_cbor(value.kem_ct), + option_to_cbor(value.signature), + ]) + } +} + +impl From for dcbor::CBOR { + fn from(value: QlHeaderUnsigned) -> Self { + dcbor::CBOR::from(vec![ + dcbor::CBOR::from(value.kind), + dcbor::CBOR::from(value.id), + dcbor::CBOR::from(value.sender), + dcbor::CBOR::from(value.recipient), + dcbor::CBOR::from(value.valid_until), + option_to_cbor(value.kem_ct), + ]) + } +} + +impl TryFrom for QlHeader { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let array = value.try_into_array()?; + if array.len() != 7 { + return Err(dcbor::Error::msg("invalid header length")); + } + let kind = MessageKind::try_from(array[0].clone())?; + let id: ARID = array[1].clone().try_into()?; + let sender: XID = array[2].clone().try_into()?; + let recipient: XID = array[3].clone().try_into()?; + let valid_until: u64 = array[4].clone().try_into()?; + let kem_ct: Option = option_from_cbor(array[5].clone())?; + let signature: Option = option_from_cbor(array[6].clone())?; + Ok(Self { + kind, + id, + sender, + recipient, + valid_until, + kem_ct, + signature, + }) + } +} + +fn option_to_cbor(value: Option) -> CBOR +where + T: Into, +{ + value.map(Into::into).unwrap_or_else(CBOR::null) +} + +fn option_from_cbor(value: CBOR) -> dcbor::Result> +where + T: TryFrom, +{ + if value.is_null() { + Ok(None) + } else { + Ok(Some(value.try_into()?)) + } +} + +#[derive(Debug, Error)] +pub enum DecodeError { + #[error("invalid message encoding")] + InvalidEncoding, + #[error("message expired")] + Expired, + #[error(transparent)] + Cbor(#[from] dcbor::Error), +} + +#[derive(Debug)] +pub struct DecodeErrContext { + pub error: DecodeError, + pub header: Option, +} + +#[derive(Debug, Clone)] +pub struct QlMessage { + pub header: QlHeader, + pub payload: EncryptedMessage, +} + +pub fn encode_ql_message( + kind: MessageKind, + id: ARID, + config: EncodeQlConfig, + payload: EncryptedMessage, + signer: &dyn Signer, +) -> Vec { + let header_unsigned = QlHeaderUnsigned { + kind, + id, + sender: config.sender, + recipient: config.recipient, + valid_until: config.valid_until, + kem_ct: config.kem_ct.clone(), + }; + let signature = if config.sign_header { + Some( + signer + .sign(&header_unsigned.aad_data()) + .expect("failed to sign header"), + ) + } else { + None + }; + let header = QlHeader { + kind, + id, + sender: config.sender, + recipient: config.recipient, + valid_until: config.valid_until, + kem_ct: config.kem_ct, + signature, + }; + let cbor = CBOR::from(vec![CBOR::from(header), CBOR::from(payload)]); + cbor.to_cbor_data() +} + +pub fn decode_ql_message(bytes: &[u8]) -> Result { + let cbor = dcbor::CBOR::try_from_data(bytes).map_err(|error| DecodeErrContext { + error: DecodeError::Cbor(error), + header: None, + })?; + let array = cbor.try_into_array().map_err(|error| DecodeErrContext { + error: DecodeError::Cbor(error), + header: None, + })?; + if array.len() != 2 { + return Err(DecodeErrContext { + error: DecodeError::InvalidEncoding, + header: None, + }); + } + let header = QlHeader::try_from(array[0].clone()).map_err(|error| DecodeErrContext { + error: DecodeError::Cbor(error), + header: None, + })?; + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_secs()) + .unwrap_or(0); + if now > header.valid_until { + return Err(DecodeErrContext { + error: DecodeError::Expired, + header: Some(header), + }); + } + let payload: EncryptedMessage = + array[1] + .clone() + .try_into() + .map_err(|error| DecodeErrContext { + error: DecodeError::Cbor(error), + header: Some(header.clone()), + })?; + Ok(QlMessage { header, payload }) +} + +impl TryFrom for MessageKind { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let kind: u64 = value.try_into()?; + match kind { + 1 => Ok(MessageKind::Request), + 2 => Ok(MessageKind::Response), + 3 => Ok(MessageKind::Event), + _ => Err(dcbor::Error::msg("unknown message kind")), + } + } +} + +impl From for CBOR { + fn from(value: MessageKind) -> Self { + let kind = match value { + MessageKind::Request => 1, + MessageKind::Response => 2, + MessageKind::Event => 3, + }; + CBOR::from(kind) + } +} + +#[cfg(test)] +mod tests { + use bc_components::Verifier; + + use super::*; + use crate::test_identity::TestIdentity; + + #[test] + fn round_trip() { + let sender = TestIdentity::generate(); + let recipient = TestIdentity::generate(); + let recipient_xid = recipient.xid; + let sender_xid = sender.xid; + let header_id = ARID::new(); + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_secs()) + .unwrap_or(0); + let valid_until = now.saturating_add(60); + let (session_key, kem_ct) = recipient + .encapsulation_public_key + .encapsulate_new_shared_secret(); + let header_unsigned = QlHeaderUnsigned { + kind: MessageKind::Request, + id: header_id, + sender: sender_xid, + recipient: recipient_xid, + valid_until, + kem_ct: Some(kem_ct.clone()), + }; + let payload = CBOR::from("secret"); + let payload_bytes = payload.to_cbor_data(); + let encrypted_payload = session_key.encrypt( + payload_bytes, + Some(header_unsigned.aad_data()), + None::, + ); + + let bytes = encode_ql_message( + MessageKind::Request, + header_id, + EncodeQlConfig { + sender: sender_xid, + recipient: recipient_xid, + valid_until, + kem_ct: Some(kem_ct), + sign_header: true, + }, + encrypted_payload, + &sender.private_keys, + ); + let decoded = decode_ql_message(&bytes).expect("decode failed"); + + assert_eq!(decoded.header.kind, MessageKind::Request); + assert_eq!(decoded.header.id, header_id); + assert_eq!(decoded.header.recipient, recipient_xid); + assert_eq!(decoded.header.sender, sender_xid); + + let signing_data = decoded.header.unsigned().aad_data(); + let signature = decoded.header.signature.as_ref().expect("signature"); + assert!(sender.signing_public_key.verify(signature, &signing_data)); + + let decrypted = session_key.decrypt(&decoded.payload).expect("decrypt"); + let decrypted_cbor = CBOR::try_from_data(decrypted).expect("cbor"); + assert_eq!(decrypted_cbor, payload); + } + + #[test] + fn header_size() { + let size = std::mem::size_of::(); + println!("header size: {} bytes", size); + assert!(size > 0); + } + + #[test] + fn encoded_message_size() { + let sender = TestIdentity::generate(); + let recipient = TestIdentity::generate(); + let recipient_xid = recipient.xid; + let sender_xid = sender.xid; + let header_id = ARID::new(); + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_secs()) + .unwrap_or(0); + let valid_until = now.saturating_add(60); + let (session_key, kem_ct) = recipient + .encapsulation_public_key + .encapsulate_new_shared_secret(); + let header_unsigned = QlHeaderUnsigned { + kind: MessageKind::Request, + id: header_id, + sender: sender_xid, + recipient: recipient_xid, + valid_until, + kem_ct: Some(kem_ct.clone()), + }; + let payload = CBOR::from("size"); + let payload_bytes = payload.to_cbor_data(); + let encrypted_payload = session_key.encrypt( + payload_bytes, + Some(header_unsigned.aad_data()), + None::, + ); + + let bytes = encode_ql_message( + MessageKind::Request, + header_id, + EncodeQlConfig { + sender: sender_xid, + recipient: recipient_xid, + valid_until, + kem_ct: Some(kem_ct), + sign_header: true, + }, + encrypted_payload, + &sender.private_keys, + ); + + println!("encoded message size: {} bytes", bytes.len()); + assert!(!bytes.is_empty()); + } + + #[test] + fn steady_state_message_size() { + let sender = TestIdentity::generate(); + let recipient = TestIdentity::generate(); + let recipient_xid = recipient.xid; + let sender_xid = sender.xid; + let header_id = ARID::new(); + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_secs()) + .unwrap_or(0); + let valid_until = now.saturating_add(60); + let (session_key, _kem_ct) = recipient + .encapsulation_public_key + .encapsulate_new_shared_secret(); + let header_unsigned = QlHeaderUnsigned { + kind: MessageKind::Request, + id: header_id, + sender: sender_xid, + recipient: recipient_xid, + valid_until, + kem_ct: None, + }; + let payload = CBOR::from("steady"); + let payload_bytes = payload.to_cbor_data(); + let encrypted_payload = session_key.encrypt( + payload_bytes, + Some(header_unsigned.aad_data()), + None::, + ); + + let bytes = encode_ql_message( + MessageKind::Request, + header_id, + EncodeQlConfig { + sender: sender_xid, + recipient: recipient_xid, + valid_until, + kem_ct: None, + sign_header: false, + }, + encrypted_payload, + &sender.private_keys, + ); + + println!("steady-state message size: {} bytes", bytes.len()); + assert!(!bytes.is_empty()); + } +} From 9ff17f1c72b836274922ae4cf2d30ee9cec473f0 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 18 Mar 2026 01:16:42 -0400 Subject: [PATCH 002/304] ql: consolidate runtime, handshake flow, and promote ql2 to ql --- Cargo.lock | 2 +- Cargo.toml | 2 +- ql-protocol/src/executor.rs | 718 -------------- ql-protocol/src/lib.rs | 19 - ql-protocol/src/test_identity.rs | 28 - ql-protocol/src/typed/handle.rs | 202 ---- ql-protocol/src/typed/mod.rs | 123 --- ql-protocol/src/typed/router.rs | 297 ------ ql-protocol/src/typed/test.rs | 287 ------ ql-protocol/src/wire.rs | 434 -------- {ql-protocol => ql}/Cargo.toml | 8 +- ql/src/crypto/handshake.rs | 140 +++ ql/src/crypto/heartbeat.rs | 40 + ql/src/crypto/message.rs | 74 ++ ql/src/crypto/mod.rs | 25 + ql/src/crypto/pair.rs | 127 +++ ql/src/id.rs | 95 ++ ql/src/lib.rs | 45 + ql/src/platform.rs | 33 + ql/src/router.rs | 210 ++++ ql/src/runtime/core.rs | 1008 +++++++++++++++++++ ql/src/runtime/handle.rs | 161 +++ ql/src/runtime/internal.rs | 308 ++++++ ql/src/runtime/mod.rs | 146 +++ ql/src/runtime/tests.rs | 1588 ++++++++++++++++++++++++++++++ ql/src/wire/handshake.rs | 120 +++ ql/src/wire/heartbeat.rs | 32 + ql/src/wire/message.rs | 141 +++ ql/src/wire/mod.rs | 182 ++++ ql/src/wire/pair.rs | 68 ++ 30 files changed, 4549 insertions(+), 2114 deletions(-) delete mode 100644 ql-protocol/src/executor.rs delete mode 100644 ql-protocol/src/lib.rs delete mode 100644 ql-protocol/src/test_identity.rs delete mode 100644 ql-protocol/src/typed/handle.rs delete mode 100644 ql-protocol/src/typed/mod.rs delete mode 100644 ql-protocol/src/typed/router.rs delete mode 100644 ql-protocol/src/typed/test.rs delete mode 100644 ql-protocol/src/wire.rs rename {ql-protocol => ql}/Cargo.toml (67%) create mode 100644 ql/src/crypto/handshake.rs create mode 100644 ql/src/crypto/heartbeat.rs create mode 100644 ql/src/crypto/message.rs create mode 100644 ql/src/crypto/mod.rs create mode 100644 ql/src/crypto/pair.rs create mode 100644 ql/src/id.rs create mode 100644 ql/src/lib.rs create mode 100644 ql/src/platform.rs create mode 100644 ql/src/router.rs create mode 100644 ql/src/runtime/core.rs create mode 100644 ql/src/runtime/handle.rs create mode 100644 ql/src/runtime/internal.rs create mode 100644 ql/src/runtime/mod.rs create mode 100644 ql/src/runtime/tests.rs create mode 100644 ql/src/wire/handshake.rs create mode 100644 ql/src/wire/heartbeat.rs create mode 100644 ql/src/wire/message.rs create mode 100644 ql/src/wire/mod.rs create mode 100644 ql/src/wire/pair.rs diff --git a/Cargo.lock b/Cargo.lock index c953474e..e6733b10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1943,7 +1943,7 @@ dependencies = [ ] [[package]] -name = "ql-protocol" +name = "ql" version = "0.1.0" dependencies = [ "async-channel", diff --git a/Cargo.toml b/Cargo.toml index 540aa7ab..80d9c698 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["api", "backup-shard", "btp", "ql-protocol", "quantum-link-macros"] +members = ["api", "backup-shard", "btp", "ql", "quantum-link-macros"] [workspace.package] homepage = "https://github.com/Foundation-Devices/foundation-api" diff --git a/ql-protocol/src/executor.rs b/ql-protocol/src/executor.rs deleted file mode 100644 index 65b79ccc..00000000 --- a/ql-protocol/src/executor.rs +++ /dev/null @@ -1,718 +0,0 @@ -use std::{ - cmp::{Ordering, Reverse}, - collections::{BinaryHeap, HashMap, VecDeque}, - future::Future, - pin::{pin, Pin}, - task::{Context, Poll}, - time::{Duration, Instant}, -}; - -use async_channel::{Receiver, Sender, WeakSender}; -use bc_components::{EncryptedMessage, Signer, ARID, XID}; - -use super::wire::{ - decode_ql_message, encode_ql_message, DecodeErrContext, EncodeQlConfig, MessageKind, QlMessage, -}; - -pub type PlatformFuture<'a, T> = Pin + 'a>>; - -pub trait QlPlatform { - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>>; - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()>; -} - -#[derive(Debug)] -pub enum QlError { - Cancelled, - Protocol, - SendFailed, - Timeout, - Decode(super::wire::DecodeError), -} - -#[derive(Debug, Clone, Copy)] -pub struct RequestConfig { - pub timeout: Option, -} - -impl Default for RequestConfig { - fn default() -> Self { - Self { timeout: None } - } -} - -#[derive(Debug, Clone, Copy)] -pub struct ExecutorConfig { - pub default_timeout: Duration, -} - -#[derive(Debug)] -pub struct InboundRequest { - pub message: QlMessage, - pub respond_to: Responder, -} - -#[derive(Debug)] -pub struct InboundEvent { - pub message: QlMessage, -} - -#[derive(Debug)] -pub enum HandlerEvent { - Request(InboundRequest), - Event(InboundEvent), -} - -#[derive(Debug, Clone)] -pub struct Responder { - id: ARID, - recipient: XID, - tx: Sender, -} - -impl Responder { - pub fn id(&self) -> ARID { - self.id - } - - pub fn recipient(&self) -> XID { - self.recipient - } - - pub fn respond( - self, - payload: EncryptedMessage, - encode_config: EncodeQlConfig, - signer: &dyn Signer, - ) -> Result<(), QlError> { - let bytes = encode_ql_message( - MessageKind::Response, - self.id, - encode_config, - payload, - signer, - ); - self.tx - .send_blocking(ExecutorEvent::SendResponse { bytes }) - .map_err(|_| QlError::Cancelled) - } -} - -#[derive(Debug)] -pub struct HandlerStream { - rx: Receiver, -} - -impl HandlerStream { - pub async fn next(&mut self) -> Result { - self.rx.recv().await.map_err(|_| QlError::Cancelled) - } -} - -impl futures_lite::Stream for HandlerStream { - type Item = Result; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> std::task::Poll> { - let rx = unsafe { self.as_mut().map_unchecked_mut(|s| &mut s.rx) }; - match rx.poll_next(cx) { - Poll::Ready(Some(event)) => Poll::Ready(Some(Ok(event))), - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - } - } -} - -#[derive(Debug)] -enum ExecutorEvent { - SendRequest { - id: ARID, - bytes: Vec, - respond_to: oneshot::Sender>, - config: RequestConfig, - }, - SendEvent { - bytes: Vec, - }, - SendResponse { - bytes: Vec, - }, - Incoming { - message: QlMessage, - }, - IncomingDecodeError { - context: DecodeErrContext, - }, -} - -#[derive(Debug, Clone)] -pub struct ExecutorHandle { - tx: Sender, -} - -pub struct ExecutorResponse { - rx: oneshot::Receiver>, -} - -impl std::future::Future for ExecutorResponse { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - pin!(&mut self.rx) - .poll(cx) - .map(|result| result.unwrap_or(Err(QlError::Cancelled))) - } -} - -impl ExecutorHandle { - pub fn request( - &self, - id: ARID, - payload: EncryptedMessage, - encode_config: EncodeQlConfig, - request_config: RequestConfig, - signer: &dyn Signer, - ) -> ExecutorResponse { - let bytes = encode_ql_message(MessageKind::Request, id, encode_config, payload, signer); - let (tx, rx) = oneshot::channel(); - self.tx - .send_blocking(ExecutorEvent::SendRequest { - id, - bytes, - respond_to: tx, - config: request_config, - }) - .unwrap(); - ExecutorResponse { rx } - } - - pub fn send_event( - &self, - id: ARID, - payload: EncryptedMessage, - encode_config: EncodeQlConfig, - signer: &dyn Signer, - ) { - let tx = self.tx.clone(); - let bytes = encode_ql_message(MessageKind::Event, id, encode_config, payload, signer); - tx.send_blocking(ExecutorEvent::SendEvent { bytes }) - .unwrap(); - } - - pub fn send_incoming(&self, bytes: Vec) -> Result<(), QlError> { - match decode_ql_message(&bytes) { - Ok(message) => self - .tx - .send_blocking(ExecutorEvent::Incoming { message }) - .map_err(|_| QlError::Cancelled), - Err(context) => { - let _ = self - .tx - .send_blocking(ExecutorEvent::IncomingDecodeError { context }); - Ok(()) - } - } - } -} - -pub struct Executor

{ - platform: P, - rx: Receiver, - tx: WeakSender, - config: ExecutorConfig, - incoming: Sender, -} - -struct ExecutorState<'a> { - pending: HashMap, - timeouts: BinaryHeap>, - outbound: VecDeque, - in_flight: Option>, -} - -struct OutboundBytes { - id: Option, - bytes: Vec, -} - -struct InFlightWrite<'a> { - id: Option, - future: PlatformFuture<'a, Result<(), QlError>>, -} - -struct PendingEntry { - tx: oneshot::Sender>, -} - -#[derive(Debug, Clone)] -struct TimeoutEntry { - deadline: Instant, - id: ARID, -} - -impl PartialEq for TimeoutEntry { - fn eq(&self, other: &Self) -> bool { - self.deadline == other.deadline - } -} - -impl Eq for TimeoutEntry {} - -impl PartialOrd for TimeoutEntry { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for TimeoutEntry { - fn cmp(&self, other: &Self) -> Ordering { - self.deadline.cmp(&other.deadline) - } -} - -enum LoopStep { - Event(Result), - WriteDone { - id: Option, - result: Result<(), QlError>, - }, - Timeout, -} - -impl

Executor

-where - P: QlPlatform, -{ - pub fn new(platform: P, config: ExecutorConfig) -> (Self, ExecutorHandle, HandlerStream) { - let (tx, rx) = async_channel::unbounded(); - let (incoming_tx, incoming_rx) = async_channel::unbounded(); - ( - Self { - rx, - tx: tx.downgrade(), - platform, - config, - incoming: incoming_tx, - }, - ExecutorHandle { tx }, - HandlerStream { rx: incoming_rx }, - ) - } - - pub async fn run<'a>(&'a mut self) { - let mut state = ExecutorState { - pending: HashMap::new(), - timeouts: BinaryHeap::new(), - outbound: VecDeque::new(), - in_flight: None, - }; - - loop { - Self::process_timeouts(&mut state); - - if state.in_flight.is_none() { - if let Some(message) = state.outbound.pop_front() { - state.in_flight = Some(InFlightWrite { - id: message.id, - future: self.platform.write_message(message.bytes), - }); - } - } - - let step = { - let recv_future = self.rx.recv(); - futures_lite::pin!(recv_future); - - let mut sleep_future = - Self::next_timeout_sleep(&state).map(|duration| self.platform.sleep(duration)); - - futures_lite::future::poll_fn(|cx| { - if let Some(in_flight) = state.in_flight.as_mut() { - if let Poll::Ready(result) = in_flight.future.as_mut().poll(cx) { - return Poll::Ready(LoopStep::WriteDone { - id: in_flight.id, - result, - }); - } - } - - if let Some(sleep_future) = sleep_future.as_mut() { - if let Poll::Ready(_result) = sleep_future.as_mut().poll(cx) { - return Poll::Ready(LoopStep::Timeout); - } - } - - match recv_future.as_mut().poll(cx) { - Poll::Ready(event) => Poll::Ready(LoopStep::Event(event)), - Poll::Pending => Poll::Pending, - } - }) - .await - }; - - match step { - LoopStep::Event(Ok(event)) => match event { - ExecutorEvent::SendRequest { - id, - bytes, - respond_to, - config, - } => { - let effective_timeout = - config.timeout.unwrap_or(self.config.default_timeout); - if effective_timeout.is_zero() { - let _ = respond_to.send(Err(QlError::Timeout)); - continue; - } - let deadline = Instant::now() + effective_timeout; - state.pending.insert(id, PendingEntry { tx: respond_to }); - state.timeouts.push(Reverse(TimeoutEntry { deadline, id })); - state.outbound.push_back(OutboundBytes { - id: Some(id), - bytes, - }); - } - ExecutorEvent::SendEvent { bytes } => { - state.outbound.push_back(OutboundBytes { id: None, bytes }); - } - ExecutorEvent::SendResponse { bytes } => { - state.outbound.push_back(OutboundBytes { id: None, bytes }); - } - ExecutorEvent::Incoming { message } => match message.header.kind { - MessageKind::Response => { - if let Some(entry) = state.pending.remove(&message.header.id) { - let _ = entry.tx.send(Ok(message)); - } - } - MessageKind::Request => { - let Some(tx) = self.tx.upgrade() else { return }; - let responder = Responder { - id: message.header.id, - recipient: message.header.sender, - tx, - }; - let _ = self - .incoming - .send(HandlerEvent::Request(InboundRequest { - message, - respond_to: responder, - })) - .await; - } - MessageKind::Event => { - let _ = self - .incoming - .send(HandlerEvent::Event(InboundEvent { message })) - .await; - } - }, - ExecutorEvent::IncomingDecodeError { context } => { - let Some(header) = context.header else { - continue; - }; - if header.kind == MessageKind::Response { - if let Some(entry) = state.pending.remove(&header.id) { - let _ = entry.tx.send(Err(QlError::Decode(context.error))); - } - } - } - }, - LoopStep::Event(Err(_)) => break, - LoopStep::WriteDone { id, result } => { - state.in_flight = None; - if let Err(e) = result { - if let Some(id) = id { - if let Some(entry) = state.pending.remove(&id) { - let _ = entry.tx.send(Err(e)); - } - } - } - } - LoopStep::Timeout => { - Self::process_timeouts(&mut state); - } - } - } - } - - fn process_timeouts(state: &mut ExecutorState<'_>) { - let now = Instant::now(); - while let Some(Reverse(entry)) = state.timeouts.peek().cloned() { - if entry.deadline > now { - break; - } - state.timeouts.pop(); - if let Some(pending) = state.pending.remove(&entry.id) { - let _ = pending.tx.send(Err(QlError::Timeout)); - } - } - } - - fn next_timeout_sleep(state: &ExecutorState<'_>) -> Option { - let Reverse(entry) = state.timeouts.peek()?; - let now = Instant::now(); - Some(entry.deadline.saturating_duration_since(now)) - } -} - -#[cfg(test)] -mod test { - use super::*; - use bc_components::{Nonce, SymmetricKey}; - use std::time::{SystemTime, UNIX_EPOCH}; - use crate::test_identity::TestIdentity; - - struct TestPlatform { - tx: Sender>, - } - - impl TestPlatform { - fn new() -> (Self, Receiver>) { - let (tx, rx) = async_channel::unbounded(); - (Self { tx }, rx) - } - } - - impl QlPlatform for TestPlatform { - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { - let tx = self.tx.clone(); - Box::pin(async move { tx.send(message).await.map_err(|_| QlError::Cancelled) }) - } - - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { - Box::pin(async move { - tokio::time::sleep(duration).await; - }) - } - } - - fn now_secs() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|duration| duration.as_secs()) - .unwrap_or(0) - } - - fn encrypt_payload(data: &str) -> EncryptedMessage { - let key = SymmetricKey::new(); - key.encrypt( - data.as_bytes(), - None::>, - None::, - ) - } - - fn encode_config(sender: XID, recipient: XID, valid_until: u64) -> EncodeQlConfig { - EncodeQlConfig { - sender, - recipient, - valid_until, - kem_ct: None, - sign_header: false, - } - } - - #[tokio::test(flavor = "current_thread")] - async fn request_response_round_trip() { - let local = tokio::task::LocalSet::new(); - local - .run_until(async { - let (platform, outbound_rx) = TestPlatform::new(); - let config = ExecutorConfig { - default_timeout: Duration::from_millis(50), - }; - let (mut core, handle, _incoming) = Executor::new(platform, config); - tokio::task::spawn_local(async move { core.run().await }); - - let requester = TestIdentity::generate(); - let responder = TestIdentity::generate(); - let recipient_xid = responder.xid; - let valid_until = now_secs().saturating_add(60); - let payload = encrypt_payload("ping"); - let request_id = ARID::new(); - - let response_task = tokio::task::spawn_local({ - let handle = handle.clone(); - let signer = requester.private_keys.clone(); - let config = encode_config(requester.xid, recipient_xid, valid_until); - async move { - handle - .request( - request_id, - payload, - config, - RequestConfig::default(), - &signer, - ) - .await - } - }); - - let outbound = outbound_rx.recv().await.expect("no outbound request"); - let outbound_message = decode_ql_message(&outbound).expect("decode outbound"); - assert_eq!(outbound_message.header.kind, MessageKind::Request); - let request_id = outbound_message.header.id; - - let response_payload = encrypt_payload("pong"); - let response_bytes = encode_ql_message( - MessageKind::Response, - request_id, - encode_config( - responder.xid, - outbound_message.header.sender, - now_secs().saturating_add(60), - ), - response_payload, - &responder.private_keys, - ); - handle.send_incoming(response_bytes).unwrap(); - - let response = response_task.await.unwrap().unwrap(); - assert_eq!(response.header.kind, MessageKind::Response); - assert_eq!(response.header.id, request_id); - }) - .await; - } - - #[tokio::test(flavor = "current_thread")] - async fn request_timeout_returns_error() { - let local = tokio::task::LocalSet::new(); - local - .run_until(async { - let (platform, _outbound_rx) = TestPlatform::new(); - let config = ExecutorConfig { - default_timeout: Duration::from_millis(5), - }; - let (mut core, handle, _incoming) = Executor::new(platform, config); - tokio::task::spawn_local(async move { core.run().await }); - - let requester = TestIdentity::generate(); - let recipient_xid = requester.xid; - let valid_until = now_secs().saturating_add(60); - let payload = encrypt_payload("timeout"); - let request_id = ARID::new(); - let result = handle - .request( - request_id, - payload, - encode_config(requester.xid, recipient_xid, valid_until), - RequestConfig { - timeout: Some(Duration::from_millis(1)), - }, - &requester.private_keys, - ) - .await; - - assert!(matches!(result, Err(QlError::Timeout))); - }) - .await; - } - - #[tokio::test(flavor = "current_thread")] - async fn event_is_forwarded() { - let local = tokio::task::LocalSet::new(); - local - .run_until(async { - let (platform, _outbound_rx) = TestPlatform::new(); - let config = ExecutorConfig { - default_timeout: Duration::from_secs(1), - }; - let (mut core, handle, mut handler_stream) = Executor::new(platform, config); - tokio::task::spawn_local(async move { core.run().await }); - - let sender = TestIdentity::generate(); - let recipient = TestIdentity::generate(); - let recipient_xid = recipient.xid; - let event_id = ARID::new(); - let payload = encrypt_payload("event"); - let event_bytes = encode_ql_message( - MessageKind::Event, - event_id, - encode_config( - sender.xid, - recipient_xid, - now_secs().saturating_add(60), - ), - payload, - &sender.private_keys, - ); - - handle.send_incoming(event_bytes).unwrap(); - - let event = handler_stream.next().await.unwrap(); - match event { - HandlerEvent::Event(event) => { - assert_eq!(event.message.header.kind, MessageKind::Event); - assert_eq!(event.message.header.id, event_id); - } - HandlerEvent::Request(_) => panic!("unexpected request"), - } - }) - .await; - } - - #[tokio::test(flavor = "current_thread")] - async fn expired_response_returns_error() { - let local = tokio::task::LocalSet::new(); - local - .run_until(async { - let (platform, outbound_rx) = TestPlatform::new(); - let config = ExecutorConfig { - default_timeout: Duration::from_secs(2), - }; - let (mut core, handle, _incoming) = Executor::new(platform, config); - tokio::task::spawn_local(async move { core.run().await }); - - let requester = TestIdentity::generate(); - let responder = TestIdentity::generate(); - let recipient_xid = responder.xid; - let valid_until = now_secs().saturating_add(60); - let payload = encrypt_payload("ping"); - let request_id = ARID::new(); - - let response_task = tokio::task::spawn_local({ - let handle = handle.clone(); - let signer = requester.private_keys.clone(); - let config = encode_config(requester.xid, recipient_xid, valid_until); - async move { - handle - .request( - request_id, - payload, - config, - RequestConfig { - timeout: Some(Duration::from_secs(3)), - }, - &signer, - ) - .await - } - }); - - let outbound = outbound_rx.recv().await.expect("no outbound request"); - let outbound_message = decode_ql_message(&outbound).expect("decode outbound"); - let request_id = outbound_message.header.id; - - let response_payload = encrypt_payload("pong"); - let response_bytes = encode_ql_message( - MessageKind::Response, - request_id, - encode_config( - responder.xid, - outbound_message.header.sender, - 0, - ), - response_payload, - &responder.private_keys, - ); - tokio::time::sleep(Duration::from_secs(1)).await; - handle.send_incoming(response_bytes).unwrap(); - - let response = response_task.await.unwrap(); - assert!(matches!(response, Err(QlError::Decode(_)))); - }) - .await; - } -} diff --git a/ql-protocol/src/lib.rs b/ql-protocol/src/lib.rs deleted file mode 100644 index 67280849..00000000 --- a/ql-protocol/src/lib.rs +++ /dev/null @@ -1,19 +0,0 @@ -pub mod executor; -pub mod typed; -pub mod wire; - -#[cfg(test)] -mod test_identity; - -pub use executor::{ - Executor, ExecutorConfig, ExecutorHandle, HandlerEvent, HandlerStream, InboundEvent, - InboundRequest, PlatformFuture, QlError, QlPlatform, RequestConfig, Responder, -}; -pub use typed::{ - Event, EventHandler, QlCodec, RequestHandler, RequestResponse, Router, RouterBuilder, - RouterError, TypedExecutorHandle, TypedPayload, TypedRequest, TypedResponder, -}; -pub use wire::{ - decode_ql_message, encode_ql_message, DecodeErrContext, DecodeError, EncodeQlConfig, - MessageKind, QlHeader, QlHeaderUnsigned, QlMessage, -}; diff --git a/ql-protocol/src/test_identity.rs b/ql-protocol/src/test_identity.rs deleted file mode 100644 index 478d294e..00000000 --- a/ql-protocol/src/test_identity.rs +++ /dev/null @@ -1,28 +0,0 @@ -use bc_components::{ - EncapsulationPublicKey, EncapsulationScheme, PrivateKeys, SignatureScheme, SigningPublicKey, - XID, -}; - -#[derive(Debug, Clone)] -pub(crate) struct TestIdentity { - pub(crate) private_keys: PrivateKeys, - pub(crate) signing_public_key: SigningPublicKey, - pub(crate) encapsulation_public_key: EncapsulationPublicKey, - pub(crate) xid: XID, -} - -impl TestIdentity { - pub(crate) fn generate() -> Self { - let (signing_private_key, signing_public_key) = SignatureScheme::MLDSA44.keypair(); - let (encapsulation_private_key, encapsulation_public_key) = - EncapsulationScheme::MLKEM512.keypair(); - let private_keys = PrivateKeys::with_keys(signing_private_key, encapsulation_private_key); - let xid = XID::new(&signing_public_key); - Self { - private_keys, - signing_public_key, - encapsulation_public_key, - xid, - } - } -} diff --git a/ql-protocol/src/typed/handle.rs b/ql-protocol/src/typed/handle.rs deleted file mode 100644 index 0c41c648..00000000 --- a/ql-protocol/src/typed/handle.rs +++ /dev/null @@ -1,202 +0,0 @@ -use std::{ - future::Future, - marker::PhantomData, - pin::Pin, - sync::Arc, - task::{Context, Poll}, - time::{Duration, SystemTime, UNIX_EPOCH}, -}; - -use bc_components::{EncapsulationCiphertext, ARID, XID}; - -use super::{Event, RequestResponse, RouterError, RouterPlatform, TypedPayload}; -use crate::{ - executor::ExecutorResponse, EncodeQlConfig, ExecutorHandle, MessageKind, QlCodec, - QlHeaderUnsigned, RequestConfig, -}; - -#[derive(Clone)] -pub struct TypedExecutorHandle { - handle: ExecutorHandle, - platform: Arc, -} - -pub struct Response { - inner: ResponseInner, - _type: PhantomData T>, -} - -enum ResponseInner { - Err(Option), - Ok { - response: ExecutorResponse, - platform: Arc, - }, -} - -impl Future for Response -where - T: QlCodec, -{ - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match &mut self.inner { - ResponseInner::Err(e) => { - let e = e.take(); - let e = e.unwrap_or(RouterError::Send(crate::QlError::Cancelled)); - Poll::Ready(Err(e)) - } - ResponseInner::Ok { response, platform } => { - Pin::new(response).poll(cx).map(|response| { - let response = response?; - let session_key = platform - .session_for_peer(response.header.sender) - .ok_or(RouterError::MissingSession(response.header.sender))?; - let decrypted = platform.decrypt_message( - &session_key, - &response.header.aad_data(), - &response.payload, - )?; - let message = T::try_from(decrypted)?; - Ok(message) - }) - } - } - } -} - -impl TypedExecutorHandle { - pub fn new(handle: ExecutorHandle, platform: Arc) -> Self { - Self { handle, platform } - } - - pub fn request( - &self, - message: M, - recipient: XID, - request_config: RequestConfig, - ) -> Response - where - M: RequestResponse, - { - let platform = self.platform.clone(); - let payload = TypedPayload { - message_id: M::ID, - payload: message.into(), - }; - let message_id = ARID::new(); - let inner = match self.encrypt_payload_for_recipient( - recipient, - MessageKind::Request, - message_id, - payload.into(), - ) { - Ok((encrypted, config)) => { - let response = self.handle.request( - message_id, - encrypted, - config, - request_config, - platform.signer(), - ); - - ResponseInner::Ok { - response, - platform: self.platform.clone(), - } - } - Err(e) => ResponseInner::Err(Some(e)), - }; - Response { - inner, - _type: Default::default(), - } - } - - pub fn send_event( - &self, - message: M, - recipient: XID, - _valid_for: Duration, - ) -> Result<(), RouterError> - where - M: Event, - { - let payload = TypedPayload { - message_id: M::ID, - payload: message.into(), - }; - let message_id = ARID::new(); - let (encrypted, config) = self.encrypt_payload_for_recipient( - recipient, - MessageKind::Event, - message_id, - payload.into(), - )?; - self.handle - .send_event(message_id, encrypted, config, self.platform.signer()); - Ok(()) - } - - fn encrypt_payload_for_recipient( - &self, - recipient: XID, - kind: MessageKind, - message_id: ARID, - payload: dcbor::CBOR, - ) -> Result<(bc_components::EncryptedMessage, EncodeQlConfig), RouterError> { - let platform = self.platform.as_ref(); - let (session_key, kem_ct, sign_header) = match platform.session_for_peer(recipient) { - Some(session_key) => (session_key, None, false), - None => self.create_session(recipient)?, - }; - let valid_until = now_secs().saturating_add(platform.message_expiration().as_secs()); - let header_unsigned = QlHeaderUnsigned { - kind, - id: message_id, - sender: platform.sender_xid(), - recipient, - valid_until, - kem_ct: kem_ct.clone(), - }; - let aad = header_unsigned.aad_data(); - let payload_bytes = payload.to_cbor_data(); - let encrypted = session_key.encrypt(payload_bytes, Some(aad), None::); - let config = EncodeQlConfig { - sender: platform.sender_xid(), - recipient, - valid_until, - kem_ct, - sign_header, - }; - Ok((encrypted, config)) - } - - fn create_session( - &self, - recipient: XID, - ) -> Result< - ( - bc_components::SymmetricKey, - Option, - bool, - ), - RouterError, - > { - let platform = self.platform.as_ref(); - let recipient_key = platform - .lookup_recipient(recipient) - .ok_or(RouterError::UnknownRecipient(recipient))?; - let (session_key, kem_ct) = recipient_key.encapsulate_new_shared_secret(); - platform.store_session(recipient, session_key.clone()); - Ok((session_key, Some(kem_ct), true)) - } -} - -fn now_secs() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|duration| duration.as_secs()) - .unwrap_or(0) -} diff --git a/ql-protocol/src/typed/mod.rs b/ql-protocol/src/typed/mod.rs deleted file mode 100644 index 1c0efefc..00000000 --- a/ql-protocol/src/typed/mod.rs +++ /dev/null @@ -1,123 +0,0 @@ -use std::time::Duration; - -use bc_components::{ - EncapsulationCiphertext, EncapsulationPrivateKey, EncapsulationPublicKey, - EncryptedMessage, Signer, SigningPublicKey, SymmetricKey, XID, -}; -use dcbor::CBOR; - -use crate::QlError; - -pub mod handle; -pub mod router; - -pub trait QlCodec: Into + TryFrom + Sized {} - -impl QlCodec for T where T: Into + TryFrom + Sized {} - -pub trait RequestResponse: QlCodec { - const ID: u64; - type Response: QlCodec; -} - -pub trait Event: QlCodec { - const ID: u64; -} - -#[derive(Debug, Clone)] -pub struct TypedPayload { - pub message_id: u64, - pub payload: CBOR, -} - -impl From for CBOR { - fn from(value: TypedPayload) -> Self { - CBOR::from(vec![CBOR::from(value.message_id), value.payload]) - } -} - -impl TryFrom for TypedPayload { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let array = value.try_into_array()?; - if array.len() != 2 { - return Err(dcbor::Error::msg("invalid typed payload length")); - } - let message_id: u64 = array[0].clone().try_into()?; - Ok(Self { - message_id, - payload: array[1].clone(), - }) - } -} - -#[derive(Debug)] -pub enum RouterError { - Decode(dcbor::Error), - InvalidPayload, - InvalidSignature, - MissingHandler(u64), - MissingSession(XID), - Send(QlError), - UnknownRecipient(XID), - UnknownSender(XID), -} - -impl From for RouterError { - fn from(error: dcbor::Error) -> Self { - Self::Decode(error) - } -} - -impl From for RouterError { - fn from(error: QlError) -> Self { - Self::Send(error) - } -} - -pub trait RouterPlatform { - fn lookup_recipient(&self, recipient: XID) -> Option<&EncapsulationPublicKey>; - fn lookup_signing_key(&self, sender: XID) -> Option<&SigningPublicKey>; - fn session_for_peer(&self, peer: XID) -> Option; - fn store_session(&self, peer: XID, key: SymmetricKey); - fn encapsulation_private_key(&self) -> EncapsulationPrivateKey; - fn signing_key(&self) -> &SigningPublicKey; - fn message_expiration(&self) -> Duration; - fn signer(&self) -> &dyn Signer; - fn handle_error(&self, e: RouterError); - - fn sender_xid(&self) -> XID { - XID::new(self.signing_key()) - } - - fn decapsulate_shared_secret( - &self, - ciphertext: &EncapsulationCiphertext, - ) -> Result { - self.encapsulation_private_key() - .decapsulate_shared_secret(ciphertext) - .map_err(|_| RouterError::InvalidPayload) - } - - fn decrypt_message( - &self, - key: &SymmetricKey, - header_aad: &[u8], - payload: &EncryptedMessage, - ) -> Result { - if payload.aad() != header_aad { - return Err(RouterError::InvalidPayload); - } - let plaintext = key.decrypt(payload).map_err(|_| RouterError::InvalidPayload)?; - CBOR::try_from_data(plaintext).map_err(RouterError::Decode) - } -} - -pub use handle::TypedExecutorHandle; -pub use router::{ - EventHandler, RequestHandler, Router, RouterBuilder, TypedRequest, TypedResponder, -}; - -#[cfg(test)] -mod test; diff --git a/ql-protocol/src/typed/router.rs b/ql-protocol/src/typed/router.rs deleted file mode 100644 index 31784b83..00000000 --- a/ql-protocol/src/typed/router.rs +++ /dev/null @@ -1,297 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use bc_components::{Verifier, XID}; - -use super::{Event, QlCodec, RequestResponse, RouterError, RouterPlatform, TypedPayload}; -use crate::{EncodeQlConfig, HandlerEvent, MessageKind, QlHeader, QlHeaderUnsigned, Responder}; - -pub trait RequestHandler -where - M: RequestResponse, -{ - fn handle(&mut self, request: TypedRequest); - fn default_response() -> M::Response; -} - -pub trait EventHandler -where - M: Event, -{ - fn handle(&mut self, event: M); -} - -pub struct TypedRequest -where - M: RequestResponse, -{ - pub message: M, - pub responder: TypedResponder, -} - -pub struct TypedResponder -where - R: QlCodec, -{ - responder: Option, - platform: Arc, - recipient: XID, - default: fn() -> R, -} - -impl TypedResponder -where - R: QlCodec, -{ - pub fn respond(mut self, response: R) -> Result<(), RouterError> { - self.respond_inner(response) - } - - fn respond_inner(&mut self, response: R) -> Result<(), RouterError> { - let responder = self.responder.take().unwrap(); - let payload = response.into(); - let session_key = self - .platform - .session_for_peer(self.recipient) - .ok_or(RouterError::MissingSession(self.recipient))?; - let now = now_secs(); - let valid_until = now.saturating_add(self.platform.message_expiration().as_secs()); - let header_unsigned = QlHeaderUnsigned { - kind: MessageKind::Response, - id: responder.id(), - sender: self.platform.sender_xid(), - recipient: self.recipient, - valid_until, - kem_ct: None, - }; - let aad = header_unsigned.aad_data(); - let payload_bytes = dcbor::CBOR::from(payload).to_cbor_data(); - let encrypted = session_key.encrypt(payload_bytes, Some(aad), None::); - let config = EncodeQlConfig { - sender: self.platform.sender_xid(), - recipient: self.recipient, - valid_until, - kem_ct: None, - sign_header: false, - }; - responder.respond(encrypted, config, self.platform.signer())?; - Ok(()) - } -} - -impl Drop for TypedResponder -where - R: QlCodec, -{ - fn drop(&mut self) { - if self.responder.is_some() { - let default = (self.default)(); - let _ = self.respond_inner(default); - } - } -} - -type RouterHandler = fn(&mut S, RouterEvent, Arc) -> Result<(), RouterError>; - -enum RouterEvent { - Event { - #[allow(unused)] - header: QlHeader, - payload: TypedPayload, - }, - Request { - header: QlHeader, - payload: TypedPayload, - responder: Responder, - }, -} - -impl RouterEvent { - fn message_id(&self) -> u64 { - match self { - RouterEvent::Event { payload, .. } => payload.message_id, - RouterEvent::Request { payload, .. } => payload.message_id, - } - } -} - -pub struct RouterBuilder { - handlers: HashMap>, -} - -impl Default for RouterBuilder { - fn default() -> Self { - Self { - handlers: HashMap::new(), - } - } -} - -impl RouterBuilder { - pub fn new() -> Self { - Self::default() - } - - pub fn add_request_handler(mut self) -> Self - where - M: RequestResponse, - S: RequestHandler, - { - self.handlers.insert(M::ID, handle_request::); - self - } - - pub fn add_event_handler(mut self) -> Self - where - M: Event, - S: EventHandler, - { - self.handlers.insert(M::ID, handle_event::); - self - } - - pub fn build(self, platform: Arc) -> Router { - Router { - platform, - handlers: self.handlers, - } - } -} - -pub struct Router { - platform: Arc, - handlers: HashMap>, -} - -impl Router { - pub fn builder() -> RouterBuilder { - RouterBuilder::new() - } - - pub fn handle(&self, state: &mut S, event: HandlerEvent) -> Result<(), RouterError> { - let event = decrypt_event(event, self.platform.as_ref())?; - let message_id = event.message_id(); - let handler = self - .handlers - .get(&message_id) - .ok_or(RouterError::MissingHandler(message_id))?; - handler(state, event, self.platform.clone()) - } -} - -fn handle_request( - state: &mut S, - event: RouterEvent, - platform: Arc, -) -> Result<(), RouterError> -where - M: RequestResponse, - S: RequestHandler, -{ - let (header, payload, responder) = match event { - RouterEvent::Request { - header, - payload, - responder, - } => (header, payload, responder), - RouterEvent::Event { .. } => unreachable!("expected request event"), - }; - let message = M::try_from(payload.payload)?; - let responder = TypedResponder { - responder: Some(responder), - platform, - recipient: header.sender, - default: S::default_response, - }; - state.handle(TypedRequest { message, responder }); - Ok(()) -} - -fn handle_event( - state: &mut S, - event: RouterEvent, - _platform: Arc, -) -> Result<(), RouterError> -where - M: Event, - S: EventHandler, -{ - let payload = match event { - RouterEvent::Event { payload, .. } => payload, - RouterEvent::Request { .. } => unreachable!("expected event"), - }; - let message = M::try_from(payload.payload)?; - state.handle(message); - Ok(()) -} - -fn decrypt_event( - event: HandlerEvent, - platform: &dyn RouterPlatform, -) -> Result { - match event { - HandlerEvent::Request(request) => { - verify_header(platform, &request.message.header)?; - let payload = - extract_typed_payload(platform, &request.message.header, request.message.payload)?; - Ok(RouterEvent::Request { - header: request.message.header, - payload, - responder: request.respond_to, - }) - } - HandlerEvent::Event(event) => { - verify_header(platform, &event.message.header)?; - let payload = - extract_typed_payload(platform, &event.message.header, event.message.payload)?; - Ok(RouterEvent::Event { - header: event.message.header, - payload, - }) - } - } -} - -fn verify_header(platform: &dyn RouterPlatform, header: &QlHeader) -> Result<(), RouterError> { - if header.kem_ct.is_none() { - return Ok(()); - } - let signature = header - .signature - .as_ref() - .ok_or(RouterError::InvalidSignature)?; - let signing_key = platform - .lookup_signing_key(header.sender) - .ok_or(RouterError::UnknownSender(header.sender))?; - if signing_key.verify(signature, &header.unsigned().aad_data()) { - Ok(()) - } else { - Err(RouterError::InvalidSignature) - } -} - -fn extract_typed_payload( - platform: &dyn RouterPlatform, - header: &QlHeader, - payload: bc_components::EncryptedMessage, -) -> Result { - let session_key = if let Some(kem_ct) = &header.kem_ct { - let key = platform.decapsulate_shared_secret(kem_ct)?; - platform.store_session(header.sender, key.clone()); - key - } else { - platform - .session_for_peer(header.sender) - .ok_or(RouterError::MissingSession(header.sender))? - }; - let decrypted = platform.decrypt_message(&session_key, &header.aad_data(), &payload)?; - TypedPayload::try_from(decrypted).map_err(RouterError::Decode) -} - -fn now_secs() -> u64 { - use std::time::{SystemTime, UNIX_EPOCH}; - - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|duration| duration.as_secs()) - .unwrap_or(0) -} diff --git a/ql-protocol/src/typed/test.rs b/ql-protocol/src/typed/test.rs deleted file mode 100644 index c21e5104..00000000 --- a/ql-protocol/src/typed/test.rs +++ /dev/null @@ -1,287 +0,0 @@ -use std::{collections::HashMap, sync::{Arc, Mutex}, time::Duration}; - -use async_channel::{Receiver, Sender}; -use oneshot; -use bc_components::{ - Decrypter, EncapsulationPrivateKey, EncapsulationPublicKey, Signer, SigningPublicKey, - SymmetricKey, XID, -}; -use dcbor::CBOR; - -use super::{ - Event, EventHandler, RequestHandler, RequestResponse, Router, RouterPlatform, - TypedExecutorHandle, TypedRequest, -}; -use crate::{ - test_identity::TestIdentity, Executor, ExecutorConfig, PlatformFuture, QlError, QlPlatform, - RequestConfig, -}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -struct Ping(u64); - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -struct Pong(u64); - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -struct Notice(u64); - -impl From for CBOR { - fn from(value: Ping) -> Self { - CBOR::from(value.0) - } -} - -impl TryFrom for Ping { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let value: u64 = value.try_into()?; - Ok(Self(value)) - } -} - -impl From for CBOR { - fn from(value: Pong) -> Self { - CBOR::from(value.0) - } -} - -impl TryFrom for Pong { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let value: u64 = value.try_into()?; - Ok(Self(value)) - } -} - -impl From for CBOR { - fn from(value: Notice) -> Self { - CBOR::from(value.0) - } -} - -impl TryFrom for Notice { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let value: u64 = value.try_into()?; - Ok(Self(value)) - } -} - -impl RequestResponse for Ping { - const ID: u64 = 100; - type Response = Pong; -} - -impl Event for Notice { - const ID: u64 = 200; -} - -struct TestPlatform { - tx: Sender>, -} - -impl TestPlatform { - fn new() -> (Self, Receiver>) { - let (tx, rx) = async_channel::unbounded(); - (Self { tx }, rx) - } -} - -impl QlPlatform for TestPlatform { - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { - let tx = self.tx.clone(); - Box::pin(async move { tx.send(message).await.map_err(|_| QlError::Cancelled) }) - } - - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { - Box::pin(async move { tokio::time::sleep(duration).await }) - } -} - -struct TestRouterPlatform { - identity: TestIdentity, - peer: EncapsulationPublicKey, - peer_signing_key: SigningPublicKey, - sessions: Mutex>, -} - -impl TestRouterPlatform { - fn new( - identity: TestIdentity, - peer: EncapsulationPublicKey, - peer_signing_key: SigningPublicKey, - ) -> Self { - Self { - identity, - peer, - peer_signing_key, - sessions: Mutex::new(HashMap::new()), - } - } - - fn xid(&self) -> XID { - self.identity.xid - } -} - -impl RouterPlatform for TestRouterPlatform { - fn lookup_recipient(&self, _recipient: XID) -> Option<&EncapsulationPublicKey> { - Some(&self.peer) - } - - fn lookup_signing_key(&self, sender: XID) -> Option<&SigningPublicKey> { - if sender == XID::new(&self.peer_signing_key) { - Some(&self.peer_signing_key) - } else { - None - } - } - - fn session_for_peer(&self, peer: XID) -> Option { - self.sessions.lock().ok()?.get(&peer).cloned() - } - - fn store_session(&self, peer: XID, key: SymmetricKey) { - if let Ok(mut sessions) = self.sessions.lock() { - sessions.insert(peer, key); - } - } - - fn encapsulation_private_key(&self) -> EncapsulationPrivateKey { - self.identity.private_keys.encapsulation_private_key() - } - - fn signing_key(&self) -> &SigningPublicKey { - &self.identity.signing_public_key - } - - fn message_expiration(&self) -> Duration { - Duration::from_secs(60) - } - - fn signer(&self) -> &dyn Signer { - &self.identity.private_keys - } - - fn handle_error(&self, _e: super::RouterError) {} -} - -struct TestState { - event_tx: Option>, -} - -impl RequestHandler for TestState { - fn handle(&mut self, request: TypedRequest) { - let response = Pong(request.message.0 + 1); - let _ = request.responder.respond(response); - } - - fn default_response() -> Pong { - Pong(0) - } -} - -impl EventHandler for TestState { - fn handle(&mut self, event: Notice) { - if let Some(tx) = self.event_tx.take() { - let _ = tx.send(event.0); - } - } -} - -#[tokio::test(flavor = "current_thread")] -async fn typed_round_trip() { - let local = tokio::task::LocalSet::new(); - local - .run_until(async { - let (client_platform, client_outbound) = TestPlatform::new(); - let (server_platform, server_outbound) = TestPlatform::new(); - let config = ExecutorConfig { - default_timeout: Duration::from_secs(1), - }; - - let (mut client_core, client_handle, _client_incoming) = - Executor::new(client_platform, config); - let (mut server_core, server_handle, mut server_incoming) = - Executor::new(server_platform, config); - - tokio::task::spawn_local(async move { client_core.run().await }); - tokio::task::spawn_local(async move { server_core.run().await }); - - tokio::task::spawn_local({ - let server_handle = server_handle.clone(); - async move { - while let Ok(bytes) = client_outbound.recv().await { - server_handle.send_incoming(bytes).unwrap(); - } - } - }); - - tokio::task::spawn_local({ - let client_handle = client_handle.clone(); - async move { - while let Ok(bytes) = server_outbound.recv().await { - client_handle.send_incoming(bytes).unwrap(); - } - } - }); - - let client_identity = TestIdentity::generate(); - let server_identity = TestIdentity::generate(); - let client_platform = Arc::new(TestRouterPlatform::new( - client_identity.clone(), - server_identity.encapsulation_public_key.clone(), - server_identity.signing_public_key.clone(), - )); - let server_platform = Arc::new(TestRouterPlatform::new( - server_identity.clone(), - client_identity.encapsulation_public_key.clone(), - client_identity.signing_public_key.clone(), - )); - let recipient = server_platform.xid(); - - let router = Router::builder() - .add_request_handler::() - .add_event_handler::() - .build(server_platform.clone()); - - let (event_tx, event_rx) = oneshot::channel(); - let mut state = TestState { - event_tx: Some(event_tx), - }; - - tokio::task::spawn_local({ - let server_platform = server_platform.clone(); - async move { - loop { - let event = match server_incoming.next().await { - Ok(event) => event, - Err(_) => break, - }; - if let Err(err) = router.handle(&mut state, event) { - server_platform.handle_error(err); - } - } - } - }); - - let client_typed = TypedExecutorHandle::new(client_handle, client_platform); - - client_typed - .send_event(Notice(7), recipient, Duration::from_secs(60)) - .unwrap(); - let event_value = event_rx.await.expect("event handled"); - assert_eq!(event_value, 7); - - let response = client_typed - .request(Ping(41), recipient, RequestConfig::default()) - .await - .expect("response"); - assert_eq!(response, Pong(42)); - }) - .await; -} diff --git a/ql-protocol/src/wire.rs b/ql-protocol/src/wire.rs deleted file mode 100644 index 8220ee86..00000000 --- a/ql-protocol/src/wire.rs +++ /dev/null @@ -1,434 +0,0 @@ -use std::time::{SystemTime, UNIX_EPOCH}; - -use bc_components::{EncapsulationCiphertext, EncryptedMessage, Signature, Signer, ARID, XID}; -use dcbor::CBOR; -use thiserror::Error; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum MessageKind { - Request, - Response, - Event, -} - -#[derive(Debug, Clone)] -pub struct QlHeader { - pub kind: MessageKind, - pub id: ARID, - pub sender: XID, - pub recipient: XID, - pub valid_until: u64, - pub kem_ct: Option, - pub signature: Option, -} - -#[derive(Debug, Clone)] -pub struct QlHeaderUnsigned { - pub kind: MessageKind, - pub id: ARID, - pub sender: XID, - pub recipient: XID, - pub valid_until: u64, - pub kem_ct: Option, -} - -#[derive(Debug, Clone)] -pub struct EncodeQlConfig { - pub sender: XID, - pub recipient: XID, - pub valid_until: u64, - pub kem_ct: Option, - pub sign_header: bool, -} - -impl QlHeader { - pub fn unsigned(&self) -> QlHeaderUnsigned { - QlHeaderUnsigned { - kind: self.kind, - id: self.id, - sender: self.sender, - recipient: self.recipient, - valid_until: self.valid_until, - kem_ct: self.kem_ct.clone(), - } - } - - pub fn aad_data(&self) -> Vec { - CBOR::from(self.unsigned()).to_cbor_data() - } -} - -impl QlHeaderUnsigned { - pub fn aad_data(&self) -> Vec { - CBOR::from(self.clone()).to_cbor_data() - } -} - -impl From for dcbor::CBOR { - fn from(value: QlHeader) -> Self { - dcbor::CBOR::from(vec![ - dcbor::CBOR::from(value.kind), - dcbor::CBOR::from(value.id), - dcbor::CBOR::from(value.sender), - dcbor::CBOR::from(value.recipient), - dcbor::CBOR::from(value.valid_until), - option_to_cbor(value.kem_ct), - option_to_cbor(value.signature), - ]) - } -} - -impl From for dcbor::CBOR { - fn from(value: QlHeaderUnsigned) -> Self { - dcbor::CBOR::from(vec![ - dcbor::CBOR::from(value.kind), - dcbor::CBOR::from(value.id), - dcbor::CBOR::from(value.sender), - dcbor::CBOR::from(value.recipient), - dcbor::CBOR::from(value.valid_until), - option_to_cbor(value.kem_ct), - ]) - } -} - -impl TryFrom for QlHeader { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let array = value.try_into_array()?; - if array.len() != 7 { - return Err(dcbor::Error::msg("invalid header length")); - } - let kind = MessageKind::try_from(array[0].clone())?; - let id: ARID = array[1].clone().try_into()?; - let sender: XID = array[2].clone().try_into()?; - let recipient: XID = array[3].clone().try_into()?; - let valid_until: u64 = array[4].clone().try_into()?; - let kem_ct: Option = option_from_cbor(array[5].clone())?; - let signature: Option = option_from_cbor(array[6].clone())?; - Ok(Self { - kind, - id, - sender, - recipient, - valid_until, - kem_ct, - signature, - }) - } -} - -fn option_to_cbor(value: Option) -> CBOR -where - T: Into, -{ - value.map(Into::into).unwrap_or_else(CBOR::null) -} - -fn option_from_cbor(value: CBOR) -> dcbor::Result> -where - T: TryFrom, -{ - if value.is_null() { - Ok(None) - } else { - Ok(Some(value.try_into()?)) - } -} - -#[derive(Debug, Error)] -pub enum DecodeError { - #[error("invalid message encoding")] - InvalidEncoding, - #[error("message expired")] - Expired, - #[error(transparent)] - Cbor(#[from] dcbor::Error), -} - -#[derive(Debug)] -pub struct DecodeErrContext { - pub error: DecodeError, - pub header: Option, -} - -#[derive(Debug, Clone)] -pub struct QlMessage { - pub header: QlHeader, - pub payload: EncryptedMessage, -} - -pub fn encode_ql_message( - kind: MessageKind, - id: ARID, - config: EncodeQlConfig, - payload: EncryptedMessage, - signer: &dyn Signer, -) -> Vec { - let header_unsigned = QlHeaderUnsigned { - kind, - id, - sender: config.sender, - recipient: config.recipient, - valid_until: config.valid_until, - kem_ct: config.kem_ct.clone(), - }; - let signature = if config.sign_header { - Some( - signer - .sign(&header_unsigned.aad_data()) - .expect("failed to sign header"), - ) - } else { - None - }; - let header = QlHeader { - kind, - id, - sender: config.sender, - recipient: config.recipient, - valid_until: config.valid_until, - kem_ct: config.kem_ct, - signature, - }; - let cbor = CBOR::from(vec![CBOR::from(header), CBOR::from(payload)]); - cbor.to_cbor_data() -} - -pub fn decode_ql_message(bytes: &[u8]) -> Result { - let cbor = dcbor::CBOR::try_from_data(bytes).map_err(|error| DecodeErrContext { - error: DecodeError::Cbor(error), - header: None, - })?; - let array = cbor.try_into_array().map_err(|error| DecodeErrContext { - error: DecodeError::Cbor(error), - header: None, - })?; - if array.len() != 2 { - return Err(DecodeErrContext { - error: DecodeError::InvalidEncoding, - header: None, - }); - } - let header = QlHeader::try_from(array[0].clone()).map_err(|error| DecodeErrContext { - error: DecodeError::Cbor(error), - header: None, - })?; - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|duration| duration.as_secs()) - .unwrap_or(0); - if now > header.valid_until { - return Err(DecodeErrContext { - error: DecodeError::Expired, - header: Some(header), - }); - } - let payload: EncryptedMessage = - array[1] - .clone() - .try_into() - .map_err(|error| DecodeErrContext { - error: DecodeError::Cbor(error), - header: Some(header.clone()), - })?; - Ok(QlMessage { header, payload }) -} - -impl TryFrom for MessageKind { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let kind: u64 = value.try_into()?; - match kind { - 1 => Ok(MessageKind::Request), - 2 => Ok(MessageKind::Response), - 3 => Ok(MessageKind::Event), - _ => Err(dcbor::Error::msg("unknown message kind")), - } - } -} - -impl From for CBOR { - fn from(value: MessageKind) -> Self { - let kind = match value { - MessageKind::Request => 1, - MessageKind::Response => 2, - MessageKind::Event => 3, - }; - CBOR::from(kind) - } -} - -#[cfg(test)] -mod tests { - use bc_components::Verifier; - - use super::*; - use crate::test_identity::TestIdentity; - - #[test] - fn round_trip() { - let sender = TestIdentity::generate(); - let recipient = TestIdentity::generate(); - let recipient_xid = recipient.xid; - let sender_xid = sender.xid; - let header_id = ARID::new(); - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|duration| duration.as_secs()) - .unwrap_or(0); - let valid_until = now.saturating_add(60); - let (session_key, kem_ct) = recipient - .encapsulation_public_key - .encapsulate_new_shared_secret(); - let header_unsigned = QlHeaderUnsigned { - kind: MessageKind::Request, - id: header_id, - sender: sender_xid, - recipient: recipient_xid, - valid_until, - kem_ct: Some(kem_ct.clone()), - }; - let payload = CBOR::from("secret"); - let payload_bytes = payload.to_cbor_data(); - let encrypted_payload = session_key.encrypt( - payload_bytes, - Some(header_unsigned.aad_data()), - None::, - ); - - let bytes = encode_ql_message( - MessageKind::Request, - header_id, - EncodeQlConfig { - sender: sender_xid, - recipient: recipient_xid, - valid_until, - kem_ct: Some(kem_ct), - sign_header: true, - }, - encrypted_payload, - &sender.private_keys, - ); - let decoded = decode_ql_message(&bytes).expect("decode failed"); - - assert_eq!(decoded.header.kind, MessageKind::Request); - assert_eq!(decoded.header.id, header_id); - assert_eq!(decoded.header.recipient, recipient_xid); - assert_eq!(decoded.header.sender, sender_xid); - - let signing_data = decoded.header.unsigned().aad_data(); - let signature = decoded.header.signature.as_ref().expect("signature"); - assert!(sender.signing_public_key.verify(signature, &signing_data)); - - let decrypted = session_key.decrypt(&decoded.payload).expect("decrypt"); - let decrypted_cbor = CBOR::try_from_data(decrypted).expect("cbor"); - assert_eq!(decrypted_cbor, payload); - } - - #[test] - fn header_size() { - let size = std::mem::size_of::(); - println!("header size: {} bytes", size); - assert!(size > 0); - } - - #[test] - fn encoded_message_size() { - let sender = TestIdentity::generate(); - let recipient = TestIdentity::generate(); - let recipient_xid = recipient.xid; - let sender_xid = sender.xid; - let header_id = ARID::new(); - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|duration| duration.as_secs()) - .unwrap_or(0); - let valid_until = now.saturating_add(60); - let (session_key, kem_ct) = recipient - .encapsulation_public_key - .encapsulate_new_shared_secret(); - let header_unsigned = QlHeaderUnsigned { - kind: MessageKind::Request, - id: header_id, - sender: sender_xid, - recipient: recipient_xid, - valid_until, - kem_ct: Some(kem_ct.clone()), - }; - let payload = CBOR::from("size"); - let payload_bytes = payload.to_cbor_data(); - let encrypted_payload = session_key.encrypt( - payload_bytes, - Some(header_unsigned.aad_data()), - None::, - ); - - let bytes = encode_ql_message( - MessageKind::Request, - header_id, - EncodeQlConfig { - sender: sender_xid, - recipient: recipient_xid, - valid_until, - kem_ct: Some(kem_ct), - sign_header: true, - }, - encrypted_payload, - &sender.private_keys, - ); - - println!("encoded message size: {} bytes", bytes.len()); - assert!(!bytes.is_empty()); - } - - #[test] - fn steady_state_message_size() { - let sender = TestIdentity::generate(); - let recipient = TestIdentity::generate(); - let recipient_xid = recipient.xid; - let sender_xid = sender.xid; - let header_id = ARID::new(); - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|duration| duration.as_secs()) - .unwrap_or(0); - let valid_until = now.saturating_add(60); - let (session_key, _kem_ct) = recipient - .encapsulation_public_key - .encapsulate_new_shared_secret(); - let header_unsigned = QlHeaderUnsigned { - kind: MessageKind::Request, - id: header_id, - sender: sender_xid, - recipient: recipient_xid, - valid_until, - kem_ct: None, - }; - let payload = CBOR::from("steady"); - let payload_bytes = payload.to_cbor_data(); - let encrypted_payload = session_key.encrypt( - payload_bytes, - Some(header_unsigned.aad_data()), - None::, - ); - - let bytes = encode_ql_message( - MessageKind::Request, - header_id, - EncodeQlConfig { - sender: sender_xid, - recipient: recipient_xid, - valid_until, - kem_ct: None, - sign_header: false, - }, - encrypted_payload, - &sender.private_keys, - ); - - println!("steady-state message size: {} bytes", bytes.len()); - assert!(!bytes.is_empty()); - } -} diff --git a/ql-protocol/Cargo.toml b/ql/Cargo.toml similarity index 67% rename from ql-protocol/Cargo.toml rename to ql/Cargo.toml index e741fce4..315382eb 100644 --- a/ql-protocol/Cargo.toml +++ b/ql/Cargo.toml @@ -1,8 +1,8 @@ [package] -name = "ql-protocol" +name = "ql" version = "0.1.0" edition = "2021" -description = "Quantum Link protocol primitives." +description = "Quantum Link handshake prototype" license = "Proprietary" [dependencies] @@ -13,7 +13,7 @@ bc-components = { version = "0.28.0", default-features = false, features = [ dcbor = { version = "0.23.3" } futures-lite = { version = "2.5" } oneshot = { version = "0.1.11" } -thiserror = "2" +thiserror = { version = "2" } [dev-dependencies] -tokio = { version = "1", features = ["rt", "time", "macros"] } +tokio = { version = "1.44", features = ["macros", "rt", "time", "sync"] } diff --git a/ql/src/crypto/handshake.rs b/ql/src/crypto/handshake.rs new file mode 100644 index 00000000..74842230 --- /dev/null +++ b/ql/src/crypto/handshake.rs @@ -0,0 +1,140 @@ +use bc_components::{ + Digest, EncapsulationCiphertext, EncapsulationPublicKey, Nonce, SigningPublicKey, SymmetricKey, + XID, +}; +use dcbor::CBOR; + +use crate::{ + platform::QlPlatform, + wire::handshake::{verify_transcript_signature, Confirm, Hello, HelloReply}, + QlError, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ResponderSecrets { + pub initiator_secret: SymmetricKey, + pub responder_secret: SymmetricKey, +} + +pub fn build_hello( + platform: &impl QlPlatform, + _sender: XID, + _recipient: XID, + recipient_encapsulation_key: &EncapsulationPublicKey, +) -> Result<(Hello, SymmetricKey), QlError> { + let nonce = next_nonce(platform); + let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(); + Ok((Hello { nonce, kem_ct }, session_key)) +} + +pub fn respond_hello( + platform: &impl QlPlatform, + initiator: XID, + responder: XID, + initiator_encapsulation_key: &EncapsulationPublicKey, + hello: &Hello, +) -> Result<(HelloReply, ResponderSecrets), QlError> { + let initiator_secret = platform + .encapsulation_private_key() + .decapsulate_shared_secret(&hello.kem_ct) + .map_err(|_| QlError::InvalidPayload)?; + let nonce = next_nonce(platform); + let (responder_secret, kem_ct) = initiator_encapsulation_key.encapsulate_new_shared_secret(); + let transcript = handshake_transcript(initiator, responder, hello, &nonce, &kem_ct); + let signature = platform + .signer() + .sign(&transcript) + .map_err(|_| QlError::InvalidPayload)?; + let reply = HelloReply { + nonce, + kem_ct, + signature, + }; + Ok(( + reply, + ResponderSecrets { + initiator_secret, + responder_secret, + }, + )) +} + +pub fn build_confirm( + platform: &impl QlPlatform, + initiator: XID, + responder: XID, + responder_signing_key: &SigningPublicKey, + hello: &Hello, + reply: &HelloReply, + initiator_secret: &SymmetricKey, +) -> Result<(Confirm, SymmetricKey), QlError> { + let transcript = handshake_transcript(initiator, responder, hello, &reply.nonce, &reply.kem_ct); + verify_transcript_signature(responder_signing_key, &reply.signature, &transcript)?; + let responder_secret = platform + .encapsulation_private_key() + .decapsulate_shared_secret(&reply.kem_ct) + .map_err(|_| QlError::InvalidPayload)?; + let signature = platform + .signer() + .sign(&transcript) + .map_err(|_| QlError::InvalidPayload)?; + let confirm = Confirm { signature }; + let session_key = derive_session_key(initiator_secret, &responder_secret, &transcript); + Ok((confirm, session_key)) +} + +pub fn finalize_confirm( + initiator: XID, + responder: XID, + initiator_signing_key: &SigningPublicKey, + hello: &Hello, + reply: &HelloReply, + confirm: &Confirm, + secrets: &ResponderSecrets, +) -> Result { + let transcript = handshake_transcript(initiator, responder, hello, &reply.nonce, &reply.kem_ct); + verify_transcript_signature(initiator_signing_key, &confirm.signature, &transcript)?; + Ok(derive_session_key( + &secrets.initiator_secret, + &secrets.responder_secret, + &transcript, + )) +} +fn handshake_transcript( + initiator: XID, + responder: XID, + hello: &Hello, + responder_nonce: &bc_components::Nonce, + responder_kem_ct: &EncapsulationCiphertext, +) -> Vec { + CBOR::from(vec![ + CBOR::from(initiator), + CBOR::from(responder), + CBOR::from(hello.nonce.clone()), + CBOR::from(responder_nonce.clone()), + CBOR::from(hello.kem_ct.clone()), + CBOR::from(responder_kem_ct.clone()), + ]) + .to_cbor_data() +} + +fn next_nonce(platform: &impl QlPlatform) -> Nonce { + let mut data = [0u8; Nonce::NONCE_SIZE]; + platform.fill_bytes(&mut data); + Nonce::from_data(data) +} + +fn derive_session_key( + initiator_secret: &SymmetricKey, + responder_secret: &SymmetricKey, + transcript: &[u8], +) -> SymmetricKey { + let payload = CBOR::from(vec![ + CBOR::from(initiator_secret.as_bytes()), + CBOR::from(responder_secret.as_bytes()), + CBOR::from(transcript), + ]) + .to_cbor_data(); + let digest = Digest::from_image(payload); + SymmetricKey::from_data(*digest.data()) +} diff --git a/ql/src/crypto/heartbeat.rs b/ql/src/crypto/heartbeat.rs new file mode 100644 index 00000000..0949e974 --- /dev/null +++ b/ql/src/crypto/heartbeat.rs @@ -0,0 +1,40 @@ +use bc_components::{Nonce, SymmetricKey}; +use dcbor::CBOR; + +use crate::{ + crypto::ensure_not_expired, + wire::{heartbeat::HeartbeatBody, QlHeader, QlPayload, QlRecord}, + QlError, +}; + +pub fn encrypt_heartbeat( + header: QlHeader, + session_key: &SymmetricKey, + body: HeartbeatBody, +) -> QlRecord { + let aad = header.aad(); + let body_bytes = CBOR::from(body).to_cbor_data(); + let encrypted = session_key.encrypt(body_bytes, Some(aad), None::); + QlRecord { + header, + payload: QlPayload::Heartbeat(encrypted), + } +} + +pub fn decrypt_heartbeat( + header: &QlHeader, + encrypted: &bc_components::EncryptedMessage, + session_key: &SymmetricKey, +) -> Result { + let aad = header.aad(); + if encrypted.aad() != aad { + return Err(QlError::InvalidPayload); + } + let plaintext = session_key + .decrypt(encrypted) + .map_err(|_| QlError::InvalidPayload)?; + let cbor = CBOR::try_from_data(plaintext).map_err(|_| QlError::InvalidPayload)?; + let body = HeartbeatBody::try_from(cbor).map_err(|_| QlError::InvalidPayload)?; + ensure_not_expired(body.message_id, body.valid_until)?; + Ok(body) +} diff --git a/ql/src/crypto/message.rs b/ql/src/crypto/message.rs new file mode 100644 index 00000000..aabc5837 --- /dev/null +++ b/ql/src/crypto/message.rs @@ -0,0 +1,74 @@ +use bc_components::{Nonce, SymmetricKey}; +use dcbor::CBOR; + +use crate::{ + crypto::ensure_not_expired, + wire::{ + message::{DecryptedMessage, MessageBody, MessageKind, Nack}, + QlHeader, QlPayload, QlRecord, + }, + MessageId, QlError, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum MessageError { + Nack { + id: MessageId, + nack: Nack, + kind: MessageKind, + }, + Error(QlError), +} + +impl From for MessageError { + fn from(value: QlError) -> Self { + Self::Error(value) + } +} + +pub fn encrypt_message( + header: QlHeader, + session_key: &SymmetricKey, + body: MessageBody, +) -> QlRecord { + let aad = CBOR::from(header.clone()).to_cbor_data(); + let body_bytes = CBOR::from(body).to_cbor_data(); + let encrypted = session_key.encrypt(body_bytes, Some(aad), None::); + QlRecord { + header, + payload: QlPayload::Message(encrypted), + } +} + +pub fn decrypt_message( + header: &QlHeader, + encrypted: &bc_components::EncryptedMessage, + session_key: &SymmetricKey, +) -> Result { + let aad = header.aad(); + if encrypted.aad() != aad { + return Err(QlError::InvalidPayload.into()); + } + let body = decrypt_body(session_key, encrypted)?; + ensure_not_expired(body.message_id, body.valid_until)?; + Ok(DecryptedMessage { + sender: header.sender, + recipient: header.recipient, + kind: body.kind, + message_id: body.message_id, + route_id: body.route_id, + valid_until: body.valid_until, + payload: body.payload, + }) +} + +fn decrypt_body( + session_key: &SymmetricKey, + encrypted: &bc_components::EncryptedMessage, +) -> Result { + let plaintext = session_key + .decrypt(encrypted) + .map_err(|_| QlError::InvalidPayload)?; + let cbor = CBOR::try_from_data(plaintext).map_err(|_| QlError::InvalidPayload)?; + MessageBody::try_from(cbor).map_err(|_| QlError::InvalidPayload) +} diff --git a/ql/src/crypto/mod.rs b/ql/src/crypto/mod.rs new file mode 100644 index 00000000..5d6e6a4e --- /dev/null +++ b/ql/src/crypto/mod.rs @@ -0,0 +1,25 @@ +use crate::{wire::message::Nack, MessageId, QlError}; + +pub mod handshake; +pub mod heartbeat; +pub mod message; +pub mod pair; + +fn ensure_not_expired(id: MessageId, valid_until: u64) -> Result<(), QlError> { + let now = now_secs(); + if now > valid_until { + Err(QlError::Nack { + id, + nack: Nack::Expired, + }) + } else { + Ok(()) + } +} + +fn now_secs() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|duration| duration.as_secs()) + .unwrap_or(0) +} diff --git a/ql/src/crypto/pair.rs b/ql/src/crypto/pair.rs new file mode 100644 index 00000000..1baec75d --- /dev/null +++ b/ql/src/crypto/pair.rs @@ -0,0 +1,127 @@ +use std::time::Duration; + +use bc_components::{EncapsulationPublicKey, Nonce, SigningPublicKey, SymmetricKey, Verifier, XID}; +use dcbor::CBOR; + +use crate::{ + crypto::ensure_not_expired, + platform::{QlPlatform, QlPlatformExt}, + wire::{ + pair::{PairRequestBody, PairRequestRecord}, + QlHeader, QlPayload, QlRecord, + }, + MessageId, QlError, +}; + +pub fn build_pair_request( + platform: &impl QlPlatform, + recipient: XID, + recipient_encapsulation_key: &EncapsulationPublicKey, + message_id: MessageId, + valid_for: Duration, +) -> Result { + let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(); + let header = QlHeader { + sender: platform.xid(), + recipient, + }; + let valid_until = super::now_secs().saturating_add(valid_for.as_secs()); + let signing_pub_key = platform.signing_public_key().clone(); + let sender_encapsulation_key = platform.encapsulation_public_key().clone(); + let proof_data = pairing_proof_data( + &header, + &kem_ct, + message_id, + valid_until, + &signing_pub_key, + &sender_encapsulation_key, + ); + let proof = platform + .signer() + .sign(&proof_data) + .map_err(|_| QlError::InvalidPayload)?; + let body = PairRequestBody { + message_id, + valid_until, + signing_pub_key, + encapsulation_pub_key: sender_encapsulation_key, + proof, + }; + let body_bytes = CBOR::from(body).to_cbor_data(); + let aad = pairing_aad(&header, &kem_ct); + let encrypted = session_key.encrypt(body_bytes, Some(aad), None::); + Ok(QlRecord { + header, + payload: QlPayload::Pair(PairRequestRecord { kem_ct, encrypted }), + }) +} + +pub fn decrypt_pair_request( + platform: &impl QlPlatform, + header: &QlHeader, + request: PairRequestRecord, +) -> Result { + let PairRequestRecord { kem_ct, encrypted } = request; + let session_key = platform + .encapsulation_private_key() + .decapsulate_shared_secret(&kem_ct) + .map_err(|_| QlError::InvalidPayload)?; + let aad = pairing_aad(header, &kem_ct); + if encrypted.aad() != aad { + return Err(QlError::InvalidPayload); + } + let decrypted = decrypt_body(&session_key, &encrypted)?; + ensure_not_expired(decrypted.message_id, decrypted.valid_until)?; + if XID::new(&decrypted.signing_pub_key) != header.sender { + return Err(QlError::InvalidPayload); + } + let proof_data = pairing_proof_data( + header, + &kem_ct, + decrypted.message_id, + decrypted.valid_until, + &decrypted.signing_pub_key, + &decrypted.encapsulation_pub_key, + ); + if decrypted + .signing_pub_key + .verify(&decrypted.proof, &proof_data) + { + Ok(decrypted) + } else { + Err(QlError::InvalidSignature) + } +} + +fn pairing_proof_data( + header: &QlHeader, + kem_ct: &bc_components::EncapsulationCiphertext, + message_id: MessageId, + valid_until: u64, + signing_pub_key: &SigningPublicKey, + encapsulation_pub_key: &EncapsulationPublicKey, +) -> Vec { + CBOR::from(vec![ + CBOR::from(pairing_aad(header, kem_ct)), + CBOR::from(message_id), + CBOR::from(valid_until), + CBOR::from(signing_pub_key.clone()), + CBOR::from(encapsulation_pub_key.clone()), + ]) + .to_cbor_data() +} + +fn decrypt_body( + key: &SymmetricKey, + encrypted: &bc_components::EncryptedMessage, +) -> Result { + let plaintext = key + .decrypt(encrypted) + .map_err(|_| QlError::InvalidPayload)?; + let cbor = CBOR::try_from_data(plaintext).map_err(|_| QlError::InvalidPayload)?; + PairRequestBody::try_from(cbor).map_err(|_| QlError::InvalidPayload) +} + +fn pairing_aad(header: &QlHeader, kem_ct: &bc_components::EncapsulationCiphertext) -> Vec { + CBOR::from(vec![CBOR::from(header.clone()), CBOR::from(kem_ct.clone())]).to_cbor_data() +} diff --git a/ql/src/id.rs b/ql/src/id.rs new file mode 100644 index 00000000..eaab1d55 --- /dev/null +++ b/ql/src/id.rs @@ -0,0 +1,95 @@ +use std::fmt; + +use dcbor::CBOR; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct MessageId(u64); + +impl MessageId { + pub const fn new(value: u64) -> Self { + Self(value) + } + + pub const fn value(self) -> u64 { + self.0 + } +} + +impl fmt::Display for MessageId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for MessageId { + fn from(value: u64) -> Self { + Self(value) + } +} + +impl From for u64 { + fn from(value: MessageId) -> Self { + value.0 + } +} + +impl From for CBOR { + fn from(value: MessageId) -> Self { + CBOR::from(value.0) + } +} + +impl TryFrom for MessageId { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let value: u64 = value.try_into()?; + Ok(Self(value)) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct RouteId(u64); + +impl RouteId { + pub const fn new(value: u64) -> Self { + Self(value) + } + + pub const fn value(self) -> u64 { + self.0 + } +} + +impl fmt::Display for RouteId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for RouteId { + fn from(value: u64) -> Self { + Self(value) + } +} + +impl From for u64 { + fn from(value: RouteId) -> Self { + value.0 + } +} + +impl From for CBOR { + fn from(value: RouteId) -> Self { + CBOR::from(value.0) + } +} + +impl TryFrom for RouteId { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let value: u64 = value.try_into()?; + Ok(Self(value)) + } +} diff --git a/ql/src/lib.rs b/ql/src/lib.rs new file mode 100644 index 00000000..2ef441ae --- /dev/null +++ b/ql/src/lib.rs @@ -0,0 +1,45 @@ +pub mod crypto; +mod id; +pub mod router; +pub mod platform; +pub mod runtime; +pub mod wire; + +pub use id::*; + +pub trait QlCodec: Into + TryFrom {} +impl QlCodec for T where T: Into + TryFrom {} + +pub trait RequestResponse: QlCodec { + const ID: RouteId; + type Response: QlCodec; +} + +pub trait Event: QlCodec { + const ID: RouteId; +} + +#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] +pub enum QlError { + #[error("invalid payload")] + InvalidPayload, + #[error("invalid handshake role")] + InvalidRole, + #[error("invalid signature")] + InvalidSignature, + #[error("missing session for {0}")] + MissingSession(bc_components::XID), + #[error("unknown peer {0}")] + UnknownPeer(bc_components::XID), + #[error("timeout")] + Timeout, + #[error("send failed")] + SendFailed, + #[error("nack {nack:?}")] + Nack { + id: MessageId, + nack: wire::message::Nack, + }, + #[error("cancelled")] + Cancelled, +} diff --git a/ql/src/platform.rs b/ql/src/platform.rs new file mode 100644 index 00000000..8e179435 --- /dev/null +++ b/ql/src/platform.rs @@ -0,0 +1,33 @@ +use std::{future::Future, pin::Pin, time::Duration}; + +use bc_components::{ + EncapsulationPrivateKey, EncapsulationPublicKey, Signer, SigningPublicKey, XID, +}; + +use crate::{ + runtime::{HandlerEvent, PeerSession}, + QlError, +}; + +pub type PlatformFuture<'a, T> = Pin + 'a>>; + +pub trait QlPlatform { + fn signer(&self) -> &dyn Signer; + fn signing_public_key(&self) -> &SigningPublicKey; + fn encapsulation_private_key(&self) -> &EncapsulationPrivateKey; + fn encapsulation_public_key(&self) -> &EncapsulationPublicKey; + + fn fill_bytes(&self, data: &mut [u8]); + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>>; + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()>; + fn handle_peer_status(&self, peer: XID, session: &PeerSession); + fn handle_inbound(&self, event: HandlerEvent); +} + +pub(crate) trait QlPlatformExt: QlPlatform { + fn xid(&self) -> XID { + XID::new(&self.signing_public_key()) + } +} + +impl QlPlatformExt for T {} diff --git a/ql/src/router.rs b/ql/src/router.rs new file mode 100644 index 00000000..51a78af1 --- /dev/null +++ b/ql/src/router.rs @@ -0,0 +1,210 @@ +use std::collections::HashMap; + +use thiserror::Error; + +use crate::{ + runtime::{HandlerEvent, Responder}, + wire::message::{Ack, Nack}, + Event, QlCodec, QlError, RequestResponse, RouteId, +}; + +pub trait RequestHandler +where + M: RequestResponse, +{ + fn handle(&mut self, request: QlRequest); + fn default_response() -> M::Response; +} + +pub trait EventHandler +where + M: Event, +{ + fn handle(&mut self, event: M); +} + +pub struct QlRequest +where + M: RequestResponse, +{ + pub message: M, + pub responder: QlResponder, +} + +pub struct QlResponder +where + R: QlCodec, +{ + responder: Option, + default: fn() -> R, +} + +impl QlResponder +where + R: QlCodec, +{ + pub fn respond(mut self, response: R) -> Result<(), QlError> { + self.respond_inner(response) + } + + pub fn respond_nack(mut self, reason: Nack) -> Result<(), QlError> { + let responder = self.responder.take().unwrap(); + responder.respond_nack(reason) + } + + fn respond_inner(&mut self, response: R) -> Result<(), QlError> { + let responder = self.responder.take().unwrap(); + responder.respond(response) + } +} + +impl Drop for QlResponder +where + R: QlCodec, +{ + fn drop(&mut self) { + if self.responder.is_some() { + let default = (self.default)(); + let _ = self.respond_inner(default); + } + } +} + +#[derive(Debug, Error)] +pub enum RouterError { + #[error(transparent)] + Decode(#[from] dcbor::Error), + #[error("missing handler {0}")] + MissingHandler(RouteId), + #[error(transparent)] + Runtime(#[from] QlError), +} + +type RouterHandler = fn(&mut S, HandlerEvent) -> Result<(), RouterError>; + +pub struct RouterBuilder { + handlers: HashMap>, +} + +impl RouterBuilder { + pub fn new() -> Self { + Self { + handlers: HashMap::new(), + } + } + + pub fn add_request_handler(self) -> Self + where + M: RequestResponse, + S: RequestHandler, + { + self.add_handler(M::ID, handle_request::) + } + + pub fn add_event_handler(self) -> Self + where + M: Event, + S: EventHandler, + { + self.add_handler(M::ID, handle_event::) + } + + pub fn build(mut self, state: S) -> Router { + self.handlers.shrink_to_fit(); + Router { + handlers: self.handlers, + state, + } + } + + fn add_handler(mut self, id: RouteId, handler: RouterHandler) -> Self { + if self.handlers.insert(id, handler).is_some() { + panic!("duplicate route_id {id}"); + } + self + } +} + +pub struct Router { + state: S, + handlers: HashMap>, +} + +impl Router { + pub fn builder() -> RouterBuilder { + RouterBuilder::new() + } + + pub fn handle(&mut self, event: HandlerEvent) -> Result<(), RouterError> { + match event { + HandlerEvent::Request(request) => { + let route_id = request.message.route_id; + let handler = match self.handlers.get(&route_id) { + Some(handler) => handler, + None => { + let _ = request.respond_to.respond_nack(Nack::UnknownRoute); + return Ok(()); + } + }; + handler(&mut self.state, HandlerEvent::Request(request)) + } + HandlerEvent::Event(event) => { + let route_id = event.message.route_id; + let handler = self + .handlers + .get(&route_id) + .ok_or(RouterError::MissingHandler(route_id))?; + handler(&mut self.state, HandlerEvent::Event(event)) + } + } + } +} + +fn handle_request(state: &mut S, event: HandlerEvent) -> Result<(), RouterError> +where + M: RequestResponse, + S: RequestHandler, +{ + let (payload, responder) = match event { + HandlerEvent::Request(request) => (request.message.payload, request.respond_to), + HandlerEvent::Event(_) => return Err(RouterError::Runtime(QlError::InvalidPayload)), + }; + let message = match M::try_from(payload) { + Ok(message) => message, + Err(error) => { + let _ = responder.respond_nack(Nack::InvalidPayload); + return Err(RouterError::Decode(error)); + } + }; + let responder = QlResponder { + responder: Some(responder), + default: S::default_response, + }; + state.handle(QlRequest { message, responder }); + Ok(()) +} + +fn handle_event(state: &mut S, event: HandlerEvent) -> Result<(), RouterError> +where + M: Event, + S: EventHandler, +{ + let (payload, responder) = match event { + HandlerEvent::Event(event) => (event.message.payload, None), + HandlerEvent::Request(request) => (request.message.payload, Some(request.respond_to)), + }; + let message = match M::try_from(payload) { + Ok(message) => message, + Err(error) => { + if let Some(responder) = responder { + let _ = responder.respond_nack(Nack::InvalidPayload); + } + return Err(RouterError::Decode(error)); + } + }; + state.handle(message); + if let Some(responder) = responder { + responder.respond(Ack)?; + } + Ok(()) +} diff --git a/ql/src/runtime/core.rs b/ql/src/runtime/core.rs new file mode 100644 index 00000000..093a9719 --- /dev/null +++ b/ql/src/runtime/core.rs @@ -0,0 +1,1008 @@ +use std::{ + cmp::Reverse, collections::binary_heap::PeekMut, future::Future, task::Poll, time::Instant, +}; + +use bc_components::{EncapsulationPublicKey, XID}; +use dcbor::CBOR; +use futures_lite::future::poll_fn; + +use crate::{ + crypto::{handshake, heartbeat, message, pair}, + platform::{QlPlatform, QlPlatformExt}, + runtime::{ + internal::{ + next_timeout_deadline, now_secs, peer_hello_wins, HelloAction, InFlightWrite, + KeepAliveState, LoopStep, OutboundMessage, PendingEntry, RuntimeCommand, RuntimeState, + TimeoutEntry, TimeoutKind, + }, + HandlerEvent, InboundEvent, InboundRequest, InitiatorStage, KeepAliveConfig, PeerSession, + Responder, Runtime, Token, + }, + wire::{ + handshake::HandshakeRecord, + heartbeat::HeartbeatBody, + message::{MessageBody, MessageKind, Nack}, + pair::PairRequestRecord, + QlHeader, QlPayload, QlRecord, + }, + MessageId, QlError, RouteId, +}; + +impl Runtime

{ + pub async fn run(self) { + let mut state = RuntimeState::new(); + let mut in_flight: Option> = None; + while !self.rx.is_closed() { + if in_flight.is_none() { + in_flight = self.start_next_write(&mut state); + } + let step = self.next_step(&state, in_flight.as_mut()).await; + match step { + LoopStep::Event(command) => match command { + RuntimeCommand::RegisterPeer { + peer, + signing_key, + encapsulation_key, + } => { + self.handle_register_peer(&mut state, peer, signing_key, encapsulation_key); + } + RuntimeCommand::Connect { peer } => { + self.handle_connect(&mut state, peer); + } + RuntimeCommand::SendRequest { + recipient, + route_id, + payload, + respond_to, + config, + } => { + self.handle_send_request( + &mut state, recipient, route_id, payload, respond_to, config, + ); + } + RuntimeCommand::SendEvent { + recipient, + route_id, + payload, + } => { + self.handle_send_event(&mut state, recipient, route_id, payload); + } + RuntimeCommand::SendResponse { + id, + recipient, + payload, + kind, + } => { + self.handle_send_response(&mut state, id, recipient, payload, kind); + } + RuntimeCommand::Incoming(bytes) => { + self.handle_incoming(&mut state, bytes); + } + }, + LoopStep::Timeout => { + self.handle_timeouts(&mut state); + } + LoopStep::WriteDone { + peer, + token, + message_id, + result, + } => { + in_flight = None; + self.handle_write_done(&mut state, peer, token, message_id, result); + } + LoopStep::Quit => break, + } + } + } + + fn start_next_write<'a>(&'a self, state: &mut RuntimeState) -> Option> { + let Some(message) = state.outbound.pop_front() else { + return None; + }; + Some(InFlightWrite { + peer: message.peer, + token: message.token, + message_id: message.message_id, + future: self.platform.write_message(message.bytes), + }) + } + + async fn next_step<'a>( + &'a self, + state: &RuntimeState, + mut in_flight: Option<&mut InFlightWrite<'a>>, + ) -> LoopStep { + let recv_future = self.rx.recv(); + futures_lite::pin!(recv_future); + + let mut sleep_future = next_timeout_deadline(state).map(|deadline| { + let timeout = deadline.saturating_duration_since(Instant::now()); + self.platform.sleep(timeout) + }); + + poll_fn(|cx| { + if let Some(in_flight) = in_flight.as_mut() { + if let Poll::Ready(result) = in_flight.future.as_mut().poll(cx) { + return Poll::Ready(LoopStep::WriteDone { + peer: in_flight.peer, + token: in_flight.token, + message_id: in_flight.message_id, + result, + }); + } + } + + if let Some(future) = sleep_future.as_mut() { + if let Poll::Ready(()) = future.as_mut().poll(cx) { + return Poll::Ready(LoopStep::Timeout); + } + } + + recv_future.as_mut().poll(cx).map(|res| match res { + Ok(event) => LoopStep::Event(event), + Err(_) => LoopStep::Quit, + }) + }) + .await + } + + fn handle_connect(&self, state: &mut RuntimeState, peer: XID) { + let encapsulation_key = match state.peers.peer(peer) { + Some(entry) => match &entry.session { + PeerSession::Connected { .. } + | PeerSession::Initiator { .. } + | PeerSession::Responder { .. } => { + return; + } + PeerSession::Disconnected => entry.encapsulation_key.clone(), + }, + None => return, + }; + + let (hello, session_key) = match handshake::build_hello( + &self.platform, + self.platform.xid(), + peer, + &encapsulation_key, + ) { + Ok(result) => result, + Err(_) => return, + }; + + let deadline = Instant::now() + self.config.handshake_timeout; + let token = state.next_token(); + if let Some(entry) = state.peers.peer_mut(peer) { + entry.session = PeerSession::Initiator { + handshake_token: token, + hello: hello.clone(), + session_key, + deadline, + stage: InitiatorStage::WaitingHelloReply, + }; + self.platform.handle_peer_status(peer, &entry.session); + } + + let message = QlRecord { + header: QlHeader { + sender: self.platform.xid(), + recipient: peer, + }, + payload: QlPayload::Handshake(HandshakeRecord::Hello(hello)), + }; + let bytes = CBOR::from(message).to_cbor_data(); + self.enqueue_handshake_message(state, peer, token, deadline, bytes); + } + + fn handle_register_peer( + &self, + state: &mut RuntimeState, + peer: XID, + signing_key: bc_components::SigningPublicKey, + encapsulation_key: EncapsulationPublicKey, + ) { + let entry = state + .peers + .upsert_peer(peer, signing_key, encapsulation_key); + if let PeerSession::Disconnected = entry.session { + self.platform.handle_peer_status(peer, &entry.session); + } + } + + fn handle_send_request( + &self, + state: &mut RuntimeState, + recipient: XID, + route_id: RouteId, + payload: CBOR, + respond_to: oneshot::Sender>, + config: super::RequestConfig, + ) { + let id = state.next_message_id(); + let timeout = config + .timeout + .unwrap_or(self.config.default_request_timeout); + if timeout.is_zero() { + let _ = respond_to.send(Err(QlError::Timeout)); + return; + } + let Some(entry) = state.peers.peer(recipient) else { + let _ = respond_to.send(Err(QlError::UnknownPeer(recipient))); + return; + }; + let session_key = match &entry.session { + PeerSession::Connected { session_key, .. } => session_key, + _ => { + let _ = respond_to.send(Err(QlError::MissingSession(recipient))); + return; + } + }; + let valid_until = now_secs().saturating_add(self.config.message_expiration.as_secs()); + let body = MessageBody { + message_id: id, + valid_until, + kind: MessageKind::Request, + route_id, + payload, + }; + let message = message::encrypt_message( + QlHeader { + sender: self.platform.xid(), + recipient, + }, + &session_key, + body, + ); + let bytes = CBOR::from(message).to_cbor_data(); + state.pending.insert( + id, + PendingEntry { + recipient, + tx: respond_to, + }, + ); + state.timeouts.push(Reverse(TimeoutEntry { + at: Instant::now() + timeout, + kind: TimeoutKind::Request { id }, + })); + let outbound_deadline = Instant::now() + self.config.message_expiration; + self.enqueue_outbound(state, recipient, bytes, outbound_deadline, Some(id)); + } + + fn handle_send_event( + &self, + state: &mut RuntimeState, + recipient: XID, + route_id: RouteId, + payload: CBOR, + ) { + let id = state.next_message_id(); + let Some(session_key) = state + .peers + .peer(recipient) + .and_then(|p| p.session.session_key()) + else { + return; + }; + let valid_until = now_secs().saturating_add(self.config.message_expiration.as_secs()); + let body = MessageBody { + message_id: id, + valid_until, + kind: MessageKind::Event, + route_id, + payload, + }; + let message = message::encrypt_message( + QlHeader { + sender: self.platform.xid(), + recipient, + }, + &session_key, + body, + ); + let bytes = CBOR::from(message).to_cbor_data(); + let outbound_deadline = Instant::now() + self.config.message_expiration; + self.enqueue_outbound(state, recipient, bytes, outbound_deadline, None); + } + + fn handle_send_response( + &self, + state: &mut RuntimeState, + id: MessageId, + recipient: XID, + payload: CBOR, + kind: MessageKind, + ) { + let kind = match kind { + MessageKind::Response | MessageKind::Nack => kind, + _ => return, + }; + let Some(session_key) = state + .peers + .peer(recipient) + .and_then(|p| p.session.session_key()) + else { + return; + }; + + let valid_until = now_secs().saturating_add(self.config.message_expiration.as_secs()); + let body = MessageBody { + message_id: id, + valid_until, + kind, + route_id: RouteId::new(0), + payload, + }; + let message = message::encrypt_message( + QlHeader { + sender: self.platform.xid(), + recipient, + }, + &session_key, + body, + ); + let bytes = CBOR::from(message).to_cbor_data(); + let outbound_deadline = Instant::now() + self.config.message_expiration; + self.enqueue_outbound(state, recipient, bytes, outbound_deadline, None); + } + + fn handle_incoming(&self, state: &mut RuntimeState, bytes: Vec) { + let Ok(record) = CBOR::try_from_data(&bytes).and_then(QlRecord::try_from) else { + return; + }; + let QlRecord { header, payload } = record; + if header.recipient != self.platform.xid() { + return; + } + match payload { + QlPayload::Handshake(message) => { + self.handle_handshake(state, header, message); + } + QlPayload::Pair(request) => { + self.handle_pairing(state, header, request); + } + QlPayload::Message(encrypted) => { + self.handle_record(state, header, encrypted); + } + QlPayload::Heartbeat(encrypted) => { + self.handle_heartbeat(state, header, encrypted); + } + } + } + + fn handle_handshake( + &self, + state: &mut RuntimeState, + header: QlHeader, + message: HandshakeRecord, + ) { + match message { + HandshakeRecord::Hello(hello) => { + self.handle_hello(state, header, hello); + } + HandshakeRecord::HelloReply(reply) => { + self.handle_hello_reply(state, header, reply); + } + HandshakeRecord::Confirm(confirm) => { + self.handle_confirm(state, header, confirm); + } + } + } + + fn handle_pairing( + &self, + state: &mut RuntimeState, + header: QlHeader, + request: PairRequestRecord, + ) { + let payload = match pair::decrypt_pair_request(&self.platform, &header, request) { + Ok(payload) => payload, + Err(_) => return, + }; + let peer = XID::new(&payload.signing_pub_key); + state + .peers + .upsert_peer(peer, payload.signing_pub_key, payload.encapsulation_pub_key); + self.handle_connect(state, peer); + } + + fn handle_record( + &self, + state: &mut RuntimeState, + header: QlHeader, + encrypted: bc_components::EncryptedMessage, + ) { + let peer = header.sender; + let session_key = match state.peers.peer(peer) { + Some(entry) => match &entry.session { + PeerSession::Connected { session_key, .. } => session_key.clone(), + _ => return, + }, + None => return, + }; + let record = match message::decrypt_message(&header, &encrypted, &session_key) { + Ok(record) => record, + // TODO: fix this + Err(message::MessageError::Nack { .. }) => return, + Err(message::MessageError::Error(_)) => return, + }; + self.record_activity(state, peer); + match record.kind { + MessageKind::Response => { + self.resolve_pending_ok(state, peer, record.message_id, record.payload); + } + MessageKind::Nack => { + let nack = Nack::from(record.payload); + self.resolve_pending_nack(state, peer, record.message_id, nack); + } + MessageKind::Request => { + let Some(tx) = self.tx.upgrade() else { + return; + }; + let responder = Responder::new(record.message_id, record.sender, tx); + self.platform + .handle_inbound(HandlerEvent::Request(InboundRequest { + message: record, + respond_to: responder, + })); + } + MessageKind::Event => { + self.platform + .handle_inbound(HandlerEvent::Event(InboundEvent { message: record })); + } + } + } + + fn handle_heartbeat( + &self, + state: &mut RuntimeState, + header: QlHeader, + encrypted: bc_components::EncryptedMessage, + ) { + let peer = header.sender; + let (session_key, should_reply) = { + let Some(entry) = state.peers.peer(peer) else { + return; + }; + match &entry.session { + PeerSession::Connected { + session_key, + keepalive, + } => (session_key.clone(), !keepalive.pending), + _ => return, + } + }; + if heartbeat::decrypt_heartbeat(&header, &encrypted, &session_key).is_err() { + return; + } + self.record_activity(state, peer); + if should_reply { + self.send_heartbeat_message(state, peer, session_key); + } + } + + fn send_heartbeat_message( + &self, + state: &mut RuntimeState, + peer: XID, + session_key: bc_components::SymmetricKey, + ) { + let message_id = state.next_message_id(); + let valid_until = now_secs().saturating_add(self.config.message_expiration.as_secs()); + let message = heartbeat::encrypt_heartbeat( + QlHeader { + sender: self.platform.xid(), + recipient: peer, + }, + &session_key, + HeartbeatBody { + message_id, + valid_until, + }, + ); + let bytes = CBOR::from(message).to_cbor_data(); + let outbound_deadline = Instant::now() + self.config.message_expiration; + self.enqueue_outbound(state, peer, bytes, outbound_deadline, None); + } + + fn keep_alive_config(&self) -> Option { + self.config + .keep_alive + .filter(|config| !config.interval.is_zero() && !config.timeout.is_zero()) + } + + fn record_activity(&self, state: &mut RuntimeState, peer: XID) { + let Some(config) = self.keep_alive_config() else { + return; + }; + let token = state.next_token(); + let Some(entry) = state.peers.peer_mut(peer) else { + return; + }; + let PeerSession::Connected { keepalive, .. } = &mut entry.session else { + return; + }; + let now = Instant::now(); + keepalive.last_activity = Some(now); + keepalive.pending = false; + keepalive.token = token; + state.timeouts.push(Reverse(TimeoutEntry { + at: now + config.interval, + kind: TimeoutKind::KeepAliveSend { peer, token }, + })); + } + + fn drop_outbound_for_peer(&self, state: &mut RuntimeState, peer: XID) { + state.outbound.retain(|message| { + if message.peer == peer { + if let Some(id) = message.message_id { + if let Some(entry) = state.pending.remove(&id) { + let _ = entry.tx.send(Err(QlError::SendFailed)); + } + } + false + } else { + true + } + }); + } + + fn fail_pending_for_peer(&self, state: &mut RuntimeState, peer: XID) { + state + .pending + .extract_if(|_id, entry| entry.recipient == peer) + .for_each(|(_, entry)| { + let _ = entry.tx.send(Err(QlError::SendFailed)); + }); + } + + fn resolve_pending_ok( + &self, + state: &mut RuntimeState, + sender: XID, + id: MessageId, + payload: CBOR, + ) { + if let Some(entry) = state.pending.remove(&id) { + if entry.recipient == sender { + let _ = entry.tx.send(Ok(payload)); + } + } + } + + fn resolve_pending_nack( + &self, + state: &mut RuntimeState, + sender: XID, + id: MessageId, + nack: Nack, + ) { + if let Some(entry) = state.pending.remove(&id) { + if entry.recipient == sender { + let _ = entry.tx.send(Err(QlError::Nack { id, nack })); + } + } + } + + fn handle_hello( + &self, + state: &mut RuntimeState, + header: QlHeader, + hello: crate::wire::handshake::Hello, + ) { + let peer = header.sender; + let action = match state.peers.peer(peer) { + Some(entry) => match &entry.session { + PeerSession::Initiator { + hello: local_hello, .. + } => { + if peer_hello_wins(local_hello, self.platform.xid(), &hello, peer) { + HelloAction::StartResponder + } else { + HelloAction::Ignore + } + } + PeerSession::Responder { + hello: stored, + reply, + deadline, + .. + } => { + if stored.nonce == hello.nonce { + HelloAction::ResendReply { + reply: reply.clone(), + deadline: *deadline, + } + } else { + HelloAction::StartResponder + } + } + PeerSession::Disconnected | PeerSession::Connected { .. } => { + HelloAction::StartResponder + } + }, + None => return, + }; + + match action { + HelloAction::StartResponder => { + self.start_responder_handshake(state, peer, hello); + } + HelloAction::ResendReply { reply, deadline } => { + let message = QlRecord { + header: QlHeader { + sender: self.platform.xid(), + recipient: peer, + }, + payload: QlPayload::Handshake(HandshakeRecord::HelloReply(reply)), + }; + let bytes = CBOR::from(message).to_cbor_data(); + self.enqueue_outbound(state, peer, bytes, deadline, None); + } + HelloAction::Ignore => {} + } + } + + fn handle_hello_reply( + &self, + state: &mut RuntimeState, + header: QlHeader, + reply: crate::wire::handshake::HelloReply, + ) { + let peer = header.sender; + let (hello, initiator_secret, stage, responder_signing_key) = match state.peers.peer(peer) { + Some(entry) => match &entry.session { + PeerSession::Initiator { + hello, + session_key, + stage, + .. + } => ( + hello.clone(), + session_key.clone(), + *stage, + entry.signing_key.clone(), + ), + _ => return, + }, + None => return, + }; + + if stage != InitiatorStage::WaitingHelloReply { + return; + } + + let confirm = match handshake::build_confirm( + &self.platform, + self.platform.xid(), + peer, + &responder_signing_key, + &hello, + &reply, + &initiator_secret, + ) { + Ok((confirm, session_key)) => { + if let Some(entry) = state.peers.peer_mut(peer) { + entry.session = PeerSession::Connected { + session_key, + keepalive: KeepAliveState::new(), + }; + self.platform.handle_peer_status(peer, &entry.session); + } + self.record_activity(state, peer); + confirm + } + Err(_) => { + if let Some(entry) = state.peers.peer_mut(peer) { + entry.session = PeerSession::Disconnected; + self.platform.handle_peer_status(peer, &entry.session); + } + return; + } + }; + + let message = QlRecord { + header: QlHeader { + sender: self.platform.xid(), + recipient: peer, + }, + payload: QlPayload::Handshake(HandshakeRecord::Confirm(confirm)), + }; + let bytes = CBOR::from(message).to_cbor_data(); + let deadline = Instant::now() + self.config.handshake_timeout; + self.enqueue_outbound(state, peer, bytes, deadline, None); + } + + fn handle_confirm( + &self, + state: &mut RuntimeState, + header: QlHeader, + confirm: crate::wire::handshake::Confirm, + ) { + let peer = header.sender; + let (hello, reply, secrets, initiator_signing_key) = match state.peers.peer(peer) { + Some(entry) => match &entry.session { + PeerSession::Responder { + hello, + reply, + secrets, + .. + } => ( + hello.clone(), + reply.clone(), + secrets.clone(), + entry.signing_key.clone(), + ), + _ => return, + }, + None => return, + }; + + match handshake::finalize_confirm( + peer, + self.platform.xid(), + &initiator_signing_key, + &hello, + &reply, + &confirm, + &secrets, + ) { + Ok(session_key) => { + if let Some(entry) = state.peers.peer_mut(peer) { + entry.session = PeerSession::Connected { + session_key, + keepalive: KeepAliveState::new(), + }; + self.platform.handle_peer_status(peer, &entry.session); + } + self.record_activity(state, peer); + } + Err(_) => { + if let Some(entry) = state.peers.peer_mut(peer) { + entry.session = PeerSession::Disconnected; + self.platform.handle_peer_status(peer, &entry.session); + } + } + } + } + + fn start_responder_handshake( + &self, + state: &mut RuntimeState, + peer: XID, + hello: crate::wire::handshake::Hello, + ) { + let encapsulation_key = match state.peers.peer(peer) { + Some(entry) => entry.encapsulation_key.clone(), + None => return, + }; + let (reply, secrets) = match handshake::respond_hello( + &self.platform, + peer, + self.platform.xid(), + &encapsulation_key, + &hello, + ) { + Ok(result) => result, + Err(_) => { + if let Some(entry) = state.peers.peer_mut(peer) { + entry.session = PeerSession::Disconnected; + self.platform.handle_peer_status(peer, &entry.session); + } + return; + } + }; + + let deadline = Instant::now() + self.config.handshake_timeout; + let token = state.next_token(); + if let Some(entry) = state.peers.peer_mut(peer) { + entry.session = PeerSession::Responder { + handshake_token: token, + hello: hello.clone(), + reply: reply.clone(), + secrets, + deadline, + }; + self.platform.handle_peer_status(peer, &entry.session); + } + + let message = QlRecord { + header: QlHeader { + sender: self.platform.xid(), + recipient: peer, + }, + payload: QlPayload::Handshake(HandshakeRecord::HelloReply(reply)), + }; + let bytes = CBOR::from(message).to_cbor_data(); + self.enqueue_handshake_message(state, peer, token, deadline, bytes); + } + + fn enqueue_handshake_message( + &self, + state: &mut RuntimeState, + peer: XID, + token: Token, + deadline: Instant, + bytes: Vec, + ) { + state.outbound.push_back(OutboundMessage { + peer, + token, + message_id: None, + bytes, + }); + state.timeouts.push(Reverse(TimeoutEntry { + at: deadline, + kind: TimeoutKind::Handshake { peer, token }, + })); + state.timeouts.push(Reverse(TimeoutEntry { + at: deadline, + kind: TimeoutKind::Outbound { token }, + })); + } + + fn enqueue_outbound( + &self, + state: &mut RuntimeState, + peer: XID, + bytes: Vec, + deadline: Instant, + message_id: Option, + ) { + let token = state.next_token(); + state.outbound.push_back(OutboundMessage { + peer, + token, + message_id, + bytes, + }); + state.timeouts.push(Reverse(TimeoutEntry { + at: deadline, + kind: TimeoutKind::Outbound { token }, + })); + } + + fn handle_timeouts(&self, state: &mut RuntimeState) { + let now = Instant::now(); + loop { + let Some(entry) = state.timeouts.peek_mut().filter(|e| e.0.at <= now) else { + break; + }; + let entry = PeekMut::pop(entry).0; + match entry.kind { + TimeoutKind::Outbound { token } => { + let mut message_id = None; + state.outbound.retain(|message| { + if message.token == token { + message_id = message.message_id; + false + } else { + true + } + }); + if let Some(id) = message_id { + if let Some(entry) = state.pending.remove(&id) { + let _ = entry.tx.send(Err(QlError::SendFailed)); + } + } + } + TimeoutKind::Handshake { peer, token } => { + let Some(entry) = state.peers.peer(peer) else { + continue; + }; + let should_disconnect = match &entry.session { + PeerSession::Initiator { + handshake_token, .. + } + | PeerSession::Responder { + handshake_token, .. + } => *handshake_token == token, + _ => false, + }; + if should_disconnect { + if let Some(entry) = state.peers.peer_mut(peer) { + entry.session = PeerSession::Disconnected; + self.platform.handle_peer_status(peer, &entry.session); + } + state.outbound.retain(|message| message.peer != peer); + } + } + TimeoutKind::Request { id } => { + if let Some(entry) = state.pending.remove(&id) { + let _ = entry.tx.send(Err(QlError::Timeout)); + } + } + TimeoutKind::KeepAliveSend { peer, token } => { + let Some(config) = self.keep_alive_config() else { + continue; + }; + let session_key = { + let Some(entry) = state.peers.peer(peer) else { + continue; + }; + let PeerSession::Connected { + session_key, + keepalive, + } = &entry.session + else { + continue; + }; + if keepalive.token == token && !keepalive.pending { + session_key.clone() + } else { + continue; + } + }; + self.send_heartbeat_message(state, peer, session_key); + if let Some(entry) = state.peers.peer_mut(peer) { + if let PeerSession::Connected { keepalive, .. } = &mut entry.session { + if keepalive.token == token { + keepalive.pending = true; + } + } + } + state.timeouts.push(Reverse(TimeoutEntry { + at: now + config.timeout, + kind: TimeoutKind::KeepAliveTimeout { peer, token }, + })); + } + TimeoutKind::KeepAliveTimeout { peer, token } => { + let Some(entry) = state.peers.peer(peer) else { + continue; + }; + + let should_disconnect = match &entry.session { + PeerSession::Connected { keepalive, .. } => { + keepalive.token == token && keepalive.pending + } + _ => false, + }; + + if should_disconnect { + if let Some(entry) = state.peers.peer_mut(peer) { + entry.session = PeerSession::Disconnected; + self.platform.handle_peer_status(peer, &entry.session); + } + self.drop_outbound_for_peer(state, peer); + self.fail_pending_for_peer(state, peer); + } + } + } + } + } + + fn handle_write_done( + &self, + state: &mut RuntimeState, + peer: XID, + token: Token, + message_id: Option, + result: Result<(), QlError>, + ) { + if result.is_ok() { + return; + } + + if let Some(id) = message_id { + if let Some(entry) = state.pending.remove(&id) { + let _ = entry.tx.send(Err(QlError::SendFailed)); + } + } + let should_disconnect = match state.peers.peer(peer).map(|entry| &entry.session) { + Some(PeerSession::Initiator { + handshake_token, .. + }) if *handshake_token == token => true, + Some(PeerSession::Responder { + handshake_token, .. + }) if *handshake_token == token => true, + _ => false, + }; + if should_disconnect { + if let Some(entry) = state.peers.peer_mut(peer) { + entry.session = PeerSession::Disconnected; + self.platform.handle_peer_status(peer, &entry.session); + } + state.outbound.retain(|message| message.peer != peer); + } + } +} diff --git a/ql/src/runtime/handle.rs b/ql/src/runtime/handle.rs new file mode 100644 index 00000000..d08ef995 --- /dev/null +++ b/ql/src/runtime/handle.rs @@ -0,0 +1,161 @@ +use std::{ + future::Future, + marker::PhantomData, + pin::{pin, Pin}, + task::{Context, Poll}, +}; + +use bc_components::{EncapsulationPublicKey, SigningPublicKey, XID}; +use dcbor::CBOR; + +use crate::{ + runtime::{internal::RuntimeCommand, RequestConfig}, + wire::message::Ack, + Event, QlCodec, QlError, RequestResponse, RouteId, +}; + +#[derive(Clone)] +pub struct RuntimeHandle { + pub(crate) tx: async_channel::Sender, +} + +pub struct Response { + rx: oneshot::Receiver>, + _type: PhantomData T>, +} + +impl Response { + pub async fn recv(self) -> Result { + self.rx.await.unwrap_or(Err(QlError::Cancelled)) + } +} + +impl Future for Response +where + T: QlCodec, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + pin!(&mut self.rx).poll(cx).map(|result| { + let payload = result.unwrap_or(Err(QlError::Cancelled))?; + Ok(T::try_from(payload).map_err(|_| QlError::InvalidPayload)?) + }) + } +} + +impl RuntimeHandle { + pub fn register_peer( + &self, + peer: XID, + signing_key: SigningPublicKey, + encapsulation_key: EncapsulationPublicKey, + ) { + self.send(RuntimeCommand::RegisterPeer { + peer, + signing_key, + encapsulation_key, + }) + } + + pub fn connect(&self, peer: XID) -> Result<(), QlError> { + self.tx + .send_blocking(RuntimeCommand::Connect { peer }) + .map_err(|_| QlError::Cancelled) + } + + pub fn send_incoming(&self, bytes: Vec) { + self.send(RuntimeCommand::Incoming(bytes)) + } + + pub fn request( + &self, + message: M, + recipient: XID, + config: RequestConfig, + ) -> Response + where + M: RequestResponse, + { + let (tx, rx) = oneshot::channel(); + self.send(RuntimeCommand::SendRequest { + recipient, + route_id: M::ID, + payload: message.into(), + respond_to: tx, + config, + }); + Response { + rx, + _type: PhantomData, + } + } + + pub fn send_event(&self, message: M, recipient: XID) + where + M: Event, + { + self.send_event_raw(recipient, M::ID, message.into()) + } + + pub fn send_event_with_ack( + &self, + message: M, + recipient: XID, + config: RequestConfig, + ) -> Response + where + M: Event, + { + let (tx, rx) = oneshot::channel(); + self.send(RuntimeCommand::SendRequest { + recipient, + route_id: M::ID, + payload: message.into(), + respond_to: tx, + config, + }); + Response { + rx, + _type: PhantomData, + } + } + + pub fn send_event_raw(&self, recipient: XID, route_id: RouteId, payload: CBOR) { + self.send(RuntimeCommand::SendEvent { + recipient, + route_id, + payload, + }) + } + + pub fn send_request_raw( + &self, + recipient: XID, + route_id: RouteId, + payload: CBOR, + config: RequestConfig, + ) -> Response { + let (tx, rx) = oneshot::channel(); + self.send(RuntimeCommand::SendRequest { + recipient, + route_id, + payload, + respond_to: tx, + config, + }); + Response { + rx, + _type: PhantomData, + } + } +} + +impl RuntimeHandle { + #[inline] + #[track_caller] + fn send(&self, cmd: RuntimeCommand) { + // send_blocking is ok bc queue is unbounded + self.tx.send_blocking(cmd).expect("runtime is alive") + } +} diff --git a/ql/src/runtime/internal.rs b/ql/src/runtime/internal.rs new file mode 100644 index 00000000..f763b47c --- /dev/null +++ b/ql/src/runtime/internal.rs @@ -0,0 +1,308 @@ +use std::{ + cell::Cell, + cmp::Reverse, + collections::{BinaryHeap, HashMap, VecDeque}, + time::{Instant, SystemTime, UNIX_EPOCH}, +}; + +use bc_components::{EncapsulationPublicKey, SigningPublicKey, SymmetricKey, XID}; +use dcbor::CBOR; + +use crate::{ + platform::PlatformFuture, + runtime::RequestConfig, + wire::{ + handshake::{Hello, HelloReply}, + message::MessageKind, + }, + MessageId, QlError, RouteId, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Token(u64); + +impl Token { + pub(crate) fn next(self) -> Self { + Self(self.0.wrapping_add(1)) + } +} + +#[derive(Debug, Clone)] +pub struct KeepAliveState { + pub token: Token, + pub pending: bool, + pub last_activity: Option, +} + +impl KeepAliveState { + pub fn new() -> Self { + Self { + token: Token(0), + pending: false, + last_activity: None, + } + } +} + +#[derive(Debug, Clone)] +pub struct PeerRecord { + pub peer: XID, + pub signing_key: SigningPublicKey, + pub encapsulation_key: EncapsulationPublicKey, + pub session: PeerSession, +} + +impl PeerRecord { + pub fn new( + peer: XID, + signing_key: SigningPublicKey, + encapsulation_key: EncapsulationPublicKey, + ) -> Self { + Self { + peer, + signing_key, + encapsulation_key, + session: PeerSession::Disconnected, + } + } +} + +#[derive(Debug, Clone)] +pub struct PeerStore { + peers: Vec, +} + +impl PeerStore { + pub fn new() -> Self { + Self { peers: Vec::new() } + } + + pub fn peer(&self, peer: XID) -> Option<&PeerRecord> { + self.peers.iter().find(|record| record.peer == peer) + } + + pub fn peer_mut(&mut self, peer: XID) -> Option<&mut PeerRecord> { + self.peers.iter_mut().find(|record| record.peer == peer) + } + + pub fn upsert_peer( + &mut self, + peer: XID, + signing_key: SigningPublicKey, + encapsulation_key: EncapsulationPublicKey, + ) -> &mut PeerRecord { + if let Some(index) = self.peers.iter().position(|record| record.peer == peer) { + let record = &mut self.peers[index]; + record.signing_key = signing_key; + record.encapsulation_key = encapsulation_key; + return record; + } + self.peers + .push(PeerRecord::new(peer, signing_key, encapsulation_key)); + self.peers.last_mut().expect("peer record just inserted") + } +} + +#[derive(Debug, Clone)] +pub enum PeerSession { + Disconnected, + Initiator { + handshake_token: Token, + hello: Hello, + session_key: SymmetricKey, + deadline: Instant, + stage: InitiatorStage, + }, + Responder { + handshake_token: Token, + hello: Hello, + reply: HelloReply, + secrets: crate::crypto::handshake::ResponderSecrets, + deadline: Instant, + }, + Connected { + session_key: SymmetricKey, + keepalive: KeepAliveState, + }, +} + +impl PeerSession { + #[inline] + pub fn is_connected(&self) -> bool { + match self { + PeerSession::Connected { .. } => true, + _ => false, + } + } + + #[inline] + pub fn session_key(&self) -> Option<&SymmetricKey> { + match self { + PeerSession::Connected { session_key, .. } => Some(session_key), + _ => None, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum InitiatorStage { + WaitingHelloReply, + WaitingConfirmAck, +} + +pub(crate) enum RuntimeCommand { + RegisterPeer { + peer: XID, + signing_key: SigningPublicKey, + encapsulation_key: EncapsulationPublicKey, + }, + Connect { + peer: XID, + }, + SendRequest { + recipient: XID, + route_id: RouteId, + payload: CBOR, + respond_to: oneshot::Sender>, + config: RequestConfig, + }, + SendEvent { + recipient: XID, + route_id: RouteId, + payload: CBOR, + }, + SendResponse { + id: MessageId, + recipient: XID, + payload: CBOR, + kind: MessageKind, + }, + Incoming(Vec), +} + +pub struct RuntimeState { + pub peers: PeerStore, + pub next_token: Cell, + pub outbound: VecDeque, + pub timeouts: BinaryHeap>, + pub pending: HashMap, + pub next_message_id: u64, +} + +impl RuntimeState { + pub fn new() -> Self { + Self { + peers: PeerStore::new(), + next_token: Cell::new(Token(0)), + outbound: VecDeque::new(), + timeouts: BinaryHeap::new(), + pending: HashMap::new(), + next_message_id: 1, + } + } + + pub fn next_token(&self) -> Token { + let token = self.next_token.get(); + self.next_token.set(token.next()); + token + } + + pub fn next_message_id(&mut self) -> MessageId { + let id = self.next_message_id; + self.next_message_id = id.wrapping_add(1); + MessageId::new(id) + } +} + +pub struct PendingEntry { + pub recipient: XID, + pub tx: oneshot::Sender>, +} + +pub struct InFlightWrite<'a> { + pub peer: XID, + pub token: Token, + pub message_id: Option, + pub future: PlatformFuture<'a, Result<(), QlError>>, +} + +pub struct OutboundMessage { + pub peer: XID, + pub token: Token, + pub message_id: Option, + pub bytes: Vec, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TimeoutKind { + Outbound { token: Token }, + Handshake { peer: XID, token: Token }, + Request { id: MessageId }, + KeepAliveSend { peer: XID, token: Token }, + KeepAliveTimeout { peer: XID, token: Token }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TimeoutEntry { + pub at: Instant, + pub kind: TimeoutKind, +} + +impl Ord for TimeoutEntry { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.at.cmp(&other.at) + } +} + +impl PartialOrd for TimeoutEntry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +pub enum LoopStep { + Event(RuntimeCommand), + Timeout, + WriteDone { + peer: XID, + token: Token, + message_id: Option, + result: Result<(), QlError>, + }, + Quit, +} + +pub enum HelloAction { + StartResponder, + ResendReply { + reply: HelloReply, + deadline: Instant, + }, + Ignore, +} + +pub fn next_timeout_deadline(state: &RuntimeState) -> Option { + state.timeouts.peek().map(|entry| entry.0.at) +} + +pub fn peer_hello_wins( + local_hello: &Hello, + local_sender: XID, + peer_hello: &Hello, + peer_sender: XID, +) -> bool { + use std::cmp::Ordering; + + match peer_hello.nonce.data().cmp(local_hello.nonce.data()) { + Ordering::Less => true, + Ordering::Greater => false, + Ordering::Equal => peer_sender.data().cmp(local_sender.data()) == Ordering::Less, + } +} + +pub fn now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_secs()) + .unwrap_or(0) +} diff --git a/ql/src/runtime/mod.rs b/ql/src/runtime/mod.rs new file mode 100644 index 00000000..8a57fa76 --- /dev/null +++ b/ql/src/runtime/mod.rs @@ -0,0 +1,146 @@ +pub use handle::{Response, RuntimeHandle}; +pub use internal::{InitiatorStage, PeerSession, Token}; + +mod core; +pub mod handle; +pub(crate) mod internal; + +#[cfg(test)] +mod tests; + +use std::time::Duration; + +use bc_components::XID; +use dcbor::CBOR; + +use crate::{ + wire::message::{DecryptedMessage, MessageKind, Nack}, + MessageId, QlCodec, QlError, +}; + +#[derive(Debug, Clone, Default)] +pub struct RequestConfig { + pub timeout: Option, +} + +#[derive(Debug, Clone, Copy)] +pub struct KeepAliveConfig { + pub interval: Duration, + pub timeout: Duration, +} + +#[derive(Debug, Clone, Copy)] +pub struct RuntimeConfig { + pub handshake_timeout: Duration, + pub default_request_timeout: Duration, + pub message_expiration: Duration, + pub keep_alive: Option, +} + +impl RuntimeConfig { + pub fn new(handshake_timeout: Duration) -> Self { + Self { + handshake_timeout, + default_request_timeout: Duration::from_secs(5), + message_expiration: Duration::from_secs(30), + keep_alive: None, + } + } + + pub fn with_request_timeout(mut self, timeout: Duration) -> Self { + self.default_request_timeout = timeout; + self + } + + pub fn with_message_expiration(mut self, expiration: Duration) -> Self { + self.message_expiration = expiration; + self + } + + pub fn with_keep_alive(mut self, config: KeepAliveConfig) -> Self { + self.keep_alive = Some(config); + self + } +} + +#[derive(Debug)] +pub enum HandlerEvent { + Request(InboundRequest), + Event(InboundEvent), +} + +#[derive(Debug)] +pub struct InboundRequest { + pub message: DecryptedMessage, + pub respond_to: Responder, +} + +#[derive(Debug)] +pub struct InboundEvent { + pub message: DecryptedMessage, +} + +#[derive(Debug, Clone)] +pub struct Responder { + id: MessageId, + recipient: XID, + tx: async_channel::Sender, +} + +impl Responder { + pub(crate) fn new( + id: MessageId, + recipient: XID, + tx: async_channel::Sender, + ) -> Self { + Self { id, recipient, tx } + } + + pub fn respond(self, response: R) -> Result<(), QlError> + where + R: QlCodec, + { + self.tx + .try_send(internal::RuntimeCommand::SendResponse { + id: self.id, + recipient: self.recipient, + payload: response.into(), + kind: MessageKind::Response, + }) + .map_err(|_| QlError::Cancelled) + } + + pub fn respond_nack(self, reason: Nack) -> Result<(), QlError> { + self.tx + .try_send(internal::RuntimeCommand::SendResponse { + id: self.id, + recipient: self.recipient, + payload: CBOR::from(reason), + kind: MessageKind::Nack, + }) + .map_err(|_| QlError::Cancelled) + } +} + +pub struct Runtime

{ + platform: P, + config: RuntimeConfig, + rx: async_channel::Receiver, + tx: async_channel::WeakSender, +} + +pub fn new_runtime

(platform: P, config: RuntimeConfig) -> (Runtime

, RuntimeHandle) +where + P: crate::platform::QlPlatform, +{ + let (tx, rx) = async_channel::unbounded(); + ( + Runtime { + platform, + config, + rx, + tx: tx.downgrade(), + }, + RuntimeHandle { tx }, + ) +} diff --git a/ql/src/runtime/tests.rs b/ql/src/runtime/tests.rs new file mode 100644 index 00000000..ceec9294 --- /dev/null +++ b/ql/src/runtime/tests.rs @@ -0,0 +1,1588 @@ +use std::{ + future::Future, + sync::{ + atomic::{AtomicBool, AtomicU8, Ordering}, + Arc, + }, + time::Duration, +}; + +use async_channel::{Receiver, Sender}; +use bc_components::{ + EncapsulationPrivateKey, EncapsulationPublicKey, EncapsulationScheme, SignatureScheme, Signer, + SigningPrivateKey, SigningPublicKey, SymmetricKey, XID, +}; +use dcbor::CBOR; +use tokio::{sync::Semaphore, task::LocalSet}; + +use crate::{ + crypto::{handshake, heartbeat, pair}, + platform::{PlatformFuture, QlPlatform, QlPlatformExt}, + runtime::{ + internal::now_secs, new_runtime, HandlerEvent, KeepAliveConfig, PeerSession, RequestConfig, + RuntimeConfig, RuntimeHandle, + }, + wire::{ + handshake::HandshakeRecord, heartbeat::HeartbeatBody, message::Nack, QlHeader, QlPayload, + QlRecord, + }, + MessageId, QlError, RouteId, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PeerStage { + Disconnected, + Initiator, + Responder, + Connected, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct StatusEvent { + peer: XID, + stage: PeerStage, +} + +struct TestPlatform { + signing_private: SigningPrivateKey, + signing_public: SigningPublicKey, + encapsulation_private: EncapsulationPrivateKey, + encapsulation_public: EncapsulationPublicKey, + outbound: Sender>, + status: Sender, + nonce_seed: u8, + nonce_counter: AtomicU8, +} + +impl TestPlatform { + fn new(seed: u8) -> (Self, Receiver>, Receiver) { + let (signing_private, signing_public) = SignatureScheme::MLDSA44.keypair(); + let (encapsulation_private, encapsulation_public) = + EncapsulationScheme::default().keypair(); + let (outbound, outbound_rx) = async_channel::unbounded(); + let (status, status_rx) = async_channel::unbounded(); + ( + Self { + signing_private, + signing_public, + encapsulation_private, + encapsulation_public, + outbound, + status, + nonce_seed: seed, + nonce_counter: AtomicU8::new(0), + }, + outbound_rx, + status_rx, + ) + } + + fn signing_public_key(&self) -> &SigningPublicKey { + &self.signing_public + } + + fn encapsulation_public_key(&self) -> &EncapsulationPublicKey { + &self.encapsulation_public + } +} + +impl QlPlatform for TestPlatform { + fn signer(&self) -> &dyn Signer { + &self.signing_private + } + + fn signing_public_key(&self) -> &SigningPublicKey { + &self.signing_public + } + + fn encapsulation_private_key(&self) -> &EncapsulationPrivateKey { + &self.encapsulation_private + } + + fn encapsulation_public_key(&self) -> &EncapsulationPublicKey { + &self.encapsulation_public + } + + fn fill_bytes(&self, data: &mut [u8]) { + let value = self + .nonce_seed + .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); + data.fill(value); + } + + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { + let outbound = self.outbound.clone(); + Box::pin(async move { + outbound + .send(message) + .await + .map_err(|_| QlError::InvalidPayload) + }) + } + + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { + Box::pin(tokio::time::sleep(duration)) + } + + fn handle_peer_status(&self, peer: XID, session: &PeerSession) { + let stage = match session { + PeerSession::Disconnected => PeerStage::Disconnected, + PeerSession::Initiator { .. } => PeerStage::Initiator, + PeerSession::Responder { .. } => PeerStage::Responder, + PeerSession::Connected { .. } => PeerStage::Connected, + }; + let _ = self.status.try_send(StatusEvent { peer, stage }); + } + + fn handle_inbound(&self, _event: crate::runtime::HandlerEvent) {} +} + +struct BlockingPlatform { + signing_private: SigningPrivateKey, + signing_public: SigningPublicKey, + encapsulation_private: EncapsulationPrivateKey, + encapsulation_public: EncapsulationPublicKey, + outbound: Sender>, + status: Sender, + nonce_seed: u8, + nonce_counter: AtomicU8, + write_gate: Arc, +} + +struct InboundPlatform { + signing_private: SigningPrivateKey, + signing_public: SigningPublicKey, + encapsulation_private: EncapsulationPrivateKey, + encapsulation_public: EncapsulationPublicKey, + outbound: Sender>, + status: Sender, + inbound: Sender, + nonce_seed: u8, + nonce_counter: AtomicU8, +} + +impl InboundPlatform { + fn new( + seed: u8, + ) -> ( + Self, + Receiver>, + Receiver, + Receiver, + ) { + let (signing_private, signing_public) = SignatureScheme::MLDSA44.keypair(); + let (encapsulation_private, encapsulation_public) = + EncapsulationScheme::default().keypair(); + let (outbound, outbound_rx) = async_channel::unbounded(); + let (status, status_rx) = async_channel::unbounded(); + let (inbound, inbound_rx) = async_channel::unbounded(); + ( + Self { + signing_private, + signing_public, + encapsulation_private, + encapsulation_public, + outbound, + status, + inbound, + nonce_seed: seed, + nonce_counter: AtomicU8::new(0), + }, + outbound_rx, + status_rx, + inbound_rx, + ) + } +} + +impl QlPlatform for InboundPlatform { + fn signer(&self) -> &dyn Signer { + &self.signing_private + } + + fn signing_public_key(&self) -> &SigningPublicKey { + &self.signing_public + } + + fn encapsulation_private_key(&self) -> &EncapsulationPrivateKey { + &self.encapsulation_private + } + + fn encapsulation_public_key(&self) -> &EncapsulationPublicKey { + &self.encapsulation_public + } + + fn fill_bytes(&self, data: &mut [u8]) { + let value = self + .nonce_seed + .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); + data.fill(value); + } + + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { + let outbound = self.outbound.clone(); + Box::pin(async move { + outbound + .send(message) + .await + .map_err(|_| QlError::InvalidPayload) + }) + } + + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { + Box::pin(tokio::time::sleep(duration)) + } + + fn handle_peer_status(&self, peer: XID, session: &PeerSession) { + let stage = match session { + PeerSession::Disconnected => PeerStage::Disconnected, + PeerSession::Initiator { .. } => PeerStage::Initiator, + PeerSession::Responder { .. } => PeerStage::Responder, + PeerSession::Connected { .. } => PeerStage::Connected, + }; + let _ = self.status.try_send(StatusEvent { peer, stage }); + } + + fn handle_inbound(&self, event: HandlerEvent) { + let _ = self.inbound.try_send(event); + } +} + +impl BlockingPlatform { + fn new( + seed: u8, + ) -> ( + Self, + Receiver>, + Receiver, + Arc, + ) { + let (signing_private, signing_public) = SignatureScheme::MLDSA44.keypair(); + let (encapsulation_private, encapsulation_public) = + EncapsulationScheme::default().keypair(); + let (outbound, outbound_rx) = async_channel::unbounded(); + let (status, status_rx) = async_channel::unbounded(); + let write_gate = Arc::new(Semaphore::new(0)); + ( + Self { + signing_private, + signing_public, + encapsulation_private, + encapsulation_public, + outbound, + status, + nonce_seed: seed, + nonce_counter: AtomicU8::new(0), + write_gate: write_gate.clone(), + }, + outbound_rx, + status_rx, + write_gate, + ) + } +} + +impl QlPlatform for BlockingPlatform { + fn signer(&self) -> &dyn Signer { + &self.signing_private + } + + fn signing_public_key(&self) -> &SigningPublicKey { + &self.signing_public + } + + fn encapsulation_private_key(&self) -> &EncapsulationPrivateKey { + &self.encapsulation_private + } + + fn encapsulation_public_key(&self) -> &EncapsulationPublicKey { + &self.encapsulation_public + } + + fn fill_bytes(&self, data: &mut [u8]) { + let value = self + .nonce_seed + .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); + data.fill(value); + } + + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { + let outbound = self.outbound.clone(); + let write_gate = self.write_gate.clone(); + Box::pin(async move { + let _permit = write_gate.acquire().await.unwrap(); + outbound + .send(message) + .await + .map_err(|_| QlError::InvalidPayload) + }) + } + + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { + Box::pin(tokio::time::sleep(duration)) + } + + fn handle_peer_status(&self, peer: XID, session: &PeerSession) { + let stage = match session { + PeerSession::Disconnected => PeerStage::Disconnected, + PeerSession::Initiator { .. } => PeerStage::Initiator, + PeerSession::Responder { .. } => PeerStage::Responder, + PeerSession::Connected { .. } => PeerStage::Connected, + }; + let _ = self.status.try_send(StatusEvent { peer, stage }); + } + + fn handle_inbound(&self, _event: crate::runtime::HandlerEvent) {} +} + +async fn run_local_test(future: F) +where + F: Future, +{ + let local = LocalSet::new(); + local.run_until(future).await; +} + +fn spawn_forwarder(outbound: Receiver>, handle: RuntimeHandle) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + let _ = handle.send_incoming(bytes); + } + }); +} + +fn is_heartbeat(bytes: &[u8]) -> bool { + let Ok(record) = CBOR::try_from_data(bytes).and_then(QlRecord::try_from) else { + return false; + }; + matches!(record.payload, QlPayload::Heartbeat(_)) +} + +fn spawn_heartbeat_tap_forwarder( + outbound: Receiver>, + handle: RuntimeHandle, + heartbeat_tx: Sender<()>, +) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + if is_heartbeat(&bytes) { + let _ = heartbeat_tx.send(()).await; + } + let _ = handle.send_incoming(bytes); + } + }); +} + +fn spawn_drop_heartbeat_forwarder(outbound: Receiver>, handle: RuntimeHandle) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + if is_heartbeat(&bytes) { + continue; + } + let _ = handle.send_incoming(bytes); + } + }); +} + +fn spawn_gated_forwarder( + outbound: Receiver>, + handle: RuntimeHandle, + drop_flag: Arc, +) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + if drop_flag.load(Ordering::Relaxed) { + continue; + } + let _ = handle.send_incoming(bytes); + } + }); +} + +fn spawn_routed_forwarder(outbound: Receiver>, routes: Vec<(XID, RuntimeHandle)>) { + spawn_routed_forwarder_with_filter(outbound, routes, |_| true); +} + +fn spawn_routed_forwarder_with_filter( + outbound: Receiver>, + routes: Vec<(XID, RuntimeHandle)>, + filter: F, +) where + F: Fn(&QlRecord) -> bool + Send + Sync + 'static, +{ + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + let Ok(record) = CBOR::try_from_data(&bytes).and_then(QlRecord::try_from) else { + continue; + }; + if !filter(&record) { + continue; + } + if let Some((_, handle)) = routes + .iter() + .find(|(peer, _)| *peer == record.header.recipient) + { + let _ = handle.send_incoming(bytes); + } + } + }); +} + +#[derive(Clone)] +struct PeerIdentity { + xid: XID, + signing_key: SigningPublicKey, + encapsulation_key: EncapsulationPublicKey, +} + +fn peer_identity(platform: &impl QlPlatformExt) -> PeerIdentity { + PeerIdentity { + xid: platform.xid(), + signing_key: platform.signing_public_key().clone(), + encapsulation_key: platform.encapsulation_public_key().clone(), + } +} + +fn register_peers( + handle_a: &RuntimeHandle, + handle_b: &RuntimeHandle, + identity_a: &PeerIdentity, + identity_b: &PeerIdentity, +) -> (XID, XID) { + let peer_a = identity_a.xid; + let peer_b = identity_b.xid; + handle_a.register_peer( + peer_b, + identity_b.signing_key.clone(), + identity_b.encapsulation_key.clone(), + ); + handle_b.register_peer( + peer_a, + identity_a.signing_key.clone(), + identity_a.encapsulation_key.clone(), + ); + (peer_a, peer_b) +} + +async fn await_status( + receiver: &Receiver, + peer: XID, + stage: PeerStage, +) -> StatusEvent { + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if let Ok(event) = receiver.recv().await { + if event.peer == peer && event.stage == stage { + return event; + } + } + } + }) + .await + .unwrap() +} + +#[tokio::test(flavor = "current_thread")] +async fn handshake_initiator_connects() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = + new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_drop_heartbeat_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn handshake_timeout_disconnects() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(50)); + let (platform_a, _outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); + + let peer_b = platform_b.xid(); + let (runtime_a, handle_a) = new_runtime(platform_a, config); + tokio::task::spawn_local(async move { runtime_a.run().await }); + + handle_a.register_peer( + peer_b, + platform_b.signing_public_key().clone(), + platform_b.encapsulation_public_key().clone(), + ); + + handle_a.connect(peer_b).unwrap(); + + await_status(&status_a, peer_b, PeerStage::Disconnected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn simultaneous_handshakes_resolve() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = + new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + handle_b.connect(peer_a.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Initiator).await; + await_status(&status_b, peer_a.xid, PeerStage::Responder).await; + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn invalid_signature_disconnects() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, _status_b) = TestPlatform::new(2); + let (wrong_private, wrong_public) = SignatureScheme::MLDSA44.keypair(); + let _ = wrong_private; + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = + new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + handle_a.register_peer(peer_b.xid, wrong_public, peer_b.encapsulation_key.clone()); + handle_b.register_peer( + peer_a.xid, + peer_a.signing_key.clone(), + peer_a.encapsulation_key.clone(), + ); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn pairing_request_triggers_handshake() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let pairing_message = pair::build_pair_request( + &platform_a, + peer_b.xid, + &peer_b.encapsulation_key, + MessageId::new(1), + Duration::from_secs(1), + ) + .unwrap(); + let pairing_bytes = CBOR::from(pairing_message).to_cbor_data(); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = + new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + handle_a.register_peer( + peer_b.xid, + peer_b.signing_key.clone(), + peer_b.encapsulation_key.clone(), + ); + + handle_b.send_incoming(pairing_bytes); + + await_status(&status_b, peer_a.xid, PeerStage::Initiator).await; + await_status(&status_a, peer_b.xid, PeerStage::Responder).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn request_response_round_trip() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_request_timeout(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config.clone()); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let inbound_task = tokio::task::spawn_local(async move { + if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { + let _ = request.respond_to.respond(99u8); + } + }); + + let response = handle_a.send_request_raw( + peer_b.xid, + RouteId::new(7), + CBOR::from(12u8), + RequestConfig::default(), + ); + + let response = response.recv().await.unwrap(); + let value: u8 = response.try_into().unwrap(); + assert_eq!(value, 99u8); + let _ = inbound_task.await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn request_timeout_returns_error() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_request_timeout(Duration::from_millis(30)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config.clone()); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let ticket = handle_a.send_request_raw( + peer_b.xid, + RouteId::new(1), + CBOR::from(1u8), + RequestConfig { + timeout: Some(Duration::from_millis(30)), + }, + ); + + let result = tokio::time::timeout(Duration::from_millis(200), ticket.recv()) + .await + .unwrap(); + assert!(matches!(result, Err(QlError::Timeout))); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn request_nack_resolves_pending() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_request_timeout(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config.clone()); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let inbound_task = tokio::task::spawn_local(async move { + if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { + let _ = request.respond_to.respond_nack(Nack::InvalidPayload); + } + }); + + let response = handle_a.send_request_raw( + peer_b.xid, + RouteId::new(2), + CBOR::from(2u8), + RequestConfig::default(), + ); + + let result = response.recv().await; + assert!(matches!( + result, + Err(QlError::Nack { + nack: Nack::InvalidPayload, + .. + }) + )); + let _ = inbound_task.await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn request_dispatches_to_platform_callback() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_request_timeout(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config.clone()); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let inbound_task = tokio::task::spawn_local(async move { + if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { + let _ = request.respond_to.respond(7u8); + } + }); + + let ticket = handle_a.send_request_raw( + peer_b.xid, + RouteId::new(3), + CBOR::from(1u8), + RequestConfig::default(), + ); + + let response = ticket.recv().await.unwrap(); + let value: u8 = response.try_into().unwrap(); + assert_eq!(value, 7u8); + let _ = inbound_task.await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn blocked_write_still_times_out() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(40)); + let (platform_a, _outbound_a, status_a, _write_gate) = BlockingPlatform::new(2); + let (platform_b, _outbound_b, _status_b) = TestPlatform::new(1); + + let signing_b = platform_b.signing_public_key().clone(); + let encap_b = platform_b.encapsulation_public_key().clone(); + let peer_b = XID::new(&signing_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + tokio::task::spawn_local(async move { runtime_a.run().await }); + + handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); + + handle_a.connect(peer_b).unwrap(); + + await_status(&status_a, peer_b, PeerStage::Initiator).await; + await_status(&status_a, peer_b, PeerStage::Disconnected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn handshake_timeout_drops_queued_messages() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(60)); + let (platform_a, outbound_a, status_a, write_gate) = BlockingPlatform::new(2); + let (platform_b, _outbound_b, _status_b) = TestPlatform::new(1); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + tokio::task::spawn_local(async move { runtime_a.run().await }); + + handle_a.register_peer( + peer_b.xid, + peer_b.signing_key.clone(), + peer_b.encapsulation_key.clone(), + ); + + handle_a.connect(peer_b.xid).unwrap(); + await_status(&status_a, peer_b.xid, PeerStage::Initiator).await; + + let (hello, _secret) = + handshake::build_hello(&platform_b, peer_b.xid, peer_a.xid, &peer_a.encapsulation_key) + .unwrap(); + let message = QlRecord { + header: QlHeader { + sender: peer_b.xid, + recipient: peer_a.xid, + }, + payload: QlPayload::Handshake(HandshakeRecord::Hello(hello)), + }; + let bytes = CBOR::from(message).to_cbor_data(); + handle_a.send_incoming(bytes); + + await_status(&status_a, peer_b.xid, PeerStage::Responder).await; + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + + write_gate.add_permits(1); + let _ = tokio::time::timeout(Duration::from_millis(100), outbound_a.recv()) + .await + .unwrap() + .unwrap(); + + write_gate.add_permits(1); + let second = tokio::time::timeout(Duration::from_millis(50), outbound_a.recv()).await; + assert!( + second.is_err(), + "expected queued handshake reply to be dropped" + ); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn heartbeat_ignored_without_session() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, _status_a) = TestPlatform::new(1); + let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); + + let peer_a = platform_a.xid(); + let peer_b = platform_b.xid(); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + tokio::task::spawn_local(async move { runtime_a.run().await }); + + handle_a.register_peer( + peer_b, + platform_b.signing_public_key().clone(), + platform_b.encapsulation_public_key().clone(), + ); + + let heartbeat = heartbeat::encrypt_heartbeat( + QlHeader { + sender: peer_b, + recipient: peer_a, + }, + &SymmetricKey::new(), + HeartbeatBody { + message_id: MessageId::new(1), + valid_until: now_secs().saturating_add(60), + }, + ); + let bytes = CBOR::from(heartbeat).to_cbor_data(); + handle_a.send_incoming(bytes); + + let result = tokio::time::timeout(Duration::from_millis(50), outbound_a.recv()).await; + assert!(result.is_err(), "expected heartbeat to be ignored"); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn keepalive_disabled_no_heartbeat() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + + let signing_a = platform_a.signing_public_key().clone(); + let signing_b = platform_b.signing_public_key().clone(); + let encap_a = platform_a.encapsulation_public_key().clone(); + let encap_b = platform_b.encapsulation_public_key().clone(); + let peer_a = XID::new(&signing_a); + let peer_b = XID::new(&signing_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); + spawn_forwarder(outbound_b, handle_a.clone()); + + handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); + handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); + + handle_a.connect(peer_b).unwrap(); + + await_status(&status_a, peer_b, PeerStage::Connected).await; + await_status(&status_b, peer_a, PeerStage::Connected).await; + + let result = tokio::time::timeout(Duration::from_millis(120), heartbeat_rx.recv()).await; + assert!(result.is_err(), "unexpected heartbeat while disabled"); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn heartbeat_sent_after_idle() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(30), + timeout: Duration::from_millis(80), + }; + let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); + let config_b = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + + let signing_a = platform_a.signing_public_key().clone(); + let signing_b = platform_b.signing_public_key().clone(); + let encap_a = platform_a.encapsulation_public_key().clone(); + let encap_b = platform_b.encapsulation_public_key().clone(); + let peer_a = XID::new(&signing_a); + let peer_b = XID::new(&signing_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); + spawn_forwarder(outbound_b, handle_a.clone()); + + handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); + handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); + + handle_a.connect(peer_b).unwrap(); + + await_status(&status_a, peer_b, PeerStage::Connected).await; + await_status(&status_b, peer_a, PeerStage::Connected).await; + + let _ = tokio::time::timeout(Duration::from_millis(200), heartbeat_rx.recv()) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn heartbeat_reply_when_connected() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(30), + timeout: Duration::from_millis(80), + }; + let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); + let config_b = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + + let signing_a = platform_a.signing_public_key().clone(); + let signing_b = platform_b.signing_public_key().clone(); + let encap_a = platform_a.encapsulation_public_key().clone(); + let encap_b = platform_b.encapsulation_public_key().clone(); + let peer_a = XID::new(&signing_a); + let peer_b = XID::new(&signing_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_ab_tx, heartbeat_ab_rx) = async_channel::unbounded(); + let (heartbeat_ba_tx, heartbeat_ba_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_ab_tx); + spawn_heartbeat_tap_forwarder(outbound_b, handle_a.clone(), heartbeat_ba_tx); + + handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); + handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); + + handle_a.connect(peer_b).unwrap(); + + await_status(&status_a, peer_b, PeerStage::Connected).await; + await_status(&status_b, peer_a, PeerStage::Connected).await; + + let _ = tokio::time::timeout(Duration::from_millis(200), heartbeat_ab_rx.recv()) + .await + .unwrap() + .unwrap(); + let _ = tokio::time::timeout(Duration::from_millis(200), heartbeat_ba_rx.recv()) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn any_message_clears_pending() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(120), + timeout: Duration::from_millis(40), + }; + let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); + let config_b = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + + let signing_a = platform_a.signing_public_key().clone(); + let signing_b = platform_b.signing_public_key().clone(); + let encap_a = platform_a.encapsulation_public_key().clone(); + let encap_b = platform_b.encapsulation_public_key().clone(); + let peer_a = XID::new(&signing_a); + let peer_b = XID::new(&signing_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); + spawn_drop_heartbeat_forwarder(outbound_b, handle_a.clone()); + + handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); + handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); + + handle_a.connect(peer_b).unwrap(); + + await_status(&status_a, peer_b, PeerStage::Connected).await; + await_status(&status_b, peer_a, PeerStage::Connected).await; + + let _ = tokio::time::timeout(Duration::from_millis(200), heartbeat_rx.recv()) + .await + .unwrap() + .unwrap(); + + handle_b.send_event_raw(peer_a, RouteId::new(99), CBOR::from(1u8)); + + let window = keep_alive.timeout + Duration::from_millis(20); + let disconnect = tokio::time::timeout(window, async { + loop { + if let Ok(event) = status_a.recv().await { + if event.peer == peer_b && event.stage == PeerStage::Disconnected { + return; + } + } + } + }) + .await; + assert!(disconnect.is_err(), "unexpected disconnect"); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn heartbeat_timeout_disconnects_and_drops_outbound() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(80), + timeout: Duration::from_millis(60), + }; + let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); + let config_b = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(2); + let (platform_b, outbound_b, status_b) = TestPlatform::new(1); + + let signing_a = platform_a.signing_public_key().clone(); + let signing_b = platform_b.signing_public_key().clone(); + let encap_a = platform_a.encapsulation_public_key().clone(); + let encap_b = platform_b.encapsulation_public_key().clone(); + let peer_a = XID::new(&signing_a); + let peer_b = XID::new(&signing_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let drop_flag = Arc::new(AtomicBool::new(false)); + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_gated_forwarder(outbound_b, handle_a.clone(), drop_flag.clone()); + + handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); + handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); + + handle_a.connect(peer_b).unwrap(); + + await_status(&status_a, peer_b, PeerStage::Connected).await; + await_status(&status_b, peer_a, PeerStage::Connected).await; + + drop_flag.store(true, Ordering::Relaxed); + + let response = handle_a.send_request_raw( + peer_b, + RouteId::new(9), + CBOR::from(9u8), + RequestConfig { + timeout: Some(Duration::from_millis(200)), + }, + ); + + await_status(&status_a, peer_b, PeerStage::Disconnected).await; + + let result = tokio::time::timeout(Duration::from_millis(300), response.recv()) + .await + .unwrap(); + assert!( + matches!(result, Err(QlError::SendFailed)), + "unexpected result: {result:?}" + ); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn no_ping_pong() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(200), + timeout: Duration::from_millis(60), + }; + let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); + let config_b = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + + let signing_a = platform_a.signing_public_key().clone(); + let signing_b = platform_b.signing_public_key().clone(); + let encap_a = platform_a.encapsulation_public_key().clone(); + let encap_b = platform_b.encapsulation_public_key().clone(); + let peer_a = XID::new(&signing_a); + let peer_b = XID::new(&signing_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_ab_tx, heartbeat_ab_rx) = async_channel::unbounded(); + let (heartbeat_ba_tx, heartbeat_ba_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_ab_tx); + spawn_heartbeat_tap_forwarder(outbound_b, handle_a.clone(), heartbeat_ba_tx); + + handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); + handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); + + handle_a.connect(peer_b).unwrap(); + + await_status(&status_a, peer_b, PeerStage::Connected).await; + await_status(&status_b, peer_a, PeerStage::Connected).await; + + let _ = tokio::time::timeout(Duration::from_millis(300), heartbeat_ab_rx.recv()) + .await + .unwrap() + .unwrap(); + let _ = tokio::time::timeout(Duration::from_millis(200), heartbeat_ba_rx.recv()) + .await + .unwrap() + .unwrap(); + + let followup = + tokio::time::timeout(Duration::from_millis(50), heartbeat_ab_rx.recv()).await; + assert!(followup.is_err(), "unexpected heartbeat ping-pong"); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn invalid_heartbeat_ignored() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + + let signing_a = platform_a.signing_public_key().clone(); + let signing_b = platform_b.signing_public_key().clone(); + let encap_a = platform_a.encapsulation_public_key().clone(); + let encap_b = platform_b.encapsulation_public_key().clone(); + let peer_a = XID::new(&signing_a); + let peer_b = XID::new(&signing_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); + spawn_forwarder(outbound_b, handle_a.clone()); + + handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); + handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); + + handle_a.connect(peer_b).unwrap(); + + await_status(&status_a, peer_b, PeerStage::Connected).await; + await_status(&status_b, peer_a, PeerStage::Connected).await; + + let heartbeat = heartbeat::encrypt_heartbeat( + QlHeader { + sender: peer_b, + recipient: peer_a, + }, + &SymmetricKey::new(), + HeartbeatBody { + message_id: MessageId::new(42), + valid_until: now_secs().saturating_add(30), + }, + ); + let bytes = CBOR::from(heartbeat).to_cbor_data(); + handle_a.send_incoming(bytes); + + let result = tokio::time::timeout(Duration::from_millis(50), heartbeat_rx.recv()).await; + assert!(result.is_err(), "unexpected heartbeat reply"); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn multi_peer_simultaneous_handshakes() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let (platform_c, outbound_c, status_c) = TestPlatform::new(3); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + let peer_c = peer_identity(&platform_c); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = + new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); + let (runtime_c, handle_c) = + new_runtime(platform_c, RuntimeConfig::new(Duration::from_millis(200))); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + tokio::task::spawn_local(async move { runtime_c.run().await }); + + spawn_routed_forwarder( + outbound_a, + vec![(peer_b.xid, handle_b.clone()), (peer_c.xid, handle_c.clone())], + ); + spawn_routed_forwarder(outbound_b, vec![(peer_a.xid, handle_a.clone())]); + spawn_routed_forwarder(outbound_c, vec![(peer_a.xid, handle_a.clone())]); + + let _ = register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + let _ = register_peers(&handle_a, &handle_c, &peer_a, &peer_c); + + handle_a.connect(peer_b.xid).unwrap(); + handle_a.connect(peer_c.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_a, peer_c.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + await_status(&status_c, peer_a.xid, PeerStage::Connected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn multi_peer_keepalive_disconnect_isolated() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(40), + timeout: Duration::from_millis(60), + }; + let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); + let config_b = RuntimeConfig::new(Duration::from_millis(200)); + let config_c = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let (platform_c, outbound_c, status_c) = TestPlatform::new(3); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + let peer_c = peer_identity(&platform_c); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + let (runtime_c, handle_c) = new_runtime(platform_c, config_c); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + tokio::task::spawn_local(async move { runtime_c.run().await }); + + let drop_b_to_a = Arc::new(AtomicBool::new(false)); + spawn_routed_forwarder( + outbound_a, + vec![(peer_b.xid, handle_b.clone()), (peer_c.xid, handle_c.clone())], + ); + spawn_routed_forwarder_with_filter(outbound_b, vec![(peer_a.xid, handle_a.clone())], { + let drop_b_to_a = drop_b_to_a.clone(); + move |record| { + !(drop_b_to_a.load(Ordering::Relaxed) && record.header.recipient == peer_a.xid) + } + }); + spawn_routed_forwarder(outbound_c, vec![(peer_a.xid, handle_a.clone())]); + + let _ = register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + let _ = register_peers(&handle_a, &handle_c, &peer_a, &peer_c); + + handle_a.connect(peer_b.xid).unwrap(); + handle_a.connect(peer_c.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_a, peer_c.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + await_status(&status_c, peer_a.xid, PeerStage::Connected).await; + + drop_b_to_a.store(true, Ordering::Relaxed); + + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + + let disconnect = + tokio::time::timeout(keep_alive.timeout + Duration::from_millis(80), async { + loop { + if let Ok(event) = status_a.recv().await { + if event.peer == peer_c.xid && event.stage == PeerStage::Disconnected { + return; + } + } + } + }) + .await; + assert!(disconnect.is_err(), "unexpected disconnect for peer C"); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn multi_peer_disconnect_drops_outbound_for_one() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(40), + timeout: Duration::from_millis(60), + }; + let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); + let config_b = RuntimeConfig::new(Duration::from_millis(200)); + let config_c = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let (platform_c, outbound_c, status_c, inbound_c) = InboundPlatform::new(3); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + let peer_c = peer_identity(&platform_c); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + let (runtime_c, handle_c) = new_runtime(platform_c, config_c); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + tokio::task::spawn_local(async move { runtime_c.run().await }); + + let drop_b_to_a = Arc::new(AtomicBool::new(false)); + spawn_routed_forwarder( + outbound_a, + vec![(peer_b.xid, handle_b.clone()), (peer_c.xid, handle_c.clone())], + ); + spawn_routed_forwarder_with_filter(outbound_b, vec![(peer_a.xid, handle_a.clone())], { + let drop_b_to_a = drop_b_to_a.clone(); + move |record| { + !(drop_b_to_a.load(Ordering::Relaxed) && record.header.recipient == peer_a.xid) + } + }); + spawn_routed_forwarder(outbound_c, vec![(peer_a.xid, handle_a.clone())]); + + let _ = register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + let _ = register_peers(&handle_a, &handle_c, &peer_a, &peer_c); + + handle_a.connect(peer_b.xid).unwrap(); + handle_a.connect(peer_c.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_a, peer_c.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + await_status(&status_c, peer_a.xid, PeerStage::Connected).await; + + let inbound_task = tokio::task::spawn_local(async move { + if let Ok(HandlerEvent::Request(request)) = inbound_c.recv().await { + let _ = request.respond_to.respond(55u8); + } + }); + + drop_b_to_a.store(true, Ordering::Relaxed); + + let request_b = handle_a.send_request_raw( + peer_b.xid, + RouteId::new(10), + CBOR::from(10u8), + RequestConfig { + timeout: Some(Duration::from_millis(200)), + }, + ); + let request_c = handle_a.send_request_raw( + peer_c.xid, + RouteId::new(11), + CBOR::from(11u8), + RequestConfig { + timeout: Some(Duration::from_millis(200)), + }, + ); + + let response_c = tokio::time::timeout(Duration::from_millis(200), request_c.recv()) + .await + .expect("response wait") + .expect("response channel"); + let value: u8 = response_c.try_into().unwrap(); + assert_eq!(value, 55u8); + + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + + let result_b = tokio::time::timeout(Duration::from_millis(200), request_b.recv()) + .await + .expect("response wait"); + assert!(matches!(result_b, Err(QlError::SendFailed))); + + let _ = inbound_task.await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn multi_peer_activity_is_per_peer() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(100), + timeout: Duration::from_millis(40), + }; + let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); + let config_b = RuntimeConfig::new(Duration::from_millis(200)); + let config_c = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let (platform_c, outbound_c, status_c) = TestPlatform::new(3); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + let peer_c = peer_identity(&platform_c); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + let (runtime_c, handle_c) = new_runtime(platform_c, config_c); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + tokio::task::spawn_local(async move { runtime_c.run().await }); + + let drop_all_c = Arc::new(AtomicBool::new(false)); + spawn_routed_forwarder( + outbound_a, + vec![ + (peer_b.xid, handle_b.clone()), + (peer_c.xid, handle_c.clone()), + ], + ); + spawn_drop_heartbeat_forwarder(outbound_b, handle_a.clone()); + spawn_gated_forwarder(outbound_c, handle_a.clone(), drop_all_c.clone()); + + let _ = register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + let _ = register_peers(&handle_a, &handle_c, &peer_a, &peer_c); + + handle_a.connect(peer_b.xid).unwrap(); + handle_a.connect(peer_c.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_a, peer_c.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + await_status(&status_c, peer_a.xid, PeerStage::Connected).await; + + drop_all_c.store(true, Ordering::Relaxed); + + tokio::time::sleep(keep_alive.interval + Duration::from_millis(5)).await; + + handle_b.send_event_raw(peer_a.xid, RouteId::new(99), CBOR::from(1u8)); + + await_status(&status_a, peer_c.xid, PeerStage::Disconnected).await; + + let disconnect = + tokio::time::timeout(keep_alive.timeout + Duration::from_millis(30), async { + loop { + if let Ok(event) = status_a.recv().await { + if event.peer == peer_b.xid && event.stage == PeerStage::Disconnected { + return; + } + } + } + }) + .await; + assert!(disconnect.is_err(), "unexpected disconnect for peer B"); + }) + .await; +} diff --git a/ql/src/wire/handshake.rs b/ql/src/wire/handshake.rs new file mode 100644 index 00000000..eafc2da4 --- /dev/null +++ b/ql/src/wire/handshake.rs @@ -0,0 +1,120 @@ +use bc_components::{EncapsulationCiphertext, Nonce, Signature, SigningPublicKey, Verifier}; +use dcbor::CBOR; + +use super::take_fields; +use crate::QlError; + +#[derive(Debug, Clone, PartialEq)] +pub enum HandshakeRecord { + Hello(Hello), + HelloReply(HelloReply), + Confirm(Confirm), +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Hello { + pub nonce: Nonce, + pub kem_ct: EncapsulationCiphertext, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct HelloReply { + pub nonce: Nonce, + pub kem_ct: EncapsulationCiphertext, + pub signature: Signature, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Confirm { + pub signature: Signature, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HandshakeKind { + Hello = 1, + HelloReply, + Confirm, +} + +impl TryFrom for HandshakeKind { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let tag: u8 = value.try_into()?; + match tag { + 1 => Ok(Self::Hello), + 2 => Ok(Self::HelloReply), + 3 => Ok(Self::Confirm), + _ => Err(dcbor::Error::msg("unknown message tag")), + } + } +} + +pub fn verify_transcript_signature( + signing_key: &SigningPublicKey, + signature: &Signature, + transcript: &[u8], +) -> Result<(), QlError> { + if signing_key.verify(signature, &transcript) { + Ok(()) + } else { + Err(QlError::InvalidSignature) + } +} + +impl From for CBOR { + fn from(value: HandshakeRecord) -> Self { + match value { + HandshakeRecord::Hello(message) => CBOR::from(vec![ + CBOR::from(HandshakeKind::Hello as u8), + CBOR::from(message.nonce), + CBOR::from(message.kem_ct), + ]), + HandshakeRecord::HelloReply(message) => CBOR::from(vec![ + CBOR::from(HandshakeKind::HelloReply as u8), + CBOR::from(message.nonce), + CBOR::from(message.kem_ct), + CBOR::from(message.signature), + ]), + HandshakeRecord::Confirm(message) => CBOR::from(vec![ + CBOR::from(HandshakeKind::Confirm as u8), + CBOR::from(message.signature), + ]), + } + } +} + +impl TryFrom for HandshakeRecord { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let mut iter = value.try_into_array()?.into_iter(); + let tag: HandshakeKind = iter + .next() + .ok_or_else(|| dcbor::Error::msg("missing handshake tag"))? + .try_into()?; + match tag { + HandshakeKind::Hello => { + let [nonce_cbor, kem_ct_cbor] = take_fields(iter)?; + Ok(HandshakeRecord::Hello(Hello { + nonce: nonce_cbor.try_into()?, + kem_ct: kem_ct_cbor.try_into()?, + })) + } + HandshakeKind::HelloReply => { + let [nonce_cbor, kem_ct_cbor, signature_cbor] = take_fields(iter)?; + Ok(HandshakeRecord::HelloReply(HelloReply { + nonce: nonce_cbor.try_into()?, + kem_ct: kem_ct_cbor.try_into()?, + signature: signature_cbor.try_into()?, + })) + } + HandshakeKind::Confirm => { + let [signature_cbor] = take_fields(iter)?; + Ok(HandshakeRecord::Confirm(Confirm { + signature: signature_cbor.try_into()?, + })) + } + } + } +} diff --git a/ql/src/wire/heartbeat.rs b/ql/src/wire/heartbeat.rs new file mode 100644 index 00000000..43f8c43a --- /dev/null +++ b/ql/src/wire/heartbeat.rs @@ -0,0 +1,32 @@ +use dcbor::CBOR; + +use super::take_fields; +use crate::MessageId; + +#[derive(Debug, Clone, PartialEq)] +pub struct HeartbeatBody { + pub message_id: MessageId, + pub valid_until: u64, +} + +impl From for CBOR { + fn from(value: HeartbeatBody) -> Self { + CBOR::from(vec![ + CBOR::from(value.message_id), + CBOR::from(value.valid_until), + ]) + } +} + +impl TryFrom for HeartbeatBody { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let iter = value.try_into_array()?.into_iter(); + let [message_id, valid_until] = take_fields(iter)?; + Ok(Self { + message_id: message_id.try_into()?, + valid_until: valid_until.try_into()?, + }) + } +} diff --git a/ql/src/wire/message.rs b/ql/src/wire/message.rs new file mode 100644 index 00000000..5f246d84 --- /dev/null +++ b/ql/src/wire/message.rs @@ -0,0 +1,141 @@ +use bc_components::XID; +use dcbor::CBOR; + +use super::take_fields; +use crate::{MessageId, RouteId}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MessageKind { + Request, + Response, + Event, + Nack, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Ack; + +#[derive(Debug, Clone, PartialEq)] +pub struct MessageBody { + pub message_id: MessageId, + pub valid_until: u64, + pub kind: MessageKind, + pub route_id: RouteId, + pub payload: CBOR, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct DecryptedMessage { + pub sender: XID, + pub recipient: XID, + pub kind: MessageKind, + pub message_id: MessageId, + pub route_id: RouteId, + pub valid_until: u64, + pub payload: CBOR, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Nack { + Unknown, + UnknownRoute, + InvalidPayload, + Expired, +} + +impl From for CBOR { + fn from(value: MessageKind) -> Self { + let kind = match value { + MessageKind::Request => 1, + MessageKind::Response => 2, + MessageKind::Event => 3, + MessageKind::Nack => 6, + }; + CBOR::from(kind) + } +} + +impl TryFrom for MessageKind { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let kind: u64 = value.try_into()?; + match kind { + 1 => Ok(Self::Request), + 2 => Ok(Self::Response), + 3 => Ok(Self::Event), + 6 => Ok(Self::Nack), + _ => Err(dcbor::Error::msg("unknown record kind")), + } + } +} + +impl From for CBOR { + fn from(_value: Ack) -> Self { + CBOR::null() + } +} + +impl TryFrom for Ack { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + if value.is_null() { + Ok(Self) + } else { + Err(dcbor::Error::msg("expected null")) + } + } +} + +impl From for CBOR { + fn from(value: MessageBody) -> Self { + CBOR::from(vec![ + CBOR::from(value.message_id), + CBOR::from(value.valid_until), + CBOR::from(value.kind), + CBOR::from(value.route_id), + value.payload, + ]) + } +} + +impl TryFrom for MessageBody { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let iter = value.try_into_array()?.into_iter(); + let [message_id, valid_until, kind, route_id, payload] = take_fields(iter)?; + Ok(Self { + message_id: message_id.try_into()?, + valid_until: valid_until.try_into()?, + kind: kind.try_into()?, + route_id: route_id.try_into()?, + payload, + }) + } +} + +impl From for CBOR { + fn from(value: Nack) -> Self { + let value = match value { + Nack::Unknown => 0, + Nack::UnknownRoute => 1, + Nack::InvalidPayload => 2, + Nack::Expired => 3, + }; + CBOR::from(value) + } +} + +impl From for Nack { + fn from(value: CBOR) -> Self { + let value: u8 = value.try_into().unwrap_or_default(); + match value { + 1 => Nack::UnknownRoute, + 2 => Nack::InvalidPayload, + 3 => Nack::Expired, + _ => Nack::Unknown, + } + } +} diff --git a/ql/src/wire/mod.rs b/ql/src/wire/mod.rs new file mode 100644 index 00000000..296f9ac7 --- /dev/null +++ b/ql/src/wire/mod.rs @@ -0,0 +1,182 @@ +use dcbor::CBOR; + +pub mod handshake; +pub mod heartbeat; +pub mod message; +pub mod pair; + +use bc_components::{EncryptedMessage, XID}; + +use crate::wire::{handshake::HandshakeRecord, pair::PairRequestRecord}; + +#[derive(Debug, Clone, PartialEq)] +pub struct QlRecord { + pub header: QlHeader, + pub payload: QlPayload, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct QlHeader { + pub sender: XID, + pub recipient: XID, +} + +impl QlHeader { + pub fn aad(&self) -> Vec { + CBOR::from(self.clone()).to_cbor_data() + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum QlPayload { + Handshake(HandshakeRecord), + Pair(PairRequestRecord), + Message(EncryptedMessage), + Heartbeat(EncryptedMessage), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum QlTag { + Handshake = 1, + Pairing = 2, + Record = 3, + Heartbeat = 4, +} + +impl From for CBOR { + fn from(value: QlTag) -> Self { + CBOR::from(value as u8) + } +} + +impl TryFrom for QlTag { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let tag: u8 = value.try_into()?; + match tag { + 1 => Ok(Self::Handshake), + 2 => Ok(Self::Pairing), + 3 => Ok(Self::Record), + 4 => Ok(Self::Heartbeat), + _ => Err(dcbor::Error::msg("unknown message tag")), + } + } +} + +impl From for CBOR { + fn from(value: QlRecord) -> Self { + let (tag, payload) = match value.payload { + QlPayload::Handshake(message) => (QlTag::Handshake, CBOR::from(message)), + QlPayload::Pair(message) => (QlTag::Pairing, CBOR::from(message)), + QlPayload::Message(message) => (QlTag::Record, CBOR::from(message)), + QlPayload::Heartbeat(message) => (QlTag::Heartbeat, CBOR::from(message)), + }; + CBOR::from(vec![ + CBOR::from(tag as u8), + CBOR::from(value.header), + payload, + ]) + } +} + +impl TryFrom for QlRecord { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let iter = value.try_into_array()?.into_iter(); + let [tag_cbor, header_cbor, payload] = take_fields(iter)?; + let tag = QlTag::try_from(tag_cbor)?; + let header = QlHeader::try_from(header_cbor)?; + match tag { + QlTag::Handshake => { + let message = HandshakeRecord::try_from(payload)?; + Ok(QlRecord { + header, + payload: QlPayload::Handshake(message), + }) + } + QlTag::Pairing => { + let message = PairRequestRecord::try_from(payload)?; + Ok(QlRecord { + header, + payload: QlPayload::Pair(message), + }) + } + QlTag::Record => { + let message = EncryptedMessage::try_from(payload)?; + Ok(QlRecord { + header, + payload: QlPayload::Message(message), + }) + } + QlTag::Heartbeat => { + let message = EncryptedMessage::try_from(payload)?; + Ok(QlRecord { + header, + payload: QlPayload::Heartbeat(message), + }) + } + } + } +} + +impl From for CBOR { + fn from(value: QlHeader) -> Self { + CBOR::from(vec![CBOR::from(value.sender), CBOR::from(value.recipient)]) + } +} + +impl TryFrom for QlHeader { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let iter = value.try_into_array()?.into_iter(); + let [sender_cbor, recipient_cbor] = take_fields(iter)?; + Ok(Self { + sender: sender_cbor.try_into()?, + recipient: recipient_cbor.try_into()?, + }) + } +} + +pub(crate) fn take_fields( + mut iter: impl Iterator, +) -> Result<[CBOR; N], dcbor::Error> { + use std::mem::MaybeUninit; + + let mut fields: [MaybeUninit; N] = unsafe { MaybeUninit::uninit().assume_init() }; + for (index, slot) in fields.iter_mut().enumerate() { + let Some(value) = iter.next() else { + for init in &mut fields[..index] { + unsafe { init.assume_init_drop() }; + } + return Err(dcbor::Error::msg("array too short")); + }; + slot.write(value); + } + let result = unsafe { std::ptr::read(&fields as *const _ as *const [CBOR; N]) }; + if iter.next().is_some() { + return Err(dcbor::Error::msg("array too long")); + } + Ok(result) +} + +#[test] +fn take_fields_reads_exact_count() { + let values = vec![CBOR::from(1u8), CBOR::from(2u8), CBOR::from(3u8)]; + let mut iter = values.into_iter(); + let [first, second, third] = take_fields(&mut iter).unwrap(); + assert_eq!(u8::try_from(first).unwrap(), 1); + assert_eq!(u8::try_from(second).unwrap(), 2); + assert_eq!(u8::try_from(third).unwrap(), 3); + assert!(iter.next().is_none()); +} + +#[test] +fn take_fields_rejects_short_arrays() { + let values = vec![CBOR::from(1u8)]; + let mut iter = values.into_iter(); + let result: Result<[CBOR; 2], _> = take_fields(&mut iter); + assert!(result.is_err()); +} diff --git a/ql/src/wire/pair.rs b/ql/src/wire/pair.rs new file mode 100644 index 00000000..276de804 --- /dev/null +++ b/ql/src/wire/pair.rs @@ -0,0 +1,68 @@ +use bc_components::{EncapsulationCiphertext, EncapsulationPublicKey, Signature, SigningPublicKey}; +use dcbor::CBOR; + +use super::take_fields; +use crate::MessageId; + +#[derive(Debug, Clone, PartialEq)] +pub struct PairRequestRecord { + pub kem_ct: EncapsulationCiphertext, + pub encrypted: bc_components::EncryptedMessage, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PairRequestBody { + pub message_id: MessageId, + pub valid_until: u64, + pub signing_pub_key: SigningPublicKey, + pub encapsulation_pub_key: EncapsulationPublicKey, + pub proof: Signature, +} + +impl From for CBOR { + fn from(value: PairRequestRecord) -> Self { + CBOR::from(vec![CBOR::from(value.kem_ct), CBOR::from(value.encrypted)]) + } +} + +impl TryFrom for PairRequestRecord { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let iter = value.try_into_array()?.into_iter(); + let [kem_ct_cbor, encrypted_cbor] = take_fields(iter)?; + Ok(Self { + kem_ct: kem_ct_cbor.try_into()?, + encrypted: encrypted_cbor.try_into()?, + }) + } +} + +impl From for CBOR { + fn from(value: PairRequestBody) -> Self { + CBOR::from(vec![ + CBOR::from(value.message_id), + CBOR::from(value.valid_until), + CBOR::from(value.signing_pub_key), + CBOR::from(value.encapsulation_pub_key), + CBOR::from(value.proof), + ]) + } +} + +impl TryFrom for PairRequestBody { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let iter = value.try_into_array()?.into_iter(); + let [message_id, valid_until, signing_pub_key, encapsulation_pub_key, proof] = + take_fields(iter)?; + Ok(Self { + message_id: message_id.try_into()?, + valid_until: valid_until.try_into()?, + signing_pub_key: signing_pub_key.try_into()?, + encapsulation_pub_key: encapsulation_pub_key.try_into()?, + proof: proof.try_into()?, + }) + } +} From 64c2fe77c2076687bbbd62e232b499e9b52520b2 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 18 Mar 2026 01:16:42 -0400 Subject: [PATCH 003/304] ql: stabilize ql runtime with tests, persistence, and crypto cleanup --- ql/README.md | 143 ++ ql/ql-v2.presenterm.md | 285 +++ ql/src/crypto/mod.rs | 25 - ql/src/id.rs | 48 +- ql/src/lib.rs | 13 +- ql/src/platform.rs | 20 +- ql/src/router.rs | 6 + ql/src/runtime/core.rs | 232 ++- ql/src/runtime/handle.rs | 8 +- ql/src/runtime/internal.rs | 81 +- ql/src/runtime/mod.rs | 4 +- ql/src/runtime/replay_cache.rs | 180 ++ ql/src/runtime/tests.rs | 1588 ----------------- ql/src/tests/handshake.rs | 292 +++ ql/src/tests/heartbeat.rs | 641 +++++++ ql/src/tests/mod.rs | 625 +++++++ ql/src/tests/persistence.rs | 228 +++ ql/src/tests/requests.rs | 445 +++++ .../handshake.rs => wire/handshake/crypto.rs} | 35 +- .../wire/{handshake.rs => handshake/mod.rs} | 24 +- .../heartbeat.rs => wire/heartbeat/crypto.rs} | 4 +- .../wire/{heartbeat.rs => heartbeat/mod.rs} | 3 + .../message.rs => wire/message/crypto.rs} | 16 +- ql/src/wire/{message.rs => message/mod.rs} | 3 + ql/src/wire/mod.rs | 24 +- .../{crypto/pair.rs => wire/pair/crypto.rs} | 31 +- ql/src/wire/{pair.rs => pair/mod.rs} | 13 +- 27 files changed, 3160 insertions(+), 1857 deletions(-) create mode 100644 ql/README.md create mode 100644 ql/ql-v2.presenterm.md delete mode 100644 ql/src/crypto/mod.rs create mode 100644 ql/src/runtime/replay_cache.rs delete mode 100644 ql/src/runtime/tests.rs create mode 100644 ql/src/tests/handshake.rs create mode 100644 ql/src/tests/heartbeat.rs create mode 100644 ql/src/tests/mod.rs create mode 100644 ql/src/tests/persistence.rs create mode 100644 ql/src/tests/requests.rs rename ql/src/{crypto/handshake.rs => wire/handshake/crypto.rs} (80%) rename ql/src/wire/{handshake.rs => handshake/mod.rs} (87%) rename ql/src/{crypto/heartbeat.rs => wire/heartbeat/crypto.rs} (91%) rename ql/src/wire/{heartbeat.rs => heartbeat/mod.rs} (96%) rename ql/src/{crypto/message.rs => wire/message/crypto.rs} (81%) rename ql/src/wire/{message.rs => message/mod.rs} (99%) rename ql/src/{crypto/pair.rs => wire/pair/crypto.rs} (80%) rename ql/src/wire/{pair.rs => pair/mod.rs} (87%) diff --git a/ql/README.md b/ql/README.md new file mode 100644 index 00000000..d39e4e90 --- /dev/null +++ b/ql/README.md @@ -0,0 +1,143 @@ +# QL Protocol (v2) + +QL is a compact, session-oriented protocol for authenticated and encrypted messaging +between peers over arbitrary transports. It targets low-bandwidth and high-latency +links while preserving strong cryptography, explicit request/response semantics, and +a clean developer-facing API. + +This crate (`ql`) implements the protocol stack: wire format, crypto, runtime state +machine, and routing. For a deeper comparison with v1, see `ql-protocol-v2.md`. + +## features +- Fixed CBOR wire format: `QlRecord` = `[tag, header, payload]`. +- Mutual-auth handshake (`Hello`, `HelloReply`, `Confirm`) with signed transcript. +- Session keys derived from KEM secrets; payloads use AEAD (ChaCha20-Poly1305). +- Sessions are ephemeral and scoped to a handshake; no long-lived symmetric keys. +- First-contact pairing request with KEM-wrapped payloads and proof signature. +- Encrypted messages with explicit `Request`, `Response`, `Event`, and `Nack`. +- `MessageId`, `RouteId`, and `valid_until` for routing and freshness. +- Heartbeats for keepalive and disconnect detection. +- Runtime state machine for sessions, timeouts, outbound queues, and correlation. +- Router for typed dispatch and automatic response wiring. +- Transport abstraction via `QlPlatform` for BLE, TCP, or other links. + +## overview +QL provides a full session protocol rather than isolated message sealing. It covers: +- Mutual authentication and end-to-end encryption above the transport. +- First-contact pairing for provisioning keys and establishing trust. +- Typed routing with explicit request/response/event semantics. +- Runtime lifecycle management (handshake, keepalive, timeouts, errors). +- Portability across transports via a minimal platform abstraction. + +### security +- Mutual authentication via a signed handshake transcript. +- Session keys derived from KEM secrets; payloads are protected with AEAD + and header AAD. +- End-to-end protection above the transport layer; pairing supports first-contact + key exchange and proof of key possession. +- Message freshness enforced via `valid_until`; replay caching is not built-in, + so applications can optionally track `MessageId` if needed. + +### session vs per-message sealing +- v1 (gstp + envelope) signs every message and then encrypts it to the recipient. + each message uses fresh encapsulation, so keys and signatures are per-message. +- v2 (ql) signs the handshake transcript once, derives a session key, then uses + AEAD for each message with the header as AAD. +- encryption strength uses the same primitive (ChaCha20-Poly1305). post-quantum + security depends on key schemes (ML-KEM + ML-DSA with `pqcrypto` enabled). +- tradeoffs: v2 is faster and smaller; v1 has per-message signature and key + isolation. v2's AEAD provides in-session integrity but is not publicly + verifiable and has a larger blast radius if a session key leaks. + +### performance +- Public-key operations are paid once per session; steady-state traffic is + symmetric AEAD. +- Compact CBOR record framing keeps headers and serialization overhead small. +- Optional heartbeats provide liveness detection without heavy traffic. + +### developer experience +- Typed routes via `RequestResponse` and `Event` traits with explicit `RouteId`. +- Router handles decode, dispatch, and response wiring automatically. +- Runtime manages sessions, timeouts, outbound queues, and request correlation. +- `QlPlatform` abstracts the transport for portability and testability. + +## message sizes +Sizes below are CBOR record sizes from `protocol_record_size_breakdown` in +`ql/src/tests/mod.rs`. + +| Record | Size (bytes) | +| :-- | --: | +| Handshake Hello | 132 | +| Handshake HelloReply | 2563 | +| Handshake Confirm | 2510 | +| Pair request | 4065 | +| Message (empty payload) | 199 | +| Heartbeat | 196 | + +Handshake total is 5205 bytes (132 + 2563 + 2510). At 20 kBps transport +throughput, raw transmit time is about 0.26 s. + +## protocol overview + +### record framing +All traffic is encoded as a `QlRecord` with a small, fixed shape: +- `tag` selects the payload type (handshake, pair, record, heartbeat). +- `header` is unencrypted but authenticated data (AEAD AAD) used for routing + (sender and recipient XIDs). +- `payload` is a CBOR-encoded handshake/pair body or an encrypted message. + +### handshake +The handshake is a three-message exchange: +- `Hello`: initiator sends a nonce and KEM ciphertext. +- `HelloReply`: responder returns its nonce, KEM ciphertext, and a signature + over the transcript. +- `Confirm`: initiator signs the transcript to confirm mutual authentication. + +Both sides derive the session key from the KEM secrets and transcript digest. +After the handshake, all records use symmetric AEAD with the header as AAD. + +### pairing (first-contact) +Pairing is a standalone request that KEM-encrypts a payload containing: +- a `MessageId` and `valid_until` timestamp +- the sender's signing and encapsulation public keys +- a proof signature binding those keys + +This enables establishing trust without an existing session. + +### message records +Steady-state messages are sent as encrypted records with a typed body: +- `MessageKind`: `Request`, `Response`, `Event`, or `Nack` +- `MessageId`, `RouteId`, `valid_until`, and CBOR payload + +Nacks communicate standard failure reasons (unknown route, invalid payload, +expired) so peers can recover consistently. + +### heartbeats +Heartbeats are lightweight encrypted records used by the runtime to maintain +session liveness and detect disconnects. + +### routing and dispatch +`RouteId` maps to concrete request/response or event types. The router decodes +payloads, dispatches handlers, and ensures each request receives a response or +a `Nack`. + +### sequence diagram +```mermaid +sequenceDiagram + participant A as Initiator + participant B as Responder + A->>B: Hello (nonce, KEM ct) + B->>A: HelloReply (nonce, KEM ct, signature) + A->>B: Confirm (signature) + Note over A,B: Session key derived, AEAD enabled + A->>B: Encrypted Record (Request) + B->>A: Encrypted Record (Response) + A-->>B: Encrypted Heartbeat (optional) +``` + +## code map +- Wire format: `ql/src/wire/*` +- Cryptography: `ql/src/crypto/*` +- Runtime state machine: `ql/src/runtime/*` +- Routing and traits: `ql/src/router.rs`, `ql/src/lib.rs` +- Transport abstraction: `ql/src/platform.rs` diff --git a/ql/ql-v2.presenterm.md b/ql/ql-v2.presenterm.md new file mode 100644 index 00000000..d4a0fff2 --- /dev/null +++ b/ql/ql-v2.presenterm.md @@ -0,0 +1,285 @@ +--- +theme: + name: gruvbox-dark +--- + +# quantumlink protocol v2 +post-quantum, session-based message protocol + + + +# ql v1: constraints +- no message id / sequence id +- no protocol-level request/response pairing +- each platform had to interpret + correlate by hand +- no ack/nack +- no notion of 'liveness'/'connected' status +- ~6.6KB min sealed event + - sender xid document (pq pubkeys) + - per-message signature + - recipient encryption (+ continuations) +- more a utility crate than a protocol + + + +# v1 vs v2 + + + + +## v1 +- gstp sealed envelope per message +- per-message sign+encrypt (envelope) +- implicit req/resp in enum variants +- app-owned pairing, timeouts, keepalive, connected status + + + +## v2 +- compact record + typed payloads +- handshake signatures + per‑message aead under symmetric session key +- explicit kind + ids + nack +- runtime handles pairing, timeouts, keepalive, connected status, request/response matching + + + +# design shift: per-message -> session +- v1 sealed each message +- v2 signs once, then aead per message + +```text +v1: seal(msg) = sign(msg) + encrypt(recipient) +v2: session_key = handshake() +v2: aead(msg, aad=header) +``` + + +_aead = authenticated encryption with associated data_ + +_aad = additional authenticated data (visible, integrity-protected)_ + + + +# configurable host platform +- same runtime across keyos / mobile / desktop +- host supplies pq keys, io, timers, callbacks + +```rust +pub trait QlPlatform { + // pq identity + fn signing_private_key(&self) -> &MLDSAPrivateKey; + fn signing_public_key(&self) -> &MLDSAPublicKey; + fn encapsulation_private_key(&self) -> &MLKEMPrivateKey; + fn encapsulation_public_key(&self) -> &MLKEMPublicKey; + + // transport + runtime hooks + fn fill_bytes(&self, data: &mut [u8]); + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>>; + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()>; + + // event handlers + fn handle_peer_status(&self, peer: XID, session: &PeerSession); + fn handle_inbound(&self, event: HandlerEvent); +} +``` + + + +# multi-peer runtime +- runtime tracks sessions per peer +- concurrent handshakes + keepalive per peer + +```rust +handle.register_peer(peer, signing_key, encapsulation_key); +handle.connect(peer)?; +``` + + + +# protocol breakdown +```mermaid +render +width:90% +sequenceDiagram + participant A as initiator + participant B as responder + + Note over A,B: pairing (first contact) + A->>B: pair request (kem + signed payload) + + Note over A,B: handshake (mutual auth) + A->>B: hello (nonce + kem ct) + B->>A: hello reply (nonce + kem ct + signature) + A->>B: confirm (signature) + + Note over A,B: session established + A->>B: request / event (aead + aad header) + B->>A: response / nack (aead + aad header) + A-->>B: heartbeat (optional) +``` + + + +# wire framing: routable header +- record = [tag, header, payload] +- header is unencrypted but authenticated (aad) + +```rust +pub struct QlRecord { + pub header: QlHeader, + pub payload: QlPayload, +} + +pub struct QlHeader { + pub sender: XID, + pub recipient: XID, +} +``` + + + +# handshake flow + records +- hello: nonce + mlkem ciphertext +- reply: nonce + mlkem ciphertext + mldsa signature +- confirm: mldsa signature, then session key + +```rust +pub struct Hello { + pub nonce: Nonce, + pub kem_ct: MLKEMCiphertext, +} + +pub struct HelloReply { + pub nonce: Nonce, + pub kem_ct: MLKEMCiphertext, + pub signature: MLDSASignature, +} + +pub struct Confirm { + pub signature: MLDSASignature, +} +``` + + + +# session key derivation +- transcript binds ids + nonces + kem ciphertexts +- session key = digest(initiator_secret, responder_secret, transcript) + +```rust +let transcript = cbor([ + initiator, responder, + hello.nonce, reply.nonce, + hello.kem_ct, reply.kem_ct, +]); +let payload = cbor([initiator_secret, responder_secret, transcript]); +let digest = Digest::from_image(payload); +let session_key = SymmetricKey::from_data(*digest.data()); +``` + + + +# message modalities +- request / response +- event: fire-and-forget or acked +- nack for structured failure + +```rust +pub enum MessageKind { + Request, + Response, + Event, + Nack, +} +``` + + + +# message body: routing + expiry +- message_id + route_id +- valid_until for freshness + +```rust +pub struct MessageBody { + pub message_id: MessageId, + pub valid_until: u64, + pub kind: MessageKind, + pub route_id: RouteId, + pub payload: CBOR, +} +``` + + + +# nack reasons +- unknown route / invalid payload / expired + +```rust +pub enum Nack { + Unknown, + UnknownRoute, + InvalidPayload, + Expired, +} +``` + + + +# type-safe routing +- route id is const per type +- compiler couples request -> response + +```rust +pub trait RequestResponse: QlCodec { + const ID: RouteId; + type Response: QlCodec; +} + +pub trait Event: QlCodec { + const ID: RouteId; +} +``` + + + +# router wiring +- builder ties route ids to handlers +- unknown routes auto-nack + +```rust +let router = Router::builder() + .add_request_handler::() + .add_event_handler::() + .build(state); +``` + + + +# runtime api flow +- request returns response or nack +- events are fire-and-forget (or acked) + +```rust +let reply = handle.request(msg, peer, RequestConfig::default()).await?; +handle.send_event(status, peer); +``` + + + +# performance snapshot (cbor sizes) +| proto | message | bytes | notes | +| :-- | :-- | --: | :-- | +| v1 | sealed msg (exchange_rate) | 6645 | sign+encrypt | +| v1 | sealed heartbeat | 6633 | sign+encrypt | +| v2 | hello | 132 | kem+nonce | +| v2 | hello reply | 2563 | sig+kem | +| v2 | confirm | 2510 | sig | +| v2 | pair request | 4065 | sig+kem | +| v2 | message (empty) | 199 | steady-state | +| v2 | heartbeat | 196 | steady-state | + +handshake total: 5205 bytes + + + +# close +- smaller packets, clearer flow, typed api +- ql v2 is the protocol, not just a crate diff --git a/ql/src/crypto/mod.rs b/ql/src/crypto/mod.rs deleted file mode 100644 index 5d6e6a4e..00000000 --- a/ql/src/crypto/mod.rs +++ /dev/null @@ -1,25 +0,0 @@ -use crate::{wire::message::Nack, MessageId, QlError}; - -pub mod handshake; -pub mod heartbeat; -pub mod message; -pub mod pair; - -fn ensure_not_expired(id: MessageId, valid_until: u64) -> Result<(), QlError> { - let now = now_secs(); - if now > valid_until { - Err(QlError::Nack { - id, - nack: Nack::Expired, - }) - } else { - Ok(()) - } -} - -fn now_secs() -> u64 { - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|duration| duration.as_secs()) - .unwrap_or(0) -} diff --git a/ql/src/id.rs b/ql/src/id.rs index eaab1d55..bc90db15 100644 --- a/ql/src/id.rs +++ b/ql/src/id.rs @@ -3,17 +3,7 @@ use std::fmt; use dcbor::CBOR; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct MessageId(u64); - -impl MessageId { - pub const fn new(value: u64) -> Self { - Self(value) - } - - pub const fn value(self) -> u64 { - self.0 - } -} +pub struct MessageId(pub u64); impl fmt::Display for MessageId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -21,18 +11,6 @@ impl fmt::Display for MessageId { } } -impl From for MessageId { - fn from(value: u64) -> Self { - Self(value) - } -} - -impl From for u64 { - fn from(value: MessageId) -> Self { - value.0 - } -} - impl From for CBOR { fn from(value: MessageId) -> Self { CBOR::from(value.0) @@ -49,17 +27,7 @@ impl TryFrom for MessageId { } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct RouteId(u64); - -impl RouteId { - pub const fn new(value: u64) -> Self { - Self(value) - } - - pub const fn value(self) -> u64 { - self.0 - } -} +pub struct RouteId(pub u64); impl fmt::Display for RouteId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -67,18 +35,6 @@ impl fmt::Display for RouteId { } } -impl From for RouteId { - fn from(value: u64) -> Self { - Self(value) - } -} - -impl From for u64 { - fn from(value: RouteId) -> Self { - value.0 - } -} - impl From for CBOR { fn from(value: RouteId) -> Self { CBOR::from(value.0) diff --git a/ql/src/lib.rs b/ql/src/lib.rs index 2ef441ae..34c09ca4 100644 --- a/ql/src/lib.rs +++ b/ql/src/lib.rs @@ -1,12 +1,21 @@ -pub mod crypto; mod id; -pub mod router; pub mod platform; +pub mod router; pub mod runtime; pub mod wire; pub use id::*; +#[cfg(test)] +mod tests; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Peer { + pub peer: bc_components::XID, + pub signing_key: bc_components::MLDSAPublicKey, + pub encapsulation_key: bc_components::MLKEMPublicKey, +} + pub trait QlCodec: Into + TryFrom {} impl QlCodec for T where T: Into + TryFrom {} diff --git a/ql/src/platform.rs b/ql/src/platform.rs index 8e179435..be13e32f 100644 --- a/ql/src/platform.rs +++ b/ql/src/platform.rs @@ -1,32 +1,36 @@ use std::{future::Future, pin::Pin, time::Duration}; use bc_components::{ - EncapsulationPrivateKey, EncapsulationPublicKey, Signer, SigningPublicKey, XID, + MLDSAPrivateKey, MLDSAPublicKey, MLKEMPrivateKey, MLKEMPublicKey, SigningPublicKey, XID, }; use crate::{ runtime::{HandlerEvent, PeerSession}, - QlError, + Peer, QlError, }; pub type PlatformFuture<'a, T> = Pin + 'a>>; pub trait QlPlatform { - fn signer(&self) -> &dyn Signer; - fn signing_public_key(&self) -> &SigningPublicKey; - fn encapsulation_private_key(&self) -> &EncapsulationPrivateKey; - fn encapsulation_public_key(&self) -> &EncapsulationPublicKey; + fn signing_private_key(&self) -> &MLDSAPrivateKey; + fn signing_public_key(&self) -> &MLDSAPublicKey; + fn encapsulation_private_key(&self) -> &MLKEMPrivateKey; + fn encapsulation_public_key(&self) -> &MLKEMPublicKey; - fn fill_bytes(&self, data: &mut [u8]); + fn fill_random_bytes(&self, data: &mut [u8]); fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>>; fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()>; + + fn load_peers(&self) -> PlatformFuture<'_, Vec>; + fn persist_peers(&self, peers: Vec); + fn handle_peer_status(&self, peer: XID, session: &PeerSession); fn handle_inbound(&self, event: HandlerEvent); } pub(crate) trait QlPlatformExt: QlPlatform { fn xid(&self) -> XID { - XID::new(&self.signing_public_key()) + XID::new(SigningPublicKey::MLDSA(self.signing_public_key().clone())) } } diff --git a/ql/src/router.rs b/ql/src/router.rs index 51a78af1..80c766b0 100644 --- a/ql/src/router.rs +++ b/ql/src/router.rs @@ -86,6 +86,12 @@ pub struct RouterBuilder { handlers: HashMap>, } +impl Default for RouterBuilder { + fn default() -> Self { + Self::new() + } +} + impl RouterBuilder { pub fn new() -> Self { Self { diff --git a/ql/src/runtime/core.rs b/ql/src/runtime/core.rs index 093a9719..d5de10d9 100644 --- a/ql/src/runtime/core.rs +++ b/ql/src/runtime/core.rs @@ -2,27 +2,27 @@ use std::{ cmp::Reverse, collections::binary_heap::PeekMut, future::Future, task::Poll, time::Instant, }; -use bc_components::{EncapsulationPublicKey, XID}; +use bc_components::{MLDSAPublicKey, MLKEMPublicKey, SigningPublicKey, XID}; use dcbor::CBOR; use futures_lite::future::poll_fn; use crate::{ - crypto::{handshake, heartbeat, message, pair}, platform::{QlPlatform, QlPlatformExt}, runtime::{ internal::{ next_timeout_deadline, now_secs, peer_hello_wins, HelloAction, InFlightWrite, - KeepAliveState, LoopStep, OutboundMessage, PendingEntry, RuntimeCommand, RuntimeState, - TimeoutEntry, TimeoutKind, + KeepAliveState, LoopStep, OutboundMessage, OutboundPayload, PendingEntry, + RuntimeCommand, RuntimeState, TimeoutEntry, TimeoutKind, }, + replay_cache::{ReplayKey, ReplayNamespace}, HandlerEvent, InboundEvent, InboundRequest, InitiatorStage, KeepAliveConfig, PeerSession, Responder, Runtime, Token, }, wire::{ - handshake::HandshakeRecord, - heartbeat::HeartbeatBody, - message::{MessageBody, MessageKind, Nack}, - pair::PairRequestRecord, + handshake::{self, HandshakeRecord}, + heartbeat::{self, HeartbeatBody}, + message::{self, MessageBody, MessageKind, Nack}, + pair::{self, PairRequestRecord}, QlHeader, QlPayload, QlRecord, }, MessageId, QlError, RouteId, @@ -31,6 +31,11 @@ use crate::{ impl Runtime

{ pub async fn run(self) { let mut state = RuntimeState::new(); + for peer in self.platform.load_peers().await { + state + .peers + .upsert_peer(peer.peer, peer.signing_key, peer.encapsulation_key); + } let mut in_flight: Option> = None; while !self.rx.is_closed() { if in_flight.is_none() { @@ -97,15 +102,41 @@ impl Runtime

{ } fn start_next_write<'a>(&'a self, state: &mut RuntimeState) -> Option> { - let Some(message) = state.outbound.pop_front() else { - return None; - }; - Some(InFlightWrite { - peer: message.peer, - token: message.token, - message_id: message.message_id, - future: self.platform.write_message(message.bytes), - }) + while let Some(message) = state.outbound.pop_front() { + let bytes = match message.payload { + OutboundPayload::PreEncoded(bytes) => bytes, + OutboundPayload::DeferredMessage(body) => { + let Some(session_key) = state + .peers + .peer(message.peer) + .and_then(|entry| entry.session.session_key()) + else { + if let Some(id) = message.message_id { + if let Some(entry) = state.pending.remove(&id) { + let _ = entry.tx.send(Err(QlError::SendFailed)); + } + } + continue; + }; + let message = message::encrypt_message( + QlHeader { + sender: self.platform.xid(), + recipient: message.peer, + }, + session_key, + body, + ); + CBOR::from(message).to_cbor_data() + } + }; + return Some(InFlightWrite { + peer: message.peer, + token: message.token, + message_id: message.message_id, + future: self.platform.write_message(bytes), + }); + } + None } async fn next_step<'a>( @@ -198,15 +229,18 @@ impl Runtime

{ &self, state: &mut RuntimeState, peer: XID, - signing_key: bc_components::SigningPublicKey, - encapsulation_key: EncapsulationPublicKey, + signing_key: MLDSAPublicKey, + encapsulation_key: MLKEMPublicKey, ) { - let entry = state - .peers - .upsert_peer(peer, signing_key, encapsulation_key); - if let PeerSession::Disconnected = entry.session { - self.platform.handle_peer_status(peer, &entry.session); + { + let entry = state + .peers + .upsert_peer(peer, signing_key, encapsulation_key); + if let PeerSession::Disconnected = entry.session { + self.platform.handle_peer_status(peer, &entry.session); + } } + self.persist_peers(state); } fn handle_send_request( @@ -230,13 +264,10 @@ impl Runtime

{ let _ = respond_to.send(Err(QlError::UnknownPeer(recipient))); return; }; - let session_key = match &entry.session { - PeerSession::Connected { session_key, .. } => session_key, - _ => { - let _ = respond_to.send(Err(QlError::MissingSession(recipient))); - return; - } - }; + if !entry.session.is_connected() { + let _ = respond_to.send(Err(QlError::MissingSession(recipient))); + return; + } let valid_until = now_secs().saturating_add(self.config.message_expiration.as_secs()); let body = MessageBody { message_id: id, @@ -245,15 +276,6 @@ impl Runtime

{ route_id, payload, }; - let message = message::encrypt_message( - QlHeader { - sender: self.platform.xid(), - recipient, - }, - &session_key, - body, - ); - let bytes = CBOR::from(message).to_cbor_data(); state.pending.insert( id, PendingEntry { @@ -266,7 +288,13 @@ impl Runtime

{ kind: TimeoutKind::Request { id }, })); let outbound_deadline = Instant::now() + self.config.message_expiration; - self.enqueue_outbound(state, recipient, bytes, outbound_deadline, Some(id)); + self.enqueue_outbound( + state, + recipient, + OutboundPayload::DeferredMessage(body), + outbound_deadline, + Some(id), + ); } fn handle_send_event( @@ -277,13 +305,12 @@ impl Runtime

{ payload: CBOR, ) { let id = state.next_message_id(); - let Some(session_key) = state - .peers - .peer(recipient) - .and_then(|p| p.session.session_key()) - else { + let Some(entry) = state.peers.peer(recipient) else { return; }; + if !entry.session.is_connected() { + return; + } let valid_until = now_secs().saturating_add(self.config.message_expiration.as_secs()); let body = MessageBody { message_id: id, @@ -292,17 +319,14 @@ impl Runtime

{ route_id, payload, }; - let message = message::encrypt_message( - QlHeader { - sender: self.platform.xid(), - recipient, - }, - &session_key, - body, - ); - let bytes = CBOR::from(message).to_cbor_data(); let outbound_deadline = Instant::now() + self.config.message_expiration; - self.enqueue_outbound(state, recipient, bytes, outbound_deadline, None); + self.enqueue_outbound( + state, + recipient, + OutboundPayload::DeferredMessage(body), + outbound_deadline, + None, + ); } fn handle_send_response( @@ -317,33 +341,29 @@ impl Runtime

{ MessageKind::Response | MessageKind::Nack => kind, _ => return, }; - let Some(session_key) = state - .peers - .peer(recipient) - .and_then(|p| p.session.session_key()) - else { + let Some(entry) = state.peers.peer(recipient) else { return; }; + if !entry.session.is_connected() { + return; + } let valid_until = now_secs().saturating_add(self.config.message_expiration.as_secs()); let body = MessageBody { message_id: id, valid_until, kind, - route_id: RouteId::new(0), + route_id: RouteId(0), payload, }; - let message = message::encrypt_message( - QlHeader { - sender: self.platform.xid(), - recipient, - }, - &session_key, - body, - ); - let bytes = CBOR::from(message).to_cbor_data(); let outbound_deadline = Instant::now() + self.config.message_expiration; - self.enqueue_outbound(state, recipient, bytes, outbound_deadline, None); + self.enqueue_outbound( + state, + recipient, + OutboundPayload::DeferredMessage(body), + outbound_deadline, + None, + ); } fn handle_incoming(&self, state: &mut RuntimeState, bytes: Vec) { @@ -399,13 +419,18 @@ impl Runtime

{ Ok(payload) => payload, Err(_) => return, }; - let peer = XID::new(&payload.signing_pub_key); + let peer = XID::new(SigningPublicKey::MLDSA(payload.signing_pub_key.clone())); state .peers .upsert_peer(peer, payload.signing_pub_key, payload.encapsulation_pub_key); + self.persist_peers(state); self.handle_connect(state, peer); } + fn persist_peers(&self, state: &RuntimeState) { + self.platform.persist_peers(state.peers.all()); + } + fn handle_record( &self, state: &mut RuntimeState, @@ -422,10 +447,23 @@ impl Runtime

{ }; let record = match message::decrypt_message(&header, &encrypted, &session_key) { Ok(record) => record, - // TODO: fix this - Err(message::MessageError::Nack { .. }) => return, + Err(message::MessageError::Nack { id, nack, kind }) => { + self.handle_message_nack(state, peer, id, nack, kind); + return; + } Err(message::MessageError::Error(_)) => return, }; + let namespace = match record.kind { + MessageKind::Request | MessageKind::Event => ReplayNamespace::Peer, + MessageKind::Response | MessageKind::Nack => ReplayNamespace::Local, + }; + let replay_key = ReplayKey::new(peer, namespace, record.message_id); + if state + .replay_cache + .check_and_store_valid_until(replay_key, record.valid_until) + { + return; + } self.record_activity(state, peer); match record.kind { MessageKind::Response => { @@ -453,6 +491,20 @@ impl Runtime

{ } } + fn handle_message_nack( + &self, + state: &mut RuntimeState, + peer: XID, + id: MessageId, + nack: Nack, + kind: MessageKind, + ) { + if kind != MessageKind::Request { + return; + } + self.handle_send_response(state, id, peer, CBOR::from(nack), MessageKind::Nack); + } + fn handle_heartbeat( &self, state: &mut RuntimeState, @@ -502,7 +554,13 @@ impl Runtime

{ ); let bytes = CBOR::from(message).to_cbor_data(); let outbound_deadline = Instant::now() + self.config.message_expiration; - self.enqueue_outbound(state, peer, bytes, outbound_deadline, None); + self.enqueue_outbound( + state, + peer, + OutboundPayload::PreEncoded(bytes), + outbound_deadline, + None, + ); } fn keep_alive_config(&self) -> Option { @@ -637,7 +695,13 @@ impl Runtime

{ payload: QlPayload::Handshake(HandshakeRecord::HelloReply(reply)), }; let bytes = CBOR::from(message).to_cbor_data(); - self.enqueue_outbound(state, peer, bytes, deadline, None); + self.enqueue_outbound( + state, + peer, + OutboundPayload::PreEncoded(bytes), + deadline, + None, + ); } HelloAction::Ignore => {} } @@ -710,7 +774,13 @@ impl Runtime

{ }; let bytes = CBOR::from(message).to_cbor_data(); let deadline = Instant::now() + self.config.handshake_timeout; - self.enqueue_outbound(state, peer, bytes, deadline, None); + self.enqueue_outbound( + state, + peer, + OutboundPayload::PreEncoded(bytes), + deadline, + None, + ); } fn handle_confirm( @@ -829,7 +899,7 @@ impl Runtime

{ peer, token, message_id: None, - bytes, + payload: OutboundPayload::PreEncoded(bytes), }); state.timeouts.push(Reverse(TimeoutEntry { at: deadline, @@ -845,7 +915,7 @@ impl Runtime

{ &self, state: &mut RuntimeState, peer: XID, - bytes: Vec, + payload: OutboundPayload, deadline: Instant, message_id: Option, ) { @@ -854,7 +924,7 @@ impl Runtime

{ peer, token, message_id, - bytes, + payload, }); state.timeouts.push(Reverse(TimeoutEntry { at: deadline, diff --git a/ql/src/runtime/handle.rs b/ql/src/runtime/handle.rs index d08ef995..004d9e4f 100644 --- a/ql/src/runtime/handle.rs +++ b/ql/src/runtime/handle.rs @@ -5,7 +5,7 @@ use std::{ task::{Context, Poll}, }; -use bc_components::{EncapsulationPublicKey, SigningPublicKey, XID}; +use bc_components::{MLDSAPublicKey, MLKEMPublicKey, XID}; use dcbor::CBOR; use crate::{ @@ -39,7 +39,7 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { pin!(&mut self.rx).poll(cx).map(|result| { let payload = result.unwrap_or(Err(QlError::Cancelled))?; - Ok(T::try_from(payload).map_err(|_| QlError::InvalidPayload)?) + T::try_from(payload).map_err(|_| QlError::InvalidPayload) }) } } @@ -48,8 +48,8 @@ impl RuntimeHandle { pub fn register_peer( &self, peer: XID, - signing_key: SigningPublicKey, - encapsulation_key: EncapsulationPublicKey, + signing_key: MLDSAPublicKey, + encapsulation_key: MLKEMPublicKey, ) { self.send(RuntimeCommand::RegisterPeer { peer, diff --git a/ql/src/runtime/internal.rs b/ql/src/runtime/internal.rs index f763b47c..f7e8769c 100644 --- a/ql/src/runtime/internal.rs +++ b/ql/src/runtime/internal.rs @@ -5,28 +5,22 @@ use std::{ time::{Instant, SystemTime, UNIX_EPOCH}, }; -use bc_components::{EncapsulationPublicKey, SigningPublicKey, SymmetricKey, XID}; +use bc_components::{MLDSAPublicKey, MLKEMPublicKey, SymmetricKey, XID}; use dcbor::CBOR; use crate::{ platform::PlatformFuture, - runtime::RequestConfig, + runtime::{replay_cache::ReplayCache, RequestConfig}, wire::{ handshake::{Hello, HelloReply}, - message::MessageKind, + message::{MessageBody, MessageKind}, }, - MessageId, QlError, RouteId, + MessageId, Peer, QlError, RouteId, }; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Token(u64); -impl Token { - pub(crate) fn next(self) -> Self { - Self(self.0.wrapping_add(1)) - } -} - #[derive(Debug, Clone)] pub struct KeepAliveState { pub token: Token, @@ -44,20 +38,22 @@ impl KeepAliveState { } } +impl Default for KeepAliveState { + fn default() -> Self { + Self::new() + } +} + #[derive(Debug, Clone)] pub struct PeerRecord { pub peer: XID, - pub signing_key: SigningPublicKey, - pub encapsulation_key: EncapsulationPublicKey, + pub signing_key: MLDSAPublicKey, + pub encapsulation_key: MLKEMPublicKey, pub session: PeerSession, } impl PeerRecord { - pub fn new( - peer: XID, - signing_key: SigningPublicKey, - encapsulation_key: EncapsulationPublicKey, - ) -> Self { + pub fn new(peer: XID, signing_key: MLDSAPublicKey, encapsulation_key: MLKEMPublicKey) -> Self { Self { peer, signing_key, @@ -88,8 +84,8 @@ impl PeerStore { pub fn upsert_peer( &mut self, peer: XID, - signing_key: SigningPublicKey, - encapsulation_key: EncapsulationPublicKey, + signing_key: MLDSAPublicKey, + encapsulation_key: MLKEMPublicKey, ) -> &mut PeerRecord { if let Some(index) = self.peers.iter().position(|record| record.peer == peer) { let record = &mut self.peers[index]; @@ -101,6 +97,17 @@ impl PeerStore { .push(PeerRecord::new(peer, signing_key, encapsulation_key)); self.peers.last_mut().expect("peer record just inserted") } + + pub fn all(&self) -> Vec { + self.peers + .iter() + .map(|record| Peer { + peer: record.peer, + signing_key: record.signing_key.clone(), + encapsulation_key: record.encapsulation_key.clone(), + }) + .collect() + } } #[derive(Debug, Clone)] @@ -117,7 +124,7 @@ pub enum PeerSession { handshake_token: Token, hello: Hello, reply: HelloReply, - secrets: crate::crypto::handshake::ResponderSecrets, + secrets: crate::wire::handshake::ResponderSecrets, deadline: Instant, }, Connected { @@ -129,10 +136,7 @@ pub enum PeerSession { impl PeerSession { #[inline] pub fn is_connected(&self) -> bool { - match self { - PeerSession::Connected { .. } => true, - _ => false, - } + matches!(self, PeerSession::Connected { .. }) } #[inline] @@ -153,8 +157,8 @@ pub enum InitiatorStage { pub(crate) enum RuntimeCommand { RegisterPeer { peer: XID, - signing_key: SigningPublicKey, - encapsulation_key: EncapsulationPublicKey, + signing_key: MLDSAPublicKey, + encapsulation_key: MLKEMPublicKey, }, Connect { peer: XID, @@ -186,31 +190,33 @@ pub struct RuntimeState { pub outbound: VecDeque, pub timeouts: BinaryHeap>, pub pending: HashMap, - pub next_message_id: u64, + pub next_message_id: Cell, + pub replay_cache: ReplayCache, } impl RuntimeState { pub fn new() -> Self { Self { peers: PeerStore::new(), - next_token: Cell::new(Token(0)), + next_token: Cell::new(Token(1)), outbound: VecDeque::new(), timeouts: BinaryHeap::new(), pending: HashMap::new(), - next_message_id: 1, + next_message_id: Cell::new(MessageId(1)), + replay_cache: ReplayCache::new(), } } pub fn next_token(&self) -> Token { let token = self.next_token.get(); - self.next_token.set(token.next()); + self.next_token.set(Token(token.0.wrapping_add(1))); token } - pub fn next_message_id(&mut self) -> MessageId { - let id = self.next_message_id; - self.next_message_id = id.wrapping_add(1); - MessageId::new(id) + pub fn next_message_id(&self) -> MessageId { + let id = self.next_message_id.get(); + self.next_message_id.set(MessageId(id.0.wrapping_add(1))); + id } } @@ -226,11 +232,16 @@ pub struct InFlightWrite<'a> { pub future: PlatformFuture<'a, Result<(), QlError>>, } +pub enum OutboundPayload { + PreEncoded(Vec), + DeferredMessage(MessageBody), +} + pub struct OutboundMessage { pub peer: XID, pub token: Token, pub message_id: Option, - pub bytes: Vec, + pub payload: OutboundPayload, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] diff --git a/ql/src/runtime/mod.rs b/ql/src/runtime/mod.rs index 8a57fa76..9a738968 100644 --- a/ql/src/runtime/mod.rs +++ b/ql/src/runtime/mod.rs @@ -4,9 +4,7 @@ pub use internal::{InitiatorStage, PeerSession, Token}; mod core; pub mod handle; pub(crate) mod internal; - -#[cfg(test)] -mod tests; +pub mod replay_cache; use std::time::Duration; diff --git a/ql/src/runtime/replay_cache.rs b/ql/src/runtime/replay_cache.rs new file mode 100644 index 00000000..876467c8 --- /dev/null +++ b/ql/src/runtime/replay_cache.rs @@ -0,0 +1,180 @@ +use std::{ + cmp::Reverse, + collections::{binary_heap::PeekMut, BinaryHeap, HashSet}, +}; + +use bc_components::XID; + +use crate::{runtime::internal::now_secs, MessageId}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ReplayNamespace { + Peer, + Local, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ReplayKey { + pub peer: XID, + pub namespace: ReplayNamespace, + pub message_id: MessageId, +} + +impl ReplayKey { + pub const fn new(peer: XID, namespace: ReplayNamespace, message_id: MessageId) -> Self { + Self { + peer, + namespace, + message_id, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct ExpiryEntry { + expires_at: u64, + key: ReplayKey, +} + +impl Ord for ExpiryEntry { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.expires_at + .cmp(&other.expires_at) + .then_with(|| self.key.cmp(&other.key)) + } +} + +impl PartialOrd for ExpiryEntry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +#[derive(Debug, Default)] +pub struct ReplayCache { + entries: HashSet, + expirations: BinaryHeap>, +} + +impl ReplayCache { + pub fn new() -> Self { + Self { + entries: HashSet::new(), + expirations: BinaryHeap::new(), + } + } + + pub fn len(&self) -> usize { + self.entries.len() + } + + pub fn is_empty(&self) -> bool { + self.entries.is_empty() + } + + pub fn add(&mut self, key: ReplayKey, expires_at: u64) { + if self.entries.insert(key) { + self.expirations + .push(Reverse(ExpiryEntry { expires_at, key })); + } + } + + pub fn check_and_store(&mut self, key: ReplayKey, expires_at: u64) -> bool { + let now_secs = now_secs(); + self.check_and_store_at(key, expires_at, now_secs) + } + + pub fn check_and_store_valid_until(&mut self, key: ReplayKey, valid_until: u64) -> bool { + let now_secs = now_secs(); + self.check_and_store_at(key, valid_until, now_secs) + } + + pub fn purge_expired(&mut self) { + let now_secs = now_secs(); + self.purge_expired_at(now_secs); + } + + pub fn clear_peer(&mut self, peer: XID) { + self.entries.retain(|entry| entry.peer != peer); + self.expirations.retain(|entry| entry.0.key.peer != peer); + } + + fn check_and_store_at(&mut self, key: ReplayKey, expires_at: u64, now_secs: u64) -> bool { + self.purge_expired_at(now_secs); + if self.entries.contains(&key) { + return true; + } + self.entries.insert(key); + self.expirations + .push(Reverse(ExpiryEntry { expires_at, key })); + false + } + + fn purge_expired_at(&mut self, now_secs: u64) { + while let Some(entry) = self.expirations.peek_mut() { + if entry.0.expires_at > now_secs { + break; + } + let entry = PeekMut::pop(entry).0; + self.entries.remove(&entry.key); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn peer_with_byte(byte: u8) -> XID { + XID::from_data([byte; XID::XID_SIZE]) + } + + #[test] + fn check_and_store_detects_replay() { + let mut cache = ReplayCache::new(); + let peer = peer_with_byte(1); + let key = ReplayKey::new(peer, ReplayNamespace::Peer, MessageId(1)); + let now_secs = 100; + let expires_at = 110; + + assert!(!cache.check_and_store_at(key, expires_at, now_secs)); + assert!(cache.check_and_store_at(key, expires_at, now_secs)); + } + + #[test] + fn purge_expired_removes_old_entries() { + let mut cache = ReplayCache::new(); + let now_secs = 100; + let expired_at = 99; + let future_at = 110; + + let key_old = ReplayKey::new(peer_with_byte(2), ReplayNamespace::Peer, MessageId(2)); + let key_new = ReplayKey::new(peer_with_byte(3), ReplayNamespace::Peer, MessageId(3)); + + cache.add(key_old, expired_at); + cache.add(key_new, future_at); + + cache.purge_expired_at(now_secs); + assert_eq!(cache.len(), 1); + assert!(!cache.check_and_store_at(key_old, future_at, now_secs)); + } + + #[test] + fn clear_peer_removes_peer_entries() { + let mut cache = ReplayCache::new(); + let now_secs = 100; + let expires_at = 110; + + let peer_a = peer_with_byte(4); + let peer_b = peer_with_byte(5); + let key_a = ReplayKey::new(peer_a, ReplayNamespace::Peer, MessageId(4)); + let key_b = ReplayKey::new(peer_b, ReplayNamespace::Peer, MessageId(5)); + + cache.add(key_a, expires_at); + cache.add(key_b, expires_at); + + cache.clear_peer(peer_a); + assert_eq!(cache.len(), 1); + assert!(!cache.check_and_store_at(key_a, expires_at, now_secs)); + } +} diff --git a/ql/src/runtime/tests.rs b/ql/src/runtime/tests.rs deleted file mode 100644 index ceec9294..00000000 --- a/ql/src/runtime/tests.rs +++ /dev/null @@ -1,1588 +0,0 @@ -use std::{ - future::Future, - sync::{ - atomic::{AtomicBool, AtomicU8, Ordering}, - Arc, - }, - time::Duration, -}; - -use async_channel::{Receiver, Sender}; -use bc_components::{ - EncapsulationPrivateKey, EncapsulationPublicKey, EncapsulationScheme, SignatureScheme, Signer, - SigningPrivateKey, SigningPublicKey, SymmetricKey, XID, -}; -use dcbor::CBOR; -use tokio::{sync::Semaphore, task::LocalSet}; - -use crate::{ - crypto::{handshake, heartbeat, pair}, - platform::{PlatformFuture, QlPlatform, QlPlatformExt}, - runtime::{ - internal::now_secs, new_runtime, HandlerEvent, KeepAliveConfig, PeerSession, RequestConfig, - RuntimeConfig, RuntimeHandle, - }, - wire::{ - handshake::HandshakeRecord, heartbeat::HeartbeatBody, message::Nack, QlHeader, QlPayload, - QlRecord, - }, - MessageId, QlError, RouteId, -}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum PeerStage { - Disconnected, - Initiator, - Responder, - Connected, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -struct StatusEvent { - peer: XID, - stage: PeerStage, -} - -struct TestPlatform { - signing_private: SigningPrivateKey, - signing_public: SigningPublicKey, - encapsulation_private: EncapsulationPrivateKey, - encapsulation_public: EncapsulationPublicKey, - outbound: Sender>, - status: Sender, - nonce_seed: u8, - nonce_counter: AtomicU8, -} - -impl TestPlatform { - fn new(seed: u8) -> (Self, Receiver>, Receiver) { - let (signing_private, signing_public) = SignatureScheme::MLDSA44.keypair(); - let (encapsulation_private, encapsulation_public) = - EncapsulationScheme::default().keypair(); - let (outbound, outbound_rx) = async_channel::unbounded(); - let (status, status_rx) = async_channel::unbounded(); - ( - Self { - signing_private, - signing_public, - encapsulation_private, - encapsulation_public, - outbound, - status, - nonce_seed: seed, - nonce_counter: AtomicU8::new(0), - }, - outbound_rx, - status_rx, - ) - } - - fn signing_public_key(&self) -> &SigningPublicKey { - &self.signing_public - } - - fn encapsulation_public_key(&self) -> &EncapsulationPublicKey { - &self.encapsulation_public - } -} - -impl QlPlatform for TestPlatform { - fn signer(&self) -> &dyn Signer { - &self.signing_private - } - - fn signing_public_key(&self) -> &SigningPublicKey { - &self.signing_public - } - - fn encapsulation_private_key(&self) -> &EncapsulationPrivateKey { - &self.encapsulation_private - } - - fn encapsulation_public_key(&self) -> &EncapsulationPublicKey { - &self.encapsulation_public - } - - fn fill_bytes(&self, data: &mut [u8]) { - let value = self - .nonce_seed - .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); - data.fill(value); - } - - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { - let outbound = self.outbound.clone(); - Box::pin(async move { - outbound - .send(message) - .await - .map_err(|_| QlError::InvalidPayload) - }) - } - - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { - Box::pin(tokio::time::sleep(duration)) - } - - fn handle_peer_status(&self, peer: XID, session: &PeerSession) { - let stage = match session { - PeerSession::Disconnected => PeerStage::Disconnected, - PeerSession::Initiator { .. } => PeerStage::Initiator, - PeerSession::Responder { .. } => PeerStage::Responder, - PeerSession::Connected { .. } => PeerStage::Connected, - }; - let _ = self.status.try_send(StatusEvent { peer, stage }); - } - - fn handle_inbound(&self, _event: crate::runtime::HandlerEvent) {} -} - -struct BlockingPlatform { - signing_private: SigningPrivateKey, - signing_public: SigningPublicKey, - encapsulation_private: EncapsulationPrivateKey, - encapsulation_public: EncapsulationPublicKey, - outbound: Sender>, - status: Sender, - nonce_seed: u8, - nonce_counter: AtomicU8, - write_gate: Arc, -} - -struct InboundPlatform { - signing_private: SigningPrivateKey, - signing_public: SigningPublicKey, - encapsulation_private: EncapsulationPrivateKey, - encapsulation_public: EncapsulationPublicKey, - outbound: Sender>, - status: Sender, - inbound: Sender, - nonce_seed: u8, - nonce_counter: AtomicU8, -} - -impl InboundPlatform { - fn new( - seed: u8, - ) -> ( - Self, - Receiver>, - Receiver, - Receiver, - ) { - let (signing_private, signing_public) = SignatureScheme::MLDSA44.keypair(); - let (encapsulation_private, encapsulation_public) = - EncapsulationScheme::default().keypair(); - let (outbound, outbound_rx) = async_channel::unbounded(); - let (status, status_rx) = async_channel::unbounded(); - let (inbound, inbound_rx) = async_channel::unbounded(); - ( - Self { - signing_private, - signing_public, - encapsulation_private, - encapsulation_public, - outbound, - status, - inbound, - nonce_seed: seed, - nonce_counter: AtomicU8::new(0), - }, - outbound_rx, - status_rx, - inbound_rx, - ) - } -} - -impl QlPlatform for InboundPlatform { - fn signer(&self) -> &dyn Signer { - &self.signing_private - } - - fn signing_public_key(&self) -> &SigningPublicKey { - &self.signing_public - } - - fn encapsulation_private_key(&self) -> &EncapsulationPrivateKey { - &self.encapsulation_private - } - - fn encapsulation_public_key(&self) -> &EncapsulationPublicKey { - &self.encapsulation_public - } - - fn fill_bytes(&self, data: &mut [u8]) { - let value = self - .nonce_seed - .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); - data.fill(value); - } - - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { - let outbound = self.outbound.clone(); - Box::pin(async move { - outbound - .send(message) - .await - .map_err(|_| QlError::InvalidPayload) - }) - } - - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { - Box::pin(tokio::time::sleep(duration)) - } - - fn handle_peer_status(&self, peer: XID, session: &PeerSession) { - let stage = match session { - PeerSession::Disconnected => PeerStage::Disconnected, - PeerSession::Initiator { .. } => PeerStage::Initiator, - PeerSession::Responder { .. } => PeerStage::Responder, - PeerSession::Connected { .. } => PeerStage::Connected, - }; - let _ = self.status.try_send(StatusEvent { peer, stage }); - } - - fn handle_inbound(&self, event: HandlerEvent) { - let _ = self.inbound.try_send(event); - } -} - -impl BlockingPlatform { - fn new( - seed: u8, - ) -> ( - Self, - Receiver>, - Receiver, - Arc, - ) { - let (signing_private, signing_public) = SignatureScheme::MLDSA44.keypair(); - let (encapsulation_private, encapsulation_public) = - EncapsulationScheme::default().keypair(); - let (outbound, outbound_rx) = async_channel::unbounded(); - let (status, status_rx) = async_channel::unbounded(); - let write_gate = Arc::new(Semaphore::new(0)); - ( - Self { - signing_private, - signing_public, - encapsulation_private, - encapsulation_public, - outbound, - status, - nonce_seed: seed, - nonce_counter: AtomicU8::new(0), - write_gate: write_gate.clone(), - }, - outbound_rx, - status_rx, - write_gate, - ) - } -} - -impl QlPlatform for BlockingPlatform { - fn signer(&self) -> &dyn Signer { - &self.signing_private - } - - fn signing_public_key(&self) -> &SigningPublicKey { - &self.signing_public - } - - fn encapsulation_private_key(&self) -> &EncapsulationPrivateKey { - &self.encapsulation_private - } - - fn encapsulation_public_key(&self) -> &EncapsulationPublicKey { - &self.encapsulation_public - } - - fn fill_bytes(&self, data: &mut [u8]) { - let value = self - .nonce_seed - .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); - data.fill(value); - } - - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { - let outbound = self.outbound.clone(); - let write_gate = self.write_gate.clone(); - Box::pin(async move { - let _permit = write_gate.acquire().await.unwrap(); - outbound - .send(message) - .await - .map_err(|_| QlError::InvalidPayload) - }) - } - - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { - Box::pin(tokio::time::sleep(duration)) - } - - fn handle_peer_status(&self, peer: XID, session: &PeerSession) { - let stage = match session { - PeerSession::Disconnected => PeerStage::Disconnected, - PeerSession::Initiator { .. } => PeerStage::Initiator, - PeerSession::Responder { .. } => PeerStage::Responder, - PeerSession::Connected { .. } => PeerStage::Connected, - }; - let _ = self.status.try_send(StatusEvent { peer, stage }); - } - - fn handle_inbound(&self, _event: crate::runtime::HandlerEvent) {} -} - -async fn run_local_test(future: F) -where - F: Future, -{ - let local = LocalSet::new(); - local.run_until(future).await; -} - -fn spawn_forwarder(outbound: Receiver>, handle: RuntimeHandle) { - tokio::task::spawn_local(async move { - while let Ok(bytes) = outbound.recv().await { - let _ = handle.send_incoming(bytes); - } - }); -} - -fn is_heartbeat(bytes: &[u8]) -> bool { - let Ok(record) = CBOR::try_from_data(bytes).and_then(QlRecord::try_from) else { - return false; - }; - matches!(record.payload, QlPayload::Heartbeat(_)) -} - -fn spawn_heartbeat_tap_forwarder( - outbound: Receiver>, - handle: RuntimeHandle, - heartbeat_tx: Sender<()>, -) { - tokio::task::spawn_local(async move { - while let Ok(bytes) = outbound.recv().await { - if is_heartbeat(&bytes) { - let _ = heartbeat_tx.send(()).await; - } - let _ = handle.send_incoming(bytes); - } - }); -} - -fn spawn_drop_heartbeat_forwarder(outbound: Receiver>, handle: RuntimeHandle) { - tokio::task::spawn_local(async move { - while let Ok(bytes) = outbound.recv().await { - if is_heartbeat(&bytes) { - continue; - } - let _ = handle.send_incoming(bytes); - } - }); -} - -fn spawn_gated_forwarder( - outbound: Receiver>, - handle: RuntimeHandle, - drop_flag: Arc, -) { - tokio::task::spawn_local(async move { - while let Ok(bytes) = outbound.recv().await { - if drop_flag.load(Ordering::Relaxed) { - continue; - } - let _ = handle.send_incoming(bytes); - } - }); -} - -fn spawn_routed_forwarder(outbound: Receiver>, routes: Vec<(XID, RuntimeHandle)>) { - spawn_routed_forwarder_with_filter(outbound, routes, |_| true); -} - -fn spawn_routed_forwarder_with_filter( - outbound: Receiver>, - routes: Vec<(XID, RuntimeHandle)>, - filter: F, -) where - F: Fn(&QlRecord) -> bool + Send + Sync + 'static, -{ - tokio::task::spawn_local(async move { - while let Ok(bytes) = outbound.recv().await { - let Ok(record) = CBOR::try_from_data(&bytes).and_then(QlRecord::try_from) else { - continue; - }; - if !filter(&record) { - continue; - } - if let Some((_, handle)) = routes - .iter() - .find(|(peer, _)| *peer == record.header.recipient) - { - let _ = handle.send_incoming(bytes); - } - } - }); -} - -#[derive(Clone)] -struct PeerIdentity { - xid: XID, - signing_key: SigningPublicKey, - encapsulation_key: EncapsulationPublicKey, -} - -fn peer_identity(platform: &impl QlPlatformExt) -> PeerIdentity { - PeerIdentity { - xid: platform.xid(), - signing_key: platform.signing_public_key().clone(), - encapsulation_key: platform.encapsulation_public_key().clone(), - } -} - -fn register_peers( - handle_a: &RuntimeHandle, - handle_b: &RuntimeHandle, - identity_a: &PeerIdentity, - identity_b: &PeerIdentity, -) -> (XID, XID) { - let peer_a = identity_a.xid; - let peer_b = identity_b.xid; - handle_a.register_peer( - peer_b, - identity_b.signing_key.clone(), - identity_b.encapsulation_key.clone(), - ); - handle_b.register_peer( - peer_a, - identity_a.signing_key.clone(), - identity_a.encapsulation_key.clone(), - ); - (peer_a, peer_b) -} - -async fn await_status( - receiver: &Receiver, - peer: XID, - stage: PeerStage, -) -> StatusEvent { - tokio::time::timeout(Duration::from_secs(1), async { - loop { - if let Ok(event) = receiver.recv().await { - if event.peer == peer && event.stage == stage { - return event; - } - } - } - }) - .await - .unwrap() -} - -#[tokio::test(flavor = "current_thread")] -async fn handshake_initiator_connects() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = - new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_drop_heartbeat_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn handshake_timeout_disconnects() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(50)); - let (platform_a, _outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); - - let peer_b = platform_b.xid(); - let (runtime_a, handle_a) = new_runtime(platform_a, config); - tokio::task::spawn_local(async move { runtime_a.run().await }); - - handle_a.register_peer( - peer_b, - platform_b.signing_public_key().clone(), - platform_b.encapsulation_public_key().clone(), - ); - - handle_a.connect(peer_b).unwrap(); - - await_status(&status_a, peer_b, PeerStage::Disconnected).await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn simultaneous_handshakes_resolve() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = - new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - handle_b.connect(peer_a.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Initiator).await; - await_status(&status_b, peer_a.xid, PeerStage::Responder).await; - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn invalid_signature_disconnects() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, _status_b) = TestPlatform::new(2); - let (wrong_private, wrong_public) = SignatureScheme::MLDSA44.keypair(); - let _ = wrong_private; - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = - new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - handle_a.register_peer(peer_b.xid, wrong_public, peer_b.encapsulation_key.clone()); - handle_b.register_peer( - peer_a.xid, - peer_a.signing_key.clone(), - peer_a.encapsulation_key.clone(), - ); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn pairing_request_triggers_handshake() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let pairing_message = pair::build_pair_request( - &platform_a, - peer_b.xid, - &peer_b.encapsulation_key, - MessageId::new(1), - Duration::from_secs(1), - ) - .unwrap(); - let pairing_bytes = CBOR::from(pairing_message).to_cbor_data(); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = - new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - handle_a.register_peer( - peer_b.xid, - peer_b.signing_key.clone(), - peer_b.encapsulation_key.clone(), - ); - - handle_b.send_incoming(pairing_bytes); - - await_status(&status_b, peer_a.xid, PeerStage::Initiator).await; - await_status(&status_a, peer_b.xid, PeerStage::Responder).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn request_response_round_trip() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_request_timeout(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config.clone()); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let inbound_task = tokio::task::spawn_local(async move { - if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { - let _ = request.respond_to.respond(99u8); - } - }); - - let response = handle_a.send_request_raw( - peer_b.xid, - RouteId::new(7), - CBOR::from(12u8), - RequestConfig::default(), - ); - - let response = response.recv().await.unwrap(); - let value: u8 = response.try_into().unwrap(); - assert_eq!(value, 99u8); - let _ = inbound_task.await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn request_timeout_returns_error() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_request_timeout(Duration::from_millis(30)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config.clone()); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let ticket = handle_a.send_request_raw( - peer_b.xid, - RouteId::new(1), - CBOR::from(1u8), - RequestConfig { - timeout: Some(Duration::from_millis(30)), - }, - ); - - let result = tokio::time::timeout(Duration::from_millis(200), ticket.recv()) - .await - .unwrap(); - assert!(matches!(result, Err(QlError::Timeout))); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn request_nack_resolves_pending() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_request_timeout(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config.clone()); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let inbound_task = tokio::task::spawn_local(async move { - if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { - let _ = request.respond_to.respond_nack(Nack::InvalidPayload); - } - }); - - let response = handle_a.send_request_raw( - peer_b.xid, - RouteId::new(2), - CBOR::from(2u8), - RequestConfig::default(), - ); - - let result = response.recv().await; - assert!(matches!( - result, - Err(QlError::Nack { - nack: Nack::InvalidPayload, - .. - }) - )); - let _ = inbound_task.await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn request_dispatches_to_platform_callback() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_request_timeout(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config.clone()); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let inbound_task = tokio::task::spawn_local(async move { - if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { - let _ = request.respond_to.respond(7u8); - } - }); - - let ticket = handle_a.send_request_raw( - peer_b.xid, - RouteId::new(3), - CBOR::from(1u8), - RequestConfig::default(), - ); - - let response = ticket.recv().await.unwrap(); - let value: u8 = response.try_into().unwrap(); - assert_eq!(value, 7u8); - let _ = inbound_task.await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn blocked_write_still_times_out() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(40)); - let (platform_a, _outbound_a, status_a, _write_gate) = BlockingPlatform::new(2); - let (platform_b, _outbound_b, _status_b) = TestPlatform::new(1); - - let signing_b = platform_b.signing_public_key().clone(); - let encap_b = platform_b.encapsulation_public_key().clone(); - let peer_b = XID::new(&signing_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - tokio::task::spawn_local(async move { runtime_a.run().await }); - - handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); - - handle_a.connect(peer_b).unwrap(); - - await_status(&status_a, peer_b, PeerStage::Initiator).await; - await_status(&status_a, peer_b, PeerStage::Disconnected).await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn handshake_timeout_drops_queued_messages() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(60)); - let (platform_a, outbound_a, status_a, write_gate) = BlockingPlatform::new(2); - let (platform_b, _outbound_b, _status_b) = TestPlatform::new(1); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - tokio::task::spawn_local(async move { runtime_a.run().await }); - - handle_a.register_peer( - peer_b.xid, - peer_b.signing_key.clone(), - peer_b.encapsulation_key.clone(), - ); - - handle_a.connect(peer_b.xid).unwrap(); - await_status(&status_a, peer_b.xid, PeerStage::Initiator).await; - - let (hello, _secret) = - handshake::build_hello(&platform_b, peer_b.xid, peer_a.xid, &peer_a.encapsulation_key) - .unwrap(); - let message = QlRecord { - header: QlHeader { - sender: peer_b.xid, - recipient: peer_a.xid, - }, - payload: QlPayload::Handshake(HandshakeRecord::Hello(hello)), - }; - let bytes = CBOR::from(message).to_cbor_data(); - handle_a.send_incoming(bytes); - - await_status(&status_a, peer_b.xid, PeerStage::Responder).await; - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - - write_gate.add_permits(1); - let _ = tokio::time::timeout(Duration::from_millis(100), outbound_a.recv()) - .await - .unwrap() - .unwrap(); - - write_gate.add_permits(1); - let second = tokio::time::timeout(Duration::from_millis(50), outbound_a.recv()).await; - assert!( - second.is_err(), - "expected queued handshake reply to be dropped" - ); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn heartbeat_ignored_without_session() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, _status_a) = TestPlatform::new(1); - let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); - - let peer_a = platform_a.xid(); - let peer_b = platform_b.xid(); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - tokio::task::spawn_local(async move { runtime_a.run().await }); - - handle_a.register_peer( - peer_b, - platform_b.signing_public_key().clone(), - platform_b.encapsulation_public_key().clone(), - ); - - let heartbeat = heartbeat::encrypt_heartbeat( - QlHeader { - sender: peer_b, - recipient: peer_a, - }, - &SymmetricKey::new(), - HeartbeatBody { - message_id: MessageId::new(1), - valid_until: now_secs().saturating_add(60), - }, - ); - let bytes = CBOR::from(heartbeat).to_cbor_data(); - handle_a.send_incoming(bytes); - - let result = tokio::time::timeout(Duration::from_millis(50), outbound_a.recv()).await; - assert!(result.is_err(), "expected heartbeat to be ignored"); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn keepalive_disabled_no_heartbeat() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - - let signing_a = platform_a.signing_public_key().clone(); - let signing_b = platform_b.signing_public_key().clone(); - let encap_a = platform_a.encapsulation_public_key().clone(); - let encap_b = platform_b.encapsulation_public_key().clone(); - let peer_a = XID::new(&signing_a); - let peer_b = XID::new(&signing_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); - spawn_forwarder(outbound_b, handle_a.clone()); - - handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); - handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); - - handle_a.connect(peer_b).unwrap(); - - await_status(&status_a, peer_b, PeerStage::Connected).await; - await_status(&status_b, peer_a, PeerStage::Connected).await; - - let result = tokio::time::timeout(Duration::from_millis(120), heartbeat_rx.recv()).await; - assert!(result.is_err(), "unexpected heartbeat while disabled"); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn heartbeat_sent_after_idle() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(30), - timeout: Duration::from_millis(80), - }; - let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); - let config_b = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - - let signing_a = platform_a.signing_public_key().clone(); - let signing_b = platform_b.signing_public_key().clone(); - let encap_a = platform_a.encapsulation_public_key().clone(); - let encap_b = platform_b.encapsulation_public_key().clone(); - let peer_a = XID::new(&signing_a); - let peer_b = XID::new(&signing_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); - spawn_forwarder(outbound_b, handle_a.clone()); - - handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); - handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); - - handle_a.connect(peer_b).unwrap(); - - await_status(&status_a, peer_b, PeerStage::Connected).await; - await_status(&status_b, peer_a, PeerStage::Connected).await; - - let _ = tokio::time::timeout(Duration::from_millis(200), heartbeat_rx.recv()) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn heartbeat_reply_when_connected() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(30), - timeout: Duration::from_millis(80), - }; - let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); - let config_b = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - - let signing_a = platform_a.signing_public_key().clone(); - let signing_b = platform_b.signing_public_key().clone(); - let encap_a = platform_a.encapsulation_public_key().clone(); - let encap_b = platform_b.encapsulation_public_key().clone(); - let peer_a = XID::new(&signing_a); - let peer_b = XID::new(&signing_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_ab_tx, heartbeat_ab_rx) = async_channel::unbounded(); - let (heartbeat_ba_tx, heartbeat_ba_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_ab_tx); - spawn_heartbeat_tap_forwarder(outbound_b, handle_a.clone(), heartbeat_ba_tx); - - handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); - handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); - - handle_a.connect(peer_b).unwrap(); - - await_status(&status_a, peer_b, PeerStage::Connected).await; - await_status(&status_b, peer_a, PeerStage::Connected).await; - - let _ = tokio::time::timeout(Duration::from_millis(200), heartbeat_ab_rx.recv()) - .await - .unwrap() - .unwrap(); - let _ = tokio::time::timeout(Duration::from_millis(200), heartbeat_ba_rx.recv()) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn any_message_clears_pending() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(120), - timeout: Duration::from_millis(40), - }; - let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); - let config_b = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - - let signing_a = platform_a.signing_public_key().clone(); - let signing_b = platform_b.signing_public_key().clone(); - let encap_a = platform_a.encapsulation_public_key().clone(); - let encap_b = platform_b.encapsulation_public_key().clone(); - let peer_a = XID::new(&signing_a); - let peer_b = XID::new(&signing_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); - spawn_drop_heartbeat_forwarder(outbound_b, handle_a.clone()); - - handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); - handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); - - handle_a.connect(peer_b).unwrap(); - - await_status(&status_a, peer_b, PeerStage::Connected).await; - await_status(&status_b, peer_a, PeerStage::Connected).await; - - let _ = tokio::time::timeout(Duration::from_millis(200), heartbeat_rx.recv()) - .await - .unwrap() - .unwrap(); - - handle_b.send_event_raw(peer_a, RouteId::new(99), CBOR::from(1u8)); - - let window = keep_alive.timeout + Duration::from_millis(20); - let disconnect = tokio::time::timeout(window, async { - loop { - if let Ok(event) = status_a.recv().await { - if event.peer == peer_b && event.stage == PeerStage::Disconnected { - return; - } - } - } - }) - .await; - assert!(disconnect.is_err(), "unexpected disconnect"); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn heartbeat_timeout_disconnects_and_drops_outbound() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(80), - timeout: Duration::from_millis(60), - }; - let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); - let config_b = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(2); - let (platform_b, outbound_b, status_b) = TestPlatform::new(1); - - let signing_a = platform_a.signing_public_key().clone(); - let signing_b = platform_b.signing_public_key().clone(); - let encap_a = platform_a.encapsulation_public_key().clone(); - let encap_b = platform_b.encapsulation_public_key().clone(); - let peer_a = XID::new(&signing_a); - let peer_b = XID::new(&signing_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let drop_flag = Arc::new(AtomicBool::new(false)); - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_gated_forwarder(outbound_b, handle_a.clone(), drop_flag.clone()); - - handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); - handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); - - handle_a.connect(peer_b).unwrap(); - - await_status(&status_a, peer_b, PeerStage::Connected).await; - await_status(&status_b, peer_a, PeerStage::Connected).await; - - drop_flag.store(true, Ordering::Relaxed); - - let response = handle_a.send_request_raw( - peer_b, - RouteId::new(9), - CBOR::from(9u8), - RequestConfig { - timeout: Some(Duration::from_millis(200)), - }, - ); - - await_status(&status_a, peer_b, PeerStage::Disconnected).await; - - let result = tokio::time::timeout(Duration::from_millis(300), response.recv()) - .await - .unwrap(); - assert!( - matches!(result, Err(QlError::SendFailed)), - "unexpected result: {result:?}" - ); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn no_ping_pong() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(200), - timeout: Duration::from_millis(60), - }; - let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); - let config_b = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - - let signing_a = platform_a.signing_public_key().clone(); - let signing_b = platform_b.signing_public_key().clone(); - let encap_a = platform_a.encapsulation_public_key().clone(); - let encap_b = platform_b.encapsulation_public_key().clone(); - let peer_a = XID::new(&signing_a); - let peer_b = XID::new(&signing_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_ab_tx, heartbeat_ab_rx) = async_channel::unbounded(); - let (heartbeat_ba_tx, heartbeat_ba_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_ab_tx); - spawn_heartbeat_tap_forwarder(outbound_b, handle_a.clone(), heartbeat_ba_tx); - - handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); - handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); - - handle_a.connect(peer_b).unwrap(); - - await_status(&status_a, peer_b, PeerStage::Connected).await; - await_status(&status_b, peer_a, PeerStage::Connected).await; - - let _ = tokio::time::timeout(Duration::from_millis(300), heartbeat_ab_rx.recv()) - .await - .unwrap() - .unwrap(); - let _ = tokio::time::timeout(Duration::from_millis(200), heartbeat_ba_rx.recv()) - .await - .unwrap() - .unwrap(); - - let followup = - tokio::time::timeout(Duration::from_millis(50), heartbeat_ab_rx.recv()).await; - assert!(followup.is_err(), "unexpected heartbeat ping-pong"); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn invalid_heartbeat_ignored() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - - let signing_a = platform_a.signing_public_key().clone(); - let signing_b = platform_b.signing_public_key().clone(); - let encap_a = platform_a.encapsulation_public_key().clone(); - let encap_b = platform_b.encapsulation_public_key().clone(); - let peer_a = XID::new(&signing_a); - let peer_b = XID::new(&signing_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); - spawn_forwarder(outbound_b, handle_a.clone()); - - handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); - handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); - - handle_a.connect(peer_b).unwrap(); - - await_status(&status_a, peer_b, PeerStage::Connected).await; - await_status(&status_b, peer_a, PeerStage::Connected).await; - - let heartbeat = heartbeat::encrypt_heartbeat( - QlHeader { - sender: peer_b, - recipient: peer_a, - }, - &SymmetricKey::new(), - HeartbeatBody { - message_id: MessageId::new(42), - valid_until: now_secs().saturating_add(30), - }, - ); - let bytes = CBOR::from(heartbeat).to_cbor_data(); - handle_a.send_incoming(bytes); - - let result = tokio::time::timeout(Duration::from_millis(50), heartbeat_rx.recv()).await; - assert!(result.is_err(), "unexpected heartbeat reply"); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn multi_peer_simultaneous_handshakes() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let (platform_c, outbound_c, status_c) = TestPlatform::new(3); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - let peer_c = peer_identity(&platform_c); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = - new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); - let (runtime_c, handle_c) = - new_runtime(platform_c, RuntimeConfig::new(Duration::from_millis(200))); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - tokio::task::spawn_local(async move { runtime_c.run().await }); - - spawn_routed_forwarder( - outbound_a, - vec![(peer_b.xid, handle_b.clone()), (peer_c.xid, handle_c.clone())], - ); - spawn_routed_forwarder(outbound_b, vec![(peer_a.xid, handle_a.clone())]); - spawn_routed_forwarder(outbound_c, vec![(peer_a.xid, handle_a.clone())]); - - let _ = register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - let _ = register_peers(&handle_a, &handle_c, &peer_a, &peer_c); - - handle_a.connect(peer_b.xid).unwrap(); - handle_a.connect(peer_c.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_a, peer_c.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - await_status(&status_c, peer_a.xid, PeerStage::Connected).await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn multi_peer_keepalive_disconnect_isolated() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(40), - timeout: Duration::from_millis(60), - }; - let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); - let config_b = RuntimeConfig::new(Duration::from_millis(200)); - let config_c = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let (platform_c, outbound_c, status_c) = TestPlatform::new(3); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - let peer_c = peer_identity(&platform_c); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - let (runtime_c, handle_c) = new_runtime(platform_c, config_c); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - tokio::task::spawn_local(async move { runtime_c.run().await }); - - let drop_b_to_a = Arc::new(AtomicBool::new(false)); - spawn_routed_forwarder( - outbound_a, - vec![(peer_b.xid, handle_b.clone()), (peer_c.xid, handle_c.clone())], - ); - spawn_routed_forwarder_with_filter(outbound_b, vec![(peer_a.xid, handle_a.clone())], { - let drop_b_to_a = drop_b_to_a.clone(); - move |record| { - !(drop_b_to_a.load(Ordering::Relaxed) && record.header.recipient == peer_a.xid) - } - }); - spawn_routed_forwarder(outbound_c, vec![(peer_a.xid, handle_a.clone())]); - - let _ = register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - let _ = register_peers(&handle_a, &handle_c, &peer_a, &peer_c); - - handle_a.connect(peer_b.xid).unwrap(); - handle_a.connect(peer_c.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_a, peer_c.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - await_status(&status_c, peer_a.xid, PeerStage::Connected).await; - - drop_b_to_a.store(true, Ordering::Relaxed); - - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - - let disconnect = - tokio::time::timeout(keep_alive.timeout + Duration::from_millis(80), async { - loop { - if let Ok(event) = status_a.recv().await { - if event.peer == peer_c.xid && event.stage == PeerStage::Disconnected { - return; - } - } - } - }) - .await; - assert!(disconnect.is_err(), "unexpected disconnect for peer C"); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn multi_peer_disconnect_drops_outbound_for_one() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(40), - timeout: Duration::from_millis(60), - }; - let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); - let config_b = RuntimeConfig::new(Duration::from_millis(200)); - let config_c = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let (platform_c, outbound_c, status_c, inbound_c) = InboundPlatform::new(3); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - let peer_c = peer_identity(&platform_c); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - let (runtime_c, handle_c) = new_runtime(platform_c, config_c); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - tokio::task::spawn_local(async move { runtime_c.run().await }); - - let drop_b_to_a = Arc::new(AtomicBool::new(false)); - spawn_routed_forwarder( - outbound_a, - vec![(peer_b.xid, handle_b.clone()), (peer_c.xid, handle_c.clone())], - ); - spawn_routed_forwarder_with_filter(outbound_b, vec![(peer_a.xid, handle_a.clone())], { - let drop_b_to_a = drop_b_to_a.clone(); - move |record| { - !(drop_b_to_a.load(Ordering::Relaxed) && record.header.recipient == peer_a.xid) - } - }); - spawn_routed_forwarder(outbound_c, vec![(peer_a.xid, handle_a.clone())]); - - let _ = register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - let _ = register_peers(&handle_a, &handle_c, &peer_a, &peer_c); - - handle_a.connect(peer_b.xid).unwrap(); - handle_a.connect(peer_c.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_a, peer_c.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - await_status(&status_c, peer_a.xid, PeerStage::Connected).await; - - let inbound_task = tokio::task::spawn_local(async move { - if let Ok(HandlerEvent::Request(request)) = inbound_c.recv().await { - let _ = request.respond_to.respond(55u8); - } - }); - - drop_b_to_a.store(true, Ordering::Relaxed); - - let request_b = handle_a.send_request_raw( - peer_b.xid, - RouteId::new(10), - CBOR::from(10u8), - RequestConfig { - timeout: Some(Duration::from_millis(200)), - }, - ); - let request_c = handle_a.send_request_raw( - peer_c.xid, - RouteId::new(11), - CBOR::from(11u8), - RequestConfig { - timeout: Some(Duration::from_millis(200)), - }, - ); - - let response_c = tokio::time::timeout(Duration::from_millis(200), request_c.recv()) - .await - .expect("response wait") - .expect("response channel"); - let value: u8 = response_c.try_into().unwrap(); - assert_eq!(value, 55u8); - - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - - let result_b = tokio::time::timeout(Duration::from_millis(200), request_b.recv()) - .await - .expect("response wait"); - assert!(matches!(result_b, Err(QlError::SendFailed))); - - let _ = inbound_task.await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn multi_peer_activity_is_per_peer() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(100), - timeout: Duration::from_millis(40), - }; - let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); - let config_b = RuntimeConfig::new(Duration::from_millis(200)); - let config_c = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let (platform_c, outbound_c, status_c) = TestPlatform::new(3); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - let peer_c = peer_identity(&platform_c); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - let (runtime_c, handle_c) = new_runtime(platform_c, config_c); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - tokio::task::spawn_local(async move { runtime_c.run().await }); - - let drop_all_c = Arc::new(AtomicBool::new(false)); - spawn_routed_forwarder( - outbound_a, - vec![ - (peer_b.xid, handle_b.clone()), - (peer_c.xid, handle_c.clone()), - ], - ); - spawn_drop_heartbeat_forwarder(outbound_b, handle_a.clone()); - spawn_gated_forwarder(outbound_c, handle_a.clone(), drop_all_c.clone()); - - let _ = register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - let _ = register_peers(&handle_a, &handle_c, &peer_a, &peer_c); - - handle_a.connect(peer_b.xid).unwrap(); - handle_a.connect(peer_c.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_a, peer_c.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - await_status(&status_c, peer_a.xid, PeerStage::Connected).await; - - drop_all_c.store(true, Ordering::Relaxed); - - tokio::time::sleep(keep_alive.interval + Duration::from_millis(5)).await; - - handle_b.send_event_raw(peer_a.xid, RouteId::new(99), CBOR::from(1u8)); - - await_status(&status_a, peer_c.xid, PeerStage::Disconnected).await; - - let disconnect = - tokio::time::timeout(keep_alive.timeout + Duration::from_millis(30), async { - loop { - if let Ok(event) = status_a.recv().await { - if event.peer == peer_b.xid && event.stage == PeerStage::Disconnected { - return; - } - } - } - }) - .await; - assert!(disconnect.is_err(), "unexpected disconnect for peer B"); - }) - .await; -} diff --git a/ql/src/tests/handshake.rs b/ql/src/tests/handshake.rs new file mode 100644 index 00000000..34580edb --- /dev/null +++ b/ql/src/tests/handshake.rs @@ -0,0 +1,292 @@ +use super::*; + +#[tokio::test(flavor = "current_thread")] +async fn handshake_initiator_connects() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = + new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_drop_heartbeat_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn handshake_timeout_disconnects() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(50)); + let (platform_a, _outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); + + let peer_b = platform_b.xid(); + let (runtime_a, handle_a) = new_runtime(platform_a, config); + tokio::task::spawn_local(async move { runtime_a.run().await }); + + handle_a.register_peer( + peer_b, + platform_b.signing_public_key().clone(), + platform_b.encapsulation_public_key().clone(), + ); + + handle_a.connect(peer_b).unwrap(); + + await_status(&status_a, peer_b, PeerStage::Disconnected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn simultaneous_handshakes_resolve() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = + new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + handle_b.connect(peer_a.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Initiator).await; + await_status(&status_b, peer_a.xid, PeerStage::Responder).await; + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn invalid_signature_disconnects() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, _status_b) = TestPlatform::new(2); + let (wrong_private, wrong_public) = MLDSA::MLDSA44.keypair(); + let _ = wrong_private; + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = + new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + handle_a.register_peer(peer_b.xid, wrong_public, peer_b.encapsulation_key.clone()); + handle_b.register_peer( + peer_a.xid, + peer_a.signing_key.clone(), + peer_a.encapsulation_key.clone(), + ); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn pairing_request_triggers_handshake() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let pairing_message = pair::build_pair_request( + &platform_a, + peer_b.xid, + &peer_b.encapsulation_key, + MessageId(1), + Duration::from_secs(1), + ) + .unwrap(); + let pairing_bytes = CBOR::from(pairing_message).to_cbor_data(); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = + new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + handle_a.register_peer( + peer_b.xid, + peer_b.signing_key.clone(), + peer_b.encapsulation_key.clone(), + ); + + handle_b.send_incoming(pairing_bytes); + + await_status(&status_b, peer_a.xid, PeerStage::Initiator).await; + await_status(&status_a, peer_b.xid, PeerStage::Responder).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn blocked_write_still_times_out() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(40)); + let (platform_a, _outbound_a, status_a, _write_gate) = BlockingPlatform::new(2); + let (platform_b, _outbound_b, _status_b) = TestPlatform::new(1); + + let signing_b = platform_b.signing_public_key().clone(); + let encap_b = platform_b.encapsulation_public_key().clone(); + let peer_b = platform_b.xid(); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + tokio::task::spawn_local(async move { runtime_a.run().await }); + + handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); + + handle_a.connect(peer_b).unwrap(); + + await_status(&status_a, peer_b, PeerStage::Initiator).await; + await_status(&status_a, peer_b, PeerStage::Disconnected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn handshake_timeout_drops_queued_messages() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(60)); + let (platform_a, outbound_a, status_a, write_gate) = BlockingPlatform::new(2); + let (platform_b, _outbound_b, _status_b) = TestPlatform::new(1); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + tokio::task::spawn_local(async move { runtime_a.run().await }); + + handle_a.register_peer( + peer_b.xid, + peer_b.signing_key.clone(), + peer_b.encapsulation_key.clone(), + ); + + handle_a.connect(peer_b.xid).unwrap(); + await_status(&status_a, peer_b.xid, PeerStage::Initiator).await; + + let (hello, _secret) = wire::handshake::build_hello( + &platform_b, + peer_b.xid, + peer_a.xid, + &peer_a.encapsulation_key, + ) + .unwrap(); + let message = QlRecord { + header: QlHeader { + sender: peer_b.xid, + recipient: peer_a.xid, + }, + payload: QlPayload::Handshake(HandshakeRecord::Hello(hello)), + }; + let bytes = CBOR::from(message).to_cbor_data(); + handle_a.send_incoming(bytes); + + await_status(&status_a, peer_b.xid, PeerStage::Responder).await; + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + + write_gate.add_permits(1); + let _ = tokio::time::timeout(Duration::from_millis(100), outbound_a.recv()) + .await + .unwrap() + .unwrap(); + + write_gate.add_permits(1); + let second = tokio::time::timeout(Duration::from_millis(50), outbound_a.recv()).await; + assert!( + second.is_err(), + "expected queued handshake reply to be dropped" + ); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn multi_peer_simultaneous_handshakes() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let (platform_c, outbound_c, status_c) = TestPlatform::new(3); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + let peer_c = peer_identity(&platform_c); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = + new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); + let (runtime_c, handle_c) = + new_runtime(platform_c, RuntimeConfig::new(Duration::from_millis(200))); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + tokio::task::spawn_local(async move { runtime_c.run().await }); + + spawn_routed_forwarder( + outbound_a, + vec![ + (peer_b.xid, handle_b.clone()), + (peer_c.xid, handle_c.clone()), + ], + ); + spawn_routed_forwarder(outbound_b, vec![(peer_a.xid, handle_a.clone())]); + spawn_routed_forwarder(outbound_c, vec![(peer_a.xid, handle_a.clone())]); + + let _ = register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + let _ = register_peers(&handle_a, &handle_c, &peer_a, &peer_c); + + handle_a.connect(peer_b.xid).unwrap(); + handle_a.connect(peer_c.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_a, peer_c.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + await_status(&status_c, peer_a.xid, PeerStage::Connected).await; + }) + .await; +} diff --git a/ql/src/tests/heartbeat.rs b/ql/src/tests/heartbeat.rs new file mode 100644 index 00000000..cc73b271 --- /dev/null +++ b/ql/src/tests/heartbeat.rs @@ -0,0 +1,641 @@ +use bc_components::SymmetricKey; + +use super::*; + +#[tokio::test(flavor = "current_thread")] +async fn heartbeat_ignored_without_session() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, _status_a) = TestPlatform::new(1); + let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); + + let peer_a = platform_a.xid(); + let peer_b = platform_b.xid(); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + tokio::task::spawn_local(async move { runtime_a.run().await }); + + handle_a.register_peer( + peer_b, + platform_b.signing_public_key().clone(), + platform_b.encapsulation_public_key().clone(), + ); + + let heartbeat = wire::heartbeat::encrypt_heartbeat( + QlHeader { + sender: peer_b, + recipient: peer_a, + }, + &SymmetricKey::new(), + HeartbeatBody { + message_id: MessageId(1), + valid_until: now_secs().saturating_add(60), + }, + ); + let bytes = CBOR::from(heartbeat).to_cbor_data(); + handle_a.send_incoming(bytes); + + let result = tokio::time::timeout(Duration::from_millis(50), outbound_a.recv()).await; + assert!(result.is_err(), "expected heartbeat to be ignored"); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn keepalive_disabled_no_heartbeat() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + + let signing_a = platform_a.signing_public_key().clone(); + let signing_b = platform_b.signing_public_key().clone(); + let encap_a = platform_a.encapsulation_public_key().clone(); + let encap_b = platform_b.encapsulation_public_key().clone(); + let peer_a = platform_a.xid(); + let peer_b = platform_b.xid(); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); + spawn_forwarder(outbound_b, handle_a.clone()); + + handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); + handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); + + handle_a.connect(peer_b).unwrap(); + + await_status(&status_a, peer_b, PeerStage::Connected).await; + await_status(&status_b, peer_a, PeerStage::Connected).await; + + let result = tokio::time::timeout(Duration::from_millis(120), heartbeat_rx.recv()).await; + assert!(result.is_err(), "unexpected heartbeat while disabled"); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn heartbeat_sent_after_idle() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(30), + timeout: Duration::from_millis(80), + }; + let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); + let config_b = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + + let signing_a = platform_a.signing_public_key().clone(); + let signing_b = platform_b.signing_public_key().clone(); + let encap_a = platform_a.encapsulation_public_key().clone(); + let encap_b = platform_b.encapsulation_public_key().clone(); + let peer_a = platform_a.xid(); + let peer_b = platform_b.xid(); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); + spawn_forwarder(outbound_b, handle_a.clone()); + + handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); + handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); + + handle_a.connect(peer_b).unwrap(); + + await_status(&status_a, peer_b, PeerStage::Connected).await; + await_status(&status_b, peer_a, PeerStage::Connected).await; + + tokio::time::timeout(Duration::from_millis(200), heartbeat_rx.recv()) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn heartbeat_reply_when_connected() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(30), + timeout: Duration::from_millis(80), + }; + let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); + let config_b = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + + let signing_a = platform_a.signing_public_key().clone(); + let signing_b = platform_b.signing_public_key().clone(); + let encap_a = platform_a.encapsulation_public_key().clone(); + let encap_b = platform_b.encapsulation_public_key().clone(); + let peer_a = platform_a.xid(); + let peer_b = platform_b.xid(); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_ab_tx, heartbeat_ab_rx) = async_channel::unbounded(); + let (heartbeat_ba_tx, heartbeat_ba_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_ab_tx); + spawn_heartbeat_tap_forwarder(outbound_b, handle_a.clone(), heartbeat_ba_tx); + + handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); + handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); + + handle_a.connect(peer_b).unwrap(); + + await_status(&status_a, peer_b, PeerStage::Connected).await; + await_status(&status_b, peer_a, PeerStage::Connected).await; + + tokio::time::timeout(Duration::from_millis(200), heartbeat_ab_rx.recv()) + .await + .unwrap() + .unwrap(); + tokio::time::timeout(Duration::from_millis(200), heartbeat_ba_rx.recv()) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn any_message_clears_pending() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(120), + timeout: Duration::from_millis(40), + }; + let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); + let config_b = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + + let signing_a = platform_a.signing_public_key().clone(); + let signing_b = platform_b.signing_public_key().clone(); + let encap_a = platform_a.encapsulation_public_key().clone(); + let encap_b = platform_b.encapsulation_public_key().clone(); + let peer_a = platform_a.xid(); + let peer_b = platform_b.xid(); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); + spawn_drop_heartbeat_forwarder(outbound_b, handle_a.clone()); + + handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); + handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); + + handle_a.connect(peer_b).unwrap(); + + await_status(&status_a, peer_b, PeerStage::Connected).await; + await_status(&status_b, peer_a, PeerStage::Connected).await; + + tokio::time::timeout(Duration::from_millis(200), heartbeat_rx.recv()) + .await + .unwrap() + .unwrap(); + + handle_b.send_event_raw(peer_a, RouteId(99), CBOR::from(1u8)); + + let window = keep_alive.timeout + Duration::from_millis(20); + let disconnect = tokio::time::timeout(window, async { + loop { + if let Ok(event) = status_a.recv().await { + if event.peer == peer_b && event.stage == PeerStage::Disconnected { + return; + } + } + } + }) + .await; + assert!(disconnect.is_err(), "unexpected disconnect"); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn heartbeat_timeout_disconnects_and_drops_outbound() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(80), + timeout: Duration::from_millis(60), + }; + let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); + let config_b = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(2); + let (platform_b, outbound_b, status_b) = TestPlatform::new(1); + + let signing_a = platform_a.signing_public_key().clone(); + let signing_b = platform_b.signing_public_key().clone(); + let encap_a = platform_a.encapsulation_public_key().clone(); + let encap_b = platform_b.encapsulation_public_key().clone(); + let peer_a = platform_a.xid(); + let peer_b = platform_b.xid(); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let drop_flag = Arc::new(AtomicBool::new(false)); + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_gated_forwarder(outbound_b, handle_a.clone(), drop_flag.clone()); + + handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); + handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); + + handle_a.connect(peer_b).unwrap(); + + await_status(&status_a, peer_b, PeerStage::Connected).await; + await_status(&status_b, peer_a, PeerStage::Connected).await; + + drop_flag.store(true, Ordering::Relaxed); + + let response = handle_a.send_request_raw( + peer_b, + RouteId(9), + CBOR::from(9u8), + RequestConfig { + timeout: Some(Duration::from_millis(200)), + }, + ); + + await_status(&status_a, peer_b, PeerStage::Disconnected).await; + + let result = tokio::time::timeout(Duration::from_millis(300), response.recv()) + .await + .unwrap(); + assert!( + matches!(result, Err(QlError::SendFailed)), + "unexpected result: {result:?}" + ); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn no_ping_pong() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(200), + timeout: Duration::from_millis(60), + }; + let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); + let config_b = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + + let signing_a = platform_a.signing_public_key().clone(); + let signing_b = platform_b.signing_public_key().clone(); + let encap_a = platform_a.encapsulation_public_key().clone(); + let encap_b = platform_b.encapsulation_public_key().clone(); + let peer_a = platform_a.xid(); + let peer_b = platform_b.xid(); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_ab_tx, heartbeat_ab_rx) = async_channel::unbounded(); + let (heartbeat_ba_tx, heartbeat_ba_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_ab_tx); + spawn_heartbeat_tap_forwarder(outbound_b, handle_a.clone(), heartbeat_ba_tx); + + handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); + handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); + + handle_a.connect(peer_b).unwrap(); + + await_status(&status_a, peer_b, PeerStage::Connected).await; + await_status(&status_b, peer_a, PeerStage::Connected).await; + + tokio::time::timeout(Duration::from_millis(300), heartbeat_ab_rx.recv()) + .await + .unwrap() + .unwrap(); + tokio::time::timeout(Duration::from_millis(200), heartbeat_ba_rx.recv()) + .await + .unwrap() + .unwrap(); + + let followup = + tokio::time::timeout(Duration::from_millis(50), heartbeat_ab_rx.recv()).await; + assert!(followup.is_err(), "unexpected heartbeat ping-pong"); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn invalid_heartbeat_ignored() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + + let signing_a = platform_a.signing_public_key().clone(); + let signing_b = platform_b.signing_public_key().clone(); + let encap_a = platform_a.encapsulation_public_key().clone(); + let encap_b = platform_b.encapsulation_public_key().clone(); + let peer_a = platform_a.xid(); + let peer_b = platform_b.xid(); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); + spawn_forwarder(outbound_b, handle_a.clone()); + + handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); + handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); + + handle_a.connect(peer_b).unwrap(); + + await_status(&status_a, peer_b, PeerStage::Connected).await; + await_status(&status_b, peer_a, PeerStage::Connected).await; + + let heartbeat = wire::heartbeat::encrypt_heartbeat( + QlHeader { + sender: peer_b, + recipient: peer_a, + }, + &SymmetricKey::new(), + HeartbeatBody { + message_id: MessageId(42), + valid_until: now_secs().saturating_add(30), + }, + ); + let bytes = CBOR::from(heartbeat).to_cbor_data(); + handle_a.send_incoming(bytes); + + let result = tokio::time::timeout(Duration::from_millis(50), heartbeat_rx.recv()).await; + assert!(result.is_err(), "unexpected heartbeat reply"); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn multi_peer_keepalive_disconnect_isolated() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(40), + timeout: Duration::from_millis(60), + }; + let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); + let config_b = RuntimeConfig::new(Duration::from_millis(200)); + let config_c = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let (platform_c, outbound_c, status_c) = TestPlatform::new(3); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + let peer_c = peer_identity(&platform_c); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + let (runtime_c, handle_c) = new_runtime(platform_c, config_c); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + tokio::task::spawn_local(async move { runtime_c.run().await }); + + let drop_b_to_a = Arc::new(AtomicBool::new(false)); + spawn_routed_forwarder( + outbound_a, + vec![ + (peer_b.xid, handle_b.clone()), + (peer_c.xid, handle_c.clone()), + ], + ); + spawn_routed_forwarder_with_filter(outbound_b, vec![(peer_a.xid, handle_a.clone())], { + let drop_b_to_a = drop_b_to_a.clone(); + move |record| { + !(drop_b_to_a.load(Ordering::Relaxed) && record.header.recipient == peer_a.xid) + } + }); + spawn_routed_forwarder(outbound_c, vec![(peer_a.xid, handle_a.clone())]); + + let _ = register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + let _ = register_peers(&handle_a, &handle_c, &peer_a, &peer_c); + + handle_a.connect(peer_b.xid).unwrap(); + handle_a.connect(peer_c.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_a, peer_c.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + await_status(&status_c, peer_a.xid, PeerStage::Connected).await; + + drop_b_to_a.store(true, Ordering::Relaxed); + + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + + let disconnect = + tokio::time::timeout(keep_alive.timeout + Duration::from_millis(80), async { + loop { + if let Ok(event) = status_a.recv().await { + if event.peer == peer_c.xid && event.stage == PeerStage::Disconnected { + return; + } + } + } + }) + .await; + assert!(disconnect.is_err(), "unexpected disconnect for peer C"); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn multi_peer_disconnect_drops_outbound_for_one() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(40), + timeout: Duration::from_millis(60), + }; + let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); + let config_b = RuntimeConfig::new(Duration::from_millis(200)); + let config_c = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let (platform_c, outbound_c, status_c, inbound_c) = InboundPlatform::new(3); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + let peer_c = peer_identity(&platform_c); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + let (runtime_c, handle_c) = new_runtime(platform_c, config_c); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + tokio::task::spawn_local(async move { runtime_c.run().await }); + + let drop_b_to_a = Arc::new(AtomicBool::new(false)); + spawn_routed_forwarder( + outbound_a, + vec![ + (peer_b.xid, handle_b.clone()), + (peer_c.xid, handle_c.clone()), + ], + ); + spawn_routed_forwarder_with_filter(outbound_b, vec![(peer_a.xid, handle_a.clone())], { + let drop_b_to_a = drop_b_to_a.clone(); + move |record| { + !(drop_b_to_a.load(Ordering::Relaxed) && record.header.recipient == peer_a.xid) + } + }); + spawn_routed_forwarder(outbound_c, vec![(peer_a.xid, handle_a.clone())]); + + let _ = register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + let _ = register_peers(&handle_a, &handle_c, &peer_a, &peer_c); + + handle_a.connect(peer_b.xid).unwrap(); + handle_a.connect(peer_c.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_a, peer_c.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + await_status(&status_c, peer_a.xid, PeerStage::Connected).await; + + let inbound_task = tokio::task::spawn_local(async move { + if let Ok(HandlerEvent::Request(request)) = inbound_c.recv().await { + let _ = request.respond_to.respond(55u8); + } + }); + + drop_b_to_a.store(true, Ordering::Relaxed); + + let request_b = handle_a.send_request_raw( + peer_b.xid, + RouteId(10), + CBOR::from(10u8), + RequestConfig { + timeout: Some(Duration::from_millis(200)), + }, + ); + let request_c = handle_a.send_request_raw( + peer_c.xid, + RouteId(11), + CBOR::from(11u8), + RequestConfig { + timeout: Some(Duration::from_millis(200)), + }, + ); + + let response_c = tokio::time::timeout(Duration::from_millis(200), request_c.recv()) + .await + .expect("response wait") + .expect("response channel"); + let value: u8 = response_c.try_into().unwrap(); + assert_eq!(value, 55u8); + + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + + let result_b = tokio::time::timeout(Duration::from_millis(200), request_b.recv()) + .await + .expect("response wait"); + assert!(matches!(result_b, Err(QlError::SendFailed))); + + let _ = inbound_task.await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn multi_peer_activity_is_per_peer() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(100), + timeout: Duration::from_millis(40), + }; + let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); + let config_b = RuntimeConfig::new(Duration::from_millis(200)); + let config_c = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let (platform_c, outbound_c, status_c) = TestPlatform::new(3); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + let peer_c = peer_identity(&platform_c); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + let (runtime_c, handle_c) = new_runtime(platform_c, config_c); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + tokio::task::spawn_local(async move { runtime_c.run().await }); + + let drop_all_c = Arc::new(AtomicBool::new(false)); + spawn_routed_forwarder( + outbound_a, + vec![ + (peer_b.xid, handle_b.clone()), + (peer_c.xid, handle_c.clone()), + ], + ); + spawn_drop_heartbeat_forwarder(outbound_b, handle_a.clone()); + spawn_gated_forwarder(outbound_c, handle_a.clone(), drop_all_c.clone()); + + let _ = register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + let _ = register_peers(&handle_a, &handle_c, &peer_a, &peer_c); + + handle_a.connect(peer_b.xid).unwrap(); + handle_a.connect(peer_c.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_a, peer_c.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + await_status(&status_c, peer_a.xid, PeerStage::Connected).await; + + drop_all_c.store(true, Ordering::Relaxed); + + tokio::time::sleep(keep_alive.interval + Duration::from_millis(5)).await; + + handle_b.send_event_raw(peer_a.xid, RouteId(99), CBOR::from(1u8)); + + await_status(&status_a, peer_c.xid, PeerStage::Disconnected).await; + + let disconnect = + tokio::time::timeout(keep_alive.timeout + Duration::from_millis(30), async { + loop { + if let Ok(event) = status_a.recv().await { + if event.peer == peer_b.xid && event.stage == PeerStage::Disconnected { + return; + } + } + } + }) + .await; + assert!(disconnect.is_err(), "unexpected disconnect for peer B"); + }) + .await; +} diff --git a/ql/src/tests/mod.rs b/ql/src/tests/mod.rs new file mode 100644 index 00000000..da48c1f5 --- /dev/null +++ b/ql/src/tests/mod.rs @@ -0,0 +1,625 @@ +use std::{ + future::Future, + sync::{ + atomic::{AtomicBool, AtomicU8, Ordering}, + Arc, + }, + time::Duration, +}; + +use async_channel::{Receiver, Sender}; +use bc_components::{ + MLDSAPrivateKey, MLDSAPublicKey, MLKEMPrivateKey, MLKEMPublicKey, MLDSA, MLKEM, XID, +}; +use dcbor::CBOR; +use tokio::{sync::Semaphore, task::LocalSet}; + +use crate::{ + platform::{PlatformFuture, QlPlatform, QlPlatformExt}, + runtime::{ + internal::now_secs, new_runtime, HandlerEvent, KeepAliveConfig, PeerSession, RequestConfig, + RuntimeConfig, RuntimeHandle, + }, + wire::{ + self, + handshake::HandshakeRecord, + heartbeat::HeartbeatBody, + message::{encrypt_message, MessageBody, MessageKind, Nack}, + pair, QlHeader, QlPayload, QlRecord, + }, + MessageId, QlError, RouteId, +}; + +mod handshake; +mod heartbeat; +mod persistence; +mod requests; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PeerStage { + Disconnected, + Initiator, + Responder, + Connected, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct StatusEvent { + peer: XID, + stage: PeerStage, +} + +struct TestPlatform { + signing_private: MLDSAPrivateKey, + signing_public: MLDSAPublicKey, + encapsulation_private: MLKEMPrivateKey, + encapsulation_public: MLKEMPublicKey, + outbound: Sender>, + status: Sender, + nonce_seed: u8, + nonce_counter: AtomicU8, +} + +impl TestPlatform { + fn new(seed: u8) -> (Self, Receiver>, Receiver) { + let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); + let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); + let (outbound, outbound_rx) = async_channel::unbounded(); + let (status, status_rx) = async_channel::unbounded(); + ( + Self { + signing_private, + signing_public, + encapsulation_private, + encapsulation_public, + outbound, + status, + nonce_seed: seed, + nonce_counter: AtomicU8::new(0), + }, + outbound_rx, + status_rx, + ) + } + + fn signing_public_key(&self) -> &MLDSAPublicKey { + &self.signing_public + } + + fn encapsulation_public_key(&self) -> &MLKEMPublicKey { + &self.encapsulation_public + } +} + +impl QlPlatform for TestPlatform { + fn signing_private_key(&self) -> &MLDSAPrivateKey { + &self.signing_private + } + + fn signing_public_key(&self) -> &MLDSAPublicKey { + &self.signing_public + } + + fn encapsulation_private_key(&self) -> &MLKEMPrivateKey { + &self.encapsulation_private + } + + fn encapsulation_public_key(&self) -> &MLKEMPublicKey { + &self.encapsulation_public + } + + fn fill_random_bytes(&self, data: &mut [u8]) { + let value = self + .nonce_seed + .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); + data.fill(value); + } + + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { + let outbound = self.outbound.clone(); + Box::pin(async move { + outbound + .send(message) + .await + .map_err(|_| QlError::InvalidPayload) + }) + } + + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { + Box::pin(tokio::time::sleep(duration)) + } + + fn load_peers(&self) -> PlatformFuture<'_, Vec> { + Box::pin(async { Vec::new() }) + } + + fn persist_peers(&self, _peers: Vec) {} + + fn handle_peer_status(&self, peer: XID, session: &PeerSession) { + let stage = match session { + PeerSession::Disconnected => PeerStage::Disconnected, + PeerSession::Initiator { .. } => PeerStage::Initiator, + PeerSession::Responder { .. } => PeerStage::Responder, + PeerSession::Connected { .. } => PeerStage::Connected, + }; + let _ = self.status.try_send(StatusEvent { peer, stage }); + } + + fn handle_inbound(&self, _event: crate::runtime::HandlerEvent) {} +} + +struct BlockingPlatform { + signing_private: MLDSAPrivateKey, + signing_public: MLDSAPublicKey, + encapsulation_private: MLKEMPrivateKey, + encapsulation_public: MLKEMPublicKey, + outbound: Sender>, + status: Sender, + nonce_seed: u8, + nonce_counter: AtomicU8, + write_gate: Arc, +} + +struct InboundPlatform { + signing_private: MLDSAPrivateKey, + signing_public: MLDSAPublicKey, + encapsulation_private: MLKEMPrivateKey, + encapsulation_public: MLKEMPublicKey, + outbound: Sender>, + status: Sender, + inbound: Sender, + nonce_seed: u8, + nonce_counter: AtomicU8, +} + +impl InboundPlatform { + fn new( + seed: u8, + ) -> ( + Self, + Receiver>, + Receiver, + Receiver, + ) { + let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); + let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); + let (outbound, outbound_rx) = async_channel::unbounded(); + let (status, status_rx) = async_channel::unbounded(); + let (inbound, inbound_rx) = async_channel::unbounded(); + ( + Self { + signing_private, + signing_public, + encapsulation_private, + encapsulation_public, + outbound, + status, + inbound, + nonce_seed: seed, + nonce_counter: AtomicU8::new(0), + }, + outbound_rx, + status_rx, + inbound_rx, + ) + } +} + +impl QlPlatform for InboundPlatform { + fn signing_private_key(&self) -> &MLDSAPrivateKey { + &self.signing_private + } + + fn signing_public_key(&self) -> &MLDSAPublicKey { + &self.signing_public + } + + fn encapsulation_private_key(&self) -> &MLKEMPrivateKey { + &self.encapsulation_private + } + + fn encapsulation_public_key(&self) -> &MLKEMPublicKey { + &self.encapsulation_public + } + + fn fill_random_bytes(&self, data: &mut [u8]) { + let value = self + .nonce_seed + .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); + data.fill(value); + } + + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { + let outbound = self.outbound.clone(); + Box::pin(async move { + outbound + .send(message) + .await + .map_err(|_| QlError::InvalidPayload) + }) + } + + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { + Box::pin(tokio::time::sleep(duration)) + } + + fn load_peers(&self) -> PlatformFuture<'_, Vec> { + Box::pin(async { Vec::new() }) + } + + fn persist_peers(&self, _peers: Vec) {} + + fn handle_peer_status(&self, peer: XID, session: &PeerSession) { + let stage = match session { + PeerSession::Disconnected => PeerStage::Disconnected, + PeerSession::Initiator { .. } => PeerStage::Initiator, + PeerSession::Responder { .. } => PeerStage::Responder, + PeerSession::Connected { .. } => PeerStage::Connected, + }; + let _ = self.status.try_send(StatusEvent { peer, stage }); + } + + fn handle_inbound(&self, event: HandlerEvent) { + let _ = self.inbound.try_send(event); + } +} + +impl BlockingPlatform { + fn new( + seed: u8, + ) -> ( + Self, + Receiver>, + Receiver, + Arc, + ) { + let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); + let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); + let (outbound, outbound_rx) = async_channel::unbounded(); + let (status, status_rx) = async_channel::unbounded(); + let write_gate = Arc::new(Semaphore::new(0)); + ( + Self { + signing_private, + signing_public, + encapsulation_private, + encapsulation_public, + outbound, + status, + nonce_seed: seed, + nonce_counter: AtomicU8::new(0), + write_gate: write_gate.clone(), + }, + outbound_rx, + status_rx, + write_gate, + ) + } +} + +impl QlPlatform for BlockingPlatform { + fn signing_private_key(&self) -> &MLDSAPrivateKey { + &self.signing_private + } + + fn signing_public_key(&self) -> &MLDSAPublicKey { + &self.signing_public + } + + fn encapsulation_private_key(&self) -> &MLKEMPrivateKey { + &self.encapsulation_private + } + + fn encapsulation_public_key(&self) -> &MLKEMPublicKey { + &self.encapsulation_public + } + + fn fill_random_bytes(&self, data: &mut [u8]) { + let value = self + .nonce_seed + .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); + data.fill(value); + } + + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { + let outbound = self.outbound.clone(); + let write_gate = self.write_gate.clone(); + Box::pin(async move { + let _permit = write_gate.acquire().await.unwrap(); + outbound + .send(message) + .await + .map_err(|_| QlError::InvalidPayload) + }) + } + + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { + Box::pin(tokio::time::sleep(duration)) + } + + fn load_peers(&self) -> PlatformFuture<'_, Vec> { + Box::pin(async { Vec::new() }) + } + + fn persist_peers(&self, _peers: Vec) {} + + fn handle_peer_status(&self, peer: XID, session: &PeerSession) { + let stage = match session { + PeerSession::Disconnected => PeerStage::Disconnected, + PeerSession::Initiator { .. } => PeerStage::Initiator, + PeerSession::Responder { .. } => PeerStage::Responder, + PeerSession::Connected { .. } => PeerStage::Connected, + }; + let _ = self.status.try_send(StatusEvent { peer, stage }); + } + + fn handle_inbound(&self, _event: crate::runtime::HandlerEvent) {} +} + +async fn run_local_test(future: F) +where + F: Future, +{ + let local = LocalSet::new(); + local.run_until(future).await; +} + +fn spawn_forwarder(outbound: Receiver>, handle: RuntimeHandle) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + handle.send_incoming(bytes); + } + }); +} + +fn is_heartbeat(bytes: &[u8]) -> bool { + let Ok(record) = CBOR::try_from_data(bytes).and_then(QlRecord::try_from) else { + return false; + }; + matches!(record.payload, QlPayload::Heartbeat(_)) +} + +fn spawn_heartbeat_tap_forwarder( + outbound: Receiver>, + handle: RuntimeHandle, + heartbeat_tx: Sender<()>, +) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + if is_heartbeat(&bytes) { + let _ = heartbeat_tx.send(()).await; + } + handle.send_incoming(bytes); + } + }); +} + +fn spawn_drop_heartbeat_forwarder(outbound: Receiver>, handle: RuntimeHandle) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + if is_heartbeat(&bytes) { + continue; + } + handle.send_incoming(bytes); + } + }); +} + +fn spawn_gated_forwarder( + outbound: Receiver>, + handle: RuntimeHandle, + drop_flag: Arc, +) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + if drop_flag.load(Ordering::Relaxed) { + continue; + } + handle.send_incoming(bytes); + } + }); +} + +fn spawn_routed_forwarder(outbound: Receiver>, routes: Vec<(XID, RuntimeHandle)>) { + spawn_routed_forwarder_with_filter(outbound, routes, |_| true); +} + +fn spawn_routed_forwarder_with_filter( + outbound: Receiver>, + routes: Vec<(XID, RuntimeHandle)>, + filter: F, +) where + F: Fn(&QlRecord) -> bool + Send + Sync + 'static, +{ + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + let Ok(record) = CBOR::try_from_data(&bytes).and_then(QlRecord::try_from) else { + continue; + }; + if !filter(&record) { + continue; + } + if let Some((_, handle)) = routes + .iter() + .find(|(peer, _)| *peer == record.header.recipient) + { + handle.send_incoming(bytes); + } + } + }); +} + +#[derive(Clone)] +struct PeerIdentity { + xid: XID, + signing_key: MLDSAPublicKey, + encapsulation_key: MLKEMPublicKey, +} + +fn peer_identity(platform: &impl QlPlatformExt) -> PeerIdentity { + PeerIdentity { + xid: platform.xid(), + signing_key: platform.signing_public_key().clone(), + encapsulation_key: platform.encapsulation_public_key().clone(), + } +} + +fn register_peers( + handle_a: &RuntimeHandle, + handle_b: &RuntimeHandle, + identity_a: &PeerIdentity, + identity_b: &PeerIdentity, +) -> (XID, XID) { + let peer_a = identity_a.xid; + let peer_b = identity_b.xid; + handle_a.register_peer( + peer_b, + identity_b.signing_key.clone(), + identity_b.encapsulation_key.clone(), + ); + handle_b.register_peer( + peer_a, + identity_a.signing_key.clone(), + identity_a.encapsulation_key.clone(), + ); + (peer_a, peer_b) +} + +async fn await_status( + receiver: &Receiver, + peer: XID, + stage: PeerStage, +) -> StatusEvent { + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if let Ok(event) = receiver.recv().await { + if event.peer == peer && event.stage == stage { + return event; + } + } + } + }) + .await + .unwrap() +} + +#[test] +fn protocol_record_size_breakdown() { + let (platform_a, _outbound_a, _status_a) = TestPlatform::new(1); + let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); + + let initiator = platform_a.xid(); + let responder = platform_b.xid(); + + let (hello, initiator_secret) = wire::handshake::build_hello( + &platform_a, + initiator, + responder, + platform_b.encapsulation_public_key(), + ) + .unwrap(); + let hello_record = QlRecord { + header: QlHeader { + sender: initiator, + recipient: responder, + }, + payload: QlPayload::Handshake(HandshakeRecord::Hello(hello.clone())), + }; + let hello_size = CBOR::from(hello_record).to_cbor_data().len(); + + let (hello_reply, responder_secrets) = wire::handshake::respond_hello( + &platform_b, + initiator, + responder, + platform_a.encapsulation_public_key(), + &hello, + ) + .unwrap(); + let reply_record = QlRecord { + header: QlHeader { + sender: responder, + recipient: initiator, + }, + payload: QlPayload::Handshake(HandshakeRecord::HelloReply(hello_reply.clone())), + }; + let reply_size = CBOR::from(reply_record).to_cbor_data().len(); + + let (confirm, session_key) = wire::handshake::build_confirm( + &platform_a, + initiator, + responder, + platform_b.signing_public_key(), + &hello, + &hello_reply, + &initiator_secret, + ) + .unwrap(); + let _session_key_b = wire::handshake::finalize_confirm( + initiator, + responder, + platform_a.signing_public_key(), + &hello, + &hello_reply, + &confirm, + &responder_secrets, + ) + .unwrap(); + let confirm_record = QlRecord { + header: QlHeader { + sender: initiator, + recipient: responder, + }, + payload: QlPayload::Handshake(HandshakeRecord::Confirm(confirm)), + }; + let confirm_size = CBOR::from(confirm_record).to_cbor_data().len(); + + let pair_record = pair::build_pair_request( + &platform_a, + responder, + platform_b.encapsulation_public_key(), + MessageId(1), + Duration::from_secs(60), + ) + .unwrap(); + let pair_size = CBOR::from(pair_record).to_cbor_data().len(); + + let message_record = encrypt_message( + QlHeader { + sender: initiator, + recipient: responder, + }, + &session_key, + MessageBody { + message_id: MessageId(2), + valid_until: now_secs().saturating_add(60), + kind: MessageKind::Event, + route_id: RouteId(1), + payload: CBOR::null(), + }, + ); + let message_size = CBOR::from(message_record).to_cbor_data().len(); + + let heartbeat_record = wire::heartbeat::encrypt_heartbeat( + QlHeader { + sender: initiator, + recipient: responder, + }, + &session_key, + HeartbeatBody { + message_id: MessageId(3), + valid_until: now_secs().saturating_add(60), + }, + ); + let heartbeat_size = CBOR::from(heartbeat_record).to_cbor_data().len(); + + let print_size = |label: &str, size: usize| { + println!("{label:<21}: {size} bytes"); + }; + + print_size("ql size hello", hello_size); + print_size("ql size hello_reply", reply_size); + print_size("ql size confirm", confirm_size); + print_size("ql size pair_request", pair_size); + print_size("ql size message", message_size); + print_size("ql size heartbeat", heartbeat_size); +} diff --git a/ql/src/tests/persistence.rs b/ql/src/tests/persistence.rs new file mode 100644 index 00000000..98a4411c --- /dev/null +++ b/ql/src/tests/persistence.rs @@ -0,0 +1,228 @@ +use std::sync::atomic::{AtomicU8, Ordering}; + +use async_channel::{Receiver, Sender}; +use bc_components::{ + MLDSAPrivateKey, MLDSAPublicKey, MLKEMPrivateKey, MLKEMPublicKey, MLDSA, MLKEM, XID, +}; + +use super::*; + +type PersistPlatformParts = ( + PersistPlatform, + Receiver>, + Receiver, + Receiver>, +); + +struct PersistPlatform { + signing_private: MLDSAPrivateKey, + signing_public: MLDSAPublicKey, + encapsulation_private: MLKEMPrivateKey, + encapsulation_public: MLKEMPublicKey, + outbound: Sender>, + status: Sender, + persisted: Sender>, + loaded_peers: Vec, + nonce_seed: u8, + nonce_counter: AtomicU8, +} + +impl PersistPlatform { + fn new(seed: u8, loaded_peers: Vec) -> PersistPlatformParts { + let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); + let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); + let (outbound, outbound_rx) = async_channel::unbounded(); + let (status, status_rx) = async_channel::unbounded(); + let (persisted, persisted_rx) = async_channel::unbounded(); + ( + Self { + signing_private, + signing_public, + encapsulation_private, + encapsulation_public, + outbound, + status, + persisted, + loaded_peers, + nonce_seed: seed, + nonce_counter: AtomicU8::new(0), + }, + outbound_rx, + status_rx, + persisted_rx, + ) + } +} + +impl QlPlatform for PersistPlatform { + fn signing_private_key(&self) -> &MLDSAPrivateKey { + &self.signing_private + } + + fn signing_public_key(&self) -> &MLDSAPublicKey { + &self.signing_public + } + + fn encapsulation_private_key(&self) -> &MLKEMPrivateKey { + &self.encapsulation_private + } + + fn encapsulation_public_key(&self) -> &MLKEMPublicKey { + &self.encapsulation_public + } + + fn fill_random_bytes(&self, data: &mut [u8]) { + let value = self + .nonce_seed + .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); + data.fill(value); + } + + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { + let outbound = self.outbound.clone(); + Box::pin(async move { + outbound + .send(message) + .await + .map_err(|_| QlError::InvalidPayload) + }) + } + + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { + Box::pin(tokio::time::sleep(duration)) + } + + fn load_peers(&self) -> PlatformFuture<'_, Vec> { + let peers = self.loaded_peers.clone(); + Box::pin(async move { peers }) + } + + fn persist_peers(&self, peers: Vec) { + let _ = self.persisted.try_send(peers); + } + + fn handle_peer_status(&self, peer: XID, session: &PeerSession) { + let stage = match session { + PeerSession::Disconnected => PeerStage::Disconnected, + PeerSession::Initiator { .. } => PeerStage::Initiator, + PeerSession::Responder { .. } => PeerStage::Responder, + PeerSession::Connected { .. } => PeerStage::Connected, + }; + let _ = self.status.try_send(StatusEvent { peer, stage }); + } + + fn handle_inbound(&self, _event: HandlerEvent) {} +} + +#[tokio::test(flavor = "current_thread")] +async fn register_peer_persists_snapshot() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, _outbound_a, _status_a, persisted_a) = PersistPlatform::new(1, Vec::new()); + let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); + let peer_b = platform_b.xid(); + let signing_b = platform_b.signing_public_key().clone(); + let encap_b = platform_b.encapsulation_public_key().clone(); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + tokio::task::spawn_local(async move { runtime_a.run().await }); + + handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); + + let snapshot = tokio::time::timeout(Duration::from_secs(1), persisted_a.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!( + snapshot, + vec![crate::Peer { + peer: peer_b, + signing_key: signing_b, + encapsulation_key: encap_b, + }] + ); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn loaded_peers_can_connect_without_register() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_b = peer_identity(&platform_b); + + let (platform_a, outbound_a, status_a, _persisted_a) = PersistPlatform::new( + 1, + vec![crate::Peer { + peer: peer_b.xid, + signing_key: peer_b.signing_key.clone(), + encapsulation_key: peer_b.encapsulation_key.clone(), + }], + ); + let peer_a = peer_identity(&platform_a); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = + new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + handle_b.register_peer( + peer_a.xid, + peer_a.signing_key.clone(), + peer_a.encapsulation_key.clone(), + ); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn pairing_persists_snapshot() { + run_local_test(async { + let (platform_a, _outbound_a, _status_a) = TestPlatform::new(1); + let peer_a = peer_identity(&platform_a); + + let (platform_b, _outbound_b, _status_b, persisted_b) = PersistPlatform::new(2, Vec::new()); + let peer_b = peer_identity(&platform_b); + + let pairing_message = pair::build_pair_request( + &platform_a, + peer_b.xid, + &peer_b.encapsulation_key, + MessageId(1), + Duration::from_secs(1), + ) + .unwrap(); + let pairing_bytes = CBOR::from(pairing_message).to_cbor_data(); + + let (runtime_b, handle_b) = + new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + handle_b.send_incoming(pairing_bytes); + + let snapshot = tokio::time::timeout(Duration::from_secs(1), persisted_b.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!( + snapshot, + vec![crate::Peer { + peer: peer_a.xid, + signing_key: peer_a.signing_key, + encapsulation_key: peer_a.encapsulation_key, + }] + ); + }) + .await; +} diff --git a/ql/src/tests/requests.rs b/ql/src/tests/requests.rs new file mode 100644 index 00000000..728d3127 --- /dev/null +++ b/ql/src/tests/requests.rs @@ -0,0 +1,445 @@ +use super::*; + +fn spawn_delayed_message_forwarder( + outbound: Receiver>, + handle: RuntimeHandle, + delay: Duration, +) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + let is_message = CBOR::try_from_data(&bytes) + .and_then(QlRecord::try_from) + .is_ok_and(|record| matches!(record.payload, QlPayload::Message(_))); + if is_message { + tokio::time::sleep(delay).await; + } + handle.send_incoming(bytes); + } + }); +} + +#[tokio::test(flavor = "current_thread")] +async fn request_response_round_trip() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_request_timeout(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let inbound_task = tokio::task::spawn_local(async move { + if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { + let _ = request.respond_to.respond(99u8); + } + }); + + let response = handle_a.send_request_raw( + peer_b.xid, + RouteId(7), + CBOR::from(12u8), + RequestConfig::default(), + ); + + let response = response.recv().await.unwrap(); + let value: u8 = response.try_into().unwrap(); + assert_eq!(value, 99u8); + let _ = inbound_task.await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn request_timeout_returns_error() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_request_timeout(Duration::from_millis(30)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let ticket = handle_a.send_request_raw( + peer_b.xid, + RouteId(1), + CBOR::from(1u8), + RequestConfig { + timeout: Some(Duration::from_millis(30)), + }, + ); + + let result = tokio::time::timeout(Duration::from_millis(200), ticket.recv()) + .await + .unwrap(); + assert!(matches!(result, Err(QlError::Timeout))); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn request_nack_resolves_pending() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_request_timeout(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let inbound_task = tokio::task::spawn_local(async move { + if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { + let _ = request.respond_to.respond_nack(Nack::InvalidPayload); + } + }); + + let response = handle_a.send_request_raw( + peer_b.xid, + RouteId(2), + CBOR::from(2u8), + RequestConfig::default(), + ); + + let result = response.recv().await; + assert!(matches!( + result, + Err(QlError::Nack { + nack: Nack::InvalidPayload, + .. + }) + )); + let _ = inbound_task.await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn request_dispatches_to_platform_callback() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_request_timeout(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let inbound_task = tokio::task::spawn_local(async move { + if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { + let _ = request.respond_to.respond(7u8); + } + }); + + let ticket = handle_a.send_request_raw( + peer_b.xid, + RouteId(3), + CBOR::from(1u8), + RequestConfig::default(), + ); + + let response = ticket.recv().await.unwrap(); + let value: u8 = response.try_into().unwrap(); + assert_eq!(value, 7u8); + let _ = inbound_task.await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn replayed_message_is_ignored() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + tokio::task::spawn_local({ + let handle_b = handle_b.clone(); + async move { + while let Ok(bytes) = outbound_a.recv().await { + let Ok(record) = CBOR::try_from_data(&bytes).and_then(QlRecord::try_from) + else { + handle_b.send_incoming(bytes); + continue; + }; + if matches!(record.payload, QlPayload::Message(_)) { + handle_b.send_incoming(bytes.clone()); + handle_b.send_incoming(bytes); + continue; + } + handle_b.send_incoming(bytes); + } + } + }); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + handle_a.send_event_raw(peer_b.xid, RouteId(9), CBOR::from(1u8)); + + let first = tokio::time::timeout(Duration::from_secs(1), inbound_b.recv()) + .await + .unwrap() + .unwrap(); + match first { + HandlerEvent::Event(event) => { + assert_eq!(event.message.route_id, RouteId(9)); + } + HandlerEvent::Request(_) => panic!("unexpected request"), + } + + let second = tokio::time::timeout(Duration::from_millis(50), inbound_b.recv()).await; + assert!(second.is_err(), "replay delivered a second event"); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn expired_request_returns_expired_nack() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_message_expiration(Duration::from_secs(1)) + .with_request_timeout(Duration::from_secs(3)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_delayed_message_forwarder(outbound_a, handle_b.clone(), Duration::from_millis(2000)); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let ticket = handle_a.send_request_raw( + peer_b.xid, + RouteId(4), + CBOR::from(1u8), + RequestConfig::default(), + ); + + let result = tokio::time::timeout(Duration::from_secs(5), ticket.recv()) + .await + .unwrap(); + assert!(matches!( + result, + Err(QlError::Nack { + nack: Nack::Expired, + .. + }) + )); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn expired_event_does_not_send_nack() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_message_expiration(Duration::from_secs(1)) + .with_request_timeout(Duration::from_secs(3)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_delayed_message_forwarder(outbound_a, handle_b.clone(), Duration::from_millis(1500)); + + let (message_tx, message_rx) = async_channel::unbounded(); + tokio::task::spawn_local({ + let handle_a = handle_a.clone(); + async move { + while let Ok(bytes) = outbound_b.recv().await { + let is_message = CBOR::try_from_data(&bytes) + .and_then(QlRecord::try_from) + .is_ok_and(|record| matches!(record.payload, QlPayload::Message(_))); + if is_message { + let _ = message_tx.send(()).await; + } + handle_a.send_incoming(bytes); + } + } + }); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + handle_a.send_event_raw(peer_b.xid, RouteId(10), CBOR::from(2u8)); + + let unexpected = tokio::time::timeout(Duration::from_secs(3), message_rx.recv()).await; + assert!( + unexpected.is_err(), + "expired event should not generate nack" + ); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn session_reset_fails_queued_request() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(60)); + let (platform_a, outbound_a, status_a, write_gate) = BlockingPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + let (reset_hello, _secret) = wire::handshake::build_hello( + &platform_b, + peer_b.xid, + peer_a.xid, + &peer_a.encapsulation_key, + ) + .unwrap(); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + write_gate.add_permits(2); + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let blocked = handle_a.send_request_raw( + peer_b.xid, + RouteId(12), + CBOR::from(12u8), + RequestConfig { + timeout: Some(Duration::from_millis(200)), + }, + ); + let queued = handle_a.send_request_raw( + peer_b.xid, + RouteId(13), + CBOR::from(13u8), + RequestConfig { + timeout: Some(Duration::from_millis(200)), + }, + ); + + let hello_message = QlRecord { + header: QlHeader { + sender: peer_b.xid, + recipient: peer_a.xid, + }, + payload: QlPayload::Handshake(HandshakeRecord::Hello(reset_hello)), + }; + handle_a.send_incoming(CBOR::from(hello_message).to_cbor_data()); + + await_status(&status_a, peer_b.xid, PeerStage::Responder).await; + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + + let queued_result = tokio::time::timeout(Duration::from_millis(300), queued.recv()) + .await + .unwrap(); + assert!(matches!(queued_result, Err(QlError::Timeout))); + + let blocked_result = tokio::time::timeout(Duration::from_millis(300), blocked.recv()) + .await + .unwrap(); + assert!(matches!(blocked_result, Err(QlError::Timeout))); + }) + .await; +} diff --git a/ql/src/crypto/handshake.rs b/ql/src/wire/handshake/crypto.rs similarity index 80% rename from ql/src/crypto/handshake.rs rename to ql/src/wire/handshake/crypto.rs index 74842230..e7ea43a5 100644 --- a/ql/src/crypto/handshake.rs +++ b/ql/src/wire/handshake/crypto.rs @@ -1,14 +1,10 @@ use bc_components::{ - Digest, EncapsulationCiphertext, EncapsulationPublicKey, Nonce, SigningPublicKey, SymmetricKey, - XID, + Digest, MLDSAPublicKey, MLKEMCiphertext, MLKEMPublicKey, Nonce, SymmetricKey, XID, }; use dcbor::CBOR; -use crate::{ - platform::QlPlatform, - wire::handshake::{verify_transcript_signature, Confirm, Hello, HelloReply}, - QlError, -}; +use super::{verify_transcript_signature, Confirm, Hello, HelloReply}; +use crate::{platform::QlPlatform, QlError}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct ResponderSecrets { @@ -20,7 +16,7 @@ pub fn build_hello( platform: &impl QlPlatform, _sender: XID, _recipient: XID, - recipient_encapsulation_key: &EncapsulationPublicKey, + recipient_encapsulation_key: &MLKEMPublicKey, ) -> Result<(Hello, SymmetricKey), QlError> { let nonce = next_nonce(platform); let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(); @@ -31,7 +27,7 @@ pub fn respond_hello( platform: &impl QlPlatform, initiator: XID, responder: XID, - initiator_encapsulation_key: &EncapsulationPublicKey, + initiator_encapsulation_key: &MLKEMPublicKey, hello: &Hello, ) -> Result<(HelloReply, ResponderSecrets), QlError> { let initiator_secret = platform @@ -41,10 +37,7 @@ pub fn respond_hello( let nonce = next_nonce(platform); let (responder_secret, kem_ct) = initiator_encapsulation_key.encapsulate_new_shared_secret(); let transcript = handshake_transcript(initiator, responder, hello, &nonce, &kem_ct); - let signature = platform - .signer() - .sign(&transcript) - .map_err(|_| QlError::InvalidPayload)?; + let signature = platform.signing_private_key().sign(&transcript); let reply = HelloReply { nonce, kem_ct, @@ -63,7 +56,7 @@ pub fn build_confirm( platform: &impl QlPlatform, initiator: XID, responder: XID, - responder_signing_key: &SigningPublicKey, + responder_signing_key: &MLDSAPublicKey, hello: &Hello, reply: &HelloReply, initiator_secret: &SymmetricKey, @@ -74,10 +67,7 @@ pub fn build_confirm( .encapsulation_private_key() .decapsulate_shared_secret(&reply.kem_ct) .map_err(|_| QlError::InvalidPayload)?; - let signature = platform - .signer() - .sign(&transcript) - .map_err(|_| QlError::InvalidPayload)?; + let signature = platform.signing_private_key().sign(&transcript); let confirm = Confirm { signature }; let session_key = derive_session_key(initiator_secret, &responder_secret, &transcript); Ok((confirm, session_key)) @@ -86,7 +76,7 @@ pub fn build_confirm( pub fn finalize_confirm( initiator: XID, responder: XID, - initiator_signing_key: &SigningPublicKey, + initiator_signing_key: &MLDSAPublicKey, hello: &Hello, reply: &HelloReply, confirm: &Confirm, @@ -100,12 +90,13 @@ pub fn finalize_confirm( &transcript, )) } + fn handshake_transcript( initiator: XID, responder: XID, hello: &Hello, - responder_nonce: &bc_components::Nonce, - responder_kem_ct: &EncapsulationCiphertext, + responder_nonce: &Nonce, + responder_kem_ct: &MLKEMCiphertext, ) -> Vec { CBOR::from(vec![ CBOR::from(initiator), @@ -120,7 +111,7 @@ fn handshake_transcript( fn next_nonce(platform: &impl QlPlatform) -> Nonce { let mut data = [0u8; Nonce::NONCE_SIZE]; - platform.fill_bytes(&mut data); + platform.fill_random_bytes(&mut data); Nonce::from_data(data) } diff --git a/ql/src/wire/handshake.rs b/ql/src/wire/handshake/mod.rs similarity index 87% rename from ql/src/wire/handshake.rs rename to ql/src/wire/handshake/mod.rs index eafc2da4..9eebbea7 100644 --- a/ql/src/wire/handshake.rs +++ b/ql/src/wire/handshake/mod.rs @@ -1,9 +1,12 @@ -use bc_components::{EncapsulationCiphertext, Nonce, Signature, SigningPublicKey, Verifier}; +use bc_components::{MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, Nonce}; use dcbor::CBOR; use super::take_fields; use crate::QlError; +mod crypto; +pub use crypto::*; + #[derive(Debug, Clone, PartialEq)] pub enum HandshakeRecord { Hello(Hello), @@ -14,19 +17,19 @@ pub enum HandshakeRecord { #[derive(Debug, Clone, PartialEq)] pub struct Hello { pub nonce: Nonce, - pub kem_ct: EncapsulationCiphertext, + pub kem_ct: MLKEMCiphertext, } #[derive(Debug, Clone, PartialEq)] pub struct HelloReply { pub nonce: Nonce, - pub kem_ct: EncapsulationCiphertext, - pub signature: Signature, + pub kem_ct: MLKEMCiphertext, + pub signature: MLDSASignature, } #[derive(Debug, Clone, PartialEq)] pub struct Confirm { - pub signature: Signature, + pub signature: MLDSASignature, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -51,14 +54,13 @@ impl TryFrom for HandshakeKind { } pub fn verify_transcript_signature( - signing_key: &SigningPublicKey, - signature: &Signature, + signing_key: &MLDSAPublicKey, + signature: &MLDSASignature, transcript: &[u8], ) -> Result<(), QlError> { - if signing_key.verify(signature, &transcript) { - Ok(()) - } else { - Err(QlError::InvalidSignature) + match signing_key.verify(signature, transcript) { + Ok(true) => Ok(()), + _ => Err(QlError::InvalidSignature), } } diff --git a/ql/src/crypto/heartbeat.rs b/ql/src/wire/heartbeat/crypto.rs similarity index 91% rename from ql/src/crypto/heartbeat.rs rename to ql/src/wire/heartbeat/crypto.rs index 0949e974..6f542d1b 100644 --- a/ql/src/crypto/heartbeat.rs +++ b/ql/src/wire/heartbeat/crypto.rs @@ -1,9 +1,9 @@ use bc_components::{Nonce, SymmetricKey}; use dcbor::CBOR; +use super::HeartbeatBody; use crate::{ - crypto::ensure_not_expired, - wire::{heartbeat::HeartbeatBody, QlHeader, QlPayload, QlRecord}, + wire::{ensure_not_expired, QlHeader, QlPayload, QlRecord}, QlError, }; diff --git a/ql/src/wire/heartbeat.rs b/ql/src/wire/heartbeat/mod.rs similarity index 96% rename from ql/src/wire/heartbeat.rs rename to ql/src/wire/heartbeat/mod.rs index 43f8c43a..bae5131a 100644 --- a/ql/src/wire/heartbeat.rs +++ b/ql/src/wire/heartbeat/mod.rs @@ -3,6 +3,9 @@ use dcbor::CBOR; use super::take_fields; use crate::MessageId; +mod crypto; +pub use crypto::*; + #[derive(Debug, Clone, PartialEq)] pub struct HeartbeatBody { pub message_id: MessageId, diff --git a/ql/src/crypto/message.rs b/ql/src/wire/message/crypto.rs similarity index 81% rename from ql/src/crypto/message.rs rename to ql/src/wire/message/crypto.rs index aabc5837..14613b8a 100644 --- a/ql/src/crypto/message.rs +++ b/ql/src/wire/message/crypto.rs @@ -1,12 +1,9 @@ use bc_components::{Nonce, SymmetricKey}; use dcbor::CBOR; +use super::{DecryptedMessage, MessageBody, MessageKind, Nack}; use crate::{ - crypto::ensure_not_expired, - wire::{ - message::{DecryptedMessage, MessageBody, MessageKind, Nack}, - QlHeader, QlPayload, QlRecord, - }, + wire::{ensure_not_expired, QlHeader, QlPayload, QlRecord}, MessageId, QlError, }; @@ -50,7 +47,14 @@ pub fn decrypt_message( return Err(QlError::InvalidPayload.into()); } let body = decrypt_body(session_key, encrypted)?; - ensure_not_expired(body.message_id, body.valid_until)?; + ensure_not_expired(body.message_id, body.valid_until).map_err(|err| match err { + QlError::Nack { id, nack } => MessageError::Nack { + id, + nack, + kind: body.kind, + }, + other => MessageError::Error(other), + })?; Ok(DecryptedMessage { sender: header.sender, recipient: header.recipient, diff --git a/ql/src/wire/message.rs b/ql/src/wire/message/mod.rs similarity index 99% rename from ql/src/wire/message.rs rename to ql/src/wire/message/mod.rs index 5f246d84..ea25b601 100644 --- a/ql/src/wire/message.rs +++ b/ql/src/wire/message/mod.rs @@ -4,6 +4,9 @@ use dcbor::CBOR; use super::take_fields; use crate::{MessageId, RouteId}; +mod crypto; +pub use crypto::*; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum MessageKind { Request, diff --git a/ql/src/wire/mod.rs b/ql/src/wire/mod.rs index 296f9ac7..c4b5e90f 100644 --- a/ql/src/wire/mod.rs +++ b/ql/src/wire/mod.rs @@ -7,7 +7,8 @@ pub mod pair; use bc_components::{EncryptedMessage, XID}; -use crate::wire::{handshake::HandshakeRecord, pair::PairRequestRecord}; +use self::{handshake::HandshakeRecord, pair::PairRequestRecord}; +use crate::{MessageId, QlError}; #[derive(Debug, Clone, PartialEq)] pub struct QlRecord { @@ -145,7 +146,7 @@ pub(crate) fn take_fields( ) -> Result<[CBOR; N], dcbor::Error> { use std::mem::MaybeUninit; - let mut fields: [MaybeUninit; N] = unsafe { MaybeUninit::uninit().assume_init() }; + let mut fields: [MaybeUninit; N] = [const { MaybeUninit::uninit() }; N]; for (index, slot) in fields.iter_mut().enumerate() { let Some(value) = iter.next() else { for init in &mut fields[..index] { @@ -162,6 +163,25 @@ pub(crate) fn take_fields( Ok(result) } +pub(crate) fn ensure_not_expired(id: MessageId, valid_until: u64) -> Result<(), QlError> { + let now = now_secs(); + if now > valid_until { + Err(QlError::Nack { + id, + nack: message::Nack::Expired, + }) + } else { + Ok(()) + } +} + +pub(crate) fn now_secs() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|duration| duration.as_secs()) + .unwrap_or(0) +} + #[test] fn take_fields_reads_exact_count() { let values = vec![CBOR::from(1u8), CBOR::from(2u8), CBOR::from(3u8)]; diff --git a/ql/src/crypto/pair.rs b/ql/src/wire/pair/crypto.rs similarity index 80% rename from ql/src/crypto/pair.rs rename to ql/src/wire/pair/crypto.rs index 1baec75d..8aca5b6d 100644 --- a/ql/src/crypto/pair.rs +++ b/ql/src/wire/pair/crypto.rs @@ -1,22 +1,21 @@ use std::time::Duration; -use bc_components::{EncapsulationPublicKey, Nonce, SigningPublicKey, SymmetricKey, Verifier, XID}; +use bc_components::{ + MLDSAPublicKey, MLKEMCiphertext, MLKEMPublicKey, Nonce, SigningPublicKey, SymmetricKey, XID, +}; use dcbor::CBOR; +use super::{PairRequestBody, PairRequestRecord}; use crate::{ - crypto::ensure_not_expired, platform::{QlPlatform, QlPlatformExt}, - wire::{ - pair::{PairRequestBody, PairRequestRecord}, - QlHeader, QlPayload, QlRecord, - }, + wire::{ensure_not_expired, now_secs, QlHeader, QlPayload, QlRecord}, MessageId, QlError, }; pub fn build_pair_request( platform: &impl QlPlatform, recipient: XID, - recipient_encapsulation_key: &EncapsulationPublicKey, + recipient_encapsulation_key: &MLKEMPublicKey, message_id: MessageId, valid_for: Duration, ) -> Result { @@ -25,7 +24,7 @@ pub fn build_pair_request( sender: platform.xid(), recipient, }; - let valid_until = super::now_secs().saturating_add(valid_for.as_secs()); + let valid_until = now_secs().saturating_add(valid_for.as_secs()); let signing_pub_key = platform.signing_public_key().clone(); let sender_encapsulation_key = platform.encapsulation_public_key().clone(); let proof_data = pairing_proof_data( @@ -36,10 +35,7 @@ pub fn build_pair_request( &signing_pub_key, &sender_encapsulation_key, ); - let proof = platform - .signer() - .sign(&proof_data) - .map_err(|_| QlError::InvalidPayload)?; + let proof = platform.signing_private_key().sign(&proof_data); let body = PairRequestBody { message_id, valid_until, @@ -72,7 +68,7 @@ pub fn decrypt_pair_request( } let decrypted = decrypt_body(&session_key, &encrypted)?; ensure_not_expired(decrypted.message_id, decrypted.valid_until)?; - if XID::new(&decrypted.signing_pub_key) != header.sender { + if XID::new(SigningPublicKey::MLDSA(decrypted.signing_pub_key.clone())) != header.sender { return Err(QlError::InvalidPayload); } let proof_data = pairing_proof_data( @@ -86,6 +82,7 @@ pub fn decrypt_pair_request( if decrypted .signing_pub_key .verify(&decrypted.proof, &proof_data) + .unwrap_or(false) { Ok(decrypted) } else { @@ -95,11 +92,11 @@ pub fn decrypt_pair_request( fn pairing_proof_data( header: &QlHeader, - kem_ct: &bc_components::EncapsulationCiphertext, + kem_ct: &MLKEMCiphertext, message_id: MessageId, valid_until: u64, - signing_pub_key: &SigningPublicKey, - encapsulation_pub_key: &EncapsulationPublicKey, + signing_pub_key: &MLDSAPublicKey, + encapsulation_pub_key: &MLKEMPublicKey, ) -> Vec { CBOR::from(vec![ CBOR::from(pairing_aad(header, kem_ct)), @@ -122,6 +119,6 @@ fn decrypt_body( PairRequestBody::try_from(cbor).map_err(|_| QlError::InvalidPayload) } -fn pairing_aad(header: &QlHeader, kem_ct: &bc_components::EncapsulationCiphertext) -> Vec { +fn pairing_aad(header: &QlHeader, kem_ct: &MLKEMCiphertext) -> Vec { CBOR::from(vec![CBOR::from(header.clone()), CBOR::from(kem_ct.clone())]).to_cbor_data() } diff --git a/ql/src/wire/pair.rs b/ql/src/wire/pair/mod.rs similarity index 87% rename from ql/src/wire/pair.rs rename to ql/src/wire/pair/mod.rs index 276de804..b14045eb 100644 --- a/ql/src/wire/pair.rs +++ b/ql/src/wire/pair/mod.rs @@ -1,12 +1,15 @@ -use bc_components::{EncapsulationCiphertext, EncapsulationPublicKey, Signature, SigningPublicKey}; +use bc_components::{MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, MLKEMPublicKey}; use dcbor::CBOR; use super::take_fields; use crate::MessageId; +mod crypto; +pub use crypto::*; + #[derive(Debug, Clone, PartialEq)] pub struct PairRequestRecord { - pub kem_ct: EncapsulationCiphertext, + pub kem_ct: MLKEMCiphertext, pub encrypted: bc_components::EncryptedMessage, } @@ -14,9 +17,9 @@ pub struct PairRequestRecord { pub struct PairRequestBody { pub message_id: MessageId, pub valid_until: u64, - pub signing_pub_key: SigningPublicKey, - pub encapsulation_pub_key: EncapsulationPublicKey, - pub proof: Signature, + pub signing_pub_key: MLDSAPublicKey, + pub encapsulation_pub_key: MLKEMPublicKey, + pub proof: MLDSASignature, } impl From for CBOR { From d918503b30e92cce9e16e9d3caaa6070c07122e8 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 18 Mar 2026 01:16:42 -0400 Subject: [PATCH 004/304] ql: add unpairing and stream support to pre-duplex runtime --- ql/src/lib.rs | 14 + ql/src/router.rs | 165 +++- ql/src/runtime/core.rs | 1407 ++++++++++++++++++++++++++++---- ql/src/runtime/handle.rs | 337 +++++++- ql/src/runtime/internal.rs | 236 +++++- ql/src/runtime/mod.rs | 40 +- ql/src/runtime/replay_cache.rs | 1 + ql/src/tests/mod.rs | 37 +- ql/src/tests/requests.rs | 1 + ql/src/tests/streams.rs | 552 +++++++++++++ ql/src/tests/unpair.rs | 160 ++++ ql/src/wire/mod.rs | 26 +- ql/src/wire/transfer/crypto.rs | 42 + ql/src/wire/transfer/mod.rs | 194 +++++ ql/src/wire/unpair/crypto.rs | 58 ++ ql/src/wire/unpair/mod.rs | 39 + 16 files changed, 3153 insertions(+), 156 deletions(-) create mode 100644 ql/src/tests/streams.rs create mode 100644 ql/src/tests/unpair.rs create mode 100644 ql/src/wire/transfer/crypto.rs create mode 100644 ql/src/wire/transfer/mod.rs create mode 100644 ql/src/wire/unpair/crypto.rs create mode 100644 ql/src/wire/unpair/mod.rs diff --git a/ql/src/lib.rs b/ql/src/lib.rs index 34c09ca4..6fc69d21 100644 --- a/ql/src/lib.rs +++ b/ql/src/lib.rs @@ -28,6 +28,16 @@ pub trait Event: QlCodec { const ID: RouteId; } +pub trait QlStream: QlCodec { + const ID: RouteId; + type StreamMeta: QlCodec; +} + +pub trait QlUpload: QlCodec { + const ID: RouteId; + type Response: QlCodec; +} + #[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] pub enum QlError { #[error("invalid payload")] @@ -51,4 +61,8 @@ pub enum QlError { }, #[error("cancelled")] Cancelled, + #[error("transfer cancelled")] + TransferCancelled { id: MessageId }, + #[error("transfer protocol error")] + TransferProtocol { id: MessageId }, } diff --git a/ql/src/router.rs b/ql/src/router.rs index 80c766b0..1f7c74a4 100644 --- a/ql/src/router.rs +++ b/ql/src/router.rs @@ -3,9 +3,9 @@ use std::collections::HashMap; use thiserror::Error; use crate::{ - runtime::{HandlerEvent, Responder}, + runtime::{HandlerEvent, InboundByteStream, OutboundTransfer, Responder}, wire::message::{Ack, Nack}, - Event, QlCodec, QlError, RequestResponse, RouteId, + Event, QlCodec, QlError, QlStream, QlUpload, RequestResponse, RouteId, }; pub trait RequestHandler @@ -16,6 +16,21 @@ where fn default_response() -> M::Response; } +pub trait StreamRequestHandler +where + M: QlStream, +{ + fn handle(&mut self, request: QlStreamRequest); +} + +pub trait UploadRequestHandler +where + M: QlUpload, +{ + fn handle(&mut self, request: QlUploadRequest); + fn default_response() -> M::Response; +} + pub trait EventHandler where M: Event, @@ -31,6 +46,23 @@ where pub responder: QlResponder, } +pub struct QlStreamRequest +where + M: QlStream, +{ + pub message: M, + pub responder: QlStreamResponder, +} + +pub struct QlUploadRequest +where + M: QlUpload, +{ + pub message: M, + pub body: InboundByteStream, + pub responder: QlResponder, +} + pub struct QlResponder where R: QlCodec, @@ -39,6 +71,14 @@ where default: fn() -> R, } +pub struct QlStreamResponder +where + M: QlCodec, +{ + responder: Option, + _meta: std::marker::PhantomData M>, +} + impl QlResponder where R: QlCodec, @@ -58,6 +98,21 @@ where } } +impl QlStreamResponder +where + M: QlCodec, +{ + pub fn respond_stream(mut self, meta: M) -> Result { + let responder = self.responder.take().unwrap(); + responder.respond_stream(meta) + } + + pub fn respond_nack(mut self, reason: Nack) -> Result<(), QlError> { + let responder = self.responder.take().unwrap(); + responder.respond_nack(reason) + } +} + impl Drop for QlResponder where R: QlCodec, @@ -70,6 +125,17 @@ where } } +impl Drop for QlStreamResponder +where + M: QlCodec, +{ + fn drop(&mut self) { + if let Some(responder) = self.responder.take() { + let _ = responder.respond_nack(Nack::Unknown); + } + } +} + #[derive(Debug, Error)] pub enum RouterError { #[error(transparent)] @@ -107,6 +173,22 @@ impl RouterBuilder { self.add_handler(M::ID, handle_request::) } + pub fn add_stream_request_handler(self) -> Self + where + M: QlStream, + S: StreamRequestHandler, + { + self.add_handler(M::ID, handle_stream_request::) + } + + pub fn add_upload_request_handler(self) -> Self + where + M: QlUpload, + S: UploadRequestHandler, + { + self.add_handler(M::ID, handle_upload_request::) + } + pub fn add_event_handler(self) -> Self where M: Event, @@ -154,6 +236,17 @@ impl Router { }; handler(&mut self.state, HandlerEvent::Request(request)) } + HandlerEvent::UploadRequest(request) => { + let route_id = request.route_id; + let handler = match self.handlers.get(&route_id) { + Some(handler) => handler, + None => { + let _ = request.respond_to.respond_nack(Nack::UnknownRoute); + return Ok(()); + } + }; + handler(&mut self.state, HandlerEvent::UploadRequest(request)) + } HandlerEvent::Event(event) => { let route_id = event.message.route_id; let handler = self @@ -173,6 +266,10 @@ where { let (payload, responder) = match event { HandlerEvent::Request(request) => (request.message.payload, request.respond_to), + HandlerEvent::UploadRequest(request) => { + let _ = request.respond_to.respond_nack(Nack::InvalidPayload); + return Err(RouterError::Runtime(QlError::InvalidPayload)); + } HandlerEvent::Event(_) => return Err(RouterError::Runtime(QlError::InvalidPayload)), }; let message = match M::try_from(payload) { @@ -190,6 +287,66 @@ where Ok(()) } +fn handle_stream_request(state: &mut S, event: HandlerEvent) -> Result<(), RouterError> +where + M: QlStream, + S: StreamRequestHandler, +{ + let (payload, responder) = match event { + HandlerEvent::Request(request) => (request.message.payload, request.respond_to), + HandlerEvent::UploadRequest(request) => { + let _ = request.respond_to.respond_nack(Nack::InvalidPayload); + return Err(RouterError::Runtime(QlError::InvalidPayload)); + } + HandlerEvent::Event(_) => return Err(RouterError::Runtime(QlError::InvalidPayload)), + }; + let message = match M::try_from(payload) { + Ok(message) => message, + Err(error) => { + let _ = responder.respond_nack(Nack::InvalidPayload); + return Err(RouterError::Decode(error)); + } + }; + let responder = QlStreamResponder { + responder: Some(responder), + _meta: std::marker::PhantomData, + }; + state.handle(QlStreamRequest { message, responder }); + Ok(()) +} + +fn handle_upload_request(state: &mut S, event: HandlerEvent) -> Result<(), RouterError> +where + M: QlUpload, + S: UploadRequestHandler, +{ + let (meta, body, responder) = match event { + HandlerEvent::UploadRequest(request) => (request.meta, request.body, request.respond_to), + HandlerEvent::Request(request) => { + let _ = request.respond_to.respond_nack(Nack::InvalidPayload); + return Err(RouterError::Runtime(QlError::InvalidPayload)); + } + HandlerEvent::Event(_) => return Err(RouterError::Runtime(QlError::InvalidPayload)), + }; + let message = match M::try_from(meta) { + Ok(message) => message, + Err(error) => { + let _ = responder.respond_nack(Nack::InvalidPayload); + return Err(RouterError::Decode(error)); + } + }; + let responder = QlResponder { + responder: Some(responder), + default: S::default_response, + }; + state.handle(QlUploadRequest { + message, + body, + responder, + }); + Ok(()) +} + fn handle_event(state: &mut S, event: HandlerEvent) -> Result<(), RouterError> where M: Event, @@ -198,6 +355,10 @@ where let (payload, responder) = match event { HandlerEvent::Event(event) => (event.message.payload, None), HandlerEvent::Request(request) => (request.message.payload, Some(request.respond_to)), + HandlerEvent::UploadRequest(request) => { + let _ = request.respond_to.respond_nack(Nack::InvalidPayload); + return Err(RouterError::Runtime(QlError::InvalidPayload)); + } }; let message = match M::try_from(payload) { Ok(message) => message, diff --git a/ql/src/runtime/core.rs b/ql/src/runtime/core.rs index d5de10d9..6406905f 100644 --- a/ql/src/runtime/core.rs +++ b/ql/src/runtime/core.rs @@ -11,23 +11,29 @@ use crate::{ runtime::{ internal::{ next_timeout_deadline, now_secs, peer_hello_wins, HelloAction, InFlightWrite, - KeepAliveState, LoopStep, OutboundMessage, OutboundPayload, PendingEntry, - RuntimeCommand, RuntimeState, TimeoutEntry, TimeoutKind, + InboundStreamDelivery, InboundStreamItem, InboundTransferOpen, InboundTransferState, + KeepAliveState, LoopStep, OutboundAwaiting, OutboundMessage, OutboundPayload, + OutboundStreamInput, OutboundTransferStage, OutboundTransferState, PendingEntry, + PendingStreamEntry, RuntimeCommand, RuntimeState, TimeoutEntry, TimeoutKind, }, replay_cache::{ReplayKey, ReplayNamespace}, - HandlerEvent, InboundEvent, InboundRequest, InitiatorStage, KeepAliveConfig, PeerSession, - Responder, Runtime, Token, + HandlerEvent, InboundByteStream, InboundEvent, InboundRequest, InboundUploadRequest, + InitiatorStage, KeepAliveConfig, PeerSession, Responder, Runtime, Token, }, wire::{ handshake::{self, HandshakeRecord}, heartbeat::{self, HeartbeatBody}, message::{self, MessageBody, MessageKind, Nack}, pair::{self, PairRequestRecord}, + transfer::{self, TransferBody, TransferFrame}, + unpair::{self, UnpairRecord}, QlHeader, QlPayload, QlRecord, }, MessageId, QlError, RouteId, }; +const TRANSFER_RETRY_LIMIT: u8 = 5; + impl Runtime

{ pub async fn run(self) { let mut state = RuntimeState::new(); @@ -38,6 +44,7 @@ impl Runtime

{ } let mut in_flight: Option> = None; while !self.rx.is_closed() { + self.drive_outbound_transfers(&mut state); if in_flight.is_none() { in_flight = self.start_next_write(&mut state); } @@ -54,6 +61,9 @@ impl Runtime

{ RuntimeCommand::Connect { peer } => { self.handle_connect(&mut state, peer); } + RuntimeCommand::Unpair { peer } => { + self.handle_send_unpair(&mut state, peer); + } RuntimeCommand::SendRequest { recipient, route_id, @@ -65,6 +75,31 @@ impl Runtime

{ &mut state, recipient, route_id, payload, respond_to, config, ); } + RuntimeCommand::SendStreamRequest { + recipient, + route_id, + payload, + respond_to, + config, + } => { + self.handle_send_stream_request( + &mut state, recipient, route_id, payload, respond_to, config, + ); + } + RuntimeCommand::SendUploadRequest { + recipient, + route_id, + payload, + respond_to, + chunk_rx, + start, + config, + } => { + self.handle_send_upload_request( + &mut state, recipient, route_id, payload, respond_to, chunk_rx, start, + config, + ); + } RuntimeCommand::SendEvent { recipient, route_id, @@ -80,6 +115,34 @@ impl Runtime

{ } => { self.handle_send_response(&mut state, id, recipient, payload, kind); } + RuntimeCommand::StartResponseStream { + request_id, + recipient, + meta, + chunk_rx, + } => { + self.handle_start_response_stream( + &mut state, request_id, recipient, meta, chunk_rx, + ); + } + RuntimeCommand::PollOutboundTransfer { + recipient, + transfer_id, + } => { + self.drive_outbound_transfer(&mut state, recipient, transfer_id); + } + RuntimeCommand::CancelOutboundTransfer { + recipient, + transfer_id, + } => { + self.handle_cancel_outbound_transfer(&mut state, recipient, transfer_id); + } + RuntimeCommand::CancelInboundTransfer { + sender, + transfer_id, + } => { + self.handle_cancel_inbound_transfer(&mut state, sender, transfer_id); + } RuntimeCommand::Incoming(bytes) => { self.handle_incoming(&mut state, bytes); } @@ -115,6 +178,9 @@ impl Runtime

{ if let Some(entry) = state.pending.remove(&id) { let _ = entry.tx.send(Err(QlError::SendFailed)); } + if let Some(entry) = state.pending_stream.remove(&id) { + let _ = entry.tx.send(Err(QlError::SendFailed)); + } } continue; }; @@ -297,6 +363,126 @@ impl Runtime

{ ); } + fn handle_send_stream_request( + &self, + state: &mut RuntimeState, + recipient: XID, + route_id: RouteId, + payload: CBOR, + respond_to: oneshot::Sender>, + config: super::RequestConfig, + ) { + let id = state.next_message_id(); + let timeout = config + .timeout + .unwrap_or(self.config.default_request_timeout); + if timeout.is_zero() { + let _ = respond_to.send(Err(QlError::Timeout)); + return; + } + let Some(entry) = state.peers.peer(recipient) else { + let _ = respond_to.send(Err(QlError::UnknownPeer(recipient))); + return; + }; + if !entry.session.is_connected() { + let _ = respond_to.send(Err(QlError::MissingSession(recipient))); + return; + } + let valid_until = now_secs().saturating_add(self.config.message_expiration.as_secs()); + let body = MessageBody { + message_id: id, + valid_until, + kind: MessageKind::Request, + route_id, + payload, + }; + state.pending_stream.insert( + id, + PendingStreamEntry { + recipient, + tx: respond_to, + }, + ); + state.timeouts.push(Reverse(TimeoutEntry { + at: Instant::now() + timeout, + kind: TimeoutKind::Request { id }, + })); + let outbound_deadline = Instant::now() + self.config.message_expiration; + self.enqueue_outbound( + state, + recipient, + OutboundPayload::DeferredMessage(body), + outbound_deadline, + Some(id), + ); + } + + fn handle_send_upload_request( + &self, + state: &mut RuntimeState, + recipient: XID, + route_id: RouteId, + payload: CBOR, + respond_to: oneshot::Sender>, + chunk_rx: async_channel::Receiver, + start: oneshot::Sender>, + config: super::RequestConfig, + ) { + let timeout = config + .timeout + .unwrap_or(self.config.default_request_timeout); + if timeout.is_zero() { + let _ = start.send(Err(QlError::Timeout)); + return; + } + let Some(entry) = state.peers.peer(recipient) else { + let _ = start.send(Err(QlError::UnknownPeer(recipient))); + return; + }; + if !entry.session.is_connected() { + let _ = start.send(Err(QlError::MissingSession(recipient))); + return; + } + + let request_id = state.next_message_id(); + state.pending.insert( + request_id, + PendingEntry { + recipient, + tx: respond_to, + }, + ); + state.timeouts.push(Reverse(TimeoutEntry { + at: Instant::now() + timeout, + kind: TimeoutKind::Request { id: request_id }, + })); + + let transfer_id = request_id; + let key = (recipient, transfer_id); + if state.outbound_transfers.contains_key(&key) { + let _ = state.pending.remove(&request_id); + let _ = start.send(Err(QlError::SendFailed)); + return; + } + + state.outbound_transfers.insert( + key, + OutboundTransferState { + request_id, + peer: recipient, + transfer_id, + stage: OutboundTransferStage::Opening, + next_seq: 1, + open_route_id: Some(route_id), + open_meta: Some(payload), + chunk_rx, + awaiting: None, + }, + ); + + let _ = start.send(Ok(request_id)); + } + fn handle_send_event( &self, state: &mut RuntimeState, @@ -348,188 +534,1027 @@ impl Runtime

{ return; } - let valid_until = now_secs().saturating_add(self.config.message_expiration.as_secs()); - let body = MessageBody { - message_id: id, - valid_until, - kind, - route_id: RouteId(0), - payload, - }; - let outbound_deadline = Instant::now() + self.config.message_expiration; - self.enqueue_outbound( - state, - recipient, - OutboundPayload::DeferredMessage(body), - outbound_deadline, - None, - ); + let valid_until = now_secs().saturating_add(self.config.message_expiration.as_secs()); + let body = MessageBody { + message_id: id, + valid_until, + kind, + route_id: RouteId(0), + payload, + }; + let outbound_deadline = Instant::now() + self.config.message_expiration; + self.enqueue_outbound( + state, + recipient, + OutboundPayload::DeferredMessage(body), + outbound_deadline, + None, + ); + } + + fn handle_start_response_stream( + &self, + state: &mut RuntimeState, + request_id: MessageId, + recipient: XID, + meta: CBOR, + chunk_rx: async_channel::Receiver, + ) { + if !matches!( + state.peers.peer(recipient), + Some(entry) if entry.session.is_connected() + ) { + return; + } + + let transfer_id = request_id; + let key = (recipient, transfer_id); + if state.outbound_transfers.contains_key(&key) { + return; + } + + state.outbound_transfers.insert( + key, + OutboundTransferState { + request_id, + peer: recipient, + transfer_id, + stage: OutboundTransferStage::Opening, + next_seq: 1, + open_route_id: None, + open_meta: Some(meta), + chunk_rx, + awaiting: None, + }, + ); + } + + fn handle_cancel_outbound_transfer( + &self, + state: &mut RuntimeState, + recipient: XID, + transfer_id: MessageId, + ) { + let key = (recipient, transfer_id); + let mut found = false; + if let Some(transfer) = state.outbound_transfers.get_mut(&key) { + found = true; + transfer.stage = OutboundTransferStage::Cancelling; + transfer.awaiting = None; + transfer.chunk_rx.close(); + } + if found { + self.drive_outbound_transfer(state, recipient, transfer_id); + } + } + + fn handle_cancel_inbound_transfer( + &self, + state: &mut RuntimeState, + sender: XID, + transfer_id: MessageId, + ) { + if state + .inbound_transfers + .remove(&(sender, transfer_id)) + .is_some() + { + self.send_transfer_frame(state, sender, transfer_id, TransferFrame::Cancel, false); + } + } + + fn handle_send_unpair(&self, state: &mut RuntimeState, peer: XID) { + if state.peers.peer(peer).is_none() { + return; + } + let message = unpair::build_unpair_record( + &self.platform, + QlHeader { + sender: self.platform.xid(), + recipient: peer, + }, + state.next_message_id(), + now_secs().saturating_add(self.config.message_expiration.as_secs()), + ); + let bytes = CBOR::from(message).to_cbor_data(); + self.unpair_peer(state, peer); + let deadline = Instant::now() + self.config.message_expiration; + self.enqueue_outbound( + state, + peer, + OutboundPayload::PreEncoded(bytes), + deadline, + None, + ); + } + + fn handle_incoming(&self, state: &mut RuntimeState, bytes: Vec) { + let Ok(record) = CBOR::try_from_data(&bytes).and_then(QlRecord::try_from) else { + return; + }; + let QlRecord { header, payload } = record; + if header.recipient != self.platform.xid() { + return; + } + match payload { + QlPayload::Handshake(message) => { + self.handle_handshake(state, header, message); + } + QlPayload::Pair(request) => { + self.handle_pairing(state, header, request); + } + QlPayload::Unpair(unpair) => { + self.handle_unpair(state, header, unpair); + } + QlPayload::Message(encrypted) => { + self.handle_record(state, header, encrypted); + } + QlPayload::Heartbeat(encrypted) => { + self.handle_heartbeat(state, header, encrypted); + } + QlPayload::Transfer(encrypted) => { + self.handle_transfer(state, header, encrypted); + } + } + } + + fn handle_handshake( + &self, + state: &mut RuntimeState, + header: QlHeader, + message: HandshakeRecord, + ) { + match message { + HandshakeRecord::Hello(hello) => { + self.handle_hello(state, header, hello); + } + HandshakeRecord::HelloReply(reply) => { + self.handle_hello_reply(state, header, reply); + } + HandshakeRecord::Confirm(confirm) => { + self.handle_confirm(state, header, confirm); + } + } + } + + fn handle_pairing( + &self, + state: &mut RuntimeState, + header: QlHeader, + request: PairRequestRecord, + ) { + let payload = match pair::decrypt_pair_request(&self.platform, &header, request) { + Ok(payload) => payload, + Err(_) => return, + }; + let peer = XID::new(SigningPublicKey::MLDSA(payload.signing_pub_key.clone())); + state + .peers + .upsert_peer(peer, payload.signing_pub_key, payload.encapsulation_pub_key); + self.persist_peers(state); + self.handle_connect(state, peer); + } + + fn handle_unpair(&self, state: &mut RuntimeState, header: QlHeader, record: UnpairRecord) { + let peer = header.sender; + let Some(signing_key) = state + .peers + .peer(peer) + .map(|entry| entry.signing_key.clone()) + else { + return; + }; + if unpair::verify_unpair_record(&header, &record, &signing_key).is_err() { + return; + } + let replay_key = ReplayKey::new(peer, ReplayNamespace::Peer, record.message_id); + if state + .replay_cache + .check_and_store_valid_until(replay_key, record.valid_until) + { + return; + } + self.unpair_peer(state, peer); + } + + fn unpair_peer(&self, state: &mut RuntimeState, peer: XID) { + if state.peers.remove_peer(peer).is_none() { + return; + } + self.drop_outbound_for_peer(state, peer); + self.fail_pending_for_peer(state, peer); + self.fail_pending_stream_for_peer(state, peer); + self.abort_transfers_for_peer(state, peer, QlError::SendFailed); + state.replay_cache.clear_peer(peer); + self.platform + .handle_peer_status(peer, &PeerSession::Disconnected); + self.persist_peers(state); + } + + fn persist_peers(&self, state: &RuntimeState) { + self.platform.persist_peers(state.peers.all()); + } + + fn handle_record( + &self, + state: &mut RuntimeState, + header: QlHeader, + encrypted: bc_components::EncryptedMessage, + ) { + let peer = header.sender; + let session_key = match state.peers.peer(peer) { + Some(entry) => match &entry.session { + PeerSession::Connected { session_key, .. } => session_key.clone(), + _ => return, + }, + None => return, + }; + let record = match message::decrypt_message(&header, &encrypted, &session_key) { + Ok(record) => record, + Err(message::MessageError::Nack { id, nack, kind }) => { + self.handle_message_nack(state, peer, id, nack, kind); + return; + } + Err(message::MessageError::Error(_)) => return, + }; + let namespace = match record.kind { + MessageKind::Request | MessageKind::Event => ReplayNamespace::Peer, + MessageKind::Response | MessageKind::Nack => ReplayNamespace::Local, + }; + let replay_key = ReplayKey::new(peer, namespace, record.message_id); + if state + .replay_cache + .check_and_store_valid_until(replay_key, record.valid_until) + { + return; + } + self.record_activity(state, peer); + match record.kind { + MessageKind::Response => { + self.resolve_pending_ok(state, peer, record.message_id, record.payload); + } + MessageKind::Nack => { + let nack = Nack::from(record.payload); + self.resolve_pending_nack(state, peer, record.message_id, nack); + } + MessageKind::Request => { + let Some(tx) = self.tx.upgrade() else { + return; + }; + let responder = Responder::new(record.message_id, record.sender, tx); + self.platform + .handle_inbound(HandlerEvent::Request(InboundRequest { + message: record, + respond_to: responder, + })); + } + MessageKind::Event => { + self.platform + .handle_inbound(HandlerEvent::Event(InboundEvent { message: record })); + } + } + } + + fn handle_message_nack( + &self, + state: &mut RuntimeState, + peer: XID, + id: MessageId, + nack: Nack, + kind: MessageKind, + ) { + if kind != MessageKind::Request { + return; + } + self.handle_send_response(state, id, peer, CBOR::from(nack), MessageKind::Nack); + } + + fn handle_heartbeat( + &self, + state: &mut RuntimeState, + header: QlHeader, + encrypted: bc_components::EncryptedMessage, + ) { + let peer = header.sender; + let (session_key, should_reply) = { + let Some(entry) = state.peers.peer(peer) else { + return; + }; + match &entry.session { + PeerSession::Connected { + session_key, + keepalive, + } => (session_key.clone(), !keepalive.pending), + _ => return, + } + }; + if heartbeat::decrypt_heartbeat(&header, &encrypted, &session_key).is_err() { + return; + } + self.record_activity(state, peer); + if should_reply { + self.send_heartbeat_message(state, peer, session_key); + } + } + + fn handle_transfer( + &self, + state: &mut RuntimeState, + header: QlHeader, + encrypted: bc_components::EncryptedMessage, + ) { + let peer = header.sender; + let session_key = match state.peers.peer(peer) { + Some(entry) => match &entry.session { + PeerSession::Connected { session_key, .. } => session_key.clone(), + _ => return, + }, + None => return, + }; + let body = match transfer::decrypt_transfer(&header, &encrypted, &session_key) { + Ok(body) => body, + Err(_) => return, + }; + + let replay_key = ReplayKey::new(peer, ReplayNamespace::Transfer, body.message_id); + if state + .replay_cache + .check_and_store_valid_until(replay_key, body.valid_until) + { + return; + } + + self.record_activity(state, peer); + self.handle_transfer_frame(state, peer, body.transfer_id, body.frame); + } + + fn handle_transfer_frame( + &self, + state: &mut RuntimeState, + peer: XID, + transfer_id: MessageId, + frame: TransferFrame, + ) { + match frame { + TransferFrame::OpenResponse { request_id, meta } => { + self.handle_transfer_open_response(state, peer, transfer_id, request_id, meta); + } + TransferFrame::OpenRequest { + request_id, + route_id, + meta, + } => { + self.handle_transfer_open_request( + state, + peer, + transfer_id, + request_id, + route_id, + meta, + ); + } + TransferFrame::Chunk { seq, data } => { + self.handle_transfer_chunk(state, peer, transfer_id, seq, data); + } + TransferFrame::Finish { seq } => { + self.handle_transfer_finish(state, peer, transfer_id, seq); + } + TransferFrame::Ack { next_seq } => { + self.handle_transfer_ack(state, peer, transfer_id, next_seq); + } + TransferFrame::Cancel => { + self.handle_transfer_cancel(state, peer, transfer_id); + } + TransferFrame::CancelAck => { + self.handle_transfer_cancel_ack(state, peer, transfer_id); + } + } + } + + fn handle_transfer_open_response( + &self, + state: &mut RuntimeState, + peer: XID, + transfer_id: MessageId, + request_id: MessageId, + meta: CBOR, + ) { + let open = InboundTransferOpen::Response { + request_id, + meta: meta.clone(), + }; + if self.handle_duplicate_transfer_open(state, peer, transfer_id, &open) { + return; + } + + let Some(pending) = state.pending_stream.remove(&request_id) else { + self.send_transfer_frame(state, peer, transfer_id, TransferFrame::Cancel, true); + return; + }; + if pending.recipient != peer { + let _ = pending.tx.send(Err(QlError::SendFailed)); + self.send_transfer_frame(state, peer, transfer_id, TransferFrame::Cancel, true); + return; + } + + let Some(tx) = self.tx.upgrade() else { + let _ = pending.tx.send(Err(QlError::Cancelled)); + return; + }; + + let (chunk_tx, chunk_rx) = async_channel::bounded(1); + + let delivery = InboundStreamDelivery { + peer, + transfer_id, + meta, + rx: chunk_rx, + tx, + }; + if pending.tx.send(Ok(delivery)).is_err() { + self.send_transfer_frame(state, peer, transfer_id, TransferFrame::Cancel, true); + return; + } + + state.inbound_transfers.insert( + (peer, transfer_id), + InboundTransferState { + open, + expected_seq: 1, + chunk_tx, + }, + ); + + self.send_transfer_frame( + state, + peer, + transfer_id, + TransferFrame::Ack { next_seq: 1 }, + true, + ); + } + + fn handle_transfer_open_request( + &self, + state: &mut RuntimeState, + peer: XID, + transfer_id: MessageId, + request_id: MessageId, + route_id: RouteId, + meta: CBOR, + ) { + let open = InboundTransferOpen::Request { + request_id, + route_id, + meta: meta.clone(), + }; + if self.handle_duplicate_transfer_open(state, peer, transfer_id, &open) { + return; + } + + let Some(tx) = self.tx.upgrade() else { + self.send_transfer_frame(state, peer, transfer_id, TransferFrame::Cancel, true); + return; + }; + + let (chunk_tx, chunk_rx) = async_channel::bounded(1); + let responder = Responder::new(request_id, peer, tx.clone()); + let body = InboundByteStream::new(peer, transfer_id, chunk_rx, tx); + self.platform + .handle_inbound(HandlerEvent::UploadRequest(InboundUploadRequest { + sender: peer, + recipient: self.platform.xid(), + route_id, + message_id: request_id, + meta, + body, + respond_to: responder, + })); + + state.inbound_transfers.insert( + (peer, transfer_id), + InboundTransferState { + open, + expected_seq: 1, + chunk_tx, + }, + ); + + self.send_transfer_frame( + state, + peer, + transfer_id, + TransferFrame::Ack { next_seq: 1 }, + true, + ); + } + + fn handle_duplicate_transfer_open( + &self, + state: &mut RuntimeState, + peer: XID, + transfer_id: MessageId, + open: &InboundTransferOpen, + ) -> bool { + let key = (peer, transfer_id); + let Some(existing) = state.inbound_transfers.get(&key) else { + return false; + }; + + let frame = if &existing.open == open { + TransferFrame::Ack { next_seq: 1 } + } else { + TransferFrame::Cancel + }; + self.send_transfer_frame(state, peer, transfer_id, frame, true); + true + } + + fn handle_transfer_chunk( + &self, + state: &mut RuntimeState, + peer: XID, + transfer_id: MessageId, + seq: u32, + data: Vec, + ) { + let key = (peer, transfer_id); + let Some(mut transfer_state) = state.inbound_transfers.remove(&key) else { + return; + }; + + if seq < transfer_state.expected_seq { + self.send_transfer_frame( + state, + peer, + transfer_id, + TransferFrame::Ack { + next_seq: transfer_state.expected_seq, + }, + true, + ); + state.inbound_transfers.insert(key, transfer_state); + return; + } + + if seq > transfer_state.expected_seq { + let _ = transfer_state.chunk_tx.try_send(InboundStreamItem::Error( + QlError::TransferProtocol { id: transfer_id }, + )); + transfer_state.chunk_tx.close(); + self.send_transfer_frame(state, peer, transfer_id, TransferFrame::Cancel, true); + return; + } + + match transfer_state + .chunk_tx + .try_send(InboundStreamItem::Chunk(data)) + { + Ok(()) => { + transfer_state.expected_seq = transfer_state.expected_seq.saturating_add(1); + self.send_transfer_frame( + state, + peer, + transfer_id, + TransferFrame::Ack { + next_seq: transfer_state.expected_seq, + }, + true, + ); + state.inbound_transfers.insert(key, transfer_state); + } + Err(async_channel::TrySendError::Full(_)) => { + state.inbound_transfers.insert(key, transfer_state); + } + Err(async_channel::TrySendError::Closed(_)) => { + self.send_transfer_frame(state, peer, transfer_id, TransferFrame::Cancel, true); + } + } + } + + fn handle_transfer_finish( + &self, + state: &mut RuntimeState, + peer: XID, + transfer_id: MessageId, + seq: u32, + ) { + let key = (peer, transfer_id); + let Some(mut transfer_state) = state.inbound_transfers.remove(&key) else { + return; + }; + + if seq < transfer_state.expected_seq { + self.send_transfer_frame( + state, + peer, + transfer_id, + TransferFrame::Ack { + next_seq: transfer_state.expected_seq, + }, + true, + ); + state.inbound_transfers.insert(key, transfer_state); + return; + } + + if seq > transfer_state.expected_seq { + let _ = transfer_state.chunk_tx.try_send(InboundStreamItem::Error( + QlError::TransferProtocol { id: transfer_id }, + )); + transfer_state.chunk_tx.close(); + self.send_transfer_frame(state, peer, transfer_id, TransferFrame::Cancel, true); + return; + } + + match transfer_state + .chunk_tx + .try_send(InboundStreamItem::Finished) + { + Ok(()) => { + transfer_state.expected_seq = transfer_state.expected_seq.saturating_add(1); + transfer_state.chunk_tx.close(); + self.send_transfer_frame( + state, + peer, + transfer_id, + TransferFrame::Ack { + next_seq: transfer_state.expected_seq, + }, + true, + ); + } + Err(async_channel::TrySendError::Full(_)) => { + state.inbound_transfers.insert(key, transfer_state); + } + Err(async_channel::TrySendError::Closed(_)) => { + self.send_transfer_frame(state, peer, transfer_id, TransferFrame::Cancel, true); + } + } + } + + fn handle_transfer_ack( + &self, + state: &mut RuntimeState, + peer: XID, + transfer_id: MessageId, + next_seq: u32, + ) { + let key = (peer, transfer_id); + let Some(mut transfer_state) = state.outbound_transfers.remove(&key) else { + return; + }; + + let matched = match transfer_state.awaiting.as_ref() { + Some(OutboundAwaiting::Open { .. }) => next_seq == 1, + Some(OutboundAwaiting::Chunk { seq, .. }) => next_seq == seq.saturating_add(1), + Some(OutboundAwaiting::Finish { seq }) => next_seq == seq.saturating_add(1), + Some(OutboundAwaiting::Cancel) | None => false, + }; + if !matched { + state.outbound_transfers.insert(key, transfer_state); + return; + } + + match transfer_state.awaiting.take() { + Some(OutboundAwaiting::Open { .. }) => { + transfer_state.stage = OutboundTransferStage::Streaming; + state.outbound_transfers.insert(key, transfer_state); + } + Some(OutboundAwaiting::Chunk { seq, .. }) => { + transfer_state.next_seq = seq.saturating_add(1); + transfer_state.stage = OutboundTransferStage::Streaming; + state.outbound_transfers.insert(key, transfer_state); + } + Some(OutboundAwaiting::Finish { .. }) => { + transfer_state.chunk_rx.close(); + } + Some(OutboundAwaiting::Cancel) | None => { + state.outbound_transfers.insert(key, transfer_state); + } + } + } + + fn handle_transfer_cancel(&self, state: &mut RuntimeState, peer: XID, transfer_id: MessageId) { + let key = (peer, transfer_id); + let mut acknowledged = false; + + if let Some(transfer_state) = state.outbound_transfers.remove(&key) { + transfer_state.chunk_rx.close(); + acknowledged = true; + } + + if let Some(transfer_state) = state.inbound_transfers.remove(&key) { + let error = QlError::TransferCancelled { id: transfer_id }; + let _ = transfer_state + .chunk_tx + .try_send(InboundStreamItem::Error(error)); + transfer_state.chunk_tx.close(); + acknowledged = true; + } + + if acknowledged { + self.send_transfer_frame(state, peer, transfer_id, TransferFrame::CancelAck, true); + } + } + + fn handle_transfer_cancel_ack( + &self, + state: &mut RuntimeState, + peer: XID, + transfer_id: MessageId, + ) { + let key = (peer, transfer_id); + let Some(transfer_state) = state.outbound_transfers.remove(&key) else { + return; + }; + if !matches!(transfer_state.awaiting, Some(OutboundAwaiting::Cancel)) { + state.outbound_transfers.insert(key, transfer_state); + return; + } + + transfer_state.chunk_rx.close(); } - fn handle_incoming(&self, state: &mut RuntimeState, bytes: Vec) { - let Ok(record) = CBOR::try_from_data(&bytes).and_then(QlRecord::try_from) else { + fn drive_outbound_transfers(&self, state: &mut RuntimeState) { + let keys: Vec<(XID, MessageId)> = state.outbound_transfers.keys().copied().collect(); + for (peer, transfer_id) in keys { + self.drive_outbound_transfer(state, peer, transfer_id); + } + } + + fn drive_outbound_transfer(&self, state: &mut RuntimeState, peer: XID, transfer_id: MessageId) { + let key = (peer, transfer_id); + let Some(mut transfer_state) = state.outbound_transfers.remove(&key) else { return; }; - let QlRecord { header, payload } = record; - if header.recipient != self.platform.xid() { + + if transfer_state.awaiting.is_some() { + state.outbound_transfers.insert(key, transfer_state); return; } - match payload { - QlPayload::Handshake(message) => { - self.handle_handshake(state, header, message); - } - QlPayload::Pair(request) => { - self.handle_pairing(state, header, request); + + match transfer_state.stage { + OutboundTransferStage::Opening => { + let Some(meta) = transfer_state.open_meta.take() else { + transfer_state.chunk_rx.close(); + return; + }; + let awaiting = OutboundAwaiting::Open { + request_id: transfer_state.request_id, + route_id: transfer_state.open_route_id, + meta, + }; + if self.send_outbound_awaiting(state, &mut transfer_state, awaiting, 0) { + state.outbound_transfers.insert(key, transfer_state); + } } - QlPayload::Message(encrypted) => { - self.handle_record(state, header, encrypted); + OutboundTransferStage::Streaming => match transfer_state.chunk_rx.try_recv() { + Ok(OutboundStreamInput::Chunk(data)) => { + let seq = transfer_state.next_seq; + let awaiting = OutboundAwaiting::Chunk { seq, data }; + if self.send_outbound_awaiting(state, &mut transfer_state, awaiting, 0) { + state.outbound_transfers.insert(key, transfer_state); + } + } + Ok(OutboundStreamInput::Finish) => { + let seq = transfer_state.next_seq; + transfer_state.stage = OutboundTransferStage::Finishing; + let awaiting = OutboundAwaiting::Finish { seq }; + if self.send_outbound_awaiting(state, &mut transfer_state, awaiting, 0) { + state.outbound_transfers.insert(key, transfer_state); + } + } + Err(async_channel::TryRecvError::Empty) => { + state.outbound_transfers.insert(key, transfer_state); + } + Err(async_channel::TryRecvError::Closed) => { + transfer_state.stage = OutboundTransferStage::Cancelling; + let awaiting = OutboundAwaiting::Cancel; + if self.send_outbound_awaiting(state, &mut transfer_state, awaiting, 0) { + state.outbound_transfers.insert(key, transfer_state); + } + } + }, + OutboundTransferStage::Finishing => { + state.outbound_transfers.insert(key, transfer_state); } - QlPayload::Heartbeat(encrypted) => { - self.handle_heartbeat(state, header, encrypted); + OutboundTransferStage::Cancelling => { + let awaiting = OutboundAwaiting::Cancel; + if self.send_outbound_awaiting(state, &mut transfer_state, awaiting, 0) { + state.outbound_transfers.insert(key, transfer_state); + } } } } - fn handle_handshake( + fn send_outbound_awaiting( &self, state: &mut RuntimeState, - header: QlHeader, - message: HandshakeRecord, - ) { - match message { - HandshakeRecord::Hello(hello) => { - self.handle_hello(state, header, hello); - } - HandshakeRecord::HelloReply(reply) => { - self.handle_hello_reply(state, header, reply); - } - HandshakeRecord::Confirm(confirm) => { - self.handle_confirm(state, header, confirm); + transfer_state: &mut OutboundTransferState, + awaiting: OutboundAwaiting, + attempt: u8, + ) -> bool { + let frame = match &awaiting { + OutboundAwaiting::Open { + request_id, + route_id, + meta, + } => match route_id { + Some(route_id) => TransferFrame::OpenRequest { + request_id: *request_id, + route_id: *route_id, + meta: meta.clone(), + }, + None => TransferFrame::OpenResponse { + request_id: *request_id, + meta: meta.clone(), + }, + }, + OutboundAwaiting::Chunk { seq, data } => TransferFrame::Chunk { + seq: *seq, + data: data.clone(), + }, + OutboundAwaiting::Finish { seq } => TransferFrame::Finish { seq: *seq }, + OutboundAwaiting::Cancel => TransferFrame::Cancel, + }; + + let priority = matches!(awaiting, OutboundAwaiting::Cancel); + if !self.send_transfer_frame( + state, + transfer_state.peer, + transfer_state.transfer_id, + frame, + priority, + ) { + transfer_state.chunk_rx.close(); + return false; + } + + transfer_state.awaiting = Some(awaiting); + let at = Instant::now() + self.transfer_ack_timeout(); + match transfer_state.awaiting.as_ref() { + Some(OutboundAwaiting::Open { .. }) => state.timeouts.push(Reverse(TimeoutEntry { + at, + kind: TimeoutKind::TransferAck { + peer: transfer_state.peer, + transfer_id: transfer_state.transfer_id, + next_seq: 1, + attempt, + }, + })), + Some(OutboundAwaiting::Chunk { seq, .. }) => { + state.timeouts.push(Reverse(TimeoutEntry { + at, + kind: TimeoutKind::TransferAck { + peer: transfer_state.peer, + transfer_id: transfer_state.transfer_id, + next_seq: seq.saturating_add(1), + attempt, + }, + })) } + Some(OutboundAwaiting::Finish { seq }) => state.timeouts.push(Reverse(TimeoutEntry { + at, + kind: TimeoutKind::TransferAck { + peer: transfer_state.peer, + transfer_id: transfer_state.transfer_id, + next_seq: seq.saturating_add(1), + attempt, + }, + })), + Some(OutboundAwaiting::Cancel) => state.timeouts.push(Reverse(TimeoutEntry { + at, + kind: TimeoutKind::TransferCancelAck { + peer: transfer_state.peer, + transfer_id: transfer_state.transfer_id, + attempt, + }, + })), + None => {} } + + true } - fn handle_pairing( + fn send_transfer_frame( &self, state: &mut RuntimeState, - header: QlHeader, - request: PairRequestRecord, - ) { - let payload = match pair::decrypt_pair_request(&self.platform, &header, request) { - Ok(payload) => payload, - Err(_) => return, - }; - let peer = XID::new(SigningPublicKey::MLDSA(payload.signing_pub_key.clone())); - state + peer: XID, + transfer_id: MessageId, + frame: TransferFrame, + priority: bool, + ) -> bool { + let Some(session_key) = state .peers - .upsert_peer(peer, payload.signing_pub_key, payload.encapsulation_pub_key); - self.persist_peers(state); - self.handle_connect(state, peer); + .peer(peer) + .and_then(|entry| entry.session.session_key()) + .cloned() + else { + return false; + }; + + let body = TransferBody { + message_id: state.next_message_id(), + valid_until: now_secs().saturating_add(self.config.message_expiration.as_secs()), + transfer_id, + frame, + }; + let record = transfer::encrypt_transfer( + QlHeader { + sender: self.platform.xid(), + recipient: peer, + }, + &session_key, + body, + ); + let bytes = CBOR::from(record).to_cbor_data(); + self.enqueue_outbound_preencoded( + state, + peer, + bytes, + Instant::now() + self.config.message_expiration, + priority, + ); + true } - fn persist_peers(&self, state: &RuntimeState) { - self.platform.persist_peers(state.peers.all()); + fn transfer_ack_timeout(&self) -> std::time::Duration { + if self.config.default_request_timeout.is_zero() { + std::time::Duration::from_millis(200) + } else { + self.config.default_request_timeout + } } - fn handle_record( + fn handle_transfer_ack_timeout( &self, state: &mut RuntimeState, - header: QlHeader, - encrypted: bc_components::EncryptedMessage, + peer: XID, + transfer_id: MessageId, + next_seq: u32, + attempt: u8, ) { - let peer = header.sender; - let session_key = match state.peers.peer(peer) { - Some(entry) => match &entry.session { - PeerSession::Connected { session_key, .. } => session_key.clone(), - _ => return, - }, - None => return, - }; - let record = match message::decrypt_message(&header, &encrypted, &session_key) { - Ok(record) => record, - Err(message::MessageError::Nack { id, nack, kind }) => { - self.handle_message_nack(state, peer, id, nack, kind); - return; - } - Err(message::MessageError::Error(_)) => return, + let key = (peer, transfer_id); + let Some(mut transfer_state) = state.outbound_transfers.remove(&key) else { + return; }; - let namespace = match record.kind { - MessageKind::Request | MessageKind::Event => ReplayNamespace::Peer, - MessageKind::Response | MessageKind::Nack => ReplayNamespace::Local, + + let expected = match transfer_state.awaiting.as_ref() { + Some(OutboundAwaiting::Open { .. }) => Some(1), + Some(OutboundAwaiting::Chunk { seq, .. }) => Some(seq.saturating_add(1)), + Some(OutboundAwaiting::Finish { seq }) => Some(seq.saturating_add(1)), + _ => None, }; - let replay_key = ReplayKey::new(peer, namespace, record.message_id); - if state - .replay_cache - .check_and_store_valid_until(replay_key, record.valid_until) - { + if expected != Some(next_seq) { + state.outbound_transfers.insert(key, transfer_state); return; } - self.record_activity(state, peer); - match record.kind { - MessageKind::Response => { - self.resolve_pending_ok(state, peer, record.message_id, record.payload); - } - MessageKind::Nack => { - let nack = Nack::from(record.payload); - self.resolve_pending_nack(state, peer, record.message_id, nack); - } - MessageKind::Request => { - let Some(tx) = self.tx.upgrade() else { - return; - }; - let responder = Responder::new(record.message_id, record.sender, tx); - self.platform - .handle_inbound(HandlerEvent::Request(InboundRequest { - message: record, - respond_to: responder, - })); - } - MessageKind::Event => { - self.platform - .handle_inbound(HandlerEvent::Event(InboundEvent { message: record })); - } + + if attempt >= TRANSFER_RETRY_LIMIT { + transfer_state.chunk_rx.close(); + return; } - } - fn handle_message_nack( - &self, - state: &mut RuntimeState, - peer: XID, - id: MessageId, - nack: Nack, - kind: MessageKind, - ) { - if kind != MessageKind::Request { + let Some(awaiting) = transfer_state.awaiting.take() else { + state.outbound_transfers.insert(key, transfer_state); return; + }; + if self.send_outbound_awaiting(state, &mut transfer_state, awaiting, attempt + 1) { + state.outbound_transfers.insert(key, transfer_state); } - self.handle_send_response(state, id, peer, CBOR::from(nack), MessageKind::Nack); } - fn handle_heartbeat( + fn handle_transfer_cancel_ack_timeout( &self, state: &mut RuntimeState, - header: QlHeader, - encrypted: bc_components::EncryptedMessage, + peer: XID, + transfer_id: MessageId, + attempt: u8, ) { - let peer = header.sender; - let (session_key, should_reply) = { - let Some(entry) = state.peers.peer(peer) else { - return; - }; - match &entry.session { - PeerSession::Connected { - session_key, - keepalive, - } => (session_key.clone(), !keepalive.pending), - _ => return, - } + let key = (peer, transfer_id); + let Some(mut transfer_state) = state.outbound_transfers.remove(&key) else { + return; }; - if heartbeat::decrypt_heartbeat(&header, &encrypted, &session_key).is_err() { + + if !matches!(transfer_state.awaiting, Some(OutboundAwaiting::Cancel)) { + state.outbound_transfers.insert(key, transfer_state); return; } - self.record_activity(state, peer); - if should_reply { - self.send_heartbeat_message(state, peer, session_key); + + if attempt >= TRANSFER_RETRY_LIMIT { + transfer_state.chunk_rx.close(); + return; + } + + transfer_state.awaiting = None; + if self.send_outbound_awaiting( + state, + &mut transfer_state, + OutboundAwaiting::Cancel, + attempt + 1, + ) { + state.outbound_transfers.insert(key, transfer_state); } } @@ -597,6 +1622,9 @@ impl Runtime

{ if let Some(entry) = state.pending.remove(&id) { let _ = entry.tx.send(Err(QlError::SendFailed)); } + if let Some(entry) = state.pending_stream.remove(&id) { + let _ = entry.tx.send(Err(QlError::SendFailed)); + } } false } else { @@ -614,6 +1642,34 @@ impl Runtime

{ }); } + fn fail_pending_stream_for_peer(&self, state: &mut RuntimeState, peer: XID) { + state + .pending_stream + .extract_if(|_id, entry| entry.recipient == peer) + .for_each(|(_, entry)| { + let _ = entry.tx.send(Err(QlError::SendFailed)); + }); + } + + fn abort_transfers_for_peer(&self, state: &mut RuntimeState, peer: XID, error: QlError) { + state + .outbound_transfers + .extract_if(|(transfer_peer, _), _| *transfer_peer == peer) + .for_each(|(_, transfer_state)| { + transfer_state.chunk_rx.close(); + }); + + state + .inbound_transfers + .extract_if(|(transfer_peer, _), _| *transfer_peer == peer) + .for_each(|(_, transfer_state)| { + let _ = transfer_state + .chunk_tx + .try_send(InboundStreamItem::Error(error.clone())); + transfer_state.chunk_tx.close(); + }); + } + fn resolve_pending_ok( &self, state: &mut RuntimeState, @@ -625,6 +1681,12 @@ impl Runtime

{ if entry.recipient == sender { let _ = entry.tx.send(Ok(payload)); } + return; + } + if let Some(entry) = state.pending_stream.remove(&id) { + if entry.recipient == sender { + let _ = entry.tx.send(Err(QlError::InvalidPayload)); + } } } @@ -639,6 +1701,12 @@ impl Runtime

{ if entry.recipient == sender { let _ = entry.tx.send(Err(QlError::Nack { id, nack })); } + return; + } + if let Some(entry) = state.pending_stream.remove(&id) { + if entry.recipient == sender { + let _ = entry.tx.send(Err(QlError::Nack { id, nack })); + } } } @@ -932,6 +2000,32 @@ impl Runtime

{ })); } + fn enqueue_outbound_preencoded( + &self, + state: &mut RuntimeState, + peer: XID, + bytes: Vec, + deadline: Instant, + priority: bool, + ) { + let token = state.next_token(); + let message = OutboundMessage { + peer, + token, + message_id: None, + payload: OutboundPayload::PreEncoded(bytes), + }; + if priority { + state.outbound.push_front(message); + } else { + state.outbound.push_back(message); + } + state.timeouts.push(Reverse(TimeoutEntry { + at: deadline, + kind: TimeoutKind::Outbound { token }, + })); + } + fn handle_timeouts(&self, state: &mut RuntimeState) { let now = Instant::now(); loop { @@ -954,6 +2048,9 @@ impl Runtime

{ if let Some(entry) = state.pending.remove(&id) { let _ = entry.tx.send(Err(QlError::SendFailed)); } + if let Some(entry) = state.pending_stream.remove(&id) { + let _ = entry.tx.send(Err(QlError::SendFailed)); + } } } TimeoutKind::Handshake { peer, token } => { @@ -981,6 +2078,9 @@ impl Runtime

{ if let Some(entry) = state.pending.remove(&id) { let _ = entry.tx.send(Err(QlError::Timeout)); } + if let Some(entry) = state.pending_stream.remove(&id) { + let _ = entry.tx.send(Err(QlError::Timeout)); + } } TimeoutKind::KeepAliveSend { peer, token } => { let Some(config) = self.keep_alive_config() else { @@ -1035,8 +2135,25 @@ impl Runtime

{ } self.drop_outbound_for_peer(state, peer); self.fail_pending_for_peer(state, peer); + self.fail_pending_stream_for_peer(state, peer); + self.abort_transfers_for_peer(state, peer, QlError::SendFailed); } } + TimeoutKind::TransferAck { + peer, + transfer_id, + next_seq, + attempt, + } => { + self.handle_transfer_ack_timeout(state, peer, transfer_id, next_seq, attempt); + } + TimeoutKind::TransferCancelAck { + peer, + transfer_id, + attempt, + } => { + self.handle_transfer_cancel_ack_timeout(state, peer, transfer_id, attempt); + } } } } @@ -1057,6 +2174,9 @@ impl Runtime

{ if let Some(entry) = state.pending.remove(&id) { let _ = entry.tx.send(Err(QlError::SendFailed)); } + if let Some(entry) = state.pending_stream.remove(&id) { + let _ = entry.tx.send(Err(QlError::SendFailed)); + } } let should_disconnect = match state.peers.peer(peer).map(|entry| &entry.session) { Some(PeerSession::Initiator { @@ -1073,6 +2193,9 @@ impl Runtime

{ self.platform.handle_peer_status(peer, &entry.session); } state.outbound.retain(|message| message.peer != peer); + self.fail_pending_for_peer(state, peer); + self.fail_pending_stream_for_peer(state, peer); + self.abort_transfers_for_peer(state, peer, QlError::SendFailed); } } } diff --git a/ql/src/runtime/handle.rs b/ql/src/runtime/handle.rs index 004d9e4f..90176191 100644 --- a/ql/src/runtime/handle.rs +++ b/ql/src/runtime/handle.rs @@ -5,13 +5,17 @@ use std::{ task::{Context, Poll}, }; +use async_channel::Sender; use bc_components::{MLDSAPublicKey, MLKEMPublicKey, XID}; use dcbor::CBOR; use crate::{ - runtime::{internal::RuntimeCommand, RequestConfig}, + runtime::{ + internal::{InboundStreamDelivery, InboundStreamItem, OutboundStreamInput, RuntimeCommand}, + RequestConfig, + }, wire::message::Ack, - Event, QlCodec, QlError, RequestResponse, RouteId, + Event, MessageId, QlCodec, QlError, QlStream, QlUpload, RequestResponse, RouteId, }; #[derive(Clone)] @@ -24,6 +28,46 @@ pub struct Response { _type: PhantomData T>, } +pub struct StreamResponse { + rx: oneshot::Receiver>, + _type: PhantomData T>, +} + +pub struct InboundStream { + pub meta: T, + pub body: InboundByteStream, +} + +pub struct InboundByteStream { + sender: XID, + transfer_id: MessageId, + rx: async_channel::Receiver, + tx: Sender, + finished: bool, +} + +impl std::fmt::Debug for InboundByteStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("InboundByteStream") + .field("sender", &self.sender) + .field("transfer_id", &self.transfer_id) + .field("finished", &self.finished) + .finish_non_exhaustive() + } +} + +pub struct OutboundTransfer { + recipient: XID, + transfer_id: MessageId, + chunk_tx: Option>, + tx: Sender, +} + +pub struct UploadRequest { + pub transfer: OutboundTransfer, + pub response: Response, +} + impl Response { pub async fn recv(self) -> Result { self.rx.await.unwrap_or(Err(QlError::Cancelled)) @@ -44,6 +88,189 @@ where } } +impl StreamResponse { + pub async fn recv(self) -> Result, QlError> { + let delivery = self.rx.await.unwrap_or(Err(QlError::Cancelled))?; + let InboundStreamDelivery { + peer, + transfer_id, + meta, + rx, + tx, + } = delivery; + Ok(InboundStream { + meta, + body: InboundByteStream::new(peer, transfer_id, rx, tx), + }) + } +} + +impl Future for StreamResponse +where + T: QlCodec, +{ + type Output = Result, QlError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + pin!(&mut self.rx).poll(cx).map(|result| { + let delivery = result.unwrap_or(Err(QlError::Cancelled))?; + let InboundStreamDelivery { + peer, + transfer_id, + meta, + rx, + tx, + } = delivery; + let meta = T::try_from(meta).map_err(|_| QlError::InvalidPayload)?; + Ok(InboundStream { + meta, + body: InboundByteStream::new(peer, transfer_id, rx, tx), + }) + }) + } +} + +impl UploadRequest +where + R: QlCodec, +{ + pub async fn finish(self) -> Result { + let Self { transfer, response } = self; + transfer.finish().await?; + response.await + } +} + +impl InboundByteStream { + pub(crate) fn new( + sender: XID, + transfer_id: MessageId, + rx: async_channel::Receiver, + tx: Sender, + ) -> Self { + Self { + sender, + transfer_id, + rx, + tx, + finished: false, + } + } + + pub async fn next_chunk(&mut self) -> Result>, QlError> { + if self.finished { + return Ok(None); + } + match self.rx.recv().await { + Ok(InboundStreamItem::Chunk(chunk)) => Ok(Some(chunk)), + Ok(InboundStreamItem::Finished) => { + self.finished = true; + Ok(None) + } + Ok(InboundStreamItem::Error(error)) => { + self.finished = true; + Err(error) + } + Err(_) => { + self.finished = true; + Err(QlError::TransferCancelled { + id: self.transfer_id, + }) + } + } + } +} + +impl Drop for InboundByteStream { + fn drop(&mut self) { + if self.finished { + return; + } + let _ = self.tx.try_send(RuntimeCommand::CancelInboundTransfer { + sender: self.sender, + transfer_id: self.transfer_id, + }); + } +} + +impl OutboundTransfer { + pub(crate) fn new( + recipient: XID, + transfer_id: MessageId, + chunk_tx: Sender, + tx: Sender, + ) -> Self { + Self { + recipient, + transfer_id, + chunk_tx: Some(chunk_tx), + tx, + } + } + + pub async fn write_next(&mut self, chunk: Vec) -> Result<(), QlError> { + let chunk_tx = self + .chunk_tx + .as_ref() + .expect("transfer not finished or cancelled"); + chunk_tx + .send(OutboundStreamInput::Chunk(chunk)) + .await + .map_err(|_| QlError::TransferCancelled { + id: self.transfer_id, + })?; + self.tx + .send(RuntimeCommand::PollOutboundTransfer { + recipient: self.recipient, + transfer_id: self.transfer_id, + }) + .await + .map_err(|_| QlError::Cancelled)?; + Ok(()) + } + + pub async fn finish(mut self) -> Result<(), QlError> { + let Some(chunk_tx) = self.chunk_tx.take() else { + return Ok(()); + }; + if chunk_tx.send(OutboundStreamInput::Finish).await.is_err() { + return Ok(()); + } + self.tx + .send(RuntimeCommand::PollOutboundTransfer { + recipient: self.recipient, + transfer_id: self.transfer_id, + }) + .await + .map_err(|_| QlError::Cancelled)?; + chunk_tx.closed().await; + Ok(()) + } + + pub async fn cancel(mut self) -> Result<(), QlError> { + self.chunk_tx.take(); + self.tx + .send(RuntimeCommand::CancelOutboundTransfer { + recipient: self.recipient, + transfer_id: self.transfer_id, + }) + .await + .map_err(|_| QlError::Cancelled) + } +} + +impl Drop for OutboundTransfer { + fn drop(&mut self) { + if self.chunk_tx.take().is_none() { + return; + } + let _ = self.tx.try_send(RuntimeCommand::CancelOutboundTransfer { + recipient: self.recipient, + transfer_id: self.transfer_id, + }); + } +} + impl RuntimeHandle { pub fn register_peer( &self, @@ -64,6 +291,12 @@ impl RuntimeHandle { .map_err(|_| QlError::Cancelled) } + pub fn unpair(&self, peer: XID) -> Result<(), QlError> { + self.tx + .send_blocking(RuntimeCommand::Unpair { peer }) + .map_err(|_| QlError::Cancelled) + } + pub fn send_incoming(&self, bytes: Vec) { self.send(RuntimeCommand::Incoming(bytes)) } @@ -91,6 +324,50 @@ impl RuntimeHandle { } } + pub fn request_stream( + &self, + message: M, + recipient: XID, + config: RequestConfig, + ) -> StreamResponse + where + M: QlStream, + { + let (tx, rx) = oneshot::channel(); + self.send(RuntimeCommand::SendStreamRequest { + recipient, + route_id: M::ID, + payload: message.into(), + respond_to: tx, + config, + }); + StreamResponse { + rx, + _type: PhantomData, + } + } + + pub async fn request_upload( + &self, + message: M, + recipient: XID, + config: RequestConfig, + ) -> Result, QlError> + where + M: QlUpload, + { + let upload = self + .send_request_upload_raw(recipient, M::ID, message.into(), config) + .await?; + Ok(UploadRequest { + transfer: upload.transfer, + response: Response { + rx: upload.response.rx, + _type: PhantomData, + }, + }) + } + pub fn send_event(&self, message: M, recipient: XID) where M: Event, @@ -149,13 +426,67 @@ impl RuntimeHandle { _type: PhantomData, } } + + pub fn send_request_stream_raw( + &self, + recipient: XID, + route_id: RouteId, + payload: CBOR, + config: RequestConfig, + ) -> StreamResponse { + let (tx, rx) = oneshot::channel(); + self.send(RuntimeCommand::SendStreamRequest { + recipient, + route_id, + payload, + respond_to: tx, + config, + }); + StreamResponse { + rx, + _type: PhantomData, + } + } + + pub async fn send_request_upload_raw( + &self, + recipient: XID, + route_id: RouteId, + payload: CBOR, + config: RequestConfig, + ) -> Result, QlError> { + let (response_tx, response_rx) = oneshot::channel(); + let (chunk_tx, chunk_rx) = async_channel::bounded(1); + let (start_tx, start_rx) = oneshot::channel(); + self.tx + .send(RuntimeCommand::SendUploadRequest { + recipient, + route_id, + payload, + respond_to: response_tx, + chunk_rx, + start: start_tx, + config, + }) + .await + .map_err(|_| QlError::Cancelled)?; + + let transfer_id = start_rx.await.unwrap_or(Err(QlError::Cancelled))?; + + Ok(UploadRequest { + transfer: OutboundTransfer::new(recipient, transfer_id, chunk_tx, self.tx.clone()), + response: Response { + rx: response_rx, + _type: PhantomData, + }, + }) + } } impl RuntimeHandle { #[inline] #[track_caller] fn send(&self, cmd: RuntimeCommand) { - // send_blocking is ok bc queue is unbounded self.tx.send_blocking(cmd).expect("runtime is alive") } } diff --git a/ql/src/runtime/internal.rs b/ql/src/runtime/internal.rs index f7e8769c..0a94ec6b 100644 --- a/ql/src/runtime/internal.rs +++ b/ql/src/runtime/internal.rs @@ -5,6 +5,7 @@ use std::{ time::{Instant, SystemTime, UNIX_EPOCH}, }; +use async_channel::{Receiver, Sender}; use bc_components::{MLDSAPublicKey, MLKEMPublicKey, SymmetricKey, XID}; use dcbor::CBOR; @@ -19,9 +20,11 @@ use crate::{ }; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +// Monotonic token for timeout correlation. pub struct Token(u64); #[derive(Debug, Clone)] +// Per-peer keepalive timers and ping state. pub struct KeepAliveState { pub token: Token, pub pending: bool, @@ -45,6 +48,7 @@ impl Default for KeepAliveState { } #[derive(Debug, Clone)] +// Registered peer identity and current session. pub struct PeerRecord { pub peer: XID, pub signing_key: MLDSAPublicKey, @@ -64,6 +68,7 @@ impl PeerRecord { } #[derive(Debug, Clone)] +// In-memory registry of known peers. pub struct PeerStore { peers: Vec, } @@ -108,11 +113,19 @@ impl PeerStore { }) .collect() } + + pub fn remove_peer(&mut self, peer: XID) -> Option { + let index = self.peers.iter().position(|record| record.peer == peer)?; + Some(self.peers.remove(index)) + } } #[derive(Debug, Clone)] +// Session state machine for a peer. pub enum PeerSession { + // No active handshake or session. Disconnected, + // Local side initiated the handshake. Initiator { handshake_token: Token, hello: Hello, @@ -120,6 +133,7 @@ pub enum PeerSession { deadline: Instant, stage: InitiatorStage, }, + // Local side is responding to a handshake. Responder { handshake_token: Token, hello: Hello, @@ -127,6 +141,7 @@ pub enum PeerSession { secrets: crate::wire::handshake::ResponderSecrets, deadline: Instant, }, + // Encrypted session is established. Connected { session_key: SymmetricKey, keepalive: KeepAliveState, @@ -149,20 +164,127 @@ impl PeerSession { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] +// Initiator-side handshake progression. pub enum InitiatorStage { + // Waiting for hello reply. WaitingHelloReply, + // Waiting for confirm completion. WaitingConfirmAck, } +// Producer messages for outbound transfer data. +pub(crate) enum OutboundStreamInput { + // Emit one data chunk. + Chunk(Vec), + // Mark stream end. + Finish, +} + +// Consumer messages for inbound transfer reads. +pub(crate) enum InboundStreamItem { + // Next received data chunk. + Chunk(Vec), + // Clean stream completion. + Finished, + // Terminal stream failure. + Error(QlError), +} + +// Identity of an accepted inbound transfer open frame. +#[derive(Debug, Clone, PartialEq)] +pub enum InboundTransferOpen { + // Streamed response correlated to a prior request. + Response { + request_id: MessageId, + meta: CBOR, + }, + // Streamed upload request with route metadata. + Request { + request_id: MessageId, + route_id: RouteId, + meta: CBOR, + }, +} + +// Runtime-delivered stream metadata and receiver. +pub(crate) struct InboundStreamDelivery { + pub peer: XID, + pub transfer_id: MessageId, + pub meta: CBOR, + pub rx: Receiver, + pub tx: Sender, +} + +// Last sender frame currently awaiting ack. +pub enum OutboundAwaiting { + // Open frame with request correlation. + Open { + request_id: MessageId, + route_id: Option, + meta: CBOR, + }, + // Data frame at a specific sequence. + Chunk { + seq: u32, + data: Vec, + }, + // Finish frame at a specific sequence. + Finish { + seq: u32, + }, + // Cancel frame awaiting cancel-ack. + Cancel, +} + +// Coarse sender-side transfer lifecycle. +pub enum OutboundTransferStage { + // Opening frame not yet acknowledged. + Opening, + // Streaming chunks frame-by-frame. + Streaming, + // Finish frame sent, waiting for ack. + Finishing, + // Cancellation in progress. + Cancelling, +} + +// Runtime state for one outbound transfer. +pub struct OutboundTransferState { + pub request_id: MessageId, + pub peer: XID, + pub transfer_id: MessageId, + pub stage: OutboundTransferStage, + pub next_seq: u32, + pub open_route_id: Option, + pub open_meta: Option, + pub chunk_rx: Receiver, + pub awaiting: Option, +} + +// Runtime state for one inbound transfer. +pub struct InboundTransferState { + pub open: InboundTransferOpen, + pub expected_seq: u32, + pub chunk_tx: Sender, +} + +// Commands consumed by the runtime loop. pub(crate) enum RuntimeCommand { + // Upsert a peer record. RegisterPeer { peer: XID, signing_key: MLDSAPublicKey, encapsulation_key: MLKEMPublicKey, }, + // Start handshake with a peer. Connect { peer: XID, }, + // Send unpair and remove peer. + Unpair { + peer: XID, + }, + // Send unary request and await unary response. SendRequest { recipient: XID, route_id: RouteId, @@ -170,26 +292,73 @@ pub(crate) enum RuntimeCommand { respond_to: oneshot::Sender>, config: RequestConfig, }, + // Send unary request and await streamed response. + SendStreamRequest { + recipient: XID, + route_id: RouteId, + payload: CBOR, + respond_to: oneshot::Sender>, + config: RequestConfig, + }, + // Send streamed request and await unary response. + SendUploadRequest { + recipient: XID, + route_id: RouteId, + payload: CBOR, + respond_to: oneshot::Sender>, + chunk_rx: Receiver, + start: oneshot::Sender>, + config: RequestConfig, + }, + // Send fire-and-forget event. SendEvent { recipient: XID, route_id: RouteId, payload: CBOR, }, + // Send unary response or nack. SendResponse { id: MessageId, recipient: XID, payload: CBOR, kind: MessageKind, }, + // Start sender-side streamed response. + StartResponseStream { + request_id: MessageId, + recipient: XID, + meta: CBOR, + chunk_rx: Receiver, + }, + // Prompt immediate outbound transfer polling. + PollOutboundTransfer { + recipient: XID, + transfer_id: MessageId, + }, + // Cancel sender-side active transfer. + CancelOutboundTransfer { + recipient: XID, + transfer_id: MessageId, + }, + // Cancel receiver-side active transfer. + CancelInboundTransfer { + sender: XID, + transfer_id: MessageId, + }, + // Process raw incoming bytes. Incoming(Vec), } +// Mutable state owned by the runtime loop. pub struct RuntimeState { pub peers: PeerStore, pub next_token: Cell, pub outbound: VecDeque, pub timeouts: BinaryHeap>, pub pending: HashMap, + pub pending_stream: HashMap, + pub outbound_transfers: HashMap<(XID, MessageId), OutboundTransferState>, + pub inbound_transfers: HashMap<(XID, MessageId), InboundTransferState>, pub next_message_id: Cell, pub replay_cache: ReplayCache, } @@ -202,6 +371,9 @@ impl RuntimeState { outbound: VecDeque::new(), timeouts: BinaryHeap::new(), pending: HashMap::new(), + pending_stream: HashMap::new(), + outbound_transfers: HashMap::new(), + inbound_transfers: HashMap::new(), next_message_id: Cell::new(MessageId(1)), replay_cache: ReplayCache::new(), } @@ -220,11 +392,19 @@ impl RuntimeState { } } +// Pending unary response waiter. pub struct PendingEntry { pub recipient: XID, pub tx: oneshot::Sender>, } +// Pending streamed response opener waiter. +pub struct PendingStreamEntry { + pub recipient: XID, + pub tx: oneshot::Sender>, +} + +// Currently executing platform write. pub struct InFlightWrite<'a> { pub peer: XID, pub token: Token, @@ -232,11 +412,15 @@ pub struct InFlightWrite<'a> { pub future: PlatformFuture<'a, Result<(), QlError>>, } +// Queued payload representation. pub enum OutboundPayload { + // Payload already encoded into bytes. PreEncoded(Vec), + // Payload to encrypt at send time. DeferredMessage(MessageBody), } +// Outbound queue item with timeout token. pub struct OutboundMessage { pub peer: XID, pub token: Token, @@ -245,15 +429,48 @@ pub struct OutboundMessage { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] +// Runtime timeout categories. pub enum TimeoutKind { - Outbound { token: Token }, - Handshake { peer: XID, token: Token }, - Request { id: MessageId }, - KeepAliveSend { peer: XID, token: Token }, - KeepAliveTimeout { peer: XID, token: Token }, + // Outbound queue item expired. + Outbound { + token: Token, + }, + // Handshake stage expired. + Handshake { + peer: XID, + token: Token, + }, + // Request waiting for reply expired. + Request { + id: MessageId, + }, + // Send keepalive ping now. + KeepAliveSend { + peer: XID, + token: Token, + }, + // Keepalive pong timeout. + KeepAliveTimeout { + peer: XID, + token: Token, + }, + // Transfer data/open/finish ack timeout. + TransferAck { + peer: XID, + transfer_id: MessageId, + next_seq: u32, + attempt: u8, + }, + // Transfer cancel-ack timeout. + TransferCancelAck { + peer: XID, + transfer_id: MessageId, + attempt: u8, + }, } #[derive(Debug, Clone, PartialEq, Eq)] +// One scheduled timeout entry. pub struct TimeoutEntry { pub at: Instant, pub kind: TimeoutKind, @@ -271,24 +488,33 @@ impl PartialOrd for TimeoutEntry { } } +// Outcome of one runtime loop poll cycle. pub enum LoopStep { + // Received a runtime command. Event(RuntimeCommand), + // One or more timeouts fired. Timeout, + // In-flight write completed. WriteDone { peer: XID, token: Token, message_id: Option, result: Result<(), QlError>, }, + // Runtime should exit loop. Quit, } +// Decision for inbound hello handling. pub enum HelloAction { + // Become responder for this hello. StartResponder, + // Re-send existing hello reply. ResendReply { reply: HelloReply, deadline: Instant, }, + // Ignore this hello. Ignore, } diff --git a/ql/src/runtime/mod.rs b/ql/src/runtime/mod.rs index 9a738968..490a73b6 100644 --- a/ql/src/runtime/mod.rs +++ b/ql/src/runtime/mod.rs @@ -1,4 +1,7 @@ -pub use handle::{Response, RuntimeHandle}; +pub use handle::{ + InboundByteStream, InboundStream, OutboundTransfer, Response, RuntimeHandle, StreamResponse, + UploadRequest, +}; pub use internal::{InitiatorStage, PeerSession, Token}; mod core; @@ -13,7 +16,7 @@ use dcbor::CBOR; use crate::{ wire::message::{DecryptedMessage, MessageKind, Nack}, - MessageId, QlCodec, QlError, + MessageId, QlCodec, QlError, RouteId, }; #[derive(Debug, Clone, Default)] @@ -64,6 +67,7 @@ impl RuntimeConfig { #[derive(Debug)] pub enum HandlerEvent { Request(InboundRequest), + UploadRequest(InboundUploadRequest), Event(InboundEvent), } @@ -73,6 +77,17 @@ pub struct InboundRequest { pub respond_to: Responder, } +#[derive(Debug)] +pub struct InboundUploadRequest { + pub sender: XID, + pub recipient: XID, + pub route_id: RouteId, + pub message_id: MessageId, + pub meta: CBOR, + pub body: InboundByteStream, + pub respond_to: Responder, +} + #[derive(Debug)] pub struct InboundEvent { pub message: DecryptedMessage, @@ -118,6 +133,27 @@ impl Responder { }) .map_err(|_| QlError::Cancelled) } + + pub fn respond_stream(self, meta: M) -> Result + where + M: QlCodec, + { + let (chunk_tx, chunk_rx) = async_channel::bounded(1); + self.tx + .send_blocking(internal::RuntimeCommand::StartResponseStream { + request_id: self.id, + recipient: self.recipient, + meta: meta.into(), + chunk_rx, + }) + .map_err(|_| QlError::Cancelled)?; + Ok(handle::OutboundTransfer::new( + self.recipient, + self.id, + chunk_tx, + self.tx, + )) + } } pub struct Runtime

{ diff --git a/ql/src/runtime/replay_cache.rs b/ql/src/runtime/replay_cache.rs index 876467c8..80e60d21 100644 --- a/ql/src/runtime/replay_cache.rs +++ b/ql/src/runtime/replay_cache.rs @@ -11,6 +11,7 @@ use crate::{runtime::internal::now_secs, MessageId}; pub enum ReplayNamespace { Peer, Local, + Transfer, } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] diff --git a/ql/src/tests/mod.rs b/ql/src/tests/mod.rs index da48c1f5..eae65d8b 100644 --- a/ql/src/tests/mod.rs +++ b/ql/src/tests/mod.rs @@ -1,7 +1,7 @@ use std::{ future::Future, sync::{ - atomic::{AtomicBool, AtomicU8, Ordering}, + atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering}, Arc, }, time::Duration, @@ -34,6 +34,8 @@ mod handshake; mod heartbeat; mod persistence; mod requests; +mod streams; +mod unpair; #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum PeerStage { @@ -379,6 +381,13 @@ fn is_heartbeat(bytes: &[u8]) -> bool { matches!(record.payload, QlPayload::Heartbeat(_)) } +fn is_transfer(bytes: &[u8]) -> bool { + let Ok(record) = CBOR::try_from_data(bytes).and_then(QlRecord::try_from) else { + return false; + }; + matches!(record.payload, QlPayload::Transfer(_)) +} + fn spawn_heartbeat_tap_forwarder( outbound: Receiver>, handle: RuntimeHandle, @@ -405,6 +414,32 @@ fn spawn_drop_heartbeat_forwarder(outbound: Receiver>, handle: RuntimeHa }); } +fn spawn_drop_first_transfer_forwarder(outbound: Receiver>, handle: RuntimeHandle) { + tokio::task::spawn_local(async move { + let mut dropped = false; + while let Ok(bytes) = outbound.recv().await { + if !dropped && is_transfer(&bytes) { + dropped = true; + continue; + } + handle.send_incoming(bytes); + } + }); +} + +fn spawn_duplicate_first_transfer_forwarder(outbound: Receiver>, handle: RuntimeHandle) { + tokio::task::spawn_local(async move { + let mut duplicated = false; + while let Ok(bytes) = outbound.recv().await { + if !duplicated && is_transfer(&bytes) { + duplicated = true; + handle.send_incoming(bytes.clone()); + } + handle.send_incoming(bytes); + } + }); +} + fn spawn_gated_forwarder( outbound: Receiver>, handle: RuntimeHandle, diff --git a/ql/src/tests/requests.rs b/ql/src/tests/requests.rs index 728d3127..23bab759 100644 --- a/ql/src/tests/requests.rs +++ b/ql/src/tests/requests.rs @@ -260,6 +260,7 @@ async fn replayed_message_is_ignored() { assert_eq!(event.message.route_id, RouteId(9)); } HandlerEvent::Request(_) => panic!("unexpected request"), + HandlerEvent::UploadRequest(_) => panic!("unexpected upload request"), } let second = tokio::time::timeout(Duration::from_millis(50), inbound_b.recv()).await; diff --git a/ql/src/tests/streams.rs b/ql/src/tests/streams.rs new file mode 100644 index 00000000..0b04ec8a --- /dev/null +++ b/ql/src/tests/streams.rs @@ -0,0 +1,552 @@ +use super::*; + +#[tokio::test(flavor = "current_thread")] +async fn request_stream_round_trip() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_request_timeout(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { + let mut stream = request.respond_to.respond_stream(7u8).unwrap(); + stream.write_next(vec![1, 2, 3]).await.unwrap(); + stream.write_next(vec![4, 5]).await.unwrap(); + stream.finish().await.unwrap(); + } + }); + + let mut response = handle_a + .send_request_stream_raw( + peer_b.xid, + RouteId(201), + CBOR::from(1u8), + RequestConfig::default(), + ) + .recv() + .await + .unwrap(); + + assert_eq!(response.meta, CBOR::from(7u8)); + assert_eq!( + response.body.next_chunk().await.unwrap(), + Some(vec![1, 2, 3]) + ); + assert_eq!(response.body.next_chunk().await.unwrap(), Some(vec![4, 5])); + assert_eq!(response.body.next_chunk().await.unwrap(), None); + + let _ = responder_task.await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn dropping_inbound_stream_cancels_sender() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_request_timeout(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { + let mut stream = request.respond_to.respond_stream(1u8).unwrap(); + stream.write_next(vec![9]).await.unwrap(); + stream.finish().await + } else { + Err(QlError::Cancelled) + } + }); + + let mut response = handle_a + .send_request_stream_raw( + peer_b.xid, + RouteId(202), + CBOR::from(2u8), + RequestConfig::default(), + ) + .recv() + .await + .unwrap(); + + assert_eq!(response.body.next_chunk().await.unwrap(), Some(vec![9])); + drop(response); + + let result = tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn sender_cancel_surfaces_error_on_receiver() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_request_timeout(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { + let mut stream = request.respond_to.respond_stream(1u8).unwrap(); + stream.write_next(vec![7]).await.unwrap(); + stream.cancel().await.unwrap(); + } + }); + + let mut response = handle_a + .send_request_stream_raw( + peer_b.xid, + RouteId(203), + CBOR::from(3u8), + RequestConfig::default(), + ) + .recv() + .await + .unwrap(); + + let first = response.body.next_chunk().await; + match first { + Ok(Some(_)) => { + let second = response.body.next_chunk().await; + assert!(matches!(second, Err(QlError::TransferCancelled { .. }))); + } + Err(QlError::TransferCancelled { .. }) => {} + other => panic!("unexpected first chunk result: {other:?}"), + } + + let _ = responder_task.await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn request_upload_round_trip() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_request_timeout(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + if let Ok(HandlerEvent::UploadRequest(request)) = inbound_b.recv().await { + assert_eq!(request.route_id, RouteId(204)); + assert_eq!(request.meta, CBOR::from("meta")); + let mut body = request.body; + let mut bytes = Vec::new(); + while let Some(chunk) = body.next_chunk().await.unwrap() { + bytes.extend(chunk); + } + assert_eq!(bytes, vec![1, 2, 3, 4]); + request.respond_to.respond(4u8).unwrap(); + } + }); + + let mut upload = handle_a + .send_request_upload_raw( + peer_b.xid, + RouteId(204), + CBOR::from("meta"), + RequestConfig::default(), + ) + .await + .unwrap(); + upload.transfer.write_next(vec![1, 2]).await.unwrap(); + upload.transfer.write_next(vec![3, 4]).await.unwrap(); + upload.transfer.finish().await.unwrap(); + let response = upload.response.recv().await.unwrap(); + + assert_eq!(response, CBOR::from(4u8)); + + let _ = responder_task.await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn duplicate_open_response_resends_ack_without_cancelling_stream() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_request_timeout(Duration::from_millis(30)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_drop_first_transfer_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { + let mut stream = request.respond_to.respond_stream(7u8).unwrap(); + stream.write_next(vec![1, 2, 3]).await.unwrap(); + stream.write_next(vec![4, 5]).await.unwrap(); + stream.finish().await.unwrap(); + } + }); + + let mut response = tokio::time::timeout( + Duration::from_secs(1), + handle_a + .send_request_stream_raw( + peer_b.xid, + RouteId(205), + CBOR::from(1u8), + RequestConfig { + timeout: Some(Duration::from_millis(200)), + }, + ) + .recv(), + ) + .await + .unwrap() + .unwrap(); + + assert_eq!(response.meta, CBOR::from(7u8)); + assert_eq!( + response.body.next_chunk().await.unwrap(), + Some(vec![1, 2, 3]) + ); + assert_eq!(response.body.next_chunk().await.unwrap(), Some(vec![4, 5])); + assert_eq!(response.body.next_chunk().await.unwrap(), None); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn duplicate_open_request_retries_without_redispatching_upload() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_request_timeout(Duration::from_millis(30)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_drop_first_transfer_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + if let Ok(HandlerEvent::UploadRequest(request)) = inbound_b.recv().await { + assert_eq!(request.route_id, RouteId(206)); + assert_eq!(request.meta, CBOR::from("meta")); + let mut body = request.body; + let mut bytes = Vec::new(); + while let Some(chunk) = body.next_chunk().await.unwrap() { + bytes.extend(chunk); + } + assert_eq!(bytes, vec![1, 2, 3, 4]); + request.respond_to.respond(4u8).unwrap(); + } + + let second = tokio::time::timeout(Duration::from_millis(150), inbound_b.recv()).await; + assert!(second.is_err(), "duplicate upload request dispatched"); + }); + + let mut upload = tokio::time::timeout( + Duration::from_secs(1), + handle_a.send_request_upload_raw( + peer_b.xid, + RouteId(206), + CBOR::from("meta"), + RequestConfig { + timeout: Some(Duration::from_millis(200)), + }, + ), + ) + .await + .unwrap() + .unwrap(); + upload.transfer.write_next(vec![1, 2]).await.unwrap(); + upload.transfer.write_next(vec![3, 4]).await.unwrap(); + upload.transfer.finish().await.unwrap(); + let response = tokio::time::timeout(Duration::from_secs(1), upload.response.recv()) + .await + .unwrap() + .unwrap(); + + assert_eq!(response, CBOR::from(4u8)); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn replayed_transfer_open_request_is_silently_ignored() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_request_timeout(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let transfer_count = Arc::new(AtomicUsize::new(0)); + spawn_duplicate_first_transfer_forwarder(outbound_a, handle_b.clone()); + tokio::task::spawn_local({ + let handle_a = handle_a.clone(); + let transfer_count = transfer_count.clone(); + async move { + while let Ok(bytes) = outbound_b.recv().await { + if is_transfer(&bytes) { + transfer_count.fetch_add(1, Ordering::Relaxed); + } + handle_a.send_incoming(bytes); + } + } + }); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let upload = handle_a + .send_request_upload_raw( + peer_b.xid, + RouteId(207), + CBOR::from("meta"), + RequestConfig::default(), + ) + .await + .unwrap(); + + let request = match tokio::time::timeout(Duration::from_secs(1), inbound_b.recv()) + .await + .unwrap() + .unwrap() + { + HandlerEvent::UploadRequest(request) => request, + other => panic!("unexpected inbound event: {other:?}"), + }; + + assert_eq!(request.route_id, RouteId(207)); + assert_eq!(request.meta, CBOR::from("meta")); + + let second = tokio::time::timeout(Duration::from_millis(50), inbound_b.recv()).await; + assert!(second.is_err(), "replayed transfer redispatched upload"); + + tokio::time::timeout(Duration::from_secs(1), async { + while transfer_count.load(Ordering::Relaxed) == 0 { + tokio::task::yield_now().await; + } + }) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(50)).await; + assert_eq!( + transfer_count.load(Ordering::Relaxed), + 1, + "replayed transfer produced extra ack" + ); + + drop(upload); + drop(request); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn replayed_transfer_open_response_is_silently_ignored() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_request_timeout(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let transfer_count = Arc::new(AtomicUsize::new(0)); + tokio::task::spawn_local({ + let handle_b = handle_b.clone(); + let transfer_count = transfer_count.clone(); + async move { + while let Ok(bytes) = outbound_a.recv().await { + if is_transfer(&bytes) { + transfer_count.fetch_add(1, Ordering::Relaxed); + } + handle_b.send_incoming(bytes); + } + } + }); + spawn_duplicate_first_transfer_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { + let stream = request.respond_to.respond_stream(7u8).unwrap(); + tokio::time::sleep(Duration::from_millis(250)).await; + drop(stream); + } + }); + + let response = tokio::time::timeout( + Duration::from_secs(1), + handle_a + .send_request_stream_raw( + peer_b.xid, + RouteId(208), + CBOR::from(1u8), + RequestConfig::default(), + ) + .recv(), + ) + .await + .unwrap() + .unwrap(); + + assert_eq!(response.meta, CBOR::from(7u8)); + + tokio::time::timeout(Duration::from_secs(1), async { + while transfer_count.load(Ordering::Relaxed) == 0 { + tokio::task::yield_now().await; + } + }) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(50)).await; + assert_eq!( + transfer_count.load(Ordering::Relaxed), + 1, + "replayed transfer produced extra ack" + ); + + drop(response); + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} diff --git a/ql/src/tests/unpair.rs b/ql/src/tests/unpair.rs new file mode 100644 index 00000000..612f7cbb --- /dev/null +++ b/ql/src/tests/unpair.rs @@ -0,0 +1,160 @@ +use super::*; + +#[tokio::test(flavor = "current_thread")] +async fn connected_unpair_removes_peer_on_both_sides() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_request_timeout(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + handle_a.connect(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + handle_a.unpair(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + await_status(&status_b, peer_a.xid, PeerStage::Disconnected).await; + + let result_a = handle_a + .send_request_raw( + peer_b.xid, + RouteId(90), + CBOR::from(1u8), + RequestConfig::default(), + ) + .recv() + .await; + assert!(matches!(result_a, Err(QlError::UnknownPeer(peer)) if peer == peer_b.xid)); + + let result_b = handle_b + .send_request_raw( + peer_a.xid, + RouteId(91), + CBOR::from(1u8), + RequestConfig::default(), + ) + .recv() + .await; + assert!(matches!(result_b, Err(QlError::UnknownPeer(peer)) if peer == peer_a.xid)); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn unpair_works_without_session() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)) + .with_request_timeout(Duration::from_millis(200)); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + await_status(&status_b, peer_a.xid, PeerStage::Disconnected).await; + + handle_a.unpair(peer_b.xid).unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + await_status(&status_b, peer_a.xid, PeerStage::Disconnected).await; + + let result_a = handle_a + .send_request_raw( + peer_b.xid, + RouteId(92), + CBOR::from(1u8), + RequestConfig::default(), + ) + .recv() + .await; + assert!(matches!(result_a, Err(QlError::UnknownPeer(peer)) if peer == peer_b.xid)); + + let result_b = handle_b + .send_request_raw( + peer_a.xid, + RouteId(93), + CBOR::from(1u8), + RequestConfig::default(), + ) + .recv() + .await; + assert!(matches!(result_b, Err(QlError::UnknownPeer(peer)) if peer == peer_a.xid)); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn invalid_unpair_signature_is_ignored() { + run_local_test(async { + let config = RuntimeConfig::new(Duration::from_millis(200)); + let (platform_a, _outbound_a, _status_a) = TestPlatform::new(1); + let (platform_b, _outbound_b, status_b) = TestPlatform::new(2); + let (fake_signer, _fake_outbound, _fake_status) = TestPlatform::new(3); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let forged_unpair = wire::unpair::build_unpair_record( + &fake_signer, + QlHeader { + sender: peer_a.xid, + recipient: peer_b.xid, + }, + MessageId(777), + now_secs().saturating_add(60), + ); + let forged_bytes = CBOR::from(forged_unpair).to_cbor_data(); + + let (runtime_b, handle_b) = new_runtime(platform_b, config); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + handle_b.register_peer( + peer_a.xid, + peer_a.signing_key.clone(), + peer_a.encapsulation_key.clone(), + ); + await_status(&status_b, peer_a.xid, PeerStage::Disconnected).await; + + handle_b.send_incoming(forged_bytes); + + tokio::time::sleep(Duration::from_millis(20)).await; + + let result = handle_b + .send_request_raw( + peer_a.xid, + RouteId(94), + CBOR::from(1u8), + RequestConfig::default(), + ) + .recv() + .await; + assert!(matches!(result, Err(QlError::MissingSession(peer)) if peer == peer_a.xid)); + }) + .await; +} diff --git a/ql/src/wire/mod.rs b/ql/src/wire/mod.rs index c4b5e90f..663a329c 100644 --- a/ql/src/wire/mod.rs +++ b/ql/src/wire/mod.rs @@ -4,10 +4,12 @@ pub mod handshake; pub mod heartbeat; pub mod message; pub mod pair; +pub mod transfer; +pub mod unpair; use bc_components::{EncryptedMessage, XID}; -use self::{handshake::HandshakeRecord, pair::PairRequestRecord}; +use self::{handshake::HandshakeRecord, pair::PairRequestRecord, unpair::UnpairRecord}; use crate::{MessageId, QlError}; #[derive(Debug, Clone, PartialEq)] @@ -32,8 +34,10 @@ impl QlHeader { pub enum QlPayload { Handshake(HandshakeRecord), Pair(PairRequestRecord), + Unpair(UnpairRecord), Message(EncryptedMessage), Heartbeat(EncryptedMessage), + Transfer(EncryptedMessage), } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -42,6 +46,8 @@ pub enum QlTag { Pairing = 2, Record = 3, Heartbeat = 4, + Unpair = 5, + Transfer = 6, } impl From for CBOR { @@ -60,6 +66,8 @@ impl TryFrom for QlTag { 2 => Ok(Self::Pairing), 3 => Ok(Self::Record), 4 => Ok(Self::Heartbeat), + 5 => Ok(Self::Unpair), + 6 => Ok(Self::Transfer), _ => Err(dcbor::Error::msg("unknown message tag")), } } @@ -72,6 +80,8 @@ impl From for CBOR { QlPayload::Pair(message) => (QlTag::Pairing, CBOR::from(message)), QlPayload::Message(message) => (QlTag::Record, CBOR::from(message)), QlPayload::Heartbeat(message) => (QlTag::Heartbeat, CBOR::from(message)), + QlPayload::Unpair(message) => (QlTag::Unpair, CBOR::from(message)), + QlPayload::Transfer(message) => (QlTag::Transfer, CBOR::from(message)), }; CBOR::from(vec![ CBOR::from(tag as u8), @@ -118,6 +128,20 @@ impl TryFrom for QlRecord { payload: QlPayload::Heartbeat(message), }) } + QlTag::Unpair => { + let message = UnpairRecord::try_from(payload)?; + Ok(QlRecord { + header, + payload: QlPayload::Unpair(message), + }) + } + QlTag::Transfer => { + let message = EncryptedMessage::try_from(payload)?; + Ok(QlRecord { + header, + payload: QlPayload::Transfer(message), + }) + } } } } diff --git a/ql/src/wire/transfer/crypto.rs b/ql/src/wire/transfer/crypto.rs new file mode 100644 index 00000000..dec752d2 --- /dev/null +++ b/ql/src/wire/transfer/crypto.rs @@ -0,0 +1,42 @@ +use bc_components::{Nonce, SymmetricKey}; +use dcbor::CBOR; + +use super::TransferBody; +use crate::{ + wire::{now_secs, QlHeader, QlPayload, QlRecord}, + QlError, +}; + +pub fn encrypt_transfer( + header: QlHeader, + session_key: &SymmetricKey, + body: TransferBody, +) -> QlRecord { + let aad = header.aad(); + let body_bytes = CBOR::from(body).to_cbor_data(); + let encrypted = session_key.encrypt(body_bytes, Some(aad), None::); + QlRecord { + header, + payload: QlPayload::Transfer(encrypted), + } +} + +pub fn decrypt_transfer( + header: &QlHeader, + encrypted: &bc_components::EncryptedMessage, + session_key: &SymmetricKey, +) -> Result { + let aad = header.aad(); + if encrypted.aad() != aad { + return Err(QlError::InvalidPayload); + } + let plaintext = session_key + .decrypt(encrypted) + .map_err(|_| QlError::InvalidPayload)?; + let cbor = CBOR::try_from_data(plaintext).map_err(|_| QlError::InvalidPayload)?; + let body = TransferBody::try_from(cbor).map_err(|_| QlError::InvalidPayload)?; + if now_secs() > body.valid_until { + return Err(QlError::InvalidPayload); + } + Ok(body) +} diff --git a/ql/src/wire/transfer/mod.rs b/ql/src/wire/transfer/mod.rs new file mode 100644 index 00000000..f6b874b0 --- /dev/null +++ b/ql/src/wire/transfer/mod.rs @@ -0,0 +1,194 @@ +use dcbor::CBOR; + +use super::take_fields; +use crate::{MessageId, RouteId}; + +mod crypto; +pub use crypto::*; + +#[derive(Debug, Clone, PartialEq)] +pub struct TransferBody { + pub message_id: MessageId, + pub valid_until: u64, + pub transfer_id: MessageId, + pub frame: TransferFrame, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum TransferFrame { + OpenResponse { + request_id: MessageId, + meta: CBOR, + }, + OpenRequest { + request_id: MessageId, + route_id: RouteId, + meta: CBOR, + }, + Chunk { + seq: u32, + data: Vec, + }, + Finish { + seq: u32, + }, + Ack { + next_seq: u32, + }, + Cancel, + CancelAck, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TransferKind { + OpenResponse = 1, + OpenRequest, + Chunk, + Finish, + Ack, + Cancel, + CancelAck, +} + +impl From for CBOR { + fn from(value: TransferBody) -> Self { + CBOR::from(vec![ + CBOR::from(value.message_id), + CBOR::from(value.valid_until), + CBOR::from(value.transfer_id), + CBOR::from(value.frame), + ]) + } +} + +impl TryFrom for TransferBody { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let iter = value.try_into_array()?.into_iter(); + let [message_id, valid_until, transfer_id, frame] = take_fields(iter)?; + Ok(Self { + message_id: message_id.try_into()?, + valid_until: valid_until.try_into()?, + transfer_id: transfer_id.try_into()?, + frame: frame.try_into()?, + }) + } +} + +impl From for CBOR { + fn from(value: TransferFrame) -> Self { + match value { + TransferFrame::OpenResponse { request_id, meta } => CBOR::from(vec![ + CBOR::from(TransferKind::OpenResponse as u8), + CBOR::from(request_id), + meta, + ]), + TransferFrame::OpenRequest { + request_id, + route_id, + meta, + } => CBOR::from(vec![ + CBOR::from(TransferKind::OpenRequest as u8), + CBOR::from(request_id), + CBOR::from(route_id), + meta, + ]), + TransferFrame::Chunk { seq, data } => CBOR::from(vec![ + CBOR::from(TransferKind::Chunk as u8), + CBOR::from(seq), + CBOR::from(data), + ]), + TransferFrame::Finish { seq } => CBOR::from(vec![ + CBOR::from(TransferKind::Finish as u8), + CBOR::from(seq), + ]), + TransferFrame::Ack { next_seq } => CBOR::from(vec![ + CBOR::from(TransferKind::Ack as u8), + CBOR::from(next_seq), + ]), + TransferFrame::Cancel => CBOR::from(vec![CBOR::from(TransferKind::Cancel as u8)]), + TransferFrame::CancelAck => CBOR::from(vec![CBOR::from(TransferKind::CancelAck as u8)]), + } + } +} + +impl TryFrom for TransferFrame { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let mut iter = value.try_into_array()?.into_iter(); + let tag: TransferKind = iter + .next() + .ok_or_else(|| dcbor::Error::msg("missing transfer frame tag"))? + .try_into()?; + match tag { + TransferKind::OpenResponse => { + let [request_id, meta] = take_fields(iter)?; + Ok(Self::OpenResponse { + request_id: request_id.try_into()?, + meta, + }) + } + TransferKind::OpenRequest => { + let [request_id, route_id, meta] = take_fields(iter)?; + Ok(Self::OpenRequest { + request_id: request_id.try_into()?, + route_id: route_id.try_into()?, + meta, + }) + } + TransferKind::Chunk => { + let [seq, data] = take_fields(iter)?; + Ok(Self::Chunk { + seq: seq.try_into()?, + data: data.try_into()?, + }) + } + TransferKind::Finish => { + let [seq] = take_fields(iter)?; + Ok(Self::Finish { + seq: seq.try_into()?, + }) + } + TransferKind::Ack => { + let [next_seq] = take_fields(iter)?; + Ok(Self::Ack { + next_seq: next_seq.try_into()?, + }) + } + TransferKind::Cancel => { + if iter.next().is_some() { + Err(dcbor::Error::msg("array too long")) + } else { + Ok(Self::Cancel) + } + } + TransferKind::CancelAck => { + if iter.next().is_some() { + Err(dcbor::Error::msg("array too long")) + } else { + Ok(Self::CancelAck) + } + } + } + } +} + +impl TryFrom for TransferKind { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let tag: u8 = value.try_into()?; + match tag { + 1 => Ok(Self::OpenResponse), + 2 => Ok(Self::OpenRequest), + 3 => Ok(Self::Chunk), + 4 => Ok(Self::Finish), + 5 => Ok(Self::Ack), + 6 => Ok(Self::Cancel), + 7 => Ok(Self::CancelAck), + _ => Err(dcbor::Error::msg("unknown transfer frame tag")), + } + } +} diff --git a/ql/src/wire/unpair/crypto.rs b/ql/src/wire/unpair/crypto.rs new file mode 100644 index 00000000..ca319ff6 --- /dev/null +++ b/ql/src/wire/unpair/crypto.rs @@ -0,0 +1,58 @@ +use bc_components::MLDSAPublicKey; +use dcbor::CBOR; + +use super::UnpairRecord; +use crate::{ + platform::QlPlatform, + wire::{now_secs, QlHeader, QlPayload, QlRecord}, + MessageId, QlError, +}; + +pub fn build_unpair_record( + platform: &impl QlPlatform, + header: QlHeader, + message_id: MessageId, + valid_until: u64, +) -> QlRecord { + let signature = + platform + .signing_private_key() + .sign(&unpair_proof_data(&header, message_id, valid_until)); + QlRecord { + header, + payload: QlPayload::Unpair(UnpairRecord { + message_id, + valid_until, + signature, + }), + } +} + +pub fn verify_unpair_record( + header: &QlHeader, + record: &UnpairRecord, + signing_key: &MLDSAPublicKey, +) -> Result<(), QlError> { + if now_secs() > record.valid_until { + return Err(QlError::InvalidPayload); + } + let proof_data = unpair_proof_data(header, record.message_id, record.valid_until); + if signing_key + .verify(&record.signature, &proof_data) + .unwrap_or(false) + { + Ok(()) + } else { + Err(QlError::InvalidSignature) + } +} + +fn unpair_proof_data(header: &QlHeader, message_id: MessageId, valid_until: u64) -> Vec { + CBOR::from(vec![ + CBOR::from("ql-unpair-v1"), + CBOR::from(header.clone()), + CBOR::from(message_id), + CBOR::from(valid_until), + ]) + .to_cbor_data() +} diff --git a/ql/src/wire/unpair/mod.rs b/ql/src/wire/unpair/mod.rs new file mode 100644 index 00000000..cc81bab4 --- /dev/null +++ b/ql/src/wire/unpair/mod.rs @@ -0,0 +1,39 @@ +use bc_components::MLDSASignature; +use dcbor::CBOR; + +use super::take_fields; +use crate::MessageId; + +mod crypto; +pub use crypto::*; + +#[derive(Debug, Clone, PartialEq)] +pub struct UnpairRecord { + pub message_id: MessageId, + pub valid_until: u64, + pub signature: MLDSASignature, +} + +impl From for CBOR { + fn from(value: UnpairRecord) -> Self { + CBOR::from(vec![ + CBOR::from(value.message_id), + CBOR::from(value.valid_until), + CBOR::from(value.signature), + ]) + } +} + +impl TryFrom for UnpairRecord { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let iter = value.try_into_array()?.into_iter(); + let [message_id, valid_until, signature] = take_fields(iter)?; + Ok(Self { + message_id: message_id.try_into()?, + valid_until: valid_until.try_into()?, + signature: signature.try_into()?, + }) + } +} From c01ed1dac5735c6c098627b9cac4415047a202e2 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 18 Mar 2026 01:16:43 -0400 Subject: [PATCH 005/304] ql: introduce duplex runtime, sans-io engine, and engine/runtime split --- Cargo.lock | 22 + Cargo.toml | 2 +- ql2/Cargo.toml | 22 + ql2/README.md | 143 ++ ql2/ql-v2.presenterm.md | 285 ++++ ql2/src/engine/mod.rs | 2254 +++++++++++++++++++++++++++++ ql2/src/engine/replay_cache.rs | 189 +++ ql2/src/engine/state.rs | 507 +++++++ ql2/src/engine/stream.rs | 302 ++++ ql2/src/id.rs | 58 + ql2/src/lib.rs | 51 + ql2/src/platform.rs | 37 + ql2/src/rpc/client.rs | 70 + ql2/src/rpc/mod.rs | 153 ++ ql2/src/rpc/modality.rs | 35 + ql2/src/rpc/server.rs | 1 + ql2/src/runtime/command.rs | 55 + ql2/src/runtime/driver.rs | 723 +++++++++ ql2/src/runtime/handle.rs | 406 ++++++ ql2/src/runtime/mod.rs | 82 ++ ql2/src/runtime/pipe.rs | 772 ++++++++++ ql2/src/tests/handshake.rs | 99 ++ ql2/src/tests/heartbeat.rs | 455 ++++++ ql2/src/tests/mod.rs | 1027 +++++++++++++ ql2/src/tests/persistence.rs | 139 ++ ql2/src/tests/rpc.rs | 264 ++++ ql2/src/tests/stream.rs | 1685 +++++++++++++++++++++ ql2/src/tests/unpair.rs | 137 ++ ql2/src/wire/codec.rs | 308 ++++ ql2/src/wire/encrypted_message.rs | 63 + ql2/src/wire/handshake/crypto.rs | 188 +++ ql2/src/wire/handshake/mod.rs | 50 + ql2/src/wire/heartbeat/crypto.rs | 39 + ql2/src/wire/heartbeat/mod.rs | 12 + ql2/src/wire/mod.rs | 128 ++ ql2/src/wire/pair/crypto.rs | 147 ++ ql2/src/wire/pair/mod.rs | 30 + ql2/src/wire/stream/crypto.rs | 39 + ql2/src/wire/stream/mod.rs | 247 ++++ ql2/src/wire/unpair/crypto.rs | 65 + ql2/src/wire/unpair/mod.rs | 16 + 41 files changed, 11306 insertions(+), 1 deletion(-) create mode 100644 ql2/Cargo.toml create mode 100644 ql2/README.md create mode 100644 ql2/ql-v2.presenterm.md create mode 100644 ql2/src/engine/mod.rs create mode 100644 ql2/src/engine/replay_cache.rs create mode 100644 ql2/src/engine/state.rs create mode 100644 ql2/src/engine/stream.rs create mode 100644 ql2/src/id.rs create mode 100644 ql2/src/lib.rs create mode 100644 ql2/src/platform.rs create mode 100644 ql2/src/rpc/client.rs create mode 100644 ql2/src/rpc/mod.rs create mode 100644 ql2/src/rpc/modality.rs create mode 100644 ql2/src/rpc/server.rs create mode 100644 ql2/src/runtime/command.rs create mode 100644 ql2/src/runtime/driver.rs create mode 100644 ql2/src/runtime/handle.rs create mode 100644 ql2/src/runtime/mod.rs create mode 100644 ql2/src/runtime/pipe.rs create mode 100644 ql2/src/tests/handshake.rs create mode 100644 ql2/src/tests/heartbeat.rs create mode 100644 ql2/src/tests/mod.rs create mode 100644 ql2/src/tests/persistence.rs create mode 100644 ql2/src/tests/rpc.rs create mode 100644 ql2/src/tests/stream.rs create mode 100644 ql2/src/tests/unpair.rs create mode 100644 ql2/src/wire/codec.rs create mode 100644 ql2/src/wire/encrypted_message.rs create mode 100644 ql2/src/wire/handshake/crypto.rs create mode 100644 ql2/src/wire/handshake/mod.rs create mode 100644 ql2/src/wire/heartbeat/crypto.rs create mode 100644 ql2/src/wire/heartbeat/mod.rs create mode 100644 ql2/src/wire/mod.rs create mode 100644 ql2/src/wire/pair/crypto.rs create mode 100644 ql2/src/wire/pair/mod.rs create mode 100644 ql2/src/wire/stream/crypto.rs create mode 100644 ql2/src/wire/stream/mod.rs create mode 100644 ql2/src/wire/unpair/crypto.rs create mode 100644 ql2/src/wire/unpair/mod.rs diff --git a/Cargo.lock b/Cargo.lock index e6733b10..ff756917 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -127,6 +127,12 @@ version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c59bdb34bc650a32731b31bd8f0829cc15d24a708ee31559e0bb34f2bc320cba" +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "autocfg" version = "1.5.0" @@ -1955,6 +1961,22 @@ dependencies = [ "tokio", ] +[[package]] +name = "ql2" +version = "0.1.0" +dependencies = [ + "async-channel", + "atomic-waker", + "bc-components", + "chacha20poly1305", + "dcbor", + "futures-lite", + "oneshot", + "rkyv", + "thiserror", + "tokio", +] + [[package]] name = "quantum-link-macros" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 80d9c698..f5f3fbec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["api", "backup-shard", "btp", "ql", "quantum-link-macros"] +members = ["api", "backup-shard", "btp", "ql", "ql2", "quantum-link-macros"] [workspace.package] homepage = "https://github.com/Foundation-Devices/foundation-api" diff --git a/ql2/Cargo.toml b/ql2/Cargo.toml new file mode 100644 index 00000000..d268ca9b --- /dev/null +++ b/ql2/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "ql2" +version = "0.1.0" +edition = "2021" +description = "Quantum Link v2 duplex stream prototype" +license = "Proprietary" + +[dependencies] +async-channel = { version = "2.5" } +atomic-waker = { version = "1.1" } +bc-components = { version = "0.28.0", default-features = false, features = [ + "pqcrypto", +] } +chacha20poly1305 = { version = "0.10.1" } +dcbor = { version = "0.23.3" } +futures-lite = { version = "2.5" } +oneshot = { version = "0.1.11" } +rkyv = { version = "0.8", default-features = false, features = ["std", "bytecheck", "little_endian", "unaligned", "pointer_width_32"] } +thiserror = { version = "2" } + +[dev-dependencies] +tokio = { version = "1.44", features = ["macros", "rt", "time", "sync"] } diff --git a/ql2/README.md b/ql2/README.md new file mode 100644 index 00000000..d39e4e90 --- /dev/null +++ b/ql2/README.md @@ -0,0 +1,143 @@ +# QL Protocol (v2) + +QL is a compact, session-oriented protocol for authenticated and encrypted messaging +between peers over arbitrary transports. It targets low-bandwidth and high-latency +links while preserving strong cryptography, explicit request/response semantics, and +a clean developer-facing API. + +This crate (`ql`) implements the protocol stack: wire format, crypto, runtime state +machine, and routing. For a deeper comparison with v1, see `ql-protocol-v2.md`. + +## features +- Fixed CBOR wire format: `QlRecord` = `[tag, header, payload]`. +- Mutual-auth handshake (`Hello`, `HelloReply`, `Confirm`) with signed transcript. +- Session keys derived from KEM secrets; payloads use AEAD (ChaCha20-Poly1305). +- Sessions are ephemeral and scoped to a handshake; no long-lived symmetric keys. +- First-contact pairing request with KEM-wrapped payloads and proof signature. +- Encrypted messages with explicit `Request`, `Response`, `Event`, and `Nack`. +- `MessageId`, `RouteId`, and `valid_until` for routing and freshness. +- Heartbeats for keepalive and disconnect detection. +- Runtime state machine for sessions, timeouts, outbound queues, and correlation. +- Router for typed dispatch and automatic response wiring. +- Transport abstraction via `QlPlatform` for BLE, TCP, or other links. + +## overview +QL provides a full session protocol rather than isolated message sealing. It covers: +- Mutual authentication and end-to-end encryption above the transport. +- First-contact pairing for provisioning keys and establishing trust. +- Typed routing with explicit request/response/event semantics. +- Runtime lifecycle management (handshake, keepalive, timeouts, errors). +- Portability across transports via a minimal platform abstraction. + +### security +- Mutual authentication via a signed handshake transcript. +- Session keys derived from KEM secrets; payloads are protected with AEAD + and header AAD. +- End-to-end protection above the transport layer; pairing supports first-contact + key exchange and proof of key possession. +- Message freshness enforced via `valid_until`; replay caching is not built-in, + so applications can optionally track `MessageId` if needed. + +### session vs per-message sealing +- v1 (gstp + envelope) signs every message and then encrypts it to the recipient. + each message uses fresh encapsulation, so keys and signatures are per-message. +- v2 (ql) signs the handshake transcript once, derives a session key, then uses + AEAD for each message with the header as AAD. +- encryption strength uses the same primitive (ChaCha20-Poly1305). post-quantum + security depends on key schemes (ML-KEM + ML-DSA with `pqcrypto` enabled). +- tradeoffs: v2 is faster and smaller; v1 has per-message signature and key + isolation. v2's AEAD provides in-session integrity but is not publicly + verifiable and has a larger blast radius if a session key leaks. + +### performance +- Public-key operations are paid once per session; steady-state traffic is + symmetric AEAD. +- Compact CBOR record framing keeps headers and serialization overhead small. +- Optional heartbeats provide liveness detection without heavy traffic. + +### developer experience +- Typed routes via `RequestResponse` and `Event` traits with explicit `RouteId`. +- Router handles decode, dispatch, and response wiring automatically. +- Runtime manages sessions, timeouts, outbound queues, and request correlation. +- `QlPlatform` abstracts the transport for portability and testability. + +## message sizes +Sizes below are CBOR record sizes from `protocol_record_size_breakdown` in +`ql/src/tests/mod.rs`. + +| Record | Size (bytes) | +| :-- | --: | +| Handshake Hello | 132 | +| Handshake HelloReply | 2563 | +| Handshake Confirm | 2510 | +| Pair request | 4065 | +| Message (empty payload) | 199 | +| Heartbeat | 196 | + +Handshake total is 5205 bytes (132 + 2563 + 2510). At 20 kBps transport +throughput, raw transmit time is about 0.26 s. + +## protocol overview + +### record framing +All traffic is encoded as a `QlRecord` with a small, fixed shape: +- `tag` selects the payload type (handshake, pair, record, heartbeat). +- `header` is unencrypted but authenticated data (AEAD AAD) used for routing + (sender and recipient XIDs). +- `payload` is a CBOR-encoded handshake/pair body or an encrypted message. + +### handshake +The handshake is a three-message exchange: +- `Hello`: initiator sends a nonce and KEM ciphertext. +- `HelloReply`: responder returns its nonce, KEM ciphertext, and a signature + over the transcript. +- `Confirm`: initiator signs the transcript to confirm mutual authentication. + +Both sides derive the session key from the KEM secrets and transcript digest. +After the handshake, all records use symmetric AEAD with the header as AAD. + +### pairing (first-contact) +Pairing is a standalone request that KEM-encrypts a payload containing: +- a `MessageId` and `valid_until` timestamp +- the sender's signing and encapsulation public keys +- a proof signature binding those keys + +This enables establishing trust without an existing session. + +### message records +Steady-state messages are sent as encrypted records with a typed body: +- `MessageKind`: `Request`, `Response`, `Event`, or `Nack` +- `MessageId`, `RouteId`, `valid_until`, and CBOR payload + +Nacks communicate standard failure reasons (unknown route, invalid payload, +expired) so peers can recover consistently. + +### heartbeats +Heartbeats are lightweight encrypted records used by the runtime to maintain +session liveness and detect disconnects. + +### routing and dispatch +`RouteId` maps to concrete request/response or event types. The router decodes +payloads, dispatches handlers, and ensures each request receives a response or +a `Nack`. + +### sequence diagram +```mermaid +sequenceDiagram + participant A as Initiator + participant B as Responder + A->>B: Hello (nonce, KEM ct) + B->>A: HelloReply (nonce, KEM ct, signature) + A->>B: Confirm (signature) + Note over A,B: Session key derived, AEAD enabled + A->>B: Encrypted Record (Request) + B->>A: Encrypted Record (Response) + A-->>B: Encrypted Heartbeat (optional) +``` + +## code map +- Wire format: `ql/src/wire/*` +- Cryptography: `ql/src/crypto/*` +- Runtime state machine: `ql/src/runtime/*` +- Routing and traits: `ql/src/router.rs`, `ql/src/lib.rs` +- Transport abstraction: `ql/src/platform.rs` diff --git a/ql2/ql-v2.presenterm.md b/ql2/ql-v2.presenterm.md new file mode 100644 index 00000000..d4a0fff2 --- /dev/null +++ b/ql2/ql-v2.presenterm.md @@ -0,0 +1,285 @@ +--- +theme: + name: gruvbox-dark +--- + +# quantumlink protocol v2 +post-quantum, session-based message protocol + + + +# ql v1: constraints +- no message id / sequence id +- no protocol-level request/response pairing +- each platform had to interpret + correlate by hand +- no ack/nack +- no notion of 'liveness'/'connected' status +- ~6.6KB min sealed event + - sender xid document (pq pubkeys) + - per-message signature + - recipient encryption (+ continuations) +- more a utility crate than a protocol + + + +# v1 vs v2 + + + + +## v1 +- gstp sealed envelope per message +- per-message sign+encrypt (envelope) +- implicit req/resp in enum variants +- app-owned pairing, timeouts, keepalive, connected status + + + +## v2 +- compact record + typed payloads +- handshake signatures + per‑message aead under symmetric session key +- explicit kind + ids + nack +- runtime handles pairing, timeouts, keepalive, connected status, request/response matching + + + +# design shift: per-message -> session +- v1 sealed each message +- v2 signs once, then aead per message + +```text +v1: seal(msg) = sign(msg) + encrypt(recipient) +v2: session_key = handshake() +v2: aead(msg, aad=header) +``` + + +_aead = authenticated encryption with associated data_ + +_aad = additional authenticated data (visible, integrity-protected)_ + + + +# configurable host platform +- same runtime across keyos / mobile / desktop +- host supplies pq keys, io, timers, callbacks + +```rust +pub trait QlPlatform { + // pq identity + fn signing_private_key(&self) -> &MLDSAPrivateKey; + fn signing_public_key(&self) -> &MLDSAPublicKey; + fn encapsulation_private_key(&self) -> &MLKEMPrivateKey; + fn encapsulation_public_key(&self) -> &MLKEMPublicKey; + + // transport + runtime hooks + fn fill_bytes(&self, data: &mut [u8]); + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>>; + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()>; + + // event handlers + fn handle_peer_status(&self, peer: XID, session: &PeerSession); + fn handle_inbound(&self, event: HandlerEvent); +} +``` + + + +# multi-peer runtime +- runtime tracks sessions per peer +- concurrent handshakes + keepalive per peer + +```rust +handle.register_peer(peer, signing_key, encapsulation_key); +handle.connect(peer)?; +``` + + + +# protocol breakdown +```mermaid +render +width:90% +sequenceDiagram + participant A as initiator + participant B as responder + + Note over A,B: pairing (first contact) + A->>B: pair request (kem + signed payload) + + Note over A,B: handshake (mutual auth) + A->>B: hello (nonce + kem ct) + B->>A: hello reply (nonce + kem ct + signature) + A->>B: confirm (signature) + + Note over A,B: session established + A->>B: request / event (aead + aad header) + B->>A: response / nack (aead + aad header) + A-->>B: heartbeat (optional) +``` + + + +# wire framing: routable header +- record = [tag, header, payload] +- header is unencrypted but authenticated (aad) + +```rust +pub struct QlRecord { + pub header: QlHeader, + pub payload: QlPayload, +} + +pub struct QlHeader { + pub sender: XID, + pub recipient: XID, +} +``` + + + +# handshake flow + records +- hello: nonce + mlkem ciphertext +- reply: nonce + mlkem ciphertext + mldsa signature +- confirm: mldsa signature, then session key + +```rust +pub struct Hello { + pub nonce: Nonce, + pub kem_ct: MLKEMCiphertext, +} + +pub struct HelloReply { + pub nonce: Nonce, + pub kem_ct: MLKEMCiphertext, + pub signature: MLDSASignature, +} + +pub struct Confirm { + pub signature: MLDSASignature, +} +``` + + + +# session key derivation +- transcript binds ids + nonces + kem ciphertexts +- session key = digest(initiator_secret, responder_secret, transcript) + +```rust +let transcript = cbor([ + initiator, responder, + hello.nonce, reply.nonce, + hello.kem_ct, reply.kem_ct, +]); +let payload = cbor([initiator_secret, responder_secret, transcript]); +let digest = Digest::from_image(payload); +let session_key = SymmetricKey::from_data(*digest.data()); +``` + + + +# message modalities +- request / response +- event: fire-and-forget or acked +- nack for structured failure + +```rust +pub enum MessageKind { + Request, + Response, + Event, + Nack, +} +``` + + + +# message body: routing + expiry +- message_id + route_id +- valid_until for freshness + +```rust +pub struct MessageBody { + pub message_id: MessageId, + pub valid_until: u64, + pub kind: MessageKind, + pub route_id: RouteId, + pub payload: CBOR, +} +``` + + + +# nack reasons +- unknown route / invalid payload / expired + +```rust +pub enum Nack { + Unknown, + UnknownRoute, + InvalidPayload, + Expired, +} +``` + + + +# type-safe routing +- route id is const per type +- compiler couples request -> response + +```rust +pub trait RequestResponse: QlCodec { + const ID: RouteId; + type Response: QlCodec; +} + +pub trait Event: QlCodec { + const ID: RouteId; +} +``` + + + +# router wiring +- builder ties route ids to handlers +- unknown routes auto-nack + +```rust +let router = Router::builder() + .add_request_handler::() + .add_event_handler::() + .build(state); +``` + + + +# runtime api flow +- request returns response or nack +- events are fire-and-forget (or acked) + +```rust +let reply = handle.request(msg, peer, RequestConfig::default()).await?; +handle.send_event(status, peer); +``` + + + +# performance snapshot (cbor sizes) +| proto | message | bytes | notes | +| :-- | :-- | --: | :-- | +| v1 | sealed msg (exchange_rate) | 6645 | sign+encrypt | +| v1 | sealed heartbeat | 6633 | sign+encrypt | +| v2 | hello | 132 | kem+nonce | +| v2 | hello reply | 2563 | sig+kem | +| v2 | confirm | 2510 | sig | +| v2 | pair request | 4065 | sig+kem | +| v2 | message (empty) | 199 | steady-state | +| v2 | heartbeat | 196 | steady-state | + +handshake total: 5205 bytes + + + +# close +- smaller packets, clearer flow, typed api +- ql v2 is the protocol, not just a crate diff --git a/ql2/src/engine/mod.rs b/ql2/src/engine/mod.rs new file mode 100644 index 00000000..b167478b --- /dev/null +++ b/ql2/src/engine/mod.rs @@ -0,0 +1,2254 @@ +pub mod replay_cache; +mod state; +mod stream; + +use std::{ + cmp::Reverse, + collections::HashMap, + mem, + time::{Duration, Instant}, +}; + +use bc_components::{SigningPublicKey, XID}; +use rkyv::access_mut; +pub use state::{ + Engine, EngineInput, EngineOutput, EngineState, InitiatorStage, KeepAliveState, OpenId, + OutputFn, PeerRecord, PeerSession, Token, TrackedWrite, +}; + +use self::{ + replay_cache::{ReplayKey, ReplayNamespace}, + state::*, + stream::*, +}; +use crate::{ + platform::QlCrypto, + runtime::StreamConfig, + wire::{ + self, + encrypted_message::{ArchivedEncryptedMessage, NONCE_SIZE}, + handshake::{self, HandshakeRecord, Hello}, + heartbeat::{self, HeartbeatBody}, + stream::{ + decrypt_stream, encrypt_stream, Direction, PacketAck, RejectCode, ResetCode, + ResetTarget, StreamBody, StreamFrame, StreamFrameAccept, StreamFrameCredit, + StreamFrameData, StreamFrameFinish, StreamFrameOpen, StreamFrameReject, + StreamFrameReset, + }, + unpair::{self}, + QlHeader, QlPayload, QlRecord, + }, + PacketId, Peer, QlError, StreamId, +}; + +#[derive(Debug, Clone, Copy)] +pub struct KeepAliveConfig { + pub interval: Duration, + pub timeout: Duration, +} + +#[derive(Debug, Clone, Copy)] +pub struct EngineConfig { + pub handshake_timeout: Duration, + pub default_open_timeout: Duration, + pub packet_expiration: Duration, + pub packet_ack_timeout: Duration, + pub stream_retry_limit: u8, + pub max_payload_bytes: usize, + pub initial_credit: u64, + pub keep_alive: Option, +} + +impl Default for EngineConfig { + fn default() -> Self { + Self { + handshake_timeout: Duration::from_secs(5), + default_open_timeout: Duration::from_secs(5), + packet_expiration: Duration::from_secs(30), + packet_ack_timeout: Duration::from_millis(150), + stream_retry_limit: 5, + max_payload_bytes: 1024, + initial_credit: 1024, + keep_alive: None, + } + } +} + +impl EngineConfig { + pub(crate) fn normalized(mut self) -> Self { + self.max_payload_bytes = self.max_payload_bytes.max(1); + self.initial_credit = self.initial_credit.max(self.max_payload_bytes as u64); + self + } +} + +impl Engine { + pub fn new(config: EngineConfig, peer: Option) -> Self { + Self { + config: config.normalized(), + state: EngineState::new(peer), + streams: HashMap::new(), + } + } + + pub fn run_tick( + &mut self, + now: Instant, + input: EngineInput, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, + ) { + match input { + EngineInput::BindPeer(peer) => self.handle_bind_peer(peer, emit), + EngineInput::Pair => self.handle_pair_local(now, crypto), + EngineInput::Connect => self.handle_connect(now, crypto, emit), + EngineInput::Unpair => self.handle_unpair_local(now, crypto, emit), + EngineInput::OpenStream { + open_id, + request_head, + config, + } => self.handle_open_stream(now, open_id, request_head, config, emit), + EngineInput::AcceptStream { + stream_id, + response_head, + } => self.handle_accept_stream(now, stream_id, response_head), + EngineInput::RejectStream { stream_id, code } => { + self.handle_reject_stream(now, stream_id, code) + } + EngineInput::OutboundData { + stream_id, + dir, + offset, + bytes, + } => self.handle_outbound_data(stream_id, dir, offset, bytes), + EngineInput::OutboundFinished { + stream_id, + dir, + final_offset, + } => self.handle_outbound_finished(stream_id, dir, final_offset), + EngineInput::InboundConsumed { + stream_id, + dir, + amount, + } => self.handle_inbound_consumed(now, stream_id, dir, amount), + EngineInput::ResetOutbound { + stream_id, + dir, + code, + } => self.handle_reset_outbound(now, stream_id, dir, code), + EngineInput::ResetInbound { + stream_id, + dir, + code, + } => self.handle_reset_inbound(now, stream_id, dir, code), + EngineInput::PendingAcceptDropped { stream_id } => { + self.handle_pending_accept_dropped(stream_id, emit) + } + EngineInput::ResponderDropped { stream_id } => { + self.handle_responder_dropped(now, stream_id) + } + EngineInput::Incoming(bytes) => self.handle_incoming(now, bytes, crypto, emit), + EngineInput::WriteCompleted { + token, + tracked, + result, + } => self.handle_write_done(now, token, tracked, result, emit), + EngineInput::TimerExpired => self.handle_timeouts(now, crypto, emit), + } + + self.drive_streams(now, emit); + self.maybe_start_next_write(crypto, emit); + emit(EngineOutput::SetTimer(self.state.next_deadline())); + } + + fn emit_peer_status(&self, emit: &mut impl OutputFn) { + if let Some(peer) = self.state.peer.as_ref() { + emit(EngineOutput::PeerStatusChanged { + peer: peer.peer, + session: peer.session.clone(), + }); + } + } + + fn bind_peer_record(&mut self, peer: Peer, emit: &mut impl OutputFn) { + self.reset_runtime(QlError::Cancelled, emit); + self.state.peer = Some(PeerRecord::new( + peer.peer, + peer.signing_key, + peer.encapsulation_key, + )); + self.emit_peer_status(emit); + if let Some(peer) = self.state.peer.as_ref() { + emit(EngineOutput::PersistPeer(peer.snapshot())); + } + } + + fn reset_runtime(&mut self, error: QlError, emit: &mut impl OutputFn) { + let streams = mem::take(&mut self.streams); + for (stream_id, stream) in streams { + self.fail_stream(stream_id, stream, error.clone(), emit); + } + self.state.outbound.clear(); + self.state.timeouts.clear(); + self.state.write_in_flight = None; + if let Some(peer) = self.state.peer.as_ref().map(|peer| peer.peer) { + self.state.replay_cache.clear_peer(peer); + } + } + + fn handle_bind_peer(&mut self, peer: Peer, emit: &mut impl OutputFn) { + if let Some(existing) = self.state.peer.as_ref() { + emit(EngineOutput::PeerStatusChanged { + peer: existing.peer, + session: PeerSession::Disconnected, + }); + } + self.bind_peer_record(peer, emit); + } + + fn handle_pair_local(&mut self, now: Instant, crypto: &impl QlCrypto) { + let Some(peer) = self.state.peer.as_ref() else { + return; + }; + let Ok(record) = wire::pair::build_pair_request( + crypto, + peer.peer, + &peer.encapsulation_key, + self.state.next_packet_id(), + self.config.packet_expiration, + ) else { + return; + }; + let token = self.state.next_token(); + self.enqueue_handshake_message( + token, + now + self.config.packet_expiration, + wire::encode_record(&record), + ); + } + + fn handle_connect(&mut self, now: Instant, crypto: &impl QlCrypto, emit: &mut impl OutputFn) { + let Some(peer_record) = self.state.peer.as_ref() else { + return; + }; + let peer = peer_record.peer; + let (hello, session_key) = match &peer_record.session { + PeerSession::Connected { .. } + | PeerSession::Initiator { .. } + | PeerSession::Responder { .. } => { + return; + } + PeerSession::Disconnected => { + match handshake::build_hello( + crypto, + crypto.xid(), + peer, + &peer_record.encapsulation_key, + ) { + Ok(result) => result, + Err(_) => return, + } + } + }; + + let deadline = now + self.config.handshake_timeout; + let token = self.state.next_token(); + if let Some(entry) = self.state.peer.as_mut() { + entry.session = PeerSession::Initiator { + handshake_token: token, + hello: hello.clone(), + session_key, + deadline, + stage: InitiatorStage::WaitingHelloReply, + }; + } + self.emit_peer_status(emit); + + let record = QlRecord { + header: QlHeader { + sender: crypto.xid(), + recipient: peer, + }, + payload: QlPayload::Handshake(HandshakeRecord::Hello(hello)), + }; + self.enqueue_handshake_message(token, deadline, wire::encode_record(&record)); + } + + fn handle_unpair_local( + &mut self, + now: Instant, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, + ) { + let Some(peer) = self.state.peer.as_ref().map(|peer| peer.peer) else { + return; + }; + let record = unpair::build_unpair_record( + crypto, + QlHeader { + sender: crypto.xid(), + recipient: peer, + }, + self.state.next_packet_id(), + wire::now_secs().saturating_add(self.config.packet_expiration.as_secs()), + ); + self.unpair_peer(emit); + let token = self.state.next_token(); + self.enqueue_handshake_message( + token, + now + self.config.packet_expiration, + wire::encode_record(&record), + ); + } + + fn handle_open_stream( + &mut self, + now: Instant, + open_id: OpenId, + request_head: Vec, + config: StreamConfig, + emit: &mut impl OutputFn, + ) { + let Some(entry) = self.state.peer.as_ref() else { + emit(EngineOutput::OpenFailed { + open_id, + stream_id: StreamId(0), + error: QlError::NoPeerBound, + }); + return; + }; + if !entry.session.is_connected() { + emit(EngineOutput::OpenFailed { + open_id, + stream_id: StreamId(0), + error: QlError::MissingSession, + }); + return; + } + + let stream_id = self.state.next_stream_id(); + let open_timeout = config + .open_timeout + .unwrap_or(self.config.default_open_timeout); + let token = self.state.next_token(); + let frame = StreamFrameOpen { + stream_id, + request_head: request_head.clone(), + response_max_offset: self.config.initial_credit, + }; + let stream = StreamState::Initiator(InitiatorStream { + meta: StreamMeta { + key: StreamKey { stream_id }, + request_head, + last_activity: now, + }, + control: StreamControl { + pending: PendingFrames { + setup: Some(SetupFrame::Open(frame)), + credit: None, + reset: None, + }, + awaiting: None, + }, + request: OutboundState::new(Direction::Request, self.config.initial_credit, true), + response: InboundState::new(self.config.initial_credit), + accept: InitiatorAccept::Opening(OpenWaiter { + open_id: Some(open_id), + open_timeout_token: token, + }), + }); + self.streams.insert(stream_id, stream); + self.state.timeouts.push(Reverse(TimeoutEntry { + at: now + open_timeout, + kind: TimeoutKind::StreamOpen { stream_id, token }, + })); + emit(EngineOutput::OpenStarted { open_id, stream_id }); + } + + fn handle_accept_stream(&mut self, now: Instant, stream_id: StreamId, response_head: Vec) { + let Some(StreamState::Responder(stream)) = self.streams.get_mut(&stream_id) else { + return; + }; + let ResponderResponse::Pending { initial_credit } = stream.response else { + return; + }; + stream + .control + .pending + .set_setup(SetupFrame::Accept(StreamFrameAccept { + stream_id, + response_head, + request_max_offset: self.config.initial_credit, + })); + stream.request.max_offset = self.config.initial_credit; + stream.response = ResponderResponse::Accepted { + initial_credit, + body: OutboundState::new(Direction::Response, initial_credit, false), + }; + stream.meta.last_activity = now; + } + + fn handle_reject_stream(&mut self, now: Instant, stream_id: StreamId, code: RejectCode) { + let Some(StreamState::Responder(stream)) = self.streams.get_mut(&stream_id) else { + return; + }; + let ResponderResponse::Pending { initial_credit } = stream.response else { + return; + }; + stream + .control + .pending + .set_setup(SetupFrame::Reject(StreamFrameReject { stream_id, code })); + stream.response = ResponderResponse::Rejecting { initial_credit }; + stream.meta.last_activity = now; + } + + fn handle_outbound_data( + &mut self, + stream_id: StreamId, + dir: Direction, + offset: u64, + bytes: Vec, + ) { + if bytes.is_empty() { + return; + } + let (streams, state) = (&mut self.streams, &mut self.state); + let Some(stream) = streams.get_mut(&stream_id) else { + return; + }; + let Some(outbound) = stream.outbound_mut(dir) else { + return; + }; + let Some(pull) = outbound.pending_pull.take() else { + return; + }; + if pull.offset != offset { + outbound.pending_pull = Some(pull); + return; + } + if bytes.len() > pull.max_len { + outbound.pending_pull = Some(pull); + return; + } + outbound.sent_offset = outbound + .sent_offset + .max(offset.saturating_add(bytes.len() as u64)); + let key = stream.key(); + let control = stream.control_mut(); + state.enqueue_data_frame(&self.config, key, control, dir, offset, bytes, 0); + } + + fn handle_outbound_finished(&mut self, stream_id: StreamId, dir: Direction, final_offset: u64) { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + let Some(outbound) = stream.outbound_mut(dir) else { + return; + }; + if final_offset < outbound.sent_offset { + return; + } + outbound.pending_pull = None; + outbound.final_offset = Some(final_offset); + } + + fn handle_inbound_consumed( + &mut self, + now: Instant, + stream_id: StreamId, + dir: Direction, + amount: u64, + ) { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + let Some(inbound) = stream.inbound_mut(dir) else { + return; + }; + if inbound.closed { + return; + } + inbound.max_offset = inbound.max_offset.saturating_add(amount); + Self::queue_credit(stream, dir); + *stream.last_activity_mut() = now; + } + + fn handle_reset_outbound( + &mut self, + now: Instant, + stream_id: StreamId, + dir: Direction, + code: ResetCode, + ) { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + let Some(outbound) = stream.outbound_mut(dir) else { + return; + }; + if outbound.closed { + return; + } + outbound.closed = true; + outbound.pending_pull = None; + stream + .control_mut() + .pending + .set_reset(reset_target_for_dir(dir), code); + *stream.last_activity_mut() = now; + } + + fn handle_reset_inbound( + &mut self, + now: Instant, + stream_id: StreamId, + dir: Direction, + code: ResetCode, + ) { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + let Some(inbound) = stream.inbound_mut(dir) else { + return; + }; + if inbound.closed { + return; + } + inbound.closed = true; + stream + .control_mut() + .pending + .set_reset(reset_target_for_dir(dir), code); + *stream.last_activity_mut() = now; + } + + fn handle_responder_dropped(&mut self, now: Instant, stream_id: StreamId) { + self.handle_reject_stream(now, stream_id, RejectCode::Unhandled); + } + + fn handle_pending_accept_dropped(&mut self, stream_id: StreamId, emit: &mut impl OutputFn) { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + if let StreamState::Initiator(stream) = stream { + match &mut stream.accept { + InitiatorAccept::Opening(waiter) | InitiatorAccept::WaitingAccept(waiter) => { + waiter.open_id = None; + } + InitiatorAccept::Open { .. } => {} + } + } + self.maybe_reap_stream(stream_id, emit); + } + + fn handle_incoming( + &mut self, + now: Instant, + mut bytes: Vec, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, + ) { + let Ok(record) = access_mut::(&mut bytes) + else { + return; + }; + let record = unsafe { record.unseal_unchecked() }; + let sender = wire::xid_from_archived(&record.header.sender); + let recipient = wire::xid_from_archived(&record.header.recipient); + if recipient != crypto.xid() { + return; + } + if !matches!(&record.payload, wire::ArchivedQlPayload::Pair(_)) { + let Some(peer) = self.state.peer.as_ref().map(|peer| peer.peer) else { + return; + }; + if sender != peer { + return; + } + } + let Ok(header) = wire::deserialize_value(&record.header) else { + return; + }; + match &mut record.payload { + wire::ArchivedQlPayload::Handshake(message) => { + self.handle_handshake(now, sender, message, crypto, emit) + } + wire::ArchivedQlPayload::Stream(encrypted) => { + self.handle_stream(now, sender, &header, encrypted, emit) + } + wire::ArchivedQlPayload::Heartbeat(encrypted) => { + self.handle_heartbeat(now, &header, encrypted, crypto, emit) + } + wire::ArchivedQlPayload::Pair(request) => { + self.handle_pairing(now, &header, request, crypto, emit) + } + wire::ArchivedQlPayload::Unpair(unpair_record) => { + self.handle_unpair(sender, &header, unpair_record, emit) + } + } + } + + fn handle_handshake( + &mut self, + now: Instant, + peer: XID, + message: &wire::handshake::ArchivedHandshakeRecord, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, + ) { + match message { + wire::handshake::ArchivedHandshakeRecord::Hello(hello) => { + self.handle_hello(now, peer, hello, crypto, emit) + } + wire::handshake::ArchivedHandshakeRecord::HelloReply(reply) => { + self.handle_hello_reply(now, peer, reply, crypto, emit) + } + wire::handshake::ArchivedHandshakeRecord::Confirm(confirm) => { + self.handle_confirm(now, peer, confirm, crypto, emit) + } + } + } + + fn handle_pairing( + &mut self, + now: Instant, + header: &QlHeader, + request: &mut wire::pair::ArchivedPairRequestRecord, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, + ) { + let payload = match wire::pair::decrypt_pair_request(crypto, header, request) { + Ok(payload) => payload, + Err(_) => return, + }; + let peer = XID::new(SigningPublicKey::MLDSA(payload.signing_pub_key.clone())); + if let Some(existing) = self.state.peer.as_ref() { + if existing.peer != peer + || existing.signing_key != payload.signing_pub_key + || existing.encapsulation_key != payload.encapsulation_pub_key + { + return; + } + } else { + self.bind_peer_record( + Peer { + peer, + signing_key: payload.signing_pub_key, + encapsulation_key: payload.encapsulation_pub_key, + }, + emit, + ); + } + self.handle_connect(now, crypto, emit); + } + + fn handle_unpair( + &mut self, + peer: XID, + header: &QlHeader, + record: &wire::unpair::ArchivedUnpairRecord, + emit: &mut impl OutputFn, + ) { + { + let Some(peer_record) = self.state.peer.as_ref() else { + return; + }; + if unpair::verify_unpair_record(header, record, &peer_record.signing_key).is_err() { + return; + } + } + let packet_id: PacketId = (&record.packet_id).into(); + let valid_until = record.valid_until.to_native(); + let replay_key = ReplayKey::new(peer, ReplayNamespace::Peer, packet_id); + if self + .state + .replay_cache + .check_and_store_valid_until(replay_key, valid_until) + { + return; + } + self.unpair_peer(emit); + } + + fn handle_heartbeat( + &mut self, + now: Instant, + header: &QlHeader, + encrypted: &mut ArchivedEncryptedMessage, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, + ) { + let should_reply = { + let Some(peer_record) = self.state.peer.as_ref() else { + return; + }; + let PeerSession::Connected { + session_key, + keepalive, + } = &peer_record.session + else { + return; + }; + if heartbeat::decrypt_heartbeat(header, encrypted, session_key).is_err() { + return; + } + !keepalive.pending + }; + self.record_activity(now); + if should_reply { + self.send_heartbeat_message(now, crypto); + } + self.emit_peer_status(emit); + } + + fn handle_stream( + &mut self, + now: Instant, + peer: XID, + header: &QlHeader, + encrypted: &mut ArchivedEncryptedMessage, + emit: &mut impl OutputFn, + ) { + let body = { + let Some(peer_record) = self.state.peer.as_ref() else { + return; + }; + let PeerSession::Connected { session_key, .. } = &peer_record.session else { + return; + }; + match decrypt_stream(header, encrypted, session_key) { + Ok(body) => body, + Err(_) => return, + } + }; + + if let Some(ack) = body.packet_ack { + self.process_packet_ack(ack.packet_id, emit); + } + + let Some(frame) = body.frame else { + return; + }; + + let replay_key = ReplayKey::new(peer, ReplayNamespace::Transfer, body.packet_id); + if self + .state + .replay_cache + .check_and_store_valid_until(replay_key, body.valid_until) + { + return; + } + + self.record_activity(now); + self.record_stream_activity(stream_id_from_frame(&frame), now); + self.send_packet_ack(body.packet_id); + + match frame { + StreamFrame::Open(frame) => self.handle_stream_open(now, frame, emit), + StreamFrame::Accept(frame) => self.handle_stream_accept_from_peer(now, frame, emit), + StreamFrame::Reject(frame) => self.handle_stream_reject_from_peer(frame, emit), + StreamFrame::Data(frame) => self.handle_stream_data(now, frame, emit), + StreamFrame::Credit(frame) => self.handle_stream_credit(now, frame, emit), + StreamFrame::Finish(frame) => self.handle_stream_finish(now, frame, emit), + StreamFrame::Reset(frame) => self.handle_stream_reset(now, frame, emit), + } + } + + fn handle_stream_open( + &mut self, + now: Instant, + frame: StreamFrameOpen, + emit: &mut impl OutputFn, + ) { + let StreamFrameOpen { + stream_id, + request_head, + response_max_offset, + } = frame; + if let Some(stream) = self.streams.get(&stream_id) { + if self.stream_matches_open(stream, &request_head, response_max_offset) { + return; + } + self.send_ephemeral_reset(stream_id, ResetTarget::Both, ResetCode::Protocol); + return; + } + + let stream = StreamState::Responder(ResponderStream { + meta: StreamMeta { + key: StreamKey { stream_id }, + request_head: request_head.clone(), + last_activity: now, + }, + control: StreamControl::default(), + request: InboundState::new(0), + response: ResponderResponse::Pending { + initial_credit: response_max_offset, + }, + }); + self.streams.insert(stream_id, stream); + emit(EngineOutput::InboundStreamOpened { + stream_id, + request_head, + }); + } + + fn handle_stream_accept_from_peer( + &mut self, + now: Instant, + frame: StreamFrameAccept, + emit: &mut impl OutputFn, + ) { + let StreamFrameAccept { + stream_id, + response_head, + request_max_offset, + } = frame; + let mut protocol = false; + { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + match stream { + StreamState::Initiator(stream) => match &mut stream.accept { + InitiatorAccept::Opening(waiter) => { + if matches!( + stream + .control + .awaiting + .as_ref() + .map(|awaiting| &awaiting.frame), + Some(AwaitingFrame::Control(StreamFrame::Open(_))) + ) { + stream.control.awaiting = None; + } + stream.request.remote_max_offset = request_max_offset; + stream.request.data_enabled = true; + if let Some(open_id) = waiter.open_id.take() { + emit(EngineOutput::OpenAccepted { + open_id, + stream_id, + response_head: response_head.clone(), + }); + } else { + stream.response.closed = true; + stream + .control + .pending + .set_reset(ResetTarget::Response, ResetCode::Cancelled); + } + stream.accept = InitiatorAccept::Open { response_head }; + stream.meta.last_activity = now; + } + InitiatorAccept::WaitingAccept(waiter) => { + stream.request.remote_max_offset = request_max_offset; + stream.request.data_enabled = true; + if let Some(open_id) = waiter.open_id.take() { + emit(EngineOutput::OpenAccepted { + open_id, + stream_id, + response_head: response_head.clone(), + }); + } else { + stream.response.closed = true; + stream + .control + .pending + .set_reset(ResetTarget::Response, ResetCode::Cancelled); + } + stream.accept = InitiatorAccept::Open { response_head }; + stream.meta.last_activity = now; + } + InitiatorAccept::Open { + response_head: stored, + } => { + if *stored != response_head + || stream.request.remote_max_offset != request_max_offset + { + protocol = true; + } + } + }, + _ => protocol = true, + } + } + + if protocol { + self.send_ephemeral_reset(stream_id, ResetTarget::Both, ResetCode::Protocol); + } + } + + fn handle_stream_reject_from_peer( + &mut self, + frame: StreamFrameReject, + emit: &mut impl OutputFn, + ) { + let StreamFrameReject { stream_id, code } = frame; + let mut protocol = false; + let mut remove_after = false; + { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + match stream { + StreamState::Initiator(stream) => match &mut stream.accept { + InitiatorAccept::Opening(waiter) | InitiatorAccept::WaitingAccept(waiter) => { + if let Some(open_id) = waiter.open_id.take() { + emit(EngineOutput::OpenFailed { + open_id, + stream_id, + error: QlError::StreamRejected { code }, + }); + } + emit(EngineOutput::OutboundClosed { + stream_id, + dir: Direction::Request, + }); + emit(EngineOutput::InboundFailed { + stream_id, + dir: Direction::Response, + error: QlError::StreamRejected { code }, + }); + stream.request.closed = true; + stream.response.closed = true; + remove_after = true; + } + InitiatorAccept::Open { .. } => protocol = true, + }, + _ => protocol = true, + } + } + if remove_after { + self.streams.remove(&stream_id); + emit(EngineOutput::StreamReaped { stream_id }); + } + if protocol { + self.send_ephemeral_reset(stream_id, ResetTarget::Both, ResetCode::Protocol); + } + } + + fn handle_stream_data( + &mut self, + now: Instant, + frame: StreamFrameData, + emit: &mut impl OutputFn, + ) { + let StreamFrameData { + stream_id, + dir, + offset, + bytes, + } = frame; + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + Self::note_setup_seen_from_remote(stream); + if dir == Direction::Response + && matches!( + stream, + StreamState::Initiator(InitiatorStream { + accept: InitiatorAccept::Opening(_) | InitiatorAccept::WaitingAccept(_), + .. + }) + ) + { + Self::queue_protocol_reset(stream, emit); + *stream.last_activity_mut() = now; + return; + } + let Some(inbound) = stream.inbound_mut(dir) else { + Self::queue_protocol_reset(stream, emit); + return; + }; + if inbound.closed { + Self::queue_protocol_reset(stream, emit); + } else if offset < inbound.next_offset { + Self::queue_credit(stream, dir); + } else { + let end = offset.saturating_add(bytes.len() as u64); + if offset != inbound.next_offset || end > inbound.max_offset { + Self::queue_protocol_reset(stream, emit); + } else { + inbound.next_offset = end; + emit(EngineOutput::InboundData { + stream_id, + dir, + bytes, + }); + Self::queue_credit(stream, dir); + } + } + *stream.last_activity_mut() = now; + } + + fn handle_stream_credit( + &mut self, + now: Instant, + frame: StreamFrameCredit, + emit: &mut impl OutputFn, + ) { + let StreamFrameCredit { + stream_id, + dir, + recv_offset, + max_offset, + } = frame; + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + Self::note_setup_seen_from_remote(stream); + let Some(outbound) = stream.outbound_mut(dir) else { + Self::queue_protocol_reset(stream, emit); + return; + }; + let released_offset = outbound.released_offset; + let sent_offset = outbound.sent_offset; + if recv_offset < released_offset || recv_offset > sent_offset || max_offset < recv_offset { + Self::queue_protocol_reset(stream, emit); + } else { + outbound.released_offset = recv_offset; + outbound.remote_max_offset = outbound.remote_max_offset.max(max_offset); + emit(EngineOutput::ReleaseOutboundThrough { + stream_id, + dir, + recv_offset, + }); + if matches!( + stream.control().awaiting.as_ref().map(|awaiting| &awaiting.frame), + Some(AwaitingFrame::Data { offset, len, .. }) + if recv_offset >= offset.saturating_add(*len as u64) + ) { + stream.control_mut().awaiting = None; + } + } + *stream.last_activity_mut() = now; + } + + fn handle_stream_finish( + &mut self, + now: Instant, + frame: StreamFrameFinish, + emit: &mut impl OutputFn, + ) { + let StreamFrameFinish { stream_id, dir } = frame; + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + Self::note_setup_seen_from_remote(stream); + let Some(inbound) = stream.inbound_mut(dir) else { + Self::queue_protocol_reset(stream, emit); + return; + }; + if !inbound.closed { + inbound.closed = true; + emit(EngineOutput::InboundFinished { stream_id, dir }); + } + *stream.last_activity_mut() = now; + self.maybe_reap_stream(stream_id, emit); + } + + fn handle_stream_reset( + &mut self, + now: Instant, + frame: StreamFrameReset, + emit: &mut impl OutputFn, + ) { + let StreamFrameReset { + stream_id, + dir, + code, + } = frame; + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + Self::note_setup_seen_from_remote(stream); + Self::apply_remote_reset(stream, dir, code, emit); + *stream.last_activity_mut() = now; + self.maybe_reap_stream(stream_id, emit); + } + + fn process_packet_ack(&mut self, packet_id: PacketId, emit: &mut impl OutputFn) { + let key = self.streams.iter().find_map(|(key, stream)| { + stream + .control() + .awaiting + .as_ref() + .is_some_and(|awaiting| awaiting.packet_id == packet_id) + .then_some(*key) + }); + let Some(key) = key else { + return; + }; + let Some(stream) = self.streams.get_mut(&key) else { + return; + }; + let Some(awaiting) = stream.control_mut().awaiting.take() else { + return; + }; + + let mut reap = false; + match awaiting.frame { + AwaitingFrame::Control(StreamFrame::Open(_)) => { + if let StreamState::Initiator(stream) = stream { + if let InitiatorAccept::Opening(waiter) = &stream.accept { + stream.accept = InitiatorAccept::WaitingAccept(OpenWaiter { + open_id: waiter.open_id, + open_timeout_token: waiter.open_timeout_token, + }); + } + } + } + AwaitingFrame::Control(StreamFrame::Accept(_)) => { + if let StreamState::Responder(stream) = stream { + if let ResponderResponse::Accepted { body, .. } = &mut stream.response { + body.data_enabled = true; + } + } + } + AwaitingFrame::Control(StreamFrame::Reject(_)) => { + reap = true; + } + AwaitingFrame::Control(StreamFrame::Finish(StreamFrameFinish { dir, .. })) => { + if let Some(outbound) = stream.outbound_mut(dir) { + outbound.closed = true; + emit(EngineOutput::OutboundClosed { + stream_id: key, + dir, + }); + } + } + AwaitingFrame::Control(StreamFrame::Reset(StreamFrameReset { dir, code, .. })) => { + for outbound_dir in [Direction::Request, Direction::Response] { + let affects_outbound = matches!( + (dir, outbound_dir), + (ResetTarget::Request, Direction::Request) + | (ResetTarget::Response, Direction::Response) + | (ResetTarget::Both, _) + ); + if affects_outbound { + if let Some(outbound) = stream.outbound_mut(outbound_dir) { + outbound.closed = true; + emit(EngineOutput::OutboundFailed { + stream_id: key, + dir: outbound_dir, + error: QlError::StreamReset { + dir: outbound_dir, + code, + }, + }); + } + } + } + } + AwaitingFrame::Control(StreamFrame::Data(_) | StreamFrame::Credit(_)) => {} + AwaitingFrame::Data { .. } => {} + } + + if reap { + self.maybe_reap_stream(key, emit); + } + } + + fn drive_streams(&mut self, now: Instant, emit: &mut impl OutputFn) { + let config = &self.config; + let state = &mut self.state; + for stream in self.streams.values_mut() { + Self::drive_stream(config, state, now, stream, emit); + } + } + + fn drive_stream( + config: &EngineConfig, + state: &mut EngineState, + _now: Instant, + stream: &mut StreamState, + emit: &mut impl OutputFn, + ) { + match stream { + StreamState::Initiator(stream) => { + let action = Self::plan_drive_outbound( + config, + stream.meta.key, + &mut stream.control, + Some(&mut stream.request), + emit, + ); + if let Some(frame) = action { + state.enqueue_control_frame( + config, + stream.meta.key, + &mut stream.control, + frame, + 0, + ); + } + } + StreamState::Responder(stream) => { + let key = stream.meta.key; + match &mut stream.response { + ResponderResponse::Accepted { body, .. } => { + let action = Self::plan_drive_outbound( + config, + key, + &mut stream.control, + Some(body), + emit, + ); + if let Some(frame) = action { + state.enqueue_control_frame(config, key, &mut stream.control, frame, 0); + } + } + _ => { + let action = + Self::plan_drive_outbound(config, key, &mut stream.control, None, emit); + if let Some(frame) = action { + state.enqueue_control_frame(config, key, &mut stream.control, frame, 0); + } + } + } + } + } + } + + fn plan_drive_outbound( + config: &EngineConfig, + key: StreamKey, + control: &mut StreamControl, + outbound: Option<&mut OutboundState>, + emit: &mut impl OutputFn, + ) -> Option { + let stream_id = key.stream_id; + if control.awaiting.is_some() { + return None; + } + if let Some(frame) = control.pending.take_next_control(stream_id) { + return Some(frame); + } + let outbound = outbound?; + if outbound.can_request_data() { + let max_len = (outbound.remote_max_offset - outbound.sent_offset) + .min(config.max_payload_bytes as u64) as usize; + if max_len > 0 { + outbound.pending_pull = Some(PendingPull { + offset: outbound.sent_offset, + max_len, + }); + emit(EngineOutput::NeedOutboundData { + stream_id, + dir: outbound.dir, + offset: outbound.sent_offset, + max_len, + }); + } + return None; + } + if outbound.data_enabled + && !outbound.closed + && outbound.pending_pull.is_none() + && outbound + .final_offset + .is_some_and(|final_offset| final_offset == outbound.sent_offset) + { + outbound.closed = true; + return Some(StreamFrame::Finish(StreamFrameFinish { + stream_id, + dir: outbound.dir, + })); + } + None + } + + fn queue_credit(stream: &mut StreamState, dir: Direction) { + let stream_id = stream.key().stream_id; + let (recv_offset, max_offset) = { + let Some(inbound) = stream.inbound_mut(dir) else { + return; + }; + (inbound.next_offset, inbound.max_offset) + }; + stream.control_mut().pending.set_credit(StreamFrameCredit { + stream_id, + dir, + recv_offset, + max_offset, + }); + } + + fn queue_protocol_reset(stream: &mut StreamState, emit: &mut impl OutputFn) { + let stream_id = stream.key().stream_id; + stream + .control_mut() + .pending + .set_reset(ResetTarget::Both, ResetCode::Protocol); + for dir in [Direction::Request, Direction::Response] { + if let Some(outbound) = stream.outbound_mut(dir) { + outbound.closed = true; + outbound.pending_pull = None; + emit(EngineOutput::OutboundFailed { + stream_id, + dir, + error: QlError::StreamProtocol, + }); + } + if let Some(inbound) = stream.inbound_mut(dir) { + if !inbound.closed { + inbound.closed = true; + emit(EngineOutput::InboundFailed { + stream_id, + dir, + error: QlError::StreamProtocol, + }); + } + } + } + if let StreamState::Initiator(stream) = stream { + match &mut stream.accept { + InitiatorAccept::Opening(waiter) | InitiatorAccept::WaitingAccept(waiter) => { + if let Some(open_id) = waiter.open_id.take() { + emit(EngineOutput::OpenFailed { + open_id, + stream_id, + error: QlError::StreamProtocol, + }); + } + } + InitiatorAccept::Open { .. } => {} + } + } + } + + fn note_setup_seen_from_remote(stream: &mut StreamState) { + if let StreamState::Responder(stream) = stream { + if matches!( + stream + .control + .awaiting + .as_ref() + .map(|awaiting| &awaiting.frame), + Some(AwaitingFrame::Control(StreamFrame::Accept(_))) + ) { + stream.control.awaiting = None; + if let ResponderResponse::Accepted { body, .. } = &mut stream.response { + body.data_enabled = true; + } + } + if matches!( + stream + .control + .awaiting + .as_ref() + .map(|awaiting| &awaiting.frame), + Some(AwaitingFrame::Control(StreamFrame::Reject(_))) + ) { + stream.control.awaiting = None; + } + } + } + + fn apply_remote_reset( + stream: &mut StreamState, + dir: ResetTarget, + code: ResetCode, + emit: &mut impl OutputFn, + ) { + let stream_id = stream.key().stream_id; + let request_error = QlError::StreamReset { + dir: Direction::Request, + code, + }; + let response_error = QlError::StreamReset { + dir: Direction::Response, + code, + }; + + if matches!(dir, ResetTarget::Request | ResetTarget::Both) { + if let Some(inbound) = stream.inbound_mut(Direction::Request) { + if !inbound.closed { + inbound.closed = true; + emit(EngineOutput::InboundFailed { + stream_id, + dir: Direction::Request, + error: request_error.clone(), + }); + } + } + if let Some(outbound) = stream.outbound_mut(Direction::Request) { + outbound.closed = true; + outbound.pending_pull = None; + emit(EngineOutput::OutboundFailed { + stream_id, + dir: Direction::Request, + error: request_error.clone(), + }); + } + } + if matches!(dir, ResetTarget::Response | ResetTarget::Both) { + if let Some(inbound) = stream.inbound_mut(Direction::Response) { + if !inbound.closed { + inbound.closed = true; + emit(EngineOutput::InboundFailed { + stream_id, + dir: Direction::Response, + error: response_error.clone(), + }); + } + } + if let Some(outbound) = stream.outbound_mut(Direction::Response) { + outbound.closed = true; + outbound.pending_pull = None; + emit(EngineOutput::OutboundFailed { + stream_id, + dir: Direction::Response, + error: response_error.clone(), + }); + } + } + + if let StreamState::Initiator(stream) = stream { + match &mut stream.accept { + InitiatorAccept::Opening(waiter) | InitiatorAccept::WaitingAccept(waiter) => { + if let Some(open_id) = waiter.open_id.take() { + emit(EngineOutput::OpenFailed { + open_id, + stream_id, + error: match dir { + ResetTarget::Request => request_error, + _ => response_error, + }, + }); + } + } + InitiatorAccept::Open { .. } => {} + } + } + } + + fn maybe_reap_stream(&mut self, stream_id: StreamId, emit: &mut impl OutputFn) { + if self + .streams + .get(&stream_id) + .is_some_and(StreamState::can_reap) + { + self.streams.remove(&stream_id); + emit(EngineOutput::StreamReaped { stream_id }); + } + } + + fn stream_matches_open( + &self, + stream: &StreamState, + request_head: &[u8], + response_max_offset: u64, + ) -> bool { + match stream { + StreamState::Responder(state) => match &state.response { + ResponderResponse::Pending { initial_credit } + | ResponderResponse::Accepted { initial_credit, .. } + | ResponderResponse::Rejecting { initial_credit } => { + state.meta.request_head == request_head + && *initial_credit == response_max_offset + } + }, + _ => false, + } + } + + fn send_packet_ack(&mut self, acked_packet: PacketId) { + let packet_id = self.state.next_packet_id(); + let valid_until = wire::now_secs().saturating_add(self.config.packet_expiration.as_secs()); + self.enqueue_stream_body( + None, + None, + false, + true, + StreamBody { + packet_id, + valid_until, + packet_ack: Some(PacketAck { + packet_id: acked_packet, + }), + frame: None, + }, + ); + } + + fn send_ephemeral_reset(&mut self, stream_id: StreamId, dir: ResetTarget, code: ResetCode) { + let packet_id = self.state.next_packet_id(); + let valid_until = wire::now_secs().saturating_add(self.config.packet_expiration.as_secs()); + self.enqueue_stream_body( + None, + None, + false, + true, + StreamBody { + packet_id, + valid_until, + packet_ack: None, + frame: Some(StreamFrame::Reset(StreamFrameReset { + stream_id, + dir, + code, + })), + }, + ); + } + + fn enqueue_handshake_message(&mut self, token: Token, deadline: Instant, bytes: Vec) { + self.state + .enqueue_handshake_message(&self.config, token, deadline, bytes); + } + + fn enqueue_stream_body( + &mut self, + stream_id: Option, + packet_id: Option, + track_ack: bool, + priority: bool, + body: StreamBody, + ) { + self.state.enqueue_stream_body( + &self.config, + stream_id, + packet_id, + track_ack, + priority, + body, + ); + } + + fn handle_hello( + &mut self, + now: Instant, + peer: XID, + hello: &wire::handshake::ArchivedHello, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, + ) { + let action = match self.state.peer.as_ref() { + Some(entry) => match &entry.session { + PeerSession::Initiator { + hello: local_hello, .. + } => { + if peer_hello_wins(local_hello, crypto.xid(), hello, peer) { + HelloAction::StartResponder + } else { + HelloAction::Ignore + } + } + PeerSession::Responder { + hello: stored, + reply, + deadline, + .. + } => { + if stored.nonce == wire::nonce_from_archived(&hello.nonce) { + HelloAction::ResendReply { + reply: reply.clone(), + deadline: *deadline, + } + } else { + HelloAction::StartResponder + } + } + PeerSession::Disconnected | PeerSession::Connected { .. } => { + HelloAction::StartResponder + } + }, + None => return, + }; + + match action { + HelloAction::StartResponder => { + self.start_responder_handshake(now, peer, hello, crypto, emit) + } + HelloAction::ResendReply { reply, deadline } => { + let record = QlRecord { + header: QlHeader { + sender: crypto.xid(), + recipient: peer, + }, + payload: QlPayload::Handshake(HandshakeRecord::HelloReply(reply)), + }; + let token = self.state.next_token(); + self.enqueue_handshake_message(token, deadline, wire::encode_record(&record)); + } + HelloAction::Ignore => {} + } + } + + fn handle_hello_reply( + &mut self, + now: Instant, + peer: XID, + reply: &wire::handshake::ArchivedHelloReply, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, + ) { + let token = self.state.next_token(); + let deadline = now + self.config.handshake_timeout; + let res = { + let Some(peer_record) = self.state.peer.as_ref() else { + return; + }; + let PeerSession::Initiator { + hello, + session_key, + stage, + .. + } = &peer_record.session + else { + return; + }; + if *stage != InitiatorStage::WaitingHelloReply { + return; + } + handshake::build_confirm( + crypto, + crypto.xid(), + peer, + &peer_record.signing_key, + hello, + reply, + session_key, + ) + .map(|(confirm, session_key)| (hello.clone(), confirm, session_key)) + }; + let confirm = match res { + Ok((hello, confirm, session_key)) => { + if let Some(entry) = self.state.peer.as_mut() { + entry.session = PeerSession::Initiator { + handshake_token: token, + hello, + session_key, + deadline, + stage: InitiatorStage::SendingConfirm, + }; + } + confirm + } + Err(_) => { + if let Some(entry) = self.state.peer.as_mut() { + entry.session = PeerSession::Disconnected; + } + self.emit_peer_status(emit); + return; + } + }; + + let record = QlRecord { + header: QlHeader { + sender: crypto.xid(), + recipient: peer, + }, + payload: QlPayload::Handshake(HandshakeRecord::Confirm(confirm)), + }; + self.enqueue_handshake_message(token, deadline, wire::encode_record(&record)); + } + + fn handle_confirm( + &mut self, + now: Instant, + peer: XID, + confirm: &wire::handshake::ArchivedConfirm, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, + ) { + let Some(peer_record) = self.state.peer.as_ref() else { + return; + }; + let PeerSession::Responder { + hello, + reply, + secrets, + .. + } = &peer_record.session + else { + return; + }; + + match handshake::finalize_confirm( + peer, + crypto.xid(), + &peer_record.signing_key, + hello, + reply, + confirm, + secrets, + ) { + Ok(session_key) => { + if let Some(entry) = self.state.peer.as_mut() { + entry.session = PeerSession::Connected { + session_key, + keepalive: KeepAliveState::default(), + }; + } + self.record_activity(now); + self.emit_peer_status(emit); + } + Err(_) => { + if let Some(entry) = self.state.peer.as_mut() { + entry.session = PeerSession::Disconnected; + } + self.emit_peer_status(emit); + } + } + } + + fn start_responder_handshake( + &mut self, + now: Instant, + peer: XID, + hello: &wire::handshake::ArchivedHello, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, + ) { + let res = { + let Some(peer_record) = self.state.peer.as_ref() else { + return; + }; + handshake::respond_hello( + crypto, + peer, + crypto.xid(), + &peer_record.encapsulation_key, + hello, + ) + }; + let (reply, secrets) = match res { + Ok(result) => result, + Err(_) => { + if let Some(entry) = self.state.peer.as_mut() { + entry.session = PeerSession::Disconnected; + } + self.emit_peer_status(emit); + return; + } + }; + let Ok(hello) = wire::deserialize_value(hello) else { + if let Some(entry) = self.state.peer.as_mut() { + entry.session = PeerSession::Disconnected; + } + self.emit_peer_status(emit); + return; + }; + + let deadline = now + self.config.handshake_timeout; + let token = self.state.next_token(); + if let Some(entry) = self.state.peer.as_mut() { + entry.session = PeerSession::Responder { + handshake_token: token, + hello, + reply: reply.clone(), + secrets, + deadline, + }; + } + self.emit_peer_status(emit); + + let record = QlRecord { + header: QlHeader { + sender: crypto.xid(), + recipient: peer, + }, + payload: QlPayload::Handshake(HandshakeRecord::HelloReply(reply)), + }; + self.enqueue_handshake_message(token, deadline, wire::encode_record(&record)); + } + + fn send_heartbeat_message(&mut self, now: Instant, crypto: &impl QlCrypto) { + let Some(peer) = self.state.peer.as_ref().map(|peer| peer.peer) else { + return; + }; + let packet_id = self.state.next_packet_id(); + let token = self.state.next_token(); + let deadline = now + self.config.packet_expiration; + let message = { + let Some(peer_record) = self.state.peer.as_ref() else { + return; + }; + let PeerSession::Connected { session_key, .. } = &peer_record.session else { + return; + }; + heartbeat::encrypt_heartbeat( + QlHeader { + sender: crypto.xid(), + recipient: peer, + }, + session_key, + HeartbeatBody { + packet_id, + valid_until: wire::now_secs() + .saturating_add(self.config.packet_expiration.as_secs()), + }, + next_encrypted_message_nonce(crypto), + ) + }; + self.enqueue_handshake_message(token, deadline, wire::encode_record(&message)); + } + + fn keep_alive_config(&self) -> Option { + self.config + .keep_alive + .filter(|config| !config.interval.is_zero() && !config.timeout.is_zero()) + } + + fn record_activity(&mut self, now: Instant) { + let Some(config) = self.keep_alive_config() else { + return; + }; + let token = self.state.next_token(); + let Some(entry) = self.state.peer.as_mut() else { + return; + }; + let PeerSession::Connected { keepalive, .. } = &mut entry.session else { + return; + }; + keepalive.last_activity = Some(now); + keepalive.pending = false; + keepalive.token = token; + self.state.timeouts.push(Reverse(TimeoutEntry { + at: now + config.interval, + kind: TimeoutKind::KeepAliveSend { token }, + })); + } + + fn record_stream_activity(&mut self, stream_id: StreamId, now: Instant) { + if let Some(stream) = self.streams.get_mut(&stream_id) { + *stream.last_activity_mut() = now; + } + } + + fn drop_outbound(&mut self, emit: &mut impl OutputFn) { + while let Some(message) = self.state.outbound.pop_front() { + if let Some(stream_id) = message.stream_id { + self.fail_stream_by_id(stream_id, QlError::SendFailed, emit); + } + } + } + + fn abort_streams(&mut self, error: QlError, emit: &mut impl OutputFn) { + let streams = mem::take(&mut self.streams); + for (stream_id, stream) in streams { + self.fail_stream(stream_id, stream, error.clone(), emit); + } + } + + fn fail_stream_by_id(&mut self, stream_id: StreamId, error: QlError, emit: &mut impl OutputFn) { + let Some(stream) = self.streams.remove(&stream_id) else { + return; + }; + self.fail_stream(stream_id, stream, error, emit); + } + + fn fail_stream( + &mut self, + stream_id: StreamId, + stream: StreamState, + error: QlError, + emit: &mut impl OutputFn, + ) { + match stream { + StreamState::Initiator(stream) => { + match stream.accept { + InitiatorAccept::Opening(waiter) | InitiatorAccept::WaitingAccept(waiter) => { + if let Some(open_id) = waiter.open_id { + emit(EngineOutput::OpenFailed { + open_id, + stream_id, + error: error.clone(), + }); + } + } + InitiatorAccept::Open { .. } => {} + } + emit(EngineOutput::OutboundFailed { + stream_id, + dir: Direction::Request, + error: error.clone(), + }); + emit(EngineOutput::InboundFailed { + stream_id, + dir: Direction::Response, + error, + }); + } + StreamState::Responder(stream) => { + emit(EngineOutput::InboundFailed { + stream_id, + dir: Direction::Request, + error: error.clone(), + }); + if matches!(stream.response, ResponderResponse::Accepted { .. }) { + emit(EngineOutput::OutboundFailed { + stream_id, + dir: Direction::Response, + error, + }); + } + } + } + emit(EngineOutput::StreamReaped { stream_id }); + } + + fn unpair_peer(&mut self, emit: &mut impl OutputFn) { + let Some(peer) = self.state.peer.as_ref().map(|peer| peer.peer) else { + return; + }; + self.drop_outbound(emit); + self.abort_streams(QlError::SendFailed, emit); + self.state.replay_cache.clear_peer(peer); + self.state.peer = None; + emit(EngineOutput::PeerStatusChanged { + peer, + session: PeerSession::Disconnected, + }); + emit(EngineOutput::ClearPeer); + } + + fn handle_timeouts(&mut self, now: Instant, crypto: &impl QlCrypto, emit: &mut impl OutputFn) { + loop { + let Some(entry) = self + .state + .timeouts + .peek_mut() + .filter(|entry| entry.0.at <= now) + else { + break; + }; + let entry = std::collections::binary_heap::PeekMut::pop(entry).0; + match entry.kind { + TimeoutKind::Outbound { token } => { + let mut timed_out_stream = None; + self.state.outbound.retain(|message| { + if message.token == token { + timed_out_stream = message.stream_id; + false + } else { + true + } + }); + if let Some(stream_id) = timed_out_stream { + self.fail_stream_by_id(stream_id, QlError::SendFailed, emit); + } + } + TimeoutKind::Handshake { token } => { + let Some(entry) = self.state.peer.as_ref() else { + continue; + }; + let should_disconnect = matches!( + &entry.session, + PeerSession::Initiator { handshake_token, .. } | PeerSession::Responder { handshake_token, .. } + if *handshake_token == token + ); + if should_disconnect { + if let Some(entry) = self.state.peer.as_mut() { + entry.session = PeerSession::Disconnected; + } + self.emit_peer_status(emit); + self.drop_outbound(emit); + self.abort_streams(QlError::SendFailed, emit); + } + } + TimeoutKind::KeepAliveSend { token } => { + let Some(config) = self.keep_alive_config() else { + continue; + }; + let should_send = { + let Some(entry) = self.state.peer.as_ref() else { + continue; + }; + let PeerSession::Connected { keepalive, .. } = &entry.session else { + continue; + }; + keepalive.token == token && !keepalive.pending + }; + if should_send { + self.send_heartbeat_message(now, crypto); + } + if let Some(entry) = self.state.peer.as_mut() { + if let PeerSession::Connected { keepalive, .. } = &mut entry.session { + if keepalive.token == token { + keepalive.pending = true; + } + } + } + self.state.timeouts.push(Reverse(TimeoutEntry { + at: now + config.timeout, + kind: TimeoutKind::KeepAliveTimeout { token }, + })); + } + TimeoutKind::KeepAliveTimeout { token } => { + let Some(entry) = self.state.peer.as_ref() else { + continue; + }; + let should_disconnect = matches!(&entry.session, PeerSession::Connected { keepalive, .. } if keepalive.token == token && keepalive.pending); + if should_disconnect { + if let Some(entry) = self.state.peer.as_mut() { + entry.session = PeerSession::Disconnected; + } + self.emit_peer_status(emit); + self.drop_outbound(emit); + self.abort_streams(QlError::SendFailed, emit); + } + } + TimeoutKind::StreamOpen { stream_id, token } => { + let should_fail = self + .streams + .get(&stream_id) + .and_then(StreamState::open_timeout_token) + .is_some_and(|stream_token| stream_token == token); + if should_fail { + self.fail_stream_by_id(stream_id, QlError::Timeout, emit); + } + } + TimeoutKind::StreamPacket { + stream_id, + packet_id, + attempt, + } => { + let mut timed_out = false; + let mut retransmit_control = None; + let mut retransmit_data = None; + { + let Some(stream) = self.streams.get_mut(&stream_id) else { + continue; + }; + let Some(retransmit) = + stream.control().awaiting.as_ref().and_then(|awaiting| { + if awaiting.packet_id != packet_id || awaiting.attempt != attempt { + return None; + } + Some(match &awaiting.frame { + AwaitingFrame::Control(frame) => { + EitherRetransmit::Control(frame.clone()) + } + AwaitingFrame::Data { dir, offset, len } => { + EitherRetransmit::Data { + dir: *dir, + offset: *offset, + len: *len, + } + } + }) + }) + else { + continue; + }; + + if attempt >= self.config.stream_retry_limit { + timed_out = true; + } else { + match retransmit { + EitherRetransmit::Control(frame) => { + retransmit_control = Some(frame) + } + EitherRetransmit::Data { dir, offset, len } => { + retransmit_data = Some((dir, offset, len)) + } + } + } + } + if timed_out { + self.fail_stream_by_id(stream_id, QlError::Timeout, emit); + } else if let Some(frame) = retransmit_control { + let (streams, state) = (&mut self.streams, &mut self.state); + if let Some(stream) = streams.get_mut(&stream_id) { + let key = stream.key(); + state.enqueue_control_frame( + &self.config, + key, + stream.control_mut(), + frame, + attempt.saturating_add(1), + ); + } + } else if let Some((dir, offset, len)) = retransmit_data { + if let Some(stream) = self.streams.get_mut(&stream_id) { + if let Some(outbound) = stream.outbound_mut(dir) { + outbound.pending_pull = Some(PendingPull { + offset, + max_len: len, + }); + emit(EngineOutput::NeedOutboundData { + stream_id, + dir, + offset, + max_len: len, + }); + } + } + } + } + } + } + } + + fn handle_write_done( + &mut self, + now: Instant, + token: Token, + tracked: Option, + result: Result<(), QlError>, + emit: &mut impl OutputFn, + ) { + if self.state.write_in_flight == Some(token) { + self.state.write_in_flight = None; + } + if let Err(error) = result { + if let Some(tracked) = tracked { + self.fail_stream_by_id(tracked.stream_id, error.clone(), emit); + } + let should_disconnect = matches!(self.state.peer.as_ref().map(|entry| &entry.session), + Some(PeerSession::Initiator { handshake_token, .. }) if *handshake_token == token) + || matches!(self.state.peer.as_ref().map(|entry| &entry.session), + Some(PeerSession::Responder { handshake_token, .. }) if *handshake_token == token); + if should_disconnect { + if let Some(entry) = self.state.peer.as_mut() { + entry.session = PeerSession::Disconnected; + } + self.emit_peer_status(emit); + self.drop_outbound(emit); + self.abort_streams(error, emit); + } + return; + } + + let connected = self + .state + .peer + .as_ref() + .and_then(|entry| match &entry.session { + PeerSession::Initiator { + session_key, + handshake_token, + stage: InitiatorStage::SendingConfirm, + .. + } if *handshake_token == token => Some(session_key.clone()), + _ => None, + }); + if let Some(session_key) = connected { + if let Some(entry) = self.state.peer.as_mut() { + entry.session = PeerSession::Connected { + session_key, + keepalive: KeepAliveState::default(), + }; + } + self.emit_peer_status(emit); + self.record_activity(now); + } + + if let Some(tracked) = tracked { + let attempt = self + .streams + .get(&tracked.stream_id) + .and_then(|stream| stream.control().awaiting.as_ref()) + .and_then(|awaiting| { + (awaiting.packet_id == tracked.packet_id).then_some(awaiting.attempt) + }) + .unwrap_or(0); + self.state.timeouts.push(Reverse(TimeoutEntry { + at: now + self.config.packet_ack_timeout, + kind: TimeoutKind::StreamPacket { + stream_id: tracked.stream_id, + packet_id: tracked.packet_id, + attempt, + }, + })); + } + } + + fn maybe_start_next_write(&mut self, crypto: &impl QlCrypto, emit: &mut impl OutputFn) { + if self.state.write_in_flight.is_some() { + return; + } + while let Some(message) = self.state.outbound.pop_front() { + let bytes = match message.payload { + QueuedPayload::PreEncoded(bytes) => bytes, + QueuedPayload::StreamBody(body) => { + let Some(peer) = self.state.peer.as_ref() else { + if let Some(stream_id) = message.stream_id { + self.fail_stream_by_id(stream_id, QlError::SendFailed, emit); + } + continue; + }; + let Some(session_key) = peer.session.session_key() else { + if let Some(stream_id) = message.stream_id { + self.fail_stream_by_id(stream_id, QlError::SendFailed, emit); + } + continue; + }; + let record = encrypt_stream( + QlHeader { + sender: crypto.xid(), + recipient: peer.peer, + }, + session_key, + body, + next_encrypted_message_nonce(crypto), + ); + wire::encode_record(&record) + } + }; + + let tracked = if message.track_ack { + message + .stream_id + .zip(message.packet_id) + .map(|(stream_id, packet_id)| TrackedWrite { + stream_id, + packet_id, + }) + } else { + None + }; + self.state.write_in_flight = Some(message.token); + emit(EngineOutput::WriteMessage { + token: message.token, + tracked, + bytes, + }); + break; + } + } +} + +fn next_encrypted_message_nonce(crypto: &impl QlCrypto) -> [u8; NONCE_SIZE] { + let mut nonce = [0u8; NONCE_SIZE]; + crypto.fill_random_bytes(&mut nonce); + nonce +} + +fn peer_hello_wins( + local_hello: &Hello, + local_sender: XID, + peer_hello: &wire::handshake::ArchivedHello, + peer_sender: XID, +) -> bool { + use std::cmp::Ordering; + + let peer_nonce = wire::nonce_from_archived(&peer_hello.nonce); + match peer_nonce.data().cmp(local_hello.nonce.data()) { + Ordering::Less => true, + Ordering::Greater => false, + Ordering::Equal => peer_sender.data().cmp(local_sender.data()) == Ordering::Less, + } +} + +fn stream_id_from_frame(frame: &StreamFrame) -> StreamId { + match frame { + StreamFrame::Open(frame) => frame.stream_id, + StreamFrame::Accept(frame) => frame.stream_id, + StreamFrame::Reject(frame) => frame.stream_id, + StreamFrame::Data(frame) => frame.stream_id, + StreamFrame::Credit(frame) => frame.stream_id, + StreamFrame::Finish(frame) => frame.stream_id, + StreamFrame::Reset(frame) => frame.stream_id, + } +} + +fn reset_target_for_dir(dir: Direction) -> ResetTarget { + match dir { + Direction::Request => ResetTarget::Request, + Direction::Response => ResetTarget::Response, + } +} diff --git a/ql2/src/engine/replay_cache.rs b/ql2/src/engine/replay_cache.rs new file mode 100644 index 00000000..292f1740 --- /dev/null +++ b/ql2/src/engine/replay_cache.rs @@ -0,0 +1,189 @@ +use std::{ + cmp::Reverse, + collections::{binary_heap::PeekMut, BinaryHeap, HashSet}, + time::{SystemTime, UNIX_EPOCH}, +}; + +use bc_components::XID; + +use crate::PacketId; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ReplayNamespace { + Peer, + Local, + Transfer, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ReplayKey { + pub peer: XID, + pub namespace: ReplayNamespace, + pub packet_id: PacketId, +} + +impl ReplayKey { + pub const fn new(peer: XID, namespace: ReplayNamespace, packet_id: PacketId) -> Self { + Self { + peer, + namespace, + packet_id, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct ExpiryEntry { + expires_at: u64, + key: ReplayKey, +} + +impl Ord for ExpiryEntry { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.expires_at + .cmp(&other.expires_at) + .then_with(|| self.key.cmp(&other.key)) + } +} + +impl PartialOrd for ExpiryEntry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +#[derive(Debug, Default)] +pub struct ReplayCache { + entries: HashSet, + expirations: BinaryHeap>, +} + +impl ReplayCache { + pub fn new() -> Self { + Self { + entries: HashSet::new(), + expirations: BinaryHeap::new(), + } + } + + pub fn len(&self) -> usize { + self.entries.len() + } + + pub fn is_empty(&self) -> bool { + self.entries.is_empty() + } + + pub fn add(&mut self, key: ReplayKey, expires_at: u64) { + if self.entries.insert(key) { + self.expirations + .push(Reverse(ExpiryEntry { expires_at, key })); + } + } + + pub fn check_and_store(&mut self, key: ReplayKey, expires_at: u64) -> bool { + let now_secs = now_secs(); + self.check_and_store_at(key, expires_at, now_secs) + } + + pub fn check_and_store_valid_until(&mut self, key: ReplayKey, valid_until: u64) -> bool { + let now_secs = now_secs(); + self.check_and_store_at(key, valid_until, now_secs) + } + + pub fn purge_expired(&mut self) { + let now_secs = now_secs(); + self.purge_expired_at(now_secs); + } + + pub fn clear_peer(&mut self, peer: XID) { + self.entries.retain(|entry| entry.peer != peer); + self.expirations.retain(|entry| entry.0.key.peer != peer); + } + + fn check_and_store_at(&mut self, key: ReplayKey, expires_at: u64, now_secs: u64) -> bool { + self.purge_expired_at(now_secs); + if self.entries.contains(&key) { + return true; + } + self.entries.insert(key); + self.expirations + .push(Reverse(ExpiryEntry { expires_at, key })); + false + } + + fn purge_expired_at(&mut self, now_secs: u64) { + while let Some(entry) = self.expirations.peek_mut() { + if entry.0.expires_at > now_secs { + break; + } + let entry = PeekMut::pop(entry).0; + self.entries.remove(&entry.key); + } + } +} + +fn now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_secs()) + .unwrap_or(0) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn peer_with_byte(byte: u8) -> XID { + XID::from_data([byte; XID::XID_SIZE]) + } + + #[test] + fn check_and_store_detects_replay() { + let mut cache = ReplayCache::new(); + let peer = peer_with_byte(1); + let key = ReplayKey::new(peer, ReplayNamespace::Peer, PacketId(1)); + let now_secs = 100; + let expires_at = 110; + + assert!(!cache.check_and_store_at(key, expires_at, now_secs)); + assert!(cache.check_and_store_at(key, expires_at, now_secs)); + } + + #[test] + fn purge_expired_removes_old_entries() { + let mut cache = ReplayCache::new(); + let now_secs = 100; + let expired_at = 99; + let future_at = 110; + + let key_old = ReplayKey::new(peer_with_byte(2), ReplayNamespace::Peer, PacketId(2)); + let key_new = ReplayKey::new(peer_with_byte(3), ReplayNamespace::Peer, PacketId(3)); + + cache.add(key_old, expired_at); + cache.add(key_new, future_at); + + cache.purge_expired_at(now_secs); + assert_eq!(cache.len(), 1); + assert!(!cache.check_and_store_at(key_old, future_at, now_secs)); + } + + #[test] + fn clear_peer_removes_peer_entries() { + let mut cache = ReplayCache::new(); + let now_secs = 100; + let expires_at = 110; + + let peer_a = peer_with_byte(4); + let peer_b = peer_with_byte(5); + let key_a = ReplayKey::new(peer_a, ReplayNamespace::Peer, PacketId(4)); + let key_b = ReplayKey::new(peer_b, ReplayNamespace::Peer, PacketId(5)); + + cache.add(key_a, expires_at); + cache.add(key_b, expires_at); + + cache.clear_peer(peer_a); + assert_eq!(cache.len(), 1); + assert!(!cache.check_and_store_at(key_a, expires_at, now_secs)); + } +} diff --git a/ql2/src/engine/state.rs b/ql2/src/engine/state.rs new file mode 100644 index 00000000..33055174 --- /dev/null +++ b/ql2/src/engine/state.rs @@ -0,0 +1,507 @@ +use std::{ + cell::Cell, + cmp::Reverse, + collections::{BinaryHeap, HashMap, VecDeque}, + time::Instant, +}; + +use bc_components::{MLDSAPublicKey, MLKEMPublicKey, SymmetricKey, XID}; + +use super::{ + replay_cache::ReplayCache, + stream::{AwaitingFrame, AwaitingPacket, QueuedWrite, StreamControl, StreamKey, StreamState}, + EngineConfig, +}; +use crate::{ + runtime::StreamConfig, + wire::{ + handshake::{Hello, HelloReply, ResponderSecrets}, + stream::{Direction, RejectCode, ResetCode, StreamBody, StreamFrame, StreamFrameData}, + }, + PacketId, Peer, QlError, StreamId, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Token(pub u64); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct OpenId(pub u64); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TrackedWrite { + pub stream_id: StreamId, + pub packet_id: PacketId, +} + +#[derive(Debug, Clone)] +pub struct KeepAliveState { + pub token: Token, + pub pending: bool, + pub last_activity: Option, +} + +impl Default for KeepAliveState { + fn default() -> Self { + Self { + token: Token(0), + pending: false, + last_activity: None, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum InitiatorStage { + WaitingHelloReply, + SendingConfirm, +} + +#[derive(Debug, Clone)] +pub enum PeerSession { + Disconnected, + Initiator { + handshake_token: Token, + hello: Hello, + session_key: SymmetricKey, + deadline: Instant, + stage: InitiatorStage, + }, + Responder { + handshake_token: Token, + hello: Hello, + reply: HelloReply, + secrets: ResponderSecrets, + deadline: Instant, + }, + Connected { + session_key: SymmetricKey, + keepalive: KeepAliveState, + }, +} + +impl PeerSession { + pub fn is_connected(&self) -> bool { + matches!(self, Self::Connected { .. }) + } + + pub fn session_key(&self) -> Option<&SymmetricKey> { + match self { + Self::Connected { session_key, .. } => Some(session_key), + _ => None, + } + } +} + +#[derive(Debug, Clone)] +pub struct PeerRecord { + pub peer: XID, + pub signing_key: MLDSAPublicKey, + pub encapsulation_key: MLKEMPublicKey, + pub session: PeerSession, +} + +impl PeerRecord { + pub fn new(peer: XID, signing_key: MLDSAPublicKey, encapsulation_key: MLKEMPublicKey) -> Self { + Self { + peer, + signing_key, + encapsulation_key, + session: PeerSession::Disconnected, + } + } + + pub fn snapshot(&self) -> Peer { + Peer { + peer: self.peer, + signing_key: self.signing_key.clone(), + encapsulation_key: self.encapsulation_key.clone(), + } + } +} + +#[derive(Debug)] +pub enum EngineInput { + BindPeer(Peer), + Pair, + Connect, + Unpair, + + OpenStream { + open_id: OpenId, + request_head: Vec, + config: StreamConfig, + }, + AcceptStream { + stream_id: StreamId, + response_head: Vec, + }, + RejectStream { + stream_id: StreamId, + code: RejectCode, + }, + + OutboundData { + stream_id: StreamId, + dir: Direction, + offset: u64, + bytes: Vec, + }, + OutboundFinished { + stream_id: StreamId, + dir: Direction, + final_offset: u64, + }, + InboundConsumed { + stream_id: StreamId, + dir: Direction, + amount: u64, + }, + + ResetOutbound { + stream_id: StreamId, + dir: Direction, + code: ResetCode, + }, + ResetInbound { + stream_id: StreamId, + dir: Direction, + code: ResetCode, + }, + PendingAcceptDropped { + stream_id: StreamId, + }, + ResponderDropped { + stream_id: StreamId, + }, + + Incoming(Vec), + WriteCompleted { + token: Token, + tracked: Option, + result: Result<(), QlError>, + }, + TimerExpired, +} + +#[derive(Debug)] +pub enum EngineOutput { + SetTimer(Option), + WriteMessage { + token: Token, + tracked: Option, + bytes: Vec, + }, + + PeerStatusChanged { + peer: XID, + session: PeerSession, + }, + PersistPeer(Peer), + ClearPeer, + + OpenStarted { + open_id: OpenId, + stream_id: StreamId, + }, + OpenAccepted { + open_id: OpenId, + stream_id: StreamId, + response_head: Vec, + }, + OpenFailed { + open_id: OpenId, + stream_id: StreamId, + error: QlError, + }, + + InboundStreamOpened { + stream_id: StreamId, + request_head: Vec, + }, + InboundData { + stream_id: StreamId, + dir: Direction, + bytes: Vec, + }, + InboundFinished { + stream_id: StreamId, + dir: Direction, + }, + InboundFailed { + stream_id: StreamId, + dir: Direction, + error: QlError, + }, + + NeedOutboundData { + stream_id: StreamId, + dir: Direction, + offset: u64, + max_len: usize, + }, + ReleaseOutboundThrough { + stream_id: StreamId, + dir: Direction, + recv_offset: u64, + }, + OutboundClosed { + stream_id: StreamId, + dir: Direction, + }, + OutboundFailed { + stream_id: StreamId, + dir: Direction, + error: QlError, + }, + + StreamReaped { + stream_id: StreamId, + }, +} + +pub trait OutputFn: FnMut(EngineOutput) {} + +impl OutputFn for T where T: FnMut(EngineOutput) {} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TimeoutKind { + Outbound { + token: Token, + }, + Handshake { + token: Token, + }, + KeepAliveSend { + token: Token, + }, + KeepAliveTimeout { + token: Token, + }, + StreamOpen { + stream_id: StreamId, + token: Token, + }, + StreamPacket { + stream_id: StreamId, + packet_id: PacketId, + attempt: u8, + }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TimeoutEntry { + pub at: Instant, + pub kind: TimeoutKind, +} + +impl Ord for TimeoutEntry { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.at.cmp(&other.at) + } +} + +impl PartialOrd for TimeoutEntry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +#[derive(Debug)] +pub enum HelloAction { + StartResponder, + ResendReply { + reply: HelloReply, + deadline: Instant, + }, + Ignore, +} + +pub struct Engine { + pub config: EngineConfig, + pub state: EngineState, + pub streams: HashMap, +} + +pub struct EngineState { + pub peer: Option, + pub replay_cache: ReplayCache, + + pub next_token: Cell, + pub next_packet_id: Cell, + pub next_stream_id: Cell, + pub outbound: VecDeque, + pub timeouts: BinaryHeap>, + pub write_in_flight: Option, +} + +impl EngineState { + pub fn new(peer: Option) -> Self { + Self { + peer: peer + .map(|peer| PeerRecord::new(peer.peer, peer.signing_key, peer.encapsulation_key)), + replay_cache: ReplayCache::new(), + next_token: Cell::new(1), + next_packet_id: Cell::new(1), + next_stream_id: Cell::new(1), + outbound: VecDeque::new(), + timeouts: BinaryHeap::new(), + write_in_flight: None, + } + } + + pub fn next_deadline(&self) -> Option { + self.timeouts.peek().map(|entry| entry.0.at) + } + + pub fn next_token(&self) -> Token { + let token = self.next_token.get(); + self.next_token.set(token.wrapping_add(1)); + Token(token) + } + + pub fn next_packet_id(&self) -> PacketId { + let id = self.next_packet_id.get(); + self.next_packet_id.set(id.wrapping_add(1)); + PacketId(id) + } + + pub fn next_stream_id(&self) -> StreamId { + let id = self.next_stream_id.get(); + self.next_stream_id.set(id.wrapping_add(1)); + StreamId(id) + } + + pub fn enqueue_handshake_message( + &mut self, + _config: &EngineConfig, + token: Token, + deadline: Instant, + bytes: Vec, + ) { + self.outbound.push_back(QueuedWrite { + token, + stream_id: None, + packet_id: None, + track_ack: false, + payload: super::stream::QueuedPayload::PreEncoded(bytes), + }); + self.timeouts.push(Reverse(TimeoutEntry { + at: deadline, + kind: TimeoutKind::Handshake { token }, + })); + self.timeouts.push(Reverse(TimeoutEntry { + at: deadline, + kind: TimeoutKind::Outbound { token }, + })); + } + + pub fn enqueue_stream_body( + &mut self, + config: &EngineConfig, + stream_id: Option, + packet_id: Option, + track_ack: bool, + priority: bool, + body: StreamBody, + ) { + let token = self.next_token(); + let message = QueuedWrite { + token, + stream_id, + packet_id, + track_ack, + payload: super::stream::QueuedPayload::StreamBody(body), + }; + if priority { + self.outbound.push_front(message); + } else { + self.outbound.push_back(message); + } + self.timeouts.push(Reverse(TimeoutEntry { + at: Instant::now() + config.packet_expiration, + kind: TimeoutKind::Outbound { token }, + })); + } + + pub fn enqueue_control_frame( + &mut self, + config: &EngineConfig, + key: StreamKey, + control: &mut StreamControl, + frame: StreamFrame, + attempt: u8, + ) { + let packet_id = self.next_packet_id(); + control.awaiting = Some(AwaitingPacket { + packet_id, + frame: AwaitingFrame::Control(frame.clone()), + attempt, + }); + let valid_until = + crate::wire::now_secs().saturating_add(config.packet_expiration.as_secs()); + self.enqueue_stream_body( + config, + Some(key.stream_id), + Some(packet_id), + true, + false, + StreamBody { + packet_id, + valid_until, + packet_ack: None, + frame: Some(frame), + }, + ); + } + + pub fn enqueue_data_frame( + &mut self, + config: &EngineConfig, + key: StreamKey, + control: &mut StreamControl, + dir: Direction, + offset: u64, + bytes: Vec, + attempt: u8, + ) { + let packet_id = self.next_packet_id(); + control.awaiting = Some(AwaitingPacket { + packet_id, + frame: AwaitingFrame::Data { + dir, + offset, + len: bytes.len(), + }, + attempt, + }); + let valid_until = + crate::wire::now_secs().saturating_add(config.packet_expiration.as_secs()); + self.enqueue_stream_body( + config, + Some(key.stream_id), + Some(packet_id), + true, + false, + StreamBody { + packet_id, + valid_until, + packet_ack: None, + frame: Some(StreamFrame::Data(StreamFrameData { + stream_id: key.stream_id, + dir, + offset, + bytes, + })), + }, + ); + } +} + +pub enum EitherRetransmit { + Control(StreamFrame), + Data { + dir: Direction, + offset: u64, + len: usize, + }, +} diff --git a/ql2/src/engine/stream.rs b/ql2/src/engine/stream.rs new file mode 100644 index 00000000..5a371b96 --- /dev/null +++ b/ql2/src/engine/stream.rs @@ -0,0 +1,302 @@ +use std::time::Instant; + +use super::{OpenId, Token}; +use crate::{ + wire::stream::{ + Direction, ResetCode, ResetTarget, StreamBody, StreamFrame, StreamFrameAccept, + StreamFrameCredit, StreamFrameOpen, StreamFrameReject, StreamFrameReset, + }, + PacketId, StreamId, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct StreamKey { + pub stream_id: StreamId, +} + +#[derive(Debug)] +pub struct StreamMeta { + pub key: StreamKey, + pub request_head: Vec, + pub last_activity: Instant, +} + +#[derive(Debug)] +pub struct PendingPull { + pub offset: u64, + pub max_len: usize, +} + +#[derive(Debug)] +pub struct OutboundState { + pub dir: Direction, + pub remote_max_offset: u64, + pub sent_offset: u64, + pub released_offset: u64, + pub final_offset: Option, + pub data_enabled: bool, + pub closed: bool, + pub pending_pull: Option, +} + +impl OutboundState { + pub fn new(dir: Direction, remote_max_offset: u64, data_enabled: bool) -> Self { + Self { + dir, + remote_max_offset, + sent_offset: 0, + released_offset: 0, + final_offset: None, + data_enabled, + closed: false, + pending_pull: None, + } + } + + pub fn can_request_data(&self) -> bool { + self.data_enabled + && !self.closed + && self.pending_pull.is_none() + && self.sent_offset < self.remote_max_offset + && self + .final_offset + .is_none_or(|final_offset| self.sent_offset < final_offset) + } +} + +#[derive(Debug)] +pub struct InboundState { + pub next_offset: u64, + pub max_offset: u64, + pub closed: bool, +} + +impl InboundState { + pub fn new(max_offset: u64) -> Self { + Self { + next_offset: 0, + max_offset, + closed: false, + } + } +} + +#[derive(Debug)] +pub struct OpenWaiter { + pub open_id: Option, + pub open_timeout_token: Token, +} + +#[derive(Debug)] +pub enum InitiatorAccept { + Opening(OpenWaiter), + WaitingAccept(OpenWaiter), + Open { response_head: Vec }, +} + +#[derive(Debug)] +pub struct InitiatorStream { + pub meta: StreamMeta, + pub control: StreamControl, + pub request: OutboundState, + pub response: InboundState, + pub accept: InitiatorAccept, +} + +#[derive(Debug)] +pub enum ResponderResponse { + Pending { + initial_credit: u64, + }, + Accepted { + initial_credit: u64, + body: OutboundState, + }, + Rejecting { + initial_credit: u64, + }, +} + +#[derive(Debug)] +pub struct ResponderStream { + pub meta: StreamMeta, + pub control: StreamControl, + pub request: InboundState, + pub response: ResponderResponse, +} + +#[derive(Debug)] +pub enum StreamState { + Initiator(InitiatorStream), + Responder(ResponderStream), +} + +impl StreamState { + pub fn key(&self) -> StreamKey { + match self { + Self::Initiator(state) => state.meta.key, + Self::Responder(state) => state.meta.key, + } + } + + pub fn last_activity_mut(&mut self) -> &mut Instant { + match self { + Self::Initiator(state) => &mut state.meta.last_activity, + Self::Responder(state) => &mut state.meta.last_activity, + } + } + + pub fn control(&self) -> &StreamControl { + match self { + Self::Initiator(state) => &state.control, + Self::Responder(state) => &state.control, + } + } + + pub fn control_mut(&mut self) -> &mut StreamControl { + match self { + Self::Initiator(state) => &mut state.control, + Self::Responder(state) => &mut state.control, + } + } + + pub fn outbound_mut(&mut self, dir: Direction) -> Option<&mut OutboundState> { + match self { + Self::Initiator(state) if dir == Direction::Request => Some(&mut state.request), + Self::Responder(state) if dir == Direction::Response => match &mut state.response { + ResponderResponse::Accepted { body, .. } => Some(body), + _ => None, + }, + _ => None, + } + } + + pub fn inbound_mut(&mut self, dir: Direction) -> Option<&mut InboundState> { + match self { + Self::Initiator(state) if dir == Direction::Response => Some(&mut state.response), + Self::Responder(state) if dir == Direction::Request => Some(&mut state.request), + _ => None, + } + } + + pub fn open_timeout_token(&self) -> Option { + match self { + Self::Initiator(state) => match &state.accept { + InitiatorAccept::Opening(waiter) | InitiatorAccept::WaitingAccept(waiter) => { + Some(waiter.open_timeout_token) + } + InitiatorAccept::Open { .. } => None, + }, + _ => None, + } + } + + pub fn can_reap(&self) -> bool { + if self.control().awaiting.is_some() || !self.control().pending.is_empty() { + return false; + } + match self { + Self::Initiator(state) => { + matches!(state.accept, InitiatorAccept::Open { .. }) + && state.request.closed + && state.response.closed + } + Self::Responder(state) => match &state.response { + ResponderResponse::Accepted { body, .. } => state.request.closed && body.closed, + ResponderResponse::Rejecting { .. } => true, + ResponderResponse::Pending { .. } => false, + }, + } + } +} + +#[derive(Debug)] +pub struct AwaitingPacket { + pub packet_id: PacketId, + pub frame: AwaitingFrame, + pub attempt: u8, +} + +#[derive(Debug, Clone)] +pub enum AwaitingFrame { + Control(StreamFrame), + Data { + dir: Direction, + offset: u64, + len: usize, + }, +} + +#[derive(Debug)] +pub enum SetupFrame { + Open(StreamFrameOpen), + Accept(StreamFrameAccept), + Reject(StreamFrameReject), +} + +#[derive(Debug, Default)] +pub struct PendingFrames { + pub setup: Option, + pub credit: Option, + pub reset: Option, +} + +impl PendingFrames { + pub fn take_next_control(&mut self, stream_id: StreamId) -> Option { + if let Some(setup) = self.setup.take() { + return Some(match setup { + SetupFrame::Open(frame) => StreamFrame::Open(frame), + SetupFrame::Accept(frame) => StreamFrame::Accept(frame), + SetupFrame::Reject(frame) => StreamFrame::Reject(frame), + }); + } + if let Some(reset) = self.reset.take() { + return Some(StreamFrame::Reset(StreamFrameReset { stream_id, ..reset })); + } + self.credit.take().map(StreamFrame::Credit) + } + + pub fn set_setup(&mut self, setup: SetupFrame) { + self.setup = Some(setup); + } + + pub fn set_credit(&mut self, frame: StreamFrameCredit) { + if self.reset.is_none() { + self.credit = Some(frame); + } + } + + pub fn set_reset(&mut self, dir: ResetTarget, code: ResetCode) { + self.credit = None; + self.reset = Some(StreamFrameReset { + stream_id: StreamId(0), + dir, + code, + }); + } + + pub fn is_empty(&self) -> bool { + self.setup.is_none() && self.credit.is_none() && self.reset.is_none() + } +} + +#[derive(Debug, Default)] +pub struct StreamControl { + pub pending: PendingFrames, + pub awaiting: Option, +} + +#[derive(Debug)] +pub enum QueuedPayload { + PreEncoded(Vec), + StreamBody(StreamBody), +} + +#[derive(Debug)] +pub struct QueuedWrite { + pub token: Token, + pub stream_id: Option, + pub packet_id: Option, + pub track_ack: bool, + pub payload: QueuedPayload, +} diff --git a/ql2/src/id.rs b/ql2/src/id.rs new file mode 100644 index 00000000..d4398df5 --- /dev/null +++ b/ql2/src/id.rs @@ -0,0 +1,58 @@ +use std::fmt; + +use dcbor::CBOR; +use rkyv::{Archive, Deserialize, Serialize}; + +macro_rules! define_id { + ($name:ident, $ty:ty) => { + #[derive( + Archive, + Serialize, + Deserialize, + Debug, + Clone, + Copy, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + )] + pub struct $name(pub $ty); + + impl fmt::Display for $name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } + } + + impl From<$name> for CBOR { + fn from(value: $name) -> Self { + CBOR::from(value.0) + } + } + + impl TryFrom for $name { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + Ok(Self(<$ty>::try_from(value)?)) + } + } + }; +} + +define_id!(PacketId, u32); +define_id!(StreamId, u64); + +impl From<&ArchivedPacketId> for PacketId { + fn from(value: &ArchivedPacketId) -> Self { + Self(value.0.to_native()) + } +} + +impl From<&ArchivedStreamId> for StreamId { + fn from(value: &ArchivedStreamId) -> Self { + Self(value.0.to_native()) + } +} diff --git a/ql2/src/lib.rs b/ql2/src/lib.rs new file mode 100644 index 00000000..e07f84a3 --- /dev/null +++ b/ql2/src/lib.rs @@ -0,0 +1,51 @@ +pub mod engine; +mod id; +pub mod platform; +pub mod rpc; +pub mod runtime; +pub mod wire; + +pub use id::*; + +#[cfg(test)] +mod tests; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Peer { + pub peer: bc_components::XID, + pub signing_key: bc_components::MLDSAPublicKey, + pub encapsulation_key: bc_components::MLKEMPublicKey, +} + +#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] +pub enum QlError { + #[error("invalid payload")] + InvalidPayload, + #[error("invalid signature")] + InvalidSignature, + #[error("missing session")] + MissingSession, + #[error("no peer bound")] + NoPeerBound, + #[error("timeout")] + Timeout, + #[error("send failed")] + SendFailed, + #[error("stream rejected {code:?}")] + StreamRejected { code: wire::stream::RejectCode }, + #[error("stream reset {code:?}")] + StreamReset { + dir: wire::stream::Direction, + code: wire::stream::ResetCode, + }, + #[error("stream protocol error")] + StreamProtocol, + #[error("cancelled")] + Cancelled, +} + +impl From for QlError { + fn from(_: crate::runtime::pipe::PipeClosed) -> Self { + Self::Cancelled + } +} diff --git a/ql2/src/platform.rs b/ql2/src/platform.rs new file mode 100644 index 00000000..4472d927 --- /dev/null +++ b/ql2/src/platform.rs @@ -0,0 +1,37 @@ +use std::{future::Future, pin::Pin, time::Duration}; + +use bc_components::{ + MLDSAPrivateKey, MLDSAPublicKey, MLKEMPrivateKey, MLKEMPublicKey, SigningPublicKey, XID, +}; + +use crate::{ + runtime::{HandlerEvent, PeerSession}, + Peer, QlError, +}; + +pub type PlatformFuture<'a, T> = Pin + 'a>>; + +pub trait QlCrypto { + fn signing_private_key(&self) -> &MLDSAPrivateKey; + fn signing_public_key(&self) -> &MLDSAPublicKey; + fn encapsulation_private_key(&self) -> &MLKEMPrivateKey; + fn encapsulation_public_key(&self) -> &MLKEMPublicKey; + + fn fill_random_bytes(&self, data: &mut [u8]); + + fn xid(&self) -> XID { + XID::new(SigningPublicKey::MLDSA(self.signing_public_key().clone())) + } +} + +pub trait QlPlatform: QlCrypto { + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>>; + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()>; + + fn load_peer(&self) -> PlatformFuture<'_, Option>; + fn persist_peer(&self, peer: Peer); + fn clear_peer(&self); + + fn handle_peer_status(&self, peer: XID, session: &PeerSession); + fn handle_inbound(&self, event: HandlerEvent); +} diff --git a/ql2/src/rpc/client.rs b/ql2/src/rpc/client.rs new file mode 100644 index 00000000..0ff532ce --- /dev/null +++ b/ql2/src/rpc/client.rs @@ -0,0 +1,70 @@ +use dcbor::CBOR; + +use super::{modality::RequestResponse, RpcError, RpcRequestHead, RpcResponseHead}; +use crate::runtime::{RuntimeHandle, StreamConfig}; + +#[derive(Clone)] +pub struct RpcHandle { + inner: RuntimeHandle, +} + +impl RpcHandle { + pub fn new(inner: RuntimeHandle) -> Self { + Self { inner } + } + + pub fn runtime(&self) -> &RuntimeHandle { + &self.inner + } + + pub async fn request( + &self, + request: M, + config: StreamConfig, + ) -> Result { + let request_body = Into::::into(request).to_cbor_data(); + let request_head = CBOR::from(RpcRequestHead::new( + M::METHOD, + Some(request_body.len() as u64), + )) + .to_cbor_data(); + + let crate::runtime::PendingStream { + mut request, + accepted, + } = self.inner.open_stream(request_head, config).await?; + let accepted = accepted.await?; + request.write_all(&request_body).await?; + request.finish().await?; + + let response_head = + RpcResponseHead::try_from(CBOR::try_from_data(&accepted.response_head)?)?; + if response_head.version != super::RPC_VERSION { + return Err(RpcError::BadVersion(response_head.version)); + } + + let response_body = + read_stream_to_end(accepted.response, response_head.content_length).await?; + Ok(CBOR::try_from_data(&response_body)?.try_into()?) + } +} + +async fn read_stream_to_end( + mut stream: crate::runtime::InboundByteStream, + content_length: Option, +) -> Result, RpcError> { + let mut body = match content_length.and_then(|length| usize::try_from(length).ok()) { + Some(length) => Vec::with_capacity(length), + None => Vec::new(), + }; + while let Some(chunk) = stream.next_chunk().await? { + body.extend_from_slice(&chunk); + } + if let Some(expected) = content_length { + let actual = body.len() as u64; + if actual != expected { + return Err(RpcError::ContentLengthMismatch { expected, actual }); + } + } + Ok(body) +} diff --git a/ql2/src/rpc/mod.rs b/ql2/src/rpc/mod.rs new file mode 100644 index 00000000..f19b95b4 --- /dev/null +++ b/ql2/src/rpc/mod.rs @@ -0,0 +1,153 @@ +mod server; + +pub mod client; +pub mod modality; + +pub use client::RpcHandle; +use dcbor::CBOR; +pub use modality::{MethodId, QlCodec, RequestResponse}; + +use crate::QlError; + +pub(crate) const RPC_VERSION: u16 = 1; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RpcRequestHead { + pub version: u16, + pub method: MethodId, + pub content_length: Option, +} + +impl RpcRequestHead { + pub fn new(method: MethodId, content_length: Option) -> Self { + Self { + version: RPC_VERSION, + method, + content_length, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RpcResponseHead { + pub version: u16, + pub content_length: Option, +} + +impl RpcResponseHead { + pub fn new(content_length: Option) -> Self { + Self { + version: RPC_VERSION, + content_length, + } + } +} + +impl Default for RpcResponseHead { + fn default() -> Self { + Self::new(None) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum RpcError { + #[error(transparent)] + Transport(#[from] QlError), + #[error(transparent)] + Decode(#[from] dcbor::Error), + #[error("unsupported rpc version {0}")] + BadVersion(u16), + #[error("rpc content length mismatch: expected {expected}, got {actual}")] + ContentLengthMismatch { expected: u64, actual: u64 }, +} + +impl From for CBOR { + fn from(value: RpcRequestHead) -> Self { + CBOR::from(vec![ + CBOR::from(value.version), + CBOR::from(value.method), + value + .content_length + .map(CBOR::from) + .unwrap_or_else(CBOR::null), + ]) + } +} + +impl TryFrom for RpcRequestHead { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let [version, method, content_length] = take_fields(value.try_into_array()?.into_iter())?; + Ok(Self { + version: version.try_into()?, + method: method.try_into()?, + content_length: if content_length.is_null() { + None + } else { + Some(content_length.try_into()?) + }, + }) + } +} + +impl From for CBOR { + fn from(value: RpcResponseHead) -> Self { + CBOR::from(vec![ + CBOR::from(value.version), + value + .content_length + .map(CBOR::from) + .unwrap_or_else(CBOR::null), + ]) + } +} + +impl TryFrom for RpcResponseHead { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + let [version, content_length] = take_fields(value.try_into_array()?.into_iter())?; + Ok(Self { + version: version.try_into()?, + content_length: if content_length.is_null() { + None + } else { + Some(content_length.try_into()?) + }, + }) + } +} + +fn take_fields( + mut iter: impl Iterator, +) -> Result<[CBOR; N], dcbor::Error> { + use std::mem::MaybeUninit; + + let mut fields: [MaybeUninit; N] = [const { MaybeUninit::uninit() }; N]; + for (index, slot) in fields.iter_mut().enumerate() { + let Some(value) = iter.next() else { + for init in &mut fields[..index] { + unsafe { init.assume_init_drop() }; + } + return Err(dcbor::Error::msg("array too short")); + }; + slot.write(value); + } + let result = unsafe { std::ptr::read(&fields as *const _ as *const [CBOR; N]) }; + if iter.next().is_some() { + return Err(dcbor::Error::msg("array too long")); + } + Ok(result) +} + +#[test] +fn take_fields_reads_exact_count() { + let values = vec![CBOR::from(1u8), CBOR::from(2u8), CBOR::from(3u8)]; + let mut iter = values.into_iter(); + let [first, second, third] = take_fields(&mut iter).unwrap(); + assert_eq!(u8::try_from(first).unwrap(), 1); + assert_eq!(u8::try_from(second).unwrap(), 2); + assert_eq!(u8::try_from(third).unwrap(), 3); + assert!(iter.next().is_none()); +} diff --git a/ql2/src/rpc/modality.rs b/ql2/src/rpc/modality.rs new file mode 100644 index 00000000..533ece93 --- /dev/null +++ b/ql2/src/rpc/modality.rs @@ -0,0 +1,35 @@ +use std::fmt; + +use dcbor::CBOR; + +pub trait QlCodec: Into + TryFrom {} + +impl QlCodec for T where T: Into + TryFrom {} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct MethodId(pub u64); + +impl fmt::Display for MethodId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for CBOR { + fn from(value: MethodId) -> Self { + CBOR::from(value.0) + } +} + +impl TryFrom for MethodId { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + Ok(Self(u64::try_from(value)?)) + } +} + +pub trait RequestResponse: QlCodec { + const METHOD: MethodId; + type Response: QlCodec; +} diff --git a/ql2/src/rpc/server.rs b/ql2/src/rpc/server.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/ql2/src/rpc/server.rs @@ -0,0 +1 @@ + diff --git a/ql2/src/runtime/command.rs b/ql2/src/runtime/command.rs new file mode 100644 index 00000000..952db98b --- /dev/null +++ b/ql2/src/runtime/command.rs @@ -0,0 +1,55 @@ +use crate::{ + runtime::{pipe, AcceptedStreamDelivery, StreamConfig}, + wire::stream::{Direction, RejectCode, ResetCode}, + Peer, QlError, StreamId, +}; + +pub(crate) enum RuntimeCommand { + BindPeer { + peer: Peer, + }, + Pair, + Connect, + Unpair, + OpenStream { + request_head: Vec, + request_pipe: pipe::PipeReader, + accepted: oneshot::Sender>, + start: oneshot::Sender>, + config: StreamConfig, + }, + AcceptStream { + stream_id: StreamId, + response_head: Vec, + response_pipe: pipe::PipeReader, + }, + RejectStream { + stream_id: StreamId, + code: RejectCode, + }, + PollStream { + stream_id: StreamId, + }, + AdvanceInboundCredit { + stream_id: StreamId, + dir: Direction, + amount: u64, + }, + ResetOutbound { + stream_id: StreamId, + dir: Direction, + code: ResetCode, + }, + ResetInbound { + stream_id: StreamId, + dir: Direction, + code: ResetCode, + }, + ResponderDropped { + stream_id: StreamId, + }, + PendingAcceptDropped { + stream_id: StreamId, + }, + Incoming(Vec), +} diff --git a/ql2/src/runtime/driver.rs b/ql2/src/runtime/driver.rs new file mode 100644 index 00000000..4a9e2945 --- /dev/null +++ b/ql2/src/runtime/driver.rs @@ -0,0 +1,723 @@ +use std::{ + collections::{HashMap, VecDeque}, + future::Future, + io::Read, + task::Poll, + time::Instant, +}; + +use futures_lite::future::poll_fn; + +use crate::{ + engine::{self, Engine, EngineInput, EngineOutput, OpenId}, + platform::{PlatformFuture, QlPlatform}, + runtime::{ + command::RuntimeCommand, + handle::{InboundByteStream, InboundStream, StreamResponder}, + pipe, AcceptedStreamDelivery, HandlerEvent, Runtime, + }, + wire::stream::{Direction, ResetCode}, + QlError, StreamId, +}; + +struct InFlightWrite<'a> { + token: engine::Token, + tracked: Option, + future: PlatformFuture<'a, Result<(), QlError>>, +} + +enum DriverEvent { + Command(RuntimeCommand), + WriteCompleted { + token: engine::Token, + tracked: Option, + result: Result<(), QlError>, + }, + TimerExpired, + Closed, +} + +struct PendingOpen { + request_pipe: pipe::PipeReader, + start_tx: oneshot::Sender>, + accepted_tx: oneshot::Sender>, +} + +struct PendingAcceptDelivery { + tx: oneshot::Sender>, + response_reader: pipe::PipeReader, +} + +#[derive(Debug, Clone, Copy)] +struct PendingPull { + offset: u64, + max_len: usize, +} + +enum OutboundIo { + Open { + dir: Direction, + pipe: pipe::PipeReader, + pending_pull: Option, + finish_queued: bool, + }, + Closed, +} + +impl OutboundIo { + fn new(dir: Direction, pipe: pipe::PipeReader) -> Self { + Self::Open { + dir, + pipe, + pending_pull: None, + finish_queued: false, + } + } + + fn set_pending_pull(&mut self, offset: u64, max_len: usize) { + if let Self::Open { pending_pull, .. } = self { + *pending_pull = Some(PendingPull { offset, max_len }); + } + } + + fn release_to(&mut self, recv_offset: u64) { + if let Self::Open { pipe, .. } = self { + pipe.release_to(recv_offset); + } + } + + fn close(&mut self) { + if let Self::Open { pipe, .. } = self { + pipe.close(); + } + *self = Self::Closed; + } + + fn poll_pending(&mut self, stream_id: StreamId, pending: &mut VecDeque) { + let Self::Open { + dir, + pipe, + pending_pull, + finish_queued, + } = self + else { + return; + }; + if let Some(pull) = pending_pull.take() { + if let Some(mut grant) = pipe.reserve_at(pull.offset, pull.max_len) { + let mut bytes = vec![0; grant.len()]; + let _ = grant.read_exact(&mut bytes); + pending.push_back(EngineInput::OutboundData { + stream_id, + dir: *dir, + offset: grant.offset(), + bytes, + }); + return; + } + if pipe.writer_finished() && pipe.all_sent() { + if !*finish_queued { + *finish_queued = true; + pending.push_back(EngineInput::OutboundFinished { + stream_id, + dir: *dir, + final_offset: pipe.sent_offset(), + }); + } + return; + } + *pending_pull = Some(pull); + return; + } + + if pipe.writer_finished() && pipe.all_sent() && !*finish_queued { + *finish_queued = true; + pending.push_back(EngineInput::OutboundFinished { + stream_id, + dir: *dir, + final_offset: pipe.sent_offset(), + }); + } + } +} + +enum InboundIo { + Open(pipe::PipeWriter), + Closed, +} + +impl InboundIo { + fn new(pipe: pipe::PipeWriter) -> Self { + Self::Open(pipe) + } + + fn write_or_cancel( + &mut self, + stream_id: StreamId, + dir: Direction, + bytes: &[u8], + pending: &mut VecDeque, + ) { + let Self::Open(pipe) = self else { + pending.push_back(EngineInput::ResetInbound { + stream_id, + dir, + code: ResetCode::Cancelled, + }); + return; + }; + match pipe.try_write(bytes) { + Ok(n) if n == bytes.len() => {} + Ok(_) | Err(_) => { + pipe.close(); + *self = Self::Closed; + pending.push_back(EngineInput::ResetInbound { + stream_id, + dir, + code: ResetCode::Cancelled, + }); + } + } + } + + fn finish(&mut self) { + if let Self::Open(pipe) = self { + pipe.finish(); + } + *self = Self::Closed; + } + + fn fail(&mut self, error: QlError) { + if let Self::Open(pipe) = self { + pipe.fail(error); + } + *self = Self::Closed; + } + + fn close(&mut self) { + if let Self::Open(pipe) = self { + pipe.close(); + } + *self = Self::Closed; + } +} + +enum PendingAcceptState { + Waiting(PendingAcceptDelivery), + Dropped, + Resolved, +} + +enum ResponderResponseIo { + Pending, + Streaming(OutboundIo), + Rejected, +} + +enum DriverStreamIo { + Initiator { + request: OutboundIo, + response: InboundIo, + pending_accept: PendingAcceptState, + }, + Responder { + request: InboundIo, + response: ResponderResponseIo, + }, +} + +impl DriverStreamIo { + fn outbound_mut(&mut self, dir: Direction) -> Option<&mut OutboundIo> { + match self { + Self::Initiator { request, .. } if dir == Direction::Request => Some(request), + Self::Responder { + response: ResponderResponseIo::Streaming(outbound), + .. + } if dir == Direction::Response => Some(outbound), + _ => None, + } + } + + fn inbound_mut(&mut self, dir: Direction) -> Option<&mut InboundIo> { + match self { + Self::Initiator { response, .. } if dir == Direction::Response => Some(response), + Self::Responder { request, .. } if dir == Direction::Request => Some(request), + _ => None, + } + } + + fn close_all(&mut self) { + match self { + Self::Initiator { + request, + response, + pending_accept, + } => { + request.close(); + response.close(); + *pending_accept = PendingAcceptState::Resolved; + } + Self::Responder { request, response } => { + request.close(); + if let ResponderResponseIo::Streaming(outbound) = response { + outbound.close(); + } + *response = ResponderResponseIo::Rejected; + } + } + } +} + +struct DriverState { + engine: Engine, + pending_inputs: VecDeque, + next_timer: Option, + next_open_id: u64, + pending_opens: HashMap, + streams: HashMap, +} + +impl DriverState { + fn new(config: engine::EngineConfig, peer: Option) -> Self { + let engine = Engine::new(config, peer); + Self { + engine, + pending_inputs: VecDeque::new(), + next_timer: None, + next_open_id: 1, + pending_opens: HashMap::new(), + streams: HashMap::new(), + } + } + + fn push_input(&mut self, input: EngineInput) { + self.pending_inputs.push_back(input); + } + + fn translate_command(&mut self, command: RuntimeCommand) { + match command { + RuntimeCommand::BindPeer { peer } => self.push_input(EngineInput::BindPeer(peer)), + RuntimeCommand::Pair => self.push_input(EngineInput::Pair), + RuntimeCommand::Connect => self.push_input(EngineInput::Connect), + RuntimeCommand::Unpair => self.push_input(EngineInput::Unpair), + RuntimeCommand::Incoming(bytes) => self.push_input(EngineInput::Incoming(bytes)), + RuntimeCommand::OpenStream { + request_head, + request_pipe, + accepted, + start, + config, + } => { + let open_id = OpenId(self.next_open_id); + self.next_open_id = self.next_open_id.wrapping_add(1); + self.pending_opens.insert( + open_id, + PendingOpen { + request_pipe, + start_tx: start, + accepted_tx: accepted, + }, + ); + self.push_input(EngineInput::OpenStream { + open_id, + request_head, + config, + }); + } + RuntimeCommand::AcceptStream { + stream_id, + response_head, + response_pipe, + } => { + if let Some(DriverStreamIo::Responder { response, .. }) = + self.streams.get_mut(&stream_id) + { + *response = ResponderResponseIo::Streaming(OutboundIo::new( + Direction::Response, + response_pipe, + )); + } + self.push_input(EngineInput::AcceptStream { + stream_id, + response_head, + }); + } + RuntimeCommand::RejectStream { stream_id, code } => { + if let Some(DriverStreamIo::Responder { response, .. }) = + self.streams.get_mut(&stream_id) + { + *response = ResponderResponseIo::Rejected; + } + self.push_input(EngineInput::RejectStream { stream_id, code }); + } + RuntimeCommand::PollStream { stream_id } => self.poll_stream(stream_id), + RuntimeCommand::AdvanceInboundCredit { + stream_id, + dir, + amount, + } => self.push_input(EngineInput::InboundConsumed { + stream_id, + dir, + amount, + }), + RuntimeCommand::ResetOutbound { + stream_id, + dir, + code, + } => self.push_input(EngineInput::ResetOutbound { + stream_id, + dir, + code, + }), + RuntimeCommand::ResetInbound { + stream_id, + dir, + code, + } => self.push_input(EngineInput::ResetInbound { + stream_id, + dir, + code, + }), + RuntimeCommand::ResponderDropped { stream_id } => { + self.push_input(EngineInput::ResponderDropped { stream_id }); + } + RuntimeCommand::PendingAcceptDropped { stream_id } => { + if let Some(DriverStreamIo::Initiator { pending_accept, .. }) = + self.streams.get_mut(&stream_id) + { + if matches!(pending_accept, PendingAcceptState::Waiting(_)) { + *pending_accept = PendingAcceptState::Dropped; + } + } + self.push_input(EngineInput::PendingAcceptDropped { stream_id }); + self.push_input(EngineInput::ResetInbound { + stream_id, + dir: Direction::Response, + code: ResetCode::Cancelled, + }); + } + } + } + + fn poll_stream(&mut self, stream_id: StreamId) { + if let Some(stream) = self.streams.get_mut(&stream_id) { + match stream { + DriverStreamIo::Initiator { request, .. } => { + request.poll_pending(stream_id, &mut self.pending_inputs) + } + DriverStreamIo::Responder { response, .. } => { + if let ResponderResponseIo::Streaming(outbound) = response { + outbound.poll_pending(stream_id, &mut self.pending_inputs); + } + } + } + } + } +} + +impl Runtime

{ + pub async fn run(self) { + let runtime_tx = self.tx.upgrade().expect("runtime tx"); + let mut state = DriverState::new(self.config.engine, self.platform.load_peer().await); + let mut in_flight: Option> = None; + + loop { + if let Some(input) = state.pending_inputs.pop_front() { + let now = Instant::now(); + let pending_inputs = &mut state.pending_inputs; + let next_timer = &mut state.next_timer; + let pending_opens = &mut state.pending_opens; + let streams = &mut state.streams; + state + .engine + .run_tick(now, input, &self.platform, &mut |output| { + self.apply_output( + pending_inputs, + next_timer, + pending_opens, + streams, + &runtime_tx, + &mut in_flight, + output, + ); + }); + continue; + } + + if self.rx.is_closed() { + break; + } + + match self + .next_driver_event(state.next_timer, in_flight.as_mut()) + .await + { + DriverEvent::Command(command) => state.translate_command(command), + DriverEvent::WriteCompleted { + token, + tracked, + result, + } => { + in_flight = None; + state.push_input(EngineInput::WriteCompleted { + token, + tracked, + result, + }); + } + DriverEvent::TimerExpired => state.push_input(EngineInput::TimerExpired), + DriverEvent::Closed => break, + } + } + } + + async fn next_driver_event<'a>( + &'a self, + next_timer: Option, + mut in_flight: Option<&mut InFlightWrite<'a>>, + ) -> DriverEvent { + let recv_future = self.rx.recv(); + futures_lite::pin!(recv_future); + + let mut sleep_future = next_timer.map(|deadline| { + let timeout = deadline.saturating_duration_since(Instant::now()); + self.platform.sleep(timeout) + }); + + poll_fn(|cx| { + if let Some(in_flight) = in_flight.as_mut() { + if let Poll::Ready(result) = in_flight.future.as_mut().poll(cx) { + return Poll::Ready(DriverEvent::WriteCompleted { + token: in_flight.token, + tracked: in_flight.tracked, + result, + }); + } + } + + if let Some(future) = sleep_future.as_mut() { + if let Poll::Ready(()) = future.as_mut().poll(cx) { + return Poll::Ready(DriverEvent::TimerExpired); + } + } + + recv_future.as_mut().poll(cx).map(|res| match res { + Ok(command) => DriverEvent::Command(command), + Err(_) => DriverEvent::Closed, + }) + }) + .await + } + + fn apply_output<'a>( + &'a self, + pending_inputs: &mut VecDeque, + next_timer: &mut Option, + pending_opens: &mut HashMap, + streams: &mut HashMap, + runtime_tx: &async_channel::Sender, + in_flight: &mut Option>, + output: EngineOutput, + ) { + match output { + EngineOutput::SetTimer(deadline) => *next_timer = deadline, + EngineOutput::WriteMessage { + token, + tracked, + bytes, + } => { + *in_flight = Some(InFlightWrite { + token, + tracked, + future: self.platform.write_message(bytes), + }); + } + EngineOutput::PeerStatusChanged { peer, session } => { + self.platform.handle_peer_status(peer, &session); + } + EngineOutput::PersistPeer(peer) => self.platform.persist_peer(peer), + EngineOutput::ClearPeer => self.platform.clear_peer(), + EngineOutput::OpenStarted { open_id, stream_id } => { + let Some(pending) = pending_opens.remove(&open_id) else { + return; + }; + let _ = pending.start_tx.send(Ok(stream_id)); + let (response_reader, response_writer) = pipe::pipe(self.config.pipe_size_bytes); + streams.insert( + stream_id, + DriverStreamIo::Initiator { + request: OutboundIo::new(Direction::Request, pending.request_pipe), + response: InboundIo::new(response_writer), + pending_accept: PendingAcceptState::Waiting(PendingAcceptDelivery { + tx: pending.accepted_tx, + response_reader, + }), + }, + ); + } + EngineOutput::OpenAccepted { + stream_id, + response_head, + .. + } => { + let Some(DriverStreamIo::Initiator { pending_accept, .. }) = + streams.get_mut(&stream_id) + else { + return; + }; + match std::mem::replace(pending_accept, PendingAcceptState::Resolved) { + PendingAcceptState::Waiting(delivery) => { + let _ = delivery.tx.send(Ok(AcceptedStreamDelivery { + stream_id, + response_head, + response: delivery.response_reader, + tx: runtime_tx.clone(), + })); + } + PendingAcceptState::Dropped => { + *pending_accept = PendingAcceptState::Dropped; + } + PendingAcceptState::Resolved => {} + } + } + EngineOutput::OpenFailed { + open_id, + stream_id, + error, + } => { + if let Some(pending) = pending_opens.remove(&open_id) { + let _ = pending.start_tx.send(Err(error)); + return; + } + let Some(DriverStreamIo::Initiator { pending_accept, .. }) = + streams.get_mut(&stream_id) + else { + return; + }; + match std::mem::replace(pending_accept, PendingAcceptState::Resolved) { + PendingAcceptState::Waiting(delivery) => { + let _ = delivery.tx.send(Err(error)); + } + PendingAcceptState::Dropped => { + *pending_accept = PendingAcceptState::Dropped; + } + PendingAcceptState::Resolved => {} + } + } + EngineOutput::InboundStreamOpened { + stream_id, + request_head, + } => { + let (request_reader, request_writer) = pipe::pipe(self.config.pipe_size_bytes); + streams.insert( + stream_id, + DriverStreamIo::Responder { + request: InboundIo::new(request_writer), + response: ResponderResponseIo::Pending, + }, + ); + self.platform + .handle_inbound(HandlerEvent::Stream(InboundStream { + stream_id, + request_head, + request: InboundByteStream::new( + stream_id, + Direction::Request, + request_reader, + runtime_tx.clone(), + ), + respond_to: StreamResponder::new( + stream_id, + self.config.pipe_size_bytes, + runtime_tx.clone(), + ), + })); + } + EngineOutput::InboundData { + stream_id, + dir, + bytes, + } => { + if let Some(stream) = streams.get_mut(&stream_id) { + if let Some(inbound) = stream.inbound_mut(dir) { + inbound.write_or_cancel(stream_id, dir, &bytes, pending_inputs); + } + } + } + EngineOutput::InboundFinished { stream_id, dir } => { + if let Some(stream) = streams.get_mut(&stream_id) { + if let Some(inbound) = stream.inbound_mut(dir) { + inbound.finish(); + } + } + } + EngineOutput::InboundFailed { + stream_id, + dir, + error, + } => { + if let Some(stream) = streams.get_mut(&stream_id) { + if let Some(inbound) = stream.inbound_mut(dir) { + inbound.fail(error); + } + } + } + EngineOutput::NeedOutboundData { + stream_id, + dir, + offset, + max_len, + } => { + if let Some(stream) = streams.get_mut(&stream_id) { + if let Some(outbound) = stream.outbound_mut(dir) { + outbound.set_pending_pull(offset, max_len); + } + } + poll_stream(streams, pending_inputs, stream_id); + } + EngineOutput::ReleaseOutboundThrough { + stream_id, + dir, + recv_offset, + } => { + if let Some(stream) = streams.get_mut(&stream_id) { + if let Some(outbound) = stream.outbound_mut(dir) { + outbound.release_to(recv_offset); + } + } + } + EngineOutput::OutboundClosed { stream_id, dir } + | EngineOutput::OutboundFailed { stream_id, dir, .. } => { + if let Some(stream) = streams.get_mut(&stream_id) { + if let Some(outbound) = stream.outbound_mut(dir) { + outbound.close(); + } + } + } + EngineOutput::StreamReaped { stream_id } => { + if let Some(mut stream) = streams.remove(&stream_id) { + stream.close_all(); + } + } + } + } +} + +fn poll_stream( + streams: &mut HashMap, + pending_inputs: &mut VecDeque, + stream_id: StreamId, +) { + if let Some(stream) = streams.get_mut(&stream_id) { + match stream { + DriverStreamIo::Initiator { request, .. } => { + request.poll_pending(stream_id, pending_inputs) + } + DriverStreamIo::Responder { response, .. } => { + if let ResponderResponseIo::Streaming(outbound) = response { + outbound.poll_pending(stream_id, pending_inputs); + } + } + } + } +} diff --git a/ql2/src/runtime/handle.rs b/ql2/src/runtime/handle.rs new file mode 100644 index 00000000..43312dae --- /dev/null +++ b/ql2/src/runtime/handle.rs @@ -0,0 +1,406 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use async_channel::Sender; + +use crate::{ + runtime::{ + command::RuntimeCommand, + pipe::{self, ReadReady}, + AcceptedStreamDelivery, StreamConfig, + }, + wire::stream::{Direction, RejectCode, ResetCode}, + Peer, QlError, StreamId, +}; + +#[derive(Clone)] +pub struct RuntimeHandle { + pub(crate) tx: async_channel::Sender, + pub(crate) pipe_size_bytes: usize, +} + +pub struct PendingStream { + pub request: OutboundByteStream, + pub accepted: PendingAccept, +} + +#[derive(Debug)] +pub struct AcceptedStream { + pub stream_id: StreamId, + pub response_head: Vec, + pub response: InboundByteStream, +} + +#[derive(Debug)] +pub struct InboundStream { + pub stream_id: StreamId, + pub request_head: Vec, + pub request: InboundByteStream, + pub respond_to: StreamResponder, +} + +#[derive(Debug)] +pub struct StreamResponder { + stream_id: StreamId, + pipe_size_bytes: usize, + tx: async_channel::Sender, + armed: bool, +} + +pub struct InboundByteStream { + stream_id: StreamId, + dir: Direction, + pipe: pipe::PipeReader, + tx: Sender, + finished: bool, +} + +impl std::fmt::Debug for InboundByteStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("InboundByteStream") + .field("stream_id", &self.stream_id) + .field("dir", &self.dir) + .field("finished", &self.finished) + .finish_non_exhaustive() + } +} + +pub struct OutboundByteStream { + stream_id: StreamId, + dir: Direction, + pipe: Option>, + tx: Sender, +} + +pub struct PendingAccept { + stream_id: StreamId, + rx: Option>>, + tx: Sender, +} + +impl Future for PendingAccept { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + let Some(rx) = this.rx.as_mut() else { + return Poll::Ready(Err(QlError::Cancelled)); + }; + Pin::new(rx).poll(cx).map(|result| match result { + Ok(Ok(delivery)) => { + let AcceptedStreamDelivery { + stream_id, + response_head, + response, + tx, + } = delivery; + this.rx = None; + Ok(AcceptedStream { + stream_id, + response_head, + response: InboundByteStream::new(stream_id, Direction::Response, response, tx), + }) + } + Ok(Err(error)) => { + this.rx = None; + Err(error) + } + Err(_) => { + this.rx = None; + Err(QlError::Cancelled) + } + }) + } +} + +impl Drop for PendingAccept { + fn drop(&mut self) { + if self.rx.take().is_none() { + return; + } + let _ = self.tx.try_send(RuntimeCommand::PendingAcceptDropped { + stream_id: self.stream_id, + }); + } +} + +impl InboundByteStream { + pub(crate) fn new( + stream_id: StreamId, + dir: Direction, + pipe: pipe::PipeReader, + tx: Sender, + ) -> Self { + Self { + stream_id, + dir, + pipe, + tx, + finished: false, + } + } + + pub async fn next_chunk(&mut self) -> Result>, QlError> { + if self.finished { + return Ok(None); + } + match self.pipe.ready().await { + ReadReady::Data => { + let chunk = self.pipe.peek_buf().to_vec(); + let len = chunk.len(); + self.pipe.consume(len); + if len > 0 { + let _ = self + .tx + .send(RuntimeCommand::AdvanceInboundCredit { + stream_id: self.stream_id, + dir: self.dir, + amount: len as u64, + }) + .await; + } + Ok(Some(chunk)) + } + ReadReady::Eof => { + self.finished = true; + Ok(None) + } + ReadReady::Error(error) => { + self.finished = true; + Err(error) + } + } + } + + pub async fn reset(mut self, code: ResetCode) -> Result<(), QlError> { + self.finished = true; + self.tx + .send(RuntimeCommand::ResetInbound { + stream_id: self.stream_id, + dir: self.dir, + code, + }) + .await + .map_err(|_| QlError::Cancelled) + } +} + +impl Drop for InboundByteStream { + fn drop(&mut self) { + if self.finished { + return; + } + let _ = self.tx.try_send(RuntimeCommand::ResetInbound { + stream_id: self.stream_id, + dir: self.dir, + code: ResetCode::Cancelled, + }); + } +} + +impl OutboundByteStream { + pub(crate) fn new( + stream_id: StreamId, + dir: Direction, + pipe: pipe::PipeWriter, + tx: Sender, + ) -> Self { + Self { + stream_id, + dir, + pipe: Some(pipe), + tx, + } + } + + pub async fn write(&mut self, bytes: &[u8]) -> Result { + let pipe = self.pipe.as_mut().expect("stream not finished or reset"); + let written = pipe.write(bytes).await?; + self.tx + .try_send(RuntimeCommand::PollStream { + stream_id: self.stream_id, + }) + .map_err(|_| QlError::Cancelled)?; + Ok(written) + } + + pub async fn write_all(&mut self, mut bytes: &[u8]) -> Result<(), QlError> { + while !bytes.is_empty() { + let written = self.write(bytes).await?; + if written == 0 { + return Err(QlError::Cancelled); + } + bytes = &bytes[written..]; + } + Ok(()) + } + + pub async fn finish(mut self) -> Result<(), QlError> { + let Some(mut pipe) = self.pipe.take() else { + return Ok(()); + }; + pipe.finish(); + self.tx + .try_send(RuntimeCommand::PollStream { + stream_id: self.stream_id, + }) + .map_err(|_| QlError::Cancelled)?; + pipe.closed().await; + Ok(()) + } + + pub async fn reset(mut self, code: ResetCode) -> Result<(), QlError> { + self.pipe.take(); + self.tx + .send(RuntimeCommand::ResetOutbound { + stream_id: self.stream_id, + dir: self.dir, + code, + }) + .await + .map_err(|_| QlError::Cancelled) + } +} + +impl Drop for OutboundByteStream { + fn drop(&mut self) { + if self.pipe.take().is_none() { + return; + } + let _ = self.tx.try_send(RuntimeCommand::ResetOutbound { + stream_id: self.stream_id, + dir: self.dir, + code: ResetCode::Cancelled, + }); + } +} + +impl StreamResponder { + pub(crate) fn new( + stream_id: StreamId, + pipe_size_bytes: usize, + tx: async_channel::Sender, + ) -> Self { + Self { + stream_id, + pipe_size_bytes, + tx, + armed: true, + } + } + + pub fn accept(mut self, response_head: Vec) -> Result { + self.armed = false; + let (response_pipe, response_writer) = pipe::pipe(self.pipe_size_bytes); + self.tx + .send_blocking(RuntimeCommand::AcceptStream { + stream_id: self.stream_id, + response_head, + response_pipe, + }) + .map_err(|_| QlError::Cancelled)?; + Ok(OutboundByteStream::new( + self.stream_id, + Direction::Response, + response_writer, + self.tx.clone(), + )) + } + + pub fn reject(mut self, code: RejectCode) -> Result<(), QlError> { + self.armed = false; + self.tx + .try_send(RuntimeCommand::RejectStream { + stream_id: self.stream_id, + code, + }) + .map_err(|_| QlError::Cancelled) + } +} + +impl Drop for StreamResponder { + fn drop(&mut self) { + if !self.armed { + return; + } + let _ = self.tx.try_send(RuntimeCommand::ResponderDropped { + stream_id: self.stream_id, + }); + } +} + +impl RuntimeHandle { + pub fn bind_peer(&self, peer: Peer) { + self.send(RuntimeCommand::BindPeer { peer }) + } + + pub fn pair(&self) -> Result<(), QlError> { + self.tx + .send_blocking(RuntimeCommand::Pair) + .map_err(|_| QlError::Cancelled) + } + + pub fn connect(&self) -> Result<(), QlError> { + self.tx + .send_blocking(RuntimeCommand::Connect) + .map_err(|_| QlError::Cancelled) + } + + pub fn unpair(&self) -> Result<(), QlError> { + self.tx + .send_blocking(RuntimeCommand::Unpair) + .map_err(|_| QlError::Cancelled) + } + + pub fn send_incoming(&self, bytes: Vec) { + self.send(RuntimeCommand::Incoming(bytes)) + } + + pub async fn open_stream( + &self, + request_head: Vec, + config: StreamConfig, + ) -> Result { + let (request_pipe, request_writer) = pipe::pipe(self.pipe_size_bytes); + let (accepted_tx, accepted_rx) = oneshot::channel(); + let (start_tx, start_rx) = oneshot::channel(); + + self.tx + .send(RuntimeCommand::OpenStream { + request_head, + request_pipe, + accepted: accepted_tx, + start: start_tx, + config, + }) + .await + .map_err(|_| QlError::Cancelled)?; + + let stream_id = start_rx.await.unwrap_or(Err(QlError::Cancelled))?; + + Ok(PendingStream { + request: OutboundByteStream::new( + stream_id, + Direction::Request, + request_writer, + self.tx.clone(), + ), + accepted: PendingAccept { + stream_id, + rx: Some(accepted_rx), + tx: self.tx.clone(), + }, + }) + } +} + +impl RuntimeHandle { + #[inline] + #[track_caller] + fn send(&self, cmd: RuntimeCommand) { + self.tx.send_blocking(cmd).expect("runtime is alive") + } +} diff --git a/ql2/src/runtime/mod.rs b/ql2/src/runtime/mod.rs new file mode 100644 index 00000000..dac4bb46 --- /dev/null +++ b/ql2/src/runtime/mod.rs @@ -0,0 +1,82 @@ +pub use handle::{ + AcceptedStream, InboundByteStream, InboundStream, OutboundByteStream, PendingAccept, + PendingStream, RuntimeHandle, StreamResponder, +}; + +pub use crate::engine::{EngineConfig, InitiatorStage, KeepAliveConfig, PeerSession, Token}; + +pub(crate) mod command; +pub(crate) mod driver; +pub mod handle; +pub(crate) mod pipe; + +use std::time::Duration; + +use crate::{platform::QlPlatform, StreamId}; + +#[derive(Debug, Clone, Copy, Default)] +pub struct StreamConfig { + pub open_timeout: Option, +} + +#[derive(Debug, Clone, Copy)] +pub struct RuntimeConfig { + pub engine: EngineConfig, + pub pipe_size_bytes: usize, +} + +impl Default for RuntimeConfig { + fn default() -> Self { + Self { + engine: EngineConfig::default(), + pipe_size_bytes: 2048, + } + } +} + +impl RuntimeConfig { + pub(crate) fn normalized(mut self) -> Self { + self.engine = self.engine.normalized(); + self.pipe_size_bytes = self.pipe_size_bytes.max(self.engine.max_payload_bytes); + self + } +} + +#[derive(Debug)] +pub enum HandlerEvent { + Stream(InboundStream), +} + +pub(crate) struct AcceptedStreamDelivery { + pub stream_id: StreamId, + pub response_head: Vec, + pub response: crate::runtime::pipe::PipeReader, + pub tx: async_channel::Sender, +} + +pub struct Runtime

{ + platform: P, + config: RuntimeConfig, + rx: async_channel::Receiver, + tx: async_channel::WeakSender, +} + +pub fn new_runtime

(platform: P, config: RuntimeConfig) -> (Runtime

, RuntimeHandle) +where + P: QlPlatform, +{ + let config = config.normalized(); + let (tx, rx) = async_channel::unbounded(); + ( + Runtime { + platform, + config, + rx, + tx: tx.downgrade(), + }, + RuntimeHandle { + tx, + pipe_size_bytes: config.pipe_size_bytes, + }, + ) +} diff --git a/ql2/src/runtime/pipe.rs b/ql2/src/runtime/pipe.rs new file mode 100644 index 00000000..4c3bb65c --- /dev/null +++ b/ql2/src/runtime/pipe.rs @@ -0,0 +1,772 @@ +use std::{ + cell::UnsafeCell, + io::{self, Read}, + mem::{self, MaybeUninit}, + ptr, + sync::{ + atomic::{AtomicU64, AtomicU8, Ordering}, + Arc, + }, + task::{Context, Poll}, +}; + +use atomic_waker::AtomicWaker; +use futures_lite::future::poll_fn; + +const PIPE_OPEN: u8 = 0; +const PIPE_FINISHED: u8 = 1; +const PIPE_FAILED: u8 = 2; +const PIPE_FAILED_TAKEN: u8 = 3; +const PIPE_CLOSED: u8 = 4; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PipeState { + Open, + Finished, + Failed, + FailedTaken, + Closed, +} + +impl PipeState { + fn from_u8(value: u8) -> Self { + match value { + PIPE_OPEN => Self::Open, + PIPE_FINISHED => Self::Finished, + PIPE_FAILED => Self::Failed, + PIPE_FAILED_TAKEN => Self::FailedTaken, + PIPE_CLOSED => Self::Closed, + _ => unreachable!("invalid pipe state"), + } + } + + fn as_u8(self) -> u8 { + match self { + Self::Open => PIPE_OPEN, + Self::Finished => PIPE_FINISHED, + Self::Failed => PIPE_FAILED, + Self::FailedTaken => PIPE_FAILED_TAKEN, + Self::Closed => PIPE_CLOSED, + } + } +} + +pub fn pipe(cap: usize) -> (PipeReader, PipeWriter) { + assert!(cap > 0, "pipe capacity must be positive"); + + let mut storage = Vec::::with_capacity(cap); + let buffer = storage.as_mut_ptr(); + mem::forget(storage); + + let inner = Arc::new(PipeInner { + released: AtomicU64::new(0), + produced: AtomicU64::new(0), + state: AtomicU8::new(PIPE_OPEN), + error: UnsafeCell::new(MaybeUninit::uninit()), + readable: AtomicWaker::new(), + writable: AtomicWaker::new(), + closed: AtomicWaker::new(), + buffer, + cap, + }); + + ( + PipeReader { + inner: inner.clone(), + released: 0, + produced: 0, + sent: 0, + }, + PipeWriter { + inner, + released: 0, + produced: 0, + sealed: false, + }, + ) +} + +struct PipeInner { + released: AtomicU64, + produced: AtomicU64, + state: AtomicU8, + error: UnsafeCell>, + readable: AtomicWaker, + writable: AtomicWaker, + closed: AtomicWaker, + buffer: *mut u8, + cap: usize, +} + +unsafe impl Send for PipeInner {} +unsafe impl Sync for PipeInner {} + +impl Drop for PipeInner { + fn drop(&mut self) { + if PipeState::from_u8(self.state.load(Ordering::Acquire)) == PipeState::Failed { + unsafe { + self.error.get_mut().assume_init_drop(); + } + } + unsafe { + drop(Vec::from_raw_parts(self.buffer, 0, self.cap)); + } + } +} + +pub struct PipeWriter { + inner: Arc>, + released: u64, + produced: u64, + sealed: bool, +} + +pub struct PipeReader { + inner: Arc>, + released: u64, + produced: u64, + sent: u64, +} + +impl std::fmt::Debug for PipeWriter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PipeWriter").finish_non_exhaustive() + } +} + +impl std::fmt::Debug for PipeReader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PipeReader").finish_non_exhaustive() + } +} + +pub struct SendGrant<'a, E> { + inner: &'a PipeInner, + offset: u64, + len: usize, + position: usize, +} + +pub enum ReadReady { + Data, + Eof, + Error(E), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct PipeClosed; + +impl PipeWriter { + pub fn try_write(&mut self, src: &[u8]) -> Result { + if src.is_empty() { + return Ok(0); + } + if self.sealed || self.is_closed() { + return Err(PipeClosed); + } + self.released = self.inner.released.load(Ordering::Acquire); + let n = self.available_capacity().min(src.len()); + if n == 0 { + return Ok(0); + } + unsafe { + write_bytes(self.inner.buffer, self.inner.cap, self.produced, &src[..n]); + } + self.produced = self.produced.saturating_add(n as u64); + self.inner.produced.store(self.produced, Ordering::Release); + self.inner.readable.wake(); + Ok(n) + } + + pub async fn write(&mut self, src: &[u8]) -> Result { + poll_fn(|cx| self.poll_write(cx, src)).await + } + + pub fn finish(&mut self) { + if self.sealed { + return; + } + self.sealed = true; + self.publish_state(PipeState::Finished); + } + + pub fn fail(&mut self, error: E) { + if self.sealed { + return; + } + self.sealed = true; + unsafe { + (*self.inner.error.get()).write(error); + } + match self.inner.state.compare_exchange( + PIPE_OPEN, + PIPE_FAILED, + Ordering::Release, + Ordering::Acquire, + ) { + Ok(_) => { + self.inner.readable.wake(); + } + Err(_) => unsafe { + (*self.inner.error.get()).assume_init_drop(); + }, + } + } + + pub fn close(&mut self) { + if self.sealed { + return; + } + self.sealed = true; + loop { + let current = PipeState::from_u8(self.inner.state.load(Ordering::Acquire)); + match current { + PipeState::Closed => return, + PipeState::Failed => { + if self + .inner + .state + .compare_exchange( + PIPE_FAILED, + PIPE_CLOSED, + Ordering::AcqRel, + Ordering::Acquire, + ) + .is_ok() + { + unsafe { + (*self.inner.error.get()).assume_init_drop(); + } + self.inner.readable.wake(); + self.inner.writable.wake(); + self.inner.closed.wake(); + return; + } + } + _ => { + if self + .inner + .state + .compare_exchange( + current.as_u8(), + PIPE_CLOSED, + Ordering::AcqRel, + Ordering::Acquire, + ) + .is_ok() + { + self.inner.readable.wake(); + self.inner.writable.wake(); + self.inner.closed.wake(); + return; + } + } + } + } + } + + pub async fn closed(&mut self) { + poll_fn(|cx| self.poll_closed(cx)).await + } + + fn poll_write(&mut self, cx: &mut Context<'_>, src: &[u8]) -> Poll> { + if src.is_empty() { + return Poll::Ready(Ok(0)); + } + if self.sealed || self.is_closed() { + return Poll::Ready(Err(PipeClosed)); + } + + let n = match self.poll_reserve(cx, src.len()) { + Poll::Ready(Ok(n)) => n, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + }; + + unsafe { + write_bytes(self.inner.buffer, self.inner.cap, self.produced, &src[..n]); + } + self.produced = self.produced.saturating_add(n as u64); + self.inner.produced.store(self.produced, Ordering::Release); + self.inner.readable.wake(); + Poll::Ready(Ok(n)) + } + + fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> { + if self.is_closed() { + return Poll::Ready(()); + } + self.inner.closed.register(cx.waker()); + if self.is_closed() { + self.inner.closed.take(); + Poll::Ready(()) + } else { + Poll::Pending + } + } + + fn poll_reserve( + &mut self, + cx: &mut Context<'_>, + want: usize, + ) -> Poll> { + self.released = self.inner.released.load(Ordering::Acquire); + let available = self.available_capacity(); + if available > 0 { + return Poll::Ready(Ok(available.min(want))); + } + + self.inner.writable.register(cx.waker()); + self.released = self.inner.released.load(Ordering::Acquire); + if self.is_closed() { + self.inner.writable.take(); + return Poll::Ready(Err(PipeClosed)); + } + let available = self.available_capacity(); + if available > 0 { + self.inner.writable.take(); + Poll::Ready(Ok(available.min(want))) + } else { + Poll::Pending + } + } + + fn available_capacity(&self) -> usize { + let used = self.produced.saturating_sub(self.released) as usize; + self.inner.cap.saturating_sub(used) + } + + fn publish_state(&mut self, next: PipeState) { + let _ = self.inner.state.compare_exchange( + PIPE_OPEN, + next.as_u8(), + Ordering::Release, + Ordering::Acquire, + ); + self.inner.readable.wake(); + } + + fn is_closed(&self) -> bool { + PipeState::from_u8(self.inner.state.load(Ordering::Acquire)) == PipeState::Closed + } + + #[cfg(test)] + fn state(&self) -> PipeState { + PipeState::from_u8(self.inner.state.load(Ordering::Acquire)) + } + + #[cfg(test)] + fn is_drained(&self) -> bool { + self.inner.released.load(Ordering::Acquire) >= self.inner.produced.load(Ordering::Acquire) + } +} + +impl Drop for PipeWriter { + fn drop(&mut self) { + if self.sealed { + return; + } + self.sealed = true; + self.publish_state(PipeState::Finished); + } +} + +impl PipeReader { + pub async fn ready(&mut self) -> ReadReady { + poll_fn(|cx| self.poll_ready(cx)).await + } + + pub fn peek_buf(&self) -> &[u8] { + let len = self + .available_data() + .min(self.inner.cap - ((self.released as usize) % self.inner.cap)); + unsafe { + ptr::slice_from_raw_parts( + self.inner + .buffer + .add((self.released as usize) % self.inner.cap), + len, + ) + .as_ref() + .unwrap() + } + } + + pub fn consume(&mut self, amt: usize) { + assert!( + amt <= self.available_data(), + "cannot consume more bytes than available" + ); + self.released = self.released.saturating_add(amt as u64); + self.inner.released.store(self.released, Ordering::Release); + if self.sent < self.released { + self.sent = self.released; + } + self.inner.writable.wake(); + } + + pub fn reserve_send( + &mut self, + remote_max_offset: u64, + max_len: usize, + ) -> Option> { + self.produced = self.inner.produced.load(Ordering::Acquire); + let limit = self.produced.min(remote_max_offset); + if self.sent >= limit { + return None; + } + let len = ((limit - self.sent) as usize).min(max_len); + let offset = self.sent; + self.sent = self.sent.saturating_add(len as u64); + Some(SendGrant { + inner: self.inner.as_ref(), + offset, + len, + position: 0, + }) + } + + pub fn retry_send(&self, offset: u64, len: usize) -> Option> { + let released = self.inner.released.load(Ordering::Acquire); + let produced = self.inner.produced.load(Ordering::Acquire); + if offset < released || offset.saturating_add(len as u64) > produced { + return None; + } + Some(SendGrant { + inner: self.inner.as_ref(), + offset, + len, + position: 0, + }) + } + + pub fn reserve_at(&mut self, offset: u64, max_len: usize) -> Option> { + if offset < self.sent { + return self.retry_send(offset, max_len); + } + if offset == self.sent { + return self.reserve_send(offset.saturating_add(max_len as u64), max_len); + } + None + } + + pub fn release_to(&mut self, released: u64) { + self.released = released; + self.inner.released.store(released, Ordering::Release); + self.inner.writable.wake(); + } + + pub fn sent_offset(&self) -> u64 { + self.sent + } + + pub fn writer_finished(&self) -> bool { + PipeState::from_u8(self.inner.state.load(Ordering::Acquire)) == PipeState::Finished + } + + pub fn all_sent(&mut self) -> bool { + self.produced = self.inner.produced.load(Ordering::Acquire); + self.sent >= self.produced + } + + pub fn close(&mut self) { + loop { + match PipeState::from_u8(self.inner.state.load(Ordering::Acquire)) { + PipeState::Closed => return, + PipeState::Failed => { + if self + .inner + .state + .compare_exchange( + PIPE_FAILED, + PIPE_CLOSED, + Ordering::AcqRel, + Ordering::Acquire, + ) + .is_ok() + { + unsafe { + (*self.inner.error.get()).assume_init_drop(); + } + self.inner.writable.wake(); + self.inner.closed.wake(); + return; + } + } + current => { + if self + .inner + .state + .compare_exchange( + current.as_u8(), + PIPE_CLOSED, + Ordering::AcqRel, + Ordering::Acquire, + ) + .is_ok() + { + self.inner.writable.wake(); + self.inner.closed.wake(); + return; + } + } + } + } + } + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.produced = self.inner.produced.load(Ordering::Acquire); + if self.available_data() > 0 { + return Poll::Ready(ReadReady::Data); + } + + loop { + match PipeState::from_u8(self.inner.state.load(Ordering::Acquire)) { + PipeState::Open => { + self.inner.readable.register(cx.waker()); + self.produced = self.inner.produced.load(Ordering::Acquire); + if self.available_data() > 0 { + self.inner.readable.take(); + return Poll::Ready(ReadReady::Data); + } + if PipeState::from_u8(self.inner.state.load(Ordering::Acquire)) + == PipeState::Open + { + return Poll::Pending; + } + self.inner.readable.take(); + } + PipeState::Finished | PipeState::Closed => return Poll::Ready(ReadReady::Eof), + PipeState::Failed => { + let err = match self.inner.state.compare_exchange( + PIPE_FAILED, + PIPE_FAILED_TAKEN, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => unsafe { (*self.inner.error.get()).assume_init_read() }, + Err(_) => continue, + }; + return Poll::Ready(ReadReady::Error(err)); + } + PipeState::FailedTaken => return Poll::Ready(ReadReady::Eof), + } + } + } + + fn available_data(&self) -> usize { + self.produced.saturating_sub(self.released) as usize + } + + #[cfg(test)] + pub fn state(&self) -> PipeState { + PipeState::from_u8(self.inner.state.load(Ordering::Acquire)) + } +} + +impl Drop for PipeReader { + fn drop(&mut self) { + self.close(); + } +} + +impl SendGrant<'_, E> { + pub fn offset(&self) -> u64 { + self.offset + } + + pub fn len(&self) -> usize { + self.len + } +} + +impl Read for SendGrant<'_, E> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let remaining = self.len.saturating_sub(self.position); + if remaining == 0 || buf.is_empty() { + return Ok(0); + } + let n = remaining.min(buf.len()); + unsafe { + copy_bytes( + self.inner.buffer, + self.inner.cap, + self.offset.saturating_add(self.position as u64), + &mut buf[..n], + ); + } + self.position += n; + Ok(n) + } +} + +unsafe fn write_bytes(buffer: *mut u8, cap: usize, offset: u64, src: &[u8]) { + let start = (offset as usize) % cap; + let first = src.len().min(cap - start); + ptr::copy_nonoverlapping(src.as_ptr(), buffer.add(start), first); + if first < src.len() { + ptr::copy_nonoverlapping(src[first..].as_ptr(), buffer, src.len() - first); + } +} + +unsafe fn copy_bytes(buffer: *mut u8, cap: usize, offset: u64, dst: &mut [u8]) { + let len = dst.len(); + let start = (offset as usize) % cap; + let first = len.min(cap - start); + ptr::copy_nonoverlapping(buffer.add(start), dst.as_mut_ptr(), first); + if first < len { + ptr::copy_nonoverlapping(buffer, dst.as_mut_ptr().add(first), len - first); + } +} + +#[cfg(test)] +mod tests { + use std::convert::Infallible; + + use futures_lite::future::poll_fn; + use tokio::task::yield_now; + + use super::*; + + #[tokio::test(flavor = "current_thread")] + async fn pipe_writes_reads_and_releases() { + let (mut reader, mut writer) = pipe::(8); + assert_eq!( + poll_fn(|cx| writer.poll_write(cx, b"abcd")).await.unwrap(), + 4 + ); + + let mut send = reader.reserve_send(8, 8).unwrap(); + assert_eq!(send.offset(), 0); + assert_eq!(send.len(), 4); + let mut bytes = vec![0; send.len()]; + send.read_exact(&mut bytes).unwrap(); + assert_eq!(bytes, b"abcd"); + + reader.release_to(4); + assert!(writer.is_drained()); + assert_eq!(poll_fn(|cx| writer.poll_write(cx, b"ef")).await.unwrap(), 2); + let mut send = reader.reserve_send(8, 8).unwrap(); + let mut bytes = vec![0; send.len()]; + send.read_exact(&mut bytes).unwrap(); + assert_eq!(bytes, b"ef"); + } + + #[tokio::test(flavor = "current_thread")] + async fn pipe_blocks_until_released() { + let (mut reader, mut writer) = pipe::(4); + assert_eq!( + poll_fn(|cx| writer.poll_write(cx, b"abcd")).await.unwrap(), + 4 + ); + + let mut blocked = false; + let poll = poll_fn(|cx| match writer.poll_write(cx, b"e") { + Poll::Ready(result) => Poll::Ready(result), + Poll::Pending => { + blocked = true; + Poll::Ready(Ok(0)) + } + }) + .await + .unwrap(); + assert_eq!(poll, 0); + assert!(blocked); + + reader.release_to(4); + yield_now().await; + assert_eq!(poll_fn(|cx| writer.poll_write(cx, b"e")).await.unwrap(), 1); + } + + #[tokio::test(flavor = "current_thread")] + async fn pipe_closed_waits_for_reader_close() { + let (mut reader, mut writer) = pipe::(8); + writer.finish(); + assert_eq!(writer.state(), PipeState::Finished); + + let waiter = tokio::spawn(async move { + writer.closed().await; + }); + + yield_now().await; + assert!(!waiter.is_finished()); + reader.close(); + waiter.await.unwrap(); + } + + #[tokio::test(flavor = "current_thread")] + async fn pipe_wraparound_reads_correctly() { + let (mut reader, mut writer) = pipe::(8); + assert_eq!( + poll_fn(|cx| writer.poll_write(cx, b"abcdef")) + .await + .unwrap(), + 6 + ); + let mut send = reader.reserve_send(8, 6).unwrap(); + let mut bytes = vec![0; send.len()]; + send.read_exact(&mut bytes).unwrap(); + assert_eq!(bytes, b"abcdef"); + reader.release_to(6); + + assert_eq!( + poll_fn(|cx| writer.poll_write(cx, b"ghijkl")) + .await + .unwrap(), + 6 + ); + let mut send = reader.reserve_send(12, 6).unwrap(); + let mut bytes = vec![0; send.len()]; + send.read_exact(&mut bytes).unwrap(); + assert_eq!(bytes, b"ghijkl"); + } + + #[tokio::test(flavor = "current_thread")] + async fn closing_reader_wakes_writer() { + let (mut reader, mut writer) = pipe::(4); + assert_eq!( + poll_fn(|cx| writer.poll_write(cx, b"abcd")).await.unwrap(), + 4 + ); + reader.close(); + assert_eq!(reader.state(), PipeState::Closed); + let err = poll_fn(|cx| writer.poll_write(cx, b"e")).await.unwrap_err(); + assert_eq!(err, PipeClosed); + } + + #[tokio::test(flavor = "current_thread")] + async fn buffered_bytes_drain_before_eof() { + let (mut reader, mut writer) = pipe::(8); + poll_fn(|cx| writer.poll_write(cx, b"abc")).await.unwrap(); + writer.finish(); + + assert!(matches!( + poll_fn(|cx| reader.poll_ready(cx)).await, + ReadReady::Data + )); + assert_eq!(reader.peek_buf(), b"abc"); + reader.consume(3); + assert!(matches!( + poll_fn(|cx| reader.poll_ready(cx)).await, + ReadReady::Eof + )); + } + + #[tokio::test(flavor = "current_thread")] + async fn buffered_bytes_drain_before_error() { + let (mut reader, mut writer) = pipe::<&'static str>(8); + poll_fn(|cx| writer.poll_write(cx, b"abc")).await.unwrap(); + writer.fail("boom"); + + assert!(matches!( + poll_fn(|cx| reader.poll_ready(cx)).await, + ReadReady::Data + )); + assert_eq!(reader.peek_buf(), b"abc"); + reader.consume(3); + match poll_fn(|cx| reader.poll_ready(cx)).await { + ReadReady::Error(err) => assert_eq!(err, "boom"), + _ => panic!("expected pipe error"), + } + } +} diff --git a/ql2/src/tests/handshake.rs b/ql2/src/tests/handshake.rs new file mode 100644 index 00000000..ab656a02 --- /dev/null +++ b/ql2/src/tests/handshake.rs @@ -0,0 +1,99 @@ +use std::time::Duration; + +use super::*; + +#[tokio::test(flavor = "current_thread")] +async fn handshake_initiator_connects() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn handshake_timeout_disconnects() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + handshake_timeout: Duration::from_millis(60), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, _outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn confirm_write_failure_disconnects_initiator() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new_with_failed_write(1, 2); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Initiator).await; + await_status(&status_b, peer_a.xid, PeerStage::Responder).await; + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + }) + .await; +} diff --git a/ql2/src/tests/heartbeat.rs b/ql2/src/tests/heartbeat.rs new file mode 100644 index 00000000..3f4dc0c7 --- /dev/null +++ b/ql2/src/tests/heartbeat.rs @@ -0,0 +1,455 @@ +use bc_components::SymmetricKey; + +use super::*; + +#[tokio::test(flavor = "current_thread")] +async fn heartbeat_ignored_without_session() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, _status_a) = TestPlatform::new(1); + let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); + + let peer_a = platform_a.xid(); + let peer_b = platform_b.xid(); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + tokio::task::spawn_local(async move { runtime_a.run().await }); + + handle_a.bind_peer(Peer { + peer: peer_b, + signing_key: platform_b.signing_public_key().clone(), + encapsulation_key: platform_b.encapsulation_public_key().clone(), + }); + + let heartbeat = wire::heartbeat::encrypt_heartbeat( + QlHeader { + sender: peer_b, + recipient: peer_a, + }, + &SymmetricKey::new(), + HeartbeatBody { + packet_id: PacketId(1), + valid_until: now_secs().saturating_add(60), + }, + test_encryption_nonce(1), + ); + handle_a.send_incoming(wire::encode_record(&heartbeat)); + + let result = tokio::time::timeout(Duration::from_millis(50), outbound_a.recv()).await; + assert!(result.is_err(), "expected heartbeat to be ignored"); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn keepalive_disabled_no_heartbeat() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let result = tokio::time::timeout(Duration::from_millis(120), heartbeat_rx.recv()).await; + assert!(result.is_err(), "unexpected heartbeat while disabled"); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn heartbeat_sent_after_idle() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(30), + timeout: Duration::from_millis(80), + }; + let config_a = RuntimeConfig { + engine: EngineConfig { + keep_alive: Some(keep_alive), + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let config_b = RuntimeConfig { + engine: EngineConfig { + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + tokio::time::timeout(Duration::from_millis(200), heartbeat_rx.recv()) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn heartbeat_reply_when_connected() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(30), + timeout: Duration::from_millis(80), + }; + let config_a = RuntimeConfig { + engine: EngineConfig { + keep_alive: Some(keep_alive), + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let config_b = RuntimeConfig { + engine: EngineConfig { + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_ab_tx, heartbeat_ab_rx) = async_channel::unbounded(); + let (heartbeat_ba_tx, heartbeat_ba_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_ab_tx); + spawn_heartbeat_tap_forwarder(outbound_b, handle_a.clone(), heartbeat_ba_tx); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + tokio::time::timeout(Duration::from_millis(200), heartbeat_ab_rx.recv()) + .await + .unwrap() + .unwrap(); + tokio::time::timeout(Duration::from_millis(200), heartbeat_ba_rx.recv()) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn any_stream_clears_pending() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(120), + timeout: Duration::from_millis(40), + }; + let config_a = RuntimeConfig { + engine: EngineConfig { + keep_alive: Some(keep_alive), + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let config_b = RuntimeConfig { + engine: EngineConfig { + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a, inbound_a) = InboundPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); + spawn_drop_heartbeat_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + tokio::time::timeout(Duration::from_millis(200), heartbeat_rx.recv()) + .await + .unwrap() + .unwrap(); + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_a.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let response = stream.respond_to.accept(Vec::new()).unwrap(); + response.finish().await.unwrap(); + }); + + let pending = handle_b + .open_stream(Vec::new(), Default::default()) + .await + .unwrap(); + pending.request.finish().await.unwrap(); + let _ = pending.accepted.await.unwrap(); + + let window = keep_alive.timeout + Duration::from_millis(20); + let disconnect = tokio::time::timeout(window, async { + loop { + if let Ok(event) = status_a.recv().await { + if event.peer == peer_b.xid && event.stage == PeerStage::Disconnected { + return; + } + } + } + }) + .await; + assert!(disconnect.is_err(), "unexpected disconnect"); + + let _ = responder_task.await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn heartbeat_timeout_disconnects_and_drops_outbound() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(80), + timeout: Duration::from_millis(60), + }; + let config_a = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(300), + keep_alive: Some(keep_alive), + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let config_b = RuntimeConfig { + engine: EngineConfig { + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(2); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(1); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let drop_flag = Arc::new(AtomicBool::new(false)); + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_gated_forwarder(outbound_b, handle_a.clone(), drop_flag.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let response = stream.respond_to.accept(Vec::new()).unwrap(); + response.finish().await.unwrap(); + }); + + drop_flag.store(true, Ordering::Relaxed); + + let pending = handle_a + .open_stream(Vec::new(), Default::default()) + .await + .unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + + let result = tokio::time::timeout(Duration::from_millis(300), pending.accepted) + .await + .unwrap(); + assert!(matches!(result, Err(QlError::SendFailed))); + + responder_task.abort(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn no_ping_pong() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(200), + timeout: Duration::from_millis(60), + }; + let config_a = RuntimeConfig { + engine: EngineConfig { + keep_alive: Some(keep_alive), + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let config_b = RuntimeConfig { + engine: EngineConfig { + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_ab_tx, heartbeat_ab_rx) = async_channel::unbounded(); + let (heartbeat_ba_tx, heartbeat_ba_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_ab_tx); + spawn_heartbeat_tap_forwarder(outbound_b, handle_a.clone(), heartbeat_ba_tx); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + tokio::time::timeout(Duration::from_millis(300), heartbeat_ab_rx.recv()) + .await + .unwrap() + .unwrap(); + tokio::time::timeout(Duration::from_millis(200), heartbeat_ba_rx.recv()) + .await + .unwrap() + .unwrap(); + + let followup = + tokio::time::timeout(Duration::from_millis(50), heartbeat_ab_rx.recv()).await; + assert!(followup.is_err(), "unexpected heartbeat ping-pong"); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn invalid_heartbeat_ignored() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let heartbeat = wire::heartbeat::encrypt_heartbeat( + QlHeader { + sender: peer_b.xid, + recipient: peer_a.xid, + }, + &SymmetricKey::new(), + HeartbeatBody { + packet_id: PacketId(42), + valid_until: now_secs().saturating_add(30), + }, + test_encryption_nonce(42), + ); + handle_a.send_incoming(wire::encode_record(&heartbeat)); + + let result = tokio::time::timeout(Duration::from_millis(50), heartbeat_rx.recv()).await; + assert!(result.is_err(), "unexpected heartbeat reply"); + }) + .await; +} diff --git a/ql2/src/tests/mod.rs b/ql2/src/tests/mod.rs new file mode 100644 index 00000000..f50a4d10 --- /dev/null +++ b/ql2/src/tests/mod.rs @@ -0,0 +1,1027 @@ +use std::{ + future::Future, + sync::{ + atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering}, + Arc, Mutex, + }, + time::Duration, +}; + +use async_channel::{Receiver, Sender}; +use bc_components::{ + Digest, MLDSAPrivateKey, MLDSAPublicKey, MLKEMPrivateKey, MLKEMPublicKey, SymmetricKey, MLDSA, + MLKEM, XID, +}; +use rkyv::{Archive, Serialize}; +use tokio::task::LocalSet; + +use crate::{ + platform::{PlatformFuture, QlCrypto, QlPlatform}, + runtime::{ + new_runtime, EngineConfig, HandlerEvent, KeepAliveConfig, PeerSession, RuntimeConfig, + RuntimeHandle, + }, + wire::{ + self, handshake::HandshakeRecord, heartbeat::HeartbeatBody, now_secs, pair, + AsWireMlKemCiphertext, AsWireNonce, AsWireXid, QlHeader, QlPayload, QlRecord, + }, + PacketId, Peer, QlError, +}; + +mod handshake; +mod heartbeat; +mod persistence; +mod rpc; +mod stream; +mod unpair; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PeerStage { + Disconnected, + Initiator, + Responder, + Connected, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct StatusEvent { + peer: XID, + stage: PeerStage, +} + +struct TestPlatform { + signing_private: MLDSAPrivateKey, + signing_public: MLDSAPublicKey, + encapsulation_private: MLKEMPrivateKey, + encapsulation_public: MLKEMPublicKey, + outbound: Sender>, + status: Sender, + nonce_seed: u8, + nonce_counter: AtomicU8, + fail_on_write: Option, + write_counter: AtomicUsize, +} + +impl TestPlatform { + fn new(seed: u8) -> (Self, Receiver>, Receiver) { + Self::new_with_fail_on_write(seed, None) + } + + fn new_with_failed_write( + seed: u8, + fail_on_write: usize, + ) -> (Self, Receiver>, Receiver) { + Self::new_with_fail_on_write(seed, Some(fail_on_write)) + } + + fn new_with_fail_on_write( + seed: u8, + fail_on_write: Option, + ) -> (Self, Receiver>, Receiver) { + let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); + let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); + let (outbound, outbound_rx) = async_channel::unbounded(); + let (status, status_rx) = async_channel::unbounded(); + ( + Self { + signing_private, + signing_public, + encapsulation_private, + encapsulation_public, + outbound, + status, + nonce_seed: seed, + nonce_counter: AtomicU8::new(0), + fail_on_write, + write_counter: AtomicUsize::new(0), + }, + outbound_rx, + status_rx, + ) + } + + fn signing_public_key(&self) -> &MLDSAPublicKey { + &self.signing_public + } + + fn encapsulation_public_key(&self) -> &MLKEMPublicKey { + &self.encapsulation_public + } +} + +impl QlCrypto for TestPlatform { + fn signing_private_key(&self) -> &MLDSAPrivateKey { + &self.signing_private + } + + fn signing_public_key(&self) -> &MLDSAPublicKey { + &self.signing_public + } + + fn encapsulation_private_key(&self) -> &MLKEMPrivateKey { + &self.encapsulation_private + } + + fn encapsulation_public_key(&self) -> &MLKEMPublicKey { + &self.encapsulation_public + } + + fn fill_random_bytes(&self, data: &mut [u8]) { + let value = self + .nonce_seed + .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); + data.fill(value); + } +} + +impl QlPlatform for TestPlatform { + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { + let fail_on_write = self.fail_on_write; + let write_index = self.write_counter.fetch_add(1, Ordering::Relaxed) + 1; + let outbound = self.outbound.clone(); + Box::pin(async move { + if fail_on_write == Some(write_index) { + return Err(QlError::SendFailed); + } + outbound + .send(message) + .await + .map_err(|_| QlError::InvalidPayload) + }) + } + + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { + Box::pin(tokio::time::sleep(duration)) + } + + fn load_peer(&self) -> PlatformFuture<'_, Option> { + Box::pin(async { None }) + } + + fn persist_peer(&self, _peer: Peer) {} + + fn clear_peer(&self) {} + + fn handle_peer_status(&self, peer: XID, session: &PeerSession) { + let stage = match session { + PeerSession::Disconnected => PeerStage::Disconnected, + PeerSession::Initiator { .. } => PeerStage::Initiator, + PeerSession::Responder { .. } => PeerStage::Responder, + PeerSession::Connected { .. } => PeerStage::Connected, + }; + let _ = self.status.try_send(StatusEvent { peer, stage }); + } + + fn handle_inbound(&self, _event: HandlerEvent) {} +} + +struct InboundPlatform { + signing_private: MLDSAPrivateKey, + signing_public: MLDSAPublicKey, + encapsulation_private: MLKEMPrivateKey, + encapsulation_public: MLKEMPublicKey, + outbound: Sender>, + status: Sender, + inbound: Sender, + nonce_seed: u8, + nonce_counter: AtomicU8, +} + +impl InboundPlatform { + fn new( + seed: u8, + ) -> ( + Self, + Receiver>, + Receiver, + Receiver, + ) { + let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); + let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); + let (outbound, outbound_rx) = async_channel::unbounded(); + let (status, status_rx) = async_channel::unbounded(); + let (inbound, inbound_rx) = async_channel::unbounded(); + ( + Self { + signing_private, + signing_public, + encapsulation_private, + encapsulation_public, + outbound, + status, + inbound, + nonce_seed: seed, + nonce_counter: AtomicU8::new(0), + }, + outbound_rx, + status_rx, + inbound_rx, + ) + } +} + +impl QlCrypto for InboundPlatform { + fn signing_private_key(&self) -> &MLDSAPrivateKey { + &self.signing_private + } + + fn signing_public_key(&self) -> &MLDSAPublicKey { + &self.signing_public + } + + fn encapsulation_private_key(&self) -> &MLKEMPrivateKey { + &self.encapsulation_private + } + + fn encapsulation_public_key(&self) -> &MLKEMPublicKey { + &self.encapsulation_public + } + + fn fill_random_bytes(&self, data: &mut [u8]) { + let value = self + .nonce_seed + .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); + data.fill(value); + } +} + +impl QlPlatform for InboundPlatform { + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { + let outbound = self.outbound.clone(); + Box::pin(async move { + outbound + .send(message) + .await + .map_err(|_| QlError::InvalidPayload) + }) + } + + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { + Box::pin(tokio::time::sleep(duration)) + } + + fn load_peer(&self) -> PlatformFuture<'_, Option> { + Box::pin(async { None }) + } + + fn persist_peer(&self, _peer: Peer) {} + + fn clear_peer(&self) {} + + fn handle_peer_status(&self, peer: XID, session: &PeerSession) { + let stage = match session { + PeerSession::Disconnected => PeerStage::Disconnected, + PeerSession::Initiator { .. } => PeerStage::Initiator, + PeerSession::Responder { .. } => PeerStage::Responder, + PeerSession::Connected { .. } => PeerStage::Connected, + }; + let _ = self.status.try_send(StatusEvent { peer, stage }); + } + + fn handle_inbound(&self, event: HandlerEvent) { + let _ = self.inbound.try_send(event); + } +} + +async fn run_local_test(future: F) +where + F: Future, +{ + let local = LocalSet::new(); + local.run_until(future).await; +} + +fn spawn_forwarder(outbound: Receiver>, handle: RuntimeHandle) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + handle.send_incoming(bytes); + } + }); +} + +fn is_stream(bytes: &[u8]) -> bool { + let Ok(record) = wire::decode_record(bytes) else { + return false; + }; + matches!(record.payload, QlPayload::Stream(_)) +} + +fn is_heartbeat(bytes: &[u8]) -> bool { + let Ok(record) = wire::decode_record(bytes) else { + return false; + }; + matches!(record.payload, QlPayload::Heartbeat(_)) +} + +fn spawn_drop_first_stream_forwarder(outbound: Receiver>, handle: RuntimeHandle) { + tokio::task::spawn_local(async move { + let mut dropped = false; + while let Ok(bytes) = outbound.recv().await { + if !dropped && is_stream(&bytes) { + dropped = true; + continue; + } + handle.send_incoming(bytes); + } + }); +} + +fn spawn_drop_first_stream_when( + outbound: Receiver>, + handle: RuntimeHandle, + armed: Arc, +) { + tokio::task::spawn_local(async move { + let mut dropped = false; + while let Ok(bytes) = outbound.recv().await { + if armed.load(Ordering::Relaxed) && !dropped && is_stream(&bytes) { + dropped = true; + continue; + } + handle.send_incoming(bytes); + } + }); +} + +fn spawn_duplicate_first_stream_forwarder(outbound: Receiver>, handle: RuntimeHandle) { + tokio::task::spawn_local(async move { + let mut duplicated = false; + while let Ok(bytes) = outbound.recv().await { + if !duplicated && is_stream(&bytes) { + duplicated = true; + handle.send_incoming(bytes.clone()); + } + handle.send_incoming(bytes); + } + }); +} + +#[derive(Clone)] +struct SessionKeyMaterial { + initiator_encapsulation_private: MLKEMPrivateKey, + responder_encapsulation_private: MLKEMPrivateKey, +} + +fn session_key_material( + initiator: &TestPlatform, + responder: &InboundPlatform, +) -> SessionKeyMaterial { + SessionKeyMaterial { + initiator_encapsulation_private: initiator.encapsulation_private.clone(), + responder_encapsulation_private: responder.encapsulation_private.clone(), + } +} + +#[derive(Default)] +struct SessionTrace { + hello_header: Option, + hello: Option, + reply: Option, + session_key: Option, +} + +#[derive(Archive, Serialize)] +struct TestHandshakeTranscript { + #[rkyv(with = AsWireXid)] + initiator: XID, + #[rkyv(with = AsWireXid)] + responder: XID, + #[rkyv(with = AsWireNonce)] + initiator_nonce: bc_components::Nonce, + #[rkyv(with = AsWireNonce)] + responder_nonce: bc_components::Nonce, + #[rkyv(with = AsWireMlKemCiphertext)] + initiator_kem_ct: bc_components::MLKEMCiphertext, + #[rkyv(with = AsWireMlKemCiphertext)] + responder_kem_ct: bc_components::MLKEMCiphertext, +} + +#[derive(Archive, Serialize)] +struct TestSessionKeyMaterial { + initiator_secret: Vec, + responder_secret: Vec, + transcript: Vec, +} + +fn derive_session_key( + trace: &SessionTrace, + key_material: &SessionKeyMaterial, +) -> Option { + let header = trace.hello_header.as_ref()?; + let hello = trace.hello.as_ref()?; + let reply = trace.reply.as_ref()?; + let initiator_secret = key_material + .responder_encapsulation_private + .decapsulate_shared_secret(&hello.kem_ct) + .ok()?; + let responder_secret = key_material + .initiator_encapsulation_private + .decapsulate_shared_secret(&reply.kem_ct) + .ok()?; + let transcript = wire::encode_value(&TestHandshakeTranscript { + initiator: header.sender, + responder: header.recipient, + initiator_nonce: hello.nonce.clone(), + responder_nonce: reply.nonce.clone(), + initiator_kem_ct: hello.kem_ct.clone(), + responder_kem_ct: reply.kem_ct.clone(), + }); + let payload = wire::encode_value(&TestSessionKeyMaterial { + initiator_secret: initiator_secret.as_bytes().to_vec(), + responder_secret: responder_secret.as_bytes().to_vec(), + transcript, + }); + let digest = Digest::from_image(payload); + Some(SymmetricKey::from_data(*digest.data())) +} + +fn test_encryption_nonce(seed: u8) -> [u8; wire::encrypted_message::NONCE_SIZE] { + [seed; wire::encrypted_message::NONCE_SIZE] +} + +fn spawn_stream_mutating_forwarder( + outbound: Receiver>, + handle: RuntimeHandle, + key_material: SessionKeyMaterial, + trace: Arc>, + mutator: F, +) where + F: FnMut(&QlHeader, &mut wire::stream::StreamBody) -> bool + 'static, +{ + tokio::task::spawn_local(async move { + let mut mutator = mutator; + while let Ok(bytes) = outbound.recv().await { + let Ok(record) = wire::access_record(&bytes) else { + handle.send_incoming(bytes); + continue; + }; + + { + let mut trace = trace.lock().unwrap(); + match &record.payload { + wire::ArchivedQlPayload::Handshake( + wire::handshake::ArchivedHandshakeRecord::Hello(hello), + ) => { + trace.hello_header = Some(wire::deserialize_value(&record.header).unwrap()); + trace.hello = Some(wire::deserialize_value(hello).unwrap()); + } + wire::ArchivedQlPayload::Handshake( + wire::handshake::ArchivedHandshakeRecord::HelloReply(reply), + ) => { + trace.reply = Some(wire::deserialize_value(reply).unwrap()); + } + _ => {} + } + if trace.session_key.is_none() { + trace.session_key = derive_session_key(&trace, &key_material); + } + } + + let session_key = trace.lock().unwrap().session_key.clone(); + if let (Some(session_key), wire::ArchivedQlPayload::Stream(encrypted)) = + (session_key, &record.payload) + { + let header = wire::deserialize_value(&record.header).unwrap(); + let encrypted = wire::deserialize_value(encrypted).unwrap(); + let plaintext = encrypted.decrypt(&session_key, &header.aad()); + if let Ok(plaintext) = plaintext { + let body = wire::access_value::(&plaintext) + .and_then(wire::deserialize_value); + if let Ok(mut body) = body { + if mutator(&header, &mut body) { + let mutated = wire::stream::encrypt_stream( + header, + &session_key, + body.clone(), + test_encryption_nonce(body.packet_id.0 as u8), + ); + handle.send_incoming(wire::encode_record(&mutated)); + continue; + } + } + } + } + + handle.send_incoming(bytes); + } + }); +} + +fn spawn_drop_every_nth_stream_forwarder( + outbound: Receiver>, + handle: RuntimeHandle, + nth: usize, +) { + tokio::task::spawn_local(async move { + let mut stream_count = 0usize; + while let Ok(bytes) = outbound.recv().await { + if nth > 0 && is_stream(&bytes) { + stream_count = stream_count.saturating_add(1); + if stream_count % nth == 0 { + continue; + } + } + handle.send_incoming(bytes); + } + }); +} + +fn spawn_heartbeat_tap_forwarder( + outbound: Receiver>, + handle: RuntimeHandle, + heartbeat_tx: Sender<()>, +) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + if is_heartbeat(&bytes) { + let _ = heartbeat_tx.send(()).await; + } + handle.send_incoming(bytes); + } + }); +} + +fn spawn_drop_heartbeat_forwarder(outbound: Receiver>, handle: RuntimeHandle) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + if is_heartbeat(&bytes) { + continue; + } + handle.send_incoming(bytes); + } + }); +} + +fn spawn_gated_forwarder( + outbound: Receiver>, + handle: RuntimeHandle, + drop_flag: Arc, +) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + if drop_flag.load(Ordering::Relaxed) { + continue; + } + handle.send_incoming(bytes); + } + }); +} + +#[derive(Clone)] +struct PeerIdentity { + xid: XID, + signing_key: MLDSAPublicKey, + encapsulation_key: MLKEMPublicKey, +} + +fn peer_identity(platform: &impl QlCrypto) -> PeerIdentity { + PeerIdentity { + xid: platform.xid(), + signing_key: platform.signing_public_key().clone(), + encapsulation_key: platform.encapsulation_public_key().clone(), + } +} + +fn register_peers( + handle_a: &RuntimeHandle, + handle_b: &RuntimeHandle, + identity_a: &PeerIdentity, + identity_b: &PeerIdentity, +) { + handle_a.bind_peer(Peer { + peer: identity_b.xid, + signing_key: identity_b.signing_key.clone(), + encapsulation_key: identity_b.encapsulation_key.clone(), + }); + handle_b.bind_peer(Peer { + peer: identity_a.xid, + signing_key: identity_a.signing_key.clone(), + encapsulation_key: identity_a.encapsulation_key.clone(), + }); +} + +type PersistPlatformParts = ( + PersistPlatform, + Receiver>, + Receiver, + Receiver>, +); + +struct PersistPlatform { + signing_private: MLDSAPrivateKey, + signing_public: MLDSAPublicKey, + encapsulation_private: MLKEMPrivateKey, + encapsulation_public: MLKEMPublicKey, + outbound: Sender>, + status: Sender, + persisted: Sender>, + loaded_peer: Option, + nonce_seed: u8, + nonce_counter: AtomicU8, +} + +impl PersistPlatform { + fn new(seed: u8, loaded_peer: Option) -> PersistPlatformParts { + let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); + let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); + let (outbound, outbound_rx) = async_channel::unbounded(); + let (status, status_rx) = async_channel::unbounded(); + let (persisted, persisted_rx) = async_channel::unbounded(); + ( + Self { + signing_private, + signing_public, + encapsulation_private, + encapsulation_public, + outbound, + status, + persisted, + loaded_peer, + nonce_seed: seed, + nonce_counter: AtomicU8::new(0), + }, + outbound_rx, + status_rx, + persisted_rx, + ) + } +} + +impl QlCrypto for PersistPlatform { + fn signing_private_key(&self) -> &MLDSAPrivateKey { + &self.signing_private + } + fn signing_public_key(&self) -> &MLDSAPublicKey { + &self.signing_public + } + fn encapsulation_private_key(&self) -> &MLKEMPrivateKey { + &self.encapsulation_private + } + fn encapsulation_public_key(&self) -> &MLKEMPublicKey { + &self.encapsulation_public + } + + fn fill_random_bytes(&self, data: &mut [u8]) { + let value = self + .nonce_seed + .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); + data.fill(value); + } +} + +impl QlPlatform for PersistPlatform { + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { + let outbound = self.outbound.clone(); + Box::pin(async move { + outbound + .send(message) + .await + .map_err(|_| QlError::InvalidPayload) + }) + } + + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { + Box::pin(tokio::time::sleep(duration)) + } + + fn load_peer(&self) -> PlatformFuture<'_, Option> { + let peer = self.loaded_peer.clone(); + Box::pin(async move { peer }) + } + + fn persist_peer(&self, peer: crate::Peer) { + let _ = self.persisted.try_send(Some(peer)); + } + + fn clear_peer(&self) { + let _ = self.persisted.try_send(None); + } + + fn handle_peer_status(&self, peer: XID, session: &PeerSession) { + let stage = match session { + PeerSession::Disconnected => PeerStage::Disconnected, + PeerSession::Initiator { .. } => PeerStage::Initiator, + PeerSession::Responder { .. } => PeerStage::Responder, + PeerSession::Connected { .. } => PeerStage::Connected, + }; + let _ = self.status.try_send(StatusEvent { peer, stage }); + } + + fn handle_inbound(&self, _event: HandlerEvent) {} +} + +async fn await_status( + receiver: &Receiver, + peer: XID, + stage: PeerStage, +) -> StatusEvent { + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if let Ok(event) = receiver.recv().await { + if event.peer == peer && event.stage == stage { + return event; + } + } + } + }) + .await + .unwrap() +} + +#[test] +fn protocol_record_size_breakdown() { + let (platform_a, _outbound_a, _status_a) = TestPlatform::new(1); + let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); + + let initiator = platform_a.xid(); + let responder = platform_b.xid(); + + let (hello, initiator_secret) = wire::handshake::build_hello( + &platform_a, + initiator, + responder, + platform_b.encapsulation_public_key(), + ) + .unwrap(); + let hello_record = QlRecord { + header: QlHeader { + sender: initiator, + recipient: responder, + }, + payload: QlPayload::Handshake(HandshakeRecord::Hello(hello.clone())), + }; + let hello_size = wire::encode_record(&hello_record).len(); + let hello_bytes = wire::encode_value(&hello); + let hello_view = wire::access_value::(&hello_bytes).unwrap(); + + let (hello_reply, responder_secrets) = wire::handshake::respond_hello( + &platform_b, + initiator, + responder, + platform_a.encapsulation_public_key(), + hello_view, + ) + .unwrap(); + let reply_record = QlRecord { + header: QlHeader { + sender: responder, + recipient: initiator, + }, + payload: QlPayload::Handshake(HandshakeRecord::HelloReply(hello_reply.clone())), + }; + let reply_size = wire::encode_record(&reply_record).len(); + let reply_bytes = wire::encode_value(&hello_reply); + let reply_view = + wire::access_value::(&reply_bytes).unwrap(); + + let (confirm, session_key) = wire::handshake::build_confirm( + &platform_a, + initiator, + responder, + platform_b.signing_public_key(), + &hello, + reply_view, + &initiator_secret, + ) + .unwrap(); + let confirm_bytes = wire::encode_value(&confirm); + let confirm_view = + wire::access_value::(&confirm_bytes).unwrap(); + let confirm_record = QlRecord { + header: QlHeader { + sender: initiator, + recipient: responder, + }, + payload: QlPayload::Handshake(HandshakeRecord::Confirm(confirm.clone())), + }; + let confirm_size = wire::encode_record(&confirm_record).len(); + let _session_key_b = wire::handshake::finalize_confirm( + initiator, + responder, + platform_a.signing_public_key(), + &hello, + &hello_reply, + confirm_view, + &responder_secrets, + ) + .unwrap(); + + let pair_size = wire::encode_record( + &pair::build_pair_request( + &platform_a, + responder, + platform_b.encapsulation_public_key(), + PacketId(11), + Duration::from_secs(60), + ) + .unwrap(), + ) + .len(); + + let heartbeat_size = wire::encode_record(&wire::heartbeat::encrypt_heartbeat( + QlHeader { + sender: initiator, + recipient: responder, + }, + &session_key, + HeartbeatBody { + packet_id: PacketId(12), + valid_until: wire::now_secs().saturating_add(60), + }, + test_encryption_nonce(12), + )) + .len(); + + let unpair_size = wire::encode_record(&wire::unpair::build_unpair_record( + &platform_a, + QlHeader { + sender: initiator, + recipient: responder, + }, + PacketId(13), + wire::now_secs().saturating_add(60), + )) + .len(); + + let stream_record_size = + |packet_id: PacketId, + packet_ack: Option, + frame: Option| { + wire::encode_record(&wire::stream::encrypt_stream( + QlHeader { + sender: initiator, + recipient: responder, + }, + &session_key, + wire::stream::StreamBody { + packet_id, + valid_until: wire::now_secs().saturating_add(60), + packet_ack, + frame, + }, + test_encryption_nonce(packet_id.0 as u8), + )) + .len() + }; + + let stream_header = QlHeader { + sender: initiator, + recipient: responder, + }; + let stream_ack_body = wire::stream::StreamBody { + packet_id: PacketId(20), + valid_until: wire::now_secs().saturating_add(60), + packet_ack: Some(wire::stream::PacketAck { + packet_id: PacketId(19), + }), + frame: None, + }; + let stream_ack_record = wire::stream::encrypt_stream( + stream_header.clone(), + &session_key, + stream_ack_body.clone(), + test_encryption_nonce(20), + ); + let stream_ack_encrypted = match &stream_ack_record.payload { + QlPayload::Stream(encrypted) => encrypted, + _ => unreachable!(), + }; + let stream_ack_header_size = wire::encode_value(&stream_header).len(); + let stream_ack_body_size = wire::encode_value(&stream_ack_body).len(); + let stream_ack_envelope_size = wire::encode_value(stream_ack_encrypted).len(); + let stream_ack_payload_size = wire::encode_value(&stream_ack_record.payload).len(); + + let stream_open_body = wire::stream::StreamBody { + packet_id: PacketId(21), + valid_until: wire::now_secs().saturating_add(60), + packet_ack: None, + frame: Some(wire::stream::StreamFrame::Open( + wire::stream::StreamFrameOpen { + stream_id: crate::StreamId(2), + request_head: vec![1, 2, 3], + response_max_offset: 1024, + }, + )), + }; + let stream_open_body_size = wire::encode_value(&stream_open_body).len(); + + let stream_ack_size = stream_record_size( + PacketId(20), + Some(wire::stream::PacketAck { + packet_id: PacketId(19), + }), + None, + ); + let stream_open_size = stream_record_size( + PacketId(21), + None, + Some(wire::stream::StreamFrame::Open( + wire::stream::StreamFrameOpen { + stream_id: crate::StreamId(2), + request_head: vec![1, 2, 3], + response_max_offset: 1024, + }, + )), + ); + let stream_accept_size = stream_record_size( + PacketId(22), + None, + Some(wire::stream::StreamFrame::Accept( + wire::stream::StreamFrameAccept { + stream_id: crate::StreamId(2), + response_head: vec![4, 5, 6], + request_max_offset: 2048, + }, + )), + ); + let stream_reject_size = stream_record_size( + PacketId(23), + None, + Some(wire::stream::StreamFrame::Reject( + wire::stream::StreamFrameReject { + stream_id: crate::StreamId(2), + code: wire::stream::RejectCode::InvalidHead, + }, + )), + ); + let stream_data_size = stream_record_size( + PacketId(24), + None, + Some(wire::stream::StreamFrame::Data( + wire::stream::StreamFrameData { + stream_id: crate::StreamId(2), + dir: wire::stream::Direction::Request, + offset: 128, + bytes: vec![7, 8, 9, 10], + }, + )), + ); + let stream_credit_size = stream_record_size( + PacketId(25), + None, + Some(wire::stream::StreamFrame::Credit( + wire::stream::StreamFrameCredit { + stream_id: crate::StreamId(2), + dir: wire::stream::Direction::Response, + recv_offset: 256, + max_offset: 4096, + }, + )), + ); + let stream_finish_size = stream_record_size( + PacketId(26), + None, + Some(wire::stream::StreamFrame::Finish( + wire::stream::StreamFrameFinish { + stream_id: crate::StreamId(2), + dir: wire::stream::Direction::Response, + }, + )), + ); + let stream_reset_size = stream_record_size( + PacketId(27), + None, + Some(wire::stream::StreamFrame::Reset( + wire::stream::StreamFrameReset { + stream_id: crate::StreamId(2), + dir: wire::stream::ResetTarget::Both, + code: wire::stream::ResetCode::Protocol, + }, + )), + ); + + let print_size = |label: &str, size: usize| { + println!("{label:<23}: {size} bytes"); + }; + + print_size("ql2 size hello", hello_size); + print_size("ql2 size hello_reply", reply_size); + print_size("ql2 size confirm", confirm_size); + print_size("ql2 size pair", pair_size); + print_size("ql2 size heartbeat", heartbeat_size); + print_size("ql2 size unpair", unpair_size); + print_size("ql2 size stream ack", stream_ack_size); + print_size("ql2 size stream open", stream_open_size); + print_size("ql2 size stream accept", stream_accept_size); + print_size("ql2 size stream reject", stream_reject_size); + print_size("ql2 size stream data", stream_data_size); + print_size("ql2 size stream credit", stream_credit_size); + print_size("ql2 size stream finish", stream_finish_size); + print_size("ql2 size stream reset", stream_reset_size); + println!( + "ql2 stream ack breakdown : header={} derived_aad={} plaintext={} ciphertext={} envelope(no aad)={} payload={} full={}", + stream_ack_header_size, + stream_header.aad().len(), + stream_ack_body_size, + stream_ack_body_size, + stream_ack_envelope_size, + stream_ack_payload_size, + stream_ack_size, + ); + println!( + "ql2 stream open delta : open_body={} ack_body={} (+{} request_head bytes)", + stream_open_body_size, + stream_ack_body_size, + stream_open_body_size.saturating_sub(stream_ack_body_size), + ); +} diff --git a/ql2/src/tests/persistence.rs b/ql2/src/tests/persistence.rs new file mode 100644 index 00000000..8fc5cf9f --- /dev/null +++ b/ql2/src/tests/persistence.rs @@ -0,0 +1,139 @@ +use std::time::Duration; + +use super::*; + +#[tokio::test(flavor = "current_thread")] +async fn register_peer_persists_snapshot() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, _outbound_a, _status_a, persisted_a) = PersistPlatform::new(1, None); + let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); + let peer_b = platform_b.xid(); + let signing_b = platform_b.signing_public_key().clone(); + let encap_b = platform_b.encapsulation_public_key().clone(); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + tokio::task::spawn_local(async move { runtime_a.run().await }); + + handle_a.bind_peer(crate::Peer { + peer: peer_b, + signing_key: signing_b.clone(), + encapsulation_key: encap_b.clone(), + }); + + let snapshot = tokio::time::timeout(Duration::from_secs(1), persisted_a.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!( + snapshot, + Some(crate::Peer { + peer: peer_b, + signing_key: signing_b, + encapsulation_key: encap_b, + }) + ); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn loaded_peers_can_connect_without_register() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_b = peer_identity(&platform_b); + + let (platform_a, outbound_a, status_a, _persisted_a) = PersistPlatform::new( + 1, + Some(crate::Peer { + peer: peer_b.xid, + signing_key: peer_b.signing_key.clone(), + encapsulation_key: peer_b.encapsulation_key.clone(), + }), + ); + let peer_a = peer_identity(&platform_a); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + handle_b.bind_peer(crate::Peer { + peer: peer_a.xid, + signing_key: peer_a.signing_key.clone(), + encapsulation_key: peer_a.encapsulation_key.clone(), + }); + + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn pairing_persists_snapshot() { + run_local_test(async { + let (platform_a, _outbound_a, _status_a) = TestPlatform::new(1); + let peer_a = peer_identity(&platform_a); + + let (platform_b, _outbound_b, _status_b, persisted_b) = PersistPlatform::new(2, None); + let peer_b = peer_identity(&platform_b); + + let pairing_message = pair::build_pair_request( + &platform_a, + peer_b.xid, + &peer_b.encapsulation_key, + PacketId(1), + Duration::from_secs(1), + ) + .unwrap(); + let pairing_bytes = wire::encode_record(&pairing_message); + + let (runtime_b, handle_b) = new_runtime( + platform_b, + RuntimeConfig { + engine: EngineConfig { + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }, + ); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + handle_b.send_incoming(pairing_bytes); + + let snapshot = tokio::time::timeout(Duration::from_secs(1), persisted_b.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!( + snapshot, + Some(crate::Peer { + peer: peer_a.xid, + signing_key: peer_a.signing_key, + encapsulation_key: peer_a.encapsulation_key, + }) + ); + }) + .await; +} diff --git a/ql2/src/tests/rpc.rs b/ql2/src/tests/rpc.rs new file mode 100644 index 00000000..7308c688 --- /dev/null +++ b/ql2/src/tests/rpc.rs @@ -0,0 +1,264 @@ +use std::time::Duration; + +use dcbor::CBOR; + +use super::*; +use crate::{ + rpc::{MethodId, RequestResponse, RpcHandle, RpcRequestHead, RpcResponseHead}, + runtime::StreamConfig, + wire::stream::RejectCode, + QlError, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +struct AddOne(u64); + +#[derive(Debug, Clone, PartialEq, Eq)] +struct AddOneResponse(u64); + +impl From for CBOR { + fn from(value: AddOne) -> Self { + CBOR::from(value.0) + } +} + +impl TryFrom for AddOne { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + Ok(Self(value.try_into()?)) + } +} + +impl From for CBOR { + fn from(value: AddOneResponse) -> Self { + CBOR::from(value.0) + } +} + +impl TryFrom for AddOneResponse { + type Error = dcbor::Error; + + fn try_from(value: CBOR) -> Result { + Ok(Self(value.try_into()?)) + } +} + +impl RequestResponse for AddOne { + const METHOD: MethodId = MethodId(1); + type Response = AddOneResponse; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_request_response_round_trip() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(300), + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + let rpc_a = RpcHandle::new(handle_a.clone()); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let request_body = CBOR::from(AddOne(41)).to_cbor_data(); + let response_body = CBOR::from(AddOneResponse(42)).to_cbor_data(); + let request_head = + RpcRequestHead::try_from(CBOR::try_from_data(&stream.request_head).unwrap()) + .unwrap(); + assert_eq!(request_head.method, AddOne::METHOD); + assert_eq!(request_head.content_length, Some(request_body.len() as u64)); + + let mut response = stream + .respond_to + .accept( + CBOR::from(RpcResponseHead::new(Some(response_body.len() as u64))) + .to_cbor_data(), + ) + .unwrap(); + + let request_body = read_body(stream.request).await.unwrap(); + let request = AddOne::try_from(CBOR::try_from_data(&request_body).unwrap()).unwrap(); + + response + .write_all(&CBOR::from(AddOneResponse(request.0 + 1)).to_cbor_data()) + .await + .unwrap(); + response.finish().await.unwrap(); + }); + + let response = rpc_a + .request(AddOne(41), StreamConfig::default()) + .await + .unwrap(); + assert_eq!(response, AddOneResponse(42)); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_request_response_reject_propagates() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(300), + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + let rpc_a = RpcHandle::new(handle_a.clone()); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let request_head = + RpcRequestHead::try_from(CBOR::try_from_data(&stream.request_head).unwrap()) + .unwrap(); + assert_eq!(request_head.method, AddOne::METHOD); + stream.respond_to.reject(RejectCode::UnknownRoute).unwrap(); + }); + + let err = rpc_a + .request(AddOne(1), StreamConfig::default()) + .await + .unwrap_err(); + assert!(matches!( + err, + crate::rpc::RpcError::Transport(QlError::StreamRejected { + code: RejectCode::UnknownRoute + }) + )); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_request_response_content_length_mismatch_errors() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(300), + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + let rpc_a = RpcHandle::new(handle_a.clone()); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut response = stream + .respond_to + .accept(CBOR::from(RpcResponseHead::new(Some(99))).to_cbor_data()) + .unwrap(); + let _request_body = read_body(stream.request).await.unwrap(); + response + .write_all(&CBOR::from(AddOneResponse(2)).to_cbor_data()) + .await + .unwrap(); + response.finish().await.unwrap(); + }); + + let err = rpc_a + .request(AddOne(1), StreamConfig::default()) + .await + .unwrap_err(); + assert!(matches!( + err, + crate::rpc::RpcError::ContentLengthMismatch { + expected: 99, + actual: 1, + } + )); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +async fn read_body(mut stream: crate::runtime::InboundByteStream) -> Result, QlError> { + let mut body = Vec::new(); + while let Some(chunk) = stream.next_chunk().await? { + body.extend_from_slice(&chunk); + } + Ok(body) +} diff --git a/ql2/src/tests/stream.rs b/ql2/src/tests/stream.rs new file mode 100644 index 00000000..5692f8fa --- /dev/null +++ b/ql2/src/tests/stream.rs @@ -0,0 +1,1685 @@ +use std::{sync::atomic::Ordering, time::Duration}; + +use super::*; +use crate::{ + runtime::{PendingStream, StreamConfig}, + wire::stream::{ + Direction, RejectCode, ResetCode, StreamFrame, StreamFrameCredit, StreamFrameData, + }, +}; + +#[tokio::test(flavor = "current_thread")] +async fn duplex_stream_round_trip() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(300), + packet_ack_timeout: Duration::from_millis(40), + max_payload_bytes: 4, + initial_credit: 4, + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + assert_eq!(stream.request_head, b"req-head".to_vec()); + + let mut request = stream.request; + let mut response = stream.respond_to.accept(b"resp-head".to_vec()).unwrap(); + + assert_eq!(request.next_chunk().await.unwrap(), Some(vec![1, 2])); + response.write_all(&[9]).await.unwrap(); + assert_eq!(request.next_chunk().await.unwrap(), Some(vec![3, 4])); + response.write_all(&[8, 7]).await.unwrap(); + assert_eq!(request.next_chunk().await.unwrap(), None); + response.finish().await.unwrap(); + }); + + let pending = handle_a + .open_stream(b"req-head".to_vec(), StreamConfig::default()) + .await + .unwrap(); + let PendingStream { + mut request, + accepted, + } = pending; + request.write_all(&[1, 2]).await.unwrap(); + let mut accepted = accepted.await.unwrap(); + assert_eq!(accepted.response_head, b"resp-head".to_vec()); + assert_eq!(accepted.response.next_chunk().await.unwrap(), Some(vec![9])); + request.write_all(&[3, 4]).await.unwrap(); + request.finish().await.unwrap(); + assert_eq!( + accepted.response.next_chunk().await.unwrap(), + Some(vec![8, 7]) + ); + assert_eq!(accepted.response.next_chunk().await.unwrap(), None); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn duplicate_open_is_idempotent() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(400), + packet_ack_timeout: Duration::from_millis(30), + max_payload_bytes: 4, + initial_credit: 4, + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_drop_first_stream_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + tokio::time::sleep(Duration::from_millis(120)).await; + let response = stream.respond_to.accept(Vec::new()).unwrap(); + let second = tokio::time::timeout(Duration::from_millis(120), inbound_b.recv()).await; + assert!(second.is_err(), "duplicate open redispatched stream"); + response.finish().await.unwrap(); + }); + + let pending = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + let PendingStream { request, accepted } = pending; + let mut accepted = accepted.await.unwrap(); + request.finish().await.unwrap(); + assert_eq!(accepted.response.next_chunk().await.unwrap(), None); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn duplicate_accept_is_idempotent() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(400), + packet_ack_timeout: Duration::from_millis(30), + max_payload_bytes: 4, + initial_credit: 4, + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + let arm_drop = Arc::new(AtomicBool::new(false)); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_drop_first_stream_when(outbound_a, handle_b.clone(), arm_drop.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + arm_drop.store(true, Ordering::Relaxed); + let response = stream.respond_to.accept(b"accepted".to_vec()).unwrap(); + tokio::time::sleep(Duration::from_millis(150)).await; + response.finish().await.unwrap(); + }); + + let pending = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + let PendingStream { request, accepted } = pending; + let mut accepted = accepted.await.unwrap(); + assert_eq!(accepted.response_head, b"accepted".to_vec()); + tokio::time::sleep(Duration::from_millis(120)).await; + request.finish().await.unwrap(); + assert_eq!(accepted.response.next_chunk().await.unwrap(), None); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn replayed_open_packet_is_ignored() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(300), + packet_ack_timeout: Duration::from_millis(40), + max_payload_bytes: 4, + initial_credit: 4, + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_duplicate_first_stream_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let second = tokio::time::timeout(Duration::from_millis(80), inbound_b.recv()).await; + assert!(second.is_err(), "replayed open redispatched stream"); + let response = stream.respond_to.accept(Vec::new()).unwrap(); + response.finish().await.unwrap(); + }); + + let pending = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + let PendingStream { request, accepted } = pending; + let mut accepted = accepted.await.unwrap(); + request.finish().await.unwrap(); + assert_eq!(accepted.response.next_chunk().await.unwrap(), None); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn request_reset_can_keep_response_alive() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(400), + packet_ack_timeout: Duration::from_millis(40), + max_payload_bytes: 16, + initial_credit: 16, + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut request = stream.request; + let mut response = stream.respond_to.accept(b"err".to_vec()).unwrap(); + assert_eq!(request.next_chunk().await.unwrap(), Some(vec![1, 2])); + request.reset(ResetCode::InvalidData).await.unwrap(); + response.write_all(b"invalid").await.unwrap(); + response.finish().await.unwrap(); + }); + + let pending = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + let PendingStream { + mut request, + accepted, + } = pending; + request.write_all(&[1, 2]).await.unwrap(); + let mut accepted = accepted.await.unwrap(); + assert_eq!(accepted.response_head, b"err".to_vec()); + assert_eq!( + accepted.response.next_chunk().await.unwrap(), + Some(b"invalid".to_vec()) + ); + let err = request.write_all(&[3, 4]).await.unwrap_err(); + assert!(matches!(err, QlError::Cancelled)); + assert_eq!(accepted.response.next_chunk().await.unwrap(), None); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn open_timeout_returns_error() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(120), + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let pending = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + + let _stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + + let err = pending.accepted.await.unwrap_err(); + assert!(matches!(err, QlError::Timeout)); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn reject_surfaces_stream_rejected() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(300), + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + stream.respond_to.reject(RejectCode::UnknownRoute).unwrap(); + }); + + let pending = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + let err = pending.accepted.await.unwrap_err(); + assert!(matches!( + err, + QlError::StreamRejected { + code: RejectCode::UnknownRoute + } + )); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn dropping_responder_rejects_unhandled() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(300), + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut request = stream.request; + drop(stream.respond_to); + assert!(matches!( + request.next_chunk().await, + Ok(None) | Err(QlError::Cancelled) + )); + }); + + let PendingStream { request, accepted } = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + request.finish().await.unwrap(); + + let err = tokio::time::timeout(Duration::from_secs(1), accepted) + .await + .unwrap() + .unwrap_err(); + assert!(matches!( + err, + QlError::StreamRejected { + code: RejectCode::Unhandled + } + )); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn request_larger_than_ring_buffer_streams_with_backpressure() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(400), + packet_ack_timeout: Duration::from_millis(30), + max_payload_bytes: 4, + initial_credit: 4, + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + pipe_size_bytes: 4, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + let payload: Vec = (0..24).collect(); + let (done_tx, done_rx) = async_channel::bounded(1); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut request = stream.request; + let response = stream.respond_to.accept(Vec::new()).unwrap(); + let mut received = Vec::new(); + while let Some(chunk) = request.next_chunk().await.unwrap() { + received.extend_from_slice(&chunk); + } + done_tx.send(received).await.unwrap(); + response.finish().await.unwrap(); + }); + + let PendingStream { + mut request, + accepted, + } = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + request.write_all(&payload).await.unwrap(); + request.finish().await.unwrap(); + + let mut accepted = tokio::time::timeout(Duration::from_secs(1), accepted) + .await + .unwrap() + .unwrap(); + assert_eq!(accepted.response.next_chunk().await.unwrap(), None); + + let received = tokio::time::timeout(Duration::from_secs(1), done_rx.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(received, payload); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn response_larger_than_ring_buffer_streams_with_backpressure() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(400), + packet_ack_timeout: Duration::from_millis(30), + max_payload_bytes: 4, + initial_credit: 4, + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + pipe_size_bytes: 4, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + let payload: Vec = (50..74).collect(); + let expected = payload.clone(); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut request = stream.request; + let mut response = stream.respond_to.accept(Vec::new()).unwrap(); + assert_eq!(request.next_chunk().await.unwrap(), None); + response.write_all(&payload).await.unwrap(); + response.finish().await.unwrap(); + }); + + let PendingStream { request, accepted } = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + request.finish().await.unwrap(); + + let mut accepted = tokio::time::timeout(Duration::from_secs(1), accepted) + .await + .unwrap() + .unwrap(); + let mut received = Vec::new(); + while let Some(chunk) = accepted.response.next_chunk().await.unwrap() { + received.extend_from_slice(&chunk); + } + assert_eq!(received, expected); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn dropping_pending_accept_cancels_response_side() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(400), + packet_ack_timeout: Duration::from_millis(30), + max_payload_bytes: 4, + initial_credit: 4, + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + pipe_size_bytes: 4, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut request = stream.request; + let mut response = stream.respond_to.accept(Vec::new()).unwrap(); + assert_eq!(request.next_chunk().await.unwrap(), None); + let err = response + .write_all(&[1, 2, 3, 4, 5, 6, 7, 8]) + .await + .unwrap_err(); + assert!(matches!(err, QlError::Cancelled)); + }); + + let PendingStream { request, accepted } = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + drop(accepted); + request.finish().await.unwrap(); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn dropping_request_writer_sends_cancel() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(300), + packet_ack_timeout: Duration::from_millis(30), + max_payload_bytes: 4, + initial_credit: 4, + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut request = stream.request; + let response = stream.respond_to.accept(Vec::new()).unwrap(); + assert_eq!(request.next_chunk().await.unwrap(), Some(vec![1, 2, 3, 4])); + let err = request.next_chunk().await.unwrap_err(); + assert!(matches!( + err, + QlError::StreamReset { + dir: Direction::Request, + code: ResetCode::Cancelled, + } + )); + response.finish().await.unwrap(); + }); + + let PendingStream { + mut request, + accepted, + } = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + request.write_all(&[1, 2, 3, 4]).await.unwrap(); + let mut accepted = accepted.await.unwrap(); + drop(request); + assert_eq!(accepted.response.next_chunk().await.unwrap(), None); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn dropping_response_writer_sends_cancel() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(300), + packet_ack_timeout: Duration::from_millis(30), + max_payload_bytes: 4, + initial_credit: 4, + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let mut stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut response = stream.respond_to.accept(Vec::new()).unwrap(); + assert_eq!(stream.request.next_chunk().await.unwrap(), None); + response.write_all(&[9, 8, 7, 6]).await.unwrap(); + drop(response); + }); + + let PendingStream { request, accepted } = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + request.finish().await.unwrap(); + let mut accepted = accepted.await.unwrap(); + assert_eq!( + accepted.response.next_chunk().await.unwrap(), + Some(vec![9, 8, 7, 6]) + ); + let err = accepted.response.next_chunk().await.unwrap_err(); + assert!(matches!( + err, + QlError::StreamReset { + dir: Direction::Response, + code: ResetCode::Cancelled, + } + )); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn dropping_request_reader_sends_cancel() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(300), + packet_ack_timeout: Duration::from_millis(30), + max_payload_bytes: 4, + initial_credit: 4, + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut request = stream.request; + let response = stream.respond_to.accept(Vec::new()).unwrap(); + assert_eq!(request.next_chunk().await.unwrap(), Some(vec![1, 2, 3, 4])); + drop(request); + response.finish().await.unwrap(); + }); + + let PendingStream { + mut request, + accepted, + } = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + request.write_all(&[1, 2, 3, 4]).await.unwrap(); + let mut accepted = accepted.await.unwrap(); + // ensure that the runtime can process the drop + tokio::time::sleep(Duration::from_millis(20)).await; + let err = request.write_all(&[5, 6, 7, 8]).await.unwrap_err(); + assert!(matches!(err, QlError::Cancelled)); + assert_eq!(accepted.response.next_chunk().await.unwrap(), None); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn dropping_response_reader_sends_cancel() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(300), + packet_ack_timeout: Duration::from_millis(30), + max_payload_bytes: 4, + initial_credit: 4, + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + pipe_size_bytes: 4, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + let (go_tx, go_rx) = async_channel::bounded(1); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let mut stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut response = stream.respond_to.accept(Vec::new()).unwrap(); + assert_eq!(stream.request.next_chunk().await.unwrap(), None); + go_rx.recv().await.unwrap(); + let err = response + .write_all(&[1, 2, 3, 4, 5, 6, 7, 8]) + .await + .unwrap_err(); + assert!(matches!(err, QlError::Cancelled)); + }); + + let PendingStream { request, accepted } = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + request.finish().await.unwrap(); + let accepted = accepted.await.unwrap(); + drop(accepted.response); + go_tx.send(()).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn empty_request_finishes_cleanly() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(300), + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut request = stream.request; + let mut response = stream.respond_to.accept(Vec::new()).unwrap(); + assert_eq!(request.next_chunk().await.unwrap(), None); + response.write_all(b"ok").await.unwrap(); + response.finish().await.unwrap(); + }); + + let PendingStream { request, accepted } = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + request.finish().await.unwrap(); + let mut accepted = accepted.await.unwrap(); + assert_eq!( + accepted.response.next_chunk().await.unwrap(), + Some(b"ok".to_vec()) + ); + assert_eq!(accepted.response.next_chunk().await.unwrap(), None); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn empty_response_finishes_cleanly() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(300), + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut request = stream.request; + let response = stream.respond_to.accept(Vec::new()).unwrap(); + assert_eq!(request.next_chunk().await.unwrap(), Some(vec![1])); + assert_eq!(request.next_chunk().await.unwrap(), None); + response.finish().await.unwrap(); + }); + + let PendingStream { + mut request, + accepted, + } = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + request.write_all(&[1]).await.unwrap(); + request.finish().await.unwrap(); + let mut accepted = accepted.await.unwrap(); + assert_eq!(accepted.response.next_chunk().await.unwrap(), None); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn stream_survives_every_third_packet_drop() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(500), + packet_ack_timeout: Duration::from_millis(20), + stream_retry_limit: 12, + max_payload_bytes: 4, + initial_credit: 4, + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + pipe_size_bytes: 4, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + let request_payload: Vec = (0..32).collect(); + let response_payload: Vec = (100..132).collect(); + let expected_response = response_payload.clone(); + let (done_tx, done_rx) = async_channel::bounded(1); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_drop_every_nth_stream_forwarder(outbound_a, handle_b.clone(), 3); + spawn_drop_every_nth_stream_forwarder(outbound_b, handle_a.clone(), 3); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut request = stream.request; + let mut response = stream.respond_to.accept(Vec::new()).unwrap(); + let mut received = Vec::new(); + while let Some(chunk) = request.next_chunk().await.unwrap() { + received.extend_from_slice(&chunk); + } + response.write_all(&response_payload).await.unwrap(); + response.finish().await.unwrap(); + done_tx.send(received).await.unwrap(); + }); + + let PendingStream { + mut request, + accepted, + } = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + request.write_all(&request_payload).await.unwrap(); + request.finish().await.unwrap(); + + let mut accepted = tokio::time::timeout(Duration::from_secs(3), accepted) + .await + .unwrap() + .unwrap(); + let mut received_response = Vec::new(); + while let Some(chunk) = accepted.response.next_chunk().await.unwrap() { + received_response.extend_from_slice(&chunk); + } + assert_eq!(received_response, expected_response); + + let received_request = tokio::time::timeout(Duration::from_secs(3), done_rx.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(received_request, request_payload); + + tokio::time::timeout(Duration::from_secs(3), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn response_data_before_accept_is_protocol_error() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(400), + packet_ack_timeout: Duration::from_millis(30), + stream_retry_limit: 8, + max_payload_bytes: 4, + initial_credit: 4, + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + let key_material = session_key_material(&platform_a, &platform_b); + let trace = Arc::new(Mutex::new(SessionTrace::default())); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_stream_mutating_forwarder( + outbound_a, + handle_b.clone(), + key_material.clone(), + trace.clone(), + |_header, _body| false, + ); + spawn_stream_mutating_forwarder(outbound_b, handle_a.clone(), key_material, trace, { + let mut mutated = false; + move |_header, body| { + if mutated { + return false; + } + if let Some(StreamFrame::Accept(frame)) = body.frame.take() { + mutated = true; + body.frame = Some(StreamFrame::Data(StreamFrameData { + stream_id: frame.stream_id, + dir: Direction::Response, + offset: 0, + bytes: vec![9], + })); + true + } else { + false + } + } + }); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut response = stream.respond_to.accept(Vec::new()).unwrap(); + response.write_all(&[9]).await.unwrap(); + let _ = response.finish().await; + }); + + let PendingStream { request, accepted } = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + request.finish().await.unwrap(); + let err = tokio::time::timeout(Duration::from_secs(1), accepted) + .await + .unwrap() + .unwrap_err(); + assert!(matches!(err, QlError::StreamProtocol)); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn data_offset_gap_is_protocol_error() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(400), + packet_ack_timeout: Duration::from_millis(30), + stream_retry_limit: 8, + max_payload_bytes: 4, + initial_credit: 4, + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + let key_material = session_key_material(&platform_a, &platform_b); + let trace = Arc::new(Mutex::new(SessionTrace::default())); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_stream_mutating_forwarder( + outbound_a, + handle_b.clone(), + key_material.clone(), + trace.clone(), + { + let mut mutated = false; + move |_header, body| { + if mutated { + return false; + } + if let Some(StreamFrame::Data(frame)) = body.frame.as_mut() { + mutated = true; + frame.offset = 2; + true + } else { + false + } + } + }, + ); + spawn_stream_mutating_forwarder( + outbound_b, + handle_a.clone(), + key_material, + trace, + |_header, _body| false, + ); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut request = stream.request; + let response = stream.respond_to.accept(Vec::new()).unwrap(); + let err = request.next_chunk().await.unwrap_err(); + assert!(matches!(err, QlError::StreamProtocol)); + let _ = response.finish().await; + }); + + let PendingStream { + mut request, + accepted, + } = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + let _accepted = accepted.await.unwrap(); + request.write_all(&[1, 2, 3, 4]).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn data_beyond_credit_is_protocol_error() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(400), + packet_ack_timeout: Duration::from_millis(30), + stream_retry_limit: 8, + max_payload_bytes: 4, + initial_credit: 4, + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + let key_material = session_key_material(&platform_a, &platform_b); + let trace = Arc::new(Mutex::new(SessionTrace::default())); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_stream_mutating_forwarder( + outbound_a, + handle_b.clone(), + key_material.clone(), + trace.clone(), + { + let mut mutated = false; + move |_header, body| { + if mutated { + return false; + } + if let Some(StreamFrame::Data(frame)) = body.frame.as_mut() { + mutated = true; + frame.offset = 4; + true + } else { + false + } + } + }, + ); + spawn_stream_mutating_forwarder( + outbound_b, + handle_a.clone(), + key_material, + trace, + |_header, _body| false, + ); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut request = stream.request; + let response = stream.respond_to.accept(Vec::new()).unwrap(); + let err = request.next_chunk().await.unwrap_err(); + assert!(matches!(err, QlError::StreamProtocol)); + let _ = response.finish().await; + }); + + let PendingStream { + mut request, + accepted, + } = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + let _accepted = accepted.await.unwrap(); + request.write_all(&[1, 2, 3, 4]).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn credit_regression_is_protocol_error() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(400), + packet_ack_timeout: Duration::from_millis(30), + stream_retry_limit: 8, + max_payload_bytes: 4, + initial_credit: 4, + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + let key_material = session_key_material(&platform_a, &platform_b); + let trace = Arc::new(Mutex::new(SessionTrace::default())); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_stream_mutating_forwarder( + outbound_a, + handle_b.clone(), + key_material.clone(), + trace.clone(), + |_header, _body| false, + ); + spawn_stream_mutating_forwarder(outbound_b, handle_a.clone(), key_material, trace, { + let mut mutated = false; + move |_header, body| { + if mutated { + return false; + } + if let Some(StreamFrame::Credit(StreamFrameCredit { + dir: Direction::Request, + recv_offset, + max_offset, + .. + })) = body.frame.as_mut() + { + mutated = true; + *recv_offset = 99; + *max_offset = 99; + true + } else { + false + } + } + }); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut request = stream.request; + let response = stream.respond_to.accept(Vec::new()).unwrap(); + assert_eq!(request.next_chunk().await.unwrap(), Some(vec![1, 2, 3, 4])); + let err = request.next_chunk().await.unwrap_err(); + assert!(matches!( + err, + QlError::StreamReset { + code: ResetCode::Protocol, + dir: Direction::Request, + } + )); + let _ = response.finish().await; + }); + + let PendingStream { + mut request, + accepted, + } = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + let mut accepted = accepted.await.unwrap(); + request.write_all(&[1, 2, 3, 4]).await.unwrap(); + tokio::time::sleep(Duration::from_millis(20)).await; + let err = request.write_all(&[5, 6, 7, 8]).await.unwrap_err(); + assert!(matches!(err, QlError::Cancelled)); + assert!(matches!( + accepted.response.next_chunk().await, + Ok(None) | Err(_) + )); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn disconnect_during_active_stream_aborts_both_halves() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + default_open_timeout: Duration::from_millis(400), + packet_ack_timeout: Duration::from_millis(30), + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + let handle_b_for_disconnect = handle_b.clone(); + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut request = stream.request; + let _response = stream.respond_to.accept(Vec::new()).unwrap(); + assert_eq!(request.next_chunk().await.unwrap(), Some(vec![1, 2, 3, 4])); + let request_outcome = request.next_chunk().await; + assert!(matches!( + request_outcome, + Ok(None) + | Err(QlError::Cancelled) + | Err(QlError::SendFailed) + | Err(QlError::StreamReset { .. }) + | Err(QlError::StreamProtocol) + )); + handle_b_for_disconnect.unpair().unwrap(); + }); + + let PendingStream { + mut request, + accepted, + } = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + request.write_all(&[1, 2, 3, 4]).await.unwrap(); + let mut accepted = accepted.await.unwrap(); + handle_a.unpair().unwrap(); + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + await_status(&status_b, peer_a.xid, PeerStage::Disconnected).await; + tokio::time::sleep(Duration::from_millis(20)).await; + + let write_err = request.write_all(&[5, 6, 7, 8]).await.unwrap_err(); + assert!(matches!(write_err, QlError::Cancelled)); + assert!(matches!( + accepted.response.next_chunk().await, + Ok(None) | Err(_) + )); + + tokio::time::timeout(Duration::from_secs(1), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} diff --git a/ql2/src/tests/unpair.rs b/ql2/src/tests/unpair.rs new file mode 100644 index 00000000..7f6b8a79 --- /dev/null +++ b/ql2/src/tests/unpair.rs @@ -0,0 +1,137 @@ +use std::time::Duration; + +use super::*; + +#[tokio::test(flavor = "current_thread")] +async fn connected_unpair_removes_peer_on_both_sides() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + handle_a.connect().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Connected).await; + await_status(&status_b, peer_a.xid, PeerStage::Connected).await; + + handle_a.unpair().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + await_status(&status_b, peer_a.xid, PeerStage::Disconnected).await; + + let result_a = handle_a.open_stream(Vec::new(), Default::default()).await; + assert!(matches!(result_a, Err(QlError::NoPeerBound))); + + let result_b = handle_b.open_stream(Vec::new(), Default::default()).await; + assert!(matches!(result_b, Err(QlError::NoPeerBound))); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn unpair_works_without_session() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let (runtime_a, handle_a) = new_runtime(platform_a, config); + let (runtime_b, handle_b) = new_runtime(platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &peer_a, &peer_b); + + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + await_status(&status_b, peer_a.xid, PeerStage::Disconnected).await; + + handle_a.unpair().unwrap(); + + await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; + await_status(&status_b, peer_a.xid, PeerStage::Disconnected).await; + + let result_a = handle_a.open_stream(Vec::new(), Default::default()).await; + assert!(matches!(result_a, Err(QlError::NoPeerBound))); + + let result_b = handle_b.open_stream(Vec::new(), Default::default()).await; + assert!(matches!(result_b, Err(QlError::NoPeerBound))); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn invalid_unpair_signature_is_ignored() { + run_local_test(async { + let config = RuntimeConfig { + engine: EngineConfig { + handshake_timeout: Duration::from_millis(200), + ..Default::default() + }, + ..Default::default() + }; + let (platform_a, _outbound_a, _status_a) = TestPlatform::new(1); + let (platform_b, _outbound_b, status_b) = TestPlatform::new(2); + let (fake_signer, _fake_outbound, _fake_status) = TestPlatform::new(3); + let peer_a = peer_identity(&platform_a); + let peer_b = peer_identity(&platform_b); + + let forged_unpair = wire::unpair::build_unpair_record( + &fake_signer, + QlHeader { + sender: peer_a.xid, + recipient: peer_b.xid, + }, + PacketId(777), + now_secs().saturating_add(60), + ); + let forged_bytes = wire::encode_record(&forged_unpair); + + let (runtime_b, handle_b) = new_runtime(platform_b, config); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + handle_b.bind_peer(Peer { + peer: peer_a.xid, + signing_key: peer_a.signing_key.clone(), + encapsulation_key: peer_a.encapsulation_key.clone(), + }); + await_status(&status_b, peer_a.xid, PeerStage::Disconnected).await; + + handle_b.send_incoming(forged_bytes); + + tokio::time::sleep(Duration::from_millis(20)).await; + + let result = handle_b.open_stream(Vec::new(), Default::default()).await; + assert!(matches!(result, Err(QlError::MissingSession))); + }) + .await; +} diff --git a/ql2/src/wire/codec.rs b/ql2/src/wire/codec.rs new file mode 100644 index 00000000..4eb8d3ec --- /dev/null +++ b/ql2/src/wire/codec.rs @@ -0,0 +1,308 @@ +use bc_components::{ + MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, MLKEMPublicKey, Nonce, MLDSA, MLKEM, XID, +}; +use rkyv::{ + rancor::{Fallible, Source}, + with::{ArchiveWith, DeserializeWith, SerializeWith}, + Archive, Archived, Deserialize, Place, Resolver, Serialize, +}; + +use crate::QlError; + +macro_rules! impl_wire_wrapper { + ($marker:ident, $external:ty, $wire:ty) => { + pub(crate) struct $marker; + + impl ArchiveWith<$external> for $marker { + type Archived = Archived<$wire>; + type Resolver = Resolver<$wire>; + + fn resolve_with( + field: &$external, + resolver: Self::Resolver, + out: Place, + ) { + <$wire>::from(field).resolve(resolver, out); + } + } + + impl SerializeWith<$external, S> for $marker + where + S: Fallible + ?Sized, + $wire: Serialize, + { + fn serialize_with( + field: &$external, + serializer: &mut S, + ) -> Result { + <$wire>::from(field).serialize(serializer) + } + } + + impl DeserializeWith, $external, D> for $marker + where + D: Fallible + ?Sized, + D::Error: Source, + Archived<$wire>: Deserialize<$wire, D>, + $wire: TryInto<$external, Error = QlError>, + { + fn deserialize_with( + field: &Archived<$wire>, + deserializer: &mut D, + ) -> Result<$external, D::Error> { + field + .deserialize(deserializer)? + .try_into() + .map_err(D::Error::new) + } + } + }; +} + +#[derive( + Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, +)] +pub(crate) struct WireXid(pub(crate) [u8; XID::XID_SIZE]); + +impl From<&XID> for WireXid { + fn from(value: &XID) -> Self { + Self(*value.data()) + } +} + +impl TryFrom for XID { + type Error = QlError; + + fn try_from(value: WireXid) -> Result { + Ok(XID::from_data(value.0)) + } +} + +pub(crate) fn xid_from_archived(value: &ArchivedWireXid) -> XID { + XID::from_data(value.0) +} + +impl_wire_wrapper!(AsWireXid, XID, WireXid); + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct WireNonce(pub(crate) [u8; Nonce::NONCE_SIZE]); + +impl From<&Nonce> for WireNonce { + fn from(value: &Nonce) -> Self { + Self(*value.data()) + } +} + +impl TryFrom for Nonce { + type Error = QlError; + + fn try_from(value: WireNonce) -> Result { + Ok(Nonce::from_data(value.0)) + } +} + +pub(crate) fn nonce_from_archived(value: &ArchivedWireNonce) -> Nonce { + Nonce::from_data(value.0) +} + +impl_wire_wrapper!(AsWireNonce, Nonce, WireNonce); + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u8)] +pub(crate) enum WireMlDsaLevel { + MlDsa44 = 2, + MlDsa65 = 3, + MlDsa87 = 5, +} + +impl TryFrom for MLDSA { + type Error = QlError; + + fn try_from(value: WireMlDsaLevel) -> Result { + Ok(match value { + WireMlDsaLevel::MlDsa44 => MLDSA::MLDSA44, + WireMlDsaLevel::MlDsa65 => MLDSA::MLDSA65, + WireMlDsaLevel::MlDsa87 => MLDSA::MLDSA87, + }) + } +} + +impl From for WireMlDsaLevel { + fn from(value: MLDSA) -> Self { + match value { + MLDSA::MLDSA44 => Self::MlDsa44, + MLDSA::MLDSA65 => Self::MlDsa65, + MLDSA::MLDSA87 => Self::MlDsa87, + } + } +} + +pub(crate) fn mldsa_level_from_archived(value: &ArchivedWireMlDsaLevel) -> MLDSA { + match value { + ArchivedWireMlDsaLevel::MlDsa44 => MLDSA::MLDSA44, + ArchivedWireMlDsaLevel::MlDsa65 => MLDSA::MLDSA65, + ArchivedWireMlDsaLevel::MlDsa87 => MLDSA::MLDSA87, + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u8)] +pub(crate) enum WireMlKemLevel { + MlKem512 = 1, + MlKem768 = 2, + MlKem1024 = 3, +} + +impl TryFrom for MLKEM { + type Error = QlError; + + fn try_from(value: WireMlKemLevel) -> Result { + Ok(match value { + WireMlKemLevel::MlKem512 => MLKEM::MLKEM512, + WireMlKemLevel::MlKem768 => MLKEM::MLKEM768, + WireMlKemLevel::MlKem1024 => MLKEM::MLKEM1024, + }) + } +} + +impl From for WireMlKemLevel { + fn from(value: MLKEM) -> Self { + match value { + MLKEM::MLKEM512 => Self::MlKem512, + MLKEM::MLKEM768 => Self::MlKem768, + MLKEM::MLKEM1024 => Self::MlKem1024, + } + } +} + +pub(crate) fn mlkem_level_from_archived(value: &ArchivedWireMlKemLevel) -> MLKEM { + match value { + ArchivedWireMlKemLevel::MlKem512 => MLKEM::MLKEM512, + ArchivedWireMlKemLevel::MlKem768 => MLKEM::MLKEM768, + ArchivedWireMlKemLevel::MlKem1024 => MLKEM::MLKEM1024, + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub(crate) struct WireMlDsaPublicKey { + pub(crate) level: WireMlDsaLevel, + pub(crate) bytes: Vec, +} + +impl TryFrom for MLDSAPublicKey { + type Error = QlError; + + fn try_from(value: WireMlDsaPublicKey) -> Result { + MLDSAPublicKey::from_bytes(value.level.try_into()?, &value.bytes) + .map_err(|_| QlError::InvalidPayload) + } +} + +impl From<&MLDSAPublicKey> for WireMlDsaPublicKey { + fn from(value: &MLDSAPublicKey) -> Self { + Self { + level: value.level().into(), + bytes: value.as_bytes().to_vec(), + } + } +} + +impl_wire_wrapper!(AsWireMlDsaPublicKey, MLDSAPublicKey, WireMlDsaPublicKey); + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub(crate) struct WireMlDsaSignature { + pub(crate) level: WireMlDsaLevel, + pub(crate) bytes: Vec, +} + +impl TryFrom for MLDSASignature { + type Error = QlError; + + fn try_from(value: WireMlDsaSignature) -> Result { + MLDSASignature::from_bytes(value.level.try_into()?, &value.bytes) + .map_err(|_| QlError::InvalidPayload) + } +} + +impl From<&MLDSASignature> for WireMlDsaSignature { + fn from(value: &MLDSASignature) -> Self { + Self { + level: value.level().into(), + bytes: value.as_bytes().to_vec(), + } + } +} + +pub(crate) fn mldsa_signature_from_archived( + value: &ArchivedWireMlDsaSignature, +) -> Result { + MLDSASignature::from_bytes( + mldsa_level_from_archived(&value.level), + value.bytes.as_slice(), + ) + .map_err(|_| QlError::InvalidPayload) +} + +impl_wire_wrapper!(AsWireMlDsaSignature, MLDSASignature, WireMlDsaSignature); + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub(crate) struct WireMlKemPublicKey { + pub(crate) level: WireMlKemLevel, + pub(crate) bytes: Vec, +} + +impl TryFrom for MLKEMPublicKey { + type Error = QlError; + + fn try_from(value: WireMlKemPublicKey) -> Result { + MLKEMPublicKey::from_bytes(value.level.try_into()?, &value.bytes) + .map_err(|_| QlError::InvalidPayload) + } +} + +impl From<&MLKEMPublicKey> for WireMlKemPublicKey { + fn from(value: &MLKEMPublicKey) -> Self { + Self { + level: value.level().into(), + bytes: value.as_bytes().to_vec(), + } + } +} + +impl_wire_wrapper!(AsWireMlKemPublicKey, MLKEMPublicKey, WireMlKemPublicKey); + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub(crate) struct WireMlKemCiphertext { + pub(crate) level: WireMlKemLevel, + pub(crate) bytes: Vec, +} + +impl TryFrom for MLKEMCiphertext { + type Error = QlError; + + fn try_from(value: WireMlKemCiphertext) -> Result { + MLKEMCiphertext::from_bytes(value.level.try_into()?, &value.bytes) + .map_err(|_| QlError::InvalidPayload) + } +} + +impl From<&MLKEMCiphertext> for WireMlKemCiphertext { + fn from(value: &MLKEMCiphertext) -> Self { + Self { + level: value.level().into(), + bytes: value.as_bytes().to_vec(), + } + } +} + +pub(crate) fn mlkem_ciphertext_from_archived( + value: &ArchivedWireMlKemCiphertext, +) -> Result { + MLKEMCiphertext::from_bytes( + mlkem_level_from_archived(&value.level), + value.bytes.as_slice(), + ) + .map_err(|_| QlError::InvalidPayload) +} + +impl_wire_wrapper!(AsWireMlKemCiphertext, MLKEMCiphertext, WireMlKemCiphertext); diff --git a/ql2/src/wire/encrypted_message.rs b/ql2/src/wire/encrypted_message.rs new file mode 100644 index 00000000..f79e7a8d --- /dev/null +++ b/ql2/src/wire/encrypted_message.rs @@ -0,0 +1,63 @@ +use bc_components::SymmetricKey; +use chacha20poly1305::{AeadInPlace, ChaCha20Poly1305, KeyInit}; +use rkyv::{seal::Seal, vec::ArchivedVec, Archive, Deserialize, Serialize}; + +use crate::QlError; + +pub const NONCE_SIZE: usize = 12; +pub const AUTH_SIZE: usize = 16; + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct EncryptedMessage { + ciphertext: Vec, + nonce: [u8; NONCE_SIZE], + auth: [u8; AUTH_SIZE], +} + +impl EncryptedMessage { + pub fn encrypt( + key: &SymmetricKey, + mut plaintext: Vec, + aad: &[u8], + nonce: [u8; NONCE_SIZE], + ) -> Self { + let cipher = ChaCha20Poly1305::new(key.data().into()); + let auth = cipher + .encrypt_in_place_detached((&nonce).into(), aad, &mut plaintext) + .expect("chacha20poly1305 encryption should succeed"); + Self { + ciphertext: plaintext, + nonce, + auth: auth.into(), + } + } + + pub fn decrypt(&self, key: &SymmetricKey, aad: &[u8]) -> Result, QlError> { + let cipher = ChaCha20Poly1305::new(key.data().into()); + let mut plaintext = self.ciphertext.clone(); + cipher + .decrypt_in_place_detached( + (&self.nonce).into(), + aad, + &mut plaintext, + (&self.auth).into(), + ) + .map_err(|_| QlError::InvalidPayload)?; + Ok(plaintext) + } +} + +impl ArchivedEncryptedMessage { + pub fn decrypt(&mut self, key: &SymmetricKey, aad: &[u8]) -> Result<&[u8], QlError> { + let cipher = ChaCha20Poly1305::new(key.data().into()); + let nonce = self.nonce; + let auth = self.auth; + let ciphertext = ArchivedVec::as_slice_seal(Seal::new(&mut self.ciphertext)); + // SAFETY: decryption only overwrites initialized u8 bytes in place. + let ciphertext = unsafe { ciphertext.unseal_unchecked() }; + cipher + .decrypt_in_place_detached((&nonce).into(), aad, ciphertext, (&auth).into()) + .map_err(|_| QlError::InvalidPayload)?; + Ok(ciphertext) + } +} diff --git a/ql2/src/wire/handshake/crypto.rs b/ql2/src/wire/handshake/crypto.rs new file mode 100644 index 00000000..4f45ef9a --- /dev/null +++ b/ql2/src/wire/handshake/crypto.rs @@ -0,0 +1,188 @@ +use bc_components::{Digest, MLDSAPublicKey, MLKEMPublicKey, Nonce, SymmetricKey, XID}; +use rkyv::{Archive, Serialize}; + +use super::{ + verify_transcript_signature, ArchivedConfirm, ArchivedHello, ArchivedHelloReply, Confirm, + Hello, HelloReply, +}; +use crate::{ + platform::QlCrypto, + wire::{ + encode_value, mldsa_signature_from_archived, mlkem_ciphertext_from_archived, + nonce_from_archived, AsWireMlKemCiphertext, AsWireNonce, AsWireXid, + }, + QlError, +}; + +#[derive(Archive, Serialize)] +struct HandshakeTranscript { + #[rkyv(with = AsWireXid)] + initiator: XID, + #[rkyv(with = AsWireXid)] + responder: XID, + #[rkyv(with = AsWireNonce)] + initiator_nonce: Nonce, + #[rkyv(with = AsWireNonce)] + responder_nonce: Nonce, + #[rkyv(with = AsWireMlKemCiphertext)] + initiator_kem_ct: bc_components::MLKEMCiphertext, + #[rkyv(with = AsWireMlKemCiphertext)] + responder_kem_ct: bc_components::MLKEMCiphertext, +} + +#[derive(Archive, Serialize)] +struct SessionKeyMaterial { + initiator_secret: Vec, + responder_secret: Vec, + transcript: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ResponderSecrets { + pub initiator_secret: SymmetricKey, + pub responder_secret: SymmetricKey, +} + +pub fn build_hello( + platform: &impl QlCrypto, + _sender: XID, + _recipient: XID, + recipient_encapsulation_key: &MLKEMPublicKey, +) -> Result<(Hello, SymmetricKey), QlError> { + let nonce = next_nonce(platform); + let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(); + Ok((Hello { nonce, kem_ct }, session_key)) +} + +pub fn respond_hello( + platform: &impl QlCrypto, + initiator: XID, + responder: XID, + initiator_encapsulation_key: &MLKEMPublicKey, + hello: &ArchivedHello, +) -> Result<(HelloReply, ResponderSecrets), QlError> { + let initiator_nonce = nonce_from_archived(&hello.nonce); + let initiator_kem_ct = mlkem_ciphertext_from_archived(&hello.kem_ct)?; + let initiator_secret = platform + .encapsulation_private_key() + .decapsulate_shared_secret(&initiator_kem_ct) + .map_err(|_| QlError::InvalidPayload)?; + let nonce = next_nonce(platform); + let (responder_secret, kem_ct) = initiator_encapsulation_key.encapsulate_new_shared_secret(); + let transcript = handshake_transcript( + initiator, + responder, + &initiator_nonce, + &nonce, + &initiator_kem_ct, + &kem_ct, + ); + let signature = platform.signing_private_key().sign(&transcript); + let reply = HelloReply { + nonce, + kem_ct, + signature, + }; + Ok(( + reply, + ResponderSecrets { + initiator_secret, + responder_secret, + }, + )) +} + +pub fn build_confirm( + platform: &impl QlCrypto, + initiator: XID, + responder: XID, + responder_signing_key: &MLDSAPublicKey, + hello: &Hello, + reply: &ArchivedHelloReply, + initiator_secret: &SymmetricKey, +) -> Result<(Confirm, SymmetricKey), QlError> { + let reply_nonce = nonce_from_archived(&reply.nonce); + let reply_kem_ct = mlkem_ciphertext_from_archived(&reply.kem_ct)?; + let reply_signature = mldsa_signature_from_archived(&reply.signature)?; + let transcript = handshake_transcript( + initiator, + responder, + &hello.nonce, + &reply_nonce, + &hello.kem_ct, + &reply_kem_ct, + ); + verify_transcript_signature(responder_signing_key, &reply_signature, &transcript)?; + let responder_secret = platform + .encapsulation_private_key() + .decapsulate_shared_secret(&reply_kem_ct) + .map_err(|_| QlError::InvalidPayload)?; + let signature = platform.signing_private_key().sign(&transcript); + let confirm = Confirm { signature }; + let session_key = derive_session_key(initiator_secret, &responder_secret, &transcript); + Ok((confirm, session_key)) +} + +pub fn finalize_confirm( + initiator: XID, + responder: XID, + initiator_signing_key: &MLDSAPublicKey, + hello: &Hello, + reply: &super::HelloReply, + confirm: &ArchivedConfirm, + secrets: &ResponderSecrets, +) -> Result { + let confirm_signature = mldsa_signature_from_archived(&confirm.signature)?; + let transcript = handshake_transcript( + initiator, + responder, + &hello.nonce, + &reply.nonce, + &hello.kem_ct, + &reply.kem_ct, + ); + verify_transcript_signature(initiator_signing_key, &confirm_signature, &transcript)?; + Ok(derive_session_key( + &secrets.initiator_secret, + &secrets.responder_secret, + &transcript, + )) +} + +fn handshake_transcript( + initiator: XID, + responder: XID, + initiator_nonce: &Nonce, + responder_nonce: &Nonce, + initiator_kem_ct: &bc_components::MLKEMCiphertext, + responder_kem_ct: &bc_components::MLKEMCiphertext, +) -> Vec { + encode_value(&HandshakeTranscript { + initiator, + responder, + initiator_nonce: initiator_nonce.clone(), + responder_nonce: responder_nonce.clone(), + initiator_kem_ct: initiator_kem_ct.clone(), + responder_kem_ct: responder_kem_ct.clone(), + }) +} + +fn next_nonce(platform: &impl QlCrypto) -> Nonce { + let mut data = [0u8; Nonce::NONCE_SIZE]; + platform.fill_random_bytes(&mut data); + Nonce::from_data(data) +} + +fn derive_session_key( + initiator_secret: &SymmetricKey, + responder_secret: &SymmetricKey, + transcript: &[u8], +) -> SymmetricKey { + let payload = encode_value(&SessionKeyMaterial { + initiator_secret: initiator_secret.as_bytes().to_vec(), + responder_secret: responder_secret.as_bytes().to_vec(), + transcript: transcript.to_vec(), + }); + let digest = Digest::from_image(payload); + SymmetricKey::from_data(*digest.data()) +} diff --git a/ql2/src/wire/handshake/mod.rs b/ql2/src/wire/handshake/mod.rs new file mode 100644 index 00000000..949e528f --- /dev/null +++ b/ql2/src/wire/handshake/mod.rs @@ -0,0 +1,50 @@ +use bc_components::{MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, Nonce}; +use rkyv::{Archive, Deserialize, Serialize}; + +use super::{AsWireMlDsaSignature, AsWireMlKemCiphertext, AsWireNonce}; +use crate::QlError; + +mod crypto; +pub use crypto::*; + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub enum HandshakeRecord { + Hello(Hello), + HelloReply(HelloReply), + Confirm(Confirm), +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Hello { + #[rkyv(with = AsWireNonce)] + pub nonce: Nonce, + #[rkyv(with = AsWireMlKemCiphertext)] + pub kem_ct: MLKEMCiphertext, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct HelloReply { + #[rkyv(with = AsWireNonce)] + pub nonce: Nonce, + #[rkyv(with = AsWireMlKemCiphertext)] + pub kem_ct: MLKEMCiphertext, + #[rkyv(with = AsWireMlDsaSignature)] + pub signature: MLDSASignature, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Confirm { + #[rkyv(with = AsWireMlDsaSignature)] + pub signature: MLDSASignature, +} + +pub fn verify_transcript_signature( + signing_key: &MLDSAPublicKey, + signature: &MLDSASignature, + transcript: &[u8], +) -> Result<(), QlError> { + match signing_key.verify(signature, transcript) { + Ok(true) => Ok(()), + _ => Err(QlError::InvalidSignature), + } +} diff --git a/ql2/src/wire/heartbeat/crypto.rs b/ql2/src/wire/heartbeat/crypto.rs new file mode 100644 index 00000000..ccc92ae9 --- /dev/null +++ b/ql2/src/wire/heartbeat/crypto.rs @@ -0,0 +1,39 @@ +use bc_components::SymmetricKey; + +use super::HeartbeatBody; +use crate::{ + wire::{ + access_value, deserialize_value, encode_value, + encrypted_message::{ArchivedEncryptedMessage, EncryptedMessage, NONCE_SIZE}, + ensure_not_expired, QlHeader, QlPayload, QlRecord, + }, + QlError, +}; + +pub fn encrypt_heartbeat( + header: QlHeader, + session_key: &SymmetricKey, + body: HeartbeatBody, + nonce: [u8; NONCE_SIZE], +) -> QlRecord { + let aad = header.aad(); + let body_bytes = encode_value(&body); + let encrypted = EncryptedMessage::encrypt(session_key, body_bytes, &aad, nonce); + QlRecord { + header, + payload: QlPayload::Heartbeat(encrypted), + } +} + +pub(crate) fn decrypt_heartbeat( + header: &QlHeader, + encrypted: &mut ArchivedEncryptedMessage, + session_key: &SymmetricKey, +) -> Result { + let aad = header.aad(); + let plaintext = encrypted.decrypt(session_key, &aad)?; + let body = access_value::(plaintext)?; + let body = deserialize_value(body)?; + ensure_not_expired(body.valid_until)?; + Ok(body) +} diff --git a/ql2/src/wire/heartbeat/mod.rs b/ql2/src/wire/heartbeat/mod.rs new file mode 100644 index 00000000..f5e75950 --- /dev/null +++ b/ql2/src/wire/heartbeat/mod.rs @@ -0,0 +1,12 @@ +use rkyv::{Archive, Deserialize, Serialize}; + +use crate::PacketId; + +mod crypto; +pub use crypto::*; + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct HeartbeatBody { + pub packet_id: PacketId, + pub valid_until: u64, +} diff --git a/ql2/src/wire/mod.rs b/ql2/src/wire/mod.rs new file mode 100644 index 00000000..358dbe19 --- /dev/null +++ b/ql2/src/wire/mod.rs @@ -0,0 +1,128 @@ +use bc_components::XID; +use rkyv::{ + api::{ + high::{to_bytes_in, HighSerializer, HighValidator}, + low::{self, LowDeserializer}, + }, + bytecheck::CheckBytes, + ser::allocator::ArenaHandle, + Archive, Deserialize, Portable, Serialize, +}; + +pub mod encrypted_message; +pub mod handshake; +pub mod heartbeat; +pub mod pair; +pub mod stream; +pub mod unpair; + +mod codec; + +pub(crate) use codec::*; + +use self::{ + encrypted_message::EncryptedMessage, handshake::HandshakeRecord, pair::PairRequestRecord, + unpair::UnpairRecord, +}; +use crate::QlError; + +pub(crate) type WireArchiveError = rkyv::rancor::Error; + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct QlRecord { + pub header: QlHeader, + pub payload: QlPayload, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct QlHeader { + #[rkyv(with = AsWireXid)] + pub sender: XID, + #[rkyv(with = AsWireXid)] + pub recipient: XID, +} + +impl QlHeader { + pub fn aad(&self) -> Vec { + encode_value(self) + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub enum QlPayload { + Handshake(HandshakeRecord), + Pair(PairRequestRecord), + Unpair(UnpairRecord), + Heartbeat(EncryptedMessage), + Stream(EncryptedMessage), +} + +pub fn encode_record(record: &QlRecord) -> Vec { + encode_value(record) +} + +pub fn access_record(bytes: &[u8]) -> Result<&ArchivedQlRecord, QlError> { + access_value(bytes) +} + +pub fn decode_record(bytes: &[u8]) -> Result { + deserialize_value(access_record(bytes)?) +} + +pub(crate) fn encode_value( + value: &impl for<'a> Serialize, ArenaHandle<'a>, WireArchiveError>>, +) -> Vec { + to_bytes_in::<_, WireArchiveError>(value, Vec::new()) + .expect("wire serialization should not fail") +} + +pub(crate) fn access_value(bytes: &[u8]) -> Result<&T, QlError> +where + T: Portable + for<'a> CheckBytes>, +{ + rkyv::access::(bytes).map_err(|_| QlError::InvalidPayload) +} + +pub(crate) fn deserialize_value( + value: &impl rkyv::Deserialize>, +) -> Result { + low::deserialize::(value).map_err(|_| QlError::InvalidPayload) +} + +pub(crate) fn ensure_not_expired(valid_until: u64) -> Result<(), QlError> { + if now_secs() > valid_until { + Err(QlError::Timeout) + } else { + Ok(()) + } +} + +pub(crate) fn now_secs() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|duration| duration.as_secs()) + .unwrap_or(0) +} + +#[test] +fn ql_record_round_trip() { + let record = QlRecord { + header: QlHeader { + sender: XID::from_data([1; XID::XID_SIZE]), + recipient: XID::from_data([2; XID::XID_SIZE]), + }, + payload: QlPayload::Heartbeat(encrypted_message::EncryptedMessage::encrypt( + &bc_components::SymmetricKey::from_data( + [7; bc_components::SymmetricKey::SYMMETRIC_KEY_SIZE], + ), + vec![3u8, 4, 5], + b"roundtrip", + [8; encrypted_message::NONCE_SIZE], + )), + }; + + let bytes = encode_record(&record); + let decoded = decode_record(&bytes).unwrap(); + + assert_eq!(decoded, record); +} diff --git a/ql2/src/wire/pair/crypto.rs b/ql2/src/wire/pair/crypto.rs new file mode 100644 index 00000000..6b087bc3 --- /dev/null +++ b/ql2/src/wire/pair/crypto.rs @@ -0,0 +1,147 @@ +use std::time::Duration; + +use bc_components::{ + MLDSAPublicKey, MLKEMCiphertext, MLKEMPublicKey, SigningPublicKey, SymmetricKey, XID, +}; +use rkyv::{Archive, Serialize}; + +use super::{PairRequestBody, PairRequestRecord}; +use crate::{ + platform::QlCrypto, + wire::{ + access_value, deserialize_value, encode_value, + encrypted_message::{ArchivedEncryptedMessage, EncryptedMessage, NONCE_SIZE}, + ensure_not_expired, mlkem_ciphertext_from_archived, now_secs, AsWireMlDsaPublicKey, + AsWireMlKemCiphertext, AsWireMlKemPublicKey, QlHeader, QlPayload, QlRecord, + }, + PacketId, QlError, +}; + +#[derive(Archive, Serialize)] +struct PairingAad { + header: QlHeader, + #[rkyv(with = AsWireMlKemCiphertext)] + kem_ct: MLKEMCiphertext, +} + +#[derive(Archive, Serialize)] +struct PairingProofData { + aad: Vec, + packet_id: PacketId, + valid_until: u64, + #[rkyv(with = AsWireMlDsaPublicKey)] + signing_pub_key: MLDSAPublicKey, + #[rkyv(with = AsWireMlKemPublicKey)] + encapsulation_pub_key: MLKEMPublicKey, +} + +pub fn build_pair_request( + platform: &impl QlCrypto, + recipient: XID, + recipient_encapsulation_key: &MLKEMPublicKey, + packet_id: PacketId, + valid_for: Duration, +) -> Result { + let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(); + let header = QlHeader { + sender: platform.xid(), + recipient, + }; + let valid_until = now_secs().saturating_add(valid_for.as_secs()); + let signing_pub_key = platform.signing_public_key().clone(); + let sender_encapsulation_key = platform.encapsulation_public_key().clone(); + let proof_data = pairing_proof_data( + &header, + &kem_ct, + packet_id, + valid_until, + &signing_pub_key, + &sender_encapsulation_key, + ); + let proof = platform.signing_private_key().sign(&proof_data); + let body = PairRequestBody { + packet_id, + valid_until, + signing_pub_key, + encapsulation_pub_key: sender_encapsulation_key, + proof, + }; + let body_bytes = encode_value(&body); + let aad = pairing_aad(&header, &kem_ct); + let mut nonce = [0u8; NONCE_SIZE]; + platform.fill_random_bytes(&mut nonce); + let encrypted = EncryptedMessage::encrypt(&session_key, body_bytes, &aad, nonce); + Ok(QlRecord { + header, + payload: QlPayload::Pair(PairRequestRecord { kem_ct, encrypted }), + }) +} + +pub fn decrypt_pair_request( + platform: &impl QlCrypto, + header: &QlHeader, + request: &mut super::ArchivedPairRequestRecord, +) -> Result { + let kem_ct = mlkem_ciphertext_from_archived(&request.kem_ct)?; + let aad = pairing_aad(header, &kem_ct); + let session_key = platform + .encapsulation_private_key() + .decapsulate_shared_secret(&kem_ct) + .map_err(|_| QlError::InvalidPayload)?; + let decrypted = decrypt_body(&session_key, &mut request.encrypted, &aad)?; + ensure_not_expired(decrypted.valid_until)?; + if XID::new(SigningPublicKey::MLDSA(decrypted.signing_pub_key.clone())) != header.sender { + return Err(QlError::InvalidPayload); + } + let proof_data = pairing_proof_data( + header, + &kem_ct, + decrypted.packet_id, + decrypted.valid_until, + &decrypted.signing_pub_key, + &decrypted.encapsulation_pub_key, + ); + if decrypted + .signing_pub_key + .verify(&decrypted.proof, &proof_data) + .unwrap_or(false) + { + Ok(decrypted) + } else { + Err(QlError::InvalidSignature) + } +} + +fn pairing_proof_data( + header: &QlHeader, + kem_ct: &MLKEMCiphertext, + packet_id: PacketId, + valid_until: u64, + signing_pub_key: &MLDSAPublicKey, + encapsulation_pub_key: &MLKEMPublicKey, +) -> Vec { + encode_value(&PairingProofData { + aad: pairing_aad(header, kem_ct), + packet_id, + valid_until, + signing_pub_key: signing_pub_key.clone(), + encapsulation_pub_key: encapsulation_pub_key.clone(), + }) +} + +fn decrypt_body( + key: &SymmetricKey, + encrypted: &mut ArchivedEncryptedMessage, + aad: &[u8], +) -> Result { + let plaintext = encrypted.decrypt(key, aad)?; + let body = access_value::(plaintext)?; + deserialize_value(body) +} + +pub(crate) fn pairing_aad(header: &QlHeader, kem_ct: &MLKEMCiphertext) -> Vec { + encode_value(&PairingAad { + header: header.clone(), + kem_ct: kem_ct.clone(), + }) +} diff --git a/ql2/src/wire/pair/mod.rs b/ql2/src/wire/pair/mod.rs new file mode 100644 index 00000000..958462cb --- /dev/null +++ b/ql2/src/wire/pair/mod.rs @@ -0,0 +1,30 @@ +use bc_components::{MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, MLKEMPublicKey}; +use rkyv::{Archive, Deserialize, Serialize}; + +use super::{ + encrypted_message::EncryptedMessage, AsWireMlDsaPublicKey, AsWireMlDsaSignature, + AsWireMlKemCiphertext, AsWireMlKemPublicKey, +}; +use crate::PacketId; + +mod crypto; +pub use crypto::*; + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct PairRequestRecord { + #[rkyv(with = AsWireMlKemCiphertext)] + pub kem_ct: MLKEMCiphertext, + pub encrypted: EncryptedMessage, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct PairRequestBody { + pub packet_id: PacketId, + pub valid_until: u64, + #[rkyv(with = AsWireMlDsaPublicKey)] + pub signing_pub_key: MLDSAPublicKey, + #[rkyv(with = AsWireMlKemPublicKey)] + pub encapsulation_pub_key: MLKEMPublicKey, + #[rkyv(with = AsWireMlDsaSignature)] + pub proof: MLDSASignature, +} diff --git a/ql2/src/wire/stream/crypto.rs b/ql2/src/wire/stream/crypto.rs new file mode 100644 index 00000000..620db25f --- /dev/null +++ b/ql2/src/wire/stream/crypto.rs @@ -0,0 +1,39 @@ +use bc_components::SymmetricKey; + +use super::StreamBody; +use crate::{ + wire::{ + access_value, deserialize_value, encode_value, + encrypted_message::{ArchivedEncryptedMessage, EncryptedMessage, NONCE_SIZE}, + ensure_not_expired, QlHeader, QlPayload, QlRecord, + }, + QlError, +}; + +pub fn encrypt_stream( + header: QlHeader, + session_key: &SymmetricKey, + body: StreamBody, + nonce: [u8; NONCE_SIZE], +) -> QlRecord { + let aad = header.aad(); + let body_bytes = encode_value(&body); + let encrypted = EncryptedMessage::encrypt(session_key, body_bytes, &aad, nonce); + QlRecord { + header, + payload: QlPayload::Stream(encrypted), + } +} + +pub(crate) fn decrypt_stream( + header: &QlHeader, + encrypted: &mut ArchivedEncryptedMessage, + session_key: &SymmetricKey, +) -> Result { + let aad = header.aad(); + let plaintext = encrypted.decrypt(session_key, &aad)?; + let body = access_value::(plaintext)?; + let body = deserialize_value(body)?; + ensure_not_expired(body.valid_until)?; + Ok(body) +} diff --git a/ql2/src/wire/stream/mod.rs b/ql2/src/wire/stream/mod.rs new file mode 100644 index 00000000..8db6fbb5 --- /dev/null +++ b/ql2/src/wire/stream/mod.rs @@ -0,0 +1,247 @@ +use rkyv::{Archive, Deserialize, Serialize}; + +use crate::{PacketId, StreamId}; + +mod crypto; +pub use crypto::*; + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct StreamBody { + pub packet_id: PacketId, + pub valid_until: u64, + pub packet_ack: Option, + pub frame: Option, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +pub struct PacketAck { + pub packet_id: PacketId, +} + +impl From<&ArchivedPacketAck> for PacketAck { + fn from(value: &ArchivedPacketAck) -> Self { + Self { + packet_id: (&value.packet_id).into(), + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub enum StreamFrame { + Open(StreamFrameOpen), + Accept(StreamFrameAccept), + Reject(StreamFrameReject), + Data(StreamFrameData), + Credit(StreamFrameCredit), + Finish(StreamFrameFinish), + Reset(StreamFrameReset), +} + +impl StreamFrame { + pub fn stream_id(&self) -> StreamId { + match self { + StreamFrame::Open(StreamFrameOpen { stream_id, .. }) + | StreamFrame::Accept(StreamFrameAccept { stream_id, .. }) + | StreamFrame::Reject(StreamFrameReject { stream_id, .. }) + | StreamFrame::Data(StreamFrameData { stream_id, .. }) + | StreamFrame::Credit(StreamFrameCredit { stream_id, .. }) + | StreamFrame::Finish(StreamFrameFinish { stream_id, .. }) + | StreamFrame::Reset(StreamFrameReset { stream_id, .. }) => *stream_id, + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct StreamFrameOpen { + pub stream_id: StreamId, + pub request_head: Vec, + pub response_max_offset: u64, +} + +impl From<&ArchivedStreamFrameOpen> for StreamFrameOpen { + fn from(value: &ArchivedStreamFrameOpen) -> Self { + Self { + stream_id: (&value.stream_id).into(), + request_head: value.request_head.as_slice().to_vec(), + response_max_offset: value.response_max_offset.to_native(), + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct StreamFrameAccept { + pub stream_id: StreamId, + pub response_head: Vec, + pub request_max_offset: u64, +} + +impl From<&ArchivedStreamFrameAccept> for StreamFrameAccept { + fn from(value: &ArchivedStreamFrameAccept) -> Self { + Self { + stream_id: (&value.stream_id).into(), + response_head: value.response_head.as_slice().to_vec(), + request_max_offset: value.request_max_offset.to_native(), + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +pub struct StreamFrameReject { + pub stream_id: StreamId, + pub code: RejectCode, +} + +impl From<&ArchivedStreamFrameReject> for StreamFrameReject { + fn from(value: &ArchivedStreamFrameReject) -> Self { + Self { + stream_id: (&value.stream_id).into(), + code: (&value.code).into(), + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +pub struct StreamFrameCredit { + pub stream_id: StreamId, + pub dir: Direction, + pub recv_offset: u64, + pub max_offset: u64, +} + +impl From<&ArchivedStreamFrameCredit> for StreamFrameCredit { + fn from(value: &ArchivedStreamFrameCredit) -> Self { + Self { + stream_id: (&value.stream_id).into(), + dir: (&value.dir).into(), + recv_offset: value.recv_offset.to_native(), + max_offset: value.max_offset.to_native(), + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct StreamFrameData { + pub stream_id: StreamId, + pub dir: Direction, + pub offset: u64, + pub bytes: Vec, +} + +impl From<&ArchivedStreamFrameData> for StreamFrameData { + fn from(value: &ArchivedStreamFrameData) -> Self { + Self { + stream_id: (&value.stream_id).into(), + dir: (&value.dir).into(), + offset: value.offset.to_native(), + bytes: value.bytes.as_slice().to_vec(), + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +pub struct StreamFrameFinish { + pub stream_id: StreamId, + pub dir: Direction, +} + +impl From<&ArchivedStreamFrameFinish> for StreamFrameFinish { + fn from(value: &ArchivedStreamFrameFinish) -> Self { + Self { + stream_id: (&value.stream_id).into(), + dir: (&value.dir).into(), + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +pub struct StreamFrameReset { + pub stream_id: StreamId, + pub dir: ResetTarget, + pub code: ResetCode, +} + +impl From<&ArchivedStreamFrameReset> for StreamFrameReset { + fn from(value: &ArchivedStreamFrameReset) -> Self { + Self { + stream_id: (&value.stream_id).into(), + dir: (&value.dir).into(), + code: (&value.code).into(), + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum Direction { + Request = 1, + Response = 2, +} + +impl From<&ArchivedDirection> for Direction { + fn from(value: &ArchivedDirection) -> Self { + match value { + ArchivedDirection::Request => Self::Request, + ArchivedDirection::Response => Self::Response, + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum ResetTarget { + Request = 1, + Response = 2, + Both = 3, +} + +impl From<&ArchivedResetTarget> for ResetTarget { + fn from(value: &ArchivedResetTarget) -> Self { + match value { + ArchivedResetTarget::Request => Self::Request, + ArchivedResetTarget::Response => Self::Response, + ArchivedResetTarget::Both => Self::Both, + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum RejectCode { + Unknown = 0, + UnknownRoute = 1, + InvalidHead = 2, + Busy = 3, + Unhandled = 4, +} + +impl From<&ArchivedRejectCode> for RejectCode { + fn from(value: &ArchivedRejectCode) -> Self { + match value { + ArchivedRejectCode::Unknown => Self::Unknown, + ArchivedRejectCode::UnknownRoute => Self::UnknownRoute, + ArchivedRejectCode::InvalidHead => Self::InvalidHead, + ArchivedRejectCode::Busy => Self::Busy, + ArchivedRejectCode::Unhandled => Self::Unhandled, + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum ResetCode { + Cancelled = 0, + InvalidData = 1, + Protocol = 2, + Timeout = 3, +} + +impl From<&ArchivedResetCode> for ResetCode { + fn from(value: &ArchivedResetCode) -> Self { + match value { + ArchivedResetCode::Cancelled => Self::Cancelled, + ArchivedResetCode::InvalidData => Self::InvalidData, + ArchivedResetCode::Protocol => Self::Protocol, + ArchivedResetCode::Timeout => Self::Timeout, + } + } +} diff --git a/ql2/src/wire/unpair/crypto.rs b/ql2/src/wire/unpair/crypto.rs new file mode 100644 index 00000000..4339accd --- /dev/null +++ b/ql2/src/wire/unpair/crypto.rs @@ -0,0 +1,65 @@ +use bc_components::MLDSAPublicKey; +use rkyv::{Archive, Serialize}; + +use super::UnpairRecord; +use crate::{ + platform::QlCrypto, + wire::{encode_value, mldsa_signature_from_archived, now_secs, QlHeader, QlPayload, QlRecord}, + PacketId, QlError, +}; + +#[derive(Archive, Serialize)] +struct UnpairProofData { + domain: Vec, + header: QlHeader, + packet_id: PacketId, + valid_until: u64, +} + +pub fn build_unpair_record( + platform: &impl QlCrypto, + header: QlHeader, + packet_id: PacketId, + valid_until: u64, +) -> QlRecord { + let signature = + platform + .signing_private_key() + .sign(unpair_proof_data(&header, packet_id, valid_until)); + QlRecord { + header, + payload: QlPayload::Unpair(UnpairRecord { + packet_id, + valid_until, + signature, + }), + } +} + +pub fn verify_unpair_record( + header: &QlHeader, + record: &super::ArchivedUnpairRecord, + signing_key: &MLDSAPublicKey, +) -> Result<(), QlError> { + let packet_id = (&record.packet_id).into(); + let valid_until = record.valid_until.to_native(); + let signature = mldsa_signature_from_archived(&record.signature)?; + if now_secs() > valid_until { + return Err(QlError::InvalidPayload); + } + let proof_data = unpair_proof_data(header, packet_id, valid_until); + if signing_key.verify(&signature, &proof_data).unwrap_or(false) { + Ok(()) + } else { + Err(QlError::InvalidSignature) + } +} + +fn unpair_proof_data(header: &QlHeader, packet_id: PacketId, valid_until: u64) -> Vec { + encode_value(&UnpairProofData { + domain: b"ql-unpair-v1".to_vec(), + header: header.clone(), + packet_id, + valid_until, + }) +} diff --git a/ql2/src/wire/unpair/mod.rs b/ql2/src/wire/unpair/mod.rs new file mode 100644 index 00000000..e17bcf16 --- /dev/null +++ b/ql2/src/wire/unpair/mod.rs @@ -0,0 +1,16 @@ +use bc_components::MLDSASignature; +use rkyv::{Archive, Deserialize, Serialize}; + +use super::AsWireMlDsaSignature; +use crate::PacketId; + +mod crypto; +pub use crypto::*; + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct UnpairRecord { + pub packet_id: PacketId, + pub valid_until: u64, + #[rkyv(with = AsWireMlDsaSignature)] + pub signature: MLDSASignature, +} From d8b1d1ef603bacbcfaa5ef12547e4e9d9b1b46ea Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 18 Mar 2026 01:16:43 -0400 Subject: [PATCH 006/304] ql: redesign transport around stop-and-wait then SACK sliding windows --- ql2/src/engine/mod.rs | 1495 ++++++++++++++++-------------- ql2/src/engine/replay_cache.rs | 26 +- ql2/src/engine/ring.rs | 394 ++++++++ ql2/src/engine/state.rs | 181 ++-- ql2/src/engine/stream.rs | 425 ++++++--- ql2/src/engine/tests.rs | 1360 +++++++++++++++++++++++++++ ql2/src/lib.rs | 14 +- ql2/src/platform.rs | 42 +- ql2/src/runtime/command.rs | 11 +- ql2/src/runtime/driver.rs | 246 +++-- ql2/src/runtime/handle.rs | 94 +- ql2/src/runtime/mod.rs | 29 +- ql2/src/runtime/pipe.rs | 772 --------------- ql2/src/wire/codec.rs | 84 +- ql2/src/wire/handshake/crypto.rs | 176 +++- ql2/src/wire/handshake/mod.rs | 13 +- ql2/src/wire/heartbeat/crypto.rs | 2 +- ql2/src/wire/heartbeat/mod.rs | 5 +- ql2/src/wire/mod.rs | 27 +- ql2/src/wire/pair/crypto.rs | 55 +- ql2/src/wire/pair/mod.rs | 6 +- ql2/src/wire/seq.rs | 97 ++ ql2/src/wire/stream/crypto.rs | 6 +- ql2/src/wire/stream/mod.rs | 132 +-- ql2/src/wire/unpair/crypto.rs | 47 +- ql2/src/wire/unpair/mod.rs | 6 +- 26 files changed, 3539 insertions(+), 2206 deletions(-) create mode 100644 ql2/src/engine/ring.rs create mode 100644 ql2/src/engine/tests.rs delete mode 100644 ql2/src/runtime/pipe.rs create mode 100644 ql2/src/wire/seq.rs diff --git a/ql2/src/engine/mod.rs b/ql2/src/engine/mod.rs index b167478b..ac1029da 100644 --- a/ql2/src/engine/mod.rs +++ b/ql2/src/engine/mod.rs @@ -1,7 +1,11 @@ pub mod replay_cache; +mod ring; mod state; mod stream; +#[cfg(test)] +mod tests; + use std::{ cmp::Reverse, collections::HashMap, @@ -16,29 +20,23 @@ pub use state::{ OutputFn, PeerRecord, PeerSession, Token, TrackedWrite, }; -use self::{ - replay_cache::{ReplayKey, ReplayNamespace}, - state::*, - stream::*, -}; +use self::{replay_cache::ReplayKey, state::*, stream::*}; use crate::{ - platform::QlCrypto, - runtime::StreamConfig, + platform::{QlCrypto, QlIdentity}, wire::{ self, encrypted_message::{ArchivedEncryptedMessage, NONCE_SIZE}, handshake::{self, HandshakeRecord, Hello}, heartbeat::{self, HeartbeatBody}, stream::{ - decrypt_stream, encrypt_stream, Direction, PacketAck, RejectCode, ResetCode, - ResetTarget, StreamBody, StreamFrame, StreamFrameAccept, StreamFrameCredit, - StreamFrameData, StreamFrameFinish, StreamFrameOpen, StreamFrameReject, - StreamFrameReset, + decrypt_stream, encrypt_stream, BodyChunk, Direction, RejectCode, ResetCode, + ResetTarget, StreamAck, StreamAckBody, StreamBody, StreamFrame, StreamFrameAccept, + StreamFrameData, StreamFrameOpen, StreamFrameReject, StreamFrameReset, StreamMessage, }, unpair::{self}, - QlHeader, QlPayload, QlRecord, + ControlMeta, QlHeader, QlPayload, QlRecord, StreamSeq, }, - PacketId, Peer, QlError, StreamId, + Peer, QlError, StreamId, }; #[derive(Debug, Clone, Copy)] @@ -47,15 +45,19 @@ pub struct KeepAliveConfig { pub timeout: Duration, } +#[derive(Debug, Clone, Copy, Default)] +pub struct StreamConfig { + pub open_timeout: Option, +} + #[derive(Debug, Clone, Copy)] pub struct EngineConfig { pub handshake_timeout: Duration, pub default_open_timeout: Duration, pub packet_expiration: Duration, - pub packet_ack_timeout: Duration, + pub stream_ack_delay: Duration, + pub stream_ack_timeout: Duration, pub stream_retry_limit: u8, - pub max_payload_bytes: usize, - pub initial_credit: u64, pub keep_alive: Option, } @@ -65,27 +67,19 @@ impl Default for EngineConfig { handshake_timeout: Duration::from_secs(5), default_open_timeout: Duration::from_secs(5), packet_expiration: Duration::from_secs(30), - packet_ack_timeout: Duration::from_millis(150), + stream_ack_delay: Duration::from_millis(5), + stream_ack_timeout: Duration::from_millis(150), stream_retry_limit: 5, - max_payload_bytes: 1024, - initial_credit: 1024, keep_alive: None, } } } -impl EngineConfig { - pub(crate) fn normalized(mut self) -> Self { - self.max_payload_bytes = self.max_payload_bytes.max(1); - self.initial_credit = self.initial_credit.max(self.max_payload_bytes as u64); - self - } -} - impl Engine { - pub fn new(config: EngineConfig, peer: Option) -> Self { + pub fn new(config: EngineConfig, identity: QlIdentity, peer: Option) -> Self { Self { - config: config.normalized(), + config: config, + identity, state: EngineState::new(peer), streams: HashMap::new(), } @@ -102,35 +96,29 @@ impl Engine { EngineInput::BindPeer(peer) => self.handle_bind_peer(peer, emit), EngineInput::Pair => self.handle_pair_local(now, crypto), EngineInput::Connect => self.handle_connect(now, crypto, emit), - EngineInput::Unpair => self.handle_unpair_local(now, crypto, emit), + EngineInput::Unpair => self.handle_unpair_local(now, emit), EngineInput::OpenStream { open_id, request_head, + request_prefix, config, - } => self.handle_open_stream(now, open_id, request_head, config, emit), + } => self.handle_open_stream(now, open_id, request_head, request_prefix, config, emit), EngineInput::AcceptStream { stream_id, response_head, - } => self.handle_accept_stream(now, stream_id, response_head), + response_prefix, + } => self.handle_accept_stream(now, stream_id, response_head, response_prefix), EngineInput::RejectStream { stream_id, code } => { self.handle_reject_stream(now, stream_id, code) } EngineInput::OutboundData { stream_id, dir, - offset, bytes, - } => self.handle_outbound_data(stream_id, dir, offset, bytes), - EngineInput::OutboundFinished { - stream_id, - dir, - final_offset, - } => self.handle_outbound_finished(stream_id, dir, final_offset), - EngineInput::InboundConsumed { - stream_id, - dir, - amount, - } => self.handle_inbound_consumed(now, stream_id, dir, amount), + } => self.handle_outbound_data(stream_id, dir, bytes), + EngineInput::OutboundFinished { stream_id, dir } => { + self.handle_outbound_finished(stream_id, dir) + } EngineInput::ResetOutbound { stream_id, dir, @@ -170,6 +158,19 @@ impl Engine { } } + fn next_control_meta(&self, valid_for: Duration) -> ControlMeta { + ControlMeta { + packet_id: self.state.next_packet_id(), + valid_until: wire::now_secs() + valid_for.as_secs(), + } + } + + fn is_replayed_control(&mut self, peer: XID, meta: ControlMeta) -> bool { + self.state + .replay_cache + .check_and_store_valid_until(ReplayKey::new(peer, meta.packet_id), meta.valid_until) + } + fn bind_peer_record(&mut self, peer: Peer, emit: &mut impl OutputFn) { self.reset_runtime(QlError::Cancelled, emit); self.state.peer = Some(PeerRecord::new( @@ -191,9 +192,6 @@ impl Engine { self.state.outbound.clear(); self.state.timeouts.clear(); self.state.write_in_flight = None; - if let Some(peer) = self.state.peer.as_ref().map(|peer| peer.peer) { - self.state.replay_cache.clear_peer(peer); - } } fn handle_bind_peer(&mut self, peer: Peer, emit: &mut impl OutputFn) { @@ -210,12 +208,13 @@ impl Engine { let Some(peer) = self.state.peer.as_ref() else { return; }; + let meta = self.next_control_meta(self.config.packet_expiration); let Ok(record) = wire::pair::build_pair_request( + &self.identity, crypto, peer.peer, &peer.encapsulation_key, - self.state.next_packet_id(), - self.config.packet_expiration, + meta, ) else { return; }; @@ -232,6 +231,7 @@ impl Engine { return; }; let peer = peer_record.peer; + let meta = self.next_control_meta(self.config.handshake_timeout); let (hello, session_key) = match &peer_record.session { PeerSession::Connected { .. } | PeerSession::Initiator { .. } @@ -240,10 +240,11 @@ impl Engine { } PeerSession::Disconnected => { match handshake::build_hello( + &self.identity, crypto, - crypto.xid(), peer, &peer_record.encapsulation_key, + meta, ) { Ok(result) => result, Err(_) => return, @@ -266,7 +267,7 @@ impl Engine { let record = QlRecord { header: QlHeader { - sender: crypto.xid(), + sender: self.identity.xid, recipient: peer, }, payload: QlPayload::Handshake(HandshakeRecord::Hello(hello)), @@ -274,23 +275,18 @@ impl Engine { self.enqueue_handshake_message(token, deadline, wire::encode_record(&record)); } - fn handle_unpair_local( - &mut self, - now: Instant, - crypto: &impl QlCrypto, - emit: &mut impl OutputFn, - ) { + fn handle_unpair_local(&mut self, now: Instant, emit: &mut impl OutputFn) { let Some(peer) = self.state.peer.as_ref().map(|peer| peer.peer) else { return; }; + let meta = self.next_control_meta(self.config.packet_expiration); let record = unpair::build_unpair_record( - crypto, + &self.identity, QlHeader { - sender: crypto.xid(), + sender: self.identity.xid, recipient: peer, }, - self.state.next_packet_id(), - wire::now_secs().saturating_add(self.config.packet_expiration.as_secs()), + meta, ); self.unpair_peer(emit); let token = self.state.next_token(); @@ -306,6 +302,7 @@ impl Engine { now: Instant, open_id: OpenId, request_head: Vec, + request_prefix: Option, config: StreamConfig, emit: &mut impl OutputFn, ) { @@ -326,32 +323,29 @@ impl Engine { return; } - let stream_id = self.state.next_stream_id(); + let stream_namespace = StreamNamespace::for_local(self.identity.xid, entry.peer); + let stream_id = self.state.next_stream_id(stream_namespace); let open_timeout = config .open_timeout .unwrap_or(self.config.default_open_timeout); let token = self.state.next_token(); + let request_prefix_fin = request_prefix.as_ref().is_some_and(|chunk| chunk.fin); let frame = StreamFrameOpen { stream_id, - request_head: request_head.clone(), - response_max_offset: self.config.initial_credit, + request_head, + request_prefix, }; let stream = StreamState::Initiator(InitiatorStream { meta: StreamMeta { - key: StreamKey { stream_id }, - request_head, + stream_id, last_activity: now, }, control: StreamControl { - pending: PendingFrames { - setup: Some(SetupFrame::Open(frame)), - credit: None, - reset: None, - }, - awaiting: None, + pending: std::collections::VecDeque::from([StreamFrame::Open(frame)]), + ..Default::default() }, - request: OutboundState::new(Direction::Request, self.config.initial_credit, true), - response: InboundState::new(self.config.initial_credit), + request: OutboundState::from_prefix(Direction::Request, request_prefix_fin), + response: InboundState::new(), accept: InitiatorAccept::Opening(OpenWaiter { open_id: Some(open_id), open_timeout_token: token, @@ -365,25 +359,30 @@ impl Engine { emit(EngineOutput::OpenStarted { open_id, stream_id }); } - fn handle_accept_stream(&mut self, now: Instant, stream_id: StreamId, response_head: Vec) { + fn handle_accept_stream( + &mut self, + now: Instant, + stream_id: StreamId, + response_head: Vec, + response_prefix: Option, + ) { let Some(StreamState::Responder(stream)) = self.streams.get_mut(&stream_id) else { return; }; - let ResponderResponse::Pending { initial_credit } = stream.response else { + let ResponderResponse::Pending = stream.response else { return; }; + let response_prefix_fin = response_prefix.as_ref().is_some_and(|chunk| chunk.fin); stream .control .pending - .set_setup(SetupFrame::Accept(StreamFrameAccept { + .push_back(StreamFrame::Accept(StreamFrameAccept { stream_id, response_head, - request_max_offset: self.config.initial_credit, + response_prefix, })); - stream.request.max_offset = self.config.initial_credit; stream.response = ResponderResponse::Accepted { - initial_credit, - body: OutboundState::new(Direction::Response, initial_credit, false), + body: OutboundState::from_prefix(Direction::Response, response_prefix_fin), }; stream.meta.last_activity = now; } @@ -392,86 +391,48 @@ impl Engine { let Some(StreamState::Responder(stream)) = self.streams.get_mut(&stream_id) else { return; }; - let ResponderResponse::Pending { initial_credit } = stream.response else { + let ResponderResponse::Pending = stream.response else { return; }; stream .control .pending - .set_setup(SetupFrame::Reject(StreamFrameReject { stream_id, code })); - stream.response = ResponderResponse::Rejecting { initial_credit }; + .push_back(StreamFrame::Reject(StreamFrameReject { stream_id, code })); + stream.response = ResponderResponse::Rejecting; stream.meta.last_activity = now; } - fn handle_outbound_data( - &mut self, - stream_id: StreamId, - dir: Direction, - offset: u64, - bytes: Vec, - ) { + fn handle_outbound_data(&mut self, stream_id: StreamId, dir: Direction, bytes: Vec) { if bytes.is_empty() { return; } - let (streams, state) = (&mut self.streams, &mut self.state); - let Some(stream) = streams.get_mut(&stream_id) else { - return; - }; - let Some(outbound) = stream.outbound_mut(dir) else { - return; - }; - let Some(pull) = outbound.pending_pull.take() else { - return; - }; - if pull.offset != offset { - outbound.pending_pull = Some(pull); - return; - } - if bytes.len() > pull.max_len { - outbound.pending_pull = Some(pull); - return; - } - outbound.sent_offset = outbound - .sent_offset - .max(offset.saturating_add(bytes.len() as u64)); - let key = stream.key(); - let control = stream.control_mut(); - state.enqueue_data_frame(&self.config, key, control, dir, offset, bytes, 0); - } - - fn handle_outbound_finished(&mut self, stream_id: StreamId, dir: Direction, final_offset: u64) { let Some(stream) = self.streams.get_mut(&stream_id) else { return; }; let Some(outbound) = stream.outbound_mut(dir) else { return; }; - if final_offset < outbound.sent_offset { + if !outbound.take_pending_pull() { return; } - outbound.pending_pull = None; - outbound.final_offset = Some(final_offset); + let chunk = BodyChunk { bytes, fin: false }; + stream + .control_mut() + .queue_frame_back(StreamFrame::Data(StreamFrameData { + stream_id, + dir, + chunk, + })); } - fn handle_inbound_consumed( - &mut self, - now: Instant, - stream_id: StreamId, - dir: Direction, - amount: u64, - ) { + fn handle_outbound_finished(&mut self, stream_id: StreamId, dir: Direction) { let Some(stream) = self.streams.get_mut(&stream_id) else { return; }; - let Some(inbound) = stream.inbound_mut(dir) else { + let Some(outbound) = stream.outbound_mut(dir) else { return; }; - if inbound.closed { - return; - } - inbound.max_offset = inbound.max_offset.saturating_add(amount); - Self::queue_credit(stream, dir); - *stream.last_activity_mut() = now; + outbound.finish(); } fn handle_reset_outbound( @@ -487,15 +448,15 @@ impl Engine { let Some(outbound) = stream.outbound_mut(dir) else { return; }; - if outbound.closed { + if outbound.is_closed() { return; } - outbound.closed = true; - outbound.pending_pull = None; - stream - .control_mut() - .pending - .set_reset(reset_target_for_dir(dir), code); + outbound.close(); + stream.control_mut().queue_frame_front(reset_frame( + stream_id, + reset_target_for_dir(dir), + code, + )); *stream.last_activity_mut() = now; } @@ -516,10 +477,11 @@ impl Engine { return; } inbound.closed = true; - stream - .control_mut() - .pending - .set_reset(reset_target_for_dir(dir), code); + stream.control_mut().queue_frame_front(reset_frame( + stream_id, + reset_target_for_dir(dir), + code, + )); *stream.last_activity_mut() = now; } @@ -554,9 +516,9 @@ impl Engine { return; }; let record = unsafe { record.unseal_unchecked() }; - let sender = wire::xid_from_archived(&record.header.sender); - let recipient = wire::xid_from_archived(&record.header.recipient); - if recipient != crypto.xid() { + let sender: XID = (&record.header.sender).into(); + let recipient: XID = (&record.header.recipient).into(); + if recipient != self.identity.xid { return; } if !matches!(&record.payload, wire::ArchivedQlPayload::Pair(_)) { @@ -602,10 +564,10 @@ impl Engine { self.handle_hello(now, peer, hello, crypto, emit) } wire::handshake::ArchivedHandshakeRecord::HelloReply(reply) => { - self.handle_hello_reply(now, peer, reply, crypto, emit) + self.handle_hello_reply(now, peer, reply, emit) } wire::handshake::ArchivedHandshakeRecord::Confirm(confirm) => { - self.handle_confirm(now, peer, confirm, crypto, emit) + self.handle_confirm(now, peer, confirm, emit) } } } @@ -618,11 +580,14 @@ impl Engine { crypto: &impl QlCrypto, emit: &mut impl OutputFn, ) { - let payload = match wire::pair::decrypt_pair_request(crypto, header, request) { + let payload = match wire::pair::decrypt_pair_request(&self.identity, header, request) { Ok(payload) => payload, Err(_) => return, }; let peer = XID::new(SigningPublicKey::MLDSA(payload.signing_pub_key.clone())); + if self.is_replayed_control(peer, payload.meta) { + return; + } if let Some(existing) = self.state.peer.as_ref() { if existing.peer != peer || existing.signing_key != payload.signing_pub_key @@ -658,14 +623,8 @@ impl Engine { return; } } - let packet_id: PacketId = (&record.packet_id).into(); - let valid_until = record.valid_until.to_native(); - let replay_key = ReplayKey::new(peer, ReplayNamespace::Peer, packet_id); - if self - .state - .replay_cache - .check_and_store_valid_until(replay_key, valid_until) - { + let meta: ControlMeta = (&record.meta).into(); + if self.is_replayed_control(peer, meta) { return; } self.unpair_peer(emit); @@ -679,22 +638,26 @@ impl Engine { crypto: &impl QlCrypto, emit: &mut impl OutputFn, ) { - let should_reply = { + let (body, should_reply) = { let Some(peer_record) = self.state.peer.as_ref() else { return; }; let PeerSession::Connected { session_key, keepalive, + .. } = &peer_record.session else { return; }; - if heartbeat::decrypt_heartbeat(header, encrypted, session_key).is_err() { + let Ok(body) = heartbeat::decrypt_heartbeat(header, encrypted, session_key) else { return; - } - !keepalive.pending + }; + (body, !keepalive.pending) }; + if self.is_replayed_control(header.sender, body.meta) { + return; + } self.record_activity(now); if should_reply { self.send_heartbeat_message(now, crypto); @@ -705,7 +668,7 @@ impl Engine { fn handle_stream( &mut self, now: Instant, - peer: XID, + _peer: XID, header: &QlHeader, encrypted: &mut ArchivedEncryptedMessage, emit: &mut impl OutputFn, @@ -723,36 +686,160 @@ impl Engine { } }; - if let Some(ack) = body.packet_ack { - self.process_packet_ack(ack.packet_id, emit); + let message = match body { + StreamBody::Ack(StreamAckBody { stream_id, ack, .. }) => { + self.process_stream_ack(stream_id, ack, emit); + self.record_activity(now); + if self.streams.contains_key(&stream_id) { + self.record_stream_activity(stream_id, now); + self.maybe_reap_stream(stream_id, emit); + } + return; + } + StreamBody::Message(message) => message, + }; + + let stream_id = message.frame.stream_id(); + if let Some(ack) = message.ack { + self.process_stream_ack(stream_id, ack, emit); } - let Some(frame) = body.frame else { - return; + if !self.streams.contains_key(&stream_id) { + let Some(peer_record) = self.state.peer.as_ref() else { + return; + }; + let local_namespace = StreamNamespace::for_local(self.identity.xid, peer_record.peer); + if !local_namespace.remote().matches(stream_id) { + return; + } + let token = self.state.next_token(); + self.streams.insert( + stream_id, + StreamState::Provisional(ProvisionalStream { + meta: StreamMeta { + stream_id, + last_activity: now, + }, + control: StreamControl::default(), + timeout_token: token, + }), + ); + self.state.timeouts.push(Reverse(TimeoutEntry { + at: now + self.config.default_open_timeout, + kind: TimeoutKind::StreamProvisional { stream_id, token }, + })); + } + + let buffer_result = { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + *stream.last_activity_mut() = now; + stream + .control_mut() + .buffer_incoming(message.tx_seq, message.frame) }; - let replay_key = ReplayKey::new(peer, ReplayNamespace::Transfer, body.packet_id); - if self - .state - .replay_cache - .check_and_store_valid_until(replay_key, body.valid_until) - { - return; + match buffer_result { + BufferIncomingResult::OutOfWindow => { + if self + .streams + .get(&stream_id) + .is_some_and(StreamState::is_provisional) + { + self.streams.remove(&stream_id); + self.send_ephemeral_reset(stream_id, ResetTarget::Both, ResetCode::Protocol); + } else if let Some(stream) = self.streams.get_mut(&stream_id) { + Self::queue_protocol_reset(stream, emit); + *stream.last_activity_mut() = now; + } + return; + } + BufferIncomingResult::Duplicate | BufferIncomingResult::AlreadyBuffered => { + if let Some(stream) = self.streams.get_mut(&stream_id) { + stream.control_mut().note_ack(true); + } + self.schedule_stream_ack(stream_id, now); + self.record_activity(now); + self.record_stream_activity(stream_id, now); + return; + } + BufferIncomingResult::Buffered { out_of_order } => { + if let Some(stream) = self.streams.get_mut(&stream_id) { + stream.control_mut().note_ack(out_of_order); + } + } } - self.record_activity(now); - self.record_stream_activity(stream_id_from_frame(&frame), now); - self.send_packet_ack(body.packet_id); + self.record_stream_activity(stream_id, now); + self.drain_committed_stream_frames(now, stream_id, emit); + if let Some(stream) = self.streams.get_mut(&stream_id) { + stream.control_mut().maybe_force_ack_for_progress(); + } + self.schedule_stream_ack(stream_id, now); + } - match frame { - StreamFrame::Open(frame) => self.handle_stream_open(now, frame, emit), - StreamFrame::Accept(frame) => self.handle_stream_accept_from_peer(now, frame, emit), - StreamFrame::Reject(frame) => self.handle_stream_reject_from_peer(frame, emit), - StreamFrame::Data(frame) => self.handle_stream_data(now, frame, emit), - StreamFrame::Credit(frame) => self.handle_stream_credit(now, frame, emit), - StreamFrame::Finish(frame) => self.handle_stream_finish(now, frame, emit), - StreamFrame::Reset(frame) => self.handle_stream_reset(now, frame, emit), + fn schedule_stream_ack(&mut self, stream_id: StreamId, now: Instant) { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + let control = stream.control_mut(); + if !control.ack_dirty { + return; + } + if control.ack_immediate || self.config.stream_ack_delay.is_zero() { + control.ack_delay_token = None; + return; } + if control.ack_delay_token.is_some() { + return; + } + let token = self.state.next_token(); + control.ack_delay_token = Some(token); + self.state.timeouts.push(Reverse(TimeoutEntry { + at: now + self.config.stream_ack_delay, + kind: TimeoutKind::StreamAckDelay { stream_id, token }, + })); + } + + fn drain_committed_stream_frames( + &mut self, + now: Instant, + stream_id: StreamId, + emit: &mut impl OutputFn, + ) { + loop { + let next = { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + stream.control_mut().pop_next_committable() + }; + let Some((_tx_seq, frame)) = next else { + break; + }; + if self + .streams + .get(&stream_id) + .is_some_and(StreamState::is_provisional) + && !matches!(frame, StreamFrame::Open(_)) + { + self.streams.remove(&stream_id); + self.send_ephemeral_reset(stream_id, ResetTarget::Both, ResetCode::Protocol); + return; + } + match frame { + StreamFrame::Open(frame) => self.handle_stream_open(now, frame, emit), + StreamFrame::Accept(frame) => self.handle_stream_accept_from_peer(now, frame, emit), + StreamFrame::Reject(frame) => self.handle_stream_reject_from_peer(frame, emit), + StreamFrame::Data(frame) => self.handle_stream_data(now, frame, emit), + StreamFrame::Reset(frame) => self.handle_stream_reset(now, frame, emit), + } + if !self.streams.contains_key(&stream_id) { + return; + } + } + self.maybe_reap_stream(stream_id, emit); } fn handle_stream_open( @@ -764,32 +851,40 @@ impl Engine { let StreamFrameOpen { stream_id, request_head, - response_max_offset, + request_prefix, } = frame; - if let Some(stream) = self.streams.get(&stream_id) { - if self.stream_matches_open(stream, &request_head, response_max_offset) { + let control = match self.streams.remove(&stream_id) { + Some(StreamState::Provisional(stream)) => stream.control, + Some(mut stream) => { + Self::queue_protocol_reset(&mut stream, emit); + self.streams.insert(stream_id, stream); return; } - self.send_ephemeral_reset(stream_id, ResetTarget::Both, ResetCode::Protocol); - return; - } + None => StreamControl::default(), + }; - let stream = StreamState::Responder(ResponderStream { + let mut stream = StreamState::Responder(ResponderStream { meta: StreamMeta { - key: StreamKey { stream_id }, - request_head: request_head.clone(), + stream_id, last_activity: now, }, - control: StreamControl::default(), - request: InboundState::new(0), - response: ResponderResponse::Pending { - initial_credit: response_max_offset, - }, + control, + request: InboundState::new(), + response: ResponderResponse::Pending, }); + if let Some(chunk) = request_prefix.as_ref() { + let Some(inbound) = stream.inbound_mut(Direction::Request) else { + return; + }; + if chunk.fin { + inbound.closed = true; + } + } self.streams.insert(stream_id, stream); emit(EngineOutput::InboundStreamOpened { stream_id, request_head, + request_prefix, }); } @@ -802,9 +897,10 @@ impl Engine { let StreamFrameAccept { stream_id, response_head, - request_max_offset, + response_prefix, } = frame; let mut protocol = false; + let mut response_prefix_output = None; { let Some(stream) = self.streams.get_mut(&stream_id) else { return; @@ -812,59 +908,49 @@ impl Engine { match stream { StreamState::Initiator(stream) => match &mut stream.accept { InitiatorAccept::Opening(waiter) => { - if matches!( - stream - .control - .awaiting - .as_ref() - .map(|awaiting| &awaiting.frame), - Some(AwaitingFrame::Control(StreamFrame::Open(_))) - ) { - stream.control.awaiting = None; - } - stream.request.remote_max_offset = request_max_offset; - stream.request.data_enabled = true; if let Some(open_id) = waiter.open_id.take() { emit(EngineOutput::OpenAccepted { open_id, stream_id, response_head: response_head.clone(), + response_prefix: response_prefix.clone(), }); } else { stream.response.closed = true; - stream - .control - .pending - .set_reset(ResetTarget::Response, ResetCode::Cancelled); + stream.control.queue_frame_front(reset_frame( + stream_id, + ResetTarget::Response, + ResetCode::Cancelled, + )); } stream.accept = InitiatorAccept::Open { response_head }; stream.meta.last_activity = now; + response_prefix_output = response_prefix.clone(); } InitiatorAccept::WaitingAccept(waiter) => { - stream.request.remote_max_offset = request_max_offset; - stream.request.data_enabled = true; if let Some(open_id) = waiter.open_id.take() { emit(EngineOutput::OpenAccepted { open_id, stream_id, response_head: response_head.clone(), + response_prefix: response_prefix.clone(), }); } else { stream.response.closed = true; - stream - .control - .pending - .set_reset(ResetTarget::Response, ResetCode::Cancelled); + stream.control.queue_frame_front(reset_frame( + stream_id, + ResetTarget::Response, + ResetCode::Cancelled, + )); } stream.accept = InitiatorAccept::Open { response_head }; stream.meta.last_activity = now; + response_prefix_output = response_prefix.clone(); } InitiatorAccept::Open { response_head: stored, } => { - if *stored != response_head - || stream.request.remote_max_offset != request_max_offset - { + if *stored != response_head { protocol = true; } } @@ -875,6 +961,21 @@ impl Engine { if protocol { self.send_ephemeral_reset(stream_id, ResetTarget::Both, ResetCode::Protocol); + return; + } + + if let Some(chunk) = response_prefix_output.as_ref() { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + let Some(inbound) = stream.inbound_mut(Direction::Response) else { + Self::queue_protocol_reset(stream, emit); + return; + }; + if chunk.fin && !inbound.closed { + inbound.closed = true; + self.maybe_reap_stream(stream_id, emit); + } } } @@ -909,7 +1010,7 @@ impl Engine { dir: Direction::Response, error: QlError::StreamRejected { code }, }); - stream.request.closed = true; + stream.request.close(); stream.response.closed = true; remove_after = true; } @@ -936,13 +1037,11 @@ impl Engine { let StreamFrameData { stream_id, dir, - offset, - bytes, + chunk, } = frame; let Some(stream) = self.streams.get_mut(&stream_id) else { return; }; - Self::note_setup_seen_from_remote(stream); if dir == Direction::Response && matches!( stream, @@ -962,88 +1061,20 @@ impl Engine { }; if inbound.closed { Self::queue_protocol_reset(stream, emit); - } else if offset < inbound.next_offset { - Self::queue_credit(stream, dir); } else { - let end = offset.saturating_add(bytes.len() as u64); - if offset != inbound.next_offset || end > inbound.max_offset { - Self::queue_protocol_reset(stream, emit); - } else { - inbound.next_offset = end; + if !chunk.bytes.is_empty() { emit(EngineOutput::InboundData { stream_id, dir, - bytes, + bytes: chunk.bytes, }); - Self::queue_credit(stream, dir); } - } - *stream.last_activity_mut() = now; - } - - fn handle_stream_credit( - &mut self, - now: Instant, - frame: StreamFrameCredit, - emit: &mut impl OutputFn, - ) { - let StreamFrameCredit { - stream_id, - dir, - recv_offset, - max_offset, - } = frame; - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - Self::note_setup_seen_from_remote(stream); - let Some(outbound) = stream.outbound_mut(dir) else { - Self::queue_protocol_reset(stream, emit); - return; - }; - let released_offset = outbound.released_offset; - let sent_offset = outbound.sent_offset; - if recv_offset < released_offset || recv_offset > sent_offset || max_offset < recv_offset { - Self::queue_protocol_reset(stream, emit); - } else { - outbound.released_offset = recv_offset; - outbound.remote_max_offset = outbound.remote_max_offset.max(max_offset); - emit(EngineOutput::ReleaseOutboundThrough { - stream_id, - dir, - recv_offset, - }); - if matches!( - stream.control().awaiting.as_ref().map(|awaiting| &awaiting.frame), - Some(AwaitingFrame::Data { offset, len, .. }) - if recv_offset >= offset.saturating_add(*len as u64) - ) { - stream.control_mut().awaiting = None; + if chunk.fin && !inbound.closed { + inbound.closed = true; + emit(EngineOutput::InboundFinished { stream_id, dir }); } } *stream.last_activity_mut() = now; - } - - fn handle_stream_finish( - &mut self, - now: Instant, - frame: StreamFrameFinish, - emit: &mut impl OutputFn, - ) { - let StreamFrameFinish { stream_id, dir } = frame; - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - Self::note_setup_seen_from_remote(stream); - let Some(inbound) = stream.inbound_mut(dir) else { - Self::queue_protocol_reset(stream, emit); - return; - }; - if !inbound.closed { - inbound.closed = true; - emit(EngineOutput::InboundFinished { stream_id, dir }); - } - *stream.last_activity_mut() = now; self.maybe_reap_stream(stream_id, emit); } @@ -1055,98 +1086,120 @@ impl Engine { ) { let StreamFrameReset { stream_id, - dir, + target, code, } = frame; let Some(stream) = self.streams.get_mut(&stream_id) else { return; }; - Self::note_setup_seen_from_remote(stream); - Self::apply_remote_reset(stream, dir, code, emit); + Self::apply_remote_reset(stream, target, code, emit); *stream.last_activity_mut() = now; self.maybe_reap_stream(stream_id, emit); } - fn process_packet_ack(&mut self, packet_id: PacketId, emit: &mut impl OutputFn) { - let key = self.streams.iter().find_map(|(key, stream)| { - stream - .control() - .awaiting - .as_ref() - .is_some_and(|awaiting| awaiting.packet_id == packet_id) - .then_some(*key) - }); - let Some(key) = key else { - return; - }; - let Some(stream) = self.streams.get_mut(&key) else { + fn process_stream_ack( + &mut self, + stream_id: StreamId, + ack: StreamAck, + emit: &mut impl OutputFn, + ) { + let Some(stream) = self.streams.get_mut(&stream_id) else { return; }; - let Some(awaiting) = stream.control_mut().awaiting.take() else { + let acked: Vec<_> = stream + .control() + .in_flight + .iter() + .map(|(tx_seq, _)| tx_seq) + .filter(|tx_seq| StreamControl::ack_covers(ack, *tx_seq)) + .collect(); + if acked.is_empty() { return; - }; + } + let mut acked_frames = Vec::with_capacity(acked.len()); + for tx_seq in acked { + if let Some(in_flight) = stream.control_mut().remove_in_flight(tx_seq) { + acked_frames.push(in_flight.frame); + } + } + let _ = stream; - let mut reap = false; - match awaiting.frame { - AwaitingFrame::Control(StreamFrame::Open(_)) => { - if let StreamState::Initiator(stream) = stream { - if let InitiatorAccept::Opening(waiter) = &stream.accept { - stream.accept = InitiatorAccept::WaitingAccept(OpenWaiter { - open_id: waiter.open_id, - open_timeout_token: waiter.open_timeout_token, - }); + for frame in acked_frames { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + match frame { + StreamFrame::Open(StreamFrameOpen { request_prefix, .. }) => { + if let StreamState::Initiator(stream) = stream { + if let InitiatorAccept::Opening(waiter) = &stream.accept { + stream.accept = InitiatorAccept::WaitingAccept(OpenWaiter { + open_id: waiter.open_id, + open_timeout_token: waiter.open_timeout_token, + }); + } + if request_prefix.as_ref().is_some_and(|chunk| chunk.fin) { + stream.request.close(); + emit(EngineOutput::OutboundClosed { + stream_id, + dir: Direction::Request, + }); + } } } - } - AwaitingFrame::Control(StreamFrame::Accept(_)) => { - if let StreamState::Responder(stream) = stream { - if let ResponderResponse::Accepted { body, .. } = &mut stream.response { - body.data_enabled = true; + StreamFrame::Accept(StreamFrameAccept { + response_prefix, .. + }) => { + if let StreamState::Responder(stream) = stream { + if response_prefix.as_ref().is_some_and(|chunk| chunk.fin) { + if let ResponderResponse::Accepted { body } = &mut stream.response { + body.close(); + emit(EngineOutput::OutboundClosed { + stream_id, + dir: Direction::Response, + }); + } + } } } - } - AwaitingFrame::Control(StreamFrame::Reject(_)) => { - reap = true; - } - AwaitingFrame::Control(StreamFrame::Finish(StreamFrameFinish { dir, .. })) => { - if let Some(outbound) = stream.outbound_mut(dir) { - outbound.closed = true; - emit(EngineOutput::OutboundClosed { - stream_id: key, - dir, - }); + StreamFrame::Reject(_) => {} + StreamFrame::Data(StreamFrameData { + dir, + chunk: BodyChunk { fin: true, .. }, + .. + }) => { + if let Some(outbound) = stream.outbound_mut(dir) { + outbound.close(); + emit(EngineOutput::OutboundClosed { stream_id, dir }); + } } - } - AwaitingFrame::Control(StreamFrame::Reset(StreamFrameReset { dir, code, .. })) => { - for outbound_dir in [Direction::Request, Direction::Response] { - let affects_outbound = matches!( - (dir, outbound_dir), - (ResetTarget::Request, Direction::Request) - | (ResetTarget::Response, Direction::Response) - | (ResetTarget::Both, _) - ); - if affects_outbound { - if let Some(outbound) = stream.outbound_mut(outbound_dir) { - outbound.closed = true; - emit(EngineOutput::OutboundFailed { - stream_id: key, - dir: outbound_dir, - error: QlError::StreamReset { + StreamFrame::Reset(StreamFrameReset { target, code, .. }) => { + for outbound_dir in [Direction::Request, Direction::Response] { + let affects_outbound = matches!( + (target, outbound_dir), + (ResetTarget::Request, Direction::Request) + | (ResetTarget::Response, Direction::Response) + | (ResetTarget::Both, _) + ); + if affects_outbound { + if let Some(outbound) = stream.outbound_mut(outbound_dir) { + outbound.close(); + emit(EngineOutput::OutboundFailed { + stream_id, dir: outbound_dir, - code, - }, - }); + error: QlError::StreamReset { + dir: outbound_dir, + code, + }, + }); + } } } } + StreamFrame::Data(_) => {} } - AwaitingFrame::Control(StreamFrame::Data(_) | StreamFrame::Credit(_)) => {} - AwaitingFrame::Data { .. } => {} } - if reap { - self.maybe_reap_stream(key, emit); - } + self.maybe_reap_stream(stream_id, emit); } fn drive_streams(&mut self, now: Instant, emit: &mut impl OutputFn) { @@ -1166,124 +1219,188 @@ impl Engine { ) { match stream { StreamState::Initiator(stream) => { - let action = Self::plan_drive_outbound( + Self::drive_stream_outbound( config, - stream.meta.key, + state, + stream.meta.stream_id, &mut stream.control, Some(&mut stream.request), emit, ); - if let Some(frame) = action { - state.enqueue_control_frame( - config, - stream.meta.key, - &mut stream.control, - frame, - 0, - ); - } } StreamState::Responder(stream) => { - let key = stream.meta.key; + let stream_id = stream.meta.stream_id; match &mut stream.response { ResponderResponse::Accepted { body, .. } => { - let action = Self::plan_drive_outbound( + Self::drive_stream_outbound( config, - key, + state, + stream_id, &mut stream.control, Some(body), emit, ); - if let Some(frame) = action { - state.enqueue_control_frame(config, key, &mut stream.control, frame, 0); - } } _ => { - let action = - Self::plan_drive_outbound(config, key, &mut stream.control, None, emit); - if let Some(frame) = action { - state.enqueue_control_frame(config, key, &mut stream.control, frame, 0); - } + Self::drive_stream_outbound( + config, + state, + stream_id, + &mut stream.control, + None, + emit, + ); } } } + StreamState::Provisional(stream) => Self::drive_stream_outbound( + config, + state, + stream.meta.stream_id, + &mut stream.control, + None, + emit, + ), } } - fn plan_drive_outbound( + fn drive_stream_outbound( config: &EngineConfig, - key: StreamKey, + state: &mut EngineState, + stream_id: StreamId, control: &mut StreamControl, - outbound: Option<&mut OutboundState>, + mut outbound: Option<&mut OutboundState>, emit: &mut impl OutputFn, - ) -> Option { - let stream_id = key.stream_id; - if control.awaiting.is_some() { - return None; - } - if let Some(frame) = control.pending.take_next_control(stream_id) { - return Some(frame); - } - let outbound = outbound?; - if outbound.can_request_data() { - let max_len = (outbound.remote_max_offset - outbound.sent_offset) - .min(config.max_payload_bytes as u64) as usize; - if max_len > 0 { - outbound.pending_pull = Some(PendingPull { - offset: outbound.sent_offset, - max_len, - }); + ) { + loop { + if control.send_window_has_space() { + if let Some(frame) = control.pending.pop_front() { + Self::enqueue_stream_frame(config, state, control, frame, 0, false); + continue; + } + } + if control.ack_dirty && control.ack_immediate && control.ack_outbound_token.is_none() { + Self::enqueue_stream_ack_body(config, state, control, stream_id, false); + continue; + } + if !control.send_window_has_space() { + return; + } + + let Some(outbound) = outbound.as_deref_mut() else { + return; + }; + if outbound.request_data() { emit(EngineOutput::NeedOutboundData { stream_id, dir: outbound.dir, - offset: outbound.sent_offset, - max_len, }); + return; } - return None; - } - if outbound.data_enabled - && !outbound.closed - && outbound.pending_pull.is_none() - && outbound - .final_offset - .is_some_and(|final_offset| final_offset == outbound.sent_offset) - { - outbound.closed = true; - return Some(StreamFrame::Finish(StreamFrameFinish { - stream_id, - dir: outbound.dir, - })); + if outbound.queue_fin() { + Self::enqueue_stream_frame( + config, + state, + control, + StreamFrame::Data(StreamFrameData { + stream_id, + dir: outbound.dir, + chunk: BodyChunk { + bytes: Vec::new(), + fin: true, + }, + }), + 0, + false, + ); + continue; + } + return; } - None } - fn queue_credit(stream: &mut StreamState, dir: Direction) { - let stream_id = stream.key().stream_id; - let (recv_offset, max_offset) = { - let Some(inbound) = stream.inbound_mut(dir) else { - return; - }; - (inbound.next_offset, inbound.max_offset) - }; - stream.control_mut().pending.set_credit(StreamFrameCredit { - stream_id, - dir, - recv_offset, - max_offset, + fn enqueue_stream_frame( + config: &EngineConfig, + state: &mut EngineState, + control: &mut StreamControl, + frame: StreamFrame, + attempt: u8, + priority: bool, + ) { + let tx_seq = control.take_tx_seq(); + Self::enqueue_stream_frame_with_seq( + config, state, control, tx_seq, frame, attempt, priority, + ); + } + + fn enqueue_stream_frame_with_seq( + config: &EngineConfig, + state: &mut EngineState, + control: &mut StreamControl, + tx_seq: StreamSeq, + frame: StreamFrame, + attempt: u8, + priority: bool, + ) { + control.insert_in_flight(InFlightFrame { + tx_seq, + frame: frame.clone(), + attempt, }); + let ack = control.ack_dirty.then(|| control.current_ack()); + if ack.is_some() { + control.clear_ack_schedule(); + } + let valid_until = wire::now_secs().saturating_add(config.packet_expiration.as_secs()); + state.enqueue_stream_body( + config, + priority, + StreamBody::Message(StreamMessage { + tx_seq, + ack, + valid_until, + frame, + }), + ); + } + + fn enqueue_stream_ack_body( + config: &EngineConfig, + state: &mut EngineState, + control: &mut StreamControl, + stream_id: StreamId, + priority: bool, + ) { + if !control.ack_dirty { + return; + } + let ack = control.current_ack(); + control.clear_ack_schedule(); + let valid_until = wire::now_secs().saturating_add(config.packet_expiration.as_secs()); + let token = state.enqueue_stream_body( + config, + priority, + StreamBody::Ack(StreamAckBody { + stream_id, + ack, + valid_until, + }), + ); + control.ack_outbound_token = Some(token); } fn queue_protocol_reset(stream: &mut StreamState, emit: &mut impl OutputFn) { - let stream_id = stream.key().stream_id; - stream - .control_mut() - .pending - .set_reset(ResetTarget::Both, ResetCode::Protocol); + let stream_id = stream.stream_id(); + let control = stream.control_mut(); + control.clear_transient_buffers(); + control.queue_frame_front(reset_frame( + stream_id, + ResetTarget::Both, + ResetCode::Protocol, + )); for dir in [Direction::Request, Direction::Response] { if let Some(outbound) = stream.outbound_mut(dir) { - outbound.closed = true; - outbound.pending_pull = None; + outbound.close(); emit(EngineOutput::OutboundFailed { stream_id, dir, @@ -1317,41 +1434,13 @@ impl Engine { } } - fn note_setup_seen_from_remote(stream: &mut StreamState) { - if let StreamState::Responder(stream) = stream { - if matches!( - stream - .control - .awaiting - .as_ref() - .map(|awaiting| &awaiting.frame), - Some(AwaitingFrame::Control(StreamFrame::Accept(_))) - ) { - stream.control.awaiting = None; - if let ResponderResponse::Accepted { body, .. } = &mut stream.response { - body.data_enabled = true; - } - } - if matches!( - stream - .control - .awaiting - .as_ref() - .map(|awaiting| &awaiting.frame), - Some(AwaitingFrame::Control(StreamFrame::Reject(_))) - ) { - stream.control.awaiting = None; - } - } - } - fn apply_remote_reset( stream: &mut StreamState, - dir: ResetTarget, + target: ResetTarget, code: ResetCode, emit: &mut impl OutputFn, ) { - let stream_id = stream.key().stream_id; + let stream_id = stream.stream_id(); let request_error = QlError::StreamReset { dir: Direction::Request, code, @@ -1361,7 +1450,7 @@ impl Engine { code, }; - if matches!(dir, ResetTarget::Request | ResetTarget::Both) { + if matches!(target, ResetTarget::Request | ResetTarget::Both) { if let Some(inbound) = stream.inbound_mut(Direction::Request) { if !inbound.closed { inbound.closed = true; @@ -1373,8 +1462,7 @@ impl Engine { } } if let Some(outbound) = stream.outbound_mut(Direction::Request) { - outbound.closed = true; - outbound.pending_pull = None; + outbound.close(); emit(EngineOutput::OutboundFailed { stream_id, dir: Direction::Request, @@ -1382,7 +1470,7 @@ impl Engine { }); } } - if matches!(dir, ResetTarget::Response | ResetTarget::Both) { + if matches!(target, ResetTarget::Response | ResetTarget::Both) { if let Some(inbound) = stream.inbound_mut(Direction::Response) { if !inbound.closed { inbound.closed = true; @@ -1394,8 +1482,7 @@ impl Engine { } } if let Some(outbound) = stream.outbound_mut(Direction::Response) { - outbound.closed = true; - outbound.pending_pull = None; + outbound.close(); emit(EngineOutput::OutboundFailed { stream_id, dir: Direction::Response, @@ -1411,7 +1498,7 @@ impl Engine { emit(EngineOutput::OpenFailed { open_id, stream_id, - error: match dir { + error: match target { ResetTarget::Request => request_error, _ => response_error, }, @@ -1434,62 +1521,48 @@ impl Engine { } } - fn stream_matches_open( - &self, - stream: &StreamState, - request_head: &[u8], - response_max_offset: u64, - ) -> bool { - match stream { - StreamState::Responder(state) => match &state.response { - ResponderResponse::Pending { initial_credit } - | ResponderResponse::Accepted { initial_credit, .. } - | ResponderResponse::Rejecting { initial_credit } => { - state.meta.request_head == request_head - && *initial_credit == response_max_offset + fn clear_ack_outbound_token(&mut self, token: Token, retry: bool) { + for stream in self.streams.values_mut() { + let control = stream.control_mut(); + if control.ack_outbound_token == Some(token) { + control.ack_outbound_token = None; + if retry { + control.note_ack(true); } - }, - _ => false, + break; + } } } - fn send_packet_ack(&mut self, acked_packet: PacketId) { - let packet_id = self.state.next_packet_id(); - let valid_until = wire::now_secs().saturating_add(self.config.packet_expiration.as_secs()); - self.enqueue_stream_body( - None, - None, - false, - true, - StreamBody { - packet_id, - valid_until, - packet_ack: Some(PacketAck { - packet_id: acked_packet, - }), - frame: None, - }, - ); + fn note_sent_stream_ack(&mut self, body: &StreamBody) { + let (stream_id, ack) = match body { + StreamBody::Ack(StreamAckBody { stream_id, ack, .. }) => (*stream_id, *ack), + StreamBody::Message(StreamMessage { + frame, + ack: Some(ack), + .. + }) => (frame.stream_id(), *ack), + StreamBody::Message(_) => return, + }; + if let Some(stream) = self.streams.get_mut(&stream_id) { + stream.control_mut().note_ack_sent(ack); + } } fn send_ephemeral_reset(&mut self, stream_id: StreamId, dir: ResetTarget, code: ResetCode) { - let packet_id = self.state.next_packet_id(); let valid_until = wire::now_secs().saturating_add(self.config.packet_expiration.as_secs()); self.enqueue_stream_body( - None, - None, - false, true, - StreamBody { - packet_id, + StreamBody::Message(StreamMessage { + tx_seq: StreamSeq::START, + ack: None, valid_until, - packet_ack: None, - frame: Some(StreamFrame::Reset(StreamFrameReset { + frame: StreamFrame::Reset(StreamFrameReset { stream_id, - dir, + target: dir, code, - })), - }, + }), + }), ); } @@ -1498,22 +1571,8 @@ impl Engine { .enqueue_handshake_message(&self.config, token, deadline, bytes); } - fn enqueue_stream_body( - &mut self, - stream_id: Option, - packet_id: Option, - track_ack: bool, - priority: bool, - body: StreamBody, - ) { - self.state.enqueue_stream_body( - &self.config, - stream_id, - packet_id, - track_ack, - priority, - body, - ); + fn enqueue_stream_body(&mut self, priority: bool, body: StreamBody) -> Token { + self.state.enqueue_stream_body(&self.config, priority, body) } fn handle_hello( @@ -1525,37 +1584,48 @@ impl Engine { emit: &mut impl OutputFn, ) { let action = match self.state.peer.as_ref() { - Some(entry) => match &entry.session { - PeerSession::Initiator { - hello: local_hello, .. - } => { - if peer_hello_wins(local_hello, crypto.xid(), hello, peer) { - HelloAction::StartResponder - } else { - HelloAction::Ignore - } + Some(entry) => { + if handshake::verify_hello(peer, self.identity.xid, &entry.signing_key, hello) + .is_err() + { + return; } - PeerSession::Responder { - hello: stored, - reply, - deadline, - .. - } => { - if stored.nonce == wire::nonce_from_archived(&hello.nonce) { - HelloAction::ResendReply { - reply: reply.clone(), - deadline: *deadline, + match &entry.session { + PeerSession::Initiator { + hello: local_hello, .. + } => { + if peer_hello_wins(local_hello, self.identity.xid, hello, peer) { + HelloAction::StartResponder + } else { + HelloAction::Ignore } - } else { + } + PeerSession::Responder { + hello: stored, + reply, + deadline, + .. + } => { + if stored.nonce == (&hello.nonce).into() { + HelloAction::ResendReply { + reply: reply.clone(), + deadline: *deadline, + } + } else { + HelloAction::StartResponder + } + } + PeerSession::Disconnected | PeerSession::Connected { .. } => { HelloAction::StartResponder } } - PeerSession::Disconnected | PeerSession::Connected { .. } => { - HelloAction::StartResponder - } - }, + } None => return, }; + let meta: ControlMeta = (&hello.meta).into(); + if self.is_replayed_control(peer, meta) { + return; + } match action { HelloAction::StartResponder => { @@ -1564,7 +1634,7 @@ impl Engine { HelloAction::ResendReply { reply, deadline } => { let record = QlRecord { header: QlHeader { - sender: crypto.xid(), + sender: self.identity.xid, recipient: peer, }, payload: QlPayload::Handshake(HandshakeRecord::HelloReply(reply)), @@ -1581,11 +1651,10 @@ impl Engine { now: Instant, peer: XID, reply: &wire::handshake::ArchivedHelloReply, - crypto: &impl QlCrypto, emit: &mut impl OutputFn, ) { - let token = self.state.next_token(); let deadline = now + self.config.handshake_timeout; + let confirm_meta = self.next_control_meta(self.config.handshake_timeout); let res = { let Some(peer_record) = self.state.peer.as_ref() else { return; @@ -1603,29 +1672,18 @@ impl Engine { return; } handshake::build_confirm( - crypto, - crypto.xid(), + &self.identity, peer, &peer_record.signing_key, hello, reply, session_key, + confirm_meta, ) .map(|(confirm, session_key)| (hello.clone(), confirm, session_key)) }; - let confirm = match res { - Ok((hello, confirm, session_key)) => { - if let Some(entry) = self.state.peer.as_mut() { - entry.session = PeerSession::Initiator { - handshake_token: token, - hello, - session_key, - deadline, - stage: InitiatorStage::SendingConfirm, - }; - } - confirm - } + let (hello, confirm, session_key) = match res { + Ok(result) => result, Err(_) => { if let Some(entry) = self.state.peer.as_mut() { entry.session = PeerSession::Disconnected; @@ -1634,10 +1692,24 @@ impl Engine { return; } }; + let reply_meta: ControlMeta = (&reply.meta).into(); + if self.is_replayed_control(peer, reply_meta) { + return; + } + let token = self.state.next_token(); + if let Some(entry) = self.state.peer.as_mut() { + entry.session = PeerSession::Initiator { + handshake_token: token, + hello, + session_key, + deadline, + stage: InitiatorStage::SendingConfirm, + }; + } let record = QlRecord { header: QlHeader { - sender: crypto.xid(), + sender: self.identity.xid, recipient: peer, }, payload: QlPayload::Handshake(HandshakeRecord::Confirm(confirm)), @@ -1650,7 +1722,6 @@ impl Engine { now: Instant, peer: XID, confirm: &wire::handshake::ArchivedConfirm, - crypto: &impl QlCrypto, emit: &mut impl OutputFn, ) { let Some(peer_record) = self.state.peer.as_ref() else { @@ -1668,7 +1739,7 @@ impl Engine { match handshake::finalize_confirm( peer, - crypto.xid(), + self.identity.xid, &peer_record.signing_key, hello, reply, @@ -1676,6 +1747,10 @@ impl Engine { secrets, ) { Ok(session_key) => { + let meta: ControlMeta = (&confirm.meta).into(); + if self.is_replayed_control(peer, meta) { + return; + } if let Some(entry) = self.state.peer.as_mut() { entry.session = PeerSession::Connected { session_key, @@ -1702,16 +1777,19 @@ impl Engine { crypto: &impl QlCrypto, emit: &mut impl OutputFn, ) { + let reply_meta = self.next_control_meta(self.config.handshake_timeout); let res = { let Some(peer_record) = self.state.peer.as_ref() else { return; }; handshake::respond_hello( + &self.identity, crypto, peer, - crypto.xid(), + &peer_record.signing_key, &peer_record.encapsulation_key, hello, + reply_meta, ) }; let (reply, secrets) = match res { @@ -1747,7 +1825,7 @@ impl Engine { let record = QlRecord { header: QlHeader { - sender: crypto.xid(), + sender: self.identity.xid, recipient: peer, }, payload: QlPayload::Handshake(HandshakeRecord::HelloReply(reply)), @@ -1759,7 +1837,7 @@ impl Engine { let Some(peer) = self.state.peer.as_ref().map(|peer| peer.peer) else { return; }; - let packet_id = self.state.next_packet_id(); + let meta = self.next_control_meta(self.config.packet_expiration); let token = self.state.next_token(); let deadline = now + self.config.packet_expiration; let message = { @@ -1771,15 +1849,11 @@ impl Engine { }; heartbeat::encrypt_heartbeat( QlHeader { - sender: crypto.xid(), + sender: self.identity.xid, recipient: peer, }, session_key, - HeartbeatBody { - packet_id, - valid_until: wire::now_secs() - .saturating_add(self.config.packet_expiration.as_secs()), - }, + HeartbeatBody { meta }, next_encrypted_message_nonce(crypto), ) }; @@ -1820,8 +1894,14 @@ impl Engine { fn drop_outbound(&mut self, emit: &mut impl OutputFn) { while let Some(message) = self.state.outbound.pop_front() { - if let Some(stream_id) = message.stream_id { - self.fail_stream_by_id(stream_id, QlError::SendFailed, emit); + if let QueuedPayload::Stream { body } = message.payload { + match body { + StreamBody::Ack(_) => self.clear_ack_outbound_token(message.token, false), + StreamBody::Message(message) => { + let stream_id = message.frame.stream_id(); + self.fail_stream_by_id(stream_id, QlError::SendFailed, emit); + } + } } } } @@ -1886,6 +1966,7 @@ impl Engine { }); } } + StreamState::Provisional(_) => {} } emit(EngineOutput::StreamReaped { stream_id }); } @@ -1896,7 +1977,6 @@ impl Engine { }; self.drop_outbound(emit); self.abort_streams(QlError::SendFailed, emit); - self.state.replay_cache.clear_peer(peer); self.state.peer = None; emit(EngineOutput::PeerStatusChanged { peer, @@ -1919,9 +1999,17 @@ impl Engine { match entry.kind { TimeoutKind::Outbound { token } => { let mut timed_out_stream = None; + let mut timed_out_ack = false; self.state.outbound.retain(|message| { if message.token == token { - timed_out_stream = message.stream_id; + if let QueuedPayload::Stream { body } = &message.payload { + match body { + StreamBody::Ack(_) => timed_out_ack = true, + StreamBody::Message(message) => { + timed_out_stream = Some(message.frame.stream_id()) + } + } + } false } else { true @@ -1929,6 +2017,8 @@ impl Engine { }); if let Some(stream_id) = timed_out_stream { self.fail_stream_by_id(stream_id, QlError::SendFailed, emit); + } else if timed_out_ack { + self.clear_ack_outbound_token(token, true); } } TimeoutKind::Handshake { token } => { @@ -2001,82 +2091,66 @@ impl Engine { self.fail_stream_by_id(stream_id, QlError::Timeout, emit); } } - TimeoutKind::StreamPacket { + TimeoutKind::StreamAckDelay { stream_id, token } => { + let should_flush = self + .streams + .get(&stream_id) + .and_then(|stream| stream.control().ack_delay_token) + .is_some_and(|ack_token| ack_token == token); + if should_flush { + if let Some(stream) = self.streams.get_mut(&stream_id) { + let control = stream.control_mut(); + control.ack_delay_token = None; + control.ack_immediate = true; + } + } + } + TimeoutKind::StreamProvisional { stream_id, token } => { + let should_reset = self + .streams + .get(&stream_id) + .and_then(StreamState::provisional_timeout_token) + .is_some_and(|stream_token| stream_token == token); + if should_reset { + self.streams.remove(&stream_id); + self.send_ephemeral_reset( + stream_id, + ResetTarget::Both, + ResetCode::Protocol, + ); + } + } + TimeoutKind::StreamMessage { stream_id, - packet_id, + tx_seq, attempt, } => { - let mut timed_out = false; - let mut retransmit_control = None; - let mut retransmit_data = None; - { - let Some(stream) = self.streams.get_mut(&stream_id) else { - continue; - }; - let Some(retransmit) = - stream.control().awaiting.as_ref().and_then(|awaiting| { - if awaiting.packet_id != packet_id || awaiting.attempt != attempt { - return None; - } - Some(match &awaiting.frame { - AwaitingFrame::Control(frame) => { - EitherRetransmit::Control(frame.clone()) - } - AwaitingFrame::Data { dir, offset, len } => { - EitherRetransmit::Data { - dir: *dir, - offset: *offset, - len: *len, - } - } - }) + let Some(frame) = self.streams.get(&stream_id).and_then(|stream| { + stream + .control() + .in_flight + .get(&tx_seq) + .and_then(|in_flight| { + (in_flight.attempt == attempt).then_some(in_flight.frame.clone()) }) - else { - continue; - }; + }) else { + continue; + }; - if attempt >= self.config.stream_retry_limit { - timed_out = true; - } else { - match retransmit { - EitherRetransmit::Control(frame) => { - retransmit_control = Some(frame) - } - EitherRetransmit::Data { dir, offset, len } => { - retransmit_data = Some((dir, offset, len)) - } - } - } - } - if timed_out { + if attempt >= self.config.stream_retry_limit { self.fail_stream_by_id(stream_id, QlError::Timeout, emit); - } else if let Some(frame) = retransmit_control { - let (streams, state) = (&mut self.streams, &mut self.state); - if let Some(stream) = streams.get_mut(&stream_id) { - let key = stream.key(); - state.enqueue_control_frame( + } else { + if let Some(stream) = self.streams.get_mut(&stream_id) { + Self::enqueue_stream_frame_with_seq( &self.config, - key, + &mut self.state, stream.control_mut(), + tx_seq, frame, attempt.saturating_add(1), + true, ); } - } else if let Some((dir, offset, len)) = retransmit_data { - if let Some(stream) = self.streams.get_mut(&stream_id) { - if let Some(outbound) = stream.outbound_mut(dir) { - outbound.pending_pull = Some(PendingPull { - offset, - max_len: len, - }); - emit(EngineOutput::NeedOutboundData { - stream_id, - dir, - offset, - max_len: len, - }); - } - } } } } @@ -2094,6 +2168,7 @@ impl Engine { if self.state.write_in_flight == Some(token) { self.state.write_in_flight = None; } + self.clear_ack_outbound_token(token, result.is_err()); if let Err(error) = result { if let Some(tracked) = tracked { self.fail_stream_by_id(tracked.stream_id, error.clone(), emit); @@ -2141,16 +2216,14 @@ impl Engine { let attempt = self .streams .get(&tracked.stream_id) - .and_then(|stream| stream.control().awaiting.as_ref()) - .and_then(|awaiting| { - (awaiting.packet_id == tracked.packet_id).then_some(awaiting.attempt) - }) + .and_then(|stream| stream.control().in_flight.get(&tracked.tx_seq)) + .map(|in_flight| in_flight.attempt) .unwrap_or(0); self.state.timeouts.push(Reverse(TimeoutEntry { - at: now + self.config.packet_ack_timeout, - kind: TimeoutKind::StreamPacket { + at: now + self.config.stream_ack_timeout, + kind: TimeoutKind::StreamMessage { stream_id: tracked.stream_id, - packet_id: tracked.packet_id, + tx_seq: tracked.tx_seq, attempt, }, })); @@ -2162,27 +2235,37 @@ impl Engine { return; } while let Some(message) = self.state.outbound.pop_front() { - let bytes = match message.payload { - QueuedPayload::PreEncoded(bytes) => bytes, - QueuedPayload::StreamBody(body) => { - let Some(peer) = self.state.peer.as_ref() else { - if let Some(stream_id) = message.stream_id { - self.fail_stream_by_id(stream_id, QlError::SendFailed, emit); - } - continue; - }; - let Some(session_key) = peer.session.session_key() else { - if let Some(stream_id) = message.stream_id { - self.fail_stream_by_id(stream_id, QlError::SendFailed, emit); + let bytes = match &message.payload { + QueuedPayload::PreEncoded(bytes) => bytes.clone(), + QueuedPayload::Stream { body } => { + let Some((recipient, session_key)) = + self.state.peer.as_ref().and_then(|peer| { + peer.session + .session_key() + .map(|key| (peer.peer, key.clone())) + }) + else { + match body { + StreamBody::Ack(_) => { + self.clear_ack_outbound_token(message.token, false) + } + StreamBody::Message(stream_message) => { + self.fail_stream_by_id( + stream_message.frame.stream_id(), + QlError::SendFailed, + emit, + ); + } } continue; }; + self.note_sent_stream_ack(body); let record = encrypt_stream( QlHeader { - sender: crypto.xid(), - recipient: peer.peer, + sender: self.identity.xid, + recipient, }, - session_key, + &session_key, body, next_encrypted_message_nonce(crypto), ); @@ -2190,16 +2273,14 @@ impl Engine { } }; - let tracked = if message.track_ack { - message - .stream_id - .zip(message.packet_id) - .map(|(stream_id, packet_id)| TrackedWrite { - stream_id, - packet_id, - }) - } else { - None + let tracked = match &message.payload { + QueuedPayload::Stream { + body: StreamBody::Message(stream_message), + } => Some(TrackedWrite { + stream_id: stream_message.frame.stream_id(), + tx_seq: stream_message.tx_seq, + }), + _ => None, }; self.state.write_in_flight = Some(message.token); emit(EngineOutput::WriteMessage { @@ -2226,7 +2307,7 @@ fn peer_hello_wins( ) -> bool { use std::cmp::Ordering; - let peer_nonce = wire::nonce_from_archived(&peer_hello.nonce); + let peer_nonce: bc_components::Nonce = (&peer_hello.nonce).into(); match peer_nonce.data().cmp(local_hello.nonce.data()) { Ordering::Less => true, Ordering::Greater => false, @@ -2234,18 +2315,6 @@ fn peer_hello_wins( } } -fn stream_id_from_frame(frame: &StreamFrame) -> StreamId { - match frame { - StreamFrame::Open(frame) => frame.stream_id, - StreamFrame::Accept(frame) => frame.stream_id, - StreamFrame::Reject(frame) => frame.stream_id, - StreamFrame::Data(frame) => frame.stream_id, - StreamFrame::Credit(frame) => frame.stream_id, - StreamFrame::Finish(frame) => frame.stream_id, - StreamFrame::Reset(frame) => frame.stream_id, - } -} - fn reset_target_for_dir(dir: Direction) -> ResetTarget { match dir { Direction::Request => ResetTarget::Request, diff --git a/ql2/src/engine/replay_cache.rs b/ql2/src/engine/replay_cache.rs index 292f1740..7643c5c8 100644 --- a/ql2/src/engine/replay_cache.rs +++ b/ql2/src/engine/replay_cache.rs @@ -8,27 +8,15 @@ use bc_components::XID; use crate::PacketId; -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum ReplayNamespace { - Peer, - Local, - Transfer, -} - #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct ReplayKey { pub peer: XID, - pub namespace: ReplayNamespace, pub packet_id: PacketId, } impl ReplayKey { - pub const fn new(peer: XID, namespace: ReplayNamespace, packet_id: PacketId) -> Self { - Self { - peer, - namespace, - packet_id, - } + pub const fn new(peer: XID, packet_id: PacketId) -> Self { + Self { peer, packet_id } } } @@ -142,7 +130,7 @@ mod tests { fn check_and_store_detects_replay() { let mut cache = ReplayCache::new(); let peer = peer_with_byte(1); - let key = ReplayKey::new(peer, ReplayNamespace::Peer, PacketId(1)); + let key = ReplayKey::new(peer, PacketId(1)); let now_secs = 100; let expires_at = 110; @@ -157,8 +145,8 @@ mod tests { let expired_at = 99; let future_at = 110; - let key_old = ReplayKey::new(peer_with_byte(2), ReplayNamespace::Peer, PacketId(2)); - let key_new = ReplayKey::new(peer_with_byte(3), ReplayNamespace::Peer, PacketId(3)); + let key_old = ReplayKey::new(peer_with_byte(2), PacketId(2)); + let key_new = ReplayKey::new(peer_with_byte(3), PacketId(3)); cache.add(key_old, expired_at); cache.add(key_new, future_at); @@ -176,8 +164,8 @@ mod tests { let peer_a = peer_with_byte(4); let peer_b = peer_with_byte(5); - let key_a = ReplayKey::new(peer_a, ReplayNamespace::Peer, PacketId(4)); - let key_b = ReplayKey::new(peer_b, ReplayNamespace::Peer, PacketId(5)); + let key_a = ReplayKey::new(peer_a, PacketId(4)); + let key_b = ReplayKey::new(peer_b, PacketId(5)); cache.add(key_a, expires_at); cache.add(key_b, expires_at); diff --git a/ql2/src/engine/ring.rs b/ql2/src/engine/ring.rs new file mode 100644 index 00000000..d1f4bf64 --- /dev/null +++ b/ql2/src/engine/ring.rs @@ -0,0 +1,394 @@ +use std::array; + +use crate::wire::StreamSeq; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SeqRingInsertError { + OutOfWindow, + Occupied, +} + +#[derive(Debug)] +pub struct SeqRing { + base_seq: StreamSeq, + head: usize, + len: usize, + slots: [Option; N], +} + +impl SeqRing { + pub fn new(base_seq: StreamSeq) -> Self { + Self { + base_seq, + head: 0, + len: 0, + slots: array::from_fn(|_| None), + } + } + + pub fn base_seq(&self) -> StreamSeq { + self.base_seq + } + + pub fn len(&self) -> usize { + self.len + } + + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + pub fn clear_with_base(&mut self, base_seq: StreamSeq) { + for slot in &mut self.slots { + let _ = slot.take(); + } + self.base_seq = base_seq; + self.head = 0; + self.len = 0; + } + + pub fn contains_key(&self, seq: &StreamSeq) -> bool { + self.get(seq).is_some() + } + + pub fn accepts_seq(&self, seq: StreamSeq) -> bool { + self.offset_for(seq).is_some() + } + + pub fn get(&self, seq: &StreamSeq) -> Option<&T> { + let index = self.index_for(*seq)?; + self.slots[index].as_ref() + } + + pub fn get_mut(&mut self, seq: &StreamSeq) -> Option<&mut T> { + let index = self.index_for(*seq)?; + self.slots[index].as_mut() + } + + pub fn insert(&mut self, seq: StreamSeq, value: T) -> Result<(), SeqRingInsertError> { + let index = self.index_for(seq).ok_or(SeqRingInsertError::OutOfWindow)?; + if self.slots[index].is_some() { + return Err(SeqRingInsertError::Occupied); + } + self.slots[index] = Some(value); + self.len += 1; + Ok(()) + } + + pub fn set(&mut self, seq: StreamSeq, value: T) -> Result, SeqRingInsertError> { + let index = self.index_for(seq).ok_or(SeqRingInsertError::OutOfWindow)?; + let previous = self.slots[index].replace(value); + if previous.is_none() { + self.len += 1; + } + Ok(previous) + } + + pub fn remove(&mut self, seq: &StreamSeq) -> Option { + let index = self.index_for(*seq)?; + let value = self.slots[index].take(); + if value.is_some() { + self.len -= 1; + } + value + } + + pub fn take_front(&mut self) -> Option<(StreamSeq, T)> { + let value = self.slots[self.head].take()?; + let seq = self.base_seq; + self.len -= 1; + self.head = self.next_index(self.head); + self.base_seq = self.base_seq.next(); + Some((seq, value)) + } + + pub fn advance_empty_front_until(&mut self, limit_seq: StreamSeq) { + while self.base_seq.serial_lt(limit_seq) && self.slots[self.head].is_none() { + self.head = self.next_index(self.head); + self.base_seq = self.base_seq.next(); + } + } + + pub fn drain_front(&mut self) -> SeqRingDrain<'_, N, T> { + SeqRingDrain { ring: self } + } + + pub fn iter(&self) -> SeqRingIter<'_, N, T> { + SeqRingIter { + ring: self, + offset: 0, + } + } + + pub fn bitmap(&self) -> u8 { + debug_assert!(N <= 8); + let mut bitmap = 0u8; + for offset in 0..N { + let index = self.index_for_offset(offset); + if self.slots[index].is_some() { + bitmap |= 1u8 << offset; + } + } + bitmap + } + + fn index_for(&self, seq: StreamSeq) -> Option { + let offset = self.offset_for(seq)?; + Some(self.index_for_offset(offset)) + } + + fn offset_for(&self, seq: StreamSeq) -> Option { + let offset = self.base_seq.forward_distance_to(seq)? as usize; + (offset < N).then_some(offset) + } + + fn index_for_offset(&self, offset: usize) -> usize { + (self.head + offset) % N + } + + fn next_index(&self, index: usize) -> usize { + (index + 1) % N + } +} + +pub struct SeqRingIter<'a, const N: usize, T> { + ring: &'a SeqRing, + offset: usize, +} + +impl<'a, const N: usize, T> Iterator for SeqRingIter<'a, N, T> { + type Item = (StreamSeq, &'a T); + + fn next(&mut self) -> Option { + while self.offset < N { + let offset = self.offset; + self.offset += 1; + let index = self.ring.index_for_offset(offset); + if let Some(value) = self.ring.slots[index].as_ref() { + let seq = self.ring.base_seq.add(offset as u32); + return Some((seq, value)); + } + } + None + } +} + +pub struct SeqRingDrain<'a, const N: usize, T> { + ring: &'a mut SeqRing, +} + +impl<'a, const N: usize, T> Iterator for SeqRingDrain<'a, N, T> { + type Item = (StreamSeq, T); + + fn next(&mut self) -> Option { + self.ring.take_front() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + engine::stream::{BufferIncomingResult, InFlightFrame, StreamControl}, + wire::stream::{ + BodyChunk, Direction, StreamAck, StreamFrame, StreamFrameData, StreamFrameOpen, + }, + StreamId, + }; + + fn data_frame(stream_id: StreamId, tx_seq: u32, byte: u8) -> (StreamSeq, StreamFrame) { + ( + StreamSeq(tx_seq), + StreamFrame::Data(StreamFrameData { + stream_id, + dir: Direction::Request, + chunk: BodyChunk { + bytes: vec![byte], + fin: false, + }, + }), + ) + } + + #[test] + fn seq_ring_drain_front_takes_contiguous_items_in_order() { + let mut ring = SeqRing::<8, u64>::new(StreamSeq(1)); + ring.insert(StreamSeq(2), 20).unwrap(); + ring.insert(StreamSeq(1), 10).unwrap(); + ring.insert(StreamSeq(3), 30).unwrap(); + + let drained: Vec<_> = ring.drain_front().collect(); + assert_eq!( + drained, + vec![(StreamSeq(1), 10), (StreamSeq(2), 20), (StreamSeq(3), 30)] + ); + assert_eq!(ring.base_seq(), StreamSeq(4)); + assert!(ring.is_empty()); + } + + #[test] + fn seq_ring_wraps_and_reuses_slots() { + let mut ring = SeqRing::<4, u64>::new(StreamSeq(1)); + ring.insert(StreamSeq(1), 1).unwrap(); + ring.insert(StreamSeq(2), 2).unwrap(); + ring.insert(StreamSeq(3), 3).unwrap(); + + assert_eq!(ring.take_front(), Some((StreamSeq(1), 1))); + assert_eq!(ring.take_front(), Some((StreamSeq(2), 2))); + + ring.insert(StreamSeq(4), 4).unwrap(); + ring.insert(StreamSeq(5), 5).unwrap(); + + let remaining: Vec<_> = ring.iter().map(|(seq, value)| (seq, *value)).collect(); + assert_eq!( + remaining, + vec![(StreamSeq(3), 3), (StreamSeq(4), 4), (StreamSeq(5), 5)] + ); + } + + #[test] + fn seq_ring_selective_take_can_slide_past_empty_front() { + let mut ring = SeqRing::<8, u64>::new(StreamSeq(1)); + for value in 1..=4 { + ring.insert(StreamSeq(value), value as u64).unwrap(); + } + + assert_eq!(ring.remove(&StreamSeq(2)), Some(2)); + assert_eq!(ring.remove(&StreamSeq(3)), Some(3)); + ring.advance_empty_front_until(StreamSeq(5)); + assert_eq!(ring.base_seq(), StreamSeq(1)); + + assert_eq!(ring.remove(&StreamSeq(1)), Some(1)); + ring.advance_empty_front_until(StreamSeq(5)); + assert_eq!(ring.base_seq(), StreamSeq(4)); + + assert_eq!(ring.remove(&StreamSeq(4)), Some(4)); + ring.advance_empty_front_until(StreamSeq(5)); + assert_eq!(ring.base_seq(), StreamSeq(5)); + assert!(ring.is_empty()); + } + + #[test] + fn stream_control_recv_buffer_preserves_ack_bitmap_and_drain_order() { + let stream_id = StreamId(7); + let mut control = StreamControl::default(); + + let (seq2, frame2) = data_frame(stream_id, 2, b'b'); + let (seq1, frame1) = data_frame(stream_id, 1, b'a'); + let (seq3, frame3) = data_frame(stream_id, 3, b'c'); + + assert!(matches!( + control.buffer_incoming(seq2, frame2), + BufferIncomingResult::Buffered { out_of_order: true } + )); + assert_eq!(control.current_ack().base, StreamSeq(0)); + assert_eq!(control.current_ack().bitmap, 0b0000_0010); + + assert!(matches!( + control.buffer_incoming(seq1, frame1), + BufferIncomingResult::Buffered { + out_of_order: false + } + )); + assert!(matches!( + control.buffer_incoming(seq3, frame3), + BufferIncomingResult::Buffered { out_of_order: true } + )); + + let committed: Vec<_> = std::iter::from_fn(|| control.pop_next_committable()).collect(); + assert_eq!( + committed.iter().map(|(seq, _)| *seq).collect::>(), + vec![StreamSeq(1), StreamSeq(2), StreamSeq(3)] + ); + assert_eq!(control.committed_rx_seq(), StreamSeq(3)); + assert_eq!(control.current_ack().base, StreamSeq(3)); + assert_eq!(control.current_ack().bitmap, 0); + } + + #[test] + fn stream_control_send_window_respects_sequence_range_not_count() { + let stream_id = StreamId(11); + let mut control = StreamControl::default(); + for tx_seq in 1..=8 { + let frame = InFlightFrame { + tx_seq: StreamSeq(tx_seq), + frame: StreamFrame::Open(StreamFrameOpen { + stream_id, + request_head: vec![tx_seq as u8], + request_prefix: None, + }), + attempt: 0, + }; + control.insert_in_flight(frame); + control.next_tx_seq = StreamSeq(tx_seq + 1); + } + + assert!(!control.send_window_has_space()); + let _ = control.remove_in_flight(StreamSeq(8)); + assert!(!control.send_window_has_space()); + let _ = control.remove_in_flight(StreamSeq(1)); + assert!(control.send_window_has_space()); + assert_eq!(control.in_flight.base_seq(), StreamSeq(2)); + } + + #[test] + fn ack_coverage_handles_wraparound_bitmap() { + let ack = StreamAck { + base: StreamSeq(u32::MAX), + bitmap: 0b0000_0011, + }; + + assert!(StreamControl::ack_covers(ack, StreamSeq(u32::MAX - 1))); + assert!(StreamControl::ack_covers(ack, StreamSeq(u32::MAX))); + assert!(StreamControl::ack_covers(ack, StreamSeq(0))); + assert!(StreamControl::ack_covers(ack, StreamSeq(1))); + assert!(!StreamControl::ack_covers(ack, StreamSeq(2))); + } + + #[test] + fn seq_ring_accepts_window_across_sequence_overflow() { + let mut ring = SeqRing::<4, u64>::new(StreamSeq(u32::MAX - 1)); + ring.insert(StreamSeq(u32::MAX - 1), 1).unwrap(); + ring.insert(StreamSeq(u32::MAX), 2).unwrap(); + ring.insert(StreamSeq(0), 3).unwrap(); + + assert_eq!(ring.take_front(), Some((StreamSeq(u32::MAX - 1), 1))); + assert_eq!(ring.take_front(), Some((StreamSeq(u32::MAX), 2))); + + ring.insert(StreamSeq(1), 4).unwrap(); + ring.insert(StreamSeq(2), 5).unwrap(); + + let remaining: Vec<_> = ring.iter().map(|(seq, value)| (seq, *value)).collect(); + assert_eq!( + remaining, + vec![(StreamSeq(0), 3), (StreamSeq(1), 4), (StreamSeq(2), 5)] + ); + } + + #[test] + fn seq_ring_selective_take_slides_across_sequence_overflow() { + let mut ring = SeqRing::<8, u64>::new(StreamSeq(u32::MAX - 1)); + for (seq, value) in [ + (StreamSeq(u32::MAX - 1), 1u64), + (StreamSeq(u32::MAX), 2u64), + (StreamSeq(0), 3u64), + (StreamSeq(1), 4u64), + ] { + ring.insert(seq, value).unwrap(); + } + + assert_eq!(ring.remove(&StreamSeq(u32::MAX)), Some(2)); + assert_eq!(ring.remove(&StreamSeq(0)), Some(3)); + ring.advance_empty_front_until(StreamSeq(2)); + assert_eq!(ring.base_seq(), StreamSeq(u32::MAX - 1)); + + assert_eq!(ring.remove(&StreamSeq(u32::MAX - 1)), Some(1)); + ring.advance_empty_front_until(StreamSeq(2)); + assert_eq!(ring.base_seq(), StreamSeq(1)); + + assert_eq!(ring.remove(&StreamSeq(1)), Some(4)); + ring.advance_empty_front_until(StreamSeq(2)); + assert_eq!(ring.base_seq(), StreamSeq(2)); + assert!(ring.is_empty()); + } +} diff --git a/ql2/src/engine/state.rs b/ql2/src/engine/state.rs index 33055174..7debf60f 100644 --- a/ql2/src/engine/state.rs +++ b/ql2/src/engine/state.rs @@ -9,14 +9,15 @@ use bc_components::{MLDSAPublicKey, MLKEMPublicKey, SymmetricKey, XID}; use super::{ replay_cache::ReplayCache, - stream::{AwaitingFrame, AwaitingPacket, QueuedWrite, StreamControl, StreamKey, StreamState}, - EngineConfig, + stream::{QueuedWrite, StreamState}, + EngineConfig, StreamConfig, }; use crate::{ - runtime::StreamConfig, + platform::QlIdentity, wire::{ handshake::{Hello, HelloReply, ResponderSecrets}, - stream::{Direction, RejectCode, ResetCode, StreamBody, StreamFrame, StreamFrameData}, + stream::{BodyChunk, Direction, RejectCode, ResetCode, StreamBody}, + StreamSeq, }, PacketId, Peer, QlError, StreamId, }; @@ -30,7 +31,7 @@ pub struct OpenId(pub u64); #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct TrackedWrite { pub stream_id: StreamId, - pub packet_id: PacketId, + pub tx_seq: StreamSeq, } #[derive(Debug, Clone)] @@ -56,6 +57,41 @@ pub enum InitiatorStage { SendingConfirm, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamNamespace { + Low, + High, +} + +impl StreamNamespace { + const BIT: u64 = 1 << 63; + + pub fn bit(self) -> u64 { + match self { + Self::Low => 0, + Self::High => Self::BIT, + } + } + + pub fn for_local(local: XID, peer: XID) -> Self { + match local.data().cmp(peer.data()) { + std::cmp::Ordering::Less | std::cmp::Ordering::Equal => Self::Low, + std::cmp::Ordering::Greater => Self::High, + } + } + + pub fn matches(self, stream_id: StreamId) -> bool { + (stream_id.0 & Self::BIT) == self.bit() + } + + pub fn remote(self) -> Self { + match self { + Self::Low => Self::High, + Self::High => Self::Low, + } + } +} + #[derive(Debug, Clone)] pub enum PeerSession { Disconnected, @@ -129,11 +165,13 @@ pub enum EngineInput { OpenStream { open_id: OpenId, request_head: Vec, + request_prefix: Option, config: StreamConfig, }, AcceptStream { stream_id: StreamId, response_head: Vec, + response_prefix: Option, }, RejectStream { stream_id: StreamId, @@ -143,18 +181,11 @@ pub enum EngineInput { OutboundData { stream_id: StreamId, dir: Direction, - offset: u64, bytes: Vec, }, OutboundFinished { stream_id: StreamId, dir: Direction, - final_offset: u64, - }, - InboundConsumed { - stream_id: StreamId, - dir: Direction, - amount: u64, }, ResetOutbound { @@ -207,6 +238,7 @@ pub enum EngineOutput { open_id: OpenId, stream_id: StreamId, response_head: Vec, + response_prefix: Option, }, OpenFailed { open_id: OpenId, @@ -217,6 +249,7 @@ pub enum EngineOutput { InboundStreamOpened { stream_id: StreamId, request_head: Vec, + request_prefix: Option, }, InboundData { stream_id: StreamId, @@ -236,13 +269,6 @@ pub enum EngineOutput { NeedOutboundData { stream_id: StreamId, dir: Direction, - offset: u64, - max_len: usize, - }, - ReleaseOutboundThrough { - stream_id: StreamId, - dir: Direction, - recv_offset: u64, }, OutboundClosed { stream_id: StreamId, @@ -281,11 +307,19 @@ pub enum TimeoutKind { stream_id: StreamId, token: Token, }, - StreamPacket { + StreamMessage { stream_id: StreamId, - packet_id: PacketId, + tx_seq: StreamSeq, attempt: u8, }, + StreamAckDelay { + stream_id: StreamId, + token: Token, + }, + StreamProvisional { + stream_id: StreamId, + token: Token, + }, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -318,6 +352,7 @@ pub enum HelloAction { pub struct Engine { pub config: EngineConfig, + pub identity: QlIdentity, pub state: EngineState, pub streams: HashMap, } @@ -365,10 +400,10 @@ impl EngineState { PacketId(id) } - pub fn next_stream_id(&self) -> StreamId { - let id = self.next_stream_id.get(); - self.next_stream_id.set(id.wrapping_add(1)); - StreamId(id) + pub fn next_stream_id(&self, namespace: StreamNamespace) -> StreamId { + let seq = self.next_stream_id.get(); + self.next_stream_id.set(seq.wrapping_add(1)); + StreamId((seq & !StreamNamespace::BIT) | namespace.bit()) } pub fn enqueue_handshake_message( @@ -380,9 +415,6 @@ impl EngineState { ) { self.outbound.push_back(QueuedWrite { token, - stream_id: None, - packet_id: None, - track_ack: false, payload: super::stream::QueuedPayload::PreEncoded(bytes), }); self.timeouts.push(Reverse(TimeoutEntry { @@ -398,19 +430,13 @@ impl EngineState { pub fn enqueue_stream_body( &mut self, config: &EngineConfig, - stream_id: Option, - packet_id: Option, - track_ack: bool, priority: bool, body: StreamBody, - ) { + ) -> Token { let token = self.next_token(); let message = QueuedWrite { token, - stream_id, - packet_id, - track_ack, - payload: super::stream::QueuedPayload::StreamBody(body), + payload: super::stream::QueuedPayload::Stream { body }, }; if priority { self.outbound.push_front(message); @@ -421,87 +447,6 @@ impl EngineState { at: Instant::now() + config.packet_expiration, kind: TimeoutKind::Outbound { token }, })); + token } - - pub fn enqueue_control_frame( - &mut self, - config: &EngineConfig, - key: StreamKey, - control: &mut StreamControl, - frame: StreamFrame, - attempt: u8, - ) { - let packet_id = self.next_packet_id(); - control.awaiting = Some(AwaitingPacket { - packet_id, - frame: AwaitingFrame::Control(frame.clone()), - attempt, - }); - let valid_until = - crate::wire::now_secs().saturating_add(config.packet_expiration.as_secs()); - self.enqueue_stream_body( - config, - Some(key.stream_id), - Some(packet_id), - true, - false, - StreamBody { - packet_id, - valid_until, - packet_ack: None, - frame: Some(frame), - }, - ); - } - - pub fn enqueue_data_frame( - &mut self, - config: &EngineConfig, - key: StreamKey, - control: &mut StreamControl, - dir: Direction, - offset: u64, - bytes: Vec, - attempt: u8, - ) { - let packet_id = self.next_packet_id(); - control.awaiting = Some(AwaitingPacket { - packet_id, - frame: AwaitingFrame::Data { - dir, - offset, - len: bytes.len(), - }, - attempt, - }); - let valid_until = - crate::wire::now_secs().saturating_add(config.packet_expiration.as_secs()); - self.enqueue_stream_body( - config, - Some(key.stream_id), - Some(packet_id), - true, - false, - StreamBody { - packet_id, - valid_until, - packet_ack: None, - frame: Some(StreamFrame::Data(StreamFrameData { - stream_id: key.stream_id, - dir, - offset, - bytes, - })), - }, - ); - } -} - -pub enum EitherRetransmit { - Control(StreamFrame), - Data { - dir: Direction, - offset: u64, - len: usize, - }, } diff --git a/ql2/src/engine/stream.rs b/ql2/src/engine/stream.rs index 5a371b96..c90d9e09 100644 --- a/ql2/src/engine/stream.rs +++ b/ql2/src/engine/stream.rs @@ -1,83 +1,104 @@ -use std::time::Instant; +use std::{collections::VecDeque, time::Instant}; -use super::{OpenId, Token}; +use super::{ring::SeqRing, OpenId, Token}; use crate::{ - wire::stream::{ - Direction, ResetCode, ResetTarget, StreamBody, StreamFrame, StreamFrameAccept, - StreamFrameCredit, StreamFrameOpen, StreamFrameReject, StreamFrameReset, + wire::{ + stream::{ + Direction, ResetCode, ResetTarget, StreamAck, StreamBody, StreamFrame, StreamFrameReset, + }, + StreamSeq, }, - PacketId, StreamId, + StreamId, }; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct StreamKey { - pub stream_id: StreamId, -} +pub const STREAM_WINDOW_CAPACITY: usize = 8; +pub const STREAM_WINDOW_SIZE: u32 = STREAM_WINDOW_CAPACITY as u32; +pub const STREAM_ACK_EAGER_THRESHOLD: u32 = STREAM_WINDOW_SIZE / 2; #[derive(Debug)] pub struct StreamMeta { - pub key: StreamKey, - pub request_head: Vec, + pub stream_id: StreamId, pub last_activity: Instant, } -#[derive(Debug)] -pub struct PendingPull { - pub offset: u64, - pub max_len: usize, +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum OutboundPhase { + Ready, + PendingPull, + FinPending, + FinQueued, + Closed, } #[derive(Debug)] pub struct OutboundState { pub dir: Direction, - pub remote_max_offset: u64, - pub sent_offset: u64, - pub released_offset: u64, - pub final_offset: Option, - pub data_enabled: bool, - pub closed: bool, - pub pending_pull: Option, + pub phase: OutboundPhase, } impl OutboundState { - pub fn new(dir: Direction, remote_max_offset: u64, data_enabled: bool) -> Self { + pub fn from_prefix(dir: Direction, fin: bool) -> Self { Self { dir, - remote_max_offset, - sent_offset: 0, - released_offset: 0, - final_offset: None, - data_enabled, - closed: false, - pending_pull: None, + phase: if fin { + OutboundPhase::FinQueued + } else { + OutboundPhase::Ready + }, + } + } + + pub fn is_closed(&self) -> bool { + self.phase == OutboundPhase::Closed + } + + pub fn request_data(&mut self) -> bool { + if self.phase != OutboundPhase::Ready { + return false; + } + self.phase = OutboundPhase::PendingPull; + true + } + + pub fn take_pending_pull(&mut self) -> bool { + if self.phase != OutboundPhase::PendingPull { + return false; } + self.phase = OutboundPhase::Ready; + true } - pub fn can_request_data(&self) -> bool { - self.data_enabled - && !self.closed - && self.pending_pull.is_none() - && self.sent_offset < self.remote_max_offset - && self - .final_offset - .is_none_or(|final_offset| self.sent_offset < final_offset) + pub fn finish(&mut self) { + self.phase = match self.phase { + OutboundPhase::Ready | OutboundPhase::PendingPull | OutboundPhase::FinPending => { + OutboundPhase::FinPending + } + OutboundPhase::FinQueued => OutboundPhase::FinQueued, + OutboundPhase::Closed => OutboundPhase::Closed, + }; + } + + pub fn queue_fin(&mut self) -> bool { + if self.phase != OutboundPhase::FinPending { + return false; + } + self.phase = OutboundPhase::FinQueued; + true + } + + pub fn close(&mut self) { + self.phase = OutboundPhase::Closed; } } #[derive(Debug)] pub struct InboundState { - pub next_offset: u64, - pub max_offset: u64, pub closed: bool, } impl InboundState { - pub fn new(max_offset: u64) -> Self { - Self { - next_offset: 0, - max_offset, - closed: false, - } + pub fn new() -> Self { + Self { closed: false } } } @@ -94,6 +115,167 @@ pub enum InitiatorAccept { Open { response_head: Vec }, } +#[derive(Debug)] +pub struct InFlightFrame { + pub tx_seq: StreamSeq, + pub frame: StreamFrame, + pub attempt: u8, +} + +#[derive(Debug)] +pub enum BufferIncomingResult { + Duplicate, + AlreadyBuffered, + Buffered { out_of_order: bool }, + OutOfWindow, +} + +#[derive(Debug)] +pub struct StreamControl { + pub pending: VecDeque, + pub in_flight: SeqRing, + pub next_tx_seq: StreamSeq, + pub recv_buffer: SeqRing, + pub ack_dirty: bool, + pub ack_immediate: bool, + pub ack_delay_token: Option, + pub ack_outbound_token: Option, + pub last_sent_ack_base: StreamSeq, +} + +impl Default for StreamControl { + fn default() -> Self { + Self { + pending: VecDeque::new(), + in_flight: SeqRing::new(StreamSeq::START), + next_tx_seq: StreamSeq::START, + recv_buffer: SeqRing::new(StreamSeq::START), + ack_dirty: false, + ack_immediate: false, + ack_delay_token: None, + ack_outbound_token: None, + last_sent_ack_base: StreamSeq(0), + } + } +} + +impl StreamControl { + pub fn take_tx_seq(&mut self) -> StreamSeq { + let tx_seq = self.next_tx_seq; + self.next_tx_seq = self.next_tx_seq.next(); + tx_seq + } + + pub fn send_window_has_space(&self) -> bool { + self.in_flight.accepts_seq(self.next_tx_seq) + } + + pub fn committed_rx_seq(&self) -> StreamSeq { + self.recv_buffer.base_seq().prev() + } + + pub fn queue_frame_back(&mut self, frame: StreamFrame) { + self.pending.push_back(frame); + } + + pub fn queue_frame_front(&mut self, frame: StreamFrame) { + self.pending.push_front(frame); + } + + pub fn note_ack(&mut self, immediate: bool) { + self.ack_dirty = true; + self.ack_immediate |= immediate; + } + + pub fn clear_ack_schedule(&mut self) { + self.ack_dirty = false; + self.ack_immediate = false; + self.ack_delay_token = None; + } + + pub fn maybe_force_ack_for_progress(&mut self) { + if !self.ack_dirty { + return; + } + let committed = self.committed_rx_seq(); + let progressed = self + .last_sent_ack_base + .forward_distance_to(committed) + .unwrap_or(0); + if progressed >= STREAM_ACK_EAGER_THRESHOLD { + self.ack_immediate = true; + } + } + + pub fn note_ack_sent(&mut self, ack: StreamAck) { + if ack.base.serial_gt(self.last_sent_ack_base) { + self.last_sent_ack_base = ack.base; + } + } + + pub fn current_ack(&self) -> StreamAck { + StreamAck { + base: self.committed_rx_seq(), + bitmap: self.recv_buffer.bitmap(), + } + } + + pub fn buffer_incoming( + &mut self, + tx_seq: StreamSeq, + frame: StreamFrame, + ) -> BufferIncomingResult { + if tx_seq.serial_lt(self.recv_buffer.base_seq()) { + return BufferIncomingResult::Duplicate; + } + if !self.recv_buffer.accepts_seq(tx_seq) { + return BufferIncomingResult::OutOfWindow; + } + if self.recv_buffer.contains_key(&tx_seq) { + return BufferIncomingResult::AlreadyBuffered; + } + + let out_of_order = tx_seq != self.recv_buffer.base_seq(); + let _ = self.recv_buffer.insert(tx_seq, frame); + BufferIncomingResult::Buffered { out_of_order } + } + + pub fn pop_next_committable(&mut self) -> Option<(StreamSeq, StreamFrame)> { + self.recv_buffer.take_front() + } + + pub fn insert_in_flight(&mut self, frame: InFlightFrame) { + let _ = self.in_flight.set(frame.tx_seq, frame); + } + + pub fn remove_in_flight(&mut self, tx_seq: StreamSeq) -> Option { + let removed = self.in_flight.remove(&tx_seq); + self.in_flight.advance_empty_front_until(self.next_tx_seq); + removed + } + + pub fn clear_transient_buffers(&mut self) { + self.pending.clear(); + self.in_flight.clear_with_base(self.next_tx_seq); + self.recv_buffer + .clear_with_base(self.committed_rx_seq().next()); + self.clear_ack_schedule(); + } + + pub fn ack_covers(ack: StreamAck, tx_seq: StreamSeq) -> bool { + if tx_seq.serial_lte(ack.base) { + return true; + } + let Some(delta) = ack.base.forward_distance_to(tx_seq) else { + return false; + }; + if !(1..=STREAM_WINDOW_SIZE).contains(&delta) { + return false; + } + (ack.bitmap & (1u8 << (delta - 1))) != 0 + } +} + #[derive(Debug)] pub struct InitiatorStream { pub meta: StreamMeta, @@ -105,16 +287,9 @@ pub struct InitiatorStream { #[derive(Debug)] pub enum ResponderResponse { - Pending { - initial_credit: u64, - }, - Accepted { - initial_credit: u64, - body: OutboundState, - }, - Rejecting { - initial_credit: u64, - }, + Pending, + Accepted { body: OutboundState }, + Rejecting, } #[derive(Debug)] @@ -125,17 +300,26 @@ pub struct ResponderStream { pub response: ResponderResponse, } +#[derive(Debug)] +pub struct ProvisionalStream { + pub meta: StreamMeta, + pub control: StreamControl, + pub timeout_token: Token, +} + #[derive(Debug)] pub enum StreamState { Initiator(InitiatorStream), Responder(ResponderStream), + Provisional(ProvisionalStream), } impl StreamState { - pub fn key(&self) -> StreamKey { + pub fn stream_id(&self) -> StreamId { match self { - Self::Initiator(state) => state.meta.key, - Self::Responder(state) => state.meta.key, + Self::Initiator(state) => state.meta.stream_id, + Self::Responder(state) => state.meta.stream_id, + Self::Provisional(state) => state.meta.stream_id, } } @@ -143,6 +327,7 @@ impl StreamState { match self { Self::Initiator(state) => &mut state.meta.last_activity, Self::Responder(state) => &mut state.meta.last_activity, + Self::Provisional(state) => &mut state.meta.last_activity, } } @@ -150,6 +335,7 @@ impl StreamState { match self { Self::Initiator(state) => &state.control, Self::Responder(state) => &state.control, + Self::Provisional(state) => &state.control, } } @@ -157,6 +343,7 @@ impl StreamState { match self { Self::Initiator(state) => &mut state.control, Self::Responder(state) => &mut state.control, + Self::Provisional(state) => &mut state.control, } } @@ -164,7 +351,7 @@ impl StreamState { match self { Self::Initiator(state) if dir == Direction::Request => Some(&mut state.request), Self::Responder(state) if dir == Direction::Response => match &mut state.response { - ResponderResponse::Accepted { body, .. } => Some(body), + ResponderResponse::Accepted { body } => Some(body), _ => None, }, _ => None, @@ -191,112 +378,58 @@ impl StreamState { } } + pub fn provisional_timeout_token(&self) -> Option { + match self { + Self::Provisional(state) => Some(state.timeout_token), + _ => None, + } + } + + pub fn is_provisional(&self) -> bool { + matches!(self, Self::Provisional(_)) + } + pub fn can_reap(&self) -> bool { - if self.control().awaiting.is_some() || !self.control().pending.is_empty() { + if !self.control().pending.is_empty() + || !self.control().in_flight.is_empty() + || !self.control().recv_buffer.is_empty() + || self.control().ack_dirty + || self.control().ack_outbound_token.is_some() + { return false; } match self { Self::Initiator(state) => { matches!(state.accept, InitiatorAccept::Open { .. }) - && state.request.closed + && state.request.is_closed() && state.response.closed } Self::Responder(state) => match &state.response { - ResponderResponse::Accepted { body, .. } => state.request.closed && body.closed, - ResponderResponse::Rejecting { .. } => true, - ResponderResponse::Pending { .. } => false, + ResponderResponse::Accepted { body } => state.request.closed && body.is_closed(), + ResponderResponse::Rejecting => true, + ResponderResponse::Pending => false, }, + Self::Provisional(_) => false, } } } -#[derive(Debug)] -pub struct AwaitingPacket { - pub packet_id: PacketId, - pub frame: AwaitingFrame, - pub attempt: u8, -} - -#[derive(Debug, Clone)] -pub enum AwaitingFrame { - Control(StreamFrame), - Data { - dir: Direction, - offset: u64, - len: usize, - }, -} - -#[derive(Debug)] -pub enum SetupFrame { - Open(StreamFrameOpen), - Accept(StreamFrameAccept), - Reject(StreamFrameReject), -} - -#[derive(Debug, Default)] -pub struct PendingFrames { - pub setup: Option, - pub credit: Option, - pub reset: Option, -} - -impl PendingFrames { - pub fn take_next_control(&mut self, stream_id: StreamId) -> Option { - if let Some(setup) = self.setup.take() { - return Some(match setup { - SetupFrame::Open(frame) => StreamFrame::Open(frame), - SetupFrame::Accept(frame) => StreamFrame::Accept(frame), - SetupFrame::Reject(frame) => StreamFrame::Reject(frame), - }); - } - if let Some(reset) = self.reset.take() { - return Some(StreamFrame::Reset(StreamFrameReset { stream_id, ..reset })); - } - self.credit.take().map(StreamFrame::Credit) - } - - pub fn set_setup(&mut self, setup: SetupFrame) { - self.setup = Some(setup); - } - - pub fn set_credit(&mut self, frame: StreamFrameCredit) { - if self.reset.is_none() { - self.credit = Some(frame); - } - } - - pub fn set_reset(&mut self, dir: ResetTarget, code: ResetCode) { - self.credit = None; - self.reset = Some(StreamFrameReset { - stream_id: StreamId(0), - dir, - code, - }); - } - - pub fn is_empty(&self) -> bool { - self.setup.is_none() && self.credit.is_none() && self.reset.is_none() - } -} - -#[derive(Debug, Default)] -pub struct StreamControl { - pub pending: PendingFrames, - pub awaiting: Option, -} - #[derive(Debug)] pub enum QueuedPayload { PreEncoded(Vec), - StreamBody(StreamBody), + Stream { body: StreamBody }, } #[derive(Debug)] pub struct QueuedWrite { pub token: Token, - pub stream_id: Option, - pub packet_id: Option, - pub track_ack: bool, pub payload: QueuedPayload, } + +pub fn reset_frame(stream_id: StreamId, target: ResetTarget, code: ResetCode) -> StreamFrame { + StreamFrame::Reset(StreamFrameReset { + stream_id, + target, + code, + }) +} diff --git a/ql2/src/engine/tests.rs b/ql2/src/engine/tests.rs new file mode 100644 index 00000000..494b1ae1 --- /dev/null +++ b/ql2/src/engine/tests.rs @@ -0,0 +1,1360 @@ +use std::{cell::Cell, mem, time::Instant}; + +use bc_components::{SymmetricKey, MLDSA, MLKEM}; + +use super::*; +use crate::{ + platform::{QlCrypto, QlIdentity}, + wire::{ + self, + stream::{ + BodyChunk, StreamAck, StreamAckBody, StreamBody, StreamFrame, StreamFrameAccept, + StreamFrameData, StreamFrameOpen, StreamMessage, + }, + QlHeader, QlPayload, + }, + PacketId, Peer, +}; + +struct TestCrypto { + identity: QlIdentity, + nonce_seed: u8, + nonce_counter: Cell, +} + +impl TestCrypto { + fn new(seed: u8) -> Self { + let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); + let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); + Self { + identity: QlIdentity::from_keys( + signing_private, + signing_public, + encapsulation_private, + encapsulation_public, + ), + nonce_seed: seed, + nonce_counter: Cell::new(0), + } + } + + fn xid(&self) -> XID { + self.identity.xid + } + + fn peer(&self) -> Peer { + Peer { + peer: self.xid(), + signing_key: self.identity.signing_public_key.clone(), + encapsulation_key: self.identity.encapsulation_public_key.clone(), + } + } +} + +impl QlCrypto for TestCrypto { + fn fill_random_bytes(&self, data: &mut [u8]) { + let value = self.nonce_seed.wrapping_add(self.nonce_counter.get()); + self.nonce_counter + .set(self.nonce_counter.get().wrapping_add(1)); + data.fill(value); + } +} + +#[derive(Clone, Copy)] +enum Side { + A, + B, +} + +impl Side { + fn other(self) -> Self { + match self { + Side::A => Side::B, + Side::B => Side::A, + } + } +} + +struct Harness { + now: Instant, + a: Engine, + b: Engine, + crypto_a: TestCrypto, + crypto_b: TestCrypto, + outputs_a: Vec, + outputs_b: Vec, +} + +fn run_engine( + engine: &mut Engine, + now: Instant, + input: EngineInput, + crypto: &TestCrypto, +) -> Vec { + let mut outputs = Vec::new(); + engine.run_tick(now, input, crypto, &mut |output| outputs.push(output)); + outputs +} + +fn take_single_write(outputs: &[EngineOutput]) -> (Token, Option, Vec) { + let writes: Vec<_> = outputs + .iter() + .filter_map(|output| match output { + EngineOutput::WriteMessage { + token, + tracked, + bytes, + } => Some((*token, *tracked, bytes.clone())), + _ => None, + }) + .collect(); + assert_eq!(writes.len(), 1); + writes.into_iter().next().unwrap() +} + +fn decode_stream_body(bytes: &[u8], session_key: &SymmetricKey) -> (QlHeader, StreamBody) { + let record = wire::decode_record(bytes).unwrap(); + let aad = record.header.aad(); + let QlPayload::Stream(encrypted) = record.payload else { + panic!("expected stream payload"); + }; + let plaintext = encrypted.decrypt(session_key, &aad).unwrap(); + let body = wire::access_value::(&plaintext) + .and_then(wire::deserialize_value) + .unwrap(); + (record.header, body) +} + +fn connected_engine(local: &TestCrypto, peer: Peer, session_key: SymmetricKey) -> Engine { + let mut engine = Engine::new(EngineConfig::default(), local.identity.clone(), Some(peer)); + engine.state.peer.as_mut().unwrap().session = PeerSession::Connected { + session_key, + keepalive: KeepAliveState::default(), + }; + engine +} + +impl Harness { + fn new(config: EngineConfig) -> Self { + let crypto_a = TestCrypto::new(1); + let crypto_b = TestCrypto::new(2); + let peer_a = crypto_a.peer(); + let peer_b = crypto_b.peer(); + let session_key = SymmetricKey::from_data([7; SymmetricKey::SYMMETRIC_KEY_SIZE]); + let mut a = Engine::new(config, crypto_a.identity.clone(), Some(peer_b)); + let mut b = Engine::new(config, crypto_b.identity.clone(), Some(peer_a)); + a.state.peer.as_mut().unwrap().session = PeerSession::Connected { + session_key: session_key.clone(), + keepalive: KeepAliveState::default(), + }; + b.state.peer.as_mut().unwrap().session = PeerSession::Connected { + session_key, + keepalive: KeepAliveState::default(), + }; + Self { + now: Instant::now(), + a, + b, + crypto_a, + crypto_b, + outputs_a: Vec::new(), + outputs_b: Vec::new(), + } + } + + fn send_a(&mut self, input: EngineInput) { + self.run_side(Side::A, input); + } + + fn send_b(&mut self, input: EngineInput) { + self.run_side(Side::B, input); + } + + fn drain_a(&mut self) -> Vec { + mem::take(&mut self.outputs_a) + } + + fn drain_b(&mut self) -> Vec { + mem::take(&mut self.outputs_b) + } + + fn run_side(&mut self, side: Side, input: EngineInput) { + let mut outputs = Vec::new(); + match side { + Side::A => self + .a + .run_tick(self.now, input, &self.crypto_a, &mut |output| { + outputs.push(output) + }), + Side::B => self + .b + .run_tick(self.now, input, &self.crypto_b, &mut |output| { + outputs.push(output) + }), + } + + let writes: Vec<(Token, Option, Vec)> = outputs + .iter() + .filter_map(|output| match output { + EngineOutput::WriteMessage { + token, + tracked, + bytes, + } => Some((*token, *tracked, bytes.clone())), + _ => None, + }) + .collect(); + + match side { + Side::A => self.outputs_a.extend(outputs), + Side::B => self.outputs_b.extend(outputs), + } + + for (token, tracked, bytes) in writes { + self.run_side( + side, + EngineInput::WriteCompleted { + token, + tracked, + result: Ok(()), + }, + ); + self.run_side(side.other(), EngineInput::Incoming(bytes)); + } + } +} + +#[test] +fn open_prefix_is_delivered_on_setup_output() { + let mut harness = Harness::new(EngineConfig::default()); + let request_prefix = BodyChunk { + bytes: b"req".to_vec(), + fin: true, + }; + + harness.send_a(EngineInput::OpenStream { + open_id: OpenId(1), + request_head: b"open-head".to_vec(), + request_prefix: Some(request_prefix.clone()), + config: StreamConfig::default(), + }); + + harness.now += EngineConfig::default().stream_ack_delay; + harness.send_b(EngineInput::TimerExpired); + + let outputs_a = harness.drain_a(); + let outputs_b = harness.drain_b(); + let stream_id = outputs_a + .iter() + .find_map(|output| match output { + EngineOutput::OpenStarted { stream_id, .. } => Some(*stream_id), + _ => None, + }) + .unwrap(); + + assert!(outputs_a.iter().any(|output| matches!( + output, + EngineOutput::OpenStarted { + open_id: OpenId(1), + stream_id: id, + } if *id == stream_id + ))); + assert!( + StreamNamespace::for_local(harness.crypto_a.xid(), harness.crypto_b.xid()) + .matches(stream_id) + ); + assert!(outputs_a.iter().any(|output| matches!( + output, + EngineOutput::OutboundClosed { + stream_id: id, + dir: Direction::Request, + } if *id == stream_id + ))); + + let opened = outputs_b.iter().find_map(|output| match output { + EngineOutput::InboundStreamOpened { + stream_id, + request_head, + request_prefix, + } => Some((*stream_id, request_head.clone(), request_prefix.clone())), + _ => None, + }); + assert_eq!( + opened, + Some(( + stream_id, + b"open-head".to_vec(), + Some(request_prefix.clone()), + )) + ); + assert!(!outputs_b + .iter() + .any(|output| matches!(output, EngineOutput::InboundData { .. }))); + assert!(!outputs_b + .iter() + .any(|output| matches!(output, EngineOutput::InboundFinished { .. }))); +} + +#[test] +fn unary_exchange_uses_open_and_accept_prefixes() { + let mut harness = Harness::new(EngineConfig::default()); + let request_prefix = BodyChunk { + bytes: b"req".to_vec(), + fin: true, + }; + let response_prefix = BodyChunk { + bytes: b"resp".to_vec(), + fin: true, + }; + + harness.send_a(EngineInput::OpenStream { + open_id: OpenId(7), + request_head: b"request-head".to_vec(), + request_prefix: Some(request_prefix.clone()), + config: StreamConfig::default(), + }); + + let outputs_a_open = harness.drain_a(); + let outputs_b = harness.drain_b(); + let started_stream_id = outputs_a_open + .iter() + .find_map(|output| match output { + EngineOutput::OpenStarted { stream_id, .. } => Some(*stream_id), + _ => None, + }) + .unwrap(); + let stream_id = outputs_b + .iter() + .find_map(|output| match output { + EngineOutput::InboundStreamOpened { stream_id, .. } => Some(*stream_id), + _ => None, + }) + .unwrap(); + assert_eq!(stream_id, started_stream_id); + + harness.send_b(EngineInput::AcceptStream { + stream_id, + response_head: b"response-head".to_vec(), + response_prefix: Some(response_prefix.clone()), + }); + + harness.now += EngineConfig::default().stream_ack_delay; + harness.send_a(EngineInput::TimerExpired); + + let outputs_a = harness.drain_a(); + let outputs_b = harness.drain_b(); + + let accepted = outputs_a.iter().find_map(|output| match output { + EngineOutput::OpenAccepted { + open_id, + stream_id, + response_head, + response_prefix, + } => Some(( + *open_id, + *stream_id, + response_head.clone(), + response_prefix.clone(), + )), + _ => None, + }); + assert_eq!( + accepted, + Some(( + OpenId(7), + stream_id, + b"response-head".to_vec(), + Some(response_prefix.clone()), + )) + ); + assert!(!outputs_a + .iter() + .any(|output| matches!(output, EngineOutput::InboundData { .. }))); + assert!(!outputs_a + .iter() + .any(|output| matches!(output, EngineOutput::InboundFinished { .. }))); + assert!(outputs_b.iter().any(|output| matches!( + output, + EngineOutput::OutboundClosed { + stream_id: id, + dir: Direction::Response, + } if *id == stream_id + ))); +} + +#[test] +fn simultaneous_opens_use_disjoint_stream_id_namespaces() { + let config = EngineConfig::default(); + let crypto_a = TestCrypto::new(11); + let crypto_b = TestCrypto::new(22); + let peer_a = crypto_a.peer(); + let peer_b = crypto_b.peer(); + let session_key = SymmetricKey::from_data([9; SymmetricKey::SYMMETRIC_KEY_SIZE]); + let mut a = Engine::new(config, crypto_a.identity.clone(), Some(peer_b)); + let mut b = Engine::new(config, crypto_b.identity.clone(), Some(peer_a)); + a.state.peer.as_mut().unwrap().session = PeerSession::Connected { + session_key: session_key.clone(), + keepalive: KeepAliveState::default(), + }; + b.state.peer.as_mut().unwrap().session = PeerSession::Connected { + session_key, + keepalive: KeepAliveState::default(), + }; + let now = Instant::now(); + + let outputs_a_open = run_engine( + &mut a, + now, + EngineInput::OpenStream { + open_id: OpenId(1), + request_head: b"a-open".to_vec(), + request_prefix: None, + config: StreamConfig::default(), + }, + &crypto_a, + ); + let outputs_b_open = run_engine( + &mut b, + now, + EngineInput::OpenStream { + open_id: OpenId(2), + request_head: b"b-open".to_vec(), + request_prefix: None, + config: StreamConfig::default(), + }, + &crypto_b, + ); + + let stream_id_a = outputs_a_open + .iter() + .find_map(|output| match output { + EngineOutput::OpenStarted { stream_id, .. } => Some(*stream_id), + _ => None, + }) + .unwrap(); + let stream_id_b = outputs_b_open + .iter() + .find_map(|output| match output { + EngineOutput::OpenStarted { stream_id, .. } => Some(*stream_id), + _ => None, + }) + .unwrap(); + + assert_ne!(stream_id_a, stream_id_b); + assert!(StreamNamespace::for_local(crypto_a.xid(), crypto_b.xid()).matches(stream_id_a)); + assert!(StreamNamespace::for_local(crypto_b.xid(), crypto_a.xid()).matches(stream_id_b)); + + let (token_a, tracked_a, bytes_a) = take_single_write(&outputs_a_open); + let (token_b, tracked_b, bytes_b) = take_single_write(&outputs_b_open); + + let _ = run_engine( + &mut a, + now, + EngineInput::WriteCompleted { + token: token_a, + tracked: tracked_a, + result: Ok(()), + }, + &crypto_a, + ); + let _ = run_engine( + &mut b, + now, + EngineInput::WriteCompleted { + token: token_b, + tracked: tracked_b, + result: Ok(()), + }, + &crypto_b, + ); + + let outputs_a_incoming = run_engine(&mut a, now, EngineInput::Incoming(bytes_b), &crypto_a); + let outputs_b_incoming = run_engine(&mut b, now, EngineInput::Incoming(bytes_a), &crypto_b); + + assert!(outputs_a_incoming.iter().any(|output| matches!( + output, + EngineOutput::InboundStreamOpened { + stream_id, + request_head, + .. + } if *stream_id == stream_id_b && request_head == b"b-open" + ))); + assert!(outputs_b_incoming.iter().any(|output| matches!( + output, + EngineOutput::InboundStreamOpened { + stream_id, + request_head, + .. + } if *stream_id == stream_id_a && request_head == b"a-open" + ))); + assert_eq!(a.streams.len(), 2); + assert_eq!(b.streams.len(), 2); +} + +#[test] +fn invalid_future_frame_does_not_ack_outstanding_open() { + let config = EngineConfig::default(); + let crypto_a = TestCrypto::new(31); + let crypto_b = TestCrypto::new(32); + let peer_a = crypto_a.peer(); + let peer_b = crypto_b.peer(); + let session_key = SymmetricKey::from_data([5; SymmetricKey::SYMMETRIC_KEY_SIZE]); + let mut a = Engine::new(config, crypto_a.identity.clone(), Some(peer_b)); + let mut _b = Engine::new(config, crypto_b.identity.clone(), Some(peer_a)); + a.state.peer.as_mut().unwrap().session = PeerSession::Connected { + session_key: session_key.clone(), + keepalive: KeepAliveState::default(), + }; + + let now = Instant::now(); + let outputs_open = run_engine( + &mut a, + now, + EngineInput::OpenStream { + open_id: OpenId(9), + request_head: b"open".to_vec(), + request_prefix: None, + config: StreamConfig::default(), + }, + &crypto_a, + ); + let stream_id = outputs_open + .iter() + .find_map(|output| match output { + EngineOutput::OpenStarted { stream_id, .. } => Some(*stream_id), + _ => None, + }) + .unwrap(); + + let message = StreamMessage { + tx_seq: StreamSeq(2), + ack: Some(crate::wire::stream::StreamAck { + base: StreamSeq(0), + bitmap: 0, + }), + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Accept(StreamFrameAccept { + stream_id, + response_head: Vec::new(), + response_prefix: None, + }), + }; + + let body = StreamBody::Message(message); + let record = wire::stream::encrypt_stream( + QlHeader { + sender: crypto_b.xid(), + recipient: crypto_a.xid(), + }, + &session_key, + &body, + [9; wire::encrypted_message::NONCE_SIZE], + ); + + let outputs_incoming = run_engine( + &mut a, + now, + EngineInput::Incoming(wire::encode_record(&record)), + &crypto_a, + ); + + assert!(!outputs_incoming + .iter() + .any(|output| matches!(output, EngineOutput::OpenAccepted { .. }))); + + let stream = a.streams.get(&stream_id).unwrap(); + assert!(stream.control().in_flight.contains_key(&StreamSeq::START)); + match stream { + StreamState::Initiator(state) => { + assert!(matches!(state.accept, InitiatorAccept::Opening(_))); + } + _ => panic!("expected initiator stream"), + } +} + +#[test] +fn out_of_order_remote_stream_buffers_until_open_arrives() { + let config = EngineConfig::default(); + let crypto_a = TestCrypto::new(41); + let crypto_b = TestCrypto::new(42); + let peer_b = crypto_b.peer(); + let session_key = SymmetricKey::from_data([6; SymmetricKey::SYMMETRIC_KEY_SIZE]); + let mut a = Engine::new(config, crypto_a.identity.clone(), Some(peer_b)); + a.state.peer.as_mut().unwrap().session = PeerSession::Connected { + session_key: session_key.clone(), + keepalive: KeepAliveState::default(), + }; + + let now = Instant::now(); + let stream_id = StreamId(StreamNamespace::for_local(crypto_b.xid(), crypto_a.xid()).bit() | 1); + + let data_message = StreamMessage { + tx_seq: StreamSeq(2), + ack: None, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(crate::wire::stream::StreamFrameData { + stream_id, + dir: Direction::Request, + chunk: BodyChunk { + bytes: b"hello".to_vec(), + fin: false, + }, + }), + }; + let data_body = StreamBody::Message(data_message); + let data_record = wire::stream::encrypt_stream( + QlHeader { + sender: crypto_b.xid(), + recipient: crypto_a.xid(), + }, + &session_key, + &data_body, + [11; wire::encrypted_message::NONCE_SIZE], + ); + + let outputs_data = run_engine( + &mut a, + now, + EngineInput::Incoming(wire::encode_record(&data_record)), + &crypto_a, + ); + + assert!(!outputs_data + .iter() + .any(|output| matches!(output, EngineOutput::InboundStreamOpened { .. }))); + assert!(!outputs_data + .iter() + .any(|output| matches!(output, EngineOutput::InboundData { .. }))); + assert!(outputs_data + .iter() + .any(|output| matches!(output, EngineOutput::WriteMessage { .. }))); + assert!(matches!( + a.streams.get(&stream_id), + Some(StreamState::Provisional(_)) + )); + + let open_message = StreamMessage { + tx_seq: StreamSeq(1), + ack: None, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Open(crate::wire::stream::StreamFrameOpen { + stream_id, + request_head: b"late-open".to_vec(), + request_prefix: None, + }), + }; + let open_body = StreamBody::Message(open_message); + let open_record = wire::stream::encrypt_stream( + QlHeader { + sender: crypto_b.xid(), + recipient: crypto_a.xid(), + }, + &session_key, + &open_body, + [12; wire::encrypted_message::NONCE_SIZE], + ); + + let outputs_open = run_engine( + &mut a, + now, + EngineInput::Incoming(wire::encode_record(&open_record)), + &crypto_a, + ); + + assert!(outputs_open.iter().any(|output| matches!( + output, + EngineOutput::InboundStreamOpened { + stream_id: id, + request_head, + request_prefix: None, + } if *id == stream_id && request_head == b"late-open" + ))); + assert!(outputs_open.iter().any(|output| matches!( + output, + EngineOutput::InboundData { + stream_id: id, + dir: Direction::Request, + bytes, + } if *id == stream_id && bytes == b"hello" + ))); +} + +#[test] +fn out_of_order_response_data_waits_for_accept() { + let config = EngineConfig::default(); + let crypto_a = TestCrypto::new(51); + let crypto_b = TestCrypto::new(52); + let peer_b = crypto_b.peer(); + let session_key = SymmetricKey::from_data([4; SymmetricKey::SYMMETRIC_KEY_SIZE]); + let mut a = Engine::new(config, crypto_a.identity.clone(), Some(peer_b)); + a.state.peer.as_mut().unwrap().session = PeerSession::Connected { + session_key: session_key.clone(), + keepalive: KeepAliveState::default(), + }; + + let now = Instant::now(); + let outputs_open = run_engine( + &mut a, + now, + EngineInput::OpenStream { + open_id: OpenId(12), + request_head: b"req".to_vec(), + request_prefix: None, + config: StreamConfig::default(), + }, + &crypto_a, + ); + let stream_id = outputs_open + .iter() + .find_map(|output| match output { + EngineOutput::OpenStarted { stream_id, .. } => Some(*stream_id), + _ => None, + }) + .unwrap(); + + let data_message = StreamMessage { + tx_seq: StreamSeq(2), + ack: None, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(crate::wire::stream::StreamFrameData { + stream_id, + dir: Direction::Response, + chunk: BodyChunk { + bytes: b"resp".to_vec(), + fin: false, + }, + }), + }; + let data_body = StreamBody::Message(data_message); + let data_record = wire::stream::encrypt_stream( + QlHeader { + sender: crypto_b.xid(), + recipient: crypto_a.xid(), + }, + &session_key, + &data_body, + [21; wire::encrypted_message::NONCE_SIZE], + ); + + let outputs_data = run_engine( + &mut a, + now, + EngineInput::Incoming(wire::encode_record(&data_record)), + &crypto_a, + ); + assert!(!outputs_data + .iter() + .any(|output| matches!(output, EngineOutput::OpenAccepted { .. }))); + assert!(!outputs_data + .iter() + .any(|output| matches!(output, EngineOutput::InboundData { .. }))); + + let accept_message = StreamMessage { + tx_seq: StreamSeq(1), + ack: None, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Accept(StreamFrameAccept { + stream_id, + response_head: b"resp-head".to_vec(), + response_prefix: None, + }), + }; + let accept_body = StreamBody::Message(accept_message); + let accept_record = wire::stream::encrypt_stream( + QlHeader { + sender: crypto_b.xid(), + recipient: crypto_a.xid(), + }, + &session_key, + &accept_body, + [22; wire::encrypted_message::NONCE_SIZE], + ); + + let outputs_accept = run_engine( + &mut a, + now, + EngineInput::Incoming(wire::encode_record(&accept_record)), + &crypto_a, + ); + + assert!(outputs_accept.iter().any(|output| matches!( + output, + EngineOutput::OpenAccepted { + open_id: OpenId(12), + stream_id: id, + response_head, + response_prefix: None, + } if *id == stream_id && response_head == b"resp-head" + ))); + assert!(outputs_accept.iter().any(|output| matches!( + output, + EngineOutput::InboundData { + stream_id: id, + dir: Direction::Response, + bytes, + } if *id == stream_id && bytes == b"resp" + ))); +} + +#[test] +fn delayed_ack_only_does_not_consume_sequence_space() { + let mut harness = Harness::new(EngineConfig::default()); + + harness.send_a(EngineInput::OpenStream { + open_id: OpenId(21), + request_head: b"open-head".to_vec(), + request_prefix: None, + config: StreamConfig::default(), + }); + + let outputs_a = harness.drain_a(); + let _outputs_b = harness.drain_b(); + let stream_id = outputs_a + .iter() + .find_map(|output| match output { + EngineOutput::OpenStarted { stream_id, .. } => Some(*stream_id), + _ => None, + }) + .unwrap(); + + harness.now += EngineConfig::default().stream_ack_delay; + harness.send_b(EngineInput::TimerExpired); + + let outputs_b = harness.drain_b(); + assert!(outputs_b + .iter() + .any(|output| matches!(output, EngineOutput::WriteMessage { tracked: None, .. }))); + + let stream = harness.b.streams.get(&stream_id).unwrap(); + assert!(stream.control().in_flight.is_empty()); + assert_eq!(stream.control().next_tx_seq, StreamSeq::START); +} + +#[test] +fn half_window_progress_flushes_ack_before_timer() { + let config = EngineConfig::default(); + let crypto_a = TestCrypto::new(61); + let crypto_b = TestCrypto::new(62); + let peer_b = crypto_b.peer(); + let session_key = SymmetricKey::from_data([8; SymmetricKey::SYMMETRIC_KEY_SIZE]); + let mut a = Engine::new(config, crypto_a.identity.clone(), Some(peer_b)); + a.state.peer.as_mut().unwrap().session = PeerSession::Connected { + session_key: session_key.clone(), + keepalive: KeepAliveState::default(), + }; + + let now = Instant::now(); + let stream_id = StreamId(StreamNamespace::for_local(crypto_b.xid(), crypto_a.xid()).bit() | 1); + let messages = [ + StreamMessage { + tx_seq: StreamSeq(1), + ack: None, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Open(crate::wire::stream::StreamFrameOpen { + stream_id, + request_head: b"open".to_vec(), + request_prefix: None, + }), + }, + StreamMessage { + tx_seq: StreamSeq(2), + ack: None, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(crate::wire::stream::StreamFrameData { + stream_id, + dir: Direction::Request, + chunk: BodyChunk { + bytes: b"a".to_vec(), + fin: false, + }, + }), + }, + StreamMessage { + tx_seq: StreamSeq(3), + ack: None, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(crate::wire::stream::StreamFrameData { + stream_id, + dir: Direction::Request, + chunk: BodyChunk { + bytes: b"b".to_vec(), + fin: false, + }, + }), + }, + StreamMessage { + tx_seq: StreamSeq(4), + ack: None, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(crate::wire::stream::StreamFrameData { + stream_id, + dir: Direction::Request, + chunk: BodyChunk { + bytes: b"c".to_vec(), + fin: false, + }, + }), + }, + ]; + + for message in messages.iter().take(3) { + let body = StreamBody::Message(message.clone()); + let record = wire::stream::encrypt_stream( + QlHeader { + sender: crypto_b.xid(), + recipient: crypto_a.xid(), + }, + &session_key, + &body, + [message.tx_seq.0 as u8; wire::encrypted_message::NONCE_SIZE], + ); + let outputs = run_engine( + &mut a, + now, + EngineInput::Incoming(wire::encode_record(&record)), + &crypto_a, + ); + assert!(!outputs + .iter() + .any(|output| matches!(output, EngineOutput::WriteMessage { tracked: None, .. }))); + } + + let body = StreamBody::Message(messages[3].clone()); + let record = wire::stream::encrypt_stream( + QlHeader { + sender: crypto_b.xid(), + recipient: crypto_a.xid(), + }, + &session_key, + &body, + [4; wire::encrypted_message::NONCE_SIZE], + ); + let outputs = run_engine( + &mut a, + now, + EngineInput::Incoming(wire::encode_record(&record)), + &crypto_a, + ); + + assert!(outputs + .iter() + .any(|output| matches!(output, EngineOutput::WriteMessage { tracked: None, .. }))); +} + +#[test] +fn out_of_order_loss_reports_selective_ack_bitmap() { + let crypto_a = TestCrypto::new(71); + let crypto_b = TestCrypto::new(72); + let session_key = SymmetricKey::from_data([3; SymmetricKey::SYMMETRIC_KEY_SIZE]); + let peer_b = crypto_b.peer(); + let mut a = connected_engine(&crypto_a, peer_b, session_key.clone()); + + let now = Instant::now(); + let stream_id = StreamId(StreamNamespace::for_local(crypto_b.xid(), crypto_a.xid()).bit() | 1); + let messages = [ + StreamMessage { + tx_seq: StreamSeq(1), + ack: None, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Open(StreamFrameOpen { + stream_id, + request_head: b"open".to_vec(), + request_prefix: None, + }), + }, + StreamMessage { + tx_seq: StreamSeq(2), + ack: None, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(StreamFrameData { + stream_id, + dir: Direction::Request, + chunk: BodyChunk { + bytes: b"a".to_vec(), + fin: false, + }, + }), + }, + StreamMessage { + tx_seq: StreamSeq(4), + ack: None, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(StreamFrameData { + stream_id, + dir: Direction::Request, + chunk: BodyChunk { + bytes: b"c".to_vec(), + fin: false, + }, + }), + }, + StreamMessage { + tx_seq: StreamSeq(5), + ack: None, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(StreamFrameData { + stream_id, + dir: Direction::Request, + chunk: BodyChunk { + bytes: b"d".to_vec(), + fin: false, + }, + }), + }, + ]; + + for message in &messages[..2] { + let record = wire::stream::encrypt_stream( + QlHeader { + sender: crypto_b.xid(), + recipient: crypto_a.xid(), + }, + &session_key, + &StreamBody::Message(message.clone()), + [message.tx_seq.0 as u8; wire::encrypted_message::NONCE_SIZE], + ); + let outputs = run_engine( + &mut a, + now, + EngineInput::Incoming(wire::encode_record(&record)), + &crypto_a, + ); + assert!(!outputs + .iter() + .any(|output| matches!(output, EngineOutput::WriteMessage { tracked: None, .. }))); + } + + let record4 = wire::stream::encrypt_stream( + QlHeader { + sender: crypto_b.xid(), + recipient: crypto_a.xid(), + }, + &session_key, + &StreamBody::Message(messages[2].clone()), + [4; wire::encrypted_message::NONCE_SIZE], + ); + let outputs4 = run_engine( + &mut a, + now, + EngineInput::Incoming(wire::encode_record(&record4)), + &crypto_a, + ); + let (ack_token4, ack_tracked4, ack_bytes4) = take_single_write(&outputs4); + assert_eq!(ack_tracked4, None); + let (_, ack_body4) = decode_stream_body(&ack_bytes4, &session_key); + assert!(matches!( + ack_body4, + StreamBody::Ack(StreamAckBody { + stream_id: id, + ack: StreamAck { + base: StreamSeq(2), + bitmap: 0b0000_0010, + }, + .. + }) if id == stream_id + )); + assert!(!outputs4 + .iter() + .any(|output| matches!(output, EngineOutput::InboundData { .. }))); + // the engine only starts a new outbound write after the previous one reports + // `WriteCompleted`. We need to retire the ACK-only write for seq 4 here so the + // follow-up out-of-order receive for seq 5 can emit its own updated ACK body. + let _ = run_engine( + &mut a, + now, + EngineInput::WriteCompleted { + token: ack_token4, + tracked: ack_tracked4, + result: Ok(()), + }, + &crypto_a, + ); + + let record5 = wire::stream::encrypt_stream( + QlHeader { + sender: crypto_b.xid(), + recipient: crypto_a.xid(), + }, + &session_key, + &StreamBody::Message(messages[3].clone()), + [5; wire::encrypted_message::NONCE_SIZE], + ); + let outputs5 = run_engine( + &mut a, + now, + EngineInput::Incoming(wire::encode_record(&record5)), + &crypto_a, + ); + let (_, _, ack_bytes5) = take_single_write(&outputs5); + let (_, ack_body5) = decode_stream_body(&ack_bytes5, &session_key); + assert!(matches!( + ack_body5, + StreamBody::Ack(StreamAckBody { + stream_id: id, + ack: StreamAck { + base: StreamSeq(2), + bitmap: 0b0000_0110, + }, + .. + }) if id == stream_id + )); + assert!(!outputs5 + .iter() + .any(|output| matches!(output, EngineOutput::InboundData { .. }))); +} + +#[test] +fn selective_ack_only_body_retires_acked_gap_tail() { + let crypto_a = TestCrypto::new(81); + let crypto_b = TestCrypto::new(82); + let session_key = SymmetricKey::from_data([2; SymmetricKey::SYMMETRIC_KEY_SIZE]); + let peer_b = crypto_b.peer(); + let mut a = connected_engine(&crypto_a, peer_b, session_key.clone()); + + let now = Instant::now(); + let stream_id = a + .state + .next_stream_id(StreamNamespace::for_local(crypto_a.xid(), crypto_b.xid())); + let mut stream = StreamState::Initiator(InitiatorStream { + meta: StreamMeta { + stream_id, + last_activity: now, + }, + control: StreamControl::default(), + request: OutboundState::from_prefix(Direction::Request, false), + response: InboundState::new(), + accept: InitiatorAccept::Opening(OpenWaiter { + open_id: Some(OpenId(1)), + open_timeout_token: Token(999), + }), + }); + let control = stream.control_mut(); + control.next_tx_seq = StreamSeq(6); + control.insert_in_flight(InFlightFrame { + tx_seq: StreamSeq(1), + frame: StreamFrame::Open(StreamFrameOpen { + stream_id, + request_head: b"open".to_vec(), + request_prefix: None, + }), + attempt: 0, + }); + for (tx_seq, byte) in [(2, b'a'), (3, b'b'), (4, b'c'), (5, b'd')] { + control.insert_in_flight(InFlightFrame { + tx_seq: StreamSeq(tx_seq), + frame: StreamFrame::Data(StreamFrameData { + stream_id, + dir: Direction::Request, + chunk: BodyChunk { + bytes: vec![byte], + fin: false, + }, + }), + attempt: 0, + }); + } + a.streams.insert(stream_id, stream); + + let ack_record = wire::stream::encrypt_stream( + QlHeader { + sender: crypto_b.xid(), + recipient: crypto_a.xid(), + }, + &session_key, + &StreamBody::Ack(StreamAckBody { + stream_id, + ack: StreamAck { + base: StreamSeq(2), + bitmap: 0b0000_0110, + }, + valid_until: wire::now_secs().saturating_add(60), + }), + [9; wire::encrypted_message::NONCE_SIZE], + ); + + let outputs = run_engine( + &mut a, + now, + EngineInput::Incoming(wire::encode_record(&ack_record)), + &crypto_a, + ); + + assert!(!outputs + .iter() + .any(|output| matches!(output, EngineOutput::OutboundFailed { .. }))); + let stream = a.streams.get(&stream_id).unwrap(); + let remaining: Vec<_> = stream + .control() + .in_flight + .iter() + .map(|(seq, _)| seq) + .collect(); + assert_eq!(remaining, vec![StreamSeq(3)]); + assert_eq!(stream.control().next_tx_seq, StreamSeq(6)); +} + +#[test] +fn timeout_retransmit_reuses_original_tx_seq_and_slot() { + let config = EngineConfig::default(); + let crypto_a = TestCrypto::new(91); + let crypto_b = TestCrypto::new(92); + let peer_b = crypto_b.peer(); + let session_key = SymmetricKey::from_data([1; SymmetricKey::SYMMETRIC_KEY_SIZE]); + let mut a = connected_engine(&crypto_a, peer_b, session_key.clone()); + + let now = Instant::now(); + let outputs_open = run_engine( + &mut a, + now, + EngineInput::OpenStream { + open_id: OpenId(44), + request_head: b"open".to_vec(), + request_prefix: None, + config: StreamConfig::default(), + }, + &crypto_a, + ); + let (token, tracked, bytes) = take_single_write(&outputs_open); + let tracked = tracked.unwrap(); + let (_, initial_body) = decode_stream_body(&bytes, &session_key); + assert!(matches!( + initial_body, + StreamBody::Message(StreamMessage { + tx_seq: StreamSeq(1), + frame: StreamFrame::Open(_), + .. + }) + )); + + let _outputs_written = run_engine( + &mut a, + now, + EngineInput::WriteCompleted { + token, + tracked: Some(tracked), + result: Ok(()), + }, + &crypto_a, + ); + + let stream = a.streams.get(&tracked.stream_id).unwrap(); + assert_eq!(stream.control().in_flight.len(), 1); + assert!(stream.control().in_flight.contains_key(&StreamSeq::START)); + assert_eq!(stream.control().next_tx_seq, StreamSeq(2)); + + let outputs_timeout = run_engine( + &mut a, + now + config.stream_ack_timeout, + EngineInput::TimerExpired, + &crypto_a, + ); + let (_, retransmit_tracked, retransmit_bytes) = take_single_write(&outputs_timeout); + assert_eq!(retransmit_tracked, Some(tracked)); + let (_, retransmit_body) = decode_stream_body(&retransmit_bytes, &session_key); + assert!(matches!( + retransmit_body, + StreamBody::Message(StreamMessage { + tx_seq: StreamSeq(1), + frame: StreamFrame::Open(_), + .. + }) + )); + + let stream = a.streams.get(&tracked.stream_id).unwrap(); + assert_eq!(stream.control().in_flight.len(), 1); + assert!(stream.control().in_flight.contains_key(&StreamSeq::START)); + assert_eq!(stream.control().next_tx_seq, StreamSeq(2)); + assert_eq!( + stream + .control() + .in_flight + .get(&StreamSeq::START) + .unwrap() + .attempt, + 1 + ); +} + +#[test] +fn replayed_heartbeat_is_ignored() { + let crypto_a = TestCrypto::new(101); + let crypto_b = TestCrypto::new(102); + let session_key = SymmetricKey::from_data([4; SymmetricKey::SYMMETRIC_KEY_SIZE]); + let peer_b = crypto_b.peer(); + let mut a = connected_engine(&crypto_a, peer_b, session_key.clone()); + let now = Instant::now(); + let heartbeat = wire::heartbeat::encrypt_heartbeat( + QlHeader { + sender: crypto_b.xid(), + recipient: crypto_a.xid(), + }, + &session_key, + wire::heartbeat::HeartbeatBody { + meta: wire::ControlMeta { + packet_id: PacketId(7), + valid_until: wire::now_secs().saturating_add(60), + }, + }, + [3; wire::encrypted_message::NONCE_SIZE], + ); + let bytes = wire::encode_record(&heartbeat); + + let first = run_engine(&mut a, now, EngineInput::Incoming(bytes.clone()), &crypto_a); + assert!(first + .iter() + .any(|output| matches!(output, EngineOutput::WriteMessage { tracked: None, .. }))); + + let second = run_engine(&mut a, now, EngineInput::Incoming(bytes), &crypto_a); + assert!(!second + .iter() + .any(|output| matches!(output, EngineOutput::WriteMessage { tracked: None, .. }))); +} + +#[test] +fn replayed_unpair_is_ignored_after_rebind() { + let config = EngineConfig::default(); + let crypto_a = TestCrypto::new(111); + let crypto_b = TestCrypto::new(112); + let peer_b = crypto_b.peer(); + let session_key = SymmetricKey::from_data([5; SymmetricKey::SYMMETRIC_KEY_SIZE]); + let mut a = Engine::new(config, crypto_a.identity.clone(), Some(peer_b.clone())); + a.state.peer.as_mut().unwrap().session = PeerSession::Connected { + session_key, + keepalive: KeepAliveState::default(), + }; + let now = Instant::now(); + let bytes = wire::encode_record(&wire::unpair::build_unpair_record( + &crypto_b.identity, + QlHeader { + sender: crypto_b.xid(), + recipient: crypto_a.xid(), + }, + wire::ControlMeta { + packet_id: PacketId(9), + valid_until: wire::now_secs().saturating_add(60), + }, + )); + + let first = run_engine(&mut a, now, EngineInput::Incoming(bytes.clone()), &crypto_a); + assert!(first + .iter() + .any(|output| matches!(output, EngineOutput::ClearPeer))); + assert!(a.state.peer.is_none()); + + let _ = run_engine( + &mut a, + now, + EngineInput::BindPeer(peer_b.clone()), + &crypto_a, + ); + assert!(a.state.peer.is_some()); + + let second = run_engine(&mut a, now, EngineInput::Incoming(bytes), &crypto_a); + assert!(!second + .iter() + .any(|output| matches!(output, EngineOutput::ClearPeer))); + assert_eq!( + a.state.peer.as_ref().map(|peer| peer.peer), + Some(peer_b.peer) + ); +} diff --git a/ql2/src/lib.rs b/ql2/src/lib.rs index e07f84a3..f89b06d4 100644 --- a/ql2/src/lib.rs +++ b/ql2/src/lib.rs @@ -1,14 +1,14 @@ pub mod engine; mod id; pub mod platform; -pub mod rpc; -pub mod runtime; +// pub mod rpc; +// pub mod runtime; pub mod wire; pub use id::*; -#[cfg(test)] -mod tests; +// #[cfg(test)] +// mod tests; #[derive(Debug, Clone, PartialEq, Eq)] pub struct Peer { @@ -43,9 +43,3 @@ pub enum QlError { #[error("cancelled")] Cancelled, } - -impl From for QlError { - fn from(_: crate::runtime::pipe::PipeClosed) -> Self { - Self::Cancelled - } -} diff --git a/ql2/src/platform.rs b/ql2/src/platform.rs index 4472d927..168944d2 100644 --- a/ql2/src/platform.rs +++ b/ql2/src/platform.rs @@ -4,26 +4,40 @@ use bc_components::{ MLDSAPrivateKey, MLDSAPublicKey, MLKEMPrivateKey, MLKEMPublicKey, SigningPublicKey, XID, }; -use crate::{ - runtime::{HandlerEvent, PeerSession}, - Peer, QlError, -}; +use crate::{engine::PeerSession, Peer, QlError}; pub type PlatformFuture<'a, T> = Pin + 'a>>; -pub trait QlCrypto { - fn signing_private_key(&self) -> &MLDSAPrivateKey; - fn signing_public_key(&self) -> &MLDSAPublicKey; - fn encapsulation_private_key(&self) -> &MLKEMPrivateKey; - fn encapsulation_public_key(&self) -> &MLKEMPublicKey; - - fn fill_random_bytes(&self, data: &mut [u8]); +#[derive(Debug, Clone)] +pub struct QlIdentity { + pub xid: XID, + pub signing_private_key: MLDSAPrivateKey, + pub signing_public_key: MLDSAPublicKey, + pub encapsulation_private_key: MLKEMPrivateKey, + pub encapsulation_public_key: MLKEMPublicKey, +} - fn xid(&self) -> XID { - XID::new(SigningPublicKey::MLDSA(self.signing_public_key().clone())) +impl QlIdentity { + pub fn from_keys( + signing_private_key: MLDSAPrivateKey, + signing_public_key: MLDSAPublicKey, + encapsulation_private_key: MLKEMPrivateKey, + encapsulation_public_key: MLKEMPublicKey, + ) -> Self { + Self { + xid: XID::new(SigningPublicKey::MLDSA(signing_public_key.clone())), + signing_private_key, + signing_public_key, + encapsulation_private_key, + encapsulation_public_key, + } } } +pub trait QlCrypto { + fn fill_random_bytes(&self, data: &mut [u8]); +} + pub trait QlPlatform: QlCrypto { fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>>; fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()>; @@ -33,5 +47,5 @@ pub trait QlPlatform: QlCrypto { fn clear_peer(&self); fn handle_peer_status(&self, peer: XID, session: &PeerSession); - fn handle_inbound(&self, event: HandlerEvent); + // fn handle_inbound(&self, event: crate::runtime::HandlerEvent); } diff --git a/ql2/src/runtime/command.rs b/ql2/src/runtime/command.rs index 952db98b..e9eafd33 100644 --- a/ql2/src/runtime/command.rs +++ b/ql2/src/runtime/command.rs @@ -1,5 +1,5 @@ use crate::{ - runtime::{pipe, AcceptedStreamDelivery, StreamConfig}, + runtime::{AcceptedStreamDelivery, StreamConfig}, wire::stream::{Direction, RejectCode, ResetCode}, Peer, QlError, StreamId, }; @@ -13,7 +13,7 @@ pub(crate) enum RuntimeCommand { Unpair, OpenStream { request_head: Vec, - request_pipe: pipe::PipeReader, + request_rx: async_channel::Receiver>, accepted: oneshot::Sender>, start: oneshot::Sender>, config: StreamConfig, @@ -21,7 +21,7 @@ pub(crate) enum RuntimeCommand { AcceptStream { stream_id: StreamId, response_head: Vec, - response_pipe: pipe::PipeReader, + response_rx: async_channel::Receiver>, }, RejectStream { stream_id: StreamId, @@ -30,11 +30,6 @@ pub(crate) enum RuntimeCommand { PollStream { stream_id: StreamId, }, - AdvanceInboundCredit { - stream_id: StreamId, - dir: Direction, - amount: u64, - }, ResetOutbound { stream_id: StreamId, dir: Direction, diff --git a/ql2/src/runtime/driver.rs b/ql2/src/runtime/driver.rs index 4a9e2945..285b7709 100644 --- a/ql2/src/runtime/driver.rs +++ b/ql2/src/runtime/driver.rs @@ -1,7 +1,6 @@ use std::{ collections::{HashMap, VecDeque}, future::Future, - io::Read, task::Poll, time::Instant, }; @@ -14,9 +13,9 @@ use crate::{ runtime::{ command::RuntimeCommand, handle::{InboundByteStream, InboundStream, StreamResponder}, - pipe, AcceptedStreamDelivery, HandlerEvent, Runtime, + AcceptedStreamDelivery, HandlerEvent, InboundEvent, Runtime, }, - wire::stream::{Direction, ResetCode}, + wire::stream::{BodyChunk, Direction, ResetCode}, QlError, StreamId, }; @@ -38,127 +37,117 @@ enum DriverEvent { } struct PendingOpen { - request_pipe: pipe::PipeReader, + request_rx: async_channel::Receiver>, start_tx: oneshot::Sender>, accepted_tx: oneshot::Sender>, } struct PendingAcceptDelivery { tx: oneshot::Sender>, - response_reader: pipe::PipeReader, -} - -#[derive(Debug, Clone, Copy)] -struct PendingPull { - offset: u64, - max_len: usize, + response_rx: async_channel::Receiver, } enum OutboundIo { Open { dir: Direction, - pipe: pipe::PipeReader, - pending_pull: Option, + rx: async_channel::Receiver>, + pending_pull: bool, finish_queued: bool, }, Closed, } impl OutboundIo { - fn new(dir: Direction, pipe: pipe::PipeReader) -> Self { + fn new(dir: Direction, rx: async_channel::Receiver>) -> Self { Self::Open { dir, - pipe, - pending_pull: None, + rx, + pending_pull: false, finish_queued: false, } } - fn set_pending_pull(&mut self, offset: u64, max_len: usize) { + fn set_pending_pull(&mut self) { if let Self::Open { pending_pull, .. } = self { - *pending_pull = Some(PendingPull { offset, max_len }); - } - } - - fn release_to(&mut self, recv_offset: u64) { - if let Self::Open { pipe, .. } = self { - pipe.release_to(recv_offset); + *pending_pull = true; } } fn close(&mut self) { - if let Self::Open { pipe, .. } = self { - pipe.close(); - } *self = Self::Closed; } fn poll_pending(&mut self, stream_id: StreamId, pending: &mut VecDeque) { let Self::Open { dir, - pipe, + rx, pending_pull, finish_queued, } = self else { return; }; - if let Some(pull) = pending_pull.take() { - if let Some(mut grant) = pipe.reserve_at(pull.offset, pull.max_len) { - let mut bytes = vec![0; grant.len()]; - let _ = grant.read_exact(&mut bytes); + + if !*pending_pull { + if rx.is_closed() && !*finish_queued { + *finish_queued = true; + pending.push_back(EngineInput::OutboundFinished { stream_id, dir: *dir }); + } + return; + } + + match rx.try_recv() { + Ok(bytes) => { + if bytes.is_empty() { + return; + } + *pending_pull = false; pending.push_back(EngineInput::OutboundData { stream_id, dir: *dir, - offset: grant.offset(), bytes, }); - return; + if rx.is_closed() && rx.is_empty() && !*finish_queued { + *finish_queued = true; + pending.push_back(EngineInput::OutboundFinished { stream_id, dir: *dir }); + } } - if pipe.writer_finished() && pipe.all_sent() { + Err(async_channel::TryRecvError::Empty) => { + if rx.is_closed() && !*finish_queued { + *pending_pull = false; + *finish_queued = true; + pending.push_back(EngineInput::OutboundFinished { stream_id, dir: *dir }); + } + } + Err(async_channel::TryRecvError::Closed) => { + *pending_pull = false; if !*finish_queued { *finish_queued = true; - pending.push_back(EngineInput::OutboundFinished { - stream_id, - dir: *dir, - final_offset: pipe.sent_offset(), - }); + pending.push_back(EngineInput::OutboundFinished { stream_id, dir: *dir }); } - return; } - *pending_pull = Some(pull); - return; - } - - if pipe.writer_finished() && pipe.all_sent() && !*finish_queued { - *finish_queued = true; - pending.push_back(EngineInput::OutboundFinished { - stream_id, - dir: *dir, - final_offset: pipe.sent_offset(), - }); } } } enum InboundIo { - Open(pipe::PipeWriter), + Open(async_channel::Sender), Closed, } impl InboundIo { - fn new(pipe: pipe::PipeWriter) -> Self { - Self::Open(pipe) + fn new(tx: async_channel::Sender) -> Self { + Self::Open(tx) } fn write_or_cancel( &mut self, stream_id: StreamId, dir: Direction, - bytes: &[u8], + bytes: Vec, pending: &mut VecDeque, ) { - let Self::Open(pipe) = self else { + let Self::Open(tx) = self else { pending.push_back(EngineInput::ResetInbound { stream_id, dir, @@ -166,40 +155,55 @@ impl InboundIo { }); return; }; - match pipe.try_write(bytes) { - Ok(n) if n == bytes.len() => {} - Ok(_) | Err(_) => { - pipe.close(); - *self = Self::Closed; - pending.push_back(EngineInput::ResetInbound { - stream_id, - dir, - code: ResetCode::Cancelled, - }); - } + if tx.try_send(InboundEvent::Data(bytes)).is_err() { + tx.close(); + *self = Self::Closed; + pending.push_back(EngineInput::ResetInbound { + stream_id, + dir, + code: ResetCode::Cancelled, + }); } } fn finish(&mut self) { - if let Self::Open(pipe) = self { - pipe.finish(); + if let Self::Open(tx) = self { + let _ = tx.try_send(InboundEvent::Finished); + tx.close(); } *self = Self::Closed; } fn fail(&mut self, error: QlError) { - if let Self::Open(pipe) = self { - pipe.fail(error); + if let Self::Open(tx) = self { + let _ = tx.try_send(InboundEvent::Failed(error)); + tx.close(); } *self = Self::Closed; } fn close(&mut self) { - if let Self::Open(pipe) = self { - pipe.close(); + if let Self::Open(tx) = self { + let _ = tx.try_send(InboundEvent::Failed(QlError::Cancelled)); + tx.close(); } *self = Self::Closed; } + + fn apply_prefix( + &mut self, + stream_id: StreamId, + dir: Direction, + prefix: &BodyChunk, + pending: &mut VecDeque, + ) { + if !prefix.bytes.is_empty() { + self.write_or_cancel(stream_id, dir, prefix.bytes.clone(), pending); + } + if prefix.fin { + self.finish(); + } + } } enum PendingAcceptState { @@ -278,8 +282,12 @@ struct DriverState { } impl DriverState { - fn new(config: engine::EngineConfig, peer: Option) -> Self { - let engine = Engine::new(config, peer); + fn new( + config: engine::EngineConfig, + local_xid: bc_components::XID, + peer: Option, + ) -> Self { + let engine = Engine::new(config, local_xid, peer); Self { engine, pending_inputs: VecDeque::new(), @@ -303,7 +311,7 @@ impl DriverState { RuntimeCommand::Incoming(bytes) => self.push_input(EngineInput::Incoming(bytes)), RuntimeCommand::OpenStream { request_head, - request_pipe, + request_rx, accepted, start, config, @@ -313,7 +321,7 @@ impl DriverState { self.pending_opens.insert( open_id, PendingOpen { - request_pipe, + request_rx, start_tx: start, accepted_tx: accepted, }, @@ -321,25 +329,27 @@ impl DriverState { self.push_input(EngineInput::OpenStream { open_id, request_head, + request_prefix: None, config, }); } RuntimeCommand::AcceptStream { stream_id, response_head, - response_pipe, + response_rx, } => { if let Some(DriverStreamIo::Responder { response, .. }) = self.streams.get_mut(&stream_id) { *response = ResponderResponseIo::Streaming(OutboundIo::new( Direction::Response, - response_pipe, + response_rx, )); } self.push_input(EngineInput::AcceptStream { stream_id, response_head, + response_prefix: None, }); } RuntimeCommand::RejectStream { stream_id, code } => { @@ -351,15 +361,6 @@ impl DriverState { self.push_input(EngineInput::RejectStream { stream_id, code }); } RuntimeCommand::PollStream { stream_id } => self.poll_stream(stream_id), - RuntimeCommand::AdvanceInboundCredit { - stream_id, - dir, - amount, - } => self.push_input(EngineInput::InboundConsumed { - stream_id, - dir, - amount, - }), RuntimeCommand::ResetOutbound { stream_id, dir, @@ -418,7 +419,12 @@ impl DriverState { impl Runtime

{ pub async fn run(self) { let runtime_tx = self.tx.upgrade().expect("runtime tx"); - let mut state = DriverState::new(self.config.engine, self.platform.load_peer().await); + let local_xid = self.platform.xid(); + let mut state = DriverState::new( + self.config.engine, + local_xid, + self.platform.load_peer().await, + ); let mut in_flight: Option> = None; loop { @@ -542,15 +548,15 @@ impl Runtime

{ return; }; let _ = pending.start_tx.send(Ok(stream_id)); - let (response_reader, response_writer) = pipe::pipe(self.config.pipe_size_bytes); + let (response_tx, response_rx) = async_channel::unbounded(); streams.insert( stream_id, DriverStreamIo::Initiator { - request: OutboundIo::new(Direction::Request, pending.request_pipe), - response: InboundIo::new(response_writer), + request: OutboundIo::new(Direction::Request, pending.request_rx), + response: InboundIo::new(response_tx), pending_accept: PendingAcceptState::Waiting(PendingAcceptDelivery { tx: pending.accepted_tx, - response_reader, + response_rx, }), }, ); @@ -558,19 +564,26 @@ impl Runtime

{ EngineOutput::OpenAccepted { stream_id, response_head, + response_prefix, .. } => { - let Some(DriverStreamIo::Initiator { pending_accept, .. }) = - streams.get_mut(&stream_id) + let Some(DriverStreamIo::Initiator { + response, + pending_accept, + .. + }) = streams.get_mut(&stream_id) else { return; }; + if let Some(prefix) = response_prefix.as_ref() { + response.apply_prefix(stream_id, Direction::Response, prefix, pending_inputs); + } match std::mem::replace(pending_accept, PendingAcceptState::Resolved) { PendingAcceptState::Waiting(delivery) => { let _ = delivery.tx.send(Ok(AcceptedStreamDelivery { stream_id, response_head, - response: delivery.response_reader, + response: delivery.response_rx, tx: runtime_tx.clone(), })); } @@ -607,12 +620,17 @@ impl Runtime

{ EngineOutput::InboundStreamOpened { stream_id, request_head, + request_prefix, } => { - let (request_reader, request_writer) = pipe::pipe(self.config.pipe_size_bytes); + let (request_tx, request_rx) = async_channel::unbounded(); + let mut request = InboundIo::new(request_tx); + if let Some(prefix) = request_prefix.as_ref() { + request.apply_prefix(stream_id, Direction::Request, prefix, pending_inputs); + } streams.insert( stream_id, DriverStreamIo::Responder { - request: InboundIo::new(request_writer), + request, response: ResponderResponseIo::Pending, }, ); @@ -623,14 +641,10 @@ impl Runtime

{ request: InboundByteStream::new( stream_id, Direction::Request, - request_reader, - runtime_tx.clone(), - ), - respond_to: StreamResponder::new( - stream_id, - self.config.pipe_size_bytes, + request_rx, runtime_tx.clone(), ), + respond_to: StreamResponder::new(stream_id, runtime_tx.clone()), })); } EngineOutput::InboundData { @@ -640,7 +654,7 @@ impl Runtime

{ } => { if let Some(stream) = streams.get_mut(&stream_id) { if let Some(inbound) = stream.inbound_mut(dir) { - inbound.write_or_cancel(stream_id, dir, &bytes, pending_inputs); + inbound.write_or_cancel(stream_id, dir, bytes, pending_inputs); } } } @@ -662,30 +676,14 @@ impl Runtime

{ } } } - EngineOutput::NeedOutboundData { - stream_id, - dir, - offset, - max_len, - } => { + EngineOutput::NeedOutboundData { stream_id, dir } => { if let Some(stream) = streams.get_mut(&stream_id) { if let Some(outbound) = stream.outbound_mut(dir) { - outbound.set_pending_pull(offset, max_len); + outbound.set_pending_pull(); } } poll_stream(streams, pending_inputs, stream_id); } - EngineOutput::ReleaseOutboundThrough { - stream_id, - dir, - recv_offset, - } => { - if let Some(stream) = streams.get_mut(&stream_id) { - if let Some(outbound) = stream.outbound_mut(dir) { - outbound.release_to(recv_offset); - } - } - } EngineOutput::OutboundClosed { stream_id, dir } | EngineOutput::OutboundFailed { stream_id, dir, .. } => { if let Some(stream) = streams.get_mut(&stream_id) { diff --git a/ql2/src/runtime/handle.rs b/ql2/src/runtime/handle.rs index 43312dae..443d81ab 100644 --- a/ql2/src/runtime/handle.rs +++ b/ql2/src/runtime/handle.rs @@ -4,22 +4,17 @@ use std::{ task::{Context, Poll}, }; -use async_channel::Sender; +use async_channel::{Receiver, Sender}; use crate::{ - runtime::{ - command::RuntimeCommand, - pipe::{self, ReadReady}, - AcceptedStreamDelivery, StreamConfig, - }, + runtime::{command::RuntimeCommand, AcceptedStreamDelivery, InboundEvent, StreamConfig}, wire::stream::{Direction, RejectCode, ResetCode}, Peer, QlError, StreamId, }; #[derive(Clone)] pub struct RuntimeHandle { - pub(crate) tx: async_channel::Sender, - pub(crate) pipe_size_bytes: usize, + pub(crate) tx: Sender, } pub struct PendingStream { @@ -45,15 +40,14 @@ pub struct InboundStream { #[derive(Debug)] pub struct StreamResponder { stream_id: StreamId, - pipe_size_bytes: usize, - tx: async_channel::Sender, + tx: Sender, armed: bool, } pub struct InboundByteStream { stream_id: StreamId, dir: Direction, - pipe: pipe::PipeReader, + rx: Receiver, tx: Sender, finished: bool, } @@ -71,7 +65,7 @@ impl std::fmt::Debug for InboundByteStream { pub struct OutboundByteStream { stream_id: StreamId, dir: Direction, - pipe: Option>, + chunks: Option>>, tx: Sender, } @@ -131,13 +125,13 @@ impl InboundByteStream { pub(crate) fn new( stream_id: StreamId, dir: Direction, - pipe: pipe::PipeReader, + rx: Receiver, tx: Sender, ) -> Self { Self { stream_id, dir, - pipe, + rx, tx, finished: false, } @@ -147,31 +141,20 @@ impl InboundByteStream { if self.finished { return Ok(None); } - match self.pipe.ready().await { - ReadReady::Data => { - let chunk = self.pipe.peek_buf().to_vec(); - let len = chunk.len(); - self.pipe.consume(len); - if len > 0 { - let _ = self - .tx - .send(RuntimeCommand::AdvanceInboundCredit { - stream_id: self.stream_id, - dir: self.dir, - amount: len as u64, - }) - .await; - } - Ok(Some(chunk)) - } - ReadReady::Eof => { + match self.rx.recv().await { + Ok(InboundEvent::Data(bytes)) => Ok(Some(bytes)), + Ok(InboundEvent::Finished) => { self.finished = true; Ok(None) } - ReadReady::Error(error) => { + Ok(InboundEvent::Failed(error)) => { self.finished = true; Err(error) } + Err(_) => { + self.finished = true; + Err(QlError::Cancelled) + } } } @@ -205,26 +188,32 @@ impl OutboundByteStream { pub(crate) fn new( stream_id: StreamId, dir: Direction, - pipe: pipe::PipeWriter, + chunks: Sender>, tx: Sender, ) -> Self { Self { stream_id, dir, - pipe: Some(pipe), + chunks: Some(chunks), tx, } } pub async fn write(&mut self, bytes: &[u8]) -> Result { - let pipe = self.pipe.as_mut().expect("stream not finished or reset"); - let written = pipe.write(bytes).await?; + if bytes.is_empty() { + return Ok(0); + } + let sender = self.chunks.as_ref().expect("stream not finished or reset"); + sender + .send(bytes.to_vec()) + .await + .map_err(|_| QlError::Cancelled)?; self.tx .try_send(RuntimeCommand::PollStream { stream_id: self.stream_id, }) .map_err(|_| QlError::Cancelled)?; - Ok(written) + Ok(bytes.len()) } pub async fn write_all(&mut self, mut bytes: &[u8]) -> Result<(), QlError> { @@ -239,21 +228,19 @@ impl OutboundByteStream { } pub async fn finish(mut self) -> Result<(), QlError> { - let Some(mut pipe) = self.pipe.take() else { + if self.chunks.take().is_none() { return Ok(()); - }; - pipe.finish(); + } self.tx .try_send(RuntimeCommand::PollStream { stream_id: self.stream_id, }) .map_err(|_| QlError::Cancelled)?; - pipe.closed().await; Ok(()) } pub async fn reset(mut self, code: ResetCode) -> Result<(), QlError> { - self.pipe.take(); + self.chunks.take(); self.tx .send(RuntimeCommand::ResetOutbound { stream_id: self.stream_id, @@ -267,7 +254,7 @@ impl OutboundByteStream { impl Drop for OutboundByteStream { fn drop(&mut self) { - if self.pipe.take().is_none() { + if self.chunks.take().is_none() { return; } let _ = self.tx.try_send(RuntimeCommand::ResetOutbound { @@ -279,14 +266,9 @@ impl Drop for OutboundByteStream { } impl StreamResponder { - pub(crate) fn new( - stream_id: StreamId, - pipe_size_bytes: usize, - tx: async_channel::Sender, - ) -> Self { + pub(crate) fn new(stream_id: StreamId, tx: Sender) -> Self { Self { stream_id, - pipe_size_bytes, tx, armed: true, } @@ -294,18 +276,18 @@ impl StreamResponder { pub fn accept(mut self, response_head: Vec) -> Result { self.armed = false; - let (response_pipe, response_writer) = pipe::pipe(self.pipe_size_bytes); + let (response_tx, response_rx) = async_channel::bounded(1); self.tx .send_blocking(RuntimeCommand::AcceptStream { stream_id: self.stream_id, response_head, - response_pipe, + response_rx, }) .map_err(|_| QlError::Cancelled)?; Ok(OutboundByteStream::new( self.stream_id, Direction::Response, - response_writer, + response_tx, self.tx.clone(), )) } @@ -364,14 +346,14 @@ impl RuntimeHandle { request_head: Vec, config: StreamConfig, ) -> Result { - let (request_pipe, request_writer) = pipe::pipe(self.pipe_size_bytes); + let (request_tx, request_rx) = async_channel::bounded(1); let (accepted_tx, accepted_rx) = oneshot::channel(); let (start_tx, start_rx) = oneshot::channel(); self.tx .send(RuntimeCommand::OpenStream { request_head, - request_pipe, + request_rx, accepted: accepted_tx, start: start_tx, config, @@ -385,7 +367,7 @@ impl RuntimeHandle { request: OutboundByteStream::new( stream_id, Direction::Request, - request_writer, + request_tx, self.tx.clone(), ), accepted: PendingAccept { diff --git a/ql2/src/runtime/mod.rs b/ql2/src/runtime/mod.rs index dac4bb46..670cc8ae 100644 --- a/ql2/src/runtime/mod.rs +++ b/ql2/src/runtime/mod.rs @@ -3,33 +3,25 @@ pub use handle::{ PendingStream, RuntimeHandle, StreamResponder, }; -pub use crate::engine::{EngineConfig, InitiatorStage, KeepAliveConfig, PeerSession, Token}; +pub use crate::engine::{ + EngineConfig, InitiatorStage, KeepAliveConfig, PeerSession, StreamConfig, Token, +}; pub(crate) mod command; pub(crate) mod driver; pub mod handle; -pub(crate) mod pipe; - -use std::time::Duration; use crate::{platform::QlPlatform, StreamId}; -#[derive(Debug, Clone, Copy, Default)] -pub struct StreamConfig { - pub open_timeout: Option, -} - #[derive(Debug, Clone, Copy)] pub struct RuntimeConfig { pub engine: EngineConfig, - pub pipe_size_bytes: usize, } impl Default for RuntimeConfig { fn default() -> Self { Self { engine: EngineConfig::default(), - pipe_size_bytes: 2048, } } } @@ -37,7 +29,6 @@ impl Default for RuntimeConfig { impl RuntimeConfig { pub(crate) fn normalized(mut self) -> Self { self.engine = self.engine.normalized(); - self.pipe_size_bytes = self.pipe_size_bytes.max(self.engine.max_payload_bytes); self } } @@ -47,10 +38,17 @@ pub enum HandlerEvent { Stream(InboundStream), } +#[derive(Debug)] +pub(crate) enum InboundEvent { + Data(Vec), + Finished, + Failed(crate::QlError), +} + pub(crate) struct AcceptedStreamDelivery { pub stream_id: StreamId, pub response_head: Vec, - pub response: crate::runtime::pipe::PipeReader, + pub response: async_channel::Receiver, pub tx: async_channel::Sender, } @@ -74,9 +72,6 @@ where rx, tx: tx.downgrade(), }, - RuntimeHandle { - tx, - pipe_size_bytes: config.pipe_size_bytes, - }, + RuntimeHandle { tx }, ) } diff --git a/ql2/src/runtime/pipe.rs b/ql2/src/runtime/pipe.rs deleted file mode 100644 index 4c3bb65c..00000000 --- a/ql2/src/runtime/pipe.rs +++ /dev/null @@ -1,772 +0,0 @@ -use std::{ - cell::UnsafeCell, - io::{self, Read}, - mem::{self, MaybeUninit}, - ptr, - sync::{ - atomic::{AtomicU64, AtomicU8, Ordering}, - Arc, - }, - task::{Context, Poll}, -}; - -use atomic_waker::AtomicWaker; -use futures_lite::future::poll_fn; - -const PIPE_OPEN: u8 = 0; -const PIPE_FINISHED: u8 = 1; -const PIPE_FAILED: u8 = 2; -const PIPE_FAILED_TAKEN: u8 = 3; -const PIPE_CLOSED: u8 = 4; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum PipeState { - Open, - Finished, - Failed, - FailedTaken, - Closed, -} - -impl PipeState { - fn from_u8(value: u8) -> Self { - match value { - PIPE_OPEN => Self::Open, - PIPE_FINISHED => Self::Finished, - PIPE_FAILED => Self::Failed, - PIPE_FAILED_TAKEN => Self::FailedTaken, - PIPE_CLOSED => Self::Closed, - _ => unreachable!("invalid pipe state"), - } - } - - fn as_u8(self) -> u8 { - match self { - Self::Open => PIPE_OPEN, - Self::Finished => PIPE_FINISHED, - Self::Failed => PIPE_FAILED, - Self::FailedTaken => PIPE_FAILED_TAKEN, - Self::Closed => PIPE_CLOSED, - } - } -} - -pub fn pipe(cap: usize) -> (PipeReader, PipeWriter) { - assert!(cap > 0, "pipe capacity must be positive"); - - let mut storage = Vec::::with_capacity(cap); - let buffer = storage.as_mut_ptr(); - mem::forget(storage); - - let inner = Arc::new(PipeInner { - released: AtomicU64::new(0), - produced: AtomicU64::new(0), - state: AtomicU8::new(PIPE_OPEN), - error: UnsafeCell::new(MaybeUninit::uninit()), - readable: AtomicWaker::new(), - writable: AtomicWaker::new(), - closed: AtomicWaker::new(), - buffer, - cap, - }); - - ( - PipeReader { - inner: inner.clone(), - released: 0, - produced: 0, - sent: 0, - }, - PipeWriter { - inner, - released: 0, - produced: 0, - sealed: false, - }, - ) -} - -struct PipeInner { - released: AtomicU64, - produced: AtomicU64, - state: AtomicU8, - error: UnsafeCell>, - readable: AtomicWaker, - writable: AtomicWaker, - closed: AtomicWaker, - buffer: *mut u8, - cap: usize, -} - -unsafe impl Send for PipeInner {} -unsafe impl Sync for PipeInner {} - -impl Drop for PipeInner { - fn drop(&mut self) { - if PipeState::from_u8(self.state.load(Ordering::Acquire)) == PipeState::Failed { - unsafe { - self.error.get_mut().assume_init_drop(); - } - } - unsafe { - drop(Vec::from_raw_parts(self.buffer, 0, self.cap)); - } - } -} - -pub struct PipeWriter { - inner: Arc>, - released: u64, - produced: u64, - sealed: bool, -} - -pub struct PipeReader { - inner: Arc>, - released: u64, - produced: u64, - sent: u64, -} - -impl std::fmt::Debug for PipeWriter { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PipeWriter").finish_non_exhaustive() - } -} - -impl std::fmt::Debug for PipeReader { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PipeReader").finish_non_exhaustive() - } -} - -pub struct SendGrant<'a, E> { - inner: &'a PipeInner, - offset: u64, - len: usize, - position: usize, -} - -pub enum ReadReady { - Data, - Eof, - Error(E), -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct PipeClosed; - -impl PipeWriter { - pub fn try_write(&mut self, src: &[u8]) -> Result { - if src.is_empty() { - return Ok(0); - } - if self.sealed || self.is_closed() { - return Err(PipeClosed); - } - self.released = self.inner.released.load(Ordering::Acquire); - let n = self.available_capacity().min(src.len()); - if n == 0 { - return Ok(0); - } - unsafe { - write_bytes(self.inner.buffer, self.inner.cap, self.produced, &src[..n]); - } - self.produced = self.produced.saturating_add(n as u64); - self.inner.produced.store(self.produced, Ordering::Release); - self.inner.readable.wake(); - Ok(n) - } - - pub async fn write(&mut self, src: &[u8]) -> Result { - poll_fn(|cx| self.poll_write(cx, src)).await - } - - pub fn finish(&mut self) { - if self.sealed { - return; - } - self.sealed = true; - self.publish_state(PipeState::Finished); - } - - pub fn fail(&mut self, error: E) { - if self.sealed { - return; - } - self.sealed = true; - unsafe { - (*self.inner.error.get()).write(error); - } - match self.inner.state.compare_exchange( - PIPE_OPEN, - PIPE_FAILED, - Ordering::Release, - Ordering::Acquire, - ) { - Ok(_) => { - self.inner.readable.wake(); - } - Err(_) => unsafe { - (*self.inner.error.get()).assume_init_drop(); - }, - } - } - - pub fn close(&mut self) { - if self.sealed { - return; - } - self.sealed = true; - loop { - let current = PipeState::from_u8(self.inner.state.load(Ordering::Acquire)); - match current { - PipeState::Closed => return, - PipeState::Failed => { - if self - .inner - .state - .compare_exchange( - PIPE_FAILED, - PIPE_CLOSED, - Ordering::AcqRel, - Ordering::Acquire, - ) - .is_ok() - { - unsafe { - (*self.inner.error.get()).assume_init_drop(); - } - self.inner.readable.wake(); - self.inner.writable.wake(); - self.inner.closed.wake(); - return; - } - } - _ => { - if self - .inner - .state - .compare_exchange( - current.as_u8(), - PIPE_CLOSED, - Ordering::AcqRel, - Ordering::Acquire, - ) - .is_ok() - { - self.inner.readable.wake(); - self.inner.writable.wake(); - self.inner.closed.wake(); - return; - } - } - } - } - } - - pub async fn closed(&mut self) { - poll_fn(|cx| self.poll_closed(cx)).await - } - - fn poll_write(&mut self, cx: &mut Context<'_>, src: &[u8]) -> Poll> { - if src.is_empty() { - return Poll::Ready(Ok(0)); - } - if self.sealed || self.is_closed() { - return Poll::Ready(Err(PipeClosed)); - } - - let n = match self.poll_reserve(cx, src.len()) { - Poll::Ready(Ok(n)) => n, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - }; - - unsafe { - write_bytes(self.inner.buffer, self.inner.cap, self.produced, &src[..n]); - } - self.produced = self.produced.saturating_add(n as u64); - self.inner.produced.store(self.produced, Ordering::Release); - self.inner.readable.wake(); - Poll::Ready(Ok(n)) - } - - fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> { - if self.is_closed() { - return Poll::Ready(()); - } - self.inner.closed.register(cx.waker()); - if self.is_closed() { - self.inner.closed.take(); - Poll::Ready(()) - } else { - Poll::Pending - } - } - - fn poll_reserve( - &mut self, - cx: &mut Context<'_>, - want: usize, - ) -> Poll> { - self.released = self.inner.released.load(Ordering::Acquire); - let available = self.available_capacity(); - if available > 0 { - return Poll::Ready(Ok(available.min(want))); - } - - self.inner.writable.register(cx.waker()); - self.released = self.inner.released.load(Ordering::Acquire); - if self.is_closed() { - self.inner.writable.take(); - return Poll::Ready(Err(PipeClosed)); - } - let available = self.available_capacity(); - if available > 0 { - self.inner.writable.take(); - Poll::Ready(Ok(available.min(want))) - } else { - Poll::Pending - } - } - - fn available_capacity(&self) -> usize { - let used = self.produced.saturating_sub(self.released) as usize; - self.inner.cap.saturating_sub(used) - } - - fn publish_state(&mut self, next: PipeState) { - let _ = self.inner.state.compare_exchange( - PIPE_OPEN, - next.as_u8(), - Ordering::Release, - Ordering::Acquire, - ); - self.inner.readable.wake(); - } - - fn is_closed(&self) -> bool { - PipeState::from_u8(self.inner.state.load(Ordering::Acquire)) == PipeState::Closed - } - - #[cfg(test)] - fn state(&self) -> PipeState { - PipeState::from_u8(self.inner.state.load(Ordering::Acquire)) - } - - #[cfg(test)] - fn is_drained(&self) -> bool { - self.inner.released.load(Ordering::Acquire) >= self.inner.produced.load(Ordering::Acquire) - } -} - -impl Drop for PipeWriter { - fn drop(&mut self) { - if self.sealed { - return; - } - self.sealed = true; - self.publish_state(PipeState::Finished); - } -} - -impl PipeReader { - pub async fn ready(&mut self) -> ReadReady { - poll_fn(|cx| self.poll_ready(cx)).await - } - - pub fn peek_buf(&self) -> &[u8] { - let len = self - .available_data() - .min(self.inner.cap - ((self.released as usize) % self.inner.cap)); - unsafe { - ptr::slice_from_raw_parts( - self.inner - .buffer - .add((self.released as usize) % self.inner.cap), - len, - ) - .as_ref() - .unwrap() - } - } - - pub fn consume(&mut self, amt: usize) { - assert!( - amt <= self.available_data(), - "cannot consume more bytes than available" - ); - self.released = self.released.saturating_add(amt as u64); - self.inner.released.store(self.released, Ordering::Release); - if self.sent < self.released { - self.sent = self.released; - } - self.inner.writable.wake(); - } - - pub fn reserve_send( - &mut self, - remote_max_offset: u64, - max_len: usize, - ) -> Option> { - self.produced = self.inner.produced.load(Ordering::Acquire); - let limit = self.produced.min(remote_max_offset); - if self.sent >= limit { - return None; - } - let len = ((limit - self.sent) as usize).min(max_len); - let offset = self.sent; - self.sent = self.sent.saturating_add(len as u64); - Some(SendGrant { - inner: self.inner.as_ref(), - offset, - len, - position: 0, - }) - } - - pub fn retry_send(&self, offset: u64, len: usize) -> Option> { - let released = self.inner.released.load(Ordering::Acquire); - let produced = self.inner.produced.load(Ordering::Acquire); - if offset < released || offset.saturating_add(len as u64) > produced { - return None; - } - Some(SendGrant { - inner: self.inner.as_ref(), - offset, - len, - position: 0, - }) - } - - pub fn reserve_at(&mut self, offset: u64, max_len: usize) -> Option> { - if offset < self.sent { - return self.retry_send(offset, max_len); - } - if offset == self.sent { - return self.reserve_send(offset.saturating_add(max_len as u64), max_len); - } - None - } - - pub fn release_to(&mut self, released: u64) { - self.released = released; - self.inner.released.store(released, Ordering::Release); - self.inner.writable.wake(); - } - - pub fn sent_offset(&self) -> u64 { - self.sent - } - - pub fn writer_finished(&self) -> bool { - PipeState::from_u8(self.inner.state.load(Ordering::Acquire)) == PipeState::Finished - } - - pub fn all_sent(&mut self) -> bool { - self.produced = self.inner.produced.load(Ordering::Acquire); - self.sent >= self.produced - } - - pub fn close(&mut self) { - loop { - match PipeState::from_u8(self.inner.state.load(Ordering::Acquire)) { - PipeState::Closed => return, - PipeState::Failed => { - if self - .inner - .state - .compare_exchange( - PIPE_FAILED, - PIPE_CLOSED, - Ordering::AcqRel, - Ordering::Acquire, - ) - .is_ok() - { - unsafe { - (*self.inner.error.get()).assume_init_drop(); - } - self.inner.writable.wake(); - self.inner.closed.wake(); - return; - } - } - current => { - if self - .inner - .state - .compare_exchange( - current.as_u8(), - PIPE_CLOSED, - Ordering::AcqRel, - Ordering::Acquire, - ) - .is_ok() - { - self.inner.writable.wake(); - self.inner.closed.wake(); - return; - } - } - } - } - } - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.produced = self.inner.produced.load(Ordering::Acquire); - if self.available_data() > 0 { - return Poll::Ready(ReadReady::Data); - } - - loop { - match PipeState::from_u8(self.inner.state.load(Ordering::Acquire)) { - PipeState::Open => { - self.inner.readable.register(cx.waker()); - self.produced = self.inner.produced.load(Ordering::Acquire); - if self.available_data() > 0 { - self.inner.readable.take(); - return Poll::Ready(ReadReady::Data); - } - if PipeState::from_u8(self.inner.state.load(Ordering::Acquire)) - == PipeState::Open - { - return Poll::Pending; - } - self.inner.readable.take(); - } - PipeState::Finished | PipeState::Closed => return Poll::Ready(ReadReady::Eof), - PipeState::Failed => { - let err = match self.inner.state.compare_exchange( - PIPE_FAILED, - PIPE_FAILED_TAKEN, - Ordering::AcqRel, - Ordering::Acquire, - ) { - Ok(_) => unsafe { (*self.inner.error.get()).assume_init_read() }, - Err(_) => continue, - }; - return Poll::Ready(ReadReady::Error(err)); - } - PipeState::FailedTaken => return Poll::Ready(ReadReady::Eof), - } - } - } - - fn available_data(&self) -> usize { - self.produced.saturating_sub(self.released) as usize - } - - #[cfg(test)] - pub fn state(&self) -> PipeState { - PipeState::from_u8(self.inner.state.load(Ordering::Acquire)) - } -} - -impl Drop for PipeReader { - fn drop(&mut self) { - self.close(); - } -} - -impl SendGrant<'_, E> { - pub fn offset(&self) -> u64 { - self.offset - } - - pub fn len(&self) -> usize { - self.len - } -} - -impl Read for SendGrant<'_, E> { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let remaining = self.len.saturating_sub(self.position); - if remaining == 0 || buf.is_empty() { - return Ok(0); - } - let n = remaining.min(buf.len()); - unsafe { - copy_bytes( - self.inner.buffer, - self.inner.cap, - self.offset.saturating_add(self.position as u64), - &mut buf[..n], - ); - } - self.position += n; - Ok(n) - } -} - -unsafe fn write_bytes(buffer: *mut u8, cap: usize, offset: u64, src: &[u8]) { - let start = (offset as usize) % cap; - let first = src.len().min(cap - start); - ptr::copy_nonoverlapping(src.as_ptr(), buffer.add(start), first); - if first < src.len() { - ptr::copy_nonoverlapping(src[first..].as_ptr(), buffer, src.len() - first); - } -} - -unsafe fn copy_bytes(buffer: *mut u8, cap: usize, offset: u64, dst: &mut [u8]) { - let len = dst.len(); - let start = (offset as usize) % cap; - let first = len.min(cap - start); - ptr::copy_nonoverlapping(buffer.add(start), dst.as_mut_ptr(), first); - if first < len { - ptr::copy_nonoverlapping(buffer, dst.as_mut_ptr().add(first), len - first); - } -} - -#[cfg(test)] -mod tests { - use std::convert::Infallible; - - use futures_lite::future::poll_fn; - use tokio::task::yield_now; - - use super::*; - - #[tokio::test(flavor = "current_thread")] - async fn pipe_writes_reads_and_releases() { - let (mut reader, mut writer) = pipe::(8); - assert_eq!( - poll_fn(|cx| writer.poll_write(cx, b"abcd")).await.unwrap(), - 4 - ); - - let mut send = reader.reserve_send(8, 8).unwrap(); - assert_eq!(send.offset(), 0); - assert_eq!(send.len(), 4); - let mut bytes = vec![0; send.len()]; - send.read_exact(&mut bytes).unwrap(); - assert_eq!(bytes, b"abcd"); - - reader.release_to(4); - assert!(writer.is_drained()); - assert_eq!(poll_fn(|cx| writer.poll_write(cx, b"ef")).await.unwrap(), 2); - let mut send = reader.reserve_send(8, 8).unwrap(); - let mut bytes = vec![0; send.len()]; - send.read_exact(&mut bytes).unwrap(); - assert_eq!(bytes, b"ef"); - } - - #[tokio::test(flavor = "current_thread")] - async fn pipe_blocks_until_released() { - let (mut reader, mut writer) = pipe::(4); - assert_eq!( - poll_fn(|cx| writer.poll_write(cx, b"abcd")).await.unwrap(), - 4 - ); - - let mut blocked = false; - let poll = poll_fn(|cx| match writer.poll_write(cx, b"e") { - Poll::Ready(result) => Poll::Ready(result), - Poll::Pending => { - blocked = true; - Poll::Ready(Ok(0)) - } - }) - .await - .unwrap(); - assert_eq!(poll, 0); - assert!(blocked); - - reader.release_to(4); - yield_now().await; - assert_eq!(poll_fn(|cx| writer.poll_write(cx, b"e")).await.unwrap(), 1); - } - - #[tokio::test(flavor = "current_thread")] - async fn pipe_closed_waits_for_reader_close() { - let (mut reader, mut writer) = pipe::(8); - writer.finish(); - assert_eq!(writer.state(), PipeState::Finished); - - let waiter = tokio::spawn(async move { - writer.closed().await; - }); - - yield_now().await; - assert!(!waiter.is_finished()); - reader.close(); - waiter.await.unwrap(); - } - - #[tokio::test(flavor = "current_thread")] - async fn pipe_wraparound_reads_correctly() { - let (mut reader, mut writer) = pipe::(8); - assert_eq!( - poll_fn(|cx| writer.poll_write(cx, b"abcdef")) - .await - .unwrap(), - 6 - ); - let mut send = reader.reserve_send(8, 6).unwrap(); - let mut bytes = vec![0; send.len()]; - send.read_exact(&mut bytes).unwrap(); - assert_eq!(bytes, b"abcdef"); - reader.release_to(6); - - assert_eq!( - poll_fn(|cx| writer.poll_write(cx, b"ghijkl")) - .await - .unwrap(), - 6 - ); - let mut send = reader.reserve_send(12, 6).unwrap(); - let mut bytes = vec![0; send.len()]; - send.read_exact(&mut bytes).unwrap(); - assert_eq!(bytes, b"ghijkl"); - } - - #[tokio::test(flavor = "current_thread")] - async fn closing_reader_wakes_writer() { - let (mut reader, mut writer) = pipe::(4); - assert_eq!( - poll_fn(|cx| writer.poll_write(cx, b"abcd")).await.unwrap(), - 4 - ); - reader.close(); - assert_eq!(reader.state(), PipeState::Closed); - let err = poll_fn(|cx| writer.poll_write(cx, b"e")).await.unwrap_err(); - assert_eq!(err, PipeClosed); - } - - #[tokio::test(flavor = "current_thread")] - async fn buffered_bytes_drain_before_eof() { - let (mut reader, mut writer) = pipe::(8); - poll_fn(|cx| writer.poll_write(cx, b"abc")).await.unwrap(); - writer.finish(); - - assert!(matches!( - poll_fn(|cx| reader.poll_ready(cx)).await, - ReadReady::Data - )); - assert_eq!(reader.peek_buf(), b"abc"); - reader.consume(3); - assert!(matches!( - poll_fn(|cx| reader.poll_ready(cx)).await, - ReadReady::Eof - )); - } - - #[tokio::test(flavor = "current_thread")] - async fn buffered_bytes_drain_before_error() { - let (mut reader, mut writer) = pipe::<&'static str>(8); - poll_fn(|cx| writer.poll_write(cx, b"abc")).await.unwrap(); - writer.fail("boom"); - - assert!(matches!( - poll_fn(|cx| reader.poll_ready(cx)).await, - ReadReady::Data - )); - assert_eq!(reader.peek_buf(), b"abc"); - reader.consume(3); - match poll_fn(|cx| reader.poll_ready(cx)).await { - ReadReady::Error(err) => assert_eq!(err, "boom"), - _ => panic!("expected pipe error"), - } - } -} diff --git a/ql2/src/wire/codec.rs b/ql2/src/wire/codec.rs index 4eb8d3ec..c9edf009 100644 --- a/ql2/src/wire/codec.rs +++ b/ql2/src/wire/codec.rs @@ -78,8 +78,10 @@ impl TryFrom for XID { } } -pub(crate) fn xid_from_archived(value: &ArchivedWireXid) -> XID { - XID::from_data(value.0) +impl From<&ArchivedWireXid> for XID { + fn from(value: &ArchivedWireXid) -> Self { + XID::from_data(value.0) + } } impl_wire_wrapper!(AsWireXid, XID, WireXid); @@ -101,8 +103,10 @@ impl TryFrom for Nonce { } } -pub(crate) fn nonce_from_archived(value: &ArchivedWireNonce) -> Nonce { - Nonce::from_data(value.0) +impl From<&ArchivedWireNonce> for Nonce { + fn from(value: &ArchivedWireNonce) -> Self { + Nonce::from_data(value.0) + } } impl_wire_wrapper!(AsWireNonce, Nonce, WireNonce); @@ -137,11 +141,13 @@ impl From for WireMlDsaLevel { } } -pub(crate) fn mldsa_level_from_archived(value: &ArchivedWireMlDsaLevel) -> MLDSA { - match value { - ArchivedWireMlDsaLevel::MlDsa44 => MLDSA::MLDSA44, - ArchivedWireMlDsaLevel::MlDsa65 => MLDSA::MLDSA65, - ArchivedWireMlDsaLevel::MlDsa87 => MLDSA::MLDSA87, +impl From<&ArchivedWireMlDsaLevel> for MLDSA { + fn from(value: &ArchivedWireMlDsaLevel) -> Self { + match value { + ArchivedWireMlDsaLevel::MlDsa44 => MLDSA::MLDSA44, + ArchivedWireMlDsaLevel::MlDsa65 => MLDSA::MLDSA65, + ArchivedWireMlDsaLevel::MlDsa87 => MLDSA::MLDSA87, + } } } @@ -175,11 +181,13 @@ impl From for WireMlKemLevel { } } -pub(crate) fn mlkem_level_from_archived(value: &ArchivedWireMlKemLevel) -> MLKEM { - match value { - ArchivedWireMlKemLevel::MlKem512 => MLKEM::MLKEM512, - ArchivedWireMlKemLevel::MlKem768 => MLKEM::MLKEM768, - ArchivedWireMlKemLevel::MlKem1024 => MLKEM::MLKEM1024, +impl From<&ArchivedWireMlKemLevel> for MLKEM { + fn from(value: &ArchivedWireMlKemLevel) -> Self { + match value { + ArchivedWireMlKemLevel::MlKem512 => MLKEM::MLKEM512, + ArchivedWireMlKemLevel::MlKem768 => MLKEM::MLKEM768, + ArchivedWireMlKemLevel::MlKem1024 => MLKEM::MLKEM1024, + } } } @@ -207,6 +215,15 @@ impl From<&MLDSAPublicKey> for WireMlDsaPublicKey { } } +impl TryFrom<&ArchivedWireMlDsaPublicKey> for MLDSAPublicKey { + type Error = QlError; + + fn try_from(value: &ArchivedWireMlDsaPublicKey) -> Result { + MLDSAPublicKey::from_bytes((&value.level).into(), value.bytes.as_slice()) + .map_err(|_| QlError::InvalidPayload) + } +} + impl_wire_wrapper!(AsWireMlDsaPublicKey, MLDSAPublicKey, WireMlDsaPublicKey); #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] @@ -233,14 +250,13 @@ impl From<&MLDSASignature> for WireMlDsaSignature { } } -pub(crate) fn mldsa_signature_from_archived( - value: &ArchivedWireMlDsaSignature, -) -> Result { - MLDSASignature::from_bytes( - mldsa_level_from_archived(&value.level), - value.bytes.as_slice(), - ) - .map_err(|_| QlError::InvalidPayload) +impl TryFrom<&ArchivedWireMlDsaSignature> for MLDSASignature { + type Error = QlError; + + fn try_from(value: &ArchivedWireMlDsaSignature) -> Result { + MLDSASignature::from_bytes((&value.level).into(), value.bytes.as_slice()) + .map_err(|_| QlError::InvalidPayload) + } } impl_wire_wrapper!(AsWireMlDsaSignature, MLDSASignature, WireMlDsaSignature); @@ -269,6 +285,15 @@ impl From<&MLKEMPublicKey> for WireMlKemPublicKey { } } +impl TryFrom<&ArchivedWireMlKemPublicKey> for MLKEMPublicKey { + type Error = QlError; + + fn try_from(value: &ArchivedWireMlKemPublicKey) -> Result { + MLKEMPublicKey::from_bytes((&value.level).into(), value.bytes.as_slice()) + .map_err(|_| QlError::InvalidPayload) + } +} + impl_wire_wrapper!(AsWireMlKemPublicKey, MLKEMPublicKey, WireMlKemPublicKey); #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] @@ -295,14 +320,13 @@ impl From<&MLKEMCiphertext> for WireMlKemCiphertext { } } -pub(crate) fn mlkem_ciphertext_from_archived( - value: &ArchivedWireMlKemCiphertext, -) -> Result { - MLKEMCiphertext::from_bytes( - mlkem_level_from_archived(&value.level), - value.bytes.as_slice(), - ) - .map_err(|_| QlError::InvalidPayload) +impl TryFrom<&ArchivedWireMlKemCiphertext> for MLKEMCiphertext { + type Error = QlError; + + fn try_from(value: &ArchivedWireMlKemCiphertext) -> Result { + MLKEMCiphertext::from_bytes((&value.level).into(), value.bytes.as_slice()) + .map_err(|_| QlError::InvalidPayload) + } } impl_wire_wrapper!(AsWireMlKemCiphertext, MLKEMCiphertext, WireMlKemCiphertext); diff --git a/ql2/src/wire/handshake/crypto.rs b/ql2/src/wire/handshake/crypto.rs index 4f45ef9a..7e78ba40 100644 --- a/ql2/src/wire/handshake/crypto.rs +++ b/ql2/src/wire/handshake/crypto.rs @@ -1,35 +1,59 @@ -use bc_components::{Digest, MLDSAPublicKey, MLKEMPublicKey, Nonce, SymmetricKey, XID}; +use bc_components::{ + Digest, MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, MLKEMPublicKey, Nonce, SymmetricKey, + XID, +}; use rkyv::{Archive, Serialize}; use super::{ - verify_transcript_signature, ArchivedConfirm, ArchivedHello, ArchivedHelloReply, Confirm, - Hello, HelloReply, + verify_signature, ArchivedConfirm, ArchivedHello, ArchivedHelloReply, Confirm, Hello, + HelloReply, }; use crate::{ - platform::QlCrypto, + platform::{QlCrypto, QlIdentity}, wire::{ - encode_value, mldsa_signature_from_archived, mlkem_ciphertext_from_archived, - nonce_from_archived, AsWireMlKemCiphertext, AsWireNonce, AsWireXid, + encode_value, ensure_not_expired, AsWireMlKemCiphertext, AsWireNonce, AsWireXid, + ControlMeta, }, QlError, }; +#[derive(Archive, Serialize)] +struct HelloProofData { + #[rkyv(with = AsWireXid)] + initiator: XID, + #[rkyv(with = AsWireXid)] + responder: XID, + meta: ControlMeta, + #[rkyv(with = AsWireNonce)] + nonce: Nonce, + #[rkyv(with = AsWireMlKemCiphertext)] + kem_ct: bc_components::MLKEMCiphertext, +} + #[derive(Archive, Serialize)] struct HandshakeTranscript { #[rkyv(with = AsWireXid)] initiator: XID, #[rkyv(with = AsWireXid)] responder: XID, + hello_meta: ControlMeta, #[rkyv(with = AsWireNonce)] initiator_nonce: Nonce, #[rkyv(with = AsWireNonce)] responder_nonce: Nonce, + reply_meta: ControlMeta, #[rkyv(with = AsWireMlKemCiphertext)] initiator_kem_ct: bc_components::MLKEMCiphertext, #[rkyv(with = AsWireMlKemCiphertext)] responder_kem_ct: bc_components::MLKEMCiphertext, } +#[derive(Archive, Serialize)] +struct ConfirmProofData { + meta: ControlMeta, + transcript: Vec, +} + #[derive(Archive, Serialize)] struct SessionKeyMaterial { initiator_secret: Vec, @@ -44,41 +68,79 @@ pub struct ResponderSecrets { } pub fn build_hello( - platform: &impl QlCrypto, - _sender: XID, - _recipient: XID, + identity: &QlIdentity, + crypto: &impl QlCrypto, + recipient: XID, recipient_encapsulation_key: &MLKEMPublicKey, + meta: ControlMeta, ) -> Result<(Hello, SymmetricKey), QlError> { - let nonce = next_nonce(platform); + let nonce = next_nonce(crypto); let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(); - Ok((Hello { nonce, kem_ct }, session_key)) + let signature = identity.signing_private_key.sign(hello_proof_data( + identity.xid, + recipient, + &meta, + &nonce, + &kem_ct, + )); + Ok(( + Hello { + meta, + nonce, + kem_ct, + signature, + }, + session_key, + )) } -pub fn respond_hello( - platform: &impl QlCrypto, +pub fn verify_hello( initiator: XID, responder: XID, + initiator_signing_key: &MLDSAPublicKey, + hello: &ArchivedHello, +) -> Result<(), QlError> { + let meta: ControlMeta = (&hello.meta).into(); + ensure_not_expired(meta.valid_until)?; + let signature = MLDSASignature::try_from(&hello.signature)?; + let nonce: Nonce = (&hello.nonce).into(); + let kem_ct = MLKEMCiphertext::try_from(&hello.kem_ct)?; + let proof_data = hello_proof_data(initiator, responder, &meta, &nonce, &kem_ct); + verify_signature(initiator_signing_key, &signature, &proof_data) +} + +pub fn respond_hello( + identity: &QlIdentity, + crypto: &impl QlCrypto, + initiator: XID, + initiator_signing_key: &MLDSAPublicKey, initiator_encapsulation_key: &MLKEMPublicKey, hello: &ArchivedHello, + meta: ControlMeta, ) -> Result<(HelloReply, ResponderSecrets), QlError> { - let initiator_nonce = nonce_from_archived(&hello.nonce); - let initiator_kem_ct = mlkem_ciphertext_from_archived(&hello.kem_ct)?; - let initiator_secret = platform - .encapsulation_private_key() + verify_hello(initiator, identity.xid, initiator_signing_key, hello)?; + let hello_meta: ControlMeta = (&hello.meta).into(); + let initiator_nonce: Nonce = (&hello.nonce).into(); + let initiator_kem_ct = MLKEMCiphertext::try_from(&hello.kem_ct)?; + let initiator_secret = identity + .encapsulation_private_key .decapsulate_shared_secret(&initiator_kem_ct) .map_err(|_| QlError::InvalidPayload)?; - let nonce = next_nonce(platform); + let nonce = next_nonce(crypto); let (responder_secret, kem_ct) = initiator_encapsulation_key.encapsulate_new_shared_secret(); let transcript = handshake_transcript( initiator, - responder, + identity.xid, + &hello_meta, &initiator_nonce, - &nonce, &initiator_kem_ct, + &meta, + &nonce, &kem_ct, ); - let signature = platform.signing_private_key().sign(&transcript); + let signature = identity.signing_private_key.sign(&transcript); let reply = HelloReply { + meta, nonce, kem_ct, signature, @@ -93,32 +155,38 @@ pub fn respond_hello( } pub fn build_confirm( - platform: &impl QlCrypto, - initiator: XID, + identity: &QlIdentity, responder: XID, responder_signing_key: &MLDSAPublicKey, hello: &Hello, reply: &ArchivedHelloReply, initiator_secret: &SymmetricKey, + meta: ControlMeta, ) -> Result<(Confirm, SymmetricKey), QlError> { - let reply_nonce = nonce_from_archived(&reply.nonce); - let reply_kem_ct = mlkem_ciphertext_from_archived(&reply.kem_ct)?; - let reply_signature = mldsa_signature_from_archived(&reply.signature)?; + let reply_meta: ControlMeta = (&reply.meta).into(); + ensure_not_expired(reply_meta.valid_until)?; + let reply_nonce: Nonce = (&reply.nonce).into(); + let reply_kem_ct = MLKEMCiphertext::try_from(&reply.kem_ct)?; + let reply_signature = MLDSASignature::try_from(&reply.signature)?; let transcript = handshake_transcript( - initiator, + identity.xid, responder, + &hello.meta, &hello.nonce, - &reply_nonce, &hello.kem_ct, + &reply_meta, + &reply_nonce, &reply_kem_ct, ); - verify_transcript_signature(responder_signing_key, &reply_signature, &transcript)?; - let responder_secret = platform - .encapsulation_private_key() + verify_signature(responder_signing_key, &reply_signature, &transcript)?; + let responder_secret = identity + .encapsulation_private_key .decapsulate_shared_secret(&reply_kem_ct) .map_err(|_| QlError::InvalidPayload)?; - let signature = platform.signing_private_key().sign(&transcript); - let confirm = Confirm { signature }; + let signature = identity + .signing_private_key + .sign(confirm_proof_data(&meta, &transcript)); + let confirm = Confirm { meta, signature }; let session_key = derive_session_key(initiator_secret, &responder_secret, &transcript); Ok((confirm, session_key)) } @@ -128,20 +196,25 @@ pub fn finalize_confirm( responder: XID, initiator_signing_key: &MLDSAPublicKey, hello: &Hello, - reply: &super::HelloReply, + reply: &HelloReply, confirm: &ArchivedConfirm, secrets: &ResponderSecrets, ) -> Result { - let confirm_signature = mldsa_signature_from_archived(&confirm.signature)?; + let confirm_meta: ControlMeta = (&confirm.meta).into(); + ensure_not_expired(confirm_meta.valid_until)?; + let confirm_signature = MLDSASignature::try_from(&confirm.signature)?; let transcript = handshake_transcript( initiator, responder, + &hello.meta, &hello.nonce, - &reply.nonce, &hello.kem_ct, + &reply.meta, + &reply.nonce, &reply.kem_ct, ); - verify_transcript_signature(initiator_signing_key, &confirm_signature, &transcript)?; + let proof_data = confirm_proof_data(&confirm_meta, &transcript); + verify_signature(initiator_signing_key, &confirm_signature, &proof_data)?; Ok(derive_session_key( &secrets.initiator_secret, &secrets.responder_secret, @@ -152,21 +225,48 @@ pub fn finalize_confirm( fn handshake_transcript( initiator: XID, responder: XID, + hello_meta: &ControlMeta, initiator_nonce: &Nonce, - responder_nonce: &Nonce, initiator_kem_ct: &bc_components::MLKEMCiphertext, + reply_meta: &ControlMeta, + responder_nonce: &Nonce, responder_kem_ct: &bc_components::MLKEMCiphertext, ) -> Vec { encode_value(&HandshakeTranscript { initiator, responder, + hello_meta: *hello_meta, initiator_nonce: initiator_nonce.clone(), responder_nonce: responder_nonce.clone(), + reply_meta: *reply_meta, initiator_kem_ct: initiator_kem_ct.clone(), responder_kem_ct: responder_kem_ct.clone(), }) } +fn hello_proof_data( + initiator: XID, + responder: XID, + meta: &ControlMeta, + nonce: &Nonce, + kem_ct: &bc_components::MLKEMCiphertext, +) -> Vec { + encode_value(&HelloProofData { + initiator, + responder, + meta: *meta, + nonce: nonce.clone(), + kem_ct: kem_ct.clone(), + }) +} + +fn confirm_proof_data(meta: &ControlMeta, transcript: &[u8]) -> Vec { + encode_value(&ConfirmProofData { + meta: *meta, + transcript: transcript.to_vec(), + }) +} + fn next_nonce(platform: &impl QlCrypto) -> Nonce { let mut data = [0u8; Nonce::NONCE_SIZE]; platform.fill_random_bytes(&mut data); diff --git a/ql2/src/wire/handshake/mod.rs b/ql2/src/wire/handshake/mod.rs index 949e528f..62b3f43f 100644 --- a/ql2/src/wire/handshake/mod.rs +++ b/ql2/src/wire/handshake/mod.rs @@ -1,7 +1,7 @@ use bc_components::{MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, Nonce}; use rkyv::{Archive, Deserialize, Serialize}; -use super::{AsWireMlDsaSignature, AsWireMlKemCiphertext, AsWireNonce}; +use super::{AsWireMlDsaSignature, AsWireMlKemCiphertext, AsWireNonce, ControlMeta}; use crate::QlError; mod crypto; @@ -16,14 +16,18 @@ pub enum HandshakeRecord { #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct Hello { + pub meta: ControlMeta, #[rkyv(with = AsWireNonce)] pub nonce: Nonce, #[rkyv(with = AsWireMlKemCiphertext)] pub kem_ct: MLKEMCiphertext, + #[rkyv(with = AsWireMlDsaSignature)] + pub signature: MLDSASignature, } #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct HelloReply { + pub meta: ControlMeta, #[rkyv(with = AsWireNonce)] pub nonce: Nonce, #[rkyv(with = AsWireMlKemCiphertext)] @@ -34,16 +38,17 @@ pub struct HelloReply { #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct Confirm { + pub meta: ControlMeta, #[rkyv(with = AsWireMlDsaSignature)] pub signature: MLDSASignature, } -pub fn verify_transcript_signature( +pub fn verify_signature( signing_key: &MLDSAPublicKey, signature: &MLDSASignature, - transcript: &[u8], + proof_data: &[u8], ) -> Result<(), QlError> { - match signing_key.verify(signature, transcript) { + match signing_key.verify(signature, proof_data) { Ok(true) => Ok(()), _ => Err(QlError::InvalidSignature), } diff --git a/ql2/src/wire/heartbeat/crypto.rs b/ql2/src/wire/heartbeat/crypto.rs index ccc92ae9..0002e979 100644 --- a/ql2/src/wire/heartbeat/crypto.rs +++ b/ql2/src/wire/heartbeat/crypto.rs @@ -34,6 +34,6 @@ pub(crate) fn decrypt_heartbeat( let plaintext = encrypted.decrypt(session_key, &aad)?; let body = access_value::(plaintext)?; let body = deserialize_value(body)?; - ensure_not_expired(body.valid_until)?; + ensure_not_expired(body.meta.valid_until)?; Ok(body) } diff --git a/ql2/src/wire/heartbeat/mod.rs b/ql2/src/wire/heartbeat/mod.rs index f5e75950..8a7810f6 100644 --- a/ql2/src/wire/heartbeat/mod.rs +++ b/ql2/src/wire/heartbeat/mod.rs @@ -1,12 +1,11 @@ use rkyv::{Archive, Deserialize, Serialize}; -use crate::PacketId; +use super::ControlMeta; mod crypto; pub use crypto::*; #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct HeartbeatBody { - pub packet_id: PacketId, - pub valid_until: u64, + pub meta: ControlMeta, } diff --git a/ql2/src/wire/mod.rs b/ql2/src/wire/mod.rs index 358dbe19..7052eba5 100644 --- a/ql2/src/wire/mod.rs +++ b/ql2/src/wire/mod.rs @@ -1,3 +1,10 @@ +//! quantum link protocol wire format +//! +//! naming conventions: +//! - *Record - unencrypted messages +//! - *Body - message content after decrypting +//! + use bc_components::XID; use rkyv::{ api::{ @@ -13,9 +20,12 @@ pub mod encrypted_message; pub mod handshake; pub mod heartbeat; pub mod pair; +pub mod seq; pub mod stream; pub mod unpair; +pub use seq::StreamSeq; + mod codec; pub(crate) use codec::*; @@ -24,7 +34,7 @@ use self::{ encrypted_message::EncryptedMessage, handshake::HandshakeRecord, pair::PairRequestRecord, unpair::UnpairRecord, }; -use crate::QlError; +use crate::{PacketId, QlError}; pub(crate) type WireArchiveError = rkyv::rancor::Error; @@ -48,6 +58,21 @@ impl QlHeader { } } +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +pub struct ControlMeta { + pub packet_id: PacketId, + pub valid_until: u64, +} + +impl From<&ArchivedControlMeta> for ControlMeta { + fn from(value: &ArchivedControlMeta) -> Self { + Self { + packet_id: (&value.packet_id).into(), + valid_until: value.valid_until.to_native(), + } + } +} + #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] pub enum QlPayload { Handshake(HandshakeRecord), diff --git a/ql2/src/wire/pair/crypto.rs b/ql2/src/wire/pair/crypto.rs index 6b087bc3..8e4748df 100644 --- a/ql2/src/wire/pair/crypto.rs +++ b/ql2/src/wire/pair/crypto.rs @@ -1,5 +1,3 @@ -use std::time::Duration; - use bc_components::{ MLDSAPublicKey, MLKEMCiphertext, MLKEMPublicKey, SigningPublicKey, SymmetricKey, XID, }; @@ -7,14 +5,14 @@ use rkyv::{Archive, Serialize}; use super::{PairRequestBody, PairRequestRecord}; use crate::{ - platform::QlCrypto, + platform::{QlCrypto, QlIdentity}, wire::{ access_value, deserialize_value, encode_value, encrypted_message::{ArchivedEncryptedMessage, EncryptedMessage, NONCE_SIZE}, - ensure_not_expired, mlkem_ciphertext_from_archived, now_secs, AsWireMlDsaPublicKey, - AsWireMlKemCiphertext, AsWireMlKemPublicKey, QlHeader, QlPayload, QlRecord, + ensure_not_expired, AsWireMlDsaPublicKey, AsWireMlKemCiphertext, AsWireMlKemPublicKey, + ControlMeta, QlHeader, QlPayload, QlRecord, }, - PacketId, QlError, + QlError, }; #[derive(Archive, Serialize)] @@ -27,8 +25,7 @@ struct PairingAad { #[derive(Archive, Serialize)] struct PairingProofData { aad: Vec, - packet_id: PacketId, - valid_until: u64, + meta: ControlMeta, #[rkyv(with = AsWireMlDsaPublicKey)] signing_pub_key: MLDSAPublicKey, #[rkyv(with = AsWireMlKemPublicKey)] @@ -36,32 +33,29 @@ struct PairingProofData { } pub fn build_pair_request( - platform: &impl QlCrypto, + identity: &QlIdentity, + crypto: &impl QlCrypto, recipient: XID, recipient_encapsulation_key: &MLKEMPublicKey, - packet_id: PacketId, - valid_for: Duration, + meta: ControlMeta, ) -> Result { let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(); let header = QlHeader { - sender: platform.xid(), + sender: identity.xid, recipient, }; - let valid_until = now_secs().saturating_add(valid_for.as_secs()); - let signing_pub_key = platform.signing_public_key().clone(); - let sender_encapsulation_key = platform.encapsulation_public_key().clone(); + let signing_pub_key = identity.signing_public_key.clone(); + let sender_encapsulation_key = identity.encapsulation_public_key.clone(); let proof_data = pairing_proof_data( &header, &kem_ct, - packet_id, - valid_until, + &meta, &signing_pub_key, &sender_encapsulation_key, ); - let proof = platform.signing_private_key().sign(&proof_data); + let proof = identity.signing_private_key.sign(&proof_data); let body = PairRequestBody { - packet_id, - valid_until, + meta, signing_pub_key, encapsulation_pub_key: sender_encapsulation_key, proof, @@ -69,7 +63,7 @@ pub fn build_pair_request( let body_bytes = encode_value(&body); let aad = pairing_aad(&header, &kem_ct); let mut nonce = [0u8; NONCE_SIZE]; - platform.fill_random_bytes(&mut nonce); + crypto.fill_random_bytes(&mut nonce); let encrypted = EncryptedMessage::encrypt(&session_key, body_bytes, &aad, nonce); Ok(QlRecord { header, @@ -78,26 +72,25 @@ pub fn build_pair_request( } pub fn decrypt_pair_request( - platform: &impl QlCrypto, + identity: &QlIdentity, header: &QlHeader, request: &mut super::ArchivedPairRequestRecord, ) -> Result { - let kem_ct = mlkem_ciphertext_from_archived(&request.kem_ct)?; + let kem_ct = MLKEMCiphertext::try_from(&request.kem_ct)?; let aad = pairing_aad(header, &kem_ct); - let session_key = platform - .encapsulation_private_key() + let session_key = identity + .encapsulation_private_key .decapsulate_shared_secret(&kem_ct) .map_err(|_| QlError::InvalidPayload)?; let decrypted = decrypt_body(&session_key, &mut request.encrypted, &aad)?; - ensure_not_expired(decrypted.valid_until)?; + ensure_not_expired(decrypted.meta.valid_until)?; if XID::new(SigningPublicKey::MLDSA(decrypted.signing_pub_key.clone())) != header.sender { return Err(QlError::InvalidPayload); } let proof_data = pairing_proof_data( header, &kem_ct, - decrypted.packet_id, - decrypted.valid_until, + &decrypted.meta, &decrypted.signing_pub_key, &decrypted.encapsulation_pub_key, ); @@ -115,15 +108,13 @@ pub fn decrypt_pair_request( fn pairing_proof_data( header: &QlHeader, kem_ct: &MLKEMCiphertext, - packet_id: PacketId, - valid_until: u64, + meta: &ControlMeta, signing_pub_key: &MLDSAPublicKey, encapsulation_pub_key: &MLKEMPublicKey, ) -> Vec { encode_value(&PairingProofData { aad: pairing_aad(header, kem_ct), - packet_id, - valid_until, + meta: *meta, signing_pub_key: signing_pub_key.clone(), encapsulation_pub_key: encapsulation_pub_key.clone(), }) diff --git a/ql2/src/wire/pair/mod.rs b/ql2/src/wire/pair/mod.rs index 958462cb..7bb5f488 100644 --- a/ql2/src/wire/pair/mod.rs +++ b/ql2/src/wire/pair/mod.rs @@ -3,9 +3,8 @@ use rkyv::{Archive, Deserialize, Serialize}; use super::{ encrypted_message::EncryptedMessage, AsWireMlDsaPublicKey, AsWireMlDsaSignature, - AsWireMlKemCiphertext, AsWireMlKemPublicKey, + AsWireMlKemCiphertext, AsWireMlKemPublicKey, ControlMeta, }; -use crate::PacketId; mod crypto; pub use crypto::*; @@ -19,8 +18,7 @@ pub struct PairRequestRecord { #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct PairRequestBody { - pub packet_id: PacketId, - pub valid_until: u64, + pub meta: ControlMeta, #[rkyv(with = AsWireMlDsaPublicKey)] pub signing_pub_key: MLDSAPublicKey, #[rkyv(with = AsWireMlKemPublicKey)] diff --git a/ql2/src/wire/seq.rs b/ql2/src/wire/seq.rs new file mode 100644 index 00000000..c3cc1dd9 --- /dev/null +++ b/ql2/src/wire/seq.rs @@ -0,0 +1,97 @@ +use std::{cmp::Ordering, fmt}; + +use rkyv::{Archive, Deserialize, Serialize}; + +#[derive( + Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, +)] +#[repr(transparent)] +pub struct StreamSeq(pub u32); + +impl fmt::Display for StreamSeq { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From<&ArchivedStreamSeq> for StreamSeq { + fn from(value: &ArchivedStreamSeq) -> Self { + Self(value.0.to_native()) + } +} + +impl StreamSeq { + const HALF_RANGE: u32 = 1 << 31; + pub const START: Self = Self(1); + + pub fn next(self) -> Self { + Self(self.0.wrapping_add(1)) + } + + pub fn prev(self) -> Self { + Self(self.0.wrapping_sub(1)) + } + + pub fn add(self, delta: u32) -> Self { + Self(self.0.wrapping_add(delta)) + } + + pub fn serial_cmp(self, other: Self) -> Ordering { + if self == other { + return Ordering::Equal; + } + + let delta = self.0.wrapping_sub(other.0); + if delta < Self::HALF_RANGE { + Ordering::Greater + } else { + Ordering::Less + } + } + + pub fn serial_lt(self, other: Self) -> bool { + self.serial_cmp(other) == Ordering::Less + } + + pub fn serial_lte(self, other: Self) -> bool { + !self.serial_gt(other) + } + + pub fn serial_gt(self, other: Self) -> bool { + self.serial_cmp(other) == Ordering::Greater + } + + pub fn forward_distance_to(self, other: Self) -> Option { + match other.serial_cmp(self) { + Ordering::Less => None, + Ordering::Equal => Some(0), + Ordering::Greater => Some(other.0.wrapping_sub(self.0)), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn stream_seq_serial_order_wraps() { + assert!(StreamSeq(0).serial_gt(StreamSeq(u32::MAX))); + assert!(StreamSeq(1).serial_gt(StreamSeq(u32::MAX))); + assert!(StreamSeq(u32::MAX).serial_lt(StreamSeq(0))); + assert!(StreamSeq(u32::MAX - 1).serial_lt(StreamSeq(1))); + } + + #[test] + fn stream_seq_forward_distance_wraps() { + assert_eq!( + StreamSeq(u32::MAX - 1).forward_distance_to(StreamSeq(1)), + Some(3) + ); + assert_eq!( + StreamSeq(u32::MAX).forward_distance_to(StreamSeq(2)), + Some(3) + ); + assert_eq!(StreamSeq(1).forward_distance_to(StreamSeq(u32::MAX)), None); + } +} diff --git a/ql2/src/wire/stream/crypto.rs b/ql2/src/wire/stream/crypto.rs index 620db25f..48ea3522 100644 --- a/ql2/src/wire/stream/crypto.rs +++ b/ql2/src/wire/stream/crypto.rs @@ -13,11 +13,11 @@ use crate::{ pub fn encrypt_stream( header: QlHeader, session_key: &SymmetricKey, - body: StreamBody, + body: &StreamBody, nonce: [u8; NONCE_SIZE], ) -> QlRecord { let aad = header.aad(); - let body_bytes = encode_value(&body); + let body_bytes = encode_value(body); let encrypted = EncryptedMessage::encrypt(session_key, body_bytes, &aad, nonce); QlRecord { header, @@ -34,6 +34,6 @@ pub(crate) fn decrypt_stream( let plaintext = encrypted.decrypt(session_key, &aad)?; let body = access_value::(plaintext)?; let body = deserialize_value(body)?; - ensure_not_expired(body.valid_until)?; + ensure_not_expired(body.valid_until())?; Ok(body) } diff --git a/ql2/src/wire/stream/mod.rs b/ql2/src/wire/stream/mod.rs index 8db6fbb5..8a2ff7d7 100644 --- a/ql2/src/wire/stream/mod.rs +++ b/ql2/src/wire/stream/mod.rs @@ -1,27 +1,68 @@ use rkyv::{Archive, Deserialize, Serialize}; -use crate::{PacketId, StreamId}; +use crate::{wire::StreamSeq, StreamId}; mod crypto; pub use crypto::*; #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub struct StreamBody { - pub packet_id: PacketId, +pub enum StreamBody { + Ack(StreamAckBody), + Message(StreamMessage), +} + +impl StreamBody { + pub fn stream_id(&self) -> StreamId { + match self { + Self::Ack(StreamAckBody { stream_id, .. }) => *stream_id, + Self::Message(message) => message.frame.stream_id(), + } + } + + pub fn valid_until(&self) -> u64 { + match self { + Self::Ack(body) => body.valid_until, + Self::Message(message) => message.valid_until, + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +pub struct StreamAckBody { + pub stream_id: StreamId, + pub ack: StreamAck, pub valid_until: u64, - pub packet_ack: Option, - pub frame: Option, +} + +impl From<&ArchivedStreamAckBody> for StreamAckBody { + fn from(value: &ArchivedStreamAckBody) -> Self { + Self { + stream_id: (&value.stream_id).into(), + ack: (&value.ack).into(), + valid_until: value.valid_until.to_native(), + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct StreamMessage { + pub tx_seq: StreamSeq, + pub ack: Option, + pub valid_until: u64, + pub frame: StreamFrame, } #[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] -pub struct PacketAck { - pub packet_id: PacketId, +pub struct StreamAck { + pub base: StreamSeq, + pub bitmap: u8, } -impl From<&ArchivedPacketAck> for PacketAck { - fn from(value: &ArchivedPacketAck) -> Self { +impl From<&ArchivedStreamAck> for StreamAck { + fn from(value: &ArchivedStreamAck) -> Self { Self { - packet_id: (&value.packet_id).into(), + base: (&value.base).into(), + bitmap: value.bitmap, } } } @@ -32,8 +73,6 @@ pub enum StreamFrame { Accept(StreamFrameAccept), Reject(StreamFrameReject), Data(StreamFrameData), - Credit(StreamFrameCredit), - Finish(StreamFrameFinish), Reset(StreamFrameReset), } @@ -44,18 +83,31 @@ impl StreamFrame { | StreamFrame::Accept(StreamFrameAccept { stream_id, .. }) | StreamFrame::Reject(StreamFrameReject { stream_id, .. }) | StreamFrame::Data(StreamFrameData { stream_id, .. }) - | StreamFrame::Credit(StreamFrameCredit { stream_id, .. }) - | StreamFrame::Finish(StreamFrameFinish { stream_id, .. }) | StreamFrame::Reset(StreamFrameReset { stream_id, .. }) => *stream_id, } } } +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct BodyChunk { + pub bytes: Vec, + pub fin: bool, +} + +impl From<&ArchivedBodyChunk> for BodyChunk { + fn from(value: &ArchivedBodyChunk) -> Self { + Self { + bytes: value.bytes.as_slice().to_vec(), + fin: value.fin, + } + } +} + #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct StreamFrameOpen { pub stream_id: StreamId, pub request_head: Vec, - pub response_max_offset: u64, + pub request_prefix: Option, } impl From<&ArchivedStreamFrameOpen> for StreamFrameOpen { @@ -63,7 +115,7 @@ impl From<&ArchivedStreamFrameOpen> for StreamFrameOpen { Self { stream_id: (&value.stream_id).into(), request_head: value.request_head.as_slice().to_vec(), - response_max_offset: value.response_max_offset.to_native(), + request_prefix: value.request_prefix.as_ref().map(Into::into), } } } @@ -72,7 +124,7 @@ impl From<&ArchivedStreamFrameOpen> for StreamFrameOpen { pub struct StreamFrameAccept { pub stream_id: StreamId, pub response_head: Vec, - pub request_max_offset: u64, + pub response_prefix: Option, } impl From<&ArchivedStreamFrameAccept> for StreamFrameAccept { @@ -80,7 +132,7 @@ impl From<&ArchivedStreamFrameAccept> for StreamFrameAccept { Self { stream_id: (&value.stream_id).into(), response_head: value.response_head.as_slice().to_vec(), - request_max_offset: value.request_max_offset.to_native(), + response_prefix: value.response_prefix.as_ref().map(Into::into), } } } @@ -100,31 +152,11 @@ impl From<&ArchivedStreamFrameReject> for StreamFrameReject { } } -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] -pub struct StreamFrameCredit { - pub stream_id: StreamId, - pub dir: Direction, - pub recv_offset: u64, - pub max_offset: u64, -} - -impl From<&ArchivedStreamFrameCredit> for StreamFrameCredit { - fn from(value: &ArchivedStreamFrameCredit) -> Self { - Self { - stream_id: (&value.stream_id).into(), - dir: (&value.dir).into(), - recv_offset: value.recv_offset.to_native(), - max_offset: value.max_offset.to_native(), - } - } -} - #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct StreamFrameData { pub stream_id: StreamId, pub dir: Direction, - pub offset: u64, - pub bytes: Vec, + pub chunk: BodyChunk, } impl From<&ArchivedStreamFrameData> for StreamFrameData { @@ -132,23 +164,7 @@ impl From<&ArchivedStreamFrameData> for StreamFrameData { Self { stream_id: (&value.stream_id).into(), dir: (&value.dir).into(), - offset: value.offset.to_native(), - bytes: value.bytes.as_slice().to_vec(), - } - } -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] -pub struct StreamFrameFinish { - pub stream_id: StreamId, - pub dir: Direction, -} - -impl From<&ArchivedStreamFrameFinish> for StreamFrameFinish { - fn from(value: &ArchivedStreamFrameFinish) -> Self { - Self { - stream_id: (&value.stream_id).into(), - dir: (&value.dir).into(), + chunk: (&value.chunk).into(), } } } @@ -156,7 +172,7 @@ impl From<&ArchivedStreamFrameFinish> for StreamFrameFinish { #[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] pub struct StreamFrameReset { pub stream_id: StreamId, - pub dir: ResetTarget, + pub target: ResetTarget, pub code: ResetCode, } @@ -164,7 +180,7 @@ impl From<&ArchivedStreamFrameReset> for StreamFrameReset { fn from(value: &ArchivedStreamFrameReset) -> Self { Self { stream_id: (&value.stream_id).into(), - dir: (&value.dir).into(), + target: (&value.target).into(), code: (&value.code).into(), } } diff --git a/ql2/src/wire/unpair/crypto.rs b/ql2/src/wire/unpair/crypto.rs index 4339accd..3225d1a7 100644 --- a/ql2/src/wire/unpair/crypto.rs +++ b/ql2/src/wire/unpair/crypto.rs @@ -1,38 +1,27 @@ -use bc_components::MLDSAPublicKey; +use bc_components::{MLDSAPublicKey, MLDSASignature}; use rkyv::{Archive, Serialize}; use super::UnpairRecord; use crate::{ - platform::QlCrypto, - wire::{encode_value, mldsa_signature_from_archived, now_secs, QlHeader, QlPayload, QlRecord}, - PacketId, QlError, + platform::QlIdentity, + wire::{encode_value, ensure_not_expired, ControlMeta, QlHeader, QlPayload, QlRecord}, + QlError, }; #[derive(Archive, Serialize)] struct UnpairProofData { domain: Vec, header: QlHeader, - packet_id: PacketId, - valid_until: u64, + meta: ControlMeta, } -pub fn build_unpair_record( - platform: &impl QlCrypto, - header: QlHeader, - packet_id: PacketId, - valid_until: u64, -) -> QlRecord { - let signature = - platform - .signing_private_key() - .sign(unpair_proof_data(&header, packet_id, valid_until)); +pub fn build_unpair_record(identity: &QlIdentity, header: QlHeader, meta: ControlMeta) -> QlRecord { + let signature = identity + .signing_private_key + .sign(unpair_proof_data(&header, &meta)); QlRecord { header, - payload: QlPayload::Unpair(UnpairRecord { - packet_id, - valid_until, - signature, - }), + payload: QlPayload::Unpair(UnpairRecord { meta, signature }), } } @@ -41,13 +30,10 @@ pub fn verify_unpair_record( record: &super::ArchivedUnpairRecord, signing_key: &MLDSAPublicKey, ) -> Result<(), QlError> { - let packet_id = (&record.packet_id).into(); - let valid_until = record.valid_until.to_native(); - let signature = mldsa_signature_from_archived(&record.signature)?; - if now_secs() > valid_until { - return Err(QlError::InvalidPayload); - } - let proof_data = unpair_proof_data(header, packet_id, valid_until); + let meta: ControlMeta = (&record.meta).into(); + let signature = MLDSASignature::try_from(&record.signature)?; + ensure_not_expired(meta.valid_until)?; + let proof_data = unpair_proof_data(header, &meta); if signing_key.verify(&signature, &proof_data).unwrap_or(false) { Ok(()) } else { @@ -55,11 +41,10 @@ pub fn verify_unpair_record( } } -fn unpair_proof_data(header: &QlHeader, packet_id: PacketId, valid_until: u64) -> Vec { +fn unpair_proof_data(header: &QlHeader, meta: &ControlMeta) -> Vec { encode_value(&UnpairProofData { domain: b"ql-unpair-v1".to_vec(), header: header.clone(), - packet_id, - valid_until, + meta: *meta, }) } diff --git a/ql2/src/wire/unpair/mod.rs b/ql2/src/wire/unpair/mod.rs index e17bcf16..62781e8f 100644 --- a/ql2/src/wire/unpair/mod.rs +++ b/ql2/src/wire/unpair/mod.rs @@ -1,16 +1,14 @@ use bc_components::MLDSASignature; use rkyv::{Archive, Deserialize, Serialize}; -use super::AsWireMlDsaSignature; -use crate::PacketId; +use super::{AsWireMlDsaSignature, ControlMeta}; mod crypto; pub use crypto::*; #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct UnpairRecord { - pub packet_id: PacketId, - pub valid_until: u64, + pub meta: ControlMeta, #[rkyv(with = AsWireMlDsaSignature)] pub signature: MLDSASignature, } From 6ea8f03631caaf72468ffeeaf6ba8a8b1380de50 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 18 Mar 2026 01:16:43 -0400 Subject: [PATCH 007/304] ql: modularize engine internals, improve session reliability, and split crates --- Cargo.lock | 28 +- Cargo.toml | 2 +- {ql2 => ql-engine}/Cargo.toml | 9 +- .../src/engine/implementation/handshake.rs | 616 +++++ ql-engine/src/engine/implementation/mod.rs | 893 +++++++ ql-engine/src/engine/implementation/peer.rs | 153 ++ ql-engine/src/engine/implementation/stream.rs | 822 ++++++ ql-engine/src/engine/mod.rs | 171 ++ {ql2 => ql-engine}/src/engine/replay_cache.rs | 1 + {ql2 => ql-engine}/src/engine/ring.rs | 8 +- {ql2 => ql-engine}/src/engine/state.rs | 373 ++- ql-engine/src/engine/stream.rs | 563 ++++ ql-engine/src/engine/tests/handshake.rs | 824 ++++++ ql-engine/src/engine/tests/liveness.rs | 87 + ql-engine/src/engine/tests/mod.rs | 410 +++ ql-engine/src/engine/tests/peer.rs | 42 + ql-engine/src/engine/tests/stream.rs | 1569 +++++++++++ .../platform.rs => ql-engine/src/identity.rs | 22 - {ql2 => ql-engine}/src/lib.rs | 20 +- {ql2 => ql-engine}/src/wire/codec.rs | 0 .../src/wire/encrypted_message.rs | 0 .../src/wire/handshake/crypto.rs | 77 +- {ql2 => ql-engine}/src/wire/handshake/mod.rs | 16 +- .../src/wire/heartbeat/crypto.rs | 0 {ql2 => ql-engine}/src/wire/heartbeat/mod.rs | 0 {ql2/src => ql-engine/src/wire}/id.rs | 18 +- ql-engine/src/wire/mod.rs | 522 ++++ {ql2 => ql-engine}/src/wire/pair/crypto.rs | 3 +- {ql2 => ql-engine}/src/wire/pair/mod.rs | 0 {ql2 => ql-engine}/src/wire/seq.rs | 0 {ql2 => ql-engine}/src/wire/stream/crypto.rs | 0 {ql2 => ql-engine}/src/wire/stream/mod.rs | 144 +- {ql2 => ql-engine}/src/wire/unpair/crypto.rs | 2 +- {ql2 => ql-engine}/src/wire/unpair/mod.rs | 0 {ql => ql-runtime}/Cargo.toml | 8 +- ql-runtime/src/command.rs | 30 + ql-runtime/src/driver.rs | 564 ++++ ql-runtime/src/handle.rs | 293 +++ .../runtime/mod.rs => ql-runtime/src/lib.rs | 37 +- ql-runtime/src/platform.rs | 22 + {ql2 => ql-runtime}/src/rpc/client.rs | 0 {ql2 => ql-runtime}/src/rpc/mod.rs | 0 {ql2 => ql-runtime}/src/rpc/modality.rs | 0 {ql2 => ql-runtime}/src/rpc/server.rs | 0 ql-runtime/src/tests/handshake.rs | 125 + ql-runtime/src/tests/heartbeat.rs | 217 ++ ql-runtime/src/tests/mod.rs | 389 +++ ql-runtime/src/tests/stream.rs | 386 +++ ql-runtime/src/tests/unpair.rs | 76 + ql/README.md | 143 - ql/ql-v2.presenterm.md | 285 -- ql/src/id.rs | 51 - ql/src/lib.rs | 68 - ql/src/platform.rs | 37 - ql/src/router.rs | 377 --- ql/src/runtime/core.rs | 2201 ---------------- ql/src/runtime/handle.rs | 492 ---- ql/src/runtime/internal.rs | 545 ---- ql/src/runtime/mod.rs | 180 -- ql/src/runtime/replay_cache.rs | 181 -- ql/src/tests/handshake.rs | 292 --- ql/src/tests/heartbeat.rs | 641 ----- ql/src/tests/mod.rs | 660 ----- ql/src/tests/persistence.rs | 228 -- ql/src/tests/requests.rs | 446 ---- ql/src/tests/streams.rs | 552 ---- ql/src/tests/unpair.rs | 160 -- ql/src/wire/handshake/crypto.rs | 131 - ql/src/wire/handshake/mod.rs | 122 - ql/src/wire/heartbeat/crypto.rs | 40 - ql/src/wire/heartbeat/mod.rs | 35 - ql/src/wire/message/crypto.rs | 78 - ql/src/wire/message/mod.rs | 144 - ql/src/wire/mod.rs | 226 -- ql/src/wire/pair/crypto.rs | 124 - ql/src/wire/pair/mod.rs | 71 - ql/src/wire/transfer/crypto.rs | 42 - ql/src/wire/transfer/mod.rs | 194 -- ql/src/wire/unpair/crypto.rs | 58 - ql/src/wire/unpair/mod.rs | 39 - ql2/README.md | 143 - ql2/ql-v2.presenterm.md | 285 -- ql2/src/engine/mod.rs | 2323 ----------------- ql2/src/engine/stream.rs | 435 --- ql2/src/engine/tests.rs | 1360 ---------- ql2/src/runtime/command.rs | 50 - ql2/src/runtime/driver.rs | 721 ----- ql2/src/runtime/handle.rs | 388 --- ql2/src/tests/handshake.rs | 99 - ql2/src/tests/heartbeat.rs | 455 ---- ql2/src/tests/mod.rs | 1027 -------- ql2/src/tests/persistence.rs | 139 - ql2/src/tests/rpc.rs | 264 -- ql2/src/tests/stream.rs | 1685 ------------ ql2/src/tests/unpair.rs | 137 - ql2/src/wire/mod.rs | 153 -- 96 files changed, 9118 insertions(+), 18931 deletions(-) rename {ql2 => ql-engine}/Cargo.toml (65%) create mode 100644 ql-engine/src/engine/implementation/handshake.rs create mode 100644 ql-engine/src/engine/implementation/mod.rs create mode 100644 ql-engine/src/engine/implementation/peer.rs create mode 100644 ql-engine/src/engine/implementation/stream.rs create mode 100644 ql-engine/src/engine/mod.rs rename {ql2 => ql-engine}/src/engine/replay_cache.rs (98%) rename {ql2 => ql-engine}/src/engine/ring.rs (98%) rename {ql2 => ql-engine}/src/engine/state.rs (55%) create mode 100644 ql-engine/src/engine/stream.rs create mode 100644 ql-engine/src/engine/tests/handshake.rs create mode 100644 ql-engine/src/engine/tests/liveness.rs create mode 100644 ql-engine/src/engine/tests/mod.rs create mode 100644 ql-engine/src/engine/tests/peer.rs create mode 100644 ql-engine/src/engine/tests/stream.rs rename ql2/src/platform.rs => ql-engine/src/identity.rs (55%) rename {ql2 => ql-engine}/src/lib.rs (69%) rename {ql2 => ql-engine}/src/wire/codec.rs (100%) rename {ql2 => ql-engine}/src/wire/encrypted_message.rs (100%) rename {ql2 => ql-engine}/src/wire/handshake/crypto.rs (82%) rename {ql2 => ql-engine}/src/wire/handshake/mod.rs (78%) rename {ql2 => ql-engine}/src/wire/heartbeat/crypto.rs (100%) rename {ql2 => ql-engine}/src/wire/heartbeat/mod.rs (100%) rename {ql2/src => ql-engine/src/wire}/id.rs (68%) create mode 100644 ql-engine/src/wire/mod.rs rename {ql2 => ql-engine}/src/wire/pair/crypto.rs (98%) rename {ql2 => ql-engine}/src/wire/pair/mod.rs (100%) rename {ql2 => ql-engine}/src/wire/seq.rs (100%) rename {ql2 => ql-engine}/src/wire/stream/crypto.rs (100%) rename {ql2 => ql-engine}/src/wire/stream/mod.rs (53%) rename {ql2 => ql-engine}/src/wire/unpair/crypto.rs (98%) rename {ql2 => ql-engine}/src/wire/unpair/mod.rs (100%) rename {ql => ql-runtime}/Cargo.toml (75%) create mode 100644 ql-runtime/src/command.rs create mode 100644 ql-runtime/src/driver.rs create mode 100644 ql-runtime/src/handle.rs rename ql2/src/runtime/mod.rs => ql-runtime/src/lib.rs (52%) create mode 100644 ql-runtime/src/platform.rs rename {ql2 => ql-runtime}/src/rpc/client.rs (100%) rename {ql2 => ql-runtime}/src/rpc/mod.rs (100%) rename {ql2 => ql-runtime}/src/rpc/modality.rs (100%) rename {ql2 => ql-runtime}/src/rpc/server.rs (100%) create mode 100644 ql-runtime/src/tests/handshake.rs create mode 100644 ql-runtime/src/tests/heartbeat.rs create mode 100644 ql-runtime/src/tests/mod.rs create mode 100644 ql-runtime/src/tests/stream.rs create mode 100644 ql-runtime/src/tests/unpair.rs delete mode 100644 ql/README.md delete mode 100644 ql/ql-v2.presenterm.md delete mode 100644 ql/src/id.rs delete mode 100644 ql/src/lib.rs delete mode 100644 ql/src/platform.rs delete mode 100644 ql/src/router.rs delete mode 100644 ql/src/runtime/core.rs delete mode 100644 ql/src/runtime/handle.rs delete mode 100644 ql/src/runtime/internal.rs delete mode 100644 ql/src/runtime/mod.rs delete mode 100644 ql/src/runtime/replay_cache.rs delete mode 100644 ql/src/tests/handshake.rs delete mode 100644 ql/src/tests/heartbeat.rs delete mode 100644 ql/src/tests/mod.rs delete mode 100644 ql/src/tests/persistence.rs delete mode 100644 ql/src/tests/requests.rs delete mode 100644 ql/src/tests/streams.rs delete mode 100644 ql/src/tests/unpair.rs delete mode 100644 ql/src/wire/handshake/crypto.rs delete mode 100644 ql/src/wire/handshake/mod.rs delete mode 100644 ql/src/wire/heartbeat/crypto.rs delete mode 100644 ql/src/wire/heartbeat/mod.rs delete mode 100644 ql/src/wire/message/crypto.rs delete mode 100644 ql/src/wire/message/mod.rs delete mode 100644 ql/src/wire/mod.rs delete mode 100644 ql/src/wire/pair/crypto.rs delete mode 100644 ql/src/wire/pair/mod.rs delete mode 100644 ql/src/wire/transfer/crypto.rs delete mode 100644 ql/src/wire/transfer/mod.rs delete mode 100644 ql/src/wire/unpair/crypto.rs delete mode 100644 ql/src/wire/unpair/mod.rs delete mode 100644 ql2/README.md delete mode 100644 ql2/ql-v2.presenterm.md delete mode 100644 ql2/src/engine/mod.rs delete mode 100644 ql2/src/engine/stream.rs delete mode 100644 ql2/src/engine/tests.rs delete mode 100644 ql2/src/runtime/command.rs delete mode 100644 ql2/src/runtime/driver.rs delete mode 100644 ql2/src/runtime/handle.rs delete mode 100644 ql2/src/tests/handshake.rs delete mode 100644 ql2/src/tests/heartbeat.rs delete mode 100644 ql2/src/tests/mod.rs delete mode 100644 ql2/src/tests/persistence.rs delete mode 100644 ql2/src/tests/rpc.rs delete mode 100644 ql2/src/tests/stream.rs delete mode 100644 ql2/src/tests/unpair.rs delete mode 100644 ql2/src/wire/mod.rs diff --git a/Cargo.lock b/Cargo.lock index ff756917..d04be45b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1789,6 +1789,17 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "piper" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c835479a4443ded371d6c535cbfd8d31ad92c5d23ae9770a61bc155e4992a3c1" +dependencies = [ + "atomic-waker", + "fastrand", + "futures-io", +] + [[package]] name = "pkcs1" version = "0.7.5" @@ -1949,31 +1960,26 @@ dependencies = [ ] [[package]] -name = "ql" +name = "ql-engine" version = "0.1.0" dependencies = [ - "async-channel", "bc-components", + "chacha20poly1305", "dcbor", - "futures-lite", - "oneshot", + "rkyv", "thiserror", - "tokio", ] [[package]] -name = "ql2" +name = "ql-runtime" version = "0.1.0" dependencies = [ "async-channel", - "atomic-waker", "bc-components", - "chacha20poly1305", - "dcbor", "futures-lite", "oneshot", - "rkyv", - "thiserror", + "piper", + "ql-engine", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index f5f3fbec..73284cea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["api", "backup-shard", "btp", "ql", "ql2", "quantum-link-macros"] +members = ["api", "backup-shard", "btp", "ql-engine", "ql-runtime", "quantum-link-macros"] [workspace.package] homepage = "https://github.com/Foundation-Devices/foundation-api" diff --git a/ql2/Cargo.toml b/ql-engine/Cargo.toml similarity index 65% rename from ql2/Cargo.toml rename to ql-engine/Cargo.toml index d268ca9b..b6a9d09d 100644 --- a/ql2/Cargo.toml +++ b/ql-engine/Cargo.toml @@ -1,22 +1,15 @@ [package] -name = "ql2" +name = "ql-engine" version = "0.1.0" edition = "2021" description = "Quantum Link v2 duplex stream prototype" license = "Proprietary" [dependencies] -async-channel = { version = "2.5" } -atomic-waker = { version = "1.1" } bc-components = { version = "0.28.0", default-features = false, features = [ "pqcrypto", ] } chacha20poly1305 = { version = "0.10.1" } dcbor = { version = "0.23.3" } -futures-lite = { version = "2.5" } -oneshot = { version = "0.1.11" } rkyv = { version = "0.8", default-features = false, features = ["std", "bytecheck", "little_endian", "unaligned", "pointer_width_32"] } thiserror = { version = "2" } - -[dev-dependencies] -tokio = { version = "1.44", features = ["macros", "rt", "time", "sync"] } diff --git a/ql-engine/src/engine/implementation/handshake.rs b/ql-engine/src/engine/implementation/handshake.rs new file mode 100644 index 00000000..3e3db21a --- /dev/null +++ b/ql-engine/src/engine/implementation/handshake.rs @@ -0,0 +1,616 @@ +use super::*; +use crate::{ + engine::{EngineConfig, EngineState, KeepAliveState}, + identity::QlIdentity, + wire::{handshake::HandshakeRecord, QlPayload, QlRecord}, +}; + +#[derive(Debug)] +enum HelloAction { + StartResponder, + ResendReply { + token: Token, + reply: wire::handshake::HelloReply, + deadline: Instant, + }, + Ignore, +} + +enum HelloReplyAction { + Advance { + hello: wire::handshake::Hello, + responder_signing_key: bc_components::MLDSAPublicKey, + initiator_secret: SymmetricKey, + }, + ResendConfirm { + token: Token, + confirm: wire::handshake::Confirm, + deadline: Instant, + }, +} + +pub fn handle_connect( + engine: &mut Engine, + now: Instant, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, +) { + let Some(_) = engine.peer.as_ref() else { + return; + }; + let started = { + let config = &engine.config; + let identity = &engine.identity; + let state = &mut engine.state; + let Some(peer_record) = engine.peer.as_mut() else { + return; + }; + start_initiator_handshake(config, identity, state, peer_record, now, crypto) + }; + if started { + engine.emit_peer_status(emit); + } +} + +pub fn handle_hello( + engine: &mut Engine, + now: Instant, + peer: XID, + hello: &wire::handshake::ArchivedHello, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, +) { + let action = match engine.peer.as_ref() { + Some(entry) => { + if wire::handshake::verify_hello(peer, engine.identity.xid, &entry.signing_key, hello) + .is_err() + { + return; + } + match &entry.session { + PeerSession::Initiator { + hello: local_hello, .. + } => { + if peer_hello_wins(local_hello, engine.identity.xid, hello, peer) { + HelloAction::StartResponder + } else { + HelloAction::Ignore + } + } + PeerSession::Responder { + handshake_token, + hello: stored, + reply, + deadline, + stage: HandshakeResponder::WaitingConfirm { .. }, + } => { + if same_hello(stored, hello) { + HelloAction::ResendReply { + token: *handshake_token, + reply: reply.clone(), + deadline: *deadline, + } + } else { + HelloAction::StartResponder + } + } + PeerSession::Responder { .. } + | PeerSession::Disconnected + | PeerSession::Connected { .. } => HelloAction::StartResponder, + } + } + None => return, + }; + + match action { + HelloAction::StartResponder => { + let meta: ControlMeta = (&hello.meta).into(); + if engine.is_replayed_control(peer, meta) { + return; + } + let changed = { + let config = &engine.config; + let identity = &engine.identity; + let state = &mut engine.state; + let Some(peer_record) = engine.peer.as_mut() else { + return; + }; + start_responder_handshake( + config, + identity, + state, + peer_record, + now, + peer, + hello, + crypto, + ) + }; + if changed { + engine.emit_peer_status(emit); + } + } + HelloAction::ResendReply { + token, + reply, + deadline, + } => { + if engine.handshake_write_pending(token) { + return; + } + engine.clear_handshake_retry_at(token); + enqueue_handshake_record( + engine, + token, + deadline, + peer, + HandshakeRecord::HelloReply(reply), + ); + } + HelloAction::Ignore => {} + } +} + +pub fn handle_hello_reply( + engine: &mut Engine, + now: Instant, + peer: XID, + reply: &wire::handshake::ArchivedHelloReply, + _emit: &mut impl OutputFn, +) { + let action = { + let Some(peer_record) = engine.peer.as_ref() else { + return; + }; + let PeerSession::Initiator { + handshake_token, + hello, + session_key, + stage, + deadline, + .. + } = &peer_record.session + else { + return; + }; + match stage { + HandshakeInitiator::WaitingHelloReply { .. } => HelloReplyAction::Advance { + hello: hello.clone(), + responder_signing_key: peer_record.signing_key.clone(), + initiator_secret: session_key.clone(), + }, + HandshakeInitiator::WaitingReady { + reply: stored_reply, + confirm, + .. + } if same_reply(stored_reply, reply) => HelloReplyAction::ResendConfirm { + token: *handshake_token, + confirm: confirm.clone(), + deadline: *deadline, + }, + HandshakeInitiator::WaitingReady { .. } => return, + } + }; + match action { + HelloReplyAction::Advance { + hello, + responder_signing_key, + initiator_secret, + } => { + let confirm_meta = engine.next_control_meta(engine.config.handshake_timeout); + let (confirm, session_key) = match wire::handshake::build_confirm( + &engine.identity, + peer, + &responder_signing_key, + &hello, + reply, + &initiator_secret, + confirm_meta, + ) { + Ok(result) => result, + Err(_) => return, + }; + let reply_meta: ControlMeta = (&reply.meta).into(); + if engine.is_replayed_control(peer, reply_meta) { + return; + } + let Ok(reply) = wire::deserialize_value(reply) else { + return; + }; + let deadline = now + engine.config.handshake_timeout; + let token = engine.state.next_token(); + let Some(peer_record) = engine.peer.as_mut() else { + return; + }; + peer_record.session = PeerSession::Initiator { + handshake_token: token, + hello, + session_key, + deadline, + stage: HandshakeInitiator::WaitingReady { + reply, + confirm: confirm.clone(), + retry_count: 0, + retry_at: None, + }, + }; + enqueue_handshake_record( + engine, + token, + deadline, + peer, + HandshakeRecord::Confirm(confirm), + ); + } + HelloReplyAction::ResendConfirm { + token, + confirm, + deadline, + } => { + if engine.handshake_write_pending(token) { + return; + } + engine.clear_handshake_retry_at(token); + enqueue_handshake_record( + engine, + token, + deadline, + peer, + HandshakeRecord::Confirm(confirm), + ); + } + } +} + +pub fn handle_confirm( + engine: &mut Engine, + now: Instant, + peer: XID, + confirm: &wire::handshake::ArchivedConfirm, + crypto: &impl QlCrypto, + _emit: &mut impl OutputFn, +) { + if let Some((ready, deadline, token)) = current_ready_resend(engine, now, peer, confirm) { + if engine.handshake_write_pending(token) { + return; + } + enqueue_handshake_record(engine, token, deadline, peer, HandshakeRecord::Ready(ready)); + return; + } + if let Some(ready) = recent_ready_resend(engine, now, peer, confirm) { + let record = QlRecord { + header: QlHeader { + sender: engine.identity.xid, + recipient: peer, + }, + payload: QlPayload::Handshake(HandshakeRecord::Ready(ready)), + }; + engine + .state + .enqueue_control(&engine.config, true, wire::encode_record(&record)); + return; + } + + let res = { + let Some(peer_record) = engine.peer.as_ref() else { + return; + }; + let PeerSession::Responder { + hello, + reply, + stage, + .. + } = &peer_record.session + else { + return; + }; + let HandshakeResponder::WaitingConfirm { secrets, .. } = stage else { + return; + }; + + wire::handshake::finalize_confirm( + peer, + engine.identity.xid, + &peer_record.signing_key, + hello, + reply, + confirm, + secrets, + ) + .map(|session_key| (hello.clone(), reply.clone(), session_key)) + }; + + match res { + Ok((hello, reply, session_key)) => { + let meta: ControlMeta = (&confirm.meta).into(); + if engine.is_replayed_control(peer, meta) { + return; + } + let deadline = now + engine.config.handshake_timeout; + let ready_meta = engine.next_control_meta(engine.config.handshake_timeout); + let ready = wire::handshake::build_ready( + QlHeader { + sender: engine.identity.xid, + recipient: peer, + }, + &session_key, + ready_meta, + encrypted_message_nonce(crypto), + ); + let token = engine.state.next_token(); + if let Some(peer_record) = engine.peer.as_mut() { + peer_record.session = PeerSession::Responder { + handshake_token: token, + hello, + reply, + deadline, + stage: HandshakeResponder::SendingReady { + session_key, + ready: ready.clone(), + }, + }; + } + enqueue_handshake_record(engine, token, deadline, peer, HandshakeRecord::Ready(ready)); + } + Err(_) => {} + } +} + +pub fn handle_ready( + engine: &mut Engine, + now: Instant, + peer: XID, + header: &QlHeader, + ready: &mut wire::handshake::ArchivedReady, + emit: &mut impl OutputFn, +) { + let session_key = { + let Some(peer_record) = engine.peer.as_ref() else { + return; + }; + let PeerSession::Initiator { + session_key, stage, .. + } = &peer_record.session + else { + return; + }; + let HandshakeInitiator::WaitingReady { .. } = stage else { + return; + }; + session_key.clone() + }; + + let Ok(body) = wire::handshake::decrypt_ready(header, ready, &session_key) else { + return; + }; + if engine.is_replayed_control(peer, body.meta) { + return; + } + + if let Some(peer_record) = engine.peer.as_mut() { + peer_record.session = PeerSession::Connected { + session_key, + keepalive: KeepAliveState::default(), + recent_ready: None, + }; + } + engine.record_activity(now); + engine.emit_peer_status(emit); +} + +fn start_initiator_handshake( + config: &EngineConfig, + identity: &QlIdentity, + state: &mut EngineState, + peer_record: &mut PeerRecord, + now: Instant, + crypto: &impl QlCrypto, +) -> bool { + if !matches!(peer_record.session, PeerSession::Disconnected) { + return false; + } + + let meta = ControlMeta { + packet_id: state.next_packet_id(), + valid_until: wire::now_secs() + config.handshake_timeout.as_secs(), + }; + let peer = peer_record.peer; + let Ok((hello, session_key)) = + wire::handshake::build_hello(identity, crypto, peer, &peer_record.encapsulation_key, meta) + else { + return false; + }; + + let deadline = now + config.handshake_timeout; + let token = state.next_token(); + peer_record.session = PeerSession::Initiator { + handshake_token: token, + hello: hello.clone(), + session_key, + deadline, + stage: HandshakeInitiator::WaitingHelloReply { + retry_count: 0, + retry_at: None, + }, + }; + let record = QlRecord { + header: QlHeader { + sender: identity.xid, + recipient: peer, + }, + payload: QlPayload::Handshake(HandshakeRecord::Hello(hello)), + }; + state.enqueue_handshake_message(config, token, deadline, wire::encode_record(&record)); + true +} + +fn start_responder_handshake( + config: &EngineConfig, + identity: &QlIdentity, + state: &mut EngineState, + peer_record: &mut PeerRecord, + now: Instant, + peer: XID, + hello: &wire::handshake::ArchivedHello, + crypto: &impl QlCrypto, +) -> bool { + let reply_meta = ControlMeta { + packet_id: state.next_packet_id(), + valid_until: wire::now_secs() + config.handshake_timeout.as_secs(), + }; + let (reply, secrets) = match wire::handshake::respond_hello( + identity, + crypto, + peer, + &peer_record.signing_key, + &peer_record.encapsulation_key, + hello, + reply_meta, + ) { + Ok(result) => result, + Err(_) => { + peer_record.session = PeerSession::Disconnected; + return true; + } + }; + let Ok(hello) = wire::deserialize_value(hello) else { + peer_record.session = PeerSession::Disconnected; + return true; + }; + + let deadline = now + config.handshake_timeout; + let token = state.next_token(); + peer_record.session = PeerSession::Responder { + handshake_token: token, + hello, + reply: reply.clone(), + deadline, + stage: HandshakeResponder::WaitingConfirm { + secrets, + retry_count: 0, + retry_at: None, + }, + }; + + let record = QlRecord { + header: QlHeader { + sender: identity.xid, + recipient: peer, + }, + payload: QlPayload::Handshake(HandshakeRecord::HelloReply(reply)), + }; + state.enqueue_handshake_message(config, token, deadline, wire::encode_record(&record)); + true +} + +pub(super) fn enqueue_handshake_record( + engine: &mut Engine, + token: Token, + deadline: Instant, + peer: XID, + record: HandshakeRecord, +) { + let record = QlRecord { + header: QlHeader { + sender: engine.identity.xid, + recipient: peer, + }, + payload: QlPayload::Handshake(record), + }; + engine.state.enqueue_handshake_message( + &engine.config, + token, + deadline, + wire::encode_record(&record), + ); +} + +fn same_hello(stored: &wire::handshake::Hello, incoming: &wire::handshake::ArchivedHello) -> bool { + let meta: ControlMeta = (&incoming.meta).into(); + stored.meta.packet_id == meta.packet_id && stored.nonce == (&incoming.nonce).into() +} + +fn same_reply( + stored: &wire::handshake::HelloReply, + incoming: &wire::handshake::ArchivedHelloReply, +) -> bool { + let meta: ControlMeta = (&incoming.meta).into(); + stored.meta.packet_id == meta.packet_id && stored.nonce == (&incoming.nonce).into() +} + +fn current_ready_resend( + engine: &Engine, + now: Instant, + peer: XID, + confirm: &wire::handshake::ArchivedConfirm, +) -> Option<(wire::handshake::Ready, Instant, Token)> { + let peer_record = engine.peer.as_ref()?; + let PeerSession::Responder { + handshake_token, + hello, + reply, + deadline, + stage: HandshakeResponder::SendingReady { ready, .. }, + } = &peer_record.session + else { + return None; + }; + if *deadline <= now { + return None; + } + wire::handshake::verify_confirm( + peer, + engine.identity.xid, + &peer_record.signing_key, + hello, + reply, + confirm, + ) + .ok()?; + Some((ready.clone(), *deadline, *handshake_token)) +} + +fn recent_ready_resend( + engine: &Engine, + now: Instant, + peer: XID, + confirm: &wire::handshake::ArchivedConfirm, +) -> Option { + let peer_record = engine.peer.as_ref()?; + let PeerSession::Connected { + recent_ready: Some(recent_ready), + .. + } = &peer_record.session + else { + return None; + }; + if recent_ready.expires_at <= now { + return None; + } + wire::handshake::verify_confirm( + peer, + engine.identity.xid, + &peer_record.signing_key, + &recent_ready.hello, + &recent_ready.reply, + confirm, + ) + .ok()?; + Some(recent_ready.ready.clone()) +} + +fn peer_hello_wins( + local_hello: &wire::handshake::Hello, + local_sender: XID, + peer_hello: &wire::handshake::ArchivedHello, + peer_sender: XID, +) -> bool { + use std::cmp::Ordering; + + let peer_nonce: bc_components::Nonce = (&peer_hello.nonce).into(); + match peer_nonce.data().cmp(local_hello.nonce.data()) { + Ordering::Less => true, + Ordering::Greater => false, + Ordering::Equal => peer_sender.data().cmp(local_sender.data()) == Ordering::Less, + } +} diff --git a/ql-engine/src/engine/implementation/mod.rs b/ql-engine/src/engine/implementation/mod.rs new file mode 100644 index 00000000..4549a243 --- /dev/null +++ b/ql-engine/src/engine/implementation/mod.rs @@ -0,0 +1,893 @@ +pub mod handshake; +pub mod peer; +pub mod stream; + +use std::time::{Duration, Instant}; + +use bc_components::{SigningPublicKey, SymmetricKey, XID}; +use rkyv::access_mut; + +use crate::{ + engine::{ + replay_cache::ReplayKey, + state::{ActiveWrite, ControlWritePayload, OutboundWriteKind, TimeoutKind}, + stream::{InFlightWriteState, StreamRole, StreamState}, + Engine, EngineInput, EngineOutput, HandshakeInitiator, HandshakeResponder, KeepAliveConfig, + KeepAliveState, OutboundWrite, OutputFn, PeerRecord, PeerSession, QlCrypto, RecentReady, + StreamConfig, Token, WriteId, + }, + wire::{ + self, + encrypted_message::{ArchivedEncryptedMessage, NONCE_SIZE}, + stream::{ + encrypt_stream, BodyChunk, CloseCode, CloseTarget, StreamAck, StreamBody, StreamFrame, + StreamFrameClose, StreamMessage, + }, + ControlMeta, QlHeader, StreamSeq, + }, + Peer, QlError, StreamId, +}; + +impl Engine { + pub fn open_stream( + &mut self, + now: Instant, + request_head: Vec, + request_prefix: Option, + config: StreamConfig, + ) -> Result { + self.state.now = now; + stream::open_stream(self, now, request_head, request_prefix, config) + } + + pub fn run_tick_inner( + &mut self, + now: Instant, + input: EngineInput, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, + ) { + self.state.now = now; + match input { + EngineInput::BindPeer(peer) => peer::handle_bind_peer(self, peer, emit), + EngineInput::Pair => peer::handle_pair_local(self, now, crypto), + EngineInput::Connect => handshake::handle_connect(self, now, crypto, emit), + EngineInput::Unpair => peer::handle_unpair_local(self, now, emit), + EngineInput::CloseStream { + stream_id, + target, + code, + payload, + } => stream::handle_close_stream(self, now, stream_id, target, code, payload), + EngineInput::OutboundData { stream_id, bytes } => { + stream::handle_outbound_data(self, stream_id, bytes) + } + EngineInput::OutboundFinished { stream_id } => { + stream::handle_outbound_finished(self, stream_id) + } + EngineInput::Incoming(bytes) => self.handle_incoming(now, bytes, crypto, emit), + EngineInput::TimerExpired => self.handle_timeouts(now, crypto, emit), + } + + self.handle_ready_retransmits(now, emit); + } + + pub fn take_next_write_inner(&mut self, crypto: &impl QlCrypto) -> Option { + self.take_next_control_write(crypto) + .or_else(|| stream::take_next_stream_write(self, crypto)) + } + + pub fn complete_write_inner( + &mut self, + write_id: WriteId, + result: Result<(), QlError>, + emit: &mut impl OutputFn, + ) { + let now = self.state.now; + let Some(active) = self.state.active_writes.remove(&write_id) else { + return; + }; + + if let OutboundWriteKind::StreamAck { .. } = active.kind { + if let Some(token) = active.token { + self.clear_ack_outbound_token(token, result.is_err()); + } + } + + if let Err(error) = result { + // only fail the stream if this frame is still in flight + // ACKs and protocol reset can remove it before write completion arrives + if let OutboundWriteKind::StreamFrame { stream_id, tx_seq } = active.kind { + if self + .streams + .get(&stream_id) + .is_some_and(|stream| stream.control.in_flight.contains_key(&tx_seq)) + { + self.fail_stream_by_id(stream_id, error.clone(), emit); + } + } + + if self.is_handshake_token(active.token) { + if let Some(entry) = self.peer.as_mut() { + entry.session = PeerSession::Disconnected; + } + self.emit_peer_status(emit); + self.drop_outbound(); + self.abort_streams(error, emit); + } + + return; + } + + if let Some((session_key, recent_ready)) = self.connected_session_for_token(active.token) { + if let Some(entry) = self.peer.as_mut() { + entry.session = PeerSession::Connected { + session_key, + keepalive: KeepAliveState::default(), + recent_ready, + }; + } + self.emit_peer_status(emit); + self.record_activity(now); + } + + if let Some(token) = active.token { + self.schedule_handshake_retry_after_write(token, now); + } + + if let OutboundWriteKind::StreamFrame { stream_id, tx_seq } = active.kind { + if let Some(stream) = self.streams.get_mut(&stream_id) { + stream + .control + .complete_write(tx_seq, now + self.config.stream_ack_timeout); + } + } + } + + pub fn next_deadline_inner(&self) -> Option { + [ + self.state.next_deadline(), + self.streams.stream_retry_deadline(), + self.handshake_deadline(), + self.keep_alive_deadline(), + ] + .into_iter() + .flatten() + .min() + } +} + +impl Engine { + fn emit_peer_status(&self, emit: &mut impl OutputFn) { + if let Some(peer) = self.peer.as_ref() { + emit(EngineOutput::PeerStatusChanged { + peer: peer.peer, + session: peer.session.clone(), + }); + } + } + + fn next_control_meta(&self, valid_for: Duration) -> ControlMeta { + ControlMeta { + packet_id: self.state.next_packet_id(), + valid_until: wire::now_secs() + valid_for.as_secs(), + } + } + + fn keep_alive_deadline(&self) -> Option { + let config = self.keep_alive_config()?; + let entry = self.peer.as_ref()?; + let PeerSession::Connected { keepalive, .. } = &entry.session else { + return None; + }; + let base = keepalive.last_activity?; + Some( + base + if keepalive.pending { + config.timeout + } else { + config.interval + }, + ) + } + + fn handshake_deadline(&self) -> Option { + let entry = self.peer.as_ref()?; + match &entry.session { + PeerSession::Initiator { deadline, .. } | PeerSession::Responder { deadline, .. } => { + Some(*deadline) + } + PeerSession::Disconnected | PeerSession::Connected { .. } => None, + } + } + + fn is_replayed_control(&mut self, peer: XID, meta: ControlMeta) -> bool { + self.state + .replay_cache + .check_and_store_valid_until(ReplayKey::new(peer, meta.packet_id), meta.valid_until) + } + + // TODO: why do we pass 'now' if it's in state? + fn handle_incoming( + &mut self, + now: Instant, + mut bytes: Vec, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, + ) { + let Ok(record) = access_mut::(&mut bytes) + else { + return; + }; + let record = unsafe { record.unseal_unchecked() }; + let sender: XID = (&record.header.sender).into(); + let recipient: XID = (&record.header.recipient).into(); + if recipient != self.identity.xid { + return; + } + if !matches!(&record.payload, wire::ArchivedQlPayload::Pair(_)) { + let Some(peer) = self.peer.as_ref().map(|peer| peer.peer) else { + return; + }; + if sender != peer { + return; + } + } + let Ok(header) = wire::deserialize_value(&record.header) else { + return; + }; + match &mut record.payload { + wire::ArchivedQlPayload::Handshake(message) => { + self.handle_handshake(now, sender, &header, message, crypto, emit) + } + wire::ArchivedQlPayload::Stream(encrypted) => { + stream::handle_stream(self, now, sender, &header, encrypted, emit) + } + wire::ArchivedQlPayload::Heartbeat(encrypted) => { + self.handle_heartbeat(now, &header, encrypted, crypto, emit) + } + wire::ArchivedQlPayload::Pair(request) => { + peer::handle_pairing(self, now, &header, request, crypto, emit) + } + wire::ArchivedQlPayload::Unpair(unpair_record) => { + peer::handle_unpair(self, sender, &header, unpair_record, emit) + } + } + } + + fn handle_handshake( + &mut self, + now: Instant, + peer: XID, + header: &QlHeader, + message: &mut wire::handshake::ArchivedHandshakeRecord, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, + ) { + match message { + wire::handshake::ArchivedHandshakeRecord::Hello(hello) => { + handshake::handle_hello(self, now, peer, hello, crypto, emit) + } + wire::handshake::ArchivedHandshakeRecord::HelloReply(reply) => { + handshake::handle_hello_reply(self, now, peer, reply, emit) + } + wire::handshake::ArchivedHandshakeRecord::Confirm(confirm) => { + handshake::handle_confirm(self, now, peer, confirm, crypto, emit) + } + wire::handshake::ArchivedHandshakeRecord::Ready(ready) => { + handshake::handle_ready(self, now, peer, header, ready, emit) + } + } + } + + fn handle_heartbeat( + &mut self, + now: Instant, + header: &QlHeader, + encrypted: &mut ArchivedEncryptedMessage, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, + ) { + let (body, should_reply) = { + let Some(peer_record) = self.peer.as_ref() else { + return; + }; + let PeerSession::Connected { + session_key, + keepalive, + .. + } = &peer_record.session + else { + return; + }; + let Ok(body) = wire::heartbeat::decrypt_heartbeat(header, encrypted, session_key) + else { + return; + }; + (body, !keepalive.pending) + }; + if self.is_replayed_control(header.sender, body.meta) { + return; + } + self.record_activity(now); + if should_reply { + self.send_heartbeat_message(now, crypto); + } + self.emit_peer_status(emit); + } + + fn handle_ready_retransmits(&mut self, now: Instant, emit: &mut impl OutputFn) { + let mut timed_out = Vec::new(); + for (stream_id, stream) in self.streams.iter() { + let exhausted = stream.control.in_flight.iter().any(|(_, in_flight)| { + matches!( + in_flight.write_state, + InFlightWriteState::WaitingRetry { retry_at } + if retry_at <= now && in_flight.attempt >= self.config.stream_retry_limit + ) + }); + if exhausted { + timed_out.push(*stream_id); + } + } + + for stream_id in timed_out { + self.fail_stream_by_id(stream_id, QlError::Timeout, emit); + } + } + + fn clear_ack_outbound_token(&mut self, token: Token, retry: bool) { + for stream in self.streams.values_mut() { + let control = &mut stream.control; + if control.ack_outbound_token == Some(token) { + control.ack_outbound_token = None; + if retry { + control.note_ack(true); + } + break; + } + } + } + + fn clear_active_writes_for_stream(&mut self, stream_id: StreamId) { + self.state + .active_writes + .retain(|_, active| match active.kind { + OutboundWriteKind::Control => true, + OutboundWriteKind::StreamAck { + stream_id: active_stream_id, + } + | OutboundWriteKind::StreamClose { + stream_id: active_stream_id, + } => active_stream_id != stream_id, + OutboundWriteKind::StreamFrame { + stream_id: active_stream_id, + .. + } => active_stream_id != stream_id, + }); + } + + fn is_handshake_token(&self, token: Option) -> bool { + let Some(token) = token else { + return false; + }; + matches!(self.peer.as_ref().map(|entry| &entry.session), + Some(PeerSession::Initiator { handshake_token, .. }) if *handshake_token == token) + || matches!(self.peer.as_ref().map(|entry| &entry.session), + Some(PeerSession::Responder { handshake_token, .. }) if *handshake_token == token) + } + + fn connected_session_for_token( + &self, + token: Option, + ) -> Option<(SymmetricKey, Option)> { + let token = token?; + self.peer.as_ref().and_then(|entry| match &entry.session { + PeerSession::Responder { + hello, + reply, + deadline, + handshake_token, + stage: HandshakeResponder::SendingReady { session_key, ready }, + } if *handshake_token == token => Some(( + session_key.clone(), + Some(RecentReady { + hello: hello.clone(), + reply: reply.clone(), + ready: ready.clone(), + expires_at: *deadline, + }), + )), + _ => None, + }) + } + + fn handshake_write_pending(&self, token: Token) -> bool { + self.state + .active_writes + .values() + .any(|active| active.token == Some(token)) + || self + .state + .control_outbound + .iter() + .any(|message| message.token == token) + } + + fn clear_handshake_retry_at(&mut self, token: Token) { + let Some(entry) = self.peer.as_mut() else { + return; + }; + match &mut entry.session { + PeerSession::Initiator { + handshake_token, + stage: HandshakeInitiator::WaitingHelloReply { retry_at, .. }, + .. + } if *handshake_token == token => *retry_at = None, + PeerSession::Initiator { + handshake_token, + stage: HandshakeInitiator::WaitingReady { retry_at, .. }, + .. + } if *handshake_token == token => *retry_at = None, + PeerSession::Responder { + handshake_token, + stage: HandshakeResponder::WaitingConfirm { retry_at, .. }, + .. + } if *handshake_token == token => *retry_at = None, + _ => {} + } + } + + fn schedule_handshake_retry_after_write(&mut self, token: Token, now: Instant) { + if self.config.handshake_retry_interval.is_zero() || self.config.max_handshake_retries == 0 + { + return; + } + let retry_at = now + self.config.handshake_retry_interval; + let Some(entry) = self.peer.as_mut() else { + return; + }; + let scheduled = match &mut entry.session { + PeerSession::Initiator { + handshake_token, + stage: + HandshakeInitiator::WaitingHelloReply { + retry_at: stage_retry_at, + .. + }, + .. + } if *handshake_token == token => { + *stage_retry_at = Some(retry_at); + true + } + PeerSession::Initiator { + handshake_token, + stage: + HandshakeInitiator::WaitingReady { + retry_at: stage_retry_at, + .. + }, + .. + } if *handshake_token == token => { + *stage_retry_at = Some(retry_at); + true + } + PeerSession::Responder { + handshake_token, + stage: + HandshakeResponder::WaitingConfirm { + retry_at: stage_retry_at, + .. + }, + .. + } if *handshake_token == token => { + *stage_retry_at = Some(retry_at); + true + } + _ => false, + }; + if scheduled { + self.state.schedule_handshake_retry(token, retry_at); + } + } + + fn stream_write_session(&self) -> Option<(XID, SymmetricKey)> { + self.peer.as_ref().and_then(|peer| { + peer.session + .session_key() + .map(|key| (peer.peer, key.clone())) + }) + } + + fn issue_write( + &mut self, + kind: OutboundWriteKind, + token: Option, + bytes: Vec, + ) -> OutboundWrite { + let id = self.state.next_write_id(); + self.state + .active_writes + .insert(id, ActiveWrite { token, kind }); + OutboundWrite { id, bytes } + } + + fn take_next_control_write(&mut self, crypto: &impl QlCrypto) -> Option { + while let Some(message) = self.state.control_outbound.pop_front() { + let bytes = match message.payload { + ControlWritePayload::Encoded(bytes) => bytes, + ControlWritePayload::StreamClose { + stream_id, + target, + code, + payload, + } => { + let Some((recipient, session_key)) = self.stream_write_session() else { + continue; + }; + let body = StreamBody::Message(StreamMessage { + tx_seq: StreamSeq::START, + ack: StreamAck::EMPTY, + valid_until: wire::now_secs() + .saturating_add(self.config.packet_expiration.as_secs()), + frame: StreamFrame::Close(StreamFrameClose { + stream_id, + target, + code, + payload, + }), + }); + let record = encrypt_stream( + QlHeader { + sender: self.identity.xid, + recipient, + }, + &session_key, + &body, + encrypted_message_nonce(crypto), + ); + wire::encode_record(&record) + } + }; + return Some(self.issue_write(message.kind, Some(message.token), bytes)); + } + None + } + + fn send_ephemeral_close(&mut self, stream_id: StreamId, target: CloseTarget, code: CloseCode) { + self.state + .enqueue_stream_close(&self.config, true, stream_id, target, code, Vec::new()); + } + + fn send_heartbeat_message(&mut self, now: Instant, crypto: &impl QlCrypto) { + let Some(peer) = self.peer.as_ref().map(|peer| peer.peer) else { + return; + }; + let meta = self.next_control_meta(self.config.packet_expiration); + let token = self.state.next_token(); + let deadline = now + self.config.packet_expiration; + let message = { + let Some(peer_record) = self.peer.as_ref() else { + return; + }; + let PeerSession::Connected { session_key, .. } = &peer_record.session else { + return; + }; + wire::heartbeat::encrypt_heartbeat( + QlHeader { + sender: self.identity.xid, + recipient: peer, + }, + session_key, + wire::heartbeat::HeartbeatBody { meta }, + encrypted_message_nonce(crypto), + ) + }; + self.state.enqueue_handshake_message( + &self.config, + token, + deadline, + wire::encode_record(&message), + ); + } + + fn keep_alive_config(&self) -> Option { + self.config + .keep_alive + .filter(|config| !config.interval.is_zero() && !config.timeout.is_zero()) + } + + fn record_activity(&mut self, now: Instant) { + if let Some(PeerRecord { + session: PeerSession::Connected { keepalive, .. }, + .. + }) = self.peer.as_mut() + { + keepalive.last_activity = Some(now); + keepalive.pending = false; + } + } + + fn drop_outbound(&mut self) { + self.state.control_outbound.clear(); + self.state.active_writes.clear(); + } + + fn fail_handshake(&mut self, error: QlError, emit: &mut impl OutputFn) { + if let Some(entry) = self.peer.as_mut() { + if matches!( + entry.session, + PeerSession::Initiator { .. } | PeerSession::Responder { .. } + ) { + entry.session = PeerSession::Disconnected; + } + } + self.emit_peer_status(emit); + self.drop_outbound(); + self.abort_streams(error, emit); + } + + fn handle_handshake_retry_timeout(&mut self, token: Token, emit: &mut impl OutputFn) { + enum RetryAction { + Resend { + peer: XID, + deadline: Instant, + record: wire::handshake::HandshakeRecord, + }, + Fail, + Ignore, + } + + let now = self.state.now; + let action = { + let Some(entry) = self.peer.as_mut() else { + return; + }; + let peer = entry.peer; + match &mut entry.session { + PeerSession::Initiator { + handshake_token, + hello, + deadline, + stage: + HandshakeInitiator::WaitingHelloReply { + retry_count, + retry_at, + }, + .. + } if *handshake_token == token && retry_at.is_some_and(|at| at <= now) => { + *retry_at = None; + if *retry_count >= self.config.max_handshake_retries { + RetryAction::Fail + } else { + *retry_count = retry_count.saturating_add(1); + RetryAction::Resend { + peer, + deadline: *deadline, + record: wire::handshake::HandshakeRecord::Hello(hello.clone()), + } + } + } + PeerSession::Initiator { + handshake_token, + deadline, + stage: + HandshakeInitiator::WaitingReady { + confirm, + retry_count, + retry_at, + .. + }, + .. + } if *handshake_token == token && retry_at.is_some_and(|at| at <= now) => { + *retry_at = None; + if *retry_count >= self.config.max_handshake_retries { + RetryAction::Fail + } else { + *retry_count = retry_count.saturating_add(1); + RetryAction::Resend { + peer, + deadline: *deadline, + record: wire::handshake::HandshakeRecord::Confirm(confirm.clone()), + } + } + } + PeerSession::Responder { + handshake_token, + reply, + deadline, + stage: + HandshakeResponder::WaitingConfirm { + retry_count, + retry_at, + .. + }, + .. + } if *handshake_token == token && retry_at.is_some_and(|at| at <= now) => { + *retry_at = None; + if *retry_count >= self.config.max_handshake_retries { + RetryAction::Fail + } else { + *retry_count = retry_count.saturating_add(1); + RetryAction::Resend { + peer, + deadline: *deadline, + record: wire::handshake::HandshakeRecord::HelloReply(reply.clone()), + } + } + } + _ => RetryAction::Ignore, + } + }; + + match action { + RetryAction::Resend { + peer, + deadline, + record, + } => { + if self.handshake_write_pending(token) { + return; + } + handshake::enqueue_handshake_record(self, token, deadline, peer, record); + } + RetryAction::Fail => self.fail_handshake(QlError::Timeout, emit), + RetryAction::Ignore => {} + } + } + + fn abort_streams(&mut self, error: QlError, emit: &mut impl OutputFn) { + let streams = std::mem::take(&mut self.streams).into_inner(); + for (stream_id, stream) in streams { + self.fail_stream(stream_id, stream, error.clone(), emit); + } + } + + fn fail_stream_by_id(&mut self, stream_id: StreamId, error: QlError, emit: &mut impl OutputFn) { + let Some(stream) = self.streams.remove(&stream_id) else { + return; + }; + self.fail_stream(stream_id, stream, error, emit); + } + + pub fn fail_stream( + &mut self, + stream_id: StreamId, + stream: StreamState, + error: QlError, + emit: &mut impl OutputFn, + ) { + self.clear_active_writes_for_stream(stream_id); + match stream.role { + StreamRole::Initiator(_) => { + emit(EngineOutput::OutboundFailed { + stream_id, + error: error.clone(), + }); + emit(EngineOutput::InboundFailed { stream_id, error }); + } + StreamRole::Responder(stream) => { + emit(EngineOutput::InboundFailed { + stream_id, + error: error.clone(), + }); + if stream.response_started || stream.response.is_closed() { + emit(EngineOutput::OutboundFailed { stream_id, error }); + } + } + StreamRole::Provisional(_) => {} + } + emit(EngineOutput::StreamReaped { stream_id }); + } + + pub fn handle_timeouts( + &mut self, + now: Instant, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, + ) { + loop { + let Some(entry) = self + .state + .timeouts + .peek_mut() + .filter(|entry| entry.0.at <= now) + else { + break; + }; + let entry = std::collections::binary_heap::PeekMut::pop(entry).0; + match entry.kind { + TimeoutKind::Outbound { token } => { + self.state + .control_outbound + .retain(|message| message.token != token); + } + TimeoutKind::HandshakeRetry { token } => { + self.handle_handshake_retry_timeout(token, emit); + } + TimeoutKind::StreamAckDelay { stream_id, token } => { + if let Some(stream) = self.streams.get_mut(&stream_id) { + let control = &mut stream.control; + if control.ack_delay_token == Some(token) { + control.ack_delay_token = None; + control.ack_immediate = true; + } + } + } + TimeoutKind::StreamProvisional { stream_id, token } => { + let should_reset = self + .streams + .get(&stream_id) + .and_then(StreamState::provisional_timeout_token) + .is_some_and(|stream_token| stream_token == token); + if should_reset { + self.streams.remove(&stream_id); + self.send_ephemeral_close( + stream_id, + CloseTarget::Both, + CloseCode::PROTOCOL, + ); + } + } + } + } + + if let Some(PeerRecord { + session: PeerSession::Connected { recent_ready, .. }, + .. + }) = self.peer.as_mut() + { + if recent_ready + .as_ref() + .is_some_and(|ready| ready.expires_at <= now) + { + *recent_ready = None; + } + } + + let handshake_due = self + .handshake_deadline() + .is_some_and(|deadline| deadline <= now); + if handshake_due { + self.fail_handshake(QlError::Timeout, emit); + return; + } + + let keepalive_due = self + .keep_alive_deadline() + .is_some_and(|deadline| deadline <= now); + if !keepalive_due { + return; + } + + let Some(entry) = self.peer.as_ref() else { + return; + }; + let PeerSession::Connected { keepalive, .. } = &entry.session else { + return; + }; + + if keepalive.pending { + if let Some(entry) = self.peer.as_mut() { + entry.session = PeerSession::Disconnected; + } + self.emit_peer_status(emit); + self.drop_outbound(); + self.abort_streams(QlError::SendFailed, emit); + return; + } + + self.send_heartbeat_message(now, crypto); + if let Some(entry) = self.peer.as_mut() { + if let PeerSession::Connected { keepalive, .. } = &mut entry.session { + keepalive.pending = true; + keepalive.last_activity = Some(now); + } + } + } +} + +fn encrypted_message_nonce(crypto: &impl QlCrypto) -> [u8; NONCE_SIZE] { + let mut nonce = [0u8; NONCE_SIZE]; + crypto.fill_random_bytes(&mut nonce); + nonce +} diff --git a/ql-engine/src/engine/implementation/peer.rs b/ql-engine/src/engine/implementation/peer.rs new file mode 100644 index 00000000..ef2dd06b --- /dev/null +++ b/ql-engine/src/engine/implementation/peer.rs @@ -0,0 +1,153 @@ +use super::*; + +pub fn handle_bind_peer(engine: &mut Engine, peer: Peer, emit: &mut impl OutputFn) { + if let Some(existing) = engine.peer.as_ref() { + emit(EngineOutput::PeerStatusChanged { + peer: existing.peer, + session: PeerSession::Disconnected, + }); + } + bind_peer_record(engine, peer, emit); +} + +pub fn handle_pair_local(engine: &mut Engine, now: Instant, crypto: &impl QlCrypto) { + let Some(peer) = engine.peer.as_ref() else { + return; + }; + let meta = engine.next_control_meta(engine.config.packet_expiration); + let Ok(record) = wire::pair::build_pair_request( + &engine.identity, + crypto, + peer.peer, + &peer.encapsulation_key, + meta, + ) else { + return; + }; + let token = engine.state.next_token(); + engine.state.enqueue_handshake_message( + &engine.config, + token, + now + engine.config.packet_expiration, + wire::encode_record(&record), + ); +} + +pub fn handle_unpair_local(engine: &mut Engine, now: Instant, emit: &mut impl OutputFn) { + let Some(peer) = engine.peer.as_ref().map(|peer| peer.peer) else { + return; + }; + let meta = engine.next_control_meta(engine.config.packet_expiration); + let record = wire::unpair::build_unpair_record( + &engine.identity, + QlHeader { + sender: engine.identity.xid, + recipient: peer, + }, + meta, + ); + unpair_peer(engine, emit); + let token = engine.state.next_token(); + engine.state.enqueue_handshake_message( + &engine.config, + token, + now + engine.config.packet_expiration, + wire::encode_record(&record), + ); +} + +pub fn handle_pairing( + engine: &mut Engine, + now: Instant, + header: &QlHeader, + request: &mut wire::pair::ArchivedPairRequestRecord, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, +) { + let payload = match wire::pair::decrypt_pair_request(&engine.identity, header, request) { + Ok(payload) => payload, + Err(_) => return, + }; + let peer = XID::new(SigningPublicKey::MLDSA(payload.signing_pub_key.clone())); + if engine.is_replayed_control(peer, payload.meta) { + return; + } + if let Some(existing) = engine.peer.as_ref() { + if existing.peer != peer + || existing.signing_key != payload.signing_pub_key + || existing.encapsulation_key != payload.encapsulation_pub_key + { + return; + } + } else { + bind_peer_record( + engine, + Peer { + peer, + signing_key: payload.signing_pub_key, + encapsulation_key: payload.encapsulation_pub_key, + }, + emit, + ); + } + handshake::handle_connect(engine, now, crypto, emit); +} + +pub fn handle_unpair( + engine: &mut Engine, + peer: XID, + header: &QlHeader, + record: &wire::unpair::ArchivedUnpairRecord, + emit: &mut impl OutputFn, +) { + { + let Some(peer_record) = engine.peer.as_ref() else { + return; + }; + if wire::unpair::verify_unpair_record(header, record, &peer_record.signing_key).is_err() { + return; + } + } + let meta: ControlMeta = (&record.meta).into(); + if engine.is_replayed_control(peer, meta) { + return; + } + unpair_peer(engine, emit); +} + +fn bind_peer_record(engine: &mut Engine, peer: Peer, emit: &mut impl OutputFn) { + reset_runtime(engine, QlError::Cancelled, emit); + engine.peer = Some(PeerRecord::new( + peer.peer, + peer.signing_key, + peer.encapsulation_key, + )); + engine.emit_peer_status(emit); + if let Some(peer) = engine.peer.as_ref() { + emit(EngineOutput::PersistPeer(peer.snapshot())); + } +} + +fn reset_runtime(engine: &mut Engine, error: QlError, emit: &mut impl OutputFn) { + let streams = std::mem::take(&mut engine.streams).into_inner(); + for (stream_id, stream) in streams { + engine.fail_stream(stream_id, stream, error.clone(), emit); + } + engine.state.control_outbound.clear(); + engine.state.active_writes.clear(); + engine.state.timeouts.clear(); +} + +fn unpair_peer(engine: &mut Engine, emit: &mut impl OutputFn) { + let Some(peer) = engine.peer.as_ref().map(|peer| peer.peer) else { + return; + }; + engine.drop_outbound(); + engine.abort_streams(QlError::SendFailed, emit); + engine.peer = None; + emit(EngineOutput::PeerStatusChanged { + peer, + session: PeerSession::Disconnected, + }); + emit(EngineOutput::ClearPeer); +} diff --git a/ql-engine/src/engine/implementation/stream.rs b/ql-engine/src/engine/implementation/stream.rs new file mode 100644 index 00000000..a957017c --- /dev/null +++ b/ql-engine/src/engine/implementation/stream.rs @@ -0,0 +1,822 @@ +use std::cmp::Reverse; + +use super::*; +use crate::{ + engine::{ + EngineConfig, EngineState, StreamConfig, + state::{StreamNamespace, TimeoutEntry}, + stream::*, + }, + wire::stream::*, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum StreamHandleResult { + Keep, + Remove, + Reap, +} + +pub fn open_stream( + engine: &mut Engine, + now: Instant, + request_head: Vec, + request_prefix: Option, + _config: StreamConfig, +) -> Result { + let Some(entry) = engine.peer.as_ref() else { + return Err(QlError::NoPeerBound); + }; + if !entry.session.is_connected() { + return Err(QlError::MissingSession); + } + + let stream_namespace = StreamNamespace::for_local(engine.identity.xid, entry.peer); + let stream_id = engine.state.next_stream_id(stream_namespace); + let request_prefix_fin = request_prefix.as_ref().is_some_and(|chunk| chunk.fin); + let frame = StreamFrameOpen { + stream_id, + request_head, + request_prefix, + }; + let mut stream = StreamState { + meta: StreamMeta { + stream_id, + last_activity: now, + }, + control: StreamControl { + pending: std::collections::VecDeque::from([StreamFrame::Open(frame)]), + ..Default::default() + }, + role: StreamRole::Initiator(InitiatorStream { + request: OutboundPhase::from_prefix(request_prefix_fin), + response: InboundState::new(), + }), + }; + drive_stream(&mut stream); + engine.streams.insert(stream_id, stream); + Ok(stream_id) +} + +pub fn handle_close_stream( + engine: &mut Engine, + now: Instant, + stream_id: StreamId, + target: CloseTarget, + code: CloseCode, + payload: Vec, +) { + let Some(stream) = engine.streams.get_mut(&stream_id) else { + return; + }; + + let mut dirty = false; + + if matches!(target, CloseTarget::Request | CloseTarget::Both) { + if let Some(inbound) = stream.inbound_mut(StreamSide::Request) { + dirty |= inbound.close(); + } + if let Some(outbound) = stream.outbound_mut(StreamSide::Request) { + dirty |= outbound.close(); + } + } + if matches!(target, CloseTarget::Response | CloseTarget::Both) { + if let Some(inbound) = stream.inbound_mut(StreamSide::Response) { + dirty |= inbound.close(); + } + if let Some(outbound) = stream.outbound_mut(StreamSide::Response) { + dirty |= outbound.close(); + } + } + + if dirty { + stream + .control + .queue_frame_front(close_frame(stream_id, target, code, payload)); + stream.meta.last_activity = now; + drive_stream(stream); + } +} + +pub fn handle_outbound_data(engine: &mut Engine, stream_id: StreamId, bytes: Vec) { + if bytes.is_empty() { + return; + } + let Some(stream) = engine.streams.get_mut(&stream_id) else { + return; + }; + let Some(side) = stream.outbound_side() else { + return; + }; + if let StreamRole::Responder(state) = &mut stream.role { + if side == StreamSide::Response { + state.response_started = true; + } + } + let Some(outbound) = stream.outbound_mut(side) else { + return; + }; + if !outbound.can_queue_data() { + return; + } + let chunk = BodyChunk { bytes, fin: false }; + stream + .control + .queue_frame_back(StreamFrame::Data(StreamFrameData { stream_id, chunk })); + drive_stream(stream); +} + +pub fn handle_outbound_finished(engine: &mut Engine, stream_id: StreamId) { + let Some(stream) = engine.streams.get_mut(&stream_id) else { + return; + }; + let Some(side) = stream.outbound_side() else { + return; + }; + if let StreamRole::Responder(state) = &mut stream.role { + if side == StreamSide::Response { + state.response_started = true; + } + } + let Some(outbound) = stream.outbound_mut(side) else { + return; + }; + outbound.finish(); + drive_stream(stream); +} + +pub fn handle_stream( + engine: &mut Engine, + now: Instant, + _peer: XID, + header: &QlHeader, + encrypted: &mut ArchivedEncryptedMessage, + emit: &mut impl OutputFn, +) { + let body = { + let Some(peer_record) = engine.peer.as_ref() else { + return; + }; + let PeerSession::Connected { session_key, .. } = &peer_record.session else { + return; + }; + match decrypt_stream(header, encrypted, session_key) { + Ok(body) => body, + Err(_) => return, + } + }; + engine.record_activity(now); + + let message = match body { + StreamBody::Ack(StreamAckBody { stream_id, ack, .. }) => { + process_stream_ack(engine, now, stream_id, ack, emit); + if let Some(stream) = engine.streams.get_mut(&stream_id) { + stream.meta.last_activity = now; + } + maybe_reap_stream(engine, stream_id, emit); + return; + } + StreamBody::Message(message) => message, + }; + + let stream_id = message.frame.stream_id(); + process_stream_ack(engine, now, stream_id, message.ack, emit); + + if !engine.streams.contains_key(&stream_id) { + let Some(peer_record) = engine.peer.as_ref() else { + return; + }; + let local_namespace = StreamNamespace::for_local(engine.identity.xid, peer_record.peer); + if !local_namespace.remote().matches(stream_id) { + return; + } + let token = engine.state.next_token(); + engine.streams.insert( + stream_id, + StreamState { + meta: StreamMeta { + stream_id, + last_activity: now, + }, + control: StreamControl::default(), + role: StreamRole::Provisional(ProvisionalStream { + timeout_token: token, + }), + }, + ); + engine.state.timeouts.push(Reverse(TimeoutEntry { + at: now + engine.config.packet_expiration, + kind: TimeoutKind::StreamProvisional { stream_id, token }, + })); + } + + let disposition = { + let (state, streams) = (&mut engine.state, &mut engine.streams); + let Some(stream) = streams.get_mut(&stream_id) else { + return; + }; + stream.meta.last_activity = now; + + match stream + .control + .buffer_incoming(message.tx_seq, message.frame) + { + BufferIncomingResult::OutOfWindow => { + if stream.is_provisional() { + state.enqueue_stream_close( + &engine.config, + true, + stream_id, + CloseTarget::Both, + CloseCode::PROTOCOL, + Vec::new(), + ); + StreamHandleResult::Remove + } else { + queue_protocol_close(stream, emit); + stream.meta.last_activity = now; + StreamHandleResult::Keep + } + } + BufferIncomingResult::Duplicate | BufferIncomingResult::AlreadyBuffered => { + stream.control.note_ack(true); + schedule_stream_ack(state, &engine.config, stream, now); + StreamHandleResult::Keep + } + BufferIncomingResult::Buffered { out_of_order } => { + stream.control.note_ack(out_of_order); + drain_committed_stream_frames(state, &engine.config, stream, now, emit) + } + } + }; + match disposition { + StreamHandleResult::Keep => {} + StreamHandleResult::Remove => { + engine.streams.remove(&stream_id); + } + StreamHandleResult::Reap => { + engine.streams.remove(&stream_id); + emit(EngineOutput::StreamReaped { stream_id }); + } + } +} + +pub fn take_next_stream_write( + engine: &mut Engine, + crypto: &impl QlCrypto, +) -> Option { + let (recipient, session_key) = engine.stream_write_session()?; + let stream_ids: Vec<_> = engine.streams.scan_from_cursor().collect(); + for stream_id in stream_ids { + let write = take_next_write_for_stream(engine, stream_id, recipient, &session_key, crypto); + if write.is_some() { + engine.streams.advance_cursor_after(stream_id); + return write; + } + } + None +} + +pub fn process_stream_ack( + engine: &mut Engine, + now: Instant, + stream_id: StreamId, + ack: StreamAck, + emit: &mut impl OutputFn, +) { + if ack == StreamAck::EMPTY { + return; + } + + let should_reap = { + let Some(stream) = engine.streams.get_mut(&stream_id) else { + return; + }; + stream.control.clear_fast_recovery(ack.base); + let fast_retransmit = stream + .control + .fast_retransmit_candidate(ack, engine.config.stream_fast_retransmit_threshold); + + loop { + let acked_tx_seq = stream + .control + .in_flight + .iter() + .find_map(|(tx_seq, in_flight)| match in_flight.write_state { + // ignore acks for writes that have not been sent out yet + InFlightWriteState::Ready => None, + InFlightWriteState::Issued | InFlightWriteState::WaitingRetry { .. } => { + StreamControl::ack_covers(ack, tx_seq).then_some(tx_seq) + } + }); + let Some(tx_seq) = acked_tx_seq else { + break; + }; + let Some(in_flight) = stream.control.remove_in_flight(tx_seq) else { + continue; + }; + + match in_flight.frame { + StreamFrame::Open(StreamFrameOpen { request_prefix, .. }) => { + if let StreamRole::Initiator(stream) = &mut stream.role { + if request_prefix.as_ref().is_some_and(|chunk| chunk.fin) + && stream.request.close() + { + emit(EngineOutput::OutboundClosed { stream_id }); + } + } + } + StreamFrame::Data(StreamFrameData { + chunk: BodyChunk { fin: true, .. }, + .. + }) => { + if let Some(side) = stream.outbound_side() { + if let Some(outbound) = stream.outbound_mut(side) { + if outbound.close() { + emit(EngineOutput::OutboundClosed { stream_id }); + } + } + } + } + StreamFrame::Close(StreamFrameClose { + target, + code, + payload, + .. + }) => { + for side in [StreamSide::Request, StreamSide::Response] { + let affects_outbound = matches!( + (target, side), + (CloseTarget::Request, StreamSide::Request) + | (CloseTarget::Response, StreamSide::Response) + | (CloseTarget::Both, _) + ); + if affects_outbound { + if let Some(outbound) = stream.outbound_mut(side) { + if outbound.close() { + emit(EngineOutput::OutboundFailed { + stream_id, + error: QlError::StreamClosed { + target, + code, + payload: payload.clone(), + }, + }); + } + } + } + } + } + StreamFrame::Data(_) => {} + } + } + + if let Some(tx_seq) = fast_retransmit { + stream.control.schedule_fast_retransmit(tx_seq, now); + } + drive_stream(stream); + stream.can_reap() + }; + + if should_reap { + engine.streams.remove(&stream_id); + emit(EngineOutput::StreamReaped { stream_id }); + } +} + +fn schedule_stream_ack( + state: &mut EngineState, + config: &EngineConfig, + stream: &mut StreamState, + now: Instant, +) { + let stream_id = stream.meta.stream_id; + let control = &mut stream.control; + if !control.ack_dirty { + return; + } + if control.ack_immediate || config.stream_ack_delay.is_zero() { + control.ack_delay_token = None; + return; + } + if control.ack_delay_token.is_some() { + return; + } + let token = state.next_token(); + control.ack_delay_token = Some(token); + state.timeouts.push(Reverse(TimeoutEntry { + at: now + config.stream_ack_delay, + kind: TimeoutKind::StreamAckDelay { stream_id, token }, + })); +} + +fn drain_committed_stream_frames( + state: &mut EngineState, + config: &EngineConfig, + stream: &mut StreamState, + now: Instant, + emit: &mut impl OutputFn, +) -> StreamHandleResult { + let stream_id = stream.meta.stream_id; + loop { + let next = stream.control.pop_next_committable(); + let Some((_tx_seq, frame)) = next else { + break; + }; + if stream.is_provisional() && !matches!(frame, StreamFrame::Open(_)) { + state.enqueue_stream_close( + config, + true, + stream_id, + CloseTarget::Both, + CloseCode::PROTOCOL, + Vec::new(), + ); + return StreamHandleResult::Remove; + } + match frame { + StreamFrame::Open(frame) => handle_stream_open(stream, now, frame, emit), + StreamFrame::Close(frame) => handle_stream_close_from_peer(stream, frame, emit), + StreamFrame::Data(frame) => handle_stream_data(stream, now, frame, emit), + } + } + stream.control.maybe_force_ack_for_progress(); + schedule_stream_ack(state, config, stream, now); + if stream.can_reap() { + StreamHandleResult::Reap + } else { + StreamHandleResult::Keep + } +} + +fn handle_stream_open( + stream: &mut StreamState, + now: Instant, + frame: StreamFrameOpen, + emit: &mut impl OutputFn, +) { + let StreamFrameOpen { + stream_id, + request_head, + request_prefix, + } = frame; + if !stream.is_provisional() { + queue_protocol_close(stream, emit); + return; + } + stream.meta.last_activity = now; + stream.role = StreamRole::Responder(ResponderStream { + request: InboundState::new(), + response: OutboundPhase::from_prefix(false), + response_started: false, + }); + if let Some(chunk) = request_prefix.as_ref() { + let Some(inbound) = stream.inbound_mut(StreamSide::Request) else { + return; + }; + if chunk.fin { + inbound.close(); + } + } + emit(EngineOutput::InboundStreamOpened { + stream_id, + request_head, + request_prefix, + }); +} + +fn handle_stream_close_from_peer( + stream: &mut StreamState, + frame: StreamFrameClose, + emit: &mut impl OutputFn, +) { + let StreamFrameClose { + target, + code, + payload, + .. + } = frame; + apply_remote_close(stream, target, code, payload, emit); +} + +fn handle_stream_data( + stream: &mut StreamState, + now: Instant, + frame: StreamFrameData, + emit: &mut impl OutputFn, +) { + let StreamFrameData { stream_id, chunk } = frame; + let Some(side) = stream.inbound_side() else { + queue_protocol_close(stream, emit); + return; + }; + let Some(inbound) = stream.inbound_mut(side) else { + queue_protocol_close(stream, emit); + return; + }; + if inbound.closed { + queue_protocol_close(stream, emit); + } else { + if !chunk.bytes.is_empty() { + emit(EngineOutput::InboundData { + stream_id, + bytes: chunk.bytes, + }); + } + if chunk.fin && inbound.close() { + emit(EngineOutput::InboundFinished { stream_id }); + } + } + stream.meta.last_activity = now; +} + +fn drive_stream(stream: &mut StreamState) { + let (meta, control, role) = stream.parts_mut(); + match role { + StreamRole::Initiator(stream) => { + drive_stream_outbound(meta.stream_id, control, Some(&mut stream.request)); + } + StreamRole::Responder(stream) => { + drive_stream_outbound(meta.stream_id, control, Some(&mut stream.response)); + } + StreamRole::Provisional(_) => drive_stream_outbound(meta.stream_id, control, None), + } +} + +fn drive_stream_outbound( + stream_id: StreamId, + control: &mut StreamControl, + mut outbound: Option<&mut OutboundPhase>, +) { + loop { + if control.send_window_has_space() { + if let Some(frame) = control.pending.pop_front() { + enqueue_stream_frame(control, frame, 0); + continue; + } + } + if !control.send_window_has_space() { + return; + } + + let Some(outbound) = outbound.as_deref_mut() else { + return; + }; + if outbound.queue_fin() { + enqueue_stream_frame( + control, + StreamFrame::Data(StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: Vec::new(), + fin: true, + }, + }), + 0, + ); + continue; + } + return; + } +} + +fn enqueue_stream_frame(control: &mut StreamControl, frame: StreamFrame, attempt: u8) { + let tx_seq = control.take_tx_seq(); + enqueue_stream_frame_with_seq(control, tx_seq, frame, attempt); +} + +fn enqueue_stream_frame_with_seq( + control: &mut StreamControl, + tx_seq: StreamSeq, + frame: StreamFrame, + attempt: u8, +) { + control.insert_in_flight(InFlightFrame { + tx_seq, + frame, + attempt, + write_state: InFlightWriteState::Ready, + }); +} + +fn queue_protocol_close(stream: &mut StreamState, emit: &mut impl OutputFn) { + let stream_id = stream.meta.stream_id; + let control = &mut stream.control; + control.clear_transient_buffers(); + control.queue_frame_front(close_frame( + stream_id, + CloseTarget::Both, + CloseCode::PROTOCOL, + Vec::new(), + )); + for side in [StreamSide::Request, StreamSide::Response] { + if let Some(outbound) = stream.outbound_mut(side) { + if outbound.close() { + emit(EngineOutput::OutboundFailed { + stream_id, + error: QlError::StreamProtocol, + }); + } + } + if let Some(inbound) = stream.inbound_mut(side) { + if inbound.close() { + emit(EngineOutput::InboundFailed { + stream_id, + error: QlError::StreamProtocol, + }); + } + } + } + drive_stream(stream); +} + +fn apply_remote_close( + stream: &mut StreamState, + target: CloseTarget, + code: CloseCode, + payload: Vec, + emit: &mut impl OutputFn, +) { + let stream_id = stream.meta.stream_id; + let error = QlError::StreamClosed { + target, + code, + payload: payload.clone(), + }; + if matches!(target, CloseTarget::Request | CloseTarget::Both) { + if let Some(inbound) = stream.inbound_mut(StreamSide::Request) { + if inbound.close() { + emit(EngineOutput::InboundFailed { + stream_id, + error: error.clone(), + }); + } + } + if let Some(outbound) = stream.outbound_mut(StreamSide::Request) { + if outbound.close() { + emit(EngineOutput::OutboundFailed { + stream_id, + error: error.clone(), + }); + } + } + } + if matches!(target, CloseTarget::Response | CloseTarget::Both) { + if let Some(inbound) = stream.inbound_mut(StreamSide::Response) { + if inbound.close() { + emit(EngineOutput::InboundFailed { + stream_id, + error: error.clone(), + }); + } + } + if let Some(outbound) = stream.outbound_mut(StreamSide::Response) { + if outbound.close() { + emit(EngineOutput::OutboundFailed { + stream_id, + error: error.clone(), + }); + } + } + } +} + +fn maybe_reap_stream(engine: &mut Engine, stream_id: StreamId, emit: &mut impl OutputFn) { + if engine + .streams + .get(&stream_id) + .is_some_and(StreamState::can_reap) + { + engine.streams.remove(&stream_id); + emit(EngineOutput::StreamReaped { stream_id }); + } +} + +fn take_next_write_for_stream( + engine: &mut Engine, + stream_id: StreamId, + recipient: XID, + session_key: &SymmetricKey, + crypto: &impl QlCrypto, +) -> Option { + #[derive(Clone, Copy)] + enum StreamWriteSelection { + Ack, + InitialFrame { tx_seq: StreamSeq }, + RetryFrame { tx_seq: StreamSeq }, + } + + let now = engine.state.now; + let selection = { + let stream = engine.streams.get(&stream_id)?; + let is_provisional = stream.is_provisional(); + let control = &stream.control; + if !is_provisional { + if let Some(tx_seq) = control.in_flight.iter().find_map(|(tx_seq, in_flight)| { + matches!( + in_flight.write_state, + InFlightWriteState::WaitingRetry { retry_at } + if retry_at <= now && in_flight.attempt < engine.config.stream_retry_limit + ) + .then_some(tx_seq) + }) { + Some(StreamWriteSelection::RetryFrame { tx_seq }) + } else if let Some(tx_seq) = control.in_flight.iter().find_map(|(tx_seq, in_flight)| { + matches!(in_flight.write_state, InFlightWriteState::Ready).then_some(tx_seq) + }) { + Some(StreamWriteSelection::InitialFrame { tx_seq }) + } else if control.ack_dirty + && control.ack_immediate + && control.ack_outbound_token.is_none() + { + Some(StreamWriteSelection::Ack) + } else { + None + } + } else if control.ack_dirty && control.ack_immediate && control.ack_outbound_token.is_none() + { + Some(StreamWriteSelection::Ack) + } else { + None + } + }?; + + match selection { + StreamWriteSelection::Ack => { + let token = engine.state.next_token(); + let ack = { + let stream = engine.streams.get_mut(&stream_id)?; + let control = &mut stream.control; + if !(control.ack_dirty + && control.ack_immediate + && control.ack_outbound_token.is_none()) + { + return None; + } + let ack = control.current_ack(); + control.clear_ack_schedule(); + control.note_ack_sent(ack); + control.ack_outbound_token = Some(token); + ack + }; + + let body = StreamBody::Ack(StreamAckBody { + stream_id, + ack, + valid_until: wire::now_secs() + .saturating_add(engine.config.packet_expiration.as_secs()), + }); + let record = encrypt_stream( + QlHeader { + sender: engine.identity.xid, + recipient, + }, + session_key, + &body, + encrypted_message_nonce(crypto), + ); + Some(engine.issue_write( + OutboundWriteKind::StreamAck { stream_id }, + Some(token), + wire::encode_record(&record), + )) + } + StreamWriteSelection::InitialFrame { tx_seq } + | StreamWriteSelection::RetryFrame { tx_seq } => { + let (ack, frame) = { + let stream = engine.streams.get_mut(&stream_id)?; + let inbound_alive = match &stream.role { + StreamRole::Initiator(state) => !state.response.closed, + StreamRole::Responder(state) => !state.request.closed, + StreamRole::Provisional(_) => return None, + }; + let control = &mut stream.control; + let ack = control.take_piggyback_ack(inbound_alive); + let frame = control.mark_write_issued(tx_seq)?; + (ack, frame) + }; + + let body = StreamBody::Message(StreamMessage { + tx_seq, + ack, + valid_until: wire::now_secs() + .saturating_add(engine.config.packet_expiration.as_secs()), + frame, + }); + let record = encrypt_stream( + QlHeader { + sender: engine.identity.xid, + recipient, + }, + session_key, + &body, + encrypted_message_nonce(crypto), + ); + Some(engine.issue_write( + OutboundWriteKind::StreamFrame { stream_id, tx_seq }, + None, + wire::encode_record(&record), + )) + } + } +} diff --git a/ql-engine/src/engine/mod.rs b/ql-engine/src/engine/mod.rs new file mode 100644 index 00000000..1922ec5b --- /dev/null +++ b/ql-engine/src/engine/mod.rs @@ -0,0 +1,171 @@ +mod implementation; +pub mod replay_cache; +mod ring; +mod state; +pub(crate) mod stream; + +#[cfg(test)] +mod tests; + +use std::time::{Duration, Instant}; + +use bc_components::XID; +pub use state::{ + Engine, EngineState, HandshakeInitiator, HandshakeResponder, KeepAliveState, OutboundWrite, + PeerRecord, PeerSession, RecentReady, Token, WriteId, +}; + +use crate::{ + identity::QlIdentity, + wire::stream::{BodyChunk, CloseCode, CloseTarget}, + Peer, QlError, StreamId, +}; + +pub trait QlCrypto { + fn fill_random_bytes(&self, data: &mut [u8]); +} + +#[derive(Debug, Clone, Copy)] +pub struct KeepAliveConfig { + pub interval: Duration, + pub timeout: Duration, +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct StreamConfig {} + +#[derive(Debug, Clone, Copy)] +pub struct EngineConfig { + pub handshake_timeout: Duration, + pub handshake_retry_interval: Duration, + pub max_handshake_retries: u8, + pub packet_expiration: Duration, + pub stream_ack_delay: Duration, + pub stream_ack_timeout: Duration, + pub stream_fast_retransmit_threshold: u8, + pub stream_retry_limit: u8, + pub keep_alive: Option, +} + +impl Default for EngineConfig { + fn default() -> Self { + Self { + handshake_timeout: Duration::from_secs(5), + handshake_retry_interval: Duration::from_millis(750), + max_handshake_retries: 3, + packet_expiration: Duration::from_secs(30), + stream_ack_delay: Duration::from_millis(5), + stream_ack_timeout: Duration::from_millis(150), + stream_fast_retransmit_threshold: 2, + stream_retry_limit: 5, + keep_alive: None, + } + } +} + +#[derive(Debug)] +pub enum EngineInput { + BindPeer(Peer), + Pair, + Connect, + Unpair, + CloseStream { + stream_id: StreamId, + target: CloseTarget, + code: CloseCode, + payload: Vec, + }, + + OutboundData { + stream_id: StreamId, + bytes: Vec, + }, + OutboundFinished { + stream_id: StreamId, + }, + Incoming(Vec), + TimerExpired, +} + +#[derive(Debug)] +pub enum EngineOutput { + PeerStatusChanged { + peer: XID, + session: PeerSession, + }, + PersistPeer(Peer), + ClearPeer, + + InboundStreamOpened { + stream_id: StreamId, + request_head: Vec, + request_prefix: Option, + }, + InboundData { + stream_id: StreamId, + bytes: Vec, + }, + InboundFinished { + stream_id: StreamId, + }, + InboundFailed { + stream_id: StreamId, + error: QlError, + }, + + OutboundClosed { + stream_id: StreamId, + }, + OutboundFailed { + stream_id: StreamId, + error: QlError, + }, + + StreamReaped { + stream_id: StreamId, + }, +} + +pub trait OutputFn: FnMut(EngineOutput) {} + +impl OutputFn for T where T: FnMut(EngineOutput) {} + +impl Engine { + pub fn new(config: EngineConfig, identity: QlIdentity, peer: Option) -> Self { + Self { + config: config, + identity, + peer: peer + .map(|peer| PeerRecord::new(peer.peer, peer.signing_key, peer.encapsulation_key)), + state: EngineState::new(), + streams: stream::StreamStore::default(), + } + } + + pub fn run_tick( + &mut self, + now: Instant, + input: EngineInput, + crypto: &impl QlCrypto, + emit: &mut impl OutputFn, + ) { + self.run_tick_inner(now, input, crypto, emit); + } + + pub fn take_next_write(&mut self, crypto: &impl QlCrypto) -> Option { + self.take_next_write_inner(crypto) + } + + pub fn complete_write( + &mut self, + write_id: WriteId, + result: Result<(), QlError>, + emit: &mut impl OutputFn, + ) { + self.complete_write_inner(write_id, result, emit); + } + + pub fn next_deadline(&self) -> Option { + self.next_deadline_inner() + } +} diff --git a/ql2/src/engine/replay_cache.rs b/ql-engine/src/engine/replay_cache.rs similarity index 98% rename from ql2/src/engine/replay_cache.rs rename to ql-engine/src/engine/replay_cache.rs index 7643c5c8..8b5d5dc3 100644 --- a/ql2/src/engine/replay_cache.rs +++ b/ql-engine/src/engine/replay_cache.rs @@ -10,6 +10,7 @@ use crate::PacketId; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct ReplayKey { + /// unfortunately we need this in the key, to avoid replay attacks of pair/unpair. pub peer: XID, pub packet_id: PacketId, } diff --git a/ql2/src/engine/ring.rs b/ql-engine/src/engine/ring.rs similarity index 98% rename from ql2/src/engine/ring.rs rename to ql-engine/src/engine/ring.rs index d1f4bf64..4ad7f567 100644 --- a/ql2/src/engine/ring.rs +++ b/ql-engine/src/engine/ring.rs @@ -189,10 +189,8 @@ impl<'a, const N: usize, T> Iterator for SeqRingDrain<'a, N, T> { mod tests { use super::*; use crate::{ - engine::stream::{BufferIncomingResult, InFlightFrame, StreamControl}, - wire::stream::{ - BodyChunk, Direction, StreamAck, StreamFrame, StreamFrameData, StreamFrameOpen, - }, + engine::stream::{BufferIncomingResult, InFlightFrame, InFlightWriteState, StreamControl}, + wire::stream::{BodyChunk, StreamAck, StreamFrame, StreamFrameData, StreamFrameOpen}, StreamId, }; @@ -201,7 +199,6 @@ mod tests { StreamSeq(tx_seq), StreamFrame::Data(StreamFrameData { stream_id, - dir: Direction::Request, chunk: BodyChunk { bytes: vec![byte], fin: false, @@ -318,6 +315,7 @@ mod tests { request_prefix: None, }), attempt: 0, + write_state: InFlightWriteState::Ready, }; control.insert_in_flight(frame); control.next_tx_seq = StreamSeq(tx_seq + 1); diff --git a/ql2/src/engine/state.rs b/ql-engine/src/engine/state.rs similarity index 55% rename from ql2/src/engine/state.rs rename to ql-engine/src/engine/state.rs index 7debf60f..0bbf9c86 100644 --- a/ql2/src/engine/state.rs +++ b/ql-engine/src/engine/state.rs @@ -7,36 +7,70 @@ use std::{ use bc_components::{MLDSAPublicKey, MLKEMPublicKey, SymmetricKey, XID}; -use super::{ - replay_cache::ReplayCache, - stream::{QueuedWrite, StreamState}, - EngineConfig, StreamConfig, -}; +use super::{replay_cache::ReplayCache, stream::StreamStore, EngineConfig}; use crate::{ - platform::QlIdentity, + identity::QlIdentity, wire::{ - handshake::{Hello, HelloReply, ResponderSecrets}, - stream::{BodyChunk, Direction, RejectCode, ResetCode, StreamBody}, + handshake::{Confirm, Hello, HelloReply, Ready, ResponderSecrets}, + stream::{CloseCode, CloseTarget}, StreamSeq, }, - PacketId, Peer, QlError, StreamId, + PacketId, Peer, StreamId, }; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Token(pub u64); #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct OpenId(pub u64); +pub struct WriteId(pub u64); #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct TrackedWrite { - pub stream_id: StreamId, - pub tx_seq: StreamSeq, +pub enum OutboundWriteKind { + Control, + StreamAck { + stream_id: StreamId, + }, + StreamFrame { + stream_id: StreamId, + tx_seq: StreamSeq, + }, + StreamClose { + stream_id: StreamId, + }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OutboundWrite { + pub id: WriteId, + pub bytes: Vec, +} + +#[derive(Debug)] +pub struct ControlWrite { + pub token: Token, + pub kind: OutboundWriteKind, + pub payload: ControlWritePayload, +} + +#[derive(Debug)] +pub enum ControlWritePayload { + Encoded(Vec), + StreamClose { + stream_id: StreamId, + target: CloseTarget, + code: CloseCode, + payload: Vec, + }, +} + +#[derive(Debug, Clone, Copy)] +pub struct ActiveWrite { + pub token: Option, + pub kind: OutboundWriteKind, } #[derive(Debug, Clone)] pub struct KeepAliveState { - pub token: Token, pub pending: bool, pub last_activity: Option, } @@ -44,17 +78,45 @@ pub struct KeepAliveState { impl Default for KeepAliveState { fn default() -> Self { Self { - token: Token(0), pending: false, last_activity: None, } } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum InitiatorStage { - WaitingHelloReply, - SendingConfirm, +#[derive(Debug, Clone, PartialEq)] +pub enum HandshakeInitiator { + WaitingHelloReply { + retry_count: u8, + retry_at: Option, + }, + WaitingReady { + reply: HelloReply, + confirm: Confirm, + retry_count: u8, + retry_at: Option, + }, +} + +#[derive(Debug, Clone)] +pub enum HandshakeResponder { + WaitingConfirm { + secrets: ResponderSecrets, + retry_count: u8, + retry_at: Option, + }, + SendingReady { + session_key: SymmetricKey, + ready: Ready, + }, +} + +#[derive(Debug, Clone)] +pub struct RecentReady { + pub hello: Hello, + pub reply: HelloReply, + pub ready: Ready, + pub expires_at: Instant, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -64,9 +126,9 @@ pub enum StreamNamespace { } impl StreamNamespace { - const BIT: u64 = 1 << 63; + const BIT: u32 = 1 << 31; - pub fn bit(self) -> u64 { + pub fn bit(self) -> u32 { match self { Self::Low => 0, Self::High => Self::BIT, @@ -100,18 +162,19 @@ pub enum PeerSession { hello: Hello, session_key: SymmetricKey, deadline: Instant, - stage: InitiatorStage, + stage: HandshakeInitiator, }, Responder { handshake_token: Token, hello: Hello, reply: HelloReply, - secrets: ResponderSecrets, deadline: Instant, + stage: HandshakeResponder, }, Connected { session_key: SymmetricKey, keepalive: KeepAliveState, + recent_ready: Option, }, } @@ -155,171 +218,12 @@ impl PeerRecord { } } -#[derive(Debug)] -pub enum EngineInput { - BindPeer(Peer), - Pair, - Connect, - Unpair, - - OpenStream { - open_id: OpenId, - request_head: Vec, - request_prefix: Option, - config: StreamConfig, - }, - AcceptStream { - stream_id: StreamId, - response_head: Vec, - response_prefix: Option, - }, - RejectStream { - stream_id: StreamId, - code: RejectCode, - }, - - OutboundData { - stream_id: StreamId, - dir: Direction, - bytes: Vec, - }, - OutboundFinished { - stream_id: StreamId, - dir: Direction, - }, - - ResetOutbound { - stream_id: StreamId, - dir: Direction, - code: ResetCode, - }, - ResetInbound { - stream_id: StreamId, - dir: Direction, - code: ResetCode, - }, - PendingAcceptDropped { - stream_id: StreamId, - }, - ResponderDropped { - stream_id: StreamId, - }, - - Incoming(Vec), - WriteCompleted { - token: Token, - tracked: Option, - result: Result<(), QlError>, - }, - TimerExpired, -} - -#[derive(Debug)] -pub enum EngineOutput { - SetTimer(Option), - WriteMessage { - token: Token, - tracked: Option, - bytes: Vec, - }, - - PeerStatusChanged { - peer: XID, - session: PeerSession, - }, - PersistPeer(Peer), - ClearPeer, - - OpenStarted { - open_id: OpenId, - stream_id: StreamId, - }, - OpenAccepted { - open_id: OpenId, - stream_id: StreamId, - response_head: Vec, - response_prefix: Option, - }, - OpenFailed { - open_id: OpenId, - stream_id: StreamId, - error: QlError, - }, - - InboundStreamOpened { - stream_id: StreamId, - request_head: Vec, - request_prefix: Option, - }, - InboundData { - stream_id: StreamId, - dir: Direction, - bytes: Vec, - }, - InboundFinished { - stream_id: StreamId, - dir: Direction, - }, - InboundFailed { - stream_id: StreamId, - dir: Direction, - error: QlError, - }, - - NeedOutboundData { - stream_id: StreamId, - dir: Direction, - }, - OutboundClosed { - stream_id: StreamId, - dir: Direction, - }, - OutboundFailed { - stream_id: StreamId, - dir: Direction, - error: QlError, - }, - - StreamReaped { - stream_id: StreamId, - }, -} - -pub trait OutputFn: FnMut(EngineOutput) {} - -impl OutputFn for T where T: FnMut(EngineOutput) {} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum TimeoutKind { - Outbound { - token: Token, - }, - Handshake { - token: Token, - }, - KeepAliveSend { - token: Token, - }, - KeepAliveTimeout { - token: Token, - }, - StreamOpen { - stream_id: StreamId, - token: Token, - }, - StreamMessage { - stream_id: StreamId, - tx_seq: StreamSeq, - attempt: u8, - }, - StreamAckDelay { - stream_id: StreamId, - token: Token, - }, - StreamProvisional { - stream_id: StreamId, - token: Token, - }, + Outbound { token: Token }, + HandshakeRetry { token: Token }, + StreamAckDelay { stream_id: StreamId, token: Token }, + StreamProvisional { stream_id: StreamId, token: Token }, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -340,47 +244,39 @@ impl PartialOrd for TimeoutEntry { } } -#[derive(Debug)] -pub enum HelloAction { - StartResponder, - ResendReply { - reply: HelloReply, - deadline: Instant, - }, - Ignore, -} - pub struct Engine { pub config: EngineConfig, pub identity: QlIdentity, + pub peer: Option, pub state: EngineState, - pub streams: HashMap, + pub streams: StreamStore, } pub struct EngineState { - pub peer: Option, pub replay_cache: ReplayCache, pub next_token: Cell, + pub next_write_id: Cell, pub next_packet_id: Cell, - pub next_stream_id: Cell, - pub outbound: VecDeque, + pub next_stream_id: Cell, + pub control_outbound: VecDeque, + pub active_writes: HashMap, pub timeouts: BinaryHeap>, - pub write_in_flight: Option, + pub now: Instant, } impl EngineState { - pub fn new(peer: Option) -> Self { + pub fn new() -> Self { Self { - peer: peer - .map(|peer| PeerRecord::new(peer.peer, peer.signing_key, peer.encapsulation_key)), replay_cache: ReplayCache::new(), next_token: Cell::new(1), + next_write_id: Cell::new(1), next_packet_id: Cell::new(1), next_stream_id: Cell::new(1), - outbound: VecDeque::new(), + control_outbound: VecDeque::new(), + active_writes: HashMap::new(), timeouts: BinaryHeap::new(), - write_in_flight: None, + now: Instant::now(), } } @@ -394,6 +290,12 @@ impl EngineState { Token(token) } + pub fn next_write_id(&self) -> WriteId { + let id = self.next_write_id.get(); + self.next_write_id.set(id.wrapping_add(1)); + WriteId(id) + } + pub fn next_packet_id(&self) -> PacketId { let id = self.next_packet_id.get(); self.next_packet_id.set(id.wrapping_add(1)); @@ -413,38 +315,75 @@ impl EngineState { deadline: Instant, bytes: Vec, ) { - self.outbound.push_back(QueuedWrite { + self.control_outbound.push_back(ControlWrite { token, - payload: super::stream::QueuedPayload::PreEncoded(bytes), + kind: OutboundWriteKind::Control, + payload: ControlWritePayload::Encoded(bytes), }); self.timeouts.push(Reverse(TimeoutEntry { at: deadline, - kind: TimeoutKind::Handshake { token }, + kind: TimeoutKind::Outbound { token }, + })); + } + + pub fn schedule_handshake_retry(&mut self, token: Token, at: Instant) { + self.timeouts.push(Reverse(TimeoutEntry { + at, + kind: TimeoutKind::HandshakeRetry { token }, })); + } + + pub fn enqueue_control( + &mut self, + config: &EngineConfig, + priority: bool, + bytes: Vec, + ) -> Token { + let token = self.next_token(); + let message = ControlWrite { + token, + kind: OutboundWriteKind::Control, + payload: ControlWritePayload::Encoded(bytes), + }; + if priority { + self.control_outbound.push_front(message); + } else { + self.control_outbound.push_back(message); + } self.timeouts.push(Reverse(TimeoutEntry { - at: deadline, + at: self.now + config.packet_expiration, kind: TimeoutKind::Outbound { token }, })); + token } - pub fn enqueue_stream_body( + pub fn enqueue_stream_close( &mut self, config: &EngineConfig, priority: bool, - body: StreamBody, + stream_id: StreamId, + target: CloseTarget, + code: CloseCode, + payload: Vec, ) -> Token { let token = self.next_token(); - let message = QueuedWrite { + let message = ControlWrite { token, - payload: super::stream::QueuedPayload::Stream { body }, + kind: OutboundWriteKind::StreamClose { stream_id }, + payload: ControlWritePayload::StreamClose { + stream_id, + target, + code, + payload, + }, }; if priority { - self.outbound.push_front(message); + self.control_outbound.push_front(message); } else { - self.outbound.push_back(message); + self.control_outbound.push_back(message); } self.timeouts.push(Reverse(TimeoutEntry { - at: Instant::now() + config.packet_expiration, + at: self.now + config.packet_expiration, kind: TimeoutKind::Outbound { token }, })); token diff --git a/ql-engine/src/engine/stream.rs b/ql-engine/src/engine/stream.rs new file mode 100644 index 00000000..35095973 --- /dev/null +++ b/ql-engine/src/engine/stream.rs @@ -0,0 +1,563 @@ +use std::{ + collections::{HashMap, VecDeque}, + time::Instant, +}; + +use super::{Token, ring::SeqRing}; +use crate::{ + StreamId, + wire::{ + StreamSeq, + stream::{CloseCode, CloseTarget, StreamAck, StreamFrame, StreamFrameClose}, + }, +}; + +// todo: need to figure out protocol behavior for: if the peer ACKs your Open and then stays silent forever, the stream will stay pending forever + +pub const STREAM_WINDOW_CAPACITY: usize = 8; +pub const STREAM_WINDOW_SIZE: u32 = STREAM_WINDOW_CAPACITY as u32; +pub const STREAM_ACK_EAGER_THRESHOLD: u32 = STREAM_WINDOW_SIZE / 2; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamSide { + Request, + Response, +} + +#[derive(Debug)] +pub struct StreamMeta { + pub stream_id: StreamId, + pub last_activity: Instant, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum OutboundPhase { + Ready, + FinPending, + FinQueued, + Closed, +} + +impl OutboundPhase { + pub fn from_prefix(fin: bool) -> Self { + if fin { Self::FinQueued } else { Self::Ready } + } + + pub fn is_closed(&self) -> bool { + *self == Self::Closed + } + + pub fn can_queue_data(&self) -> bool { + *self == Self::Ready + } + + pub fn finish(&mut self) { + *self = match *self { + Self::Ready | Self::FinPending => Self::FinPending, + Self::FinQueued => Self::FinQueued, + Self::Closed => Self::Closed, + }; + } + + pub fn queue_fin(&mut self) -> bool { + if *self != Self::FinPending { + return false; + } + *self = Self::FinQueued; + true + } + + pub fn close(&mut self) -> bool { + if *self == Self::Closed { + return false; + } + *self = Self::Closed; + true + } +} + +#[derive(Debug)] +pub struct InboundState { + pub closed: bool, +} + +impl InboundState { + pub fn new() -> Self { + Self { closed: false } + } + + pub fn close(&mut self) -> bool { + if self.closed { + return false; + } + self.closed = true; + true + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum InFlightWriteState { + /// The frame has never been handed out to be written. + Ready, + /// The frame was handed out and is awaiting `complete_write`. + Issued, + /// The frame write completed and is waiting for retransmit eligibility. + WaitingRetry { retry_at: Instant }, +} + +#[derive(Debug)] +pub struct InFlightFrame { + pub tx_seq: StreamSeq, + pub frame: StreamFrame, + pub attempt: u8, + pub write_state: InFlightWriteState, +} + +#[derive(Debug)] +pub enum BufferIncomingResult { + Duplicate, + AlreadyBuffered, + Buffered { out_of_order: bool }, + OutOfWindow, +} + +// TODO: does it really make sense to have terminal control frames have sequence ids? +#[derive(Debug)] +pub struct StreamControl { + pub pending: VecDeque, + pub in_flight: SeqRing, + pub next_tx_seq: StreamSeq, + pub recv_buffer: SeqRing, + pub ack_dirty: bool, + pub ack_immediate: bool, + pub ack_delay_token: Option, + pub ack_outbound_token: Option, + pub last_sent_ack_base: StreamSeq, + pub fast_recovery: Option, +} + +impl Default for StreamControl { + fn default() -> Self { + Self { + pending: VecDeque::new(), + in_flight: SeqRing::new(StreamSeq::START), + next_tx_seq: StreamSeq::START, + recv_buffer: SeqRing::new(StreamSeq::START), + ack_dirty: false, + ack_immediate: false, + ack_delay_token: None, + ack_outbound_token: None, + last_sent_ack_base: StreamSeq(0), + fast_recovery: None, + } + } +} + +impl StreamControl { + pub fn take_tx_seq(&mut self) -> StreamSeq { + let tx_seq = self.next_tx_seq; + self.next_tx_seq = self.next_tx_seq.next(); + tx_seq + } + + pub fn send_window_has_space(&self) -> bool { + self.in_flight.accepts_seq(self.next_tx_seq) + } + + pub fn committed_rx_seq(&self) -> StreamSeq { + self.recv_buffer.base_seq().prev() + } + + pub fn queue_frame_back(&mut self, frame: StreamFrame) { + self.pending.push_back(frame); + } + + pub fn queue_frame_front(&mut self, frame: StreamFrame) { + self.pending.push_front(frame); + } + + pub fn note_ack(&mut self, immediate: bool) { + self.ack_dirty = true; + self.ack_immediate |= immediate; + } + + pub fn clear_ack_schedule(&mut self) { + self.ack_dirty = false; + self.ack_immediate = false; + self.ack_delay_token = None; + } + + pub fn maybe_force_ack_for_progress(&mut self) { + if !self.ack_dirty { + return; + } + let committed = self.committed_rx_seq(); + let progressed = self + .last_sent_ack_base + .forward_distance_to(committed) + .unwrap_or(0); + if progressed >= STREAM_ACK_EAGER_THRESHOLD { + self.ack_immediate = true; + } + } + + pub fn note_ack_sent(&mut self, ack: StreamAck) { + if ack.base.serial_gt(self.last_sent_ack_base) { + self.last_sent_ack_base = ack.base; + } + } + + pub fn take_piggyback_ack(&mut self, inbound_alive: bool) -> StreamAck { + if !inbound_alive || !self.ack_dirty { + return StreamAck::EMPTY; + } + let ack = self.current_ack(); + self.clear_ack_schedule(); + self.note_ack_sent(ack); + ack + } + + pub fn current_ack(&self) -> StreamAck { + StreamAck { + base: self.committed_rx_seq(), + bitmap: self.recv_buffer.bitmap(), + } + } + + pub fn buffer_incoming( + &mut self, + tx_seq: StreamSeq, + frame: StreamFrame, + ) -> BufferIncomingResult { + if tx_seq.serial_lt(self.recv_buffer.base_seq()) { + return BufferIncomingResult::Duplicate; + } + if !self.recv_buffer.accepts_seq(tx_seq) { + return BufferIncomingResult::OutOfWindow; + } + if self.recv_buffer.contains_key(&tx_seq) { + return BufferIncomingResult::AlreadyBuffered; + } + + let out_of_order = tx_seq != self.recv_buffer.base_seq(); + let _ = self.recv_buffer.insert(tx_seq, frame); + BufferIncomingResult::Buffered { out_of_order } + } + + pub fn pop_next_committable(&mut self) -> Option<(StreamSeq, StreamFrame)> { + self.recv_buffer.take_front() + } + + pub fn insert_in_flight(&mut self, frame: InFlightFrame) { + let _ = self.in_flight.set(frame.tx_seq, frame); + } + + pub fn fast_retransmit_candidate(&self, ack: StreamAck, threshold: u8) -> Option { + if threshold == 0 { + return None; + } + + let hole = self + .in_flight + .iter() + .map(|(tx_seq, _)| tx_seq) + .find(|tx_seq| !Self::ack_covers(ack, *tx_seq))?; + + if self.fast_recovery == Some(hole) { + return None; + } + + let later_acked = self + .in_flight + .iter() + .map(|(tx_seq, _)| tx_seq) + .filter(|tx_seq| tx_seq.serial_gt(hole) && Self::ack_covers(ack, *tx_seq)) + .count(); + + (later_acked >= threshold as usize).then_some(hole) + } + + pub fn schedule_fast_retransmit(&mut self, tx_seq: StreamSeq, now: Instant) { + if let Some(in_flight) = self.in_flight.get_mut(&tx_seq) { + in_flight.write_state = InFlightWriteState::WaitingRetry { retry_at: now }; + self.fast_recovery = Some(tx_seq); + } + } + + pub fn mark_write_issued(&mut self, tx_seq: StreamSeq) -> Option { + let in_flight = self.in_flight.get_mut(&tx_seq)?; + match in_flight.write_state { + InFlightWriteState::Issued => return None, + InFlightWriteState::WaitingRetry { .. } => { + in_flight.attempt = in_flight.attempt.saturating_add(1); + } + InFlightWriteState::Ready => {} + } + in_flight.write_state = InFlightWriteState::Issued; + Some(in_flight.frame.clone()) + } + + pub fn complete_write(&mut self, tx_seq: StreamSeq, retry_at: Instant) { + if let Some(in_flight) = self.in_flight.get_mut(&tx_seq) { + in_flight.write_state = InFlightWriteState::WaitingRetry { retry_at }; + } + } + + pub fn set_retry_deadline(&mut self, tx_seq: StreamSeq, retry_at: Instant) { + if let Some(in_flight) = self.in_flight.get_mut(&tx_seq) { + in_flight.write_state = InFlightWriteState::WaitingRetry { retry_at }; + } + } + + pub fn clear_fast_recovery(&mut self, ack_base: StreamSeq) { + let should_clear = self.fast_recovery.is_some_and(|tx_seq| { + tx_seq.serial_lte(ack_base) || !self.in_flight.contains_key(&tx_seq) + }); + if should_clear { + self.fast_recovery = None; + } + } + + pub fn remove_in_flight(&mut self, tx_seq: StreamSeq) -> Option { + let removed = self.in_flight.remove(&tx_seq); + self.in_flight.advance_empty_front_until(self.next_tx_seq); + if self.fast_recovery == Some(tx_seq) { + self.fast_recovery = None; + } + removed + } + + pub fn clear_transient_buffers(&mut self) { + self.pending.clear(); + self.in_flight.clear_with_base(self.next_tx_seq); + self.recv_buffer + .clear_with_base(self.committed_rx_seq().next()); + self.clear_ack_schedule(); + self.ack_outbound_token = None; + self.fast_recovery = None; + } + + pub fn ack_covers(ack: StreamAck, tx_seq: StreamSeq) -> bool { + if tx_seq.serial_lte(ack.base) { + return true; + } + let Some(delta) = ack.base.forward_distance_to(tx_seq) else { + return false; + }; + if !(1..=STREAM_WINDOW_SIZE).contains(&delta) { + return false; + } + (ack.bitmap & (1u8 << (delta - 1))) != 0 + } +} + +#[derive(Debug)] +pub struct InitiatorStream { + pub request: OutboundPhase, + pub response: InboundState, +} + +#[derive(Debug)] +pub struct ResponderStream { + pub request: InboundState, + pub response: OutboundPhase, + pub response_started: bool, +} + +#[derive(Debug)] +pub struct ProvisionalStream { + pub timeout_token: Token, +} + +#[derive(Debug)] +pub enum StreamRole { + Initiator(InitiatorStream), + Responder(ResponderStream), + Provisional(ProvisionalStream), +} + +#[derive(Debug)] +pub struct StreamState { + pub meta: StreamMeta, + pub control: StreamControl, + pub role: StreamRole, +} + +impl StreamState { + pub fn parts_mut(&mut self) -> (&mut StreamMeta, &mut StreamControl, &mut StreamRole) { + (&mut self.meta, &mut self.control, &mut self.role) + } + + pub fn outbound_mut(&mut self, side: StreamSide) -> Option<&mut OutboundPhase> { + match &mut self.role { + StreamRole::Initiator(state) if side == StreamSide::Request => Some(&mut state.request), + StreamRole::Responder(state) if side == StreamSide::Response => { + Some(&mut state.response) + } + _ => None, + } + } + + pub fn inbound_mut(&mut self, side: StreamSide) -> Option<&mut InboundState> { + match &mut self.role { + StreamRole::Initiator(state) if side == StreamSide::Response => { + Some(&mut state.response) + } + StreamRole::Responder(state) if side == StreamSide::Request => Some(&mut state.request), + _ => None, + } + } + + pub fn provisional_timeout_token(&self) -> Option { + match &self.role { + StreamRole::Provisional(state) => Some(state.timeout_token), + _ => None, + } + } + + pub fn outbound_side(&self) -> Option { + match &self.role { + StreamRole::Initiator(_) => Some(StreamSide::Request), + StreamRole::Responder(_) => Some(StreamSide::Response), + StreamRole::Provisional(_) => None, + } + } + + pub fn inbound_side(&self) -> Option { + match &self.role { + StreamRole::Initiator(_) => Some(StreamSide::Response), + StreamRole::Responder(_) => Some(StreamSide::Request), + StreamRole::Provisional(_) => None, + } + } + + pub fn is_provisional(&self) -> bool { + matches!(&self.role, StreamRole::Provisional(_)) + } + + pub fn can_reap(&self) -> bool { + if !self.control.pending.is_empty() + || !self.control.in_flight.is_empty() + || !self.control.recv_buffer.is_empty() + || self.control.ack_dirty + || self.control.ack_outbound_token.is_some() + { + return false; + } + match &self.role { + StreamRole::Initiator(state) => state.request.is_closed() && state.response.closed, + StreamRole::Responder(state) => state.request.closed && state.response.is_closed(), + StreamRole::Provisional(_) => false, + } + } +} + +#[derive(Debug, Default)] +pub struct StreamStore { + streams: HashMap, + order: Vec, + cursor: usize, +} + +impl StreamStore { + pub fn len(&self) -> usize { + self.streams.len() + } + + pub fn contains_key(&self, stream_id: &StreamId) -> bool { + self.streams.contains_key(stream_id) + } + + pub fn insert(&mut self, stream_id: StreamId, stream: StreamState) -> Option { + if !self.streams.contains_key(&stream_id) { + self.order.push(stream_id); + } + self.streams.insert(stream_id, stream) + } + + pub fn get(&self, stream_id: &StreamId) -> Option<&StreamState> { + self.streams.get(stream_id) + } + + pub fn get_mut(&mut self, stream_id: &StreamId) -> Option<&mut StreamState> { + self.streams.get_mut(stream_id) + } + + pub fn remove(&mut self, stream_id: &StreamId) -> Option { + let removed = self.streams.remove(stream_id); + if removed.is_some() { + if let Some(index) = self.order.iter().position(|id| id == stream_id) { + self.order.remove(index); + if self.order.is_empty() { + self.cursor = 0; + } else if index < self.cursor { + self.cursor -= 1; + } else if self.cursor >= self.order.len() { + self.cursor = 0; + } + } + } + removed + } + + pub fn values(&self) -> impl Iterator { + self.streams.values() + } + + pub fn values_mut(&mut self) -> impl Iterator { + self.streams.values_mut() + } + + pub fn iter(&self) -> impl Iterator { + self.streams.iter() + } + + pub fn into_inner(self) -> HashMap { + self.streams + } + + pub fn scan_from_cursor(&self) -> impl Iterator + '_ { + let len = self.order.len(); + (0..len).map(move |offset| self.order[(self.cursor + offset) % len]) + } + + pub fn advance_cursor_after(&mut self, stream_id: StreamId) { + if let Some(index) = self.order.iter().position(|id| *id == stream_id) { + self.cursor = if self.order.is_empty() { + 0 + } else { + (index + 1) % self.order.len() + }; + } + } + + pub fn stream_retry_deadline(&self) -> Option { + self.streams + .values() + .flat_map(|stream| { + stream + .control + .in_flight + .iter() + .filter_map(|(_, in_flight)| match in_flight.write_state { + InFlightWriteState::WaitingRetry { retry_at } => Some(retry_at), + InFlightWriteState::Ready | InFlightWriteState::Issued => None, + }) + }) + .min() + } +} + +pub fn close_frame( + stream_id: StreamId, + target: CloseTarget, + code: CloseCode, + payload: Vec, +) -> StreamFrame { + StreamFrame::Close(StreamFrameClose { + stream_id, + target, + code, + payload, + }) +} diff --git a/ql-engine/src/engine/tests/handshake.rs b/ql-engine/src/engine/tests/handshake.rs new file mode 100644 index 00000000..83fe1f67 --- /dev/null +++ b/ql-engine/src/engine/tests/handshake.rs @@ -0,0 +1,824 @@ +use super::*; + +fn handshake_bytes( + sender: XID, + recipient: XID, + record: wire::handshake::HandshakeRecord, +) -> Vec { + wire::encode_record(&QlRecord { + header: QlHeader { sender, recipient }, + payload: QlPayload::Handshake(record), + }) +} + +fn build_reply( + initiator_identity: &QlIdentity, + responder_identity: &QlIdentity, + responder_crypto: &TestCrypto, + hello: &wire::handshake::Hello, + packet_id: u32, +) -> wire::handshake::HelloReply { + let hello_bytes = wire::encode_value(hello); + let hello_view = wire::access_value::(&hello_bytes).unwrap(); + let (reply, _secrets) = wire::handshake::respond_hello( + responder_identity, + responder_crypto, + initiator_identity.xid, + &initiator_identity.signing_public_key, + &initiator_identity.encapsulation_public_key, + hello_view, + wire::ControlMeta { + packet_id: PacketId(packet_id), + valid_until: wire::now_secs().saturating_add(60), + }, + ) + .unwrap(); + reply +} + +fn build_confirm( + initiator_identity: &QlIdentity, + responder_identity: &QlIdentity, + hello: &wire::handshake::Hello, + reply: &wire::handshake::HelloReply, + initiator_secret: &SymmetricKey, + packet_id: u32, +) -> wire::handshake::Confirm { + let reply_bytes = wire::encode_value(reply); + let reply_view = + wire::access_value::(&reply_bytes).unwrap(); + let (confirm, _session_key) = wire::handshake::build_confirm( + initiator_identity, + responder_identity.xid, + &responder_identity.signing_public_key, + hello, + reply_view, + initiator_secret, + wire::ControlMeta { + packet_id: PacketId(packet_id), + valid_until: wire::now_secs().saturating_add(60), + }, + ) + .unwrap(); + confirm +} + +fn pump_between(a: &mut EngineWrapper, b: &mut EngineWrapper, now: Instant) { + loop { + let mut progressed = false; + + while let Some(write) = a.take_next_write() { + let bytes = write.bytes.clone(); + let _ = a.complete_write_collect(write.id, Ok(())); + let _ = b.run_tick_collect(now, EngineInput::Incoming(bytes)); + progressed = true; + } + + while let Some(write) = b.take_next_write() { + let bytes = write.bytes.clone(); + let _ = b.complete_write_collect(write.id, Ok(())); + let _ = a.run_tick_collect(now, EngineInput::Incoming(bytes)); + progressed = true; + } + + if !progressed { + break; + } + } +} + +#[test] +fn handshake_deadline_is_derived_from_peer_state() { + let mut config = EngineConfig::default(); + config.handshake_timeout = Duration::from_secs(5); + config.handshake_retry_interval = Duration::ZERO; + config.max_handshake_retries = 0; + + let identity = test_identity(); + let peer_identity = test_identity(); + let mut engine = EngineWrapper::new( + Engine::new( + config, + identity.clone(), + Some(peer_from_identity(&peer_identity)), + ), + TestCrypto::new(103), + ); + let now = Instant::now(); + + let _outputs = engine.run_tick_collect(now, EngineInput::Connect); + assert_eq!(engine.next_deadline(), Some(now + Duration::from_secs(5))); + + let write = engine.take_next_write().unwrap(); + let _outputs = engine.complete_write_collect(write.id, Ok(())); + assert_eq!(engine.next_deadline(), Some(now + Duration::from_secs(5))); + + let outputs = engine.run_tick_collect(now + Duration::from_secs(4), EngineInput::TimerExpired); + assert!(!outputs.iter().any(|output| { + matches!( + output, + EngineOutput::PeerStatusChanged { + session: PeerSession::Disconnected, + .. + } + ) + })); + assert_eq!(engine.next_deadline(), Some(now + Duration::from_secs(5))); + + let outputs = engine.run_tick_collect(now + Duration::from_secs(5), EngineInput::TimerExpired); + assert!(outputs.iter().any(|output| { + matches!( + output, + EngineOutput::PeerStatusChanged { + session: PeerSession::Disconnected, + .. + } + ) + })); +} + +#[test] +fn initiator_retries_hello_after_retry_interval() { + let mut config = EngineConfig::default(); + config.handshake_timeout = Duration::from_secs(5); + config.handshake_retry_interval = Duration::from_millis(250); + config.max_handshake_retries = 2; + + let identity = test_identity(); + let peer_identity = test_identity(); + let mut engine = EngineWrapper::new( + Engine::new(config, identity, Some(peer_from_identity(&peer_identity))), + TestCrypto::new(111), + ); + let now = Instant::now(); + + let _ = engine.run_tick_collect(now, EngineInput::Connect); + let hello_write = engine.take_next_write().unwrap(); + let hello_bytes = hello_write.bytes.clone(); + let _ = engine.complete_write_collect(hello_write.id, Ok(())); + + let _ = engine.run_tick_collect(now + Duration::from_millis(250), EngineInput::TimerExpired); + let retry_write = engine.take_next_write().unwrap(); + assert_eq!(retry_write.bytes, hello_bytes); + assert!(matches!( + engine.peer.as_ref().map(|peer| &peer.session), + Some(PeerSession::Initiator { + stage: HandshakeInitiator::WaitingHelloReply { retry_count: 1, .. }, + .. + }) + )); +} + +#[test] +fn responder_retries_hello_reply_after_retry_interval() { + let mut config = EngineConfig::default(); + config.handshake_timeout = Duration::from_secs(5); + config.handshake_retry_interval = Duration::from_millis(250); + config.max_handshake_retries = 2; + + let responder_identity = test_identity(); + let initiator_identity = test_identity(); + let initiator_crypto = TestCrypto::new(112); + let responder_crypto = TestCrypto::new(113); + let mut engine = EngineWrapper::new( + Engine::new( + config, + responder_identity.clone(), + Some(peer_from_identity(&initiator_identity)), + ), + responder_crypto, + ); + let now = Instant::now(); + + let (hello, _secret) = wire::handshake::build_hello( + &initiator_identity, + &initiator_crypto, + responder_identity.xid, + &responder_identity.encapsulation_public_key, + wire::ControlMeta { + packet_id: PacketId(81), + valid_until: wire::now_secs().saturating_add(60), + }, + ) + .unwrap(); + + let _ = engine.run_tick_collect( + now, + EngineInput::Incoming(handshake_bytes( + initiator_identity.xid, + responder_identity.xid, + wire::handshake::HandshakeRecord::Hello(hello), + )), + ); + let reply_write = engine.take_next_write().unwrap(); + let reply_bytes = reply_write.bytes.clone(); + let _ = engine.complete_write_collect(reply_write.id, Ok(())); + + let _ = engine.run_tick_collect(now + Duration::from_millis(250), EngineInput::TimerExpired); + let retry_write = engine.take_next_write().unwrap(); + assert_eq!(retry_write.bytes, reply_bytes); + assert!(matches!( + engine.peer.as_ref().map(|peer| &peer.session), + Some(PeerSession::Responder { + stage: HandshakeResponder::WaitingConfirm { retry_count: 1, .. }, + .. + }) + )); +} + +#[test] +fn initiator_retries_confirm_after_retry_interval() { + let mut config = EngineConfig::default(); + config.handshake_timeout = Duration::from_secs(5); + config.handshake_retry_interval = Duration::from_millis(250); + config.max_handshake_retries = 2; + + let identity = test_identity(); + let peer_identity = test_identity(); + let responder_crypto = TestCrypto::new(114); + let mut engine = EngineWrapper::new( + Engine::new( + config, + identity.clone(), + Some(peer_from_identity(&peer_identity)), + ), + TestCrypto::new(115), + ); + let now = Instant::now(); + + let _ = engine.run_tick_collect(now, EngineInput::Connect); + let hello_write = engine.take_next_write().unwrap(); + let hello_record = wire::decode_record(&hello_write.bytes).unwrap(); + let QlPayload::Handshake(wire::handshake::HandshakeRecord::Hello(hello)) = hello_record.payload + else { + panic!("expected hello record"); + }; + let _ = engine.complete_write_collect(hello_write.id, Ok(())); + + let reply = build_reply(&identity, &peer_identity, &responder_crypto, &hello, 82); + let _ = engine.run_tick_collect( + now, + EngineInput::Incoming(handshake_bytes( + peer_identity.xid, + identity.xid, + wire::handshake::HandshakeRecord::HelloReply(reply), + )), + ); + let confirm_write = engine.take_next_write().unwrap(); + let confirm_bytes = confirm_write.bytes.clone(); + let _ = engine.complete_write_collect(confirm_write.id, Ok(())); + + let _ = engine.run_tick_collect(now + Duration::from_millis(250), EngineInput::TimerExpired); + let retry_write = engine.take_next_write().unwrap(); + assert_eq!(retry_write.bytes, confirm_bytes); + assert!(matches!( + engine.peer.as_ref().map(|peer| &peer.session), + Some(PeerSession::Initiator { + stage: HandshakeInitiator::WaitingReady { retry_count: 1, .. }, + .. + }) + )); +} + +#[test] +fn duplicate_hello_resends_hello_reply() { + let responder_identity = test_identity(); + let initiator_identity = test_identity(); + let initiator_crypto = TestCrypto::new(116); + let responder_crypto = TestCrypto::new(117); + let mut engine = EngineWrapper::new( + Engine::new( + EngineConfig::default(), + responder_identity.clone(), + Some(peer_from_identity(&initiator_identity)), + ), + responder_crypto, + ); + let now = Instant::now(); + + let (hello, _secret) = wire::handshake::build_hello( + &initiator_identity, + &initiator_crypto, + responder_identity.xid, + &responder_identity.encapsulation_public_key, + wire::ControlMeta { + packet_id: PacketId(83), + valid_until: wire::now_secs().saturating_add(60), + }, + ) + .unwrap(); + let hello_bytes = handshake_bytes( + initiator_identity.xid, + responder_identity.xid, + wire::handshake::HandshakeRecord::Hello(hello), + ); + + let _ = engine.run_tick_collect(now, EngineInput::Incoming(hello_bytes.clone())); + let reply_write = engine.take_next_write().unwrap(); + let reply_bytes = reply_write.bytes.clone(); + let _ = engine.complete_write_collect(reply_write.id, Ok(())); + + let _ = engine.run_tick_collect(now, EngineInput::Incoming(hello_bytes)); + let resent_reply = engine.take_next_write().unwrap(); + assert_eq!(resent_reply.bytes, reply_bytes); +} + +#[test] +fn duplicate_hello_reply_resends_confirm() { + let identity = test_identity(); + let peer_identity = test_identity(); + let responder_crypto = TestCrypto::new(118); + let mut engine = EngineWrapper::new( + Engine::new( + EngineConfig::default(), + identity.clone(), + Some(peer_from_identity(&peer_identity)), + ), + TestCrypto::new(119), + ); + let now = Instant::now(); + + let _ = engine.run_tick_collect(now, EngineInput::Connect); + let hello_write = engine.take_next_write().unwrap(); + let hello_record = wire::decode_record(&hello_write.bytes).unwrap(); + let QlPayload::Handshake(wire::handshake::HandshakeRecord::Hello(hello)) = hello_record.payload + else { + panic!("expected hello record"); + }; + let _ = engine.complete_write_collect(hello_write.id, Ok(())); + + let reply = build_reply(&identity, &peer_identity, &responder_crypto, &hello, 84); + let reply_bytes = handshake_bytes( + peer_identity.xid, + identity.xid, + wire::handshake::HandshakeRecord::HelloReply(reply.clone()), + ); + + let _ = engine.run_tick_collect(now, EngineInput::Incoming(reply_bytes.clone())); + let confirm_write = engine.take_next_write().unwrap(); + let confirm_bytes = confirm_write.bytes.clone(); + let _ = engine.complete_write_collect(confirm_write.id, Ok(())); + + let _ = engine.run_tick_collect(now, EngineInput::Incoming(reply_bytes)); + let resent_confirm = engine.take_next_write().unwrap(); + assert_eq!(resent_confirm.bytes, confirm_bytes); +} + +#[test] +fn responder_resends_ready_for_duplicate_confirm_after_connecting() { + let responder_identity = test_identity(); + let initiator_identity = test_identity(); + let initiator_crypto = TestCrypto::new(120); + let responder_crypto = TestCrypto::new(121); + let mut engine = EngineWrapper::new( + Engine::new( + EngineConfig::default(), + responder_identity.clone(), + Some(peer_from_identity(&initiator_identity)), + ), + responder_crypto, + ); + let now = Instant::now(); + + let (hello, initiator_secret) = wire::handshake::build_hello( + &initiator_identity, + &initiator_crypto, + responder_identity.xid, + &responder_identity.encapsulation_public_key, + wire::ControlMeta { + packet_id: PacketId(85), + valid_until: wire::now_secs().saturating_add(60), + }, + ) + .unwrap(); + let _ = engine.run_tick_collect( + now, + EngineInput::Incoming(handshake_bytes( + initiator_identity.xid, + responder_identity.xid, + wire::handshake::HandshakeRecord::Hello(hello.clone()), + )), + ); + + let reply_write = engine.take_next_write().unwrap(); + let reply_record = wire::decode_record(&reply_write.bytes).unwrap(); + let QlPayload::Handshake(wire::handshake::HandshakeRecord::HelloReply(reply)) = + reply_record.payload + else { + panic!("expected hello reply"); + }; + let _ = engine.complete_write_collect(reply_write.id, Ok(())); + + let confirm = build_confirm( + &initiator_identity, + &responder_identity, + &hello, + &reply, + &initiator_secret, + 86, + ); + let confirm_bytes = handshake_bytes( + initiator_identity.xid, + responder_identity.xid, + wire::handshake::HandshakeRecord::Confirm(confirm.clone()), + ); + + let _ = engine.run_tick_collect(now, EngineInput::Incoming(confirm_bytes.clone())); + let ready_write = engine.take_next_write().unwrap(); + let ready_bytes = ready_write.bytes.clone(); + let _ = engine.complete_write_collect(ready_write.id, Ok(())); + + assert!(matches!( + engine.peer.as_ref().map(|peer| &peer.session), + Some(PeerSession::Connected { + recent_ready: Some(_), + .. + }) + )); + + let _ = engine.run_tick_collect(now, EngineInput::Incoming(confirm_bytes)); + let resent_ready = engine.take_next_write().unwrap(); + assert_eq!(resent_ready.bytes, ready_bytes); +} + +#[test] +fn stale_hello_reply_does_not_abort_fresh_handshake() { + let identity = test_identity(); + let peer_identity = test_identity(); + let responder_crypto = TestCrypto::new(122); + let stale_initiator_crypto = TestCrypto::new(123); + let mut engine = EngineWrapper::new( + Engine::new( + EngineConfig::default(), + identity.clone(), + Some(peer_from_identity(&peer_identity)), + ), + TestCrypto::new(124), + ); + let now = Instant::now(); + + let (stale_hello, _stale_secret) = wire::handshake::build_hello( + &identity, + &stale_initiator_crypto, + peer_identity.xid, + &peer_identity.encapsulation_public_key, + wire::ControlMeta { + packet_id: PacketId(87), + valid_until: wire::now_secs().saturating_add(60), + }, + ) + .unwrap(); + let stale_reply = build_reply( + &identity, + &peer_identity, + &responder_crypto, + &stale_hello, + 88, + ); + + let _ = engine.run_tick_collect(now, EngineInput::Connect); + let hello_write = engine.take_next_write().unwrap(); + let hello_record = wire::decode_record(&hello_write.bytes).unwrap(); + let QlPayload::Handshake(wire::handshake::HandshakeRecord::Hello(current_hello)) = + hello_record.payload + else { + panic!("expected hello record"); + }; + let _ = engine.complete_write_collect(hello_write.id, Ok(())); + + let outputs = engine.run_tick_collect( + now, + EngineInput::Incoming(handshake_bytes( + peer_identity.xid, + identity.xid, + wire::handshake::HandshakeRecord::HelloReply(stale_reply), + )), + ); + assert!(!outputs.iter().any(|output| matches!( + output, + EngineOutput::PeerStatusChanged { + session: PeerSession::Disconnected, + .. + } + ))); + assert!(matches!( + engine.peer.as_ref().map(|peer| &peer.session), + Some(PeerSession::Initiator { + stage: HandshakeInitiator::WaitingHelloReply { .. }, + .. + }) + )); + + let current_reply = build_reply( + &identity, + &peer_identity, + &responder_crypto, + ¤t_hello, + 89, + ); + let _ = engine.run_tick_collect( + now, + EngineInput::Incoming(handshake_bytes( + peer_identity.xid, + identity.xid, + wire::handshake::HandshakeRecord::HelloReply(current_reply), + )), + ); + assert!(matches!( + engine.peer.as_ref().map(|peer| &peer.session), + Some(PeerSession::Initiator { + stage: HandshakeInitiator::WaitingReady { .. }, + .. + }) + )); + assert!(engine.take_next_write().is_some()); +} + +#[test] +fn stale_confirm_does_not_abort_fresh_handshake() { + let responder_identity = test_identity(); + let initiator_identity = test_identity(); + let responder_crypto = TestCrypto::new(125); + let initiator_crypto = TestCrypto::new(126); + let stale_initiator_crypto = TestCrypto::new(127); + let mut engine = EngineWrapper::new( + Engine::new( + EngineConfig::default(), + responder_identity.clone(), + Some(peer_from_identity(&initiator_identity)), + ), + responder_crypto, + ); + let now = Instant::now(); + + let (stale_hello, stale_secret) = wire::handshake::build_hello( + &initiator_identity, + &stale_initiator_crypto, + responder_identity.xid, + &responder_identity.encapsulation_public_key, + wire::ControlMeta { + packet_id: PacketId(90), + valid_until: wire::now_secs().saturating_add(60), + }, + ) + .unwrap(); + let stale_reply = build_reply( + &initiator_identity, + &responder_identity, + &TestCrypto::new(128), + &stale_hello, + 91, + ); + let stale_confirm = build_confirm( + &initiator_identity, + &responder_identity, + &stale_hello, + &stale_reply, + &stale_secret, + 92, + ); + + let (hello, initiator_secret) = wire::handshake::build_hello( + &initiator_identity, + &initiator_crypto, + responder_identity.xid, + &responder_identity.encapsulation_public_key, + wire::ControlMeta { + packet_id: PacketId(93), + valid_until: wire::now_secs().saturating_add(60), + }, + ) + .unwrap(); + let _ = engine.run_tick_collect( + now, + EngineInput::Incoming(handshake_bytes( + initiator_identity.xid, + responder_identity.xid, + wire::handshake::HandshakeRecord::Hello(hello.clone()), + )), + ); + + let reply_write = engine.take_next_write().unwrap(); + let reply_record = wire::decode_record(&reply_write.bytes).unwrap(); + let QlPayload::Handshake(wire::handshake::HandshakeRecord::HelloReply(reply)) = + reply_record.payload + else { + panic!("expected hello reply"); + }; + let _ = engine.complete_write_collect(reply_write.id, Ok(())); + + let outputs = engine.run_tick_collect( + now, + EngineInput::Incoming(handshake_bytes( + initiator_identity.xid, + responder_identity.xid, + wire::handshake::HandshakeRecord::Confirm(stale_confirm), + )), + ); + assert!(!outputs.iter().any(|output| matches!( + output, + EngineOutput::PeerStatusChanged { + session: PeerSession::Disconnected, + .. + } + ))); + assert!(matches!( + engine.peer.as_ref().map(|peer| &peer.session), + Some(PeerSession::Responder { + stage: HandshakeResponder::WaitingConfirm { .. }, + .. + }) + )); + + let confirm = build_confirm( + &initiator_identity, + &responder_identity, + &hello, + &reply, + &initiator_secret, + 94, + ); + let _ = engine.run_tick_collect( + now, + EngineInput::Incoming(handshake_bytes( + initiator_identity.xid, + responder_identity.xid, + wire::handshake::HandshakeRecord::Confirm(confirm), + )), + ); + assert!(engine.take_next_write().is_some()); +} + +#[test] +fn initiator_waits_for_ready_before_connecting() { + let config = EngineConfig::default(); + let identity = test_identity(); + let peer_identity = test_identity(); + let responder_crypto = TestCrypto::new(129); + let mut engine = EngineWrapper::new( + Engine::new( + config, + identity.clone(), + Some(peer_from_identity(&peer_identity)), + ), + TestCrypto::new(130), + ); + let now = Instant::now(); + + let _outputs = engine.run_tick_collect(now, EngineInput::Connect); + + let hello_write = engine.take_next_write().unwrap(); + let hello_record = wire::decode_record(&hello_write.bytes).unwrap(); + let QlPayload::Handshake(wire::handshake::HandshakeRecord::Hello(hello)) = hello_record.payload + else { + panic!("expected hello record"); + }; + let _outputs = engine.complete_write_collect(hello_write.id, Ok(())); + + let reply = build_reply(&identity, &peer_identity, &responder_crypto, &hello, 95); + let _outputs = engine.run_tick_collect( + now, + EngineInput::Incoming(handshake_bytes( + peer_identity.xid, + identity.xid, + wire::handshake::HandshakeRecord::HelloReply(reply), + )), + ); + + let confirm_write = engine.take_next_write().unwrap(); + let _outputs = engine.complete_write_collect(confirm_write.id, Ok(())); + + assert!(matches!( + engine.peer.as_ref().map(|peer| &peer.session), + Some(PeerSession::Initiator { + stage: HandshakeInitiator::WaitingReady { .. }, + .. + }) + )); + assert!(matches!( + engine.open_stream(now, Vec::new(), None, StreamConfig::default()), + Err(QlError::MissingSession) + )); + + let pending_session_key = match engine.peer.as_ref().map(|peer| &peer.session) { + Some(PeerSession::Initiator { session_key, .. }) => session_key.clone(), + other => panic!("expected pending initiator session, got {other:?}"), + }; + let outputs = engine.run_tick_collect( + now, + EngineInput::Incoming(handshake_bytes( + peer_identity.xid, + identity.xid, + wire::handshake::HandshakeRecord::Ready(wire::handshake::build_ready( + QlHeader { + sender: peer_identity.xid, + recipient: identity.xid, + }, + &pending_session_key, + wire::ControlMeta { + packet_id: PacketId(96), + valid_until: wire::now_secs().saturating_add(60), + }, + [9; wire::encrypted_message::NONCE_SIZE], + )), + )), + ); + + assert!(matches!( + engine.peer.as_ref().map(|peer| &peer.session), + Some(PeerSession::Connected { .. }) + )); + assert!(outputs.iter().any(|output| matches!( + output, + EngineOutput::PeerStatusChanged { + session: PeerSession::Connected { .. }, + .. + } + ))); +} + +#[test] +fn handshake_retry_limit_disconnects_initiator() { + let mut config = EngineConfig::default(); + config.handshake_timeout = Duration::from_secs(5); + config.handshake_retry_interval = Duration::from_millis(250); + config.max_handshake_retries = 1; + + let identity = test_identity(); + let peer_identity = test_identity(); + let mut engine = EngineWrapper::new( + Engine::new( + config, + identity, + Some(peer_from_identity(&peer_identity)), + ), + TestCrypto::new(131), + ); + let now = Instant::now(); + + let _ = engine.run_tick_collect(now, EngineInput::Connect); + let hello_write = engine.take_next_write().unwrap(); + let hello_bytes = hello_write.bytes.clone(); + let _ = engine.complete_write_collect(hello_write.id, Ok(())); + + let _ = engine.run_tick_collect(now + Duration::from_millis(250), EngineInput::TimerExpired); + let retry_write = engine.take_next_write().unwrap(); + assert_eq!(retry_write.bytes, hello_bytes); + let _ = engine.complete_write_collect(retry_write.id, Ok(())); + + let outputs = engine.run_tick_collect(now + Duration::from_millis(500), EngineInput::TimerExpired); + assert!(outputs.iter().any(|output| matches!( + output, + EngineOutput::PeerStatusChanged { + session: PeerSession::Disconnected, + .. + } + ))); + assert!(matches!( + engine.peer.as_ref().map(|peer| &peer.session), + Some(PeerSession::Disconnected) + )); +} + +#[test] +fn simultaneous_connect_converges_to_connected_peers() { + let config = EngineConfig::default(); + let identity_a = test_identity(); + let identity_b = test_identity(); + let mut a = EngineWrapper::new( + Engine::new( + config, + identity_a.clone(), + Some(peer_from_identity(&identity_b)), + ), + TestCrypto::new(132), + ); + let mut b = EngineWrapper::new( + Engine::new( + config, + identity_b.clone(), + Some(peer_from_identity(&identity_a)), + ), + TestCrypto::new(133), + ); + let now = Instant::now(); + + let _ = a.run_tick_collect(now, EngineInput::Connect); + let _ = b.run_tick_collect(now, EngineInput::Connect); + + let hello_a = a.take_next_write().unwrap(); + let hello_a_bytes = hello_a.bytes.clone(); + let _ = a.complete_write_collect(hello_a.id, Ok(())); + + let hello_b = b.take_next_write().unwrap(); + let hello_b_bytes = hello_b.bytes.clone(); + let _ = b.complete_write_collect(hello_b.id, Ok(())); + + let _ = a.run_tick_collect(now, EngineInput::Incoming(hello_b_bytes)); + let _ = b.run_tick_collect(now, EngineInput::Incoming(hello_a_bytes)); + + pump_between(&mut a, &mut b, now); + + assert!(matches!(a.peer.as_ref().map(|peer| &peer.session), Some(PeerSession::Connected { .. }))); + assert!(matches!(b.peer.as_ref().map(|peer| &peer.session), Some(PeerSession::Connected { .. }))); +} diff --git a/ql-engine/src/engine/tests/liveness.rs b/ql-engine/src/engine/tests/liveness.rs new file mode 100644 index 00000000..2c3ac9d4 --- /dev/null +++ b/ql-engine/src/engine/tests/liveness.rs @@ -0,0 +1,87 @@ +use super::*; + +#[test] +fn replayed_heartbeat_is_ignored() { + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(EngineConfig::default(), 101, 4); + let heartbeat = wire::heartbeat::encrypt_heartbeat( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + wire::heartbeat::HeartbeatBody { + meta: wire::ControlMeta { + packet_id: PacketId(7), + valid_until: wire::now_secs().saturating_add(60), + }, + }, + [3; wire::encrypted_message::NONCE_SIZE], + ); + let bytes = wire::encode_record(&heartbeat); + + let _first = engine.run_tick_collect(now, EngineInput::Incoming(bytes.clone())); + let first_write = engine.take_next_write().unwrap(); + let first_record = wire::decode_record(&first_write.bytes).unwrap(); + assert!(matches!(first_record.payload, QlPayload::Heartbeat(_))); + let _ = engine.complete_write_collect(first_write.id, Ok(())); + + let _second = engine.run_tick_collect(now, EngineInput::Incoming(bytes)); + assert!(engine.take_next_write().is_none()); +} + +#[test] +fn keepalive_deadline_is_derived_from_peer_state() { + let mut config = EngineConfig::default(); + config.keep_alive = Some(KeepAliveConfig { + interval: Duration::from_secs(5), + timeout: Duration::from_secs(7), + }); + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(config, 103, 6); + + let heartbeat = encrypt_heartbeat_record( + peer.xid, + engine.engine.identity.xid, + &session_key, + 1, + [7; wire::encrypted_message::NONCE_SIZE], + ); + let outputs = + engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&heartbeat))); + let _ = outputs; + assert_eq!(engine.next_deadline(), Some(now + Duration::from_secs(5))); + + let write = engine.take_next_write().unwrap(); + let record = wire::decode_record(&write.bytes).unwrap(); + assert!(matches!(record.payload, QlPayload::Heartbeat(_))); + let _ = engine.complete_write_collect(write.id, Ok(())); + + let outputs = engine.run_tick_collect(now + Duration::from_secs(5), EngineInput::TimerExpired); + let _ = outputs; + assert_eq!(engine.next_deadline(), Some(now + Duration::from_secs(12))); + + let write = engine.take_next_write().unwrap(); + let record = wire::decode_record(&write.bytes).unwrap(); + assert!(matches!(record.payload, QlPayload::Heartbeat(_))); + let _ = engine.complete_write_collect(write.id, Ok(())); + + let outputs = engine.run_tick_collect(now + Duration::from_secs(12), EngineInput::TimerExpired); + assert!(outputs.iter().any(|output| { + matches!( + output, + EngineOutput::PeerStatusChanged { + session: PeerSession::Disconnected, + .. + } + ) + })); +} diff --git a/ql-engine/src/engine/tests/mod.rs b/ql-engine/src/engine/tests/mod.rs new file mode 100644 index 00000000..1c396ce0 --- /dev/null +++ b/ql-engine/src/engine/tests/mod.rs @@ -0,0 +1,410 @@ +mod handshake; +mod liveness; +mod peer; +mod stream; + +use std::{ + cell::Cell, + mem, + ops::{Deref, DerefMut}, + time::{Duration, Instant}, +}; + +use bc_components::{SymmetricKey, MLDSA, MLKEM, XID}; + +use crate::{ + engine::{state::StreamNamespace, stream::*, *}, + identity::QlIdentity, + wire::{self, stream::*, QlHeader, QlPayload, QlRecord, StreamSeq}, + PacketId, Peer, +}; + +#[derive(Clone)] +struct TestCrypto { + nonce_seed: u8, + nonce_counter: Cell, +} + +impl TestCrypto { + fn new(seed: u8) -> Self { + Self { + nonce_seed: seed, + nonce_counter: Cell::new(0), + } + } +} + +impl QlCrypto for TestCrypto { + fn fill_random_bytes(&self, data: &mut [u8]) { + let value = self.nonce_seed.wrapping_add(self.nonce_counter.get()); + self.nonce_counter + .set(self.nonce_counter.get().wrapping_add(1)); + data.fill(value); + } +} + +#[derive(Clone, Copy)] +enum Side { + A, + B, +} + +impl Side { + fn other(self) -> Self { + match self { + Side::A => Side::B, + Side::B => Side::A, + } + } +} + +struct Harness { + now: Instant, + a: EngineWrapper, + b: EngineWrapper, +} + +struct SingleEngineHarness { + now: Instant, + engine: EngineWrapper, + peer: QlIdentity, + session_key: SymmetricKey, +} + +impl SingleEngineHarness { + fn connected(config: EngineConfig, nonce_seed: u8, session_fill: u8) -> Self { + let local_identity = test_identity(); + let peer = test_identity(); + let session_key = SymmetricKey::from_data([session_fill; SymmetricKey::SYMMETRIC_KEY_SIZE]); + let mut engine = Engine::new( + config, + local_identity.clone(), + Some(peer_from_identity(&peer)), + ); + engine.peer.as_mut().unwrap().session = PeerSession::Connected { + session_key: session_key.clone(), + keepalive: KeepAliveState::default(), + recent_ready: None, + }; + Self { + now: Instant::now(), + engine: EngineWrapper::new(engine, TestCrypto::new(nonce_seed)), + peer, + session_key, + } + } +} + +impl Harness { + fn connected(config: EngineConfig) -> Self { + let identity_a = test_identity(); + let identity_b = test_identity(); + let peer_a = peer_from_identity(&identity_a); + let peer_b = peer_from_identity(&identity_b); + let crypto_a = TestCrypto::new(1); + let crypto_b = TestCrypto::new(2); + let session_key = SymmetricKey::from_data([7; SymmetricKey::SYMMETRIC_KEY_SIZE]); + let mut a = Engine::new(config, identity_a.clone(), Some(peer_b)); + let mut b = Engine::new(config, identity_b.clone(), Some(peer_a)); + a.peer.as_mut().unwrap().session = PeerSession::Connected { + session_key: session_key.clone(), + keepalive: KeepAliveState::default(), + recent_ready: None, + }; + b.peer.as_mut().unwrap().session = PeerSession::Connected { + session_key, + keepalive: KeepAliveState::default(), + recent_ready: None, + }; + Self { + now: Instant::now(), + a: EngineWrapper::new(a, crypto_a), + b: EngineWrapper::new(b, crypto_b), + } + } + + fn run_side(&mut self, side: Side, input: EngineInput) { + match side { + Side::A => self.a.run_tick(self.now, input), + Side::B => self.b.run_tick(self.now, input), + } + + while let Some(write) = match side { + Side::A => self.a.take_next_write(), + Side::B => self.b.take_next_write(), + } { + let bytes = write.bytes.clone(); + self.complete_side_write(side, write.id, Ok(())); + self.run_side(side.other(), EngineInput::Incoming(bytes)); + } + } + + fn complete_side_write(&mut self, side: Side, write_id: WriteId, result: Result<(), QlError>) { + match side { + Side::A => self.a.complete_write(write_id, result), + Side::B => self.b.complete_write(write_id, result), + } + } +} + +struct EngineWrapper { + engine: Engine, + crypto: TestCrypto, + outputs: Vec, +} + +impl Deref for EngineWrapper { + type Target = Engine; + + fn deref(&self) -> &Self::Target { + &self.engine + } +} + +impl DerefMut for EngineWrapper { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.engine + } +} + +impl EngineWrapper { + fn new(engine: Engine, crypto: TestCrypto) -> Self { + Self { + engine, + crypto, + outputs: Vec::new(), + } + } + + fn run_tick(&mut self, now: Instant, input: EngineInput) { + self.engine + .run_tick(now, input, &self.crypto, &mut |output| { + self.outputs.push(output) + }); + } + + fn run_tick_collect(&mut self, now: Instant, input: EngineInput) -> Vec { + self.run_tick(now, input); + self.drain_outputs() + } + + fn complete_write(&mut self, write_id: WriteId, result: Result<(), QlError>) { + self.engine + .complete_write(write_id, result, &mut |output| self.outputs.push(output)); + } + + fn take_next_write(&mut self) -> Option { + self.engine.take_next_write(&self.crypto) + } + + fn complete_write_collect( + &mut self, + write_id: WriteId, + result: Result<(), QlError>, + ) -> Vec { + self.complete_write(write_id, result); + self.drain_outputs() + } + + fn open_stream( + &mut self, + now: Instant, + request_head: Vec, + request_prefix: Option, + config: StreamConfig, + ) -> Result { + self.engine + .open_stream(now, request_head, request_prefix, config) + } + + fn drain_outputs(&mut self) -> Vec { + mem::take(&mut self.outputs) + } +} + +fn test_identity() -> QlIdentity { + let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); + let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); + QlIdentity::from_keys( + signing_private, + signing_public, + encapsulation_private, + encapsulation_public, + ) +} + +fn peer_from_identity(identity: &QlIdentity) -> Peer { + Peer { + peer: identity.xid, + signing_key: identity.signing_public_key.clone(), + encapsulation_key: identity.encapsulation_public_key.clone(), + } +} + +fn decode_stream_body(bytes: &[u8], session_key: &SymmetricKey) -> (QlHeader, StreamBody) { + let record = wire::decode_record(bytes).unwrap(); + let aad = record.header.aad(); + let QlPayload::Stream(encrypted) = record.payload else { + panic!("expected stream payload"); + }; + let plaintext = encrypted.decrypt(session_key, &aad).unwrap(); + let body = wire::access_value::(&plaintext) + .and_then(wire::deserialize_value) + .unwrap(); + (record.header, body) +} + +fn encrypt_heartbeat_record( + sender: XID, + recipient: XID, + session_key: &SymmetricKey, + packet_id: u32, + nonce: [u8; wire::encrypted_message::NONCE_SIZE], +) -> QlRecord { + wire::heartbeat::encrypt_heartbeat( + QlHeader { sender, recipient }, + session_key, + wire::heartbeat::HeartbeatBody { + meta: crate::wire::ControlMeta { + packet_id: PacketId(packet_id), + valid_until: wire::now_secs().saturating_add(60), + }, + }, + nonce, + ) +} + +fn insert_inflight_gap_stream(engine: &mut EngineWrapper, stream_id: StreamId, now: Instant) { + let retry_at = now + Duration::from_secs(60); + let mut stream = StreamState { + meta: StreamMeta { + stream_id, + last_activity: now, + }, + control: StreamControl::default(), + role: StreamRole::Initiator(InitiatorStream { + request: OutboundPhase::from_prefix(false), + response: InboundState::new(), + }), + }; + let control = &mut stream.control; + control.next_tx_seq = StreamSeq(6); + control.insert_in_flight(InFlightFrame { + tx_seq: StreamSeq::START, + frame: StreamFrame::Open(StreamFrameOpen { + stream_id, + request_head: b"open".to_vec(), + request_prefix: None, + }), + attempt: 0, + write_state: InFlightWriteState::WaitingRetry { retry_at }, + }); + for (tx_seq, byte) in [(2, b'a'), (3, b'b'), (4, b'c'), (5, b'd')] { + control.insert_in_flight(InFlightFrame { + tx_seq: StreamSeq(tx_seq), + frame: StreamFrame::Data(StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: vec![byte], + fin: false, + }, + }), + attempt: 0, + write_state: InFlightWriteState::WaitingRetry { retry_at }, + }); + } + engine.streams.insert(stream_id, stream); +} + +fn insert_inflight_stream_with_data( + engine: &mut EngineWrapper, + stream_id: StreamId, + now: Instant, + data_seqs: &[u32], +) { + let retry_at = now + Duration::from_secs(60); + let mut stream = StreamState { + meta: StreamMeta { + stream_id, + last_activity: now, + }, + control: StreamControl::default(), + role: StreamRole::Initiator(InitiatorStream { + request: OutboundPhase::from_prefix(false), + response: InboundState::new(), + }), + }; + let control = &mut stream.control; + control.next_tx_seq = StreamSeq(data_seqs.iter().copied().max().unwrap_or(1) + 1); + control.insert_in_flight(InFlightFrame { + tx_seq: StreamSeq::START, + frame: StreamFrame::Open(StreamFrameOpen { + stream_id, + request_head: b"open".to_vec(), + request_prefix: None, + }), + attempt: 0, + write_state: InFlightWriteState::WaitingRetry { retry_at }, + }); + for &tx_seq in data_seqs { + control.insert_in_flight(InFlightFrame { + tx_seq: StreamSeq(tx_seq), + frame: StreamFrame::Data(StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: vec![b'a' + (tx_seq as u8)], + fin: false, + }, + }), + attempt: 0, + write_state: InFlightWriteState::WaitingRetry { retry_at }, + }); + } + engine.streams.insert(stream_id, stream); +} + +fn insert_unwritten_inflight_stream_with_data( + engine: &mut EngineWrapper, + stream_id: StreamId, + now: Instant, + data_seqs: &[u32], +) { + let mut stream = StreamState { + meta: StreamMeta { + stream_id, + last_activity: now, + }, + control: StreamControl::default(), + role: StreamRole::Initiator(InitiatorStream { + request: OutboundPhase::from_prefix(false), + response: InboundState::new(), + }), + }; + let control = &mut stream.control; + control.next_tx_seq = StreamSeq(data_seqs.iter().copied().max().unwrap_or(1) + 1); + control.insert_in_flight(InFlightFrame { + tx_seq: StreamSeq::START, + frame: StreamFrame::Open(StreamFrameOpen { + stream_id, + request_head: b"open".to_vec(), + request_prefix: None, + }), + attempt: 0, + write_state: InFlightWriteState::Ready, + }); + for &tx_seq in data_seqs { + control.insert_in_flight(InFlightFrame { + tx_seq: StreamSeq(tx_seq), + frame: StreamFrame::Data(StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: vec![b'a' + (tx_seq as u8)], + fin: false, + }, + }), + attempt: 0, + write_state: InFlightWriteState::Ready, + }); + } + engine.streams.insert(stream_id, stream); +} diff --git a/ql-engine/src/engine/tests/peer.rs b/ql-engine/src/engine/tests/peer.rs new file mode 100644 index 00000000..11ea08aa --- /dev/null +++ b/ql-engine/src/engine/tests/peer.rs @@ -0,0 +1,42 @@ +use super::*; + +#[test] +fn replayed_unpair_is_ignored_after_rebind() { + let config = EngineConfig::default(); + let SingleEngineHarness { + now, + mut engine, + peer, + session_key: _session_key, + } = SingleEngineHarness::connected(config, 111, 5); + let peer_b = peer_from_identity(&peer); + let bytes = wire::encode_record(&wire::unpair::build_unpair_record( + &peer, + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + wire::ControlMeta { + packet_id: PacketId(9), + valid_until: wire::now_secs().saturating_add(60), + }, + )); + + let first = engine.run_tick_collect(now, EngineInput::Incoming(bytes.clone())); + assert!(first + .iter() + .any(|output| matches!(output, EngineOutput::ClearPeer))); + assert!(engine.peer.is_none()); + + let _ = engine.run_tick_collect(now, EngineInput::BindPeer(peer_b.clone())); + assert!(engine.peer.is_some()); + + let second = engine.run_tick_collect(now, EngineInput::Incoming(bytes)); + assert!(!second + .iter() + .any(|output| matches!(output, EngineOutput::ClearPeer))); + assert_eq!( + engine.peer.as_ref().map(|peer| peer.peer), + Some(peer_b.peer) + ); +} diff --git a/ql-engine/src/engine/tests/stream.rs b/ql-engine/src/engine/tests/stream.rs new file mode 100644 index 00000000..1c8e7238 --- /dev/null +++ b/ql-engine/src/engine/tests/stream.rs @@ -0,0 +1,1569 @@ +#![allow(clippy::too_many_lines)] + +use super::*; + +#[test] +fn simultaneous_opens_use_disjoint_stream_id_namespaces() { + let config = EngineConfig::default(); + let mut harness = Harness::connected(config); + let now = harness.now; + + let stream_id_a = harness + .a + .open_stream(now, b"a-open".to_vec(), None, StreamConfig::default()) + .unwrap(); + let stream_id_b = harness + .b + .open_stream(now, b"b-open".to_vec(), None, StreamConfig::default()) + .unwrap(); + + assert_ne!(stream_id_a, stream_id_b); + assert!(StreamNamespace::for_local( + harness.a.engine.identity.xid, + harness.b.engine.identity.xid + ) + .matches(stream_id_a)); + assert!(StreamNamespace::for_local( + harness.b.engine.identity.xid, + harness.a.engine.identity.xid + ) + .matches(stream_id_b)); + + let write_a = harness.a.take_next_write().unwrap(); + let write_b = harness.b.take_next_write().unwrap(); + + let _ = harness.a.complete_write_collect(write_a.id, Ok(())); + let _ = harness.b.complete_write_collect(write_b.id, Ok(())); + + let outputs_a_incoming = harness + .a + .run_tick_collect(now, EngineInput::Incoming(write_b.bytes)); + let outputs_b_incoming = harness + .b + .run_tick_collect(now, EngineInput::Incoming(write_a.bytes)); + + assert!(outputs_a_incoming.iter().any(|output| matches!( + output, + EngineOutput::InboundStreamOpened { + stream_id, + request_head, + .. + } if *stream_id == stream_id_b && request_head == b"b-open" + ))); + assert!(outputs_b_incoming.iter().any(|output| matches!( + output, + EngineOutput::InboundStreamOpened { + stream_id, + request_head, + .. + } if *stream_id == stream_id_a && request_head == b"a-open" + ))); + assert_eq!(harness.a.streams.len(), 2); + assert_eq!(harness.b.streams.len(), 2); +} + +#[test] +fn invalid_future_frame_does_not_ack_outstanding_open() { + let config = EngineConfig::default(); + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(config, 31, 5); + let stream_id = engine + .open_stream(now, b"open".to_vec(), None, StreamConfig::default()) + .unwrap(); + + let message = StreamMessage { + tx_seq: StreamSeq(2), + ack: crate::wire::stream::StreamAck { + base: StreamSeq(0), + bitmap: 0, + }, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: b"resp".to_vec(), + fin: false, + }, + }), + }; + + let body = StreamBody::Message(message); + let record = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &body, + [9; wire::encrypted_message::NONCE_SIZE], + ); + + let outputs_incoming = + engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record))); + + assert!(!outputs_incoming + .iter() + .any(|output| matches!(output, EngineOutput::InboundData { .. }))); + + let stream = engine.streams.get(&stream_id).unwrap(); + assert!(stream.control.in_flight.contains_key(&StreamSeq::START)); +} + +#[test] +fn ack_for_issued_open_is_applied_before_write_completion() { + let config = EngineConfig::default(); + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(config, 33, 7); + let stream_id = engine + .open_stream(now, b"open".to_vec(), None, StreamConfig::default()) + .unwrap(); + + let _open_write = engine.take_next_write().unwrap(); + + let message = StreamMessage { + tx_seq: StreamSeq::START, + ack: StreamAck { + base: StreamSeq::START, + bitmap: 0, + }, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: b"resp".to_vec(), + fin: false, + }, + }), + }; + let record = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &StreamBody::Message(message), + [10; wire::encrypted_message::NONCE_SIZE], + ); + + let outputs_incoming = + engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record))); + + assert!(outputs_incoming.iter().any(|output| matches!( + output, + EngineOutput::InboundData { + stream_id: id, + bytes, + } if *id == stream_id && bytes == b"resp" + ))); + let stream = engine.streams.get(&stream_id).unwrap(); + assert!(!stream.control.in_flight.contains_key(&StreamSeq::START)); +} + +#[test] +fn ack_does_not_retire_ready_data() { + let config = EngineConfig::default(); + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(config, 35, 8); + let stream_id = engine + .open_stream(now, b"open".to_vec(), None, StreamConfig::default()) + .unwrap(); + + let _open_write = engine.take_next_write().unwrap(); + let _ = engine.run_tick_collect( + now, + EngineInput::OutboundData { + stream_id, + bytes: b"body".to_vec(), + }, + ); + + let message = StreamMessage { + tx_seq: StreamSeq::START, + ack: StreamAck { + base: StreamSeq(2), + bitmap: 0, + }, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: b"resp".to_vec(), + fin: false, + }, + }), + }; + let record = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &StreamBody::Message(message), + [11; wire::encrypted_message::NONCE_SIZE], + ); + + let outputs_incoming = + engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record))); + + assert!(outputs_incoming.iter().any(|output| matches!( + output, + EngineOutput::InboundData { + stream_id: id, + bytes, + } if *id == stream_id && bytes == b"resp" + ))); + + let stream = engine.streams.get(&stream_id).unwrap(); + assert!(!stream.control.in_flight.contains_key(&StreamSeq::START)); + assert!(stream.control.in_flight.contains_key(&StreamSeq(2))); + + let write = engine.take_next_write().unwrap(); + let (_, body) = decode_stream_body(&write.bytes, &session_key); + assert!(matches!( + body, + StreamBody::Message(StreamMessage { + tx_seq: StreamSeq(2), + frame: StreamFrame::Data(StreamFrameData { + stream_id: id, + chunk: BodyChunk { bytes, fin: false }, + }), + .. + }) if id == stream_id && bytes == b"body" + )); +} + +#[test] +fn late_failed_write_after_remote_close_ack_is_ignored() { + let config = EngineConfig::default(); + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(config, 37, 9); + let stream_id = engine + .open_stream(now, b"open".to_vec(), None, StreamConfig::default()) + .unwrap(); + + let open_write = engine.take_next_write().unwrap(); + + let record = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &StreamBody::Message(StreamMessage { + tx_seq: StreamSeq::START, + ack: StreamAck { + base: StreamSeq::START, + bitmap: 0, + }, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Close(StreamFrameClose { + stream_id, + target: CloseTarget::Both, + code: CloseCode::PROTOCOL, + payload: Vec::new(), + }), + }), + [12; wire::encrypted_message::NONCE_SIZE], + ); + + let outputs_close = + engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record))); + + assert!(outputs_close.iter().any(|output| matches!( + output, + EngineOutput::OutboundFailed { + stream_id: id, + error: QlError::StreamClosed { + target: CloseTarget::Both, + code: CloseCode::PROTOCOL, + payload, + }, + } if *id == stream_id + && payload.is_empty() + ))); + assert!(outputs_close.iter().any(|output| matches!( + output, + EngineOutput::InboundFailed { + stream_id: id, + error: QlError::StreamClosed { + target: CloseTarget::Both, + code: CloseCode::PROTOCOL, + payload, + }, + } if *id == stream_id + && payload.is_empty() + ))); + let stream = engine.streams.get(&stream_id).unwrap(); + assert!(!stream.control.in_flight.contains_key(&StreamSeq::START)); + + let outputs_late = engine.complete_write_collect(open_write.id, Err(QlError::SendFailed)); + assert!(outputs_late.is_empty()); + assert!(engine.streams.contains_key(&stream_id)); +} + +#[test] +fn local_close_both_is_idempotent() { + let SingleEngineHarness { + now, + mut engine, + session_key, + .. + } = SingleEngineHarness::connected(EngineConfig::default(), 39, 10); + let stream_id = engine + .open_stream(now, b"open".to_vec(), None, StreamConfig::default()) + .unwrap(); + + let open_write = engine.take_next_write().unwrap(); + let _ = engine.complete_write_collect(open_write.id, Ok(())); + + let _ = engine.run_tick_collect( + now, + EngineInput::CloseStream { + stream_id, + target: CloseTarget::Request, + code: CloseCode::CANCELLED, + payload: Vec::new(), + }, + ); + let request_close = engine.take_next_write().unwrap(); + let (_, request_close_body) = decode_stream_body(&request_close.bytes, &session_key); + assert!(matches!( + request_close_body, + StreamBody::Message(StreamMessage { + frame: StreamFrame::Close(StreamFrameClose { + stream_id: id, + target: CloseTarget::Request, + .. + }), + .. + }) if id == stream_id + )); + let _ = engine.complete_write_collect(request_close.id, Ok(())); + + let _ = engine.run_tick_collect( + now, + EngineInput::CloseStream { + stream_id, + target: CloseTarget::Both, + code: CloseCode::CANCELLED, + payload: Vec::new(), + }, + ); + let both_close = engine.take_next_write().unwrap(); + let (_, both_close_body) = decode_stream_body(&both_close.bytes, &session_key); + assert!(matches!( + both_close_body, + StreamBody::Message(StreamMessage { + frame: StreamFrame::Close(StreamFrameClose { + stream_id: id, + target: CloseTarget::Both, + .. + }), + .. + }) if id == stream_id + )); + let _ = engine.complete_write_collect(both_close.id, Ok(())); + + let _ = engine.run_tick_collect( + now, + EngineInput::CloseStream { + stream_id, + target: CloseTarget::Both, + code: CloseCode::CANCELLED, + payload: Vec::new(), + }, + ); + assert!(engine.take_next_write().is_none()); +} + +#[test] +fn out_of_order_remote_stream_buffers_until_open_arrives() { + let config = EngineConfig::default(); + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(config, 41, 6); + let stream_id = + StreamId(StreamNamespace::for_local(peer.xid, engine.engine.identity.xid).bit() | 1); + + let data_message = StreamMessage { + tx_seq: StreamSeq(2), + ack: StreamAck::EMPTY, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(crate::wire::stream::StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: b"hello".to_vec(), + fin: false, + }, + }), + }; + let data_body = StreamBody::Message(data_message); + let data_record = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &data_body, + [11; wire::encrypted_message::NONCE_SIZE], + ); + + let outputs_data = engine.run_tick_collect( + now, + EngineInput::Incoming(wire::encode_record(&data_record)), + ); + + assert!(!outputs_data + .iter() + .any(|output| matches!(output, EngineOutput::InboundStreamOpened { .. }))); + assert!(!outputs_data + .iter() + .any(|output| matches!(output, EngineOutput::InboundData { .. }))); + assert!(engine.take_next_write().is_some()); + assert!(engine + .streams + .get(&stream_id) + .is_some_and(StreamState::is_provisional)); + + let open_message = StreamMessage { + tx_seq: StreamSeq(1), + ack: StreamAck::EMPTY, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Open(crate::wire::stream::StreamFrameOpen { + stream_id, + request_head: b"late-open".to_vec(), + request_prefix: None, + }), + }; + let open_body = StreamBody::Message(open_message); + let open_record = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &open_body, + [12; wire::encrypted_message::NONCE_SIZE], + ); + + let outputs_open = engine.run_tick_collect( + now, + EngineInput::Incoming(wire::encode_record(&open_record)), + ); + + assert!(outputs_open.iter().any(|output| matches!( + output, + EngineOutput::InboundStreamOpened { + stream_id: id, + request_head, + request_prefix: None, + } if *id == stream_id && request_head == b"late-open" + ))); + assert!(outputs_open.iter().any(|output| matches!( + output, + EngineOutput::InboundData { + stream_id: id, + bytes, + } if *id == stream_id && bytes == b"hello" + ))); +} + +#[test] +fn delayed_ack_only_does_not_consume_sequence_space() { + let mut harness = Harness::connected(EngineConfig::default()); + let stream_id = harness + .a + .open_stream( + harness.now, + b"open-head".to_vec(), + None, + StreamConfig::default(), + ) + .unwrap(); + let open_write = harness.a.take_next_write().unwrap(); + harness.complete_side_write(Side::A, open_write.id, Ok(())); + harness.run_side(Side::B, EngineInput::Incoming(open_write.bytes)); + + harness.now += EngineConfig::default().stream_ack_delay; + harness.run_side(Side::B, EngineInput::TimerExpired); + + let _outputs_b = harness.b.drain_outputs(); + + let stream = harness.b.streams.get(&stream_id).unwrap(); + assert!(stream.control.in_flight.is_empty()); + assert_eq!(stream.control.next_tx_seq, StreamSeq::START); +} + +#[test] +fn half_window_progress_flushes_ack_before_timer() { + let config = EngineConfig::default(); + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(config, 61, 8); + let stream_id = + StreamId(StreamNamespace::for_local(peer.xid, engine.engine.identity.xid).bit() | 1); + let messages = [ + StreamMessage { + tx_seq: StreamSeq(1), + ack: StreamAck::EMPTY, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Open(crate::wire::stream::StreamFrameOpen { + stream_id, + request_head: b"open".to_vec(), + request_prefix: None, + }), + }, + StreamMessage { + tx_seq: StreamSeq(2), + ack: StreamAck::EMPTY, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(crate::wire::stream::StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: b"a".to_vec(), + fin: false, + }, + }), + }, + StreamMessage { + tx_seq: StreamSeq(3), + ack: StreamAck::EMPTY, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(crate::wire::stream::StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: b"b".to_vec(), + fin: false, + }, + }), + }, + StreamMessage { + tx_seq: StreamSeq(4), + ack: StreamAck::EMPTY, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(crate::wire::stream::StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: b"c".to_vec(), + fin: false, + }, + }), + }, + ]; + + for message in messages.iter().take(3) { + let body = StreamBody::Message(message.clone()); + let record = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &body, + [message.tx_seq.0 as u8; wire::encrypted_message::NONCE_SIZE], + ); + let _outputs = + engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record))); + assert!(engine.take_next_write().is_none()); + } + + let body = StreamBody::Message(messages[3].clone()); + let record = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &body, + [4; wire::encrypted_message::NONCE_SIZE], + ); + let _outputs = + engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record))); + + let ack_write = engine.take_next_write().unwrap(); + let (_, ack_body) = decode_stream_body(&ack_write.bytes, &session_key); + assert!(matches!(ack_body, StreamBody::Ack(_))); +} + +#[test] +fn out_of_order_loss_reports_selective_ack_bitmap() { + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(EngineConfig::default(), 71, 3); + let stream_id = + StreamId(StreamNamespace::for_local(peer.xid, engine.engine.identity.xid).bit() | 1); + let messages = [ + StreamMessage { + tx_seq: StreamSeq(1), + ack: StreamAck::EMPTY, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Open(StreamFrameOpen { + stream_id, + request_head: b"open".to_vec(), + request_prefix: None, + }), + }, + StreamMessage { + tx_seq: StreamSeq(2), + ack: StreamAck::EMPTY, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: b"a".to_vec(), + fin: false, + }, + }), + }, + StreamMessage { + tx_seq: StreamSeq(4), + ack: StreamAck::EMPTY, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: b"c".to_vec(), + fin: false, + }, + }), + }, + StreamMessage { + tx_seq: StreamSeq(5), + ack: StreamAck::EMPTY, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: b"d".to_vec(), + fin: false, + }, + }), + }, + ]; + + for message in &messages[..2] { + let record = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &StreamBody::Message(message.clone()), + [message.tx_seq.0 as u8; wire::encrypted_message::NONCE_SIZE], + ); + let _outputs = + engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record))); + assert!(engine.take_next_write().is_none()); + } + + let record4 = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &StreamBody::Message(messages[2].clone()), + [4; wire::encrypted_message::NONCE_SIZE], + ); + let outputs4 = + engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record4))); + let ack_write4 = engine.take_next_write().unwrap(); + let (_, ack_body4) = decode_stream_body(&ack_write4.bytes, &session_key); + assert!(matches!( + ack_body4, + StreamBody::Ack(StreamAckBody { + stream_id: id, + ack: StreamAck { + base: StreamSeq(2), + bitmap: 0b0000_0010, + }, + .. + }) if id == stream_id + )); + assert!(!outputs4 + .iter() + .any(|output| matches!(output, EngineOutput::InboundData { .. }))); + let _ = engine.complete_write_collect(ack_write4.id, Ok(())); + + let record5 = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &StreamBody::Message(messages[3].clone()), + [5; wire::encrypted_message::NONCE_SIZE], + ); + let outputs5 = + engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record5))); + let ack_write5 = engine.take_next_write().unwrap(); + let (_, ack_body5) = decode_stream_body(&ack_write5.bytes, &session_key); + assert!(matches!( + ack_body5, + StreamBody::Ack(StreamAckBody { + stream_id: id, + ack: StreamAck { + base: StreamSeq(2), + bitmap: 0b0000_0110, + }, + .. + }) if id == stream_id + )); + assert!(!outputs5 + .iter() + .any(|output| matches!(output, EngineOutput::InboundData { .. }))); +} + +#[test] +fn selective_ack_only_body_retires_acked_gap_tail() { + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(EngineConfig::default(), 81, 2); + let stream_id = engine.state.next_stream_id(StreamNamespace::for_local( + engine.engine.identity.xid, + peer.xid, + )); + insert_inflight_gap_stream(&mut engine, stream_id, now); + + let ack_record = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &StreamBody::Ack(StreamAckBody { + stream_id, + ack: StreamAck { + base: StreamSeq(2), + bitmap: 0b0000_0110, + }, + valid_until: wire::now_secs().saturating_add(60), + }), + [9; wire::encrypted_message::NONCE_SIZE], + ); + + let outputs = + engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&ack_record))); + + assert!(!outputs + .iter() + .any(|output| matches!(output, EngineOutput::OutboundFailed { .. }))); + let stream = engine.streams.get(&stream_id).unwrap(); + let remaining: Vec<_> = stream + .control + .in_flight + .iter() + .map(|(seq, _)| seq) + .collect(); + assert_eq!(remaining, vec![StreamSeq(3)]); + assert_eq!(stream.control.next_tx_seq, StreamSeq(6)); +} + +#[test] +fn fast_retransmit_resends_oldest_gap_when_threshold_met() { + let mut config = EngineConfig::default(); + config.stream_fast_retransmit_threshold = 2; + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(config, 83, 9); + let stream_id = engine.state.next_stream_id(StreamNamespace::for_local( + engine.engine.identity.xid, + peer.xid, + )); + insert_inflight_gap_stream(&mut engine, stream_id, now); + + let ack_record = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &StreamBody::Ack(StreamAckBody { + stream_id, + ack: StreamAck { + base: StreamSeq(2), + bitmap: 0b0000_0110, + }, + valid_until: wire::now_secs().saturating_add(60), + }), + [10; wire::encrypted_message::NONCE_SIZE], + ); + + let _outputs = + engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&ack_record))); + + let write = engine.take_next_write().unwrap(); + let (_, body) = decode_stream_body(&write.bytes, &session_key); + assert!(matches!( + body, + StreamBody::Message(StreamMessage { + tx_seq: StreamSeq(3), + frame: StreamFrame::Data(StreamFrameData { .. }), + .. + }) + )); + + let stream = engine.streams.get(&stream_id).unwrap(); + let remaining: Vec<_> = stream + .control + .in_flight + .iter() + .map(|(seq, _)| seq) + .collect(); + assert_eq!(remaining, vec![StreamSeq(3)]); + let frame = stream.control.in_flight.get(&StreamSeq(3)).unwrap(); + assert_eq!(frame.attempt, 1); + assert!(matches!(frame.write_state, InFlightWriteState::Issued)); +} + +#[test] +fn fast_retransmit_respects_configured_threshold() { + let mut config = EngineConfig::default(); + config.stream_fast_retransmit_threshold = 3; + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(config, 85, 10); + let stream_id = engine.state.next_stream_id(StreamNamespace::for_local( + engine.engine.identity.xid, + peer.xid, + )); + insert_inflight_gap_stream(&mut engine, stream_id, now); + + let ack_record = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &StreamBody::Ack(StreamAckBody { + stream_id, + ack: StreamAck { + base: StreamSeq(2), + bitmap: 0b0000_0110, + }, + valid_until: wire::now_secs().saturating_add(60), + }), + [11; wire::encrypted_message::NONCE_SIZE], + ); + + let _outputs = + engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&ack_record))); + + if let Some(write) = engine.take_next_write() { + let (_, body) = decode_stream_body(&write.bytes, &session_key); + assert!(matches!(body, StreamBody::Ack(_))); + } + + let stream = engine.streams.get(&stream_id).unwrap(); + let remaining: Vec<_> = stream + .control + .in_flight + .iter() + .map(|(seq, _)| seq) + .collect(); + assert_eq!(remaining, vec![StreamSeq(3)]); + let frame = stream.control.in_flight.get(&StreamSeq(3)).unwrap(); + assert_eq!(frame.attempt, 0); + assert!(matches!( + frame.write_state, + InFlightWriteState::WaitingRetry { .. } + )); +} + +#[test] +fn timeout_retransmit_reuses_original_tx_seq_and_slot() { + let config = EngineConfig::default(); + let SingleEngineHarness { + now, + mut engine, + peer: _, + session_key, + } = SingleEngineHarness::connected(config, 91, 1); + let tracked_stream_id = engine + .open_stream(now, b"open".to_vec(), None, StreamConfig::default()) + .unwrap(); + let write = engine.take_next_write().unwrap(); + let (_, initial_body) = decode_stream_body(&write.bytes, &session_key); + assert!(matches!( + &initial_body, + StreamBody::Message(StreamMessage { + tx_seq: StreamSeq(1), + frame: StreamFrame::Open(_), + .. + }) + )); + let _outputs_written = engine.complete_write_collect(write.id, Ok(())); + + let stream = engine.streams.get(&tracked_stream_id).unwrap(); + assert_eq!(stream.control.in_flight.len(), 1); + assert!(stream.control.in_flight.contains_key(&StreamSeq::START)); + assert_eq!(stream.control.next_tx_seq, StreamSeq(2)); + + let _outputs_timeout = + engine.run_tick_collect(now + config.stream_ack_timeout, EngineInput::TimerExpired); + let retransmit_write = engine.take_next_write().unwrap(); + let (_, retransmit_body) = decode_stream_body(&retransmit_write.bytes, &session_key); + assert!(matches!( + retransmit_body, + StreamBody::Message(StreamMessage { + tx_seq: StreamSeq(1), + frame: StreamFrame::Open(StreamFrameOpen { stream_id, .. }), + .. + }) if stream_id == tracked_stream_id + )); + + let stream = engine.streams.get(&tracked_stream_id).unwrap(); + assert_eq!(stream.control.in_flight.len(), 1); + assert!(stream.control.in_flight.contains_key(&StreamSeq::START)); + assert_eq!(stream.control.next_tx_seq, StreamSeq(2)); + assert_eq!( + stream + .control + .in_flight + .get(&StreamSeq::START) + .unwrap() + .attempt, + 1 + ); +} + +#[test] +fn take_next_write_drains_multiple_stream_frames_before_completion() { + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(EngineConfig::default(), 93, 12); + let stream_id = engine.state.next_stream_id(StreamNamespace::for_local( + engine.engine.identity.xid, + peer.xid, + )); + insert_unwritten_inflight_stream_with_data(&mut engine, stream_id, now, &[2, 3]); + + let writes = { + let mut writes = Vec::new(); + while let Some(write) = engine.take_next_write() { + writes.push(write); + } + writes + }; + assert_eq!(writes.len(), 3); + + let tx_seqs: Vec<_> = writes + .iter() + .map( + |write| match decode_stream_body(&write.bytes, &session_key).1 { + StreamBody::Message(message) => message.tx_seq, + other => panic!("expected stream message, got {other:?}"), + }, + ) + .collect(); + assert_eq!(tx_seqs, vec![StreamSeq::START, StreamSeq(2), StreamSeq(3)]); + + let unique_ids: std::collections::HashSet<_> = writes.iter().map(|write| write.id).collect(); + assert_eq!(unique_ids.len(), writes.len()); + assert_eq!(engine.state.active_writes.len(), writes.len()); + assert!(engine.take_next_write().is_none()); + + let stream = engine.streams.get(&stream_id).unwrap(); + assert!(stream + .control + .in_flight + .iter() + .all(|(_, in_flight)| matches!(in_flight.write_state, InFlightWriteState::Issued))); +} + +#[test] +fn take_next_write_does_not_reissue_outstanding_frame() { + let SingleEngineHarness { + now, + mut engine, + peer, + session_key: _session_key, + } = SingleEngineHarness::connected(EngineConfig::default(), 95, 13); + let stream_id = engine.state.next_stream_id(StreamNamespace::for_local( + engine.engine.identity.xid, + peer.xid, + )); + insert_unwritten_inflight_stream_with_data(&mut engine, stream_id, now, &[]); + + let write = engine.take_next_write().unwrap(); + assert!(engine.take_next_write().is_none()); + assert!(engine.state.active_writes.contains_key(&write.id)); +} + +#[test] +fn take_next_write_round_robins_across_ready_streams() { + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(EngineConfig::default(), 97, 14); + let stream_id1 = engine.state.next_stream_id(StreamNamespace::for_local( + engine.engine.identity.xid, + peer.xid, + )); + let stream_id2 = engine.state.next_stream_id(StreamNamespace::for_local( + engine.engine.identity.xid, + peer.xid, + )); + insert_unwritten_inflight_stream_with_data(&mut engine, stream_id1, now, &[2]); + insert_unwritten_inflight_stream_with_data(&mut engine, stream_id2, now, &[2]); + + let scheduled: Vec<_> = { + let mut writes = Vec::new(); + while let Some(write) = engine.take_next_write() { + writes.push(write); + } + writes + } + .into_iter() + .map( + |write| match decode_stream_body(&write.bytes, &session_key).1 { + StreamBody::Message(message) => (message.frame.stream_id(), message.tx_seq), + other => panic!("expected stream message, got {other:?}"), + }, + ) + .collect(); + + assert_eq!( + scheduled, + vec![ + (stream_id1, StreamSeq::START), + (stream_id2, StreamSeq::START), + (stream_id1, StreamSeq(2)), + (stream_id2, StreamSeq(2)), + ] + ); +} + +#[test] +fn stale_ack_delay_timer_after_piggyback_does_not_emit_extra_ack_only() { + let mut harness = Harness::connected(EngineConfig::default()); + let stream_id = harness + .a + .open_stream( + harness.now, + b"open-head".to_vec(), + None, + StreamConfig::default(), + ) + .unwrap(); + let open_write = harness.a.take_next_write().unwrap(); + harness.complete_side_write(Side::A, open_write.id, Ok(())); + harness.run_side(Side::B, EngineInput::Incoming(open_write.bytes)); + let _ = harness.a.drain_outputs(); + let _ = harness.b.drain_outputs(); + + harness.run_side( + Side::B, + EngineInput::OutboundData { + stream_id, + bytes: b"resp".to_vec(), + }, + ); + let _ = harness.a.drain_outputs(); + let _ = harness.b.drain_outputs(); + + harness.now += EngineConfig::default().stream_ack_delay; + harness.run_side(Side::B, EngineInput::TimerExpired); + let _outputs_b_timer = harness.b.drain_outputs(); + + assert!(harness.b.take_next_write().is_none()); +} + +#[test] +fn provisional_timeout_after_late_open_is_ignored() { + let config = EngineConfig::default(); + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(config, 63, 11); + let stream_id = + StreamId(StreamNamespace::for_local(peer.xid, engine.engine.identity.xid).bit() | 1); + + let early_record = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &StreamBody::Message(StreamMessage { + tx_seq: StreamSeq(2), + ack: StreamAck::EMPTY, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: b"hello".to_vec(), + fin: false, + }, + }), + }), + [31; wire::encrypted_message::NONCE_SIZE], + ); + let _ = engine.run_tick_collect( + now, + EngineInput::Incoming(wire::encode_record(&early_record)), + ); + + let open_record = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &StreamBody::Message(StreamMessage { + tx_seq: StreamSeq::START, + ack: StreamAck::EMPTY, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Open(StreamFrameOpen { + stream_id, + request_head: b"late-open".to_vec(), + request_prefix: None, + }), + }), + [32; wire::encrypted_message::NONCE_SIZE], + ); + let outputs_open = engine.run_tick_collect( + now, + EngineInput::Incoming(wire::encode_record(&open_record)), + ); + assert!(outputs_open.iter().any(|output| matches!( + output, + EngineOutput::InboundStreamOpened { stream_id: id, .. } if *id == stream_id + ))); + + let _outputs_timeout = + engine.run_tick_collect(now + config.packet_expiration, EngineInput::TimerExpired); + + assert!(matches!( + engine.streams.get(&stream_id).map(|stream| &stream.role), + Some(StreamRole::Responder(_)) + )); + if let Some(write) = engine.take_next_write() { + let (_, body) = decode_stream_body(&write.bytes, &session_key); + assert!(!matches!( + body, + StreamBody::Message(StreamMessage { + frame: StreamFrame::Close(_), + .. + }) + )); + } +} + +#[test] +fn ack_only_write_failure_immediately_requeues_ack_without_spending_extra_seq() { + let config = EngineConfig::default(); + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(config, 65, 12); + let stream_id = + StreamId(StreamNamespace::for_local(peer.xid, engine.engine.identity.xid).bit() | 1); + let open_record = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &StreamBody::Message(StreamMessage { + tx_seq: StreamSeq::START, + ack: StreamAck::EMPTY, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Open(StreamFrameOpen { + stream_id, + request_head: b"open".to_vec(), + request_prefix: None, + }), + }), + [33; wire::encrypted_message::NONCE_SIZE], + ); + let outputs_open = engine.run_tick_collect( + now, + EngineInput::Incoming(wire::encode_record(&open_record)), + ); + assert!(outputs_open.iter().any(|output| matches!( + output, + EngineOutput::InboundStreamOpened { stream_id: id, .. } if *id == stream_id + ))); + + let _outputs_ack = + engine.run_tick_collect(now + config.stream_ack_delay, EngineInput::TimerExpired); + let ack_write = engine.take_next_write().unwrap(); + let (_, ack_body) = decode_stream_body(&ack_write.bytes, &session_key); + assert!(matches!( + ack_body, + StreamBody::Ack(StreamAckBody { + stream_id: id, + ack: StreamAck { + base: StreamSeq::START, + bitmap: 0, + }, + .. + }) if id == stream_id + )); + + let outputs_failed = engine.complete_write_collect(ack_write.id, Err(QlError::SendFailed)); + assert!(!outputs_failed + .iter() + .any(|output| matches!(output, EngineOutput::StreamReaped { .. }))); + let retry_write = engine.take_next_write().unwrap(); + let (_, retry_body) = decode_stream_body(&retry_write.bytes, &session_key); + assert!(matches!( + retry_body, + StreamBody::Ack(StreamAckBody { + stream_id: id, + ack: StreamAck { + base: StreamSeq::START, + bitmap: 0, + }, + .. + }) if id == stream_id + )); + + let _ = engine.complete_write_collect(retry_write.id, Ok(())); + + let _outputs_data = engine.run_tick_collect( + now + config.stream_ack_delay, + EngineInput::OutboundData { + stream_id, + bytes: b"resp".to_vec(), + }, + ); + let response_write = engine.take_next_write().unwrap(); + let (_, body) = decode_stream_body(&response_write.bytes, &session_key); + assert!(matches!( + body, + StreamBody::Message(StreamMessage { + tx_seq: StreamSeq::START, + frame: StreamFrame::Data(StreamFrameData { + stream_id: id, + chunk: BodyChunk { bytes, fin: false }, + }), + .. + }) if id == stream_id && bytes == b"resp" + )); + let stream = engine.streams.get(&stream_id).unwrap(); + assert_eq!(stream.control.next_tx_seq, StreamSeq(2)); +} + +#[test] +fn duplicate_committed_data_is_acked_without_redelivery() { + let config = EngineConfig::default(); + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(config, 67, 13); + let stream_id = + StreamId(StreamNamespace::for_local(peer.xid, engine.engine.identity.xid).bit() | 1); + + for (nonce, body) in [ + ( + 34u8, + StreamBody::Message(StreamMessage { + tx_seq: StreamSeq::START, + ack: StreamAck::EMPTY, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Open(StreamFrameOpen { + stream_id, + request_head: b"open".to_vec(), + request_prefix: None, + }), + }), + ), + ( + 35u8, + StreamBody::Message(StreamMessage { + tx_seq: StreamSeq(2), + ack: StreamAck::EMPTY, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: b"hello".to_vec(), + fin: false, + }, + }), + }), + ), + ] { + let record = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &body, + [nonce; wire::encrypted_message::NONCE_SIZE], + ); + let _ = engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record))); + } + + let duplicate_record = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &StreamBody::Message(StreamMessage { + tx_seq: StreamSeq(2), + ack: StreamAck::EMPTY, + valid_until: wire::now_secs().saturating_add(60), + frame: StreamFrame::Data(StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: b"hello".to_vec(), + fin: false, + }, + }), + }), + [36; wire::encrypted_message::NONCE_SIZE], + ); + let outputs_dup = engine.run_tick_collect( + now, + EngineInput::Incoming(wire::encode_record(&duplicate_record)), + ); + + assert!(!outputs_dup + .iter() + .any(|output| matches!(output, EngineOutput::InboundData { .. }))); + let ack_write = engine.take_next_write().unwrap(); + let (_, body) = decode_stream_body(&ack_write.bytes, &session_key); + assert!(matches!( + body, + StreamBody::Ack(StreamAckBody { + stream_id: id, + ack: StreamAck { + base: StreamSeq(2), + bitmap: 0, + }, + .. + }) if id == stream_id + )); +} + +#[test] +fn repeated_identical_gap_ack_only_fast_retransmits_once() { + let mut config = EngineConfig::default(); + config.stream_fast_retransmit_threshold = 2; + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(config, 69, 14); + let stream_id = engine.state.next_stream_id(StreamNamespace::for_local( + engine.engine.identity.xid, + peer.xid, + )); + insert_inflight_gap_stream(&mut engine, stream_id, now); + + let local_xid = engine.engine.identity.xid; + let remote_xid = peer.xid; + let ack_record = |nonce: u8| { + wire::stream::encrypt_stream( + QlHeader { + sender: remote_xid, + recipient: local_xid, + }, + &session_key, + &StreamBody::Ack(StreamAckBody { + stream_id, + ack: StreamAck { + base: StreamSeq(2), + bitmap: 0b0000_0110, + }, + valid_until: wire::now_secs().saturating_add(60), + }), + [nonce; wire::encrypted_message::NONCE_SIZE], + ) + }; + + let _outputs_first = engine.run_tick_collect( + now, + EngineInput::Incoming(wire::encode_record(&ack_record(37))), + ); + let write = engine.take_next_write().unwrap(); + let (_, body) = decode_stream_body(&write.bytes, &session_key); + assert!(matches!( + body, + StreamBody::Message(StreamMessage { + tx_seq: StreamSeq(3), + .. + }) + )); + + let _ = engine.complete_write_collect(write.id, Ok(())); + + let _outputs_second = engine.run_tick_collect( + now, + EngineInput::Incoming(wire::encode_record(&ack_record(38))), + ); + assert!(engine.take_next_write().is_none()); +} + +#[test] +fn fast_recovery_clears_after_gap_is_acked_and_allows_next_gap() { + let mut config = EngineConfig::default(); + config.stream_fast_retransmit_threshold = 1; + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(config, 73, 15); + let stream_id = engine.state.next_stream_id(StreamNamespace::for_local( + engine.engine.identity.xid, + peer.xid, + )); + insert_inflight_stream_with_data(&mut engine, stream_id, now, &[2, 3, 4, 5, 6]); + + let first_ack = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &StreamBody::Ack(StreamAckBody { + stream_id, + ack: StreamAck { + base: StreamSeq(2), + bitmap: 0b0000_0010, + }, + valid_until: wire::now_secs().saturating_add(60), + }), + [39; wire::encrypted_message::NONCE_SIZE], + ); + let _outputs_first = + engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&first_ack))); + let write_first = engine.take_next_write().unwrap(); + let (_, first_body) = decode_stream_body(&write_first.bytes, &session_key); + assert!(matches!( + first_body, + StreamBody::Message(StreamMessage { + tx_seq: StreamSeq(3), + .. + }) + )); + + let _ = engine.complete_write_collect(write_first.id, Ok(())); + + let second_ack = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &StreamBody::Ack(StreamAckBody { + stream_id, + ack: StreamAck { + base: StreamSeq(4), + bitmap: 0b0000_0010, + }, + valid_until: wire::now_secs().saturating_add(60), + }), + [40; wire::encrypted_message::NONCE_SIZE], + ); + let _outputs_second = + engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&second_ack))); + let write_second = engine.take_next_write().unwrap(); + let (_, second_body) = decode_stream_body(&write_second.bytes, &session_key); + assert!(matches!( + second_body, + StreamBody::Message(StreamMessage { + tx_seq: StreamSeq(5), + .. + }) + )); +} + +#[test] +fn fast_retransmit_and_retry_deadline_same_tick_only_send_once() { + let mut config = EngineConfig::default(); + config.stream_fast_retransmit_threshold = 2; + let SingleEngineHarness { + now, + mut engine, + peer, + session_key, + } = SingleEngineHarness::connected(config, 75, 16); + let stream_id = engine.state.next_stream_id(StreamNamespace::for_local( + engine.engine.identity.xid, + peer.xid, + )); + insert_inflight_gap_stream(&mut engine, stream_id, now); + engine + .streams + .get_mut(&stream_id) + .unwrap() + .control + .set_retry_deadline(StreamSeq(3), now); + + let ack_record = wire::stream::encrypt_stream( + QlHeader { + sender: peer.xid, + recipient: engine.engine.identity.xid, + }, + &session_key, + &StreamBody::Ack(StreamAckBody { + stream_id, + ack: StreamAck { + base: StreamSeq(2), + bitmap: 0b0000_0110, + }, + valid_until: wire::now_secs().saturating_add(60), + }), + [41; wire::encrypted_message::NONCE_SIZE], + ); + let _outputs_ack = + engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&ack_record))); + let _write = engine.take_next_write().unwrap(); + assert!(engine.take_next_write().is_none()); + + let _outputs_timeout = engine.run_tick_collect(now, EngineInput::TimerExpired); + assert!(engine.take_next_write().is_none()); +} diff --git a/ql2/src/platform.rs b/ql-engine/src/identity.rs similarity index 55% rename from ql2/src/platform.rs rename to ql-engine/src/identity.rs index 168944d2..b4e12886 100644 --- a/ql2/src/platform.rs +++ b/ql-engine/src/identity.rs @@ -1,13 +1,7 @@ -use std::{future::Future, pin::Pin, time::Duration}; - use bc_components::{ MLDSAPrivateKey, MLDSAPublicKey, MLKEMPrivateKey, MLKEMPublicKey, SigningPublicKey, XID, }; -use crate::{engine::PeerSession, Peer, QlError}; - -pub type PlatformFuture<'a, T> = Pin + 'a>>; - #[derive(Debug, Clone)] pub struct QlIdentity { pub xid: XID, @@ -33,19 +27,3 @@ impl QlIdentity { } } } - -pub trait QlCrypto { - fn fill_random_bytes(&self, data: &mut [u8]); -} - -pub trait QlPlatform: QlCrypto { - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>>; - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()>; - - fn load_peer(&self) -> PlatformFuture<'_, Option>; - fn persist_peer(&self, peer: Peer); - fn clear_peer(&self); - - fn handle_peer_status(&self, peer: XID, session: &PeerSession); - // fn handle_inbound(&self, event: crate::runtime::HandlerEvent); -} diff --git a/ql2/src/lib.rs b/ql-engine/src/lib.rs similarity index 69% rename from ql2/src/lib.rs rename to ql-engine/src/lib.rs index f89b06d4..b4a1e8ac 100644 --- a/ql2/src/lib.rs +++ b/ql-engine/src/lib.rs @@ -1,14 +1,9 @@ pub mod engine; -mod id; -pub mod platform; +pub mod identity; // pub mod rpc; -// pub mod runtime; pub mod wire; -pub use id::*; - -// #[cfg(test)] -// mod tests; +pub use wire::{PacketId, StreamId}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct Peer { @@ -31,12 +26,11 @@ pub enum QlError { Timeout, #[error("send failed")] SendFailed, - #[error("stream rejected {code:?}")] - StreamRejected { code: wire::stream::RejectCode }, - #[error("stream reset {code:?}")] - StreamReset { - dir: wire::stream::Direction, - code: wire::stream::ResetCode, + #[error("stream closed {code:?}")] + StreamClosed { + target: wire::stream::CloseTarget, + code: wire::stream::CloseCode, + payload: Vec, }, #[error("stream protocol error")] StreamProtocol, diff --git a/ql2/src/wire/codec.rs b/ql-engine/src/wire/codec.rs similarity index 100% rename from ql2/src/wire/codec.rs rename to ql-engine/src/wire/codec.rs diff --git a/ql2/src/wire/encrypted_message.rs b/ql-engine/src/wire/encrypted_message.rs similarity index 100% rename from ql2/src/wire/encrypted_message.rs rename to ql-engine/src/wire/encrypted_message.rs diff --git a/ql2/src/wire/handshake/crypto.rs b/ql-engine/src/wire/handshake/crypto.rs similarity index 82% rename from ql2/src/wire/handshake/crypto.rs rename to ql-engine/src/wire/handshake/crypto.rs index 7e78ba40..f2960ae9 100644 --- a/ql2/src/wire/handshake/crypto.rs +++ b/ql-engine/src/wire/handshake/crypto.rs @@ -5,14 +5,17 @@ use bc_components::{ use rkyv::{Archive, Serialize}; use super::{ - verify_signature, ArchivedConfirm, ArchivedHello, ArchivedHelloReply, Confirm, Hello, - HelloReply, + verify_signature, ArchivedConfirm, ArchivedHello, ArchivedHelloReply, ArchivedReady, Confirm, + Hello, HelloReply, Ready, ReadyBody, }; use crate::{ - platform::{QlCrypto, QlIdentity}, + engine::QlCrypto, + identity::QlIdentity, wire::{ - encode_value, ensure_not_expired, AsWireMlKemCiphertext, AsWireNonce, AsWireXid, - ControlMeta, + access_value, deserialize_value, encode_value, + encrypted_message::{EncryptedMessage, NONCE_SIZE}, + ensure_not_expired, AsWireMlKemCiphertext, AsWireNonce, AsWireXid, ControlMeta, + QlHeader, }, QlError, }; @@ -200,6 +203,38 @@ pub fn finalize_confirm( confirm: &ArchivedConfirm, secrets: &ResponderSecrets, ) -> Result { + verify_confirm( + initiator, + responder, + initiator_signing_key, + hello, + reply, + confirm, + )?; + Ok(derive_session_key( + &secrets.initiator_secret, + &secrets.responder_secret, + &handshake_transcript( + initiator, + responder, + &hello.meta, + &hello.nonce, + &hello.kem_ct, + &reply.meta, + &reply.nonce, + &reply.kem_ct, + ), + )) +} + +pub fn verify_confirm( + initiator: XID, + responder: XID, + initiator_signing_key: &MLDSAPublicKey, + hello: &Hello, + reply: &HelloReply, + confirm: &ArchivedConfirm, +) -> Result<(), QlError> { let confirm_meta: ControlMeta = (&confirm.meta).into(); ensure_not_expired(confirm_meta.valid_until)?; let confirm_signature = MLDSASignature::try_from(&confirm.signature)?; @@ -215,11 +250,33 @@ pub fn finalize_confirm( ); let proof_data = confirm_proof_data(&confirm_meta, &transcript); verify_signature(initiator_signing_key, &confirm_signature, &proof_data)?; - Ok(derive_session_key( - &secrets.initiator_secret, - &secrets.responder_secret, - &transcript, - )) + Ok(()) +} + +pub fn build_ready( + header: QlHeader, + session_key: &SymmetricKey, + meta: ControlMeta, + nonce: [u8; NONCE_SIZE], +) -> Ready { + let aad = header.aad(); + let body_bytes = encode_value(&ReadyBody { meta }); + Ready { + encrypted: EncryptedMessage::encrypt(session_key, body_bytes, &aad, nonce), + } +} + +pub fn decrypt_ready( + header: &QlHeader, + ready: &mut ArchivedReady, + session_key: &SymmetricKey, +) -> Result { + let aad = header.aad(); + let plaintext = ready.encrypted.decrypt(session_key, &aad)?; + let body = access_value::(plaintext)?; + let body = deserialize_value(body)?; + ensure_not_expired(body.meta.valid_until)?; + Ok(body) } fn handshake_transcript( diff --git a/ql2/src/wire/handshake/mod.rs b/ql-engine/src/wire/handshake/mod.rs similarity index 78% rename from ql2/src/wire/handshake/mod.rs rename to ql-engine/src/wire/handshake/mod.rs index 62b3f43f..756a3e78 100644 --- a/ql2/src/wire/handshake/mod.rs +++ b/ql-engine/src/wire/handshake/mod.rs @@ -1,7 +1,10 @@ use bc_components::{MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, Nonce}; use rkyv::{Archive, Deserialize, Serialize}; -use super::{AsWireMlDsaSignature, AsWireMlKemCiphertext, AsWireNonce, ControlMeta}; +use super::{ + encrypted_message::EncryptedMessage, AsWireMlDsaSignature, AsWireMlKemCiphertext, + AsWireNonce, ControlMeta, +}; use crate::QlError; mod crypto; @@ -12,6 +15,7 @@ pub enum HandshakeRecord { Hello(Hello), HelloReply(HelloReply), Confirm(Confirm), + Ready(Ready), } #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] @@ -43,6 +47,16 @@ pub struct Confirm { pub signature: MLDSASignature, } +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Ready { + pub encrypted: EncryptedMessage, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct ReadyBody { + pub meta: ControlMeta, +} + pub fn verify_signature( signing_key: &MLDSAPublicKey, signature: &MLDSASignature, diff --git a/ql2/src/wire/heartbeat/crypto.rs b/ql-engine/src/wire/heartbeat/crypto.rs similarity index 100% rename from ql2/src/wire/heartbeat/crypto.rs rename to ql-engine/src/wire/heartbeat/crypto.rs diff --git a/ql2/src/wire/heartbeat/mod.rs b/ql-engine/src/wire/heartbeat/mod.rs similarity index 100% rename from ql2/src/wire/heartbeat/mod.rs rename to ql-engine/src/wire/heartbeat/mod.rs diff --git a/ql2/src/id.rs b/ql-engine/src/wire/id.rs similarity index 68% rename from ql2/src/id.rs rename to ql-engine/src/wire/id.rs index d4398df5..1c32f62c 100644 --- a/ql2/src/id.rs +++ b/ql-engine/src/wire/id.rs @@ -1,6 +1,5 @@ use std::fmt; -use dcbor::CBOR; use rkyv::{Archive, Deserialize, Serialize}; macro_rules! define_id { @@ -18,6 +17,7 @@ macro_rules! define_id { PartialOrd, Ord, )] + #[repr(transparent)] pub struct $name(pub $ty); impl fmt::Display for $name { @@ -25,25 +25,11 @@ macro_rules! define_id { write!(f, "{}", self.0) } } - - impl From<$name> for CBOR { - fn from(value: $name) -> Self { - CBOR::from(value.0) - } - } - - impl TryFrom for $name { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - Ok(Self(<$ty>::try_from(value)?)) - } - } }; } define_id!(PacketId, u32); -define_id!(StreamId, u64); +define_id!(StreamId, u32); impl From<&ArchivedPacketId> for PacketId { fn from(value: &ArchivedPacketId) -> Self { diff --git a/ql-engine/src/wire/mod.rs b/ql-engine/src/wire/mod.rs new file mode 100644 index 00000000..a1b7f548 --- /dev/null +++ b/ql-engine/src/wire/mod.rs @@ -0,0 +1,522 @@ +//! quantum link protocol wire format +//! +//! naming conventions: +//! - *Record - unencrypted messages +//! - *Body - message content after decrypting +//! + +use bc_components::XID; +use rkyv::{ + api::{ + high::{to_bytes_in, HighSerializer, HighValidator}, + low::{self, LowDeserializer}, + }, + bytecheck::CheckBytes, + ser::allocator::ArenaHandle, + Archive, Deserialize, Portable, Serialize, +}; + +pub mod encrypted_message; +pub mod handshake; +pub mod heartbeat; +mod id; +pub mod pair; +pub mod seq; +pub mod stream; +pub mod unpair; + +pub use id::*; +pub use seq::StreamSeq; + +mod codec; + +pub(crate) use codec::*; + +use self::{ + encrypted_message::EncryptedMessage, handshake::HandshakeRecord, pair::PairRequestRecord, + unpair::UnpairRecord, +}; +use crate::QlError; + +pub(crate) type WireArchiveError = rkyv::rancor::Error; + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct QlRecord { + pub header: QlHeader, + pub payload: QlPayload, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct QlHeader { + #[rkyv(with = AsWireXid)] + pub sender: XID, + #[rkyv(with = AsWireXid)] + pub recipient: XID, +} + +impl QlHeader { + pub fn aad(&self) -> Vec { + encode_value(self) + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +pub struct ControlMeta { + pub packet_id: PacketId, + pub valid_until: u64, +} + +impl From<&ArchivedControlMeta> for ControlMeta { + fn from(value: &ArchivedControlMeta) -> Self { + Self { + packet_id: (&value.packet_id).into(), + valid_until: value.valid_until.to_native(), + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub enum QlPayload { + Handshake(HandshakeRecord), + Pair(PairRequestRecord), + Unpair(UnpairRecord), + Heartbeat(EncryptedMessage), + Stream(EncryptedMessage), +} + +pub fn encode_record(record: &QlRecord) -> Vec { + encode_value(record) +} + +pub fn access_record(bytes: &[u8]) -> Result<&ArchivedQlRecord, QlError> { + access_value(bytes) +} + +pub fn decode_record(bytes: &[u8]) -> Result { + deserialize_value(access_record(bytes)?) +} + +pub(crate) fn encode_value( + value: &impl for<'a> Serialize, ArenaHandle<'a>, WireArchiveError>>, +) -> Vec { + to_bytes_in::<_, WireArchiveError>(value, Vec::new()) + .expect("wire serialization should not fail") +} + +pub(crate) fn access_value(bytes: &[u8]) -> Result<&T, QlError> +where + T: Portable + for<'a> CheckBytes>, +{ + rkyv::access::(bytes).map_err(|_| QlError::InvalidPayload) +} + +pub(crate) fn deserialize_value( + value: &impl rkyv::Deserialize>, +) -> Result { + low::deserialize::(value).map_err(|_| QlError::InvalidPayload) +} + +pub(crate) fn ensure_not_expired(valid_until: u64) -> Result<(), QlError> { + if now_secs() > valid_until { + Err(QlError::Timeout) + } else { + Ok(()) + } +} + +pub(crate) fn now_secs() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|duration| duration.as_secs()) + .unwrap_or(0) +} + +#[test] +fn ql_record_round_trip() { + let record = QlRecord { + header: QlHeader { + sender: XID::from_data([1; XID::XID_SIZE]), + recipient: XID::from_data([2; XID::XID_SIZE]), + }, + payload: QlPayload::Heartbeat(encrypted_message::EncryptedMessage::encrypt( + &bc_components::SymmetricKey::from_data( + [7; bc_components::SymmetricKey::SYMMETRIC_KEY_SIZE], + ), + vec![3u8, 4, 5], + b"roundtrip", + [8; encrypted_message::NONCE_SIZE], + )), + }; + + let bytes = encode_record(&record); + let decoded = decode_record(&bytes).unwrap(); + + assert_eq!(decoded, record); +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{engine::QlCrypto, identity::QlIdentity}; + + struct SizeTestCrypto(std::sync::atomic::AtomicU8); + + impl SizeTestCrypto { + fn new(seed: u8) -> Self { + Self(std::sync::atomic::AtomicU8::new(seed)) + } + } + + impl QlCrypto for SizeTestCrypto { + fn fill_random_bytes(&self, data: &mut [u8]) { + let seed = self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + for (index, byte) in data.iter_mut().enumerate() { + *byte = seed.wrapping_add(index as u8); + } + } + } + + fn size_test_identity() -> QlIdentity { + use bc_components::{MLDSA, MLKEM}; + + let (signing_private_key, signing_public_key) = MLDSA::MLDSA44.keypair(); + let (encapsulation_private_key, encapsulation_public_key) = MLKEM::MLKEM512.keypair(); + QlIdentity::from_keys( + signing_private_key, + signing_public_key, + encapsulation_private_key, + encapsulation_public_key, + ) + } + + fn size_test_meta(packet_id: u32) -> ControlMeta { + ControlMeta { + packet_id: PacketId(packet_id), + valid_until: now_secs().saturating_add(60), + } + } + + /* + #[test] + fn protocol_record_size_breakdown() { + use crate::{ + wire::{handshake::HandshakeRecord, heartbeat::HeartbeatBody}, + StreamId, + }; + + let identity_a = size_test_identity(); + let identity_b = size_test_identity(); + let crypto_a = SizeTestCrypto::new(1); + let crypto_b = SizeTestCrypto::new(2); + + let initiator = identity_a.xid; + let responder = identity_b.xid; + + let (hello, initiator_secret) = handshake::build_hello( + &identity_a, + &crypto_a, + responder, + &identity_b.encapsulation_public_key, + size_test_meta(1), + ) + .unwrap(); + let hello_record = QlRecord { + header: QlHeader { + sender: initiator, + recipient: responder, + }, + payload: QlPayload::Handshake(HandshakeRecord::Hello(hello.clone())), + }; + let hello_size = encode_record(&hello_record).len(); + let hello_bytes = encode_value(&hello); + let hello_view = access_value::(&hello_bytes).unwrap(); + + let (hello_reply, responder_secrets) = handshake::respond_hello( + &identity_b, + &crypto_b, + initiator, + &identity_a.signing_public_key, + &identity_a.encapsulation_public_key, + hello_view, + size_test_meta(2), + ) + .unwrap(); + let reply_record = QlRecord { + header: QlHeader { + sender: responder, + recipient: initiator, + }, + payload: QlPayload::Handshake(HandshakeRecord::HelloReply(hello_reply.clone())), + }; + let reply_size = encode_record(&reply_record).len(); + let reply_bytes = encode_value(&hello_reply); + let reply_view = access_value::(&reply_bytes).unwrap(); + + let (confirm, session_key) = handshake::build_confirm( + &identity_a, + responder, + &identity_b.signing_public_key, + &hello, + reply_view, + &initiator_secret, + size_test_meta(3), + ) + .unwrap(); + let confirm_size = encode_record(&QlRecord { + header: QlHeader { + sender: initiator, + recipient: responder, + }, + payload: QlPayload::Handshake(HandshakeRecord::Confirm(confirm.clone())), + }) + .len(); + + let confirm_bytes = encode_value(&confirm); + let confirm_view = access_value::(&confirm_bytes).unwrap(); + let _session_key_b = handshake::finalize_confirm( + initiator, + responder, + &identity_a.signing_public_key, + &hello, + &hello_reply, + confirm_view, + &responder_secrets, + ) + .unwrap(); + + let pair_size = encode_record( + &pair::build_pair_request( + &identity_a, + &crypto_a, + responder, + &identity_b.encapsulation_public_key, + size_test_meta(11), + ) + .unwrap(), + ) + .len(); + + let heartbeat_size = encode_record(&heartbeat::encrypt_heartbeat( + QlHeader { + sender: initiator, + recipient: responder, + }, + &session_key, + HeartbeatBody { + meta: size_test_meta(12), + }, + [12; encrypted_message::NONCE_SIZE], + )) + .len(); + + let unpair_size = encode_record(&unpair::build_unpair_record( + &identity_a, + QlHeader { + sender: initiator, + recipient: responder, + }, + size_test_meta(13), + )) + .len(); + + let stream_header = QlHeader { + sender: initiator, + recipient: responder, + }; + let stream_record_size = |body: &stream::StreamBody, nonce: u8| { + encode_record(&stream::encrypt_stream( + stream_header.clone(), + &session_key, + body, + [nonce; encrypted_message::NONCE_SIZE], + )) + .len() + }; + + let stream_ack_body = stream::StreamBody::Ack(stream::StreamAckBody { + stream_id: StreamId(2), + ack: stream::StreamAck { + base: StreamSeq(19), + bitmap: 0b0000_0110, + }, + valid_until: now_secs().saturating_add(60), + }); + let stream_ack_record = stream::encrypt_stream( + stream_header.clone(), + &session_key, + &stream_ack_body, + [20; encrypted_message::NONCE_SIZE], + ); + let stream_ack_encrypted = match &stream_ack_record.payload { + QlPayload::Stream(encrypted) => encrypted, + _ => unreachable!(), + }; + let stream_ack_header_size = encode_value(&stream_header).len(); + let stream_ack_body_size = encode_value(&stream_ack_body).len(); + let stream_ack_envelope_size = encode_value(stream_ack_encrypted).len(); + let stream_ack_payload_size = encode_value(&stream_ack_record.payload).len(); + + let stream_open_body = stream::StreamBody::Message(stream::StreamMessage { + tx_seq: StreamSeq(21), + ack: stream::StreamAck::EMPTY, + valid_until: now_secs().saturating_add(60), + frame: stream::StreamFrame::Open(stream::StreamFrameOpen { + stream_id: StreamId(2), + request_head: vec![1, 2, 3], + request_prefix: Some(stream::BodyChunk { + bytes: vec![9, 9, 9], + fin: false, + }), + }), + }); + let stream_open_body_size = encode_value(&stream_open_body).len(); + + let stream_message_no_ack = stream::StreamBody::Message(stream::StreamMessage { + tx_seq: StreamSeq(20), + ack: stream::StreamAck::EMPTY, + valid_until: now_secs().saturating_add(60), + frame: stream::StreamFrame::Data(stream::StreamFrameData { + stream_id: StreamId(2), + dir: stream::Direction::Request, + chunk: stream::BodyChunk { + bytes: vec![7, 8, 9, 10], + fin: false, + }, + }), + }); + let stream_message_with_ack = stream::StreamBody::Message(stream::StreamMessage { + tx_seq: StreamSeq(20), + ack: stream::StreamAck { + base: StreamSeq(19), + bitmap: 0b0000_0110, + }, + valid_until: now_secs().saturating_add(60), + frame: stream::StreamFrame::Data(stream::StreamFrameData { + stream_id: StreamId(2), + dir: stream::Direction::Request, + chunk: stream::BodyChunk { + bytes: vec![7, 8, 9, 10], + fin: false, + }, + }), + }); + + let stream_ack_size = stream_record_size(&stream_ack_body, 20); + let stream_open_size = stream_record_size(&stream_open_body, 21); + let stream_accept_size = stream_record_size( + &stream::StreamBody::Message(stream::StreamMessage { + tx_seq: StreamSeq(22), + ack: stream::StreamAck::EMPTY, + valid_until: now_secs().saturating_add(60), + frame: stream::StreamFrame::Accept(stream::StreamFrameAccept { + stream_id: StreamId(2), + response_head: vec![4, 5, 6], + response_prefix: Some(stream::BodyChunk { + bytes: vec![1, 2], + fin: false, + }), + }), + }), + 22, + ); + let stream_reject_size = stream_record_size( + &stream::StreamBody::Message(stream::StreamMessage { + tx_seq: StreamSeq(23), + ack: stream::StreamAck::EMPTY, + valid_until: now_secs().saturating_add(60), + frame: stream::StreamFrame::Reject(stream::StreamFrameReject { + stream_id: StreamId(2), + code: stream::RejectCode::InvalidHead, + }), + }), + 23, + ); + let stream_data_no_ack_size = stream_record_size(&stream_message_no_ack, 24); + let stream_data_with_ack_size = stream_record_size(&stream_message_with_ack, 25); + let stream_fin_size = stream_record_size( + &stream::StreamBody::Message(stream::StreamMessage { + tx_seq: StreamSeq(26), + ack: stream::StreamAck::EMPTY, + valid_until: now_secs().saturating_add(60), + frame: stream::StreamFrame::Data(stream::StreamFrameData { + stream_id: StreamId(2), + dir: stream::Direction::Response, + chunk: stream::BodyChunk { + bytes: Vec::new(), + fin: true, + }, + }), + }), + 26, + ); + let stream_reset_size = stream_record_size( + &stream::StreamBody::Message(stream::StreamMessage { + tx_seq: StreamSeq(27), + ack: stream::StreamAck::EMPTY, + valid_until: now_secs().saturating_add(60), + frame: stream::StreamFrame::Reset(stream::StreamFrameReset { + stream_id: StreamId(2), + target: stream::ResetTarget::Both, + code: stream::ResetCode::Protocol, + }), + }), + 27, + ); + + let print_size = |label: &str, size: usize| { + println!("{label:<28}: {size} bytes"); + }; + + print_size("ql2 size hello", hello_size); + print_size("ql2 size hello_reply", reply_size); + print_size("ql2 size confirm", confirm_size); + print_size("ql2 size pair", pair_size); + print_size("ql2 size heartbeat", heartbeat_size); + print_size("ql2 size unpair", unpair_size); + print_size("ql2 size stream ack-only", stream_ack_size); + print_size("ql2 size stream open", stream_open_size); + print_size("ql2 size stream accept", stream_accept_size); + print_size("ql2 size stream reject", stream_reject_size); + print_size("ql2 size stream data no ack", stream_data_no_ack_size); + print_size("ql2 size stream data w/ ack", stream_data_with_ack_size); + print_size("ql2 size stream fin", stream_fin_size); + print_size("ql2 size stream reset", stream_reset_size); + println!( + "ql2 stream ack breakdown : header={} aad={} plaintext={} envelope={} payload={} full={}", + stream_ack_header_size, + stream_header.aad().len(), + stream_ack_body_size, + stream_ack_envelope_size, + stream_ack_payload_size, + stream_ack_size, + ); + println!( + "ql2 stream open delta : open_body={} ack_body={} (+{} bytes)", + stream_open_body_size, + stream_ack_body_size, + stream_open_body_size.saturating_sub(stream_ack_body_size), + ); + println!( + "ql2 stream data ack delta : no_ack={} with_ack={} (+{} bytes)", + stream_data_no_ack_size, + stream_data_with_ack_size, + stream_data_with_ack_size.saturating_sub(stream_data_no_ack_size), + ); + + assert!(hello_size > 0); + assert!(reply_size > 0); + assert!(confirm_size > 0); + assert!(pair_size > 0); + assert!(heartbeat_size > 0); + assert!(unpair_size > 0); + assert!(stream_ack_size > 0); + assert!(stream_open_size > 0); + assert!(stream_accept_size > 0); + assert!(stream_reject_size > 0); + assert!(stream_data_no_ack_size > 0); + assert!(stream_data_with_ack_size > 0); + assert!(stream_fin_size > 0); + assert!(stream_reset_size > 0); + } + */ +} diff --git a/ql2/src/wire/pair/crypto.rs b/ql-engine/src/wire/pair/crypto.rs similarity index 98% rename from ql2/src/wire/pair/crypto.rs rename to ql-engine/src/wire/pair/crypto.rs index 8e4748df..ab7e4033 100644 --- a/ql2/src/wire/pair/crypto.rs +++ b/ql-engine/src/wire/pair/crypto.rs @@ -5,7 +5,8 @@ use rkyv::{Archive, Serialize}; use super::{PairRequestBody, PairRequestRecord}; use crate::{ - platform::{QlCrypto, QlIdentity}, + engine::QlCrypto, + identity::QlIdentity, wire::{ access_value, deserialize_value, encode_value, encrypted_message::{ArchivedEncryptedMessage, EncryptedMessage, NONCE_SIZE}, diff --git a/ql2/src/wire/pair/mod.rs b/ql-engine/src/wire/pair/mod.rs similarity index 100% rename from ql2/src/wire/pair/mod.rs rename to ql-engine/src/wire/pair/mod.rs diff --git a/ql2/src/wire/seq.rs b/ql-engine/src/wire/seq.rs similarity index 100% rename from ql2/src/wire/seq.rs rename to ql-engine/src/wire/seq.rs diff --git a/ql2/src/wire/stream/crypto.rs b/ql-engine/src/wire/stream/crypto.rs similarity index 100% rename from ql2/src/wire/stream/crypto.rs rename to ql-engine/src/wire/stream/crypto.rs diff --git a/ql2/src/wire/stream/mod.rs b/ql-engine/src/wire/stream/mod.rs similarity index 53% rename from ql2/src/wire/stream/mod.rs rename to ql-engine/src/wire/stream/mod.rs index 8a2ff7d7..b6bfeab5 100644 --- a/ql2/src/wire/stream/mod.rs +++ b/ql-engine/src/wire/stream/mod.rs @@ -47,7 +47,7 @@ impl From<&ArchivedStreamAckBody> for StreamAckBody { #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct StreamMessage { pub tx_seq: StreamSeq, - pub ack: Option, + pub ack: StreamAck, pub valid_until: u64, pub frame: StreamFrame, } @@ -67,23 +67,26 @@ impl From<&ArchivedStreamAck> for StreamAck { } } +impl StreamAck { + pub const EMPTY: Self = Self { + base: StreamSeq(0), + bitmap: 0, + }; +} + #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub enum StreamFrame { Open(StreamFrameOpen), - Accept(StreamFrameAccept), - Reject(StreamFrameReject), Data(StreamFrameData), - Reset(StreamFrameReset), + Close(StreamFrameClose), } impl StreamFrame { pub fn stream_id(&self) -> StreamId { match self { StreamFrame::Open(StreamFrameOpen { stream_id, .. }) - | StreamFrame::Accept(StreamFrameAccept { stream_id, .. }) - | StreamFrame::Reject(StreamFrameReject { stream_id, .. }) | StreamFrame::Data(StreamFrameData { stream_id, .. }) - | StreamFrame::Reset(StreamFrameReset { stream_id, .. }) => *stream_id, + | StreamFrame::Close(StreamFrameClose { stream_id, .. }) => *stream_id, } } } @@ -120,42 +123,9 @@ impl From<&ArchivedStreamFrameOpen> for StreamFrameOpen { } } -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub struct StreamFrameAccept { - pub stream_id: StreamId, - pub response_head: Vec, - pub response_prefix: Option, -} - -impl From<&ArchivedStreamFrameAccept> for StreamFrameAccept { - fn from(value: &ArchivedStreamFrameAccept) -> Self { - Self { - stream_id: (&value.stream_id).into(), - response_head: value.response_head.as_slice().to_vec(), - response_prefix: value.response_prefix.as_ref().map(Into::into), - } - } -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] -pub struct StreamFrameReject { - pub stream_id: StreamId, - pub code: RejectCode, -} - -impl From<&ArchivedStreamFrameReject> for StreamFrameReject { - fn from(value: &ArchivedStreamFrameReject) -> Self { - Self { - stream_id: (&value.stream_id).into(), - code: (&value.code).into(), - } - } -} - #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct StreamFrameData { pub stream_id: StreamId, - pub dir: Direction, pub chunk: BodyChunk, } @@ -163,101 +133,67 @@ impl From<&ArchivedStreamFrameData> for StreamFrameData { fn from(value: &ArchivedStreamFrameData) -> Self { Self { stream_id: (&value.stream_id).into(), - dir: (&value.dir).into(), chunk: (&value.chunk).into(), } } } -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] -pub struct StreamFrameReset { +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct StreamFrameClose { pub stream_id: StreamId, - pub target: ResetTarget, - pub code: ResetCode, + pub target: CloseTarget, + pub code: CloseCode, + pub payload: Vec, } -impl From<&ArchivedStreamFrameReset> for StreamFrameReset { - fn from(value: &ArchivedStreamFrameReset) -> Self { +impl From<&ArchivedStreamFrameClose> for StreamFrameClose { + fn from(value: &ArchivedStreamFrameClose) -> Self { Self { stream_id: (&value.stream_id).into(), target: (&value.target).into(), code: (&value.code).into(), + payload: value.payload.as_slice().to_vec(), } } } #[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] -pub enum Direction { - Request = 1, - Response = 2, -} - -impl From<&ArchivedDirection> for Direction { - fn from(value: &ArchivedDirection) -> Self { - match value { - ArchivedDirection::Request => Self::Request, - ArchivedDirection::Response => Self::Response, - } - } -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] -#[repr(u8)] -pub enum ResetTarget { +pub enum CloseTarget { Request = 1, Response = 2, Both = 3, } -impl From<&ArchivedResetTarget> for ResetTarget { - fn from(value: &ArchivedResetTarget) -> Self { +impl From<&ArchivedCloseTarget> for CloseTarget { + fn from(value: &ArchivedCloseTarget) -> Self { match value { - ArchivedResetTarget::Request => Self::Request, - ArchivedResetTarget::Response => Self::Response, - ArchivedResetTarget::Both => Self::Both, + ArchivedCloseTarget::Request => Self::Request, + ArchivedCloseTarget::Response => Self::Response, + ArchivedCloseTarget::Both => Self::Both, } } } -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] -#[repr(u8)] -pub enum RejectCode { - Unknown = 0, - UnknownRoute = 1, - InvalidHead = 2, - Busy = 3, - Unhandled = 4, -} +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct CloseCode(pub u16); -impl From<&ArchivedRejectCode> for RejectCode { - fn from(value: &ArchivedRejectCode) -> Self { - match value { - ArchivedRejectCode::Unknown => Self::Unknown, - ArchivedRejectCode::UnknownRoute => Self::UnknownRoute, - ArchivedRejectCode::InvalidHead => Self::InvalidHead, - ArchivedRejectCode::Busy => Self::Busy, - ArchivedRejectCode::Unhandled => Self::Unhandled, - } - } -} +impl CloseCode { + pub const CANCELLED: Self = Self(0); + pub const PROTOCOL: Self = Self(1); + pub const INVALID_DATA: Self = Self(2); + pub const TIMEOUT: Self = Self(3); -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] -#[repr(u8)] -pub enum ResetCode { - Cancelled = 0, - InvalidData = 1, - Protocol = 2, - Timeout = 3, + pub const UNKNOWN: Self = Self(16); + pub const UNKNOWN_ROUTE: Self = Self(17); + pub const INVALID_HEAD: Self = Self(18); + pub const BUSY: Self = Self(19); + pub const UNHANDLED: Self = Self(20); } -impl From<&ArchivedResetCode> for ResetCode { - fn from(value: &ArchivedResetCode) -> Self { - match value { - ArchivedResetCode::Cancelled => Self::Cancelled, - ArchivedResetCode::InvalidData => Self::InvalidData, - ArchivedResetCode::Protocol => Self::Protocol, - ArchivedResetCode::Timeout => Self::Timeout, - } +impl From<&ArchivedCloseCode> for CloseCode { + fn from(value: &ArchivedCloseCode) -> Self { + Self(value.0.to_native()) } } diff --git a/ql2/src/wire/unpair/crypto.rs b/ql-engine/src/wire/unpair/crypto.rs similarity index 98% rename from ql2/src/wire/unpair/crypto.rs rename to ql-engine/src/wire/unpair/crypto.rs index 3225d1a7..05df157d 100644 --- a/ql2/src/wire/unpair/crypto.rs +++ b/ql-engine/src/wire/unpair/crypto.rs @@ -3,7 +3,7 @@ use rkyv::{Archive, Serialize}; use super::UnpairRecord; use crate::{ - platform::QlIdentity, + identity::QlIdentity, wire::{encode_value, ensure_not_expired, ControlMeta, QlHeader, QlPayload, QlRecord}, QlError, }; diff --git a/ql2/src/wire/unpair/mod.rs b/ql-engine/src/wire/unpair/mod.rs similarity index 100% rename from ql2/src/wire/unpair/mod.rs rename to ql-engine/src/wire/unpair/mod.rs diff --git a/ql/Cargo.toml b/ql-runtime/Cargo.toml similarity index 75% rename from ql/Cargo.toml rename to ql-runtime/Cargo.toml index 315382eb..f7eda226 100644 --- a/ql/Cargo.toml +++ b/ql-runtime/Cargo.toml @@ -1,8 +1,8 @@ [package] -name = "ql" +name = "ql-runtime" version = "0.1.0" edition = "2021" -description = "Quantum Link handshake prototype" +description = "Quantum Link runtime" license = "Proprietary" [dependencies] @@ -10,10 +10,10 @@ async-channel = { version = "2.5" } bc-components = { version = "0.28.0", default-features = false, features = [ "pqcrypto", ] } -dcbor = { version = "0.23.3" } futures-lite = { version = "2.5" } oneshot = { version = "0.1.11" } -thiserror = { version = "2" } +piper = { version = "0.2.4" } +ql-engine = { path = "../ql-engine" } [dev-dependencies] tokio = { version = "1.44", features = ["macros", "rt", "time", "sync"] } diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs new file mode 100644 index 00000000..6c5a5844 --- /dev/null +++ b/ql-runtime/src/command.rs @@ -0,0 +1,30 @@ +use crate::{ + OpenedStreamDelivery, StreamConfig, + wire::stream::{CloseCode, CloseTarget}, + Peer, QlError, StreamId, +}; + +pub(crate) enum RuntimeCommand { + BindPeer { + peer: Peer, + }, + Pair, + Connect, + Unpair, + OpenStream { + request_head: Vec, + request_reader: piper::Reader, + start: oneshot::Sender>, + config: StreamConfig, + }, + PollStream { + stream_id: StreamId, + }, + CloseStream { + stream_id: StreamId, + target: CloseTarget, + code: CloseCode, + payload: Vec, + }, + Incoming(Vec), +} diff --git a/ql-runtime/src/driver.rs b/ql-runtime/src/driver.rs new file mode 100644 index 00000000..4c7eb536 --- /dev/null +++ b/ql-runtime/src/driver.rs @@ -0,0 +1,564 @@ +use std::{ + collections::{HashMap, VecDeque}, + future::Future, + task::Poll, + time::Instant, +}; + +use futures_lite::future::poll_fn; + +use crate::{ + engine::{Engine, EngineInput, EngineOutput, WriteId}, + command::RuntimeCommand, + handle::{InboundByteStream, InboundStream, OutboundByteStream}, + platform::{PlatformFuture, QlPlatform}, + wire::stream::{BodyChunk, CloseCode, CloseTarget}, + HandlerEvent, InboundEvent, OpenedStreamDelivery, QlError, Runtime, StreamId, +}; + +struct InFlightWrite<'a> { + id: WriteId, + future: PlatformFuture<'a, Result<(), QlError>>, +} + +enum DriverEvent { + Command(RuntimeCommand), + WriteCompleted { + write_id: WriteId, + result: Result<(), QlError>, + }, + TimerExpired, + Closed, +} + +enum OutboundIo { + Open { + reader: piper::Reader, + finish_queued: bool, + }, + Closed, +} + +impl OutboundIo { + fn new(reader: piper::Reader) -> Self { + Self::Open { + reader, + finish_queued: false, + } + } + + fn close(&mut self) { + *self = Self::Closed; + } + + fn poll_pending(&mut self, stream_id: StreamId, pending_inputs: &mut VecDeque) { + let Self::Open { + reader, + finish_queued, + } = self + else { + return; + }; + + let available = reader.len(); + if available > 0 { + let mut bytes = vec![0; available]; + let read = reader.try_drain(&mut bytes); + if read > 0 { + bytes.truncate(read); + pending_inputs.push_back(EngineInput::OutboundData { stream_id, bytes }); + } + } + + if reader.is_closed() && !*finish_queued { + *finish_queued = true; + pending_inputs.push_back(EngineInput::OutboundFinished { stream_id }); + } + } +} + +enum InboundIo { + Open(async_channel::Sender), + Closed, +} + +impl InboundIo { + fn new(tx: async_channel::Sender) -> Self { + Self::Open(tx) + } + + fn write_or_close( + &mut self, + stream_id: StreamId, + target: CloseTarget, + bytes: Vec, + ) -> Option { + let Self::Open(tx) = self else { + return Some(EngineInput::CloseStream { + stream_id, + target, + code: CloseCode::CANCELLED, + payload: Vec::new(), + }); + }; + if tx.try_send(InboundEvent::Data(bytes)).is_err() { + tx.close(); + *self = Self::Closed; + return Some(EngineInput::CloseStream { + stream_id, + target, + code: CloseCode::CANCELLED, + payload: Vec::new(), + }); + } + None + } + + fn finish(&mut self) { + if let Self::Open(tx) = self { + let _ = tx.try_send(InboundEvent::Finished); + tx.close(); + } + *self = Self::Closed; + } + + fn fail(&mut self, error: QlError) { + if let Self::Open(tx) = self { + let _ = tx.try_send(InboundEvent::Failed(error)); + tx.close(); + } + *self = Self::Closed; + } + + fn close(&mut self) { + if let Self::Open(tx) = self { + let _ = tx.try_send(InboundEvent::Failed(QlError::Cancelled)); + tx.close(); + } + *self = Self::Closed; + } + + fn apply_prefix( + &mut self, + stream_id: StreamId, + target: CloseTarget, + prefix: &BodyChunk, + ) -> Option { + let mut input = None; + if !prefix.bytes.is_empty() { + input = self.write_or_close(stream_id, target, prefix.bytes.clone()); + } + if prefix.fin { + self.finish(); + } + input + } +} + +enum DriverStreamIo { + Initiator { + request: OutboundIo, + response: InboundIo, + }, + Responder { + request: InboundIo, + response: OutboundIo, + }, +} + +impl DriverStreamIo { + fn poll_pending(&mut self, stream_id: StreamId, pending_inputs: &mut VecDeque) { + match self { + Self::Initiator { request, .. } => request.poll_pending(stream_id, pending_inputs), + Self::Responder { response, .. } => response.poll_pending(stream_id, pending_inputs), + } + } + + fn outbound_mut(&mut self) -> &mut OutboundIo { + match self { + Self::Initiator { request, .. } => request, + Self::Responder { response, .. } => response, + } + } + + fn inbound_mut(&mut self) -> &mut InboundIo { + match self { + Self::Initiator { response, .. } => response, + Self::Responder { request, .. } => request, + } + } + + fn inbound_target(&self) -> CloseTarget { + match self { + Self::Initiator { .. } => CloseTarget::Response, + Self::Responder { .. } => CloseTarget::Request, + } + } + + fn close_all(&mut self) { + match self { + Self::Initiator { request, response } => { + request.close(); + response.close(); + } + Self::Responder { request, response } => { + request.close(); + response.close(); + } + } + } +} + +struct DriverState { + engine: Engine, + pending_inputs: VecDeque, + streams: HashMap, + runtime_tx: async_channel::Sender, + stream_send_buffer_bytes: usize, + max_concurrent_message_writes: usize, +} + +impl DriverState { + fn drive_command<'a, P: QlPlatform>( + &mut self, + command: RuntimeCommand, + platform: &'a P, + in_flight: &mut Vec>, + ) { + match command { + RuntimeCommand::BindPeer { peer } => { + self.drive_input(EngineInput::BindPeer(peer), platform, in_flight); + } + RuntimeCommand::Pair => { + self.drive_input(EngineInput::Pair, platform, in_flight); + } + RuntimeCommand::Connect => { + self.drive_input(EngineInput::Connect, platform, in_flight); + } + RuntimeCommand::Unpair => { + self.drive_input(EngineInput::Unpair, platform, in_flight); + } + RuntimeCommand::Incoming(bytes) => { + self.drive_input(EngineInput::Incoming(bytes), platform, in_flight); + } + RuntimeCommand::OpenStream { + request_head, + request_reader, + start, + config, + } => { + match self + .engine + .open_stream(Instant::now(), request_head, None, config) + { + Ok(stream_id) => { + let (response_tx, response_rx) = async_channel::unbounded(); + self.streams.insert( + stream_id, + DriverStreamIo::Initiator { + request: OutboundIo::new(request_reader), + response: InboundIo::new(response_tx), + }, + ); + let _ = start.send(Ok(OpenedStreamDelivery { + stream_id, + response: response_rx, + })); + self.poll_stream(stream_id); + self.drive_pending(platform, in_flight); + } + Err(error) => { + let _ = start.send(Err(error)); + } + } + } + RuntimeCommand::PollStream { stream_id } => { + self.poll_stream(stream_id); + self.drive_pending(platform, in_flight); + } + RuntimeCommand::CloseStream { + stream_id, + target, + code, + payload, + } => { + self.drive_input( + EngineInput::CloseStream { + stream_id, + target, + code, + payload, + }, + platform, + in_flight, + ); + } + } + } + + fn drive_input<'a, P: QlPlatform>( + &mut self, + input: EngineInput, + platform: &'a P, + in_flight: &mut Vec>, + ) { + self.pending_inputs.push_back(input); + self.drive_pending(platform, in_flight); + } + + fn drive_write_completed<'a, P: QlPlatform>( + &mut self, + write_id: WriteId, + result: Result<(), QlError>, + platform: &'a P, + in_flight: &mut Vec>, + ) { + { + let runtime_tx = self.runtime_tx.clone(); + let stream_send_buffer_bytes = self.stream_send_buffer_bytes; + let pending_inputs = &mut self.pending_inputs; + let streams = &mut self.streams; + self.engine.complete_write(write_id, result, &mut |output| { + handle_output( + output, + platform, + &runtime_tx, + stream_send_buffer_bytes, + pending_inputs, + streams, + ) + }); + } + self.fill_write_slots(platform, in_flight); + self.drive_pending(platform, in_flight); + } + + fn drive_pending<'a, P: QlPlatform>( + &mut self, + platform: &'a P, + in_flight: &mut Vec>, + ) { + while let Some(input) = self.pending_inputs.pop_front() { + { + let runtime_tx = &self.runtime_tx; + let stream_send_buffer_bytes = self.stream_send_buffer_bytes; + let pending_inputs = &mut self.pending_inputs; + let streams = &mut self.streams; + self.engine + .run_tick(Instant::now(), input, platform, &mut |output| { + handle_output( + output, + platform, + runtime_tx, + stream_send_buffer_bytes, + pending_inputs, + streams, + ) + }); + } + self.fill_write_slots(platform, in_flight); + } + + self.fill_write_slots(platform, in_flight); + } + + fn fill_write_slots<'a, P: QlPlatform>( + &mut self, + platform: &'a P, + in_flight: &mut Vec>, + ) { + while in_flight.len() < self.max_concurrent_message_writes { + let Some(write) = self.engine.take_next_write(platform) else { + break; + }; + in_flight.push(InFlightWrite { + id: write.id, + future: platform.write_message(write.bytes), + }); + } + } + + fn poll_stream(&mut self, stream_id: StreamId) { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + stream.poll_pending(stream_id, &mut self.pending_inputs); + } +} + +fn handle_output( + output: EngineOutput, + platform: &P, + runtime_tx: &async_channel::Sender, + stream_send_buffer_bytes: usize, + pending_inputs: &mut VecDeque, + streams: &mut HashMap, +) { + match output { + EngineOutput::PeerStatusChanged { peer, session } => { + platform.handle_peer_status(peer, &session); + } + EngineOutput::PersistPeer(peer) => platform.persist_peer(peer), + EngineOutput::ClearPeer => platform.clear_peer(), + EngineOutput::InboundStreamOpened { + stream_id, + request_head, + request_prefix, + } => { + let (request_tx, request_rx) = async_channel::unbounded(); + let mut request = InboundIo::new(request_tx); + if let Some(prefix) = request_prefix.as_ref() { + if let Some(input) = request.apply_prefix(stream_id, CloseTarget::Request, prefix) { + pending_inputs.push_back(input); + } + } + + let (response_reader, response_writer) = piper::pipe(stream_send_buffer_bytes); + streams.insert( + stream_id, + DriverStreamIo::Responder { + request, + response: OutboundIo::new(response_reader), + }, + ); + + platform.handle_inbound(HandlerEvent::Stream(InboundStream { + stream_id, + request_head, + request: InboundByteStream::new( + stream_id, + CloseTarget::Request, + request_rx, + runtime_tx.clone(), + ), + response: OutboundByteStream::new( + stream_id, + CloseTarget::Response, + response_writer, + runtime_tx.clone(), + ), + })); + } + EngineOutput::InboundData { stream_id, bytes } => { + let Some(stream) = streams.get_mut(&stream_id) else { + return; + }; + let target = stream.inbound_target(); + let inbound = stream.inbound_mut(); + if let Some(input) = inbound.write_or_close(stream_id, target, bytes) { + pending_inputs.push_back(input); + } + } + EngineOutput::InboundFinished { stream_id } => { + let Some(stream) = streams.get_mut(&stream_id) else { + return; + }; + stream.inbound_mut().finish(); + } + EngineOutput::InboundFailed { stream_id, error } => { + let Some(stream) = streams.get_mut(&stream_id) else { + return; + }; + stream.inbound_mut().fail(error); + } + EngineOutput::OutboundClosed { stream_id } + | EngineOutput::OutboundFailed { stream_id, .. } => { + let Some(stream) = streams.get_mut(&stream_id) else { + return; + }; + stream.outbound_mut().close(); + } + EngineOutput::StreamReaped { stream_id } => { + if let Some(mut stream) = streams.remove(&stream_id) { + stream.close_all(); + } + } + } +} + +async fn next_driver_event( + rx: &async_channel::Receiver, + platform: &P, + next_timer: Option, + in_flight: &mut Vec>, +) -> DriverEvent { + let recv_future = rx.recv(); + futures_lite::pin!(recv_future); + + let mut sleep_future = next_timer.map(|deadline| { + let timeout = deadline.saturating_duration_since(Instant::now()); + platform.sleep(timeout) + }); + + poll_fn(|cx| { + for write in in_flight.iter_mut() { + if let Poll::Ready(result) = write.future.as_mut().poll(cx) { + return Poll::Ready(DriverEvent::WriteCompleted { + write_id: write.id, + result, + }); + } + } + + if let Some(future) = sleep_future.as_mut() { + if let Poll::Ready(()) = future.as_mut().poll(cx) { + return Poll::Ready(DriverEvent::TimerExpired); + } + } + + recv_future.as_mut().poll(cx).map(|res| match res { + Ok(command) => DriverEvent::Command(command), + Err(_) => DriverEvent::Closed, + }) + }) + .await +} + +impl Runtime

{ + pub async fn run(self) { + let Runtime { + identity, + platform, + config, + rx, + tx, + } = self; + let peer = platform.load_peer().await; + let runtime_tx = tx.upgrade().expect("runtime tx"); + let mut state = DriverState { + engine: Engine::new(config.engine, identity, peer), + pending_inputs: VecDeque::new(), + streams: HashMap::new(), + runtime_tx, + stream_send_buffer_bytes: config.stream_send_buffer_bytes, + max_concurrent_message_writes: config.max_concurrent_message_writes, + }; + let mut in_flight = Vec::new(); + + loop { + state.drive_pending(&platform, &mut in_flight); + + if rx.is_closed() && state.pending_inputs.is_empty() && in_flight.is_empty() { + break; + } + + match next_driver_event(&rx, &platform, state.engine.next_deadline(), &mut in_flight) + .await + { + DriverEvent::Command(command) => { + state.drive_command(command, &platform, &mut in_flight); + } + DriverEvent::WriteCompleted { write_id, result } => { + if let Some(index) = in_flight.iter().position(|write| write.id == write_id) { + in_flight.swap_remove(index); + } + state.drive_write_completed(write_id, result, &platform, &mut in_flight); + } + DriverEvent::TimerExpired => { + state.drive_input(EngineInput::TimerExpired, &platform, &mut in_flight); + } + DriverEvent::Closed => break, + } + } + } +} diff --git a/ql-runtime/src/handle.rs b/ql-runtime/src/handle.rs new file mode 100644 index 00000000..1a2618a9 --- /dev/null +++ b/ql-runtime/src/handle.rs @@ -0,0 +1,293 @@ +use async_channel::{Receiver, Sender}; +use futures_lite::future::poll_fn; + +use crate::{ + command::RuntimeCommand, InboundEvent, OpenedStreamDelivery, StreamConfig, + wire::stream::{CloseCode, CloseTarget}, + Peer, QlError, StreamId, +}; + +#[derive(Clone)] +pub struct RuntimeHandle { + pub(crate) tx: Sender, + pub(crate) stream_send_buffer_bytes: usize, +} + +pub struct DuplexStream { + pub stream_id: StreamId, + pub request: OutboundByteStream, + pub response: InboundByteStream, +} + +#[derive(Debug)] +pub struct InboundStream { + pub stream_id: StreamId, + pub request_head: Vec, + pub request: InboundByteStream, + pub response: OutboundByteStream, +} + +pub struct InboundByteStream { + stream_id: StreamId, + target: CloseTarget, + rx: Receiver, + tx: Sender, + finished: bool, +} + +impl std::fmt::Debug for InboundByteStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("InboundByteStream") + .field("stream_id", &self.stream_id) + .field("target", &self.target) + .field("finished", &self.finished) + .finish_non_exhaustive() + } +} + +pub struct OutboundByteStream { + stream_id: StreamId, + target: CloseTarget, + writer: Option, + tx: Sender, +} + +impl std::fmt::Debug for OutboundByteStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OutboundByteStream") + .field("stream_id", &self.stream_id) + .field("target", &self.target) + .field("closed", &self.writer.is_none()) + .finish_non_exhaustive() + } +} + +impl InboundByteStream { + pub(crate) fn new( + stream_id: StreamId, + target: CloseTarget, + rx: Receiver, + tx: Sender, + ) -> Self { + Self { + stream_id, + target, + rx, + tx, + finished: false, + } + } + + pub async fn next_chunk(&mut self) -> Result>, QlError> { + if self.finished { + return Ok(None); + } + match self.rx.recv().await { + Ok(InboundEvent::Data(bytes)) => Ok(Some(bytes)), + Ok(InboundEvent::Finished) => { + self.finished = true; + Ok(None) + } + Ok(InboundEvent::Failed(error)) => { + self.finished = true; + Err(error) + } + Err(_) => { + self.finished = true; + Err(QlError::Cancelled) + } + } + } + + pub async fn close(mut self, code: CloseCode, payload: Vec) -> Result<(), QlError> { + if self.finished { + return Ok(()); + } + self.finished = true; + self.tx + .send(RuntimeCommand::CloseStream { + stream_id: self.stream_id, + target: self.target, + code, + payload, + }) + .await + .map_err(|_| QlError::Cancelled) + } +} + +impl Drop for InboundByteStream { + fn drop(&mut self) { + if self.finished { + return; + } + let _ = self.tx.try_send(RuntimeCommand::CloseStream { + stream_id: self.stream_id, + target: self.target, + code: CloseCode::CANCELLED, + payload: Vec::new(), + }); + } +} + +impl OutboundByteStream { + pub(crate) fn new( + stream_id: StreamId, + target: CloseTarget, + writer: piper::Writer, + tx: Sender, + ) -> Self { + Self { + stream_id, + target, + writer: Some(writer), + tx, + } + } + + fn poll_runtime(&self) -> Result<(), QlError> { + self.tx + .try_send(RuntimeCommand::PollStream { + stream_id: self.stream_id, + }) + .map_err(|_| QlError::Cancelled) + } + + pub async fn write(&mut self, bytes: &[u8]) -> Result { + if bytes.is_empty() { + return Ok(0); + } + self.poll_runtime()?; + let writer = self.writer.as_mut().expect("stream not finished or closed"); + let written = poll_fn(|cx| writer.poll_fill_bytes(cx, bytes)).await; + if written == 0 { + self.writer.take(); + return Err(QlError::Cancelled); + } + self.poll_runtime()?; + Ok(written) + } + + pub async fn write_all(&mut self, mut bytes: &[u8]) -> Result<(), QlError> { + while !bytes.is_empty() { + let written = self.write(bytes).await?; + if written == 0 { + return Err(QlError::Cancelled); + } + bytes = &bytes[written..]; + } + Ok(()) + } + + pub async fn finish(mut self) -> Result<(), QlError> { + if self.writer.take().is_none() { + return Ok(()); + } + self.poll_runtime() + } + + pub async fn close(mut self, code: CloseCode, payload: Vec) -> Result<(), QlError> { + if self.writer.take().is_none() { + return Ok(()); + } + self.tx + .send(RuntimeCommand::CloseStream { + stream_id: self.stream_id, + target: self.target, + code, + payload, + }) + .await + .map_err(|_| QlError::Cancelled) + } +} + +impl Drop for OutboundByteStream { + fn drop(&mut self) { + if self.writer.take().is_none() { + return; + } + let _ = self.tx.try_send(RuntimeCommand::CloseStream { + stream_id: self.stream_id, + target: self.target, + code: CloseCode::CANCELLED, + payload: Vec::new(), + }); + } +} + +impl RuntimeHandle { + pub fn bind_peer(&self, peer: Peer) { + self.send(RuntimeCommand::BindPeer { peer }) + } + + pub fn pair(&self) -> Result<(), QlError> { + self.tx + .send_blocking(RuntimeCommand::Pair) + .map_err(|_| QlError::Cancelled) + } + + pub fn connect(&self) -> Result<(), QlError> { + self.tx + .send_blocking(RuntimeCommand::Connect) + .map_err(|_| QlError::Cancelled) + } + + pub fn unpair(&self) -> Result<(), QlError> { + self.tx + .send_blocking(RuntimeCommand::Unpair) + .map_err(|_| QlError::Cancelled) + } + + pub fn send_incoming(&self, bytes: Vec) { + self.send(RuntimeCommand::Incoming(bytes)) + } + + pub async fn open_stream( + &self, + request_head: Vec, + config: StreamConfig, + ) -> Result { + let (request_reader, request_writer) = piper::pipe(self.stream_send_buffer_bytes); + let (start_tx, start_rx) = oneshot::channel(); + + self.tx + .send(RuntimeCommand::OpenStream { + request_head, + request_reader, + start: start_tx, + config, + }) + .await + .map_err(|_| QlError::Cancelled)?; + + let OpenedStreamDelivery { + stream_id, + response, + } = start_rx.await.unwrap_or(Err(QlError::Cancelled))?; + + Ok(DuplexStream { + stream_id, + request: OutboundByteStream::new( + stream_id, + CloseTarget::Request, + request_writer, + self.tx.clone(), + ), + response: InboundByteStream::new( + stream_id, + CloseTarget::Response, + response, + self.tx.clone(), + ), + }) + } +} + +impl RuntimeHandle { + #[inline] + #[track_caller] + fn send(&self, cmd: RuntimeCommand) { + self.tx.send_blocking(cmd).expect("runtime is alive") + } +} diff --git a/ql2/src/runtime/mod.rs b/ql-runtime/src/lib.rs similarity index 52% rename from ql2/src/runtime/mod.rs rename to ql-runtime/src/lib.rs index 670cc8ae..14ac5f98 100644 --- a/ql2/src/runtime/mod.rs +++ b/ql-runtime/src/lib.rs @@ -1,34 +1,44 @@ pub use handle::{ - AcceptedStream, InboundByteStream, InboundStream, OutboundByteStream, PendingAccept, - PendingStream, RuntimeHandle, StreamResponder, + DuplexStream, InboundByteStream, InboundStream, OutboundByteStream, RuntimeHandle, }; +pub use ql_engine::{engine, identity, wire, PacketId, Peer, QlError, StreamId}; pub use crate::engine::{ - EngineConfig, InitiatorStage, KeepAliveConfig, PeerSession, StreamConfig, Token, + EngineConfig, HandshakeInitiator, KeepAliveConfig, PeerSession, StreamConfig, }; pub(crate) mod command; pub(crate) mod driver; pub mod handle; +pub mod platform; -use crate::{platform::QlPlatform, StreamId}; +#[cfg(test)] +mod tests; + +use self::platform::QlPlatform; +use crate::identity::QlIdentity; #[derive(Debug, Clone, Copy)] pub struct RuntimeConfig { pub engine: EngineConfig, + pub stream_send_buffer_bytes: usize, + pub max_concurrent_message_writes: usize, } impl Default for RuntimeConfig { fn default() -> Self { Self { engine: EngineConfig::default(), + stream_send_buffer_bytes: 64 * 1024, + max_concurrent_message_writes: 4, } } } impl RuntimeConfig { pub(crate) fn normalized(mut self) -> Self { - self.engine = self.engine.normalized(); + self.stream_send_buffer_bytes = self.stream_send_buffer_bytes.max(1); + self.max_concurrent_message_writes = self.max_concurrent_message_writes.max(1); self } } @@ -45,21 +55,24 @@ pub(crate) enum InboundEvent { Failed(crate::QlError), } -pub(crate) struct AcceptedStreamDelivery { +pub(crate) struct OpenedStreamDelivery { pub stream_id: StreamId, - pub response_head: Vec, pub response: async_channel::Receiver, - pub tx: async_channel::Sender, } pub struct Runtime

{ + identity: QlIdentity, platform: P, config: RuntimeConfig, rx: async_channel::Receiver, tx: async_channel::WeakSender, } -pub fn new_runtime

(platform: P, config: RuntimeConfig) -> (Runtime

, RuntimeHandle) +pub fn new_runtime

( + identity: QlIdentity, + platform: P, + config: RuntimeConfig, +) -> (Runtime

, RuntimeHandle) where P: QlPlatform, { @@ -67,11 +80,15 @@ where let (tx, rx) = async_channel::unbounded(); ( Runtime { + identity, platform, config, rx, tx: tx.downgrade(), }, - RuntimeHandle { tx }, + RuntimeHandle { + tx, + stream_send_buffer_bytes: config.stream_send_buffer_bytes, + }, ) } diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs new file mode 100644 index 00000000..8a6ce873 --- /dev/null +++ b/ql-runtime/src/platform.rs @@ -0,0 +1,22 @@ +use std::{future::Future, pin::Pin, time::Duration}; + +use bc_components::XID; + +use crate::{ + engine::{PeerSession, QlCrypto}, + Peer, QlError, +}; + +pub type PlatformFuture<'a, T> = Pin + 'a>>; + +pub trait QlPlatform: QlCrypto { + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>>; + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()>; + + fn load_peer(&self) -> PlatformFuture<'_, Option>; + fn persist_peer(&self, peer: Peer); + fn clear_peer(&self); + + fn handle_peer_status(&self, peer: XID, session: &PeerSession); + fn handle_inbound(&self, event: super::HandlerEvent); +} diff --git a/ql2/src/rpc/client.rs b/ql-runtime/src/rpc/client.rs similarity index 100% rename from ql2/src/rpc/client.rs rename to ql-runtime/src/rpc/client.rs diff --git a/ql2/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs similarity index 100% rename from ql2/src/rpc/mod.rs rename to ql-runtime/src/rpc/mod.rs diff --git a/ql2/src/rpc/modality.rs b/ql-runtime/src/rpc/modality.rs similarity index 100% rename from ql2/src/rpc/modality.rs rename to ql-runtime/src/rpc/modality.rs diff --git a/ql2/src/rpc/server.rs b/ql-runtime/src/rpc/server.rs similarity index 100% rename from ql2/src/rpc/server.rs rename to ql-runtime/src/rpc/server.rs diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs new file mode 100644 index 00000000..096f38e6 --- /dev/null +++ b/ql-runtime/src/tests/handshake.rs @@ -0,0 +1,125 @@ +use std::time::Duration; + +use super::*; + +#[tokio::test(flavor = "current_thread")] +async fn connect_round_trip_changes_peer_status() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let identity_a = new_identity(); + let identity_b = new_identity(); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect().unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Connected).await; + await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn handshake_timeout_disconnects() { + run_local_test(async { + let config = RuntimeConfig { + engine: crate::engine::EngineConfig { + handshake_timeout: Duration::from_millis(60), + ..default_runtime_config().engine + }, + ..default_runtime_config() + }; + let (platform_a, _outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); + let identity_a = new_identity(); + let identity_b = new_identity(); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect().unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Disconnected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn confirm_write_failure_disconnects_initiator() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, status_a) = TestPlatform::new_with_stream_write_failure(1, 1); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); + let identity_a = new_identity(); + let identity_b = new_identity(); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect().unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Connected).await; + await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let second = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut second_request = second.request; + let mut second_response = second.response; + assert_eq!(second_request.next_chunk().await.unwrap(), None); + second_response.write_all(b"ok").await.unwrap(); + second_response.finish().await.unwrap(); + }); + + let mut first = handle_a + .open_stream(Vec::new(), crate::StreamConfig::default()) + .await + .unwrap(); + let _ = first.request.finish().await; + let _ = first.response.next_chunk().await; + + assert_no_status_for( + &status_a, + identity_b.xid, + PeerStage::Disconnected, + Duration::from_millis(150), + ) + .await; + + let mut second = handle_a + .open_stream(Vec::new(), crate::StreamConfig::default()) + .await + .unwrap(); + second.request.finish().await.unwrap(); + assert_eq!(second.response.next_chunk().await.unwrap(), Some(b"ok".to_vec())); + assert_eq!(second.response.next_chunk().await.unwrap(), None); + + tokio::time::timeout(Duration::from_secs(2), responder_task) + .await + .unwrap() + .unwrap(); + }) + .await; +} diff --git a/ql-runtime/src/tests/heartbeat.rs b/ql-runtime/src/tests/heartbeat.rs new file mode 100644 index 00000000..57ff7e53 --- /dev/null +++ b/ql-runtime/src/tests/heartbeat.rs @@ -0,0 +1,217 @@ +use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::Duration, +}; + +use super::*; + +#[tokio::test(flavor = "current_thread")] +async fn keepalive_disabled_no_heartbeat() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let identity_a = new_identity(); + let identity_b = new_identity(); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect().unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Connected).await; + await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + + let result = tokio::time::timeout(Duration::from_millis(120), heartbeat_rx.recv()).await; + assert!(result.is_err(), "unexpected heartbeat while disabled"); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn heartbeat_sent_after_idle() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(30), + timeout: Duration::from_millis(80), + }; + let config_a = RuntimeConfig { + engine: crate::engine::EngineConfig { + keep_alive: Some(keep_alive), + ..default_runtime_config().engine + }, + ..default_runtime_config() + }; + let config_b = default_runtime_config(); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let identity_a = new_identity(); + let identity_b = new_identity(); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect().unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Connected).await; + await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + + tokio::time::timeout(Duration::from_millis(200), heartbeat_rx.recv()) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn stream_activity_prevents_keepalive_timeout() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(120), + timeout: Duration::from_millis(40), + }; + let config_a = RuntimeConfig { + engine: crate::engine::EngineConfig { + keep_alive: Some(keep_alive), + ..default_runtime_config().engine + }, + ..default_runtime_config() + }; + let config_b = default_runtime_config(); + let (platform_a, outbound_a, status_a, inbound_a) = TestPlatform::new_with_inbound(1); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); + let identity_a = new_identity(); + let identity_b = new_identity(); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); + spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); + spawn_drop_heartbeat_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect().unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Connected).await; + await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + + tokio::time::timeout(Duration::from_millis(200), heartbeat_rx.recv()) + .await + .unwrap() + .unwrap(); + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_a.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let response = stream.response; + response.finish().await.unwrap(); + }); + + let stream = handle_b.open_stream(Vec::new(), crate::StreamConfig::default()).await; + let mut stream = stream.unwrap(); + stream.request.finish().await.unwrap(); + assert_eq!(stream.response.next_chunk().await.unwrap(), None); + + let disconnect = tokio::time::timeout(keep_alive.timeout + Duration::from_millis(20), async { + loop { + if let Ok(event) = status_a.recv().await { + if event.peer == identity_b.xid && event.stage == PeerStage::Disconnected { + return; + } + } + } + }) + .await; + assert!(disconnect.is_err(), "unexpected disconnect"); + + let _ = responder_task.await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn heartbeat_timeout_disconnects_and_fails_pending_open() { + run_local_test(async { + let keep_alive = KeepAliveConfig { + interval: Duration::from_millis(80), + timeout: Duration::from_millis(60), + }; + let config_a = RuntimeConfig { + engine: crate::engine::EngineConfig { + keep_alive: Some(keep_alive), + ..default_runtime_config().engine + }, + ..default_runtime_config() + }; + let config_b = default_runtime_config(); + let (platform_a, outbound_a, status_a) = TestPlatform::new(2); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(1); + let identity_a = new_identity(); + let identity_b = new_identity(); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let drop_flag = Arc::new(AtomicBool::new(false)); + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_gated_forwarder(outbound_b, handle_a.clone(), drop_flag.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect().unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Connected).await; + await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let response = stream.response; + response.finish().await.unwrap(); + }); + + drop_flag.store(true, Ordering::Relaxed); + + let mut pending = handle_a + .open_stream(Vec::new(), crate::StreamConfig::default()) + .await + .unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Disconnected).await; + + let result = tokio::time::timeout(Duration::from_millis(300), pending.response.next_chunk()) + .await; + assert!(result.is_ok(), "pending stream never resolved after disconnect"); + + responder_task.abort(); + }) + .await; +} diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs new file mode 100644 index 00000000..e7c3e65f --- /dev/null +++ b/ql-runtime/src/tests/mod.rs @@ -0,0 +1,389 @@ +use std::{ + future::Future, + sync::{ + atomic::{AtomicU8, AtomicUsize, Ordering}, + Arc, + }, + time::Duration, +}; + +use async_channel::{Receiver, Sender}; +use bc_components::{MLDSA, MLKEM}; +use tokio::task::LocalSet; + +use crate::{ + engine::QlCrypto, + identity::QlIdentity, + new_runtime, + platform::PlatformFuture, + wire::{self, QlPayload}, + HandlerEvent, KeepAliveConfig, Peer, PeerSession, QlError, RuntimeConfig, RuntimeHandle, +}; + +mod heartbeat; +mod handshake; +mod stream; +mod unpair; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PeerStage { + Disconnected, + Initiator, + Responder, + Connected, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct StatusEvent { + peer: bc_components::XID, + stage: PeerStage, +} + +#[derive(Debug, Clone)] +struct WriteStats { + active: Arc, + max_active: Arc, +} + +impl WriteStats { + fn new() -> Self { + Self { + active: Arc::new(AtomicUsize::new(0)), + max_active: Arc::new(AtomicUsize::new(0)), + } + } + + fn max_active(&self) -> usize { + self.max_active.load(Ordering::Relaxed) + } +} + +struct TestPlatform { + outbound: Sender>, + status: Sender, + inbound: Option>, + nonce_seed: u8, + nonce_counter: AtomicU8, + stream_write_counter: AtomicUsize, + fail_stream_write_at: Option, + write_delay: Duration, + write_stats: Option, +} + +impl TestPlatform { + fn new(seed: u8) -> (Self, Receiver>, Receiver) { + Self::new_inner(seed, None, None, Duration::ZERO, None) + } + + fn new_with_inbound( + seed: u8, + ) -> ( + Self, + Receiver>, + Receiver, + Receiver, + ) { + let (inbound_tx, inbound_rx) = async_channel::unbounded(); + let (platform, outbound_rx, status_rx) = + Self::new_inner(seed, Some(inbound_tx), None, Duration::ZERO, None); + (platform, outbound_rx, status_rx, inbound_rx) + } + + fn new_with_stream_write_failure( + seed: u8, + fail_stream_write_at: usize, + ) -> (Self, Receiver>, Receiver) { + Self::new_inner( + seed, + None, + Some(fail_stream_write_at), + Duration::ZERO, + None, + ) + } + + fn new_with_delayed_writes( + seed: u8, + delay: Duration, + write_stats: WriteStats, + ) -> (Self, Receiver>, Receiver) { + Self::new_inner(seed, None, None, delay, Some(write_stats)) + } + + fn new_inner( + seed: u8, + inbound: Option>, + fail_stream_write_at: Option, + write_delay: Duration, + write_stats: Option, + ) -> (Self, Receiver>, Receiver) { + let (outbound, outbound_rx) = async_channel::unbounded(); + let (status, status_rx) = async_channel::unbounded(); + ( + Self { + outbound, + status, + inbound, + nonce_seed: seed, + nonce_counter: AtomicU8::new(0), + stream_write_counter: AtomicUsize::new(0), + fail_stream_write_at, + write_delay, + write_stats, + }, + outbound_rx, + status_rx, + ) + } +} + +impl QlCrypto for TestPlatform { + fn fill_random_bytes(&self, data: &mut [u8]) { + let value = self + .nonce_seed + .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); + data.fill(value); + } +} + +impl crate::platform::QlPlatform for TestPlatform { + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { + let outbound = self.outbound.clone(); + let write_delay = self.write_delay; + let fail_stream_write_at = self.fail_stream_write_at; + let write_stats = self.write_stats.clone(); + + Box::pin(async move { + if let Some(stats) = write_stats.as_ref() { + let active = stats.active.fetch_add(1, Ordering::Relaxed) + 1; + stats.max_active.fetch_max(active, Ordering::Relaxed); + } + + if !write_delay.is_zero() { + tokio::time::sleep(write_delay).await; + } + + let mut should_fail = false; + if is_stream_payload(&message) { + let count = self.stream_write_counter.fetch_add(1, Ordering::Relaxed) + 1; + should_fail = fail_stream_write_at == Some(count); + } + + let result = if should_fail { + Err(QlError::SendFailed) + } else { + outbound + .send(message) + .await + .map_err(|_| QlError::InvalidPayload) + }; + + if let Some(stats) = write_stats.as_ref() { + stats.active.fetch_sub(1, Ordering::Relaxed); + } + + result + }) + } + + fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { + Box::pin(tokio::time::sleep(duration)) + } + + fn load_peer(&self) -> PlatformFuture<'_, Option> { + Box::pin(async { None }) + } + + fn persist_peer(&self, _peer: Peer) {} + + fn clear_peer(&self) {} + + fn handle_peer_status(&self, peer: bc_components::XID, session: &PeerSession) { + let stage = match session { + PeerSession::Disconnected => PeerStage::Disconnected, + PeerSession::Initiator { .. } => PeerStage::Initiator, + PeerSession::Responder { .. } => PeerStage::Responder, + PeerSession::Connected { .. } => PeerStage::Connected, + }; + let _ = self.status.try_send(StatusEvent { peer, stage }); + } + + fn handle_inbound(&self, event: HandlerEvent) { + if let Some(tx) = &self.inbound { + let _ = tx.try_send(event); + } + } +} + +fn is_stream_payload(bytes: &[u8]) -> bool { + wire::decode_record(bytes) + .ok() + .is_some_and(|record| matches!(record.payload, QlPayload::Stream(_))) +} + +fn new_identity() -> QlIdentity { + let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); + let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); + QlIdentity::from_keys( + signing_private, + signing_public, + encapsulation_private, + encapsulation_public, + ) +} + +fn peer_from_identity(identity: &QlIdentity) -> Peer { + Peer { + peer: identity.xid, + signing_key: identity.signing_public_key.clone(), + encapsulation_key: identity.encapsulation_public_key.clone(), + } +} + +fn register_peers( + handle_a: &RuntimeHandle, + handle_b: &RuntimeHandle, + id_a: &QlIdentity, + id_b: &QlIdentity, +) { + handle_a.bind_peer(peer_from_identity(id_b)); + handle_b.bind_peer(peer_from_identity(id_a)); +} + +fn spawn_forwarder(outbound: Receiver>, handle: RuntimeHandle) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + handle.send_incoming(bytes); + } + }); +} + +fn spawn_drop_every_nth_stream_forwarder( + outbound: Receiver>, + handle: RuntimeHandle, + nth: usize, +) { + tokio::task::spawn_local(async move { + let mut stream_count = 0usize; + while let Ok(bytes) = outbound.recv().await { + if nth > 0 && is_stream_payload(&bytes) { + stream_count = stream_count.saturating_add(1); + if stream_count % nth == 0 { + continue; + } + } + handle.send_incoming(bytes); + } + }); +} + +fn is_heartbeat(bytes: &[u8]) -> bool { + wire::decode_record(bytes) + .ok() + .is_some_and(|record| matches!(record.payload, QlPayload::Heartbeat(_))) +} + +fn spawn_heartbeat_tap_forwarder( + outbound: Receiver>, + handle: RuntimeHandle, + heartbeat_tx: Sender<()>, +) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + if is_heartbeat(&bytes) { + let _ = heartbeat_tx.send(()).await; + } + handle.send_incoming(bytes); + } + }); +} + +fn spawn_drop_heartbeat_forwarder(outbound: Receiver>, handle: RuntimeHandle) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + if is_heartbeat(&bytes) { + continue; + } + handle.send_incoming(bytes); + } + }); +} + +fn spawn_gated_forwarder( + outbound: Receiver>, + handle: RuntimeHandle, + drop_flag: Arc, +) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + if drop_flag.load(Ordering::Relaxed) { + continue; + } + handle.send_incoming(bytes); + } + }); +} + +async fn run_local_test(future: F) +where + F: Future, +{ + let local = LocalSet::new(); + local.run_until(future).await; +} + +async fn await_status( + receiver: &Receiver, + peer: bc_components::XID, + stage: PeerStage, +) { + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if let Ok(event) = receiver.recv().await { + if event.peer == peer && event.stage == stage { + return; + } + } + } + }) + .await + .unwrap(); +} + +async fn assert_no_status_for( + receiver: &Receiver, + peer: bc_components::XID, + stage: PeerStage, + window: Duration, +) { + let res = tokio::time::timeout(window, async { + loop { + let event = receiver.recv().await.unwrap(); + if event.peer == peer && event.stage == stage { + return; + } + } + }) + .await; + assert!(res.is_err(), "unexpected status event: {stage:?}"); +} + +async fn read_all(mut stream: crate::InboundByteStream) -> Result, QlError> { + let mut data = Vec::new(); + while let Some(chunk) = stream.next_chunk().await? { + data.extend_from_slice(&chunk); + } + Ok(data) +} + +fn default_runtime_config() -> RuntimeConfig { + RuntimeConfig { + engine: crate::engine::EngineConfig { + handshake_timeout: Duration::from_millis(300), + stream_ack_timeout: Duration::from_millis(30), + stream_retry_limit: 8, + ..Default::default() + }, + ..Default::default() + } +} diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs new file mode 100644 index 00000000..476999ef --- /dev/null +++ b/ql-runtime/src/tests/stream.rs @@ -0,0 +1,386 @@ +use std::time::Duration; + +use super::*; +use crate::{ + StreamConfig, + wire::stream::{CloseCode, CloseTarget}, +}; + +#[tokio::test(flavor = "current_thread")] +async fn open_stream_duplex_happy_path() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); + let identity_a = new_identity(); + let identity_b = new_identity(); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect().unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Connected).await; + await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + + let responder = tokio::task::spawn_local(async move { + let inbound = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + assert_eq!(inbound.request_head, b"req-head".to_vec()); + + let mut request = inbound.request; + let mut response = inbound.response; + + assert_eq!(request.next_chunk().await.unwrap(), Some(vec![1, 2])); + response.write_all(&[9]).await.unwrap(); + assert_eq!(request.next_chunk().await.unwrap(), Some(vec![3, 4])); + response.write_all(&[8, 7]).await.unwrap(); + assert_eq!(request.next_chunk().await.unwrap(), None); + response.finish().await.unwrap(); + }); + + let mut stream = handle_a + .open_stream(b"req-head".to_vec(), StreamConfig::default()) + .await + .unwrap(); + stream.request.write_all(&[1, 2]).await.unwrap(); + assert_eq!(stream.response.next_chunk().await.unwrap(), Some(vec![9])); + stream.request.write_all(&[3, 4]).await.unwrap(); + stream.request.finish().await.unwrap(); + assert_eq!(stream.response.next_chunk().await.unwrap(), Some(vec![8, 7])); + assert_eq!(stream.response.next_chunk().await.unwrap(), None); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn stream_backpressure_with_small_runtime_buffer() { + run_local_test(async { + let config = RuntimeConfig { + stream_send_buffer_bytes: 4, + ..default_runtime_config() + }; + let payload: Vec = (0..40).collect(); + + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); + let identity_a = new_identity(); + let identity_b = new_identity(); + let (done_tx, done_rx) = async_channel::bounded(1); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect().unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Connected).await; + await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + + let responder = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let request_data = read_all(stream.request).await.unwrap(); + stream.response.finish().await.unwrap(); + done_tx.send(request_data).await.unwrap(); + }); + + let mut stream = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + stream.request.write_all(&payload).await.unwrap(); + stream.request.finish().await.unwrap(); + assert_eq!(stream.response.next_chunk().await.unwrap(), None); + + let received = tokio::time::timeout(Duration::from_secs(2), done_rx.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(received, payload); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn dropping_responder_rejects_as_unhandled() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); + let identity_a = new_identity(); + let identity_b = new_identity(); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect().unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Connected).await; + await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + + let responder = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + drop(stream.response); + }); + + let mut stream = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + stream.request.finish().await.unwrap(); + + let err = stream.response.next_chunk().await.unwrap_err(); + assert!(matches!( + err, + QlError::StreamClosed { + target: CloseTarget::Response, + code: CloseCode::CANCELLED, + payload, + } if payload.is_empty() + )); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn dropping_inbound_reader_cancels_remote_writer() { + run_local_test(async { + let config = RuntimeConfig { + stream_send_buffer_bytes: 4, + ..default_runtime_config() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); + let identity_a = new_identity(); + let identity_b = new_identity(); + let (go_tx, go_rx) = async_channel::bounded(1); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect().unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Connected).await; + await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + + let responder = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut request = stream.request; + let mut response = stream.response; + assert_eq!(request.next_chunk().await.unwrap(), None); + response.write_all(&[1, 2, 3, 4]).await.unwrap(); + go_rx.recv().await.unwrap(); + let err = response.write_all(&[5; 64]).await.unwrap_err(); + assert!(matches!(err, QlError::Cancelled)); + }); + + let mut stream = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + stream.request.finish().await.unwrap(); + assert_eq!(stream.response.next_chunk().await.unwrap(), Some(vec![1, 2, 3, 4])); + drop(stream.response); + go_tx.send(()).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn max_concurrent_message_writes_is_respected() { + run_local_test(async { + let stats = WriteStats::new(); + let config = RuntimeConfig { + max_concurrent_message_writes: 2, + ..default_runtime_config() + }; + let (platform_a, outbound_a, status_a) = + TestPlatform::new_with_delayed_writes(1, Duration::from_millis(40), stats.clone()); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); + let identity_a = new_identity(); + let identity_b = new_identity(); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect().unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Connected).await; + await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + + let responder = tokio::task::spawn_local(async move { + for _ in 0..4 { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let request = stream.request; + let response = stream.response; + let _ = read_all(request).await; + let _ = response.finish().await; + } + }); + + let mut tasks = Vec::new(); + for i in 0..4u8 { + let handle = handle_a.clone(); + tasks.push(tokio::task::spawn_local(async move { + let mut stream = handle + .open_stream(vec![i], StreamConfig::default()) + .await + .unwrap(); + stream.request.write_all(&[i; 8]).await.unwrap(); + stream.request.finish().await.unwrap(); + assert_eq!(stream.response.next_chunk().await.unwrap(), None); + })); + } + + for task in tasks { + tokio::time::timeout(Duration::from_secs(4), task) + .await + .unwrap() + .unwrap(); + } + + tokio::time::timeout(Duration::from_secs(4), responder) + .await + .unwrap() + .unwrap(); + + assert!( + stats.max_active() <= 2, + "max active writes exceeded: {}", + stats.max_active() + ); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn stream_round_trip_survives_packet_drops() { + run_local_test(async { + let config = RuntimeConfig { + engine: crate::engine::EngineConfig { + stream_retry_limit: 12, + stream_ack_timeout: Duration::from_millis(20), + ..default_runtime_config().engine + }, + stream_send_buffer_bytes: 4, + ..default_runtime_config() + }; + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); + let identity_a = new_identity(); + let identity_b = new_identity(); + + let request_payload: Vec = (0..32).collect(); + let response_payload: Vec = (100..132).collect(); + let expected_response = response_payload.clone(); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_drop_every_nth_stream_forwarder(outbound_a, handle_b.clone(), 3); + spawn_drop_every_nth_stream_forwarder(outbound_b, handle_a.clone(), 3); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect().unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Connected).await; + await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + + let responder = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let received_request = read_all(stream.request).await.unwrap(); + let mut response = stream.response; + response.write_all(&response_payload).await.unwrap(); + response.finish().await.unwrap(); + received_request + }); + + let mut stream = handle_a + .open_stream(Vec::new(), StreamConfig::default()) + .await + .unwrap(); + stream.request.write_all(&request_payload).await.unwrap(); + stream.request.finish().await.unwrap(); + + let mut received_response = Vec::new(); + while let Some(chunk) = stream.response.next_chunk().await.unwrap() { + received_response.extend_from_slice(&chunk); + } + assert_eq!(received_response, expected_response); + + let received_request = tokio::time::timeout(Duration::from_secs(4), responder) + .await + .unwrap() + .unwrap(); + assert_eq!(received_request, request_payload); + }) + .await; +} diff --git a/ql-runtime/src/tests/unpair.rs b/ql-runtime/src/tests/unpair.rs new file mode 100644 index 00000000..e73be578 --- /dev/null +++ b/ql-runtime/src/tests/unpair.rs @@ -0,0 +1,76 @@ +use super::*; + +#[tokio::test(flavor = "current_thread")] +async fn unpair_aborts_active_stream_and_clears_peer() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); + let identity_a = new_identity(); + let identity_b = new_identity(); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect().unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Connected).await; + await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + + let responder = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let mut request = stream.request; + let _response = stream.response; + let first = request.next_chunk().await; + assert!(matches!(first, Ok(Some(_)) | Ok(None) | Err(_))); + let second = request.next_chunk().await; + assert!(matches!( + second, + Ok(None) + | Err(QlError::Cancelled) + | Err(QlError::SendFailed) + | Err(QlError::StreamClosed { .. }) + | Err(QlError::StreamProtocol) + )); + }); + + let mut stream = handle_a + .open_stream(Vec::new(), crate::StreamConfig::default()) + .await + .unwrap(); + stream.request.write_all(&[1, 2, 3, 4]).await.unwrap(); + + handle_a.unpair().unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Disconnected).await; + await_status(&status_b, identity_a.xid, PeerStage::Disconnected).await; + + let write_err = stream.request.write_all(&[5, 6, 7, 8]).await.unwrap_err(); + assert!(matches!(write_err, QlError::Cancelled)); + + let open_err_a = handle_a + .open_stream(Vec::new(), crate::StreamConfig::default()) + .await; + let open_err_b = handle_b + .open_stream(Vec::new(), crate::StreamConfig::default()) + .await; + + assert!(matches!(open_err_a, Err(QlError::NoPeerBound))); + assert!(matches!(open_err_b, Err(QlError::NoPeerBound))); + + tokio::time::timeout(std::time::Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} diff --git a/ql/README.md b/ql/README.md deleted file mode 100644 index d39e4e90..00000000 --- a/ql/README.md +++ /dev/null @@ -1,143 +0,0 @@ -# QL Protocol (v2) - -QL is a compact, session-oriented protocol for authenticated and encrypted messaging -between peers over arbitrary transports. It targets low-bandwidth and high-latency -links while preserving strong cryptography, explicit request/response semantics, and -a clean developer-facing API. - -This crate (`ql`) implements the protocol stack: wire format, crypto, runtime state -machine, and routing. For a deeper comparison with v1, see `ql-protocol-v2.md`. - -## features -- Fixed CBOR wire format: `QlRecord` = `[tag, header, payload]`. -- Mutual-auth handshake (`Hello`, `HelloReply`, `Confirm`) with signed transcript. -- Session keys derived from KEM secrets; payloads use AEAD (ChaCha20-Poly1305). -- Sessions are ephemeral and scoped to a handshake; no long-lived symmetric keys. -- First-contact pairing request with KEM-wrapped payloads and proof signature. -- Encrypted messages with explicit `Request`, `Response`, `Event`, and `Nack`. -- `MessageId`, `RouteId`, and `valid_until` for routing and freshness. -- Heartbeats for keepalive and disconnect detection. -- Runtime state machine for sessions, timeouts, outbound queues, and correlation. -- Router for typed dispatch and automatic response wiring. -- Transport abstraction via `QlPlatform` for BLE, TCP, or other links. - -## overview -QL provides a full session protocol rather than isolated message sealing. It covers: -- Mutual authentication and end-to-end encryption above the transport. -- First-contact pairing for provisioning keys and establishing trust. -- Typed routing with explicit request/response/event semantics. -- Runtime lifecycle management (handshake, keepalive, timeouts, errors). -- Portability across transports via a minimal platform abstraction. - -### security -- Mutual authentication via a signed handshake transcript. -- Session keys derived from KEM secrets; payloads are protected with AEAD - and header AAD. -- End-to-end protection above the transport layer; pairing supports first-contact - key exchange and proof of key possession. -- Message freshness enforced via `valid_until`; replay caching is not built-in, - so applications can optionally track `MessageId` if needed. - -### session vs per-message sealing -- v1 (gstp + envelope) signs every message and then encrypts it to the recipient. - each message uses fresh encapsulation, so keys and signatures are per-message. -- v2 (ql) signs the handshake transcript once, derives a session key, then uses - AEAD for each message with the header as AAD. -- encryption strength uses the same primitive (ChaCha20-Poly1305). post-quantum - security depends on key schemes (ML-KEM + ML-DSA with `pqcrypto` enabled). -- tradeoffs: v2 is faster and smaller; v1 has per-message signature and key - isolation. v2's AEAD provides in-session integrity but is not publicly - verifiable and has a larger blast radius if a session key leaks. - -### performance -- Public-key operations are paid once per session; steady-state traffic is - symmetric AEAD. -- Compact CBOR record framing keeps headers and serialization overhead small. -- Optional heartbeats provide liveness detection without heavy traffic. - -### developer experience -- Typed routes via `RequestResponse` and `Event` traits with explicit `RouteId`. -- Router handles decode, dispatch, and response wiring automatically. -- Runtime manages sessions, timeouts, outbound queues, and request correlation. -- `QlPlatform` abstracts the transport for portability and testability. - -## message sizes -Sizes below are CBOR record sizes from `protocol_record_size_breakdown` in -`ql/src/tests/mod.rs`. - -| Record | Size (bytes) | -| :-- | --: | -| Handshake Hello | 132 | -| Handshake HelloReply | 2563 | -| Handshake Confirm | 2510 | -| Pair request | 4065 | -| Message (empty payload) | 199 | -| Heartbeat | 196 | - -Handshake total is 5205 bytes (132 + 2563 + 2510). At 20 kBps transport -throughput, raw transmit time is about 0.26 s. - -## protocol overview - -### record framing -All traffic is encoded as a `QlRecord` with a small, fixed shape: -- `tag` selects the payload type (handshake, pair, record, heartbeat). -- `header` is unencrypted but authenticated data (AEAD AAD) used for routing - (sender and recipient XIDs). -- `payload` is a CBOR-encoded handshake/pair body or an encrypted message. - -### handshake -The handshake is a three-message exchange: -- `Hello`: initiator sends a nonce and KEM ciphertext. -- `HelloReply`: responder returns its nonce, KEM ciphertext, and a signature - over the transcript. -- `Confirm`: initiator signs the transcript to confirm mutual authentication. - -Both sides derive the session key from the KEM secrets and transcript digest. -After the handshake, all records use symmetric AEAD with the header as AAD. - -### pairing (first-contact) -Pairing is a standalone request that KEM-encrypts a payload containing: -- a `MessageId` and `valid_until` timestamp -- the sender's signing and encapsulation public keys -- a proof signature binding those keys - -This enables establishing trust without an existing session. - -### message records -Steady-state messages are sent as encrypted records with a typed body: -- `MessageKind`: `Request`, `Response`, `Event`, or `Nack` -- `MessageId`, `RouteId`, `valid_until`, and CBOR payload - -Nacks communicate standard failure reasons (unknown route, invalid payload, -expired) so peers can recover consistently. - -### heartbeats -Heartbeats are lightweight encrypted records used by the runtime to maintain -session liveness and detect disconnects. - -### routing and dispatch -`RouteId` maps to concrete request/response or event types. The router decodes -payloads, dispatches handlers, and ensures each request receives a response or -a `Nack`. - -### sequence diagram -```mermaid -sequenceDiagram - participant A as Initiator - participant B as Responder - A->>B: Hello (nonce, KEM ct) - B->>A: HelloReply (nonce, KEM ct, signature) - A->>B: Confirm (signature) - Note over A,B: Session key derived, AEAD enabled - A->>B: Encrypted Record (Request) - B->>A: Encrypted Record (Response) - A-->>B: Encrypted Heartbeat (optional) -``` - -## code map -- Wire format: `ql/src/wire/*` -- Cryptography: `ql/src/crypto/*` -- Runtime state machine: `ql/src/runtime/*` -- Routing and traits: `ql/src/router.rs`, `ql/src/lib.rs` -- Transport abstraction: `ql/src/platform.rs` diff --git a/ql/ql-v2.presenterm.md b/ql/ql-v2.presenterm.md deleted file mode 100644 index d4a0fff2..00000000 --- a/ql/ql-v2.presenterm.md +++ /dev/null @@ -1,285 +0,0 @@ ---- -theme: - name: gruvbox-dark ---- - -# quantumlink protocol v2 -post-quantum, session-based message protocol - - - -# ql v1: constraints -- no message id / sequence id -- no protocol-level request/response pairing -- each platform had to interpret + correlate by hand -- no ack/nack -- no notion of 'liveness'/'connected' status -- ~6.6KB min sealed event - - sender xid document (pq pubkeys) - - per-message signature - - recipient encryption (+ continuations) -- more a utility crate than a protocol - - - -# v1 vs v2 - - - - -## v1 -- gstp sealed envelope per message -- per-message sign+encrypt (envelope) -- implicit req/resp in enum variants -- app-owned pairing, timeouts, keepalive, connected status - - - -## v2 -- compact record + typed payloads -- handshake signatures + per‑message aead under symmetric session key -- explicit kind + ids + nack -- runtime handles pairing, timeouts, keepalive, connected status, request/response matching - - - -# design shift: per-message -> session -- v1 sealed each message -- v2 signs once, then aead per message - -```text -v1: seal(msg) = sign(msg) + encrypt(recipient) -v2: session_key = handshake() -v2: aead(msg, aad=header) -``` - - -_aead = authenticated encryption with associated data_ - -_aad = additional authenticated data (visible, integrity-protected)_ - - - -# configurable host platform -- same runtime across keyos / mobile / desktop -- host supplies pq keys, io, timers, callbacks - -```rust -pub trait QlPlatform { - // pq identity - fn signing_private_key(&self) -> &MLDSAPrivateKey; - fn signing_public_key(&self) -> &MLDSAPublicKey; - fn encapsulation_private_key(&self) -> &MLKEMPrivateKey; - fn encapsulation_public_key(&self) -> &MLKEMPublicKey; - - // transport + runtime hooks - fn fill_bytes(&self, data: &mut [u8]); - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>>; - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()>; - - // event handlers - fn handle_peer_status(&self, peer: XID, session: &PeerSession); - fn handle_inbound(&self, event: HandlerEvent); -} -``` - - - -# multi-peer runtime -- runtime tracks sessions per peer -- concurrent handshakes + keepalive per peer - -```rust -handle.register_peer(peer, signing_key, encapsulation_key); -handle.connect(peer)?; -``` - - - -# protocol breakdown -```mermaid +render +width:90% -sequenceDiagram - participant A as initiator - participant B as responder - - Note over A,B: pairing (first contact) - A->>B: pair request (kem + signed payload) - - Note over A,B: handshake (mutual auth) - A->>B: hello (nonce + kem ct) - B->>A: hello reply (nonce + kem ct + signature) - A->>B: confirm (signature) - - Note over A,B: session established - A->>B: request / event (aead + aad header) - B->>A: response / nack (aead + aad header) - A-->>B: heartbeat (optional) -``` - - - -# wire framing: routable header -- record = [tag, header, payload] -- header is unencrypted but authenticated (aad) - -```rust -pub struct QlRecord { - pub header: QlHeader, - pub payload: QlPayload, -} - -pub struct QlHeader { - pub sender: XID, - pub recipient: XID, -} -``` - - - -# handshake flow + records -- hello: nonce + mlkem ciphertext -- reply: nonce + mlkem ciphertext + mldsa signature -- confirm: mldsa signature, then session key - -```rust -pub struct Hello { - pub nonce: Nonce, - pub kem_ct: MLKEMCiphertext, -} - -pub struct HelloReply { - pub nonce: Nonce, - pub kem_ct: MLKEMCiphertext, - pub signature: MLDSASignature, -} - -pub struct Confirm { - pub signature: MLDSASignature, -} -``` - - - -# session key derivation -- transcript binds ids + nonces + kem ciphertexts -- session key = digest(initiator_secret, responder_secret, transcript) - -```rust -let transcript = cbor([ - initiator, responder, - hello.nonce, reply.nonce, - hello.kem_ct, reply.kem_ct, -]); -let payload = cbor([initiator_secret, responder_secret, transcript]); -let digest = Digest::from_image(payload); -let session_key = SymmetricKey::from_data(*digest.data()); -``` - - - -# message modalities -- request / response -- event: fire-and-forget or acked -- nack for structured failure - -```rust -pub enum MessageKind { - Request, - Response, - Event, - Nack, -} -``` - - - -# message body: routing + expiry -- message_id + route_id -- valid_until for freshness - -```rust -pub struct MessageBody { - pub message_id: MessageId, - pub valid_until: u64, - pub kind: MessageKind, - pub route_id: RouteId, - pub payload: CBOR, -} -``` - - - -# nack reasons -- unknown route / invalid payload / expired - -```rust -pub enum Nack { - Unknown, - UnknownRoute, - InvalidPayload, - Expired, -} -``` - - - -# type-safe routing -- route id is const per type -- compiler couples request -> response - -```rust -pub trait RequestResponse: QlCodec { - const ID: RouteId; - type Response: QlCodec; -} - -pub trait Event: QlCodec { - const ID: RouteId; -} -``` - - - -# router wiring -- builder ties route ids to handlers -- unknown routes auto-nack - -```rust -let router = Router::builder() - .add_request_handler::() - .add_event_handler::() - .build(state); -``` - - - -# runtime api flow -- request returns response or nack -- events are fire-and-forget (or acked) - -```rust -let reply = handle.request(msg, peer, RequestConfig::default()).await?; -handle.send_event(status, peer); -``` - - - -# performance snapshot (cbor sizes) -| proto | message | bytes | notes | -| :-- | :-- | --: | :-- | -| v1 | sealed msg (exchange_rate) | 6645 | sign+encrypt | -| v1 | sealed heartbeat | 6633 | sign+encrypt | -| v2 | hello | 132 | kem+nonce | -| v2 | hello reply | 2563 | sig+kem | -| v2 | confirm | 2510 | sig | -| v2 | pair request | 4065 | sig+kem | -| v2 | message (empty) | 199 | steady-state | -| v2 | heartbeat | 196 | steady-state | - -handshake total: 5205 bytes - - - -# close -- smaller packets, clearer flow, typed api -- ql v2 is the protocol, not just a crate diff --git a/ql/src/id.rs b/ql/src/id.rs deleted file mode 100644 index bc90db15..00000000 --- a/ql/src/id.rs +++ /dev/null @@ -1,51 +0,0 @@ -use std::fmt; - -use dcbor::CBOR; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct MessageId(pub u64); - -impl fmt::Display for MessageId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From for CBOR { - fn from(value: MessageId) -> Self { - CBOR::from(value.0) - } -} - -impl TryFrom for MessageId { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let value: u64 = value.try_into()?; - Ok(Self(value)) - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct RouteId(pub u64); - -impl fmt::Display for RouteId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From for CBOR { - fn from(value: RouteId) -> Self { - CBOR::from(value.0) - } -} - -impl TryFrom for RouteId { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let value: u64 = value.try_into()?; - Ok(Self(value)) - } -} diff --git a/ql/src/lib.rs b/ql/src/lib.rs deleted file mode 100644 index 6fc69d21..00000000 --- a/ql/src/lib.rs +++ /dev/null @@ -1,68 +0,0 @@ -mod id; -pub mod platform; -pub mod router; -pub mod runtime; -pub mod wire; - -pub use id::*; - -#[cfg(test)] -mod tests; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Peer { - pub peer: bc_components::XID, - pub signing_key: bc_components::MLDSAPublicKey, - pub encapsulation_key: bc_components::MLKEMPublicKey, -} - -pub trait QlCodec: Into + TryFrom {} -impl QlCodec for T where T: Into + TryFrom {} - -pub trait RequestResponse: QlCodec { - const ID: RouteId; - type Response: QlCodec; -} - -pub trait Event: QlCodec { - const ID: RouteId; -} - -pub trait QlStream: QlCodec { - const ID: RouteId; - type StreamMeta: QlCodec; -} - -pub trait QlUpload: QlCodec { - const ID: RouteId; - type Response: QlCodec; -} - -#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] -pub enum QlError { - #[error("invalid payload")] - InvalidPayload, - #[error("invalid handshake role")] - InvalidRole, - #[error("invalid signature")] - InvalidSignature, - #[error("missing session for {0}")] - MissingSession(bc_components::XID), - #[error("unknown peer {0}")] - UnknownPeer(bc_components::XID), - #[error("timeout")] - Timeout, - #[error("send failed")] - SendFailed, - #[error("nack {nack:?}")] - Nack { - id: MessageId, - nack: wire::message::Nack, - }, - #[error("cancelled")] - Cancelled, - #[error("transfer cancelled")] - TransferCancelled { id: MessageId }, - #[error("transfer protocol error")] - TransferProtocol { id: MessageId }, -} diff --git a/ql/src/platform.rs b/ql/src/platform.rs deleted file mode 100644 index be13e32f..00000000 --- a/ql/src/platform.rs +++ /dev/null @@ -1,37 +0,0 @@ -use std::{future::Future, pin::Pin, time::Duration}; - -use bc_components::{ - MLDSAPrivateKey, MLDSAPublicKey, MLKEMPrivateKey, MLKEMPublicKey, SigningPublicKey, XID, -}; - -use crate::{ - runtime::{HandlerEvent, PeerSession}, - Peer, QlError, -}; - -pub type PlatformFuture<'a, T> = Pin + 'a>>; - -pub trait QlPlatform { - fn signing_private_key(&self) -> &MLDSAPrivateKey; - fn signing_public_key(&self) -> &MLDSAPublicKey; - fn encapsulation_private_key(&self) -> &MLKEMPrivateKey; - fn encapsulation_public_key(&self) -> &MLKEMPublicKey; - - fn fill_random_bytes(&self, data: &mut [u8]); - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>>; - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()>; - - fn load_peers(&self) -> PlatformFuture<'_, Vec>; - fn persist_peers(&self, peers: Vec); - - fn handle_peer_status(&self, peer: XID, session: &PeerSession); - fn handle_inbound(&self, event: HandlerEvent); -} - -pub(crate) trait QlPlatformExt: QlPlatform { - fn xid(&self) -> XID { - XID::new(SigningPublicKey::MLDSA(self.signing_public_key().clone())) - } -} - -impl QlPlatformExt for T {} diff --git a/ql/src/router.rs b/ql/src/router.rs deleted file mode 100644 index 1f7c74a4..00000000 --- a/ql/src/router.rs +++ /dev/null @@ -1,377 +0,0 @@ -use std::collections::HashMap; - -use thiserror::Error; - -use crate::{ - runtime::{HandlerEvent, InboundByteStream, OutboundTransfer, Responder}, - wire::message::{Ack, Nack}, - Event, QlCodec, QlError, QlStream, QlUpload, RequestResponse, RouteId, -}; - -pub trait RequestHandler -where - M: RequestResponse, -{ - fn handle(&mut self, request: QlRequest); - fn default_response() -> M::Response; -} - -pub trait StreamRequestHandler -where - M: QlStream, -{ - fn handle(&mut self, request: QlStreamRequest); -} - -pub trait UploadRequestHandler -where - M: QlUpload, -{ - fn handle(&mut self, request: QlUploadRequest); - fn default_response() -> M::Response; -} - -pub trait EventHandler -where - M: Event, -{ - fn handle(&mut self, event: M); -} - -pub struct QlRequest -where - M: RequestResponse, -{ - pub message: M, - pub responder: QlResponder, -} - -pub struct QlStreamRequest -where - M: QlStream, -{ - pub message: M, - pub responder: QlStreamResponder, -} - -pub struct QlUploadRequest -where - M: QlUpload, -{ - pub message: M, - pub body: InboundByteStream, - pub responder: QlResponder, -} - -pub struct QlResponder -where - R: QlCodec, -{ - responder: Option, - default: fn() -> R, -} - -pub struct QlStreamResponder -where - M: QlCodec, -{ - responder: Option, - _meta: std::marker::PhantomData M>, -} - -impl QlResponder -where - R: QlCodec, -{ - pub fn respond(mut self, response: R) -> Result<(), QlError> { - self.respond_inner(response) - } - - pub fn respond_nack(mut self, reason: Nack) -> Result<(), QlError> { - let responder = self.responder.take().unwrap(); - responder.respond_nack(reason) - } - - fn respond_inner(&mut self, response: R) -> Result<(), QlError> { - let responder = self.responder.take().unwrap(); - responder.respond(response) - } -} - -impl QlStreamResponder -where - M: QlCodec, -{ - pub fn respond_stream(mut self, meta: M) -> Result { - let responder = self.responder.take().unwrap(); - responder.respond_stream(meta) - } - - pub fn respond_nack(mut self, reason: Nack) -> Result<(), QlError> { - let responder = self.responder.take().unwrap(); - responder.respond_nack(reason) - } -} - -impl Drop for QlResponder -where - R: QlCodec, -{ - fn drop(&mut self) { - if self.responder.is_some() { - let default = (self.default)(); - let _ = self.respond_inner(default); - } - } -} - -impl Drop for QlStreamResponder -where - M: QlCodec, -{ - fn drop(&mut self) { - if let Some(responder) = self.responder.take() { - let _ = responder.respond_nack(Nack::Unknown); - } - } -} - -#[derive(Debug, Error)] -pub enum RouterError { - #[error(transparent)] - Decode(#[from] dcbor::Error), - #[error("missing handler {0}")] - MissingHandler(RouteId), - #[error(transparent)] - Runtime(#[from] QlError), -} - -type RouterHandler = fn(&mut S, HandlerEvent) -> Result<(), RouterError>; - -pub struct RouterBuilder { - handlers: HashMap>, -} - -impl Default for RouterBuilder { - fn default() -> Self { - Self::new() - } -} - -impl RouterBuilder { - pub fn new() -> Self { - Self { - handlers: HashMap::new(), - } - } - - pub fn add_request_handler(self) -> Self - where - M: RequestResponse, - S: RequestHandler, - { - self.add_handler(M::ID, handle_request::) - } - - pub fn add_stream_request_handler(self) -> Self - where - M: QlStream, - S: StreamRequestHandler, - { - self.add_handler(M::ID, handle_stream_request::) - } - - pub fn add_upload_request_handler(self) -> Self - where - M: QlUpload, - S: UploadRequestHandler, - { - self.add_handler(M::ID, handle_upload_request::) - } - - pub fn add_event_handler(self) -> Self - where - M: Event, - S: EventHandler, - { - self.add_handler(M::ID, handle_event::) - } - - pub fn build(mut self, state: S) -> Router { - self.handlers.shrink_to_fit(); - Router { - handlers: self.handlers, - state, - } - } - - fn add_handler(mut self, id: RouteId, handler: RouterHandler) -> Self { - if self.handlers.insert(id, handler).is_some() { - panic!("duplicate route_id {id}"); - } - self - } -} - -pub struct Router { - state: S, - handlers: HashMap>, -} - -impl Router { - pub fn builder() -> RouterBuilder { - RouterBuilder::new() - } - - pub fn handle(&mut self, event: HandlerEvent) -> Result<(), RouterError> { - match event { - HandlerEvent::Request(request) => { - let route_id = request.message.route_id; - let handler = match self.handlers.get(&route_id) { - Some(handler) => handler, - None => { - let _ = request.respond_to.respond_nack(Nack::UnknownRoute); - return Ok(()); - } - }; - handler(&mut self.state, HandlerEvent::Request(request)) - } - HandlerEvent::UploadRequest(request) => { - let route_id = request.route_id; - let handler = match self.handlers.get(&route_id) { - Some(handler) => handler, - None => { - let _ = request.respond_to.respond_nack(Nack::UnknownRoute); - return Ok(()); - } - }; - handler(&mut self.state, HandlerEvent::UploadRequest(request)) - } - HandlerEvent::Event(event) => { - let route_id = event.message.route_id; - let handler = self - .handlers - .get(&route_id) - .ok_or(RouterError::MissingHandler(route_id))?; - handler(&mut self.state, HandlerEvent::Event(event)) - } - } - } -} - -fn handle_request(state: &mut S, event: HandlerEvent) -> Result<(), RouterError> -where - M: RequestResponse, - S: RequestHandler, -{ - let (payload, responder) = match event { - HandlerEvent::Request(request) => (request.message.payload, request.respond_to), - HandlerEvent::UploadRequest(request) => { - let _ = request.respond_to.respond_nack(Nack::InvalidPayload); - return Err(RouterError::Runtime(QlError::InvalidPayload)); - } - HandlerEvent::Event(_) => return Err(RouterError::Runtime(QlError::InvalidPayload)), - }; - let message = match M::try_from(payload) { - Ok(message) => message, - Err(error) => { - let _ = responder.respond_nack(Nack::InvalidPayload); - return Err(RouterError::Decode(error)); - } - }; - let responder = QlResponder { - responder: Some(responder), - default: S::default_response, - }; - state.handle(QlRequest { message, responder }); - Ok(()) -} - -fn handle_stream_request(state: &mut S, event: HandlerEvent) -> Result<(), RouterError> -where - M: QlStream, - S: StreamRequestHandler, -{ - let (payload, responder) = match event { - HandlerEvent::Request(request) => (request.message.payload, request.respond_to), - HandlerEvent::UploadRequest(request) => { - let _ = request.respond_to.respond_nack(Nack::InvalidPayload); - return Err(RouterError::Runtime(QlError::InvalidPayload)); - } - HandlerEvent::Event(_) => return Err(RouterError::Runtime(QlError::InvalidPayload)), - }; - let message = match M::try_from(payload) { - Ok(message) => message, - Err(error) => { - let _ = responder.respond_nack(Nack::InvalidPayload); - return Err(RouterError::Decode(error)); - } - }; - let responder = QlStreamResponder { - responder: Some(responder), - _meta: std::marker::PhantomData, - }; - state.handle(QlStreamRequest { message, responder }); - Ok(()) -} - -fn handle_upload_request(state: &mut S, event: HandlerEvent) -> Result<(), RouterError> -where - M: QlUpload, - S: UploadRequestHandler, -{ - let (meta, body, responder) = match event { - HandlerEvent::UploadRequest(request) => (request.meta, request.body, request.respond_to), - HandlerEvent::Request(request) => { - let _ = request.respond_to.respond_nack(Nack::InvalidPayload); - return Err(RouterError::Runtime(QlError::InvalidPayload)); - } - HandlerEvent::Event(_) => return Err(RouterError::Runtime(QlError::InvalidPayload)), - }; - let message = match M::try_from(meta) { - Ok(message) => message, - Err(error) => { - let _ = responder.respond_nack(Nack::InvalidPayload); - return Err(RouterError::Decode(error)); - } - }; - let responder = QlResponder { - responder: Some(responder), - default: S::default_response, - }; - state.handle(QlUploadRequest { - message, - body, - responder, - }); - Ok(()) -} - -fn handle_event(state: &mut S, event: HandlerEvent) -> Result<(), RouterError> -where - M: Event, - S: EventHandler, -{ - let (payload, responder) = match event { - HandlerEvent::Event(event) => (event.message.payload, None), - HandlerEvent::Request(request) => (request.message.payload, Some(request.respond_to)), - HandlerEvent::UploadRequest(request) => { - let _ = request.respond_to.respond_nack(Nack::InvalidPayload); - return Err(RouterError::Runtime(QlError::InvalidPayload)); - } - }; - let message = match M::try_from(payload) { - Ok(message) => message, - Err(error) => { - if let Some(responder) = responder { - let _ = responder.respond_nack(Nack::InvalidPayload); - } - return Err(RouterError::Decode(error)); - } - }; - state.handle(message); - if let Some(responder) = responder { - responder.respond(Ack)?; - } - Ok(()) -} diff --git a/ql/src/runtime/core.rs b/ql/src/runtime/core.rs deleted file mode 100644 index 6406905f..00000000 --- a/ql/src/runtime/core.rs +++ /dev/null @@ -1,2201 +0,0 @@ -use std::{ - cmp::Reverse, collections::binary_heap::PeekMut, future::Future, task::Poll, time::Instant, -}; - -use bc_components::{MLDSAPublicKey, MLKEMPublicKey, SigningPublicKey, XID}; -use dcbor::CBOR; -use futures_lite::future::poll_fn; - -use crate::{ - platform::{QlPlatform, QlPlatformExt}, - runtime::{ - internal::{ - next_timeout_deadline, now_secs, peer_hello_wins, HelloAction, InFlightWrite, - InboundStreamDelivery, InboundStreamItem, InboundTransferOpen, InboundTransferState, - KeepAliveState, LoopStep, OutboundAwaiting, OutboundMessage, OutboundPayload, - OutboundStreamInput, OutboundTransferStage, OutboundTransferState, PendingEntry, - PendingStreamEntry, RuntimeCommand, RuntimeState, TimeoutEntry, TimeoutKind, - }, - replay_cache::{ReplayKey, ReplayNamespace}, - HandlerEvent, InboundByteStream, InboundEvent, InboundRequest, InboundUploadRequest, - InitiatorStage, KeepAliveConfig, PeerSession, Responder, Runtime, Token, - }, - wire::{ - handshake::{self, HandshakeRecord}, - heartbeat::{self, HeartbeatBody}, - message::{self, MessageBody, MessageKind, Nack}, - pair::{self, PairRequestRecord}, - transfer::{self, TransferBody, TransferFrame}, - unpair::{self, UnpairRecord}, - QlHeader, QlPayload, QlRecord, - }, - MessageId, QlError, RouteId, -}; - -const TRANSFER_RETRY_LIMIT: u8 = 5; - -impl Runtime

{ - pub async fn run(self) { - let mut state = RuntimeState::new(); - for peer in self.platform.load_peers().await { - state - .peers - .upsert_peer(peer.peer, peer.signing_key, peer.encapsulation_key); - } - let mut in_flight: Option> = None; - while !self.rx.is_closed() { - self.drive_outbound_transfers(&mut state); - if in_flight.is_none() { - in_flight = self.start_next_write(&mut state); - } - let step = self.next_step(&state, in_flight.as_mut()).await; - match step { - LoopStep::Event(command) => match command { - RuntimeCommand::RegisterPeer { - peer, - signing_key, - encapsulation_key, - } => { - self.handle_register_peer(&mut state, peer, signing_key, encapsulation_key); - } - RuntimeCommand::Connect { peer } => { - self.handle_connect(&mut state, peer); - } - RuntimeCommand::Unpair { peer } => { - self.handle_send_unpair(&mut state, peer); - } - RuntimeCommand::SendRequest { - recipient, - route_id, - payload, - respond_to, - config, - } => { - self.handle_send_request( - &mut state, recipient, route_id, payload, respond_to, config, - ); - } - RuntimeCommand::SendStreamRequest { - recipient, - route_id, - payload, - respond_to, - config, - } => { - self.handle_send_stream_request( - &mut state, recipient, route_id, payload, respond_to, config, - ); - } - RuntimeCommand::SendUploadRequest { - recipient, - route_id, - payload, - respond_to, - chunk_rx, - start, - config, - } => { - self.handle_send_upload_request( - &mut state, recipient, route_id, payload, respond_to, chunk_rx, start, - config, - ); - } - RuntimeCommand::SendEvent { - recipient, - route_id, - payload, - } => { - self.handle_send_event(&mut state, recipient, route_id, payload); - } - RuntimeCommand::SendResponse { - id, - recipient, - payload, - kind, - } => { - self.handle_send_response(&mut state, id, recipient, payload, kind); - } - RuntimeCommand::StartResponseStream { - request_id, - recipient, - meta, - chunk_rx, - } => { - self.handle_start_response_stream( - &mut state, request_id, recipient, meta, chunk_rx, - ); - } - RuntimeCommand::PollOutboundTransfer { - recipient, - transfer_id, - } => { - self.drive_outbound_transfer(&mut state, recipient, transfer_id); - } - RuntimeCommand::CancelOutboundTransfer { - recipient, - transfer_id, - } => { - self.handle_cancel_outbound_transfer(&mut state, recipient, transfer_id); - } - RuntimeCommand::CancelInboundTransfer { - sender, - transfer_id, - } => { - self.handle_cancel_inbound_transfer(&mut state, sender, transfer_id); - } - RuntimeCommand::Incoming(bytes) => { - self.handle_incoming(&mut state, bytes); - } - }, - LoopStep::Timeout => { - self.handle_timeouts(&mut state); - } - LoopStep::WriteDone { - peer, - token, - message_id, - result, - } => { - in_flight = None; - self.handle_write_done(&mut state, peer, token, message_id, result); - } - LoopStep::Quit => break, - } - } - } - - fn start_next_write<'a>(&'a self, state: &mut RuntimeState) -> Option> { - while let Some(message) = state.outbound.pop_front() { - let bytes = match message.payload { - OutboundPayload::PreEncoded(bytes) => bytes, - OutboundPayload::DeferredMessage(body) => { - let Some(session_key) = state - .peers - .peer(message.peer) - .and_then(|entry| entry.session.session_key()) - else { - if let Some(id) = message.message_id { - if let Some(entry) = state.pending.remove(&id) { - let _ = entry.tx.send(Err(QlError::SendFailed)); - } - if let Some(entry) = state.pending_stream.remove(&id) { - let _ = entry.tx.send(Err(QlError::SendFailed)); - } - } - continue; - }; - let message = message::encrypt_message( - QlHeader { - sender: self.platform.xid(), - recipient: message.peer, - }, - session_key, - body, - ); - CBOR::from(message).to_cbor_data() - } - }; - return Some(InFlightWrite { - peer: message.peer, - token: message.token, - message_id: message.message_id, - future: self.platform.write_message(bytes), - }); - } - None - } - - async fn next_step<'a>( - &'a self, - state: &RuntimeState, - mut in_flight: Option<&mut InFlightWrite<'a>>, - ) -> LoopStep { - let recv_future = self.rx.recv(); - futures_lite::pin!(recv_future); - - let mut sleep_future = next_timeout_deadline(state).map(|deadline| { - let timeout = deadline.saturating_duration_since(Instant::now()); - self.platform.sleep(timeout) - }); - - poll_fn(|cx| { - if let Some(in_flight) = in_flight.as_mut() { - if let Poll::Ready(result) = in_flight.future.as_mut().poll(cx) { - return Poll::Ready(LoopStep::WriteDone { - peer: in_flight.peer, - token: in_flight.token, - message_id: in_flight.message_id, - result, - }); - } - } - - if let Some(future) = sleep_future.as_mut() { - if let Poll::Ready(()) = future.as_mut().poll(cx) { - return Poll::Ready(LoopStep::Timeout); - } - } - - recv_future.as_mut().poll(cx).map(|res| match res { - Ok(event) => LoopStep::Event(event), - Err(_) => LoopStep::Quit, - }) - }) - .await - } - - fn handle_connect(&self, state: &mut RuntimeState, peer: XID) { - let encapsulation_key = match state.peers.peer(peer) { - Some(entry) => match &entry.session { - PeerSession::Connected { .. } - | PeerSession::Initiator { .. } - | PeerSession::Responder { .. } => { - return; - } - PeerSession::Disconnected => entry.encapsulation_key.clone(), - }, - None => return, - }; - - let (hello, session_key) = match handshake::build_hello( - &self.platform, - self.platform.xid(), - peer, - &encapsulation_key, - ) { - Ok(result) => result, - Err(_) => return, - }; - - let deadline = Instant::now() + self.config.handshake_timeout; - let token = state.next_token(); - if let Some(entry) = state.peers.peer_mut(peer) { - entry.session = PeerSession::Initiator { - handshake_token: token, - hello: hello.clone(), - session_key, - deadline, - stage: InitiatorStage::WaitingHelloReply, - }; - self.platform.handle_peer_status(peer, &entry.session); - } - - let message = QlRecord { - header: QlHeader { - sender: self.platform.xid(), - recipient: peer, - }, - payload: QlPayload::Handshake(HandshakeRecord::Hello(hello)), - }; - let bytes = CBOR::from(message).to_cbor_data(); - self.enqueue_handshake_message(state, peer, token, deadline, bytes); - } - - fn handle_register_peer( - &self, - state: &mut RuntimeState, - peer: XID, - signing_key: MLDSAPublicKey, - encapsulation_key: MLKEMPublicKey, - ) { - { - let entry = state - .peers - .upsert_peer(peer, signing_key, encapsulation_key); - if let PeerSession::Disconnected = entry.session { - self.platform.handle_peer_status(peer, &entry.session); - } - } - self.persist_peers(state); - } - - fn handle_send_request( - &self, - state: &mut RuntimeState, - recipient: XID, - route_id: RouteId, - payload: CBOR, - respond_to: oneshot::Sender>, - config: super::RequestConfig, - ) { - let id = state.next_message_id(); - let timeout = config - .timeout - .unwrap_or(self.config.default_request_timeout); - if timeout.is_zero() { - let _ = respond_to.send(Err(QlError::Timeout)); - return; - } - let Some(entry) = state.peers.peer(recipient) else { - let _ = respond_to.send(Err(QlError::UnknownPeer(recipient))); - return; - }; - if !entry.session.is_connected() { - let _ = respond_to.send(Err(QlError::MissingSession(recipient))); - return; - } - let valid_until = now_secs().saturating_add(self.config.message_expiration.as_secs()); - let body = MessageBody { - message_id: id, - valid_until, - kind: MessageKind::Request, - route_id, - payload, - }; - state.pending.insert( - id, - PendingEntry { - recipient, - tx: respond_to, - }, - ); - state.timeouts.push(Reverse(TimeoutEntry { - at: Instant::now() + timeout, - kind: TimeoutKind::Request { id }, - })); - let outbound_deadline = Instant::now() + self.config.message_expiration; - self.enqueue_outbound( - state, - recipient, - OutboundPayload::DeferredMessage(body), - outbound_deadline, - Some(id), - ); - } - - fn handle_send_stream_request( - &self, - state: &mut RuntimeState, - recipient: XID, - route_id: RouteId, - payload: CBOR, - respond_to: oneshot::Sender>, - config: super::RequestConfig, - ) { - let id = state.next_message_id(); - let timeout = config - .timeout - .unwrap_or(self.config.default_request_timeout); - if timeout.is_zero() { - let _ = respond_to.send(Err(QlError::Timeout)); - return; - } - let Some(entry) = state.peers.peer(recipient) else { - let _ = respond_to.send(Err(QlError::UnknownPeer(recipient))); - return; - }; - if !entry.session.is_connected() { - let _ = respond_to.send(Err(QlError::MissingSession(recipient))); - return; - } - let valid_until = now_secs().saturating_add(self.config.message_expiration.as_secs()); - let body = MessageBody { - message_id: id, - valid_until, - kind: MessageKind::Request, - route_id, - payload, - }; - state.pending_stream.insert( - id, - PendingStreamEntry { - recipient, - tx: respond_to, - }, - ); - state.timeouts.push(Reverse(TimeoutEntry { - at: Instant::now() + timeout, - kind: TimeoutKind::Request { id }, - })); - let outbound_deadline = Instant::now() + self.config.message_expiration; - self.enqueue_outbound( - state, - recipient, - OutboundPayload::DeferredMessage(body), - outbound_deadline, - Some(id), - ); - } - - fn handle_send_upload_request( - &self, - state: &mut RuntimeState, - recipient: XID, - route_id: RouteId, - payload: CBOR, - respond_to: oneshot::Sender>, - chunk_rx: async_channel::Receiver, - start: oneshot::Sender>, - config: super::RequestConfig, - ) { - let timeout = config - .timeout - .unwrap_or(self.config.default_request_timeout); - if timeout.is_zero() { - let _ = start.send(Err(QlError::Timeout)); - return; - } - let Some(entry) = state.peers.peer(recipient) else { - let _ = start.send(Err(QlError::UnknownPeer(recipient))); - return; - }; - if !entry.session.is_connected() { - let _ = start.send(Err(QlError::MissingSession(recipient))); - return; - } - - let request_id = state.next_message_id(); - state.pending.insert( - request_id, - PendingEntry { - recipient, - tx: respond_to, - }, - ); - state.timeouts.push(Reverse(TimeoutEntry { - at: Instant::now() + timeout, - kind: TimeoutKind::Request { id: request_id }, - })); - - let transfer_id = request_id; - let key = (recipient, transfer_id); - if state.outbound_transfers.contains_key(&key) { - let _ = state.pending.remove(&request_id); - let _ = start.send(Err(QlError::SendFailed)); - return; - } - - state.outbound_transfers.insert( - key, - OutboundTransferState { - request_id, - peer: recipient, - transfer_id, - stage: OutboundTransferStage::Opening, - next_seq: 1, - open_route_id: Some(route_id), - open_meta: Some(payload), - chunk_rx, - awaiting: None, - }, - ); - - let _ = start.send(Ok(request_id)); - } - - fn handle_send_event( - &self, - state: &mut RuntimeState, - recipient: XID, - route_id: RouteId, - payload: CBOR, - ) { - let id = state.next_message_id(); - let Some(entry) = state.peers.peer(recipient) else { - return; - }; - if !entry.session.is_connected() { - return; - } - let valid_until = now_secs().saturating_add(self.config.message_expiration.as_secs()); - let body = MessageBody { - message_id: id, - valid_until, - kind: MessageKind::Event, - route_id, - payload, - }; - let outbound_deadline = Instant::now() + self.config.message_expiration; - self.enqueue_outbound( - state, - recipient, - OutboundPayload::DeferredMessage(body), - outbound_deadline, - None, - ); - } - - fn handle_send_response( - &self, - state: &mut RuntimeState, - id: MessageId, - recipient: XID, - payload: CBOR, - kind: MessageKind, - ) { - let kind = match kind { - MessageKind::Response | MessageKind::Nack => kind, - _ => return, - }; - let Some(entry) = state.peers.peer(recipient) else { - return; - }; - if !entry.session.is_connected() { - return; - } - - let valid_until = now_secs().saturating_add(self.config.message_expiration.as_secs()); - let body = MessageBody { - message_id: id, - valid_until, - kind, - route_id: RouteId(0), - payload, - }; - let outbound_deadline = Instant::now() + self.config.message_expiration; - self.enqueue_outbound( - state, - recipient, - OutboundPayload::DeferredMessage(body), - outbound_deadline, - None, - ); - } - - fn handle_start_response_stream( - &self, - state: &mut RuntimeState, - request_id: MessageId, - recipient: XID, - meta: CBOR, - chunk_rx: async_channel::Receiver, - ) { - if !matches!( - state.peers.peer(recipient), - Some(entry) if entry.session.is_connected() - ) { - return; - } - - let transfer_id = request_id; - let key = (recipient, transfer_id); - if state.outbound_transfers.contains_key(&key) { - return; - } - - state.outbound_transfers.insert( - key, - OutboundTransferState { - request_id, - peer: recipient, - transfer_id, - stage: OutboundTransferStage::Opening, - next_seq: 1, - open_route_id: None, - open_meta: Some(meta), - chunk_rx, - awaiting: None, - }, - ); - } - - fn handle_cancel_outbound_transfer( - &self, - state: &mut RuntimeState, - recipient: XID, - transfer_id: MessageId, - ) { - let key = (recipient, transfer_id); - let mut found = false; - if let Some(transfer) = state.outbound_transfers.get_mut(&key) { - found = true; - transfer.stage = OutboundTransferStage::Cancelling; - transfer.awaiting = None; - transfer.chunk_rx.close(); - } - if found { - self.drive_outbound_transfer(state, recipient, transfer_id); - } - } - - fn handle_cancel_inbound_transfer( - &self, - state: &mut RuntimeState, - sender: XID, - transfer_id: MessageId, - ) { - if state - .inbound_transfers - .remove(&(sender, transfer_id)) - .is_some() - { - self.send_transfer_frame(state, sender, transfer_id, TransferFrame::Cancel, false); - } - } - - fn handle_send_unpair(&self, state: &mut RuntimeState, peer: XID) { - if state.peers.peer(peer).is_none() { - return; - } - let message = unpair::build_unpair_record( - &self.platform, - QlHeader { - sender: self.platform.xid(), - recipient: peer, - }, - state.next_message_id(), - now_secs().saturating_add(self.config.message_expiration.as_secs()), - ); - let bytes = CBOR::from(message).to_cbor_data(); - self.unpair_peer(state, peer); - let deadline = Instant::now() + self.config.message_expiration; - self.enqueue_outbound( - state, - peer, - OutboundPayload::PreEncoded(bytes), - deadline, - None, - ); - } - - fn handle_incoming(&self, state: &mut RuntimeState, bytes: Vec) { - let Ok(record) = CBOR::try_from_data(&bytes).and_then(QlRecord::try_from) else { - return; - }; - let QlRecord { header, payload } = record; - if header.recipient != self.platform.xid() { - return; - } - match payload { - QlPayload::Handshake(message) => { - self.handle_handshake(state, header, message); - } - QlPayload::Pair(request) => { - self.handle_pairing(state, header, request); - } - QlPayload::Unpair(unpair) => { - self.handle_unpair(state, header, unpair); - } - QlPayload::Message(encrypted) => { - self.handle_record(state, header, encrypted); - } - QlPayload::Heartbeat(encrypted) => { - self.handle_heartbeat(state, header, encrypted); - } - QlPayload::Transfer(encrypted) => { - self.handle_transfer(state, header, encrypted); - } - } - } - - fn handle_handshake( - &self, - state: &mut RuntimeState, - header: QlHeader, - message: HandshakeRecord, - ) { - match message { - HandshakeRecord::Hello(hello) => { - self.handle_hello(state, header, hello); - } - HandshakeRecord::HelloReply(reply) => { - self.handle_hello_reply(state, header, reply); - } - HandshakeRecord::Confirm(confirm) => { - self.handle_confirm(state, header, confirm); - } - } - } - - fn handle_pairing( - &self, - state: &mut RuntimeState, - header: QlHeader, - request: PairRequestRecord, - ) { - let payload = match pair::decrypt_pair_request(&self.platform, &header, request) { - Ok(payload) => payload, - Err(_) => return, - }; - let peer = XID::new(SigningPublicKey::MLDSA(payload.signing_pub_key.clone())); - state - .peers - .upsert_peer(peer, payload.signing_pub_key, payload.encapsulation_pub_key); - self.persist_peers(state); - self.handle_connect(state, peer); - } - - fn handle_unpair(&self, state: &mut RuntimeState, header: QlHeader, record: UnpairRecord) { - let peer = header.sender; - let Some(signing_key) = state - .peers - .peer(peer) - .map(|entry| entry.signing_key.clone()) - else { - return; - }; - if unpair::verify_unpair_record(&header, &record, &signing_key).is_err() { - return; - } - let replay_key = ReplayKey::new(peer, ReplayNamespace::Peer, record.message_id); - if state - .replay_cache - .check_and_store_valid_until(replay_key, record.valid_until) - { - return; - } - self.unpair_peer(state, peer); - } - - fn unpair_peer(&self, state: &mut RuntimeState, peer: XID) { - if state.peers.remove_peer(peer).is_none() { - return; - } - self.drop_outbound_for_peer(state, peer); - self.fail_pending_for_peer(state, peer); - self.fail_pending_stream_for_peer(state, peer); - self.abort_transfers_for_peer(state, peer, QlError::SendFailed); - state.replay_cache.clear_peer(peer); - self.platform - .handle_peer_status(peer, &PeerSession::Disconnected); - self.persist_peers(state); - } - - fn persist_peers(&self, state: &RuntimeState) { - self.platform.persist_peers(state.peers.all()); - } - - fn handle_record( - &self, - state: &mut RuntimeState, - header: QlHeader, - encrypted: bc_components::EncryptedMessage, - ) { - let peer = header.sender; - let session_key = match state.peers.peer(peer) { - Some(entry) => match &entry.session { - PeerSession::Connected { session_key, .. } => session_key.clone(), - _ => return, - }, - None => return, - }; - let record = match message::decrypt_message(&header, &encrypted, &session_key) { - Ok(record) => record, - Err(message::MessageError::Nack { id, nack, kind }) => { - self.handle_message_nack(state, peer, id, nack, kind); - return; - } - Err(message::MessageError::Error(_)) => return, - }; - let namespace = match record.kind { - MessageKind::Request | MessageKind::Event => ReplayNamespace::Peer, - MessageKind::Response | MessageKind::Nack => ReplayNamespace::Local, - }; - let replay_key = ReplayKey::new(peer, namespace, record.message_id); - if state - .replay_cache - .check_and_store_valid_until(replay_key, record.valid_until) - { - return; - } - self.record_activity(state, peer); - match record.kind { - MessageKind::Response => { - self.resolve_pending_ok(state, peer, record.message_id, record.payload); - } - MessageKind::Nack => { - let nack = Nack::from(record.payload); - self.resolve_pending_nack(state, peer, record.message_id, nack); - } - MessageKind::Request => { - let Some(tx) = self.tx.upgrade() else { - return; - }; - let responder = Responder::new(record.message_id, record.sender, tx); - self.platform - .handle_inbound(HandlerEvent::Request(InboundRequest { - message: record, - respond_to: responder, - })); - } - MessageKind::Event => { - self.platform - .handle_inbound(HandlerEvent::Event(InboundEvent { message: record })); - } - } - } - - fn handle_message_nack( - &self, - state: &mut RuntimeState, - peer: XID, - id: MessageId, - nack: Nack, - kind: MessageKind, - ) { - if kind != MessageKind::Request { - return; - } - self.handle_send_response(state, id, peer, CBOR::from(nack), MessageKind::Nack); - } - - fn handle_heartbeat( - &self, - state: &mut RuntimeState, - header: QlHeader, - encrypted: bc_components::EncryptedMessage, - ) { - let peer = header.sender; - let (session_key, should_reply) = { - let Some(entry) = state.peers.peer(peer) else { - return; - }; - match &entry.session { - PeerSession::Connected { - session_key, - keepalive, - } => (session_key.clone(), !keepalive.pending), - _ => return, - } - }; - if heartbeat::decrypt_heartbeat(&header, &encrypted, &session_key).is_err() { - return; - } - self.record_activity(state, peer); - if should_reply { - self.send_heartbeat_message(state, peer, session_key); - } - } - - fn handle_transfer( - &self, - state: &mut RuntimeState, - header: QlHeader, - encrypted: bc_components::EncryptedMessage, - ) { - let peer = header.sender; - let session_key = match state.peers.peer(peer) { - Some(entry) => match &entry.session { - PeerSession::Connected { session_key, .. } => session_key.clone(), - _ => return, - }, - None => return, - }; - let body = match transfer::decrypt_transfer(&header, &encrypted, &session_key) { - Ok(body) => body, - Err(_) => return, - }; - - let replay_key = ReplayKey::new(peer, ReplayNamespace::Transfer, body.message_id); - if state - .replay_cache - .check_and_store_valid_until(replay_key, body.valid_until) - { - return; - } - - self.record_activity(state, peer); - self.handle_transfer_frame(state, peer, body.transfer_id, body.frame); - } - - fn handle_transfer_frame( - &self, - state: &mut RuntimeState, - peer: XID, - transfer_id: MessageId, - frame: TransferFrame, - ) { - match frame { - TransferFrame::OpenResponse { request_id, meta } => { - self.handle_transfer_open_response(state, peer, transfer_id, request_id, meta); - } - TransferFrame::OpenRequest { - request_id, - route_id, - meta, - } => { - self.handle_transfer_open_request( - state, - peer, - transfer_id, - request_id, - route_id, - meta, - ); - } - TransferFrame::Chunk { seq, data } => { - self.handle_transfer_chunk(state, peer, transfer_id, seq, data); - } - TransferFrame::Finish { seq } => { - self.handle_transfer_finish(state, peer, transfer_id, seq); - } - TransferFrame::Ack { next_seq } => { - self.handle_transfer_ack(state, peer, transfer_id, next_seq); - } - TransferFrame::Cancel => { - self.handle_transfer_cancel(state, peer, transfer_id); - } - TransferFrame::CancelAck => { - self.handle_transfer_cancel_ack(state, peer, transfer_id); - } - } - } - - fn handle_transfer_open_response( - &self, - state: &mut RuntimeState, - peer: XID, - transfer_id: MessageId, - request_id: MessageId, - meta: CBOR, - ) { - let open = InboundTransferOpen::Response { - request_id, - meta: meta.clone(), - }; - if self.handle_duplicate_transfer_open(state, peer, transfer_id, &open) { - return; - } - - let Some(pending) = state.pending_stream.remove(&request_id) else { - self.send_transfer_frame(state, peer, transfer_id, TransferFrame::Cancel, true); - return; - }; - if pending.recipient != peer { - let _ = pending.tx.send(Err(QlError::SendFailed)); - self.send_transfer_frame(state, peer, transfer_id, TransferFrame::Cancel, true); - return; - } - - let Some(tx) = self.tx.upgrade() else { - let _ = pending.tx.send(Err(QlError::Cancelled)); - return; - }; - - let (chunk_tx, chunk_rx) = async_channel::bounded(1); - - let delivery = InboundStreamDelivery { - peer, - transfer_id, - meta, - rx: chunk_rx, - tx, - }; - if pending.tx.send(Ok(delivery)).is_err() { - self.send_transfer_frame(state, peer, transfer_id, TransferFrame::Cancel, true); - return; - } - - state.inbound_transfers.insert( - (peer, transfer_id), - InboundTransferState { - open, - expected_seq: 1, - chunk_tx, - }, - ); - - self.send_transfer_frame( - state, - peer, - transfer_id, - TransferFrame::Ack { next_seq: 1 }, - true, - ); - } - - fn handle_transfer_open_request( - &self, - state: &mut RuntimeState, - peer: XID, - transfer_id: MessageId, - request_id: MessageId, - route_id: RouteId, - meta: CBOR, - ) { - let open = InboundTransferOpen::Request { - request_id, - route_id, - meta: meta.clone(), - }; - if self.handle_duplicate_transfer_open(state, peer, transfer_id, &open) { - return; - } - - let Some(tx) = self.tx.upgrade() else { - self.send_transfer_frame(state, peer, transfer_id, TransferFrame::Cancel, true); - return; - }; - - let (chunk_tx, chunk_rx) = async_channel::bounded(1); - let responder = Responder::new(request_id, peer, tx.clone()); - let body = InboundByteStream::new(peer, transfer_id, chunk_rx, tx); - self.platform - .handle_inbound(HandlerEvent::UploadRequest(InboundUploadRequest { - sender: peer, - recipient: self.platform.xid(), - route_id, - message_id: request_id, - meta, - body, - respond_to: responder, - })); - - state.inbound_transfers.insert( - (peer, transfer_id), - InboundTransferState { - open, - expected_seq: 1, - chunk_tx, - }, - ); - - self.send_transfer_frame( - state, - peer, - transfer_id, - TransferFrame::Ack { next_seq: 1 }, - true, - ); - } - - fn handle_duplicate_transfer_open( - &self, - state: &mut RuntimeState, - peer: XID, - transfer_id: MessageId, - open: &InboundTransferOpen, - ) -> bool { - let key = (peer, transfer_id); - let Some(existing) = state.inbound_transfers.get(&key) else { - return false; - }; - - let frame = if &existing.open == open { - TransferFrame::Ack { next_seq: 1 } - } else { - TransferFrame::Cancel - }; - self.send_transfer_frame(state, peer, transfer_id, frame, true); - true - } - - fn handle_transfer_chunk( - &self, - state: &mut RuntimeState, - peer: XID, - transfer_id: MessageId, - seq: u32, - data: Vec, - ) { - let key = (peer, transfer_id); - let Some(mut transfer_state) = state.inbound_transfers.remove(&key) else { - return; - }; - - if seq < transfer_state.expected_seq { - self.send_transfer_frame( - state, - peer, - transfer_id, - TransferFrame::Ack { - next_seq: transfer_state.expected_seq, - }, - true, - ); - state.inbound_transfers.insert(key, transfer_state); - return; - } - - if seq > transfer_state.expected_seq { - let _ = transfer_state.chunk_tx.try_send(InboundStreamItem::Error( - QlError::TransferProtocol { id: transfer_id }, - )); - transfer_state.chunk_tx.close(); - self.send_transfer_frame(state, peer, transfer_id, TransferFrame::Cancel, true); - return; - } - - match transfer_state - .chunk_tx - .try_send(InboundStreamItem::Chunk(data)) - { - Ok(()) => { - transfer_state.expected_seq = transfer_state.expected_seq.saturating_add(1); - self.send_transfer_frame( - state, - peer, - transfer_id, - TransferFrame::Ack { - next_seq: transfer_state.expected_seq, - }, - true, - ); - state.inbound_transfers.insert(key, transfer_state); - } - Err(async_channel::TrySendError::Full(_)) => { - state.inbound_transfers.insert(key, transfer_state); - } - Err(async_channel::TrySendError::Closed(_)) => { - self.send_transfer_frame(state, peer, transfer_id, TransferFrame::Cancel, true); - } - } - } - - fn handle_transfer_finish( - &self, - state: &mut RuntimeState, - peer: XID, - transfer_id: MessageId, - seq: u32, - ) { - let key = (peer, transfer_id); - let Some(mut transfer_state) = state.inbound_transfers.remove(&key) else { - return; - }; - - if seq < transfer_state.expected_seq { - self.send_transfer_frame( - state, - peer, - transfer_id, - TransferFrame::Ack { - next_seq: transfer_state.expected_seq, - }, - true, - ); - state.inbound_transfers.insert(key, transfer_state); - return; - } - - if seq > transfer_state.expected_seq { - let _ = transfer_state.chunk_tx.try_send(InboundStreamItem::Error( - QlError::TransferProtocol { id: transfer_id }, - )); - transfer_state.chunk_tx.close(); - self.send_transfer_frame(state, peer, transfer_id, TransferFrame::Cancel, true); - return; - } - - match transfer_state - .chunk_tx - .try_send(InboundStreamItem::Finished) - { - Ok(()) => { - transfer_state.expected_seq = transfer_state.expected_seq.saturating_add(1); - transfer_state.chunk_tx.close(); - self.send_transfer_frame( - state, - peer, - transfer_id, - TransferFrame::Ack { - next_seq: transfer_state.expected_seq, - }, - true, - ); - } - Err(async_channel::TrySendError::Full(_)) => { - state.inbound_transfers.insert(key, transfer_state); - } - Err(async_channel::TrySendError::Closed(_)) => { - self.send_transfer_frame(state, peer, transfer_id, TransferFrame::Cancel, true); - } - } - } - - fn handle_transfer_ack( - &self, - state: &mut RuntimeState, - peer: XID, - transfer_id: MessageId, - next_seq: u32, - ) { - let key = (peer, transfer_id); - let Some(mut transfer_state) = state.outbound_transfers.remove(&key) else { - return; - }; - - let matched = match transfer_state.awaiting.as_ref() { - Some(OutboundAwaiting::Open { .. }) => next_seq == 1, - Some(OutboundAwaiting::Chunk { seq, .. }) => next_seq == seq.saturating_add(1), - Some(OutboundAwaiting::Finish { seq }) => next_seq == seq.saturating_add(1), - Some(OutboundAwaiting::Cancel) | None => false, - }; - if !matched { - state.outbound_transfers.insert(key, transfer_state); - return; - } - - match transfer_state.awaiting.take() { - Some(OutboundAwaiting::Open { .. }) => { - transfer_state.stage = OutboundTransferStage::Streaming; - state.outbound_transfers.insert(key, transfer_state); - } - Some(OutboundAwaiting::Chunk { seq, .. }) => { - transfer_state.next_seq = seq.saturating_add(1); - transfer_state.stage = OutboundTransferStage::Streaming; - state.outbound_transfers.insert(key, transfer_state); - } - Some(OutboundAwaiting::Finish { .. }) => { - transfer_state.chunk_rx.close(); - } - Some(OutboundAwaiting::Cancel) | None => { - state.outbound_transfers.insert(key, transfer_state); - } - } - } - - fn handle_transfer_cancel(&self, state: &mut RuntimeState, peer: XID, transfer_id: MessageId) { - let key = (peer, transfer_id); - let mut acknowledged = false; - - if let Some(transfer_state) = state.outbound_transfers.remove(&key) { - transfer_state.chunk_rx.close(); - acknowledged = true; - } - - if let Some(transfer_state) = state.inbound_transfers.remove(&key) { - let error = QlError::TransferCancelled { id: transfer_id }; - let _ = transfer_state - .chunk_tx - .try_send(InboundStreamItem::Error(error)); - transfer_state.chunk_tx.close(); - acknowledged = true; - } - - if acknowledged { - self.send_transfer_frame(state, peer, transfer_id, TransferFrame::CancelAck, true); - } - } - - fn handle_transfer_cancel_ack( - &self, - state: &mut RuntimeState, - peer: XID, - transfer_id: MessageId, - ) { - let key = (peer, transfer_id); - let Some(transfer_state) = state.outbound_transfers.remove(&key) else { - return; - }; - if !matches!(transfer_state.awaiting, Some(OutboundAwaiting::Cancel)) { - state.outbound_transfers.insert(key, transfer_state); - return; - } - - transfer_state.chunk_rx.close(); - } - - fn drive_outbound_transfers(&self, state: &mut RuntimeState) { - let keys: Vec<(XID, MessageId)> = state.outbound_transfers.keys().copied().collect(); - for (peer, transfer_id) in keys { - self.drive_outbound_transfer(state, peer, transfer_id); - } - } - - fn drive_outbound_transfer(&self, state: &mut RuntimeState, peer: XID, transfer_id: MessageId) { - let key = (peer, transfer_id); - let Some(mut transfer_state) = state.outbound_transfers.remove(&key) else { - return; - }; - - if transfer_state.awaiting.is_some() { - state.outbound_transfers.insert(key, transfer_state); - return; - } - - match transfer_state.stage { - OutboundTransferStage::Opening => { - let Some(meta) = transfer_state.open_meta.take() else { - transfer_state.chunk_rx.close(); - return; - }; - let awaiting = OutboundAwaiting::Open { - request_id: transfer_state.request_id, - route_id: transfer_state.open_route_id, - meta, - }; - if self.send_outbound_awaiting(state, &mut transfer_state, awaiting, 0) { - state.outbound_transfers.insert(key, transfer_state); - } - } - OutboundTransferStage::Streaming => match transfer_state.chunk_rx.try_recv() { - Ok(OutboundStreamInput::Chunk(data)) => { - let seq = transfer_state.next_seq; - let awaiting = OutboundAwaiting::Chunk { seq, data }; - if self.send_outbound_awaiting(state, &mut transfer_state, awaiting, 0) { - state.outbound_transfers.insert(key, transfer_state); - } - } - Ok(OutboundStreamInput::Finish) => { - let seq = transfer_state.next_seq; - transfer_state.stage = OutboundTransferStage::Finishing; - let awaiting = OutboundAwaiting::Finish { seq }; - if self.send_outbound_awaiting(state, &mut transfer_state, awaiting, 0) { - state.outbound_transfers.insert(key, transfer_state); - } - } - Err(async_channel::TryRecvError::Empty) => { - state.outbound_transfers.insert(key, transfer_state); - } - Err(async_channel::TryRecvError::Closed) => { - transfer_state.stage = OutboundTransferStage::Cancelling; - let awaiting = OutboundAwaiting::Cancel; - if self.send_outbound_awaiting(state, &mut transfer_state, awaiting, 0) { - state.outbound_transfers.insert(key, transfer_state); - } - } - }, - OutboundTransferStage::Finishing => { - state.outbound_transfers.insert(key, transfer_state); - } - OutboundTransferStage::Cancelling => { - let awaiting = OutboundAwaiting::Cancel; - if self.send_outbound_awaiting(state, &mut transfer_state, awaiting, 0) { - state.outbound_transfers.insert(key, transfer_state); - } - } - } - } - - fn send_outbound_awaiting( - &self, - state: &mut RuntimeState, - transfer_state: &mut OutboundTransferState, - awaiting: OutboundAwaiting, - attempt: u8, - ) -> bool { - let frame = match &awaiting { - OutboundAwaiting::Open { - request_id, - route_id, - meta, - } => match route_id { - Some(route_id) => TransferFrame::OpenRequest { - request_id: *request_id, - route_id: *route_id, - meta: meta.clone(), - }, - None => TransferFrame::OpenResponse { - request_id: *request_id, - meta: meta.clone(), - }, - }, - OutboundAwaiting::Chunk { seq, data } => TransferFrame::Chunk { - seq: *seq, - data: data.clone(), - }, - OutboundAwaiting::Finish { seq } => TransferFrame::Finish { seq: *seq }, - OutboundAwaiting::Cancel => TransferFrame::Cancel, - }; - - let priority = matches!(awaiting, OutboundAwaiting::Cancel); - if !self.send_transfer_frame( - state, - transfer_state.peer, - transfer_state.transfer_id, - frame, - priority, - ) { - transfer_state.chunk_rx.close(); - return false; - } - - transfer_state.awaiting = Some(awaiting); - let at = Instant::now() + self.transfer_ack_timeout(); - match transfer_state.awaiting.as_ref() { - Some(OutboundAwaiting::Open { .. }) => state.timeouts.push(Reverse(TimeoutEntry { - at, - kind: TimeoutKind::TransferAck { - peer: transfer_state.peer, - transfer_id: transfer_state.transfer_id, - next_seq: 1, - attempt, - }, - })), - Some(OutboundAwaiting::Chunk { seq, .. }) => { - state.timeouts.push(Reverse(TimeoutEntry { - at, - kind: TimeoutKind::TransferAck { - peer: transfer_state.peer, - transfer_id: transfer_state.transfer_id, - next_seq: seq.saturating_add(1), - attempt, - }, - })) - } - Some(OutboundAwaiting::Finish { seq }) => state.timeouts.push(Reverse(TimeoutEntry { - at, - kind: TimeoutKind::TransferAck { - peer: transfer_state.peer, - transfer_id: transfer_state.transfer_id, - next_seq: seq.saturating_add(1), - attempt, - }, - })), - Some(OutboundAwaiting::Cancel) => state.timeouts.push(Reverse(TimeoutEntry { - at, - kind: TimeoutKind::TransferCancelAck { - peer: transfer_state.peer, - transfer_id: transfer_state.transfer_id, - attempt, - }, - })), - None => {} - } - - true - } - - fn send_transfer_frame( - &self, - state: &mut RuntimeState, - peer: XID, - transfer_id: MessageId, - frame: TransferFrame, - priority: bool, - ) -> bool { - let Some(session_key) = state - .peers - .peer(peer) - .and_then(|entry| entry.session.session_key()) - .cloned() - else { - return false; - }; - - let body = TransferBody { - message_id: state.next_message_id(), - valid_until: now_secs().saturating_add(self.config.message_expiration.as_secs()), - transfer_id, - frame, - }; - let record = transfer::encrypt_transfer( - QlHeader { - sender: self.platform.xid(), - recipient: peer, - }, - &session_key, - body, - ); - let bytes = CBOR::from(record).to_cbor_data(); - self.enqueue_outbound_preencoded( - state, - peer, - bytes, - Instant::now() + self.config.message_expiration, - priority, - ); - true - } - - fn transfer_ack_timeout(&self) -> std::time::Duration { - if self.config.default_request_timeout.is_zero() { - std::time::Duration::from_millis(200) - } else { - self.config.default_request_timeout - } - } - - fn handle_transfer_ack_timeout( - &self, - state: &mut RuntimeState, - peer: XID, - transfer_id: MessageId, - next_seq: u32, - attempt: u8, - ) { - let key = (peer, transfer_id); - let Some(mut transfer_state) = state.outbound_transfers.remove(&key) else { - return; - }; - - let expected = match transfer_state.awaiting.as_ref() { - Some(OutboundAwaiting::Open { .. }) => Some(1), - Some(OutboundAwaiting::Chunk { seq, .. }) => Some(seq.saturating_add(1)), - Some(OutboundAwaiting::Finish { seq }) => Some(seq.saturating_add(1)), - _ => None, - }; - if expected != Some(next_seq) { - state.outbound_transfers.insert(key, transfer_state); - return; - } - - if attempt >= TRANSFER_RETRY_LIMIT { - transfer_state.chunk_rx.close(); - return; - } - - let Some(awaiting) = transfer_state.awaiting.take() else { - state.outbound_transfers.insert(key, transfer_state); - return; - }; - if self.send_outbound_awaiting(state, &mut transfer_state, awaiting, attempt + 1) { - state.outbound_transfers.insert(key, transfer_state); - } - } - - fn handle_transfer_cancel_ack_timeout( - &self, - state: &mut RuntimeState, - peer: XID, - transfer_id: MessageId, - attempt: u8, - ) { - let key = (peer, transfer_id); - let Some(mut transfer_state) = state.outbound_transfers.remove(&key) else { - return; - }; - - if !matches!(transfer_state.awaiting, Some(OutboundAwaiting::Cancel)) { - state.outbound_transfers.insert(key, transfer_state); - return; - } - - if attempt >= TRANSFER_RETRY_LIMIT { - transfer_state.chunk_rx.close(); - return; - } - - transfer_state.awaiting = None; - if self.send_outbound_awaiting( - state, - &mut transfer_state, - OutboundAwaiting::Cancel, - attempt + 1, - ) { - state.outbound_transfers.insert(key, transfer_state); - } - } - - fn send_heartbeat_message( - &self, - state: &mut RuntimeState, - peer: XID, - session_key: bc_components::SymmetricKey, - ) { - let message_id = state.next_message_id(); - let valid_until = now_secs().saturating_add(self.config.message_expiration.as_secs()); - let message = heartbeat::encrypt_heartbeat( - QlHeader { - sender: self.platform.xid(), - recipient: peer, - }, - &session_key, - HeartbeatBody { - message_id, - valid_until, - }, - ); - let bytes = CBOR::from(message).to_cbor_data(); - let outbound_deadline = Instant::now() + self.config.message_expiration; - self.enqueue_outbound( - state, - peer, - OutboundPayload::PreEncoded(bytes), - outbound_deadline, - None, - ); - } - - fn keep_alive_config(&self) -> Option { - self.config - .keep_alive - .filter(|config| !config.interval.is_zero() && !config.timeout.is_zero()) - } - - fn record_activity(&self, state: &mut RuntimeState, peer: XID) { - let Some(config) = self.keep_alive_config() else { - return; - }; - let token = state.next_token(); - let Some(entry) = state.peers.peer_mut(peer) else { - return; - }; - let PeerSession::Connected { keepalive, .. } = &mut entry.session else { - return; - }; - let now = Instant::now(); - keepalive.last_activity = Some(now); - keepalive.pending = false; - keepalive.token = token; - state.timeouts.push(Reverse(TimeoutEntry { - at: now + config.interval, - kind: TimeoutKind::KeepAliveSend { peer, token }, - })); - } - - fn drop_outbound_for_peer(&self, state: &mut RuntimeState, peer: XID) { - state.outbound.retain(|message| { - if message.peer == peer { - if let Some(id) = message.message_id { - if let Some(entry) = state.pending.remove(&id) { - let _ = entry.tx.send(Err(QlError::SendFailed)); - } - if let Some(entry) = state.pending_stream.remove(&id) { - let _ = entry.tx.send(Err(QlError::SendFailed)); - } - } - false - } else { - true - } - }); - } - - fn fail_pending_for_peer(&self, state: &mut RuntimeState, peer: XID) { - state - .pending - .extract_if(|_id, entry| entry.recipient == peer) - .for_each(|(_, entry)| { - let _ = entry.tx.send(Err(QlError::SendFailed)); - }); - } - - fn fail_pending_stream_for_peer(&self, state: &mut RuntimeState, peer: XID) { - state - .pending_stream - .extract_if(|_id, entry| entry.recipient == peer) - .for_each(|(_, entry)| { - let _ = entry.tx.send(Err(QlError::SendFailed)); - }); - } - - fn abort_transfers_for_peer(&self, state: &mut RuntimeState, peer: XID, error: QlError) { - state - .outbound_transfers - .extract_if(|(transfer_peer, _), _| *transfer_peer == peer) - .for_each(|(_, transfer_state)| { - transfer_state.chunk_rx.close(); - }); - - state - .inbound_transfers - .extract_if(|(transfer_peer, _), _| *transfer_peer == peer) - .for_each(|(_, transfer_state)| { - let _ = transfer_state - .chunk_tx - .try_send(InboundStreamItem::Error(error.clone())); - transfer_state.chunk_tx.close(); - }); - } - - fn resolve_pending_ok( - &self, - state: &mut RuntimeState, - sender: XID, - id: MessageId, - payload: CBOR, - ) { - if let Some(entry) = state.pending.remove(&id) { - if entry.recipient == sender { - let _ = entry.tx.send(Ok(payload)); - } - return; - } - if let Some(entry) = state.pending_stream.remove(&id) { - if entry.recipient == sender { - let _ = entry.tx.send(Err(QlError::InvalidPayload)); - } - } - } - - fn resolve_pending_nack( - &self, - state: &mut RuntimeState, - sender: XID, - id: MessageId, - nack: Nack, - ) { - if let Some(entry) = state.pending.remove(&id) { - if entry.recipient == sender { - let _ = entry.tx.send(Err(QlError::Nack { id, nack })); - } - return; - } - if let Some(entry) = state.pending_stream.remove(&id) { - if entry.recipient == sender { - let _ = entry.tx.send(Err(QlError::Nack { id, nack })); - } - } - } - - fn handle_hello( - &self, - state: &mut RuntimeState, - header: QlHeader, - hello: crate::wire::handshake::Hello, - ) { - let peer = header.sender; - let action = match state.peers.peer(peer) { - Some(entry) => match &entry.session { - PeerSession::Initiator { - hello: local_hello, .. - } => { - if peer_hello_wins(local_hello, self.platform.xid(), &hello, peer) { - HelloAction::StartResponder - } else { - HelloAction::Ignore - } - } - PeerSession::Responder { - hello: stored, - reply, - deadline, - .. - } => { - if stored.nonce == hello.nonce { - HelloAction::ResendReply { - reply: reply.clone(), - deadline: *deadline, - } - } else { - HelloAction::StartResponder - } - } - PeerSession::Disconnected | PeerSession::Connected { .. } => { - HelloAction::StartResponder - } - }, - None => return, - }; - - match action { - HelloAction::StartResponder => { - self.start_responder_handshake(state, peer, hello); - } - HelloAction::ResendReply { reply, deadline } => { - let message = QlRecord { - header: QlHeader { - sender: self.platform.xid(), - recipient: peer, - }, - payload: QlPayload::Handshake(HandshakeRecord::HelloReply(reply)), - }; - let bytes = CBOR::from(message).to_cbor_data(); - self.enqueue_outbound( - state, - peer, - OutboundPayload::PreEncoded(bytes), - deadline, - None, - ); - } - HelloAction::Ignore => {} - } - } - - fn handle_hello_reply( - &self, - state: &mut RuntimeState, - header: QlHeader, - reply: crate::wire::handshake::HelloReply, - ) { - let peer = header.sender; - let (hello, initiator_secret, stage, responder_signing_key) = match state.peers.peer(peer) { - Some(entry) => match &entry.session { - PeerSession::Initiator { - hello, - session_key, - stage, - .. - } => ( - hello.clone(), - session_key.clone(), - *stage, - entry.signing_key.clone(), - ), - _ => return, - }, - None => return, - }; - - if stage != InitiatorStage::WaitingHelloReply { - return; - } - - let confirm = match handshake::build_confirm( - &self.platform, - self.platform.xid(), - peer, - &responder_signing_key, - &hello, - &reply, - &initiator_secret, - ) { - Ok((confirm, session_key)) => { - if let Some(entry) = state.peers.peer_mut(peer) { - entry.session = PeerSession::Connected { - session_key, - keepalive: KeepAliveState::new(), - }; - self.platform.handle_peer_status(peer, &entry.session); - } - self.record_activity(state, peer); - confirm - } - Err(_) => { - if let Some(entry) = state.peers.peer_mut(peer) { - entry.session = PeerSession::Disconnected; - self.platform.handle_peer_status(peer, &entry.session); - } - return; - } - }; - - let message = QlRecord { - header: QlHeader { - sender: self.platform.xid(), - recipient: peer, - }, - payload: QlPayload::Handshake(HandshakeRecord::Confirm(confirm)), - }; - let bytes = CBOR::from(message).to_cbor_data(); - let deadline = Instant::now() + self.config.handshake_timeout; - self.enqueue_outbound( - state, - peer, - OutboundPayload::PreEncoded(bytes), - deadline, - None, - ); - } - - fn handle_confirm( - &self, - state: &mut RuntimeState, - header: QlHeader, - confirm: crate::wire::handshake::Confirm, - ) { - let peer = header.sender; - let (hello, reply, secrets, initiator_signing_key) = match state.peers.peer(peer) { - Some(entry) => match &entry.session { - PeerSession::Responder { - hello, - reply, - secrets, - .. - } => ( - hello.clone(), - reply.clone(), - secrets.clone(), - entry.signing_key.clone(), - ), - _ => return, - }, - None => return, - }; - - match handshake::finalize_confirm( - peer, - self.platform.xid(), - &initiator_signing_key, - &hello, - &reply, - &confirm, - &secrets, - ) { - Ok(session_key) => { - if let Some(entry) = state.peers.peer_mut(peer) { - entry.session = PeerSession::Connected { - session_key, - keepalive: KeepAliveState::new(), - }; - self.platform.handle_peer_status(peer, &entry.session); - } - self.record_activity(state, peer); - } - Err(_) => { - if let Some(entry) = state.peers.peer_mut(peer) { - entry.session = PeerSession::Disconnected; - self.platform.handle_peer_status(peer, &entry.session); - } - } - } - } - - fn start_responder_handshake( - &self, - state: &mut RuntimeState, - peer: XID, - hello: crate::wire::handshake::Hello, - ) { - let encapsulation_key = match state.peers.peer(peer) { - Some(entry) => entry.encapsulation_key.clone(), - None => return, - }; - let (reply, secrets) = match handshake::respond_hello( - &self.platform, - peer, - self.platform.xid(), - &encapsulation_key, - &hello, - ) { - Ok(result) => result, - Err(_) => { - if let Some(entry) = state.peers.peer_mut(peer) { - entry.session = PeerSession::Disconnected; - self.platform.handle_peer_status(peer, &entry.session); - } - return; - } - }; - - let deadline = Instant::now() + self.config.handshake_timeout; - let token = state.next_token(); - if let Some(entry) = state.peers.peer_mut(peer) { - entry.session = PeerSession::Responder { - handshake_token: token, - hello: hello.clone(), - reply: reply.clone(), - secrets, - deadline, - }; - self.platform.handle_peer_status(peer, &entry.session); - } - - let message = QlRecord { - header: QlHeader { - sender: self.platform.xid(), - recipient: peer, - }, - payload: QlPayload::Handshake(HandshakeRecord::HelloReply(reply)), - }; - let bytes = CBOR::from(message).to_cbor_data(); - self.enqueue_handshake_message(state, peer, token, deadline, bytes); - } - - fn enqueue_handshake_message( - &self, - state: &mut RuntimeState, - peer: XID, - token: Token, - deadline: Instant, - bytes: Vec, - ) { - state.outbound.push_back(OutboundMessage { - peer, - token, - message_id: None, - payload: OutboundPayload::PreEncoded(bytes), - }); - state.timeouts.push(Reverse(TimeoutEntry { - at: deadline, - kind: TimeoutKind::Handshake { peer, token }, - })); - state.timeouts.push(Reverse(TimeoutEntry { - at: deadline, - kind: TimeoutKind::Outbound { token }, - })); - } - - fn enqueue_outbound( - &self, - state: &mut RuntimeState, - peer: XID, - payload: OutboundPayload, - deadline: Instant, - message_id: Option, - ) { - let token = state.next_token(); - state.outbound.push_back(OutboundMessage { - peer, - token, - message_id, - payload, - }); - state.timeouts.push(Reverse(TimeoutEntry { - at: deadline, - kind: TimeoutKind::Outbound { token }, - })); - } - - fn enqueue_outbound_preencoded( - &self, - state: &mut RuntimeState, - peer: XID, - bytes: Vec, - deadline: Instant, - priority: bool, - ) { - let token = state.next_token(); - let message = OutboundMessage { - peer, - token, - message_id: None, - payload: OutboundPayload::PreEncoded(bytes), - }; - if priority { - state.outbound.push_front(message); - } else { - state.outbound.push_back(message); - } - state.timeouts.push(Reverse(TimeoutEntry { - at: deadline, - kind: TimeoutKind::Outbound { token }, - })); - } - - fn handle_timeouts(&self, state: &mut RuntimeState) { - let now = Instant::now(); - loop { - let Some(entry) = state.timeouts.peek_mut().filter(|e| e.0.at <= now) else { - break; - }; - let entry = PeekMut::pop(entry).0; - match entry.kind { - TimeoutKind::Outbound { token } => { - let mut message_id = None; - state.outbound.retain(|message| { - if message.token == token { - message_id = message.message_id; - false - } else { - true - } - }); - if let Some(id) = message_id { - if let Some(entry) = state.pending.remove(&id) { - let _ = entry.tx.send(Err(QlError::SendFailed)); - } - if let Some(entry) = state.pending_stream.remove(&id) { - let _ = entry.tx.send(Err(QlError::SendFailed)); - } - } - } - TimeoutKind::Handshake { peer, token } => { - let Some(entry) = state.peers.peer(peer) else { - continue; - }; - let should_disconnect = match &entry.session { - PeerSession::Initiator { - handshake_token, .. - } - | PeerSession::Responder { - handshake_token, .. - } => *handshake_token == token, - _ => false, - }; - if should_disconnect { - if let Some(entry) = state.peers.peer_mut(peer) { - entry.session = PeerSession::Disconnected; - self.platform.handle_peer_status(peer, &entry.session); - } - state.outbound.retain(|message| message.peer != peer); - } - } - TimeoutKind::Request { id } => { - if let Some(entry) = state.pending.remove(&id) { - let _ = entry.tx.send(Err(QlError::Timeout)); - } - if let Some(entry) = state.pending_stream.remove(&id) { - let _ = entry.tx.send(Err(QlError::Timeout)); - } - } - TimeoutKind::KeepAliveSend { peer, token } => { - let Some(config) = self.keep_alive_config() else { - continue; - }; - let session_key = { - let Some(entry) = state.peers.peer(peer) else { - continue; - }; - let PeerSession::Connected { - session_key, - keepalive, - } = &entry.session - else { - continue; - }; - if keepalive.token == token && !keepalive.pending { - session_key.clone() - } else { - continue; - } - }; - self.send_heartbeat_message(state, peer, session_key); - if let Some(entry) = state.peers.peer_mut(peer) { - if let PeerSession::Connected { keepalive, .. } = &mut entry.session { - if keepalive.token == token { - keepalive.pending = true; - } - } - } - state.timeouts.push(Reverse(TimeoutEntry { - at: now + config.timeout, - kind: TimeoutKind::KeepAliveTimeout { peer, token }, - })); - } - TimeoutKind::KeepAliveTimeout { peer, token } => { - let Some(entry) = state.peers.peer(peer) else { - continue; - }; - - let should_disconnect = match &entry.session { - PeerSession::Connected { keepalive, .. } => { - keepalive.token == token && keepalive.pending - } - _ => false, - }; - - if should_disconnect { - if let Some(entry) = state.peers.peer_mut(peer) { - entry.session = PeerSession::Disconnected; - self.platform.handle_peer_status(peer, &entry.session); - } - self.drop_outbound_for_peer(state, peer); - self.fail_pending_for_peer(state, peer); - self.fail_pending_stream_for_peer(state, peer); - self.abort_transfers_for_peer(state, peer, QlError::SendFailed); - } - } - TimeoutKind::TransferAck { - peer, - transfer_id, - next_seq, - attempt, - } => { - self.handle_transfer_ack_timeout(state, peer, transfer_id, next_seq, attempt); - } - TimeoutKind::TransferCancelAck { - peer, - transfer_id, - attempt, - } => { - self.handle_transfer_cancel_ack_timeout(state, peer, transfer_id, attempt); - } - } - } - } - - fn handle_write_done( - &self, - state: &mut RuntimeState, - peer: XID, - token: Token, - message_id: Option, - result: Result<(), QlError>, - ) { - if result.is_ok() { - return; - } - - if let Some(id) = message_id { - if let Some(entry) = state.pending.remove(&id) { - let _ = entry.tx.send(Err(QlError::SendFailed)); - } - if let Some(entry) = state.pending_stream.remove(&id) { - let _ = entry.tx.send(Err(QlError::SendFailed)); - } - } - let should_disconnect = match state.peers.peer(peer).map(|entry| &entry.session) { - Some(PeerSession::Initiator { - handshake_token, .. - }) if *handshake_token == token => true, - Some(PeerSession::Responder { - handshake_token, .. - }) if *handshake_token == token => true, - _ => false, - }; - if should_disconnect { - if let Some(entry) = state.peers.peer_mut(peer) { - entry.session = PeerSession::Disconnected; - self.platform.handle_peer_status(peer, &entry.session); - } - state.outbound.retain(|message| message.peer != peer); - self.fail_pending_for_peer(state, peer); - self.fail_pending_stream_for_peer(state, peer); - self.abort_transfers_for_peer(state, peer, QlError::SendFailed); - } - } -} diff --git a/ql/src/runtime/handle.rs b/ql/src/runtime/handle.rs deleted file mode 100644 index 90176191..00000000 --- a/ql/src/runtime/handle.rs +++ /dev/null @@ -1,492 +0,0 @@ -use std::{ - future::Future, - marker::PhantomData, - pin::{pin, Pin}, - task::{Context, Poll}, -}; - -use async_channel::Sender; -use bc_components::{MLDSAPublicKey, MLKEMPublicKey, XID}; -use dcbor::CBOR; - -use crate::{ - runtime::{ - internal::{InboundStreamDelivery, InboundStreamItem, OutboundStreamInput, RuntimeCommand}, - RequestConfig, - }, - wire::message::Ack, - Event, MessageId, QlCodec, QlError, QlStream, QlUpload, RequestResponse, RouteId, -}; - -#[derive(Clone)] -pub struct RuntimeHandle { - pub(crate) tx: async_channel::Sender, -} - -pub struct Response { - rx: oneshot::Receiver>, - _type: PhantomData T>, -} - -pub struct StreamResponse { - rx: oneshot::Receiver>, - _type: PhantomData T>, -} - -pub struct InboundStream { - pub meta: T, - pub body: InboundByteStream, -} - -pub struct InboundByteStream { - sender: XID, - transfer_id: MessageId, - rx: async_channel::Receiver, - tx: Sender, - finished: bool, -} - -impl std::fmt::Debug for InboundByteStream { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("InboundByteStream") - .field("sender", &self.sender) - .field("transfer_id", &self.transfer_id) - .field("finished", &self.finished) - .finish_non_exhaustive() - } -} - -pub struct OutboundTransfer { - recipient: XID, - transfer_id: MessageId, - chunk_tx: Option>, - tx: Sender, -} - -pub struct UploadRequest { - pub transfer: OutboundTransfer, - pub response: Response, -} - -impl Response { - pub async fn recv(self) -> Result { - self.rx.await.unwrap_or(Err(QlError::Cancelled)) - } -} - -impl Future for Response -where - T: QlCodec, -{ - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - pin!(&mut self.rx).poll(cx).map(|result| { - let payload = result.unwrap_or(Err(QlError::Cancelled))?; - T::try_from(payload).map_err(|_| QlError::InvalidPayload) - }) - } -} - -impl StreamResponse { - pub async fn recv(self) -> Result, QlError> { - let delivery = self.rx.await.unwrap_or(Err(QlError::Cancelled))?; - let InboundStreamDelivery { - peer, - transfer_id, - meta, - rx, - tx, - } = delivery; - Ok(InboundStream { - meta, - body: InboundByteStream::new(peer, transfer_id, rx, tx), - }) - } -} - -impl Future for StreamResponse -where - T: QlCodec, -{ - type Output = Result, QlError>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - pin!(&mut self.rx).poll(cx).map(|result| { - let delivery = result.unwrap_or(Err(QlError::Cancelled))?; - let InboundStreamDelivery { - peer, - transfer_id, - meta, - rx, - tx, - } = delivery; - let meta = T::try_from(meta).map_err(|_| QlError::InvalidPayload)?; - Ok(InboundStream { - meta, - body: InboundByteStream::new(peer, transfer_id, rx, tx), - }) - }) - } -} - -impl UploadRequest -where - R: QlCodec, -{ - pub async fn finish(self) -> Result { - let Self { transfer, response } = self; - transfer.finish().await?; - response.await - } -} - -impl InboundByteStream { - pub(crate) fn new( - sender: XID, - transfer_id: MessageId, - rx: async_channel::Receiver, - tx: Sender, - ) -> Self { - Self { - sender, - transfer_id, - rx, - tx, - finished: false, - } - } - - pub async fn next_chunk(&mut self) -> Result>, QlError> { - if self.finished { - return Ok(None); - } - match self.rx.recv().await { - Ok(InboundStreamItem::Chunk(chunk)) => Ok(Some(chunk)), - Ok(InboundStreamItem::Finished) => { - self.finished = true; - Ok(None) - } - Ok(InboundStreamItem::Error(error)) => { - self.finished = true; - Err(error) - } - Err(_) => { - self.finished = true; - Err(QlError::TransferCancelled { - id: self.transfer_id, - }) - } - } - } -} - -impl Drop for InboundByteStream { - fn drop(&mut self) { - if self.finished { - return; - } - let _ = self.tx.try_send(RuntimeCommand::CancelInboundTransfer { - sender: self.sender, - transfer_id: self.transfer_id, - }); - } -} - -impl OutboundTransfer { - pub(crate) fn new( - recipient: XID, - transfer_id: MessageId, - chunk_tx: Sender, - tx: Sender, - ) -> Self { - Self { - recipient, - transfer_id, - chunk_tx: Some(chunk_tx), - tx, - } - } - - pub async fn write_next(&mut self, chunk: Vec) -> Result<(), QlError> { - let chunk_tx = self - .chunk_tx - .as_ref() - .expect("transfer not finished or cancelled"); - chunk_tx - .send(OutboundStreamInput::Chunk(chunk)) - .await - .map_err(|_| QlError::TransferCancelled { - id: self.transfer_id, - })?; - self.tx - .send(RuntimeCommand::PollOutboundTransfer { - recipient: self.recipient, - transfer_id: self.transfer_id, - }) - .await - .map_err(|_| QlError::Cancelled)?; - Ok(()) - } - - pub async fn finish(mut self) -> Result<(), QlError> { - let Some(chunk_tx) = self.chunk_tx.take() else { - return Ok(()); - }; - if chunk_tx.send(OutboundStreamInput::Finish).await.is_err() { - return Ok(()); - } - self.tx - .send(RuntimeCommand::PollOutboundTransfer { - recipient: self.recipient, - transfer_id: self.transfer_id, - }) - .await - .map_err(|_| QlError::Cancelled)?; - chunk_tx.closed().await; - Ok(()) - } - - pub async fn cancel(mut self) -> Result<(), QlError> { - self.chunk_tx.take(); - self.tx - .send(RuntimeCommand::CancelOutboundTransfer { - recipient: self.recipient, - transfer_id: self.transfer_id, - }) - .await - .map_err(|_| QlError::Cancelled) - } -} - -impl Drop for OutboundTransfer { - fn drop(&mut self) { - if self.chunk_tx.take().is_none() { - return; - } - let _ = self.tx.try_send(RuntimeCommand::CancelOutboundTransfer { - recipient: self.recipient, - transfer_id: self.transfer_id, - }); - } -} - -impl RuntimeHandle { - pub fn register_peer( - &self, - peer: XID, - signing_key: MLDSAPublicKey, - encapsulation_key: MLKEMPublicKey, - ) { - self.send(RuntimeCommand::RegisterPeer { - peer, - signing_key, - encapsulation_key, - }) - } - - pub fn connect(&self, peer: XID) -> Result<(), QlError> { - self.tx - .send_blocking(RuntimeCommand::Connect { peer }) - .map_err(|_| QlError::Cancelled) - } - - pub fn unpair(&self, peer: XID) -> Result<(), QlError> { - self.tx - .send_blocking(RuntimeCommand::Unpair { peer }) - .map_err(|_| QlError::Cancelled) - } - - pub fn send_incoming(&self, bytes: Vec) { - self.send(RuntimeCommand::Incoming(bytes)) - } - - pub fn request( - &self, - message: M, - recipient: XID, - config: RequestConfig, - ) -> Response - where - M: RequestResponse, - { - let (tx, rx) = oneshot::channel(); - self.send(RuntimeCommand::SendRequest { - recipient, - route_id: M::ID, - payload: message.into(), - respond_to: tx, - config, - }); - Response { - rx, - _type: PhantomData, - } - } - - pub fn request_stream( - &self, - message: M, - recipient: XID, - config: RequestConfig, - ) -> StreamResponse - where - M: QlStream, - { - let (tx, rx) = oneshot::channel(); - self.send(RuntimeCommand::SendStreamRequest { - recipient, - route_id: M::ID, - payload: message.into(), - respond_to: tx, - config, - }); - StreamResponse { - rx, - _type: PhantomData, - } - } - - pub async fn request_upload( - &self, - message: M, - recipient: XID, - config: RequestConfig, - ) -> Result, QlError> - where - M: QlUpload, - { - let upload = self - .send_request_upload_raw(recipient, M::ID, message.into(), config) - .await?; - Ok(UploadRequest { - transfer: upload.transfer, - response: Response { - rx: upload.response.rx, - _type: PhantomData, - }, - }) - } - - pub fn send_event(&self, message: M, recipient: XID) - where - M: Event, - { - self.send_event_raw(recipient, M::ID, message.into()) - } - - pub fn send_event_with_ack( - &self, - message: M, - recipient: XID, - config: RequestConfig, - ) -> Response - where - M: Event, - { - let (tx, rx) = oneshot::channel(); - self.send(RuntimeCommand::SendRequest { - recipient, - route_id: M::ID, - payload: message.into(), - respond_to: tx, - config, - }); - Response { - rx, - _type: PhantomData, - } - } - - pub fn send_event_raw(&self, recipient: XID, route_id: RouteId, payload: CBOR) { - self.send(RuntimeCommand::SendEvent { - recipient, - route_id, - payload, - }) - } - - pub fn send_request_raw( - &self, - recipient: XID, - route_id: RouteId, - payload: CBOR, - config: RequestConfig, - ) -> Response { - let (tx, rx) = oneshot::channel(); - self.send(RuntimeCommand::SendRequest { - recipient, - route_id, - payload, - respond_to: tx, - config, - }); - Response { - rx, - _type: PhantomData, - } - } - - pub fn send_request_stream_raw( - &self, - recipient: XID, - route_id: RouteId, - payload: CBOR, - config: RequestConfig, - ) -> StreamResponse { - let (tx, rx) = oneshot::channel(); - self.send(RuntimeCommand::SendStreamRequest { - recipient, - route_id, - payload, - respond_to: tx, - config, - }); - StreamResponse { - rx, - _type: PhantomData, - } - } - - pub async fn send_request_upload_raw( - &self, - recipient: XID, - route_id: RouteId, - payload: CBOR, - config: RequestConfig, - ) -> Result, QlError> { - let (response_tx, response_rx) = oneshot::channel(); - let (chunk_tx, chunk_rx) = async_channel::bounded(1); - let (start_tx, start_rx) = oneshot::channel(); - self.tx - .send(RuntimeCommand::SendUploadRequest { - recipient, - route_id, - payload, - respond_to: response_tx, - chunk_rx, - start: start_tx, - config, - }) - .await - .map_err(|_| QlError::Cancelled)?; - - let transfer_id = start_rx.await.unwrap_or(Err(QlError::Cancelled))?; - - Ok(UploadRequest { - transfer: OutboundTransfer::new(recipient, transfer_id, chunk_tx, self.tx.clone()), - response: Response { - rx: response_rx, - _type: PhantomData, - }, - }) - } -} - -impl RuntimeHandle { - #[inline] - #[track_caller] - fn send(&self, cmd: RuntimeCommand) { - self.tx.send_blocking(cmd).expect("runtime is alive") - } -} diff --git a/ql/src/runtime/internal.rs b/ql/src/runtime/internal.rs deleted file mode 100644 index 0a94ec6b..00000000 --- a/ql/src/runtime/internal.rs +++ /dev/null @@ -1,545 +0,0 @@ -use std::{ - cell::Cell, - cmp::Reverse, - collections::{BinaryHeap, HashMap, VecDeque}, - time::{Instant, SystemTime, UNIX_EPOCH}, -}; - -use async_channel::{Receiver, Sender}; -use bc_components::{MLDSAPublicKey, MLKEMPublicKey, SymmetricKey, XID}; -use dcbor::CBOR; - -use crate::{ - platform::PlatformFuture, - runtime::{replay_cache::ReplayCache, RequestConfig}, - wire::{ - handshake::{Hello, HelloReply}, - message::{MessageBody, MessageKind}, - }, - MessageId, Peer, QlError, RouteId, -}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -// Monotonic token for timeout correlation. -pub struct Token(u64); - -#[derive(Debug, Clone)] -// Per-peer keepalive timers and ping state. -pub struct KeepAliveState { - pub token: Token, - pub pending: bool, - pub last_activity: Option, -} - -impl KeepAliveState { - pub fn new() -> Self { - Self { - token: Token(0), - pending: false, - last_activity: None, - } - } -} - -impl Default for KeepAliveState { - fn default() -> Self { - Self::new() - } -} - -#[derive(Debug, Clone)] -// Registered peer identity and current session. -pub struct PeerRecord { - pub peer: XID, - pub signing_key: MLDSAPublicKey, - pub encapsulation_key: MLKEMPublicKey, - pub session: PeerSession, -} - -impl PeerRecord { - pub fn new(peer: XID, signing_key: MLDSAPublicKey, encapsulation_key: MLKEMPublicKey) -> Self { - Self { - peer, - signing_key, - encapsulation_key, - session: PeerSession::Disconnected, - } - } -} - -#[derive(Debug, Clone)] -// In-memory registry of known peers. -pub struct PeerStore { - peers: Vec, -} - -impl PeerStore { - pub fn new() -> Self { - Self { peers: Vec::new() } - } - - pub fn peer(&self, peer: XID) -> Option<&PeerRecord> { - self.peers.iter().find(|record| record.peer == peer) - } - - pub fn peer_mut(&mut self, peer: XID) -> Option<&mut PeerRecord> { - self.peers.iter_mut().find(|record| record.peer == peer) - } - - pub fn upsert_peer( - &mut self, - peer: XID, - signing_key: MLDSAPublicKey, - encapsulation_key: MLKEMPublicKey, - ) -> &mut PeerRecord { - if let Some(index) = self.peers.iter().position(|record| record.peer == peer) { - let record = &mut self.peers[index]; - record.signing_key = signing_key; - record.encapsulation_key = encapsulation_key; - return record; - } - self.peers - .push(PeerRecord::new(peer, signing_key, encapsulation_key)); - self.peers.last_mut().expect("peer record just inserted") - } - - pub fn all(&self) -> Vec { - self.peers - .iter() - .map(|record| Peer { - peer: record.peer, - signing_key: record.signing_key.clone(), - encapsulation_key: record.encapsulation_key.clone(), - }) - .collect() - } - - pub fn remove_peer(&mut self, peer: XID) -> Option { - let index = self.peers.iter().position(|record| record.peer == peer)?; - Some(self.peers.remove(index)) - } -} - -#[derive(Debug, Clone)] -// Session state machine for a peer. -pub enum PeerSession { - // No active handshake or session. - Disconnected, - // Local side initiated the handshake. - Initiator { - handshake_token: Token, - hello: Hello, - session_key: SymmetricKey, - deadline: Instant, - stage: InitiatorStage, - }, - // Local side is responding to a handshake. - Responder { - handshake_token: Token, - hello: Hello, - reply: HelloReply, - secrets: crate::wire::handshake::ResponderSecrets, - deadline: Instant, - }, - // Encrypted session is established. - Connected { - session_key: SymmetricKey, - keepalive: KeepAliveState, - }, -} - -impl PeerSession { - #[inline] - pub fn is_connected(&self) -> bool { - matches!(self, PeerSession::Connected { .. }) - } - - #[inline] - pub fn session_key(&self) -> Option<&SymmetricKey> { - match self { - PeerSession::Connected { session_key, .. } => Some(session_key), - _ => None, - } - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -// Initiator-side handshake progression. -pub enum InitiatorStage { - // Waiting for hello reply. - WaitingHelloReply, - // Waiting for confirm completion. - WaitingConfirmAck, -} - -// Producer messages for outbound transfer data. -pub(crate) enum OutboundStreamInput { - // Emit one data chunk. - Chunk(Vec), - // Mark stream end. - Finish, -} - -// Consumer messages for inbound transfer reads. -pub(crate) enum InboundStreamItem { - // Next received data chunk. - Chunk(Vec), - // Clean stream completion. - Finished, - // Terminal stream failure. - Error(QlError), -} - -// Identity of an accepted inbound transfer open frame. -#[derive(Debug, Clone, PartialEq)] -pub enum InboundTransferOpen { - // Streamed response correlated to a prior request. - Response { - request_id: MessageId, - meta: CBOR, - }, - // Streamed upload request with route metadata. - Request { - request_id: MessageId, - route_id: RouteId, - meta: CBOR, - }, -} - -// Runtime-delivered stream metadata and receiver. -pub(crate) struct InboundStreamDelivery { - pub peer: XID, - pub transfer_id: MessageId, - pub meta: CBOR, - pub rx: Receiver, - pub tx: Sender, -} - -// Last sender frame currently awaiting ack. -pub enum OutboundAwaiting { - // Open frame with request correlation. - Open { - request_id: MessageId, - route_id: Option, - meta: CBOR, - }, - // Data frame at a specific sequence. - Chunk { - seq: u32, - data: Vec, - }, - // Finish frame at a specific sequence. - Finish { - seq: u32, - }, - // Cancel frame awaiting cancel-ack. - Cancel, -} - -// Coarse sender-side transfer lifecycle. -pub enum OutboundTransferStage { - // Opening frame not yet acknowledged. - Opening, - // Streaming chunks frame-by-frame. - Streaming, - // Finish frame sent, waiting for ack. - Finishing, - // Cancellation in progress. - Cancelling, -} - -// Runtime state for one outbound transfer. -pub struct OutboundTransferState { - pub request_id: MessageId, - pub peer: XID, - pub transfer_id: MessageId, - pub stage: OutboundTransferStage, - pub next_seq: u32, - pub open_route_id: Option, - pub open_meta: Option, - pub chunk_rx: Receiver, - pub awaiting: Option, -} - -// Runtime state for one inbound transfer. -pub struct InboundTransferState { - pub open: InboundTransferOpen, - pub expected_seq: u32, - pub chunk_tx: Sender, -} - -// Commands consumed by the runtime loop. -pub(crate) enum RuntimeCommand { - // Upsert a peer record. - RegisterPeer { - peer: XID, - signing_key: MLDSAPublicKey, - encapsulation_key: MLKEMPublicKey, - }, - // Start handshake with a peer. - Connect { - peer: XID, - }, - // Send unpair and remove peer. - Unpair { - peer: XID, - }, - // Send unary request and await unary response. - SendRequest { - recipient: XID, - route_id: RouteId, - payload: CBOR, - respond_to: oneshot::Sender>, - config: RequestConfig, - }, - // Send unary request and await streamed response. - SendStreamRequest { - recipient: XID, - route_id: RouteId, - payload: CBOR, - respond_to: oneshot::Sender>, - config: RequestConfig, - }, - // Send streamed request and await unary response. - SendUploadRequest { - recipient: XID, - route_id: RouteId, - payload: CBOR, - respond_to: oneshot::Sender>, - chunk_rx: Receiver, - start: oneshot::Sender>, - config: RequestConfig, - }, - // Send fire-and-forget event. - SendEvent { - recipient: XID, - route_id: RouteId, - payload: CBOR, - }, - // Send unary response or nack. - SendResponse { - id: MessageId, - recipient: XID, - payload: CBOR, - kind: MessageKind, - }, - // Start sender-side streamed response. - StartResponseStream { - request_id: MessageId, - recipient: XID, - meta: CBOR, - chunk_rx: Receiver, - }, - // Prompt immediate outbound transfer polling. - PollOutboundTransfer { - recipient: XID, - transfer_id: MessageId, - }, - // Cancel sender-side active transfer. - CancelOutboundTransfer { - recipient: XID, - transfer_id: MessageId, - }, - // Cancel receiver-side active transfer. - CancelInboundTransfer { - sender: XID, - transfer_id: MessageId, - }, - // Process raw incoming bytes. - Incoming(Vec), -} - -// Mutable state owned by the runtime loop. -pub struct RuntimeState { - pub peers: PeerStore, - pub next_token: Cell, - pub outbound: VecDeque, - pub timeouts: BinaryHeap>, - pub pending: HashMap, - pub pending_stream: HashMap, - pub outbound_transfers: HashMap<(XID, MessageId), OutboundTransferState>, - pub inbound_transfers: HashMap<(XID, MessageId), InboundTransferState>, - pub next_message_id: Cell, - pub replay_cache: ReplayCache, -} - -impl RuntimeState { - pub fn new() -> Self { - Self { - peers: PeerStore::new(), - next_token: Cell::new(Token(1)), - outbound: VecDeque::new(), - timeouts: BinaryHeap::new(), - pending: HashMap::new(), - pending_stream: HashMap::new(), - outbound_transfers: HashMap::new(), - inbound_transfers: HashMap::new(), - next_message_id: Cell::new(MessageId(1)), - replay_cache: ReplayCache::new(), - } - } - - pub fn next_token(&self) -> Token { - let token = self.next_token.get(); - self.next_token.set(Token(token.0.wrapping_add(1))); - token - } - - pub fn next_message_id(&self) -> MessageId { - let id = self.next_message_id.get(); - self.next_message_id.set(MessageId(id.0.wrapping_add(1))); - id - } -} - -// Pending unary response waiter. -pub struct PendingEntry { - pub recipient: XID, - pub tx: oneshot::Sender>, -} - -// Pending streamed response opener waiter. -pub struct PendingStreamEntry { - pub recipient: XID, - pub tx: oneshot::Sender>, -} - -// Currently executing platform write. -pub struct InFlightWrite<'a> { - pub peer: XID, - pub token: Token, - pub message_id: Option, - pub future: PlatformFuture<'a, Result<(), QlError>>, -} - -// Queued payload representation. -pub enum OutboundPayload { - // Payload already encoded into bytes. - PreEncoded(Vec), - // Payload to encrypt at send time. - DeferredMessage(MessageBody), -} - -// Outbound queue item with timeout token. -pub struct OutboundMessage { - pub peer: XID, - pub token: Token, - pub message_id: Option, - pub payload: OutboundPayload, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -// Runtime timeout categories. -pub enum TimeoutKind { - // Outbound queue item expired. - Outbound { - token: Token, - }, - // Handshake stage expired. - Handshake { - peer: XID, - token: Token, - }, - // Request waiting for reply expired. - Request { - id: MessageId, - }, - // Send keepalive ping now. - KeepAliveSend { - peer: XID, - token: Token, - }, - // Keepalive pong timeout. - KeepAliveTimeout { - peer: XID, - token: Token, - }, - // Transfer data/open/finish ack timeout. - TransferAck { - peer: XID, - transfer_id: MessageId, - next_seq: u32, - attempt: u8, - }, - // Transfer cancel-ack timeout. - TransferCancelAck { - peer: XID, - transfer_id: MessageId, - attempt: u8, - }, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -// One scheduled timeout entry. -pub struct TimeoutEntry { - pub at: Instant, - pub kind: TimeoutKind, -} - -impl Ord for TimeoutEntry { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.at.cmp(&other.at) - } -} - -impl PartialOrd for TimeoutEntry { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -// Outcome of one runtime loop poll cycle. -pub enum LoopStep { - // Received a runtime command. - Event(RuntimeCommand), - // One or more timeouts fired. - Timeout, - // In-flight write completed. - WriteDone { - peer: XID, - token: Token, - message_id: Option, - result: Result<(), QlError>, - }, - // Runtime should exit loop. - Quit, -} - -// Decision for inbound hello handling. -pub enum HelloAction { - // Become responder for this hello. - StartResponder, - // Re-send existing hello reply. - ResendReply { - reply: HelloReply, - deadline: Instant, - }, - // Ignore this hello. - Ignore, -} - -pub fn next_timeout_deadline(state: &RuntimeState) -> Option { - state.timeouts.peek().map(|entry| entry.0.at) -} - -pub fn peer_hello_wins( - local_hello: &Hello, - local_sender: XID, - peer_hello: &Hello, - peer_sender: XID, -) -> bool { - use std::cmp::Ordering; - - match peer_hello.nonce.data().cmp(local_hello.nonce.data()) { - Ordering::Less => true, - Ordering::Greater => false, - Ordering::Equal => peer_sender.data().cmp(local_sender.data()) == Ordering::Less, - } -} - -pub fn now_secs() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|duration| duration.as_secs()) - .unwrap_or(0) -} diff --git a/ql/src/runtime/mod.rs b/ql/src/runtime/mod.rs deleted file mode 100644 index 490a73b6..00000000 --- a/ql/src/runtime/mod.rs +++ /dev/null @@ -1,180 +0,0 @@ -pub use handle::{ - InboundByteStream, InboundStream, OutboundTransfer, Response, RuntimeHandle, StreamResponse, - UploadRequest, -}; -pub use internal::{InitiatorStage, PeerSession, Token}; - -mod core; -pub mod handle; -pub(crate) mod internal; -pub mod replay_cache; - -use std::time::Duration; - -use bc_components::XID; -use dcbor::CBOR; - -use crate::{ - wire::message::{DecryptedMessage, MessageKind, Nack}, - MessageId, QlCodec, QlError, RouteId, -}; - -#[derive(Debug, Clone, Default)] -pub struct RequestConfig { - pub timeout: Option, -} - -#[derive(Debug, Clone, Copy)] -pub struct KeepAliveConfig { - pub interval: Duration, - pub timeout: Duration, -} - -#[derive(Debug, Clone, Copy)] -pub struct RuntimeConfig { - pub handshake_timeout: Duration, - pub default_request_timeout: Duration, - pub message_expiration: Duration, - pub keep_alive: Option, -} - -impl RuntimeConfig { - pub fn new(handshake_timeout: Duration) -> Self { - Self { - handshake_timeout, - default_request_timeout: Duration::from_secs(5), - message_expiration: Duration::from_secs(30), - keep_alive: None, - } - } - - pub fn with_request_timeout(mut self, timeout: Duration) -> Self { - self.default_request_timeout = timeout; - self - } - - pub fn with_message_expiration(mut self, expiration: Duration) -> Self { - self.message_expiration = expiration; - self - } - - pub fn with_keep_alive(mut self, config: KeepAliveConfig) -> Self { - self.keep_alive = Some(config); - self - } -} - -#[derive(Debug)] -pub enum HandlerEvent { - Request(InboundRequest), - UploadRequest(InboundUploadRequest), - Event(InboundEvent), -} - -#[derive(Debug)] -pub struct InboundRequest { - pub message: DecryptedMessage, - pub respond_to: Responder, -} - -#[derive(Debug)] -pub struct InboundUploadRequest { - pub sender: XID, - pub recipient: XID, - pub route_id: RouteId, - pub message_id: MessageId, - pub meta: CBOR, - pub body: InboundByteStream, - pub respond_to: Responder, -} - -#[derive(Debug)] -pub struct InboundEvent { - pub message: DecryptedMessage, -} - -#[derive(Debug, Clone)] -pub struct Responder { - id: MessageId, - recipient: XID, - tx: async_channel::Sender, -} - -impl Responder { - pub(crate) fn new( - id: MessageId, - recipient: XID, - tx: async_channel::Sender, - ) -> Self { - Self { id, recipient, tx } - } - - pub fn respond(self, response: R) -> Result<(), QlError> - where - R: QlCodec, - { - self.tx - .try_send(internal::RuntimeCommand::SendResponse { - id: self.id, - recipient: self.recipient, - payload: response.into(), - kind: MessageKind::Response, - }) - .map_err(|_| QlError::Cancelled) - } - - pub fn respond_nack(self, reason: Nack) -> Result<(), QlError> { - self.tx - .try_send(internal::RuntimeCommand::SendResponse { - id: self.id, - recipient: self.recipient, - payload: CBOR::from(reason), - kind: MessageKind::Nack, - }) - .map_err(|_| QlError::Cancelled) - } - - pub fn respond_stream(self, meta: M) -> Result - where - M: QlCodec, - { - let (chunk_tx, chunk_rx) = async_channel::bounded(1); - self.tx - .send_blocking(internal::RuntimeCommand::StartResponseStream { - request_id: self.id, - recipient: self.recipient, - meta: meta.into(), - chunk_rx, - }) - .map_err(|_| QlError::Cancelled)?; - Ok(handle::OutboundTransfer::new( - self.recipient, - self.id, - chunk_tx, - self.tx, - )) - } -} - -pub struct Runtime

{ - platform: P, - config: RuntimeConfig, - rx: async_channel::Receiver, - tx: async_channel::WeakSender, -} - -pub fn new_runtime

(platform: P, config: RuntimeConfig) -> (Runtime

, RuntimeHandle) -where - P: crate::platform::QlPlatform, -{ - let (tx, rx) = async_channel::unbounded(); - ( - Runtime { - platform, - config, - rx, - tx: tx.downgrade(), - }, - RuntimeHandle { tx }, - ) -} diff --git a/ql/src/runtime/replay_cache.rs b/ql/src/runtime/replay_cache.rs deleted file mode 100644 index 80e60d21..00000000 --- a/ql/src/runtime/replay_cache.rs +++ /dev/null @@ -1,181 +0,0 @@ -use std::{ - cmp::Reverse, - collections::{binary_heap::PeekMut, BinaryHeap, HashSet}, -}; - -use bc_components::XID; - -use crate::{runtime::internal::now_secs, MessageId}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum ReplayNamespace { - Peer, - Local, - Transfer, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct ReplayKey { - pub peer: XID, - pub namespace: ReplayNamespace, - pub message_id: MessageId, -} - -impl ReplayKey { - pub const fn new(peer: XID, namespace: ReplayNamespace, message_id: MessageId) -> Self { - Self { - peer, - namespace, - message_id, - } - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -struct ExpiryEntry { - expires_at: u64, - key: ReplayKey, -} - -impl Ord for ExpiryEntry { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.expires_at - .cmp(&other.expires_at) - .then_with(|| self.key.cmp(&other.key)) - } -} - -impl PartialOrd for ExpiryEntry { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -#[derive(Debug, Default)] -pub struct ReplayCache { - entries: HashSet, - expirations: BinaryHeap>, -} - -impl ReplayCache { - pub fn new() -> Self { - Self { - entries: HashSet::new(), - expirations: BinaryHeap::new(), - } - } - - pub fn len(&self) -> usize { - self.entries.len() - } - - pub fn is_empty(&self) -> bool { - self.entries.is_empty() - } - - pub fn add(&mut self, key: ReplayKey, expires_at: u64) { - if self.entries.insert(key) { - self.expirations - .push(Reverse(ExpiryEntry { expires_at, key })); - } - } - - pub fn check_and_store(&mut self, key: ReplayKey, expires_at: u64) -> bool { - let now_secs = now_secs(); - self.check_and_store_at(key, expires_at, now_secs) - } - - pub fn check_and_store_valid_until(&mut self, key: ReplayKey, valid_until: u64) -> bool { - let now_secs = now_secs(); - self.check_and_store_at(key, valid_until, now_secs) - } - - pub fn purge_expired(&mut self) { - let now_secs = now_secs(); - self.purge_expired_at(now_secs); - } - - pub fn clear_peer(&mut self, peer: XID) { - self.entries.retain(|entry| entry.peer != peer); - self.expirations.retain(|entry| entry.0.key.peer != peer); - } - - fn check_and_store_at(&mut self, key: ReplayKey, expires_at: u64, now_secs: u64) -> bool { - self.purge_expired_at(now_secs); - if self.entries.contains(&key) { - return true; - } - self.entries.insert(key); - self.expirations - .push(Reverse(ExpiryEntry { expires_at, key })); - false - } - - fn purge_expired_at(&mut self, now_secs: u64) { - while let Some(entry) = self.expirations.peek_mut() { - if entry.0.expires_at > now_secs { - break; - } - let entry = PeekMut::pop(entry).0; - self.entries.remove(&entry.key); - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn peer_with_byte(byte: u8) -> XID { - XID::from_data([byte; XID::XID_SIZE]) - } - - #[test] - fn check_and_store_detects_replay() { - let mut cache = ReplayCache::new(); - let peer = peer_with_byte(1); - let key = ReplayKey::new(peer, ReplayNamespace::Peer, MessageId(1)); - let now_secs = 100; - let expires_at = 110; - - assert!(!cache.check_and_store_at(key, expires_at, now_secs)); - assert!(cache.check_and_store_at(key, expires_at, now_secs)); - } - - #[test] - fn purge_expired_removes_old_entries() { - let mut cache = ReplayCache::new(); - let now_secs = 100; - let expired_at = 99; - let future_at = 110; - - let key_old = ReplayKey::new(peer_with_byte(2), ReplayNamespace::Peer, MessageId(2)); - let key_new = ReplayKey::new(peer_with_byte(3), ReplayNamespace::Peer, MessageId(3)); - - cache.add(key_old, expired_at); - cache.add(key_new, future_at); - - cache.purge_expired_at(now_secs); - assert_eq!(cache.len(), 1); - assert!(!cache.check_and_store_at(key_old, future_at, now_secs)); - } - - #[test] - fn clear_peer_removes_peer_entries() { - let mut cache = ReplayCache::new(); - let now_secs = 100; - let expires_at = 110; - - let peer_a = peer_with_byte(4); - let peer_b = peer_with_byte(5); - let key_a = ReplayKey::new(peer_a, ReplayNamespace::Peer, MessageId(4)); - let key_b = ReplayKey::new(peer_b, ReplayNamespace::Peer, MessageId(5)); - - cache.add(key_a, expires_at); - cache.add(key_b, expires_at); - - cache.clear_peer(peer_a); - assert_eq!(cache.len(), 1); - assert!(!cache.check_and_store_at(key_a, expires_at, now_secs)); - } -} diff --git a/ql/src/tests/handshake.rs b/ql/src/tests/handshake.rs deleted file mode 100644 index 34580edb..00000000 --- a/ql/src/tests/handshake.rs +++ /dev/null @@ -1,292 +0,0 @@ -use super::*; - -#[tokio::test(flavor = "current_thread")] -async fn handshake_initiator_connects() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = - new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_drop_heartbeat_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn handshake_timeout_disconnects() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(50)); - let (platform_a, _outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); - - let peer_b = platform_b.xid(); - let (runtime_a, handle_a) = new_runtime(platform_a, config); - tokio::task::spawn_local(async move { runtime_a.run().await }); - - handle_a.register_peer( - peer_b, - platform_b.signing_public_key().clone(), - platform_b.encapsulation_public_key().clone(), - ); - - handle_a.connect(peer_b).unwrap(); - - await_status(&status_a, peer_b, PeerStage::Disconnected).await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn simultaneous_handshakes_resolve() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = - new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - handle_b.connect(peer_a.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Initiator).await; - await_status(&status_b, peer_a.xid, PeerStage::Responder).await; - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn invalid_signature_disconnects() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, _status_b) = TestPlatform::new(2); - let (wrong_private, wrong_public) = MLDSA::MLDSA44.keypair(); - let _ = wrong_private; - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = - new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - handle_a.register_peer(peer_b.xid, wrong_public, peer_b.encapsulation_key.clone()); - handle_b.register_peer( - peer_a.xid, - peer_a.signing_key.clone(), - peer_a.encapsulation_key.clone(), - ); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn pairing_request_triggers_handshake() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let pairing_message = pair::build_pair_request( - &platform_a, - peer_b.xid, - &peer_b.encapsulation_key, - MessageId(1), - Duration::from_secs(1), - ) - .unwrap(); - let pairing_bytes = CBOR::from(pairing_message).to_cbor_data(); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = - new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - handle_a.register_peer( - peer_b.xid, - peer_b.signing_key.clone(), - peer_b.encapsulation_key.clone(), - ); - - handle_b.send_incoming(pairing_bytes); - - await_status(&status_b, peer_a.xid, PeerStage::Initiator).await; - await_status(&status_a, peer_b.xid, PeerStage::Responder).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn blocked_write_still_times_out() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(40)); - let (platform_a, _outbound_a, status_a, _write_gate) = BlockingPlatform::new(2); - let (platform_b, _outbound_b, _status_b) = TestPlatform::new(1); - - let signing_b = platform_b.signing_public_key().clone(); - let encap_b = platform_b.encapsulation_public_key().clone(); - let peer_b = platform_b.xid(); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - tokio::task::spawn_local(async move { runtime_a.run().await }); - - handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); - - handle_a.connect(peer_b).unwrap(); - - await_status(&status_a, peer_b, PeerStage::Initiator).await; - await_status(&status_a, peer_b, PeerStage::Disconnected).await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn handshake_timeout_drops_queued_messages() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(60)); - let (platform_a, outbound_a, status_a, write_gate) = BlockingPlatform::new(2); - let (platform_b, _outbound_b, _status_b) = TestPlatform::new(1); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - tokio::task::spawn_local(async move { runtime_a.run().await }); - - handle_a.register_peer( - peer_b.xid, - peer_b.signing_key.clone(), - peer_b.encapsulation_key.clone(), - ); - - handle_a.connect(peer_b.xid).unwrap(); - await_status(&status_a, peer_b.xid, PeerStage::Initiator).await; - - let (hello, _secret) = wire::handshake::build_hello( - &platform_b, - peer_b.xid, - peer_a.xid, - &peer_a.encapsulation_key, - ) - .unwrap(); - let message = QlRecord { - header: QlHeader { - sender: peer_b.xid, - recipient: peer_a.xid, - }, - payload: QlPayload::Handshake(HandshakeRecord::Hello(hello)), - }; - let bytes = CBOR::from(message).to_cbor_data(); - handle_a.send_incoming(bytes); - - await_status(&status_a, peer_b.xid, PeerStage::Responder).await; - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - - write_gate.add_permits(1); - let _ = tokio::time::timeout(Duration::from_millis(100), outbound_a.recv()) - .await - .unwrap() - .unwrap(); - - write_gate.add_permits(1); - let second = tokio::time::timeout(Duration::from_millis(50), outbound_a.recv()).await; - assert!( - second.is_err(), - "expected queued handshake reply to be dropped" - ); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn multi_peer_simultaneous_handshakes() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let (platform_c, outbound_c, status_c) = TestPlatform::new(3); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - let peer_c = peer_identity(&platform_c); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = - new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); - let (runtime_c, handle_c) = - new_runtime(platform_c, RuntimeConfig::new(Duration::from_millis(200))); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - tokio::task::spawn_local(async move { runtime_c.run().await }); - - spawn_routed_forwarder( - outbound_a, - vec![ - (peer_b.xid, handle_b.clone()), - (peer_c.xid, handle_c.clone()), - ], - ); - spawn_routed_forwarder(outbound_b, vec![(peer_a.xid, handle_a.clone())]); - spawn_routed_forwarder(outbound_c, vec![(peer_a.xid, handle_a.clone())]); - - let _ = register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - let _ = register_peers(&handle_a, &handle_c, &peer_a, &peer_c); - - handle_a.connect(peer_b.xid).unwrap(); - handle_a.connect(peer_c.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_a, peer_c.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - await_status(&status_c, peer_a.xid, PeerStage::Connected).await; - }) - .await; -} diff --git a/ql/src/tests/heartbeat.rs b/ql/src/tests/heartbeat.rs deleted file mode 100644 index cc73b271..00000000 --- a/ql/src/tests/heartbeat.rs +++ /dev/null @@ -1,641 +0,0 @@ -use bc_components::SymmetricKey; - -use super::*; - -#[tokio::test(flavor = "current_thread")] -async fn heartbeat_ignored_without_session() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, _status_a) = TestPlatform::new(1); - let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); - - let peer_a = platform_a.xid(); - let peer_b = platform_b.xid(); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - tokio::task::spawn_local(async move { runtime_a.run().await }); - - handle_a.register_peer( - peer_b, - platform_b.signing_public_key().clone(), - platform_b.encapsulation_public_key().clone(), - ); - - let heartbeat = wire::heartbeat::encrypt_heartbeat( - QlHeader { - sender: peer_b, - recipient: peer_a, - }, - &SymmetricKey::new(), - HeartbeatBody { - message_id: MessageId(1), - valid_until: now_secs().saturating_add(60), - }, - ); - let bytes = CBOR::from(heartbeat).to_cbor_data(); - handle_a.send_incoming(bytes); - - let result = tokio::time::timeout(Duration::from_millis(50), outbound_a.recv()).await; - assert!(result.is_err(), "expected heartbeat to be ignored"); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn keepalive_disabled_no_heartbeat() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - - let signing_a = platform_a.signing_public_key().clone(); - let signing_b = platform_b.signing_public_key().clone(); - let encap_a = platform_a.encapsulation_public_key().clone(); - let encap_b = platform_b.encapsulation_public_key().clone(); - let peer_a = platform_a.xid(); - let peer_b = platform_b.xid(); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); - spawn_forwarder(outbound_b, handle_a.clone()); - - handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); - handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); - - handle_a.connect(peer_b).unwrap(); - - await_status(&status_a, peer_b, PeerStage::Connected).await; - await_status(&status_b, peer_a, PeerStage::Connected).await; - - let result = tokio::time::timeout(Duration::from_millis(120), heartbeat_rx.recv()).await; - assert!(result.is_err(), "unexpected heartbeat while disabled"); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn heartbeat_sent_after_idle() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(30), - timeout: Duration::from_millis(80), - }; - let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); - let config_b = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - - let signing_a = platform_a.signing_public_key().clone(); - let signing_b = platform_b.signing_public_key().clone(); - let encap_a = platform_a.encapsulation_public_key().clone(); - let encap_b = platform_b.encapsulation_public_key().clone(); - let peer_a = platform_a.xid(); - let peer_b = platform_b.xid(); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); - spawn_forwarder(outbound_b, handle_a.clone()); - - handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); - handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); - - handle_a.connect(peer_b).unwrap(); - - await_status(&status_a, peer_b, PeerStage::Connected).await; - await_status(&status_b, peer_a, PeerStage::Connected).await; - - tokio::time::timeout(Duration::from_millis(200), heartbeat_rx.recv()) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn heartbeat_reply_when_connected() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(30), - timeout: Duration::from_millis(80), - }; - let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); - let config_b = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - - let signing_a = platform_a.signing_public_key().clone(); - let signing_b = platform_b.signing_public_key().clone(); - let encap_a = platform_a.encapsulation_public_key().clone(); - let encap_b = platform_b.encapsulation_public_key().clone(); - let peer_a = platform_a.xid(); - let peer_b = platform_b.xid(); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_ab_tx, heartbeat_ab_rx) = async_channel::unbounded(); - let (heartbeat_ba_tx, heartbeat_ba_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_ab_tx); - spawn_heartbeat_tap_forwarder(outbound_b, handle_a.clone(), heartbeat_ba_tx); - - handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); - handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); - - handle_a.connect(peer_b).unwrap(); - - await_status(&status_a, peer_b, PeerStage::Connected).await; - await_status(&status_b, peer_a, PeerStage::Connected).await; - - tokio::time::timeout(Duration::from_millis(200), heartbeat_ab_rx.recv()) - .await - .unwrap() - .unwrap(); - tokio::time::timeout(Duration::from_millis(200), heartbeat_ba_rx.recv()) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn any_message_clears_pending() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(120), - timeout: Duration::from_millis(40), - }; - let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); - let config_b = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - - let signing_a = platform_a.signing_public_key().clone(); - let signing_b = platform_b.signing_public_key().clone(); - let encap_a = platform_a.encapsulation_public_key().clone(); - let encap_b = platform_b.encapsulation_public_key().clone(); - let peer_a = platform_a.xid(); - let peer_b = platform_b.xid(); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); - spawn_drop_heartbeat_forwarder(outbound_b, handle_a.clone()); - - handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); - handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); - - handle_a.connect(peer_b).unwrap(); - - await_status(&status_a, peer_b, PeerStage::Connected).await; - await_status(&status_b, peer_a, PeerStage::Connected).await; - - tokio::time::timeout(Duration::from_millis(200), heartbeat_rx.recv()) - .await - .unwrap() - .unwrap(); - - handle_b.send_event_raw(peer_a, RouteId(99), CBOR::from(1u8)); - - let window = keep_alive.timeout + Duration::from_millis(20); - let disconnect = tokio::time::timeout(window, async { - loop { - if let Ok(event) = status_a.recv().await { - if event.peer == peer_b && event.stage == PeerStage::Disconnected { - return; - } - } - } - }) - .await; - assert!(disconnect.is_err(), "unexpected disconnect"); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn heartbeat_timeout_disconnects_and_drops_outbound() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(80), - timeout: Duration::from_millis(60), - }; - let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); - let config_b = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(2); - let (platform_b, outbound_b, status_b) = TestPlatform::new(1); - - let signing_a = platform_a.signing_public_key().clone(); - let signing_b = platform_b.signing_public_key().clone(); - let encap_a = platform_a.encapsulation_public_key().clone(); - let encap_b = platform_b.encapsulation_public_key().clone(); - let peer_a = platform_a.xid(); - let peer_b = platform_b.xid(); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let drop_flag = Arc::new(AtomicBool::new(false)); - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_gated_forwarder(outbound_b, handle_a.clone(), drop_flag.clone()); - - handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); - handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); - - handle_a.connect(peer_b).unwrap(); - - await_status(&status_a, peer_b, PeerStage::Connected).await; - await_status(&status_b, peer_a, PeerStage::Connected).await; - - drop_flag.store(true, Ordering::Relaxed); - - let response = handle_a.send_request_raw( - peer_b, - RouteId(9), - CBOR::from(9u8), - RequestConfig { - timeout: Some(Duration::from_millis(200)), - }, - ); - - await_status(&status_a, peer_b, PeerStage::Disconnected).await; - - let result = tokio::time::timeout(Duration::from_millis(300), response.recv()) - .await - .unwrap(); - assert!( - matches!(result, Err(QlError::SendFailed)), - "unexpected result: {result:?}" - ); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn no_ping_pong() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(200), - timeout: Duration::from_millis(60), - }; - let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); - let config_b = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - - let signing_a = platform_a.signing_public_key().clone(); - let signing_b = platform_b.signing_public_key().clone(); - let encap_a = platform_a.encapsulation_public_key().clone(); - let encap_b = platform_b.encapsulation_public_key().clone(); - let peer_a = platform_a.xid(); - let peer_b = platform_b.xid(); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_ab_tx, heartbeat_ab_rx) = async_channel::unbounded(); - let (heartbeat_ba_tx, heartbeat_ba_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_ab_tx); - spawn_heartbeat_tap_forwarder(outbound_b, handle_a.clone(), heartbeat_ba_tx); - - handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); - handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); - - handle_a.connect(peer_b).unwrap(); - - await_status(&status_a, peer_b, PeerStage::Connected).await; - await_status(&status_b, peer_a, PeerStage::Connected).await; - - tokio::time::timeout(Duration::from_millis(300), heartbeat_ab_rx.recv()) - .await - .unwrap() - .unwrap(); - tokio::time::timeout(Duration::from_millis(200), heartbeat_ba_rx.recv()) - .await - .unwrap() - .unwrap(); - - let followup = - tokio::time::timeout(Duration::from_millis(50), heartbeat_ab_rx.recv()).await; - assert!(followup.is_err(), "unexpected heartbeat ping-pong"); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn invalid_heartbeat_ignored() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - - let signing_a = platform_a.signing_public_key().clone(); - let signing_b = platform_b.signing_public_key().clone(); - let encap_a = platform_a.encapsulation_public_key().clone(); - let encap_b = platform_b.encapsulation_public_key().clone(); - let peer_a = platform_a.xid(); - let peer_b = platform_b.xid(); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); - spawn_forwarder(outbound_b, handle_a.clone()); - - handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); - handle_b.register_peer(peer_a, signing_a.clone(), encap_a.clone()); - - handle_a.connect(peer_b).unwrap(); - - await_status(&status_a, peer_b, PeerStage::Connected).await; - await_status(&status_b, peer_a, PeerStage::Connected).await; - - let heartbeat = wire::heartbeat::encrypt_heartbeat( - QlHeader { - sender: peer_b, - recipient: peer_a, - }, - &SymmetricKey::new(), - HeartbeatBody { - message_id: MessageId(42), - valid_until: now_secs().saturating_add(30), - }, - ); - let bytes = CBOR::from(heartbeat).to_cbor_data(); - handle_a.send_incoming(bytes); - - let result = tokio::time::timeout(Duration::from_millis(50), heartbeat_rx.recv()).await; - assert!(result.is_err(), "unexpected heartbeat reply"); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn multi_peer_keepalive_disconnect_isolated() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(40), - timeout: Duration::from_millis(60), - }; - let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); - let config_b = RuntimeConfig::new(Duration::from_millis(200)); - let config_c = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let (platform_c, outbound_c, status_c) = TestPlatform::new(3); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - let peer_c = peer_identity(&platform_c); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - let (runtime_c, handle_c) = new_runtime(platform_c, config_c); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - tokio::task::spawn_local(async move { runtime_c.run().await }); - - let drop_b_to_a = Arc::new(AtomicBool::new(false)); - spawn_routed_forwarder( - outbound_a, - vec![ - (peer_b.xid, handle_b.clone()), - (peer_c.xid, handle_c.clone()), - ], - ); - spawn_routed_forwarder_with_filter(outbound_b, vec![(peer_a.xid, handle_a.clone())], { - let drop_b_to_a = drop_b_to_a.clone(); - move |record| { - !(drop_b_to_a.load(Ordering::Relaxed) && record.header.recipient == peer_a.xid) - } - }); - spawn_routed_forwarder(outbound_c, vec![(peer_a.xid, handle_a.clone())]); - - let _ = register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - let _ = register_peers(&handle_a, &handle_c, &peer_a, &peer_c); - - handle_a.connect(peer_b.xid).unwrap(); - handle_a.connect(peer_c.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_a, peer_c.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - await_status(&status_c, peer_a.xid, PeerStage::Connected).await; - - drop_b_to_a.store(true, Ordering::Relaxed); - - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - - let disconnect = - tokio::time::timeout(keep_alive.timeout + Duration::from_millis(80), async { - loop { - if let Ok(event) = status_a.recv().await { - if event.peer == peer_c.xid && event.stage == PeerStage::Disconnected { - return; - } - } - } - }) - .await; - assert!(disconnect.is_err(), "unexpected disconnect for peer C"); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn multi_peer_disconnect_drops_outbound_for_one() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(40), - timeout: Duration::from_millis(60), - }; - let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); - let config_b = RuntimeConfig::new(Duration::from_millis(200)); - let config_c = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let (platform_c, outbound_c, status_c, inbound_c) = InboundPlatform::new(3); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - let peer_c = peer_identity(&platform_c); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - let (runtime_c, handle_c) = new_runtime(platform_c, config_c); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - tokio::task::spawn_local(async move { runtime_c.run().await }); - - let drop_b_to_a = Arc::new(AtomicBool::new(false)); - spawn_routed_forwarder( - outbound_a, - vec![ - (peer_b.xid, handle_b.clone()), - (peer_c.xid, handle_c.clone()), - ], - ); - spawn_routed_forwarder_with_filter(outbound_b, vec![(peer_a.xid, handle_a.clone())], { - let drop_b_to_a = drop_b_to_a.clone(); - move |record| { - !(drop_b_to_a.load(Ordering::Relaxed) && record.header.recipient == peer_a.xid) - } - }); - spawn_routed_forwarder(outbound_c, vec![(peer_a.xid, handle_a.clone())]); - - let _ = register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - let _ = register_peers(&handle_a, &handle_c, &peer_a, &peer_c); - - handle_a.connect(peer_b.xid).unwrap(); - handle_a.connect(peer_c.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_a, peer_c.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - await_status(&status_c, peer_a.xid, PeerStage::Connected).await; - - let inbound_task = tokio::task::spawn_local(async move { - if let Ok(HandlerEvent::Request(request)) = inbound_c.recv().await { - let _ = request.respond_to.respond(55u8); - } - }); - - drop_b_to_a.store(true, Ordering::Relaxed); - - let request_b = handle_a.send_request_raw( - peer_b.xid, - RouteId(10), - CBOR::from(10u8), - RequestConfig { - timeout: Some(Duration::from_millis(200)), - }, - ); - let request_c = handle_a.send_request_raw( - peer_c.xid, - RouteId(11), - CBOR::from(11u8), - RequestConfig { - timeout: Some(Duration::from_millis(200)), - }, - ); - - let response_c = tokio::time::timeout(Duration::from_millis(200), request_c.recv()) - .await - .expect("response wait") - .expect("response channel"); - let value: u8 = response_c.try_into().unwrap(); - assert_eq!(value, 55u8); - - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - - let result_b = tokio::time::timeout(Duration::from_millis(200), request_b.recv()) - .await - .expect("response wait"); - assert!(matches!(result_b, Err(QlError::SendFailed))); - - let _ = inbound_task.await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn multi_peer_activity_is_per_peer() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(100), - timeout: Duration::from_millis(40), - }; - let config_a = RuntimeConfig::new(Duration::from_millis(200)).with_keep_alive(keep_alive); - let config_b = RuntimeConfig::new(Duration::from_millis(200)); - let config_c = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let (platform_c, outbound_c, status_c) = TestPlatform::new(3); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - let peer_c = peer_identity(&platform_c); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - let (runtime_c, handle_c) = new_runtime(platform_c, config_c); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - tokio::task::spawn_local(async move { runtime_c.run().await }); - - let drop_all_c = Arc::new(AtomicBool::new(false)); - spawn_routed_forwarder( - outbound_a, - vec![ - (peer_b.xid, handle_b.clone()), - (peer_c.xid, handle_c.clone()), - ], - ); - spawn_drop_heartbeat_forwarder(outbound_b, handle_a.clone()); - spawn_gated_forwarder(outbound_c, handle_a.clone(), drop_all_c.clone()); - - let _ = register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - let _ = register_peers(&handle_a, &handle_c, &peer_a, &peer_c); - - handle_a.connect(peer_b.xid).unwrap(); - handle_a.connect(peer_c.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_a, peer_c.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - await_status(&status_c, peer_a.xid, PeerStage::Connected).await; - - drop_all_c.store(true, Ordering::Relaxed); - - tokio::time::sleep(keep_alive.interval + Duration::from_millis(5)).await; - - handle_b.send_event_raw(peer_a.xid, RouteId(99), CBOR::from(1u8)); - - await_status(&status_a, peer_c.xid, PeerStage::Disconnected).await; - - let disconnect = - tokio::time::timeout(keep_alive.timeout + Duration::from_millis(30), async { - loop { - if let Ok(event) = status_a.recv().await { - if event.peer == peer_b.xid && event.stage == PeerStage::Disconnected { - return; - } - } - } - }) - .await; - assert!(disconnect.is_err(), "unexpected disconnect for peer B"); - }) - .await; -} diff --git a/ql/src/tests/mod.rs b/ql/src/tests/mod.rs deleted file mode 100644 index eae65d8b..00000000 --- a/ql/src/tests/mod.rs +++ /dev/null @@ -1,660 +0,0 @@ -use std::{ - future::Future, - sync::{ - atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering}, - Arc, - }, - time::Duration, -}; - -use async_channel::{Receiver, Sender}; -use bc_components::{ - MLDSAPrivateKey, MLDSAPublicKey, MLKEMPrivateKey, MLKEMPublicKey, MLDSA, MLKEM, XID, -}; -use dcbor::CBOR; -use tokio::{sync::Semaphore, task::LocalSet}; - -use crate::{ - platform::{PlatformFuture, QlPlatform, QlPlatformExt}, - runtime::{ - internal::now_secs, new_runtime, HandlerEvent, KeepAliveConfig, PeerSession, RequestConfig, - RuntimeConfig, RuntimeHandle, - }, - wire::{ - self, - handshake::HandshakeRecord, - heartbeat::HeartbeatBody, - message::{encrypt_message, MessageBody, MessageKind, Nack}, - pair, QlHeader, QlPayload, QlRecord, - }, - MessageId, QlError, RouteId, -}; - -mod handshake; -mod heartbeat; -mod persistence; -mod requests; -mod streams; -mod unpair; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum PeerStage { - Disconnected, - Initiator, - Responder, - Connected, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -struct StatusEvent { - peer: XID, - stage: PeerStage, -} - -struct TestPlatform { - signing_private: MLDSAPrivateKey, - signing_public: MLDSAPublicKey, - encapsulation_private: MLKEMPrivateKey, - encapsulation_public: MLKEMPublicKey, - outbound: Sender>, - status: Sender, - nonce_seed: u8, - nonce_counter: AtomicU8, -} - -impl TestPlatform { - fn new(seed: u8) -> (Self, Receiver>, Receiver) { - let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); - let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); - let (outbound, outbound_rx) = async_channel::unbounded(); - let (status, status_rx) = async_channel::unbounded(); - ( - Self { - signing_private, - signing_public, - encapsulation_private, - encapsulation_public, - outbound, - status, - nonce_seed: seed, - nonce_counter: AtomicU8::new(0), - }, - outbound_rx, - status_rx, - ) - } - - fn signing_public_key(&self) -> &MLDSAPublicKey { - &self.signing_public - } - - fn encapsulation_public_key(&self) -> &MLKEMPublicKey { - &self.encapsulation_public - } -} - -impl QlPlatform for TestPlatform { - fn signing_private_key(&self) -> &MLDSAPrivateKey { - &self.signing_private - } - - fn signing_public_key(&self) -> &MLDSAPublicKey { - &self.signing_public - } - - fn encapsulation_private_key(&self) -> &MLKEMPrivateKey { - &self.encapsulation_private - } - - fn encapsulation_public_key(&self) -> &MLKEMPublicKey { - &self.encapsulation_public - } - - fn fill_random_bytes(&self, data: &mut [u8]) { - let value = self - .nonce_seed - .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); - data.fill(value); - } - - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { - let outbound = self.outbound.clone(); - Box::pin(async move { - outbound - .send(message) - .await - .map_err(|_| QlError::InvalidPayload) - }) - } - - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { - Box::pin(tokio::time::sleep(duration)) - } - - fn load_peers(&self) -> PlatformFuture<'_, Vec> { - Box::pin(async { Vec::new() }) - } - - fn persist_peers(&self, _peers: Vec) {} - - fn handle_peer_status(&self, peer: XID, session: &PeerSession) { - let stage = match session { - PeerSession::Disconnected => PeerStage::Disconnected, - PeerSession::Initiator { .. } => PeerStage::Initiator, - PeerSession::Responder { .. } => PeerStage::Responder, - PeerSession::Connected { .. } => PeerStage::Connected, - }; - let _ = self.status.try_send(StatusEvent { peer, stage }); - } - - fn handle_inbound(&self, _event: crate::runtime::HandlerEvent) {} -} - -struct BlockingPlatform { - signing_private: MLDSAPrivateKey, - signing_public: MLDSAPublicKey, - encapsulation_private: MLKEMPrivateKey, - encapsulation_public: MLKEMPublicKey, - outbound: Sender>, - status: Sender, - nonce_seed: u8, - nonce_counter: AtomicU8, - write_gate: Arc, -} - -struct InboundPlatform { - signing_private: MLDSAPrivateKey, - signing_public: MLDSAPublicKey, - encapsulation_private: MLKEMPrivateKey, - encapsulation_public: MLKEMPublicKey, - outbound: Sender>, - status: Sender, - inbound: Sender, - nonce_seed: u8, - nonce_counter: AtomicU8, -} - -impl InboundPlatform { - fn new( - seed: u8, - ) -> ( - Self, - Receiver>, - Receiver, - Receiver, - ) { - let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); - let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); - let (outbound, outbound_rx) = async_channel::unbounded(); - let (status, status_rx) = async_channel::unbounded(); - let (inbound, inbound_rx) = async_channel::unbounded(); - ( - Self { - signing_private, - signing_public, - encapsulation_private, - encapsulation_public, - outbound, - status, - inbound, - nonce_seed: seed, - nonce_counter: AtomicU8::new(0), - }, - outbound_rx, - status_rx, - inbound_rx, - ) - } -} - -impl QlPlatform for InboundPlatform { - fn signing_private_key(&self) -> &MLDSAPrivateKey { - &self.signing_private - } - - fn signing_public_key(&self) -> &MLDSAPublicKey { - &self.signing_public - } - - fn encapsulation_private_key(&self) -> &MLKEMPrivateKey { - &self.encapsulation_private - } - - fn encapsulation_public_key(&self) -> &MLKEMPublicKey { - &self.encapsulation_public - } - - fn fill_random_bytes(&self, data: &mut [u8]) { - let value = self - .nonce_seed - .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); - data.fill(value); - } - - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { - let outbound = self.outbound.clone(); - Box::pin(async move { - outbound - .send(message) - .await - .map_err(|_| QlError::InvalidPayload) - }) - } - - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { - Box::pin(tokio::time::sleep(duration)) - } - - fn load_peers(&self) -> PlatformFuture<'_, Vec> { - Box::pin(async { Vec::new() }) - } - - fn persist_peers(&self, _peers: Vec) {} - - fn handle_peer_status(&self, peer: XID, session: &PeerSession) { - let stage = match session { - PeerSession::Disconnected => PeerStage::Disconnected, - PeerSession::Initiator { .. } => PeerStage::Initiator, - PeerSession::Responder { .. } => PeerStage::Responder, - PeerSession::Connected { .. } => PeerStage::Connected, - }; - let _ = self.status.try_send(StatusEvent { peer, stage }); - } - - fn handle_inbound(&self, event: HandlerEvent) { - let _ = self.inbound.try_send(event); - } -} - -impl BlockingPlatform { - fn new( - seed: u8, - ) -> ( - Self, - Receiver>, - Receiver, - Arc, - ) { - let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); - let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); - let (outbound, outbound_rx) = async_channel::unbounded(); - let (status, status_rx) = async_channel::unbounded(); - let write_gate = Arc::new(Semaphore::new(0)); - ( - Self { - signing_private, - signing_public, - encapsulation_private, - encapsulation_public, - outbound, - status, - nonce_seed: seed, - nonce_counter: AtomicU8::new(0), - write_gate: write_gate.clone(), - }, - outbound_rx, - status_rx, - write_gate, - ) - } -} - -impl QlPlatform for BlockingPlatform { - fn signing_private_key(&self) -> &MLDSAPrivateKey { - &self.signing_private - } - - fn signing_public_key(&self) -> &MLDSAPublicKey { - &self.signing_public - } - - fn encapsulation_private_key(&self) -> &MLKEMPrivateKey { - &self.encapsulation_private - } - - fn encapsulation_public_key(&self) -> &MLKEMPublicKey { - &self.encapsulation_public - } - - fn fill_random_bytes(&self, data: &mut [u8]) { - let value = self - .nonce_seed - .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); - data.fill(value); - } - - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { - let outbound = self.outbound.clone(); - let write_gate = self.write_gate.clone(); - Box::pin(async move { - let _permit = write_gate.acquire().await.unwrap(); - outbound - .send(message) - .await - .map_err(|_| QlError::InvalidPayload) - }) - } - - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { - Box::pin(tokio::time::sleep(duration)) - } - - fn load_peers(&self) -> PlatformFuture<'_, Vec> { - Box::pin(async { Vec::new() }) - } - - fn persist_peers(&self, _peers: Vec) {} - - fn handle_peer_status(&self, peer: XID, session: &PeerSession) { - let stage = match session { - PeerSession::Disconnected => PeerStage::Disconnected, - PeerSession::Initiator { .. } => PeerStage::Initiator, - PeerSession::Responder { .. } => PeerStage::Responder, - PeerSession::Connected { .. } => PeerStage::Connected, - }; - let _ = self.status.try_send(StatusEvent { peer, stage }); - } - - fn handle_inbound(&self, _event: crate::runtime::HandlerEvent) {} -} - -async fn run_local_test(future: F) -where - F: Future, -{ - let local = LocalSet::new(); - local.run_until(future).await; -} - -fn spawn_forwarder(outbound: Receiver>, handle: RuntimeHandle) { - tokio::task::spawn_local(async move { - while let Ok(bytes) = outbound.recv().await { - handle.send_incoming(bytes); - } - }); -} - -fn is_heartbeat(bytes: &[u8]) -> bool { - let Ok(record) = CBOR::try_from_data(bytes).and_then(QlRecord::try_from) else { - return false; - }; - matches!(record.payload, QlPayload::Heartbeat(_)) -} - -fn is_transfer(bytes: &[u8]) -> bool { - let Ok(record) = CBOR::try_from_data(bytes).and_then(QlRecord::try_from) else { - return false; - }; - matches!(record.payload, QlPayload::Transfer(_)) -} - -fn spawn_heartbeat_tap_forwarder( - outbound: Receiver>, - handle: RuntimeHandle, - heartbeat_tx: Sender<()>, -) { - tokio::task::spawn_local(async move { - while let Ok(bytes) = outbound.recv().await { - if is_heartbeat(&bytes) { - let _ = heartbeat_tx.send(()).await; - } - handle.send_incoming(bytes); - } - }); -} - -fn spawn_drop_heartbeat_forwarder(outbound: Receiver>, handle: RuntimeHandle) { - tokio::task::spawn_local(async move { - while let Ok(bytes) = outbound.recv().await { - if is_heartbeat(&bytes) { - continue; - } - handle.send_incoming(bytes); - } - }); -} - -fn spawn_drop_first_transfer_forwarder(outbound: Receiver>, handle: RuntimeHandle) { - tokio::task::spawn_local(async move { - let mut dropped = false; - while let Ok(bytes) = outbound.recv().await { - if !dropped && is_transfer(&bytes) { - dropped = true; - continue; - } - handle.send_incoming(bytes); - } - }); -} - -fn spawn_duplicate_first_transfer_forwarder(outbound: Receiver>, handle: RuntimeHandle) { - tokio::task::spawn_local(async move { - let mut duplicated = false; - while let Ok(bytes) = outbound.recv().await { - if !duplicated && is_transfer(&bytes) { - duplicated = true; - handle.send_incoming(bytes.clone()); - } - handle.send_incoming(bytes); - } - }); -} - -fn spawn_gated_forwarder( - outbound: Receiver>, - handle: RuntimeHandle, - drop_flag: Arc, -) { - tokio::task::spawn_local(async move { - while let Ok(bytes) = outbound.recv().await { - if drop_flag.load(Ordering::Relaxed) { - continue; - } - handle.send_incoming(bytes); - } - }); -} - -fn spawn_routed_forwarder(outbound: Receiver>, routes: Vec<(XID, RuntimeHandle)>) { - spawn_routed_forwarder_with_filter(outbound, routes, |_| true); -} - -fn spawn_routed_forwarder_with_filter( - outbound: Receiver>, - routes: Vec<(XID, RuntimeHandle)>, - filter: F, -) where - F: Fn(&QlRecord) -> bool + Send + Sync + 'static, -{ - tokio::task::spawn_local(async move { - while let Ok(bytes) = outbound.recv().await { - let Ok(record) = CBOR::try_from_data(&bytes).and_then(QlRecord::try_from) else { - continue; - }; - if !filter(&record) { - continue; - } - if let Some((_, handle)) = routes - .iter() - .find(|(peer, _)| *peer == record.header.recipient) - { - handle.send_incoming(bytes); - } - } - }); -} - -#[derive(Clone)] -struct PeerIdentity { - xid: XID, - signing_key: MLDSAPublicKey, - encapsulation_key: MLKEMPublicKey, -} - -fn peer_identity(platform: &impl QlPlatformExt) -> PeerIdentity { - PeerIdentity { - xid: platform.xid(), - signing_key: platform.signing_public_key().clone(), - encapsulation_key: platform.encapsulation_public_key().clone(), - } -} - -fn register_peers( - handle_a: &RuntimeHandle, - handle_b: &RuntimeHandle, - identity_a: &PeerIdentity, - identity_b: &PeerIdentity, -) -> (XID, XID) { - let peer_a = identity_a.xid; - let peer_b = identity_b.xid; - handle_a.register_peer( - peer_b, - identity_b.signing_key.clone(), - identity_b.encapsulation_key.clone(), - ); - handle_b.register_peer( - peer_a, - identity_a.signing_key.clone(), - identity_a.encapsulation_key.clone(), - ); - (peer_a, peer_b) -} - -async fn await_status( - receiver: &Receiver, - peer: XID, - stage: PeerStage, -) -> StatusEvent { - tokio::time::timeout(Duration::from_secs(1), async { - loop { - if let Ok(event) = receiver.recv().await { - if event.peer == peer && event.stage == stage { - return event; - } - } - } - }) - .await - .unwrap() -} - -#[test] -fn protocol_record_size_breakdown() { - let (platform_a, _outbound_a, _status_a) = TestPlatform::new(1); - let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); - - let initiator = platform_a.xid(); - let responder = platform_b.xid(); - - let (hello, initiator_secret) = wire::handshake::build_hello( - &platform_a, - initiator, - responder, - platform_b.encapsulation_public_key(), - ) - .unwrap(); - let hello_record = QlRecord { - header: QlHeader { - sender: initiator, - recipient: responder, - }, - payload: QlPayload::Handshake(HandshakeRecord::Hello(hello.clone())), - }; - let hello_size = CBOR::from(hello_record).to_cbor_data().len(); - - let (hello_reply, responder_secrets) = wire::handshake::respond_hello( - &platform_b, - initiator, - responder, - platform_a.encapsulation_public_key(), - &hello, - ) - .unwrap(); - let reply_record = QlRecord { - header: QlHeader { - sender: responder, - recipient: initiator, - }, - payload: QlPayload::Handshake(HandshakeRecord::HelloReply(hello_reply.clone())), - }; - let reply_size = CBOR::from(reply_record).to_cbor_data().len(); - - let (confirm, session_key) = wire::handshake::build_confirm( - &platform_a, - initiator, - responder, - platform_b.signing_public_key(), - &hello, - &hello_reply, - &initiator_secret, - ) - .unwrap(); - let _session_key_b = wire::handshake::finalize_confirm( - initiator, - responder, - platform_a.signing_public_key(), - &hello, - &hello_reply, - &confirm, - &responder_secrets, - ) - .unwrap(); - let confirm_record = QlRecord { - header: QlHeader { - sender: initiator, - recipient: responder, - }, - payload: QlPayload::Handshake(HandshakeRecord::Confirm(confirm)), - }; - let confirm_size = CBOR::from(confirm_record).to_cbor_data().len(); - - let pair_record = pair::build_pair_request( - &platform_a, - responder, - platform_b.encapsulation_public_key(), - MessageId(1), - Duration::from_secs(60), - ) - .unwrap(); - let pair_size = CBOR::from(pair_record).to_cbor_data().len(); - - let message_record = encrypt_message( - QlHeader { - sender: initiator, - recipient: responder, - }, - &session_key, - MessageBody { - message_id: MessageId(2), - valid_until: now_secs().saturating_add(60), - kind: MessageKind::Event, - route_id: RouteId(1), - payload: CBOR::null(), - }, - ); - let message_size = CBOR::from(message_record).to_cbor_data().len(); - - let heartbeat_record = wire::heartbeat::encrypt_heartbeat( - QlHeader { - sender: initiator, - recipient: responder, - }, - &session_key, - HeartbeatBody { - message_id: MessageId(3), - valid_until: now_secs().saturating_add(60), - }, - ); - let heartbeat_size = CBOR::from(heartbeat_record).to_cbor_data().len(); - - let print_size = |label: &str, size: usize| { - println!("{label:<21}: {size} bytes"); - }; - - print_size("ql size hello", hello_size); - print_size("ql size hello_reply", reply_size); - print_size("ql size confirm", confirm_size); - print_size("ql size pair_request", pair_size); - print_size("ql size message", message_size); - print_size("ql size heartbeat", heartbeat_size); -} diff --git a/ql/src/tests/persistence.rs b/ql/src/tests/persistence.rs deleted file mode 100644 index 98a4411c..00000000 --- a/ql/src/tests/persistence.rs +++ /dev/null @@ -1,228 +0,0 @@ -use std::sync::atomic::{AtomicU8, Ordering}; - -use async_channel::{Receiver, Sender}; -use bc_components::{ - MLDSAPrivateKey, MLDSAPublicKey, MLKEMPrivateKey, MLKEMPublicKey, MLDSA, MLKEM, XID, -}; - -use super::*; - -type PersistPlatformParts = ( - PersistPlatform, - Receiver>, - Receiver, - Receiver>, -); - -struct PersistPlatform { - signing_private: MLDSAPrivateKey, - signing_public: MLDSAPublicKey, - encapsulation_private: MLKEMPrivateKey, - encapsulation_public: MLKEMPublicKey, - outbound: Sender>, - status: Sender, - persisted: Sender>, - loaded_peers: Vec, - nonce_seed: u8, - nonce_counter: AtomicU8, -} - -impl PersistPlatform { - fn new(seed: u8, loaded_peers: Vec) -> PersistPlatformParts { - let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); - let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); - let (outbound, outbound_rx) = async_channel::unbounded(); - let (status, status_rx) = async_channel::unbounded(); - let (persisted, persisted_rx) = async_channel::unbounded(); - ( - Self { - signing_private, - signing_public, - encapsulation_private, - encapsulation_public, - outbound, - status, - persisted, - loaded_peers, - nonce_seed: seed, - nonce_counter: AtomicU8::new(0), - }, - outbound_rx, - status_rx, - persisted_rx, - ) - } -} - -impl QlPlatform for PersistPlatform { - fn signing_private_key(&self) -> &MLDSAPrivateKey { - &self.signing_private - } - - fn signing_public_key(&self) -> &MLDSAPublicKey { - &self.signing_public - } - - fn encapsulation_private_key(&self) -> &MLKEMPrivateKey { - &self.encapsulation_private - } - - fn encapsulation_public_key(&self) -> &MLKEMPublicKey { - &self.encapsulation_public - } - - fn fill_random_bytes(&self, data: &mut [u8]) { - let value = self - .nonce_seed - .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); - data.fill(value); - } - - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { - let outbound = self.outbound.clone(); - Box::pin(async move { - outbound - .send(message) - .await - .map_err(|_| QlError::InvalidPayload) - }) - } - - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { - Box::pin(tokio::time::sleep(duration)) - } - - fn load_peers(&self) -> PlatformFuture<'_, Vec> { - let peers = self.loaded_peers.clone(); - Box::pin(async move { peers }) - } - - fn persist_peers(&self, peers: Vec) { - let _ = self.persisted.try_send(peers); - } - - fn handle_peer_status(&self, peer: XID, session: &PeerSession) { - let stage = match session { - PeerSession::Disconnected => PeerStage::Disconnected, - PeerSession::Initiator { .. } => PeerStage::Initiator, - PeerSession::Responder { .. } => PeerStage::Responder, - PeerSession::Connected { .. } => PeerStage::Connected, - }; - let _ = self.status.try_send(StatusEvent { peer, stage }); - } - - fn handle_inbound(&self, _event: HandlerEvent) {} -} - -#[tokio::test(flavor = "current_thread")] -async fn register_peer_persists_snapshot() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, _outbound_a, _status_a, persisted_a) = PersistPlatform::new(1, Vec::new()); - let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); - let peer_b = platform_b.xid(); - let signing_b = platform_b.signing_public_key().clone(); - let encap_b = platform_b.encapsulation_public_key().clone(); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - tokio::task::spawn_local(async move { runtime_a.run().await }); - - handle_a.register_peer(peer_b, signing_b.clone(), encap_b.clone()); - - let snapshot = tokio::time::timeout(Duration::from_secs(1), persisted_a.recv()) - .await - .unwrap() - .unwrap(); - assert_eq!( - snapshot, - vec![crate::Peer { - peer: peer_b, - signing_key: signing_b, - encapsulation_key: encap_b, - }] - ); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn loaded_peers_can_connect_without_register() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_b = peer_identity(&platform_b); - - let (platform_a, outbound_a, status_a, _persisted_a) = PersistPlatform::new( - 1, - vec![crate::Peer { - peer: peer_b.xid, - signing_key: peer_b.signing_key.clone(), - encapsulation_key: peer_b.encapsulation_key.clone(), - }], - ); - let peer_a = peer_identity(&platform_a); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = - new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - handle_b.register_peer( - peer_a.xid, - peer_a.signing_key.clone(), - peer_a.encapsulation_key.clone(), - ); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn pairing_persists_snapshot() { - run_local_test(async { - let (platform_a, _outbound_a, _status_a) = TestPlatform::new(1); - let peer_a = peer_identity(&platform_a); - - let (platform_b, _outbound_b, _status_b, persisted_b) = PersistPlatform::new(2, Vec::new()); - let peer_b = peer_identity(&platform_b); - - let pairing_message = pair::build_pair_request( - &platform_a, - peer_b.xid, - &peer_b.encapsulation_key, - MessageId(1), - Duration::from_secs(1), - ) - .unwrap(); - let pairing_bytes = CBOR::from(pairing_message).to_cbor_data(); - - let (runtime_b, handle_b) = - new_runtime(platform_b, RuntimeConfig::new(Duration::from_millis(200))); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - handle_b.send_incoming(pairing_bytes); - - let snapshot = tokio::time::timeout(Duration::from_secs(1), persisted_b.recv()) - .await - .unwrap() - .unwrap(); - assert_eq!( - snapshot, - vec![crate::Peer { - peer: peer_a.xid, - signing_key: peer_a.signing_key, - encapsulation_key: peer_a.encapsulation_key, - }] - ); - }) - .await; -} diff --git a/ql/src/tests/requests.rs b/ql/src/tests/requests.rs deleted file mode 100644 index 23bab759..00000000 --- a/ql/src/tests/requests.rs +++ /dev/null @@ -1,446 +0,0 @@ -use super::*; - -fn spawn_delayed_message_forwarder( - outbound: Receiver>, - handle: RuntimeHandle, - delay: Duration, -) { - tokio::task::spawn_local(async move { - while let Ok(bytes) = outbound.recv().await { - let is_message = CBOR::try_from_data(&bytes) - .and_then(QlRecord::try_from) - .is_ok_and(|record| matches!(record.payload, QlPayload::Message(_))); - if is_message { - tokio::time::sleep(delay).await; - } - handle.send_incoming(bytes); - } - }); -} - -#[tokio::test(flavor = "current_thread")] -async fn request_response_round_trip() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_request_timeout(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let inbound_task = tokio::task::spawn_local(async move { - if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { - let _ = request.respond_to.respond(99u8); - } - }); - - let response = handle_a.send_request_raw( - peer_b.xid, - RouteId(7), - CBOR::from(12u8), - RequestConfig::default(), - ); - - let response = response.recv().await.unwrap(); - let value: u8 = response.try_into().unwrap(); - assert_eq!(value, 99u8); - let _ = inbound_task.await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn request_timeout_returns_error() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_request_timeout(Duration::from_millis(30)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let ticket = handle_a.send_request_raw( - peer_b.xid, - RouteId(1), - CBOR::from(1u8), - RequestConfig { - timeout: Some(Duration::from_millis(30)), - }, - ); - - let result = tokio::time::timeout(Duration::from_millis(200), ticket.recv()) - .await - .unwrap(); - assert!(matches!(result, Err(QlError::Timeout))); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn request_nack_resolves_pending() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_request_timeout(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let inbound_task = tokio::task::spawn_local(async move { - if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { - let _ = request.respond_to.respond_nack(Nack::InvalidPayload); - } - }); - - let response = handle_a.send_request_raw( - peer_b.xid, - RouteId(2), - CBOR::from(2u8), - RequestConfig::default(), - ); - - let result = response.recv().await; - assert!(matches!( - result, - Err(QlError::Nack { - nack: Nack::InvalidPayload, - .. - }) - )); - let _ = inbound_task.await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn request_dispatches_to_platform_callback() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_request_timeout(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let inbound_task = tokio::task::spawn_local(async move { - if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { - let _ = request.respond_to.respond(7u8); - } - }); - - let ticket = handle_a.send_request_raw( - peer_b.xid, - RouteId(3), - CBOR::from(1u8), - RequestConfig::default(), - ); - - let response = ticket.recv().await.unwrap(); - let value: u8 = response.try_into().unwrap(); - assert_eq!(value, 7u8); - let _ = inbound_task.await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn replayed_message_is_ignored() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - tokio::task::spawn_local({ - let handle_b = handle_b.clone(); - async move { - while let Ok(bytes) = outbound_a.recv().await { - let Ok(record) = CBOR::try_from_data(&bytes).and_then(QlRecord::try_from) - else { - handle_b.send_incoming(bytes); - continue; - }; - if matches!(record.payload, QlPayload::Message(_)) { - handle_b.send_incoming(bytes.clone()); - handle_b.send_incoming(bytes); - continue; - } - handle_b.send_incoming(bytes); - } - } - }); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - handle_a.send_event_raw(peer_b.xid, RouteId(9), CBOR::from(1u8)); - - let first = tokio::time::timeout(Duration::from_secs(1), inbound_b.recv()) - .await - .unwrap() - .unwrap(); - match first { - HandlerEvent::Event(event) => { - assert_eq!(event.message.route_id, RouteId(9)); - } - HandlerEvent::Request(_) => panic!("unexpected request"), - HandlerEvent::UploadRequest(_) => panic!("unexpected upload request"), - } - - let second = tokio::time::timeout(Duration::from_millis(50), inbound_b.recv()).await; - assert!(second.is_err(), "replay delivered a second event"); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn expired_request_returns_expired_nack() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_message_expiration(Duration::from_secs(1)) - .with_request_timeout(Duration::from_secs(3)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_delayed_message_forwarder(outbound_a, handle_b.clone(), Duration::from_millis(2000)); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let ticket = handle_a.send_request_raw( - peer_b.xid, - RouteId(4), - CBOR::from(1u8), - RequestConfig::default(), - ); - - let result = tokio::time::timeout(Duration::from_secs(5), ticket.recv()) - .await - .unwrap(); - assert!(matches!( - result, - Err(QlError::Nack { - nack: Nack::Expired, - .. - }) - )); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn expired_event_does_not_send_nack() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_message_expiration(Duration::from_secs(1)) - .with_request_timeout(Duration::from_secs(3)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_delayed_message_forwarder(outbound_a, handle_b.clone(), Duration::from_millis(1500)); - - let (message_tx, message_rx) = async_channel::unbounded(); - tokio::task::spawn_local({ - let handle_a = handle_a.clone(); - async move { - while let Ok(bytes) = outbound_b.recv().await { - let is_message = CBOR::try_from_data(&bytes) - .and_then(QlRecord::try_from) - .is_ok_and(|record| matches!(record.payload, QlPayload::Message(_))); - if is_message { - let _ = message_tx.send(()).await; - } - handle_a.send_incoming(bytes); - } - } - }); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - handle_a.send_event_raw(peer_b.xid, RouteId(10), CBOR::from(2u8)); - - let unexpected = tokio::time::timeout(Duration::from_secs(3), message_rx.recv()).await; - assert!( - unexpected.is_err(), - "expired event should not generate nack" - ); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn session_reset_fails_queued_request() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(60)); - let (platform_a, outbound_a, status_a, write_gate) = BlockingPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - let (reset_hello, _secret) = wire::handshake::build_hello( - &platform_b, - peer_b.xid, - peer_a.xid, - &peer_a.encapsulation_key, - ) - .unwrap(); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - write_gate.add_permits(2); - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let blocked = handle_a.send_request_raw( - peer_b.xid, - RouteId(12), - CBOR::from(12u8), - RequestConfig { - timeout: Some(Duration::from_millis(200)), - }, - ); - let queued = handle_a.send_request_raw( - peer_b.xid, - RouteId(13), - CBOR::from(13u8), - RequestConfig { - timeout: Some(Duration::from_millis(200)), - }, - ); - - let hello_message = QlRecord { - header: QlHeader { - sender: peer_b.xid, - recipient: peer_a.xid, - }, - payload: QlPayload::Handshake(HandshakeRecord::Hello(reset_hello)), - }; - handle_a.send_incoming(CBOR::from(hello_message).to_cbor_data()); - - await_status(&status_a, peer_b.xid, PeerStage::Responder).await; - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - - let queued_result = tokio::time::timeout(Duration::from_millis(300), queued.recv()) - .await - .unwrap(); - assert!(matches!(queued_result, Err(QlError::Timeout))); - - let blocked_result = tokio::time::timeout(Duration::from_millis(300), blocked.recv()) - .await - .unwrap(); - assert!(matches!(blocked_result, Err(QlError::Timeout))); - }) - .await; -} diff --git a/ql/src/tests/streams.rs b/ql/src/tests/streams.rs deleted file mode 100644 index 0b04ec8a..00000000 --- a/ql/src/tests/streams.rs +++ /dev/null @@ -1,552 +0,0 @@ -use super::*; - -#[tokio::test(flavor = "current_thread")] -async fn request_stream_round_trip() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_request_timeout(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { - let mut stream = request.respond_to.respond_stream(7u8).unwrap(); - stream.write_next(vec![1, 2, 3]).await.unwrap(); - stream.write_next(vec![4, 5]).await.unwrap(); - stream.finish().await.unwrap(); - } - }); - - let mut response = handle_a - .send_request_stream_raw( - peer_b.xid, - RouteId(201), - CBOR::from(1u8), - RequestConfig::default(), - ) - .recv() - .await - .unwrap(); - - assert_eq!(response.meta, CBOR::from(7u8)); - assert_eq!( - response.body.next_chunk().await.unwrap(), - Some(vec![1, 2, 3]) - ); - assert_eq!(response.body.next_chunk().await.unwrap(), Some(vec![4, 5])); - assert_eq!(response.body.next_chunk().await.unwrap(), None); - - let _ = responder_task.await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn dropping_inbound_stream_cancels_sender() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_request_timeout(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { - let mut stream = request.respond_to.respond_stream(1u8).unwrap(); - stream.write_next(vec![9]).await.unwrap(); - stream.finish().await - } else { - Err(QlError::Cancelled) - } - }); - - let mut response = handle_a - .send_request_stream_raw( - peer_b.xid, - RouteId(202), - CBOR::from(2u8), - RequestConfig::default(), - ) - .recv() - .await - .unwrap(); - - assert_eq!(response.body.next_chunk().await.unwrap(), Some(vec![9])); - drop(response); - - let result = tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - assert!(result.is_ok()); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn sender_cancel_surfaces_error_on_receiver() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_request_timeout(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { - let mut stream = request.respond_to.respond_stream(1u8).unwrap(); - stream.write_next(vec![7]).await.unwrap(); - stream.cancel().await.unwrap(); - } - }); - - let mut response = handle_a - .send_request_stream_raw( - peer_b.xid, - RouteId(203), - CBOR::from(3u8), - RequestConfig::default(), - ) - .recv() - .await - .unwrap(); - - let first = response.body.next_chunk().await; - match first { - Ok(Some(_)) => { - let second = response.body.next_chunk().await; - assert!(matches!(second, Err(QlError::TransferCancelled { .. }))); - } - Err(QlError::TransferCancelled { .. }) => {} - other => panic!("unexpected first chunk result: {other:?}"), - } - - let _ = responder_task.await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn request_upload_round_trip() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_request_timeout(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - if let Ok(HandlerEvent::UploadRequest(request)) = inbound_b.recv().await { - assert_eq!(request.route_id, RouteId(204)); - assert_eq!(request.meta, CBOR::from("meta")); - let mut body = request.body; - let mut bytes = Vec::new(); - while let Some(chunk) = body.next_chunk().await.unwrap() { - bytes.extend(chunk); - } - assert_eq!(bytes, vec![1, 2, 3, 4]); - request.respond_to.respond(4u8).unwrap(); - } - }); - - let mut upload = handle_a - .send_request_upload_raw( - peer_b.xid, - RouteId(204), - CBOR::from("meta"), - RequestConfig::default(), - ) - .await - .unwrap(); - upload.transfer.write_next(vec![1, 2]).await.unwrap(); - upload.transfer.write_next(vec![3, 4]).await.unwrap(); - upload.transfer.finish().await.unwrap(); - let response = upload.response.recv().await.unwrap(); - - assert_eq!(response, CBOR::from(4u8)); - - let _ = responder_task.await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn duplicate_open_response_resends_ack_without_cancelling_stream() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_request_timeout(Duration::from_millis(30)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_drop_first_transfer_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { - let mut stream = request.respond_to.respond_stream(7u8).unwrap(); - stream.write_next(vec![1, 2, 3]).await.unwrap(); - stream.write_next(vec![4, 5]).await.unwrap(); - stream.finish().await.unwrap(); - } - }); - - let mut response = tokio::time::timeout( - Duration::from_secs(1), - handle_a - .send_request_stream_raw( - peer_b.xid, - RouteId(205), - CBOR::from(1u8), - RequestConfig { - timeout: Some(Duration::from_millis(200)), - }, - ) - .recv(), - ) - .await - .unwrap() - .unwrap(); - - assert_eq!(response.meta, CBOR::from(7u8)); - assert_eq!( - response.body.next_chunk().await.unwrap(), - Some(vec![1, 2, 3]) - ); - assert_eq!(response.body.next_chunk().await.unwrap(), Some(vec![4, 5])); - assert_eq!(response.body.next_chunk().await.unwrap(), None); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn duplicate_open_request_retries_without_redispatching_upload() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_request_timeout(Duration::from_millis(30)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_drop_first_transfer_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - if let Ok(HandlerEvent::UploadRequest(request)) = inbound_b.recv().await { - assert_eq!(request.route_id, RouteId(206)); - assert_eq!(request.meta, CBOR::from("meta")); - let mut body = request.body; - let mut bytes = Vec::new(); - while let Some(chunk) = body.next_chunk().await.unwrap() { - bytes.extend(chunk); - } - assert_eq!(bytes, vec![1, 2, 3, 4]); - request.respond_to.respond(4u8).unwrap(); - } - - let second = tokio::time::timeout(Duration::from_millis(150), inbound_b.recv()).await; - assert!(second.is_err(), "duplicate upload request dispatched"); - }); - - let mut upload = tokio::time::timeout( - Duration::from_secs(1), - handle_a.send_request_upload_raw( - peer_b.xid, - RouteId(206), - CBOR::from("meta"), - RequestConfig { - timeout: Some(Duration::from_millis(200)), - }, - ), - ) - .await - .unwrap() - .unwrap(); - upload.transfer.write_next(vec![1, 2]).await.unwrap(); - upload.transfer.write_next(vec![3, 4]).await.unwrap(); - upload.transfer.finish().await.unwrap(); - let response = tokio::time::timeout(Duration::from_secs(1), upload.response.recv()) - .await - .unwrap() - .unwrap(); - - assert_eq!(response, CBOR::from(4u8)); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn replayed_transfer_open_request_is_silently_ignored() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_request_timeout(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let transfer_count = Arc::new(AtomicUsize::new(0)); - spawn_duplicate_first_transfer_forwarder(outbound_a, handle_b.clone()); - tokio::task::spawn_local({ - let handle_a = handle_a.clone(); - let transfer_count = transfer_count.clone(); - async move { - while let Ok(bytes) = outbound_b.recv().await { - if is_transfer(&bytes) { - transfer_count.fetch_add(1, Ordering::Relaxed); - } - handle_a.send_incoming(bytes); - } - } - }); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let upload = handle_a - .send_request_upload_raw( - peer_b.xid, - RouteId(207), - CBOR::from("meta"), - RequestConfig::default(), - ) - .await - .unwrap(); - - let request = match tokio::time::timeout(Duration::from_secs(1), inbound_b.recv()) - .await - .unwrap() - .unwrap() - { - HandlerEvent::UploadRequest(request) => request, - other => panic!("unexpected inbound event: {other:?}"), - }; - - assert_eq!(request.route_id, RouteId(207)); - assert_eq!(request.meta, CBOR::from("meta")); - - let second = tokio::time::timeout(Duration::from_millis(50), inbound_b.recv()).await; - assert!(second.is_err(), "replayed transfer redispatched upload"); - - tokio::time::timeout(Duration::from_secs(1), async { - while transfer_count.load(Ordering::Relaxed) == 0 { - tokio::task::yield_now().await; - } - }) - .await - .unwrap(); - tokio::time::sleep(Duration::from_millis(50)).await; - assert_eq!( - transfer_count.load(Ordering::Relaxed), - 1, - "replayed transfer produced extra ack" - ); - - drop(upload); - drop(request); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn replayed_transfer_open_response_is_silently_ignored() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_request_timeout(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let transfer_count = Arc::new(AtomicUsize::new(0)); - tokio::task::spawn_local({ - let handle_b = handle_b.clone(); - let transfer_count = transfer_count.clone(); - async move { - while let Ok(bytes) = outbound_a.recv().await { - if is_transfer(&bytes) { - transfer_count.fetch_add(1, Ordering::Relaxed); - } - handle_b.send_incoming(bytes); - } - } - }); - spawn_duplicate_first_transfer_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - if let Ok(HandlerEvent::Request(request)) = inbound_b.recv().await { - let stream = request.respond_to.respond_stream(7u8).unwrap(); - tokio::time::sleep(Duration::from_millis(250)).await; - drop(stream); - } - }); - - let response = tokio::time::timeout( - Duration::from_secs(1), - handle_a - .send_request_stream_raw( - peer_b.xid, - RouteId(208), - CBOR::from(1u8), - RequestConfig::default(), - ) - .recv(), - ) - .await - .unwrap() - .unwrap(); - - assert_eq!(response.meta, CBOR::from(7u8)); - - tokio::time::timeout(Duration::from_secs(1), async { - while transfer_count.load(Ordering::Relaxed) == 0 { - tokio::task::yield_now().await; - } - }) - .await - .unwrap(); - tokio::time::sleep(Duration::from_millis(50)).await; - assert_eq!( - transfer_count.load(Ordering::Relaxed), - 1, - "replayed transfer produced extra ack" - ); - - drop(response); - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} diff --git a/ql/src/tests/unpair.rs b/ql/src/tests/unpair.rs deleted file mode 100644 index 612f7cbb..00000000 --- a/ql/src/tests/unpair.rs +++ /dev/null @@ -1,160 +0,0 @@ -use super::*; - -#[tokio::test(flavor = "current_thread")] -async fn connected_unpair_removes_peer_on_both_sides() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_request_timeout(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - handle_a.connect(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - handle_a.unpair(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - await_status(&status_b, peer_a.xid, PeerStage::Disconnected).await; - - let result_a = handle_a - .send_request_raw( - peer_b.xid, - RouteId(90), - CBOR::from(1u8), - RequestConfig::default(), - ) - .recv() - .await; - assert!(matches!(result_a, Err(QlError::UnknownPeer(peer)) if peer == peer_b.xid)); - - let result_b = handle_b - .send_request_raw( - peer_a.xid, - RouteId(91), - CBOR::from(1u8), - RequestConfig::default(), - ) - .recv() - .await; - assert!(matches!(result_b, Err(QlError::UnknownPeer(peer)) if peer == peer_a.xid)); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn unpair_works_without_session() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)) - .with_request_timeout(Duration::from_millis(200)); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - await_status(&status_b, peer_a.xid, PeerStage::Disconnected).await; - - handle_a.unpair(peer_b.xid).unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - await_status(&status_b, peer_a.xid, PeerStage::Disconnected).await; - - let result_a = handle_a - .send_request_raw( - peer_b.xid, - RouteId(92), - CBOR::from(1u8), - RequestConfig::default(), - ) - .recv() - .await; - assert!(matches!(result_a, Err(QlError::UnknownPeer(peer)) if peer == peer_b.xid)); - - let result_b = handle_b - .send_request_raw( - peer_a.xid, - RouteId(93), - CBOR::from(1u8), - RequestConfig::default(), - ) - .recv() - .await; - assert!(matches!(result_b, Err(QlError::UnknownPeer(peer)) if peer == peer_a.xid)); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn invalid_unpair_signature_is_ignored() { - run_local_test(async { - let config = RuntimeConfig::new(Duration::from_millis(200)); - let (platform_a, _outbound_a, _status_a) = TestPlatform::new(1); - let (platform_b, _outbound_b, status_b) = TestPlatform::new(2); - let (fake_signer, _fake_outbound, _fake_status) = TestPlatform::new(3); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let forged_unpair = wire::unpair::build_unpair_record( - &fake_signer, - QlHeader { - sender: peer_a.xid, - recipient: peer_b.xid, - }, - MessageId(777), - now_secs().saturating_add(60), - ); - let forged_bytes = CBOR::from(forged_unpair).to_cbor_data(); - - let (runtime_b, handle_b) = new_runtime(platform_b, config); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - handle_b.register_peer( - peer_a.xid, - peer_a.signing_key.clone(), - peer_a.encapsulation_key.clone(), - ); - await_status(&status_b, peer_a.xid, PeerStage::Disconnected).await; - - handle_b.send_incoming(forged_bytes); - - tokio::time::sleep(Duration::from_millis(20)).await; - - let result = handle_b - .send_request_raw( - peer_a.xid, - RouteId(94), - CBOR::from(1u8), - RequestConfig::default(), - ) - .recv() - .await; - assert!(matches!(result, Err(QlError::MissingSession(peer)) if peer == peer_a.xid)); - }) - .await; -} diff --git a/ql/src/wire/handshake/crypto.rs b/ql/src/wire/handshake/crypto.rs deleted file mode 100644 index e7ea43a5..00000000 --- a/ql/src/wire/handshake/crypto.rs +++ /dev/null @@ -1,131 +0,0 @@ -use bc_components::{ - Digest, MLDSAPublicKey, MLKEMCiphertext, MLKEMPublicKey, Nonce, SymmetricKey, XID, -}; -use dcbor::CBOR; - -use super::{verify_transcript_signature, Confirm, Hello, HelloReply}; -use crate::{platform::QlPlatform, QlError}; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ResponderSecrets { - pub initiator_secret: SymmetricKey, - pub responder_secret: SymmetricKey, -} - -pub fn build_hello( - platform: &impl QlPlatform, - _sender: XID, - _recipient: XID, - recipient_encapsulation_key: &MLKEMPublicKey, -) -> Result<(Hello, SymmetricKey), QlError> { - let nonce = next_nonce(platform); - let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(); - Ok((Hello { nonce, kem_ct }, session_key)) -} - -pub fn respond_hello( - platform: &impl QlPlatform, - initiator: XID, - responder: XID, - initiator_encapsulation_key: &MLKEMPublicKey, - hello: &Hello, -) -> Result<(HelloReply, ResponderSecrets), QlError> { - let initiator_secret = platform - .encapsulation_private_key() - .decapsulate_shared_secret(&hello.kem_ct) - .map_err(|_| QlError::InvalidPayload)?; - let nonce = next_nonce(platform); - let (responder_secret, kem_ct) = initiator_encapsulation_key.encapsulate_new_shared_secret(); - let transcript = handshake_transcript(initiator, responder, hello, &nonce, &kem_ct); - let signature = platform.signing_private_key().sign(&transcript); - let reply = HelloReply { - nonce, - kem_ct, - signature, - }; - Ok(( - reply, - ResponderSecrets { - initiator_secret, - responder_secret, - }, - )) -} - -pub fn build_confirm( - platform: &impl QlPlatform, - initiator: XID, - responder: XID, - responder_signing_key: &MLDSAPublicKey, - hello: &Hello, - reply: &HelloReply, - initiator_secret: &SymmetricKey, -) -> Result<(Confirm, SymmetricKey), QlError> { - let transcript = handshake_transcript(initiator, responder, hello, &reply.nonce, &reply.kem_ct); - verify_transcript_signature(responder_signing_key, &reply.signature, &transcript)?; - let responder_secret = platform - .encapsulation_private_key() - .decapsulate_shared_secret(&reply.kem_ct) - .map_err(|_| QlError::InvalidPayload)?; - let signature = platform.signing_private_key().sign(&transcript); - let confirm = Confirm { signature }; - let session_key = derive_session_key(initiator_secret, &responder_secret, &transcript); - Ok((confirm, session_key)) -} - -pub fn finalize_confirm( - initiator: XID, - responder: XID, - initiator_signing_key: &MLDSAPublicKey, - hello: &Hello, - reply: &HelloReply, - confirm: &Confirm, - secrets: &ResponderSecrets, -) -> Result { - let transcript = handshake_transcript(initiator, responder, hello, &reply.nonce, &reply.kem_ct); - verify_transcript_signature(initiator_signing_key, &confirm.signature, &transcript)?; - Ok(derive_session_key( - &secrets.initiator_secret, - &secrets.responder_secret, - &transcript, - )) -} - -fn handshake_transcript( - initiator: XID, - responder: XID, - hello: &Hello, - responder_nonce: &Nonce, - responder_kem_ct: &MLKEMCiphertext, -) -> Vec { - CBOR::from(vec![ - CBOR::from(initiator), - CBOR::from(responder), - CBOR::from(hello.nonce.clone()), - CBOR::from(responder_nonce.clone()), - CBOR::from(hello.kem_ct.clone()), - CBOR::from(responder_kem_ct.clone()), - ]) - .to_cbor_data() -} - -fn next_nonce(platform: &impl QlPlatform) -> Nonce { - let mut data = [0u8; Nonce::NONCE_SIZE]; - platform.fill_random_bytes(&mut data); - Nonce::from_data(data) -} - -fn derive_session_key( - initiator_secret: &SymmetricKey, - responder_secret: &SymmetricKey, - transcript: &[u8], -) -> SymmetricKey { - let payload = CBOR::from(vec![ - CBOR::from(initiator_secret.as_bytes()), - CBOR::from(responder_secret.as_bytes()), - CBOR::from(transcript), - ]) - .to_cbor_data(); - let digest = Digest::from_image(payload); - SymmetricKey::from_data(*digest.data()) -} diff --git a/ql/src/wire/handshake/mod.rs b/ql/src/wire/handshake/mod.rs deleted file mode 100644 index 9eebbea7..00000000 --- a/ql/src/wire/handshake/mod.rs +++ /dev/null @@ -1,122 +0,0 @@ -use bc_components::{MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, Nonce}; -use dcbor::CBOR; - -use super::take_fields; -use crate::QlError; - -mod crypto; -pub use crypto::*; - -#[derive(Debug, Clone, PartialEq)] -pub enum HandshakeRecord { - Hello(Hello), - HelloReply(HelloReply), - Confirm(Confirm), -} - -#[derive(Debug, Clone, PartialEq)] -pub struct Hello { - pub nonce: Nonce, - pub kem_ct: MLKEMCiphertext, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct HelloReply { - pub nonce: Nonce, - pub kem_ct: MLKEMCiphertext, - pub signature: MLDSASignature, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct Confirm { - pub signature: MLDSASignature, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum HandshakeKind { - Hello = 1, - HelloReply, - Confirm, -} - -impl TryFrom for HandshakeKind { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let tag: u8 = value.try_into()?; - match tag { - 1 => Ok(Self::Hello), - 2 => Ok(Self::HelloReply), - 3 => Ok(Self::Confirm), - _ => Err(dcbor::Error::msg("unknown message tag")), - } - } -} - -pub fn verify_transcript_signature( - signing_key: &MLDSAPublicKey, - signature: &MLDSASignature, - transcript: &[u8], -) -> Result<(), QlError> { - match signing_key.verify(signature, transcript) { - Ok(true) => Ok(()), - _ => Err(QlError::InvalidSignature), - } -} - -impl From for CBOR { - fn from(value: HandshakeRecord) -> Self { - match value { - HandshakeRecord::Hello(message) => CBOR::from(vec![ - CBOR::from(HandshakeKind::Hello as u8), - CBOR::from(message.nonce), - CBOR::from(message.kem_ct), - ]), - HandshakeRecord::HelloReply(message) => CBOR::from(vec![ - CBOR::from(HandshakeKind::HelloReply as u8), - CBOR::from(message.nonce), - CBOR::from(message.kem_ct), - CBOR::from(message.signature), - ]), - HandshakeRecord::Confirm(message) => CBOR::from(vec![ - CBOR::from(HandshakeKind::Confirm as u8), - CBOR::from(message.signature), - ]), - } - } -} - -impl TryFrom for HandshakeRecord { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let mut iter = value.try_into_array()?.into_iter(); - let tag: HandshakeKind = iter - .next() - .ok_or_else(|| dcbor::Error::msg("missing handshake tag"))? - .try_into()?; - match tag { - HandshakeKind::Hello => { - let [nonce_cbor, kem_ct_cbor] = take_fields(iter)?; - Ok(HandshakeRecord::Hello(Hello { - nonce: nonce_cbor.try_into()?, - kem_ct: kem_ct_cbor.try_into()?, - })) - } - HandshakeKind::HelloReply => { - let [nonce_cbor, kem_ct_cbor, signature_cbor] = take_fields(iter)?; - Ok(HandshakeRecord::HelloReply(HelloReply { - nonce: nonce_cbor.try_into()?, - kem_ct: kem_ct_cbor.try_into()?, - signature: signature_cbor.try_into()?, - })) - } - HandshakeKind::Confirm => { - let [signature_cbor] = take_fields(iter)?; - Ok(HandshakeRecord::Confirm(Confirm { - signature: signature_cbor.try_into()?, - })) - } - } - } -} diff --git a/ql/src/wire/heartbeat/crypto.rs b/ql/src/wire/heartbeat/crypto.rs deleted file mode 100644 index 6f542d1b..00000000 --- a/ql/src/wire/heartbeat/crypto.rs +++ /dev/null @@ -1,40 +0,0 @@ -use bc_components::{Nonce, SymmetricKey}; -use dcbor::CBOR; - -use super::HeartbeatBody; -use crate::{ - wire::{ensure_not_expired, QlHeader, QlPayload, QlRecord}, - QlError, -}; - -pub fn encrypt_heartbeat( - header: QlHeader, - session_key: &SymmetricKey, - body: HeartbeatBody, -) -> QlRecord { - let aad = header.aad(); - let body_bytes = CBOR::from(body).to_cbor_data(); - let encrypted = session_key.encrypt(body_bytes, Some(aad), None::); - QlRecord { - header, - payload: QlPayload::Heartbeat(encrypted), - } -} - -pub fn decrypt_heartbeat( - header: &QlHeader, - encrypted: &bc_components::EncryptedMessage, - session_key: &SymmetricKey, -) -> Result { - let aad = header.aad(); - if encrypted.aad() != aad { - return Err(QlError::InvalidPayload); - } - let plaintext = session_key - .decrypt(encrypted) - .map_err(|_| QlError::InvalidPayload)?; - let cbor = CBOR::try_from_data(plaintext).map_err(|_| QlError::InvalidPayload)?; - let body = HeartbeatBody::try_from(cbor).map_err(|_| QlError::InvalidPayload)?; - ensure_not_expired(body.message_id, body.valid_until)?; - Ok(body) -} diff --git a/ql/src/wire/heartbeat/mod.rs b/ql/src/wire/heartbeat/mod.rs deleted file mode 100644 index bae5131a..00000000 --- a/ql/src/wire/heartbeat/mod.rs +++ /dev/null @@ -1,35 +0,0 @@ -use dcbor::CBOR; - -use super::take_fields; -use crate::MessageId; - -mod crypto; -pub use crypto::*; - -#[derive(Debug, Clone, PartialEq)] -pub struct HeartbeatBody { - pub message_id: MessageId, - pub valid_until: u64, -} - -impl From for CBOR { - fn from(value: HeartbeatBody) -> Self { - CBOR::from(vec![ - CBOR::from(value.message_id), - CBOR::from(value.valid_until), - ]) - } -} - -impl TryFrom for HeartbeatBody { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let iter = value.try_into_array()?.into_iter(); - let [message_id, valid_until] = take_fields(iter)?; - Ok(Self { - message_id: message_id.try_into()?, - valid_until: valid_until.try_into()?, - }) - } -} diff --git a/ql/src/wire/message/crypto.rs b/ql/src/wire/message/crypto.rs deleted file mode 100644 index 14613b8a..00000000 --- a/ql/src/wire/message/crypto.rs +++ /dev/null @@ -1,78 +0,0 @@ -use bc_components::{Nonce, SymmetricKey}; -use dcbor::CBOR; - -use super::{DecryptedMessage, MessageBody, MessageKind, Nack}; -use crate::{ - wire::{ensure_not_expired, QlHeader, QlPayload, QlRecord}, - MessageId, QlError, -}; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum MessageError { - Nack { - id: MessageId, - nack: Nack, - kind: MessageKind, - }, - Error(QlError), -} - -impl From for MessageError { - fn from(value: QlError) -> Self { - Self::Error(value) - } -} - -pub fn encrypt_message( - header: QlHeader, - session_key: &SymmetricKey, - body: MessageBody, -) -> QlRecord { - let aad = CBOR::from(header.clone()).to_cbor_data(); - let body_bytes = CBOR::from(body).to_cbor_data(); - let encrypted = session_key.encrypt(body_bytes, Some(aad), None::); - QlRecord { - header, - payload: QlPayload::Message(encrypted), - } -} - -pub fn decrypt_message( - header: &QlHeader, - encrypted: &bc_components::EncryptedMessage, - session_key: &SymmetricKey, -) -> Result { - let aad = header.aad(); - if encrypted.aad() != aad { - return Err(QlError::InvalidPayload.into()); - } - let body = decrypt_body(session_key, encrypted)?; - ensure_not_expired(body.message_id, body.valid_until).map_err(|err| match err { - QlError::Nack { id, nack } => MessageError::Nack { - id, - nack, - kind: body.kind, - }, - other => MessageError::Error(other), - })?; - Ok(DecryptedMessage { - sender: header.sender, - recipient: header.recipient, - kind: body.kind, - message_id: body.message_id, - route_id: body.route_id, - valid_until: body.valid_until, - payload: body.payload, - }) -} - -fn decrypt_body( - session_key: &SymmetricKey, - encrypted: &bc_components::EncryptedMessage, -) -> Result { - let plaintext = session_key - .decrypt(encrypted) - .map_err(|_| QlError::InvalidPayload)?; - let cbor = CBOR::try_from_data(plaintext).map_err(|_| QlError::InvalidPayload)?; - MessageBody::try_from(cbor).map_err(|_| QlError::InvalidPayload) -} diff --git a/ql/src/wire/message/mod.rs b/ql/src/wire/message/mod.rs deleted file mode 100644 index ea25b601..00000000 --- a/ql/src/wire/message/mod.rs +++ /dev/null @@ -1,144 +0,0 @@ -use bc_components::XID; -use dcbor::CBOR; - -use super::take_fields; -use crate::{MessageId, RouteId}; - -mod crypto; -pub use crypto::*; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum MessageKind { - Request, - Response, - Event, - Nack, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct Ack; - -#[derive(Debug, Clone, PartialEq)] -pub struct MessageBody { - pub message_id: MessageId, - pub valid_until: u64, - pub kind: MessageKind, - pub route_id: RouteId, - pub payload: CBOR, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct DecryptedMessage { - pub sender: XID, - pub recipient: XID, - pub kind: MessageKind, - pub message_id: MessageId, - pub route_id: RouteId, - pub valid_until: u64, - pub payload: CBOR, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Nack { - Unknown, - UnknownRoute, - InvalidPayload, - Expired, -} - -impl From for CBOR { - fn from(value: MessageKind) -> Self { - let kind = match value { - MessageKind::Request => 1, - MessageKind::Response => 2, - MessageKind::Event => 3, - MessageKind::Nack => 6, - }; - CBOR::from(kind) - } -} - -impl TryFrom for MessageKind { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let kind: u64 = value.try_into()?; - match kind { - 1 => Ok(Self::Request), - 2 => Ok(Self::Response), - 3 => Ok(Self::Event), - 6 => Ok(Self::Nack), - _ => Err(dcbor::Error::msg("unknown record kind")), - } - } -} - -impl From for CBOR { - fn from(_value: Ack) -> Self { - CBOR::null() - } -} - -impl TryFrom for Ack { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - if value.is_null() { - Ok(Self) - } else { - Err(dcbor::Error::msg("expected null")) - } - } -} - -impl From for CBOR { - fn from(value: MessageBody) -> Self { - CBOR::from(vec![ - CBOR::from(value.message_id), - CBOR::from(value.valid_until), - CBOR::from(value.kind), - CBOR::from(value.route_id), - value.payload, - ]) - } -} - -impl TryFrom for MessageBody { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let iter = value.try_into_array()?.into_iter(); - let [message_id, valid_until, kind, route_id, payload] = take_fields(iter)?; - Ok(Self { - message_id: message_id.try_into()?, - valid_until: valid_until.try_into()?, - kind: kind.try_into()?, - route_id: route_id.try_into()?, - payload, - }) - } -} - -impl From for CBOR { - fn from(value: Nack) -> Self { - let value = match value { - Nack::Unknown => 0, - Nack::UnknownRoute => 1, - Nack::InvalidPayload => 2, - Nack::Expired => 3, - }; - CBOR::from(value) - } -} - -impl From for Nack { - fn from(value: CBOR) -> Self { - let value: u8 = value.try_into().unwrap_or_default(); - match value { - 1 => Nack::UnknownRoute, - 2 => Nack::InvalidPayload, - 3 => Nack::Expired, - _ => Nack::Unknown, - } - } -} diff --git a/ql/src/wire/mod.rs b/ql/src/wire/mod.rs deleted file mode 100644 index 663a329c..00000000 --- a/ql/src/wire/mod.rs +++ /dev/null @@ -1,226 +0,0 @@ -use dcbor::CBOR; - -pub mod handshake; -pub mod heartbeat; -pub mod message; -pub mod pair; -pub mod transfer; -pub mod unpair; - -use bc_components::{EncryptedMessage, XID}; - -use self::{handshake::HandshakeRecord, pair::PairRequestRecord, unpair::UnpairRecord}; -use crate::{MessageId, QlError}; - -#[derive(Debug, Clone, PartialEq)] -pub struct QlRecord { - pub header: QlHeader, - pub payload: QlPayload, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct QlHeader { - pub sender: XID, - pub recipient: XID, -} - -impl QlHeader { - pub fn aad(&self) -> Vec { - CBOR::from(self.clone()).to_cbor_data() - } -} - -#[derive(Debug, Clone, PartialEq)] -pub enum QlPayload { - Handshake(HandshakeRecord), - Pair(PairRequestRecord), - Unpair(UnpairRecord), - Message(EncryptedMessage), - Heartbeat(EncryptedMessage), - Transfer(EncryptedMessage), -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum QlTag { - Handshake = 1, - Pairing = 2, - Record = 3, - Heartbeat = 4, - Unpair = 5, - Transfer = 6, -} - -impl From for CBOR { - fn from(value: QlTag) -> Self { - CBOR::from(value as u8) - } -} - -impl TryFrom for QlTag { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let tag: u8 = value.try_into()?; - match tag { - 1 => Ok(Self::Handshake), - 2 => Ok(Self::Pairing), - 3 => Ok(Self::Record), - 4 => Ok(Self::Heartbeat), - 5 => Ok(Self::Unpair), - 6 => Ok(Self::Transfer), - _ => Err(dcbor::Error::msg("unknown message tag")), - } - } -} - -impl From for CBOR { - fn from(value: QlRecord) -> Self { - let (tag, payload) = match value.payload { - QlPayload::Handshake(message) => (QlTag::Handshake, CBOR::from(message)), - QlPayload::Pair(message) => (QlTag::Pairing, CBOR::from(message)), - QlPayload::Message(message) => (QlTag::Record, CBOR::from(message)), - QlPayload::Heartbeat(message) => (QlTag::Heartbeat, CBOR::from(message)), - QlPayload::Unpair(message) => (QlTag::Unpair, CBOR::from(message)), - QlPayload::Transfer(message) => (QlTag::Transfer, CBOR::from(message)), - }; - CBOR::from(vec![ - CBOR::from(tag as u8), - CBOR::from(value.header), - payload, - ]) - } -} - -impl TryFrom for QlRecord { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let iter = value.try_into_array()?.into_iter(); - let [tag_cbor, header_cbor, payload] = take_fields(iter)?; - let tag = QlTag::try_from(tag_cbor)?; - let header = QlHeader::try_from(header_cbor)?; - match tag { - QlTag::Handshake => { - let message = HandshakeRecord::try_from(payload)?; - Ok(QlRecord { - header, - payload: QlPayload::Handshake(message), - }) - } - QlTag::Pairing => { - let message = PairRequestRecord::try_from(payload)?; - Ok(QlRecord { - header, - payload: QlPayload::Pair(message), - }) - } - QlTag::Record => { - let message = EncryptedMessage::try_from(payload)?; - Ok(QlRecord { - header, - payload: QlPayload::Message(message), - }) - } - QlTag::Heartbeat => { - let message = EncryptedMessage::try_from(payload)?; - Ok(QlRecord { - header, - payload: QlPayload::Heartbeat(message), - }) - } - QlTag::Unpair => { - let message = UnpairRecord::try_from(payload)?; - Ok(QlRecord { - header, - payload: QlPayload::Unpair(message), - }) - } - QlTag::Transfer => { - let message = EncryptedMessage::try_from(payload)?; - Ok(QlRecord { - header, - payload: QlPayload::Transfer(message), - }) - } - } - } -} - -impl From for CBOR { - fn from(value: QlHeader) -> Self { - CBOR::from(vec![CBOR::from(value.sender), CBOR::from(value.recipient)]) - } -} - -impl TryFrom for QlHeader { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let iter = value.try_into_array()?.into_iter(); - let [sender_cbor, recipient_cbor] = take_fields(iter)?; - Ok(Self { - sender: sender_cbor.try_into()?, - recipient: recipient_cbor.try_into()?, - }) - } -} - -pub(crate) fn take_fields( - mut iter: impl Iterator, -) -> Result<[CBOR; N], dcbor::Error> { - use std::mem::MaybeUninit; - - let mut fields: [MaybeUninit; N] = [const { MaybeUninit::uninit() }; N]; - for (index, slot) in fields.iter_mut().enumerate() { - let Some(value) = iter.next() else { - for init in &mut fields[..index] { - unsafe { init.assume_init_drop() }; - } - return Err(dcbor::Error::msg("array too short")); - }; - slot.write(value); - } - let result = unsafe { std::ptr::read(&fields as *const _ as *const [CBOR; N]) }; - if iter.next().is_some() { - return Err(dcbor::Error::msg("array too long")); - } - Ok(result) -} - -pub(crate) fn ensure_not_expired(id: MessageId, valid_until: u64) -> Result<(), QlError> { - let now = now_secs(); - if now > valid_until { - Err(QlError::Nack { - id, - nack: message::Nack::Expired, - }) - } else { - Ok(()) - } -} - -pub(crate) fn now_secs() -> u64 { - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|duration| duration.as_secs()) - .unwrap_or(0) -} - -#[test] -fn take_fields_reads_exact_count() { - let values = vec![CBOR::from(1u8), CBOR::from(2u8), CBOR::from(3u8)]; - let mut iter = values.into_iter(); - let [first, second, third] = take_fields(&mut iter).unwrap(); - assert_eq!(u8::try_from(first).unwrap(), 1); - assert_eq!(u8::try_from(second).unwrap(), 2); - assert_eq!(u8::try_from(third).unwrap(), 3); - assert!(iter.next().is_none()); -} - -#[test] -fn take_fields_rejects_short_arrays() { - let values = vec![CBOR::from(1u8)]; - let mut iter = values.into_iter(); - let result: Result<[CBOR; 2], _> = take_fields(&mut iter); - assert!(result.is_err()); -} diff --git a/ql/src/wire/pair/crypto.rs b/ql/src/wire/pair/crypto.rs deleted file mode 100644 index 8aca5b6d..00000000 --- a/ql/src/wire/pair/crypto.rs +++ /dev/null @@ -1,124 +0,0 @@ -use std::time::Duration; - -use bc_components::{ - MLDSAPublicKey, MLKEMCiphertext, MLKEMPublicKey, Nonce, SigningPublicKey, SymmetricKey, XID, -}; -use dcbor::CBOR; - -use super::{PairRequestBody, PairRequestRecord}; -use crate::{ - platform::{QlPlatform, QlPlatformExt}, - wire::{ensure_not_expired, now_secs, QlHeader, QlPayload, QlRecord}, - MessageId, QlError, -}; - -pub fn build_pair_request( - platform: &impl QlPlatform, - recipient: XID, - recipient_encapsulation_key: &MLKEMPublicKey, - message_id: MessageId, - valid_for: Duration, -) -> Result { - let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(); - let header = QlHeader { - sender: platform.xid(), - recipient, - }; - let valid_until = now_secs().saturating_add(valid_for.as_secs()); - let signing_pub_key = platform.signing_public_key().clone(); - let sender_encapsulation_key = platform.encapsulation_public_key().clone(); - let proof_data = pairing_proof_data( - &header, - &kem_ct, - message_id, - valid_until, - &signing_pub_key, - &sender_encapsulation_key, - ); - let proof = platform.signing_private_key().sign(&proof_data); - let body = PairRequestBody { - message_id, - valid_until, - signing_pub_key, - encapsulation_pub_key: sender_encapsulation_key, - proof, - }; - let body_bytes = CBOR::from(body).to_cbor_data(); - let aad = pairing_aad(&header, &kem_ct); - let encrypted = session_key.encrypt(body_bytes, Some(aad), None::); - Ok(QlRecord { - header, - payload: QlPayload::Pair(PairRequestRecord { kem_ct, encrypted }), - }) -} - -pub fn decrypt_pair_request( - platform: &impl QlPlatform, - header: &QlHeader, - request: PairRequestRecord, -) -> Result { - let PairRequestRecord { kem_ct, encrypted } = request; - let session_key = platform - .encapsulation_private_key() - .decapsulate_shared_secret(&kem_ct) - .map_err(|_| QlError::InvalidPayload)?; - let aad = pairing_aad(header, &kem_ct); - if encrypted.aad() != aad { - return Err(QlError::InvalidPayload); - } - let decrypted = decrypt_body(&session_key, &encrypted)?; - ensure_not_expired(decrypted.message_id, decrypted.valid_until)?; - if XID::new(SigningPublicKey::MLDSA(decrypted.signing_pub_key.clone())) != header.sender { - return Err(QlError::InvalidPayload); - } - let proof_data = pairing_proof_data( - header, - &kem_ct, - decrypted.message_id, - decrypted.valid_until, - &decrypted.signing_pub_key, - &decrypted.encapsulation_pub_key, - ); - if decrypted - .signing_pub_key - .verify(&decrypted.proof, &proof_data) - .unwrap_or(false) - { - Ok(decrypted) - } else { - Err(QlError::InvalidSignature) - } -} - -fn pairing_proof_data( - header: &QlHeader, - kem_ct: &MLKEMCiphertext, - message_id: MessageId, - valid_until: u64, - signing_pub_key: &MLDSAPublicKey, - encapsulation_pub_key: &MLKEMPublicKey, -) -> Vec { - CBOR::from(vec![ - CBOR::from(pairing_aad(header, kem_ct)), - CBOR::from(message_id), - CBOR::from(valid_until), - CBOR::from(signing_pub_key.clone()), - CBOR::from(encapsulation_pub_key.clone()), - ]) - .to_cbor_data() -} - -fn decrypt_body( - key: &SymmetricKey, - encrypted: &bc_components::EncryptedMessage, -) -> Result { - let plaintext = key - .decrypt(encrypted) - .map_err(|_| QlError::InvalidPayload)?; - let cbor = CBOR::try_from_data(plaintext).map_err(|_| QlError::InvalidPayload)?; - PairRequestBody::try_from(cbor).map_err(|_| QlError::InvalidPayload) -} - -fn pairing_aad(header: &QlHeader, kem_ct: &MLKEMCiphertext) -> Vec { - CBOR::from(vec![CBOR::from(header.clone()), CBOR::from(kem_ct.clone())]).to_cbor_data() -} diff --git a/ql/src/wire/pair/mod.rs b/ql/src/wire/pair/mod.rs deleted file mode 100644 index b14045eb..00000000 --- a/ql/src/wire/pair/mod.rs +++ /dev/null @@ -1,71 +0,0 @@ -use bc_components::{MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, MLKEMPublicKey}; -use dcbor::CBOR; - -use super::take_fields; -use crate::MessageId; - -mod crypto; -pub use crypto::*; - -#[derive(Debug, Clone, PartialEq)] -pub struct PairRequestRecord { - pub kem_ct: MLKEMCiphertext, - pub encrypted: bc_components::EncryptedMessage, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct PairRequestBody { - pub message_id: MessageId, - pub valid_until: u64, - pub signing_pub_key: MLDSAPublicKey, - pub encapsulation_pub_key: MLKEMPublicKey, - pub proof: MLDSASignature, -} - -impl From for CBOR { - fn from(value: PairRequestRecord) -> Self { - CBOR::from(vec![CBOR::from(value.kem_ct), CBOR::from(value.encrypted)]) - } -} - -impl TryFrom for PairRequestRecord { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let iter = value.try_into_array()?.into_iter(); - let [kem_ct_cbor, encrypted_cbor] = take_fields(iter)?; - Ok(Self { - kem_ct: kem_ct_cbor.try_into()?, - encrypted: encrypted_cbor.try_into()?, - }) - } -} - -impl From for CBOR { - fn from(value: PairRequestBody) -> Self { - CBOR::from(vec![ - CBOR::from(value.message_id), - CBOR::from(value.valid_until), - CBOR::from(value.signing_pub_key), - CBOR::from(value.encapsulation_pub_key), - CBOR::from(value.proof), - ]) - } -} - -impl TryFrom for PairRequestBody { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let iter = value.try_into_array()?.into_iter(); - let [message_id, valid_until, signing_pub_key, encapsulation_pub_key, proof] = - take_fields(iter)?; - Ok(Self { - message_id: message_id.try_into()?, - valid_until: valid_until.try_into()?, - signing_pub_key: signing_pub_key.try_into()?, - encapsulation_pub_key: encapsulation_pub_key.try_into()?, - proof: proof.try_into()?, - }) - } -} diff --git a/ql/src/wire/transfer/crypto.rs b/ql/src/wire/transfer/crypto.rs deleted file mode 100644 index dec752d2..00000000 --- a/ql/src/wire/transfer/crypto.rs +++ /dev/null @@ -1,42 +0,0 @@ -use bc_components::{Nonce, SymmetricKey}; -use dcbor::CBOR; - -use super::TransferBody; -use crate::{ - wire::{now_secs, QlHeader, QlPayload, QlRecord}, - QlError, -}; - -pub fn encrypt_transfer( - header: QlHeader, - session_key: &SymmetricKey, - body: TransferBody, -) -> QlRecord { - let aad = header.aad(); - let body_bytes = CBOR::from(body).to_cbor_data(); - let encrypted = session_key.encrypt(body_bytes, Some(aad), None::); - QlRecord { - header, - payload: QlPayload::Transfer(encrypted), - } -} - -pub fn decrypt_transfer( - header: &QlHeader, - encrypted: &bc_components::EncryptedMessage, - session_key: &SymmetricKey, -) -> Result { - let aad = header.aad(); - if encrypted.aad() != aad { - return Err(QlError::InvalidPayload); - } - let plaintext = session_key - .decrypt(encrypted) - .map_err(|_| QlError::InvalidPayload)?; - let cbor = CBOR::try_from_data(plaintext).map_err(|_| QlError::InvalidPayload)?; - let body = TransferBody::try_from(cbor).map_err(|_| QlError::InvalidPayload)?; - if now_secs() > body.valid_until { - return Err(QlError::InvalidPayload); - } - Ok(body) -} diff --git a/ql/src/wire/transfer/mod.rs b/ql/src/wire/transfer/mod.rs deleted file mode 100644 index f6b874b0..00000000 --- a/ql/src/wire/transfer/mod.rs +++ /dev/null @@ -1,194 +0,0 @@ -use dcbor::CBOR; - -use super::take_fields; -use crate::{MessageId, RouteId}; - -mod crypto; -pub use crypto::*; - -#[derive(Debug, Clone, PartialEq)] -pub struct TransferBody { - pub message_id: MessageId, - pub valid_until: u64, - pub transfer_id: MessageId, - pub frame: TransferFrame, -} - -#[derive(Debug, Clone, PartialEq)] -pub enum TransferFrame { - OpenResponse { - request_id: MessageId, - meta: CBOR, - }, - OpenRequest { - request_id: MessageId, - route_id: RouteId, - meta: CBOR, - }, - Chunk { - seq: u32, - data: Vec, - }, - Finish { - seq: u32, - }, - Ack { - next_seq: u32, - }, - Cancel, - CancelAck, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum TransferKind { - OpenResponse = 1, - OpenRequest, - Chunk, - Finish, - Ack, - Cancel, - CancelAck, -} - -impl From for CBOR { - fn from(value: TransferBody) -> Self { - CBOR::from(vec![ - CBOR::from(value.message_id), - CBOR::from(value.valid_until), - CBOR::from(value.transfer_id), - CBOR::from(value.frame), - ]) - } -} - -impl TryFrom for TransferBody { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let iter = value.try_into_array()?.into_iter(); - let [message_id, valid_until, transfer_id, frame] = take_fields(iter)?; - Ok(Self { - message_id: message_id.try_into()?, - valid_until: valid_until.try_into()?, - transfer_id: transfer_id.try_into()?, - frame: frame.try_into()?, - }) - } -} - -impl From for CBOR { - fn from(value: TransferFrame) -> Self { - match value { - TransferFrame::OpenResponse { request_id, meta } => CBOR::from(vec![ - CBOR::from(TransferKind::OpenResponse as u8), - CBOR::from(request_id), - meta, - ]), - TransferFrame::OpenRequest { - request_id, - route_id, - meta, - } => CBOR::from(vec![ - CBOR::from(TransferKind::OpenRequest as u8), - CBOR::from(request_id), - CBOR::from(route_id), - meta, - ]), - TransferFrame::Chunk { seq, data } => CBOR::from(vec![ - CBOR::from(TransferKind::Chunk as u8), - CBOR::from(seq), - CBOR::from(data), - ]), - TransferFrame::Finish { seq } => CBOR::from(vec![ - CBOR::from(TransferKind::Finish as u8), - CBOR::from(seq), - ]), - TransferFrame::Ack { next_seq } => CBOR::from(vec![ - CBOR::from(TransferKind::Ack as u8), - CBOR::from(next_seq), - ]), - TransferFrame::Cancel => CBOR::from(vec![CBOR::from(TransferKind::Cancel as u8)]), - TransferFrame::CancelAck => CBOR::from(vec![CBOR::from(TransferKind::CancelAck as u8)]), - } - } -} - -impl TryFrom for TransferFrame { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let mut iter = value.try_into_array()?.into_iter(); - let tag: TransferKind = iter - .next() - .ok_or_else(|| dcbor::Error::msg("missing transfer frame tag"))? - .try_into()?; - match tag { - TransferKind::OpenResponse => { - let [request_id, meta] = take_fields(iter)?; - Ok(Self::OpenResponse { - request_id: request_id.try_into()?, - meta, - }) - } - TransferKind::OpenRequest => { - let [request_id, route_id, meta] = take_fields(iter)?; - Ok(Self::OpenRequest { - request_id: request_id.try_into()?, - route_id: route_id.try_into()?, - meta, - }) - } - TransferKind::Chunk => { - let [seq, data] = take_fields(iter)?; - Ok(Self::Chunk { - seq: seq.try_into()?, - data: data.try_into()?, - }) - } - TransferKind::Finish => { - let [seq] = take_fields(iter)?; - Ok(Self::Finish { - seq: seq.try_into()?, - }) - } - TransferKind::Ack => { - let [next_seq] = take_fields(iter)?; - Ok(Self::Ack { - next_seq: next_seq.try_into()?, - }) - } - TransferKind::Cancel => { - if iter.next().is_some() { - Err(dcbor::Error::msg("array too long")) - } else { - Ok(Self::Cancel) - } - } - TransferKind::CancelAck => { - if iter.next().is_some() { - Err(dcbor::Error::msg("array too long")) - } else { - Ok(Self::CancelAck) - } - } - } - } -} - -impl TryFrom for TransferKind { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let tag: u8 = value.try_into()?; - match tag { - 1 => Ok(Self::OpenResponse), - 2 => Ok(Self::OpenRequest), - 3 => Ok(Self::Chunk), - 4 => Ok(Self::Finish), - 5 => Ok(Self::Ack), - 6 => Ok(Self::Cancel), - 7 => Ok(Self::CancelAck), - _ => Err(dcbor::Error::msg("unknown transfer frame tag")), - } - } -} diff --git a/ql/src/wire/unpair/crypto.rs b/ql/src/wire/unpair/crypto.rs deleted file mode 100644 index ca319ff6..00000000 --- a/ql/src/wire/unpair/crypto.rs +++ /dev/null @@ -1,58 +0,0 @@ -use bc_components::MLDSAPublicKey; -use dcbor::CBOR; - -use super::UnpairRecord; -use crate::{ - platform::QlPlatform, - wire::{now_secs, QlHeader, QlPayload, QlRecord}, - MessageId, QlError, -}; - -pub fn build_unpair_record( - platform: &impl QlPlatform, - header: QlHeader, - message_id: MessageId, - valid_until: u64, -) -> QlRecord { - let signature = - platform - .signing_private_key() - .sign(&unpair_proof_data(&header, message_id, valid_until)); - QlRecord { - header, - payload: QlPayload::Unpair(UnpairRecord { - message_id, - valid_until, - signature, - }), - } -} - -pub fn verify_unpair_record( - header: &QlHeader, - record: &UnpairRecord, - signing_key: &MLDSAPublicKey, -) -> Result<(), QlError> { - if now_secs() > record.valid_until { - return Err(QlError::InvalidPayload); - } - let proof_data = unpair_proof_data(header, record.message_id, record.valid_until); - if signing_key - .verify(&record.signature, &proof_data) - .unwrap_or(false) - { - Ok(()) - } else { - Err(QlError::InvalidSignature) - } -} - -fn unpair_proof_data(header: &QlHeader, message_id: MessageId, valid_until: u64) -> Vec { - CBOR::from(vec![ - CBOR::from("ql-unpair-v1"), - CBOR::from(header.clone()), - CBOR::from(message_id), - CBOR::from(valid_until), - ]) - .to_cbor_data() -} diff --git a/ql/src/wire/unpair/mod.rs b/ql/src/wire/unpair/mod.rs deleted file mode 100644 index cc81bab4..00000000 --- a/ql/src/wire/unpair/mod.rs +++ /dev/null @@ -1,39 +0,0 @@ -use bc_components::MLDSASignature; -use dcbor::CBOR; - -use super::take_fields; -use crate::MessageId; - -mod crypto; -pub use crypto::*; - -#[derive(Debug, Clone, PartialEq)] -pub struct UnpairRecord { - pub message_id: MessageId, - pub valid_until: u64, - pub signature: MLDSASignature, -} - -impl From for CBOR { - fn from(value: UnpairRecord) -> Self { - CBOR::from(vec![ - CBOR::from(value.message_id), - CBOR::from(value.valid_until), - CBOR::from(value.signature), - ]) - } -} - -impl TryFrom for UnpairRecord { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let iter = value.try_into_array()?.into_iter(); - let [message_id, valid_until, signature] = take_fields(iter)?; - Ok(Self { - message_id: message_id.try_into()?, - valid_until: valid_until.try_into()?, - signature: signature.try_into()?, - }) - } -} diff --git a/ql2/README.md b/ql2/README.md deleted file mode 100644 index d39e4e90..00000000 --- a/ql2/README.md +++ /dev/null @@ -1,143 +0,0 @@ -# QL Protocol (v2) - -QL is a compact, session-oriented protocol for authenticated and encrypted messaging -between peers over arbitrary transports. It targets low-bandwidth and high-latency -links while preserving strong cryptography, explicit request/response semantics, and -a clean developer-facing API. - -This crate (`ql`) implements the protocol stack: wire format, crypto, runtime state -machine, and routing. For a deeper comparison with v1, see `ql-protocol-v2.md`. - -## features -- Fixed CBOR wire format: `QlRecord` = `[tag, header, payload]`. -- Mutual-auth handshake (`Hello`, `HelloReply`, `Confirm`) with signed transcript. -- Session keys derived from KEM secrets; payloads use AEAD (ChaCha20-Poly1305). -- Sessions are ephemeral and scoped to a handshake; no long-lived symmetric keys. -- First-contact pairing request with KEM-wrapped payloads and proof signature. -- Encrypted messages with explicit `Request`, `Response`, `Event`, and `Nack`. -- `MessageId`, `RouteId`, and `valid_until` for routing and freshness. -- Heartbeats for keepalive and disconnect detection. -- Runtime state machine for sessions, timeouts, outbound queues, and correlation. -- Router for typed dispatch and automatic response wiring. -- Transport abstraction via `QlPlatform` for BLE, TCP, or other links. - -## overview -QL provides a full session protocol rather than isolated message sealing. It covers: -- Mutual authentication and end-to-end encryption above the transport. -- First-contact pairing for provisioning keys and establishing trust. -- Typed routing with explicit request/response/event semantics. -- Runtime lifecycle management (handshake, keepalive, timeouts, errors). -- Portability across transports via a minimal platform abstraction. - -### security -- Mutual authentication via a signed handshake transcript. -- Session keys derived from KEM secrets; payloads are protected with AEAD - and header AAD. -- End-to-end protection above the transport layer; pairing supports first-contact - key exchange and proof of key possession. -- Message freshness enforced via `valid_until`; replay caching is not built-in, - so applications can optionally track `MessageId` if needed. - -### session vs per-message sealing -- v1 (gstp + envelope) signs every message and then encrypts it to the recipient. - each message uses fresh encapsulation, so keys and signatures are per-message. -- v2 (ql) signs the handshake transcript once, derives a session key, then uses - AEAD for each message with the header as AAD. -- encryption strength uses the same primitive (ChaCha20-Poly1305). post-quantum - security depends on key schemes (ML-KEM + ML-DSA with `pqcrypto` enabled). -- tradeoffs: v2 is faster and smaller; v1 has per-message signature and key - isolation. v2's AEAD provides in-session integrity but is not publicly - verifiable and has a larger blast radius if a session key leaks. - -### performance -- Public-key operations are paid once per session; steady-state traffic is - symmetric AEAD. -- Compact CBOR record framing keeps headers and serialization overhead small. -- Optional heartbeats provide liveness detection without heavy traffic. - -### developer experience -- Typed routes via `RequestResponse` and `Event` traits with explicit `RouteId`. -- Router handles decode, dispatch, and response wiring automatically. -- Runtime manages sessions, timeouts, outbound queues, and request correlation. -- `QlPlatform` abstracts the transport for portability and testability. - -## message sizes -Sizes below are CBOR record sizes from `protocol_record_size_breakdown` in -`ql/src/tests/mod.rs`. - -| Record | Size (bytes) | -| :-- | --: | -| Handshake Hello | 132 | -| Handshake HelloReply | 2563 | -| Handshake Confirm | 2510 | -| Pair request | 4065 | -| Message (empty payload) | 199 | -| Heartbeat | 196 | - -Handshake total is 5205 bytes (132 + 2563 + 2510). At 20 kBps transport -throughput, raw transmit time is about 0.26 s. - -## protocol overview - -### record framing -All traffic is encoded as a `QlRecord` with a small, fixed shape: -- `tag` selects the payload type (handshake, pair, record, heartbeat). -- `header` is unencrypted but authenticated data (AEAD AAD) used for routing - (sender and recipient XIDs). -- `payload` is a CBOR-encoded handshake/pair body or an encrypted message. - -### handshake -The handshake is a three-message exchange: -- `Hello`: initiator sends a nonce and KEM ciphertext. -- `HelloReply`: responder returns its nonce, KEM ciphertext, and a signature - over the transcript. -- `Confirm`: initiator signs the transcript to confirm mutual authentication. - -Both sides derive the session key from the KEM secrets and transcript digest. -After the handshake, all records use symmetric AEAD with the header as AAD. - -### pairing (first-contact) -Pairing is a standalone request that KEM-encrypts a payload containing: -- a `MessageId` and `valid_until` timestamp -- the sender's signing and encapsulation public keys -- a proof signature binding those keys - -This enables establishing trust without an existing session. - -### message records -Steady-state messages are sent as encrypted records with a typed body: -- `MessageKind`: `Request`, `Response`, `Event`, or `Nack` -- `MessageId`, `RouteId`, `valid_until`, and CBOR payload - -Nacks communicate standard failure reasons (unknown route, invalid payload, -expired) so peers can recover consistently. - -### heartbeats -Heartbeats are lightweight encrypted records used by the runtime to maintain -session liveness and detect disconnects. - -### routing and dispatch -`RouteId` maps to concrete request/response or event types. The router decodes -payloads, dispatches handlers, and ensures each request receives a response or -a `Nack`. - -### sequence diagram -```mermaid -sequenceDiagram - participant A as Initiator - participant B as Responder - A->>B: Hello (nonce, KEM ct) - B->>A: HelloReply (nonce, KEM ct, signature) - A->>B: Confirm (signature) - Note over A,B: Session key derived, AEAD enabled - A->>B: Encrypted Record (Request) - B->>A: Encrypted Record (Response) - A-->>B: Encrypted Heartbeat (optional) -``` - -## code map -- Wire format: `ql/src/wire/*` -- Cryptography: `ql/src/crypto/*` -- Runtime state machine: `ql/src/runtime/*` -- Routing and traits: `ql/src/router.rs`, `ql/src/lib.rs` -- Transport abstraction: `ql/src/platform.rs` diff --git a/ql2/ql-v2.presenterm.md b/ql2/ql-v2.presenterm.md deleted file mode 100644 index d4a0fff2..00000000 --- a/ql2/ql-v2.presenterm.md +++ /dev/null @@ -1,285 +0,0 @@ ---- -theme: - name: gruvbox-dark ---- - -# quantumlink protocol v2 -post-quantum, session-based message protocol - - - -# ql v1: constraints -- no message id / sequence id -- no protocol-level request/response pairing -- each platform had to interpret + correlate by hand -- no ack/nack -- no notion of 'liveness'/'connected' status -- ~6.6KB min sealed event - - sender xid document (pq pubkeys) - - per-message signature - - recipient encryption (+ continuations) -- more a utility crate than a protocol - - - -# v1 vs v2 - - - - -## v1 -- gstp sealed envelope per message -- per-message sign+encrypt (envelope) -- implicit req/resp in enum variants -- app-owned pairing, timeouts, keepalive, connected status - - - -## v2 -- compact record + typed payloads -- handshake signatures + per‑message aead under symmetric session key -- explicit kind + ids + nack -- runtime handles pairing, timeouts, keepalive, connected status, request/response matching - - - -# design shift: per-message -> session -- v1 sealed each message -- v2 signs once, then aead per message - -```text -v1: seal(msg) = sign(msg) + encrypt(recipient) -v2: session_key = handshake() -v2: aead(msg, aad=header) -``` - - -_aead = authenticated encryption with associated data_ - -_aad = additional authenticated data (visible, integrity-protected)_ - - - -# configurable host platform -- same runtime across keyos / mobile / desktop -- host supplies pq keys, io, timers, callbacks - -```rust -pub trait QlPlatform { - // pq identity - fn signing_private_key(&self) -> &MLDSAPrivateKey; - fn signing_public_key(&self) -> &MLDSAPublicKey; - fn encapsulation_private_key(&self) -> &MLKEMPrivateKey; - fn encapsulation_public_key(&self) -> &MLKEMPublicKey; - - // transport + runtime hooks - fn fill_bytes(&self, data: &mut [u8]); - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>>; - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()>; - - // event handlers - fn handle_peer_status(&self, peer: XID, session: &PeerSession); - fn handle_inbound(&self, event: HandlerEvent); -} -``` - - - -# multi-peer runtime -- runtime tracks sessions per peer -- concurrent handshakes + keepalive per peer - -```rust -handle.register_peer(peer, signing_key, encapsulation_key); -handle.connect(peer)?; -``` - - - -# protocol breakdown -```mermaid +render +width:90% -sequenceDiagram - participant A as initiator - participant B as responder - - Note over A,B: pairing (first contact) - A->>B: pair request (kem + signed payload) - - Note over A,B: handshake (mutual auth) - A->>B: hello (nonce + kem ct) - B->>A: hello reply (nonce + kem ct + signature) - A->>B: confirm (signature) - - Note over A,B: session established - A->>B: request / event (aead + aad header) - B->>A: response / nack (aead + aad header) - A-->>B: heartbeat (optional) -``` - - - -# wire framing: routable header -- record = [tag, header, payload] -- header is unencrypted but authenticated (aad) - -```rust -pub struct QlRecord { - pub header: QlHeader, - pub payload: QlPayload, -} - -pub struct QlHeader { - pub sender: XID, - pub recipient: XID, -} -``` - - - -# handshake flow + records -- hello: nonce + mlkem ciphertext -- reply: nonce + mlkem ciphertext + mldsa signature -- confirm: mldsa signature, then session key - -```rust -pub struct Hello { - pub nonce: Nonce, - pub kem_ct: MLKEMCiphertext, -} - -pub struct HelloReply { - pub nonce: Nonce, - pub kem_ct: MLKEMCiphertext, - pub signature: MLDSASignature, -} - -pub struct Confirm { - pub signature: MLDSASignature, -} -``` - - - -# session key derivation -- transcript binds ids + nonces + kem ciphertexts -- session key = digest(initiator_secret, responder_secret, transcript) - -```rust -let transcript = cbor([ - initiator, responder, - hello.nonce, reply.nonce, - hello.kem_ct, reply.kem_ct, -]); -let payload = cbor([initiator_secret, responder_secret, transcript]); -let digest = Digest::from_image(payload); -let session_key = SymmetricKey::from_data(*digest.data()); -``` - - - -# message modalities -- request / response -- event: fire-and-forget or acked -- nack for structured failure - -```rust -pub enum MessageKind { - Request, - Response, - Event, - Nack, -} -``` - - - -# message body: routing + expiry -- message_id + route_id -- valid_until for freshness - -```rust -pub struct MessageBody { - pub message_id: MessageId, - pub valid_until: u64, - pub kind: MessageKind, - pub route_id: RouteId, - pub payload: CBOR, -} -``` - - - -# nack reasons -- unknown route / invalid payload / expired - -```rust -pub enum Nack { - Unknown, - UnknownRoute, - InvalidPayload, - Expired, -} -``` - - - -# type-safe routing -- route id is const per type -- compiler couples request -> response - -```rust -pub trait RequestResponse: QlCodec { - const ID: RouteId; - type Response: QlCodec; -} - -pub trait Event: QlCodec { - const ID: RouteId; -} -``` - - - -# router wiring -- builder ties route ids to handlers -- unknown routes auto-nack - -```rust -let router = Router::builder() - .add_request_handler::() - .add_event_handler::() - .build(state); -``` - - - -# runtime api flow -- request returns response or nack -- events are fire-and-forget (or acked) - -```rust -let reply = handle.request(msg, peer, RequestConfig::default()).await?; -handle.send_event(status, peer); -``` - - - -# performance snapshot (cbor sizes) -| proto | message | bytes | notes | -| :-- | :-- | --: | :-- | -| v1 | sealed msg (exchange_rate) | 6645 | sign+encrypt | -| v1 | sealed heartbeat | 6633 | sign+encrypt | -| v2 | hello | 132 | kem+nonce | -| v2 | hello reply | 2563 | sig+kem | -| v2 | confirm | 2510 | sig | -| v2 | pair request | 4065 | sig+kem | -| v2 | message (empty) | 199 | steady-state | -| v2 | heartbeat | 196 | steady-state | - -handshake total: 5205 bytes - - - -# close -- smaller packets, clearer flow, typed api -- ql v2 is the protocol, not just a crate diff --git a/ql2/src/engine/mod.rs b/ql2/src/engine/mod.rs deleted file mode 100644 index ac1029da..00000000 --- a/ql2/src/engine/mod.rs +++ /dev/null @@ -1,2323 +0,0 @@ -pub mod replay_cache; -mod ring; -mod state; -mod stream; - -#[cfg(test)] -mod tests; - -use std::{ - cmp::Reverse, - collections::HashMap, - mem, - time::{Duration, Instant}, -}; - -use bc_components::{SigningPublicKey, XID}; -use rkyv::access_mut; -pub use state::{ - Engine, EngineInput, EngineOutput, EngineState, InitiatorStage, KeepAliveState, OpenId, - OutputFn, PeerRecord, PeerSession, Token, TrackedWrite, -}; - -use self::{replay_cache::ReplayKey, state::*, stream::*}; -use crate::{ - platform::{QlCrypto, QlIdentity}, - wire::{ - self, - encrypted_message::{ArchivedEncryptedMessage, NONCE_SIZE}, - handshake::{self, HandshakeRecord, Hello}, - heartbeat::{self, HeartbeatBody}, - stream::{ - decrypt_stream, encrypt_stream, BodyChunk, Direction, RejectCode, ResetCode, - ResetTarget, StreamAck, StreamAckBody, StreamBody, StreamFrame, StreamFrameAccept, - StreamFrameData, StreamFrameOpen, StreamFrameReject, StreamFrameReset, StreamMessage, - }, - unpair::{self}, - ControlMeta, QlHeader, QlPayload, QlRecord, StreamSeq, - }, - Peer, QlError, StreamId, -}; - -#[derive(Debug, Clone, Copy)] -pub struct KeepAliveConfig { - pub interval: Duration, - pub timeout: Duration, -} - -#[derive(Debug, Clone, Copy, Default)] -pub struct StreamConfig { - pub open_timeout: Option, -} - -#[derive(Debug, Clone, Copy)] -pub struct EngineConfig { - pub handshake_timeout: Duration, - pub default_open_timeout: Duration, - pub packet_expiration: Duration, - pub stream_ack_delay: Duration, - pub stream_ack_timeout: Duration, - pub stream_retry_limit: u8, - pub keep_alive: Option, -} - -impl Default for EngineConfig { - fn default() -> Self { - Self { - handshake_timeout: Duration::from_secs(5), - default_open_timeout: Duration::from_secs(5), - packet_expiration: Duration::from_secs(30), - stream_ack_delay: Duration::from_millis(5), - stream_ack_timeout: Duration::from_millis(150), - stream_retry_limit: 5, - keep_alive: None, - } - } -} - -impl Engine { - pub fn new(config: EngineConfig, identity: QlIdentity, peer: Option) -> Self { - Self { - config: config, - identity, - state: EngineState::new(peer), - streams: HashMap::new(), - } - } - - pub fn run_tick( - &mut self, - now: Instant, - input: EngineInput, - crypto: &impl QlCrypto, - emit: &mut impl OutputFn, - ) { - match input { - EngineInput::BindPeer(peer) => self.handle_bind_peer(peer, emit), - EngineInput::Pair => self.handle_pair_local(now, crypto), - EngineInput::Connect => self.handle_connect(now, crypto, emit), - EngineInput::Unpair => self.handle_unpair_local(now, emit), - EngineInput::OpenStream { - open_id, - request_head, - request_prefix, - config, - } => self.handle_open_stream(now, open_id, request_head, request_prefix, config, emit), - EngineInput::AcceptStream { - stream_id, - response_head, - response_prefix, - } => self.handle_accept_stream(now, stream_id, response_head, response_prefix), - EngineInput::RejectStream { stream_id, code } => { - self.handle_reject_stream(now, stream_id, code) - } - EngineInput::OutboundData { - stream_id, - dir, - bytes, - } => self.handle_outbound_data(stream_id, dir, bytes), - EngineInput::OutboundFinished { stream_id, dir } => { - self.handle_outbound_finished(stream_id, dir) - } - EngineInput::ResetOutbound { - stream_id, - dir, - code, - } => self.handle_reset_outbound(now, stream_id, dir, code), - EngineInput::ResetInbound { - stream_id, - dir, - code, - } => self.handle_reset_inbound(now, stream_id, dir, code), - EngineInput::PendingAcceptDropped { stream_id } => { - self.handle_pending_accept_dropped(stream_id, emit) - } - EngineInput::ResponderDropped { stream_id } => { - self.handle_responder_dropped(now, stream_id) - } - EngineInput::Incoming(bytes) => self.handle_incoming(now, bytes, crypto, emit), - EngineInput::WriteCompleted { - token, - tracked, - result, - } => self.handle_write_done(now, token, tracked, result, emit), - EngineInput::TimerExpired => self.handle_timeouts(now, crypto, emit), - } - - self.drive_streams(now, emit); - self.maybe_start_next_write(crypto, emit); - emit(EngineOutput::SetTimer(self.state.next_deadline())); - } - - fn emit_peer_status(&self, emit: &mut impl OutputFn) { - if let Some(peer) = self.state.peer.as_ref() { - emit(EngineOutput::PeerStatusChanged { - peer: peer.peer, - session: peer.session.clone(), - }); - } - } - - fn next_control_meta(&self, valid_for: Duration) -> ControlMeta { - ControlMeta { - packet_id: self.state.next_packet_id(), - valid_until: wire::now_secs() + valid_for.as_secs(), - } - } - - fn is_replayed_control(&mut self, peer: XID, meta: ControlMeta) -> bool { - self.state - .replay_cache - .check_and_store_valid_until(ReplayKey::new(peer, meta.packet_id), meta.valid_until) - } - - fn bind_peer_record(&mut self, peer: Peer, emit: &mut impl OutputFn) { - self.reset_runtime(QlError::Cancelled, emit); - self.state.peer = Some(PeerRecord::new( - peer.peer, - peer.signing_key, - peer.encapsulation_key, - )); - self.emit_peer_status(emit); - if let Some(peer) = self.state.peer.as_ref() { - emit(EngineOutput::PersistPeer(peer.snapshot())); - } - } - - fn reset_runtime(&mut self, error: QlError, emit: &mut impl OutputFn) { - let streams = mem::take(&mut self.streams); - for (stream_id, stream) in streams { - self.fail_stream(stream_id, stream, error.clone(), emit); - } - self.state.outbound.clear(); - self.state.timeouts.clear(); - self.state.write_in_flight = None; - } - - fn handle_bind_peer(&mut self, peer: Peer, emit: &mut impl OutputFn) { - if let Some(existing) = self.state.peer.as_ref() { - emit(EngineOutput::PeerStatusChanged { - peer: existing.peer, - session: PeerSession::Disconnected, - }); - } - self.bind_peer_record(peer, emit); - } - - fn handle_pair_local(&mut self, now: Instant, crypto: &impl QlCrypto) { - let Some(peer) = self.state.peer.as_ref() else { - return; - }; - let meta = self.next_control_meta(self.config.packet_expiration); - let Ok(record) = wire::pair::build_pair_request( - &self.identity, - crypto, - peer.peer, - &peer.encapsulation_key, - meta, - ) else { - return; - }; - let token = self.state.next_token(); - self.enqueue_handshake_message( - token, - now + self.config.packet_expiration, - wire::encode_record(&record), - ); - } - - fn handle_connect(&mut self, now: Instant, crypto: &impl QlCrypto, emit: &mut impl OutputFn) { - let Some(peer_record) = self.state.peer.as_ref() else { - return; - }; - let peer = peer_record.peer; - let meta = self.next_control_meta(self.config.handshake_timeout); - let (hello, session_key) = match &peer_record.session { - PeerSession::Connected { .. } - | PeerSession::Initiator { .. } - | PeerSession::Responder { .. } => { - return; - } - PeerSession::Disconnected => { - match handshake::build_hello( - &self.identity, - crypto, - peer, - &peer_record.encapsulation_key, - meta, - ) { - Ok(result) => result, - Err(_) => return, - } - } - }; - - let deadline = now + self.config.handshake_timeout; - let token = self.state.next_token(); - if let Some(entry) = self.state.peer.as_mut() { - entry.session = PeerSession::Initiator { - handshake_token: token, - hello: hello.clone(), - session_key, - deadline, - stage: InitiatorStage::WaitingHelloReply, - }; - } - self.emit_peer_status(emit); - - let record = QlRecord { - header: QlHeader { - sender: self.identity.xid, - recipient: peer, - }, - payload: QlPayload::Handshake(HandshakeRecord::Hello(hello)), - }; - self.enqueue_handshake_message(token, deadline, wire::encode_record(&record)); - } - - fn handle_unpair_local(&mut self, now: Instant, emit: &mut impl OutputFn) { - let Some(peer) = self.state.peer.as_ref().map(|peer| peer.peer) else { - return; - }; - let meta = self.next_control_meta(self.config.packet_expiration); - let record = unpair::build_unpair_record( - &self.identity, - QlHeader { - sender: self.identity.xid, - recipient: peer, - }, - meta, - ); - self.unpair_peer(emit); - let token = self.state.next_token(); - self.enqueue_handshake_message( - token, - now + self.config.packet_expiration, - wire::encode_record(&record), - ); - } - - fn handle_open_stream( - &mut self, - now: Instant, - open_id: OpenId, - request_head: Vec, - request_prefix: Option, - config: StreamConfig, - emit: &mut impl OutputFn, - ) { - let Some(entry) = self.state.peer.as_ref() else { - emit(EngineOutput::OpenFailed { - open_id, - stream_id: StreamId(0), - error: QlError::NoPeerBound, - }); - return; - }; - if !entry.session.is_connected() { - emit(EngineOutput::OpenFailed { - open_id, - stream_id: StreamId(0), - error: QlError::MissingSession, - }); - return; - } - - let stream_namespace = StreamNamespace::for_local(self.identity.xid, entry.peer); - let stream_id = self.state.next_stream_id(stream_namespace); - let open_timeout = config - .open_timeout - .unwrap_or(self.config.default_open_timeout); - let token = self.state.next_token(); - let request_prefix_fin = request_prefix.as_ref().is_some_and(|chunk| chunk.fin); - let frame = StreamFrameOpen { - stream_id, - request_head, - request_prefix, - }; - let stream = StreamState::Initiator(InitiatorStream { - meta: StreamMeta { - stream_id, - last_activity: now, - }, - control: StreamControl { - pending: std::collections::VecDeque::from([StreamFrame::Open(frame)]), - ..Default::default() - }, - request: OutboundState::from_prefix(Direction::Request, request_prefix_fin), - response: InboundState::new(), - accept: InitiatorAccept::Opening(OpenWaiter { - open_id: Some(open_id), - open_timeout_token: token, - }), - }); - self.streams.insert(stream_id, stream); - self.state.timeouts.push(Reverse(TimeoutEntry { - at: now + open_timeout, - kind: TimeoutKind::StreamOpen { stream_id, token }, - })); - emit(EngineOutput::OpenStarted { open_id, stream_id }); - } - - fn handle_accept_stream( - &mut self, - now: Instant, - stream_id: StreamId, - response_head: Vec, - response_prefix: Option, - ) { - let Some(StreamState::Responder(stream)) = self.streams.get_mut(&stream_id) else { - return; - }; - let ResponderResponse::Pending = stream.response else { - return; - }; - let response_prefix_fin = response_prefix.as_ref().is_some_and(|chunk| chunk.fin); - stream - .control - .pending - .push_back(StreamFrame::Accept(StreamFrameAccept { - stream_id, - response_head, - response_prefix, - })); - stream.response = ResponderResponse::Accepted { - body: OutboundState::from_prefix(Direction::Response, response_prefix_fin), - }; - stream.meta.last_activity = now; - } - - fn handle_reject_stream(&mut self, now: Instant, stream_id: StreamId, code: RejectCode) { - let Some(StreamState::Responder(stream)) = self.streams.get_mut(&stream_id) else { - return; - }; - let ResponderResponse::Pending = stream.response else { - return; - }; - stream - .control - .pending - .push_back(StreamFrame::Reject(StreamFrameReject { stream_id, code })); - stream.response = ResponderResponse::Rejecting; - stream.meta.last_activity = now; - } - - fn handle_outbound_data(&mut self, stream_id: StreamId, dir: Direction, bytes: Vec) { - if bytes.is_empty() { - return; - } - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - let Some(outbound) = stream.outbound_mut(dir) else { - return; - }; - if !outbound.take_pending_pull() { - return; - } - let chunk = BodyChunk { bytes, fin: false }; - stream - .control_mut() - .queue_frame_back(StreamFrame::Data(StreamFrameData { - stream_id, - dir, - chunk, - })); - } - - fn handle_outbound_finished(&mut self, stream_id: StreamId, dir: Direction) { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - let Some(outbound) = stream.outbound_mut(dir) else { - return; - }; - outbound.finish(); - } - - fn handle_reset_outbound( - &mut self, - now: Instant, - stream_id: StreamId, - dir: Direction, - code: ResetCode, - ) { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - let Some(outbound) = stream.outbound_mut(dir) else { - return; - }; - if outbound.is_closed() { - return; - } - outbound.close(); - stream.control_mut().queue_frame_front(reset_frame( - stream_id, - reset_target_for_dir(dir), - code, - )); - *stream.last_activity_mut() = now; - } - - fn handle_reset_inbound( - &mut self, - now: Instant, - stream_id: StreamId, - dir: Direction, - code: ResetCode, - ) { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - let Some(inbound) = stream.inbound_mut(dir) else { - return; - }; - if inbound.closed { - return; - } - inbound.closed = true; - stream.control_mut().queue_frame_front(reset_frame( - stream_id, - reset_target_for_dir(dir), - code, - )); - *stream.last_activity_mut() = now; - } - - fn handle_responder_dropped(&mut self, now: Instant, stream_id: StreamId) { - self.handle_reject_stream(now, stream_id, RejectCode::Unhandled); - } - - fn handle_pending_accept_dropped(&mut self, stream_id: StreamId, emit: &mut impl OutputFn) { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - if let StreamState::Initiator(stream) = stream { - match &mut stream.accept { - InitiatorAccept::Opening(waiter) | InitiatorAccept::WaitingAccept(waiter) => { - waiter.open_id = None; - } - InitiatorAccept::Open { .. } => {} - } - } - self.maybe_reap_stream(stream_id, emit); - } - - fn handle_incoming( - &mut self, - now: Instant, - mut bytes: Vec, - crypto: &impl QlCrypto, - emit: &mut impl OutputFn, - ) { - let Ok(record) = access_mut::(&mut bytes) - else { - return; - }; - let record = unsafe { record.unseal_unchecked() }; - let sender: XID = (&record.header.sender).into(); - let recipient: XID = (&record.header.recipient).into(); - if recipient != self.identity.xid { - return; - } - if !matches!(&record.payload, wire::ArchivedQlPayload::Pair(_)) { - let Some(peer) = self.state.peer.as_ref().map(|peer| peer.peer) else { - return; - }; - if sender != peer { - return; - } - } - let Ok(header) = wire::deserialize_value(&record.header) else { - return; - }; - match &mut record.payload { - wire::ArchivedQlPayload::Handshake(message) => { - self.handle_handshake(now, sender, message, crypto, emit) - } - wire::ArchivedQlPayload::Stream(encrypted) => { - self.handle_stream(now, sender, &header, encrypted, emit) - } - wire::ArchivedQlPayload::Heartbeat(encrypted) => { - self.handle_heartbeat(now, &header, encrypted, crypto, emit) - } - wire::ArchivedQlPayload::Pair(request) => { - self.handle_pairing(now, &header, request, crypto, emit) - } - wire::ArchivedQlPayload::Unpair(unpair_record) => { - self.handle_unpair(sender, &header, unpair_record, emit) - } - } - } - - fn handle_handshake( - &mut self, - now: Instant, - peer: XID, - message: &wire::handshake::ArchivedHandshakeRecord, - crypto: &impl QlCrypto, - emit: &mut impl OutputFn, - ) { - match message { - wire::handshake::ArchivedHandshakeRecord::Hello(hello) => { - self.handle_hello(now, peer, hello, crypto, emit) - } - wire::handshake::ArchivedHandshakeRecord::HelloReply(reply) => { - self.handle_hello_reply(now, peer, reply, emit) - } - wire::handshake::ArchivedHandshakeRecord::Confirm(confirm) => { - self.handle_confirm(now, peer, confirm, emit) - } - } - } - - fn handle_pairing( - &mut self, - now: Instant, - header: &QlHeader, - request: &mut wire::pair::ArchivedPairRequestRecord, - crypto: &impl QlCrypto, - emit: &mut impl OutputFn, - ) { - let payload = match wire::pair::decrypt_pair_request(&self.identity, header, request) { - Ok(payload) => payload, - Err(_) => return, - }; - let peer = XID::new(SigningPublicKey::MLDSA(payload.signing_pub_key.clone())); - if self.is_replayed_control(peer, payload.meta) { - return; - } - if let Some(existing) = self.state.peer.as_ref() { - if existing.peer != peer - || existing.signing_key != payload.signing_pub_key - || existing.encapsulation_key != payload.encapsulation_pub_key - { - return; - } - } else { - self.bind_peer_record( - Peer { - peer, - signing_key: payload.signing_pub_key, - encapsulation_key: payload.encapsulation_pub_key, - }, - emit, - ); - } - self.handle_connect(now, crypto, emit); - } - - fn handle_unpair( - &mut self, - peer: XID, - header: &QlHeader, - record: &wire::unpair::ArchivedUnpairRecord, - emit: &mut impl OutputFn, - ) { - { - let Some(peer_record) = self.state.peer.as_ref() else { - return; - }; - if unpair::verify_unpair_record(header, record, &peer_record.signing_key).is_err() { - return; - } - } - let meta: ControlMeta = (&record.meta).into(); - if self.is_replayed_control(peer, meta) { - return; - } - self.unpair_peer(emit); - } - - fn handle_heartbeat( - &mut self, - now: Instant, - header: &QlHeader, - encrypted: &mut ArchivedEncryptedMessage, - crypto: &impl QlCrypto, - emit: &mut impl OutputFn, - ) { - let (body, should_reply) = { - let Some(peer_record) = self.state.peer.as_ref() else { - return; - }; - let PeerSession::Connected { - session_key, - keepalive, - .. - } = &peer_record.session - else { - return; - }; - let Ok(body) = heartbeat::decrypt_heartbeat(header, encrypted, session_key) else { - return; - }; - (body, !keepalive.pending) - }; - if self.is_replayed_control(header.sender, body.meta) { - return; - } - self.record_activity(now); - if should_reply { - self.send_heartbeat_message(now, crypto); - } - self.emit_peer_status(emit); - } - - fn handle_stream( - &mut self, - now: Instant, - _peer: XID, - header: &QlHeader, - encrypted: &mut ArchivedEncryptedMessage, - emit: &mut impl OutputFn, - ) { - let body = { - let Some(peer_record) = self.state.peer.as_ref() else { - return; - }; - let PeerSession::Connected { session_key, .. } = &peer_record.session else { - return; - }; - match decrypt_stream(header, encrypted, session_key) { - Ok(body) => body, - Err(_) => return, - } - }; - - let message = match body { - StreamBody::Ack(StreamAckBody { stream_id, ack, .. }) => { - self.process_stream_ack(stream_id, ack, emit); - self.record_activity(now); - if self.streams.contains_key(&stream_id) { - self.record_stream_activity(stream_id, now); - self.maybe_reap_stream(stream_id, emit); - } - return; - } - StreamBody::Message(message) => message, - }; - - let stream_id = message.frame.stream_id(); - if let Some(ack) = message.ack { - self.process_stream_ack(stream_id, ack, emit); - } - - if !self.streams.contains_key(&stream_id) { - let Some(peer_record) = self.state.peer.as_ref() else { - return; - }; - let local_namespace = StreamNamespace::for_local(self.identity.xid, peer_record.peer); - if !local_namespace.remote().matches(stream_id) { - return; - } - let token = self.state.next_token(); - self.streams.insert( - stream_id, - StreamState::Provisional(ProvisionalStream { - meta: StreamMeta { - stream_id, - last_activity: now, - }, - control: StreamControl::default(), - timeout_token: token, - }), - ); - self.state.timeouts.push(Reverse(TimeoutEntry { - at: now + self.config.default_open_timeout, - kind: TimeoutKind::StreamProvisional { stream_id, token }, - })); - } - - let buffer_result = { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - *stream.last_activity_mut() = now; - stream - .control_mut() - .buffer_incoming(message.tx_seq, message.frame) - }; - - match buffer_result { - BufferIncomingResult::OutOfWindow => { - if self - .streams - .get(&stream_id) - .is_some_and(StreamState::is_provisional) - { - self.streams.remove(&stream_id); - self.send_ephemeral_reset(stream_id, ResetTarget::Both, ResetCode::Protocol); - } else if let Some(stream) = self.streams.get_mut(&stream_id) { - Self::queue_protocol_reset(stream, emit); - *stream.last_activity_mut() = now; - } - return; - } - BufferIncomingResult::Duplicate | BufferIncomingResult::AlreadyBuffered => { - if let Some(stream) = self.streams.get_mut(&stream_id) { - stream.control_mut().note_ack(true); - } - self.schedule_stream_ack(stream_id, now); - self.record_activity(now); - self.record_stream_activity(stream_id, now); - return; - } - BufferIncomingResult::Buffered { out_of_order } => { - if let Some(stream) = self.streams.get_mut(&stream_id) { - stream.control_mut().note_ack(out_of_order); - } - } - } - self.record_activity(now); - self.record_stream_activity(stream_id, now); - self.drain_committed_stream_frames(now, stream_id, emit); - if let Some(stream) = self.streams.get_mut(&stream_id) { - stream.control_mut().maybe_force_ack_for_progress(); - } - self.schedule_stream_ack(stream_id, now); - } - - fn schedule_stream_ack(&mut self, stream_id: StreamId, now: Instant) { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - let control = stream.control_mut(); - if !control.ack_dirty { - return; - } - if control.ack_immediate || self.config.stream_ack_delay.is_zero() { - control.ack_delay_token = None; - return; - } - if control.ack_delay_token.is_some() { - return; - } - let token = self.state.next_token(); - control.ack_delay_token = Some(token); - self.state.timeouts.push(Reverse(TimeoutEntry { - at: now + self.config.stream_ack_delay, - kind: TimeoutKind::StreamAckDelay { stream_id, token }, - })); - } - - fn drain_committed_stream_frames( - &mut self, - now: Instant, - stream_id: StreamId, - emit: &mut impl OutputFn, - ) { - loop { - let next = { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - stream.control_mut().pop_next_committable() - }; - let Some((_tx_seq, frame)) = next else { - break; - }; - if self - .streams - .get(&stream_id) - .is_some_and(StreamState::is_provisional) - && !matches!(frame, StreamFrame::Open(_)) - { - self.streams.remove(&stream_id); - self.send_ephemeral_reset(stream_id, ResetTarget::Both, ResetCode::Protocol); - return; - } - match frame { - StreamFrame::Open(frame) => self.handle_stream_open(now, frame, emit), - StreamFrame::Accept(frame) => self.handle_stream_accept_from_peer(now, frame, emit), - StreamFrame::Reject(frame) => self.handle_stream_reject_from_peer(frame, emit), - StreamFrame::Data(frame) => self.handle_stream_data(now, frame, emit), - StreamFrame::Reset(frame) => self.handle_stream_reset(now, frame, emit), - } - if !self.streams.contains_key(&stream_id) { - return; - } - } - self.maybe_reap_stream(stream_id, emit); - } - - fn handle_stream_open( - &mut self, - now: Instant, - frame: StreamFrameOpen, - emit: &mut impl OutputFn, - ) { - let StreamFrameOpen { - stream_id, - request_head, - request_prefix, - } = frame; - let control = match self.streams.remove(&stream_id) { - Some(StreamState::Provisional(stream)) => stream.control, - Some(mut stream) => { - Self::queue_protocol_reset(&mut stream, emit); - self.streams.insert(stream_id, stream); - return; - } - None => StreamControl::default(), - }; - - let mut stream = StreamState::Responder(ResponderStream { - meta: StreamMeta { - stream_id, - last_activity: now, - }, - control, - request: InboundState::new(), - response: ResponderResponse::Pending, - }); - if let Some(chunk) = request_prefix.as_ref() { - let Some(inbound) = stream.inbound_mut(Direction::Request) else { - return; - }; - if chunk.fin { - inbound.closed = true; - } - } - self.streams.insert(stream_id, stream); - emit(EngineOutput::InboundStreamOpened { - stream_id, - request_head, - request_prefix, - }); - } - - fn handle_stream_accept_from_peer( - &mut self, - now: Instant, - frame: StreamFrameAccept, - emit: &mut impl OutputFn, - ) { - let StreamFrameAccept { - stream_id, - response_head, - response_prefix, - } = frame; - let mut protocol = false; - let mut response_prefix_output = None; - { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - match stream { - StreamState::Initiator(stream) => match &mut stream.accept { - InitiatorAccept::Opening(waiter) => { - if let Some(open_id) = waiter.open_id.take() { - emit(EngineOutput::OpenAccepted { - open_id, - stream_id, - response_head: response_head.clone(), - response_prefix: response_prefix.clone(), - }); - } else { - stream.response.closed = true; - stream.control.queue_frame_front(reset_frame( - stream_id, - ResetTarget::Response, - ResetCode::Cancelled, - )); - } - stream.accept = InitiatorAccept::Open { response_head }; - stream.meta.last_activity = now; - response_prefix_output = response_prefix.clone(); - } - InitiatorAccept::WaitingAccept(waiter) => { - if let Some(open_id) = waiter.open_id.take() { - emit(EngineOutput::OpenAccepted { - open_id, - stream_id, - response_head: response_head.clone(), - response_prefix: response_prefix.clone(), - }); - } else { - stream.response.closed = true; - stream.control.queue_frame_front(reset_frame( - stream_id, - ResetTarget::Response, - ResetCode::Cancelled, - )); - } - stream.accept = InitiatorAccept::Open { response_head }; - stream.meta.last_activity = now; - response_prefix_output = response_prefix.clone(); - } - InitiatorAccept::Open { - response_head: stored, - } => { - if *stored != response_head { - protocol = true; - } - } - }, - _ => protocol = true, - } - } - - if protocol { - self.send_ephemeral_reset(stream_id, ResetTarget::Both, ResetCode::Protocol); - return; - } - - if let Some(chunk) = response_prefix_output.as_ref() { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - let Some(inbound) = stream.inbound_mut(Direction::Response) else { - Self::queue_protocol_reset(stream, emit); - return; - }; - if chunk.fin && !inbound.closed { - inbound.closed = true; - self.maybe_reap_stream(stream_id, emit); - } - } - } - - fn handle_stream_reject_from_peer( - &mut self, - frame: StreamFrameReject, - emit: &mut impl OutputFn, - ) { - let StreamFrameReject { stream_id, code } = frame; - let mut protocol = false; - let mut remove_after = false; - { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - match stream { - StreamState::Initiator(stream) => match &mut stream.accept { - InitiatorAccept::Opening(waiter) | InitiatorAccept::WaitingAccept(waiter) => { - if let Some(open_id) = waiter.open_id.take() { - emit(EngineOutput::OpenFailed { - open_id, - stream_id, - error: QlError::StreamRejected { code }, - }); - } - emit(EngineOutput::OutboundClosed { - stream_id, - dir: Direction::Request, - }); - emit(EngineOutput::InboundFailed { - stream_id, - dir: Direction::Response, - error: QlError::StreamRejected { code }, - }); - stream.request.close(); - stream.response.closed = true; - remove_after = true; - } - InitiatorAccept::Open { .. } => protocol = true, - }, - _ => protocol = true, - } - } - if remove_after { - self.streams.remove(&stream_id); - emit(EngineOutput::StreamReaped { stream_id }); - } - if protocol { - self.send_ephemeral_reset(stream_id, ResetTarget::Both, ResetCode::Protocol); - } - } - - fn handle_stream_data( - &mut self, - now: Instant, - frame: StreamFrameData, - emit: &mut impl OutputFn, - ) { - let StreamFrameData { - stream_id, - dir, - chunk, - } = frame; - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - if dir == Direction::Response - && matches!( - stream, - StreamState::Initiator(InitiatorStream { - accept: InitiatorAccept::Opening(_) | InitiatorAccept::WaitingAccept(_), - .. - }) - ) - { - Self::queue_protocol_reset(stream, emit); - *stream.last_activity_mut() = now; - return; - } - let Some(inbound) = stream.inbound_mut(dir) else { - Self::queue_protocol_reset(stream, emit); - return; - }; - if inbound.closed { - Self::queue_protocol_reset(stream, emit); - } else { - if !chunk.bytes.is_empty() { - emit(EngineOutput::InboundData { - stream_id, - dir, - bytes: chunk.bytes, - }); - } - if chunk.fin && !inbound.closed { - inbound.closed = true; - emit(EngineOutput::InboundFinished { stream_id, dir }); - } - } - *stream.last_activity_mut() = now; - self.maybe_reap_stream(stream_id, emit); - } - - fn handle_stream_reset( - &mut self, - now: Instant, - frame: StreamFrameReset, - emit: &mut impl OutputFn, - ) { - let StreamFrameReset { - stream_id, - target, - code, - } = frame; - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - Self::apply_remote_reset(stream, target, code, emit); - *stream.last_activity_mut() = now; - self.maybe_reap_stream(stream_id, emit); - } - - fn process_stream_ack( - &mut self, - stream_id: StreamId, - ack: StreamAck, - emit: &mut impl OutputFn, - ) { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - let acked: Vec<_> = stream - .control() - .in_flight - .iter() - .map(|(tx_seq, _)| tx_seq) - .filter(|tx_seq| StreamControl::ack_covers(ack, *tx_seq)) - .collect(); - if acked.is_empty() { - return; - } - let mut acked_frames = Vec::with_capacity(acked.len()); - for tx_seq in acked { - if let Some(in_flight) = stream.control_mut().remove_in_flight(tx_seq) { - acked_frames.push(in_flight.frame); - } - } - let _ = stream; - - for frame in acked_frames { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - match frame { - StreamFrame::Open(StreamFrameOpen { request_prefix, .. }) => { - if let StreamState::Initiator(stream) = stream { - if let InitiatorAccept::Opening(waiter) = &stream.accept { - stream.accept = InitiatorAccept::WaitingAccept(OpenWaiter { - open_id: waiter.open_id, - open_timeout_token: waiter.open_timeout_token, - }); - } - if request_prefix.as_ref().is_some_and(|chunk| chunk.fin) { - stream.request.close(); - emit(EngineOutput::OutboundClosed { - stream_id, - dir: Direction::Request, - }); - } - } - } - StreamFrame::Accept(StreamFrameAccept { - response_prefix, .. - }) => { - if let StreamState::Responder(stream) = stream { - if response_prefix.as_ref().is_some_and(|chunk| chunk.fin) { - if let ResponderResponse::Accepted { body } = &mut stream.response { - body.close(); - emit(EngineOutput::OutboundClosed { - stream_id, - dir: Direction::Response, - }); - } - } - } - } - StreamFrame::Reject(_) => {} - StreamFrame::Data(StreamFrameData { - dir, - chunk: BodyChunk { fin: true, .. }, - .. - }) => { - if let Some(outbound) = stream.outbound_mut(dir) { - outbound.close(); - emit(EngineOutput::OutboundClosed { stream_id, dir }); - } - } - StreamFrame::Reset(StreamFrameReset { target, code, .. }) => { - for outbound_dir in [Direction::Request, Direction::Response] { - let affects_outbound = matches!( - (target, outbound_dir), - (ResetTarget::Request, Direction::Request) - | (ResetTarget::Response, Direction::Response) - | (ResetTarget::Both, _) - ); - if affects_outbound { - if let Some(outbound) = stream.outbound_mut(outbound_dir) { - outbound.close(); - emit(EngineOutput::OutboundFailed { - stream_id, - dir: outbound_dir, - error: QlError::StreamReset { - dir: outbound_dir, - code, - }, - }); - } - } - } - } - StreamFrame::Data(_) => {} - } - } - - self.maybe_reap_stream(stream_id, emit); - } - - fn drive_streams(&mut self, now: Instant, emit: &mut impl OutputFn) { - let config = &self.config; - let state = &mut self.state; - for stream in self.streams.values_mut() { - Self::drive_stream(config, state, now, stream, emit); - } - } - - fn drive_stream( - config: &EngineConfig, - state: &mut EngineState, - _now: Instant, - stream: &mut StreamState, - emit: &mut impl OutputFn, - ) { - match stream { - StreamState::Initiator(stream) => { - Self::drive_stream_outbound( - config, - state, - stream.meta.stream_id, - &mut stream.control, - Some(&mut stream.request), - emit, - ); - } - StreamState::Responder(stream) => { - let stream_id = stream.meta.stream_id; - match &mut stream.response { - ResponderResponse::Accepted { body, .. } => { - Self::drive_stream_outbound( - config, - state, - stream_id, - &mut stream.control, - Some(body), - emit, - ); - } - _ => { - Self::drive_stream_outbound( - config, - state, - stream_id, - &mut stream.control, - None, - emit, - ); - } - } - } - StreamState::Provisional(stream) => Self::drive_stream_outbound( - config, - state, - stream.meta.stream_id, - &mut stream.control, - None, - emit, - ), - } - } - - fn drive_stream_outbound( - config: &EngineConfig, - state: &mut EngineState, - stream_id: StreamId, - control: &mut StreamControl, - mut outbound: Option<&mut OutboundState>, - emit: &mut impl OutputFn, - ) { - loop { - if control.send_window_has_space() { - if let Some(frame) = control.pending.pop_front() { - Self::enqueue_stream_frame(config, state, control, frame, 0, false); - continue; - } - } - if control.ack_dirty && control.ack_immediate && control.ack_outbound_token.is_none() { - Self::enqueue_stream_ack_body(config, state, control, stream_id, false); - continue; - } - if !control.send_window_has_space() { - return; - } - - let Some(outbound) = outbound.as_deref_mut() else { - return; - }; - if outbound.request_data() { - emit(EngineOutput::NeedOutboundData { - stream_id, - dir: outbound.dir, - }); - return; - } - if outbound.queue_fin() { - Self::enqueue_stream_frame( - config, - state, - control, - StreamFrame::Data(StreamFrameData { - stream_id, - dir: outbound.dir, - chunk: BodyChunk { - bytes: Vec::new(), - fin: true, - }, - }), - 0, - false, - ); - continue; - } - return; - } - } - - fn enqueue_stream_frame( - config: &EngineConfig, - state: &mut EngineState, - control: &mut StreamControl, - frame: StreamFrame, - attempt: u8, - priority: bool, - ) { - let tx_seq = control.take_tx_seq(); - Self::enqueue_stream_frame_with_seq( - config, state, control, tx_seq, frame, attempt, priority, - ); - } - - fn enqueue_stream_frame_with_seq( - config: &EngineConfig, - state: &mut EngineState, - control: &mut StreamControl, - tx_seq: StreamSeq, - frame: StreamFrame, - attempt: u8, - priority: bool, - ) { - control.insert_in_flight(InFlightFrame { - tx_seq, - frame: frame.clone(), - attempt, - }); - let ack = control.ack_dirty.then(|| control.current_ack()); - if ack.is_some() { - control.clear_ack_schedule(); - } - let valid_until = wire::now_secs().saturating_add(config.packet_expiration.as_secs()); - state.enqueue_stream_body( - config, - priority, - StreamBody::Message(StreamMessage { - tx_seq, - ack, - valid_until, - frame, - }), - ); - } - - fn enqueue_stream_ack_body( - config: &EngineConfig, - state: &mut EngineState, - control: &mut StreamControl, - stream_id: StreamId, - priority: bool, - ) { - if !control.ack_dirty { - return; - } - let ack = control.current_ack(); - control.clear_ack_schedule(); - let valid_until = wire::now_secs().saturating_add(config.packet_expiration.as_secs()); - let token = state.enqueue_stream_body( - config, - priority, - StreamBody::Ack(StreamAckBody { - stream_id, - ack, - valid_until, - }), - ); - control.ack_outbound_token = Some(token); - } - - fn queue_protocol_reset(stream: &mut StreamState, emit: &mut impl OutputFn) { - let stream_id = stream.stream_id(); - let control = stream.control_mut(); - control.clear_transient_buffers(); - control.queue_frame_front(reset_frame( - stream_id, - ResetTarget::Both, - ResetCode::Protocol, - )); - for dir in [Direction::Request, Direction::Response] { - if let Some(outbound) = stream.outbound_mut(dir) { - outbound.close(); - emit(EngineOutput::OutboundFailed { - stream_id, - dir, - error: QlError::StreamProtocol, - }); - } - if let Some(inbound) = stream.inbound_mut(dir) { - if !inbound.closed { - inbound.closed = true; - emit(EngineOutput::InboundFailed { - stream_id, - dir, - error: QlError::StreamProtocol, - }); - } - } - } - if let StreamState::Initiator(stream) = stream { - match &mut stream.accept { - InitiatorAccept::Opening(waiter) | InitiatorAccept::WaitingAccept(waiter) => { - if let Some(open_id) = waiter.open_id.take() { - emit(EngineOutput::OpenFailed { - open_id, - stream_id, - error: QlError::StreamProtocol, - }); - } - } - InitiatorAccept::Open { .. } => {} - } - } - } - - fn apply_remote_reset( - stream: &mut StreamState, - target: ResetTarget, - code: ResetCode, - emit: &mut impl OutputFn, - ) { - let stream_id = stream.stream_id(); - let request_error = QlError::StreamReset { - dir: Direction::Request, - code, - }; - let response_error = QlError::StreamReset { - dir: Direction::Response, - code, - }; - - if matches!(target, ResetTarget::Request | ResetTarget::Both) { - if let Some(inbound) = stream.inbound_mut(Direction::Request) { - if !inbound.closed { - inbound.closed = true; - emit(EngineOutput::InboundFailed { - stream_id, - dir: Direction::Request, - error: request_error.clone(), - }); - } - } - if let Some(outbound) = stream.outbound_mut(Direction::Request) { - outbound.close(); - emit(EngineOutput::OutboundFailed { - stream_id, - dir: Direction::Request, - error: request_error.clone(), - }); - } - } - if matches!(target, ResetTarget::Response | ResetTarget::Both) { - if let Some(inbound) = stream.inbound_mut(Direction::Response) { - if !inbound.closed { - inbound.closed = true; - emit(EngineOutput::InboundFailed { - stream_id, - dir: Direction::Response, - error: response_error.clone(), - }); - } - } - if let Some(outbound) = stream.outbound_mut(Direction::Response) { - outbound.close(); - emit(EngineOutput::OutboundFailed { - stream_id, - dir: Direction::Response, - error: response_error.clone(), - }); - } - } - - if let StreamState::Initiator(stream) = stream { - match &mut stream.accept { - InitiatorAccept::Opening(waiter) | InitiatorAccept::WaitingAccept(waiter) => { - if let Some(open_id) = waiter.open_id.take() { - emit(EngineOutput::OpenFailed { - open_id, - stream_id, - error: match target { - ResetTarget::Request => request_error, - _ => response_error, - }, - }); - } - } - InitiatorAccept::Open { .. } => {} - } - } - } - - fn maybe_reap_stream(&mut self, stream_id: StreamId, emit: &mut impl OutputFn) { - if self - .streams - .get(&stream_id) - .is_some_and(StreamState::can_reap) - { - self.streams.remove(&stream_id); - emit(EngineOutput::StreamReaped { stream_id }); - } - } - - fn clear_ack_outbound_token(&mut self, token: Token, retry: bool) { - for stream in self.streams.values_mut() { - let control = stream.control_mut(); - if control.ack_outbound_token == Some(token) { - control.ack_outbound_token = None; - if retry { - control.note_ack(true); - } - break; - } - } - } - - fn note_sent_stream_ack(&mut self, body: &StreamBody) { - let (stream_id, ack) = match body { - StreamBody::Ack(StreamAckBody { stream_id, ack, .. }) => (*stream_id, *ack), - StreamBody::Message(StreamMessage { - frame, - ack: Some(ack), - .. - }) => (frame.stream_id(), *ack), - StreamBody::Message(_) => return, - }; - if let Some(stream) = self.streams.get_mut(&stream_id) { - stream.control_mut().note_ack_sent(ack); - } - } - - fn send_ephemeral_reset(&mut self, stream_id: StreamId, dir: ResetTarget, code: ResetCode) { - let valid_until = wire::now_secs().saturating_add(self.config.packet_expiration.as_secs()); - self.enqueue_stream_body( - true, - StreamBody::Message(StreamMessage { - tx_seq: StreamSeq::START, - ack: None, - valid_until, - frame: StreamFrame::Reset(StreamFrameReset { - stream_id, - target: dir, - code, - }), - }), - ); - } - - fn enqueue_handshake_message(&mut self, token: Token, deadline: Instant, bytes: Vec) { - self.state - .enqueue_handshake_message(&self.config, token, deadline, bytes); - } - - fn enqueue_stream_body(&mut self, priority: bool, body: StreamBody) -> Token { - self.state.enqueue_stream_body(&self.config, priority, body) - } - - fn handle_hello( - &mut self, - now: Instant, - peer: XID, - hello: &wire::handshake::ArchivedHello, - crypto: &impl QlCrypto, - emit: &mut impl OutputFn, - ) { - let action = match self.state.peer.as_ref() { - Some(entry) => { - if handshake::verify_hello(peer, self.identity.xid, &entry.signing_key, hello) - .is_err() - { - return; - } - match &entry.session { - PeerSession::Initiator { - hello: local_hello, .. - } => { - if peer_hello_wins(local_hello, self.identity.xid, hello, peer) { - HelloAction::StartResponder - } else { - HelloAction::Ignore - } - } - PeerSession::Responder { - hello: stored, - reply, - deadline, - .. - } => { - if stored.nonce == (&hello.nonce).into() { - HelloAction::ResendReply { - reply: reply.clone(), - deadline: *deadline, - } - } else { - HelloAction::StartResponder - } - } - PeerSession::Disconnected | PeerSession::Connected { .. } => { - HelloAction::StartResponder - } - } - } - None => return, - }; - let meta: ControlMeta = (&hello.meta).into(); - if self.is_replayed_control(peer, meta) { - return; - } - - match action { - HelloAction::StartResponder => { - self.start_responder_handshake(now, peer, hello, crypto, emit) - } - HelloAction::ResendReply { reply, deadline } => { - let record = QlRecord { - header: QlHeader { - sender: self.identity.xid, - recipient: peer, - }, - payload: QlPayload::Handshake(HandshakeRecord::HelloReply(reply)), - }; - let token = self.state.next_token(); - self.enqueue_handshake_message(token, deadline, wire::encode_record(&record)); - } - HelloAction::Ignore => {} - } - } - - fn handle_hello_reply( - &mut self, - now: Instant, - peer: XID, - reply: &wire::handshake::ArchivedHelloReply, - emit: &mut impl OutputFn, - ) { - let deadline = now + self.config.handshake_timeout; - let confirm_meta = self.next_control_meta(self.config.handshake_timeout); - let res = { - let Some(peer_record) = self.state.peer.as_ref() else { - return; - }; - let PeerSession::Initiator { - hello, - session_key, - stage, - .. - } = &peer_record.session - else { - return; - }; - if *stage != InitiatorStage::WaitingHelloReply { - return; - } - handshake::build_confirm( - &self.identity, - peer, - &peer_record.signing_key, - hello, - reply, - session_key, - confirm_meta, - ) - .map(|(confirm, session_key)| (hello.clone(), confirm, session_key)) - }; - let (hello, confirm, session_key) = match res { - Ok(result) => result, - Err(_) => { - if let Some(entry) = self.state.peer.as_mut() { - entry.session = PeerSession::Disconnected; - } - self.emit_peer_status(emit); - return; - } - }; - let reply_meta: ControlMeta = (&reply.meta).into(); - if self.is_replayed_control(peer, reply_meta) { - return; - } - let token = self.state.next_token(); - if let Some(entry) = self.state.peer.as_mut() { - entry.session = PeerSession::Initiator { - handshake_token: token, - hello, - session_key, - deadline, - stage: InitiatorStage::SendingConfirm, - }; - } - - let record = QlRecord { - header: QlHeader { - sender: self.identity.xid, - recipient: peer, - }, - payload: QlPayload::Handshake(HandshakeRecord::Confirm(confirm)), - }; - self.enqueue_handshake_message(token, deadline, wire::encode_record(&record)); - } - - fn handle_confirm( - &mut self, - now: Instant, - peer: XID, - confirm: &wire::handshake::ArchivedConfirm, - emit: &mut impl OutputFn, - ) { - let Some(peer_record) = self.state.peer.as_ref() else { - return; - }; - let PeerSession::Responder { - hello, - reply, - secrets, - .. - } = &peer_record.session - else { - return; - }; - - match handshake::finalize_confirm( - peer, - self.identity.xid, - &peer_record.signing_key, - hello, - reply, - confirm, - secrets, - ) { - Ok(session_key) => { - let meta: ControlMeta = (&confirm.meta).into(); - if self.is_replayed_control(peer, meta) { - return; - } - if let Some(entry) = self.state.peer.as_mut() { - entry.session = PeerSession::Connected { - session_key, - keepalive: KeepAliveState::default(), - }; - } - self.record_activity(now); - self.emit_peer_status(emit); - } - Err(_) => { - if let Some(entry) = self.state.peer.as_mut() { - entry.session = PeerSession::Disconnected; - } - self.emit_peer_status(emit); - } - } - } - - fn start_responder_handshake( - &mut self, - now: Instant, - peer: XID, - hello: &wire::handshake::ArchivedHello, - crypto: &impl QlCrypto, - emit: &mut impl OutputFn, - ) { - let reply_meta = self.next_control_meta(self.config.handshake_timeout); - let res = { - let Some(peer_record) = self.state.peer.as_ref() else { - return; - }; - handshake::respond_hello( - &self.identity, - crypto, - peer, - &peer_record.signing_key, - &peer_record.encapsulation_key, - hello, - reply_meta, - ) - }; - let (reply, secrets) = match res { - Ok(result) => result, - Err(_) => { - if let Some(entry) = self.state.peer.as_mut() { - entry.session = PeerSession::Disconnected; - } - self.emit_peer_status(emit); - return; - } - }; - let Ok(hello) = wire::deserialize_value(hello) else { - if let Some(entry) = self.state.peer.as_mut() { - entry.session = PeerSession::Disconnected; - } - self.emit_peer_status(emit); - return; - }; - - let deadline = now + self.config.handshake_timeout; - let token = self.state.next_token(); - if let Some(entry) = self.state.peer.as_mut() { - entry.session = PeerSession::Responder { - handshake_token: token, - hello, - reply: reply.clone(), - secrets, - deadline, - }; - } - self.emit_peer_status(emit); - - let record = QlRecord { - header: QlHeader { - sender: self.identity.xid, - recipient: peer, - }, - payload: QlPayload::Handshake(HandshakeRecord::HelloReply(reply)), - }; - self.enqueue_handshake_message(token, deadline, wire::encode_record(&record)); - } - - fn send_heartbeat_message(&mut self, now: Instant, crypto: &impl QlCrypto) { - let Some(peer) = self.state.peer.as_ref().map(|peer| peer.peer) else { - return; - }; - let meta = self.next_control_meta(self.config.packet_expiration); - let token = self.state.next_token(); - let deadline = now + self.config.packet_expiration; - let message = { - let Some(peer_record) = self.state.peer.as_ref() else { - return; - }; - let PeerSession::Connected { session_key, .. } = &peer_record.session else { - return; - }; - heartbeat::encrypt_heartbeat( - QlHeader { - sender: self.identity.xid, - recipient: peer, - }, - session_key, - HeartbeatBody { meta }, - next_encrypted_message_nonce(crypto), - ) - }; - self.enqueue_handshake_message(token, deadline, wire::encode_record(&message)); - } - - fn keep_alive_config(&self) -> Option { - self.config - .keep_alive - .filter(|config| !config.interval.is_zero() && !config.timeout.is_zero()) - } - - fn record_activity(&mut self, now: Instant) { - let Some(config) = self.keep_alive_config() else { - return; - }; - let token = self.state.next_token(); - let Some(entry) = self.state.peer.as_mut() else { - return; - }; - let PeerSession::Connected { keepalive, .. } = &mut entry.session else { - return; - }; - keepalive.last_activity = Some(now); - keepalive.pending = false; - keepalive.token = token; - self.state.timeouts.push(Reverse(TimeoutEntry { - at: now + config.interval, - kind: TimeoutKind::KeepAliveSend { token }, - })); - } - - fn record_stream_activity(&mut self, stream_id: StreamId, now: Instant) { - if let Some(stream) = self.streams.get_mut(&stream_id) { - *stream.last_activity_mut() = now; - } - } - - fn drop_outbound(&mut self, emit: &mut impl OutputFn) { - while let Some(message) = self.state.outbound.pop_front() { - if let QueuedPayload::Stream { body } = message.payload { - match body { - StreamBody::Ack(_) => self.clear_ack_outbound_token(message.token, false), - StreamBody::Message(message) => { - let stream_id = message.frame.stream_id(); - self.fail_stream_by_id(stream_id, QlError::SendFailed, emit); - } - } - } - } - } - - fn abort_streams(&mut self, error: QlError, emit: &mut impl OutputFn) { - let streams = mem::take(&mut self.streams); - for (stream_id, stream) in streams { - self.fail_stream(stream_id, stream, error.clone(), emit); - } - } - - fn fail_stream_by_id(&mut self, stream_id: StreamId, error: QlError, emit: &mut impl OutputFn) { - let Some(stream) = self.streams.remove(&stream_id) else { - return; - }; - self.fail_stream(stream_id, stream, error, emit); - } - - fn fail_stream( - &mut self, - stream_id: StreamId, - stream: StreamState, - error: QlError, - emit: &mut impl OutputFn, - ) { - match stream { - StreamState::Initiator(stream) => { - match stream.accept { - InitiatorAccept::Opening(waiter) | InitiatorAccept::WaitingAccept(waiter) => { - if let Some(open_id) = waiter.open_id { - emit(EngineOutput::OpenFailed { - open_id, - stream_id, - error: error.clone(), - }); - } - } - InitiatorAccept::Open { .. } => {} - } - emit(EngineOutput::OutboundFailed { - stream_id, - dir: Direction::Request, - error: error.clone(), - }); - emit(EngineOutput::InboundFailed { - stream_id, - dir: Direction::Response, - error, - }); - } - StreamState::Responder(stream) => { - emit(EngineOutput::InboundFailed { - stream_id, - dir: Direction::Request, - error: error.clone(), - }); - if matches!(stream.response, ResponderResponse::Accepted { .. }) { - emit(EngineOutput::OutboundFailed { - stream_id, - dir: Direction::Response, - error, - }); - } - } - StreamState::Provisional(_) => {} - } - emit(EngineOutput::StreamReaped { stream_id }); - } - - fn unpair_peer(&mut self, emit: &mut impl OutputFn) { - let Some(peer) = self.state.peer.as_ref().map(|peer| peer.peer) else { - return; - }; - self.drop_outbound(emit); - self.abort_streams(QlError::SendFailed, emit); - self.state.peer = None; - emit(EngineOutput::PeerStatusChanged { - peer, - session: PeerSession::Disconnected, - }); - emit(EngineOutput::ClearPeer); - } - - fn handle_timeouts(&mut self, now: Instant, crypto: &impl QlCrypto, emit: &mut impl OutputFn) { - loop { - let Some(entry) = self - .state - .timeouts - .peek_mut() - .filter(|entry| entry.0.at <= now) - else { - break; - }; - let entry = std::collections::binary_heap::PeekMut::pop(entry).0; - match entry.kind { - TimeoutKind::Outbound { token } => { - let mut timed_out_stream = None; - let mut timed_out_ack = false; - self.state.outbound.retain(|message| { - if message.token == token { - if let QueuedPayload::Stream { body } = &message.payload { - match body { - StreamBody::Ack(_) => timed_out_ack = true, - StreamBody::Message(message) => { - timed_out_stream = Some(message.frame.stream_id()) - } - } - } - false - } else { - true - } - }); - if let Some(stream_id) = timed_out_stream { - self.fail_stream_by_id(stream_id, QlError::SendFailed, emit); - } else if timed_out_ack { - self.clear_ack_outbound_token(token, true); - } - } - TimeoutKind::Handshake { token } => { - let Some(entry) = self.state.peer.as_ref() else { - continue; - }; - let should_disconnect = matches!( - &entry.session, - PeerSession::Initiator { handshake_token, .. } | PeerSession::Responder { handshake_token, .. } - if *handshake_token == token - ); - if should_disconnect { - if let Some(entry) = self.state.peer.as_mut() { - entry.session = PeerSession::Disconnected; - } - self.emit_peer_status(emit); - self.drop_outbound(emit); - self.abort_streams(QlError::SendFailed, emit); - } - } - TimeoutKind::KeepAliveSend { token } => { - let Some(config) = self.keep_alive_config() else { - continue; - }; - let should_send = { - let Some(entry) = self.state.peer.as_ref() else { - continue; - }; - let PeerSession::Connected { keepalive, .. } = &entry.session else { - continue; - }; - keepalive.token == token && !keepalive.pending - }; - if should_send { - self.send_heartbeat_message(now, crypto); - } - if let Some(entry) = self.state.peer.as_mut() { - if let PeerSession::Connected { keepalive, .. } = &mut entry.session { - if keepalive.token == token { - keepalive.pending = true; - } - } - } - self.state.timeouts.push(Reverse(TimeoutEntry { - at: now + config.timeout, - kind: TimeoutKind::KeepAliveTimeout { token }, - })); - } - TimeoutKind::KeepAliveTimeout { token } => { - let Some(entry) = self.state.peer.as_ref() else { - continue; - }; - let should_disconnect = matches!(&entry.session, PeerSession::Connected { keepalive, .. } if keepalive.token == token && keepalive.pending); - if should_disconnect { - if let Some(entry) = self.state.peer.as_mut() { - entry.session = PeerSession::Disconnected; - } - self.emit_peer_status(emit); - self.drop_outbound(emit); - self.abort_streams(QlError::SendFailed, emit); - } - } - TimeoutKind::StreamOpen { stream_id, token } => { - let should_fail = self - .streams - .get(&stream_id) - .and_then(StreamState::open_timeout_token) - .is_some_and(|stream_token| stream_token == token); - if should_fail { - self.fail_stream_by_id(stream_id, QlError::Timeout, emit); - } - } - TimeoutKind::StreamAckDelay { stream_id, token } => { - let should_flush = self - .streams - .get(&stream_id) - .and_then(|stream| stream.control().ack_delay_token) - .is_some_and(|ack_token| ack_token == token); - if should_flush { - if let Some(stream) = self.streams.get_mut(&stream_id) { - let control = stream.control_mut(); - control.ack_delay_token = None; - control.ack_immediate = true; - } - } - } - TimeoutKind::StreamProvisional { stream_id, token } => { - let should_reset = self - .streams - .get(&stream_id) - .and_then(StreamState::provisional_timeout_token) - .is_some_and(|stream_token| stream_token == token); - if should_reset { - self.streams.remove(&stream_id); - self.send_ephemeral_reset( - stream_id, - ResetTarget::Both, - ResetCode::Protocol, - ); - } - } - TimeoutKind::StreamMessage { - stream_id, - tx_seq, - attempt, - } => { - let Some(frame) = self.streams.get(&stream_id).and_then(|stream| { - stream - .control() - .in_flight - .get(&tx_seq) - .and_then(|in_flight| { - (in_flight.attempt == attempt).then_some(in_flight.frame.clone()) - }) - }) else { - continue; - }; - - if attempt >= self.config.stream_retry_limit { - self.fail_stream_by_id(stream_id, QlError::Timeout, emit); - } else { - if let Some(stream) = self.streams.get_mut(&stream_id) { - Self::enqueue_stream_frame_with_seq( - &self.config, - &mut self.state, - stream.control_mut(), - tx_seq, - frame, - attempt.saturating_add(1), - true, - ); - } - } - } - } - } - } - - fn handle_write_done( - &mut self, - now: Instant, - token: Token, - tracked: Option, - result: Result<(), QlError>, - emit: &mut impl OutputFn, - ) { - if self.state.write_in_flight == Some(token) { - self.state.write_in_flight = None; - } - self.clear_ack_outbound_token(token, result.is_err()); - if let Err(error) = result { - if let Some(tracked) = tracked { - self.fail_stream_by_id(tracked.stream_id, error.clone(), emit); - } - let should_disconnect = matches!(self.state.peer.as_ref().map(|entry| &entry.session), - Some(PeerSession::Initiator { handshake_token, .. }) if *handshake_token == token) - || matches!(self.state.peer.as_ref().map(|entry| &entry.session), - Some(PeerSession::Responder { handshake_token, .. }) if *handshake_token == token); - if should_disconnect { - if let Some(entry) = self.state.peer.as_mut() { - entry.session = PeerSession::Disconnected; - } - self.emit_peer_status(emit); - self.drop_outbound(emit); - self.abort_streams(error, emit); - } - return; - } - - let connected = self - .state - .peer - .as_ref() - .and_then(|entry| match &entry.session { - PeerSession::Initiator { - session_key, - handshake_token, - stage: InitiatorStage::SendingConfirm, - .. - } if *handshake_token == token => Some(session_key.clone()), - _ => None, - }); - if let Some(session_key) = connected { - if let Some(entry) = self.state.peer.as_mut() { - entry.session = PeerSession::Connected { - session_key, - keepalive: KeepAliveState::default(), - }; - } - self.emit_peer_status(emit); - self.record_activity(now); - } - - if let Some(tracked) = tracked { - let attempt = self - .streams - .get(&tracked.stream_id) - .and_then(|stream| stream.control().in_flight.get(&tracked.tx_seq)) - .map(|in_flight| in_flight.attempt) - .unwrap_or(0); - self.state.timeouts.push(Reverse(TimeoutEntry { - at: now + self.config.stream_ack_timeout, - kind: TimeoutKind::StreamMessage { - stream_id: tracked.stream_id, - tx_seq: tracked.tx_seq, - attempt, - }, - })); - } - } - - fn maybe_start_next_write(&mut self, crypto: &impl QlCrypto, emit: &mut impl OutputFn) { - if self.state.write_in_flight.is_some() { - return; - } - while let Some(message) = self.state.outbound.pop_front() { - let bytes = match &message.payload { - QueuedPayload::PreEncoded(bytes) => bytes.clone(), - QueuedPayload::Stream { body } => { - let Some((recipient, session_key)) = - self.state.peer.as_ref().and_then(|peer| { - peer.session - .session_key() - .map(|key| (peer.peer, key.clone())) - }) - else { - match body { - StreamBody::Ack(_) => { - self.clear_ack_outbound_token(message.token, false) - } - StreamBody::Message(stream_message) => { - self.fail_stream_by_id( - stream_message.frame.stream_id(), - QlError::SendFailed, - emit, - ); - } - } - continue; - }; - self.note_sent_stream_ack(body); - let record = encrypt_stream( - QlHeader { - sender: self.identity.xid, - recipient, - }, - &session_key, - body, - next_encrypted_message_nonce(crypto), - ); - wire::encode_record(&record) - } - }; - - let tracked = match &message.payload { - QueuedPayload::Stream { - body: StreamBody::Message(stream_message), - } => Some(TrackedWrite { - stream_id: stream_message.frame.stream_id(), - tx_seq: stream_message.tx_seq, - }), - _ => None, - }; - self.state.write_in_flight = Some(message.token); - emit(EngineOutput::WriteMessage { - token: message.token, - tracked, - bytes, - }); - break; - } - } -} - -fn next_encrypted_message_nonce(crypto: &impl QlCrypto) -> [u8; NONCE_SIZE] { - let mut nonce = [0u8; NONCE_SIZE]; - crypto.fill_random_bytes(&mut nonce); - nonce -} - -fn peer_hello_wins( - local_hello: &Hello, - local_sender: XID, - peer_hello: &wire::handshake::ArchivedHello, - peer_sender: XID, -) -> bool { - use std::cmp::Ordering; - - let peer_nonce: bc_components::Nonce = (&peer_hello.nonce).into(); - match peer_nonce.data().cmp(local_hello.nonce.data()) { - Ordering::Less => true, - Ordering::Greater => false, - Ordering::Equal => peer_sender.data().cmp(local_sender.data()) == Ordering::Less, - } -} - -fn reset_target_for_dir(dir: Direction) -> ResetTarget { - match dir { - Direction::Request => ResetTarget::Request, - Direction::Response => ResetTarget::Response, - } -} diff --git a/ql2/src/engine/stream.rs b/ql2/src/engine/stream.rs deleted file mode 100644 index c90d9e09..00000000 --- a/ql2/src/engine/stream.rs +++ /dev/null @@ -1,435 +0,0 @@ -use std::{collections::VecDeque, time::Instant}; - -use super::{ring::SeqRing, OpenId, Token}; -use crate::{ - wire::{ - stream::{ - Direction, ResetCode, ResetTarget, StreamAck, StreamBody, StreamFrame, StreamFrameReset, - }, - StreamSeq, - }, - StreamId, -}; - -pub const STREAM_WINDOW_CAPACITY: usize = 8; -pub const STREAM_WINDOW_SIZE: u32 = STREAM_WINDOW_CAPACITY as u32; -pub const STREAM_ACK_EAGER_THRESHOLD: u32 = STREAM_WINDOW_SIZE / 2; - -#[derive(Debug)] -pub struct StreamMeta { - pub stream_id: StreamId, - pub last_activity: Instant, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum OutboundPhase { - Ready, - PendingPull, - FinPending, - FinQueued, - Closed, -} - -#[derive(Debug)] -pub struct OutboundState { - pub dir: Direction, - pub phase: OutboundPhase, -} - -impl OutboundState { - pub fn from_prefix(dir: Direction, fin: bool) -> Self { - Self { - dir, - phase: if fin { - OutboundPhase::FinQueued - } else { - OutboundPhase::Ready - }, - } - } - - pub fn is_closed(&self) -> bool { - self.phase == OutboundPhase::Closed - } - - pub fn request_data(&mut self) -> bool { - if self.phase != OutboundPhase::Ready { - return false; - } - self.phase = OutboundPhase::PendingPull; - true - } - - pub fn take_pending_pull(&mut self) -> bool { - if self.phase != OutboundPhase::PendingPull { - return false; - } - self.phase = OutboundPhase::Ready; - true - } - - pub fn finish(&mut self) { - self.phase = match self.phase { - OutboundPhase::Ready | OutboundPhase::PendingPull | OutboundPhase::FinPending => { - OutboundPhase::FinPending - } - OutboundPhase::FinQueued => OutboundPhase::FinQueued, - OutboundPhase::Closed => OutboundPhase::Closed, - }; - } - - pub fn queue_fin(&mut self) -> bool { - if self.phase != OutboundPhase::FinPending { - return false; - } - self.phase = OutboundPhase::FinQueued; - true - } - - pub fn close(&mut self) { - self.phase = OutboundPhase::Closed; - } -} - -#[derive(Debug)] -pub struct InboundState { - pub closed: bool, -} - -impl InboundState { - pub fn new() -> Self { - Self { closed: false } - } -} - -#[derive(Debug)] -pub struct OpenWaiter { - pub open_id: Option, - pub open_timeout_token: Token, -} - -#[derive(Debug)] -pub enum InitiatorAccept { - Opening(OpenWaiter), - WaitingAccept(OpenWaiter), - Open { response_head: Vec }, -} - -#[derive(Debug)] -pub struct InFlightFrame { - pub tx_seq: StreamSeq, - pub frame: StreamFrame, - pub attempt: u8, -} - -#[derive(Debug)] -pub enum BufferIncomingResult { - Duplicate, - AlreadyBuffered, - Buffered { out_of_order: bool }, - OutOfWindow, -} - -#[derive(Debug)] -pub struct StreamControl { - pub pending: VecDeque, - pub in_flight: SeqRing, - pub next_tx_seq: StreamSeq, - pub recv_buffer: SeqRing, - pub ack_dirty: bool, - pub ack_immediate: bool, - pub ack_delay_token: Option, - pub ack_outbound_token: Option, - pub last_sent_ack_base: StreamSeq, -} - -impl Default for StreamControl { - fn default() -> Self { - Self { - pending: VecDeque::new(), - in_flight: SeqRing::new(StreamSeq::START), - next_tx_seq: StreamSeq::START, - recv_buffer: SeqRing::new(StreamSeq::START), - ack_dirty: false, - ack_immediate: false, - ack_delay_token: None, - ack_outbound_token: None, - last_sent_ack_base: StreamSeq(0), - } - } -} - -impl StreamControl { - pub fn take_tx_seq(&mut self) -> StreamSeq { - let tx_seq = self.next_tx_seq; - self.next_tx_seq = self.next_tx_seq.next(); - tx_seq - } - - pub fn send_window_has_space(&self) -> bool { - self.in_flight.accepts_seq(self.next_tx_seq) - } - - pub fn committed_rx_seq(&self) -> StreamSeq { - self.recv_buffer.base_seq().prev() - } - - pub fn queue_frame_back(&mut self, frame: StreamFrame) { - self.pending.push_back(frame); - } - - pub fn queue_frame_front(&mut self, frame: StreamFrame) { - self.pending.push_front(frame); - } - - pub fn note_ack(&mut self, immediate: bool) { - self.ack_dirty = true; - self.ack_immediate |= immediate; - } - - pub fn clear_ack_schedule(&mut self) { - self.ack_dirty = false; - self.ack_immediate = false; - self.ack_delay_token = None; - } - - pub fn maybe_force_ack_for_progress(&mut self) { - if !self.ack_dirty { - return; - } - let committed = self.committed_rx_seq(); - let progressed = self - .last_sent_ack_base - .forward_distance_to(committed) - .unwrap_or(0); - if progressed >= STREAM_ACK_EAGER_THRESHOLD { - self.ack_immediate = true; - } - } - - pub fn note_ack_sent(&mut self, ack: StreamAck) { - if ack.base.serial_gt(self.last_sent_ack_base) { - self.last_sent_ack_base = ack.base; - } - } - - pub fn current_ack(&self) -> StreamAck { - StreamAck { - base: self.committed_rx_seq(), - bitmap: self.recv_buffer.bitmap(), - } - } - - pub fn buffer_incoming( - &mut self, - tx_seq: StreamSeq, - frame: StreamFrame, - ) -> BufferIncomingResult { - if tx_seq.serial_lt(self.recv_buffer.base_seq()) { - return BufferIncomingResult::Duplicate; - } - if !self.recv_buffer.accepts_seq(tx_seq) { - return BufferIncomingResult::OutOfWindow; - } - if self.recv_buffer.contains_key(&tx_seq) { - return BufferIncomingResult::AlreadyBuffered; - } - - let out_of_order = tx_seq != self.recv_buffer.base_seq(); - let _ = self.recv_buffer.insert(tx_seq, frame); - BufferIncomingResult::Buffered { out_of_order } - } - - pub fn pop_next_committable(&mut self) -> Option<(StreamSeq, StreamFrame)> { - self.recv_buffer.take_front() - } - - pub fn insert_in_flight(&mut self, frame: InFlightFrame) { - let _ = self.in_flight.set(frame.tx_seq, frame); - } - - pub fn remove_in_flight(&mut self, tx_seq: StreamSeq) -> Option { - let removed = self.in_flight.remove(&tx_seq); - self.in_flight.advance_empty_front_until(self.next_tx_seq); - removed - } - - pub fn clear_transient_buffers(&mut self) { - self.pending.clear(); - self.in_flight.clear_with_base(self.next_tx_seq); - self.recv_buffer - .clear_with_base(self.committed_rx_seq().next()); - self.clear_ack_schedule(); - } - - pub fn ack_covers(ack: StreamAck, tx_seq: StreamSeq) -> bool { - if tx_seq.serial_lte(ack.base) { - return true; - } - let Some(delta) = ack.base.forward_distance_to(tx_seq) else { - return false; - }; - if !(1..=STREAM_WINDOW_SIZE).contains(&delta) { - return false; - } - (ack.bitmap & (1u8 << (delta - 1))) != 0 - } -} - -#[derive(Debug)] -pub struct InitiatorStream { - pub meta: StreamMeta, - pub control: StreamControl, - pub request: OutboundState, - pub response: InboundState, - pub accept: InitiatorAccept, -} - -#[derive(Debug)] -pub enum ResponderResponse { - Pending, - Accepted { body: OutboundState }, - Rejecting, -} - -#[derive(Debug)] -pub struct ResponderStream { - pub meta: StreamMeta, - pub control: StreamControl, - pub request: InboundState, - pub response: ResponderResponse, -} - -#[derive(Debug)] -pub struct ProvisionalStream { - pub meta: StreamMeta, - pub control: StreamControl, - pub timeout_token: Token, -} - -#[derive(Debug)] -pub enum StreamState { - Initiator(InitiatorStream), - Responder(ResponderStream), - Provisional(ProvisionalStream), -} - -impl StreamState { - pub fn stream_id(&self) -> StreamId { - match self { - Self::Initiator(state) => state.meta.stream_id, - Self::Responder(state) => state.meta.stream_id, - Self::Provisional(state) => state.meta.stream_id, - } - } - - pub fn last_activity_mut(&mut self) -> &mut Instant { - match self { - Self::Initiator(state) => &mut state.meta.last_activity, - Self::Responder(state) => &mut state.meta.last_activity, - Self::Provisional(state) => &mut state.meta.last_activity, - } - } - - pub fn control(&self) -> &StreamControl { - match self { - Self::Initiator(state) => &state.control, - Self::Responder(state) => &state.control, - Self::Provisional(state) => &state.control, - } - } - - pub fn control_mut(&mut self) -> &mut StreamControl { - match self { - Self::Initiator(state) => &mut state.control, - Self::Responder(state) => &mut state.control, - Self::Provisional(state) => &mut state.control, - } - } - - pub fn outbound_mut(&mut self, dir: Direction) -> Option<&mut OutboundState> { - match self { - Self::Initiator(state) if dir == Direction::Request => Some(&mut state.request), - Self::Responder(state) if dir == Direction::Response => match &mut state.response { - ResponderResponse::Accepted { body } => Some(body), - _ => None, - }, - _ => None, - } - } - - pub fn inbound_mut(&mut self, dir: Direction) -> Option<&mut InboundState> { - match self { - Self::Initiator(state) if dir == Direction::Response => Some(&mut state.response), - Self::Responder(state) if dir == Direction::Request => Some(&mut state.request), - _ => None, - } - } - - pub fn open_timeout_token(&self) -> Option { - match self { - Self::Initiator(state) => match &state.accept { - InitiatorAccept::Opening(waiter) | InitiatorAccept::WaitingAccept(waiter) => { - Some(waiter.open_timeout_token) - } - InitiatorAccept::Open { .. } => None, - }, - _ => None, - } - } - - pub fn provisional_timeout_token(&self) -> Option { - match self { - Self::Provisional(state) => Some(state.timeout_token), - _ => None, - } - } - - pub fn is_provisional(&self) -> bool { - matches!(self, Self::Provisional(_)) - } - - pub fn can_reap(&self) -> bool { - if !self.control().pending.is_empty() - || !self.control().in_flight.is_empty() - || !self.control().recv_buffer.is_empty() - || self.control().ack_dirty - || self.control().ack_outbound_token.is_some() - { - return false; - } - match self { - Self::Initiator(state) => { - matches!(state.accept, InitiatorAccept::Open { .. }) - && state.request.is_closed() - && state.response.closed - } - Self::Responder(state) => match &state.response { - ResponderResponse::Accepted { body } => state.request.closed && body.is_closed(), - ResponderResponse::Rejecting => true, - ResponderResponse::Pending => false, - }, - Self::Provisional(_) => false, - } - } -} - -#[derive(Debug)] -pub enum QueuedPayload { - PreEncoded(Vec), - Stream { body: StreamBody }, -} - -#[derive(Debug)] -pub struct QueuedWrite { - pub token: Token, - pub payload: QueuedPayload, -} - -pub fn reset_frame(stream_id: StreamId, target: ResetTarget, code: ResetCode) -> StreamFrame { - StreamFrame::Reset(StreamFrameReset { - stream_id, - target, - code, - }) -} diff --git a/ql2/src/engine/tests.rs b/ql2/src/engine/tests.rs deleted file mode 100644 index 494b1ae1..00000000 --- a/ql2/src/engine/tests.rs +++ /dev/null @@ -1,1360 +0,0 @@ -use std::{cell::Cell, mem, time::Instant}; - -use bc_components::{SymmetricKey, MLDSA, MLKEM}; - -use super::*; -use crate::{ - platform::{QlCrypto, QlIdentity}, - wire::{ - self, - stream::{ - BodyChunk, StreamAck, StreamAckBody, StreamBody, StreamFrame, StreamFrameAccept, - StreamFrameData, StreamFrameOpen, StreamMessage, - }, - QlHeader, QlPayload, - }, - PacketId, Peer, -}; - -struct TestCrypto { - identity: QlIdentity, - nonce_seed: u8, - nonce_counter: Cell, -} - -impl TestCrypto { - fn new(seed: u8) -> Self { - let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); - let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); - Self { - identity: QlIdentity::from_keys( - signing_private, - signing_public, - encapsulation_private, - encapsulation_public, - ), - nonce_seed: seed, - nonce_counter: Cell::new(0), - } - } - - fn xid(&self) -> XID { - self.identity.xid - } - - fn peer(&self) -> Peer { - Peer { - peer: self.xid(), - signing_key: self.identity.signing_public_key.clone(), - encapsulation_key: self.identity.encapsulation_public_key.clone(), - } - } -} - -impl QlCrypto for TestCrypto { - fn fill_random_bytes(&self, data: &mut [u8]) { - let value = self.nonce_seed.wrapping_add(self.nonce_counter.get()); - self.nonce_counter - .set(self.nonce_counter.get().wrapping_add(1)); - data.fill(value); - } -} - -#[derive(Clone, Copy)] -enum Side { - A, - B, -} - -impl Side { - fn other(self) -> Self { - match self { - Side::A => Side::B, - Side::B => Side::A, - } - } -} - -struct Harness { - now: Instant, - a: Engine, - b: Engine, - crypto_a: TestCrypto, - crypto_b: TestCrypto, - outputs_a: Vec, - outputs_b: Vec, -} - -fn run_engine( - engine: &mut Engine, - now: Instant, - input: EngineInput, - crypto: &TestCrypto, -) -> Vec { - let mut outputs = Vec::new(); - engine.run_tick(now, input, crypto, &mut |output| outputs.push(output)); - outputs -} - -fn take_single_write(outputs: &[EngineOutput]) -> (Token, Option, Vec) { - let writes: Vec<_> = outputs - .iter() - .filter_map(|output| match output { - EngineOutput::WriteMessage { - token, - tracked, - bytes, - } => Some((*token, *tracked, bytes.clone())), - _ => None, - }) - .collect(); - assert_eq!(writes.len(), 1); - writes.into_iter().next().unwrap() -} - -fn decode_stream_body(bytes: &[u8], session_key: &SymmetricKey) -> (QlHeader, StreamBody) { - let record = wire::decode_record(bytes).unwrap(); - let aad = record.header.aad(); - let QlPayload::Stream(encrypted) = record.payload else { - panic!("expected stream payload"); - }; - let plaintext = encrypted.decrypt(session_key, &aad).unwrap(); - let body = wire::access_value::(&plaintext) - .and_then(wire::deserialize_value) - .unwrap(); - (record.header, body) -} - -fn connected_engine(local: &TestCrypto, peer: Peer, session_key: SymmetricKey) -> Engine { - let mut engine = Engine::new(EngineConfig::default(), local.identity.clone(), Some(peer)); - engine.state.peer.as_mut().unwrap().session = PeerSession::Connected { - session_key, - keepalive: KeepAliveState::default(), - }; - engine -} - -impl Harness { - fn new(config: EngineConfig) -> Self { - let crypto_a = TestCrypto::new(1); - let crypto_b = TestCrypto::new(2); - let peer_a = crypto_a.peer(); - let peer_b = crypto_b.peer(); - let session_key = SymmetricKey::from_data([7; SymmetricKey::SYMMETRIC_KEY_SIZE]); - let mut a = Engine::new(config, crypto_a.identity.clone(), Some(peer_b)); - let mut b = Engine::new(config, crypto_b.identity.clone(), Some(peer_a)); - a.state.peer.as_mut().unwrap().session = PeerSession::Connected { - session_key: session_key.clone(), - keepalive: KeepAliveState::default(), - }; - b.state.peer.as_mut().unwrap().session = PeerSession::Connected { - session_key, - keepalive: KeepAliveState::default(), - }; - Self { - now: Instant::now(), - a, - b, - crypto_a, - crypto_b, - outputs_a: Vec::new(), - outputs_b: Vec::new(), - } - } - - fn send_a(&mut self, input: EngineInput) { - self.run_side(Side::A, input); - } - - fn send_b(&mut self, input: EngineInput) { - self.run_side(Side::B, input); - } - - fn drain_a(&mut self) -> Vec { - mem::take(&mut self.outputs_a) - } - - fn drain_b(&mut self) -> Vec { - mem::take(&mut self.outputs_b) - } - - fn run_side(&mut self, side: Side, input: EngineInput) { - let mut outputs = Vec::new(); - match side { - Side::A => self - .a - .run_tick(self.now, input, &self.crypto_a, &mut |output| { - outputs.push(output) - }), - Side::B => self - .b - .run_tick(self.now, input, &self.crypto_b, &mut |output| { - outputs.push(output) - }), - } - - let writes: Vec<(Token, Option, Vec)> = outputs - .iter() - .filter_map(|output| match output { - EngineOutput::WriteMessage { - token, - tracked, - bytes, - } => Some((*token, *tracked, bytes.clone())), - _ => None, - }) - .collect(); - - match side { - Side::A => self.outputs_a.extend(outputs), - Side::B => self.outputs_b.extend(outputs), - } - - for (token, tracked, bytes) in writes { - self.run_side( - side, - EngineInput::WriteCompleted { - token, - tracked, - result: Ok(()), - }, - ); - self.run_side(side.other(), EngineInput::Incoming(bytes)); - } - } -} - -#[test] -fn open_prefix_is_delivered_on_setup_output() { - let mut harness = Harness::new(EngineConfig::default()); - let request_prefix = BodyChunk { - bytes: b"req".to_vec(), - fin: true, - }; - - harness.send_a(EngineInput::OpenStream { - open_id: OpenId(1), - request_head: b"open-head".to_vec(), - request_prefix: Some(request_prefix.clone()), - config: StreamConfig::default(), - }); - - harness.now += EngineConfig::default().stream_ack_delay; - harness.send_b(EngineInput::TimerExpired); - - let outputs_a = harness.drain_a(); - let outputs_b = harness.drain_b(); - let stream_id = outputs_a - .iter() - .find_map(|output| match output { - EngineOutput::OpenStarted { stream_id, .. } => Some(*stream_id), - _ => None, - }) - .unwrap(); - - assert!(outputs_a.iter().any(|output| matches!( - output, - EngineOutput::OpenStarted { - open_id: OpenId(1), - stream_id: id, - } if *id == stream_id - ))); - assert!( - StreamNamespace::for_local(harness.crypto_a.xid(), harness.crypto_b.xid()) - .matches(stream_id) - ); - assert!(outputs_a.iter().any(|output| matches!( - output, - EngineOutput::OutboundClosed { - stream_id: id, - dir: Direction::Request, - } if *id == stream_id - ))); - - let opened = outputs_b.iter().find_map(|output| match output { - EngineOutput::InboundStreamOpened { - stream_id, - request_head, - request_prefix, - } => Some((*stream_id, request_head.clone(), request_prefix.clone())), - _ => None, - }); - assert_eq!( - opened, - Some(( - stream_id, - b"open-head".to_vec(), - Some(request_prefix.clone()), - )) - ); - assert!(!outputs_b - .iter() - .any(|output| matches!(output, EngineOutput::InboundData { .. }))); - assert!(!outputs_b - .iter() - .any(|output| matches!(output, EngineOutput::InboundFinished { .. }))); -} - -#[test] -fn unary_exchange_uses_open_and_accept_prefixes() { - let mut harness = Harness::new(EngineConfig::default()); - let request_prefix = BodyChunk { - bytes: b"req".to_vec(), - fin: true, - }; - let response_prefix = BodyChunk { - bytes: b"resp".to_vec(), - fin: true, - }; - - harness.send_a(EngineInput::OpenStream { - open_id: OpenId(7), - request_head: b"request-head".to_vec(), - request_prefix: Some(request_prefix.clone()), - config: StreamConfig::default(), - }); - - let outputs_a_open = harness.drain_a(); - let outputs_b = harness.drain_b(); - let started_stream_id = outputs_a_open - .iter() - .find_map(|output| match output { - EngineOutput::OpenStarted { stream_id, .. } => Some(*stream_id), - _ => None, - }) - .unwrap(); - let stream_id = outputs_b - .iter() - .find_map(|output| match output { - EngineOutput::InboundStreamOpened { stream_id, .. } => Some(*stream_id), - _ => None, - }) - .unwrap(); - assert_eq!(stream_id, started_stream_id); - - harness.send_b(EngineInput::AcceptStream { - stream_id, - response_head: b"response-head".to_vec(), - response_prefix: Some(response_prefix.clone()), - }); - - harness.now += EngineConfig::default().stream_ack_delay; - harness.send_a(EngineInput::TimerExpired); - - let outputs_a = harness.drain_a(); - let outputs_b = harness.drain_b(); - - let accepted = outputs_a.iter().find_map(|output| match output { - EngineOutput::OpenAccepted { - open_id, - stream_id, - response_head, - response_prefix, - } => Some(( - *open_id, - *stream_id, - response_head.clone(), - response_prefix.clone(), - )), - _ => None, - }); - assert_eq!( - accepted, - Some(( - OpenId(7), - stream_id, - b"response-head".to_vec(), - Some(response_prefix.clone()), - )) - ); - assert!(!outputs_a - .iter() - .any(|output| matches!(output, EngineOutput::InboundData { .. }))); - assert!(!outputs_a - .iter() - .any(|output| matches!(output, EngineOutput::InboundFinished { .. }))); - assert!(outputs_b.iter().any(|output| matches!( - output, - EngineOutput::OutboundClosed { - stream_id: id, - dir: Direction::Response, - } if *id == stream_id - ))); -} - -#[test] -fn simultaneous_opens_use_disjoint_stream_id_namespaces() { - let config = EngineConfig::default(); - let crypto_a = TestCrypto::new(11); - let crypto_b = TestCrypto::new(22); - let peer_a = crypto_a.peer(); - let peer_b = crypto_b.peer(); - let session_key = SymmetricKey::from_data([9; SymmetricKey::SYMMETRIC_KEY_SIZE]); - let mut a = Engine::new(config, crypto_a.identity.clone(), Some(peer_b)); - let mut b = Engine::new(config, crypto_b.identity.clone(), Some(peer_a)); - a.state.peer.as_mut().unwrap().session = PeerSession::Connected { - session_key: session_key.clone(), - keepalive: KeepAliveState::default(), - }; - b.state.peer.as_mut().unwrap().session = PeerSession::Connected { - session_key, - keepalive: KeepAliveState::default(), - }; - let now = Instant::now(); - - let outputs_a_open = run_engine( - &mut a, - now, - EngineInput::OpenStream { - open_id: OpenId(1), - request_head: b"a-open".to_vec(), - request_prefix: None, - config: StreamConfig::default(), - }, - &crypto_a, - ); - let outputs_b_open = run_engine( - &mut b, - now, - EngineInput::OpenStream { - open_id: OpenId(2), - request_head: b"b-open".to_vec(), - request_prefix: None, - config: StreamConfig::default(), - }, - &crypto_b, - ); - - let stream_id_a = outputs_a_open - .iter() - .find_map(|output| match output { - EngineOutput::OpenStarted { stream_id, .. } => Some(*stream_id), - _ => None, - }) - .unwrap(); - let stream_id_b = outputs_b_open - .iter() - .find_map(|output| match output { - EngineOutput::OpenStarted { stream_id, .. } => Some(*stream_id), - _ => None, - }) - .unwrap(); - - assert_ne!(stream_id_a, stream_id_b); - assert!(StreamNamespace::for_local(crypto_a.xid(), crypto_b.xid()).matches(stream_id_a)); - assert!(StreamNamespace::for_local(crypto_b.xid(), crypto_a.xid()).matches(stream_id_b)); - - let (token_a, tracked_a, bytes_a) = take_single_write(&outputs_a_open); - let (token_b, tracked_b, bytes_b) = take_single_write(&outputs_b_open); - - let _ = run_engine( - &mut a, - now, - EngineInput::WriteCompleted { - token: token_a, - tracked: tracked_a, - result: Ok(()), - }, - &crypto_a, - ); - let _ = run_engine( - &mut b, - now, - EngineInput::WriteCompleted { - token: token_b, - tracked: tracked_b, - result: Ok(()), - }, - &crypto_b, - ); - - let outputs_a_incoming = run_engine(&mut a, now, EngineInput::Incoming(bytes_b), &crypto_a); - let outputs_b_incoming = run_engine(&mut b, now, EngineInput::Incoming(bytes_a), &crypto_b); - - assert!(outputs_a_incoming.iter().any(|output| matches!( - output, - EngineOutput::InboundStreamOpened { - stream_id, - request_head, - .. - } if *stream_id == stream_id_b && request_head == b"b-open" - ))); - assert!(outputs_b_incoming.iter().any(|output| matches!( - output, - EngineOutput::InboundStreamOpened { - stream_id, - request_head, - .. - } if *stream_id == stream_id_a && request_head == b"a-open" - ))); - assert_eq!(a.streams.len(), 2); - assert_eq!(b.streams.len(), 2); -} - -#[test] -fn invalid_future_frame_does_not_ack_outstanding_open() { - let config = EngineConfig::default(); - let crypto_a = TestCrypto::new(31); - let crypto_b = TestCrypto::new(32); - let peer_a = crypto_a.peer(); - let peer_b = crypto_b.peer(); - let session_key = SymmetricKey::from_data([5; SymmetricKey::SYMMETRIC_KEY_SIZE]); - let mut a = Engine::new(config, crypto_a.identity.clone(), Some(peer_b)); - let mut _b = Engine::new(config, crypto_b.identity.clone(), Some(peer_a)); - a.state.peer.as_mut().unwrap().session = PeerSession::Connected { - session_key: session_key.clone(), - keepalive: KeepAliveState::default(), - }; - - let now = Instant::now(); - let outputs_open = run_engine( - &mut a, - now, - EngineInput::OpenStream { - open_id: OpenId(9), - request_head: b"open".to_vec(), - request_prefix: None, - config: StreamConfig::default(), - }, - &crypto_a, - ); - let stream_id = outputs_open - .iter() - .find_map(|output| match output { - EngineOutput::OpenStarted { stream_id, .. } => Some(*stream_id), - _ => None, - }) - .unwrap(); - - let message = StreamMessage { - tx_seq: StreamSeq(2), - ack: Some(crate::wire::stream::StreamAck { - base: StreamSeq(0), - bitmap: 0, - }), - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Accept(StreamFrameAccept { - stream_id, - response_head: Vec::new(), - response_prefix: None, - }), - }; - - let body = StreamBody::Message(message); - let record = wire::stream::encrypt_stream( - QlHeader { - sender: crypto_b.xid(), - recipient: crypto_a.xid(), - }, - &session_key, - &body, - [9; wire::encrypted_message::NONCE_SIZE], - ); - - let outputs_incoming = run_engine( - &mut a, - now, - EngineInput::Incoming(wire::encode_record(&record)), - &crypto_a, - ); - - assert!(!outputs_incoming - .iter() - .any(|output| matches!(output, EngineOutput::OpenAccepted { .. }))); - - let stream = a.streams.get(&stream_id).unwrap(); - assert!(stream.control().in_flight.contains_key(&StreamSeq::START)); - match stream { - StreamState::Initiator(state) => { - assert!(matches!(state.accept, InitiatorAccept::Opening(_))); - } - _ => panic!("expected initiator stream"), - } -} - -#[test] -fn out_of_order_remote_stream_buffers_until_open_arrives() { - let config = EngineConfig::default(); - let crypto_a = TestCrypto::new(41); - let crypto_b = TestCrypto::new(42); - let peer_b = crypto_b.peer(); - let session_key = SymmetricKey::from_data([6; SymmetricKey::SYMMETRIC_KEY_SIZE]); - let mut a = Engine::new(config, crypto_a.identity.clone(), Some(peer_b)); - a.state.peer.as_mut().unwrap().session = PeerSession::Connected { - session_key: session_key.clone(), - keepalive: KeepAliveState::default(), - }; - - let now = Instant::now(); - let stream_id = StreamId(StreamNamespace::for_local(crypto_b.xid(), crypto_a.xid()).bit() | 1); - - let data_message = StreamMessage { - tx_seq: StreamSeq(2), - ack: None, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(crate::wire::stream::StreamFrameData { - stream_id, - dir: Direction::Request, - chunk: BodyChunk { - bytes: b"hello".to_vec(), - fin: false, - }, - }), - }; - let data_body = StreamBody::Message(data_message); - let data_record = wire::stream::encrypt_stream( - QlHeader { - sender: crypto_b.xid(), - recipient: crypto_a.xid(), - }, - &session_key, - &data_body, - [11; wire::encrypted_message::NONCE_SIZE], - ); - - let outputs_data = run_engine( - &mut a, - now, - EngineInput::Incoming(wire::encode_record(&data_record)), - &crypto_a, - ); - - assert!(!outputs_data - .iter() - .any(|output| matches!(output, EngineOutput::InboundStreamOpened { .. }))); - assert!(!outputs_data - .iter() - .any(|output| matches!(output, EngineOutput::InboundData { .. }))); - assert!(outputs_data - .iter() - .any(|output| matches!(output, EngineOutput::WriteMessage { .. }))); - assert!(matches!( - a.streams.get(&stream_id), - Some(StreamState::Provisional(_)) - )); - - let open_message = StreamMessage { - tx_seq: StreamSeq(1), - ack: None, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Open(crate::wire::stream::StreamFrameOpen { - stream_id, - request_head: b"late-open".to_vec(), - request_prefix: None, - }), - }; - let open_body = StreamBody::Message(open_message); - let open_record = wire::stream::encrypt_stream( - QlHeader { - sender: crypto_b.xid(), - recipient: crypto_a.xid(), - }, - &session_key, - &open_body, - [12; wire::encrypted_message::NONCE_SIZE], - ); - - let outputs_open = run_engine( - &mut a, - now, - EngineInput::Incoming(wire::encode_record(&open_record)), - &crypto_a, - ); - - assert!(outputs_open.iter().any(|output| matches!( - output, - EngineOutput::InboundStreamOpened { - stream_id: id, - request_head, - request_prefix: None, - } if *id == stream_id && request_head == b"late-open" - ))); - assert!(outputs_open.iter().any(|output| matches!( - output, - EngineOutput::InboundData { - stream_id: id, - dir: Direction::Request, - bytes, - } if *id == stream_id && bytes == b"hello" - ))); -} - -#[test] -fn out_of_order_response_data_waits_for_accept() { - let config = EngineConfig::default(); - let crypto_a = TestCrypto::new(51); - let crypto_b = TestCrypto::new(52); - let peer_b = crypto_b.peer(); - let session_key = SymmetricKey::from_data([4; SymmetricKey::SYMMETRIC_KEY_SIZE]); - let mut a = Engine::new(config, crypto_a.identity.clone(), Some(peer_b)); - a.state.peer.as_mut().unwrap().session = PeerSession::Connected { - session_key: session_key.clone(), - keepalive: KeepAliveState::default(), - }; - - let now = Instant::now(); - let outputs_open = run_engine( - &mut a, - now, - EngineInput::OpenStream { - open_id: OpenId(12), - request_head: b"req".to_vec(), - request_prefix: None, - config: StreamConfig::default(), - }, - &crypto_a, - ); - let stream_id = outputs_open - .iter() - .find_map(|output| match output { - EngineOutput::OpenStarted { stream_id, .. } => Some(*stream_id), - _ => None, - }) - .unwrap(); - - let data_message = StreamMessage { - tx_seq: StreamSeq(2), - ack: None, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(crate::wire::stream::StreamFrameData { - stream_id, - dir: Direction::Response, - chunk: BodyChunk { - bytes: b"resp".to_vec(), - fin: false, - }, - }), - }; - let data_body = StreamBody::Message(data_message); - let data_record = wire::stream::encrypt_stream( - QlHeader { - sender: crypto_b.xid(), - recipient: crypto_a.xid(), - }, - &session_key, - &data_body, - [21; wire::encrypted_message::NONCE_SIZE], - ); - - let outputs_data = run_engine( - &mut a, - now, - EngineInput::Incoming(wire::encode_record(&data_record)), - &crypto_a, - ); - assert!(!outputs_data - .iter() - .any(|output| matches!(output, EngineOutput::OpenAccepted { .. }))); - assert!(!outputs_data - .iter() - .any(|output| matches!(output, EngineOutput::InboundData { .. }))); - - let accept_message = StreamMessage { - tx_seq: StreamSeq(1), - ack: None, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Accept(StreamFrameAccept { - stream_id, - response_head: b"resp-head".to_vec(), - response_prefix: None, - }), - }; - let accept_body = StreamBody::Message(accept_message); - let accept_record = wire::stream::encrypt_stream( - QlHeader { - sender: crypto_b.xid(), - recipient: crypto_a.xid(), - }, - &session_key, - &accept_body, - [22; wire::encrypted_message::NONCE_SIZE], - ); - - let outputs_accept = run_engine( - &mut a, - now, - EngineInput::Incoming(wire::encode_record(&accept_record)), - &crypto_a, - ); - - assert!(outputs_accept.iter().any(|output| matches!( - output, - EngineOutput::OpenAccepted { - open_id: OpenId(12), - stream_id: id, - response_head, - response_prefix: None, - } if *id == stream_id && response_head == b"resp-head" - ))); - assert!(outputs_accept.iter().any(|output| matches!( - output, - EngineOutput::InboundData { - stream_id: id, - dir: Direction::Response, - bytes, - } if *id == stream_id && bytes == b"resp" - ))); -} - -#[test] -fn delayed_ack_only_does_not_consume_sequence_space() { - let mut harness = Harness::new(EngineConfig::default()); - - harness.send_a(EngineInput::OpenStream { - open_id: OpenId(21), - request_head: b"open-head".to_vec(), - request_prefix: None, - config: StreamConfig::default(), - }); - - let outputs_a = harness.drain_a(); - let _outputs_b = harness.drain_b(); - let stream_id = outputs_a - .iter() - .find_map(|output| match output { - EngineOutput::OpenStarted { stream_id, .. } => Some(*stream_id), - _ => None, - }) - .unwrap(); - - harness.now += EngineConfig::default().stream_ack_delay; - harness.send_b(EngineInput::TimerExpired); - - let outputs_b = harness.drain_b(); - assert!(outputs_b - .iter() - .any(|output| matches!(output, EngineOutput::WriteMessage { tracked: None, .. }))); - - let stream = harness.b.streams.get(&stream_id).unwrap(); - assert!(stream.control().in_flight.is_empty()); - assert_eq!(stream.control().next_tx_seq, StreamSeq::START); -} - -#[test] -fn half_window_progress_flushes_ack_before_timer() { - let config = EngineConfig::default(); - let crypto_a = TestCrypto::new(61); - let crypto_b = TestCrypto::new(62); - let peer_b = crypto_b.peer(); - let session_key = SymmetricKey::from_data([8; SymmetricKey::SYMMETRIC_KEY_SIZE]); - let mut a = Engine::new(config, crypto_a.identity.clone(), Some(peer_b)); - a.state.peer.as_mut().unwrap().session = PeerSession::Connected { - session_key: session_key.clone(), - keepalive: KeepAliveState::default(), - }; - - let now = Instant::now(); - let stream_id = StreamId(StreamNamespace::for_local(crypto_b.xid(), crypto_a.xid()).bit() | 1); - let messages = [ - StreamMessage { - tx_seq: StreamSeq(1), - ack: None, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Open(crate::wire::stream::StreamFrameOpen { - stream_id, - request_head: b"open".to_vec(), - request_prefix: None, - }), - }, - StreamMessage { - tx_seq: StreamSeq(2), - ack: None, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(crate::wire::stream::StreamFrameData { - stream_id, - dir: Direction::Request, - chunk: BodyChunk { - bytes: b"a".to_vec(), - fin: false, - }, - }), - }, - StreamMessage { - tx_seq: StreamSeq(3), - ack: None, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(crate::wire::stream::StreamFrameData { - stream_id, - dir: Direction::Request, - chunk: BodyChunk { - bytes: b"b".to_vec(), - fin: false, - }, - }), - }, - StreamMessage { - tx_seq: StreamSeq(4), - ack: None, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(crate::wire::stream::StreamFrameData { - stream_id, - dir: Direction::Request, - chunk: BodyChunk { - bytes: b"c".to_vec(), - fin: false, - }, - }), - }, - ]; - - for message in messages.iter().take(3) { - let body = StreamBody::Message(message.clone()); - let record = wire::stream::encrypt_stream( - QlHeader { - sender: crypto_b.xid(), - recipient: crypto_a.xid(), - }, - &session_key, - &body, - [message.tx_seq.0 as u8; wire::encrypted_message::NONCE_SIZE], - ); - let outputs = run_engine( - &mut a, - now, - EngineInput::Incoming(wire::encode_record(&record)), - &crypto_a, - ); - assert!(!outputs - .iter() - .any(|output| matches!(output, EngineOutput::WriteMessage { tracked: None, .. }))); - } - - let body = StreamBody::Message(messages[3].clone()); - let record = wire::stream::encrypt_stream( - QlHeader { - sender: crypto_b.xid(), - recipient: crypto_a.xid(), - }, - &session_key, - &body, - [4; wire::encrypted_message::NONCE_SIZE], - ); - let outputs = run_engine( - &mut a, - now, - EngineInput::Incoming(wire::encode_record(&record)), - &crypto_a, - ); - - assert!(outputs - .iter() - .any(|output| matches!(output, EngineOutput::WriteMessage { tracked: None, .. }))); -} - -#[test] -fn out_of_order_loss_reports_selective_ack_bitmap() { - let crypto_a = TestCrypto::new(71); - let crypto_b = TestCrypto::new(72); - let session_key = SymmetricKey::from_data([3; SymmetricKey::SYMMETRIC_KEY_SIZE]); - let peer_b = crypto_b.peer(); - let mut a = connected_engine(&crypto_a, peer_b, session_key.clone()); - - let now = Instant::now(); - let stream_id = StreamId(StreamNamespace::for_local(crypto_b.xid(), crypto_a.xid()).bit() | 1); - let messages = [ - StreamMessage { - tx_seq: StreamSeq(1), - ack: None, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Open(StreamFrameOpen { - stream_id, - request_head: b"open".to_vec(), - request_prefix: None, - }), - }, - StreamMessage { - tx_seq: StreamSeq(2), - ack: None, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(StreamFrameData { - stream_id, - dir: Direction::Request, - chunk: BodyChunk { - bytes: b"a".to_vec(), - fin: false, - }, - }), - }, - StreamMessage { - tx_seq: StreamSeq(4), - ack: None, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(StreamFrameData { - stream_id, - dir: Direction::Request, - chunk: BodyChunk { - bytes: b"c".to_vec(), - fin: false, - }, - }), - }, - StreamMessage { - tx_seq: StreamSeq(5), - ack: None, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(StreamFrameData { - stream_id, - dir: Direction::Request, - chunk: BodyChunk { - bytes: b"d".to_vec(), - fin: false, - }, - }), - }, - ]; - - for message in &messages[..2] { - let record = wire::stream::encrypt_stream( - QlHeader { - sender: crypto_b.xid(), - recipient: crypto_a.xid(), - }, - &session_key, - &StreamBody::Message(message.clone()), - [message.tx_seq.0 as u8; wire::encrypted_message::NONCE_SIZE], - ); - let outputs = run_engine( - &mut a, - now, - EngineInput::Incoming(wire::encode_record(&record)), - &crypto_a, - ); - assert!(!outputs - .iter() - .any(|output| matches!(output, EngineOutput::WriteMessage { tracked: None, .. }))); - } - - let record4 = wire::stream::encrypt_stream( - QlHeader { - sender: crypto_b.xid(), - recipient: crypto_a.xid(), - }, - &session_key, - &StreamBody::Message(messages[2].clone()), - [4; wire::encrypted_message::NONCE_SIZE], - ); - let outputs4 = run_engine( - &mut a, - now, - EngineInput::Incoming(wire::encode_record(&record4)), - &crypto_a, - ); - let (ack_token4, ack_tracked4, ack_bytes4) = take_single_write(&outputs4); - assert_eq!(ack_tracked4, None); - let (_, ack_body4) = decode_stream_body(&ack_bytes4, &session_key); - assert!(matches!( - ack_body4, - StreamBody::Ack(StreamAckBody { - stream_id: id, - ack: StreamAck { - base: StreamSeq(2), - bitmap: 0b0000_0010, - }, - .. - }) if id == stream_id - )); - assert!(!outputs4 - .iter() - .any(|output| matches!(output, EngineOutput::InboundData { .. }))); - // the engine only starts a new outbound write after the previous one reports - // `WriteCompleted`. We need to retire the ACK-only write for seq 4 here so the - // follow-up out-of-order receive for seq 5 can emit its own updated ACK body. - let _ = run_engine( - &mut a, - now, - EngineInput::WriteCompleted { - token: ack_token4, - tracked: ack_tracked4, - result: Ok(()), - }, - &crypto_a, - ); - - let record5 = wire::stream::encrypt_stream( - QlHeader { - sender: crypto_b.xid(), - recipient: crypto_a.xid(), - }, - &session_key, - &StreamBody::Message(messages[3].clone()), - [5; wire::encrypted_message::NONCE_SIZE], - ); - let outputs5 = run_engine( - &mut a, - now, - EngineInput::Incoming(wire::encode_record(&record5)), - &crypto_a, - ); - let (_, _, ack_bytes5) = take_single_write(&outputs5); - let (_, ack_body5) = decode_stream_body(&ack_bytes5, &session_key); - assert!(matches!( - ack_body5, - StreamBody::Ack(StreamAckBody { - stream_id: id, - ack: StreamAck { - base: StreamSeq(2), - bitmap: 0b0000_0110, - }, - .. - }) if id == stream_id - )); - assert!(!outputs5 - .iter() - .any(|output| matches!(output, EngineOutput::InboundData { .. }))); -} - -#[test] -fn selective_ack_only_body_retires_acked_gap_tail() { - let crypto_a = TestCrypto::new(81); - let crypto_b = TestCrypto::new(82); - let session_key = SymmetricKey::from_data([2; SymmetricKey::SYMMETRIC_KEY_SIZE]); - let peer_b = crypto_b.peer(); - let mut a = connected_engine(&crypto_a, peer_b, session_key.clone()); - - let now = Instant::now(); - let stream_id = a - .state - .next_stream_id(StreamNamespace::for_local(crypto_a.xid(), crypto_b.xid())); - let mut stream = StreamState::Initiator(InitiatorStream { - meta: StreamMeta { - stream_id, - last_activity: now, - }, - control: StreamControl::default(), - request: OutboundState::from_prefix(Direction::Request, false), - response: InboundState::new(), - accept: InitiatorAccept::Opening(OpenWaiter { - open_id: Some(OpenId(1)), - open_timeout_token: Token(999), - }), - }); - let control = stream.control_mut(); - control.next_tx_seq = StreamSeq(6); - control.insert_in_flight(InFlightFrame { - tx_seq: StreamSeq(1), - frame: StreamFrame::Open(StreamFrameOpen { - stream_id, - request_head: b"open".to_vec(), - request_prefix: None, - }), - attempt: 0, - }); - for (tx_seq, byte) in [(2, b'a'), (3, b'b'), (4, b'c'), (5, b'd')] { - control.insert_in_flight(InFlightFrame { - tx_seq: StreamSeq(tx_seq), - frame: StreamFrame::Data(StreamFrameData { - stream_id, - dir: Direction::Request, - chunk: BodyChunk { - bytes: vec![byte], - fin: false, - }, - }), - attempt: 0, - }); - } - a.streams.insert(stream_id, stream); - - let ack_record = wire::stream::encrypt_stream( - QlHeader { - sender: crypto_b.xid(), - recipient: crypto_a.xid(), - }, - &session_key, - &StreamBody::Ack(StreamAckBody { - stream_id, - ack: StreamAck { - base: StreamSeq(2), - bitmap: 0b0000_0110, - }, - valid_until: wire::now_secs().saturating_add(60), - }), - [9; wire::encrypted_message::NONCE_SIZE], - ); - - let outputs = run_engine( - &mut a, - now, - EngineInput::Incoming(wire::encode_record(&ack_record)), - &crypto_a, - ); - - assert!(!outputs - .iter() - .any(|output| matches!(output, EngineOutput::OutboundFailed { .. }))); - let stream = a.streams.get(&stream_id).unwrap(); - let remaining: Vec<_> = stream - .control() - .in_flight - .iter() - .map(|(seq, _)| seq) - .collect(); - assert_eq!(remaining, vec![StreamSeq(3)]); - assert_eq!(stream.control().next_tx_seq, StreamSeq(6)); -} - -#[test] -fn timeout_retransmit_reuses_original_tx_seq_and_slot() { - let config = EngineConfig::default(); - let crypto_a = TestCrypto::new(91); - let crypto_b = TestCrypto::new(92); - let peer_b = crypto_b.peer(); - let session_key = SymmetricKey::from_data([1; SymmetricKey::SYMMETRIC_KEY_SIZE]); - let mut a = connected_engine(&crypto_a, peer_b, session_key.clone()); - - let now = Instant::now(); - let outputs_open = run_engine( - &mut a, - now, - EngineInput::OpenStream { - open_id: OpenId(44), - request_head: b"open".to_vec(), - request_prefix: None, - config: StreamConfig::default(), - }, - &crypto_a, - ); - let (token, tracked, bytes) = take_single_write(&outputs_open); - let tracked = tracked.unwrap(); - let (_, initial_body) = decode_stream_body(&bytes, &session_key); - assert!(matches!( - initial_body, - StreamBody::Message(StreamMessage { - tx_seq: StreamSeq(1), - frame: StreamFrame::Open(_), - .. - }) - )); - - let _outputs_written = run_engine( - &mut a, - now, - EngineInput::WriteCompleted { - token, - tracked: Some(tracked), - result: Ok(()), - }, - &crypto_a, - ); - - let stream = a.streams.get(&tracked.stream_id).unwrap(); - assert_eq!(stream.control().in_flight.len(), 1); - assert!(stream.control().in_flight.contains_key(&StreamSeq::START)); - assert_eq!(stream.control().next_tx_seq, StreamSeq(2)); - - let outputs_timeout = run_engine( - &mut a, - now + config.stream_ack_timeout, - EngineInput::TimerExpired, - &crypto_a, - ); - let (_, retransmit_tracked, retransmit_bytes) = take_single_write(&outputs_timeout); - assert_eq!(retransmit_tracked, Some(tracked)); - let (_, retransmit_body) = decode_stream_body(&retransmit_bytes, &session_key); - assert!(matches!( - retransmit_body, - StreamBody::Message(StreamMessage { - tx_seq: StreamSeq(1), - frame: StreamFrame::Open(_), - .. - }) - )); - - let stream = a.streams.get(&tracked.stream_id).unwrap(); - assert_eq!(stream.control().in_flight.len(), 1); - assert!(stream.control().in_flight.contains_key(&StreamSeq::START)); - assert_eq!(stream.control().next_tx_seq, StreamSeq(2)); - assert_eq!( - stream - .control() - .in_flight - .get(&StreamSeq::START) - .unwrap() - .attempt, - 1 - ); -} - -#[test] -fn replayed_heartbeat_is_ignored() { - let crypto_a = TestCrypto::new(101); - let crypto_b = TestCrypto::new(102); - let session_key = SymmetricKey::from_data([4; SymmetricKey::SYMMETRIC_KEY_SIZE]); - let peer_b = crypto_b.peer(); - let mut a = connected_engine(&crypto_a, peer_b, session_key.clone()); - let now = Instant::now(); - let heartbeat = wire::heartbeat::encrypt_heartbeat( - QlHeader { - sender: crypto_b.xid(), - recipient: crypto_a.xid(), - }, - &session_key, - wire::heartbeat::HeartbeatBody { - meta: wire::ControlMeta { - packet_id: PacketId(7), - valid_until: wire::now_secs().saturating_add(60), - }, - }, - [3; wire::encrypted_message::NONCE_SIZE], - ); - let bytes = wire::encode_record(&heartbeat); - - let first = run_engine(&mut a, now, EngineInput::Incoming(bytes.clone()), &crypto_a); - assert!(first - .iter() - .any(|output| matches!(output, EngineOutput::WriteMessage { tracked: None, .. }))); - - let second = run_engine(&mut a, now, EngineInput::Incoming(bytes), &crypto_a); - assert!(!second - .iter() - .any(|output| matches!(output, EngineOutput::WriteMessage { tracked: None, .. }))); -} - -#[test] -fn replayed_unpair_is_ignored_after_rebind() { - let config = EngineConfig::default(); - let crypto_a = TestCrypto::new(111); - let crypto_b = TestCrypto::new(112); - let peer_b = crypto_b.peer(); - let session_key = SymmetricKey::from_data([5; SymmetricKey::SYMMETRIC_KEY_SIZE]); - let mut a = Engine::new(config, crypto_a.identity.clone(), Some(peer_b.clone())); - a.state.peer.as_mut().unwrap().session = PeerSession::Connected { - session_key, - keepalive: KeepAliveState::default(), - }; - let now = Instant::now(); - let bytes = wire::encode_record(&wire::unpair::build_unpair_record( - &crypto_b.identity, - QlHeader { - sender: crypto_b.xid(), - recipient: crypto_a.xid(), - }, - wire::ControlMeta { - packet_id: PacketId(9), - valid_until: wire::now_secs().saturating_add(60), - }, - )); - - let first = run_engine(&mut a, now, EngineInput::Incoming(bytes.clone()), &crypto_a); - assert!(first - .iter() - .any(|output| matches!(output, EngineOutput::ClearPeer))); - assert!(a.state.peer.is_none()); - - let _ = run_engine( - &mut a, - now, - EngineInput::BindPeer(peer_b.clone()), - &crypto_a, - ); - assert!(a.state.peer.is_some()); - - let second = run_engine(&mut a, now, EngineInput::Incoming(bytes), &crypto_a); - assert!(!second - .iter() - .any(|output| matches!(output, EngineOutput::ClearPeer))); - assert_eq!( - a.state.peer.as_ref().map(|peer| peer.peer), - Some(peer_b.peer) - ); -} diff --git a/ql2/src/runtime/command.rs b/ql2/src/runtime/command.rs deleted file mode 100644 index e9eafd33..00000000 --- a/ql2/src/runtime/command.rs +++ /dev/null @@ -1,50 +0,0 @@ -use crate::{ - runtime::{AcceptedStreamDelivery, StreamConfig}, - wire::stream::{Direction, RejectCode, ResetCode}, - Peer, QlError, StreamId, -}; - -pub(crate) enum RuntimeCommand { - BindPeer { - peer: Peer, - }, - Pair, - Connect, - Unpair, - OpenStream { - request_head: Vec, - request_rx: async_channel::Receiver>, - accepted: oneshot::Sender>, - start: oneshot::Sender>, - config: StreamConfig, - }, - AcceptStream { - stream_id: StreamId, - response_head: Vec, - response_rx: async_channel::Receiver>, - }, - RejectStream { - stream_id: StreamId, - code: RejectCode, - }, - PollStream { - stream_id: StreamId, - }, - ResetOutbound { - stream_id: StreamId, - dir: Direction, - code: ResetCode, - }, - ResetInbound { - stream_id: StreamId, - dir: Direction, - code: ResetCode, - }, - ResponderDropped { - stream_id: StreamId, - }, - PendingAcceptDropped { - stream_id: StreamId, - }, - Incoming(Vec), -} diff --git a/ql2/src/runtime/driver.rs b/ql2/src/runtime/driver.rs deleted file mode 100644 index 285b7709..00000000 --- a/ql2/src/runtime/driver.rs +++ /dev/null @@ -1,721 +0,0 @@ -use std::{ - collections::{HashMap, VecDeque}, - future::Future, - task::Poll, - time::Instant, -}; - -use futures_lite::future::poll_fn; - -use crate::{ - engine::{self, Engine, EngineInput, EngineOutput, OpenId}, - platform::{PlatformFuture, QlPlatform}, - runtime::{ - command::RuntimeCommand, - handle::{InboundByteStream, InboundStream, StreamResponder}, - AcceptedStreamDelivery, HandlerEvent, InboundEvent, Runtime, - }, - wire::stream::{BodyChunk, Direction, ResetCode}, - QlError, StreamId, -}; - -struct InFlightWrite<'a> { - token: engine::Token, - tracked: Option, - future: PlatformFuture<'a, Result<(), QlError>>, -} - -enum DriverEvent { - Command(RuntimeCommand), - WriteCompleted { - token: engine::Token, - tracked: Option, - result: Result<(), QlError>, - }, - TimerExpired, - Closed, -} - -struct PendingOpen { - request_rx: async_channel::Receiver>, - start_tx: oneshot::Sender>, - accepted_tx: oneshot::Sender>, -} - -struct PendingAcceptDelivery { - tx: oneshot::Sender>, - response_rx: async_channel::Receiver, -} - -enum OutboundIo { - Open { - dir: Direction, - rx: async_channel::Receiver>, - pending_pull: bool, - finish_queued: bool, - }, - Closed, -} - -impl OutboundIo { - fn new(dir: Direction, rx: async_channel::Receiver>) -> Self { - Self::Open { - dir, - rx, - pending_pull: false, - finish_queued: false, - } - } - - fn set_pending_pull(&mut self) { - if let Self::Open { pending_pull, .. } = self { - *pending_pull = true; - } - } - - fn close(&mut self) { - *self = Self::Closed; - } - - fn poll_pending(&mut self, stream_id: StreamId, pending: &mut VecDeque) { - let Self::Open { - dir, - rx, - pending_pull, - finish_queued, - } = self - else { - return; - }; - - if !*pending_pull { - if rx.is_closed() && !*finish_queued { - *finish_queued = true; - pending.push_back(EngineInput::OutboundFinished { stream_id, dir: *dir }); - } - return; - } - - match rx.try_recv() { - Ok(bytes) => { - if bytes.is_empty() { - return; - } - *pending_pull = false; - pending.push_back(EngineInput::OutboundData { - stream_id, - dir: *dir, - bytes, - }); - if rx.is_closed() && rx.is_empty() && !*finish_queued { - *finish_queued = true; - pending.push_back(EngineInput::OutboundFinished { stream_id, dir: *dir }); - } - } - Err(async_channel::TryRecvError::Empty) => { - if rx.is_closed() && !*finish_queued { - *pending_pull = false; - *finish_queued = true; - pending.push_back(EngineInput::OutboundFinished { stream_id, dir: *dir }); - } - } - Err(async_channel::TryRecvError::Closed) => { - *pending_pull = false; - if !*finish_queued { - *finish_queued = true; - pending.push_back(EngineInput::OutboundFinished { stream_id, dir: *dir }); - } - } - } - } -} - -enum InboundIo { - Open(async_channel::Sender), - Closed, -} - -impl InboundIo { - fn new(tx: async_channel::Sender) -> Self { - Self::Open(tx) - } - - fn write_or_cancel( - &mut self, - stream_id: StreamId, - dir: Direction, - bytes: Vec, - pending: &mut VecDeque, - ) { - let Self::Open(tx) = self else { - pending.push_back(EngineInput::ResetInbound { - stream_id, - dir, - code: ResetCode::Cancelled, - }); - return; - }; - if tx.try_send(InboundEvent::Data(bytes)).is_err() { - tx.close(); - *self = Self::Closed; - pending.push_back(EngineInput::ResetInbound { - stream_id, - dir, - code: ResetCode::Cancelled, - }); - } - } - - fn finish(&mut self) { - if let Self::Open(tx) = self { - let _ = tx.try_send(InboundEvent::Finished); - tx.close(); - } - *self = Self::Closed; - } - - fn fail(&mut self, error: QlError) { - if let Self::Open(tx) = self { - let _ = tx.try_send(InboundEvent::Failed(error)); - tx.close(); - } - *self = Self::Closed; - } - - fn close(&mut self) { - if let Self::Open(tx) = self { - let _ = tx.try_send(InboundEvent::Failed(QlError::Cancelled)); - tx.close(); - } - *self = Self::Closed; - } - - fn apply_prefix( - &mut self, - stream_id: StreamId, - dir: Direction, - prefix: &BodyChunk, - pending: &mut VecDeque, - ) { - if !prefix.bytes.is_empty() { - self.write_or_cancel(stream_id, dir, prefix.bytes.clone(), pending); - } - if prefix.fin { - self.finish(); - } - } -} - -enum PendingAcceptState { - Waiting(PendingAcceptDelivery), - Dropped, - Resolved, -} - -enum ResponderResponseIo { - Pending, - Streaming(OutboundIo), - Rejected, -} - -enum DriverStreamIo { - Initiator { - request: OutboundIo, - response: InboundIo, - pending_accept: PendingAcceptState, - }, - Responder { - request: InboundIo, - response: ResponderResponseIo, - }, -} - -impl DriverStreamIo { - fn outbound_mut(&mut self, dir: Direction) -> Option<&mut OutboundIo> { - match self { - Self::Initiator { request, .. } if dir == Direction::Request => Some(request), - Self::Responder { - response: ResponderResponseIo::Streaming(outbound), - .. - } if dir == Direction::Response => Some(outbound), - _ => None, - } - } - - fn inbound_mut(&mut self, dir: Direction) -> Option<&mut InboundIo> { - match self { - Self::Initiator { response, .. } if dir == Direction::Response => Some(response), - Self::Responder { request, .. } if dir == Direction::Request => Some(request), - _ => None, - } - } - - fn close_all(&mut self) { - match self { - Self::Initiator { - request, - response, - pending_accept, - } => { - request.close(); - response.close(); - *pending_accept = PendingAcceptState::Resolved; - } - Self::Responder { request, response } => { - request.close(); - if let ResponderResponseIo::Streaming(outbound) = response { - outbound.close(); - } - *response = ResponderResponseIo::Rejected; - } - } - } -} - -struct DriverState { - engine: Engine, - pending_inputs: VecDeque, - next_timer: Option, - next_open_id: u64, - pending_opens: HashMap, - streams: HashMap, -} - -impl DriverState { - fn new( - config: engine::EngineConfig, - local_xid: bc_components::XID, - peer: Option, - ) -> Self { - let engine = Engine::new(config, local_xid, peer); - Self { - engine, - pending_inputs: VecDeque::new(), - next_timer: None, - next_open_id: 1, - pending_opens: HashMap::new(), - streams: HashMap::new(), - } - } - - fn push_input(&mut self, input: EngineInput) { - self.pending_inputs.push_back(input); - } - - fn translate_command(&mut self, command: RuntimeCommand) { - match command { - RuntimeCommand::BindPeer { peer } => self.push_input(EngineInput::BindPeer(peer)), - RuntimeCommand::Pair => self.push_input(EngineInput::Pair), - RuntimeCommand::Connect => self.push_input(EngineInput::Connect), - RuntimeCommand::Unpair => self.push_input(EngineInput::Unpair), - RuntimeCommand::Incoming(bytes) => self.push_input(EngineInput::Incoming(bytes)), - RuntimeCommand::OpenStream { - request_head, - request_rx, - accepted, - start, - config, - } => { - let open_id = OpenId(self.next_open_id); - self.next_open_id = self.next_open_id.wrapping_add(1); - self.pending_opens.insert( - open_id, - PendingOpen { - request_rx, - start_tx: start, - accepted_tx: accepted, - }, - ); - self.push_input(EngineInput::OpenStream { - open_id, - request_head, - request_prefix: None, - config, - }); - } - RuntimeCommand::AcceptStream { - stream_id, - response_head, - response_rx, - } => { - if let Some(DriverStreamIo::Responder { response, .. }) = - self.streams.get_mut(&stream_id) - { - *response = ResponderResponseIo::Streaming(OutboundIo::new( - Direction::Response, - response_rx, - )); - } - self.push_input(EngineInput::AcceptStream { - stream_id, - response_head, - response_prefix: None, - }); - } - RuntimeCommand::RejectStream { stream_id, code } => { - if let Some(DriverStreamIo::Responder { response, .. }) = - self.streams.get_mut(&stream_id) - { - *response = ResponderResponseIo::Rejected; - } - self.push_input(EngineInput::RejectStream { stream_id, code }); - } - RuntimeCommand::PollStream { stream_id } => self.poll_stream(stream_id), - RuntimeCommand::ResetOutbound { - stream_id, - dir, - code, - } => self.push_input(EngineInput::ResetOutbound { - stream_id, - dir, - code, - }), - RuntimeCommand::ResetInbound { - stream_id, - dir, - code, - } => self.push_input(EngineInput::ResetInbound { - stream_id, - dir, - code, - }), - RuntimeCommand::ResponderDropped { stream_id } => { - self.push_input(EngineInput::ResponderDropped { stream_id }); - } - RuntimeCommand::PendingAcceptDropped { stream_id } => { - if let Some(DriverStreamIo::Initiator { pending_accept, .. }) = - self.streams.get_mut(&stream_id) - { - if matches!(pending_accept, PendingAcceptState::Waiting(_)) { - *pending_accept = PendingAcceptState::Dropped; - } - } - self.push_input(EngineInput::PendingAcceptDropped { stream_id }); - self.push_input(EngineInput::ResetInbound { - stream_id, - dir: Direction::Response, - code: ResetCode::Cancelled, - }); - } - } - } - - fn poll_stream(&mut self, stream_id: StreamId) { - if let Some(stream) = self.streams.get_mut(&stream_id) { - match stream { - DriverStreamIo::Initiator { request, .. } => { - request.poll_pending(stream_id, &mut self.pending_inputs) - } - DriverStreamIo::Responder { response, .. } => { - if let ResponderResponseIo::Streaming(outbound) = response { - outbound.poll_pending(stream_id, &mut self.pending_inputs); - } - } - } - } - } -} - -impl Runtime

{ - pub async fn run(self) { - let runtime_tx = self.tx.upgrade().expect("runtime tx"); - let local_xid = self.platform.xid(); - let mut state = DriverState::new( - self.config.engine, - local_xid, - self.platform.load_peer().await, - ); - let mut in_flight: Option> = None; - - loop { - if let Some(input) = state.pending_inputs.pop_front() { - let now = Instant::now(); - let pending_inputs = &mut state.pending_inputs; - let next_timer = &mut state.next_timer; - let pending_opens = &mut state.pending_opens; - let streams = &mut state.streams; - state - .engine - .run_tick(now, input, &self.platform, &mut |output| { - self.apply_output( - pending_inputs, - next_timer, - pending_opens, - streams, - &runtime_tx, - &mut in_flight, - output, - ); - }); - continue; - } - - if self.rx.is_closed() { - break; - } - - match self - .next_driver_event(state.next_timer, in_flight.as_mut()) - .await - { - DriverEvent::Command(command) => state.translate_command(command), - DriverEvent::WriteCompleted { - token, - tracked, - result, - } => { - in_flight = None; - state.push_input(EngineInput::WriteCompleted { - token, - tracked, - result, - }); - } - DriverEvent::TimerExpired => state.push_input(EngineInput::TimerExpired), - DriverEvent::Closed => break, - } - } - } - - async fn next_driver_event<'a>( - &'a self, - next_timer: Option, - mut in_flight: Option<&mut InFlightWrite<'a>>, - ) -> DriverEvent { - let recv_future = self.rx.recv(); - futures_lite::pin!(recv_future); - - let mut sleep_future = next_timer.map(|deadline| { - let timeout = deadline.saturating_duration_since(Instant::now()); - self.platform.sleep(timeout) - }); - - poll_fn(|cx| { - if let Some(in_flight) = in_flight.as_mut() { - if let Poll::Ready(result) = in_flight.future.as_mut().poll(cx) { - return Poll::Ready(DriverEvent::WriteCompleted { - token: in_flight.token, - tracked: in_flight.tracked, - result, - }); - } - } - - if let Some(future) = sleep_future.as_mut() { - if let Poll::Ready(()) = future.as_mut().poll(cx) { - return Poll::Ready(DriverEvent::TimerExpired); - } - } - - recv_future.as_mut().poll(cx).map(|res| match res { - Ok(command) => DriverEvent::Command(command), - Err(_) => DriverEvent::Closed, - }) - }) - .await - } - - fn apply_output<'a>( - &'a self, - pending_inputs: &mut VecDeque, - next_timer: &mut Option, - pending_opens: &mut HashMap, - streams: &mut HashMap, - runtime_tx: &async_channel::Sender, - in_flight: &mut Option>, - output: EngineOutput, - ) { - match output { - EngineOutput::SetTimer(deadline) => *next_timer = deadline, - EngineOutput::WriteMessage { - token, - tracked, - bytes, - } => { - *in_flight = Some(InFlightWrite { - token, - tracked, - future: self.platform.write_message(bytes), - }); - } - EngineOutput::PeerStatusChanged { peer, session } => { - self.platform.handle_peer_status(peer, &session); - } - EngineOutput::PersistPeer(peer) => self.platform.persist_peer(peer), - EngineOutput::ClearPeer => self.platform.clear_peer(), - EngineOutput::OpenStarted { open_id, stream_id } => { - let Some(pending) = pending_opens.remove(&open_id) else { - return; - }; - let _ = pending.start_tx.send(Ok(stream_id)); - let (response_tx, response_rx) = async_channel::unbounded(); - streams.insert( - stream_id, - DriverStreamIo::Initiator { - request: OutboundIo::new(Direction::Request, pending.request_rx), - response: InboundIo::new(response_tx), - pending_accept: PendingAcceptState::Waiting(PendingAcceptDelivery { - tx: pending.accepted_tx, - response_rx, - }), - }, - ); - } - EngineOutput::OpenAccepted { - stream_id, - response_head, - response_prefix, - .. - } => { - let Some(DriverStreamIo::Initiator { - response, - pending_accept, - .. - }) = streams.get_mut(&stream_id) - else { - return; - }; - if let Some(prefix) = response_prefix.as_ref() { - response.apply_prefix(stream_id, Direction::Response, prefix, pending_inputs); - } - match std::mem::replace(pending_accept, PendingAcceptState::Resolved) { - PendingAcceptState::Waiting(delivery) => { - let _ = delivery.tx.send(Ok(AcceptedStreamDelivery { - stream_id, - response_head, - response: delivery.response_rx, - tx: runtime_tx.clone(), - })); - } - PendingAcceptState::Dropped => { - *pending_accept = PendingAcceptState::Dropped; - } - PendingAcceptState::Resolved => {} - } - } - EngineOutput::OpenFailed { - open_id, - stream_id, - error, - } => { - if let Some(pending) = pending_opens.remove(&open_id) { - let _ = pending.start_tx.send(Err(error)); - return; - } - let Some(DriverStreamIo::Initiator { pending_accept, .. }) = - streams.get_mut(&stream_id) - else { - return; - }; - match std::mem::replace(pending_accept, PendingAcceptState::Resolved) { - PendingAcceptState::Waiting(delivery) => { - let _ = delivery.tx.send(Err(error)); - } - PendingAcceptState::Dropped => { - *pending_accept = PendingAcceptState::Dropped; - } - PendingAcceptState::Resolved => {} - } - } - EngineOutput::InboundStreamOpened { - stream_id, - request_head, - request_prefix, - } => { - let (request_tx, request_rx) = async_channel::unbounded(); - let mut request = InboundIo::new(request_tx); - if let Some(prefix) = request_prefix.as_ref() { - request.apply_prefix(stream_id, Direction::Request, prefix, pending_inputs); - } - streams.insert( - stream_id, - DriverStreamIo::Responder { - request, - response: ResponderResponseIo::Pending, - }, - ); - self.platform - .handle_inbound(HandlerEvent::Stream(InboundStream { - stream_id, - request_head, - request: InboundByteStream::new( - stream_id, - Direction::Request, - request_rx, - runtime_tx.clone(), - ), - respond_to: StreamResponder::new(stream_id, runtime_tx.clone()), - })); - } - EngineOutput::InboundData { - stream_id, - dir, - bytes, - } => { - if let Some(stream) = streams.get_mut(&stream_id) { - if let Some(inbound) = stream.inbound_mut(dir) { - inbound.write_or_cancel(stream_id, dir, bytes, pending_inputs); - } - } - } - EngineOutput::InboundFinished { stream_id, dir } => { - if let Some(stream) = streams.get_mut(&stream_id) { - if let Some(inbound) = stream.inbound_mut(dir) { - inbound.finish(); - } - } - } - EngineOutput::InboundFailed { - stream_id, - dir, - error, - } => { - if let Some(stream) = streams.get_mut(&stream_id) { - if let Some(inbound) = stream.inbound_mut(dir) { - inbound.fail(error); - } - } - } - EngineOutput::NeedOutboundData { stream_id, dir } => { - if let Some(stream) = streams.get_mut(&stream_id) { - if let Some(outbound) = stream.outbound_mut(dir) { - outbound.set_pending_pull(); - } - } - poll_stream(streams, pending_inputs, stream_id); - } - EngineOutput::OutboundClosed { stream_id, dir } - | EngineOutput::OutboundFailed { stream_id, dir, .. } => { - if let Some(stream) = streams.get_mut(&stream_id) { - if let Some(outbound) = stream.outbound_mut(dir) { - outbound.close(); - } - } - } - EngineOutput::StreamReaped { stream_id } => { - if let Some(mut stream) = streams.remove(&stream_id) { - stream.close_all(); - } - } - } - } -} - -fn poll_stream( - streams: &mut HashMap, - pending_inputs: &mut VecDeque, - stream_id: StreamId, -) { - if let Some(stream) = streams.get_mut(&stream_id) { - match stream { - DriverStreamIo::Initiator { request, .. } => { - request.poll_pending(stream_id, pending_inputs) - } - DriverStreamIo::Responder { response, .. } => { - if let ResponderResponseIo::Streaming(outbound) = response { - outbound.poll_pending(stream_id, pending_inputs); - } - } - } - } -} diff --git a/ql2/src/runtime/handle.rs b/ql2/src/runtime/handle.rs deleted file mode 100644 index 443d81ab..00000000 --- a/ql2/src/runtime/handle.rs +++ /dev/null @@ -1,388 +0,0 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -use async_channel::{Receiver, Sender}; - -use crate::{ - runtime::{command::RuntimeCommand, AcceptedStreamDelivery, InboundEvent, StreamConfig}, - wire::stream::{Direction, RejectCode, ResetCode}, - Peer, QlError, StreamId, -}; - -#[derive(Clone)] -pub struct RuntimeHandle { - pub(crate) tx: Sender, -} - -pub struct PendingStream { - pub request: OutboundByteStream, - pub accepted: PendingAccept, -} - -#[derive(Debug)] -pub struct AcceptedStream { - pub stream_id: StreamId, - pub response_head: Vec, - pub response: InboundByteStream, -} - -#[derive(Debug)] -pub struct InboundStream { - pub stream_id: StreamId, - pub request_head: Vec, - pub request: InboundByteStream, - pub respond_to: StreamResponder, -} - -#[derive(Debug)] -pub struct StreamResponder { - stream_id: StreamId, - tx: Sender, - armed: bool, -} - -pub struct InboundByteStream { - stream_id: StreamId, - dir: Direction, - rx: Receiver, - tx: Sender, - finished: bool, -} - -impl std::fmt::Debug for InboundByteStream { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("InboundByteStream") - .field("stream_id", &self.stream_id) - .field("dir", &self.dir) - .field("finished", &self.finished) - .finish_non_exhaustive() - } -} - -pub struct OutboundByteStream { - stream_id: StreamId, - dir: Direction, - chunks: Option>>, - tx: Sender, -} - -pub struct PendingAccept { - stream_id: StreamId, - rx: Option>>, - tx: Sender, -} - -impl Future for PendingAccept { - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - let Some(rx) = this.rx.as_mut() else { - return Poll::Ready(Err(QlError::Cancelled)); - }; - Pin::new(rx).poll(cx).map(|result| match result { - Ok(Ok(delivery)) => { - let AcceptedStreamDelivery { - stream_id, - response_head, - response, - tx, - } = delivery; - this.rx = None; - Ok(AcceptedStream { - stream_id, - response_head, - response: InboundByteStream::new(stream_id, Direction::Response, response, tx), - }) - } - Ok(Err(error)) => { - this.rx = None; - Err(error) - } - Err(_) => { - this.rx = None; - Err(QlError::Cancelled) - } - }) - } -} - -impl Drop for PendingAccept { - fn drop(&mut self) { - if self.rx.take().is_none() { - return; - } - let _ = self.tx.try_send(RuntimeCommand::PendingAcceptDropped { - stream_id: self.stream_id, - }); - } -} - -impl InboundByteStream { - pub(crate) fn new( - stream_id: StreamId, - dir: Direction, - rx: Receiver, - tx: Sender, - ) -> Self { - Self { - stream_id, - dir, - rx, - tx, - finished: false, - } - } - - pub async fn next_chunk(&mut self) -> Result>, QlError> { - if self.finished { - return Ok(None); - } - match self.rx.recv().await { - Ok(InboundEvent::Data(bytes)) => Ok(Some(bytes)), - Ok(InboundEvent::Finished) => { - self.finished = true; - Ok(None) - } - Ok(InboundEvent::Failed(error)) => { - self.finished = true; - Err(error) - } - Err(_) => { - self.finished = true; - Err(QlError::Cancelled) - } - } - } - - pub async fn reset(mut self, code: ResetCode) -> Result<(), QlError> { - self.finished = true; - self.tx - .send(RuntimeCommand::ResetInbound { - stream_id: self.stream_id, - dir: self.dir, - code, - }) - .await - .map_err(|_| QlError::Cancelled) - } -} - -impl Drop for InboundByteStream { - fn drop(&mut self) { - if self.finished { - return; - } - let _ = self.tx.try_send(RuntimeCommand::ResetInbound { - stream_id: self.stream_id, - dir: self.dir, - code: ResetCode::Cancelled, - }); - } -} - -impl OutboundByteStream { - pub(crate) fn new( - stream_id: StreamId, - dir: Direction, - chunks: Sender>, - tx: Sender, - ) -> Self { - Self { - stream_id, - dir, - chunks: Some(chunks), - tx, - } - } - - pub async fn write(&mut self, bytes: &[u8]) -> Result { - if bytes.is_empty() { - return Ok(0); - } - let sender = self.chunks.as_ref().expect("stream not finished or reset"); - sender - .send(bytes.to_vec()) - .await - .map_err(|_| QlError::Cancelled)?; - self.tx - .try_send(RuntimeCommand::PollStream { - stream_id: self.stream_id, - }) - .map_err(|_| QlError::Cancelled)?; - Ok(bytes.len()) - } - - pub async fn write_all(&mut self, mut bytes: &[u8]) -> Result<(), QlError> { - while !bytes.is_empty() { - let written = self.write(bytes).await?; - if written == 0 { - return Err(QlError::Cancelled); - } - bytes = &bytes[written..]; - } - Ok(()) - } - - pub async fn finish(mut self) -> Result<(), QlError> { - if self.chunks.take().is_none() { - return Ok(()); - } - self.tx - .try_send(RuntimeCommand::PollStream { - stream_id: self.stream_id, - }) - .map_err(|_| QlError::Cancelled)?; - Ok(()) - } - - pub async fn reset(mut self, code: ResetCode) -> Result<(), QlError> { - self.chunks.take(); - self.tx - .send(RuntimeCommand::ResetOutbound { - stream_id: self.stream_id, - dir: self.dir, - code, - }) - .await - .map_err(|_| QlError::Cancelled) - } -} - -impl Drop for OutboundByteStream { - fn drop(&mut self) { - if self.chunks.take().is_none() { - return; - } - let _ = self.tx.try_send(RuntimeCommand::ResetOutbound { - stream_id: self.stream_id, - dir: self.dir, - code: ResetCode::Cancelled, - }); - } -} - -impl StreamResponder { - pub(crate) fn new(stream_id: StreamId, tx: Sender) -> Self { - Self { - stream_id, - tx, - armed: true, - } - } - - pub fn accept(mut self, response_head: Vec) -> Result { - self.armed = false; - let (response_tx, response_rx) = async_channel::bounded(1); - self.tx - .send_blocking(RuntimeCommand::AcceptStream { - stream_id: self.stream_id, - response_head, - response_rx, - }) - .map_err(|_| QlError::Cancelled)?; - Ok(OutboundByteStream::new( - self.stream_id, - Direction::Response, - response_tx, - self.tx.clone(), - )) - } - - pub fn reject(mut self, code: RejectCode) -> Result<(), QlError> { - self.armed = false; - self.tx - .try_send(RuntimeCommand::RejectStream { - stream_id: self.stream_id, - code, - }) - .map_err(|_| QlError::Cancelled) - } -} - -impl Drop for StreamResponder { - fn drop(&mut self) { - if !self.armed { - return; - } - let _ = self.tx.try_send(RuntimeCommand::ResponderDropped { - stream_id: self.stream_id, - }); - } -} - -impl RuntimeHandle { - pub fn bind_peer(&self, peer: Peer) { - self.send(RuntimeCommand::BindPeer { peer }) - } - - pub fn pair(&self) -> Result<(), QlError> { - self.tx - .send_blocking(RuntimeCommand::Pair) - .map_err(|_| QlError::Cancelled) - } - - pub fn connect(&self) -> Result<(), QlError> { - self.tx - .send_blocking(RuntimeCommand::Connect) - .map_err(|_| QlError::Cancelled) - } - - pub fn unpair(&self) -> Result<(), QlError> { - self.tx - .send_blocking(RuntimeCommand::Unpair) - .map_err(|_| QlError::Cancelled) - } - - pub fn send_incoming(&self, bytes: Vec) { - self.send(RuntimeCommand::Incoming(bytes)) - } - - pub async fn open_stream( - &self, - request_head: Vec, - config: StreamConfig, - ) -> Result { - let (request_tx, request_rx) = async_channel::bounded(1); - let (accepted_tx, accepted_rx) = oneshot::channel(); - let (start_tx, start_rx) = oneshot::channel(); - - self.tx - .send(RuntimeCommand::OpenStream { - request_head, - request_rx, - accepted: accepted_tx, - start: start_tx, - config, - }) - .await - .map_err(|_| QlError::Cancelled)?; - - let stream_id = start_rx.await.unwrap_or(Err(QlError::Cancelled))?; - - Ok(PendingStream { - request: OutboundByteStream::new( - stream_id, - Direction::Request, - request_tx, - self.tx.clone(), - ), - accepted: PendingAccept { - stream_id, - rx: Some(accepted_rx), - tx: self.tx.clone(), - }, - }) - } -} - -impl RuntimeHandle { - #[inline] - #[track_caller] - fn send(&self, cmd: RuntimeCommand) { - self.tx.send_blocking(cmd).expect("runtime is alive") - } -} diff --git a/ql2/src/tests/handshake.rs b/ql2/src/tests/handshake.rs deleted file mode 100644 index ab656a02..00000000 --- a/ql2/src/tests/handshake.rs +++ /dev/null @@ -1,99 +0,0 @@ -use std::time::Duration; - -use super::*; - -#[tokio::test(flavor = "current_thread")] -async fn handshake_initiator_connects() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn handshake_timeout_disconnects() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - handshake_timeout: Duration::from_millis(60), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, _outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn confirm_write_failure_disconnects_initiator() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new_with_failed_write(1, 2); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Initiator).await; - await_status(&status_b, peer_a.xid, PeerStage::Responder).await; - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - }) - .await; -} diff --git a/ql2/src/tests/heartbeat.rs b/ql2/src/tests/heartbeat.rs deleted file mode 100644 index 3f4dc0c7..00000000 --- a/ql2/src/tests/heartbeat.rs +++ /dev/null @@ -1,455 +0,0 @@ -use bc_components::SymmetricKey; - -use super::*; - -#[tokio::test(flavor = "current_thread")] -async fn heartbeat_ignored_without_session() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, _status_a) = TestPlatform::new(1); - let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); - - let peer_a = platform_a.xid(); - let peer_b = platform_b.xid(); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - tokio::task::spawn_local(async move { runtime_a.run().await }); - - handle_a.bind_peer(Peer { - peer: peer_b, - signing_key: platform_b.signing_public_key().clone(), - encapsulation_key: platform_b.encapsulation_public_key().clone(), - }); - - let heartbeat = wire::heartbeat::encrypt_heartbeat( - QlHeader { - sender: peer_b, - recipient: peer_a, - }, - &SymmetricKey::new(), - HeartbeatBody { - packet_id: PacketId(1), - valid_until: now_secs().saturating_add(60), - }, - test_encryption_nonce(1), - ); - handle_a.send_incoming(wire::encode_record(&heartbeat)); - - let result = tokio::time::timeout(Duration::from_millis(50), outbound_a.recv()).await; - assert!(result.is_err(), "expected heartbeat to be ignored"); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn keepalive_disabled_no_heartbeat() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let result = tokio::time::timeout(Duration::from_millis(120), heartbeat_rx.recv()).await; - assert!(result.is_err(), "unexpected heartbeat while disabled"); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn heartbeat_sent_after_idle() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(30), - timeout: Duration::from_millis(80), - }; - let config_a = RuntimeConfig { - engine: EngineConfig { - keep_alive: Some(keep_alive), - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let config_b = RuntimeConfig { - engine: EngineConfig { - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - tokio::time::timeout(Duration::from_millis(200), heartbeat_rx.recv()) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn heartbeat_reply_when_connected() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(30), - timeout: Duration::from_millis(80), - }; - let config_a = RuntimeConfig { - engine: EngineConfig { - keep_alive: Some(keep_alive), - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let config_b = RuntimeConfig { - engine: EngineConfig { - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_ab_tx, heartbeat_ab_rx) = async_channel::unbounded(); - let (heartbeat_ba_tx, heartbeat_ba_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_ab_tx); - spawn_heartbeat_tap_forwarder(outbound_b, handle_a.clone(), heartbeat_ba_tx); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - tokio::time::timeout(Duration::from_millis(200), heartbeat_ab_rx.recv()) - .await - .unwrap() - .unwrap(); - tokio::time::timeout(Duration::from_millis(200), heartbeat_ba_rx.recv()) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn any_stream_clears_pending() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(120), - timeout: Duration::from_millis(40), - }; - let config_a = RuntimeConfig { - engine: EngineConfig { - keep_alive: Some(keep_alive), - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let config_b = RuntimeConfig { - engine: EngineConfig { - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a, inbound_a) = InboundPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); - spawn_drop_heartbeat_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - tokio::time::timeout(Duration::from_millis(200), heartbeat_rx.recv()) - .await - .unwrap() - .unwrap(); - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_a.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let response = stream.respond_to.accept(Vec::new()).unwrap(); - response.finish().await.unwrap(); - }); - - let pending = handle_b - .open_stream(Vec::new(), Default::default()) - .await - .unwrap(); - pending.request.finish().await.unwrap(); - let _ = pending.accepted.await.unwrap(); - - let window = keep_alive.timeout + Duration::from_millis(20); - let disconnect = tokio::time::timeout(window, async { - loop { - if let Ok(event) = status_a.recv().await { - if event.peer == peer_b.xid && event.stage == PeerStage::Disconnected { - return; - } - } - } - }) - .await; - assert!(disconnect.is_err(), "unexpected disconnect"); - - let _ = responder_task.await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn heartbeat_timeout_disconnects_and_drops_outbound() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(80), - timeout: Duration::from_millis(60), - }; - let config_a = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(300), - keep_alive: Some(keep_alive), - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let config_b = RuntimeConfig { - engine: EngineConfig { - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(2); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(1); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let drop_flag = Arc::new(AtomicBool::new(false)); - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_gated_forwarder(outbound_b, handle_a.clone(), drop_flag.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let response = stream.respond_to.accept(Vec::new()).unwrap(); - response.finish().await.unwrap(); - }); - - drop_flag.store(true, Ordering::Relaxed); - - let pending = handle_a - .open_stream(Vec::new(), Default::default()) - .await - .unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - - let result = tokio::time::timeout(Duration::from_millis(300), pending.accepted) - .await - .unwrap(); - assert!(matches!(result, Err(QlError::SendFailed))); - - responder_task.abort(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn no_ping_pong() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(200), - timeout: Duration::from_millis(60), - }; - let config_a = RuntimeConfig { - engine: EngineConfig { - keep_alive: Some(keep_alive), - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let config_b = RuntimeConfig { - engine: EngineConfig { - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(platform_b, config_b); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_ab_tx, heartbeat_ab_rx) = async_channel::unbounded(); - let (heartbeat_ba_tx, heartbeat_ba_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_ab_tx); - spawn_heartbeat_tap_forwarder(outbound_b, handle_a.clone(), heartbeat_ba_tx); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - tokio::time::timeout(Duration::from_millis(300), heartbeat_ab_rx.recv()) - .await - .unwrap() - .unwrap(); - tokio::time::timeout(Duration::from_millis(200), heartbeat_ba_rx.recv()) - .await - .unwrap() - .unwrap(); - - let followup = - tokio::time::timeout(Duration::from_millis(50), heartbeat_ab_rx.recv()).await; - assert!(followup.is_err(), "unexpected heartbeat ping-pong"); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn invalid_heartbeat_ignored() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let heartbeat = wire::heartbeat::encrypt_heartbeat( - QlHeader { - sender: peer_b.xid, - recipient: peer_a.xid, - }, - &SymmetricKey::new(), - HeartbeatBody { - packet_id: PacketId(42), - valid_until: now_secs().saturating_add(30), - }, - test_encryption_nonce(42), - ); - handle_a.send_incoming(wire::encode_record(&heartbeat)); - - let result = tokio::time::timeout(Duration::from_millis(50), heartbeat_rx.recv()).await; - assert!(result.is_err(), "unexpected heartbeat reply"); - }) - .await; -} diff --git a/ql2/src/tests/mod.rs b/ql2/src/tests/mod.rs deleted file mode 100644 index f50a4d10..00000000 --- a/ql2/src/tests/mod.rs +++ /dev/null @@ -1,1027 +0,0 @@ -use std::{ - future::Future, - sync::{ - atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering}, - Arc, Mutex, - }, - time::Duration, -}; - -use async_channel::{Receiver, Sender}; -use bc_components::{ - Digest, MLDSAPrivateKey, MLDSAPublicKey, MLKEMPrivateKey, MLKEMPublicKey, SymmetricKey, MLDSA, - MLKEM, XID, -}; -use rkyv::{Archive, Serialize}; -use tokio::task::LocalSet; - -use crate::{ - platform::{PlatformFuture, QlCrypto, QlPlatform}, - runtime::{ - new_runtime, EngineConfig, HandlerEvent, KeepAliveConfig, PeerSession, RuntimeConfig, - RuntimeHandle, - }, - wire::{ - self, handshake::HandshakeRecord, heartbeat::HeartbeatBody, now_secs, pair, - AsWireMlKemCiphertext, AsWireNonce, AsWireXid, QlHeader, QlPayload, QlRecord, - }, - PacketId, Peer, QlError, -}; - -mod handshake; -mod heartbeat; -mod persistence; -mod rpc; -mod stream; -mod unpair; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum PeerStage { - Disconnected, - Initiator, - Responder, - Connected, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -struct StatusEvent { - peer: XID, - stage: PeerStage, -} - -struct TestPlatform { - signing_private: MLDSAPrivateKey, - signing_public: MLDSAPublicKey, - encapsulation_private: MLKEMPrivateKey, - encapsulation_public: MLKEMPublicKey, - outbound: Sender>, - status: Sender, - nonce_seed: u8, - nonce_counter: AtomicU8, - fail_on_write: Option, - write_counter: AtomicUsize, -} - -impl TestPlatform { - fn new(seed: u8) -> (Self, Receiver>, Receiver) { - Self::new_with_fail_on_write(seed, None) - } - - fn new_with_failed_write( - seed: u8, - fail_on_write: usize, - ) -> (Self, Receiver>, Receiver) { - Self::new_with_fail_on_write(seed, Some(fail_on_write)) - } - - fn new_with_fail_on_write( - seed: u8, - fail_on_write: Option, - ) -> (Self, Receiver>, Receiver) { - let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); - let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); - let (outbound, outbound_rx) = async_channel::unbounded(); - let (status, status_rx) = async_channel::unbounded(); - ( - Self { - signing_private, - signing_public, - encapsulation_private, - encapsulation_public, - outbound, - status, - nonce_seed: seed, - nonce_counter: AtomicU8::new(0), - fail_on_write, - write_counter: AtomicUsize::new(0), - }, - outbound_rx, - status_rx, - ) - } - - fn signing_public_key(&self) -> &MLDSAPublicKey { - &self.signing_public - } - - fn encapsulation_public_key(&self) -> &MLKEMPublicKey { - &self.encapsulation_public - } -} - -impl QlCrypto for TestPlatform { - fn signing_private_key(&self) -> &MLDSAPrivateKey { - &self.signing_private - } - - fn signing_public_key(&self) -> &MLDSAPublicKey { - &self.signing_public - } - - fn encapsulation_private_key(&self) -> &MLKEMPrivateKey { - &self.encapsulation_private - } - - fn encapsulation_public_key(&self) -> &MLKEMPublicKey { - &self.encapsulation_public - } - - fn fill_random_bytes(&self, data: &mut [u8]) { - let value = self - .nonce_seed - .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); - data.fill(value); - } -} - -impl QlPlatform for TestPlatform { - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { - let fail_on_write = self.fail_on_write; - let write_index = self.write_counter.fetch_add(1, Ordering::Relaxed) + 1; - let outbound = self.outbound.clone(); - Box::pin(async move { - if fail_on_write == Some(write_index) { - return Err(QlError::SendFailed); - } - outbound - .send(message) - .await - .map_err(|_| QlError::InvalidPayload) - }) - } - - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { - Box::pin(tokio::time::sleep(duration)) - } - - fn load_peer(&self) -> PlatformFuture<'_, Option> { - Box::pin(async { None }) - } - - fn persist_peer(&self, _peer: Peer) {} - - fn clear_peer(&self) {} - - fn handle_peer_status(&self, peer: XID, session: &PeerSession) { - let stage = match session { - PeerSession::Disconnected => PeerStage::Disconnected, - PeerSession::Initiator { .. } => PeerStage::Initiator, - PeerSession::Responder { .. } => PeerStage::Responder, - PeerSession::Connected { .. } => PeerStage::Connected, - }; - let _ = self.status.try_send(StatusEvent { peer, stage }); - } - - fn handle_inbound(&self, _event: HandlerEvent) {} -} - -struct InboundPlatform { - signing_private: MLDSAPrivateKey, - signing_public: MLDSAPublicKey, - encapsulation_private: MLKEMPrivateKey, - encapsulation_public: MLKEMPublicKey, - outbound: Sender>, - status: Sender, - inbound: Sender, - nonce_seed: u8, - nonce_counter: AtomicU8, -} - -impl InboundPlatform { - fn new( - seed: u8, - ) -> ( - Self, - Receiver>, - Receiver, - Receiver, - ) { - let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); - let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); - let (outbound, outbound_rx) = async_channel::unbounded(); - let (status, status_rx) = async_channel::unbounded(); - let (inbound, inbound_rx) = async_channel::unbounded(); - ( - Self { - signing_private, - signing_public, - encapsulation_private, - encapsulation_public, - outbound, - status, - inbound, - nonce_seed: seed, - nonce_counter: AtomicU8::new(0), - }, - outbound_rx, - status_rx, - inbound_rx, - ) - } -} - -impl QlCrypto for InboundPlatform { - fn signing_private_key(&self) -> &MLDSAPrivateKey { - &self.signing_private - } - - fn signing_public_key(&self) -> &MLDSAPublicKey { - &self.signing_public - } - - fn encapsulation_private_key(&self) -> &MLKEMPrivateKey { - &self.encapsulation_private - } - - fn encapsulation_public_key(&self) -> &MLKEMPublicKey { - &self.encapsulation_public - } - - fn fill_random_bytes(&self, data: &mut [u8]) { - let value = self - .nonce_seed - .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); - data.fill(value); - } -} - -impl QlPlatform for InboundPlatform { - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { - let outbound = self.outbound.clone(); - Box::pin(async move { - outbound - .send(message) - .await - .map_err(|_| QlError::InvalidPayload) - }) - } - - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { - Box::pin(tokio::time::sleep(duration)) - } - - fn load_peer(&self) -> PlatformFuture<'_, Option> { - Box::pin(async { None }) - } - - fn persist_peer(&self, _peer: Peer) {} - - fn clear_peer(&self) {} - - fn handle_peer_status(&self, peer: XID, session: &PeerSession) { - let stage = match session { - PeerSession::Disconnected => PeerStage::Disconnected, - PeerSession::Initiator { .. } => PeerStage::Initiator, - PeerSession::Responder { .. } => PeerStage::Responder, - PeerSession::Connected { .. } => PeerStage::Connected, - }; - let _ = self.status.try_send(StatusEvent { peer, stage }); - } - - fn handle_inbound(&self, event: HandlerEvent) { - let _ = self.inbound.try_send(event); - } -} - -async fn run_local_test(future: F) -where - F: Future, -{ - let local = LocalSet::new(); - local.run_until(future).await; -} - -fn spawn_forwarder(outbound: Receiver>, handle: RuntimeHandle) { - tokio::task::spawn_local(async move { - while let Ok(bytes) = outbound.recv().await { - handle.send_incoming(bytes); - } - }); -} - -fn is_stream(bytes: &[u8]) -> bool { - let Ok(record) = wire::decode_record(bytes) else { - return false; - }; - matches!(record.payload, QlPayload::Stream(_)) -} - -fn is_heartbeat(bytes: &[u8]) -> bool { - let Ok(record) = wire::decode_record(bytes) else { - return false; - }; - matches!(record.payload, QlPayload::Heartbeat(_)) -} - -fn spawn_drop_first_stream_forwarder(outbound: Receiver>, handle: RuntimeHandle) { - tokio::task::spawn_local(async move { - let mut dropped = false; - while let Ok(bytes) = outbound.recv().await { - if !dropped && is_stream(&bytes) { - dropped = true; - continue; - } - handle.send_incoming(bytes); - } - }); -} - -fn spawn_drop_first_stream_when( - outbound: Receiver>, - handle: RuntimeHandle, - armed: Arc, -) { - tokio::task::spawn_local(async move { - let mut dropped = false; - while let Ok(bytes) = outbound.recv().await { - if armed.load(Ordering::Relaxed) && !dropped && is_stream(&bytes) { - dropped = true; - continue; - } - handle.send_incoming(bytes); - } - }); -} - -fn spawn_duplicate_first_stream_forwarder(outbound: Receiver>, handle: RuntimeHandle) { - tokio::task::spawn_local(async move { - let mut duplicated = false; - while let Ok(bytes) = outbound.recv().await { - if !duplicated && is_stream(&bytes) { - duplicated = true; - handle.send_incoming(bytes.clone()); - } - handle.send_incoming(bytes); - } - }); -} - -#[derive(Clone)] -struct SessionKeyMaterial { - initiator_encapsulation_private: MLKEMPrivateKey, - responder_encapsulation_private: MLKEMPrivateKey, -} - -fn session_key_material( - initiator: &TestPlatform, - responder: &InboundPlatform, -) -> SessionKeyMaterial { - SessionKeyMaterial { - initiator_encapsulation_private: initiator.encapsulation_private.clone(), - responder_encapsulation_private: responder.encapsulation_private.clone(), - } -} - -#[derive(Default)] -struct SessionTrace { - hello_header: Option, - hello: Option, - reply: Option, - session_key: Option, -} - -#[derive(Archive, Serialize)] -struct TestHandshakeTranscript { - #[rkyv(with = AsWireXid)] - initiator: XID, - #[rkyv(with = AsWireXid)] - responder: XID, - #[rkyv(with = AsWireNonce)] - initiator_nonce: bc_components::Nonce, - #[rkyv(with = AsWireNonce)] - responder_nonce: bc_components::Nonce, - #[rkyv(with = AsWireMlKemCiphertext)] - initiator_kem_ct: bc_components::MLKEMCiphertext, - #[rkyv(with = AsWireMlKemCiphertext)] - responder_kem_ct: bc_components::MLKEMCiphertext, -} - -#[derive(Archive, Serialize)] -struct TestSessionKeyMaterial { - initiator_secret: Vec, - responder_secret: Vec, - transcript: Vec, -} - -fn derive_session_key( - trace: &SessionTrace, - key_material: &SessionKeyMaterial, -) -> Option { - let header = trace.hello_header.as_ref()?; - let hello = trace.hello.as_ref()?; - let reply = trace.reply.as_ref()?; - let initiator_secret = key_material - .responder_encapsulation_private - .decapsulate_shared_secret(&hello.kem_ct) - .ok()?; - let responder_secret = key_material - .initiator_encapsulation_private - .decapsulate_shared_secret(&reply.kem_ct) - .ok()?; - let transcript = wire::encode_value(&TestHandshakeTranscript { - initiator: header.sender, - responder: header.recipient, - initiator_nonce: hello.nonce.clone(), - responder_nonce: reply.nonce.clone(), - initiator_kem_ct: hello.kem_ct.clone(), - responder_kem_ct: reply.kem_ct.clone(), - }); - let payload = wire::encode_value(&TestSessionKeyMaterial { - initiator_secret: initiator_secret.as_bytes().to_vec(), - responder_secret: responder_secret.as_bytes().to_vec(), - transcript, - }); - let digest = Digest::from_image(payload); - Some(SymmetricKey::from_data(*digest.data())) -} - -fn test_encryption_nonce(seed: u8) -> [u8; wire::encrypted_message::NONCE_SIZE] { - [seed; wire::encrypted_message::NONCE_SIZE] -} - -fn spawn_stream_mutating_forwarder( - outbound: Receiver>, - handle: RuntimeHandle, - key_material: SessionKeyMaterial, - trace: Arc>, - mutator: F, -) where - F: FnMut(&QlHeader, &mut wire::stream::StreamBody) -> bool + 'static, -{ - tokio::task::spawn_local(async move { - let mut mutator = mutator; - while let Ok(bytes) = outbound.recv().await { - let Ok(record) = wire::access_record(&bytes) else { - handle.send_incoming(bytes); - continue; - }; - - { - let mut trace = trace.lock().unwrap(); - match &record.payload { - wire::ArchivedQlPayload::Handshake( - wire::handshake::ArchivedHandshakeRecord::Hello(hello), - ) => { - trace.hello_header = Some(wire::deserialize_value(&record.header).unwrap()); - trace.hello = Some(wire::deserialize_value(hello).unwrap()); - } - wire::ArchivedQlPayload::Handshake( - wire::handshake::ArchivedHandshakeRecord::HelloReply(reply), - ) => { - trace.reply = Some(wire::deserialize_value(reply).unwrap()); - } - _ => {} - } - if trace.session_key.is_none() { - trace.session_key = derive_session_key(&trace, &key_material); - } - } - - let session_key = trace.lock().unwrap().session_key.clone(); - if let (Some(session_key), wire::ArchivedQlPayload::Stream(encrypted)) = - (session_key, &record.payload) - { - let header = wire::deserialize_value(&record.header).unwrap(); - let encrypted = wire::deserialize_value(encrypted).unwrap(); - let plaintext = encrypted.decrypt(&session_key, &header.aad()); - if let Ok(plaintext) = plaintext { - let body = wire::access_value::(&plaintext) - .and_then(wire::deserialize_value); - if let Ok(mut body) = body { - if mutator(&header, &mut body) { - let mutated = wire::stream::encrypt_stream( - header, - &session_key, - body.clone(), - test_encryption_nonce(body.packet_id.0 as u8), - ); - handle.send_incoming(wire::encode_record(&mutated)); - continue; - } - } - } - } - - handle.send_incoming(bytes); - } - }); -} - -fn spawn_drop_every_nth_stream_forwarder( - outbound: Receiver>, - handle: RuntimeHandle, - nth: usize, -) { - tokio::task::spawn_local(async move { - let mut stream_count = 0usize; - while let Ok(bytes) = outbound.recv().await { - if nth > 0 && is_stream(&bytes) { - stream_count = stream_count.saturating_add(1); - if stream_count % nth == 0 { - continue; - } - } - handle.send_incoming(bytes); - } - }); -} - -fn spawn_heartbeat_tap_forwarder( - outbound: Receiver>, - handle: RuntimeHandle, - heartbeat_tx: Sender<()>, -) { - tokio::task::spawn_local(async move { - while let Ok(bytes) = outbound.recv().await { - if is_heartbeat(&bytes) { - let _ = heartbeat_tx.send(()).await; - } - handle.send_incoming(bytes); - } - }); -} - -fn spawn_drop_heartbeat_forwarder(outbound: Receiver>, handle: RuntimeHandle) { - tokio::task::spawn_local(async move { - while let Ok(bytes) = outbound.recv().await { - if is_heartbeat(&bytes) { - continue; - } - handle.send_incoming(bytes); - } - }); -} - -fn spawn_gated_forwarder( - outbound: Receiver>, - handle: RuntimeHandle, - drop_flag: Arc, -) { - tokio::task::spawn_local(async move { - while let Ok(bytes) = outbound.recv().await { - if drop_flag.load(Ordering::Relaxed) { - continue; - } - handle.send_incoming(bytes); - } - }); -} - -#[derive(Clone)] -struct PeerIdentity { - xid: XID, - signing_key: MLDSAPublicKey, - encapsulation_key: MLKEMPublicKey, -} - -fn peer_identity(platform: &impl QlCrypto) -> PeerIdentity { - PeerIdentity { - xid: platform.xid(), - signing_key: platform.signing_public_key().clone(), - encapsulation_key: platform.encapsulation_public_key().clone(), - } -} - -fn register_peers( - handle_a: &RuntimeHandle, - handle_b: &RuntimeHandle, - identity_a: &PeerIdentity, - identity_b: &PeerIdentity, -) { - handle_a.bind_peer(Peer { - peer: identity_b.xid, - signing_key: identity_b.signing_key.clone(), - encapsulation_key: identity_b.encapsulation_key.clone(), - }); - handle_b.bind_peer(Peer { - peer: identity_a.xid, - signing_key: identity_a.signing_key.clone(), - encapsulation_key: identity_a.encapsulation_key.clone(), - }); -} - -type PersistPlatformParts = ( - PersistPlatform, - Receiver>, - Receiver, - Receiver>, -); - -struct PersistPlatform { - signing_private: MLDSAPrivateKey, - signing_public: MLDSAPublicKey, - encapsulation_private: MLKEMPrivateKey, - encapsulation_public: MLKEMPublicKey, - outbound: Sender>, - status: Sender, - persisted: Sender>, - loaded_peer: Option, - nonce_seed: u8, - nonce_counter: AtomicU8, -} - -impl PersistPlatform { - fn new(seed: u8, loaded_peer: Option) -> PersistPlatformParts { - let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); - let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); - let (outbound, outbound_rx) = async_channel::unbounded(); - let (status, status_rx) = async_channel::unbounded(); - let (persisted, persisted_rx) = async_channel::unbounded(); - ( - Self { - signing_private, - signing_public, - encapsulation_private, - encapsulation_public, - outbound, - status, - persisted, - loaded_peer, - nonce_seed: seed, - nonce_counter: AtomicU8::new(0), - }, - outbound_rx, - status_rx, - persisted_rx, - ) - } -} - -impl QlCrypto for PersistPlatform { - fn signing_private_key(&self) -> &MLDSAPrivateKey { - &self.signing_private - } - fn signing_public_key(&self) -> &MLDSAPublicKey { - &self.signing_public - } - fn encapsulation_private_key(&self) -> &MLKEMPrivateKey { - &self.encapsulation_private - } - fn encapsulation_public_key(&self) -> &MLKEMPublicKey { - &self.encapsulation_public - } - - fn fill_random_bytes(&self, data: &mut [u8]) { - let value = self - .nonce_seed - .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); - data.fill(value); - } -} - -impl QlPlatform for PersistPlatform { - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { - let outbound = self.outbound.clone(); - Box::pin(async move { - outbound - .send(message) - .await - .map_err(|_| QlError::InvalidPayload) - }) - } - - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { - Box::pin(tokio::time::sleep(duration)) - } - - fn load_peer(&self) -> PlatformFuture<'_, Option> { - let peer = self.loaded_peer.clone(); - Box::pin(async move { peer }) - } - - fn persist_peer(&self, peer: crate::Peer) { - let _ = self.persisted.try_send(Some(peer)); - } - - fn clear_peer(&self) { - let _ = self.persisted.try_send(None); - } - - fn handle_peer_status(&self, peer: XID, session: &PeerSession) { - let stage = match session { - PeerSession::Disconnected => PeerStage::Disconnected, - PeerSession::Initiator { .. } => PeerStage::Initiator, - PeerSession::Responder { .. } => PeerStage::Responder, - PeerSession::Connected { .. } => PeerStage::Connected, - }; - let _ = self.status.try_send(StatusEvent { peer, stage }); - } - - fn handle_inbound(&self, _event: HandlerEvent) {} -} - -async fn await_status( - receiver: &Receiver, - peer: XID, - stage: PeerStage, -) -> StatusEvent { - tokio::time::timeout(Duration::from_secs(1), async { - loop { - if let Ok(event) = receiver.recv().await { - if event.peer == peer && event.stage == stage { - return event; - } - } - } - }) - .await - .unwrap() -} - -#[test] -fn protocol_record_size_breakdown() { - let (platform_a, _outbound_a, _status_a) = TestPlatform::new(1); - let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); - - let initiator = platform_a.xid(); - let responder = platform_b.xid(); - - let (hello, initiator_secret) = wire::handshake::build_hello( - &platform_a, - initiator, - responder, - platform_b.encapsulation_public_key(), - ) - .unwrap(); - let hello_record = QlRecord { - header: QlHeader { - sender: initiator, - recipient: responder, - }, - payload: QlPayload::Handshake(HandshakeRecord::Hello(hello.clone())), - }; - let hello_size = wire::encode_record(&hello_record).len(); - let hello_bytes = wire::encode_value(&hello); - let hello_view = wire::access_value::(&hello_bytes).unwrap(); - - let (hello_reply, responder_secrets) = wire::handshake::respond_hello( - &platform_b, - initiator, - responder, - platform_a.encapsulation_public_key(), - hello_view, - ) - .unwrap(); - let reply_record = QlRecord { - header: QlHeader { - sender: responder, - recipient: initiator, - }, - payload: QlPayload::Handshake(HandshakeRecord::HelloReply(hello_reply.clone())), - }; - let reply_size = wire::encode_record(&reply_record).len(); - let reply_bytes = wire::encode_value(&hello_reply); - let reply_view = - wire::access_value::(&reply_bytes).unwrap(); - - let (confirm, session_key) = wire::handshake::build_confirm( - &platform_a, - initiator, - responder, - platform_b.signing_public_key(), - &hello, - reply_view, - &initiator_secret, - ) - .unwrap(); - let confirm_bytes = wire::encode_value(&confirm); - let confirm_view = - wire::access_value::(&confirm_bytes).unwrap(); - let confirm_record = QlRecord { - header: QlHeader { - sender: initiator, - recipient: responder, - }, - payload: QlPayload::Handshake(HandshakeRecord::Confirm(confirm.clone())), - }; - let confirm_size = wire::encode_record(&confirm_record).len(); - let _session_key_b = wire::handshake::finalize_confirm( - initiator, - responder, - platform_a.signing_public_key(), - &hello, - &hello_reply, - confirm_view, - &responder_secrets, - ) - .unwrap(); - - let pair_size = wire::encode_record( - &pair::build_pair_request( - &platform_a, - responder, - platform_b.encapsulation_public_key(), - PacketId(11), - Duration::from_secs(60), - ) - .unwrap(), - ) - .len(); - - let heartbeat_size = wire::encode_record(&wire::heartbeat::encrypt_heartbeat( - QlHeader { - sender: initiator, - recipient: responder, - }, - &session_key, - HeartbeatBody { - packet_id: PacketId(12), - valid_until: wire::now_secs().saturating_add(60), - }, - test_encryption_nonce(12), - )) - .len(); - - let unpair_size = wire::encode_record(&wire::unpair::build_unpair_record( - &platform_a, - QlHeader { - sender: initiator, - recipient: responder, - }, - PacketId(13), - wire::now_secs().saturating_add(60), - )) - .len(); - - let stream_record_size = - |packet_id: PacketId, - packet_ack: Option, - frame: Option| { - wire::encode_record(&wire::stream::encrypt_stream( - QlHeader { - sender: initiator, - recipient: responder, - }, - &session_key, - wire::stream::StreamBody { - packet_id, - valid_until: wire::now_secs().saturating_add(60), - packet_ack, - frame, - }, - test_encryption_nonce(packet_id.0 as u8), - )) - .len() - }; - - let stream_header = QlHeader { - sender: initiator, - recipient: responder, - }; - let stream_ack_body = wire::stream::StreamBody { - packet_id: PacketId(20), - valid_until: wire::now_secs().saturating_add(60), - packet_ack: Some(wire::stream::PacketAck { - packet_id: PacketId(19), - }), - frame: None, - }; - let stream_ack_record = wire::stream::encrypt_stream( - stream_header.clone(), - &session_key, - stream_ack_body.clone(), - test_encryption_nonce(20), - ); - let stream_ack_encrypted = match &stream_ack_record.payload { - QlPayload::Stream(encrypted) => encrypted, - _ => unreachable!(), - }; - let stream_ack_header_size = wire::encode_value(&stream_header).len(); - let stream_ack_body_size = wire::encode_value(&stream_ack_body).len(); - let stream_ack_envelope_size = wire::encode_value(stream_ack_encrypted).len(); - let stream_ack_payload_size = wire::encode_value(&stream_ack_record.payload).len(); - - let stream_open_body = wire::stream::StreamBody { - packet_id: PacketId(21), - valid_until: wire::now_secs().saturating_add(60), - packet_ack: None, - frame: Some(wire::stream::StreamFrame::Open( - wire::stream::StreamFrameOpen { - stream_id: crate::StreamId(2), - request_head: vec![1, 2, 3], - response_max_offset: 1024, - }, - )), - }; - let stream_open_body_size = wire::encode_value(&stream_open_body).len(); - - let stream_ack_size = stream_record_size( - PacketId(20), - Some(wire::stream::PacketAck { - packet_id: PacketId(19), - }), - None, - ); - let stream_open_size = stream_record_size( - PacketId(21), - None, - Some(wire::stream::StreamFrame::Open( - wire::stream::StreamFrameOpen { - stream_id: crate::StreamId(2), - request_head: vec![1, 2, 3], - response_max_offset: 1024, - }, - )), - ); - let stream_accept_size = stream_record_size( - PacketId(22), - None, - Some(wire::stream::StreamFrame::Accept( - wire::stream::StreamFrameAccept { - stream_id: crate::StreamId(2), - response_head: vec![4, 5, 6], - request_max_offset: 2048, - }, - )), - ); - let stream_reject_size = stream_record_size( - PacketId(23), - None, - Some(wire::stream::StreamFrame::Reject( - wire::stream::StreamFrameReject { - stream_id: crate::StreamId(2), - code: wire::stream::RejectCode::InvalidHead, - }, - )), - ); - let stream_data_size = stream_record_size( - PacketId(24), - None, - Some(wire::stream::StreamFrame::Data( - wire::stream::StreamFrameData { - stream_id: crate::StreamId(2), - dir: wire::stream::Direction::Request, - offset: 128, - bytes: vec![7, 8, 9, 10], - }, - )), - ); - let stream_credit_size = stream_record_size( - PacketId(25), - None, - Some(wire::stream::StreamFrame::Credit( - wire::stream::StreamFrameCredit { - stream_id: crate::StreamId(2), - dir: wire::stream::Direction::Response, - recv_offset: 256, - max_offset: 4096, - }, - )), - ); - let stream_finish_size = stream_record_size( - PacketId(26), - None, - Some(wire::stream::StreamFrame::Finish( - wire::stream::StreamFrameFinish { - stream_id: crate::StreamId(2), - dir: wire::stream::Direction::Response, - }, - )), - ); - let stream_reset_size = stream_record_size( - PacketId(27), - None, - Some(wire::stream::StreamFrame::Reset( - wire::stream::StreamFrameReset { - stream_id: crate::StreamId(2), - dir: wire::stream::ResetTarget::Both, - code: wire::stream::ResetCode::Protocol, - }, - )), - ); - - let print_size = |label: &str, size: usize| { - println!("{label:<23}: {size} bytes"); - }; - - print_size("ql2 size hello", hello_size); - print_size("ql2 size hello_reply", reply_size); - print_size("ql2 size confirm", confirm_size); - print_size("ql2 size pair", pair_size); - print_size("ql2 size heartbeat", heartbeat_size); - print_size("ql2 size unpair", unpair_size); - print_size("ql2 size stream ack", stream_ack_size); - print_size("ql2 size stream open", stream_open_size); - print_size("ql2 size stream accept", stream_accept_size); - print_size("ql2 size stream reject", stream_reject_size); - print_size("ql2 size stream data", stream_data_size); - print_size("ql2 size stream credit", stream_credit_size); - print_size("ql2 size stream finish", stream_finish_size); - print_size("ql2 size stream reset", stream_reset_size); - println!( - "ql2 stream ack breakdown : header={} derived_aad={} plaintext={} ciphertext={} envelope(no aad)={} payload={} full={}", - stream_ack_header_size, - stream_header.aad().len(), - stream_ack_body_size, - stream_ack_body_size, - stream_ack_envelope_size, - stream_ack_payload_size, - stream_ack_size, - ); - println!( - "ql2 stream open delta : open_body={} ack_body={} (+{} request_head bytes)", - stream_open_body_size, - stream_ack_body_size, - stream_open_body_size.saturating_sub(stream_ack_body_size), - ); -} diff --git a/ql2/src/tests/persistence.rs b/ql2/src/tests/persistence.rs deleted file mode 100644 index 8fc5cf9f..00000000 --- a/ql2/src/tests/persistence.rs +++ /dev/null @@ -1,139 +0,0 @@ -use std::time::Duration; - -use super::*; - -#[tokio::test(flavor = "current_thread")] -async fn register_peer_persists_snapshot() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, _outbound_a, _status_a, persisted_a) = PersistPlatform::new(1, None); - let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); - let peer_b = platform_b.xid(); - let signing_b = platform_b.signing_public_key().clone(); - let encap_b = platform_b.encapsulation_public_key().clone(); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - tokio::task::spawn_local(async move { runtime_a.run().await }); - - handle_a.bind_peer(crate::Peer { - peer: peer_b, - signing_key: signing_b.clone(), - encapsulation_key: encap_b.clone(), - }); - - let snapshot = tokio::time::timeout(Duration::from_secs(1), persisted_a.recv()) - .await - .unwrap() - .unwrap(); - assert_eq!( - snapshot, - Some(crate::Peer { - peer: peer_b, - signing_key: signing_b, - encapsulation_key: encap_b, - }) - ); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn loaded_peers_can_connect_without_register() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_b = peer_identity(&platform_b); - - let (platform_a, outbound_a, status_a, _persisted_a) = PersistPlatform::new( - 1, - Some(crate::Peer { - peer: peer_b.xid, - signing_key: peer_b.signing_key.clone(), - encapsulation_key: peer_b.encapsulation_key.clone(), - }), - ); - let peer_a = peer_identity(&platform_a); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - handle_b.bind_peer(crate::Peer { - peer: peer_a.xid, - signing_key: peer_a.signing_key.clone(), - encapsulation_key: peer_a.encapsulation_key.clone(), - }); - - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn pairing_persists_snapshot() { - run_local_test(async { - let (platform_a, _outbound_a, _status_a) = TestPlatform::new(1); - let peer_a = peer_identity(&platform_a); - - let (platform_b, _outbound_b, _status_b, persisted_b) = PersistPlatform::new(2, None); - let peer_b = peer_identity(&platform_b); - - let pairing_message = pair::build_pair_request( - &platform_a, - peer_b.xid, - &peer_b.encapsulation_key, - PacketId(1), - Duration::from_secs(1), - ) - .unwrap(); - let pairing_bytes = wire::encode_record(&pairing_message); - - let (runtime_b, handle_b) = new_runtime( - platform_b, - RuntimeConfig { - engine: EngineConfig { - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }, - ); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - handle_b.send_incoming(pairing_bytes); - - let snapshot = tokio::time::timeout(Duration::from_secs(1), persisted_b.recv()) - .await - .unwrap() - .unwrap(); - assert_eq!( - snapshot, - Some(crate::Peer { - peer: peer_a.xid, - signing_key: peer_a.signing_key, - encapsulation_key: peer_a.encapsulation_key, - }) - ); - }) - .await; -} diff --git a/ql2/src/tests/rpc.rs b/ql2/src/tests/rpc.rs deleted file mode 100644 index 7308c688..00000000 --- a/ql2/src/tests/rpc.rs +++ /dev/null @@ -1,264 +0,0 @@ -use std::time::Duration; - -use dcbor::CBOR; - -use super::*; -use crate::{ - rpc::{MethodId, RequestResponse, RpcHandle, RpcRequestHead, RpcResponseHead}, - runtime::StreamConfig, - wire::stream::RejectCode, - QlError, -}; - -#[derive(Debug, Clone, PartialEq, Eq)] -struct AddOne(u64); - -#[derive(Debug, Clone, PartialEq, Eq)] -struct AddOneResponse(u64); - -impl From for CBOR { - fn from(value: AddOne) -> Self { - CBOR::from(value.0) - } -} - -impl TryFrom for AddOne { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - Ok(Self(value.try_into()?)) - } -} - -impl From for CBOR { - fn from(value: AddOneResponse) -> Self { - CBOR::from(value.0) - } -} - -impl TryFrom for AddOneResponse { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - Ok(Self(value.try_into()?)) - } -} - -impl RequestResponse for AddOne { - const METHOD: MethodId = MethodId(1); - type Response = AddOneResponse; -} - -#[tokio::test(flavor = "current_thread")] -async fn rpc_request_response_round_trip() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(300), - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - let rpc_a = RpcHandle::new(handle_a.clone()); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let request_body = CBOR::from(AddOne(41)).to_cbor_data(); - let response_body = CBOR::from(AddOneResponse(42)).to_cbor_data(); - let request_head = - RpcRequestHead::try_from(CBOR::try_from_data(&stream.request_head).unwrap()) - .unwrap(); - assert_eq!(request_head.method, AddOne::METHOD); - assert_eq!(request_head.content_length, Some(request_body.len() as u64)); - - let mut response = stream - .respond_to - .accept( - CBOR::from(RpcResponseHead::new(Some(response_body.len() as u64))) - .to_cbor_data(), - ) - .unwrap(); - - let request_body = read_body(stream.request).await.unwrap(); - let request = AddOne::try_from(CBOR::try_from_data(&request_body).unwrap()).unwrap(); - - response - .write_all(&CBOR::from(AddOneResponse(request.0 + 1)).to_cbor_data()) - .await - .unwrap(); - response.finish().await.unwrap(); - }); - - let response = rpc_a - .request(AddOne(41), StreamConfig::default()) - .await - .unwrap(); - assert_eq!(response, AddOneResponse(42)); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn rpc_request_response_reject_propagates() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(300), - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - let rpc_a = RpcHandle::new(handle_a.clone()); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let request_head = - RpcRequestHead::try_from(CBOR::try_from_data(&stream.request_head).unwrap()) - .unwrap(); - assert_eq!(request_head.method, AddOne::METHOD); - stream.respond_to.reject(RejectCode::UnknownRoute).unwrap(); - }); - - let err = rpc_a - .request(AddOne(1), StreamConfig::default()) - .await - .unwrap_err(); - assert!(matches!( - err, - crate::rpc::RpcError::Transport(QlError::StreamRejected { - code: RejectCode::UnknownRoute - }) - )); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn rpc_request_response_content_length_mismatch_errors() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(300), - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - let rpc_a = RpcHandle::new(handle_a.clone()); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let mut response = stream - .respond_to - .accept(CBOR::from(RpcResponseHead::new(Some(99))).to_cbor_data()) - .unwrap(); - let _request_body = read_body(stream.request).await.unwrap(); - response - .write_all(&CBOR::from(AddOneResponse(2)).to_cbor_data()) - .await - .unwrap(); - response.finish().await.unwrap(); - }); - - let err = rpc_a - .request(AddOne(1), StreamConfig::default()) - .await - .unwrap_err(); - assert!(matches!( - err, - crate::rpc::RpcError::ContentLengthMismatch { - expected: 99, - actual: 1, - } - )); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -async fn read_body(mut stream: crate::runtime::InboundByteStream) -> Result, QlError> { - let mut body = Vec::new(); - while let Some(chunk) = stream.next_chunk().await? { - body.extend_from_slice(&chunk); - } - Ok(body) -} diff --git a/ql2/src/tests/stream.rs b/ql2/src/tests/stream.rs deleted file mode 100644 index 5692f8fa..00000000 --- a/ql2/src/tests/stream.rs +++ /dev/null @@ -1,1685 +0,0 @@ -use std::{sync::atomic::Ordering, time::Duration}; - -use super::*; -use crate::{ - runtime::{PendingStream, StreamConfig}, - wire::stream::{ - Direction, RejectCode, ResetCode, StreamFrame, StreamFrameCredit, StreamFrameData, - }, -}; - -#[tokio::test(flavor = "current_thread")] -async fn duplex_stream_round_trip() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(300), - packet_ack_timeout: Duration::from_millis(40), - max_payload_bytes: 4, - initial_credit: 4, - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - assert_eq!(stream.request_head, b"req-head".to_vec()); - - let mut request = stream.request; - let mut response = stream.respond_to.accept(b"resp-head".to_vec()).unwrap(); - - assert_eq!(request.next_chunk().await.unwrap(), Some(vec![1, 2])); - response.write_all(&[9]).await.unwrap(); - assert_eq!(request.next_chunk().await.unwrap(), Some(vec![3, 4])); - response.write_all(&[8, 7]).await.unwrap(); - assert_eq!(request.next_chunk().await.unwrap(), None); - response.finish().await.unwrap(); - }); - - let pending = handle_a - .open_stream(b"req-head".to_vec(), StreamConfig::default()) - .await - .unwrap(); - let PendingStream { - mut request, - accepted, - } = pending; - request.write_all(&[1, 2]).await.unwrap(); - let mut accepted = accepted.await.unwrap(); - assert_eq!(accepted.response_head, b"resp-head".to_vec()); - assert_eq!(accepted.response.next_chunk().await.unwrap(), Some(vec![9])); - request.write_all(&[3, 4]).await.unwrap(); - request.finish().await.unwrap(); - assert_eq!( - accepted.response.next_chunk().await.unwrap(), - Some(vec![8, 7]) - ); - assert_eq!(accepted.response.next_chunk().await.unwrap(), None); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn duplicate_open_is_idempotent() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(400), - packet_ack_timeout: Duration::from_millis(30), - max_payload_bytes: 4, - initial_credit: 4, - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_drop_first_stream_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - tokio::time::sleep(Duration::from_millis(120)).await; - let response = stream.respond_to.accept(Vec::new()).unwrap(); - let second = tokio::time::timeout(Duration::from_millis(120), inbound_b.recv()).await; - assert!(second.is_err(), "duplicate open redispatched stream"); - response.finish().await.unwrap(); - }); - - let pending = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - let PendingStream { request, accepted } = pending; - let mut accepted = accepted.await.unwrap(); - request.finish().await.unwrap(); - assert_eq!(accepted.response.next_chunk().await.unwrap(), None); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn duplicate_accept_is_idempotent() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(400), - packet_ack_timeout: Duration::from_millis(30), - max_payload_bytes: 4, - initial_credit: 4, - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - let arm_drop = Arc::new(AtomicBool::new(false)); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_drop_first_stream_when(outbound_a, handle_b.clone(), arm_drop.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - arm_drop.store(true, Ordering::Relaxed); - let response = stream.respond_to.accept(b"accepted".to_vec()).unwrap(); - tokio::time::sleep(Duration::from_millis(150)).await; - response.finish().await.unwrap(); - }); - - let pending = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - let PendingStream { request, accepted } = pending; - let mut accepted = accepted.await.unwrap(); - assert_eq!(accepted.response_head, b"accepted".to_vec()); - tokio::time::sleep(Duration::from_millis(120)).await; - request.finish().await.unwrap(); - assert_eq!(accepted.response.next_chunk().await.unwrap(), None); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn replayed_open_packet_is_ignored() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(300), - packet_ack_timeout: Duration::from_millis(40), - max_payload_bytes: 4, - initial_credit: 4, - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_duplicate_first_stream_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let second = tokio::time::timeout(Duration::from_millis(80), inbound_b.recv()).await; - assert!(second.is_err(), "replayed open redispatched stream"); - let response = stream.respond_to.accept(Vec::new()).unwrap(); - response.finish().await.unwrap(); - }); - - let pending = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - let PendingStream { request, accepted } = pending; - let mut accepted = accepted.await.unwrap(); - request.finish().await.unwrap(); - assert_eq!(accepted.response.next_chunk().await.unwrap(), None); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn request_reset_can_keep_response_alive() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(400), - packet_ack_timeout: Duration::from_millis(40), - max_payload_bytes: 16, - initial_credit: 16, - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let mut request = stream.request; - let mut response = stream.respond_to.accept(b"err".to_vec()).unwrap(); - assert_eq!(request.next_chunk().await.unwrap(), Some(vec![1, 2])); - request.reset(ResetCode::InvalidData).await.unwrap(); - response.write_all(b"invalid").await.unwrap(); - response.finish().await.unwrap(); - }); - - let pending = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - let PendingStream { - mut request, - accepted, - } = pending; - request.write_all(&[1, 2]).await.unwrap(); - let mut accepted = accepted.await.unwrap(); - assert_eq!(accepted.response_head, b"err".to_vec()); - assert_eq!( - accepted.response.next_chunk().await.unwrap(), - Some(b"invalid".to_vec()) - ); - let err = request.write_all(&[3, 4]).await.unwrap_err(); - assert!(matches!(err, QlError::Cancelled)); - assert_eq!(accepted.response.next_chunk().await.unwrap(), None); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn open_timeout_returns_error() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(120), - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let pending = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - - let _stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - - let err = pending.accepted.await.unwrap_err(); - assert!(matches!(err, QlError::Timeout)); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn reject_surfaces_stream_rejected() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(300), - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - stream.respond_to.reject(RejectCode::UnknownRoute).unwrap(); - }); - - let pending = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - let err = pending.accepted.await.unwrap_err(); - assert!(matches!( - err, - QlError::StreamRejected { - code: RejectCode::UnknownRoute - } - )); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn dropping_responder_rejects_unhandled() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(300), - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let mut request = stream.request; - drop(stream.respond_to); - assert!(matches!( - request.next_chunk().await, - Ok(None) | Err(QlError::Cancelled) - )); - }); - - let PendingStream { request, accepted } = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - request.finish().await.unwrap(); - - let err = tokio::time::timeout(Duration::from_secs(1), accepted) - .await - .unwrap() - .unwrap_err(); - assert!(matches!( - err, - QlError::StreamRejected { - code: RejectCode::Unhandled - } - )); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn request_larger_than_ring_buffer_streams_with_backpressure() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(400), - packet_ack_timeout: Duration::from_millis(30), - max_payload_bytes: 4, - initial_credit: 4, - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - pipe_size_bytes: 4, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - let payload: Vec = (0..24).collect(); - let (done_tx, done_rx) = async_channel::bounded(1); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let mut request = stream.request; - let response = stream.respond_to.accept(Vec::new()).unwrap(); - let mut received = Vec::new(); - while let Some(chunk) = request.next_chunk().await.unwrap() { - received.extend_from_slice(&chunk); - } - done_tx.send(received).await.unwrap(); - response.finish().await.unwrap(); - }); - - let PendingStream { - mut request, - accepted, - } = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - request.write_all(&payload).await.unwrap(); - request.finish().await.unwrap(); - - let mut accepted = tokio::time::timeout(Duration::from_secs(1), accepted) - .await - .unwrap() - .unwrap(); - assert_eq!(accepted.response.next_chunk().await.unwrap(), None); - - let received = tokio::time::timeout(Duration::from_secs(1), done_rx.recv()) - .await - .unwrap() - .unwrap(); - assert_eq!(received, payload); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn response_larger_than_ring_buffer_streams_with_backpressure() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(400), - packet_ack_timeout: Duration::from_millis(30), - max_payload_bytes: 4, - initial_credit: 4, - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - pipe_size_bytes: 4, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - let payload: Vec = (50..74).collect(); - let expected = payload.clone(); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let mut request = stream.request; - let mut response = stream.respond_to.accept(Vec::new()).unwrap(); - assert_eq!(request.next_chunk().await.unwrap(), None); - response.write_all(&payload).await.unwrap(); - response.finish().await.unwrap(); - }); - - let PendingStream { request, accepted } = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - request.finish().await.unwrap(); - - let mut accepted = tokio::time::timeout(Duration::from_secs(1), accepted) - .await - .unwrap() - .unwrap(); - let mut received = Vec::new(); - while let Some(chunk) = accepted.response.next_chunk().await.unwrap() { - received.extend_from_slice(&chunk); - } - assert_eq!(received, expected); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn dropping_pending_accept_cancels_response_side() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(400), - packet_ack_timeout: Duration::from_millis(30), - max_payload_bytes: 4, - initial_credit: 4, - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - pipe_size_bytes: 4, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let mut request = stream.request; - let mut response = stream.respond_to.accept(Vec::new()).unwrap(); - assert_eq!(request.next_chunk().await.unwrap(), None); - let err = response - .write_all(&[1, 2, 3, 4, 5, 6, 7, 8]) - .await - .unwrap_err(); - assert!(matches!(err, QlError::Cancelled)); - }); - - let PendingStream { request, accepted } = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - drop(accepted); - request.finish().await.unwrap(); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn dropping_request_writer_sends_cancel() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(300), - packet_ack_timeout: Duration::from_millis(30), - max_payload_bytes: 4, - initial_credit: 4, - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let mut request = stream.request; - let response = stream.respond_to.accept(Vec::new()).unwrap(); - assert_eq!(request.next_chunk().await.unwrap(), Some(vec![1, 2, 3, 4])); - let err = request.next_chunk().await.unwrap_err(); - assert!(matches!( - err, - QlError::StreamReset { - dir: Direction::Request, - code: ResetCode::Cancelled, - } - )); - response.finish().await.unwrap(); - }); - - let PendingStream { - mut request, - accepted, - } = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - request.write_all(&[1, 2, 3, 4]).await.unwrap(); - let mut accepted = accepted.await.unwrap(); - drop(request); - assert_eq!(accepted.response.next_chunk().await.unwrap(), None); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn dropping_response_writer_sends_cancel() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(300), - packet_ack_timeout: Duration::from_millis(30), - max_payload_bytes: 4, - initial_credit: 4, - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let mut stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let mut response = stream.respond_to.accept(Vec::new()).unwrap(); - assert_eq!(stream.request.next_chunk().await.unwrap(), None); - response.write_all(&[9, 8, 7, 6]).await.unwrap(); - drop(response); - }); - - let PendingStream { request, accepted } = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - request.finish().await.unwrap(); - let mut accepted = accepted.await.unwrap(); - assert_eq!( - accepted.response.next_chunk().await.unwrap(), - Some(vec![9, 8, 7, 6]) - ); - let err = accepted.response.next_chunk().await.unwrap_err(); - assert!(matches!( - err, - QlError::StreamReset { - dir: Direction::Response, - code: ResetCode::Cancelled, - } - )); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn dropping_request_reader_sends_cancel() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(300), - packet_ack_timeout: Duration::from_millis(30), - max_payload_bytes: 4, - initial_credit: 4, - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let mut request = stream.request; - let response = stream.respond_to.accept(Vec::new()).unwrap(); - assert_eq!(request.next_chunk().await.unwrap(), Some(vec![1, 2, 3, 4])); - drop(request); - response.finish().await.unwrap(); - }); - - let PendingStream { - mut request, - accepted, - } = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - request.write_all(&[1, 2, 3, 4]).await.unwrap(); - let mut accepted = accepted.await.unwrap(); - // ensure that the runtime can process the drop - tokio::time::sleep(Duration::from_millis(20)).await; - let err = request.write_all(&[5, 6, 7, 8]).await.unwrap_err(); - assert!(matches!(err, QlError::Cancelled)); - assert_eq!(accepted.response.next_chunk().await.unwrap(), None); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn dropping_response_reader_sends_cancel() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(300), - packet_ack_timeout: Duration::from_millis(30), - max_payload_bytes: 4, - initial_credit: 4, - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - pipe_size_bytes: 4, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - let (go_tx, go_rx) = async_channel::bounded(1); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let mut stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let mut response = stream.respond_to.accept(Vec::new()).unwrap(); - assert_eq!(stream.request.next_chunk().await.unwrap(), None); - go_rx.recv().await.unwrap(); - let err = response - .write_all(&[1, 2, 3, 4, 5, 6, 7, 8]) - .await - .unwrap_err(); - assert!(matches!(err, QlError::Cancelled)); - }); - - let PendingStream { request, accepted } = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - request.finish().await.unwrap(); - let accepted = accepted.await.unwrap(); - drop(accepted.response); - go_tx.send(()).await.unwrap(); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn empty_request_finishes_cleanly() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(300), - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let mut request = stream.request; - let mut response = stream.respond_to.accept(Vec::new()).unwrap(); - assert_eq!(request.next_chunk().await.unwrap(), None); - response.write_all(b"ok").await.unwrap(); - response.finish().await.unwrap(); - }); - - let PendingStream { request, accepted } = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - request.finish().await.unwrap(); - let mut accepted = accepted.await.unwrap(); - assert_eq!( - accepted.response.next_chunk().await.unwrap(), - Some(b"ok".to_vec()) - ); - assert_eq!(accepted.response.next_chunk().await.unwrap(), None); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn empty_response_finishes_cleanly() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(300), - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let mut request = stream.request; - let response = stream.respond_to.accept(Vec::new()).unwrap(); - assert_eq!(request.next_chunk().await.unwrap(), Some(vec![1])); - assert_eq!(request.next_chunk().await.unwrap(), None); - response.finish().await.unwrap(); - }); - - let PendingStream { - mut request, - accepted, - } = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - request.write_all(&[1]).await.unwrap(); - request.finish().await.unwrap(); - let mut accepted = accepted.await.unwrap(); - assert_eq!(accepted.response.next_chunk().await.unwrap(), None); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn stream_survives_every_third_packet_drop() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(500), - packet_ack_timeout: Duration::from_millis(20), - stream_retry_limit: 12, - max_payload_bytes: 4, - initial_credit: 4, - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - pipe_size_bytes: 4, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - let request_payload: Vec = (0..32).collect(); - let response_payload: Vec = (100..132).collect(); - let expected_response = response_payload.clone(); - let (done_tx, done_rx) = async_channel::bounded(1); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_drop_every_nth_stream_forwarder(outbound_a, handle_b.clone(), 3); - spawn_drop_every_nth_stream_forwarder(outbound_b, handle_a.clone(), 3); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let mut request = stream.request; - let mut response = stream.respond_to.accept(Vec::new()).unwrap(); - let mut received = Vec::new(); - while let Some(chunk) = request.next_chunk().await.unwrap() { - received.extend_from_slice(&chunk); - } - response.write_all(&response_payload).await.unwrap(); - response.finish().await.unwrap(); - done_tx.send(received).await.unwrap(); - }); - - let PendingStream { - mut request, - accepted, - } = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - request.write_all(&request_payload).await.unwrap(); - request.finish().await.unwrap(); - - let mut accepted = tokio::time::timeout(Duration::from_secs(3), accepted) - .await - .unwrap() - .unwrap(); - let mut received_response = Vec::new(); - while let Some(chunk) = accepted.response.next_chunk().await.unwrap() { - received_response.extend_from_slice(&chunk); - } - assert_eq!(received_response, expected_response); - - let received_request = tokio::time::timeout(Duration::from_secs(3), done_rx.recv()) - .await - .unwrap() - .unwrap(); - assert_eq!(received_request, request_payload); - - tokio::time::timeout(Duration::from_secs(3), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn response_data_before_accept_is_protocol_error() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(400), - packet_ack_timeout: Duration::from_millis(30), - stream_retry_limit: 8, - max_payload_bytes: 4, - initial_credit: 4, - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - let key_material = session_key_material(&platform_a, &platform_b); - let trace = Arc::new(Mutex::new(SessionTrace::default())); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_stream_mutating_forwarder( - outbound_a, - handle_b.clone(), - key_material.clone(), - trace.clone(), - |_header, _body| false, - ); - spawn_stream_mutating_forwarder(outbound_b, handle_a.clone(), key_material, trace, { - let mut mutated = false; - move |_header, body| { - if mutated { - return false; - } - if let Some(StreamFrame::Accept(frame)) = body.frame.take() { - mutated = true; - body.frame = Some(StreamFrame::Data(StreamFrameData { - stream_id: frame.stream_id, - dir: Direction::Response, - offset: 0, - bytes: vec![9], - })); - true - } else { - false - } - } - }); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let mut response = stream.respond_to.accept(Vec::new()).unwrap(); - response.write_all(&[9]).await.unwrap(); - let _ = response.finish().await; - }); - - let PendingStream { request, accepted } = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - request.finish().await.unwrap(); - let err = tokio::time::timeout(Duration::from_secs(1), accepted) - .await - .unwrap() - .unwrap_err(); - assert!(matches!(err, QlError::StreamProtocol)); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn data_offset_gap_is_protocol_error() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(400), - packet_ack_timeout: Duration::from_millis(30), - stream_retry_limit: 8, - max_payload_bytes: 4, - initial_credit: 4, - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - let key_material = session_key_material(&platform_a, &platform_b); - let trace = Arc::new(Mutex::new(SessionTrace::default())); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_stream_mutating_forwarder( - outbound_a, - handle_b.clone(), - key_material.clone(), - trace.clone(), - { - let mut mutated = false; - move |_header, body| { - if mutated { - return false; - } - if let Some(StreamFrame::Data(frame)) = body.frame.as_mut() { - mutated = true; - frame.offset = 2; - true - } else { - false - } - } - }, - ); - spawn_stream_mutating_forwarder( - outbound_b, - handle_a.clone(), - key_material, - trace, - |_header, _body| false, - ); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let mut request = stream.request; - let response = stream.respond_to.accept(Vec::new()).unwrap(); - let err = request.next_chunk().await.unwrap_err(); - assert!(matches!(err, QlError::StreamProtocol)); - let _ = response.finish().await; - }); - - let PendingStream { - mut request, - accepted, - } = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - let _accepted = accepted.await.unwrap(); - request.write_all(&[1, 2, 3, 4]).await.unwrap(); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn data_beyond_credit_is_protocol_error() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(400), - packet_ack_timeout: Duration::from_millis(30), - stream_retry_limit: 8, - max_payload_bytes: 4, - initial_credit: 4, - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - let key_material = session_key_material(&platform_a, &platform_b); - let trace = Arc::new(Mutex::new(SessionTrace::default())); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_stream_mutating_forwarder( - outbound_a, - handle_b.clone(), - key_material.clone(), - trace.clone(), - { - let mut mutated = false; - move |_header, body| { - if mutated { - return false; - } - if let Some(StreamFrame::Data(frame)) = body.frame.as_mut() { - mutated = true; - frame.offset = 4; - true - } else { - false - } - } - }, - ); - spawn_stream_mutating_forwarder( - outbound_b, - handle_a.clone(), - key_material, - trace, - |_header, _body| false, - ); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let mut request = stream.request; - let response = stream.respond_to.accept(Vec::new()).unwrap(); - let err = request.next_chunk().await.unwrap_err(); - assert!(matches!(err, QlError::StreamProtocol)); - let _ = response.finish().await; - }); - - let PendingStream { - mut request, - accepted, - } = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - let _accepted = accepted.await.unwrap(); - request.write_all(&[1, 2, 3, 4]).await.unwrap(); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn credit_regression_is_protocol_error() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(400), - packet_ack_timeout: Duration::from_millis(30), - stream_retry_limit: 8, - max_payload_bytes: 4, - initial_credit: 4, - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - let key_material = session_key_material(&platform_a, &platform_b); - let trace = Arc::new(Mutex::new(SessionTrace::default())); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_stream_mutating_forwarder( - outbound_a, - handle_b.clone(), - key_material.clone(), - trace.clone(), - |_header, _body| false, - ); - spawn_stream_mutating_forwarder(outbound_b, handle_a.clone(), key_material, trace, { - let mut mutated = false; - move |_header, body| { - if mutated { - return false; - } - if let Some(StreamFrame::Credit(StreamFrameCredit { - dir: Direction::Request, - recv_offset, - max_offset, - .. - })) = body.frame.as_mut() - { - mutated = true; - *recv_offset = 99; - *max_offset = 99; - true - } else { - false - } - } - }); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let mut request = stream.request; - let response = stream.respond_to.accept(Vec::new()).unwrap(); - assert_eq!(request.next_chunk().await.unwrap(), Some(vec![1, 2, 3, 4])); - let err = request.next_chunk().await.unwrap_err(); - assert!(matches!( - err, - QlError::StreamReset { - code: ResetCode::Protocol, - dir: Direction::Request, - } - )); - let _ = response.finish().await; - }); - - let PendingStream { - mut request, - accepted, - } = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - let mut accepted = accepted.await.unwrap(); - request.write_all(&[1, 2, 3, 4]).await.unwrap(); - tokio::time::sleep(Duration::from_millis(20)).await; - let err = request.write_all(&[5, 6, 7, 8]).await.unwrap_err(); - assert!(matches!(err, QlError::Cancelled)); - assert!(matches!( - accepted.response.next_chunk().await, - Ok(None) | Err(_) - )); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn disconnect_during_active_stream_aborts_both_halves() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - default_open_timeout: Duration::from_millis(400), - packet_ack_timeout: Duration::from_millis(30), - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = InboundPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - let handle_b_for_disconnect = handle_b.clone(); - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let mut request = stream.request; - let _response = stream.respond_to.accept(Vec::new()).unwrap(); - assert_eq!(request.next_chunk().await.unwrap(), Some(vec![1, 2, 3, 4])); - let request_outcome = request.next_chunk().await; - assert!(matches!( - request_outcome, - Ok(None) - | Err(QlError::Cancelled) - | Err(QlError::SendFailed) - | Err(QlError::StreamReset { .. }) - | Err(QlError::StreamProtocol) - )); - handle_b_for_disconnect.unpair().unwrap(); - }); - - let PendingStream { - mut request, - accepted, - } = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); - request.write_all(&[1, 2, 3, 4]).await.unwrap(); - let mut accepted = accepted.await.unwrap(); - handle_a.unpair().unwrap(); - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - await_status(&status_b, peer_a.xid, PeerStage::Disconnected).await; - tokio::time::sleep(Duration::from_millis(20)).await; - - let write_err = request.write_all(&[5, 6, 7, 8]).await.unwrap_err(); - assert!(matches!(write_err, QlError::Cancelled)); - assert!(matches!( - accepted.response.next_chunk().await, - Ok(None) | Err(_) - )); - - tokio::time::timeout(Duration::from_secs(1), responder_task) - .await - .unwrap() - .unwrap(); - }) - .await; -} diff --git a/ql2/src/tests/unpair.rs b/ql2/src/tests/unpair.rs deleted file mode 100644 index 7f6b8a79..00000000 --- a/ql2/src/tests/unpair.rs +++ /dev/null @@ -1,137 +0,0 @@ -use std::time::Duration; - -use super::*; - -#[tokio::test(flavor = "current_thread")] -async fn connected_unpair_removes_peer_on_both_sides() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - handle_a.connect().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Connected).await; - await_status(&status_b, peer_a.xid, PeerStage::Connected).await; - - handle_a.unpair().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - await_status(&status_b, peer_a.xid, PeerStage::Disconnected).await; - - let result_a = handle_a.open_stream(Vec::new(), Default::default()).await; - assert!(matches!(result_a, Err(QlError::NoPeerBound))); - - let result_b = handle_b.open_stream(Vec::new(), Default::default()).await; - assert!(matches!(result_b, Err(QlError::NoPeerBound))); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn unpair_works_without_session() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let (runtime_a, handle_a) = new_runtime(platform_a, config); - let (runtime_b, handle_b) = new_runtime(platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &peer_a, &peer_b); - - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - await_status(&status_b, peer_a.xid, PeerStage::Disconnected).await; - - handle_a.unpair().unwrap(); - - await_status(&status_a, peer_b.xid, PeerStage::Disconnected).await; - await_status(&status_b, peer_a.xid, PeerStage::Disconnected).await; - - let result_a = handle_a.open_stream(Vec::new(), Default::default()).await; - assert!(matches!(result_a, Err(QlError::NoPeerBound))); - - let result_b = handle_b.open_stream(Vec::new(), Default::default()).await; - assert!(matches!(result_b, Err(QlError::NoPeerBound))); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn invalid_unpair_signature_is_ignored() { - run_local_test(async { - let config = RuntimeConfig { - engine: EngineConfig { - handshake_timeout: Duration::from_millis(200), - ..Default::default() - }, - ..Default::default() - }; - let (platform_a, _outbound_a, _status_a) = TestPlatform::new(1); - let (platform_b, _outbound_b, status_b) = TestPlatform::new(2); - let (fake_signer, _fake_outbound, _fake_status) = TestPlatform::new(3); - let peer_a = peer_identity(&platform_a); - let peer_b = peer_identity(&platform_b); - - let forged_unpair = wire::unpair::build_unpair_record( - &fake_signer, - QlHeader { - sender: peer_a.xid, - recipient: peer_b.xid, - }, - PacketId(777), - now_secs().saturating_add(60), - ); - let forged_bytes = wire::encode_record(&forged_unpair); - - let (runtime_b, handle_b) = new_runtime(platform_b, config); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - handle_b.bind_peer(Peer { - peer: peer_a.xid, - signing_key: peer_a.signing_key.clone(), - encapsulation_key: peer_a.encapsulation_key.clone(), - }); - await_status(&status_b, peer_a.xid, PeerStage::Disconnected).await; - - handle_b.send_incoming(forged_bytes); - - tokio::time::sleep(Duration::from_millis(20)).await; - - let result = handle_b.open_stream(Vec::new(), Default::default()).await; - assert!(matches!(result, Err(QlError::MissingSession))); - }) - .await; -} diff --git a/ql2/src/wire/mod.rs b/ql2/src/wire/mod.rs deleted file mode 100644 index 7052eba5..00000000 --- a/ql2/src/wire/mod.rs +++ /dev/null @@ -1,153 +0,0 @@ -//! quantum link protocol wire format -//! -//! naming conventions: -//! - *Record - unencrypted messages -//! - *Body - message content after decrypting -//! - -use bc_components::XID; -use rkyv::{ - api::{ - high::{to_bytes_in, HighSerializer, HighValidator}, - low::{self, LowDeserializer}, - }, - bytecheck::CheckBytes, - ser::allocator::ArenaHandle, - Archive, Deserialize, Portable, Serialize, -}; - -pub mod encrypted_message; -pub mod handshake; -pub mod heartbeat; -pub mod pair; -pub mod seq; -pub mod stream; -pub mod unpair; - -pub use seq::StreamSeq; - -mod codec; - -pub(crate) use codec::*; - -use self::{ - encrypted_message::EncryptedMessage, handshake::HandshakeRecord, pair::PairRequestRecord, - unpair::UnpairRecord, -}; -use crate::{PacketId, QlError}; - -pub(crate) type WireArchiveError = rkyv::rancor::Error; - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct QlRecord { - pub header: QlHeader, - pub payload: QlPayload, -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct QlHeader { - #[rkyv(with = AsWireXid)] - pub sender: XID, - #[rkyv(with = AsWireXid)] - pub recipient: XID, -} - -impl QlHeader { - pub fn aad(&self) -> Vec { - encode_value(self) - } -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] -pub struct ControlMeta { - pub packet_id: PacketId, - pub valid_until: u64, -} - -impl From<&ArchivedControlMeta> for ControlMeta { - fn from(value: &ArchivedControlMeta) -> Self { - Self { - packet_id: (&value.packet_id).into(), - valid_until: value.valid_until.to_native(), - } - } -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub enum QlPayload { - Handshake(HandshakeRecord), - Pair(PairRequestRecord), - Unpair(UnpairRecord), - Heartbeat(EncryptedMessage), - Stream(EncryptedMessage), -} - -pub fn encode_record(record: &QlRecord) -> Vec { - encode_value(record) -} - -pub fn access_record(bytes: &[u8]) -> Result<&ArchivedQlRecord, QlError> { - access_value(bytes) -} - -pub fn decode_record(bytes: &[u8]) -> Result { - deserialize_value(access_record(bytes)?) -} - -pub(crate) fn encode_value( - value: &impl for<'a> Serialize, ArenaHandle<'a>, WireArchiveError>>, -) -> Vec { - to_bytes_in::<_, WireArchiveError>(value, Vec::new()) - .expect("wire serialization should not fail") -} - -pub(crate) fn access_value(bytes: &[u8]) -> Result<&T, QlError> -where - T: Portable + for<'a> CheckBytes>, -{ - rkyv::access::(bytes).map_err(|_| QlError::InvalidPayload) -} - -pub(crate) fn deserialize_value( - value: &impl rkyv::Deserialize>, -) -> Result { - low::deserialize::(value).map_err(|_| QlError::InvalidPayload) -} - -pub(crate) fn ensure_not_expired(valid_until: u64) -> Result<(), QlError> { - if now_secs() > valid_until { - Err(QlError::Timeout) - } else { - Ok(()) - } -} - -pub(crate) fn now_secs() -> u64 { - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|duration| duration.as_secs()) - .unwrap_or(0) -} - -#[test] -fn ql_record_round_trip() { - let record = QlRecord { - header: QlHeader { - sender: XID::from_data([1; XID::XID_SIZE]), - recipient: XID::from_data([2; XID::XID_SIZE]), - }, - payload: QlPayload::Heartbeat(encrypted_message::EncryptedMessage::encrypt( - &bc_components::SymmetricKey::from_data( - [7; bc_components::SymmetricKey::SYMMETRIC_KEY_SIZE], - ), - vec![3u8, 4, 5], - b"roundtrip", - [8; encrypted_message::NONCE_SIZE], - )), - }; - - let bytes = encode_record(&record); - let decoded = decode_record(&bytes).unwrap(); - - assert_eq!(decoded, record); -} From 3faca25103889e64e510691fd5403158c7325ae6 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 18 Mar 2026 01:16:43 -0400 Subject: [PATCH 008/304] ql: rewrite wire format and introduce FSM-based session core --- Cargo.lock | 21 +- Cargo.toml | 13 +- ql-engine/Cargo.toml | 9 +- ql-engine/src/arena.rs | 194 ++++ .../src/engine/implementation/handshake.rs | 27 +- ql-engine/src/engine/implementation/mod.rs | 901 +++++++--------- ql-engine/src/engine/implementation/peer.rs | 73 +- ql-engine/src/engine/implementation/stream.rs | 995 +++++------------- ql-engine/src/engine/mod.rs | 226 ++-- ql-engine/src/engine/ring.rs | 392 ------- ql-engine/src/engine/state.rs | 157 +-- ql-engine/src/engine/tests/handshake.rs | 19 +- ql-engine/src/engine/tests/mod.rs | 174 ++- ql-engine/src/engine/tests/stream.rs | 125 +-- ql-engine/src/lib.rs | 3 +- ql-engine/src/stream/internal.rs | 842 +++++++++++++++ ql-engine/src/stream/mod.rs | 270 +++++ ql-engine/src/stream/ring.rs | 194 ++++ .../src/{engine/stream.rs => stream/state.rs} | 501 +++++---- ql-engine/src/stream/tests.rs | 334 ++++++ ql-engine/src/wire/handshake/crypto.rs | 3 +- ql-engine/src/wire/handshake/mod.rs | 4 +- ql-engine/src/wire/mod.rs | 57 +- ql-fsm/Cargo.toml | 20 + ql-fsm/src/implementation/handshake.rs | 713 +++++++++++++ ql-fsm/src/implementation/mod.rs | 135 +++ ql-fsm/src/implementation/peer.rs | 56 + ql-fsm/src/lib.rs | 84 ++ ql-fsm/src/replay_cache.rs | 38 + ql-fsm/src/session/internal.rs | 628 +++++++++++ ql-fsm/src/session/mod.rs | 174 +++ ql-fsm/src/session/ring.rs | 141 +++ ql-fsm/src/session/state.rs | 156 +++ ql-fsm/src/session/tests.rs | 158 +++ ql-fsm/src/state.rs | 177 ++++ ql-runtime/src/driver.rs | 442 +++++--- ql-runtime/src/tests/mod.rs | 10 +- ql-wire/Cargo.toml | 20 + ql-wire/src/codec.rs | 280 +++++ ql-wire/src/encrypted/close/mod.rs | 8 + ql-wire/src/encrypted/heartbeat/mod.rs | 4 + ql-wire/src/encrypted/mod.rs | 68 ++ ql-wire/src/encrypted/stream/mod.rs | 60 ++ ql-wire/src/encrypted/unpair/mod.rs | 4 + ql-wire/src/encrypted_message.rs | 63 ++ ql-wire/src/handshake/crypto.rs | 331 ++++++ ql-wire/src/handshake/mod.rs | 66 ++ ql-wire/src/id.rs | 51 + ql-wire/src/lib.rs | 228 ++++ ql-wire/src/pair/crypto.rs | 132 +++ ql-wire/src/pair/mod.rs | 28 + ql-wire/src/xid.rs | 16 + 52 files changed, 7292 insertions(+), 2533 deletions(-) create mode 100644 ql-engine/src/arena.rs delete mode 100644 ql-engine/src/engine/ring.rs create mode 100644 ql-engine/src/stream/internal.rs create mode 100644 ql-engine/src/stream/mod.rs create mode 100644 ql-engine/src/stream/ring.rs rename ql-engine/src/{engine/stream.rs => stream/state.rs} (69%) create mode 100644 ql-engine/src/stream/tests.rs create mode 100644 ql-fsm/Cargo.toml create mode 100644 ql-fsm/src/implementation/handshake.rs create mode 100644 ql-fsm/src/implementation/mod.rs create mode 100644 ql-fsm/src/implementation/peer.rs create mode 100644 ql-fsm/src/lib.rs create mode 100644 ql-fsm/src/replay_cache.rs create mode 100644 ql-fsm/src/session/internal.rs create mode 100644 ql-fsm/src/session/mod.rs create mode 100644 ql-fsm/src/session/ring.rs create mode 100644 ql-fsm/src/session/state.rs create mode 100644 ql-fsm/src/session/tests.rs create mode 100644 ql-fsm/src/state.rs create mode 100644 ql-wire/Cargo.toml create mode 100644 ql-wire/src/codec.rs create mode 100644 ql-wire/src/encrypted/close/mod.rs create mode 100644 ql-wire/src/encrypted/heartbeat/mod.rs create mode 100644 ql-wire/src/encrypted/mod.rs create mode 100644 ql-wire/src/encrypted/stream/mod.rs create mode 100644 ql-wire/src/encrypted/unpair/mod.rs create mode 100644 ql-wire/src/encrypted_message.rs create mode 100644 ql-wire/src/handshake/crypto.rs create mode 100644 ql-wire/src/handshake/mod.rs create mode 100644 ql-wire/src/id.rs create mode 100644 ql-wire/src/lib.rs create mode 100644 ql-wire/src/pair/crypto.rs create mode 100644 ql-wire/src/pair/mod.rs create mode 100644 ql-wire/src/xid.rs diff --git a/Cargo.lock b/Cargo.lock index d04be45b..f09e182f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1965,7 +1965,16 @@ version = "0.1.0" dependencies = [ "bc-components", "chacha20poly1305", - "dcbor", + "rkyv", + "thiserror", +] + +[[package]] +name = "ql-fsm" +version = "0.1.0" +dependencies = [ + "bc-components", + "ql-wire", "rkyv", "thiserror", ] @@ -1983,6 +1992,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "ql-wire" +version = "0.1.0" +dependencies = [ + "bc-components", + "chacha20poly1305", + "rkyv", + "thiserror", +] + [[package]] name = "quantum-link-macros" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 73284cea..496d5574 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,15 @@ [workspace] resolver = "2" -members = ["api", "backup-shard", "btp", "ql-engine", "ql-runtime", "quantum-link-macros"] +members = [ + "api", + "backup-shard", + "btp", + "ql-fsm", + "ql-engine", + "ql-runtime", + "ql-wire", + "quantum-link-macros", +] [workspace.package] homepage = "https://github.com/Foundation-Devices/foundation-api" @@ -25,6 +34,8 @@ btp = { path = "btp" } foundation-api = { path = "api" } quantum-link-macros = { path = "quantum-link-macros" } ql-protocol = { path = "ql-protocol" } +ql-fsm = { path = "ql-fsm" } +ql-wire = { path = "ql-wire" } [patch.crates-io] pqcrypto-traits = { git = "https://github.com/Foundation-Devices/pqcrypto", rev = "ebadf71214f67cb970242fa1053b4acb65767737" } diff --git a/ql-engine/Cargo.toml b/ql-engine/Cargo.toml index b6a9d09d..4803431f 100644 --- a/ql-engine/Cargo.toml +++ b/ql-engine/Cargo.toml @@ -10,6 +10,11 @@ bc-components = { version = "0.28.0", default-features = false, features = [ "pqcrypto", ] } chacha20poly1305 = { version = "0.10.1" } -dcbor = { version = "0.23.3" } -rkyv = { version = "0.8", default-features = false, features = ["std", "bytecheck", "little_endian", "unaligned", "pointer_width_32"] } +rkyv = { version = "0.8", default-features = false, features = [ + "std", + "bytecheck", + "little_endian", + "unaligned", + "pointer_width_32", +] } thiserror = { version = "2" } diff --git a/ql-engine/src/arena.rs b/ql-engine/src/arena.rs new file mode 100644 index 00000000..b7b5a4a4 --- /dev/null +++ b/ql-engine/src/arena.rs @@ -0,0 +1,194 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ArenaKey { + index: u32, + generation: u32, +} + +impl ArenaKey { + fn index(self) -> usize { + self.index as usize + } +} + +#[derive(Debug)] +struct Slot { + generation: u32, + value: Option, + next_free: Option, +} + +#[derive(Debug)] +pub struct GenerationalArena { + slots: Vec>, + free_head: Option, + len: usize, +} + +impl GenerationalArena { + pub fn new() -> Self { + Self { + slots: Vec::new(), + free_head: None, + len: 0, + } + } + + pub fn len(&self) -> usize { + self.len + } + + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + pub fn contains(&self, key: ArenaKey) -> bool { + self.get(key).is_some() + } + + pub fn values(&self) -> impl Iterator { + self.slots.iter().filter_map(|slot| slot.value.as_ref()) + } + + pub fn clear(&mut self) { + self.slots.clear(); + self.free_head = None; + self.len = 0; + } + + pub fn insert(&mut self, value: T) -> ArenaKey { + self.len += 1; + + if let Some(index) = self.free_head { + let slot = &mut self.slots[index as usize]; + self.free_head = slot.next_free.take(); + slot.value = Some(value); + return ArenaKey { + index, + generation: slot.generation, + }; + } + + assert!(self.slots.len() < u32::MAX as usize); + let index = self.slots.len() as u32; + self.slots.push(Slot { + generation: 0, + value: Some(value), + next_free: None, + }); + ArenaKey { + index, + generation: 0, + } + } + + pub fn get(&self, key: ArenaKey) -> Option<&T> { + let slot = self.slots.get(key.index())?; + (slot.generation == key.generation) + .then_some(slot.value.as_ref()) + .flatten() + } + + pub fn get_mut(&mut self, key: ArenaKey) -> Option<&mut T> { + let slot = self.slots.get_mut(key.index())?; + (slot.generation == key.generation) + .then_some(slot.value.as_mut()) + .flatten() + } + + pub fn remove(&mut self, key: ArenaKey) -> Option { + let slot = self.slots.get_mut(key.index())?; + if slot.generation != key.generation { + return None; + } + + let value = slot.value.take()?; + slot.generation = slot.generation.wrapping_add(1); + slot.next_free = self.free_head; + self.free_head = Some(key.index); + self.len -= 1; + Some(value) + } + + pub fn retain(&mut self, mut f: impl FnMut(ArenaKey, &mut T) -> bool) { + for (index, slot) in self.slots.iter_mut().enumerate() { + let Some(value) = slot.value.as_mut() else { + continue; + }; + let key = ArenaKey { + index: index as u32, + generation: slot.generation, + }; + if f(key, value) { + continue; + } + let _ = slot.value.take(); + slot.generation = slot.generation.wrapping_add(1); + slot.next_free = self.free_head; + self.free_head = Some(index as u32); + self.len -= 1; + } + } +} + +impl Default for GenerationalArena { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::{ArenaKey, GenerationalArena}; + + #[test] + fn insert_get_remove_round_trips() { + let mut arena = GenerationalArena::new(); + let key = arena.insert("hello"); + + assert_eq!(arena.len(), 1); + assert_eq!(arena.get(key), Some(&"hello")); + assert!(arena.contains(key)); + + assert_eq!(arena.remove(key), Some("hello")); + assert!(arena.is_empty()); + assert_eq!(arena.get(key), None); + assert!(!arena.contains(key)); + } + + #[test] + fn stale_key_does_not_hit_reused_slot() { + let mut arena = GenerationalArena::new(); + let old = arena.insert(10); + assert_eq!(arena.remove(old), Some(10)); + + let new = arena.insert(20); + assert_eq!(old.index(), new.index()); + assert_ne!(old, new); + + assert_eq!(arena.get(old), None); + assert_eq!(arena.get(new), Some(&20)); + } + + #[test] + fn get_mut_updates_value() { + let mut arena = GenerationalArena::new(); + let key = arena.insert(String::from("a")); + + arena.get_mut(key).unwrap().push('b'); + + assert_eq!(arena.get(key).map(String::as_str), Some("ab")); + } + + #[test] + fn remove_rejects_wrong_generation() { + let mut arena = GenerationalArena::new(); + let key = arena.insert(1u32); + let wrong = ArenaKey { + index: key.index as u32, + generation: key.generation.wrapping_add(1), + }; + + assert_eq!(arena.remove(wrong), None); + assert_eq!(arena.get(key), Some(&1)); + } +} diff --git a/ql-engine/src/engine/implementation/handshake.rs b/ql-engine/src/engine/implementation/handshake.rs index 3e3db21a..0acf91dc 100644 --- a/ql-engine/src/engine/implementation/handshake.rs +++ b/ql-engine/src/engine/implementation/handshake.rs @@ -29,12 +29,8 @@ enum HelloReplyAction { }, } -pub fn handle_connect( - engine: &mut Engine, - now: Instant, - crypto: &impl QlCrypto, - emit: &mut impl OutputFn, -) { +pub fn handle_connect(engine: &mut Engine, crypto: &impl QlCrypto) { + let now = engine.state.now; let Some(_) = engine.peer.as_ref() else { return; }; @@ -48,18 +44,17 @@ pub fn handle_connect( start_initiator_handshake(config, identity, state, peer_record, now, crypto) }; if started { - engine.emit_peer_status(emit); + engine.emit_peer_status(); } } pub fn handle_hello( engine: &mut Engine, - now: Instant, peer: XID, hello: &wire::handshake::ArchivedHello, crypto: &impl QlCrypto, - emit: &mut impl OutputFn, ) { + let now = engine.state.now; let action = match engine.peer.as_ref() { Some(entry) => { if wire::handshake::verify_hello(peer, engine.identity.xid, &entry.signing_key, hello) @@ -127,7 +122,7 @@ pub fn handle_hello( ) }; if changed { - engine.emit_peer_status(emit); + engine.emit_peer_status(); } } HelloAction::ResendReply { @@ -153,11 +148,10 @@ pub fn handle_hello( pub fn handle_hello_reply( engine: &mut Engine, - now: Instant, peer: XID, reply: &wire::handshake::ArchivedHelloReply, - _emit: &mut impl OutputFn, ) { + let now = engine.state.now; let action = { let Some(peer_record) = engine.peer.as_ref() else { return; @@ -264,12 +258,11 @@ pub fn handle_hello_reply( pub fn handle_confirm( engine: &mut Engine, - now: Instant, peer: XID, confirm: &wire::handshake::ArchivedConfirm, crypto: &impl QlCrypto, - _emit: &mut impl OutputFn, ) { + let now = engine.state.now; if let Some((ready, deadline, token)) = current_ready_resend(engine, now, peer, confirm) { if engine.handshake_write_pending(token) { return; @@ -358,11 +351,9 @@ pub fn handle_confirm( pub fn handle_ready( engine: &mut Engine, - now: Instant, peer: XID, header: &QlHeader, ready: &mut wire::handshake::ArchivedReady, - emit: &mut impl OutputFn, ) { let session_key = { let Some(peer_record) = engine.peer.as_ref() else { @@ -394,8 +385,8 @@ pub fn handle_ready( recent_ready: None, }; } - engine.record_activity(now); - engine.emit_peer_status(emit); + engine.record_activity(); + engine.emit_peer_status(); } fn start_initiator_handshake( diff --git a/ql-engine/src/engine/implementation/mod.rs b/ql-engine/src/engine/implementation/mod.rs index 4549a243..4b747893 100644 --- a/ql-engine/src/engine/implementation/mod.rs +++ b/ql-engine/src/engine/implementation/mod.rs @@ -10,110 +10,97 @@ use rkyv::access_mut; use crate::{ engine::{ replay_cache::ReplayKey, - state::{ActiveWrite, ControlWritePayload, OutboundWriteKind, TimeoutKind}, - stream::{InFlightWriteState, StreamRole, StreamState}, - Engine, EngineInput, EngineOutput, HandshakeInitiator, HandshakeResponder, KeepAliveConfig, - KeepAliveState, OutboundWrite, OutputFn, PeerRecord, PeerSession, QlCrypto, RecentReady, + state::{ActiveWrite, OutboundWriteKind, TimeoutKind}, + Engine, EngineEvent, HandshakeInitiator, HandshakeResponder, KeepAliveConfig, + KeepAliveState, OutboundWrite, PeerRecord, PeerSession, QlCrypto, RecentReady, StreamConfig, Token, WriteId, }, wire::{ self, encrypted_message::{ArchivedEncryptedMessage, NONCE_SIZE}, - stream::{ - encrypt_stream, BodyChunk, CloseCode, CloseTarget, StreamAck, StreamBody, StreamFrame, - StreamFrameClose, StreamMessage, - }, - ControlMeta, QlHeader, StreamSeq, + stream::{BodyChunk, CloseCode, CloseTarget}, + ControlMeta, QlHeader, }, Peer, QlError, StreamId, }; impl Engine { - pub fn open_stream( + pub(crate) fn open_stream_inner( &mut self, - now: Instant, request_head: Vec, request_prefix: Option, config: StreamConfig, ) -> Result { - self.state.now = now; - stream::open_stream(self, now, request_head, request_prefix, config) + stream::open_stream(self, request_head, request_prefix, config) } - pub fn run_tick_inner( + pub(crate) fn bind_peer_inner(&mut self, peer: Peer) { + peer::handle_bind_peer(self, peer); + } + + pub(crate) fn pair_inner(&mut self, crypto: &impl QlCrypto) { + peer::handle_pair_local(self, crypto); + } + + pub(crate) fn connect_inner(&mut self, crypto: &impl QlCrypto) { + handshake::handle_connect(self, crypto); + } + + pub(crate) fn unpair_inner(&mut self) { + peer::handle_unpair_local(self); + } + + pub(crate) fn write_stream_inner( &mut self, - now: Instant, - input: EngineInput, - crypto: &impl QlCrypto, - emit: &mut impl OutputFn, - ) { - self.state.now = now; - match input { - EngineInput::BindPeer(peer) => peer::handle_bind_peer(self, peer, emit), - EngineInput::Pair => peer::handle_pair_local(self, now, crypto), - EngineInput::Connect => handshake::handle_connect(self, now, crypto, emit), - EngineInput::Unpair => peer::handle_unpair_local(self, now, emit), - EngineInput::CloseStream { - stream_id, - target, - code, - payload, - } => stream::handle_close_stream(self, now, stream_id, target, code, payload), - EngineInput::OutboundData { stream_id, bytes } => { - stream::handle_outbound_data(self, stream_id, bytes) - } - EngineInput::OutboundFinished { stream_id } => { - stream::handle_outbound_finished(self, stream_id) - } - EngineInput::Incoming(bytes) => self.handle_incoming(now, bytes, crypto, emit), - EngineInput::TimerExpired => self.handle_timeouts(now, crypto, emit), - } + stream_id: StreamId, + bytes: Vec, + ) -> Result<(), QlError> { + stream::handle_outbound_data(self, stream_id, bytes) + } - self.handle_ready_retransmits(now, emit); + pub(crate) fn finish_stream_inner(&mut self, stream_id: StreamId) -> Result<(), QlError> { + stream::handle_outbound_finished(self, stream_id) } - pub fn take_next_write_inner(&mut self, crypto: &impl QlCrypto) -> Option { - self.take_next_control_write(crypto) - .or_else(|| stream::take_next_stream_write(self, crypto)) + pub(crate) fn close_stream_inner( + &mut self, + stream_id: StreamId, + target: CloseTarget, + code: CloseCode, + payload: Vec, + ) -> Result<(), QlError> { + stream::handle_close_stream(self, stream_id, target, code, payload) + } + + pub(crate) fn receive_inner(&mut self, bytes: Vec, crypto: &impl QlCrypto) { + self.handle_incoming(bytes, crypto); } - pub fn complete_write_inner( + pub(crate) fn take_next_write_inner( &mut self, - write_id: WriteId, - result: Result<(), QlError>, - emit: &mut impl OutputFn, - ) { - let now = self.state.now; - let Some(active) = self.state.active_writes.remove(&write_id) else { + crypto: &impl QlCrypto, + ) -> Option { + self.take_next_control_write() + .or_else(|| stream::take_next_stream_write(self, crypto)) + } + + pub(crate) fn complete_write_inner(&mut self, write_id: WriteId, result: Result<(), QlError>) { + let Some(active) = self.state.active_writes.remove(write_id.0) else { return; }; - if let OutboundWriteKind::StreamAck { .. } = active.kind { - if let Some(token) = active.token { - self.clear_ack_outbound_token(token, result.is_err()); - } - } - if let Err(error) = result { - // only fail the stream if this frame is still in flight - // ACKs and protocol reset can remove it before write completion arrives - if let OutboundWriteKind::StreamFrame { stream_id, tx_seq } = active.kind { - if self - .streams - .get(&stream_id) - .is_some_and(|stream| stream.control.in_flight.contains_key(&tx_seq)) - { - self.fail_stream_by_id(stream_id, error.clone(), emit); - } + if let OutboundWriteKind::Stream(completion) = active.kind { + stream::complete_stream_write(self, completion, Err(error.clone())); } if self.is_handshake_token(active.token) { if let Some(entry) = self.peer.as_mut() { entry.session = PeerSession::Disconnected; } - self.emit_peer_status(emit); + self.emit_peer_status(); self.drop_outbound(); - self.abort_streams(error, emit); + self.abort_streams(error); } return; @@ -127,27 +114,108 @@ impl Engine { recent_ready, }; } - self.emit_peer_status(emit); - self.record_activity(now); + self.emit_peer_status(); + self.record_activity(); } if let Some(token) = active.token { - self.schedule_handshake_retry_after_write(token, now); + self.schedule_handshake_retry_after_write(token); + } + + if let OutboundWriteKind::Stream(completion) = active.kind { + stream::complete_stream_write(self, completion, Ok(())); + } + } + + pub(crate) fn on_timer_inner(&mut self, crypto: &impl QlCrypto) { + let now = self.state.now; + loop { + let Some(entry) = self + .state + .timeouts + .peek_mut() + .filter(|entry| entry.0.at <= now) + else { + break; + }; + let entry = std::collections::binary_heap::PeekMut::pop(entry).0; + match entry.kind { + TimeoutKind::Outbound { token } => { + self.state + .control_outbound + .retain(|message| message.token != token); + } + } } - if let OutboundWriteKind::StreamFrame { stream_id, tx_seq } = active.kind { - if let Some(stream) = self.streams.get_mut(&stream_id) { - stream - .control - .complete_write(tx_seq, now + self.config.stream_ack_timeout); + stream::handle_stream_timeouts(self); + + if let Some(PeerRecord { + session: PeerSession::Connected { recent_ready, .. }, + .. + }) = self.peer.as_mut() + { + if recent_ready + .as_ref() + .is_some_and(|ready| ready.expires_at <= now) + { + *recent_ready = None; + } + } + + let handshake_due = self + .handshake_deadline() + .is_some_and(|deadline| deadline <= now); + if handshake_due { + self.fail_handshake(QlError::Timeout); + return; + } + + let handshake_retry_due = self + .handshake_retry_deadline() + .is_some_and(|deadline| deadline <= now); + if handshake_retry_due { + self.handle_handshake_retry_timeout(); + } + + let keepalive_due = self + .keep_alive_deadline() + .is_some_and(|deadline| deadline <= now); + if !keepalive_due { + return; + } + + let Some(entry) = self.peer.as_ref() else { + return; + }; + let PeerSession::Connected { keepalive, .. } = &entry.session else { + return; + }; + + if keepalive.pending { + if let Some(entry) = self.peer.as_mut() { + entry.session = PeerSession::Disconnected; + } + self.emit_peer_status(); + self.drop_outbound(); + self.abort_streams(QlError::SendFailed); + return; + } + + self.send_heartbeat_message(crypto); + if let Some(entry) = self.peer.as_mut() { + if let PeerSession::Connected { keepalive, .. } = &mut entry.session { + keepalive.pending = true; + keepalive.last_activity = Some(now); } } } - pub fn next_deadline_inner(&self) -> Option { + pub(crate) fn next_deadline_inner(&self) -> Option { [ self.state.next_deadline(), - self.streams.stream_retry_deadline(), + self.streams.next_deadline(), + self.handshake_retry_deadline(), self.handshake_deadline(), self.keep_alive_deadline(), ] @@ -155,65 +223,27 @@ impl Engine { .flatten() .min() } + + pub(crate) fn abort_inner(&mut self, error: QlError) { + self.abort_streams(error); + } } impl Engine { - fn emit_peer_status(&self, emit: &mut impl OutputFn) { - if let Some(peer) = self.peer.as_ref() { - emit(EngineOutput::PeerStatusChanged { + fn emit_peer_status(&mut self) { + let event = self + .peer + .as_ref() + .map(|peer| EngineEvent::PeerStatusChanged { peer: peer.peer, session: peer.session.clone(), }); + if let Some(event) = event { + self.state.pending_events.push_back(event); } } - fn next_control_meta(&self, valid_for: Duration) -> ControlMeta { - ControlMeta { - packet_id: self.state.next_packet_id(), - valid_until: wire::now_secs() + valid_for.as_secs(), - } - } - - fn keep_alive_deadline(&self) -> Option { - let config = self.keep_alive_config()?; - let entry = self.peer.as_ref()?; - let PeerSession::Connected { keepalive, .. } = &entry.session else { - return None; - }; - let base = keepalive.last_activity?; - Some( - base + if keepalive.pending { - config.timeout - } else { - config.interval - }, - ) - } - - fn handshake_deadline(&self) -> Option { - let entry = self.peer.as_ref()?; - match &entry.session { - PeerSession::Initiator { deadline, .. } | PeerSession::Responder { deadline, .. } => { - Some(*deadline) - } - PeerSession::Disconnected | PeerSession::Connected { .. } => None, - } - } - - fn is_replayed_control(&mut self, peer: XID, meta: ControlMeta) -> bool { - self.state - .replay_cache - .check_and_store_valid_until(ReplayKey::new(peer, meta.packet_id), meta.valid_until) - } - - // TODO: why do we pass 'now' if it's in state? - fn handle_incoming( - &mut self, - now: Instant, - mut bytes: Vec, - crypto: &impl QlCrypto, - emit: &mut impl OutputFn, - ) { + fn handle_incoming(&mut self, mut bytes: Vec, crypto: &impl QlCrypto) { let Ok(record) = access_mut::(&mut bytes) else { return; @@ -237,55 +267,51 @@ impl Engine { }; match &mut record.payload { wire::ArchivedQlPayload::Handshake(message) => { - self.handle_handshake(now, sender, &header, message, crypto, emit) + self.handle_handshake(sender, &header, message, crypto) } wire::ArchivedQlPayload::Stream(encrypted) => { - stream::handle_stream(self, now, sender, &header, encrypted, emit) + stream::handle_stream(self, sender, &header, encrypted) } wire::ArchivedQlPayload::Heartbeat(encrypted) => { - self.handle_heartbeat(now, &header, encrypted, crypto, emit) + self.handle_heartbeat(&header, encrypted, crypto) } wire::ArchivedQlPayload::Pair(request) => { - peer::handle_pairing(self, now, &header, request, crypto, emit) + peer::handle_pairing(self, &header, request, crypto) } wire::ArchivedQlPayload::Unpair(unpair_record) => { - peer::handle_unpair(self, sender, &header, unpair_record, emit) + peer::handle_unpair(self, sender, &header, unpair_record) } } } fn handle_handshake( &mut self, - now: Instant, peer: XID, header: &QlHeader, message: &mut wire::handshake::ArchivedHandshakeRecord, crypto: &impl QlCrypto, - emit: &mut impl OutputFn, ) { match message { wire::handshake::ArchivedHandshakeRecord::Hello(hello) => { - handshake::handle_hello(self, now, peer, hello, crypto, emit) + handshake::handle_hello(self, peer, hello, crypto) } wire::handshake::ArchivedHandshakeRecord::HelloReply(reply) => { - handshake::handle_hello_reply(self, now, peer, reply, emit) + handshake::handle_hello_reply(self, peer, reply) } wire::handshake::ArchivedHandshakeRecord::Confirm(confirm) => { - handshake::handle_confirm(self, now, peer, confirm, crypto, emit) + handshake::handle_confirm(self, peer, confirm, crypto) } wire::handshake::ArchivedHandshakeRecord::Ready(ready) => { - handshake::handle_ready(self, now, peer, header, ready, emit) + handshake::handle_ready(self, peer, header, ready) } } } fn handle_heartbeat( &mut self, - now: Instant, header: &QlHeader, encrypted: &mut ArchivedEncryptedMessage, crypto: &impl QlCrypto, - emit: &mut impl OutputFn, ) { let (body, should_reply) = { let Some(peer_record) = self.peer.as_ref() else { @@ -308,62 +334,209 @@ impl Engine { if self.is_replayed_control(header.sender, body.meta) { return; } - self.record_activity(now); + self.record_activity(); if should_reply { - self.send_heartbeat_message(now, crypto); + self.send_heartbeat_message(crypto); } - self.emit_peer_status(emit); - } - - fn handle_ready_retransmits(&mut self, now: Instant, emit: &mut impl OutputFn) { - let mut timed_out = Vec::new(); - for (stream_id, stream) in self.streams.iter() { - let exhausted = stream.control.in_flight.iter().any(|(_, in_flight)| { - matches!( - in_flight.write_state, - InFlightWriteState::WaitingRetry { retry_at } - if retry_at <= now && in_flight.attempt >= self.config.stream_retry_limit - ) - }); - if exhausted { - timed_out.push(*stream_id); + self.emit_peer_status(); + } + + fn fail_handshake(&mut self, error: QlError) { + if let Some(entry) = self.peer.as_mut() { + if matches!( + entry.session, + PeerSession::Initiator { .. } | PeerSession::Responder { .. } + ) { + entry.session = PeerSession::Disconnected; + } + } + self.emit_peer_status(); + self.drop_outbound(); + self.abort_streams(error); + } + + fn handle_handshake_retry_timeout(&mut self) { + enum RetryAction { + Resend { + token: Token, + peer: XID, + deadline: Instant, + record: wire::handshake::HandshakeRecord, + }, + Fail, + Ignore, + } + + let now = self.state.now; + let action = { + let Some(entry) = self.peer.as_mut() else { + return; + }; + let peer = entry.peer; + match &mut entry.session { + PeerSession::Initiator { + handshake_token, + hello, + deadline, + stage: + HandshakeInitiator::WaitingHelloReply { + retry_count, + retry_at, + }, + .. + } if retry_at.is_some_and(|at| at <= now) => { + let token = *handshake_token; + *retry_at = None; + if *retry_count >= self.config.max_handshake_retries { + RetryAction::Fail + } else { + *retry_count = retry_count.saturating_add(1); + RetryAction::Resend { + token, + peer, + deadline: *deadline, + record: wire::handshake::HandshakeRecord::Hello(hello.clone()), + } + } + } + PeerSession::Initiator { + handshake_token, + deadline, + stage: + HandshakeInitiator::WaitingReady { + confirm, + retry_count, + retry_at, + .. + }, + .. + } if retry_at.is_some_and(|at| at <= now) => { + let token = *handshake_token; + *retry_at = None; + if *retry_count >= self.config.max_handshake_retries { + RetryAction::Fail + } else { + *retry_count = retry_count.saturating_add(1); + RetryAction::Resend { + token, + peer, + deadline: *deadline, + record: wire::handshake::HandshakeRecord::Confirm(confirm.clone()), + } + } + } + PeerSession::Responder { + handshake_token, + reply, + deadline, + stage: + HandshakeResponder::WaitingConfirm { + retry_count, + retry_at, + .. + }, + .. + } if retry_at.is_some_and(|at| at <= now) => { + let token = *handshake_token; + *retry_at = None; + if *retry_count >= self.config.max_handshake_retries { + RetryAction::Fail + } else { + *retry_count = retry_count.saturating_add(1); + RetryAction::Resend { + token, + peer, + deadline: *deadline, + record: wire::handshake::HandshakeRecord::HelloReply(reply.clone()), + } + } + } + _ => RetryAction::Ignore, } + }; + + match action { + RetryAction::Resend { + token, + peer, + deadline, + record, + } => { + if self.handshake_write_pending(token) { + return; + } + handshake::enqueue_handshake_record(self, token, deadline, peer, record); + } + RetryAction::Fail => self.fail_handshake(QlError::Timeout), + RetryAction::Ignore => {} } + } + + fn abort_streams(&mut self, error: QlError) { + stream::abort_streams(self, error); + } + + fn next_control_meta(&self, valid_for: Duration) -> ControlMeta { + ControlMeta { + packet_id: self.state.next_packet_id(), + valid_until: wire::now_secs() + valid_for.as_secs(), + } + } + + fn keep_alive_deadline(&self) -> Option { + let config = self.keep_alive_config()?; + let entry = self.peer.as_ref()?; + let PeerSession::Connected { keepalive, .. } = &entry.session else { + return None; + }; + let base = keepalive.last_activity?; + Some( + base + if keepalive.pending { + config.timeout + } else { + config.interval + }, + ) + } - for stream_id in timed_out { - self.fail_stream_by_id(stream_id, QlError::Timeout, emit); + fn handshake_deadline(&self) -> Option { + let entry = self.peer.as_ref()?; + match &entry.session { + PeerSession::Initiator { deadline, .. } | PeerSession::Responder { deadline, .. } => { + Some(*deadline) + } + PeerSession::Disconnected | PeerSession::Connected { .. } => None, } } - fn clear_ack_outbound_token(&mut self, token: Token, retry: bool) { - for stream in self.streams.values_mut() { - let control = &mut stream.control; - if control.ack_outbound_token == Some(token) { - control.ack_outbound_token = None; - if retry { - control.note_ack(true); - } - break; + fn handshake_retry_deadline(&self) -> Option { + let entry = self.peer.as_ref()?; + match &entry.session { + PeerSession::Initiator { + stage: HandshakeInitiator::WaitingHelloReply { retry_at, .. }, + .. + } + | PeerSession::Initiator { + stage: HandshakeInitiator::WaitingReady { retry_at, .. }, + .. + } + | PeerSession::Responder { + stage: HandshakeResponder::WaitingConfirm { retry_at, .. }, + .. + } => *retry_at, + PeerSession::Disconnected + | PeerSession::Responder { + stage: HandshakeResponder::SendingReady { .. }, + .. } + | PeerSession::Connected { .. } => None, } } - fn clear_active_writes_for_stream(&mut self, stream_id: StreamId) { + fn is_replayed_control(&mut self, peer: XID, meta: ControlMeta) -> bool { self.state - .active_writes - .retain(|_, active| match active.kind { - OutboundWriteKind::Control => true, - OutboundWriteKind::StreamAck { - stream_id: active_stream_id, - } - | OutboundWriteKind::StreamClose { - stream_id: active_stream_id, - } => active_stream_id != stream_id, - OutboundWriteKind::StreamFrame { - stream_id: active_stream_id, - .. - } => active_stream_id != stream_id, - }); + .replay_cache + .check_and_store_valid_until(ReplayKey::new(peer, meta.packet_id), meta.valid_until) } fn is_handshake_token(&self, token: Option) -> bool { @@ -437,16 +610,17 @@ impl Engine { } } - fn schedule_handshake_retry_after_write(&mut self, token: Token, now: Instant) { + fn schedule_handshake_retry_after_write(&mut self, token: Token) { if self.config.handshake_retry_interval.is_zero() || self.config.max_handshake_retries == 0 { return; } + let now = self.state.now; let retry_at = now + self.config.handshake_retry_interval; let Some(entry) = self.peer.as_mut() else { return; }; - let scheduled = match &mut entry.session { + match &mut entry.session { PeerSession::Initiator { handshake_token, stage: @@ -457,7 +631,6 @@ impl Engine { .. } if *handshake_token == token => { *stage_retry_at = Some(retry_at); - true } PeerSession::Initiator { handshake_token, @@ -469,7 +642,6 @@ impl Engine { .. } if *handshake_token == token => { *stage_retry_at = Some(retry_at); - true } PeerSession::Responder { handshake_token, @@ -481,16 +653,12 @@ impl Engine { .. } if *handshake_token == token => { *stage_retry_at = Some(retry_at); - true } - _ => false, - }; - if scheduled { - self.state.schedule_handshake_retry(token, retry_at); + _ => {} } } - fn stream_write_session(&self) -> Option<(XID, SymmetricKey)> { + fn peer_session(&self) -> Option<(XID, SymmetricKey)> { self.peer.as_ref().and_then(|peer| { peer.session .session_key() @@ -498,70 +666,43 @@ impl Engine { }) } + // todo: this is called in too many places + fn sync_stream_namespace(&mut self) { + use crate::stream::StreamNamespace; + let namespace = self + .peer + .as_ref() + .map(|peer| StreamNamespace::for_local(self.identity.xid, peer.peer)) + .unwrap_or(crate::stream::StreamNamespace::Low); + self.streams.set_local_namespace(namespace); + } + fn issue_write( &mut self, kind: OutboundWriteKind, token: Option, bytes: Vec, ) -> OutboundWrite { - let id = self.state.next_write_id(); - self.state - .active_writes - .insert(id, ActiveWrite { token, kind }); + let id = WriteId(self.state.active_writes.insert(ActiveWrite { token, kind })); OutboundWrite { id, bytes } } - fn take_next_control_write(&mut self, crypto: &impl QlCrypto) -> Option { + fn take_next_control_write(&mut self) -> Option { while let Some(message) = self.state.control_outbound.pop_front() { - let bytes = match message.payload { - ControlWritePayload::Encoded(bytes) => bytes, - ControlWritePayload::StreamClose { - stream_id, - target, - code, - payload, - } => { - let Some((recipient, session_key)) = self.stream_write_session() else { - continue; - }; - let body = StreamBody::Message(StreamMessage { - tx_seq: StreamSeq::START, - ack: StreamAck::EMPTY, - valid_until: wire::now_secs() - .saturating_add(self.config.packet_expiration.as_secs()), - frame: StreamFrame::Close(StreamFrameClose { - stream_id, - target, - code, - payload, - }), - }); - let record = encrypt_stream( - QlHeader { - sender: self.identity.xid, - recipient, - }, - &session_key, - &body, - encrypted_message_nonce(crypto), - ); - wire::encode_record(&record) - } - }; - return Some(self.issue_write(message.kind, Some(message.token), bytes)); + return Some(self.issue_write( + OutboundWriteKind::Control, + Some(message.token), + message.bytes, + )); } None } - fn send_ephemeral_close(&mut self, stream_id: StreamId, target: CloseTarget, code: CloseCode) { - self.state - .enqueue_stream_close(&self.config, true, stream_id, target, code, Vec::new()); - } - - fn send_heartbeat_message(&mut self, now: Instant, crypto: &impl QlCrypto) { + fn send_heartbeat_message(&mut self, crypto: &impl QlCrypto) { let Some(peer) = self.peer.as_ref().map(|peer| peer.peer) else { return; }; + let now = self.state.now; let meta = self.next_control_meta(self.config.packet_expiration); let token = self.state.next_token(); let deadline = now + self.config.packet_expiration; @@ -596,7 +737,8 @@ impl Engine { .filter(|config| !config.interval.is_zero() && !config.timeout.is_zero()) } - fn record_activity(&mut self, now: Instant) { + fn record_activity(&mut self) { + let now = self.state.now; if let Some(PeerRecord { session: PeerSession::Connected { keepalive, .. }, .. @@ -611,279 +753,6 @@ impl Engine { self.state.control_outbound.clear(); self.state.active_writes.clear(); } - - fn fail_handshake(&mut self, error: QlError, emit: &mut impl OutputFn) { - if let Some(entry) = self.peer.as_mut() { - if matches!( - entry.session, - PeerSession::Initiator { .. } | PeerSession::Responder { .. } - ) { - entry.session = PeerSession::Disconnected; - } - } - self.emit_peer_status(emit); - self.drop_outbound(); - self.abort_streams(error, emit); - } - - fn handle_handshake_retry_timeout(&mut self, token: Token, emit: &mut impl OutputFn) { - enum RetryAction { - Resend { - peer: XID, - deadline: Instant, - record: wire::handshake::HandshakeRecord, - }, - Fail, - Ignore, - } - - let now = self.state.now; - let action = { - let Some(entry) = self.peer.as_mut() else { - return; - }; - let peer = entry.peer; - match &mut entry.session { - PeerSession::Initiator { - handshake_token, - hello, - deadline, - stage: - HandshakeInitiator::WaitingHelloReply { - retry_count, - retry_at, - }, - .. - } if *handshake_token == token && retry_at.is_some_and(|at| at <= now) => { - *retry_at = None; - if *retry_count >= self.config.max_handshake_retries { - RetryAction::Fail - } else { - *retry_count = retry_count.saturating_add(1); - RetryAction::Resend { - peer, - deadline: *deadline, - record: wire::handshake::HandshakeRecord::Hello(hello.clone()), - } - } - } - PeerSession::Initiator { - handshake_token, - deadline, - stage: - HandshakeInitiator::WaitingReady { - confirm, - retry_count, - retry_at, - .. - }, - .. - } if *handshake_token == token && retry_at.is_some_and(|at| at <= now) => { - *retry_at = None; - if *retry_count >= self.config.max_handshake_retries { - RetryAction::Fail - } else { - *retry_count = retry_count.saturating_add(1); - RetryAction::Resend { - peer, - deadline: *deadline, - record: wire::handshake::HandshakeRecord::Confirm(confirm.clone()), - } - } - } - PeerSession::Responder { - handshake_token, - reply, - deadline, - stage: - HandshakeResponder::WaitingConfirm { - retry_count, - retry_at, - .. - }, - .. - } if *handshake_token == token && retry_at.is_some_and(|at| at <= now) => { - *retry_at = None; - if *retry_count >= self.config.max_handshake_retries { - RetryAction::Fail - } else { - *retry_count = retry_count.saturating_add(1); - RetryAction::Resend { - peer, - deadline: *deadline, - record: wire::handshake::HandshakeRecord::HelloReply(reply.clone()), - } - } - } - _ => RetryAction::Ignore, - } - }; - - match action { - RetryAction::Resend { - peer, - deadline, - record, - } => { - if self.handshake_write_pending(token) { - return; - } - handshake::enqueue_handshake_record(self, token, deadline, peer, record); - } - RetryAction::Fail => self.fail_handshake(QlError::Timeout, emit), - RetryAction::Ignore => {} - } - } - - fn abort_streams(&mut self, error: QlError, emit: &mut impl OutputFn) { - let streams = std::mem::take(&mut self.streams).into_inner(); - for (stream_id, stream) in streams { - self.fail_stream(stream_id, stream, error.clone(), emit); - } - } - - fn fail_stream_by_id(&mut self, stream_id: StreamId, error: QlError, emit: &mut impl OutputFn) { - let Some(stream) = self.streams.remove(&stream_id) else { - return; - }; - self.fail_stream(stream_id, stream, error, emit); - } - - pub fn fail_stream( - &mut self, - stream_id: StreamId, - stream: StreamState, - error: QlError, - emit: &mut impl OutputFn, - ) { - self.clear_active_writes_for_stream(stream_id); - match stream.role { - StreamRole::Initiator(_) => { - emit(EngineOutput::OutboundFailed { - stream_id, - error: error.clone(), - }); - emit(EngineOutput::InboundFailed { stream_id, error }); - } - StreamRole::Responder(stream) => { - emit(EngineOutput::InboundFailed { - stream_id, - error: error.clone(), - }); - if stream.response_started || stream.response.is_closed() { - emit(EngineOutput::OutboundFailed { stream_id, error }); - } - } - StreamRole::Provisional(_) => {} - } - emit(EngineOutput::StreamReaped { stream_id }); - } - - pub fn handle_timeouts( - &mut self, - now: Instant, - crypto: &impl QlCrypto, - emit: &mut impl OutputFn, - ) { - loop { - let Some(entry) = self - .state - .timeouts - .peek_mut() - .filter(|entry| entry.0.at <= now) - else { - break; - }; - let entry = std::collections::binary_heap::PeekMut::pop(entry).0; - match entry.kind { - TimeoutKind::Outbound { token } => { - self.state - .control_outbound - .retain(|message| message.token != token); - } - TimeoutKind::HandshakeRetry { token } => { - self.handle_handshake_retry_timeout(token, emit); - } - TimeoutKind::StreamAckDelay { stream_id, token } => { - if let Some(stream) = self.streams.get_mut(&stream_id) { - let control = &mut stream.control; - if control.ack_delay_token == Some(token) { - control.ack_delay_token = None; - control.ack_immediate = true; - } - } - } - TimeoutKind::StreamProvisional { stream_id, token } => { - let should_reset = self - .streams - .get(&stream_id) - .and_then(StreamState::provisional_timeout_token) - .is_some_and(|stream_token| stream_token == token); - if should_reset { - self.streams.remove(&stream_id); - self.send_ephemeral_close( - stream_id, - CloseTarget::Both, - CloseCode::PROTOCOL, - ); - } - } - } - } - - if let Some(PeerRecord { - session: PeerSession::Connected { recent_ready, .. }, - .. - }) = self.peer.as_mut() - { - if recent_ready - .as_ref() - .is_some_and(|ready| ready.expires_at <= now) - { - *recent_ready = None; - } - } - - let handshake_due = self - .handshake_deadline() - .is_some_and(|deadline| deadline <= now); - if handshake_due { - self.fail_handshake(QlError::Timeout, emit); - return; - } - - let keepalive_due = self - .keep_alive_deadline() - .is_some_and(|deadline| deadline <= now); - if !keepalive_due { - return; - } - - let Some(entry) = self.peer.as_ref() else { - return; - }; - let PeerSession::Connected { keepalive, .. } = &entry.session else { - return; - }; - - if keepalive.pending { - if let Some(entry) = self.peer.as_mut() { - entry.session = PeerSession::Disconnected; - } - self.emit_peer_status(emit); - self.drop_outbound(); - self.abort_streams(QlError::SendFailed, emit); - return; - } - - self.send_heartbeat_message(now, crypto); - if let Some(entry) = self.peer.as_mut() { - if let PeerSession::Connected { keepalive, .. } = &mut entry.session { - keepalive.pending = true; - keepalive.last_activity = Some(now); - } - } - } } fn encrypted_message_nonce(crypto: &impl QlCrypto) -> [u8; NONCE_SIZE] { diff --git a/ql-engine/src/engine/implementation/peer.rs b/ql-engine/src/engine/implementation/peer.rs index ef2dd06b..05f7d76f 100644 --- a/ql-engine/src/engine/implementation/peer.rs +++ b/ql-engine/src/engine/implementation/peer.rs @@ -1,16 +1,20 @@ use super::*; -pub fn handle_bind_peer(engine: &mut Engine, peer: Peer, emit: &mut impl OutputFn) { - if let Some(existing) = engine.peer.as_ref() { - emit(EngineOutput::PeerStatusChanged { - peer: existing.peer, - session: PeerSession::Disconnected, - }); +pub fn handle_bind_peer(engine: &mut Engine, peer: Peer) { + if let Some(peer) = engine.peer.as_ref().map(|existing| existing.peer) { + engine + .state + .pending_events + .push_back(EngineEvent::PeerStatusChanged { + peer, + session: PeerSession::Disconnected, + }); } - bind_peer_record(engine, peer, emit); + bind_peer_record(engine, peer); } -pub fn handle_pair_local(engine: &mut Engine, now: Instant, crypto: &impl QlCrypto) { +pub fn handle_pair_local(engine: &mut Engine, crypto: &impl QlCrypto) { + let now = engine.state.now; let Some(peer) = engine.peer.as_ref() else { return; }; @@ -33,7 +37,8 @@ pub fn handle_pair_local(engine: &mut Engine, now: Instant, crypto: &impl QlCryp ); } -pub fn handle_unpair_local(engine: &mut Engine, now: Instant, emit: &mut impl OutputFn) { +pub fn handle_unpair_local(engine: &mut Engine) { + let now = engine.state.now; let Some(peer) = engine.peer.as_ref().map(|peer| peer.peer) else { return; }; @@ -46,7 +51,7 @@ pub fn handle_unpair_local(engine: &mut Engine, now: Instant, emit: &mut impl Ou }, meta, ); - unpair_peer(engine, emit); + unpair_peer(engine); let token = engine.state.next_token(); engine.state.enqueue_handshake_message( &engine.config, @@ -58,11 +63,9 @@ pub fn handle_unpair_local(engine: &mut Engine, now: Instant, emit: &mut impl Ou pub fn handle_pairing( engine: &mut Engine, - now: Instant, header: &QlHeader, request: &mut wire::pair::ArchivedPairRequestRecord, crypto: &impl QlCrypto, - emit: &mut impl OutputFn, ) { let payload = match wire::pair::decrypt_pair_request(&engine.identity, header, request) { Ok(payload) => payload, @@ -87,10 +90,9 @@ pub fn handle_pairing( signing_key: payload.signing_pub_key, encapsulation_key: payload.encapsulation_pub_key, }, - emit, ); } - handshake::handle_connect(engine, now, crypto, emit); + handshake::handle_connect(engine, crypto); } pub fn handle_unpair( @@ -98,7 +100,6 @@ pub fn handle_unpair( peer: XID, header: &QlHeader, record: &wire::unpair::ArchivedUnpairRecord, - emit: &mut impl OutputFn, ) { { let Some(peer_record) = engine.peer.as_ref() else { @@ -112,42 +113,48 @@ pub fn handle_unpair( if engine.is_replayed_control(peer, meta) { return; } - unpair_peer(engine, emit); + unpair_peer(engine); } -fn bind_peer_record(engine: &mut Engine, peer: Peer, emit: &mut impl OutputFn) { - reset_runtime(engine, QlError::Cancelled, emit); +fn bind_peer_record(engine: &mut Engine, peer: Peer) { + reset_runtime(engine, QlError::Cancelled); engine.peer = Some(PeerRecord::new( peer.peer, peer.signing_key, peer.encapsulation_key, )); - engine.emit_peer_status(emit); - if let Some(peer) = engine.peer.as_ref() { - emit(EngineOutput::PersistPeer(peer.snapshot())); + engine.emit_peer_status(); + if let Some(peer) = engine.peer.as_ref().map(PeerRecord::snapshot) { + engine + .state + .pending_events + .push_back(EngineEvent::PersistPeer(peer)); } } -fn reset_runtime(engine: &mut Engine, error: QlError, emit: &mut impl OutputFn) { - let streams = std::mem::take(&mut engine.streams).into_inner(); - for (stream_id, stream) in streams { - engine.fail_stream(stream_id, stream, error.clone(), emit); - } +fn reset_runtime(engine: &mut Engine, error: QlError) { + engine.abort_streams(error); engine.state.control_outbound.clear(); engine.state.active_writes.clear(); engine.state.timeouts.clear(); } -fn unpair_peer(engine: &mut Engine, emit: &mut impl OutputFn) { +fn unpair_peer(engine: &mut Engine) { let Some(peer) = engine.peer.as_ref().map(|peer| peer.peer) else { return; }; engine.drop_outbound(); - engine.abort_streams(QlError::SendFailed, emit); + engine.abort_streams(QlError::SendFailed); engine.peer = None; - emit(EngineOutput::PeerStatusChanged { - peer, - session: PeerSession::Disconnected, - }); - emit(EngineOutput::ClearPeer); + engine + .state + .pending_events + .push_back(EngineEvent::PeerStatusChanged { + peer, + session: PeerSession::Disconnected, + }); + engine + .state + .pending_events + .push_back(EngineEvent::ClearPeer); } diff --git a/ql-engine/src/engine/implementation/stream.rs b/ql-engine/src/engine/implementation/stream.rs index a957017c..02dd25a7 100644 --- a/ql-engine/src/engine/implementation/stream.rs +++ b/ql-engine/src/engine/implementation/stream.rs @@ -1,25 +1,175 @@ -use std::cmp::Reverse; - use super::*; use crate::{ - engine::{ - EngineConfig, EngineState, StreamConfig, - state::{StreamNamespace, TimeoutEntry}, - stream::*, - }, + engine::{state::OutboundWriteKind, Engine, EngineEvent, EngineState, QlCrypto}, + stream::{StreamCloseEvent, StreamCloseKind, StreamError, StreamEventSink, WriteError}, wire::stream::*, }; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum StreamHandleResult { - Keep, - Remove, - Reap, +struct EngineStreamSink<'a> { + state: &'a mut EngineState, +} + +impl EngineStreamSink<'_> { + fn clear_active_writes_for_stream(&mut self, stream_id: StreamId) { + self.state + .active_writes + .retain(|_, active| match active.kind { + OutboundWriteKind::Control => true, + OutboundWriteKind::Stream(completion) => completion.stream_id() != stream_id, + }); + } + + fn emit_remote_close(&mut self, event: StreamCloseEvent) { + let error = QlError::StreamClosed { + target: event.frame.target, + code: event.frame.code, + payload: event.frame.payload, + }; + + match event.role { + crate::stream::StreamLocalRole::Initiator => { + if matches!(event.frame.target, CloseTarget::Request | CloseTarget::Both) { + self.state + .pending_events + .push_back(EngineEvent::OutboundFailed { + stream_id: event.frame.stream_id, + error: error.clone(), + }); + } + if matches!( + event.frame.target, + CloseTarget::Response | CloseTarget::Both + ) { + self.state + .pending_events + .push_back(EngineEvent::InboundFailed { + stream_id: event.frame.stream_id, + error, + }); + } + } + crate::stream::StreamLocalRole::Responder => { + if matches!(event.frame.target, CloseTarget::Request | CloseTarget::Both) { + self.state + .pending_events + .push_back(EngineEvent::InboundFailed { + stream_id: event.frame.stream_id, + error: error.clone(), + }); + } + if matches!( + event.frame.target, + CloseTarget::Response | CloseTarget::Both + ) { + self.state + .pending_events + .push_back(EngineEvent::OutboundFailed { + stream_id: event.frame.stream_id, + error, + }); + } + } + } + } + + fn emit_acked_close(&mut self, event: StreamCloseEvent) { + let affects_outbound = match event.role { + crate::stream::StreamLocalRole::Initiator => { + matches!(event.frame.target, CloseTarget::Request | CloseTarget::Both) + } + crate::stream::StreamLocalRole::Responder => { + matches!( + event.frame.target, + CloseTarget::Response | CloseTarget::Both + ) + } + }; + if !affects_outbound { + return; + } + + self.state + .pending_events + .push_back(EngineEvent::OutboundFailed { + stream_id: event.frame.stream_id, + error: QlError::StreamClosed { + target: event.frame.target, + code: event.frame.code, + payload: event.frame.payload, + }, + }); + } +} + +impl StreamEventSink for EngineStreamSink<'_> { + fn opened( + &mut self, + stream_id: StreamId, + request_head: Vec, + request_prefix: Option, + ) { + self.state + .pending_events + .push_back(EngineEvent::InboundStreamOpened { + stream_id, + request_head, + request_prefix, + }); + } + + fn inbound_data(&mut self, stream_id: StreamId, bytes: Vec) { + self.state + .pending_events + .push_back(EngineEvent::InboundData { stream_id, bytes }); + } + + fn inbound_finished(&mut self, stream_id: StreamId) { + self.state + .pending_events + .push_back(EngineEvent::InboundFinished { stream_id }); + } + + fn inbound_failed(&mut self, stream_id: StreamId, error: StreamError) { + self.state + .pending_events + .push_back(EngineEvent::InboundFailed { + stream_id, + error: stream_error(error), + }); + } + + fn close(&mut self, event: StreamCloseEvent) { + match event.kind { + StreamCloseKind::Acked => self.emit_acked_close(event), + StreamCloseKind::Remote => self.emit_remote_close(event), + } + } + + fn outbound_closed(&mut self, stream_id: StreamId) { + self.state + .pending_events + .push_back(EngineEvent::OutboundClosed { stream_id }); + } + + fn outbound_failed(&mut self, stream_id: StreamId, error: StreamError) { + self.state + .pending_events + .push_back(EngineEvent::OutboundFailed { + stream_id, + error: stream_error(error), + }); + } + + fn reaped(&mut self, stream_id: StreamId) { + self.clear_active_writes_for_stream(stream_id); + self.state + .pending_events + .push_back(EngineEvent::StreamReaped { stream_id }); + } } pub fn open_stream( engine: &mut Engine, - now: Instant, request_head: Vec, request_prefix: Option, _config: StreamConfig, @@ -31,128 +181,48 @@ pub fn open_stream( return Err(QlError::MissingSession); } - let stream_namespace = StreamNamespace::for_local(engine.identity.xid, entry.peer); - let stream_id = engine.state.next_stream_id(stream_namespace); - let request_prefix_fin = request_prefix.as_ref().is_some_and(|chunk| chunk.fin); - let frame = StreamFrameOpen { - stream_id, - request_head, - request_prefix, - }; - let mut stream = StreamState { - meta: StreamMeta { - stream_id, - last_activity: now, - }, - control: StreamControl { - pending: std::collections::VecDeque::from([StreamFrame::Open(frame)]), - ..Default::default() - }, - role: StreamRole::Initiator(InitiatorStream { - request: OutboundPhase::from_prefix(request_prefix_fin), - response: InboundState::new(), - }), - }; - drive_stream(&mut stream); - engine.streams.insert(stream_id, stream); - Ok(stream_id) + engine.sync_stream_namespace(); + Ok(engine.streams.open_stream(request_head, request_prefix)) } pub fn handle_close_stream( engine: &mut Engine, - now: Instant, stream_id: StreamId, target: CloseTarget, code: CloseCode, payload: Vec, -) { - let Some(stream) = engine.streams.get_mut(&stream_id) else { - return; - }; - - let mut dirty = false; - - if matches!(target, CloseTarget::Request | CloseTarget::Both) { - if let Some(inbound) = stream.inbound_mut(StreamSide::Request) { - dirty |= inbound.close(); - } - if let Some(outbound) = stream.outbound_mut(StreamSide::Request) { - dirty |= outbound.close(); - } - } - if matches!(target, CloseTarget::Response | CloseTarget::Both) { - if let Some(inbound) = stream.inbound_mut(StreamSide::Response) { - dirty |= inbound.close(); - } - if let Some(outbound) = stream.outbound_mut(StreamSide::Response) { - dirty |= outbound.close(); - } - } - - if dirty { - stream - .control - .queue_frame_front(close_frame(stream_id, target, code, payload)); - stream.meta.last_activity = now; - drive_stream(stream); - } +) -> Result<(), QlError> { + engine + .streams + .close_stream(stream_id, target, code, payload) + .map_err(stream_error) } -pub fn handle_outbound_data(engine: &mut Engine, stream_id: StreamId, bytes: Vec) { - if bytes.is_empty() { - return; - } - let Some(stream) = engine.streams.get_mut(&stream_id) else { - return; - }; - let Some(side) = stream.outbound_side() else { - return; - }; - if let StreamRole::Responder(state) = &mut stream.role { - if side == StreamSide::Response { - state.response_started = true; - } - } - let Some(outbound) = stream.outbound_mut(side) else { - return; - }; - if !outbound.can_queue_data() { - return; - } - let chunk = BodyChunk { bytes, fin: false }; - stream - .control - .queue_frame_back(StreamFrame::Data(StreamFrameData { stream_id, chunk })); - drive_stream(stream); +pub fn handle_outbound_data( + engine: &mut Engine, + stream_id: StreamId, + bytes: Vec, +) -> Result<(), QlError> { + engine + .streams + .write_stream(stream_id, bytes) + .map_err(stream_error) } -pub fn handle_outbound_finished(engine: &mut Engine, stream_id: StreamId) { - let Some(stream) = engine.streams.get_mut(&stream_id) else { - return; - }; - let Some(side) = stream.outbound_side() else { - return; - }; - if let StreamRole::Responder(state) = &mut stream.role { - if side == StreamSide::Response { - state.response_started = true; - } - } - let Some(outbound) = stream.outbound_mut(side) else { - return; - }; - outbound.finish(); - drive_stream(stream); +pub fn handle_outbound_finished(engine: &mut Engine, stream_id: StreamId) -> Result<(), QlError> { + engine + .streams + .finish_stream(stream_id) + .map_err(stream_error) } pub fn handle_stream( engine: &mut Engine, - now: Instant, _peer: XID, header: &QlHeader, encrypted: &mut ArchivedEncryptedMessage, - emit: &mut impl OutputFn, ) { + let now = engine.state.now; let body = { let Some(peer_record) = engine.peer.as_ref() else { return; @@ -165,658 +235,103 @@ pub fn handle_stream( Err(_) => return, } }; - engine.record_activity(now); - - let message = match body { - StreamBody::Ack(StreamAckBody { stream_id, ack, .. }) => { - process_stream_ack(engine, now, stream_id, ack, emit); - if let Some(stream) = engine.streams.get_mut(&stream_id) { - stream.meta.last_activity = now; - } - maybe_reap_stream(engine, stream_id, emit); - return; - } - StreamBody::Message(message) => message, - }; - - let stream_id = message.frame.stream_id(); - process_stream_ack(engine, now, stream_id, message.ack, emit); - if !engine.streams.contains_key(&stream_id) { - let Some(peer_record) = engine.peer.as_ref() else { - return; - }; - let local_namespace = StreamNamespace::for_local(engine.identity.xid, peer_record.peer); - if !local_namespace.remote().matches(stream_id) { - return; - } - let token = engine.state.next_token(); - engine.streams.insert( - stream_id, - StreamState { - meta: StreamMeta { - stream_id, - last_activity: now, - }, - control: StreamControl::default(), - role: StreamRole::Provisional(ProvisionalStream { - timeout_token: token, - }), - }, - ); - engine.state.timeouts.push(Reverse(TimeoutEntry { - at: now + engine.config.packet_expiration, - kind: TimeoutKind::StreamProvisional { stream_id, token }, - })); - } + engine.record_activity(); + engine.sync_stream_namespace(); - let disposition = { - let (state, streams) = (&mut engine.state, &mut engine.streams); - let Some(stream) = streams.get_mut(&stream_id) else { - return; - }; - stream.meta.last_activity = now; - - match stream - .control - .buffer_incoming(message.tx_seq, message.frame) - { - BufferIncomingResult::OutOfWindow => { - if stream.is_provisional() { - state.enqueue_stream_close( - &engine.config, - true, - stream_id, - CloseTarget::Both, - CloseCode::PROTOCOL, - Vec::new(), - ); - StreamHandleResult::Remove - } else { - queue_protocol_close(stream, emit); - stream.meta.last_activity = now; - StreamHandleResult::Keep - } - } - BufferIncomingResult::Duplicate | BufferIncomingResult::AlreadyBuffered => { - stream.control.note_ack(true); - schedule_stream_ack(state, &engine.config, stream, now); - StreamHandleResult::Keep - } - BufferIncomingResult::Buffered { out_of_order } => { - stream.control.note_ack(out_of_order); - drain_committed_stream_frames(state, &engine.config, stream, now, emit) - } - } + let mut sink = EngineStreamSink { + state: &mut engine.state, }; - match disposition { - StreamHandleResult::Keep => {} - StreamHandleResult::Remove => { - engine.streams.remove(&stream_id); - } - StreamHandleResult::Reap => { - engine.streams.remove(&stream_id); - emit(EngineOutput::StreamReaped { stream_id }); - } - } + engine.streams.receive(now, body, &mut sink); } pub fn take_next_stream_write( engine: &mut Engine, crypto: &impl QlCrypto, ) -> Option { - let (recipient, session_key) = engine.stream_write_session()?; - let stream_ids: Vec<_> = engine.streams.scan_from_cursor().collect(); - for stream_id in stream_ids { - let write = take_next_write_for_stream(engine, stream_id, recipient, &session_key, crypto); - if write.is_some() { - engine.streams.advance_cursor_after(stream_id); - return write; - } - } - None + let (recipient, session_key) = engine.peer_session()?; + engine.sync_stream_namespace(); + + let outbound = engine.streams.next_outbound( + engine.state.now, + wire::now_secs().saturating_add(engine.config.packet_expiration.as_secs()), + )?; + let record = encrypt_stream( + QlHeader { + sender: engine.identity.xid, + recipient, + }, + &session_key, + &outbound.body, + encrypted_message_nonce(crypto), + ); + + Some(engine.issue_write( + OutboundWriteKind::Stream(outbound.completion), + None, + wire::encode_record(&record), + )) } -pub fn process_stream_ack( +pub fn complete_stream_write( engine: &mut Engine, - now: Instant, - stream_id: StreamId, - ack: StreamAck, - emit: &mut impl OutputFn, + completion: crate::stream::OutboundCompletion, + result: Result<(), QlError>, ) { - if ack == StreamAck::EMPTY { - return; - } - - let should_reap = { - let Some(stream) = engine.streams.get_mut(&stream_id) else { - return; - }; - stream.control.clear_fast_recovery(ack.base); - let fast_retransmit = stream - .control - .fast_retransmit_candidate(ack, engine.config.stream_fast_retransmit_threshold); - - loop { - let acked_tx_seq = stream - .control - .in_flight - .iter() - .find_map(|(tx_seq, in_flight)| match in_flight.write_state { - // ignore acks for writes that have not been sent out yet - InFlightWriteState::Ready => None, - InFlightWriteState::Issued | InFlightWriteState::WaitingRetry { .. } => { - StreamControl::ack_covers(ack, tx_seq).then_some(tx_seq) - } - }); - let Some(tx_seq) = acked_tx_seq else { - break; - }; - let Some(in_flight) = stream.control.remove_in_flight(tx_seq) else { - continue; - }; - - match in_flight.frame { - StreamFrame::Open(StreamFrameOpen { request_prefix, .. }) => { - if let StreamRole::Initiator(stream) = &mut stream.role { - if request_prefix.as_ref().is_some_and(|chunk| chunk.fin) - && stream.request.close() - { - emit(EngineOutput::OutboundClosed { stream_id }); - } - } - } - StreamFrame::Data(StreamFrameData { - chunk: BodyChunk { fin: true, .. }, - .. - }) => { - if let Some(side) = stream.outbound_side() { - if let Some(outbound) = stream.outbound_mut(side) { - if outbound.close() { - emit(EngineOutput::OutboundClosed { stream_id }); - } - } - } - } - StreamFrame::Close(StreamFrameClose { - target, - code, - payload, - .. - }) => { - for side in [StreamSide::Request, StreamSide::Response] { - let affects_outbound = matches!( - (target, side), - (CloseTarget::Request, StreamSide::Request) - | (CloseTarget::Response, StreamSide::Response) - | (CloseTarget::Both, _) - ); - if affects_outbound { - if let Some(outbound) = stream.outbound_mut(side) { - if outbound.close() { - emit(EngineOutput::OutboundFailed { - stream_id, - error: QlError::StreamClosed { - target, - code, - payload: payload.clone(), - }, - }); - } - } - } - } - } - StreamFrame::Data(_) => {} - } - } - - if let Some(tx_seq) = fast_retransmit { - stream.control.schedule_fast_retransmit(tx_seq, now); - } - drive_stream(stream); - stream.can_reap() + let now = engine.state.now; + let mut sink = EngineStreamSink { + state: &mut engine.state, }; - - if should_reap { - engine.streams.remove(&stream_id); - emit(EngineOutput::StreamReaped { stream_id }); - } -} - -fn schedule_stream_ack( - state: &mut EngineState, - config: &EngineConfig, - stream: &mut StreamState, - now: Instant, -) { - let stream_id = stream.meta.stream_id; - let control = &mut stream.control; - if !control.ack_dirty { - return; - } - if control.ack_immediate || config.stream_ack_delay.is_zero() { - control.ack_delay_token = None; - return; - } - if control.ack_delay_token.is_some() { - return; - } - let token = state.next_token(); - control.ack_delay_token = Some(token); - state.timeouts.push(Reverse(TimeoutEntry { - at: now + config.stream_ack_delay, - kind: TimeoutKind::StreamAckDelay { stream_id, token }, - })); -} - -fn drain_committed_stream_frames( - state: &mut EngineState, - config: &EngineConfig, - stream: &mut StreamState, - now: Instant, - emit: &mut impl OutputFn, -) -> StreamHandleResult { - let stream_id = stream.meta.stream_id; - loop { - let next = stream.control.pop_next_committable(); - let Some((_tx_seq, frame)) = next else { - break; - }; - if stream.is_provisional() && !matches!(frame, StreamFrame::Open(_)) { - state.enqueue_stream_close( - config, - true, - stream_id, - CloseTarget::Both, - CloseCode::PROTOCOL, - Vec::new(), - ); - return StreamHandleResult::Remove; - } - match frame { - StreamFrame::Open(frame) => handle_stream_open(stream, now, frame, emit), - StreamFrame::Close(frame) => handle_stream_close_from_peer(stream, frame, emit), - StreamFrame::Data(frame) => handle_stream_data(stream, now, frame, emit), - } - } - stream.control.maybe_force_ack_for_progress(); - schedule_stream_ack(state, config, stream, now); - if stream.can_reap() { - StreamHandleResult::Reap - } else { - StreamHandleResult::Keep - } + engine.streams.complete_outbound( + now, + completion, + result.map_err(|_| WriteError::SendFailed), + &mut sink, + ); } -fn handle_stream_open( - stream: &mut StreamState, - now: Instant, - frame: StreamFrameOpen, - emit: &mut impl OutputFn, -) { - let StreamFrameOpen { - stream_id, - request_head, - request_prefix, - } = frame; - if !stream.is_provisional() { - queue_protocol_close(stream, emit); +pub fn handle_stream_timeouts(engine: &mut Engine) { + let now = engine.state.now; + if !engine + .streams + .next_deadline() + .is_some_and(|deadline| deadline <= now) + { return; } - stream.meta.last_activity = now; - stream.role = StreamRole::Responder(ResponderStream { - request: InboundState::new(), - response: OutboundPhase::from_prefix(false), - response_started: false, - }); - if let Some(chunk) = request_prefix.as_ref() { - let Some(inbound) = stream.inbound_mut(StreamSide::Request) else { - return; - }; - if chunk.fin { - inbound.close(); - } - } - emit(EngineOutput::InboundStreamOpened { - stream_id, - request_head, - request_prefix, - }); -} - -fn handle_stream_close_from_peer( - stream: &mut StreamState, - frame: StreamFrameClose, - emit: &mut impl OutputFn, -) { - let StreamFrameClose { - target, - code, - payload, - .. - } = frame; - apply_remote_close(stream, target, code, payload, emit); -} -fn handle_stream_data( - stream: &mut StreamState, - now: Instant, - frame: StreamFrameData, - emit: &mut impl OutputFn, -) { - let StreamFrameData { stream_id, chunk } = frame; - let Some(side) = stream.inbound_side() else { - queue_protocol_close(stream, emit); - return; + let mut sink = EngineStreamSink { + state: &mut engine.state, }; - let Some(inbound) = stream.inbound_mut(side) else { - queue_protocol_close(stream, emit); - return; - }; - if inbound.closed { - queue_protocol_close(stream, emit); - } else { - if !chunk.bytes.is_empty() { - emit(EngineOutput::InboundData { - stream_id, - bytes: chunk.bytes, - }); - } - if chunk.fin && inbound.close() { - emit(EngineOutput::InboundFinished { stream_id }); - } - } - stream.meta.last_activity = now; + engine.streams.on_timer(now, &mut sink); } -fn drive_stream(stream: &mut StreamState) { - let (meta, control, role) = stream.parts_mut(); - match role { - StreamRole::Initiator(stream) => { - drive_stream_outbound(meta.stream_id, control, Some(&mut stream.request)); - } - StreamRole::Responder(stream) => { - drive_stream_outbound(meta.stream_id, control, Some(&mut stream.response)); - } - StreamRole::Provisional(_) => drive_stream_outbound(meta.stream_id, control, None), - } -} - -fn drive_stream_outbound( - stream_id: StreamId, - control: &mut StreamControl, - mut outbound: Option<&mut OutboundPhase>, -) { - loop { - if control.send_window_has_space() { - if let Some(frame) = control.pending.pop_front() { - enqueue_stream_frame(control, frame, 0); - continue; - } - } - if !control.send_window_has_space() { - return; - } - - let Some(outbound) = outbound.as_deref_mut() else { - return; - }; - if outbound.queue_fin() { - enqueue_stream_frame( - control, - StreamFrame::Data(StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: Vec::new(), - fin: true, - }, - }), - 0, - ); - continue; - } - return; - } -} - -fn enqueue_stream_frame(control: &mut StreamControl, frame: StreamFrame, attempt: u8) { - let tx_seq = control.take_tx_seq(); - enqueue_stream_frame_with_seq(control, tx_seq, frame, attempt); -} - -fn enqueue_stream_frame_with_seq( - control: &mut StreamControl, - tx_seq: StreamSeq, - frame: StreamFrame, - attempt: u8, -) { - control.insert_in_flight(InFlightFrame { - tx_seq, - frame, - attempt, - write_state: InFlightWriteState::Ready, - }); -} - -fn queue_protocol_close(stream: &mut StreamState, emit: &mut impl OutputFn) { - let stream_id = stream.meta.stream_id; - let control = &mut stream.control; - control.clear_transient_buffers(); - control.queue_frame_front(close_frame( - stream_id, - CloseTarget::Both, - CloseCode::PROTOCOL, - Vec::new(), - )); - for side in [StreamSide::Request, StreamSide::Response] { - if let Some(outbound) = stream.outbound_mut(side) { - if outbound.close() { - emit(EngineOutput::OutboundFailed { - stream_id, - error: QlError::StreamProtocol, - }); - } - } - if let Some(inbound) = stream.inbound_mut(side) { - if inbound.close() { - emit(EngineOutput::InboundFailed { - stream_id, - error: QlError::StreamProtocol, - }); - } - } - } - drive_stream(stream); -} - -fn apply_remote_close( - stream: &mut StreamState, - target: CloseTarget, - code: CloseCode, - payload: Vec, - emit: &mut impl OutputFn, -) { - let stream_id = stream.meta.stream_id; - let error = QlError::StreamClosed { - target, - code, - payload: payload.clone(), +pub fn abort_streams(engine: &mut Engine, error: QlError) { + let mut sink = EngineStreamSink { + state: &mut engine.state, }; - if matches!(target, CloseTarget::Request | CloseTarget::Both) { - if let Some(inbound) = stream.inbound_mut(StreamSide::Request) { - if inbound.close() { - emit(EngineOutput::InboundFailed { - stream_id, - error: error.clone(), - }); - } - } - if let Some(outbound) = stream.outbound_mut(StreamSide::Request) { - if outbound.close() { - emit(EngineOutput::OutboundFailed { - stream_id, - error: error.clone(), - }); - } - } - } - if matches!(target, CloseTarget::Response | CloseTarget::Both) { - if let Some(inbound) = stream.inbound_mut(StreamSide::Response) { - if inbound.close() { - emit(EngineOutput::InboundFailed { - stream_id, - error: error.clone(), - }); - } - } - if let Some(outbound) = stream.outbound_mut(StreamSide::Response) { - if outbound.close() { - emit(EngineOutput::OutboundFailed { - stream_id, - error: error.clone(), - }); - } - } - } + engine.streams.abort(stream_error_inverse(error), &mut sink); } -fn maybe_reap_stream(engine: &mut Engine, stream_id: StreamId, emit: &mut impl OutputFn) { - if engine - .streams - .get(&stream_id) - .is_some_and(StreamState::can_reap) - { - engine.streams.remove(&stream_id); - emit(EngineOutput::StreamReaped { stream_id }); +fn stream_error(error: StreamError) -> QlError { + match error { + StreamError::MissingStream | StreamError::NotWritable => QlError::StreamProtocol, + StreamError::SendFailed => QlError::SendFailed, + StreamError::Timeout => QlError::Timeout, + StreamError::Cancelled => QlError::Cancelled, + StreamError::StreamProtocol => QlError::StreamProtocol, } } -fn take_next_write_for_stream( - engine: &mut Engine, - stream_id: StreamId, - recipient: XID, - session_key: &SymmetricKey, - crypto: &impl QlCrypto, -) -> Option { - #[derive(Clone, Copy)] - enum StreamWriteSelection { - Ack, - InitialFrame { tx_seq: StreamSeq }, - RetryFrame { tx_seq: StreamSeq }, - } - - let now = engine.state.now; - let selection = { - let stream = engine.streams.get(&stream_id)?; - let is_provisional = stream.is_provisional(); - let control = &stream.control; - if !is_provisional { - if let Some(tx_seq) = control.in_flight.iter().find_map(|(tx_seq, in_flight)| { - matches!( - in_flight.write_state, - InFlightWriteState::WaitingRetry { retry_at } - if retry_at <= now && in_flight.attempt < engine.config.stream_retry_limit - ) - .then_some(tx_seq) - }) { - Some(StreamWriteSelection::RetryFrame { tx_seq }) - } else if let Some(tx_seq) = control.in_flight.iter().find_map(|(tx_seq, in_flight)| { - matches!(in_flight.write_state, InFlightWriteState::Ready).then_some(tx_seq) - }) { - Some(StreamWriteSelection::InitialFrame { tx_seq }) - } else if control.ack_dirty - && control.ack_immediate - && control.ack_outbound_token.is_none() - { - Some(StreamWriteSelection::Ack) - } else { - None - } - } else if control.ack_dirty && control.ack_immediate && control.ack_outbound_token.is_none() - { - Some(StreamWriteSelection::Ack) - } else { - None - } - }?; - - match selection { - StreamWriteSelection::Ack => { - let token = engine.state.next_token(); - let ack = { - let stream = engine.streams.get_mut(&stream_id)?; - let control = &mut stream.control; - if !(control.ack_dirty - && control.ack_immediate - && control.ack_outbound_token.is_none()) - { - return None; - } - let ack = control.current_ack(); - control.clear_ack_schedule(); - control.note_ack_sent(ack); - control.ack_outbound_token = Some(token); - ack - }; - - let body = StreamBody::Ack(StreamAckBody { - stream_id, - ack, - valid_until: wire::now_secs() - .saturating_add(engine.config.packet_expiration.as_secs()), - }); - let record = encrypt_stream( - QlHeader { - sender: engine.identity.xid, - recipient, - }, - session_key, - &body, - encrypted_message_nonce(crypto), - ); - Some(engine.issue_write( - OutboundWriteKind::StreamAck { stream_id }, - Some(token), - wire::encode_record(&record), - )) - } - StreamWriteSelection::InitialFrame { tx_seq } - | StreamWriteSelection::RetryFrame { tx_seq } => { - let (ack, frame) = { - let stream = engine.streams.get_mut(&stream_id)?; - let inbound_alive = match &stream.role { - StreamRole::Initiator(state) => !state.response.closed, - StreamRole::Responder(state) => !state.request.closed, - StreamRole::Provisional(_) => return None, - }; - let control = &mut stream.control; - let ack = control.take_piggyback_ack(inbound_alive); - let frame = control.mark_write_issued(tx_seq)?; - (ack, frame) - }; - - let body = StreamBody::Message(StreamMessage { - tx_seq, - ack, - valid_until: wire::now_secs() - .saturating_add(engine.config.packet_expiration.as_secs()), - frame, - }); - let record = encrypt_stream( - QlHeader { - sender: engine.identity.xid, - recipient, - }, - session_key, - &body, - encrypted_message_nonce(crypto), - ); - Some(engine.issue_write( - OutboundWriteKind::StreamFrame { stream_id, tx_seq }, - None, - wire::encode_record(&record), - )) - } +fn stream_error_inverse(error: QlError) -> StreamError { + match error { + QlError::SendFailed => StreamError::SendFailed, + QlError::Timeout => StreamError::Timeout, + QlError::Cancelled => StreamError::Cancelled, + QlError::StreamProtocol | QlError::StreamClosed { .. } => StreamError::StreamProtocol, + QlError::NoPeerBound + | QlError::MissingSession + | QlError::InvalidPayload + | QlError::InvalidSignature => StreamError::Cancelled, } } diff --git a/ql-engine/src/engine/mod.rs b/ql-engine/src/engine/mod.rs index 1922ec5b..ed4f964a 100644 --- a/ql-engine/src/engine/mod.rs +++ b/ql-engine/src/engine/mod.rs @@ -1,9 +1,6 @@ mod implementation; pub mod replay_cache; -mod ring; mod state; -pub(crate) mod stream; - #[cfg(test)] mod tests; @@ -17,6 +14,7 @@ pub use state::{ use crate::{ identity::QlIdentity, + stream, wire::stream::{BodyChunk, CloseCode, CloseTarget}, Peer, QlError, StreamId, }; @@ -25,70 +23,8 @@ pub trait QlCrypto { fn fill_random_bytes(&self, data: &mut [u8]); } -#[derive(Debug, Clone, Copy)] -pub struct KeepAliveConfig { - pub interval: Duration, - pub timeout: Duration, -} - -#[derive(Debug, Clone, Copy, Default)] -pub struct StreamConfig {} - -#[derive(Debug, Clone, Copy)] -pub struct EngineConfig { - pub handshake_timeout: Duration, - pub handshake_retry_interval: Duration, - pub max_handshake_retries: u8, - pub packet_expiration: Duration, - pub stream_ack_delay: Duration, - pub stream_ack_timeout: Duration, - pub stream_fast_retransmit_threshold: u8, - pub stream_retry_limit: u8, - pub keep_alive: Option, -} - -impl Default for EngineConfig { - fn default() -> Self { - Self { - handshake_timeout: Duration::from_secs(5), - handshake_retry_interval: Duration::from_millis(750), - max_handshake_retries: 3, - packet_expiration: Duration::from_secs(30), - stream_ack_delay: Duration::from_millis(5), - stream_ack_timeout: Duration::from_millis(150), - stream_fast_retransmit_threshold: 2, - stream_retry_limit: 5, - keep_alive: None, - } - } -} - -#[derive(Debug)] -pub enum EngineInput { - BindPeer(Peer), - Pair, - Connect, - Unpair, - CloseStream { - stream_id: StreamId, - target: CloseTarget, - code: CloseCode, - payload: Vec, - }, - - OutboundData { - stream_id: StreamId, - bytes: Vec, - }, - OutboundFinished { - stream_id: StreamId, - }, - Incoming(Vec), - TimerExpired, -} - -#[derive(Debug)] -pub enum EngineOutput { +#[derive(Debug, Clone)] +pub enum EngineEvent { PeerStatusChanged { peer: XID, session: PeerSession, @@ -126,46 +62,170 @@ pub enum EngineOutput { }, } -pub trait OutputFn: FnMut(EngineOutput) {} +#[derive(Debug, Clone, Copy)] +pub struct KeepAliveConfig { + pub interval: Duration, + pub timeout: Duration, +} -impl OutputFn for T where T: FnMut(EngineOutput) {} +#[derive(Debug, Clone, Copy, Default)] +pub struct StreamConfig {} + +#[derive(Debug, Clone, Copy)] +pub struct EngineConfig { + pub handshake_timeout: Duration, + pub handshake_retry_interval: Duration, + pub max_handshake_retries: u8, + pub packet_expiration: Duration, + pub stream_ack_delay: Duration, + pub stream_ack_timeout: Duration, + pub stream_fast_retransmit_threshold: u8, + pub stream_retry_limit: u8, + pub keep_alive: Option, +} + +impl Default for EngineConfig { + fn default() -> Self { + Self { + handshake_timeout: Duration::from_secs(5), + handshake_retry_interval: Duration::from_millis(750), + max_handshake_retries: 3, + packet_expiration: Duration::from_secs(30), + stream_ack_delay: Duration::from_millis(5), + stream_ack_timeout: Duration::from_millis(150), + stream_fast_retransmit_threshold: 2, + stream_retry_limit: 5, + keep_alive: None, + } + } +} impl Engine { pub fn new(config: EngineConfig, identity: QlIdentity, peer: Option) -> Self { + let local_namespace = peer + .as_ref() + .map(|peer| stream::StreamNamespace::for_local(identity.xid, peer.peer)) + .map(|namespace| match namespace { + stream::StreamNamespace::Low => crate::stream::StreamNamespace::Low, + stream::StreamNamespace::High => crate::stream::StreamNamespace::High, + }) + .unwrap_or(crate::stream::StreamNamespace::Low); Self { config: config, identity, peer: peer .map(|peer| PeerRecord::new(peer.peer, peer.signing_key, peer.encapsulation_key)), state: EngineState::new(), - streams: stream::StreamStore::default(), + streams: stream::StreamFsm::new(stream::StreamFsmConfig { + local_namespace, + ack_delay: config.stream_ack_delay, + ack_timeout: config.stream_ack_timeout, + fast_retransmit_threshold: config.stream_fast_retransmit_threshold, + retry_limit: config.stream_retry_limit, + }), } } - pub fn run_tick( + pub fn open_stream( &mut self, now: Instant, - input: EngineInput, - crypto: &impl QlCrypto, - emit: &mut impl OutputFn, - ) { - self.run_tick_inner(now, input, crypto, emit); + request_head: Vec, + request_prefix: Option, + config: StreamConfig, + ) -> Result { + self.state.now = now; + self.open_stream_inner(request_head, request_prefix, config) + } + + pub fn bind_peer(&mut self, now: Instant, peer: Peer) { + self.state.now = now; + self.bind_peer_inner(peer); + } + + pub fn pair(&mut self, now: Instant, crypto: &impl QlCrypto) { + self.state.now = now; + self.pair_inner(crypto); + } + + pub fn connect(&mut self, now: Instant, crypto: &impl QlCrypto) { + self.state.now = now; + self.connect_inner(crypto); } - pub fn take_next_write(&mut self, crypto: &impl QlCrypto) -> Option { + pub fn unpair(&mut self, now: Instant) { + self.state.now = now; + self.unpair_inner(); + } + + pub fn take_next_write( + &mut self, + now: Instant, + crypto: &impl QlCrypto, + ) -> Option { + self.state.now = now; self.take_next_write_inner(crypto) } - pub fn complete_write( + pub fn complete_write(&mut self, now: Instant, write_id: WriteId, result: Result<(), QlError>) { + self.state.now = now; + self.complete_write_inner(write_id, result); + } + + pub fn write_stream( + &mut self, + now: Instant, + stream_id: StreamId, + bytes: Vec, + ) -> Result<(), QlError> { + self.state.now = now; + self.write_stream_inner(stream_id, bytes) + } + + pub fn finish_stream(&mut self, now: Instant, stream_id: StreamId) -> Result<(), QlError> { + self.state.now = now; + self.finish_stream_inner(stream_id) + } + + pub fn close_stream( &mut self, - write_id: WriteId, - result: Result<(), QlError>, - emit: &mut impl OutputFn, - ) { - self.complete_write_inner(write_id, result, emit); + now: Instant, + stream_id: StreamId, + target: CloseTarget, + code: CloseCode, + payload: Vec, + ) -> Result<(), QlError> { + self.state.now = now; + self.close_stream_inner(stream_id, target, code, payload) + } + + pub fn receive(&mut self, now: Instant, bytes: Vec, crypto: &impl QlCrypto) { + self.state.now = now; + self.receive_inner(bytes, crypto); + } + + pub fn on_timer(&mut self, now: Instant, crypto: &impl QlCrypto) { + self.state.now = now; + self.on_timer_inner(crypto); } pub fn next_deadline(&self) -> Option { self.next_deadline_inner() } + + pub fn take_next_event(&mut self) -> Option { + self.state.pending_events.pop_front() + } + + pub fn has_pending_events(&self) -> bool { + !self.state.pending_events.is_empty() + } + + pub fn drain_events(&mut self) -> std::collections::vec_deque::Drain<'_, EngineEvent> { + self.state.pending_events.drain(..) + } + + pub fn abort(&mut self, now: Instant, error: QlError) { + self.state.now = now; + self.abort_inner(error); + } } diff --git a/ql-engine/src/engine/ring.rs b/ql-engine/src/engine/ring.rs deleted file mode 100644 index 4ad7f567..00000000 --- a/ql-engine/src/engine/ring.rs +++ /dev/null @@ -1,392 +0,0 @@ -use std::array; - -use crate::wire::StreamSeq; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum SeqRingInsertError { - OutOfWindow, - Occupied, -} - -#[derive(Debug)] -pub struct SeqRing { - base_seq: StreamSeq, - head: usize, - len: usize, - slots: [Option; N], -} - -impl SeqRing { - pub fn new(base_seq: StreamSeq) -> Self { - Self { - base_seq, - head: 0, - len: 0, - slots: array::from_fn(|_| None), - } - } - - pub fn base_seq(&self) -> StreamSeq { - self.base_seq - } - - pub fn len(&self) -> usize { - self.len - } - - pub fn is_empty(&self) -> bool { - self.len == 0 - } - - pub fn clear_with_base(&mut self, base_seq: StreamSeq) { - for slot in &mut self.slots { - let _ = slot.take(); - } - self.base_seq = base_seq; - self.head = 0; - self.len = 0; - } - - pub fn contains_key(&self, seq: &StreamSeq) -> bool { - self.get(seq).is_some() - } - - pub fn accepts_seq(&self, seq: StreamSeq) -> bool { - self.offset_for(seq).is_some() - } - - pub fn get(&self, seq: &StreamSeq) -> Option<&T> { - let index = self.index_for(*seq)?; - self.slots[index].as_ref() - } - - pub fn get_mut(&mut self, seq: &StreamSeq) -> Option<&mut T> { - let index = self.index_for(*seq)?; - self.slots[index].as_mut() - } - - pub fn insert(&mut self, seq: StreamSeq, value: T) -> Result<(), SeqRingInsertError> { - let index = self.index_for(seq).ok_or(SeqRingInsertError::OutOfWindow)?; - if self.slots[index].is_some() { - return Err(SeqRingInsertError::Occupied); - } - self.slots[index] = Some(value); - self.len += 1; - Ok(()) - } - - pub fn set(&mut self, seq: StreamSeq, value: T) -> Result, SeqRingInsertError> { - let index = self.index_for(seq).ok_or(SeqRingInsertError::OutOfWindow)?; - let previous = self.slots[index].replace(value); - if previous.is_none() { - self.len += 1; - } - Ok(previous) - } - - pub fn remove(&mut self, seq: &StreamSeq) -> Option { - let index = self.index_for(*seq)?; - let value = self.slots[index].take(); - if value.is_some() { - self.len -= 1; - } - value - } - - pub fn take_front(&mut self) -> Option<(StreamSeq, T)> { - let value = self.slots[self.head].take()?; - let seq = self.base_seq; - self.len -= 1; - self.head = self.next_index(self.head); - self.base_seq = self.base_seq.next(); - Some((seq, value)) - } - - pub fn advance_empty_front_until(&mut self, limit_seq: StreamSeq) { - while self.base_seq.serial_lt(limit_seq) && self.slots[self.head].is_none() { - self.head = self.next_index(self.head); - self.base_seq = self.base_seq.next(); - } - } - - pub fn drain_front(&mut self) -> SeqRingDrain<'_, N, T> { - SeqRingDrain { ring: self } - } - - pub fn iter(&self) -> SeqRingIter<'_, N, T> { - SeqRingIter { - ring: self, - offset: 0, - } - } - - pub fn bitmap(&self) -> u8 { - debug_assert!(N <= 8); - let mut bitmap = 0u8; - for offset in 0..N { - let index = self.index_for_offset(offset); - if self.slots[index].is_some() { - bitmap |= 1u8 << offset; - } - } - bitmap - } - - fn index_for(&self, seq: StreamSeq) -> Option { - let offset = self.offset_for(seq)?; - Some(self.index_for_offset(offset)) - } - - fn offset_for(&self, seq: StreamSeq) -> Option { - let offset = self.base_seq.forward_distance_to(seq)? as usize; - (offset < N).then_some(offset) - } - - fn index_for_offset(&self, offset: usize) -> usize { - (self.head + offset) % N - } - - fn next_index(&self, index: usize) -> usize { - (index + 1) % N - } -} - -pub struct SeqRingIter<'a, const N: usize, T> { - ring: &'a SeqRing, - offset: usize, -} - -impl<'a, const N: usize, T> Iterator for SeqRingIter<'a, N, T> { - type Item = (StreamSeq, &'a T); - - fn next(&mut self) -> Option { - while self.offset < N { - let offset = self.offset; - self.offset += 1; - let index = self.ring.index_for_offset(offset); - if let Some(value) = self.ring.slots[index].as_ref() { - let seq = self.ring.base_seq.add(offset as u32); - return Some((seq, value)); - } - } - None - } -} - -pub struct SeqRingDrain<'a, const N: usize, T> { - ring: &'a mut SeqRing, -} - -impl<'a, const N: usize, T> Iterator for SeqRingDrain<'a, N, T> { - type Item = (StreamSeq, T); - - fn next(&mut self) -> Option { - self.ring.take_front() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - engine::stream::{BufferIncomingResult, InFlightFrame, InFlightWriteState, StreamControl}, - wire::stream::{BodyChunk, StreamAck, StreamFrame, StreamFrameData, StreamFrameOpen}, - StreamId, - }; - - fn data_frame(stream_id: StreamId, tx_seq: u32, byte: u8) -> (StreamSeq, StreamFrame) { - ( - StreamSeq(tx_seq), - StreamFrame::Data(StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: vec![byte], - fin: false, - }, - }), - ) - } - - #[test] - fn seq_ring_drain_front_takes_contiguous_items_in_order() { - let mut ring = SeqRing::<8, u64>::new(StreamSeq(1)); - ring.insert(StreamSeq(2), 20).unwrap(); - ring.insert(StreamSeq(1), 10).unwrap(); - ring.insert(StreamSeq(3), 30).unwrap(); - - let drained: Vec<_> = ring.drain_front().collect(); - assert_eq!( - drained, - vec![(StreamSeq(1), 10), (StreamSeq(2), 20), (StreamSeq(3), 30)] - ); - assert_eq!(ring.base_seq(), StreamSeq(4)); - assert!(ring.is_empty()); - } - - #[test] - fn seq_ring_wraps_and_reuses_slots() { - let mut ring = SeqRing::<4, u64>::new(StreamSeq(1)); - ring.insert(StreamSeq(1), 1).unwrap(); - ring.insert(StreamSeq(2), 2).unwrap(); - ring.insert(StreamSeq(3), 3).unwrap(); - - assert_eq!(ring.take_front(), Some((StreamSeq(1), 1))); - assert_eq!(ring.take_front(), Some((StreamSeq(2), 2))); - - ring.insert(StreamSeq(4), 4).unwrap(); - ring.insert(StreamSeq(5), 5).unwrap(); - - let remaining: Vec<_> = ring.iter().map(|(seq, value)| (seq, *value)).collect(); - assert_eq!( - remaining, - vec![(StreamSeq(3), 3), (StreamSeq(4), 4), (StreamSeq(5), 5)] - ); - } - - #[test] - fn seq_ring_selective_take_can_slide_past_empty_front() { - let mut ring = SeqRing::<8, u64>::new(StreamSeq(1)); - for value in 1..=4 { - ring.insert(StreamSeq(value), value as u64).unwrap(); - } - - assert_eq!(ring.remove(&StreamSeq(2)), Some(2)); - assert_eq!(ring.remove(&StreamSeq(3)), Some(3)); - ring.advance_empty_front_until(StreamSeq(5)); - assert_eq!(ring.base_seq(), StreamSeq(1)); - - assert_eq!(ring.remove(&StreamSeq(1)), Some(1)); - ring.advance_empty_front_until(StreamSeq(5)); - assert_eq!(ring.base_seq(), StreamSeq(4)); - - assert_eq!(ring.remove(&StreamSeq(4)), Some(4)); - ring.advance_empty_front_until(StreamSeq(5)); - assert_eq!(ring.base_seq(), StreamSeq(5)); - assert!(ring.is_empty()); - } - - #[test] - fn stream_control_recv_buffer_preserves_ack_bitmap_and_drain_order() { - let stream_id = StreamId(7); - let mut control = StreamControl::default(); - - let (seq2, frame2) = data_frame(stream_id, 2, b'b'); - let (seq1, frame1) = data_frame(stream_id, 1, b'a'); - let (seq3, frame3) = data_frame(stream_id, 3, b'c'); - - assert!(matches!( - control.buffer_incoming(seq2, frame2), - BufferIncomingResult::Buffered { out_of_order: true } - )); - assert_eq!(control.current_ack().base, StreamSeq(0)); - assert_eq!(control.current_ack().bitmap, 0b0000_0010); - - assert!(matches!( - control.buffer_incoming(seq1, frame1), - BufferIncomingResult::Buffered { - out_of_order: false - } - )); - assert!(matches!( - control.buffer_incoming(seq3, frame3), - BufferIncomingResult::Buffered { out_of_order: true } - )); - - let committed: Vec<_> = std::iter::from_fn(|| control.pop_next_committable()).collect(); - assert_eq!( - committed.iter().map(|(seq, _)| *seq).collect::>(), - vec![StreamSeq(1), StreamSeq(2), StreamSeq(3)] - ); - assert_eq!(control.committed_rx_seq(), StreamSeq(3)); - assert_eq!(control.current_ack().base, StreamSeq(3)); - assert_eq!(control.current_ack().bitmap, 0); - } - - #[test] - fn stream_control_send_window_respects_sequence_range_not_count() { - let stream_id = StreamId(11); - let mut control = StreamControl::default(); - for tx_seq in 1..=8 { - let frame = InFlightFrame { - tx_seq: StreamSeq(tx_seq), - frame: StreamFrame::Open(StreamFrameOpen { - stream_id, - request_head: vec![tx_seq as u8], - request_prefix: None, - }), - attempt: 0, - write_state: InFlightWriteState::Ready, - }; - control.insert_in_flight(frame); - control.next_tx_seq = StreamSeq(tx_seq + 1); - } - - assert!(!control.send_window_has_space()); - let _ = control.remove_in_flight(StreamSeq(8)); - assert!(!control.send_window_has_space()); - let _ = control.remove_in_flight(StreamSeq(1)); - assert!(control.send_window_has_space()); - assert_eq!(control.in_flight.base_seq(), StreamSeq(2)); - } - - #[test] - fn ack_coverage_handles_wraparound_bitmap() { - let ack = StreamAck { - base: StreamSeq(u32::MAX), - bitmap: 0b0000_0011, - }; - - assert!(StreamControl::ack_covers(ack, StreamSeq(u32::MAX - 1))); - assert!(StreamControl::ack_covers(ack, StreamSeq(u32::MAX))); - assert!(StreamControl::ack_covers(ack, StreamSeq(0))); - assert!(StreamControl::ack_covers(ack, StreamSeq(1))); - assert!(!StreamControl::ack_covers(ack, StreamSeq(2))); - } - - #[test] - fn seq_ring_accepts_window_across_sequence_overflow() { - let mut ring = SeqRing::<4, u64>::new(StreamSeq(u32::MAX - 1)); - ring.insert(StreamSeq(u32::MAX - 1), 1).unwrap(); - ring.insert(StreamSeq(u32::MAX), 2).unwrap(); - ring.insert(StreamSeq(0), 3).unwrap(); - - assert_eq!(ring.take_front(), Some((StreamSeq(u32::MAX - 1), 1))); - assert_eq!(ring.take_front(), Some((StreamSeq(u32::MAX), 2))); - - ring.insert(StreamSeq(1), 4).unwrap(); - ring.insert(StreamSeq(2), 5).unwrap(); - - let remaining: Vec<_> = ring.iter().map(|(seq, value)| (seq, *value)).collect(); - assert_eq!( - remaining, - vec![(StreamSeq(0), 3), (StreamSeq(1), 4), (StreamSeq(2), 5)] - ); - } - - #[test] - fn seq_ring_selective_take_slides_across_sequence_overflow() { - let mut ring = SeqRing::<8, u64>::new(StreamSeq(u32::MAX - 1)); - for (seq, value) in [ - (StreamSeq(u32::MAX - 1), 1u64), - (StreamSeq(u32::MAX), 2u64), - (StreamSeq(0), 3u64), - (StreamSeq(1), 4u64), - ] { - ring.insert(seq, value).unwrap(); - } - - assert_eq!(ring.remove(&StreamSeq(u32::MAX)), Some(2)); - assert_eq!(ring.remove(&StreamSeq(0)), Some(3)); - ring.advance_empty_front_until(StreamSeq(2)); - assert_eq!(ring.base_seq(), StreamSeq(u32::MAX - 1)); - - assert_eq!(ring.remove(&StreamSeq(u32::MAX - 1)), Some(1)); - ring.advance_empty_front_until(StreamSeq(2)); - assert_eq!(ring.base_seq(), StreamSeq(1)); - - assert_eq!(ring.remove(&StreamSeq(1)), Some(4)); - ring.advance_empty_front_until(StreamSeq(2)); - assert_eq!(ring.base_seq(), StreamSeq(2)); - assert!(ring.is_empty()); - } -} diff --git a/ql-engine/src/engine/state.rs b/ql-engine/src/engine/state.rs index 0bbf9c86..481d4f06 100644 --- a/ql-engine/src/engine/state.rs +++ b/ql-engine/src/engine/state.rs @@ -1,42 +1,31 @@ use std::{ cell::Cell, cmp::Reverse, - collections::{BinaryHeap, HashMap, VecDeque}, + collections::{BinaryHeap, VecDeque}, time::Instant, }; use bc_components::{MLDSAPublicKey, MLKEMPublicKey, SymmetricKey, XID}; -use super::{replay_cache::ReplayCache, stream::StreamStore, EngineConfig}; +use super::{replay_cache::ReplayCache, EngineConfig, EngineEvent}; use crate::{ + arena::{ArenaKey, GenerationalArena}, identity::QlIdentity, - wire::{ - handshake::{Confirm, Hello, HelloReply, Ready, ResponderSecrets}, - stream::{CloseCode, CloseTarget}, - StreamSeq, - }, - PacketId, Peer, StreamId, + stream::{self, StreamFsm}, + wire::handshake::{Confirm, Hello, HelloReply, Ready, ResponderSecrets}, + PacketId, Peer, }; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Token(pub u64); -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct WriteId(pub u64); +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct WriteId(pub(crate) ArenaKey); #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum OutboundWriteKind { Control, - StreamAck { - stream_id: StreamId, - }, - StreamFrame { - stream_id: StreamId, - tx_seq: StreamSeq, - }, - StreamClose { - stream_id: StreamId, - }, + Stream(stream::OutboundCompletion), } #[derive(Debug, Clone, PartialEq, Eq)] @@ -48,19 +37,7 @@ pub struct OutboundWrite { #[derive(Debug)] pub struct ControlWrite { pub token: Token, - pub kind: OutboundWriteKind, - pub payload: ControlWritePayload, -} - -#[derive(Debug)] -pub enum ControlWritePayload { - Encoded(Vec), - StreamClose { - stream_id: StreamId, - target: CloseTarget, - code: CloseCode, - payload: Vec, - }, + pub bytes: Vec, } #[derive(Debug, Clone, Copy)] @@ -119,41 +96,6 @@ pub struct RecentReady { pub expires_at: Instant, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum StreamNamespace { - Low, - High, -} - -impl StreamNamespace { - const BIT: u32 = 1 << 31; - - pub fn bit(self) -> u32 { - match self { - Self::Low => 0, - Self::High => Self::BIT, - } - } - - pub fn for_local(local: XID, peer: XID) -> Self { - match local.data().cmp(peer.data()) { - std::cmp::Ordering::Less | std::cmp::Ordering::Equal => Self::Low, - std::cmp::Ordering::Greater => Self::High, - } - } - - pub fn matches(self, stream_id: StreamId) -> bool { - (stream_id.0 & Self::BIT) == self.bit() - } - - pub fn remote(self) -> Self { - match self { - Self::Low => Self::High, - Self::High => Self::Low, - } - } -} - #[derive(Debug, Clone)] pub enum PeerSession { Disconnected, @@ -221,9 +163,6 @@ impl PeerRecord { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum TimeoutKind { Outbound { token: Token }, - HandshakeRetry { token: Token }, - StreamAckDelay { stream_id: StreamId, token: Token }, - StreamProvisional { stream_id: StreamId, token: Token }, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -249,18 +188,17 @@ pub struct Engine { pub identity: QlIdentity, pub peer: Option, pub state: EngineState, - pub streams: StreamStore, + pub streams: StreamFsm, } pub struct EngineState { pub replay_cache: ReplayCache, pub next_token: Cell, - pub next_write_id: Cell, pub next_packet_id: Cell, - pub next_stream_id: Cell, + pub pending_events: VecDeque, pub control_outbound: VecDeque, - pub active_writes: HashMap, + pub active_writes: GenerationalArena, pub timeouts: BinaryHeap>, pub now: Instant, } @@ -270,11 +208,10 @@ impl EngineState { Self { replay_cache: ReplayCache::new(), next_token: Cell::new(1), - next_write_id: Cell::new(1), next_packet_id: Cell::new(1), - next_stream_id: Cell::new(1), + pending_events: VecDeque::new(), control_outbound: VecDeque::new(), - active_writes: HashMap::new(), + active_writes: GenerationalArena::new(), timeouts: BinaryHeap::new(), now: Instant::now(), } @@ -290,24 +227,12 @@ impl EngineState { Token(token) } - pub fn next_write_id(&self) -> WriteId { - let id = self.next_write_id.get(); - self.next_write_id.set(id.wrapping_add(1)); - WriteId(id) - } - pub fn next_packet_id(&self) -> PacketId { let id = self.next_packet_id.get(); self.next_packet_id.set(id.wrapping_add(1)); PacketId(id) } - pub fn next_stream_id(&self, namespace: StreamNamespace) -> StreamId { - let seq = self.next_stream_id.get(); - self.next_stream_id.set(seq.wrapping_add(1)); - StreamId((seq & !StreamNamespace::BIT) | namespace.bit()) - } - pub fn enqueue_handshake_message( &mut self, _config: &EngineConfig, @@ -315,24 +240,14 @@ impl EngineState { deadline: Instant, bytes: Vec, ) { - self.control_outbound.push_back(ControlWrite { - token, - kind: OutboundWriteKind::Control, - payload: ControlWritePayload::Encoded(bytes), - }); + self.control_outbound + .push_back(ControlWrite { token, bytes }); self.timeouts.push(Reverse(TimeoutEntry { at: deadline, kind: TimeoutKind::Outbound { token }, })); } - pub fn schedule_handshake_retry(&mut self, token: Token, at: Instant) { - self.timeouts.push(Reverse(TimeoutEntry { - at, - kind: TimeoutKind::HandshakeRetry { token }, - })); - } - pub fn enqueue_control( &mut self, config: &EngineConfig, @@ -340,43 +255,7 @@ impl EngineState { bytes: Vec, ) -> Token { let token = self.next_token(); - let message = ControlWrite { - token, - kind: OutboundWriteKind::Control, - payload: ControlWritePayload::Encoded(bytes), - }; - if priority { - self.control_outbound.push_front(message); - } else { - self.control_outbound.push_back(message); - } - self.timeouts.push(Reverse(TimeoutEntry { - at: self.now + config.packet_expiration, - kind: TimeoutKind::Outbound { token }, - })); - token - } - - pub fn enqueue_stream_close( - &mut self, - config: &EngineConfig, - priority: bool, - stream_id: StreamId, - target: CloseTarget, - code: CloseCode, - payload: Vec, - ) -> Token { - let token = self.next_token(); - let message = ControlWrite { - token, - kind: OutboundWriteKind::StreamClose { stream_id }, - payload: ControlWritePayload::StreamClose { - stream_id, - target, - code, - payload, - }, - }; + let message = ControlWrite { token, bytes }; if priority { self.control_outbound.push_front(message); } else { diff --git a/ql-engine/src/engine/tests/handshake.rs b/ql-engine/src/engine/tests/handshake.rs index 83fe1f67..399c5b37 100644 --- a/ql-engine/src/engine/tests/handshake.rs +++ b/ql-engine/src/engine/tests/handshake.rs @@ -747,11 +747,7 @@ fn handshake_retry_limit_disconnects_initiator() { let identity = test_identity(); let peer_identity = test_identity(); let mut engine = EngineWrapper::new( - Engine::new( - config, - identity, - Some(peer_from_identity(&peer_identity)), - ), + Engine::new(config, identity, Some(peer_from_identity(&peer_identity))), TestCrypto::new(131), ); let now = Instant::now(); @@ -766,7 +762,8 @@ fn handshake_retry_limit_disconnects_initiator() { assert_eq!(retry_write.bytes, hello_bytes); let _ = engine.complete_write_collect(retry_write.id, Ok(())); - let outputs = engine.run_tick_collect(now + Duration::from_millis(500), EngineInput::TimerExpired); + let outputs = + engine.run_tick_collect(now + Duration::from_millis(500), EngineInput::TimerExpired); assert!(outputs.iter().any(|output| matches!( output, EngineOutput::PeerStatusChanged { @@ -819,6 +816,12 @@ fn simultaneous_connect_converges_to_connected_peers() { pump_between(&mut a, &mut b, now); - assert!(matches!(a.peer.as_ref().map(|peer| &peer.session), Some(PeerSession::Connected { .. }))); - assert!(matches!(b.peer.as_ref().map(|peer| &peer.session), Some(PeerSession::Connected { .. }))); + assert!(matches!( + a.peer.as_ref().map(|peer| &peer.session), + Some(PeerSession::Connected { .. }) + )); + assert!(matches!( + b.peer.as_ref().map(|peer| &peer.session), + Some(PeerSession::Connected { .. }) + )); } diff --git a/ql-engine/src/engine/tests/mod.rs b/ql-engine/src/engine/tests/mod.rs index 1c396ce0..2ab01583 100644 --- a/ql-engine/src/engine/tests/mod.rs +++ b/ql-engine/src/engine/tests/mod.rs @@ -13,12 +13,103 @@ use std::{ use bc_components::{SymmetricKey, MLDSA, MLKEM, XID}; use crate::{ - engine::{state::StreamNamespace, stream::*, *}, + engine::*, identity::QlIdentity, + stream::{state::*, StreamNamespace}, wire::{self, stream::*, QlHeader, QlPayload, QlRecord, StreamSeq}, PacketId, Peer, }; +#[derive(Debug)] +pub enum EngineOutput { + PeerStatusChanged { + peer: XID, + session: PeerSession, + }, + PersistPeer(Peer), + ClearPeer, + + InboundStreamOpened { + stream_id: StreamId, + request_head: Vec, + request_prefix: Option, + }, + InboundData { + stream_id: StreamId, + bytes: Vec, + }, + InboundFinished { + stream_id: StreamId, + }, + InboundFailed { + stream_id: StreamId, + error: QlError, + }, + + OutboundClosed { + stream_id: StreamId, + }, + OutboundFailed { + stream_id: StreamId, + error: QlError, + }, + + StreamReaped { + stream_id: StreamId, + }, +} + +impl EngineEventSink for Vec { + fn peer_status_changed(&mut self, peer: XID, session: PeerSession) { + self.push(EngineOutput::PeerStatusChanged { peer, session }); + } + + fn persist_peer(&mut self, peer: Peer) { + self.push(EngineOutput::PersistPeer(peer)); + } + + fn clear_peer(&mut self) { + self.push(EngineOutput::ClearPeer); + } + + fn inbound_stream_opened( + &mut self, + stream_id: StreamId, + request_head: Vec, + request_prefix: Option, + ) { + self.push(EngineOutput::InboundStreamOpened { + stream_id, + request_head, + request_prefix, + }); + } + + fn inbound_data(&mut self, stream_id: StreamId, bytes: Vec) { + self.push(EngineOutput::InboundData { stream_id, bytes }); + } + + fn inbound_finished(&mut self, stream_id: StreamId) { + self.push(EngineOutput::InboundFinished { stream_id }); + } + + fn inbound_failed(&mut self, stream_id: StreamId, error: QlError) { + self.push(EngineOutput::InboundFailed { stream_id, error }); + } + + fn outbound_closed(&mut self, stream_id: StreamId) { + self.push(EngineOutput::OutboundClosed { stream_id }); + } + + fn outbound_failed(&mut self, stream_id: StreamId, error: QlError) { + self.push(EngineOutput::OutboundFailed { stream_id, error }); + } + + fn stream_reaped(&mut self, stream_id: StreamId) { + self.push(EngineOutput::StreamReaped { stream_id }); + } +} + #[derive(Clone)] struct TestCrypto { nonce_seed: u8, @@ -58,6 +149,29 @@ impl Side { } } +#[allow(dead_code)] +enum EngineInput { + BindPeer(Peer), + Pair, + Connect, + Unpair, + CloseStream { + stream_id: StreamId, + target: CloseTarget, + code: CloseCode, + payload: Vec, + }, + OutboundData { + stream_id: StreamId, + bytes: Vec, + }, + OutboundFinished { + stream_id: StreamId, + }, + Incoming(Vec), + TimerExpired, +} + struct Harness { now: Instant, a: EngineWrapper, @@ -177,10 +291,35 @@ impl EngineWrapper { } fn run_tick(&mut self, now: Instant, input: EngineInput) { - self.engine - .run_tick(now, input, &self.crypto, &mut |output| { - self.outputs.push(output) - }); + match input { + EngineInput::BindPeer(peer) => self.engine.bind_peer(now, peer, &mut self.outputs), + EngineInput::Pair => self.engine.pair(now, &self.crypto), + EngineInput::Connect => self.engine.connect(now, &self.crypto, &mut self.outputs), + EngineInput::Unpair => self.engine.unpair(now, &mut self.outputs), + EngineInput::CloseStream { + stream_id, + target, + code, + payload, + } => { + let _ = self + .engine + .close_stream(now, stream_id, target, code, payload); + } + EngineInput::OutboundData { stream_id, bytes } => { + let _ = self.engine.write_stream(now, stream_id, bytes); + } + EngineInput::OutboundFinished { stream_id } => { + let _ = self.engine.finish_stream(now, stream_id); + } + EngineInput::Incoming(bytes) => { + self.engine + .receive(now, bytes, &self.crypto, &mut self.outputs); + } + EngineInput::TimerExpired => { + self.engine.on_timer(now, &self.crypto, &mut self.outputs); + } + } } fn run_tick_collect(&mut self, now: Instant, input: EngineInput) -> Vec { @@ -190,11 +329,12 @@ impl EngineWrapper { fn complete_write(&mut self, write_id: WriteId, result: Result<(), QlError>) { self.engine - .complete_write(write_id, result, &mut |output| self.outputs.push(output)); + .complete_write(self.engine.state.now, write_id, result, &mut self.outputs); } fn take_next_write(&mut self) -> Option { - self.engine.take_next_write(&self.crypto) + self.engine + .take_next_write(self.engine.state.now, &self.crypto) } fn complete_write_collect( @@ -277,10 +417,6 @@ fn encrypt_heartbeat_record( fn insert_inflight_gap_stream(engine: &mut EngineWrapper, stream_id: StreamId, now: Instant) { let retry_at = now + Duration::from_secs(60); let mut stream = StreamState { - meta: StreamMeta { - stream_id, - last_activity: now, - }, control: StreamControl::default(), role: StreamRole::Initiator(InitiatorStream { request: OutboundPhase::from_prefix(false), @@ -313,7 +449,7 @@ fn insert_inflight_gap_stream(engine: &mut EngineWrapper, stream_id: StreamId, n write_state: InFlightWriteState::WaitingRetry { retry_at }, }); } - engine.streams.insert(stream_id, stream); + engine.streams.streams.insert(stream_id, stream); } fn insert_inflight_stream_with_data( @@ -324,10 +460,6 @@ fn insert_inflight_stream_with_data( ) { let retry_at = now + Duration::from_secs(60); let mut stream = StreamState { - meta: StreamMeta { - stream_id, - last_activity: now, - }, control: StreamControl::default(), role: StreamRole::Initiator(InitiatorStream { request: OutboundPhase::from_prefix(false), @@ -360,20 +492,16 @@ fn insert_inflight_stream_with_data( write_state: InFlightWriteState::WaitingRetry { retry_at }, }); } - engine.streams.insert(stream_id, stream); + engine.streams.streams.insert(stream_id, stream); } fn insert_unwritten_inflight_stream_with_data( engine: &mut EngineWrapper, stream_id: StreamId, - now: Instant, + _now: Instant, data_seqs: &[u32], ) { let mut stream = StreamState { - meta: StreamMeta { - stream_id, - last_activity: now, - }, control: StreamControl::default(), role: StreamRole::Initiator(InitiatorStream { request: OutboundPhase::from_prefix(false), @@ -406,5 +534,5 @@ fn insert_unwritten_inflight_stream_with_data( write_state: InFlightWriteState::Ready, }); } - engine.streams.insert(stream_id, stream); + engine.streams.streams.insert(stream_id, stream); } diff --git a/ql-engine/src/engine/tests/stream.rs b/ql-engine/src/engine/tests/stream.rs index 1c8e7238..739f1736 100644 --- a/ql-engine/src/engine/tests/stream.rs +++ b/ql-engine/src/engine/tests/stream.rs @@ -58,8 +58,8 @@ fn simultaneous_opens_use_disjoint_stream_id_namespaces() { .. } if *stream_id == stream_id_a && request_head == b"a-open" ))); - assert_eq!(harness.a.streams.len(), 2); - assert_eq!(harness.b.streams.len(), 2); + assert_eq!(harness.a.streams.streams.len(), 2); + assert_eq!(harness.b.streams.streams.len(), 2); } #[test] @@ -109,7 +109,7 @@ fn invalid_future_frame_does_not_ack_outstanding_open() { .iter() .any(|output| matches!(output, EngineOutput::InboundData { .. }))); - let stream = engine.streams.get(&stream_id).unwrap(); + let stream = engine.streams.streams.get(&stream_id).unwrap(); assert!(stream.control.in_flight.contains_key(&StreamSeq::START)); } @@ -163,7 +163,7 @@ fn ack_for_issued_open_is_applied_before_write_completion() { bytes, } if *id == stream_id && bytes == b"resp" ))); - let stream = engine.streams.get(&stream_id).unwrap(); + let stream = engine.streams.streams.get(&stream_id).unwrap(); assert!(!stream.control.in_flight.contains_key(&StreamSeq::START)); } @@ -225,7 +225,7 @@ fn ack_does_not_retire_ready_data() { } if *id == stream_id && bytes == b"resp" ))); - let stream = engine.streams.get(&stream_id).unwrap(); + let stream = engine.streams.streams.get(&stream_id).unwrap(); assert!(!stream.control.in_flight.contains_key(&StreamSeq::START)); assert!(stream.control.in_flight.contains_key(&StreamSeq(2))); @@ -309,12 +309,12 @@ fn late_failed_write_after_remote_close_ack_is_ignored() { } if *id == stream_id && payload.is_empty() ))); - let stream = engine.streams.get(&stream_id).unwrap(); + let stream = engine.streams.streams.get(&stream_id).unwrap(); assert!(!stream.control.in_flight.contains_key(&StreamSeq::START)); let outputs_late = engine.complete_write_collect(open_write.id, Err(QlError::SendFailed)); assert!(outputs_late.is_empty()); - assert!(engine.streams.contains_key(&stream_id)); + assert!(engine.streams.streams.contains_key(&stream_id)); } #[test] @@ -440,9 +440,10 @@ fn out_of_order_remote_stream_buffers_until_open_arrives() { .any(|output| matches!(output, EngineOutput::InboundData { .. }))); assert!(engine.take_next_write().is_some()); assert!(engine + .streams .streams .get(&stream_id) - .is_some_and(StreamState::is_provisional)); + .is_some_and(StreamState::awaiting_open)); let open_message = StreamMessage { tx_seq: StreamSeq(1), @@ -508,7 +509,7 @@ fn delayed_ack_only_does_not_consume_sequence_space() { let _outputs_b = harness.b.drain_outputs(); - let stream = harness.b.streams.get(&stream_id).unwrap(); + let stream = harness.b.streams.streams.get(&stream_id).unwrap(); assert!(stream.control.in_flight.is_empty()); assert_eq!(stream.control.next_tx_seq, StreamSeq::START); } @@ -747,10 +748,7 @@ fn selective_ack_only_body_retires_acked_gap_tail() { peer, session_key, } = SingleEngineHarness::connected(EngineConfig::default(), 81, 2); - let stream_id = engine.state.next_stream_id(StreamNamespace::for_local( - engine.engine.identity.xid, - peer.xid, - )); + let stream_id = engine.streams.next_stream_id(); insert_inflight_gap_stream(&mut engine, stream_id, now); let ack_record = wire::stream::encrypt_stream( @@ -776,7 +774,7 @@ fn selective_ack_only_body_retires_acked_gap_tail() { assert!(!outputs .iter() .any(|output| matches!(output, EngineOutput::OutboundFailed { .. }))); - let stream = engine.streams.get(&stream_id).unwrap(); + let stream = engine.streams.streams.get(&stream_id).unwrap(); let remaining: Vec<_> = stream .control .in_flight @@ -797,10 +795,7 @@ fn fast_retransmit_resends_oldest_gap_when_threshold_met() { peer, session_key, } = SingleEngineHarness::connected(config, 83, 9); - let stream_id = engine.state.next_stream_id(StreamNamespace::for_local( - engine.engine.identity.xid, - peer.xid, - )); + let stream_id = engine.streams.next_stream_id(); insert_inflight_gap_stream(&mut engine, stream_id, now); let ack_record = wire::stream::encrypt_stream( @@ -834,7 +829,7 @@ fn fast_retransmit_resends_oldest_gap_when_threshold_met() { }) )); - let stream = engine.streams.get(&stream_id).unwrap(); + let stream = engine.streams.streams.get(&stream_id).unwrap(); let remaining: Vec<_> = stream .control .in_flight @@ -844,7 +839,10 @@ fn fast_retransmit_resends_oldest_gap_when_threshold_met() { assert_eq!(remaining, vec![StreamSeq(3)]); let frame = stream.control.in_flight.get(&StreamSeq(3)).unwrap(); assert_eq!(frame.attempt, 1); - assert!(matches!(frame.write_state, InFlightWriteState::Issued)); + assert!(matches!( + frame.write_state, + InFlightWriteState::Issued { .. } + )); } #[test] @@ -857,10 +855,7 @@ fn fast_retransmit_respects_configured_threshold() { peer, session_key, } = SingleEngineHarness::connected(config, 85, 10); - let stream_id = engine.state.next_stream_id(StreamNamespace::for_local( - engine.engine.identity.xid, - peer.xid, - )); + let stream_id = engine.streams.next_stream_id(); insert_inflight_gap_stream(&mut engine, stream_id, now); let ack_record = wire::stream::encrypt_stream( @@ -888,7 +883,7 @@ fn fast_retransmit_respects_configured_threshold() { assert!(matches!(body, StreamBody::Ack(_))); } - let stream = engine.streams.get(&stream_id).unwrap(); + let stream = engine.streams.streams.get(&stream_id).unwrap(); let remaining: Vec<_> = stream .control .in_flight @@ -928,7 +923,7 @@ fn timeout_retransmit_reuses_original_tx_seq_and_slot() { )); let _outputs_written = engine.complete_write_collect(write.id, Ok(())); - let stream = engine.streams.get(&tracked_stream_id).unwrap(); + let stream = engine.streams.streams.get(&tracked_stream_id).unwrap(); assert_eq!(stream.control.in_flight.len(), 1); assert!(stream.control.in_flight.contains_key(&StreamSeq::START)); assert_eq!(stream.control.next_tx_seq, StreamSeq(2)); @@ -946,7 +941,7 @@ fn timeout_retransmit_reuses_original_tx_seq_and_slot() { }) if stream_id == tracked_stream_id )); - let stream = engine.streams.get(&tracked_stream_id).unwrap(); + let stream = engine.streams.streams.get(&tracked_stream_id).unwrap(); assert_eq!(stream.control.in_flight.len(), 1); assert!(stream.control.in_flight.contains_key(&StreamSeq::START)); assert_eq!(stream.control.next_tx_seq, StreamSeq(2)); @@ -966,13 +961,10 @@ fn take_next_write_drains_multiple_stream_frames_before_completion() { let SingleEngineHarness { now, mut engine, - peer, + peer: _, session_key, } = SingleEngineHarness::connected(EngineConfig::default(), 93, 12); - let stream_id = engine.state.next_stream_id(StreamNamespace::for_local( - engine.engine.identity.xid, - peer.xid, - )); + let stream_id = engine.streams.next_stream_id(); insert_unwritten_inflight_stream_with_data(&mut engine, stream_id, now, &[2, 3]); let writes = { @@ -1000,12 +992,12 @@ fn take_next_write_drains_multiple_stream_frames_before_completion() { assert_eq!(engine.state.active_writes.len(), writes.len()); assert!(engine.take_next_write().is_none()); - let stream = engine.streams.get(&stream_id).unwrap(); + let stream = engine.streams.streams.get(&stream_id).unwrap(); assert!(stream .control .in_flight .iter() - .all(|(_, in_flight)| matches!(in_flight.write_state, InFlightWriteState::Issued))); + .all(|(_, in_flight)| matches!(in_flight.write_state, InFlightWriteState::Issued { .. }))); } #[test] @@ -1013,18 +1005,15 @@ fn take_next_write_does_not_reissue_outstanding_frame() { let SingleEngineHarness { now, mut engine, - peer, + peer: _, session_key: _session_key, } = SingleEngineHarness::connected(EngineConfig::default(), 95, 13); - let stream_id = engine.state.next_stream_id(StreamNamespace::for_local( - engine.engine.identity.xid, - peer.xid, - )); + let stream_id = engine.streams.next_stream_id(); insert_unwritten_inflight_stream_with_data(&mut engine, stream_id, now, &[]); let write = engine.take_next_write().unwrap(); assert!(engine.take_next_write().is_none()); - assert!(engine.state.active_writes.contains_key(&write.id)); + assert!(engine.state.active_writes.contains(write.id.0)); } #[test] @@ -1032,17 +1021,11 @@ fn take_next_write_round_robins_across_ready_streams() { let SingleEngineHarness { now, mut engine, - peer, + peer: _, session_key, } = SingleEngineHarness::connected(EngineConfig::default(), 97, 14); - let stream_id1 = engine.state.next_stream_id(StreamNamespace::for_local( - engine.engine.identity.xid, - peer.xid, - )); - let stream_id2 = engine.state.next_stream_id(StreamNamespace::for_local( - engine.engine.identity.xid, - peer.xid, - )); + let stream_id1 = engine.streams.next_stream_id(); + let stream_id2 = engine.streams.next_stream_id(); insert_unwritten_inflight_stream_with_data(&mut engine, stream_id1, now, &[2]); insert_unwritten_inflight_stream_with_data(&mut engine, stream_id2, now, &[2]); @@ -1109,7 +1092,7 @@ fn stale_ack_delay_timer_after_piggyback_does_not_emit_extra_ack_only() { } #[test] -fn provisional_timeout_after_late_open_is_ignored() { +fn late_opened_stream_ignores_unrelated_timer_tick() { let config = EngineConfig::default(); let SingleEngineHarness { now, @@ -1176,7 +1159,11 @@ fn provisional_timeout_after_late_open_is_ignored() { engine.run_tick_collect(now + config.packet_expiration, EngineInput::TimerExpired); assert!(matches!( - engine.streams.get(&stream_id).map(|stream| &stream.role), + engine + .streams + .streams + .get(&stream_id) + .map(|stream| &stream.role), Some(StreamRole::Responder(_)) )); if let Some(write) = engine.take_next_write() { @@ -1285,7 +1272,7 @@ fn ack_only_write_failure_immediately_requeues_ack_without_spending_extra_seq() .. }) if id == stream_id && bytes == b"resp" )); - let stream = engine.streams.get(&stream_id).unwrap(); + let stream = engine.streams.streams.get(&stream_id).unwrap(); assert_eq!(stream.control.next_tx_seq, StreamSeq(2)); } @@ -1396,10 +1383,7 @@ fn repeated_identical_gap_ack_only_fast_retransmits_once() { peer, session_key, } = SingleEngineHarness::connected(config, 69, 14); - let stream_id = engine.state.next_stream_id(StreamNamespace::for_local( - engine.engine.identity.xid, - peer.xid, - )); + let stream_id = engine.streams.next_stream_id(); insert_inflight_gap_stream(&mut engine, stream_id, now); let local_xid = engine.engine.identity.xid; @@ -1456,10 +1440,7 @@ fn fast_recovery_clears_after_gap_is_acked_and_allows_next_gap() { peer, session_key, } = SingleEngineHarness::connected(config, 73, 15); - let stream_id = engine.state.next_stream_id(StreamNamespace::for_local( - engine.engine.identity.xid, - peer.xid, - )); + let stream_id = engine.streams.next_stream_id(); insert_inflight_stream_with_data(&mut engine, stream_id, now, &[2, 3, 4, 5, 6]); let first_ack = wire::stream::encrypt_stream( @@ -1531,17 +1512,21 @@ fn fast_retransmit_and_retry_deadline_same_tick_only_send_once() { peer, session_key, } = SingleEngineHarness::connected(config, 75, 16); - let stream_id = engine.state.next_stream_id(StreamNamespace::for_local( - engine.engine.identity.xid, - peer.xid, - )); + let stream_id = engine.streams.next_stream_id(); insert_inflight_gap_stream(&mut engine, stream_id, now); - engine - .streams - .get_mut(&stream_id) - .unwrap() - .control - .set_retry_deadline(StreamSeq(3), now); + + { + let in_flight = engine + .streams + .streams + .get_mut(&stream_id) + .unwrap() + .control + .in_flight + .get_mut(&StreamSeq(3)) + .unwrap(); + in_flight.write_state = InFlightWriteState::WaitingRetry { retry_at: now }; + } let ack_record = wire::stream::encrypt_stream( QlHeader { diff --git a/ql-engine/src/lib.rs b/ql-engine/src/lib.rs index b4a1e8ac..db9c878d 100644 --- a/ql-engine/src/lib.rs +++ b/ql-engine/src/lib.rs @@ -1,6 +1,7 @@ +pub(crate) mod arena; pub mod engine; pub mod identity; -// pub mod rpc; +pub mod stream; pub mod wire; pub use wire::{PacketId, StreamId}; diff --git a/ql-engine/src/stream/internal.rs b/ql-engine/src/stream/internal.rs new file mode 100644 index 00000000..358f0465 --- /dev/null +++ b/ql-engine/src/stream/internal.rs @@ -0,0 +1,842 @@ +use std::{collections::VecDeque, time::Instant}; + +use super::{state::*, *}; +use crate::{ + wire::{ + stream::{ + BodyChunk, CloseCode, CloseTarget, StreamAck, StreamAckBody, StreamBody, StreamFrame, + StreamFrameClose, StreamFrameData, StreamFrameOpen, StreamMessage, + }, + StreamSeq, + }, + StreamId, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum OutboundSelection { + Ack, + InitialFrame { tx_seq: StreamSeq }, + RetryFrame { tx_seq: StreamSeq }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum StreamDisposition { + Keep, + Reap, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum TimerAction { + None, + Fail, +} + +impl StreamFsm { + pub fn open_stream_inner( + &mut self, + request_head: Vec, + request_prefix: Option, + ) -> StreamId { + let stream_id = self.next_stream_id(); + let request_prefix_fin = request_prefix.as_ref().is_some_and(|chunk| chunk.fin); + let mut stream = StreamState { + control: StreamControl { + pending: VecDeque::from([StreamFrame::Open(StreamFrameOpen { + stream_id, + request_head, + request_prefix, + })]), + ..Default::default() + }, + role: StreamRole::Initiator(InitiatorStream { + request: OutboundPhase::from_prefix(request_prefix_fin), + response: InboundState::new(), + }), + }; + Self::drive_stream(&mut stream, stream_id); + self.streams.insert(stream_id, stream); + stream_id + } + + pub fn write_stream_inner( + &mut self, + stream_id: StreamId, + bytes: Vec, + ) -> Result<(), StreamError> { + if bytes.is_empty() { + return Ok(()); + } + + let Some(stream) = self.streams.get_mut(&stream_id) else { + return Err(StreamError::MissingStream); + }; + let Some(side) = stream.outbound_side() else { + return Err(StreamError::NotWritable); + }; + if let StreamRole::Responder(state) = &mut stream.role { + if side == StreamSide::Response { + state.response_started = true; + } + } + let Some(outbound) = stream.outbound_mut(side) else { + return Err(StreamError::NotWritable); + }; + if !outbound.can_queue_data() { + return Err(StreamError::NotWritable); + } + + stream + .control + .pending + .push_back(StreamFrame::Data(StreamFrameData { + stream_id, + chunk: BodyChunk { bytes, fin: false }, + })); + Self::drive_stream(stream, stream_id); + Ok(()) + } + + pub fn finish_stream_inner(&mut self, stream_id: StreamId) -> Result<(), StreamError> { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return Err(StreamError::MissingStream); + }; + let Some(side) = stream.outbound_side() else { + return Err(StreamError::NotWritable); + }; + if let StreamRole::Responder(state) = &mut stream.role { + if side == StreamSide::Response { + state.response_started = true; + } + } + let Some(outbound) = stream.outbound_mut(side) else { + return Err(StreamError::NotWritable); + }; + outbound.finish(); + Self::drive_stream(stream, stream_id); + Ok(()) + } + + pub fn close_stream_inner( + &mut self, + stream_id: StreamId, + target: CloseTarget, + code: CloseCode, + payload: Vec, + ) -> Result<(), StreamError> { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return Err(StreamError::MissingStream); + }; + + let mut dirty = false; + if matches!(target, CloseTarget::Request | CloseTarget::Both) { + if let Some(inbound) = stream.inbound_mut(StreamSide::Request) { + dirty |= inbound.close(); + } + if let Some(outbound) = stream.outbound_mut(StreamSide::Request) { + dirty |= outbound.close(); + } + } + if matches!(target, CloseTarget::Response | CloseTarget::Both) { + if let Some(inbound) = stream.inbound_mut(StreamSide::Response) { + dirty |= inbound.close(); + } + if let Some(outbound) = stream.outbound_mut(StreamSide::Response) { + dirty |= outbound.close(); + } + } + + if dirty { + stream + .control + .pending + .push_front(close_frame(stream_id, target, code, payload)); + Self::drive_stream(stream, stream_id); + } + + Ok(()) + } + + pub fn receive_inner( + &mut self, + now: Instant, + body: StreamBody, + events: &mut impl StreamEventSink, + ) { + match body { + StreamBody::Ack(StreamAckBody { stream_id, ack, .. }) => { + self.process_ack(now, stream_id, ack, events) + } + StreamBody::Message(StreamMessage { + tx_seq, ack, frame, .. + }) => { + let stream_id = frame.stream_id(); + self.process_ack(now, stream_id, ack, events); + + if !self.streams.contains_key(&stream_id) { + if !self.config.local_namespace.remote().matches(stream_id) { + return; + } + self.streams.insert( + stream_id, + StreamState { + control: StreamControl::default(), + role: StreamRole::Responder(ResponderStream { + opened: false, + request: InboundState::new(), + response: OutboundPhase::Ready, + response_started: false, + }), + }, + ); + } + + let disposition = { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + + match stream.control.buffer_incoming(tx_seq, frame) { + BufferIncomingResult::OutOfWindow => { + Self::queue_protocol_close(stream_id, stream, events); + StreamDisposition::Keep + } + BufferIncomingResult::Duplicate | BufferIncomingResult::AlreadyBuffered => { + stream.control.note_ack(now, self.config.ack_delay, true); + StreamDisposition::Keep + } + BufferIncomingResult::Buffered { out_of_order } => { + stream + .control + .note_ack(now, self.config.ack_delay, out_of_order); + Self::drain_committed_frames(stream_id, stream, events) + } + } + }; + + match disposition { + StreamDisposition::Keep => {} + StreamDisposition::Reap => { + self.streams.remove(&stream_id); + events.reaped(stream_id); + } + } + } + } + } + + pub fn next_outbound_inner(&mut self, now: Instant, valid_until: u64) -> Option { + for offset in 0..self.streams.len() { + let stream_id = self.streams.id_at_offset(offset)?; + let selection = { + let stream = self.streams.get(&stream_id)?; + self.select_outbound(stream, now) + }; + let Some(selection) = selection else { + continue; + }; + + let outbound = match selection { + OutboundSelection::Ack => { + let stream = self.streams.get_mut(&stream_id)?; + let ack = stream.control.current_ack(); + stream.control.clear_ack_schedule(); + stream.control.note_ack_sent(ack); + Outbound { + body: StreamBody::Ack(StreamAckBody { + stream_id, + ack, + valid_until, + }), + completion: OutboundCompletion::Ack { stream_id }, + } + } + OutboundSelection::InitialFrame { tx_seq } + | OutboundSelection::RetryFrame { tx_seq } => { + let issue_id = self.next_issue_id(); + let stream = self.streams.get_mut(&stream_id)?; + let inbound_alive = match stream.role { + StreamRole::Initiator(state) => !state.response.closed, + StreamRole::Responder(state) => !state.request.closed, + }; + let ack = stream.control.take_piggyback_ack(inbound_alive); + let frame = stream.control.mark_write_issued(tx_seq, issue_id)?; + Outbound { + body: StreamBody::Message(StreamMessage { + tx_seq, + ack, + valid_until, + frame, + }), + completion: OutboundCompletion::Frame { + stream_id, + tx_seq, + issue_id, + }, + } + } + }; + + self.streams.advance_cursor_after(stream_id); + return Some(outbound); + } + + None + } + + pub fn complete_outbound_inner( + &mut self, + now: Instant, + completion: OutboundCompletion, + result: Result<(), WriteError>, + events: &mut impl StreamEventSink, + ) { + match completion { + OutboundCompletion::Ack { stream_id } => { + if let Some(stream) = self.streams.get_mut(&stream_id) { + if result.is_err() { + stream.control.note_ack(now, self.config.ack_delay, true); + } + if stream.can_reap() { + self.streams.remove(&stream_id); + events.reaped(stream_id); + } + } + } + OutboundCompletion::Frame { + stream_id, + tx_seq, + issue_id, + } => match result { + Ok(()) => { + if let Some(stream) = self.streams.get_mut(&stream_id) { + let _ = stream.control.complete_write( + tx_seq, + issue_id, + now + self.config.ack_timeout, + ); + } + } + Err(WriteError::SendFailed) => { + let should_fail = self.streams.get(&stream_id).is_some_and(|stream| { + stream.control.frame_write_is_issued(tx_seq, issue_id) + }); + if should_fail { + self.fail_stream_by_id(stream_id, StreamError::SendFailed, events); + } + } + }, + } + } + + pub fn on_timer_inner(&mut self, now: Instant, events: &mut impl StreamEventSink) { + let mut index = 0; + while let Some(stream_id) = self.streams.ordered_id(index) { + let action = { + let stream = self + .streams + .get(&stream_id) + .expect("ordered stream id should exist"); + if stream.control.in_flight.iter().any(|(_, in_flight)| { + matches!( + in_flight.write_state, + InFlightWriteState::WaitingRetry { retry_at } + if retry_at <= now && in_flight.attempt >= self.config.retry_limit + ) + }) { + TimerAction::Fail + } else { + TimerAction::None + } + }; + + match action { + TimerAction::Fail => { + self.fail_stream_by_id(stream_id, StreamError::Timeout, events); + } + TimerAction::None => { + if let Some(stream) = self.streams.get_mut(&stream_id) { + if stream + .control + .ack_deadline() + .is_some_and(|due_at| due_at <= now) + { + stream.control.ack_state = AckState::Immediate; + } + } + index += 1; + } + } + } + } + + pub fn next_deadline_inner(&self) -> Option { + let mut next = None; + for stream in self.streams.values() { + if let Some(deadline) = stream.control.ack_deadline() { + next = min_deadline(next, deadline); + } + for (_, in_flight) in stream.control.in_flight.iter() { + if let InFlightWriteState::WaitingRetry { retry_at } = in_flight.write_state { + next = min_deadline(next, retry_at); + } + } + } + next + } + + pub fn abort_inner(&mut self, error: StreamError, events: &mut impl StreamEventSink) { + while let Some(stream_id) = self.streams.first_id() { + self.fail_stream_by_id(stream_id, error.clone(), events); + } + } +} + +impl StreamFsm { + pub(crate) fn next_stream_id(&mut self) -> StreamId { + let seq = self.next_stream_id; + self.next_stream_id = seq.wrapping_add(1); + StreamId((seq & !StreamNamespace::BIT) | self.config.local_namespace.bit()) + } + + fn next_issue_id(&mut self) -> u64 { + let id = self.next_issue_id; + self.next_issue_id = id.wrapping_add(1); + id + } + + fn select_outbound(&self, stream: &StreamState, now: Instant) -> Option { + if let Some(tx_seq) = stream + .control + .in_flight + .iter() + .find_map(|(tx_seq, in_flight)| { + matches!( + in_flight.write_state, + InFlightWriteState::WaitingRetry { retry_at } + if retry_at <= now && in_flight.attempt < self.config.retry_limit + ) + .then_some(tx_seq) + }) + { + return Some(OutboundSelection::RetryFrame { tx_seq }); + } + if let Some(tx_seq) = stream + .control + .in_flight + .iter() + .find_map(|(tx_seq, in_flight)| { + matches!(in_flight.write_state, InFlightWriteState::Ready).then_some(tx_seq) + }) + { + return Some(OutboundSelection::InitialFrame { tx_seq }); + } + + matches!(stream.control.ack_state, AckState::Immediate).then_some(OutboundSelection::Ack) + } + + fn process_ack( + &mut self, + now: Instant, + stream_id: StreamId, + ack: StreamAck, + events: &mut impl StreamEventSink, + ) { + if ack == StreamAck::EMPTY { + return; + } + + let should_reap = { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + stream.control.clear_fast_recovery(ack.base); + let fast_retransmit = stream + .control + .fast_retransmit_candidate(ack, self.config.fast_retransmit_threshold); + + loop { + let acked_tx_seq = + stream + .control + .in_flight + .iter() + .find_map(|(tx_seq, in_flight)| match in_flight.write_state { + InFlightWriteState::Ready => None, + InFlightWriteState::Issued { .. } + | InFlightWriteState::WaitingRetry { .. } => { + StreamControl::ack_covers(ack, tx_seq).then_some(tx_seq) + } + }); + let Some(tx_seq) = acked_tx_seq else { + break; + }; + let Some(in_flight) = stream.control.remove_in_flight(tx_seq) else { + continue; + }; + + match in_flight.frame { + StreamFrame::Open(StreamFrameOpen { request_prefix, .. }) => { + if let StreamRole::Initiator(state) = &mut stream.role { + if request_prefix.as_ref().is_some_and(|chunk| chunk.fin) + && state.request.close() + { + events.outbound_closed(stream_id); + } + } + } + StreamFrame::Data(StreamFrameData { + chunk: BodyChunk { fin: true, .. }, + .. + }) => { + if let Some(side) = stream.outbound_side() { + if let Some(outbound) = stream.outbound_mut(side) { + if outbound.close() { + events.outbound_closed(stream_id); + } + } + } + } + StreamFrame::Close(frame) => { + let mut changed = false; + for side in [StreamSide::Request, StreamSide::Response] { + let affects_outbound = matches!( + (frame.target, side), + (CloseTarget::Request, StreamSide::Request) + | (CloseTarget::Response, StreamSide::Response) + | (CloseTarget::Both, _) + ); + if affects_outbound { + if let Some(outbound) = stream.outbound_mut(side) { + if outbound.close() { + changed = true; + } + } + } + } + if changed { + events.close(StreamCloseEvent { + kind: StreamCloseKind::Acked, + role: stream.local_role(), + frame, + }); + } + } + StreamFrame::Data(_) => {} + } + } + + if let Some(tx_seq) = fast_retransmit { + stream.control.schedule_fast_retransmit(tx_seq, now); + } + Self::drive_stream(stream, stream_id); + stream.can_reap() + }; + + if should_reap { + self.streams.remove(&stream_id); + events.reaped(stream_id); + } + } + + fn drain_committed_frames( + stream_id: StreamId, + stream: &mut StreamState, + events: &mut impl StreamEventSink, + ) -> StreamDisposition { + loop { + let Some((tx_seq, frame)) = stream.control.pop_next_committable() else { + break; + }; + + if stream.awaiting_open() + && (tx_seq != StreamSeq::START || !matches!(frame, StreamFrame::Open(_))) + { + Self::queue_protocol_close(stream_id, stream, events); + return StreamDisposition::Keep; + } + + match frame { + StreamFrame::Open(frame) => { + Self::handle_stream_open(stream_id, stream, frame, events) + } + StreamFrame::Close(frame) => { + Self::handle_stream_close_from_peer(stream_id, stream, frame, events) + } + StreamFrame::Data(frame) => { + Self::handle_stream_data(stream_id, stream, frame, events) + } + } + } + + stream.control.maybe_force_ack_for_progress(); + if stream.can_reap() { + StreamDisposition::Reap + } else { + StreamDisposition::Keep + } + } + + fn handle_stream_open( + stream_id: StreamId, + stream: &mut StreamState, + frame: StreamFrameOpen, + events: &mut impl StreamEventSink, + ) { + let StreamFrameOpen { + request_head, + request_prefix, + .. + } = frame; + + let StreamRole::Responder(state) = &mut stream.role else { + Self::queue_protocol_close(stream_id, stream, events); + return; + }; + if state.opened { + Self::queue_protocol_close(stream_id, stream, events); + return; + } + + let request_fin = request_prefix.as_ref().is_some_and(|chunk| chunk.fin); + state.opened = true; + if request_fin { + let _ = stream + .inbound_mut(StreamSide::Request) + .expect("responder request side should exist") + .close(); + } + events.opened(stream_id, request_head, request_prefix); + } + + fn handle_stream_close_from_peer( + stream_id: StreamId, + stream: &mut StreamState, + frame: StreamFrameClose, + events: &mut impl StreamEventSink, + ) { + let StreamFrameClose { + target, + code, + payload, + .. + } = frame; + Self::apply_remote_close(stream_id, stream, target, code, payload, events); + } + + fn handle_stream_data( + stream_id: StreamId, + stream: &mut StreamState, + frame: StreamFrameData, + events: &mut impl StreamEventSink, + ) { + let Some(side) = stream.inbound_side() else { + Self::queue_protocol_close(stream_id, stream, events); + return; + }; + let Some(inbound) = stream.inbound_mut(side) else { + Self::queue_protocol_close(stream_id, stream, events); + return; + }; + if inbound.closed { + Self::queue_protocol_close(stream_id, stream, events); + return; + } + + let BodyChunk { bytes, fin } = frame.chunk; + if !bytes.is_empty() { + events.inbound_data(stream_id, bytes); + } + if fin && inbound.close() { + events.inbound_finished(stream_id); + } + } + + fn drive_stream(stream: &mut StreamState, stream_id: StreamId) { + match &mut stream.role { + StreamRole::Initiator(state) => Self::drive_stream_outbound( + stream_id, + &mut stream.control, + Some(&mut state.request), + ), + StreamRole::Responder(state) => Self::drive_stream_outbound( + stream_id, + &mut stream.control, + Some(&mut state.response), + ), + } + } + + fn drive_stream_outbound( + stream_id: StreamId, + control: &mut StreamControl, + mut outbound: Option<&mut OutboundPhase>, + ) { + loop { + if control.send_window_has_space() { + if let Some(frame) = control.pending.pop_front() { + let tx_seq = control.take_tx_seq(); + control.insert_in_flight(InFlightFrame { + tx_seq, + frame, + attempt: 0, + write_state: InFlightWriteState::Ready, + }); + continue; + } + } + if !control.send_window_has_space() { + return; + } + let Some(outbound) = outbound.as_deref_mut() else { + return; + }; + if outbound.queue_fin() { + let tx_seq = control.take_tx_seq(); + control.insert_in_flight(InFlightFrame { + tx_seq, + frame: StreamFrame::Data(StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: Vec::new(), + fin: true, + }, + }), + attempt: 0, + write_state: InFlightWriteState::Ready, + }); + continue; + } + return; + } + } + + fn queue_protocol_close( + stream_id: StreamId, + stream: &mut StreamState, + events: &mut impl StreamEventSink, + ) { + let opened = !stream.awaiting_open(); + stream.control.clear_transient_buffers(); + stream.control.pending.push_front(close_frame( + stream_id, + CloseTarget::Both, + CloseCode::PROTOCOL, + Vec::new(), + )); + for side in [StreamSide::Request, StreamSide::Response] { + if let Some(outbound) = stream.outbound_mut(side) { + if outbound.close() { + if opened { + events.outbound_failed(stream_id, StreamError::StreamProtocol); + } + } + } + if let Some(inbound) = stream.inbound_mut(side) { + if inbound.close() { + if opened { + events.inbound_failed(stream_id, StreamError::StreamProtocol); + } + } + } + } + Self::drive_stream(stream, stream_id); + } + + fn apply_remote_close( + stream_id: StreamId, + stream: &mut StreamState, + target: CloseTarget, + code: CloseCode, + payload: Vec, + events: &mut impl StreamEventSink, + ) { + let frame = StreamFrameClose { + stream_id, + target, + code, + payload, + }; + let mut changed = false; + if matches!(target, CloseTarget::Request | CloseTarget::Both) { + if let Some(inbound) = stream.inbound_mut(StreamSide::Request) { + if inbound.close() { + changed = true; + } + } + if let Some(outbound) = stream.outbound_mut(StreamSide::Request) { + if outbound.close() { + changed = true; + } + } + } + if matches!(target, CloseTarget::Response | CloseTarget::Both) { + if let Some(inbound) = stream.inbound_mut(StreamSide::Response) { + if inbound.close() { + changed = true; + } + } + if let Some(outbound) = stream.outbound_mut(StreamSide::Response) { + if outbound.close() { + changed = true; + } + } + } + if changed { + events.close(StreamCloseEvent { + kind: StreamCloseKind::Remote, + role: stream.local_role(), + frame, + }); + } + } + + fn fail_stream_by_id( + &mut self, + stream_id: StreamId, + error: StreamError, + events: &mut impl StreamEventSink, + ) { + let Some(stream) = self.streams.remove(&stream_id) else { + return; + }; + + match stream.role { + StreamRole::Initiator(_) => { + events.outbound_failed(stream_id, error.clone()); + events.inbound_failed(stream_id, error); + } + StreamRole::Responder(stream) => { + if !stream.opened { + events.reaped(stream_id); + return; + } + events.inbound_failed(stream_id, error.clone()); + if stream.response_started || stream.response.is_closed() { + events.outbound_failed(stream_id, error); + } + } + } + events.reaped(stream_id); + } +} + +fn min_deadline(current: Option, candidate: Instant) -> Option { + Some(match current { + Some(current) => current.min(candidate), + None => candidate, + }) +} + +fn close_frame( + stream_id: StreamId, + target: CloseTarget, + code: CloseCode, + payload: Vec, +) -> StreamFrame { + StreamFrame::Close(StreamFrameClose { + stream_id, + target, + code, + payload, + }) +} diff --git a/ql-engine/src/stream/mod.rs b/ql-engine/src/stream/mod.rs new file mode 100644 index 00000000..481cbc8c --- /dev/null +++ b/ql-engine/src/stream/mod.rs @@ -0,0 +1,270 @@ +pub(crate) mod internal; +pub(crate) mod ring; +pub(crate) mod state; + +#[cfg(test)] +mod tests; + +use std::time::{Duration, Instant}; + +use thiserror::Error; + +use crate::{ + wire::{ + stream::{BodyChunk, CloseCode, CloseTarget, StreamBody, StreamFrameClose}, + StreamSeq, + }, + StreamId, +}; + +pub const STREAM_WINDOW_CAPACITY: usize = 8; +pub const STREAM_WINDOW_SIZE: u32 = STREAM_WINDOW_CAPACITY as u32; +pub const STREAM_ACK_EAGER_THRESHOLD: u32 = STREAM_WINDOW_SIZE / 2; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamNamespace { + Low, + High, +} + +impl StreamNamespace { + const BIT: u32 = 1 << 31; + + pub fn for_local(local: bc_components::XID, peer: bc_components::XID) -> Self { + match local.data().cmp(peer.data()) { + std::cmp::Ordering::Less | std::cmp::Ordering::Equal => Self::Low, + std::cmp::Ordering::Greater => Self::High, + } + } + + pub fn bit(self) -> u32 { + match self { + Self::Low => 0, + Self::High => Self::BIT, + } + } + + pub fn matches(self, stream_id: StreamId) -> bool { + (stream_id.0 & Self::BIT) == self.bit() + } + + pub fn remote(self) -> Self { + match self { + Self::Low => Self::High, + Self::High => Self::Low, + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct StreamFsmConfig { + pub local_namespace: StreamNamespace, + pub ack_delay: Duration, + pub ack_timeout: Duration, + pub fast_retransmit_threshold: u8, + pub retry_limit: u8, +} + +impl Default for StreamFsmConfig { + fn default() -> Self { + Self { + local_namespace: StreamNamespace::Low, + ack_delay: Duration::from_millis(5), + ack_timeout: Duration::from_millis(150), + fast_retransmit_threshold: 2, + retry_limit: 5, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum OutboundCompletion { + Ack { + stream_id: StreamId, + }, + Frame { + stream_id: StreamId, + tx_seq: StreamSeq, + issue_id: u64, + }, +} + +impl OutboundCompletion { + pub fn stream_id(self) -> StreamId { + match self { + Self::Ack { stream_id } | Self::Frame { stream_id, .. } => stream_id, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Outbound { + pub body: StreamBody, + pub completion: OutboundCompletion, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamCloseKind { + Acked, + Remote, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamLocalRole { + Initiator, + Responder, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamCloseEvent { + pub kind: StreamCloseKind, + pub role: StreamLocalRole, + pub frame: StreamFrameClose, +} + +pub trait StreamEventSink { + fn opened( + &mut self, + stream_id: StreamId, + request_head: Vec, + request_prefix: Option, + ); + + fn inbound_data(&mut self, stream_id: StreamId, bytes: Vec); + + fn inbound_finished(&mut self, stream_id: StreamId); + + fn inbound_failed(&mut self, stream_id: StreamId, error: StreamError); + + fn close(&mut self, event: StreamCloseEvent); + + fn outbound_closed(&mut self, stream_id: StreamId); + + fn outbound_failed(&mut self, stream_id: StreamId, error: StreamError); + + fn reaped(&mut self, stream_id: StreamId); +} + +impl StreamEventSink for () { + fn opened( + &mut self, + _stream_id: StreamId, + _request_head: Vec, + _request_prefix: Option, + ) { + } + + fn inbound_data(&mut self, _stream_id: StreamId, _bytes: Vec) {} + + fn inbound_finished(&mut self, _stream_id: StreamId) {} + + fn inbound_failed(&mut self, _stream_id: StreamId, _error: StreamError) {} + + fn close(&mut self, _event: StreamCloseEvent) {} + + fn outbound_closed(&mut self, _stream_id: StreamId) {} + + fn outbound_failed(&mut self, _stream_id: StreamId, _error: StreamError) {} + + fn reaped(&mut self, _stream_id: StreamId) {} +} + +#[derive(Debug, Clone, PartialEq, Eq, Error)] +pub enum StreamError { + #[error("missing stream")] + MissingStream, + #[error("stream is not writable")] + NotWritable, + #[error("send failed")] + SendFailed, + #[error("timeout")] + Timeout, + #[error("cancelled")] + Cancelled, + #[error("stream protocol error")] + StreamProtocol, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)] +pub enum WriteError { + #[error("send failed")] + SendFailed, +} + +pub struct StreamFsm { + config: StreamFsmConfig, + pub(crate) streams: state::StreamStore, + next_stream_id: u32, + next_issue_id: u64, +} + +impl StreamFsm { + pub fn new(config: StreamFsmConfig) -> Self { + Self { + config, + streams: state::StreamStore::default(), + next_stream_id: 1, + next_issue_id: 1, + } + } + + pub fn open_stream( + &mut self, + request_head: Vec, + request_prefix: Option, + ) -> StreamId { + self.open_stream_inner(request_head, request_prefix) + } + + pub fn write_stream(&mut self, stream_id: StreamId, bytes: Vec) -> Result<(), StreamError> { + self.write_stream_inner(stream_id, bytes) + } + + pub fn finish_stream(&mut self, stream_id: StreamId) -> Result<(), StreamError> { + self.finish_stream_inner(stream_id) + } + + pub fn close_stream( + &mut self, + stream_id: StreamId, + target: CloseTarget, + code: CloseCode, + payload: Vec, + ) -> Result<(), StreamError> { + self.close_stream_inner(stream_id, target, code, payload) + } + + pub fn receive(&mut self, now: Instant, body: StreamBody, events: &mut impl StreamEventSink) { + self.receive_inner(now, body, events) + } + + pub fn next_outbound(&mut self, now: Instant, valid_until: u64) -> Option { + self.next_outbound_inner(now, valid_until) + } + + pub fn complete_outbound( + &mut self, + now: Instant, + completion: OutboundCompletion, + result: Result<(), WriteError>, + events: &mut impl StreamEventSink, + ) { + self.complete_outbound_inner(now, completion, result, events) + } + + pub fn on_timer(&mut self, now: Instant, events: &mut impl StreamEventSink) { + self.on_timer_inner(now, events) + } + + pub fn next_deadline(&self) -> Option { + self.next_deadline_inner() + } + + pub fn abort(&mut self, error: StreamError, events: &mut impl StreamEventSink) { + self.abort_inner(error, events); + } + + pub fn set_local_namespace(&mut self, local_namespace: StreamNamespace) { + self.config.local_namespace = local_namespace; + } +} diff --git a/ql-engine/src/stream/ring.rs b/ql-engine/src/stream/ring.rs new file mode 100644 index 00000000..ccefb1ce --- /dev/null +++ b/ql-engine/src/stream/ring.rs @@ -0,0 +1,194 @@ +use std::array; + +use crate::wire::StreamSeq; + +#[derive(Debug)] +pub enum SeqRingInsertError { + OutOfWindow, + Occupied, +} + +#[derive(Debug)] +pub struct SeqRing { + base_seq: StreamSeq, + head: usize, + len: usize, + slots: [Option; N], +} + +impl SeqRing { + pub fn new(base_seq: StreamSeq) -> Self { + Self { + base_seq, + head: 0, + len: 0, + slots: array::from_fn(|_| None), + } + } + + pub fn base_seq(&self) -> StreamSeq { + self.base_seq + } + + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + pub fn len(&self) -> usize { + self.len + } + + pub fn clear_with_base(&mut self, base_seq: StreamSeq) { + for slot in &mut self.slots { + let _ = slot.take(); + } + self.base_seq = base_seq; + self.head = 0; + self.len = 0; + } + + pub fn contains_key(&self, seq: &StreamSeq) -> bool { + self.get(seq).is_some() + } + + pub fn accepts_seq(&self, seq: StreamSeq) -> bool { + self.offset_for(seq).is_some() + } + + pub fn get(&self, seq: &StreamSeq) -> Option<&T> { + let index = self.index_for(*seq)?; + self.slots[index].as_ref() + } + + pub fn get_mut(&mut self, seq: &StreamSeq) -> Option<&mut T> { + let index = self.index_for(*seq)?; + self.slots[index].as_mut() + } + + pub fn insert(&mut self, seq: StreamSeq, value: T) -> Result<(), SeqRingInsertError> { + let index = self.index_for(seq).ok_or(SeqRingInsertError::OutOfWindow)?; + if self.slots[index].is_some() { + return Err(SeqRingInsertError::Occupied); + } + self.slots[index] = Some(value); + self.len += 1; + Ok(()) + } + + pub fn set(&mut self, seq: StreamSeq, value: T) -> Result, SeqRingInsertError> { + let index = self.index_for(seq).ok_or(SeqRingInsertError::OutOfWindow)?; + let previous = self.slots[index].replace(value); + if previous.is_none() { + self.len += 1; + } + Ok(previous) + } + + pub fn remove(&mut self, seq: &StreamSeq) -> Option { + let index = self.index_for(*seq)?; + let value = self.slots[index].take(); + if value.is_some() { + self.len -= 1; + } + value + } + + pub fn take_front(&mut self) -> Option<(StreamSeq, T)> { + let value = self.slots[self.head].take()?; + let seq = self.base_seq; + self.len -= 1; + self.head = self.next_index(self.head); + self.base_seq = self.base_seq.next(); + Some((seq, value)) + } + + pub fn advance_empty_front_until(&mut self, limit_seq: StreamSeq) { + while self.base_seq.serial_lt(limit_seq) && self.slots[self.head].is_none() { + self.head = self.next_index(self.head); + self.base_seq = self.base_seq.next(); + } + } + + pub fn iter(&self) -> SeqRingIter<'_, N, T> { + SeqRingIter { + ring: self, + offset: 0, + } + } + + pub fn bitmap(&self) -> u8 { + debug_assert!(N <= 8); + let mut bitmap = 0u8; + for offset in 0..N { + let index = self.index_for_offset(offset); + if self.slots[index].is_some() { + bitmap |= 1u8 << offset; + } + } + bitmap + } + + fn index_for(&self, seq: StreamSeq) -> Option { + let offset = self.offset_for(seq)?; + Some(self.index_for_offset(offset)) + } + + fn offset_for(&self, seq: StreamSeq) -> Option { + let offset = self.base_seq.forward_distance_to(seq)? as usize; + (offset < N).then_some(offset) + } + + fn index_for_offset(&self, offset: usize) -> usize { + (self.head + offset) % N + } + + fn next_index(&self, index: usize) -> usize { + (index + 1) % N + } +} + +pub struct SeqRingIter<'a, const N: usize, T> { + ring: &'a SeqRing, + offset: usize, +} + +impl<'a, const N: usize, T> Iterator for SeqRingIter<'a, N, T> { + type Item = (StreamSeq, &'a T); + + fn next(&mut self) -> Option { + while self.offset < N { + let offset = self.offset; + self.offset += 1; + let index = self.ring.index_for_offset(offset); + if let Some(value) = self.ring.slots[index].as_ref() { + return Some((self.ring.base_seq.add(offset as u32), value)); + } + } + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn seq_ring_wraps_and_reuses_slots() { + let mut ring = SeqRing::<4, u64>::new(StreamSeq(1)); + ring.insert(StreamSeq(1), 1).unwrap(); + ring.insert(StreamSeq(2), 2).unwrap(); + ring.insert(StreamSeq(3), 3).unwrap(); + + assert_eq!(ring.take_front(), Some((StreamSeq(1), 1))); + assert_eq!(ring.take_front(), Some((StreamSeq(2), 2))); + + ring.insert(StreamSeq(4), 4).unwrap(); + ring.insert(StreamSeq(5), 5).unwrap(); + + let remaining: Vec<_> = ring.iter().map(|(seq, value)| (seq, *value)).collect(); + assert_eq!( + remaining, + vec![(StreamSeq(3), 3), (StreamSeq(4), 4), (StreamSeq(5), 5)] + ); + } +} diff --git a/ql-engine/src/engine/stream.rs b/ql-engine/src/stream/state.rs similarity index 69% rename from ql-engine/src/engine/stream.rs rename to ql-engine/src/stream/state.rs index 35095973..402b7686 100644 --- a/ql-engine/src/engine/stream.rs +++ b/ql-engine/src/stream/state.rs @@ -1,22 +1,169 @@ use std::{ collections::{HashMap, VecDeque}, - time::Instant, + time::{Duration, Instant}, }; -use super::{Token, ring::SeqRing}; +use super::{ + ring::SeqRing, StreamLocalRole, STREAM_ACK_EAGER_THRESHOLD, STREAM_WINDOW_CAPACITY, + STREAM_WINDOW_SIZE, +}; use crate::{ - StreamId, wire::{ + stream::{StreamAck, StreamFrame}, StreamSeq, - stream::{CloseCode, CloseTarget, StreamAck, StreamFrame, StreamFrameClose}, }, + StreamId, }; -// todo: need to figure out protocol behavior for: if the peer ACKs your Open and then stays silent forever, the stream will stay pending forever +#[derive(Debug, Default)] +pub struct StreamStore { + streams: HashMap, + order: Vec, + cursor: usize, +} + +impl StreamStore { + pub fn contains_key(&self, stream_id: &StreamId) -> bool { + self.streams.contains_key(stream_id) + } + + pub fn insert(&mut self, stream_id: StreamId, stream: StreamState) -> Option { + if !self.streams.contains_key(&stream_id) { + self.order.push(stream_id); + } + self.streams.insert(stream_id, stream) + } + + pub fn get(&self, stream_id: &StreamId) -> Option<&StreamState> { + self.streams.get(stream_id) + } -pub const STREAM_WINDOW_CAPACITY: usize = 8; -pub const STREAM_WINDOW_SIZE: u32 = STREAM_WINDOW_CAPACITY as u32; -pub const STREAM_ACK_EAGER_THRESHOLD: u32 = STREAM_WINDOW_SIZE / 2; + pub fn get_mut(&mut self, stream_id: &StreamId) -> Option<&mut StreamState> { + self.streams.get_mut(stream_id) + } + + pub fn remove(&mut self, stream_id: &StreamId) -> Option { + let removed = self.streams.remove(stream_id); + if removed.is_some() { + if let Some(index) = self.order.iter().position(|id| id == stream_id) { + self.order.remove(index); + if self.order.is_empty() { + self.cursor = 0; + } else if index < self.cursor { + self.cursor -= 1; + } else if self.cursor >= self.order.len() { + self.cursor = 0; + } + } + } + removed + } + + pub fn values(&self) -> impl Iterator { + self.streams.values() + } + + pub fn len(&self) -> usize { + self.order.len() + } + + pub fn id_at_offset(&self, offset: usize) -> Option { + let len = self.order.len(); + if len == 0 || offset >= len { + return None; + } + Some(self.order[(self.cursor + offset) % len]) + } + + pub fn ordered_id(&self, index: usize) -> Option { + self.order.get(index).copied() + } + + pub fn first_id(&self) -> Option { + self.order.first().copied() + } + + pub fn advance_cursor_after(&mut self, stream_id: StreamId) { + if let Some(index) = self.order.iter().position(|id| *id == stream_id) { + self.cursor = if self.order.is_empty() { + 0 + } else { + (index + 1) % self.order.len() + }; + } + } +} + +#[derive(Debug)] +pub struct StreamState { + pub control: StreamControl, + pub role: StreamRole, +} + +impl StreamState { + pub fn outbound_mut(&mut self, side: StreamSide) -> Option<&mut OutboundPhase> { + match &mut self.role { + StreamRole::Initiator(state) if side == StreamSide::Request => Some(&mut state.request), + StreamRole::Responder(state) if side == StreamSide::Response => { + Some(&mut state.response) + } + StreamRole::Initiator(_) | StreamRole::Responder(_) => None, + } + } + + pub fn inbound_mut(&mut self, side: StreamSide) -> Option<&mut InboundState> { + match &mut self.role { + StreamRole::Initiator(state) if side == StreamSide::Response => { + Some(&mut state.response) + } + StreamRole::Responder(state) if side == StreamSide::Request => Some(&mut state.request), + StreamRole::Initiator(_) | StreamRole::Responder(_) => None, + } + } + + pub fn outbound_side(&self) -> Option { + match self.role { + StreamRole::Initiator(_) => Some(StreamSide::Request), + StreamRole::Responder(_) => Some(StreamSide::Response), + } + } + + pub fn inbound_side(&self) -> Option { + match self.role { + StreamRole::Initiator(_) => Some(StreamSide::Response), + StreamRole::Responder(_) => Some(StreamSide::Request), + } + } + + pub fn awaiting_open(&self) -> bool { + matches!( + self.role, + StreamRole::Responder(ResponderStream { opened: false, .. }) + ) + } + + pub fn can_reap(&self) -> bool { + if !self.control.pending.is_empty() + || !self.control.in_flight.is_empty() + || !self.control.recv_buffer.is_empty() + || !matches!(self.control.ack_state, AckState::Idle) + { + return false; + } + + match self.role { + StreamRole::Initiator(state) => state.request.is_closed() && state.response.closed, + StreamRole::Responder(state) => state.request.closed && state.response.is_closed(), + } + } + + pub fn local_role(&self) -> StreamLocalRole { + match self.role { + StreamRole::Initiator(_) => StreamLocalRole::Initiator, + StreamRole::Responder(_) => StreamLocalRole::Responder, + } + } +} #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum StreamSide { @@ -24,12 +171,6 @@ pub enum StreamSide { Response, } -#[derive(Debug)] -pub struct StreamMeta { - pub stream_id: StreamId, - pub last_activity: Instant, -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum OutboundPhase { Ready, @@ -40,15 +181,19 @@ pub enum OutboundPhase { impl OutboundPhase { pub fn from_prefix(fin: bool) -> Self { - if fin { Self::FinQueued } else { Self::Ready } + if fin { + Self::FinQueued + } else { + Self::Ready + } } - pub fn is_closed(&self) -> bool { - *self == Self::Closed + pub fn is_closed(self) -> bool { + self == Self::Closed } - pub fn can_queue_data(&self) -> bool { - *self == Self::Ready + pub fn can_queue_data(self) -> bool { + self == Self::Ready } pub fn finish(&mut self) { @@ -76,7 +221,7 @@ impl OutboundPhase { } } -#[derive(Debug)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct InboundState { pub closed: bool, } @@ -97,11 +242,8 @@ impl InboundState { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum InFlightWriteState { - /// The frame has never been handed out to be written. Ready, - /// The frame was handed out and is awaiting `complete_write`. - Issued, - /// The frame write completed and is waiting for retransmit eligibility. + Issued { issue_id: u64 }, WaitingRetry { retry_at: Instant }, } @@ -113,7 +255,7 @@ pub struct InFlightFrame { pub write_state: InFlightWriteState, } -#[derive(Debug)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum BufferIncomingResult { Duplicate, AlreadyBuffered, @@ -121,17 +263,20 @@ pub enum BufferIncomingResult { OutOfWindow, } -// TODO: does it really make sense to have terminal control frames have sequence ids? +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AckState { + Idle, + Delayed { due_at: Instant }, + Immediate, +} + #[derive(Debug)] pub struct StreamControl { pub pending: VecDeque, pub in_flight: SeqRing, pub next_tx_seq: StreamSeq, pub recv_buffer: SeqRing, - pub ack_dirty: bool, - pub ack_immediate: bool, - pub ack_delay_token: Option, - pub ack_outbound_token: Option, + pub ack_state: AckState, pub last_sent_ack_base: StreamSeq, pub fast_recovery: Option, } @@ -143,10 +288,7 @@ impl Default for StreamControl { in_flight: SeqRing::new(StreamSeq::START), next_tx_seq: StreamSeq::START, recv_buffer: SeqRing::new(StreamSeq::START), - ack_dirty: false, - ack_immediate: false, - ack_delay_token: None, - ack_outbound_token: None, + ack_state: AckState::Idle, last_sent_ack_base: StreamSeq(0), fast_recovery: None, } @@ -168,27 +310,25 @@ impl StreamControl { self.recv_buffer.base_seq().prev() } - pub fn queue_frame_back(&mut self, frame: StreamFrame) { - self.pending.push_back(frame); - } - - pub fn queue_frame_front(&mut self, frame: StreamFrame) { - self.pending.push_front(frame); - } - - pub fn note_ack(&mut self, immediate: bool) { - self.ack_dirty = true; - self.ack_immediate |= immediate; + pub fn note_ack(&mut self, now: Instant, ack_delay: Duration, immediate: bool) { + self.ack_state = match self.ack_state { + AckState::Immediate => AckState::Immediate, + AckState::Delayed { due_at } if !immediate && !ack_delay.is_zero() => { + AckState::Delayed { due_at } + } + _ if immediate || ack_delay.is_zero() => AckState::Immediate, + _ => AckState::Delayed { + due_at: now + ack_delay, + }, + }; } pub fn clear_ack_schedule(&mut self) { - self.ack_dirty = false; - self.ack_immediate = false; - self.ack_delay_token = None; + self.ack_state = AckState::Idle; } pub fn maybe_force_ack_for_progress(&mut self) { - if !self.ack_dirty { + if matches!(self.ack_state, AckState::Idle) { return; } let committed = self.committed_rx_seq(); @@ -197,7 +337,7 @@ impl StreamControl { .forward_distance_to(committed) .unwrap_or(0); if progressed >= STREAM_ACK_EAGER_THRESHOLD { - self.ack_immediate = true; + self.ack_state = AckState::Immediate; } } @@ -207,8 +347,15 @@ impl StreamControl { } } + pub fn current_ack(&self) -> StreamAck { + StreamAck { + base: self.committed_rx_seq(), + bitmap: self.recv_buffer.bitmap(), + } + } + pub fn take_piggyback_ack(&mut self, inbound_alive: bool) -> StreamAck { - if !inbound_alive || !self.ack_dirty { + if !inbound_alive || matches!(self.ack_state, AckState::Idle) { return StreamAck::EMPTY; } let ack = self.current_ack(); @@ -217,10 +364,10 @@ impl StreamControl { ack } - pub fn current_ack(&self) -> StreamAck { - StreamAck { - base: self.committed_rx_seq(), - bitmap: self.recv_buffer.bitmap(), + pub fn ack_deadline(&self) -> Option { + match self.ack_state { + AckState::Delayed { due_at } => Some(due_at), + AckState::Idle | AckState::Immediate => None, } } @@ -284,28 +431,42 @@ impl StreamControl { } } - pub fn mark_write_issued(&mut self, tx_seq: StreamSeq) -> Option { + pub fn mark_write_issued(&mut self, tx_seq: StreamSeq, issue_id: u64) -> Option { let in_flight = self.in_flight.get_mut(&tx_seq)?; match in_flight.write_state { - InFlightWriteState::Issued => return None, + InFlightWriteState::Issued { .. } => return None, InFlightWriteState::WaitingRetry { .. } => { in_flight.attempt = in_flight.attempt.saturating_add(1); } InFlightWriteState::Ready => {} } - in_flight.write_state = InFlightWriteState::Issued; + in_flight.write_state = InFlightWriteState::Issued { issue_id }; Some(in_flight.frame.clone()) } - pub fn complete_write(&mut self, tx_seq: StreamSeq, retry_at: Instant) { - if let Some(in_flight) = self.in_flight.get_mut(&tx_seq) { - in_flight.write_state = InFlightWriteState::WaitingRetry { retry_at }; - } + pub fn frame_write_is_issued(&self, tx_seq: StreamSeq, issue_id: u64) -> bool { + matches!( + self.in_flight.get(&tx_seq).map(|in_flight| in_flight.write_state), + Some(InFlightWriteState::Issued { + issue_id: current_issue_id, + }) if current_issue_id == issue_id + ) } - pub fn set_retry_deadline(&mut self, tx_seq: StreamSeq, retry_at: Instant) { - if let Some(in_flight) = self.in_flight.get_mut(&tx_seq) { - in_flight.write_state = InFlightWriteState::WaitingRetry { retry_at }; + pub fn complete_write(&mut self, tx_seq: StreamSeq, issue_id: u64, retry_at: Instant) -> bool { + let Some(in_flight) = self.in_flight.get_mut(&tx_seq) else { + return false; + }; + match in_flight.write_state { + InFlightWriteState::Issued { + issue_id: current_issue_id, + } if current_issue_id == issue_id => { + in_flight.write_state = InFlightWriteState::WaitingRetry { retry_at }; + true + } + InFlightWriteState::Ready + | InFlightWriteState::WaitingRetry { .. } + | InFlightWriteState::Issued { .. } => false, } } @@ -333,7 +494,6 @@ impl StreamControl { self.recv_buffer .clear_with_base(self.committed_rx_seq().next()); self.clear_ack_schedule(); - self.ack_outbound_token = None; self.fast_recovery = None; } @@ -351,213 +511,22 @@ impl StreamControl { } } -#[derive(Debug)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamRole { + Initiator(InitiatorStream), + Responder(ResponderStream), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct InitiatorStream { pub request: OutboundPhase, pub response: InboundState, } -#[derive(Debug)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ResponderStream { + pub opened: bool, pub request: InboundState, pub response: OutboundPhase, pub response_started: bool, } - -#[derive(Debug)] -pub struct ProvisionalStream { - pub timeout_token: Token, -} - -#[derive(Debug)] -pub enum StreamRole { - Initiator(InitiatorStream), - Responder(ResponderStream), - Provisional(ProvisionalStream), -} - -#[derive(Debug)] -pub struct StreamState { - pub meta: StreamMeta, - pub control: StreamControl, - pub role: StreamRole, -} - -impl StreamState { - pub fn parts_mut(&mut self) -> (&mut StreamMeta, &mut StreamControl, &mut StreamRole) { - (&mut self.meta, &mut self.control, &mut self.role) - } - - pub fn outbound_mut(&mut self, side: StreamSide) -> Option<&mut OutboundPhase> { - match &mut self.role { - StreamRole::Initiator(state) if side == StreamSide::Request => Some(&mut state.request), - StreamRole::Responder(state) if side == StreamSide::Response => { - Some(&mut state.response) - } - _ => None, - } - } - - pub fn inbound_mut(&mut self, side: StreamSide) -> Option<&mut InboundState> { - match &mut self.role { - StreamRole::Initiator(state) if side == StreamSide::Response => { - Some(&mut state.response) - } - StreamRole::Responder(state) if side == StreamSide::Request => Some(&mut state.request), - _ => None, - } - } - - pub fn provisional_timeout_token(&self) -> Option { - match &self.role { - StreamRole::Provisional(state) => Some(state.timeout_token), - _ => None, - } - } - - pub fn outbound_side(&self) -> Option { - match &self.role { - StreamRole::Initiator(_) => Some(StreamSide::Request), - StreamRole::Responder(_) => Some(StreamSide::Response), - StreamRole::Provisional(_) => None, - } - } - - pub fn inbound_side(&self) -> Option { - match &self.role { - StreamRole::Initiator(_) => Some(StreamSide::Response), - StreamRole::Responder(_) => Some(StreamSide::Request), - StreamRole::Provisional(_) => None, - } - } - - pub fn is_provisional(&self) -> bool { - matches!(&self.role, StreamRole::Provisional(_)) - } - - pub fn can_reap(&self) -> bool { - if !self.control.pending.is_empty() - || !self.control.in_flight.is_empty() - || !self.control.recv_buffer.is_empty() - || self.control.ack_dirty - || self.control.ack_outbound_token.is_some() - { - return false; - } - match &self.role { - StreamRole::Initiator(state) => state.request.is_closed() && state.response.closed, - StreamRole::Responder(state) => state.request.closed && state.response.is_closed(), - StreamRole::Provisional(_) => false, - } - } -} - -#[derive(Debug, Default)] -pub struct StreamStore { - streams: HashMap, - order: Vec, - cursor: usize, -} - -impl StreamStore { - pub fn len(&self) -> usize { - self.streams.len() - } - - pub fn contains_key(&self, stream_id: &StreamId) -> bool { - self.streams.contains_key(stream_id) - } - - pub fn insert(&mut self, stream_id: StreamId, stream: StreamState) -> Option { - if !self.streams.contains_key(&stream_id) { - self.order.push(stream_id); - } - self.streams.insert(stream_id, stream) - } - - pub fn get(&self, stream_id: &StreamId) -> Option<&StreamState> { - self.streams.get(stream_id) - } - - pub fn get_mut(&mut self, stream_id: &StreamId) -> Option<&mut StreamState> { - self.streams.get_mut(stream_id) - } - - pub fn remove(&mut self, stream_id: &StreamId) -> Option { - let removed = self.streams.remove(stream_id); - if removed.is_some() { - if let Some(index) = self.order.iter().position(|id| id == stream_id) { - self.order.remove(index); - if self.order.is_empty() { - self.cursor = 0; - } else if index < self.cursor { - self.cursor -= 1; - } else if self.cursor >= self.order.len() { - self.cursor = 0; - } - } - } - removed - } - - pub fn values(&self) -> impl Iterator { - self.streams.values() - } - - pub fn values_mut(&mut self) -> impl Iterator { - self.streams.values_mut() - } - - pub fn iter(&self) -> impl Iterator { - self.streams.iter() - } - - pub fn into_inner(self) -> HashMap { - self.streams - } - - pub fn scan_from_cursor(&self) -> impl Iterator + '_ { - let len = self.order.len(); - (0..len).map(move |offset| self.order[(self.cursor + offset) % len]) - } - - pub fn advance_cursor_after(&mut self, stream_id: StreamId) { - if let Some(index) = self.order.iter().position(|id| *id == stream_id) { - self.cursor = if self.order.is_empty() { - 0 - } else { - (index + 1) % self.order.len() - }; - } - } - - pub fn stream_retry_deadline(&self) -> Option { - self.streams - .values() - .flat_map(|stream| { - stream - .control - .in_flight - .iter() - .filter_map(|(_, in_flight)| match in_flight.write_state { - InFlightWriteState::WaitingRetry { retry_at } => Some(retry_at), - InFlightWriteState::Ready | InFlightWriteState::Issued => None, - }) - }) - .min() - } -} - -pub fn close_frame( - stream_id: StreamId, - target: CloseTarget, - code: CloseCode, - payload: Vec, -) -> StreamFrame { - StreamFrame::Close(StreamFrameClose { - stream_id, - target, - code, - payload, - }) -} diff --git a/ql-engine/src/stream/tests.rs b/ql-engine/src/stream/tests.rs new file mode 100644 index 00000000..31a1f5de --- /dev/null +++ b/ql-engine/src/stream/tests.rs @@ -0,0 +1,334 @@ +use std::time::Instant; + +use super::{ + Outbound, StreamCloseEvent, StreamCloseKind, StreamError, StreamEventSink, StreamFsm, + StreamFsmConfig, StreamLocalRole, StreamNamespace, WriteError, +}; +use crate::{ + wire::stream::{ + BodyChunk, CloseCode, CloseTarget, StreamAck, StreamAckBody, StreamBody, StreamFrame, + StreamFrameClose, StreamFrameData, StreamFrameOpen, StreamMessage, + }, + StreamId, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +struct OpenedStream { + stream_id: StreamId, + request_head: Vec, + request_prefix: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct InboundChunk { + stream_id: StreamId, + bytes: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct StreamFailure { + stream_id: StreamId, + error: StreamError, +} + +#[derive(Debug, Default, Clone, PartialEq, Eq)] +struct Recorder { + opened: Vec, + closes: Vec, + inbound_data: Vec, + inbound_finished: Vec, + inbound_failed: Vec, + outbound_closed: Vec, + outbound_failed: Vec, + reaped: Vec, +} + +impl StreamEventSink for Recorder { + fn opened( + &mut self, + stream_id: StreamId, + request_head: Vec, + request_prefix: Option, + ) { + self.opened.push(OpenedStream { + stream_id, + request_head, + request_prefix, + }); + } + + fn inbound_data(&mut self, stream_id: StreamId, bytes: Vec) { + self.inbound_data.push(InboundChunk { stream_id, bytes }); + } + + fn inbound_finished(&mut self, stream_id: StreamId) { + self.inbound_finished.push(stream_id); + } + + fn inbound_failed(&mut self, stream_id: StreamId, error: StreamError) { + self.inbound_failed.push(StreamFailure { stream_id, error }); + } + + fn close(&mut self, event: StreamCloseEvent) { + self.closes.push(event); + } + + fn outbound_closed(&mut self, stream_id: StreamId) { + self.outbound_closed.push(stream_id); + } + + fn outbound_failed(&mut self, stream_id: StreamId, error: StreamError) { + self.outbound_failed + .push(StreamFailure { stream_id, error }); + } + + fn reaped(&mut self, stream_id: StreamId) { + self.reaped.push(stream_id); + } +} + +fn data_packet(stream_id: StreamId, tx_seq: u32, byte: u8) -> StreamBody { + StreamBody::Message(StreamMessage { + tx_seq: crate::wire::StreamSeq(tx_seq), + ack: StreamAck::EMPTY, + valid_until: 0, + frame: StreamFrame::Data(StreamFrameData { + stream_id, + chunk: BodyChunk { + bytes: vec![byte], + fin: false, + }, + }), + }) +} + +#[test] +fn open_stream_enqueues_open_packet() { + let now = Instant::now(); + let mut stream = StreamFsm::new(StreamFsmConfig::default()); + let stream_id = stream.open_stream(b"open".to_vec(), None); + + let outbound = stream.next_outbound(now, 7).unwrap(); + assert_open(outbound, stream_id, b"open", 7); +} + +#[test] +fn out_of_order_remote_stream_buffers_until_open_arrives() { + let now = Instant::now(); + let mut stream = StreamFsm::new(StreamFsmConfig { + local_namespace: StreamNamespace::Low, + ..Default::default() + }); + let stream_id = StreamId(StreamNamespace::High.bit() | 1); + + let mut events = Recorder::default(); + stream.receive(now, data_packet(stream_id, 2, b'h'), &mut events); + assert!(events.opened.is_empty()); + assert!(events.inbound_data.is_empty()); + + stream.receive( + now, + StreamBody::Message(StreamMessage { + tx_seq: crate::wire::StreamSeq::START, + ack: StreamAck::EMPTY, + valid_until: 0, + frame: StreamFrame::Open(StreamFrameOpen { + stream_id, + request_head: b"late-open".to_vec(), + request_prefix: None, + }), + }), + &mut events, + ); + + assert_eq!( + events.opened, + vec![OpenedStream { + stream_id, + request_head: b"late-open".to_vec(), + request_prefix: None, + }] + ); + assert_eq!( + events.inbound_data, + vec![InboundChunk { + stream_id, + bytes: vec![b'h'], + }] + ); +} + +#[test] +fn ack_only_write_failure_requeues_without_spending_sequence_space() { + let now = Instant::now(); + let config = StreamFsmConfig::default(); + let mut stream = StreamFsm::new(config); + let stream_id = StreamId(StreamNamespace::High.bit() | 1); + + let mut events = Recorder::default(); + stream.receive( + now, + StreamBody::Message(StreamMessage { + tx_seq: crate::wire::StreamSeq::START, + ack: StreamAck::EMPTY, + valid_until: 0, + frame: StreamFrame::Open(StreamFrameOpen { + stream_id, + request_head: b"open".to_vec(), + request_prefix: None, + }), + }), + &mut events, + ); + assert_eq!(events.opened.len(), 1); + + stream.on_timer(now + config.ack_delay, &mut ()); + let ack_write = stream.next_outbound(now + config.ack_delay, 11).unwrap(); + assert!(matches!( + ack_write.body, + StreamBody::Ack(StreamAckBody { + stream_id: id, + ack: StreamAck { + base: crate::wire::StreamSeq::START, + bitmap: 0, + }, + valid_until: 11, + }) if id == stream_id + )); + + stream.complete_outbound( + now + config.ack_delay, + ack_write.completion, + Err(WriteError::SendFailed), + &mut (), + ); + let retry = stream.next_outbound(now + config.ack_delay, 12).unwrap(); + assert!(matches!(retry.body, StreamBody::Ack(_))); + + stream.complete_outbound(now + config.ack_delay, retry.completion, Ok(()), &mut ()); + stream.write_stream(stream_id, b"resp".to_vec()).unwrap(); + let response = stream.next_outbound(now, 13).unwrap(); + assert!(matches!( + response.body, + StreamBody::Message(StreamMessage { + tx_seq: crate::wire::StreamSeq::START, + valid_until: 13, + frame: StreamFrame::Data(StreamFrameData { + stream_id: id, + chunk: BodyChunk { bytes, fin: false }, + }), + .. + }) if id == stream_id && bytes == b"resp" + )); +} + +#[test] +fn fast_retransmit_resends_oldest_gap_when_threshold_met() { + let now = Instant::now(); + let mut stream = StreamFsm::new(StreamFsmConfig { + fast_retransmit_threshold: 2, + ..Default::default() + }); + let stream_id = stream.open_stream(b"open".to_vec(), None); + let open = stream.next_outbound(now, 1).unwrap(); + stream.complete_outbound(now, open.completion, Ok(()), &mut ()); + stream.write_stream(stream_id, b"a".to_vec()).unwrap(); + stream.write_stream(stream_id, b"b".to_vec()).unwrap(); + stream.write_stream(stream_id, b"c".to_vec()).unwrap(); + stream.write_stream(stream_id, b"d".to_vec()).unwrap(); + let first = stream.next_outbound(now, 2).unwrap(); + let second = stream.next_outbound(now, 3).unwrap(); + let third = stream.next_outbound(now, 4).unwrap(); + let fourth = stream.next_outbound(now, 5).unwrap(); + stream.complete_outbound(now, first.completion, Ok(()), &mut ()); + stream.complete_outbound(now, second.completion, Ok(()), &mut ()); + stream.complete_outbound(now, third.completion, Ok(()), &mut ()); + stream.complete_outbound(now, fourth.completion, Ok(()), &mut ()); + + stream.receive( + now, + StreamBody::Ack(StreamAckBody { + stream_id, + ack: StreamAck { + base: crate::wire::StreamSeq(2), + bitmap: 0b0000_0110, + }, + valid_until: 0, + }), + &mut (), + ); + + let retransmit = stream.next_outbound(now, 6).unwrap(); + assert!(matches!( + retransmit.body, + StreamBody::Message(StreamMessage { + tx_seq: crate::wire::StreamSeq(3), + frame: StreamFrame::Data(_), + .. + }) + )); +} + +#[test] +fn late_failed_write_after_remote_close_ack_is_ignored() { + let now = Instant::now(); + let mut stream = StreamFsm::new(StreamFsmConfig::default()); + let stream_id = stream.open_stream(b"open".to_vec(), None); + let open = stream.next_outbound(now, 1).unwrap(); + + let mut events = Recorder::default(); + stream.receive( + now, + StreamBody::Message(StreamMessage { + tx_seq: crate::wire::StreamSeq::START, + ack: StreamAck { + base: crate::wire::StreamSeq::START, + bitmap: 0, + }, + valid_until: 0, + frame: StreamFrame::Close(StreamFrameClose { + stream_id, + target: CloseTarget::Both, + code: CloseCode::PROTOCOL, + payload: Vec::new(), + }), + }), + &mut events, + ); + assert_eq!( + events.closes, + vec![StreamCloseEvent { + kind: StreamCloseKind::Remote, + role: StreamLocalRole::Initiator, + frame: StreamFrameClose { + stream_id, + target: CloseTarget::Both, + code: CloseCode::PROTOCOL, + payload: Vec::new(), + }, + }] + ); + assert!(events.outbound_failed.is_empty()); + assert!(events.inbound_failed.is_empty()); + + let mut late = Recorder::default(); + stream.complete_outbound(now, open.completion, Err(WriteError::SendFailed), &mut late); + assert!(late.outbound_failed.is_empty()); + assert!(late.inbound_failed.is_empty()); +} + +fn assert_open(outbound: Outbound, stream_id: StreamId, request_head: &[u8], valid_until: u64) { + assert!(matches!( + outbound.body, + StreamBody::Message(StreamMessage { + tx_seq: crate::wire::StreamSeq::START, + ack: StreamAck::EMPTY, + valid_until: expires_at, + frame: StreamFrame::Open(StreamFrameOpen { + stream_id: id, + request_head: actual_head, + request_prefix: None, + }), + }) if id == stream_id && actual_head == request_head && expires_at == valid_until + )); +} diff --git a/ql-engine/src/wire/handshake/crypto.rs b/ql-engine/src/wire/handshake/crypto.rs index f2960ae9..cc79f7be 100644 --- a/ql-engine/src/wire/handshake/crypto.rs +++ b/ql-engine/src/wire/handshake/crypto.rs @@ -14,8 +14,7 @@ use crate::{ wire::{ access_value, deserialize_value, encode_value, encrypted_message::{EncryptedMessage, NONCE_SIZE}, - ensure_not_expired, AsWireMlKemCiphertext, AsWireNonce, AsWireXid, ControlMeta, - QlHeader, + ensure_not_expired, AsWireMlKemCiphertext, AsWireNonce, AsWireXid, ControlMeta, QlHeader, }, QlError, }; diff --git a/ql-engine/src/wire/handshake/mod.rs b/ql-engine/src/wire/handshake/mod.rs index 756a3e78..a049af62 100644 --- a/ql-engine/src/wire/handshake/mod.rs +++ b/ql-engine/src/wire/handshake/mod.rs @@ -2,8 +2,8 @@ use bc_components::{MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, Nonce}; use rkyv::{Archive, Deserialize, Serialize}; use super::{ - encrypted_message::EncryptedMessage, AsWireMlDsaSignature, AsWireMlKemCiphertext, - AsWireNonce, ControlMeta, + encrypted_message::EncryptedMessage, AsWireMlDsaSignature, AsWireMlKemCiphertext, AsWireNonce, + ControlMeta, }; use crate::QlError; diff --git a/ql-engine/src/wire/mod.rs b/ql-engine/src/wire/mod.rs index a1b7f548..bc889e58 100644 --- a/ql-engine/src/wire/mod.rs +++ b/ql-engine/src/wire/mod.rs @@ -196,7 +196,6 @@ mod test { } } - /* #[test] fn protocol_record_size_breakdown() { use crate::{ @@ -377,7 +376,6 @@ mod test { valid_until: now_secs().saturating_add(60), frame: stream::StreamFrame::Data(stream::StreamFrameData { stream_id: StreamId(2), - dir: stream::Direction::Request, chunk: stream::BodyChunk { bytes: vec![7, 8, 9, 10], fin: false, @@ -393,7 +391,6 @@ mod test { valid_until: now_secs().saturating_add(60), frame: stream::StreamFrame::Data(stream::StreamFrameData { stream_id: StreamId(2), - dir: stream::Direction::Request, chunk: stream::BodyChunk { bytes: vec![7, 8, 9, 10], fin: false, @@ -403,34 +400,6 @@ mod test { let stream_ack_size = stream_record_size(&stream_ack_body, 20); let stream_open_size = stream_record_size(&stream_open_body, 21); - let stream_accept_size = stream_record_size( - &stream::StreamBody::Message(stream::StreamMessage { - tx_seq: StreamSeq(22), - ack: stream::StreamAck::EMPTY, - valid_until: now_secs().saturating_add(60), - frame: stream::StreamFrame::Accept(stream::StreamFrameAccept { - stream_id: StreamId(2), - response_head: vec![4, 5, 6], - response_prefix: Some(stream::BodyChunk { - bytes: vec![1, 2], - fin: false, - }), - }), - }), - 22, - ); - let stream_reject_size = stream_record_size( - &stream::StreamBody::Message(stream::StreamMessage { - tx_seq: StreamSeq(23), - ack: stream::StreamAck::EMPTY, - valid_until: now_secs().saturating_add(60), - frame: stream::StreamFrame::Reject(stream::StreamFrameReject { - stream_id: StreamId(2), - code: stream::RejectCode::InvalidHead, - }), - }), - 23, - ); let stream_data_no_ack_size = stream_record_size(&stream_message_no_ack, 24); let stream_data_with_ack_size = stream_record_size(&stream_message_with_ack, 25); let stream_fin_size = stream_record_size( @@ -440,7 +409,6 @@ mod test { valid_until: now_secs().saturating_add(60), frame: stream::StreamFrame::Data(stream::StreamFrameData { stream_id: StreamId(2), - dir: stream::Direction::Response, chunk: stream::BodyChunk { bytes: Vec::new(), fin: true, @@ -454,10 +422,11 @@ mod test { tx_seq: StreamSeq(27), ack: stream::StreamAck::EMPTY, valid_until: now_secs().saturating_add(60), - frame: stream::StreamFrame::Reset(stream::StreamFrameReset { + frame: stream::StreamFrame::Close(stream::StreamFrameClose { stream_id: StreamId(2), - target: stream::ResetTarget::Both, - code: stream::ResetCode::Protocol, + target: stream::CloseTarget::Both, + code: stream::CloseCode::PROTOCOL, + payload: Vec::new(), }), }), 27, @@ -475,8 +444,6 @@ mod test { print_size("ql2 size unpair", unpair_size); print_size("ql2 size stream ack-only", stream_ack_size); print_size("ql2 size stream open", stream_open_size); - print_size("ql2 size stream accept", stream_accept_size); - print_size("ql2 size stream reject", stream_reject_size); print_size("ql2 size stream data no ack", stream_data_no_ack_size); print_size("ql2 size stream data w/ ack", stream_data_with_ack_size); print_size("ql2 size stream fin", stream_fin_size); @@ -502,21 +469,5 @@ mod test { stream_data_with_ack_size, stream_data_with_ack_size.saturating_sub(stream_data_no_ack_size), ); - - assert!(hello_size > 0); - assert!(reply_size > 0); - assert!(confirm_size > 0); - assert!(pair_size > 0); - assert!(heartbeat_size > 0); - assert!(unpair_size > 0); - assert!(stream_ack_size > 0); - assert!(stream_open_size > 0); - assert!(stream_accept_size > 0); - assert!(stream_reject_size > 0); - assert!(stream_data_no_ack_size > 0); - assert!(stream_data_with_ack_size > 0); - assert!(stream_fin_size > 0); - assert!(stream_reset_size > 0); } - */ } diff --git a/ql-fsm/Cargo.toml b/ql-fsm/Cargo.toml new file mode 100644 index 00000000..98b0abed --- /dev/null +++ b/ql-fsm/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "ql-fsm" +version = "0.1.0" +edition = "2021" +description = "Quantum Link synchronous finite state machine" +license = "Proprietary" + +[dependencies] +bc-components = { version = "0.28.0", default-features = false, features = [ + "pqcrypto", +] } +ql-wire = { path = "../ql-wire" } +rkyv = { version = "0.8", default-features = false, features = [ + "std", + "bytecheck", + "little_endian", + "unaligned", + "pointer_width_32", +] } +thiserror = { version = "2" } diff --git a/ql-fsm/src/implementation/handshake.rs b/ql-fsm/src/implementation/handshake.rs new file mode 100644 index 00000000..5820a170 --- /dev/null +++ b/ql-fsm/src/implementation/handshake.rs @@ -0,0 +1,713 @@ +use std::{cmp::Ordering, time::Instant}; + +use bc_components::{MLDSAPublicKey, SymmetricKey}; +use ql_wire::{ + self as wire, + handshake::{Confirm, Hello, HelloReply, Ready}, + ControlMeta, QlCrypto, QlHeader, XID, +}; +use rkyv::api::low; + +use crate::{ + HandshakeInitiator, HandshakeResponder, Peer, PeerSession, QlFsm, QlFsmError, RecentReady, +}; + +#[derive(Debug)] +enum HelloAction { + StartResponder, + ResendReply { reply: HelloReply }, + Ignore, +} + +#[derive(Debug)] +enum HelloReplyAction { + Advance { + hello: Hello, + initiator_secret: SymmetricKey, + responder_signing_key: MLDSAPublicKey, + }, + ResendConfirm { + confirm: Confirm, + }, +} + +#[derive(Debug, Clone)] +enum RetryAction { + Hello { peer: XID, hello: Hello }, + HelloReply { peer: XID, reply: HelloReply }, + Confirm { peer: XID, confirm: Confirm }, +} + +pub fn handle_connect(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { + start_initiator_handshake(fsm, crypto) +} + +pub fn handle_hello( + fsm: &mut QlFsm, + header: &QlHeader, + archived_hello: &wire::handshake::ArchivedHello, + crypto: &impl QlCrypto, +) -> Result<(), QlFsmError> { + let hello: Hello = deserialize_archived(archived_hello)?; + let action = { + let Some(entry) = fsm.peer.as_ref() else { + return Ok(()); + }; + if wire::handshake::verify_hello( + header.sender, + fsm.identity.xid, + &entry.peer.signing_key, + archived_hello, + ) + .is_err() + { + return Ok(()); + } + + match &entry.session { + PeerSession::Initiator { + hello: local_hello, .. + } => { + if peer_hello_wins(local_hello, fsm.identity.xid, &hello, header.sender) { + HelloAction::StartResponder + } else { + HelloAction::Ignore + } + } + PeerSession::Responder { + hello: stored, + reply, + stage: HandshakeResponder::WaitingConfirm { .. }, + .. + } => { + if same_hello(stored, &hello) { + HelloAction::ResendReply { + reply: reply.clone(), + } + } else { + HelloAction::StartResponder + } + } + PeerSession::Disconnected | PeerSession::Connected { .. } => { + HelloAction::StartResponder + } + } + }; + + match action { + HelloAction::Ignore => {} + HelloAction::ResendReply { reply } => { + fsm.enqueue_handshake( + header.sender, + wire::handshake::HandshakeRecord::HelloReply(reply), + ); + } + HelloAction::StartResponder => { + if fsm.is_replayed_control(header.sender, hello.meta) { + return Ok(()); + } + + let peer = fsm.peer.as_ref().map(|entry| entry.peer.clone()).unwrap(); + let reply_meta = fsm.next_control_meta(fsm.config.handshake_timeout); + let responder = wire::handshake::respond_hello( + &fsm.identity, + crypto, + peer.xid, + &peer.signing_key, + &peer.encapsulation_key, + archived_hello, + reply_meta, + ); + + let (reply, secrets) = match responder { + Ok(result) => result, + Err(_) => { + if let Some(entry) = fsm.peer.as_mut() { + entry.session = PeerSession::Disconnected; + } + fsm.emit_peer_status(); + return Ok(()); + } + }; + + let deadline = fsm.state.now.instant + fsm.config.handshake_timeout; + let retry_at = Some(fsm.state.now.instant + fsm.config.handshake_retry_interval); + if let Some(entry) = fsm.peer.as_mut() { + entry.session = PeerSession::Responder { + hello: hello.clone(), + reply: reply.clone(), + deadline, + stage: HandshakeResponder::WaitingConfirm { + secrets, + retry_count: 0, + retry_at, + }, + }; + } + fsm.enqueue_handshake( + header.sender, + wire::handshake::HandshakeRecord::HelloReply(reply), + ); + fsm.emit_peer_status(); + } + } + + Ok(()) +} + +pub fn handle_hello_reply( + fsm: &mut QlFsm, + header: &QlHeader, + archived_reply: &wire::handshake::ArchivedHelloReply, +) -> Result<(), QlFsmError> { + let reply: HelloReply = deserialize_archived(archived_reply)?; + let action = { + let Some(entry) = fsm.peer.as_ref() else { + return Ok(()); + }; + match &entry.session { + PeerSession::Initiator { + hello, + stage: + HandshakeInitiator::WaitingHelloReply { + initiator_secret, .. + }, + .. + } => HelloReplyAction::Advance { + hello: hello.clone(), + initiator_secret: initiator_secret.clone(), + responder_signing_key: entry.peer.signing_key.clone(), + }, + PeerSession::Initiator { + stage: + HandshakeInitiator::WaitingReady { + reply: stored, + confirm, + .. + }, + .. + } if same_reply(stored, &reply) => HelloReplyAction::ResendConfirm { + confirm: confirm.clone(), + }, + _ => return Ok(()), + } + }; + + match action { + HelloReplyAction::ResendConfirm { confirm } => { + fsm.enqueue_handshake( + header.sender, + wire::handshake::HandshakeRecord::Confirm(confirm), + ); + } + HelloReplyAction::Advance { + hello, + initiator_secret, + responder_signing_key, + } => { + let confirm_meta = fsm.next_control_meta(fsm.config.handshake_timeout); + let (confirm, session_key) = match wire::handshake::build_confirm( + &fsm.identity, + header.sender, + &responder_signing_key, + &hello, + archived_reply, + &initiator_secret, + confirm_meta, + ) { + Ok(result) => result, + Err(_) => return Ok(()), + }; + + if fsm.is_replayed_control(header.sender, reply.meta) { + return Ok(()); + } + + let deadline = fsm.state.now.instant + fsm.config.handshake_timeout; + let retry_at = Some(fsm.state.now.instant + fsm.config.handshake_retry_interval); + if let Some(entry) = fsm.peer.as_mut() { + entry.session = PeerSession::Initiator { + hello, + deadline, + stage: HandshakeInitiator::WaitingReady { + reply: reply.clone(), + confirm: confirm.clone(), + session_key, + retry_count: 0, + retry_at, + }, + }; + } + fsm.enqueue_handshake( + header.sender, + wire::handshake::HandshakeRecord::Confirm(confirm), + ); + } + } + + Ok(()) +} + +fn deserialize_archived( + value: &impl rkyv::Deserialize>, +) -> Result { + low::deserialize::(value).map_err(|_| QlFsmError::InvalidPayload) +} + +pub fn handle_confirm( + fsm: &mut QlFsm, + header: &QlHeader, + confirm: &wire::handshake::ArchivedConfirm, + crypto: &impl QlCrypto, +) -> Result<(), QlFsmError> { + if let Some(ready) = recent_ready_resend(fsm, header.sender, confirm) { + fsm.enqueue_handshake( + header.sender, + wire::handshake::HandshakeRecord::Ready(ready), + ); + return Ok(()); + } + + let outcome = { + let Some(entry) = fsm.peer.as_ref() else { + return Ok(()); + }; + let PeerSession::Responder { + hello, + reply, + deadline, + stage: HandshakeResponder::WaitingConfirm { secrets, .. }, + } = &entry.session + else { + return Ok(()); + }; + + wire::handshake::finalize_confirm( + header.sender, + fsm.identity.xid, + &entry.peer.signing_key, + hello, + reply, + confirm, + secrets, + ) + .map(|session_key| (hello.clone(), reply.clone(), *deadline, session_key)) + }; + + let (hello, reply, deadline, session_key) = match outcome { + Ok(result) => result, + Err(_) => return Ok(()), + }; + + let meta: ControlMeta = (&confirm.meta).into(); + if fsm.is_replayed_control(header.sender, meta) { + return Ok(()); + } + + let ready = wire::handshake::build_ready( + QlHeader { + sender: fsm.identity.xid, + recipient: header.sender, + }, + &session_key, + fsm.next_control_meta(fsm.config.handshake_timeout), + next_encrypted_nonce(crypto), + ); + + if let Some(entry) = fsm.peer.as_mut() { + entry.session = PeerSession::Connected { + session_key, + recent_ready: Some(RecentReady { + hello, + reply, + ready: ready.clone(), + expires_at: deadline, + }), + }; + } + + fsm.enqueue_handshake( + header.sender, + wire::handshake::HandshakeRecord::Ready(ready), + ); + fsm.emit_peer_status(); + Ok(()) +} + +pub fn handle_ready( + fsm: &mut QlFsm, + header: &QlHeader, + ready: &mut wire::handshake::ArchivedReady, +) -> Result<(), QlFsmError> { + let session_key = { + let Some(entry) = fsm.peer.as_ref() else { + return Ok(()); + }; + match &entry.session { + PeerSession::Initiator { + stage: HandshakeInitiator::WaitingReady { session_key, .. }, + .. + } => session_key.clone(), + _ => return Ok(()), + } + }; + + let body = match wire::handshake::decrypt_ready(header, ready, &session_key) { + Ok(body) => body, + Err(_) => return Ok(()), + }; + if fsm.is_replayed_control(header.sender, body.meta) { + return Ok(()); + } + + if let Some(entry) = fsm.peer.as_mut() { + entry.session = PeerSession::Connected { + session_key, + recent_ready: None, + }; + } + fsm.emit_peer_status(); + Ok(()) +} + +pub fn handle_timer(fsm: &mut QlFsm) { + let now = fsm.state.now.instant; + if let Some(PeerSession::Connected { + recent_ready: Some(recent_ready), + .. + }) = fsm.peer.as_mut().map(|entry| &mut entry.session) + { + if recent_ready.expires_at <= now { + if let Some(entry) = fsm.peer.as_mut() { + if let PeerSession::Connected { recent_ready, .. } = &mut entry.session { + *recent_ready = None; + } + } + } + } + + let mut retry_action = None; + let mut disconnected = false; + + if let Some(entry) = fsm.peer.as_mut() { + match &mut entry.session { + PeerSession::Initiator { + hello, + deadline, + stage, + } => { + if *deadline <= now { + entry.session = PeerSession::Disconnected; + disconnected = true; + } else { + retry_action = handle_initiator_retry( + &entry.peer, + hello, + stage, + now, + fsm.config.handshake_retry_interval, + fsm.config.max_handshake_retries, + ); + if retry_action.is_none() && initiator_retries_exhausted(stage) { + entry.session = PeerSession::Disconnected; + disconnected = true; + } + } + } + PeerSession::Responder { + reply, + deadline, + stage, + .. + } => { + if *deadline <= now { + entry.session = PeerSession::Disconnected; + disconnected = true; + } else { + retry_action = handle_responder_retry( + &entry.peer, + reply, + stage, + now, + fsm.config.handshake_retry_interval, + fsm.config.max_handshake_retries, + ); + if retry_action.is_none() && responder_retries_exhausted(stage) { + entry.session = PeerSession::Disconnected; + disconnected = true; + } + } + } + PeerSession::Disconnected | PeerSession::Connected { .. } => {} + } + } + + if disconnected { + fsm.emit_peer_status(); + } + + if let Some(action) = retry_action { + match action { + RetryAction::Hello { peer, hello } => { + fsm.enqueue_handshake(peer, wire::handshake::HandshakeRecord::Hello(hello)); + } + RetryAction::HelloReply { peer, reply } => { + fsm.enqueue_handshake(peer, wire::handshake::HandshakeRecord::HelloReply(reply)); + } + RetryAction::Confirm { peer, confirm } => { + fsm.enqueue_handshake(peer, wire::handshake::HandshakeRecord::Confirm(confirm)); + } + } + } +} + +pub fn next_deadline(fsm: &QlFsm) -> Option { + let mut deadline = None; + if let Some(entry) = fsm.peer.as_ref() { + match &entry.session { + PeerSession::Initiator { + deadline: session_deadline, + stage, + .. + } => { + deadline = Some(*session_deadline); + deadline = min_optional(deadline, initiator_retry_at(stage)); + } + PeerSession::Responder { + deadline: session_deadline, + stage, + .. + } => { + deadline = Some(*session_deadline); + deadline = min_optional(deadline, responder_retry_at(stage)); + } + PeerSession::Connected { + recent_ready: Some(recent_ready), + .. + } => { + deadline = Some(recent_ready.expires_at); + } + PeerSession::Disconnected | PeerSession::Connected { .. } => {} + } + } + deadline +} + +fn start_initiator_handshake(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { + let Some(entry) = fsm.peer.as_ref() else { + return Err(QlFsmError::NoPeerBound); + }; + if !matches!(entry.session, PeerSession::Disconnected) { + return Ok(()); + } + + let peer = entry.peer.clone(); + let meta = fsm.next_control_meta(fsm.config.handshake_timeout); + let (hello, initiator_secret) = wire::handshake::build_hello( + &fsm.identity, + crypto, + peer.xid, + &peer.encapsulation_key, + meta, + )?; + let deadline = fsm.state.now.instant + fsm.config.handshake_timeout; + let retry_at = Some(fsm.state.now.instant + fsm.config.handshake_retry_interval); + + if let Some(entry) = fsm.peer.as_mut() { + entry.session = PeerSession::Initiator { + hello: hello.clone(), + deadline, + stage: HandshakeInitiator::WaitingHelloReply { + initiator_secret, + retry_count: 0, + retry_at, + }, + }; + } + + fsm.enqueue_handshake(peer.xid, wire::handshake::HandshakeRecord::Hello(hello)); + fsm.emit_peer_status(); + Ok(()) +} + +fn recent_ready_resend( + fsm: &QlFsm, + peer: XID, + confirm: &wire::handshake::ArchivedConfirm, +) -> Option { + let entry = fsm.peer.as_ref()?; + let PeerSession::Connected { + recent_ready: Some(recent_ready), + .. + } = &entry.session + else { + return None; + }; + if recent_ready.expires_at <= fsm.state.now.instant { + return None; + } + wire::handshake::verify_confirm( + peer, + fsm.identity.xid, + &entry.peer.signing_key, + &recent_ready.hello, + &recent_ready.reply, + confirm, + ) + .ok()?; + Some(recent_ready.ready.clone()) +} + +fn handle_initiator_retry( + peer: &Peer, + hello: &Hello, + stage: &mut HandshakeInitiator, + now: Instant, + retry_interval: std::time::Duration, + max_retries: u8, +) -> Option { + match stage { + HandshakeInitiator::WaitingHelloReply { + retry_count, + retry_at, + .. + } => { + if retry_due(*retry_at, now) { + if *retry_count >= max_retries { + *retry_at = None; + None + } else { + *retry_count = retry_count.saturating_add(1); + *retry_at = Some(now + retry_interval); + Some(RetryAction::Hello { + peer: peer.xid, + hello: hello.clone(), + }) + } + } else { + None + } + } + HandshakeInitiator::WaitingReady { + confirm, + retry_count, + retry_at, + .. + } => { + if retry_due(*retry_at, now) { + if *retry_count >= max_retries { + *retry_at = None; + None + } else { + *retry_count = retry_count.saturating_add(1); + *retry_at = Some(now + retry_interval); + Some(RetryAction::Confirm { + peer: peer.xid, + confirm: confirm.clone(), + }) + } + } else { + None + } + } + } +} + +fn handle_responder_retry( + peer: &Peer, + reply: &HelloReply, + stage: &mut HandshakeResponder, + now: Instant, + retry_interval: std::time::Duration, + max_retries: u8, +) -> Option { + match stage { + HandshakeResponder::WaitingConfirm { + retry_count, + retry_at, + .. + } => { + if retry_due(*retry_at, now) { + if *retry_count >= max_retries { + *retry_at = None; + None + } else { + *retry_count = retry_count.saturating_add(1); + *retry_at = Some(now + retry_interval); + Some(RetryAction::HelloReply { + peer: peer.xid, + reply: reply.clone(), + }) + } + } else { + None + } + } + } +} + +fn initiator_retries_exhausted(stage: &HandshakeInitiator) -> bool { + match stage { + HandshakeInitiator::WaitingHelloReply { retry_at, .. } + | HandshakeInitiator::WaitingReady { retry_at, .. } => retry_at.is_none(), + } +} + +fn responder_retries_exhausted(stage: &HandshakeResponder) -> bool { + match stage { + HandshakeResponder::WaitingConfirm { retry_at, .. } => retry_at.is_none(), + } +} + +fn initiator_retry_at(stage: &HandshakeInitiator) -> Option { + match stage { + HandshakeInitiator::WaitingHelloReply { retry_at, .. } + | HandshakeInitiator::WaitingReady { retry_at, .. } => *retry_at, + } +} + +fn responder_retry_at(stage: &HandshakeResponder) -> Option { + match stage { + HandshakeResponder::WaitingConfirm { retry_at, .. } => *retry_at, + } +} + +fn same_hello(stored: &Hello, incoming: &Hello) -> bool { + stored.meta.control_id == incoming.meta.control_id && stored.nonce == incoming.nonce +} + +fn same_reply(stored: &HelloReply, incoming: &HelloReply) -> bool { + stored.meta.control_id == incoming.meta.control_id && stored.nonce == incoming.nonce +} + +fn peer_hello_wins( + local_hello: &Hello, + local_sender: XID, + peer_hello: &Hello, + peer_sender: XID, +) -> bool { + match peer_hello.nonce.0.cmp(&local_hello.nonce.0) { + Ordering::Less => true, + Ordering::Greater => false, + Ordering::Equal => peer_sender.0.cmp(&local_sender.0) == Ordering::Less, + } +} + +fn next_encrypted_nonce(crypto: &impl QlCrypto) -> wire::Nonce { + let mut bytes = [0u8; wire::Nonce::NONCE_SIZE]; + crypto.fill_random_bytes(&mut bytes); + wire::Nonce(bytes) +} + +fn retry_due(retry_at: Option, now: Instant) -> bool { + retry_at.is_some_and(|deadline| deadline <= now) +} + +fn min_optional(current: Option, other: Option) -> Option { + match (current, other) { + (Some(left), Some(right)) => Some(left.min(right)), + (Some(left), None) => Some(left), + (None, Some(right)) => Some(right), + (None, None) => None, + } +} diff --git a/ql-fsm/src/implementation/mod.rs b/ql-fsm/src/implementation/mod.rs new file mode 100644 index 00000000..af67a96d --- /dev/null +++ b/ql-fsm/src/implementation/mod.rs @@ -0,0 +1,135 @@ +pub mod handshake; +pub mod peer; + +use std::time::Duration; + +use ql_wire::{ + self as wire, handshake::ArchivedHandshakeRecord, ArchivedQlPayload, ControlId, ControlMeta, + QlCrypto, QlHeader, QlPayload, QlRecord, XID, +}; +use rkyv::api::low; + +use crate::{Peer, QlFsm, QlFsmError, QlFsmEvent}; + +impl QlFsm { + pub fn bind_peer_inner(&mut self, peer: Peer) { + peer::handle_bind_peer(self, peer); + } + + pub fn pair_inner(&mut self, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { + peer::handle_pair_local(self, crypto) + } + + pub fn connect_inner(&mut self, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { + handshake::handle_connect(self, crypto) + } + + pub fn receive_inner( + &mut self, + mut bytes: Vec, + crypto: &impl QlCrypto, + ) -> Result<(), QlFsmError> { + let archived = wire::access_record_mut(&mut bytes)?; + let archived = unsafe { archived.unseal_unchecked() }; + let header: QlHeader = deserialize_archived(&archived.header)?; + + if header.recipient != self.identity.xid { + return Ok(()); + } + if !matches!(&archived.payload, ArchivedQlPayload::Pair(_)) { + let Some(peer) = self.peer.as_ref().map(|entry| entry.peer.xid) else { + return Ok(()); + }; + if header.sender != peer { + return Ok(()); + } + } + + match &mut archived.payload { + ArchivedQlPayload::Pair(request) => { + peer::handle_pair(self, &header, request, crypto)?; + } + ArchivedQlPayload::Handshake(ArchivedHandshakeRecord::Hello(archived_hello)) => { + handshake::handle_hello(self, &header, archived_hello, crypto)?; + } + ArchivedQlPayload::Handshake(ArchivedHandshakeRecord::HelloReply(archived_reply)) => { + handshake::handle_hello_reply(self, &header, archived_reply)?; + } + ArchivedQlPayload::Handshake(ArchivedHandshakeRecord::Confirm(archived_confirm)) => { + handshake::handle_confirm(self, &header, archived_confirm, crypto)?; + } + ArchivedQlPayload::Handshake(ArchivedHandshakeRecord::Ready(archived_ready)) => { + handshake::handle_ready(self, &header, archived_ready)?; + } + ArchivedQlPayload::Encrypted(_) => {} + } + + Ok(()) + } + + pub fn on_timer_inner(&mut self) { + handshake::handle_timer(self); + } + + pub fn next_deadline_inner(&self) -> Option { + handshake::next_deadline(self) + } + + pub fn take_next_outbound_inner(&mut self) -> Option { + self.state.outbound.pop_front() + } + + pub fn take_next_event_inner(&mut self) -> Option { + self.state.events.pop_front() + } + + fn emit_peer_status(&mut self) { + if let Some(entry) = self.peer.as_ref() { + self.state.events.push_back(QlFsmEvent::PeerStatusChanged { + peer: entry.peer.xid, + status: entry.session.status(), + }); + } + } + + fn next_control_meta(&mut self, lifetime: Duration) -> ControlMeta { + let control_id = ControlId(self.state.next_control_id); + self.state.next_control_id = self.state.next_control_id.wrapping_add(1); + ControlMeta { + control_id, + valid_until: deadline_after_secs(self.state.now.unix_secs, lifetime), + } + } + + fn enqueue_handshake(&mut self, peer: XID, record: wire::handshake::HandshakeRecord) { + self.state.outbound.push_back(QlRecord { + header: QlHeader { + sender: self.identity.xid, + recipient: peer, + }, + payload: QlPayload::Handshake(record), + }); + } + + fn is_replayed_control(&mut self, peer: XID, meta: ControlMeta) -> bool { + self.state + .replay_cache + .check_and_store_valid_until(peer, meta, self.state.now.unix_secs) + } +} + +fn deadline_after_secs(now_secs: u64, duration: Duration) -> u64 { + now_secs.saturating_add(duration_to_secs(duration)) +} + +fn duration_to_secs(duration: Duration) -> u64 { + duration + .as_secs() + .saturating_add(u64::from(duration.subsec_nanos() > 0)) +} + +fn deserialize_archived( + value: &impl rkyv::Deserialize>, +) -> Result { + low::deserialize::(value).map_err(|_| QlFsmError::InvalidPayload) +} diff --git a/ql-fsm/src/implementation/peer.rs b/ql-fsm/src/implementation/peer.rs new file mode 100644 index 00000000..ec5a75c7 --- /dev/null +++ b/ql-fsm/src/implementation/peer.rs @@ -0,0 +1,56 @@ +use ql_wire::{self as wire, pair::ArchivedPairRequestRecord, QlCrypto, QlHeader}; + +use super::handshake; +use crate::{Peer, PeerRecord, QlFsm, QlFsmError, QlFsmEvent}; + +pub fn handle_bind_peer(fsm: &mut QlFsm, peer: Peer) { + bind_peer_record(fsm, peer); +} + +pub fn handle_pair_local(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { + let meta = fsm.next_control_meta(fsm.config.control_expiration); + let peer = fsm.peer.as_ref().ok_or(QlFsmError::NoPeerBound)?; + let record = wire::pair::build_pair_request( + &fsm.identity, + crypto, + peer.peer.xid, + &peer.peer.encapsulation_key, + meta, + )?; + fsm.state.outbound.push_back(record); + Ok(()) +} + +pub fn handle_pair( + fsm: &mut QlFsm, + header: &QlHeader, + request: &mut ArchivedPairRequestRecord, + crypto: &impl QlCrypto, +) -> Result<(), QlFsmError> { + let payload = match wire::pair::decrypt_pair_request(&fsm.identity, header, request) { + Ok(payload) => payload, + Err(_) => return Ok(()), + }; + let peer = Peer { + xid: ql_wire::XID::from_signing_public_key(&payload.signing_pub_key), + signing_key: payload.signing_pub_key, + encapsulation_key: payload.encapsulation_pub_key, + }; + if fsm.is_replayed_control(peer.xid, payload.meta) { + return Ok(()); + } + + match fsm.peer.as_ref() { + Some(existing) if existing.peer != peer => return Ok(()), + Some(_) => {} + None => bind_peer_record(fsm, peer.clone()), + } + + handshake::handle_connect(fsm, crypto) +} + +fn bind_peer_record(fsm: &mut QlFsm, peer: Peer) { + fsm.peer = Some(PeerRecord::new(peer.clone())); + fsm.state.events.push_back(QlFsmEvent::PersistPeer(peer)); + fsm.emit_peer_status(); +} diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs new file mode 100644 index 00000000..de3d552d --- /dev/null +++ b/ql-fsm/src/lib.rs @@ -0,0 +1,84 @@ +pub(crate) mod implementation; +pub(crate) mod replay_cache; +pub mod session; +pub(crate) mod state; + +use std::time::Instant; + +use ql_wire::{QlCrypto, QlIdentity, QlRecord}; +use state::{ + HandshakeInitiator, HandshakeResponder, Peer, PeerRecord, PeerSession, QlFsm, QlFsmConfig, + QlFsmError, QlFsmEvent, RecentReady, +}; + +use crate::{replay_cache::ReplayCache, state::QlFsmState}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct FsmTime { + pub instant: Instant, + pub unix_secs: u64, +} + +impl QlFsm { + pub fn new( + config: QlFsmConfig, + identity: QlIdentity, + peer: Option, + now: FsmTime, + ) -> Self { + let peer = peer.map(PeerRecord::new); + Self { + config, + identity, + peer, + state: QlFsmState { + replay_cache: ReplayCache::default(), + next_control_id: 1, + outbound: Default::default(), + events: Default::default(), + now, + }, + } + } + + pub fn bind_peer(&mut self, peer: Peer) { + self.bind_peer_inner(peer); + } + + pub fn pair(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { + self.state.now = now; + self.pair_inner(crypto) + } + + pub fn connect(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { + self.state.now = now; + self.connect_inner(crypto) + } + + pub fn receive( + &mut self, + now: FsmTime, + bytes: Vec, + crypto: &impl QlCrypto, + ) -> Result<(), QlFsmError> { + self.state.now = now; + self.receive_inner(bytes, crypto) + } + + pub fn on_timer(&mut self, now: FsmTime) { + self.state.now = now; + self.on_timer_inner(); + } + + pub fn next_deadline(&self) -> Option { + self.next_deadline_inner() + } + + pub fn take_next_outbound(&mut self) -> Option { + self.take_next_outbound_inner() + } + + pub fn take_next_event(&mut self) -> Option { + self.take_next_event_inner() + } +} diff --git a/ql-fsm/src/replay_cache.rs b/ql-fsm/src/replay_cache.rs new file mode 100644 index 00000000..4843c517 --- /dev/null +++ b/ql-fsm/src/replay_cache.rs @@ -0,0 +1,38 @@ +use std::collections::HashMap; + +use ql_wire::{ControlId, ControlMeta, XID}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct ReplayKey { + peer: XID, + control_id: ControlId, +} + +#[derive(Debug, Default)] +pub struct ReplayCache { + valid_until_by_key: HashMap, +} + +impl ReplayCache { + pub fn check_and_store_valid_until( + &mut self, + peer: XID, + meta: ControlMeta, + now_secs: u64, + ) -> bool { + self.valid_until_by_key + .retain(|_, valid_until| *valid_until > now_secs); + + let key = ReplayKey { + peer, + control_id: meta.control_id, + }; + + if self.valid_until_by_key.contains_key(&key) { + true + } else { + self.valid_until_by_key.insert(key, meta.valid_until); + false + } + } +} diff --git a/ql-fsm/src/session/internal.rs b/ql-fsm/src/session/internal.rs new file mode 100644 index 00000000..6a0c931b --- /dev/null +++ b/ql-fsm/src/session/internal.rs @@ -0,0 +1,628 @@ +use std::time::Instant; + +use ql_wire::{ + encrypted::{heartbeat::HeartbeatBody, unpair::UnpairBody}, + CloseCode, CloseTarget, SessionBody, SessionCloseBody, SessionEnvelope, SessionSeq, + StreamCloseFrame, StreamFrame, StreamId, +}; + +use super::{ + ring::SeqRingInsertError, + state::{ + AckState, PendingChunk, PendingSessionBody, PendingStreamBody, SessionFsmState, StreamRole, + StreamState, TxEntry, + }, + SessionEvent, SessionFsm, SessionFsmConfig, SessionState, StreamError, StreamIncoming, +}; + +impl SessionFsm { + pub fn new_inner(config: SessionFsmConfig) -> Self { + let now = Instant::now(); + Self { + config, + state: SessionFsmState { + now, + session_state: SessionState::Open, + next_stream_ordinal: 1, + next_seq: SessionSeq(1), + tx_ring: super::ring::SeqRing::new(SessionSeq(1)), + rx_ring: super::ring::SeqRing::new(SessionSeq(1)), + ack_state: AckState::Idle, + pending_control: Default::default(), + streams: Default::default(), + ready_streams: Default::default(), + events: Default::default(), + }, + } + } + + pub fn open_stream_inner(&mut self) -> Result { + self.ensure_session_open()?; + let stream_id = + StreamId(self.config.local_namespace.bit() | self.state.next_stream_ordinal); + self.state.next_stream_ordinal = self.state.next_stream_ordinal.saturating_add(1); + self.state + .streams + .insert(stream_id, StreamState::new(StreamRole::Initiator)); + Ok(stream_id) + } + + pub fn write_stream_inner( + &mut self, + stream_id: StreamId, + bytes: Vec, + ) -> Result<(), StreamError> { + self.ensure_session_open()?; + if bytes.is_empty() { + return Ok(()); + } + + let stream = self + .state + .streams + .get_mut(&stream_id) + .ok_or(StreamError::MissingStream)?; + if !stream.is_writable() { + return Err(StreamError::NotWritable); + } + + let frame = StreamFrame { + stream_id, + offset: stream.next_send_offset, + bytes, + fin: false, + }; + stream.next_send_offset += frame.bytes.len() as u64; + stream + .send_queue + .push_back(PendingStreamBody::Stream(frame)); + Self::mark_stream_ready(&mut self.state, stream_id); + Ok(()) + } + + pub fn finish_stream_inner(&mut self, stream_id: StreamId) -> Result<(), StreamError> { + self.ensure_session_open()?; + let stream = self + .state + .streams + .get_mut(&stream_id) + .ok_or(StreamError::MissingStream)?; + if !stream.is_writable() { + return Err(StreamError::NotWritable); + } + + stream.outbound_finished = true; + stream + .send_queue + .push_back(PendingStreamBody::Stream(StreamFrame { + stream_id, + offset: stream.next_send_offset, + bytes: Vec::new(), + fin: true, + })); + Self::mark_stream_ready(&mut self.state, stream_id); + Ok(()) + } + + pub fn close_stream_inner( + &mut self, + stream_id: StreamId, + target: CloseTarget, + code: CloseCode, + payload: Vec, + ) -> Result<(), StreamError> { + self.ensure_session_open()?; + let stream = self + .state + .streams + .get_mut(&stream_id) + .ok_or(StreamError::MissingStream)?; + + Self::apply_close_to_stream(stream, target); + stream + .send_queue + .push_back(PendingStreamBody::StreamClose(StreamCloseFrame { + stream_id, + target, + code, + payload, + })); + Self::mark_stream_ready(&mut self.state, stream_id); + Ok(()) + } + + pub fn queue_heartbeat_inner(&mut self) -> Result<(), StreamError> { + self.ensure_session_open()?; + self.state.pending_control.heartbeat = true; + Ok(()) + } + + pub fn queue_unpair_inner(&mut self) -> Result<(), StreamError> { + self.ensure_session_open()?; + self.state.pending_control.unpair = true; + Ok(()) + } + + pub fn close_session_inner(&mut self, code: CloseCode) { + self.fail_session(SessionCloseBody { code }); + } + + pub fn receive_inner(&mut self, envelope: SessionEnvelope) { + self.collect_timeouts(); + self.process_ack(envelope.ack); + + if self.state.session_state == SessionState::Closed { + return; + } + + let seq = envelope.seq; + if seq.0 < self.state.rx_ring.base_seq().0 || self.state.rx_ring.contains_key(&seq) { + self.schedule_ack(true); + return; + } + match self.state.rx_ring.insert(seq, ()) { + Ok(()) => { + let out_of_order = seq != self.state.rx_ring.base_seq(); + self.state.rx_ring.advance_occupied_front(); + self.schedule_ack(out_of_order); + } + Err(SeqRingInsertError::OutOfWindow) => { + self.fail_session(SessionCloseBody { + code: CloseCode::PROTOCOL, + }); + return; + } + Err(SeqRingInsertError::Occupied) => { + self.schedule_ack(true); + return; + } + } + + match envelope.body { + SessionBody::Heartbeat(_) => {} + SessionBody::Unpair(_) => { + self.state.session_state = SessionState::Closed; + self.clear_streams(); + self.state.events.push_back(SessionEvent::Unpaired); + } + SessionBody::Close(close) => { + self.state.session_state = SessionState::Closed; + self.clear_streams(); + self.state + .events + .push_back(SessionEvent::SessionClosed(close)); + } + SessionBody::Stream(frame) => self.handle_stream_frame(frame), + SessionBody::StreamClose(frame) => self.handle_stream_close(frame), + } + } + + pub fn next_outbound_inner(&mut self) -> Option { + self.collect_timeouts(); + let pending = self.next_pending_body()?; + if !self.state.tx_ring.accepts_seq(self.state.next_seq) { + if pending.priority { + self.requeue_pending_front(pending); + } + return None; + } + + let seq = self.state.next_seq; + self.state.next_seq = SessionSeq(seq.0 + 1); + let ack = self.state.current_ack(); + self.state.clear_ack_schedule(); + let envelope = SessionEnvelope { + seq, + ack, + body: pending.body.clone(), + }; + let entry = TxEntry { + pending, + sent_at: self.state.now, + }; + let _ = self.state.tx_ring.insert(seq, entry); + Some(envelope) + } + + pub fn on_timer_inner(&mut self) { + self.collect_timeouts(); + if let AckState::Delayed { due_at } = self.state.ack_state { + if due_at <= self.state.now { + self.state.ack_state = AckState::Immediate; + } + } + } + + pub fn next_deadline_inner(&self) -> Option { + let ack_deadline = match self.state.ack_state { + AckState::Idle => None, + AckState::Immediate => Some(self.state.now), + AckState::Delayed { due_at } => Some(due_at), + }; + let retransmit_deadline = self + .state + .tx_ring + .iter() + .map(|(_, entry)| entry.sent_at + self.config.retransmit_timeout) + .min(); + [ack_deadline, retransmit_deadline] + .into_iter() + .flatten() + .min() + } + + pub fn take_next_event_inner(&mut self) -> Option { + self.state.events.pop_front() + } + + pub fn take_next_inbound_inner(&mut self, stream_id: StreamId) -> Option { + self.state + .streams + .get_mut(&stream_id) + .and_then(|stream| stream.inbound_queue.pop_front()) + } + + pub fn session_state_inner(&self) -> SessionState { + self.state.session_state + } + + fn next_pending_body(&mut self) -> Option { + if let Some(close) = self.state.pending_control.close.take() { + return Some(PendingSessionBody { + body: SessionBody::Close(close), + retransmit: true, + priority: true, + }); + } + if self.state.pending_control.unpair { + self.state.pending_control.unpair = false; + return Some(PendingSessionBody { + body: SessionBody::Unpair(UnpairBody), + retransmit: true, + priority: true, + }); + } + if self.state.pending_control.heartbeat { + self.state.pending_control.heartbeat = false; + return Some(PendingSessionBody { + body: SessionBody::Heartbeat(HeartbeatBody), + retransmit: false, + priority: true, + }); + } + + while let Some(stream_id) = self.state.ready_streams.pop_front() { + let Some(stream) = self.state.streams.get_mut(&stream_id) else { + continue; + }; + stream.ready_enqueued = false; + let Some(item) = stream.send_queue.pop_front() else { + continue; + }; + if !stream.send_queue.is_empty() { + Self::mark_stream_ready(&mut self.state, stream_id); + } + return Some(PendingSessionBody { + body: item.to_session_body(), + retransmit: true, + priority: true, + }); + } + + let ack_due = match self.state.ack_state { + AckState::Immediate => true, + AckState::Delayed { due_at } => due_at <= self.state.now, + AckState::Idle => false, + }; + ack_due.then_some(PendingSessionBody { + body: SessionBody::Heartbeat(HeartbeatBody), + retransmit: false, + priority: false, + }) + } + + fn ensure_session_open(&self) -> Result<(), StreamError> { + if self.state.session_state == SessionState::Closed { + Err(StreamError::SessionClosed) + } else { + Ok(()) + } + } + + fn process_ack(&mut self, ack: ql_wire::SessionAck) { + let acked: Vec<_> = self + .state + .tx_ring + .iter() + .filter_map(|(seq, _)| Self::ack_covers(ack, seq).then_some(seq)) + .collect(); + for seq in acked { + let _ = self.state.tx_ring.remove(&seq); + } + self.state + .tx_ring + .advance_empty_front_until(self.state.next_seq); + } + + fn ack_covers(ack: ql_wire::SessionAck, seq: SessionSeq) -> bool { + if seq.0 <= ack.base.0 { + return true; + } + let delta = seq.0 - ack.base.0; + if delta == 0 || delta > 64 { + return false; + } + (ack.bitmap & (1u64 << (delta - 1))) != 0 + } + + fn schedule_ack(&mut self, immediate: bool) { + self.state.ack_state = match self.state.ack_state { + AckState::Immediate => AckState::Immediate, + _ if immediate || self.config.ack_delay.is_zero() => AckState::Immediate, + AckState::Delayed { due_at } => AckState::Delayed { due_at }, + AckState::Idle => AckState::Delayed { + due_at: self.state.now + self.config.ack_delay, + }, + }; + } + + fn collect_timeouts(&mut self) { + let expired: Vec<_> = self + .state + .tx_ring + .iter() + .filter_map(|(seq, entry)| { + (entry.sent_at + self.config.retransmit_timeout <= self.state.now).then_some(seq) + }) + .collect(); + + for seq in expired { + if let Some(entry) = self.state.tx_ring.remove(&seq) { + if entry.pending.retransmit { + self.requeue_pending_front(entry.pending); + } + } + } + + self.state + .tx_ring + .advance_empty_front_until(self.state.next_seq); + } + + fn requeue_pending_front(&mut self, pending: PendingSessionBody) { + match pending.body { + SessionBody::Stream(frame) => { + if let Some(stream) = self.state.streams.get_mut(&frame.stream_id) { + let stream_id = frame.stream_id; + stream + .send_queue + .push_front(PendingStreamBody::Stream(frame)); + Self::mark_stream_ready_front(&mut self.state, stream_id); + } + } + SessionBody::StreamClose(frame) => { + if let Some(stream) = self.state.streams.get_mut(&frame.stream_id) { + let stream_id = frame.stream_id; + stream + .send_queue + .push_front(PendingStreamBody::StreamClose(frame)); + Self::mark_stream_ready_front(&mut self.state, stream_id); + } + } + body => match body { + SessionBody::Heartbeat(_) => self.state.pending_control.heartbeat = true, + SessionBody::Unpair(_) => self.state.pending_control.unpair = true, + SessionBody::Close(close) => self.state.pending_control.close = Some(close), + SessionBody::Stream(_) | SessionBody::StreamClose(_) => unreachable!(), + }, + } + } + + fn mark_stream_ready(state: &mut SessionFsmState, stream_id: StreamId) { + let Some(stream) = state.streams.get_mut(&stream_id) else { + return; + }; + if stream.ready_enqueued { + return; + } + stream.ready_enqueued = true; + state.ready_streams.push_back(stream_id); + } + + fn mark_stream_ready_front(state: &mut SessionFsmState, stream_id: StreamId) { + let Some(stream) = state.streams.get_mut(&stream_id) else { + return; + }; + if stream.ready_enqueued { + return; + } + stream.ready_enqueued = true; + state.ready_streams.push_front(stream_id); + } + + fn handle_stream_frame(&mut self, frame: StreamFrame) { + let stream_id = frame.stream_id; + let remote_namespace = self.config.local_namespace.remote(); + if !self.state.streams.contains_key(&stream_id) { + if !remote_namespace.matches(stream_id) || frame.offset != 0 { + self.fail_session(SessionCloseBody { + code: CloseCode::PROTOCOL, + }); + return; + } + self.state + .streams + .insert(stream_id, StreamState::new(StreamRole::Responder)); + self.state.events.push_back(SessionEvent::Opened(stream_id)); + } + + let Some(stream) = self.state.streams.get_mut(&stream_id) else { + return; + }; + if stream.inbound_discarding { + return; + } + if stream.inbound_closed || stream.inbound_finished { + if frame.offset + frame.bytes.len() as u64 <= stream.next_recv_offset { + return; + } + self.fail_session(SessionCloseBody { + code: CloseCode::PROTOCOL, + }); + return; + } + + if frame.offset < stream.next_recv_offset { + let frame_end = frame.offset + frame.bytes.len() as u64; + if frame_end <= stream.next_recv_offset { + return; + } + self.fail_session(SessionCloseBody { + code: CloseCode::PROTOCOL, + }); + return; + } + + if frame.offset == stream.next_recv_offset { + Self::commit_inbound_frame(stream, frame); + Self::drain_pending_recv(stream); + self.state + .events + .push_back(SessionEvent::Readable(stream_id)); + return; + } + + if Self::insert_pending_chunk( + stream, + frame.offset, + PendingChunk { + bytes: frame.bytes, + fin: frame.fin, + }, + ) + .is_err() + { + self.fail_session(SessionCloseBody { + code: CloseCode::PROTOCOL, + }); + } + } + + fn handle_stream_close(&mut self, frame: StreamCloseFrame) { + let Some(stream) = self.state.streams.get_mut(&frame.stream_id) else { + self.fail_session(SessionCloseBody { + code: CloseCode::PROTOCOL, + }); + return; + }; + + if Self::target_affects_inbound(stream.role, frame.target) { + stream.inbound_closed = true; + stream.inbound_discarding = false; + stream.pending_recv.clear(); + stream + .inbound_queue + .push_back(StreamIncoming::Closed(frame.clone())); + self.state + .events + .push_back(SessionEvent::Readable(frame.stream_id)); + } + if Self::target_affects_outbound(stream.role, frame.target) { + stream.outbound_closed = true; + stream.send_queue.clear(); + self.state + .events + .push_back(SessionEvent::WritableClosed(frame.stream_id)); + } + } + + fn apply_close_to_stream(stream: &mut StreamState, target: CloseTarget) { + if Self::target_affects_inbound(stream.role, target) { + stream.inbound_discarding = true; + stream.pending_recv.clear(); + } + if Self::target_affects_outbound(stream.role, target) { + stream.outbound_closed = true; + stream.outbound_finished = true; + stream.send_queue.clear(); + } + } + + fn target_affects_inbound(role: StreamRole, target: CloseTarget) -> bool { + matches!(target, CloseTarget::Both) || role.inbound_target() == target + } + + fn target_affects_outbound(role: StreamRole, target: CloseTarget) -> bool { + matches!(target, CloseTarget::Both) || role.outbound_target() == target + } + + fn commit_inbound_frame(stream: &mut StreamState, frame: StreamFrame) { + Self::commit_inbound_chunk(stream, frame.bytes, frame.fin); + } + + fn commit_inbound_chunk(stream: &mut StreamState, bytes: Vec, fin: bool) { + stream.next_recv_offset += bytes.len() as u64; + if !bytes.is_empty() { + stream.inbound_queue.push_back(StreamIncoming::Data(bytes)); + } + if fin { + stream.inbound_finished = true; + stream.inbound_queue.push_back(StreamIncoming::Finished); + } + } + + fn drain_pending_recv(stream: &mut StreamState) { + while let Some(chunk) = stream.pending_recv.remove(&stream.next_recv_offset) { + Self::commit_inbound_chunk(stream, chunk.bytes, chunk.fin); + if stream.inbound_finished { + break; + } + } + } + + fn insert_pending_chunk( + stream: &mut StreamState, + offset: u64, + chunk: PendingChunk, + ) -> Result<(), ()> { + let end = chunk.end_offset(offset); + + if let Some((&prev_offset, prev)) = stream.pending_recv.range(..=offset).next_back() { + let prev_end = prev.end_offset(prev_offset); + if prev_end > offset { + if prev_offset == offset && prev.bytes == chunk.bytes && prev.fin == chunk.fin { + return Ok(()); + } + return Err(()); + } + } + + if let Some((&next_offset, _)) = stream.pending_recv.range(offset..).next() { + if end > next_offset { + return Err(()); + } + } + + stream.pending_recv.insert(offset, chunk); + Ok(()) + } + + fn fail_session(&mut self, close: SessionCloseBody) { + if self.state.session_state == SessionState::Closed { + return; + } + + self.state.session_state = SessionState::Closed; + self.clear_streams(); + self.state.pending_control = Default::default(); + self.state.pending_control.close = Some(close.clone()); + self.state + .events + .push_back(SessionEvent::SessionClosed(close)); + } + + fn clear_streams(&mut self) { + self.state.ready_streams.clear(); + self.state.streams.clear(); + } +} diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs new file mode 100644 index 00000000..570527be --- /dev/null +++ b/ql-fsm/src/session/mod.rs @@ -0,0 +1,174 @@ +pub(crate) mod internal; +pub(crate) mod ring; +pub(crate) mod state; + +#[cfg(test)] +mod tests; + +use std::time::{Duration, Instant}; + +use ql_wire::{ + CloseCode, CloseTarget, SessionCloseBody, SessionEnvelope, StreamCloseFrame, StreamId, XID, +}; + +use self::state::SessionFsmState; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamNamespace { + Low, + High, +} + +impl StreamNamespace { + const BIT: u32 = 1 << 31; + + pub fn for_local(local: XID, peer: XID) -> Self { + match local.0.cmp(&peer.0) { + std::cmp::Ordering::Less | std::cmp::Ordering::Equal => Self::Low, + std::cmp::Ordering::Greater => Self::High, + } + } + + pub fn bit(self) -> u32 { + match self { + Self::Low => 0, + Self::High => Self::BIT, + } + } + + pub fn matches(self, stream_id: StreamId) -> bool { + (stream_id.0 & Self::BIT) == self.bit() + } + + pub fn remote(self) -> Self { + match self { + Self::Low => Self::High, + Self::High => Self::Low, + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct SessionFsmConfig { + pub local_namespace: StreamNamespace, + pub ack_delay: Duration, + pub retransmit_timeout: Duration, +} + +impl Default for SessionFsmConfig { + fn default() -> Self { + Self { + local_namespace: StreamNamespace::Low, + ack_delay: Duration::from_millis(5), + retransmit_timeout: Duration::from_millis(150), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SessionEvent { + Opened(StreamId), + Readable(StreamId), + WritableClosed(StreamId), + Unpaired, + SessionClosed(SessionCloseBody), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum StreamIncoming { + Data(Vec), + Finished, + Closed(StreamCloseFrame), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SessionState { + Open, + Closed, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +pub enum StreamError { + #[error("missing stream")] + MissingStream, + #[error("stream is not writable")] + NotWritable, + #[error("session is closed")] + SessionClosed, +} + +pub struct SessionFsm { + config: SessionFsmConfig, + state: SessionFsmState, +} + +impl SessionFsm { + pub fn new(config: SessionFsmConfig) -> Self { + Self::new_inner(config) + } + + pub fn open_stream(&mut self) -> Result { + self.open_stream_inner() + } + + pub fn write_stream(&mut self, stream_id: StreamId, bytes: Vec) -> Result<(), StreamError> { + self.write_stream_inner(stream_id, bytes) + } + + pub fn finish_stream(&mut self, stream_id: StreamId) -> Result<(), StreamError> { + self.finish_stream_inner(stream_id) + } + + pub fn close_stream( + &mut self, + stream_id: StreamId, + target: CloseTarget, + code: CloseCode, + payload: Vec, + ) -> Result<(), StreamError> { + self.close_stream_inner(stream_id, target, code, payload) + } + + pub fn queue_heartbeat(&mut self) -> Result<(), StreamError> { + self.queue_heartbeat_inner() + } + + pub fn queue_unpair(&mut self) -> Result<(), StreamError> { + self.queue_unpair_inner() + } + + pub fn close_session(&mut self, code: CloseCode) { + self.close_session_inner(code); + } + + pub fn receive(&mut self, now: Instant, envelope: SessionEnvelope) { + self.state.now = now; + self.receive_inner(envelope); + } + + pub fn next_outbound(&mut self, now: Instant) -> Option { + self.state.now = now; + self.next_outbound_inner() + } + + pub fn on_timer(&mut self, now: Instant) { + self.state.now = now; + self.on_timer_inner(); + } + + pub fn next_deadline(&self) -> Option { + self.next_deadline_inner() + } + + pub fn take_next_event(&mut self) -> Option { + self.take_next_event_inner() + } + + pub fn take_next_inbound(&mut self, stream_id: StreamId) -> Option { + self.take_next_inbound_inner(stream_id) + } + + pub fn session_state(&self) -> SessionState { + self.session_state_inner() + } +} diff --git a/ql-fsm/src/session/ring.rs b/ql-fsm/src/session/ring.rs new file mode 100644 index 00000000..872c92a5 --- /dev/null +++ b/ql-fsm/src/session/ring.rs @@ -0,0 +1,141 @@ +use std::array; + +use ql_wire::SessionSeq; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SeqRingInsertError { + OutOfWindow, + Occupied, +} + +#[derive(Debug)] +pub struct SeqRing { + base_seq: SessionSeq, + head: usize, + len: usize, + slots: [Option; N], +} + +impl SeqRing { + pub fn new(base_seq: SessionSeq) -> Self { + Self { + base_seq, + head: 0, + len: 0, + slots: array::from_fn(|_| None), + } + } + + pub fn base_seq(&self) -> SessionSeq { + self.base_seq + } + + pub fn accepts_seq(&self, seq: SessionSeq) -> bool { + self.offset_for(seq).is_some() + } + + pub fn contains_key(&self, seq: &SessionSeq) -> bool { + self.get(seq).is_some() + } + + pub fn get(&self, seq: &SessionSeq) -> Option<&T> { + let index = self.index_for(*seq)?; + self.slots[index].as_ref() + } + + pub fn insert(&mut self, seq: SessionSeq, value: T) -> Result<(), SeqRingInsertError> { + let index = self.index_for(seq).ok_or(SeqRingInsertError::OutOfWindow)?; + if self.slots[index].is_some() { + return Err(SeqRingInsertError::Occupied); + } + self.slots[index] = Some(value); + self.len += 1; + Ok(()) + } + + pub fn remove(&mut self, seq: &SessionSeq) -> Option { + let index = self.index_for(*seq)?; + let value = self.slots[index].take(); + if value.is_some() { + self.len -= 1; + } + value + } + + pub fn advance_empty_front_until(&mut self, limit_seq: SessionSeq) { + while self.base_seq.0 < limit_seq.0 && self.slots[self.head].is_none() { + self.head = self.next_index(self.head); + self.base_seq = SessionSeq(self.base_seq.0 + 1); + } + } + + pub fn advance_occupied_front(&mut self) { + while self.slots[self.head].is_some() { + let _ = self.slots[self.head].take(); + self.len -= 1; + self.head = self.next_index(self.head); + self.base_seq = SessionSeq(self.base_seq.0 + 1); + } + } + + pub fn iter(&self) -> SeqRingIter<'_, N, T> { + SeqRingIter { + ring: self, + offset: 0, + } + } + + pub fn bitmap(&self) -> u64 { + debug_assert!(N <= 64); + let mut bitmap = 0u64; + for offset in 0..N { + let index = self.index_for_offset(offset); + if self.slots[index].is_some() { + bitmap |= 1u64 << offset; + } + } + bitmap + } + + fn index_for(&self, seq: SessionSeq) -> Option { + let offset = self.offset_for(seq)?; + Some(self.index_for_offset(offset)) + } + + fn offset_for(&self, seq: SessionSeq) -> Option { + if seq.0 < self.base_seq.0 { + return None; + } + let offset = (seq.0 - self.base_seq.0) as usize; + (offset < N).then_some(offset) + } + + fn index_for_offset(&self, offset: usize) -> usize { + (self.head + offset) % N + } + + fn next_index(&self, index: usize) -> usize { + (index + 1) % N + } +} + +pub struct SeqRingIter<'a, const N: usize, T> { + ring: &'a SeqRing, + offset: usize, +} + +impl<'a, const N: usize, T> Iterator for SeqRingIter<'a, N, T> { + type Item = (SessionSeq, &'a T); + + fn next(&mut self) -> Option { + while self.offset < N { + let offset = self.offset; + self.offset += 1; + let index = self.ring.index_for_offset(offset); + if let Some(value) = self.ring.slots[index].as_ref() { + return Some((SessionSeq(self.ring.base_seq.0 + offset as u64), value)); + } + } + None + } +} diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs new file mode 100644 index 00000000..f922a8b0 --- /dev/null +++ b/ql-fsm/src/session/state.rs @@ -0,0 +1,156 @@ +use std::{ + collections::{BTreeMap, HashMap, VecDeque}, + time::Instant, +}; + +use ql_wire::{ + CloseTarget, SessionAck, SessionBody, SessionCloseBody, SessionSeq, StreamCloseFrame, + StreamFrame, StreamId, +}; + +use super::ring::SeqRing; +use super::{SessionEvent, SessionState, StreamIncoming}; + +pub const SESSION_WINDOW_CAPACITY: usize = 64; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamRole { + Initiator, + Responder, +} + +impl StreamRole { + pub fn outbound_target(self) -> CloseTarget { + match self { + Self::Initiator => CloseTarget::Request, + Self::Responder => CloseTarget::Response, + } + } + + pub fn inbound_target(self) -> CloseTarget { + match self { + Self::Initiator => CloseTarget::Response, + Self::Responder => CloseTarget::Request, + } + } +} + +#[derive(Debug, Clone)] +pub struct PendingChunk { + pub bytes: Vec, + pub fin: bool, +} + +impl PendingChunk { + pub fn end_offset(&self, offset: u64) -> u64 { + offset + self.bytes.len() as u64 + } +} + +#[derive(Debug, Clone)] +pub enum PendingStreamBody { + Stream(StreamFrame), + StreamClose(StreamCloseFrame), +} + +impl PendingStreamBody { + pub fn to_session_body(&self) -> SessionBody { + match self { + Self::Stream(frame) => SessionBody::Stream(frame.clone()), + Self::StreamClose(frame) => SessionBody::StreamClose(frame.clone()), + } + } +} + +#[derive(Debug)] +pub struct StreamState { + pub role: StreamRole, + pub send_queue: VecDeque, + pub inbound_queue: VecDeque, + pub pending_recv: BTreeMap, + pub next_send_offset: u64, + pub next_recv_offset: u64, + pub outbound_finished: bool, + pub outbound_closed: bool, + pub inbound_finished: bool, + pub inbound_closed: bool, + pub inbound_discarding: bool, + pub ready_enqueued: bool, +} + +impl StreamState { + pub fn new(role: StreamRole) -> Self { + Self { + role, + send_queue: VecDeque::new(), + inbound_queue: VecDeque::new(), + pending_recv: BTreeMap::new(), + next_send_offset: 0, + next_recv_offset: 0, + outbound_finished: false, + outbound_closed: false, + inbound_finished: false, + inbound_closed: false, + inbound_discarding: false, + ready_enqueued: false, + } + } + + pub fn is_writable(&self) -> bool { + !self.outbound_finished && !self.outbound_closed + } +} + +#[derive(Debug, Clone)] +pub struct PendingSessionBody { + pub body: SessionBody, + pub retransmit: bool, + pub priority: bool, +} + +#[derive(Debug, Clone, Default)] +pub struct PendingSessionControl { + pub heartbeat: bool, + pub unpair: bool, + pub close: Option, +} + +#[derive(Debug, Clone)] +pub struct TxEntry { + pub pending: PendingSessionBody, + pub sent_at: Instant, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AckState { + Idle, + Delayed { due_at: Instant }, + Immediate, +} + +pub struct SessionFsmState { + pub now: Instant, + pub session_state: SessionState, + pub next_stream_ordinal: u32, + pub next_seq: SessionSeq, + pub tx_ring: SeqRing, + pub rx_ring: SeqRing, + pub ack_state: AckState, + pub pending_control: PendingSessionControl, + pub streams: HashMap, + pub ready_streams: VecDeque, + pub events: VecDeque, +} + +impl SessionFsmState { + pub fn current_ack(&self) -> SessionAck { + SessionAck { + base: SessionSeq(self.rx_ring.base_seq().0.saturating_sub(1)), + bitmap: self.rx_ring.bitmap(), + } + } + + pub fn clear_ack_schedule(&mut self) { + self.ack_state = AckState::Idle; + } +} diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs new file mode 100644 index 00000000..e1dbe3cb --- /dev/null +++ b/ql-fsm/src/session/tests.rs @@ -0,0 +1,158 @@ +use std::time::{Duration, Instant}; + +use ql_wire::{ + encrypted::heartbeat::HeartbeatBody, CloseCode, CloseTarget, SessionAck, SessionBody, + SessionEnvelope, SessionSeq, StreamFrame, +}; + +use super::{SessionFsm, SessionFsmConfig, SessionState}; + +fn heartbeat(seq: u64, ack: SessionAck) -> SessionEnvelope { + SessionEnvelope { + seq: SessionSeq(seq), + ack, + body: SessionBody::Heartbeat(HeartbeatBody), + } +} + +#[test] +fn outbound_session_seq_increments_monotonically() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default()); + let stream_id = fsm.open_stream().unwrap(); + + fsm.write_stream(stream_id, b"one".to_vec()).unwrap(); + let first = fsm.next_outbound(now).unwrap(); + + fsm.write_stream(stream_id, b"two".to_vec()).unwrap(); + let second = fsm.next_outbound(now + Duration::from_millis(1)).unwrap(); + + assert_eq!(first.seq, SessionSeq(1)); + assert_eq!(second.seq, SessionSeq(2)); +} + +#[test] +fn inbound_ack_removes_acked_tx_entries() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default()); + let stream_id = fsm.open_stream().unwrap(); + + fsm.write_stream(stream_id, b"one".to_vec()).unwrap(); + let first = fsm.next_outbound(now).unwrap(); + assert_eq!(first.seq, SessionSeq(1)); + assert!(fsm.state.tx_ring.contains_key(&SessionSeq(1))); + + fsm.receive( + now + Duration::from_millis(1), + heartbeat( + 1, + SessionAck { + base: SessionSeq(1), + bitmap: 0, + }, + ), + ); + + assert!(!fsm.state.tx_ring.contains_key(&SessionSeq(1))); +} + +#[test] +fn out_of_order_receive_produces_bitmap_ack_then_advances_base() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default()); + + fsm.receive(now, heartbeat(2, SessionAck::EMPTY)); + let gap_ack = fsm.next_outbound(now).unwrap(); + assert_eq!(gap_ack.seq, SessionSeq(1)); + assert_eq!( + gap_ack.ack, + SessionAck { + base: SessionSeq(0), + bitmap: 0b10, + } + ); + + fsm.receive( + now + Duration::from_millis(1), + heartbeat(1, SessionAck::EMPTY), + ); + let contiguous_ack = fsm.next_outbound(now + Duration::from_millis(10)).unwrap(); + assert_eq!(contiguous_ack.seq, SessionSeq(2)); + assert_eq!( + contiguous_ack.ack, + SessionAck { + base: SessionSeq(2), + bitmap: 0, + } + ); +} + +#[test] +fn retransmit_requeues_body_with_new_session_seq() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default()); + let stream_id = fsm.open_stream().unwrap(); + + fsm.write_stream(stream_id, b"retry-me".to_vec()).unwrap(); + let first = fsm.next_outbound(now).unwrap(); + + let retransmit_at = now + Duration::from_millis(200); + let retried = fsm.next_outbound(retransmit_at).unwrap(); + + assert_eq!(first.seq, SessionSeq(1)); + assert_eq!(retried.seq, SessionSeq(2)); + assert_eq!(retried.body, first.body); +} + +#[test] +fn repeated_outbound_messages_keep_reporting_latest_receive_ack() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default()); + let stream_id = fsm.open_stream().unwrap(); + + fsm.receive(now, heartbeat(1, SessionAck::EMPTY)); + + fsm.write_stream(stream_id, b"one".to_vec()).unwrap(); + let first = fsm.next_outbound(now).unwrap(); + + fsm.write_stream(stream_id, b"two".to_vec()).unwrap(); + let second = fsm.next_outbound(now + Duration::from_millis(1)).unwrap(); + + assert_eq!(first.ack.base, SessionSeq(1)); + assert_eq!(second.ack.base, SessionSeq(1)); + assert_eq!(first.ack.bitmap, 0); + assert_eq!(second.ack.bitmap, 0); +} + +#[test] +fn local_inbound_close_ignores_late_remote_bytes() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default()); + let stream_id = fsm.open_stream().unwrap(); + + fsm.close_stream( + stream_id, + CloseTarget::Response, + CloseCode::CANCELLED, + Vec::new(), + ) + .unwrap(); + + fsm.receive( + now, + SessionEnvelope { + seq: SessionSeq(1), + ack: SessionAck::EMPTY, + body: SessionBody::Stream(StreamFrame { + stream_id, + offset: 0, + bytes: b"late".to_vec(), + fin: false, + }), + }, + ); + + assert_eq!(fsm.session_state(), SessionState::Open); + assert!(fsm.take_next_inbound(stream_id).is_none()); + assert!(fsm.take_next_event().is_none()); +} diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs new file mode 100644 index 00000000..5e9b108d --- /dev/null +++ b/ql-fsm/src/state.rs @@ -0,0 +1,177 @@ +use std::{ + collections::VecDeque, + time::{Duration, Instant}, +}; + +use bc_components::{MLDSAPublicKey, MLKEMPublicKey, SymmetricKey}; +use ql_wire::{ + handshake::{Confirm, Hello, HelloReply, Ready, ResponderSecrets}, + QlIdentity, QlRecord, WireError, XID, +}; +use thiserror::Error; + +use crate::{replay_cache::ReplayCache, FsmTime}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Peer { + pub xid: XID, + pub signing_key: MLDSAPublicKey, + pub encapsulation_key: MLKEMPublicKey, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PeerStatus { + Disconnected, + Initiator, + Responder, + Connected, +} + +#[derive(Debug, Clone)] +pub enum QlFsmEvent { + PersistPeer(Peer), + ClearPeer, + PeerStatusChanged { peer: XID, status: PeerStatus }, +} + +#[derive(Debug, Clone, Copy)] +pub struct QlFsmConfig { + pub handshake_timeout: Duration, + pub handshake_retry_interval: Duration, + pub max_handshake_retries: u8, + pub control_expiration: Duration, +} + +impl Default for QlFsmConfig { + fn default() -> Self { + Self { + handshake_timeout: Duration::from_secs(5), + handshake_retry_interval: Duration::from_millis(750), + max_handshake_retries: 3, + control_expiration: Duration::from_secs(30), + } + } +} + +#[derive(Debug, Clone)] +pub enum HandshakeInitiator { + WaitingHelloReply { + initiator_secret: SymmetricKey, + retry_count: u8, + retry_at: Option, + }, + WaitingReady { + reply: HelloReply, + confirm: Confirm, + session_key: SymmetricKey, + retry_count: u8, + retry_at: Option, + }, +} + +#[derive(Debug, Clone)] +pub enum HandshakeResponder { + WaitingConfirm { + secrets: ResponderSecrets, + retry_count: u8, + retry_at: Option, + }, +} + +#[derive(Debug, Clone)] +pub struct RecentReady { + pub hello: Hello, + pub reply: HelloReply, + pub ready: Ready, + pub expires_at: Instant, +} + +#[derive(Debug, Clone)] +pub enum PeerSession { + Disconnected, + Initiator { + hello: Hello, + deadline: Instant, + stage: HandshakeInitiator, + }, + Responder { + hello: Hello, + reply: HelloReply, + deadline: Instant, + stage: HandshakeResponder, + }, + Connected { + session_key: SymmetricKey, + recent_ready: Option, + }, +} + +impl PeerSession { + pub fn status(&self) -> PeerStatus { + match self { + Self::Disconnected => PeerStatus::Disconnected, + Self::Initiator { .. } => PeerStatus::Initiator, + Self::Responder { .. } => PeerStatus::Responder, + Self::Connected { .. } => PeerStatus::Connected, + } + } + + pub fn session_key(&self) -> Option<&SymmetricKey> { + match self { + Self::Connected { session_key, .. } => Some(session_key), + _ => None, + } + } +} + +#[derive(Debug, Clone)] +pub struct PeerRecord { + pub peer: Peer, + pub session: PeerSession, +} + +impl PeerRecord { + pub fn new(peer: Peer) -> Self { + Self { + peer, + session: PeerSession::Disconnected, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Error)] +pub enum QlFsmError { + #[error("invalid payload")] + InvalidPayload, + #[error("invalid signature")] + InvalidSignature, + #[error("expired")] + Expired, + #[error("no peer bound")] + NoPeerBound, +} + +impl From for QlFsmError { + fn from(value: WireError) -> Self { + match value { + WireError::InvalidPayload => Self::InvalidPayload, + WireError::InvalidSignature => Self::InvalidSignature, + WireError::Expired => Self::Expired, + } + } +} + +pub struct QlFsm { + pub config: QlFsmConfig, + pub identity: QlIdentity, + pub peer: Option, + pub state: QlFsmState, +} + +pub struct QlFsmState { + pub replay_cache: ReplayCache, + pub next_control_id: u32, + pub outbound: VecDeque, + pub events: VecDeque, + pub now: FsmTime, +} diff --git a/ql-runtime/src/driver.rs b/ql-runtime/src/driver.rs index 4c7eb536..40238f27 100644 --- a/ql-runtime/src/driver.rs +++ b/ql-runtime/src/driver.rs @@ -8,12 +8,12 @@ use std::{ use futures_lite::future::poll_fn; use crate::{ - engine::{Engine, EngineInput, EngineOutput, WriteId}, command::RuntimeCommand, + engine::{Engine, EngineEventSink, WriteId}, handle::{InboundByteStream, InboundStream, OutboundByteStream}, platform::{PlatformFuture, QlPlatform}, wire::stream::{BodyChunk, CloseCode, CloseTarget}, - HandlerEvent, InboundEvent, OpenedStreamDelivery, QlError, Runtime, StreamId, + HandlerEvent, InboundEvent, OpenedStreamDelivery, Peer, QlError, Runtime, StreamId, }; struct InFlightWrite<'a> { @@ -21,6 +21,22 @@ struct InFlightWrite<'a> { future: PlatformFuture<'a, Result<(), QlError>>, } +enum PendingAction { + CloseStream { + stream_id: StreamId, + target: CloseTarget, + code: CloseCode, + payload: Vec, + }, + OutboundData { + stream_id: StreamId, + bytes: Vec, + }, + OutboundFinished { + stream_id: StreamId, + }, +} + enum DriverEvent { Command(RuntimeCommand), WriteCompleted { @@ -51,7 +67,7 @@ impl OutboundIo { *self = Self::Closed; } - fn poll_pending(&mut self, stream_id: StreamId, pending_inputs: &mut VecDeque) { + fn poll_pending(&mut self, stream_id: StreamId, pending_inputs: &mut VecDeque) { let Self::Open { reader, finish_queued, @@ -66,13 +82,13 @@ impl OutboundIo { let read = reader.try_drain(&mut bytes); if read > 0 { bytes.truncate(read); - pending_inputs.push_back(EngineInput::OutboundData { stream_id, bytes }); + pending_inputs.push_back(PendingAction::OutboundData { stream_id, bytes }); } } if reader.is_closed() && !*finish_queued { *finish_queued = true; - pending_inputs.push_back(EngineInput::OutboundFinished { stream_id }); + pending_inputs.push_back(PendingAction::OutboundFinished { stream_id }); } } } @@ -92,9 +108,9 @@ impl InboundIo { stream_id: StreamId, target: CloseTarget, bytes: Vec, - ) -> Option { + ) -> Option { let Self::Open(tx) = self else { - return Some(EngineInput::CloseStream { + return Some(PendingAction::CloseStream { stream_id, target, code: CloseCode::CANCELLED, @@ -104,7 +120,7 @@ impl InboundIo { if tx.try_send(InboundEvent::Data(bytes)).is_err() { tx.close(); *self = Self::Closed; - return Some(EngineInput::CloseStream { + return Some(PendingAction::CloseStream { stream_id, target, code: CloseCode::CANCELLED, @@ -137,22 +153,6 @@ impl InboundIo { } *self = Self::Closed; } - - fn apply_prefix( - &mut self, - stream_id: StreamId, - target: CloseTarget, - prefix: &BodyChunk, - ) -> Option { - let mut input = None; - if !prefix.bytes.is_empty() { - input = self.write_or_close(stream_id, target, prefix.bytes.clone()); - } - if prefix.fin { - self.finish(); - } - input - } } enum DriverStreamIo { @@ -167,7 +167,7 @@ enum DriverStreamIo { } impl DriverStreamIo { - fn poll_pending(&mut self, stream_id: StreamId, pending_inputs: &mut VecDeque) { + fn poll_pending(&mut self, stream_id: StreamId, pending_inputs: &mut VecDeque) { match self { Self::Initiator { request, .. } => request.poll_pending(stream_id, pending_inputs), Self::Responder { response, .. } => response.poll_pending(stream_id, pending_inputs), @@ -209,9 +209,147 @@ impl DriverStreamIo { } } +struct DriverEventSink<'a, P> { + platform: &'a P, + runtime_tx: &'a async_channel::Sender, + stream_send_buffer_bytes: usize, + pending_inputs: &'a mut VecDeque, + streams: &'a mut HashMap, +} + +impl<'a, P> DriverEventSink<'a, P> { + fn new( + platform: &'a P, + runtime_tx: &'a async_channel::Sender, + stream_send_buffer_bytes: usize, + pending_inputs: &'a mut VecDeque, + streams: &'a mut HashMap, + ) -> Self { + Self { + platform, + runtime_tx, + stream_send_buffer_bytes, + pending_inputs, + streams, + } + } +} + +impl EngineEventSink for DriverEventSink<'_, P> { + fn peer_status_changed( + &mut self, + peer: bc_components::XID, + session: crate::engine::PeerSession, + ) { + self.platform.handle_peer_status(peer, &session); + } + + fn persist_peer(&mut self, peer: Peer) { + self.platform.persist_peer(peer); + } + + fn clear_peer(&mut self) { + self.platform.clear_peer(); + } + + fn inbound_stream_opened( + &mut self, + stream_id: StreamId, + request_head: Vec, + request_prefix: Option, + ) { + let (request_tx, request_rx) = async_channel::unbounded(); + let mut request = InboundIo::new(request_tx); + if let Some(prefix) = request_prefix.as_ref() { + if !prefix.bytes.is_empty() { + let InboundIo::Open(tx) = &request else { + unreachable!("fresh inbound stream must be open"); + }; + tx.try_send(InboundEvent::Data(prefix.bytes.clone())) + .expect("new inbound stream prefix send should succeed"); + } + if prefix.fin { + request.finish(); + } + } + + let (response_reader, response_writer) = piper::pipe(self.stream_send_buffer_bytes); + self.streams.insert( + stream_id, + DriverStreamIo::Responder { + request, + response: OutboundIo::new(response_reader), + }, + ); + + self.platform + .handle_inbound(HandlerEvent::Stream(InboundStream { + stream_id, + request_head, + request: InboundByteStream::new( + stream_id, + CloseTarget::Request, + request_rx, + self.runtime_tx.clone(), + ), + response: OutboundByteStream::new( + stream_id, + CloseTarget::Response, + response_writer, + self.runtime_tx.clone(), + ), + })); + } + + fn inbound_data(&mut self, stream_id: StreamId, bytes: Vec) { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + let target = stream.inbound_target(); + let inbound = stream.inbound_mut(); + if let Some(input) = inbound.write_or_close(stream_id, target, bytes) { + self.pending_inputs.push_back(input); + } + } + + fn inbound_finished(&mut self, stream_id: StreamId) { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + stream.inbound_mut().finish(); + } + + fn inbound_failed(&mut self, stream_id: StreamId, error: QlError) { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + stream.inbound_mut().fail(error); + } + + fn outbound_closed(&mut self, stream_id: StreamId) { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + stream.outbound_mut().close(); + } + + fn outbound_failed(&mut self, stream_id: StreamId, _error: QlError) { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + stream.outbound_mut().close(); + } + + fn stream_reaped(&mut self, stream_id: StreamId) { + if let Some(mut stream) = self.streams.remove(&stream_id) { + stream.close_all(); + } + } +} + struct DriverState { engine: Engine, - pending_inputs: VecDeque, + pending_inputs: VecDeque, streams: HashMap, runtime_tx: async_channel::Sender, stream_send_buffer_bytes: usize, @@ -227,19 +365,56 @@ impl DriverState { ) { match command { RuntimeCommand::BindPeer { peer } => { - self.drive_input(EngineInput::BindPeer(peer), platform, in_flight); + let now = Instant::now(); + let mut events = DriverEventSink::new( + platform, + &self.runtime_tx, + self.stream_send_buffer_bytes, + &mut self.pending_inputs, + &mut self.streams, + ); + self.engine.bind_peer(now, peer, &mut events); + self.finish_step(platform, in_flight); } RuntimeCommand::Pair => { - self.drive_input(EngineInput::Pair, platform, in_flight); + self.engine.pair(Instant::now(), platform); + self.finish_step(platform, in_flight); } RuntimeCommand::Connect => { - self.drive_input(EngineInput::Connect, platform, in_flight); + let now = Instant::now(); + let mut events = DriverEventSink::new( + platform, + &self.runtime_tx, + self.stream_send_buffer_bytes, + &mut self.pending_inputs, + &mut self.streams, + ); + self.engine.connect(now, platform, &mut events); + self.finish_step(platform, in_flight); } RuntimeCommand::Unpair => { - self.drive_input(EngineInput::Unpair, platform, in_flight); + let now = Instant::now(); + let mut events = DriverEventSink::new( + platform, + &self.runtime_tx, + self.stream_send_buffer_bytes, + &mut self.pending_inputs, + &mut self.streams, + ); + self.engine.unpair(now, &mut events); + self.finish_step(platform, in_flight); } RuntimeCommand::Incoming(bytes) => { - self.drive_input(EngineInput::Incoming(bytes), platform, in_flight); + let now = Instant::now(); + let mut events = DriverEventSink::new( + platform, + &self.runtime_tx, + self.stream_send_buffer_bytes, + &mut self.pending_inputs, + &mut self.streams, + ); + self.engine.receive(now, bytes, platform, &mut events); + self.finish_step(platform, in_flight); } RuntimeCommand::OpenStream { request_head, @@ -282,30 +457,14 @@ impl DriverState { code, payload, } => { - self.drive_input( - EngineInput::CloseStream { - stream_id, - target, - code, - payload, - }, - platform, - in_flight, - ); + let _ = self + .engine + .close_stream(Instant::now(), stream_id, target, code, payload); + self.finish_step(platform, in_flight); } } } - fn drive_input<'a, P: QlPlatform>( - &mut self, - input: EngineInput, - platform: &'a P, - in_flight: &mut Vec>, - ) { - self.pending_inputs.push_back(input); - self.drive_pending(platform, in_flight); - } - fn drive_write_completed<'a, P: QlPlatform>( &mut self, write_id: WriteId, @@ -314,23 +473,18 @@ impl DriverState { in_flight: &mut Vec>, ) { { - let runtime_tx = self.runtime_tx.clone(); - let stream_send_buffer_bytes = self.stream_send_buffer_bytes; - let pending_inputs = &mut self.pending_inputs; - let streams = &mut self.streams; - self.engine.complete_write(write_id, result, &mut |output| { - handle_output( - output, - platform, - &runtime_tx, - stream_send_buffer_bytes, - pending_inputs, - streams, - ) - }); + let now = self.engine.state.now; + let mut events = DriverEventSink::new( + platform, + &self.runtime_tx, + self.stream_send_buffer_bytes, + &mut self.pending_inputs, + &mut self.streams, + ); + self.engine + .complete_write(now, write_id, result, &mut events); } - self.fill_write_slots(platform, in_flight); - self.drive_pending(platform, in_flight); + self.finish_step(platform, in_flight); } fn drive_pending<'a, P: QlPlatform>( @@ -339,22 +493,24 @@ impl DriverState { in_flight: &mut Vec>, ) { while let Some(input) = self.pending_inputs.pop_front() { - { - let runtime_tx = &self.runtime_tx; - let stream_send_buffer_bytes = self.stream_send_buffer_bytes; - let pending_inputs = &mut self.pending_inputs; - let streams = &mut self.streams; - self.engine - .run_tick(Instant::now(), input, platform, &mut |output| { - handle_output( - output, - platform, - runtime_tx, - stream_send_buffer_bytes, - pending_inputs, - streams, - ) - }); + let now = Instant::now(); + match input { + PendingAction::CloseStream { + stream_id, + target, + code, + payload, + } => { + let _ = self + .engine + .close_stream(now, stream_id, target, code, payload); + } + PendingAction::OutboundData { stream_id, bytes } => { + let _ = self.engine.write_stream(now, stream_id, bytes); + } + PendingAction::OutboundFinished { stream_id } => { + let _ = self.engine.finish_stream(now, stream_id); + } } self.fill_write_slots(platform, in_flight); } @@ -362,13 +518,39 @@ impl DriverState { self.fill_write_slots(platform, in_flight); } + fn drive_timer<'a, P: QlPlatform>( + &mut self, + platform: &'a P, + in_flight: &mut Vec>, + ) { + let now = Instant::now(); + let mut events = DriverEventSink::new( + platform, + &self.runtime_tx, + self.stream_send_buffer_bytes, + &mut self.pending_inputs, + &mut self.streams, + ); + self.engine.on_timer(now, platform, &mut events); + self.finish_step(platform, in_flight); + } + + fn finish_step<'a, P: QlPlatform>( + &mut self, + platform: &'a P, + in_flight: &mut Vec>, + ) { + self.fill_write_slots(platform, in_flight); + self.drive_pending(platform, in_flight); + } + fn fill_write_slots<'a, P: QlPlatform>( &mut self, platform: &'a P, in_flight: &mut Vec>, ) { while in_flight.len() < self.max_concurrent_message_writes { - let Some(write) = self.engine.take_next_write(platform) else { + let Some(write) = self.engine.take_next_write(self.engine.state.now, platform) else { break; }; in_flight.push(InFlightWrite { @@ -386,96 +568,6 @@ impl DriverState { } } -fn handle_output( - output: EngineOutput, - platform: &P, - runtime_tx: &async_channel::Sender, - stream_send_buffer_bytes: usize, - pending_inputs: &mut VecDeque, - streams: &mut HashMap, -) { - match output { - EngineOutput::PeerStatusChanged { peer, session } => { - platform.handle_peer_status(peer, &session); - } - EngineOutput::PersistPeer(peer) => platform.persist_peer(peer), - EngineOutput::ClearPeer => platform.clear_peer(), - EngineOutput::InboundStreamOpened { - stream_id, - request_head, - request_prefix, - } => { - let (request_tx, request_rx) = async_channel::unbounded(); - let mut request = InboundIo::new(request_tx); - if let Some(prefix) = request_prefix.as_ref() { - if let Some(input) = request.apply_prefix(stream_id, CloseTarget::Request, prefix) { - pending_inputs.push_back(input); - } - } - - let (response_reader, response_writer) = piper::pipe(stream_send_buffer_bytes); - streams.insert( - stream_id, - DriverStreamIo::Responder { - request, - response: OutboundIo::new(response_reader), - }, - ); - - platform.handle_inbound(HandlerEvent::Stream(InboundStream { - stream_id, - request_head, - request: InboundByteStream::new( - stream_id, - CloseTarget::Request, - request_rx, - runtime_tx.clone(), - ), - response: OutboundByteStream::new( - stream_id, - CloseTarget::Response, - response_writer, - runtime_tx.clone(), - ), - })); - } - EngineOutput::InboundData { stream_id, bytes } => { - let Some(stream) = streams.get_mut(&stream_id) else { - return; - }; - let target = stream.inbound_target(); - let inbound = stream.inbound_mut(); - if let Some(input) = inbound.write_or_close(stream_id, target, bytes) { - pending_inputs.push_back(input); - } - } - EngineOutput::InboundFinished { stream_id } => { - let Some(stream) = streams.get_mut(&stream_id) else { - return; - }; - stream.inbound_mut().finish(); - } - EngineOutput::InboundFailed { stream_id, error } => { - let Some(stream) = streams.get_mut(&stream_id) else { - return; - }; - stream.inbound_mut().fail(error); - } - EngineOutput::OutboundClosed { stream_id } - | EngineOutput::OutboundFailed { stream_id, .. } => { - let Some(stream) = streams.get_mut(&stream_id) else { - return; - }; - stream.outbound_mut().close(); - } - EngineOutput::StreamReaped { stream_id } => { - if let Some(mut stream) = streams.remove(&stream_id) { - stream.close_all(); - } - } - } -} - async fn next_driver_event( rx: &async_channel::Receiver, platform: &P, @@ -555,7 +647,7 @@ impl Runtime

{ state.drive_write_completed(write_id, result, &platform, &mut in_flight); } DriverEvent::TimerExpired => { - state.drive_input(EngineInput::TimerExpired, &platform, &mut in_flight); + state.drive_timer(&platform, &mut in_flight); } DriverEvent::Closed => break, } diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index e7c3e65f..30aa2ff8 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -20,8 +20,8 @@ use crate::{ HandlerEvent, KeepAliveConfig, Peer, PeerSession, QlError, RuntimeConfig, RuntimeHandle, }; -mod heartbeat; mod handshake; +mod heartbeat; mod stream; mod unpair; @@ -93,13 +93,7 @@ impl TestPlatform { seed: u8, fail_stream_write_at: usize, ) -> (Self, Receiver>, Receiver) { - Self::new_inner( - seed, - None, - Some(fail_stream_write_at), - Duration::ZERO, - None, - ) + Self::new_inner(seed, None, Some(fail_stream_write_at), Duration::ZERO, None) } fn new_with_delayed_writes( diff --git a/ql-wire/Cargo.toml b/ql-wire/Cargo.toml new file mode 100644 index 00000000..eb1a76a7 --- /dev/null +++ b/ql-wire/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "ql-wire" +version = "0.1.0" +edition = "2021" +description = "Quantum Link wire format types and crypto helpers" +license = "Proprietary" + +[dependencies] +bc-components = { version = "0.28.0", default-features = false, features = [ + "pqcrypto", +] } +chacha20poly1305 = { version = "0.10.1" } +rkyv = { version = "0.8", default-features = false, features = [ + "std", + "bytecheck", + "little_endian", + "unaligned", + "pointer_width_32", +] } +thiserror = { version = "2" } diff --git a/ql-wire/src/codec.rs b/ql-wire/src/codec.rs new file mode 100644 index 00000000..ea1d76f0 --- /dev/null +++ b/ql-wire/src/codec.rs @@ -0,0 +1,280 @@ +use bc_components::{ + MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, MLKEMPublicKey, MLDSA, MLKEM, +}; +use rkyv::{ + rancor::{Fallible, Source}, + with::{ArchiveWith, DeserializeWith, SerializeWith}, + Archive, Archived, Deserialize, Place, Resolver, Serialize, +}; + +use crate::WireError; + +macro_rules! impl_wire_wrapper { + ($marker:ident, $external:ty, $wire:ty) => { + pub(crate) struct $marker; + + impl ArchiveWith<$external> for $marker { + type Archived = Archived<$wire>; + type Resolver = Resolver<$wire>; + + fn resolve_with( + field: &$external, + resolver: Self::Resolver, + out: Place, + ) { + <$wire>::from(field).resolve(resolver, out); + } + } + + impl SerializeWith<$external, S> for $marker + where + S: Fallible + ?Sized, + $wire: Serialize, + { + fn serialize_with( + field: &$external, + serializer: &mut S, + ) -> Result { + <$wire>::from(field).serialize(serializer) + } + } + + impl DeserializeWith, $external, D> for $marker + where + D: Fallible + ?Sized, + D::Error: Source, + Archived<$wire>: Deserialize<$wire, D>, + $wire: TryInto<$external, Error = WireError>, + { + fn deserialize_with( + field: &Archived<$wire>, + deserializer: &mut D, + ) -> Result<$external, D::Error> { + field + .deserialize(deserializer)? + .try_into() + .map_err(D::Error::new) + } + } + }; +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u8)] +pub(crate) enum WireMlDsaLevel { + MlDsa44 = 2, + MlDsa65 = 3, + MlDsa87 = 5, +} + +impl TryFrom for MLDSA { + type Error = WireError; + + fn try_from(value: WireMlDsaLevel) -> Result { + Ok(match value { + WireMlDsaLevel::MlDsa44 => MLDSA::MLDSA44, + WireMlDsaLevel::MlDsa65 => MLDSA::MLDSA65, + WireMlDsaLevel::MlDsa87 => MLDSA::MLDSA87, + }) + } +} + +impl From for WireMlDsaLevel { + fn from(value: MLDSA) -> Self { + match value { + MLDSA::MLDSA44 => Self::MlDsa44, + MLDSA::MLDSA65 => Self::MlDsa65, + MLDSA::MLDSA87 => Self::MlDsa87, + } + } +} + +impl From<&ArchivedWireMlDsaLevel> for MLDSA { + fn from(value: &ArchivedWireMlDsaLevel) -> Self { + match value { + ArchivedWireMlDsaLevel::MlDsa44 => MLDSA::MLDSA44, + ArchivedWireMlDsaLevel::MlDsa65 => MLDSA::MLDSA65, + ArchivedWireMlDsaLevel::MlDsa87 => MLDSA::MLDSA87, + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u8)] +pub(crate) enum WireMlKemLevel { + MlKem512 = 1, + MlKem768 = 2, + MlKem1024 = 3, +} + +impl TryFrom for MLKEM { + type Error = WireError; + + fn try_from(value: WireMlKemLevel) -> Result { + Ok(match value { + WireMlKemLevel::MlKem512 => MLKEM::MLKEM512, + WireMlKemLevel::MlKem768 => MLKEM::MLKEM768, + WireMlKemLevel::MlKem1024 => MLKEM::MLKEM1024, + }) + } +} + +impl From for WireMlKemLevel { + fn from(value: MLKEM) -> Self { + match value { + MLKEM::MLKEM512 => Self::MlKem512, + MLKEM::MLKEM768 => Self::MlKem768, + MLKEM::MLKEM1024 => Self::MlKem1024, + } + } +} + +impl From<&ArchivedWireMlKemLevel> for MLKEM { + fn from(value: &ArchivedWireMlKemLevel) -> Self { + match value { + ArchivedWireMlKemLevel::MlKem512 => MLKEM::MLKEM512, + ArchivedWireMlKemLevel::MlKem768 => MLKEM::MLKEM768, + ArchivedWireMlKemLevel::MlKem1024 => MLKEM::MLKEM1024, + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub(crate) struct WireMlDsaPublicKey { + pub(crate) level: WireMlDsaLevel, + pub(crate) bytes: Vec, +} + +impl TryFrom for MLDSAPublicKey { + type Error = WireError; + + fn try_from(value: WireMlDsaPublicKey) -> Result { + MLDSAPublicKey::from_bytes(value.level.try_into()?, &value.bytes) + .map_err(|_| WireError::InvalidPayload) + } +} + +impl From<&MLDSAPublicKey> for WireMlDsaPublicKey { + fn from(value: &MLDSAPublicKey) -> Self { + Self { + level: value.level().into(), + bytes: value.as_bytes().to_vec(), + } + } +} + +impl TryFrom<&ArchivedWireMlDsaPublicKey> for MLDSAPublicKey { + type Error = WireError; + + fn try_from(value: &ArchivedWireMlDsaPublicKey) -> Result { + MLDSAPublicKey::from_bytes((&value.level).into(), value.bytes.as_slice()) + .map_err(|_| WireError::InvalidPayload) + } +} + +impl_wire_wrapper!(AsWireMlDsaPublicKey, MLDSAPublicKey, WireMlDsaPublicKey); + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub(crate) struct WireMlDsaSignature { + pub(crate) level: WireMlDsaLevel, + pub(crate) bytes: Vec, +} + +impl TryFrom for MLDSASignature { + type Error = WireError; + + fn try_from(value: WireMlDsaSignature) -> Result { + MLDSASignature::from_bytes(value.level.try_into()?, &value.bytes) + .map_err(|_| WireError::InvalidPayload) + } +} + +impl From<&MLDSASignature> for WireMlDsaSignature { + fn from(value: &MLDSASignature) -> Self { + Self { + level: value.level().into(), + bytes: value.as_bytes().to_vec(), + } + } +} + +impl TryFrom<&ArchivedWireMlDsaSignature> for MLDSASignature { + type Error = WireError; + + fn try_from(value: &ArchivedWireMlDsaSignature) -> Result { + MLDSASignature::from_bytes((&value.level).into(), value.bytes.as_slice()) + .map_err(|_| WireError::InvalidPayload) + } +} + +impl_wire_wrapper!(AsWireMlDsaSignature, MLDSASignature, WireMlDsaSignature); + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub(crate) struct WireMlKemPublicKey { + pub(crate) level: WireMlKemLevel, + pub(crate) bytes: Vec, +} + +impl TryFrom for MLKEMPublicKey { + type Error = WireError; + + fn try_from(value: WireMlKemPublicKey) -> Result { + MLKEMPublicKey::from_bytes(value.level.try_into()?, &value.bytes) + .map_err(|_| WireError::InvalidPayload) + } +} + +impl From<&MLKEMPublicKey> for WireMlKemPublicKey { + fn from(value: &MLKEMPublicKey) -> Self { + Self { + level: value.level().into(), + bytes: value.as_bytes().to_vec(), + } + } +} + +impl TryFrom<&ArchivedWireMlKemPublicKey> for MLKEMPublicKey { + type Error = WireError; + + fn try_from(value: &ArchivedWireMlKemPublicKey) -> Result { + MLKEMPublicKey::from_bytes((&value.level).into(), value.bytes.as_slice()) + .map_err(|_| WireError::InvalidPayload) + } +} + +impl_wire_wrapper!(AsWireMlKemPublicKey, MLKEMPublicKey, WireMlKemPublicKey); + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub(crate) struct WireMlKemCiphertext { + pub(crate) level: WireMlKemLevel, + pub(crate) bytes: Vec, +} + +impl TryFrom for MLKEMCiphertext { + type Error = WireError; + + fn try_from(value: WireMlKemCiphertext) -> Result { + MLKEMCiphertext::from_bytes(value.level.try_into()?, &value.bytes) + .map_err(|_| WireError::InvalidPayload) + } +} + +impl From<&MLKEMCiphertext> for WireMlKemCiphertext { + fn from(value: &MLKEMCiphertext) -> Self { + Self { + level: value.level().into(), + bytes: value.as_bytes().to_vec(), + } + } +} + +impl TryFrom<&ArchivedWireMlKemCiphertext> for MLKEMCiphertext { + type Error = WireError; + + fn try_from(value: &ArchivedWireMlKemCiphertext) -> Result { + MLKEMCiphertext::from_bytes((&value.level).into(), value.bytes.as_slice()) + .map_err(|_| WireError::InvalidPayload) + } +} + +impl_wire_wrapper!(AsWireMlKemCiphertext, MLKEMCiphertext, WireMlKemCiphertext); diff --git a/ql-wire/src/encrypted/close/mod.rs b/ql-wire/src/encrypted/close/mod.rs new file mode 100644 index 00000000..02c75aaa --- /dev/null +++ b/ql-wire/src/encrypted/close/mod.rs @@ -0,0 +1,8 @@ +use rkyv::{Archive, Deserialize, Serialize}; + +use crate::encrypted::stream::CloseCode; + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct SessionCloseBody { + pub code: CloseCode, +} diff --git a/ql-wire/src/encrypted/heartbeat/mod.rs b/ql-wire/src/encrypted/heartbeat/mod.rs new file mode 100644 index 00000000..2bdd15f9 --- /dev/null +++ b/ql-wire/src/encrypted/heartbeat/mod.rs @@ -0,0 +1,4 @@ +use rkyv::{Archive, Deserialize, Serialize}; + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct HeartbeatBody; diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs new file mode 100644 index 00000000..45680ed7 --- /dev/null +++ b/ql-wire/src/encrypted/mod.rs @@ -0,0 +1,68 @@ +use bc_components::SymmetricKey; +use rkyv::{Archive, Deserialize, Serialize}; + +use crate::{ + access_value, deserialize_value, encode_value, + encrypted_message::{ArchivedEncryptedMessage, EncryptedMessage}, + Nonce, QlHeader, QlPayload, QlRecord, SessionSeq, WireError, +}; + +pub mod close; +pub mod heartbeat; +pub mod stream; +pub mod unpair; + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct SessionEnvelope { + pub seq: SessionSeq, + pub ack: SessionAck, + pub body: SessionBody, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +pub struct SessionAck { + pub base: SessionSeq, + pub bitmap: u64, +} + +impl SessionAck { + pub const EMPTY: Self = Self { + base: SessionSeq(0), + bitmap: 0, + }; +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub enum SessionBody { + Heartbeat(heartbeat::HeartbeatBody), + Unpair(unpair::UnpairBody), + Stream(stream::StreamFrame), + StreamClose(stream::StreamCloseFrame), + Close(close::SessionCloseBody), +} + +pub fn encrypt_record( + header: QlHeader, + session_key: &SymmetricKey, + body: &SessionEnvelope, + nonce: Nonce, +) -> QlRecord { + let aad = header.aad(); + let body_bytes = encode_value(body); + let encrypted = EncryptedMessage::encrypt(session_key, body_bytes, &aad, nonce); + QlRecord { + header, + payload: QlPayload::Encrypted(encrypted), + } +} + +pub fn decrypt_record( + header: &QlHeader, + encrypted: &mut ArchivedEncryptedMessage, + session_key: &SymmetricKey, +) -> Result { + let aad = header.aad(); + let plaintext = encrypted.decrypt(session_key, &aad)?; + let body = access_value::(plaintext)?; + deserialize_value(body) +} diff --git a/ql-wire/src/encrypted/stream/mod.rs b/ql-wire/src/encrypted/stream/mod.rs new file mode 100644 index 00000000..a21712ac --- /dev/null +++ b/ql-wire/src/encrypted/stream/mod.rs @@ -0,0 +1,60 @@ +use rkyv::{Archive, Deserialize, Serialize}; + +use crate::StreamId; + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct StreamFrame { + pub stream_id: StreamId, + pub offset: u64, + pub bytes: Vec, + pub fin: bool, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct StreamCloseFrame { + pub stream_id: StreamId, + pub target: CloseTarget, + pub code: CloseCode, + pub payload: Vec, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum CloseTarget { + Request = 1, + Response = 2, + Both = 3, +} + +impl From<&ArchivedCloseTarget> for CloseTarget { + fn from(value: &ArchivedCloseTarget) -> Self { + match value { + ArchivedCloseTarget::Request => Self::Request, + ArchivedCloseTarget::Response => Self::Response, + ArchivedCloseTarget::Both => Self::Both, + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct CloseCode(pub u16); + +impl CloseCode { + pub const CANCELLED: Self = Self(0); + pub const PROTOCOL: Self = Self(1); + pub const INVALID_DATA: Self = Self(2); + pub const TIMEOUT: Self = Self(3); + + pub const UNKNOWN: Self = Self(16); + pub const UNKNOWN_ROUTE: Self = Self(17); + pub const INVALID_HEAD: Self = Self(18); + pub const BUSY: Self = Self(19); + pub const UNHANDLED: Self = Self(20); +} + +impl From<&ArchivedCloseCode> for CloseCode { + fn from(value: &ArchivedCloseCode) -> Self { + Self(value.0.to_native()) + } +} diff --git a/ql-wire/src/encrypted/unpair/mod.rs b/ql-wire/src/encrypted/unpair/mod.rs new file mode 100644 index 00000000..70a65e63 --- /dev/null +++ b/ql-wire/src/encrypted/unpair/mod.rs @@ -0,0 +1,4 @@ +use rkyv::{Archive, Deserialize, Serialize}; + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct UnpairBody; diff --git a/ql-wire/src/encrypted_message.rs b/ql-wire/src/encrypted_message.rs new file mode 100644 index 00000000..bd29c3e2 --- /dev/null +++ b/ql-wire/src/encrypted_message.rs @@ -0,0 +1,63 @@ +use bc_components::SymmetricKey; +use chacha20poly1305::{AeadInPlace, ChaCha20Poly1305, KeyInit}; +use rkyv::{seal::Seal, vec::ArchivedVec, Archive, Deserialize, Serialize}; + +use crate::WireError; + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Nonce(pub [u8; Self::NONCE_SIZE]); + +impl Nonce { + pub const NONCE_SIZE: usize = 12; +} + +pub const AUTH_SIZE: usize = 16; + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct EncryptedMessage { + ciphertext: Vec, + nonce: Nonce, + auth: [u8; AUTH_SIZE], +} + +impl EncryptedMessage { + pub fn encrypt(key: &SymmetricKey, mut plaintext: Vec, aad: &[u8], nonce: Nonce) -> Self { + let cipher = ChaCha20Poly1305::new(key.data().into()); + let auth = cipher + .encrypt_in_place_detached((&nonce.0).into(), aad, &mut plaintext) + .expect("chacha20poly1305 encryption should succeed"); + Self { + ciphertext: plaintext, + nonce, + auth: auth.into(), + } + } + + pub fn decrypt(&self, key: &SymmetricKey, aad: &[u8]) -> Result, WireError> { + let cipher = ChaCha20Poly1305::new(key.data().into()); + let mut plaintext = self.ciphertext.clone(); + cipher + .decrypt_in_place_detached( + (&self.nonce.0).into(), + aad, + &mut plaintext, + (&self.auth).into(), + ) + .map_err(|_| WireError::InvalidPayload)?; + Ok(plaintext) + } +} + +impl ArchivedEncryptedMessage { + pub fn decrypt(&mut self, key: &SymmetricKey, aad: &[u8]) -> Result<&[u8], WireError> { + let cipher = ChaCha20Poly1305::new(key.data().into()); + let nonce = &self.nonce; + let auth = self.auth; + let ciphertext = ArchivedVec::as_slice_seal(Seal::new(&mut self.ciphertext)); + let ciphertext = unsafe { ciphertext.unseal_unchecked() }; + cipher + .decrypt_in_place_detached((&nonce.0).into(), aad, ciphertext, (&auth).into()) + .map_err(|_| WireError::InvalidPayload)?; + Ok(ciphertext) + } +} diff --git a/ql-wire/src/handshake/crypto.rs b/ql-wire/src/handshake/crypto.rs new file mode 100644 index 00000000..46bc850a --- /dev/null +++ b/ql-wire/src/handshake/crypto.rs @@ -0,0 +1,331 @@ +use bc_components::{ + Digest, MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, MLKEMPublicKey, SymmetricKey, +}; +use rkyv::{Archive, Serialize}; + +use super::{ + verify_signature, ArchivedConfirm, ArchivedHello, ArchivedHelloReply, ArchivedReady, Confirm, + Hello, HelloReply, Ready, ReadyBody, +}; +use crate::{ + access_value, deserialize_value, encode_value, encrypted_message::EncryptedMessage, + ensure_not_expired, AsWireMlKemCiphertext, ControlMeta, Nonce, QlCrypto, QlHeader, QlIdentity, + WireError, XID, +}; + +#[derive(Archive, Serialize)] +struct HelloProofData { + initiator: XID, + responder: XID, + meta: ControlMeta, + nonce: Nonce, + #[rkyv(with = AsWireMlKemCiphertext)] + kem_ct: bc_components::MLKEMCiphertext, +} + +#[derive(Archive, Serialize)] +struct HandshakeTranscript { + initiator: XID, + responder: XID, + hello_meta: ControlMeta, + initiator_nonce: Nonce, + responder_nonce: Nonce, + reply_meta: ControlMeta, + #[rkyv(with = AsWireMlKemCiphertext)] + initiator_kem_ct: bc_components::MLKEMCiphertext, + #[rkyv(with = AsWireMlKemCiphertext)] + responder_kem_ct: bc_components::MLKEMCiphertext, +} + +#[derive(Archive, Serialize)] +struct ConfirmProofData { + meta: ControlMeta, + transcript: Vec, +} + +#[derive(Archive, Serialize)] +struct SessionKeyMaterial { + initiator_secret: Vec, + responder_secret: Vec, + transcript: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ResponderSecrets { + pub initiator_secret: SymmetricKey, + pub responder_secret: SymmetricKey, +} + +pub fn build_hello( + identity: &QlIdentity, + crypto: &impl QlCrypto, + recipient: XID, + recipient_encapsulation_key: &MLKEMPublicKey, + meta: ControlMeta, +) -> Result<(Hello, SymmetricKey), WireError> { + let nonce = next_nonce(crypto); + let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(); + let signature = identity.signing_private_key.sign(hello_proof_data( + identity.xid, + recipient, + &meta, + &nonce, + &kem_ct, + )); + Ok(( + Hello { + meta, + nonce, + kem_ct, + signature, + }, + session_key, + )) +} + +pub fn verify_hello( + initiator: XID, + responder: XID, + initiator_signing_key: &MLDSAPublicKey, + hello: &ArchivedHello, +) -> Result<(), WireError> { + let meta: ControlMeta = (&hello.meta).into(); + ensure_not_expired(meta.valid_until)?; + let signature = MLDSASignature::try_from(&hello.signature)?; + let nonce: Nonce = deserialize_value(&hello.nonce)?; + let kem_ct = MLKEMCiphertext::try_from(&hello.kem_ct)?; + let proof_data = hello_proof_data(initiator, responder, &meta, &nonce, &kem_ct); + verify_signature(initiator_signing_key, &signature, &proof_data) +} + +pub fn respond_hello( + identity: &QlIdentity, + crypto: &impl QlCrypto, + initiator: XID, + initiator_signing_key: &MLDSAPublicKey, + initiator_encapsulation_key: &MLKEMPublicKey, + hello: &ArchivedHello, + meta: ControlMeta, +) -> Result<(HelloReply, ResponderSecrets), WireError> { + verify_hello(initiator, identity.xid, initiator_signing_key, hello)?; + let hello_meta: ControlMeta = (&hello.meta).into(); + let initiator_nonce: Nonce = deserialize_value(&hello.nonce)?; + let initiator_kem_ct = MLKEMCiphertext::try_from(&hello.kem_ct)?; + let initiator_secret = identity + .encapsulation_private_key + .decapsulate_shared_secret(&initiator_kem_ct) + .map_err(|_| WireError::InvalidPayload)?; + let nonce = next_nonce(crypto); + let (responder_secret, kem_ct) = initiator_encapsulation_key.encapsulate_new_shared_secret(); + let transcript = handshake_transcript( + initiator, + identity.xid, + &hello_meta, + &initiator_nonce, + &initiator_kem_ct, + &meta, + &nonce, + &kem_ct, + ); + let signature = identity.signing_private_key.sign(&transcript); + let reply = HelloReply { + meta, + nonce, + kem_ct, + signature, + }; + Ok(( + reply, + ResponderSecrets { + initiator_secret, + responder_secret, + }, + )) +} + +pub fn build_confirm( + identity: &QlIdentity, + responder: XID, + responder_signing_key: &MLDSAPublicKey, + hello: &Hello, + reply: &ArchivedHelloReply, + initiator_secret: &SymmetricKey, + meta: ControlMeta, +) -> Result<(Confirm, SymmetricKey), WireError> { + let reply_meta: ControlMeta = (&reply.meta).into(); + ensure_not_expired(reply_meta.valid_until)?; + let reply_nonce: Nonce = deserialize_value(&reply.nonce)?; + let reply_kem_ct = MLKEMCiphertext::try_from(&reply.kem_ct)?; + let reply_signature = MLDSASignature::try_from(&reply.signature)?; + let transcript = handshake_transcript( + identity.xid, + responder, + &hello.meta, + &hello.nonce, + &hello.kem_ct, + &reply_meta, + &reply_nonce, + &reply_kem_ct, + ); + verify_signature(responder_signing_key, &reply_signature, &transcript)?; + let responder_secret = identity + .encapsulation_private_key + .decapsulate_shared_secret(&reply_kem_ct) + .map_err(|_| WireError::InvalidPayload)?; + let signature = identity + .signing_private_key + .sign(confirm_proof_data(&meta, &transcript)); + let confirm = Confirm { meta, signature }; + let session_key = derive_session_key(initiator_secret, &responder_secret, &transcript); + Ok((confirm, session_key)) +} + +pub fn finalize_confirm( + initiator: XID, + responder: XID, + initiator_signing_key: &MLDSAPublicKey, + hello: &Hello, + reply: &HelloReply, + confirm: &ArchivedConfirm, + secrets: &ResponderSecrets, +) -> Result { + verify_confirm( + initiator, + responder, + initiator_signing_key, + hello, + reply, + confirm, + )?; + Ok(derive_session_key( + &secrets.initiator_secret, + &secrets.responder_secret, + &handshake_transcript( + initiator, + responder, + &hello.meta, + &hello.nonce, + &hello.kem_ct, + &reply.meta, + &reply.nonce, + &reply.kem_ct, + ), + )) +} + +pub fn verify_confirm( + initiator: XID, + responder: XID, + initiator_signing_key: &MLDSAPublicKey, + hello: &Hello, + reply: &HelloReply, + confirm: &ArchivedConfirm, +) -> Result<(), WireError> { + let confirm_meta: ControlMeta = (&confirm.meta).into(); + ensure_not_expired(confirm_meta.valid_until)?; + let confirm_signature = MLDSASignature::try_from(&confirm.signature)?; + let transcript = handshake_transcript( + initiator, + responder, + &hello.meta, + &hello.nonce, + &hello.kem_ct, + &reply.meta, + &reply.nonce, + &reply.kem_ct, + ); + let proof_data = confirm_proof_data(&confirm_meta, &transcript); + verify_signature(initiator_signing_key, &confirm_signature, &proof_data)?; + Ok(()) +} + +pub fn build_ready( + header: QlHeader, + session_key: &SymmetricKey, + meta: ControlMeta, + nonce: Nonce, +) -> Ready { + let aad = header.aad(); + let body_bytes = encode_value(&ReadyBody { meta }); + Ready { + encrypted: EncryptedMessage::encrypt(session_key, body_bytes, &aad, nonce), + } +} + +pub fn decrypt_ready( + header: &QlHeader, + ready: &mut ArchivedReady, + session_key: &SymmetricKey, +) -> Result { + let aad = header.aad(); + let plaintext = ready.encrypted.decrypt(session_key, &aad)?; + let body = access_value::(plaintext)?; + let body = deserialize_value(body)?; + ensure_not_expired(body.meta.valid_until)?; + Ok(body) +} + +fn handshake_transcript( + initiator: XID, + responder: XID, + hello_meta: &ControlMeta, + initiator_nonce: &Nonce, + initiator_kem_ct: &bc_components::MLKEMCiphertext, + reply_meta: &ControlMeta, + responder_nonce: &Nonce, + responder_kem_ct: &bc_components::MLKEMCiphertext, +) -> Vec { + encode_value(&HandshakeTranscript { + initiator, + responder, + hello_meta: *hello_meta, + initiator_nonce: initiator_nonce.clone(), + responder_nonce: responder_nonce.clone(), + reply_meta: *reply_meta, + initiator_kem_ct: initiator_kem_ct.clone(), + responder_kem_ct: responder_kem_ct.clone(), + }) +} + +fn hello_proof_data( + initiator: XID, + responder: XID, + meta: &ControlMeta, + nonce: &Nonce, + kem_ct: &bc_components::MLKEMCiphertext, +) -> Vec { + encode_value(&HelloProofData { + initiator, + responder, + meta: *meta, + nonce: nonce.clone(), + kem_ct: kem_ct.clone(), + }) +} + +fn confirm_proof_data(meta: &ControlMeta, transcript: &[u8]) -> Vec { + encode_value(&ConfirmProofData { + meta: *meta, + transcript: transcript.to_vec(), + }) +} + +fn next_nonce(platform: &impl QlCrypto) -> Nonce { + let mut data = [0u8; Nonce::NONCE_SIZE]; + platform.fill_random_bytes(&mut data); + Nonce(data) +} + +fn derive_session_key( + initiator_secret: &SymmetricKey, + responder_secret: &SymmetricKey, + transcript: &[u8], +) -> SymmetricKey { + let payload = encode_value(&SessionKeyMaterial { + initiator_secret: initiator_secret.as_bytes().to_vec(), + responder_secret: responder_secret.as_bytes().to_vec(), + transcript: transcript.to_vec(), + }); + let digest = Digest::from_image(payload); + SymmetricKey::from_data(*digest.data()) +} diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs new file mode 100644 index 00000000..c40378ae --- /dev/null +++ b/ql-wire/src/handshake/mod.rs @@ -0,0 +1,66 @@ +use bc_components::{MLDSAPublicKey, MLDSASignature, MLKEMCiphertext}; +use rkyv::{Archive, Deserialize, Serialize}; + +use crate::{ + encrypted_message::EncryptedMessage, AsWireMlDsaSignature, AsWireMlKemCiphertext, ControlMeta, + Nonce, WireError, +}; + +mod crypto; +pub use crypto::*; + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub enum HandshakeRecord { + Hello(Hello), + HelloReply(HelloReply), + Confirm(Confirm), + Ready(Ready), +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Hello { + pub meta: ControlMeta, + pub nonce: Nonce, + #[rkyv(with = AsWireMlKemCiphertext)] + pub kem_ct: MLKEMCiphertext, + #[rkyv(with = AsWireMlDsaSignature)] + pub signature: MLDSASignature, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct HelloReply { + pub meta: ControlMeta, + pub nonce: Nonce, + #[rkyv(with = AsWireMlKemCiphertext)] + pub kem_ct: MLKEMCiphertext, + #[rkyv(with = AsWireMlDsaSignature)] + pub signature: MLDSASignature, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Confirm { + pub meta: ControlMeta, + #[rkyv(with = AsWireMlDsaSignature)] + pub signature: MLDSASignature, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Ready { + pub encrypted: EncryptedMessage, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct ReadyBody { + pub meta: ControlMeta, +} + +pub fn verify_signature( + signing_key: &MLDSAPublicKey, + signature: &MLDSASignature, + proof_data: &[u8], +) -> Result<(), WireError> { + match signing_key.verify(signature, proof_data) { + Ok(true) => Ok(()), + _ => Err(WireError::InvalidSignature), + } +} diff --git a/ql-wire/src/id.rs b/ql-wire/src/id.rs new file mode 100644 index 00000000..f236514d --- /dev/null +++ b/ql-wire/src/id.rs @@ -0,0 +1,51 @@ +use std::fmt; + +use rkyv::{Archive, Deserialize, Serialize}; + +macro_rules! define_id { + ($name:ident, $ty:ty) => { + #[derive( + Archive, + Serialize, + Deserialize, + Debug, + Clone, + Copy, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + )] + #[repr(transparent)] + pub struct $name(pub $ty); + + impl fmt::Display for $name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } + } + }; +} + +define_id!(ControlId, u32); +define_id!(SessionSeq, u64); +define_id!(StreamId, u32); + +impl From<&ArchivedControlId> for ControlId { + fn from(value: &ArchivedControlId) -> Self { + Self(value.0.to_native()) + } +} + +impl From<&ArchivedSessionSeq> for SessionSeq { + fn from(value: &ArchivedSessionSeq) -> Self { + Self(value.0.to_native()) + } +} + +impl From<&ArchivedStreamId> for StreamId { + fn from(value: &ArchivedStreamId) -> Self { + Self(value.0.to_native()) + } +} diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs new file mode 100644 index 00000000..ac238055 --- /dev/null +++ b/ql-wire/src/lib.rs @@ -0,0 +1,228 @@ +//! quantum link protocol wire format +//! +//! naming conventions: +//! - *Record - unencrypted messages +//! - *Body - message content after decrypting + +use bc_components::{MLDSAPrivateKey, MLDSAPublicKey, MLKEMPrivateKey, MLKEMPublicKey}; +use rkyv::{ + api::{ + high::{to_bytes_in, HighSerializer, HighValidator}, + low::{self, LowDeserializer}, + }, + bytecheck::CheckBytes, + ser::allocator::ArenaHandle, + Archive, Deserialize, Portable, Serialize, +}; +use thiserror::Error; + +mod codec; +pub mod encrypted; +pub mod encrypted_message; +pub mod handshake; +mod id; +pub mod pair; +mod xid; + +pub(crate) use codec::*; +pub use encrypted::{ + close::SessionCloseBody, + stream::{CloseCode, CloseTarget, StreamCloseFrame, StreamFrame}, + SessionAck, SessionBody, SessionEnvelope, +}; +pub use encrypted_message::Nonce; +pub use id::{ControlId, SessionSeq, StreamId}; +pub use xid::XID; + +pub(crate) type WireArchiveError = rkyv::rancor::Error; + +#[derive(Debug, Clone)] +pub struct QlIdentity { + pub xid: XID, + pub signing_private_key: MLDSAPrivateKey, + pub signing_public_key: MLDSAPublicKey, + pub encapsulation_private_key: MLKEMPrivateKey, + pub encapsulation_public_key: MLKEMPublicKey, +} + +impl QlIdentity { + pub fn from_keys( + signing_private_key: MLDSAPrivateKey, + signing_public_key: MLDSAPublicKey, + encapsulation_private_key: MLKEMPrivateKey, + encapsulation_public_key: MLKEMPublicKey, + ) -> Self { + Self { + xid: XID::from_signing_public_key(&signing_public_key), + signing_private_key, + signing_public_key, + encapsulation_private_key, + encapsulation_public_key, + } + } +} + +pub trait QlCrypto { + fn fill_random_bytes(&self, data: &mut [u8]); +} + +#[derive(Debug, Clone, PartialEq, Eq, Error)] +pub enum WireError { + #[error("invalid payload")] + InvalidPayload, + #[error("invalid signature")] + InvalidSignature, + #[error("expired")] + Expired, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct QlRecord { + pub header: QlHeader, + pub payload: QlPayload, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct QlHeader { + pub sender: XID, + pub recipient: XID, +} + +impl QlHeader { + pub fn aad(&self) -> Vec { + encode_value(self) + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +pub struct ControlMeta { + pub control_id: ControlId, + pub valid_until: u64, +} + +impl From<&ArchivedControlMeta> for ControlMeta { + fn from(value: &ArchivedControlMeta) -> Self { + Self { + control_id: (&value.control_id).into(), + valid_until: value.valid_until.to_native(), + } + } +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub enum QlPayload { + Handshake(handshake::HandshakeRecord), + Pair(pair::PairRequestRecord), + Encrypted(encrypted_message::EncryptedMessage), +} + +pub fn encode_record(record: &QlRecord) -> Vec { + encode_value(record) +} + +pub fn access_record(bytes: &[u8]) -> Result<&ArchivedQlRecord, WireError> { + access_value(bytes) +} + +pub fn access_record_mut( + bytes: &mut [u8], +) -> Result, WireError> { + rkyv::access_mut::(bytes) + .map_err(|_| WireError::InvalidPayload) +} + +pub fn decode_record(bytes: &[u8]) -> Result { + deserialize_value(access_record(bytes)?) +} + +pub(crate) fn encode_value( + value: &impl for<'a> Serialize, ArenaHandle<'a>, WireArchiveError>>, +) -> Vec { + to_bytes_in::<_, WireArchiveError>(value, Vec::new()) + .expect("wire serialization should not fail") +} + +pub(crate) fn access_value(bytes: &[u8]) -> Result<&T, WireError> +where + T: Portable + for<'a> CheckBytes>, +{ + rkyv::access::(bytes).map_err(|_| WireError::InvalidPayload) +} + +pub(crate) fn deserialize_value( + value: &impl rkyv::Deserialize>, +) -> Result { + low::deserialize::(value).map_err(|_| WireError::InvalidPayload) +} + +pub(crate) fn ensure_not_expired(valid_until: u64) -> Result<(), WireError> { + if now_secs() > valid_until { + Err(WireError::Expired) + } else { + Ok(()) + } +} + +pub fn now_secs() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|duration| duration.as_secs()) + .unwrap_or(0) +} + +#[cfg(test)] +mod tests { + use bc_components::SymmetricKey; + + use super::*; + + struct TestCrypto(std::sync::atomic::AtomicU8); + + impl TestCrypto { + fn new(seed: u8) -> Self { + Self(std::sync::atomic::AtomicU8::new(seed)) + } + } + + impl QlCrypto for TestCrypto { + fn fill_random_bytes(&self, data: &mut [u8]) { + let seed = self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + for (index, byte) in data.iter_mut().enumerate() { + *byte = seed.wrapping_add(index as u8); + } + } + } + + #[test] + fn ql_record_round_trip() { + let header = QlHeader { + sender: XID([1; XID::XID_SIZE]), + recipient: XID([2; XID::XID_SIZE]), + }; + let body = SessionEnvelope { + seq: SessionSeq(7), + ack: SessionAck { + base: SessionSeq(3), + bitmap: 0b101, + }, + body: SessionBody::Heartbeat(encrypted::heartbeat::HeartbeatBody), + }; + let record = encrypted::encrypt_record( + header.clone(), + &SymmetricKey::from_data([7; SymmetricKey::SYMMETRIC_KEY_SIZE]), + &body, + Nonce([8; Nonce::NONCE_SIZE]), + ); + + let bytes = encode_record(&record); + let decoded = decode_record(&bytes).unwrap(); + assert_eq!(decoded.header, header); + assert!(matches!(decoded.payload, QlPayload::Encrypted(_))); + } + + #[test] + fn now_secs_advances() { + let _ = TestCrypto::new(1); + assert!(now_secs() > 0); + } +} diff --git a/ql-wire/src/pair/crypto.rs b/ql-wire/src/pair/crypto.rs new file mode 100644 index 00000000..7c9c1659 --- /dev/null +++ b/ql-wire/src/pair/crypto.rs @@ -0,0 +1,132 @@ +use bc_components::{MLDSAPublicKey, MLKEMCiphertext, MLKEMPublicKey, SymmetricKey}; +use rkyv::{Archive, Serialize}; + +use super::{PairRequestBody, PairRequestRecord}; +use crate::{ + access_value, deserialize_value, encode_value, + encrypted_message::{ArchivedEncryptedMessage, EncryptedMessage}, + ensure_not_expired, AsWireMlDsaPublicKey, AsWireMlKemCiphertext, AsWireMlKemPublicKey, + ControlMeta, Nonce, QlCrypto, QlHeader, QlIdentity, QlPayload, QlRecord, WireError, XID, +}; + +#[derive(Archive, Serialize)] +struct PairingAad { + header: QlHeader, + #[rkyv(with = AsWireMlKemCiphertext)] + kem_ct: MLKEMCiphertext, +} + +#[derive(Archive, Serialize)] +struct PairingProofData { + aad: Vec, + meta: ControlMeta, + #[rkyv(with = AsWireMlDsaPublicKey)] + signing_pub_key: MLDSAPublicKey, + #[rkyv(with = AsWireMlKemPublicKey)] + encapsulation_pub_key: MLKEMPublicKey, +} + +pub fn build_pair_request( + identity: &QlIdentity, + crypto: &impl QlCrypto, + recipient: XID, + recipient_encapsulation_key: &MLKEMPublicKey, + meta: ControlMeta, +) -> Result { + let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(); + let header = QlHeader { + sender: identity.xid, + recipient, + }; + let signing_pub_key = identity.signing_public_key.clone(); + let sender_encapsulation_key = identity.encapsulation_public_key.clone(); + let proof_data = pairing_proof_data( + &header, + &kem_ct, + &meta, + &signing_pub_key, + &sender_encapsulation_key, + ); + let proof = identity.signing_private_key.sign(&proof_data); + let body = PairRequestBody { + meta, + signing_pub_key, + encapsulation_pub_key: sender_encapsulation_key, + proof, + }; + let body_bytes = encode_value(&body); + let aad = pairing_aad(&header, &kem_ct); + let mut nonce_bytes = [0u8; Nonce::NONCE_SIZE]; + crypto.fill_random_bytes(&mut nonce_bytes); + let encrypted = EncryptedMessage::encrypt(&session_key, body_bytes, &aad, Nonce(nonce_bytes)); + Ok(QlRecord { + header, + payload: QlPayload::Pair(PairRequestRecord { kem_ct, encrypted }), + }) +} + +pub fn decrypt_pair_request( + identity: &QlIdentity, + header: &QlHeader, + request: &mut super::ArchivedPairRequestRecord, +) -> Result { + let kem_ct = MLKEMCiphertext::try_from(&request.kem_ct)?; + let aad = pairing_aad(header, &kem_ct); + let session_key = identity + .encapsulation_private_key + .decapsulate_shared_secret(&kem_ct) + .map_err(|_| WireError::InvalidPayload)?; + let decrypted = decrypt_body(&session_key, &mut request.encrypted, &aad)?; + ensure_not_expired(decrypted.meta.valid_until)?; + if XID::from_signing_public_key(&decrypted.signing_pub_key) != header.sender { + return Err(WireError::InvalidPayload); + } + let proof_data = pairing_proof_data( + header, + &kem_ct, + &decrypted.meta, + &decrypted.signing_pub_key, + &decrypted.encapsulation_pub_key, + ); + if decrypted + .signing_pub_key + .verify(&decrypted.proof, &proof_data) + .unwrap_or(false) + { + Ok(decrypted) + } else { + Err(WireError::InvalidSignature) + } +} + +fn pairing_proof_data( + header: &QlHeader, + kem_ct: &MLKEMCiphertext, + meta: &ControlMeta, + signing_pub_key: &MLDSAPublicKey, + encapsulation_pub_key: &MLKEMPublicKey, +) -> Vec { + encode_value(&PairingProofData { + aad: pairing_aad(header, kem_ct), + meta: *meta, + signing_pub_key: signing_pub_key.clone(), + encapsulation_pub_key: encapsulation_pub_key.clone(), + }) +} + +fn decrypt_body( + key: &SymmetricKey, + encrypted: &mut ArchivedEncryptedMessage, + aad: &[u8], +) -> Result { + let plaintext = encrypted.decrypt(key, aad)?; + let body = access_value::(plaintext)?; + deserialize_value(body) +} + +pub(crate) fn pairing_aad(header: &QlHeader, kem_ct: &MLKEMCiphertext) -> Vec { + encode_value(&PairingAad { + header: header.clone(), + kem_ct: kem_ct.clone(), + }) +} diff --git a/ql-wire/src/pair/mod.rs b/ql-wire/src/pair/mod.rs new file mode 100644 index 00000000..c2c06dd8 --- /dev/null +++ b/ql-wire/src/pair/mod.rs @@ -0,0 +1,28 @@ +use bc_components::{MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, MLKEMPublicKey}; +use rkyv::{Archive, Deserialize, Serialize}; + +use crate::{ + encrypted_message::EncryptedMessage, AsWireMlDsaPublicKey, AsWireMlDsaSignature, + AsWireMlKemCiphertext, AsWireMlKemPublicKey, ControlMeta, +}; + +mod crypto; +pub use crypto::*; + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct PairRequestRecord { + #[rkyv(with = AsWireMlKemCiphertext)] + pub kem_ct: MLKEMCiphertext, + pub encrypted: EncryptedMessage, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct PairRequestBody { + pub meta: ControlMeta, + #[rkyv(with = AsWireMlDsaPublicKey)] + pub signing_pub_key: MLDSAPublicKey, + #[rkyv(with = AsWireMlKemPublicKey)] + pub encapsulation_pub_key: MLKEMPublicKey, + #[rkyv(with = AsWireMlDsaSignature)] + pub proof: MLDSASignature, +} diff --git a/ql-wire/src/xid.rs b/ql-wire/src/xid.rs new file mode 100644 index 00000000..548da558 --- /dev/null +++ b/ql-wire/src/xid.rs @@ -0,0 +1,16 @@ +use bc_components::{MLDSAPublicKey, SigningPublicKey}; +use rkyv::{Archive, Deserialize, Serialize}; + +#[derive( + Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, +)] +pub struct XID(pub [u8; Self::XID_SIZE]); + +impl XID { + pub const XID_SIZE: usize = 32; + + pub fn from_signing_public_key(signing_public_key: &MLDSAPublicKey) -> Self { + let xid = bc_components::XID::new(SigningPublicKey::MLDSA(signing_public_key.clone())); + Self(*xid.data()) + } +} From ebcc8a248da8fb2724e6ac5885d4a23452410ed0 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 18 Mar 2026 01:16:43 -0400 Subject: [PATCH 009/304] ql: move runtime onto fsm, switch wire serialization, and remove ql-engine --- Cargo.lock | 240 ++- Cargo.toml | 15 +- ql-engine/Cargo.toml | 20 - ql-engine/src/arena.rs | 194 -- .../src/engine/implementation/handshake.rs | 607 ------- ql-engine/src/engine/implementation/mod.rs | 762 -------- ql-engine/src/engine/implementation/peer.rs | 160 -- ql-engine/src/engine/implementation/stream.rs | 337 ---- ql-engine/src/engine/mod.rs | 231 --- ql-engine/src/engine/replay_cache.rs | 178 -- ql-engine/src/engine/state.rs | 270 --- ql-engine/src/engine/tests/handshake.rs | 827 --------- ql-engine/src/engine/tests/liveness.rs | 87 - ql-engine/src/engine/tests/mod.rs | 538 ------ ql-engine/src/engine/tests/peer.rs | 42 - ql-engine/src/engine/tests/stream.rs | 1554 ----------------- ql-engine/src/identity.rs | 29 - ql-engine/src/lib.rs | 40 - ql-engine/src/stream/internal.rs | 842 --------- ql-engine/src/stream/mod.rs | 270 --- ql-engine/src/stream/ring.rs | 194 -- ql-engine/src/stream/state.rs | 532 ------ ql-engine/src/stream/tests.rs | 334 ---- ql-engine/src/wire/codec.rs | 332 ---- ql-engine/src/wire/encrypted_message.rs | 63 - ql-engine/src/wire/handshake/crypto.rs | 344 ---- ql-engine/src/wire/handshake/mod.rs | 69 - ql-engine/src/wire/heartbeat/crypto.rs | 39 - ql-engine/src/wire/heartbeat/mod.rs | 11 - ql-engine/src/wire/id.rs | 44 - ql-engine/src/wire/mod.rs | 473 ----- ql-engine/src/wire/pair/crypto.rs | 139 -- ql-engine/src/wire/pair/mod.rs | 28 - ql-engine/src/wire/seq.rs | 97 - ql-engine/src/wire/stream/crypto.rs | 39 - ql-engine/src/wire/stream/mod.rs | 199 --- ql-engine/src/wire/unpair/crypto.rs | 50 - ql-engine/src/wire/unpair/mod.rs | 14 - ql-fsm/Cargo.toml | 15 +- ql-fsm/src/error.rs | 51 + ql-fsm/src/implementation/fsm.rs | 218 +++ ql-fsm/src/implementation/handshake.rs | 222 ++- ql-fsm/src/implementation/mod.rs | 229 +-- ql-fsm/src/implementation/peer.rs | 33 +- ql-fsm/src/lib.rs | 204 ++- ql-fsm/src/session/internal.rs | 628 ------- ql-fsm/src/session/mod.rs | 706 +++++++- ql-fsm/src/session/ring.rs | 56 + ql-fsm/src/session/state.rs | 41 +- ql-fsm/src/session/tests.rs | 356 +++- ql-fsm/src/state.rs | 101 +- ql-fsm/src/tests/handshake.rs | 315 ++++ ql-fsm/src/tests/mod.rs | 283 +++ ql-fsm/src/tests/session.rs | 352 ++++ ql-runtime/Cargo.toml | 9 +- ql-runtime/src/command.rs | 7 +- ql-runtime/src/driver.rs | 609 +++---- ql-runtime/src/handle.rs | 52 +- ql-runtime/src/lib.rs | 69 +- ql-runtime/src/platform.rs | 9 +- ql-runtime/src/tests/handshake.rs | 111 +- ql-runtime/src/tests/heartbeat.rs | 183 +- ql-runtime/src/tests/mod.rs | 245 ++- ql-runtime/src/tests/stream.rs | 91 +- ql-runtime/src/tests/unpair.rs | 51 +- ql-wire/Cargo.toml | 18 +- ql-wire/src/codec.rs | 311 +--- ql-wire/src/control.rs | 47 + ql-wire/src/encrypted/close.rs | 34 + ql-wire/src/encrypted/close/mod.rs | 8 - ql-wire/src/encrypted/heartbeat/mod.rs | 4 - ql-wire/src/encrypted/mod.rs | 210 ++- ql-wire/src/encrypted/ping.rs | 2 + ql-wire/src/encrypted/stream/mod.rs | 60 - ql-wire/src/encrypted/stream_chunk.rs | 63 + ql-wire/src/encrypted/stream_close.rs | 97 + ql-wire/src/encrypted/unpair.rs | 2 + ql-wire/src/encrypted/unpair/mod.rs | 4 - ql-wire/src/encrypted_message.rs | 142 +- ql-wire/src/error.rs | 17 + ql-wire/src/handshake/crypto.rs | 501 +++--- ql-wire/src/handshake/mod.rs | 153 +- ql-wire/src/header.rs | 55 + ql-wire/src/id.rs | 51 - ql-wire/src/identity.rs | 28 + ql-wire/src/lib.rs | 256 +-- ql-wire/src/nonce.rs | 7 + ql-wire/src/pair/crypto.rs | 148 +- ql-wire/src/pair/mod.rs | 106 +- ql-wire/src/pq.rs | 185 ++ ql-wire/src/record.rs | 152 ++ ql-wire/src/tests.rs | 378 ++++ ql-wire/src/xid.rs | 17 +- 93 files changed, 5692 insertions(+), 12854 deletions(-) delete mode 100644 ql-engine/Cargo.toml delete mode 100644 ql-engine/src/arena.rs delete mode 100644 ql-engine/src/engine/implementation/handshake.rs delete mode 100644 ql-engine/src/engine/implementation/mod.rs delete mode 100644 ql-engine/src/engine/implementation/peer.rs delete mode 100644 ql-engine/src/engine/implementation/stream.rs delete mode 100644 ql-engine/src/engine/mod.rs delete mode 100644 ql-engine/src/engine/replay_cache.rs delete mode 100644 ql-engine/src/engine/state.rs delete mode 100644 ql-engine/src/engine/tests/handshake.rs delete mode 100644 ql-engine/src/engine/tests/liveness.rs delete mode 100644 ql-engine/src/engine/tests/mod.rs delete mode 100644 ql-engine/src/engine/tests/peer.rs delete mode 100644 ql-engine/src/engine/tests/stream.rs delete mode 100644 ql-engine/src/identity.rs delete mode 100644 ql-engine/src/lib.rs delete mode 100644 ql-engine/src/stream/internal.rs delete mode 100644 ql-engine/src/stream/mod.rs delete mode 100644 ql-engine/src/stream/ring.rs delete mode 100644 ql-engine/src/stream/state.rs delete mode 100644 ql-engine/src/stream/tests.rs delete mode 100644 ql-engine/src/wire/codec.rs delete mode 100644 ql-engine/src/wire/encrypted_message.rs delete mode 100644 ql-engine/src/wire/handshake/crypto.rs delete mode 100644 ql-engine/src/wire/handshake/mod.rs delete mode 100644 ql-engine/src/wire/heartbeat/crypto.rs delete mode 100644 ql-engine/src/wire/heartbeat/mod.rs delete mode 100644 ql-engine/src/wire/id.rs delete mode 100644 ql-engine/src/wire/mod.rs delete mode 100644 ql-engine/src/wire/pair/crypto.rs delete mode 100644 ql-engine/src/wire/pair/mod.rs delete mode 100644 ql-engine/src/wire/seq.rs delete mode 100644 ql-engine/src/wire/stream/crypto.rs delete mode 100644 ql-engine/src/wire/stream/mod.rs delete mode 100644 ql-engine/src/wire/unpair/crypto.rs delete mode 100644 ql-engine/src/wire/unpair/mod.rs create mode 100644 ql-fsm/src/error.rs create mode 100644 ql-fsm/src/implementation/fsm.rs delete mode 100644 ql-fsm/src/session/internal.rs create mode 100644 ql-fsm/src/tests/handshake.rs create mode 100644 ql-fsm/src/tests/mod.rs create mode 100644 ql-fsm/src/tests/session.rs create mode 100644 ql-wire/src/control.rs create mode 100644 ql-wire/src/encrypted/close.rs delete mode 100644 ql-wire/src/encrypted/close/mod.rs delete mode 100644 ql-wire/src/encrypted/heartbeat/mod.rs create mode 100644 ql-wire/src/encrypted/ping.rs delete mode 100644 ql-wire/src/encrypted/stream/mod.rs create mode 100644 ql-wire/src/encrypted/stream_chunk.rs create mode 100644 ql-wire/src/encrypted/stream_close.rs create mode 100644 ql-wire/src/encrypted/unpair.rs delete mode 100644 ql-wire/src/encrypted/unpair/mod.rs create mode 100644 ql-wire/src/error.rs create mode 100644 ql-wire/src/header.rs delete mode 100644 ql-wire/src/id.rs create mode 100644 ql-wire/src/identity.rs create mode 100644 ql-wire/src/nonce.rs create mode 100644 ql-wire/src/pq.rs create mode 100644 ql-wire/src/record.rs create mode 100644 ql-wire/src/tests.rs diff --git a/Cargo.lock b/Cargo.lock index f09e182f..c0e954a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -559,6 +559,17 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "core-models" +version = "0.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "657f625ff361906f779745d08375ae3cc9fef87a35fba5f22874cf773010daf4" +dependencies = [ + "hax-lib", + "pastey", + "rand 0.9.2", +] + [[package]] name = "cpufeatures" version = "0.2.17" @@ -1154,6 +1165,43 @@ version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +[[package]] +name = "hax-lib" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "543f93241d32b3f00569201bfce9d7a93c92c6421b23c77864ac929dc947b9fc" +dependencies = [ + "hax-lib-macros", + "num-bigint", + "num-traits", +] + +[[package]] +name = "hax-lib-macros" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8755751e760b11021765bb04cb4a6c4e24742688d9f3aa14c2079638f537b0f" +dependencies = [ + "hax-lib-macros-types", + "proc-macro-error2", + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "hax-lib-macros-types" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f177c9ae8ea456e2f71ff3c1ea47bf4464f772a05133fcbba56cd5ba169035a2" +dependencies = [ + "proc-macro2", + "quote", + "serde", + "serde_json", + "uuid", +] + [[package]] name = "hermit-abi" version = "0.5.2" @@ -1429,6 +1477,108 @@ version = "0.2.175" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" +[[package]] +name = "libcrux-aesgcm" +version = "0.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99f2a019dab4097585a7d4f5b9deebe46cd1e628b16a5bc4cb0ce35e1da334e6" +dependencies = [ + "libcrux-intrinsics", + "libcrux-platform", + "libcrux-secrets", + "libcrux-traits", +] + +[[package]] +name = "libcrux-intrinsics" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1b5db005ff8001e026b73a6842ee81bbef8ec5ff0e1915a67ae65fd2a9fafa5" +dependencies = [ + "core-models", + "hax-lib", +] + +[[package]] +name = "libcrux-macros" +version = "0.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffd6aa2dcd5be681662001b81d493f1569c6d49a32361f470b0c955465cd0338" +dependencies = [ + "quote", + "syn 2.0.106", +] + +[[package]] +name = "libcrux-ml-dsa" +version = "0.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b34d977eb95b8fe93e6eb87197b55ee21e50e725bc3f206a7cb3a0d7d719c4b" +dependencies = [ + "core-models", + "hax-lib", + "libcrux-intrinsics", + "libcrux-macros", + "libcrux-platform", + "libcrux-sha3", +] + +[[package]] +name = "libcrux-ml-kem" +version = "0.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aca7de713c6dddcf7aaf76e8ef9dc0097c8d7ce23a8eadf04c8761734714e184" +dependencies = [ + "hax-lib", + "libcrux-intrinsics", + "libcrux-platform", + "libcrux-secrets", + "libcrux-sha3", + "libcrux-traits", + "rand 0.9.2", + "tls_codec", +] + +[[package]] +name = "libcrux-platform" +version = "0.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d9e21d7ed31a92ac539bd69a8c970b183ee883872d2d19ce27036e24cb8ecc4" +dependencies = [ + "libc", +] + +[[package]] +name = "libcrux-secrets" +version = "0.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ce650f3041b44ba40d4263852347d007cd2cd9d1cc856a6f6c8b2e10c3fd40b" +dependencies = [ + "hax-lib", +] + +[[package]] +name = "libcrux-sha3" +version = "0.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c50f6e04a184511b782c5cc1eb6a227c6d36f2c935e93d698655a93a99696b5" +dependencies = [ + "hax-lib", + "libcrux-intrinsics", + "libcrux-platform", + "libcrux-traits", +] + +[[package]] +name = "libcrux-traits" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e4fa89f3f5e34b47f928b22b1b78395a0d4ec23b1f583db635f128159d65f" +dependencies = [ + "libcrux-secrets", + "rand 0.9.2", +] + [[package]] name = "libm" version = "0.2.15" @@ -1542,6 +1692,16 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-bigint-dig" version = "0.8.6" @@ -1710,6 +1870,12 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pastey" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b867cad97c0791bbd3aaa6472142568c6c9e8f71937e98379f584cfb0cf35bec" + [[package]] name = "pbkdf2" version = "0.12.2" @@ -1906,6 +2072,28 @@ dependencies = [ "elliptic-curve", ] +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "proc-macro2" version = "1.0.101" @@ -1959,23 +2147,14 @@ dependencies = [ "syn 2.0.106", ] -[[package]] -name = "ql-engine" -version = "0.1.0" -dependencies = [ - "bc-components", - "chacha20poly1305", - "rkyv", - "thiserror", -] - [[package]] name = "ql-fsm" version = "0.1.0" dependencies = [ - "bc-components", + "indexmap", + "libcrux-aesgcm", "ql-wire", - "rkyv", + "sha2", "thiserror", ] @@ -1984,11 +2163,14 @@ name = "ql-runtime" version = "0.1.0" dependencies = [ "async-channel", - "bc-components", "futures-lite", + "libcrux-aesgcm", "oneshot", "piper", - "ql-engine", + "ql-fsm", + "ql-wire", + "sha2", + "thiserror", "tokio", ] @@ -1996,10 +2178,12 @@ dependencies = [ name = "ql-wire" version = "0.1.0" dependencies = [ - "bc-components", - "chacha20poly1305", - "rkyv", + "libcrux-aesgcm", + "libcrux-ml-dsa", + "libcrux-ml-kem", + "sha2", "thiserror", + "zerocopy", ] [[package]] @@ -2575,6 +2759,27 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tls_codec" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de2e01245e2bb89d6f05801c564fa27624dbd7b1846859876c7dad82e90bf6b" +dependencies = [ + "tls_codec_derive", + "zeroize", +] + +[[package]] +name = "tls_codec_derive" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d2e76690929402faae40aebdda620a2c0e25dd6d3b9afe48867dfd95991f4bd" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "tokio" version = "1.47.1" @@ -2668,6 +2873,7 @@ version = "1.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" dependencies = [ + "getrandom 0.3.3", "js-sys", "wasm-bindgen", ] diff --git a/Cargo.toml b/Cargo.toml index 496d5574..d56111c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,14 +1,13 @@ [workspace] resolver = "2" members = [ - "api", - "backup-shard", - "btp", - "ql-fsm", - "ql-engine", - "ql-runtime", - "ql-wire", - "quantum-link-macros", + "api", + "backup-shard", + "btp", + "ql-fsm", + "ql-runtime", + "ql-wire", + "quantum-link-macros", ] [workspace.package] diff --git a/ql-engine/Cargo.toml b/ql-engine/Cargo.toml deleted file mode 100644 index 4803431f..00000000 --- a/ql-engine/Cargo.toml +++ /dev/null @@ -1,20 +0,0 @@ -[package] -name = "ql-engine" -version = "0.1.0" -edition = "2021" -description = "Quantum Link v2 duplex stream prototype" -license = "Proprietary" - -[dependencies] -bc-components = { version = "0.28.0", default-features = false, features = [ - "pqcrypto", -] } -chacha20poly1305 = { version = "0.10.1" } -rkyv = { version = "0.8", default-features = false, features = [ - "std", - "bytecheck", - "little_endian", - "unaligned", - "pointer_width_32", -] } -thiserror = { version = "2" } diff --git a/ql-engine/src/arena.rs b/ql-engine/src/arena.rs deleted file mode 100644 index b7b5a4a4..00000000 --- a/ql-engine/src/arena.rs +++ /dev/null @@ -1,194 +0,0 @@ -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct ArenaKey { - index: u32, - generation: u32, -} - -impl ArenaKey { - fn index(self) -> usize { - self.index as usize - } -} - -#[derive(Debug)] -struct Slot { - generation: u32, - value: Option, - next_free: Option, -} - -#[derive(Debug)] -pub struct GenerationalArena { - slots: Vec>, - free_head: Option, - len: usize, -} - -impl GenerationalArena { - pub fn new() -> Self { - Self { - slots: Vec::new(), - free_head: None, - len: 0, - } - } - - pub fn len(&self) -> usize { - self.len - } - - pub fn is_empty(&self) -> bool { - self.len == 0 - } - - pub fn contains(&self, key: ArenaKey) -> bool { - self.get(key).is_some() - } - - pub fn values(&self) -> impl Iterator { - self.slots.iter().filter_map(|slot| slot.value.as_ref()) - } - - pub fn clear(&mut self) { - self.slots.clear(); - self.free_head = None; - self.len = 0; - } - - pub fn insert(&mut self, value: T) -> ArenaKey { - self.len += 1; - - if let Some(index) = self.free_head { - let slot = &mut self.slots[index as usize]; - self.free_head = slot.next_free.take(); - slot.value = Some(value); - return ArenaKey { - index, - generation: slot.generation, - }; - } - - assert!(self.slots.len() < u32::MAX as usize); - let index = self.slots.len() as u32; - self.slots.push(Slot { - generation: 0, - value: Some(value), - next_free: None, - }); - ArenaKey { - index, - generation: 0, - } - } - - pub fn get(&self, key: ArenaKey) -> Option<&T> { - let slot = self.slots.get(key.index())?; - (slot.generation == key.generation) - .then_some(slot.value.as_ref()) - .flatten() - } - - pub fn get_mut(&mut self, key: ArenaKey) -> Option<&mut T> { - let slot = self.slots.get_mut(key.index())?; - (slot.generation == key.generation) - .then_some(slot.value.as_mut()) - .flatten() - } - - pub fn remove(&mut self, key: ArenaKey) -> Option { - let slot = self.slots.get_mut(key.index())?; - if slot.generation != key.generation { - return None; - } - - let value = slot.value.take()?; - slot.generation = slot.generation.wrapping_add(1); - slot.next_free = self.free_head; - self.free_head = Some(key.index); - self.len -= 1; - Some(value) - } - - pub fn retain(&mut self, mut f: impl FnMut(ArenaKey, &mut T) -> bool) { - for (index, slot) in self.slots.iter_mut().enumerate() { - let Some(value) = slot.value.as_mut() else { - continue; - }; - let key = ArenaKey { - index: index as u32, - generation: slot.generation, - }; - if f(key, value) { - continue; - } - let _ = slot.value.take(); - slot.generation = slot.generation.wrapping_add(1); - slot.next_free = self.free_head; - self.free_head = Some(index as u32); - self.len -= 1; - } - } -} - -impl Default for GenerationalArena { - fn default() -> Self { - Self::new() - } -} - -#[cfg(test)] -mod tests { - use super::{ArenaKey, GenerationalArena}; - - #[test] - fn insert_get_remove_round_trips() { - let mut arena = GenerationalArena::new(); - let key = arena.insert("hello"); - - assert_eq!(arena.len(), 1); - assert_eq!(arena.get(key), Some(&"hello")); - assert!(arena.contains(key)); - - assert_eq!(arena.remove(key), Some("hello")); - assert!(arena.is_empty()); - assert_eq!(arena.get(key), None); - assert!(!arena.contains(key)); - } - - #[test] - fn stale_key_does_not_hit_reused_slot() { - let mut arena = GenerationalArena::new(); - let old = arena.insert(10); - assert_eq!(arena.remove(old), Some(10)); - - let new = arena.insert(20); - assert_eq!(old.index(), new.index()); - assert_ne!(old, new); - - assert_eq!(arena.get(old), None); - assert_eq!(arena.get(new), Some(&20)); - } - - #[test] - fn get_mut_updates_value() { - let mut arena = GenerationalArena::new(); - let key = arena.insert(String::from("a")); - - arena.get_mut(key).unwrap().push('b'); - - assert_eq!(arena.get(key).map(String::as_str), Some("ab")); - } - - #[test] - fn remove_rejects_wrong_generation() { - let mut arena = GenerationalArena::new(); - let key = arena.insert(1u32); - let wrong = ArenaKey { - index: key.index as u32, - generation: key.generation.wrapping_add(1), - }; - - assert_eq!(arena.remove(wrong), None); - assert_eq!(arena.get(key), Some(&1)); - } -} diff --git a/ql-engine/src/engine/implementation/handshake.rs b/ql-engine/src/engine/implementation/handshake.rs deleted file mode 100644 index 0acf91dc..00000000 --- a/ql-engine/src/engine/implementation/handshake.rs +++ /dev/null @@ -1,607 +0,0 @@ -use super::*; -use crate::{ - engine::{EngineConfig, EngineState, KeepAliveState}, - identity::QlIdentity, - wire::{handshake::HandshakeRecord, QlPayload, QlRecord}, -}; - -#[derive(Debug)] -enum HelloAction { - StartResponder, - ResendReply { - token: Token, - reply: wire::handshake::HelloReply, - deadline: Instant, - }, - Ignore, -} - -enum HelloReplyAction { - Advance { - hello: wire::handshake::Hello, - responder_signing_key: bc_components::MLDSAPublicKey, - initiator_secret: SymmetricKey, - }, - ResendConfirm { - token: Token, - confirm: wire::handshake::Confirm, - deadline: Instant, - }, -} - -pub fn handle_connect(engine: &mut Engine, crypto: &impl QlCrypto) { - let now = engine.state.now; - let Some(_) = engine.peer.as_ref() else { - return; - }; - let started = { - let config = &engine.config; - let identity = &engine.identity; - let state = &mut engine.state; - let Some(peer_record) = engine.peer.as_mut() else { - return; - }; - start_initiator_handshake(config, identity, state, peer_record, now, crypto) - }; - if started { - engine.emit_peer_status(); - } -} - -pub fn handle_hello( - engine: &mut Engine, - peer: XID, - hello: &wire::handshake::ArchivedHello, - crypto: &impl QlCrypto, -) { - let now = engine.state.now; - let action = match engine.peer.as_ref() { - Some(entry) => { - if wire::handshake::verify_hello(peer, engine.identity.xid, &entry.signing_key, hello) - .is_err() - { - return; - } - match &entry.session { - PeerSession::Initiator { - hello: local_hello, .. - } => { - if peer_hello_wins(local_hello, engine.identity.xid, hello, peer) { - HelloAction::StartResponder - } else { - HelloAction::Ignore - } - } - PeerSession::Responder { - handshake_token, - hello: stored, - reply, - deadline, - stage: HandshakeResponder::WaitingConfirm { .. }, - } => { - if same_hello(stored, hello) { - HelloAction::ResendReply { - token: *handshake_token, - reply: reply.clone(), - deadline: *deadline, - } - } else { - HelloAction::StartResponder - } - } - PeerSession::Responder { .. } - | PeerSession::Disconnected - | PeerSession::Connected { .. } => HelloAction::StartResponder, - } - } - None => return, - }; - - match action { - HelloAction::StartResponder => { - let meta: ControlMeta = (&hello.meta).into(); - if engine.is_replayed_control(peer, meta) { - return; - } - let changed = { - let config = &engine.config; - let identity = &engine.identity; - let state = &mut engine.state; - let Some(peer_record) = engine.peer.as_mut() else { - return; - }; - start_responder_handshake( - config, - identity, - state, - peer_record, - now, - peer, - hello, - crypto, - ) - }; - if changed { - engine.emit_peer_status(); - } - } - HelloAction::ResendReply { - token, - reply, - deadline, - } => { - if engine.handshake_write_pending(token) { - return; - } - engine.clear_handshake_retry_at(token); - enqueue_handshake_record( - engine, - token, - deadline, - peer, - HandshakeRecord::HelloReply(reply), - ); - } - HelloAction::Ignore => {} - } -} - -pub fn handle_hello_reply( - engine: &mut Engine, - peer: XID, - reply: &wire::handshake::ArchivedHelloReply, -) { - let now = engine.state.now; - let action = { - let Some(peer_record) = engine.peer.as_ref() else { - return; - }; - let PeerSession::Initiator { - handshake_token, - hello, - session_key, - stage, - deadline, - .. - } = &peer_record.session - else { - return; - }; - match stage { - HandshakeInitiator::WaitingHelloReply { .. } => HelloReplyAction::Advance { - hello: hello.clone(), - responder_signing_key: peer_record.signing_key.clone(), - initiator_secret: session_key.clone(), - }, - HandshakeInitiator::WaitingReady { - reply: stored_reply, - confirm, - .. - } if same_reply(stored_reply, reply) => HelloReplyAction::ResendConfirm { - token: *handshake_token, - confirm: confirm.clone(), - deadline: *deadline, - }, - HandshakeInitiator::WaitingReady { .. } => return, - } - }; - match action { - HelloReplyAction::Advance { - hello, - responder_signing_key, - initiator_secret, - } => { - let confirm_meta = engine.next_control_meta(engine.config.handshake_timeout); - let (confirm, session_key) = match wire::handshake::build_confirm( - &engine.identity, - peer, - &responder_signing_key, - &hello, - reply, - &initiator_secret, - confirm_meta, - ) { - Ok(result) => result, - Err(_) => return, - }; - let reply_meta: ControlMeta = (&reply.meta).into(); - if engine.is_replayed_control(peer, reply_meta) { - return; - } - let Ok(reply) = wire::deserialize_value(reply) else { - return; - }; - let deadline = now + engine.config.handshake_timeout; - let token = engine.state.next_token(); - let Some(peer_record) = engine.peer.as_mut() else { - return; - }; - peer_record.session = PeerSession::Initiator { - handshake_token: token, - hello, - session_key, - deadline, - stage: HandshakeInitiator::WaitingReady { - reply, - confirm: confirm.clone(), - retry_count: 0, - retry_at: None, - }, - }; - enqueue_handshake_record( - engine, - token, - deadline, - peer, - HandshakeRecord::Confirm(confirm), - ); - } - HelloReplyAction::ResendConfirm { - token, - confirm, - deadline, - } => { - if engine.handshake_write_pending(token) { - return; - } - engine.clear_handshake_retry_at(token); - enqueue_handshake_record( - engine, - token, - deadline, - peer, - HandshakeRecord::Confirm(confirm), - ); - } - } -} - -pub fn handle_confirm( - engine: &mut Engine, - peer: XID, - confirm: &wire::handshake::ArchivedConfirm, - crypto: &impl QlCrypto, -) { - let now = engine.state.now; - if let Some((ready, deadline, token)) = current_ready_resend(engine, now, peer, confirm) { - if engine.handshake_write_pending(token) { - return; - } - enqueue_handshake_record(engine, token, deadline, peer, HandshakeRecord::Ready(ready)); - return; - } - if let Some(ready) = recent_ready_resend(engine, now, peer, confirm) { - let record = QlRecord { - header: QlHeader { - sender: engine.identity.xid, - recipient: peer, - }, - payload: QlPayload::Handshake(HandshakeRecord::Ready(ready)), - }; - engine - .state - .enqueue_control(&engine.config, true, wire::encode_record(&record)); - return; - } - - let res = { - let Some(peer_record) = engine.peer.as_ref() else { - return; - }; - let PeerSession::Responder { - hello, - reply, - stage, - .. - } = &peer_record.session - else { - return; - }; - let HandshakeResponder::WaitingConfirm { secrets, .. } = stage else { - return; - }; - - wire::handshake::finalize_confirm( - peer, - engine.identity.xid, - &peer_record.signing_key, - hello, - reply, - confirm, - secrets, - ) - .map(|session_key| (hello.clone(), reply.clone(), session_key)) - }; - - match res { - Ok((hello, reply, session_key)) => { - let meta: ControlMeta = (&confirm.meta).into(); - if engine.is_replayed_control(peer, meta) { - return; - } - let deadline = now + engine.config.handshake_timeout; - let ready_meta = engine.next_control_meta(engine.config.handshake_timeout); - let ready = wire::handshake::build_ready( - QlHeader { - sender: engine.identity.xid, - recipient: peer, - }, - &session_key, - ready_meta, - encrypted_message_nonce(crypto), - ); - let token = engine.state.next_token(); - if let Some(peer_record) = engine.peer.as_mut() { - peer_record.session = PeerSession::Responder { - handshake_token: token, - hello, - reply, - deadline, - stage: HandshakeResponder::SendingReady { - session_key, - ready: ready.clone(), - }, - }; - } - enqueue_handshake_record(engine, token, deadline, peer, HandshakeRecord::Ready(ready)); - } - Err(_) => {} - } -} - -pub fn handle_ready( - engine: &mut Engine, - peer: XID, - header: &QlHeader, - ready: &mut wire::handshake::ArchivedReady, -) { - let session_key = { - let Some(peer_record) = engine.peer.as_ref() else { - return; - }; - let PeerSession::Initiator { - session_key, stage, .. - } = &peer_record.session - else { - return; - }; - let HandshakeInitiator::WaitingReady { .. } = stage else { - return; - }; - session_key.clone() - }; - - let Ok(body) = wire::handshake::decrypt_ready(header, ready, &session_key) else { - return; - }; - if engine.is_replayed_control(peer, body.meta) { - return; - } - - if let Some(peer_record) = engine.peer.as_mut() { - peer_record.session = PeerSession::Connected { - session_key, - keepalive: KeepAliveState::default(), - recent_ready: None, - }; - } - engine.record_activity(); - engine.emit_peer_status(); -} - -fn start_initiator_handshake( - config: &EngineConfig, - identity: &QlIdentity, - state: &mut EngineState, - peer_record: &mut PeerRecord, - now: Instant, - crypto: &impl QlCrypto, -) -> bool { - if !matches!(peer_record.session, PeerSession::Disconnected) { - return false; - } - - let meta = ControlMeta { - packet_id: state.next_packet_id(), - valid_until: wire::now_secs() + config.handshake_timeout.as_secs(), - }; - let peer = peer_record.peer; - let Ok((hello, session_key)) = - wire::handshake::build_hello(identity, crypto, peer, &peer_record.encapsulation_key, meta) - else { - return false; - }; - - let deadline = now + config.handshake_timeout; - let token = state.next_token(); - peer_record.session = PeerSession::Initiator { - handshake_token: token, - hello: hello.clone(), - session_key, - deadline, - stage: HandshakeInitiator::WaitingHelloReply { - retry_count: 0, - retry_at: None, - }, - }; - let record = QlRecord { - header: QlHeader { - sender: identity.xid, - recipient: peer, - }, - payload: QlPayload::Handshake(HandshakeRecord::Hello(hello)), - }; - state.enqueue_handshake_message(config, token, deadline, wire::encode_record(&record)); - true -} - -fn start_responder_handshake( - config: &EngineConfig, - identity: &QlIdentity, - state: &mut EngineState, - peer_record: &mut PeerRecord, - now: Instant, - peer: XID, - hello: &wire::handshake::ArchivedHello, - crypto: &impl QlCrypto, -) -> bool { - let reply_meta = ControlMeta { - packet_id: state.next_packet_id(), - valid_until: wire::now_secs() + config.handshake_timeout.as_secs(), - }; - let (reply, secrets) = match wire::handshake::respond_hello( - identity, - crypto, - peer, - &peer_record.signing_key, - &peer_record.encapsulation_key, - hello, - reply_meta, - ) { - Ok(result) => result, - Err(_) => { - peer_record.session = PeerSession::Disconnected; - return true; - } - }; - let Ok(hello) = wire::deserialize_value(hello) else { - peer_record.session = PeerSession::Disconnected; - return true; - }; - - let deadline = now + config.handshake_timeout; - let token = state.next_token(); - peer_record.session = PeerSession::Responder { - handshake_token: token, - hello, - reply: reply.clone(), - deadline, - stage: HandshakeResponder::WaitingConfirm { - secrets, - retry_count: 0, - retry_at: None, - }, - }; - - let record = QlRecord { - header: QlHeader { - sender: identity.xid, - recipient: peer, - }, - payload: QlPayload::Handshake(HandshakeRecord::HelloReply(reply)), - }; - state.enqueue_handshake_message(config, token, deadline, wire::encode_record(&record)); - true -} - -pub(super) fn enqueue_handshake_record( - engine: &mut Engine, - token: Token, - deadline: Instant, - peer: XID, - record: HandshakeRecord, -) { - let record = QlRecord { - header: QlHeader { - sender: engine.identity.xid, - recipient: peer, - }, - payload: QlPayload::Handshake(record), - }; - engine.state.enqueue_handshake_message( - &engine.config, - token, - deadline, - wire::encode_record(&record), - ); -} - -fn same_hello(stored: &wire::handshake::Hello, incoming: &wire::handshake::ArchivedHello) -> bool { - let meta: ControlMeta = (&incoming.meta).into(); - stored.meta.packet_id == meta.packet_id && stored.nonce == (&incoming.nonce).into() -} - -fn same_reply( - stored: &wire::handshake::HelloReply, - incoming: &wire::handshake::ArchivedHelloReply, -) -> bool { - let meta: ControlMeta = (&incoming.meta).into(); - stored.meta.packet_id == meta.packet_id && stored.nonce == (&incoming.nonce).into() -} - -fn current_ready_resend( - engine: &Engine, - now: Instant, - peer: XID, - confirm: &wire::handshake::ArchivedConfirm, -) -> Option<(wire::handshake::Ready, Instant, Token)> { - let peer_record = engine.peer.as_ref()?; - let PeerSession::Responder { - handshake_token, - hello, - reply, - deadline, - stage: HandshakeResponder::SendingReady { ready, .. }, - } = &peer_record.session - else { - return None; - }; - if *deadline <= now { - return None; - } - wire::handshake::verify_confirm( - peer, - engine.identity.xid, - &peer_record.signing_key, - hello, - reply, - confirm, - ) - .ok()?; - Some((ready.clone(), *deadline, *handshake_token)) -} - -fn recent_ready_resend( - engine: &Engine, - now: Instant, - peer: XID, - confirm: &wire::handshake::ArchivedConfirm, -) -> Option { - let peer_record = engine.peer.as_ref()?; - let PeerSession::Connected { - recent_ready: Some(recent_ready), - .. - } = &peer_record.session - else { - return None; - }; - if recent_ready.expires_at <= now { - return None; - } - wire::handshake::verify_confirm( - peer, - engine.identity.xid, - &peer_record.signing_key, - &recent_ready.hello, - &recent_ready.reply, - confirm, - ) - .ok()?; - Some(recent_ready.ready.clone()) -} - -fn peer_hello_wins( - local_hello: &wire::handshake::Hello, - local_sender: XID, - peer_hello: &wire::handshake::ArchivedHello, - peer_sender: XID, -) -> bool { - use std::cmp::Ordering; - - let peer_nonce: bc_components::Nonce = (&peer_hello.nonce).into(); - match peer_nonce.data().cmp(local_hello.nonce.data()) { - Ordering::Less => true, - Ordering::Greater => false, - Ordering::Equal => peer_sender.data().cmp(local_sender.data()) == Ordering::Less, - } -} diff --git a/ql-engine/src/engine/implementation/mod.rs b/ql-engine/src/engine/implementation/mod.rs deleted file mode 100644 index 4b747893..00000000 --- a/ql-engine/src/engine/implementation/mod.rs +++ /dev/null @@ -1,762 +0,0 @@ -pub mod handshake; -pub mod peer; -pub mod stream; - -use std::time::{Duration, Instant}; - -use bc_components::{SigningPublicKey, SymmetricKey, XID}; -use rkyv::access_mut; - -use crate::{ - engine::{ - replay_cache::ReplayKey, - state::{ActiveWrite, OutboundWriteKind, TimeoutKind}, - Engine, EngineEvent, HandshakeInitiator, HandshakeResponder, KeepAliveConfig, - KeepAliveState, OutboundWrite, PeerRecord, PeerSession, QlCrypto, RecentReady, - StreamConfig, Token, WriteId, - }, - wire::{ - self, - encrypted_message::{ArchivedEncryptedMessage, NONCE_SIZE}, - stream::{BodyChunk, CloseCode, CloseTarget}, - ControlMeta, QlHeader, - }, - Peer, QlError, StreamId, -}; - -impl Engine { - pub(crate) fn open_stream_inner( - &mut self, - request_head: Vec, - request_prefix: Option, - config: StreamConfig, - ) -> Result { - stream::open_stream(self, request_head, request_prefix, config) - } - - pub(crate) fn bind_peer_inner(&mut self, peer: Peer) { - peer::handle_bind_peer(self, peer); - } - - pub(crate) fn pair_inner(&mut self, crypto: &impl QlCrypto) { - peer::handle_pair_local(self, crypto); - } - - pub(crate) fn connect_inner(&mut self, crypto: &impl QlCrypto) { - handshake::handle_connect(self, crypto); - } - - pub(crate) fn unpair_inner(&mut self) { - peer::handle_unpair_local(self); - } - - pub(crate) fn write_stream_inner( - &mut self, - stream_id: StreamId, - bytes: Vec, - ) -> Result<(), QlError> { - stream::handle_outbound_data(self, stream_id, bytes) - } - - pub(crate) fn finish_stream_inner(&mut self, stream_id: StreamId) -> Result<(), QlError> { - stream::handle_outbound_finished(self, stream_id) - } - - pub(crate) fn close_stream_inner( - &mut self, - stream_id: StreamId, - target: CloseTarget, - code: CloseCode, - payload: Vec, - ) -> Result<(), QlError> { - stream::handle_close_stream(self, stream_id, target, code, payload) - } - - pub(crate) fn receive_inner(&mut self, bytes: Vec, crypto: &impl QlCrypto) { - self.handle_incoming(bytes, crypto); - } - - pub(crate) fn take_next_write_inner( - &mut self, - crypto: &impl QlCrypto, - ) -> Option { - self.take_next_control_write() - .or_else(|| stream::take_next_stream_write(self, crypto)) - } - - pub(crate) fn complete_write_inner(&mut self, write_id: WriteId, result: Result<(), QlError>) { - let Some(active) = self.state.active_writes.remove(write_id.0) else { - return; - }; - - if let Err(error) = result { - if let OutboundWriteKind::Stream(completion) = active.kind { - stream::complete_stream_write(self, completion, Err(error.clone())); - } - - if self.is_handshake_token(active.token) { - if let Some(entry) = self.peer.as_mut() { - entry.session = PeerSession::Disconnected; - } - self.emit_peer_status(); - self.drop_outbound(); - self.abort_streams(error); - } - - return; - } - - if let Some((session_key, recent_ready)) = self.connected_session_for_token(active.token) { - if let Some(entry) = self.peer.as_mut() { - entry.session = PeerSession::Connected { - session_key, - keepalive: KeepAliveState::default(), - recent_ready, - }; - } - self.emit_peer_status(); - self.record_activity(); - } - - if let Some(token) = active.token { - self.schedule_handshake_retry_after_write(token); - } - - if let OutboundWriteKind::Stream(completion) = active.kind { - stream::complete_stream_write(self, completion, Ok(())); - } - } - - pub(crate) fn on_timer_inner(&mut self, crypto: &impl QlCrypto) { - let now = self.state.now; - loop { - let Some(entry) = self - .state - .timeouts - .peek_mut() - .filter(|entry| entry.0.at <= now) - else { - break; - }; - let entry = std::collections::binary_heap::PeekMut::pop(entry).0; - match entry.kind { - TimeoutKind::Outbound { token } => { - self.state - .control_outbound - .retain(|message| message.token != token); - } - } - } - - stream::handle_stream_timeouts(self); - - if let Some(PeerRecord { - session: PeerSession::Connected { recent_ready, .. }, - .. - }) = self.peer.as_mut() - { - if recent_ready - .as_ref() - .is_some_and(|ready| ready.expires_at <= now) - { - *recent_ready = None; - } - } - - let handshake_due = self - .handshake_deadline() - .is_some_and(|deadline| deadline <= now); - if handshake_due { - self.fail_handshake(QlError::Timeout); - return; - } - - let handshake_retry_due = self - .handshake_retry_deadline() - .is_some_and(|deadline| deadline <= now); - if handshake_retry_due { - self.handle_handshake_retry_timeout(); - } - - let keepalive_due = self - .keep_alive_deadline() - .is_some_and(|deadline| deadline <= now); - if !keepalive_due { - return; - } - - let Some(entry) = self.peer.as_ref() else { - return; - }; - let PeerSession::Connected { keepalive, .. } = &entry.session else { - return; - }; - - if keepalive.pending { - if let Some(entry) = self.peer.as_mut() { - entry.session = PeerSession::Disconnected; - } - self.emit_peer_status(); - self.drop_outbound(); - self.abort_streams(QlError::SendFailed); - return; - } - - self.send_heartbeat_message(crypto); - if let Some(entry) = self.peer.as_mut() { - if let PeerSession::Connected { keepalive, .. } = &mut entry.session { - keepalive.pending = true; - keepalive.last_activity = Some(now); - } - } - } - - pub(crate) fn next_deadline_inner(&self) -> Option { - [ - self.state.next_deadline(), - self.streams.next_deadline(), - self.handshake_retry_deadline(), - self.handshake_deadline(), - self.keep_alive_deadline(), - ] - .into_iter() - .flatten() - .min() - } - - pub(crate) fn abort_inner(&mut self, error: QlError) { - self.abort_streams(error); - } -} - -impl Engine { - fn emit_peer_status(&mut self) { - let event = self - .peer - .as_ref() - .map(|peer| EngineEvent::PeerStatusChanged { - peer: peer.peer, - session: peer.session.clone(), - }); - if let Some(event) = event { - self.state.pending_events.push_back(event); - } - } - - fn handle_incoming(&mut self, mut bytes: Vec, crypto: &impl QlCrypto) { - let Ok(record) = access_mut::(&mut bytes) - else { - return; - }; - let record = unsafe { record.unseal_unchecked() }; - let sender: XID = (&record.header.sender).into(); - let recipient: XID = (&record.header.recipient).into(); - if recipient != self.identity.xid { - return; - } - if !matches!(&record.payload, wire::ArchivedQlPayload::Pair(_)) { - let Some(peer) = self.peer.as_ref().map(|peer| peer.peer) else { - return; - }; - if sender != peer { - return; - } - } - let Ok(header) = wire::deserialize_value(&record.header) else { - return; - }; - match &mut record.payload { - wire::ArchivedQlPayload::Handshake(message) => { - self.handle_handshake(sender, &header, message, crypto) - } - wire::ArchivedQlPayload::Stream(encrypted) => { - stream::handle_stream(self, sender, &header, encrypted) - } - wire::ArchivedQlPayload::Heartbeat(encrypted) => { - self.handle_heartbeat(&header, encrypted, crypto) - } - wire::ArchivedQlPayload::Pair(request) => { - peer::handle_pairing(self, &header, request, crypto) - } - wire::ArchivedQlPayload::Unpair(unpair_record) => { - peer::handle_unpair(self, sender, &header, unpair_record) - } - } - } - - fn handle_handshake( - &mut self, - peer: XID, - header: &QlHeader, - message: &mut wire::handshake::ArchivedHandshakeRecord, - crypto: &impl QlCrypto, - ) { - match message { - wire::handshake::ArchivedHandshakeRecord::Hello(hello) => { - handshake::handle_hello(self, peer, hello, crypto) - } - wire::handshake::ArchivedHandshakeRecord::HelloReply(reply) => { - handshake::handle_hello_reply(self, peer, reply) - } - wire::handshake::ArchivedHandshakeRecord::Confirm(confirm) => { - handshake::handle_confirm(self, peer, confirm, crypto) - } - wire::handshake::ArchivedHandshakeRecord::Ready(ready) => { - handshake::handle_ready(self, peer, header, ready) - } - } - } - - fn handle_heartbeat( - &mut self, - header: &QlHeader, - encrypted: &mut ArchivedEncryptedMessage, - crypto: &impl QlCrypto, - ) { - let (body, should_reply) = { - let Some(peer_record) = self.peer.as_ref() else { - return; - }; - let PeerSession::Connected { - session_key, - keepalive, - .. - } = &peer_record.session - else { - return; - }; - let Ok(body) = wire::heartbeat::decrypt_heartbeat(header, encrypted, session_key) - else { - return; - }; - (body, !keepalive.pending) - }; - if self.is_replayed_control(header.sender, body.meta) { - return; - } - self.record_activity(); - if should_reply { - self.send_heartbeat_message(crypto); - } - self.emit_peer_status(); - } - - fn fail_handshake(&mut self, error: QlError) { - if let Some(entry) = self.peer.as_mut() { - if matches!( - entry.session, - PeerSession::Initiator { .. } | PeerSession::Responder { .. } - ) { - entry.session = PeerSession::Disconnected; - } - } - self.emit_peer_status(); - self.drop_outbound(); - self.abort_streams(error); - } - - fn handle_handshake_retry_timeout(&mut self) { - enum RetryAction { - Resend { - token: Token, - peer: XID, - deadline: Instant, - record: wire::handshake::HandshakeRecord, - }, - Fail, - Ignore, - } - - let now = self.state.now; - let action = { - let Some(entry) = self.peer.as_mut() else { - return; - }; - let peer = entry.peer; - match &mut entry.session { - PeerSession::Initiator { - handshake_token, - hello, - deadline, - stage: - HandshakeInitiator::WaitingHelloReply { - retry_count, - retry_at, - }, - .. - } if retry_at.is_some_and(|at| at <= now) => { - let token = *handshake_token; - *retry_at = None; - if *retry_count >= self.config.max_handshake_retries { - RetryAction::Fail - } else { - *retry_count = retry_count.saturating_add(1); - RetryAction::Resend { - token, - peer, - deadline: *deadline, - record: wire::handshake::HandshakeRecord::Hello(hello.clone()), - } - } - } - PeerSession::Initiator { - handshake_token, - deadline, - stage: - HandshakeInitiator::WaitingReady { - confirm, - retry_count, - retry_at, - .. - }, - .. - } if retry_at.is_some_and(|at| at <= now) => { - let token = *handshake_token; - *retry_at = None; - if *retry_count >= self.config.max_handshake_retries { - RetryAction::Fail - } else { - *retry_count = retry_count.saturating_add(1); - RetryAction::Resend { - token, - peer, - deadline: *deadline, - record: wire::handshake::HandshakeRecord::Confirm(confirm.clone()), - } - } - } - PeerSession::Responder { - handshake_token, - reply, - deadline, - stage: - HandshakeResponder::WaitingConfirm { - retry_count, - retry_at, - .. - }, - .. - } if retry_at.is_some_and(|at| at <= now) => { - let token = *handshake_token; - *retry_at = None; - if *retry_count >= self.config.max_handshake_retries { - RetryAction::Fail - } else { - *retry_count = retry_count.saturating_add(1); - RetryAction::Resend { - token, - peer, - deadline: *deadline, - record: wire::handshake::HandshakeRecord::HelloReply(reply.clone()), - } - } - } - _ => RetryAction::Ignore, - } - }; - - match action { - RetryAction::Resend { - token, - peer, - deadline, - record, - } => { - if self.handshake_write_pending(token) { - return; - } - handshake::enqueue_handshake_record(self, token, deadline, peer, record); - } - RetryAction::Fail => self.fail_handshake(QlError::Timeout), - RetryAction::Ignore => {} - } - } - - fn abort_streams(&mut self, error: QlError) { - stream::abort_streams(self, error); - } - - fn next_control_meta(&self, valid_for: Duration) -> ControlMeta { - ControlMeta { - packet_id: self.state.next_packet_id(), - valid_until: wire::now_secs() + valid_for.as_secs(), - } - } - - fn keep_alive_deadline(&self) -> Option { - let config = self.keep_alive_config()?; - let entry = self.peer.as_ref()?; - let PeerSession::Connected { keepalive, .. } = &entry.session else { - return None; - }; - let base = keepalive.last_activity?; - Some( - base + if keepalive.pending { - config.timeout - } else { - config.interval - }, - ) - } - - fn handshake_deadline(&self) -> Option { - let entry = self.peer.as_ref()?; - match &entry.session { - PeerSession::Initiator { deadline, .. } | PeerSession::Responder { deadline, .. } => { - Some(*deadline) - } - PeerSession::Disconnected | PeerSession::Connected { .. } => None, - } - } - - fn handshake_retry_deadline(&self) -> Option { - let entry = self.peer.as_ref()?; - match &entry.session { - PeerSession::Initiator { - stage: HandshakeInitiator::WaitingHelloReply { retry_at, .. }, - .. - } - | PeerSession::Initiator { - stage: HandshakeInitiator::WaitingReady { retry_at, .. }, - .. - } - | PeerSession::Responder { - stage: HandshakeResponder::WaitingConfirm { retry_at, .. }, - .. - } => *retry_at, - PeerSession::Disconnected - | PeerSession::Responder { - stage: HandshakeResponder::SendingReady { .. }, - .. - } - | PeerSession::Connected { .. } => None, - } - } - - fn is_replayed_control(&mut self, peer: XID, meta: ControlMeta) -> bool { - self.state - .replay_cache - .check_and_store_valid_until(ReplayKey::new(peer, meta.packet_id), meta.valid_until) - } - - fn is_handshake_token(&self, token: Option) -> bool { - let Some(token) = token else { - return false; - }; - matches!(self.peer.as_ref().map(|entry| &entry.session), - Some(PeerSession::Initiator { handshake_token, .. }) if *handshake_token == token) - || matches!(self.peer.as_ref().map(|entry| &entry.session), - Some(PeerSession::Responder { handshake_token, .. }) if *handshake_token == token) - } - - fn connected_session_for_token( - &self, - token: Option, - ) -> Option<(SymmetricKey, Option)> { - let token = token?; - self.peer.as_ref().and_then(|entry| match &entry.session { - PeerSession::Responder { - hello, - reply, - deadline, - handshake_token, - stage: HandshakeResponder::SendingReady { session_key, ready }, - } if *handshake_token == token => Some(( - session_key.clone(), - Some(RecentReady { - hello: hello.clone(), - reply: reply.clone(), - ready: ready.clone(), - expires_at: *deadline, - }), - )), - _ => None, - }) - } - - fn handshake_write_pending(&self, token: Token) -> bool { - self.state - .active_writes - .values() - .any(|active| active.token == Some(token)) - || self - .state - .control_outbound - .iter() - .any(|message| message.token == token) - } - - fn clear_handshake_retry_at(&mut self, token: Token) { - let Some(entry) = self.peer.as_mut() else { - return; - }; - match &mut entry.session { - PeerSession::Initiator { - handshake_token, - stage: HandshakeInitiator::WaitingHelloReply { retry_at, .. }, - .. - } if *handshake_token == token => *retry_at = None, - PeerSession::Initiator { - handshake_token, - stage: HandshakeInitiator::WaitingReady { retry_at, .. }, - .. - } if *handshake_token == token => *retry_at = None, - PeerSession::Responder { - handshake_token, - stage: HandshakeResponder::WaitingConfirm { retry_at, .. }, - .. - } if *handshake_token == token => *retry_at = None, - _ => {} - } - } - - fn schedule_handshake_retry_after_write(&mut self, token: Token) { - if self.config.handshake_retry_interval.is_zero() || self.config.max_handshake_retries == 0 - { - return; - } - let now = self.state.now; - let retry_at = now + self.config.handshake_retry_interval; - let Some(entry) = self.peer.as_mut() else { - return; - }; - match &mut entry.session { - PeerSession::Initiator { - handshake_token, - stage: - HandshakeInitiator::WaitingHelloReply { - retry_at: stage_retry_at, - .. - }, - .. - } if *handshake_token == token => { - *stage_retry_at = Some(retry_at); - } - PeerSession::Initiator { - handshake_token, - stage: - HandshakeInitiator::WaitingReady { - retry_at: stage_retry_at, - .. - }, - .. - } if *handshake_token == token => { - *stage_retry_at = Some(retry_at); - } - PeerSession::Responder { - handshake_token, - stage: - HandshakeResponder::WaitingConfirm { - retry_at: stage_retry_at, - .. - }, - .. - } if *handshake_token == token => { - *stage_retry_at = Some(retry_at); - } - _ => {} - } - } - - fn peer_session(&self) -> Option<(XID, SymmetricKey)> { - self.peer.as_ref().and_then(|peer| { - peer.session - .session_key() - .map(|key| (peer.peer, key.clone())) - }) - } - - // todo: this is called in too many places - fn sync_stream_namespace(&mut self) { - use crate::stream::StreamNamespace; - let namespace = self - .peer - .as_ref() - .map(|peer| StreamNamespace::for_local(self.identity.xid, peer.peer)) - .unwrap_or(crate::stream::StreamNamespace::Low); - self.streams.set_local_namespace(namespace); - } - - fn issue_write( - &mut self, - kind: OutboundWriteKind, - token: Option, - bytes: Vec, - ) -> OutboundWrite { - let id = WriteId(self.state.active_writes.insert(ActiveWrite { token, kind })); - OutboundWrite { id, bytes } - } - - fn take_next_control_write(&mut self) -> Option { - while let Some(message) = self.state.control_outbound.pop_front() { - return Some(self.issue_write( - OutboundWriteKind::Control, - Some(message.token), - message.bytes, - )); - } - None - } - - fn send_heartbeat_message(&mut self, crypto: &impl QlCrypto) { - let Some(peer) = self.peer.as_ref().map(|peer| peer.peer) else { - return; - }; - let now = self.state.now; - let meta = self.next_control_meta(self.config.packet_expiration); - let token = self.state.next_token(); - let deadline = now + self.config.packet_expiration; - let message = { - let Some(peer_record) = self.peer.as_ref() else { - return; - }; - let PeerSession::Connected { session_key, .. } = &peer_record.session else { - return; - }; - wire::heartbeat::encrypt_heartbeat( - QlHeader { - sender: self.identity.xid, - recipient: peer, - }, - session_key, - wire::heartbeat::HeartbeatBody { meta }, - encrypted_message_nonce(crypto), - ) - }; - self.state.enqueue_handshake_message( - &self.config, - token, - deadline, - wire::encode_record(&message), - ); - } - - fn keep_alive_config(&self) -> Option { - self.config - .keep_alive - .filter(|config| !config.interval.is_zero() && !config.timeout.is_zero()) - } - - fn record_activity(&mut self) { - let now = self.state.now; - if let Some(PeerRecord { - session: PeerSession::Connected { keepalive, .. }, - .. - }) = self.peer.as_mut() - { - keepalive.last_activity = Some(now); - keepalive.pending = false; - } - } - - fn drop_outbound(&mut self) { - self.state.control_outbound.clear(); - self.state.active_writes.clear(); - } -} - -fn encrypted_message_nonce(crypto: &impl QlCrypto) -> [u8; NONCE_SIZE] { - let mut nonce = [0u8; NONCE_SIZE]; - crypto.fill_random_bytes(&mut nonce); - nonce -} diff --git a/ql-engine/src/engine/implementation/peer.rs b/ql-engine/src/engine/implementation/peer.rs deleted file mode 100644 index 05f7d76f..00000000 --- a/ql-engine/src/engine/implementation/peer.rs +++ /dev/null @@ -1,160 +0,0 @@ -use super::*; - -pub fn handle_bind_peer(engine: &mut Engine, peer: Peer) { - if let Some(peer) = engine.peer.as_ref().map(|existing| existing.peer) { - engine - .state - .pending_events - .push_back(EngineEvent::PeerStatusChanged { - peer, - session: PeerSession::Disconnected, - }); - } - bind_peer_record(engine, peer); -} - -pub fn handle_pair_local(engine: &mut Engine, crypto: &impl QlCrypto) { - let now = engine.state.now; - let Some(peer) = engine.peer.as_ref() else { - return; - }; - let meta = engine.next_control_meta(engine.config.packet_expiration); - let Ok(record) = wire::pair::build_pair_request( - &engine.identity, - crypto, - peer.peer, - &peer.encapsulation_key, - meta, - ) else { - return; - }; - let token = engine.state.next_token(); - engine.state.enqueue_handshake_message( - &engine.config, - token, - now + engine.config.packet_expiration, - wire::encode_record(&record), - ); -} - -pub fn handle_unpair_local(engine: &mut Engine) { - let now = engine.state.now; - let Some(peer) = engine.peer.as_ref().map(|peer| peer.peer) else { - return; - }; - let meta = engine.next_control_meta(engine.config.packet_expiration); - let record = wire::unpair::build_unpair_record( - &engine.identity, - QlHeader { - sender: engine.identity.xid, - recipient: peer, - }, - meta, - ); - unpair_peer(engine); - let token = engine.state.next_token(); - engine.state.enqueue_handshake_message( - &engine.config, - token, - now + engine.config.packet_expiration, - wire::encode_record(&record), - ); -} - -pub fn handle_pairing( - engine: &mut Engine, - header: &QlHeader, - request: &mut wire::pair::ArchivedPairRequestRecord, - crypto: &impl QlCrypto, -) { - let payload = match wire::pair::decrypt_pair_request(&engine.identity, header, request) { - Ok(payload) => payload, - Err(_) => return, - }; - let peer = XID::new(SigningPublicKey::MLDSA(payload.signing_pub_key.clone())); - if engine.is_replayed_control(peer, payload.meta) { - return; - } - if let Some(existing) = engine.peer.as_ref() { - if existing.peer != peer - || existing.signing_key != payload.signing_pub_key - || existing.encapsulation_key != payload.encapsulation_pub_key - { - return; - } - } else { - bind_peer_record( - engine, - Peer { - peer, - signing_key: payload.signing_pub_key, - encapsulation_key: payload.encapsulation_pub_key, - }, - ); - } - handshake::handle_connect(engine, crypto); -} - -pub fn handle_unpair( - engine: &mut Engine, - peer: XID, - header: &QlHeader, - record: &wire::unpair::ArchivedUnpairRecord, -) { - { - let Some(peer_record) = engine.peer.as_ref() else { - return; - }; - if wire::unpair::verify_unpair_record(header, record, &peer_record.signing_key).is_err() { - return; - } - } - let meta: ControlMeta = (&record.meta).into(); - if engine.is_replayed_control(peer, meta) { - return; - } - unpair_peer(engine); -} - -fn bind_peer_record(engine: &mut Engine, peer: Peer) { - reset_runtime(engine, QlError::Cancelled); - engine.peer = Some(PeerRecord::new( - peer.peer, - peer.signing_key, - peer.encapsulation_key, - )); - engine.emit_peer_status(); - if let Some(peer) = engine.peer.as_ref().map(PeerRecord::snapshot) { - engine - .state - .pending_events - .push_back(EngineEvent::PersistPeer(peer)); - } -} - -fn reset_runtime(engine: &mut Engine, error: QlError) { - engine.abort_streams(error); - engine.state.control_outbound.clear(); - engine.state.active_writes.clear(); - engine.state.timeouts.clear(); -} - -fn unpair_peer(engine: &mut Engine) { - let Some(peer) = engine.peer.as_ref().map(|peer| peer.peer) else { - return; - }; - engine.drop_outbound(); - engine.abort_streams(QlError::SendFailed); - engine.peer = None; - engine - .state - .pending_events - .push_back(EngineEvent::PeerStatusChanged { - peer, - session: PeerSession::Disconnected, - }); - engine - .state - .pending_events - .push_back(EngineEvent::ClearPeer); -} diff --git a/ql-engine/src/engine/implementation/stream.rs b/ql-engine/src/engine/implementation/stream.rs deleted file mode 100644 index 02dd25a7..00000000 --- a/ql-engine/src/engine/implementation/stream.rs +++ /dev/null @@ -1,337 +0,0 @@ -use super::*; -use crate::{ - engine::{state::OutboundWriteKind, Engine, EngineEvent, EngineState, QlCrypto}, - stream::{StreamCloseEvent, StreamCloseKind, StreamError, StreamEventSink, WriteError}, - wire::stream::*, -}; - -struct EngineStreamSink<'a> { - state: &'a mut EngineState, -} - -impl EngineStreamSink<'_> { - fn clear_active_writes_for_stream(&mut self, stream_id: StreamId) { - self.state - .active_writes - .retain(|_, active| match active.kind { - OutboundWriteKind::Control => true, - OutboundWriteKind::Stream(completion) => completion.stream_id() != stream_id, - }); - } - - fn emit_remote_close(&mut self, event: StreamCloseEvent) { - let error = QlError::StreamClosed { - target: event.frame.target, - code: event.frame.code, - payload: event.frame.payload, - }; - - match event.role { - crate::stream::StreamLocalRole::Initiator => { - if matches!(event.frame.target, CloseTarget::Request | CloseTarget::Both) { - self.state - .pending_events - .push_back(EngineEvent::OutboundFailed { - stream_id: event.frame.stream_id, - error: error.clone(), - }); - } - if matches!( - event.frame.target, - CloseTarget::Response | CloseTarget::Both - ) { - self.state - .pending_events - .push_back(EngineEvent::InboundFailed { - stream_id: event.frame.stream_id, - error, - }); - } - } - crate::stream::StreamLocalRole::Responder => { - if matches!(event.frame.target, CloseTarget::Request | CloseTarget::Both) { - self.state - .pending_events - .push_back(EngineEvent::InboundFailed { - stream_id: event.frame.stream_id, - error: error.clone(), - }); - } - if matches!( - event.frame.target, - CloseTarget::Response | CloseTarget::Both - ) { - self.state - .pending_events - .push_back(EngineEvent::OutboundFailed { - stream_id: event.frame.stream_id, - error, - }); - } - } - } - } - - fn emit_acked_close(&mut self, event: StreamCloseEvent) { - let affects_outbound = match event.role { - crate::stream::StreamLocalRole::Initiator => { - matches!(event.frame.target, CloseTarget::Request | CloseTarget::Both) - } - crate::stream::StreamLocalRole::Responder => { - matches!( - event.frame.target, - CloseTarget::Response | CloseTarget::Both - ) - } - }; - if !affects_outbound { - return; - } - - self.state - .pending_events - .push_back(EngineEvent::OutboundFailed { - stream_id: event.frame.stream_id, - error: QlError::StreamClosed { - target: event.frame.target, - code: event.frame.code, - payload: event.frame.payload, - }, - }); - } -} - -impl StreamEventSink for EngineStreamSink<'_> { - fn opened( - &mut self, - stream_id: StreamId, - request_head: Vec, - request_prefix: Option, - ) { - self.state - .pending_events - .push_back(EngineEvent::InboundStreamOpened { - stream_id, - request_head, - request_prefix, - }); - } - - fn inbound_data(&mut self, stream_id: StreamId, bytes: Vec) { - self.state - .pending_events - .push_back(EngineEvent::InboundData { stream_id, bytes }); - } - - fn inbound_finished(&mut self, stream_id: StreamId) { - self.state - .pending_events - .push_back(EngineEvent::InboundFinished { stream_id }); - } - - fn inbound_failed(&mut self, stream_id: StreamId, error: StreamError) { - self.state - .pending_events - .push_back(EngineEvent::InboundFailed { - stream_id, - error: stream_error(error), - }); - } - - fn close(&mut self, event: StreamCloseEvent) { - match event.kind { - StreamCloseKind::Acked => self.emit_acked_close(event), - StreamCloseKind::Remote => self.emit_remote_close(event), - } - } - - fn outbound_closed(&mut self, stream_id: StreamId) { - self.state - .pending_events - .push_back(EngineEvent::OutboundClosed { stream_id }); - } - - fn outbound_failed(&mut self, stream_id: StreamId, error: StreamError) { - self.state - .pending_events - .push_back(EngineEvent::OutboundFailed { - stream_id, - error: stream_error(error), - }); - } - - fn reaped(&mut self, stream_id: StreamId) { - self.clear_active_writes_for_stream(stream_id); - self.state - .pending_events - .push_back(EngineEvent::StreamReaped { stream_id }); - } -} - -pub fn open_stream( - engine: &mut Engine, - request_head: Vec, - request_prefix: Option, - _config: StreamConfig, -) -> Result { - let Some(entry) = engine.peer.as_ref() else { - return Err(QlError::NoPeerBound); - }; - if !entry.session.is_connected() { - return Err(QlError::MissingSession); - } - - engine.sync_stream_namespace(); - Ok(engine.streams.open_stream(request_head, request_prefix)) -} - -pub fn handle_close_stream( - engine: &mut Engine, - stream_id: StreamId, - target: CloseTarget, - code: CloseCode, - payload: Vec, -) -> Result<(), QlError> { - engine - .streams - .close_stream(stream_id, target, code, payload) - .map_err(stream_error) -} - -pub fn handle_outbound_data( - engine: &mut Engine, - stream_id: StreamId, - bytes: Vec, -) -> Result<(), QlError> { - engine - .streams - .write_stream(stream_id, bytes) - .map_err(stream_error) -} - -pub fn handle_outbound_finished(engine: &mut Engine, stream_id: StreamId) -> Result<(), QlError> { - engine - .streams - .finish_stream(stream_id) - .map_err(stream_error) -} - -pub fn handle_stream( - engine: &mut Engine, - _peer: XID, - header: &QlHeader, - encrypted: &mut ArchivedEncryptedMessage, -) { - let now = engine.state.now; - let body = { - let Some(peer_record) = engine.peer.as_ref() else { - return; - }; - let PeerSession::Connected { session_key, .. } = &peer_record.session else { - return; - }; - match decrypt_stream(header, encrypted, session_key) { - Ok(body) => body, - Err(_) => return, - } - }; - - engine.record_activity(); - engine.sync_stream_namespace(); - - let mut sink = EngineStreamSink { - state: &mut engine.state, - }; - engine.streams.receive(now, body, &mut sink); -} - -pub fn take_next_stream_write( - engine: &mut Engine, - crypto: &impl QlCrypto, -) -> Option { - let (recipient, session_key) = engine.peer_session()?; - engine.sync_stream_namespace(); - - let outbound = engine.streams.next_outbound( - engine.state.now, - wire::now_secs().saturating_add(engine.config.packet_expiration.as_secs()), - )?; - let record = encrypt_stream( - QlHeader { - sender: engine.identity.xid, - recipient, - }, - &session_key, - &outbound.body, - encrypted_message_nonce(crypto), - ); - - Some(engine.issue_write( - OutboundWriteKind::Stream(outbound.completion), - None, - wire::encode_record(&record), - )) -} - -pub fn complete_stream_write( - engine: &mut Engine, - completion: crate::stream::OutboundCompletion, - result: Result<(), QlError>, -) { - let now = engine.state.now; - let mut sink = EngineStreamSink { - state: &mut engine.state, - }; - engine.streams.complete_outbound( - now, - completion, - result.map_err(|_| WriteError::SendFailed), - &mut sink, - ); -} - -pub fn handle_stream_timeouts(engine: &mut Engine) { - let now = engine.state.now; - if !engine - .streams - .next_deadline() - .is_some_and(|deadline| deadline <= now) - { - return; - } - - let mut sink = EngineStreamSink { - state: &mut engine.state, - }; - engine.streams.on_timer(now, &mut sink); -} - -pub fn abort_streams(engine: &mut Engine, error: QlError) { - let mut sink = EngineStreamSink { - state: &mut engine.state, - }; - engine.streams.abort(stream_error_inverse(error), &mut sink); -} - -fn stream_error(error: StreamError) -> QlError { - match error { - StreamError::MissingStream | StreamError::NotWritable => QlError::StreamProtocol, - StreamError::SendFailed => QlError::SendFailed, - StreamError::Timeout => QlError::Timeout, - StreamError::Cancelled => QlError::Cancelled, - StreamError::StreamProtocol => QlError::StreamProtocol, - } -} - -fn stream_error_inverse(error: QlError) -> StreamError { - match error { - QlError::SendFailed => StreamError::SendFailed, - QlError::Timeout => StreamError::Timeout, - QlError::Cancelled => StreamError::Cancelled, - QlError::StreamProtocol | QlError::StreamClosed { .. } => StreamError::StreamProtocol, - QlError::NoPeerBound - | QlError::MissingSession - | QlError::InvalidPayload - | QlError::InvalidSignature => StreamError::Cancelled, - } -} diff --git a/ql-engine/src/engine/mod.rs b/ql-engine/src/engine/mod.rs deleted file mode 100644 index ed4f964a..00000000 --- a/ql-engine/src/engine/mod.rs +++ /dev/null @@ -1,231 +0,0 @@ -mod implementation; -pub mod replay_cache; -mod state; -#[cfg(test)] -mod tests; - -use std::time::{Duration, Instant}; - -use bc_components::XID; -pub use state::{ - Engine, EngineState, HandshakeInitiator, HandshakeResponder, KeepAliveState, OutboundWrite, - PeerRecord, PeerSession, RecentReady, Token, WriteId, -}; - -use crate::{ - identity::QlIdentity, - stream, - wire::stream::{BodyChunk, CloseCode, CloseTarget}, - Peer, QlError, StreamId, -}; - -pub trait QlCrypto { - fn fill_random_bytes(&self, data: &mut [u8]); -} - -#[derive(Debug, Clone)] -pub enum EngineEvent { - PeerStatusChanged { - peer: XID, - session: PeerSession, - }, - PersistPeer(Peer), - ClearPeer, - - InboundStreamOpened { - stream_id: StreamId, - request_head: Vec, - request_prefix: Option, - }, - InboundData { - stream_id: StreamId, - bytes: Vec, - }, - InboundFinished { - stream_id: StreamId, - }, - InboundFailed { - stream_id: StreamId, - error: QlError, - }, - - OutboundClosed { - stream_id: StreamId, - }, - OutboundFailed { - stream_id: StreamId, - error: QlError, - }, - - StreamReaped { - stream_id: StreamId, - }, -} - -#[derive(Debug, Clone, Copy)] -pub struct KeepAliveConfig { - pub interval: Duration, - pub timeout: Duration, -} - -#[derive(Debug, Clone, Copy, Default)] -pub struct StreamConfig {} - -#[derive(Debug, Clone, Copy)] -pub struct EngineConfig { - pub handshake_timeout: Duration, - pub handshake_retry_interval: Duration, - pub max_handshake_retries: u8, - pub packet_expiration: Duration, - pub stream_ack_delay: Duration, - pub stream_ack_timeout: Duration, - pub stream_fast_retransmit_threshold: u8, - pub stream_retry_limit: u8, - pub keep_alive: Option, -} - -impl Default for EngineConfig { - fn default() -> Self { - Self { - handshake_timeout: Duration::from_secs(5), - handshake_retry_interval: Duration::from_millis(750), - max_handshake_retries: 3, - packet_expiration: Duration::from_secs(30), - stream_ack_delay: Duration::from_millis(5), - stream_ack_timeout: Duration::from_millis(150), - stream_fast_retransmit_threshold: 2, - stream_retry_limit: 5, - keep_alive: None, - } - } -} - -impl Engine { - pub fn new(config: EngineConfig, identity: QlIdentity, peer: Option) -> Self { - let local_namespace = peer - .as_ref() - .map(|peer| stream::StreamNamespace::for_local(identity.xid, peer.peer)) - .map(|namespace| match namespace { - stream::StreamNamespace::Low => crate::stream::StreamNamespace::Low, - stream::StreamNamespace::High => crate::stream::StreamNamespace::High, - }) - .unwrap_or(crate::stream::StreamNamespace::Low); - Self { - config: config, - identity, - peer: peer - .map(|peer| PeerRecord::new(peer.peer, peer.signing_key, peer.encapsulation_key)), - state: EngineState::new(), - streams: stream::StreamFsm::new(stream::StreamFsmConfig { - local_namespace, - ack_delay: config.stream_ack_delay, - ack_timeout: config.stream_ack_timeout, - fast_retransmit_threshold: config.stream_fast_retransmit_threshold, - retry_limit: config.stream_retry_limit, - }), - } - } - - pub fn open_stream( - &mut self, - now: Instant, - request_head: Vec, - request_prefix: Option, - config: StreamConfig, - ) -> Result { - self.state.now = now; - self.open_stream_inner(request_head, request_prefix, config) - } - - pub fn bind_peer(&mut self, now: Instant, peer: Peer) { - self.state.now = now; - self.bind_peer_inner(peer); - } - - pub fn pair(&mut self, now: Instant, crypto: &impl QlCrypto) { - self.state.now = now; - self.pair_inner(crypto); - } - - pub fn connect(&mut self, now: Instant, crypto: &impl QlCrypto) { - self.state.now = now; - self.connect_inner(crypto); - } - - pub fn unpair(&mut self, now: Instant) { - self.state.now = now; - self.unpair_inner(); - } - - pub fn take_next_write( - &mut self, - now: Instant, - crypto: &impl QlCrypto, - ) -> Option { - self.state.now = now; - self.take_next_write_inner(crypto) - } - - pub fn complete_write(&mut self, now: Instant, write_id: WriteId, result: Result<(), QlError>) { - self.state.now = now; - self.complete_write_inner(write_id, result); - } - - pub fn write_stream( - &mut self, - now: Instant, - stream_id: StreamId, - bytes: Vec, - ) -> Result<(), QlError> { - self.state.now = now; - self.write_stream_inner(stream_id, bytes) - } - - pub fn finish_stream(&mut self, now: Instant, stream_id: StreamId) -> Result<(), QlError> { - self.state.now = now; - self.finish_stream_inner(stream_id) - } - - pub fn close_stream( - &mut self, - now: Instant, - stream_id: StreamId, - target: CloseTarget, - code: CloseCode, - payload: Vec, - ) -> Result<(), QlError> { - self.state.now = now; - self.close_stream_inner(stream_id, target, code, payload) - } - - pub fn receive(&mut self, now: Instant, bytes: Vec, crypto: &impl QlCrypto) { - self.state.now = now; - self.receive_inner(bytes, crypto); - } - - pub fn on_timer(&mut self, now: Instant, crypto: &impl QlCrypto) { - self.state.now = now; - self.on_timer_inner(crypto); - } - - pub fn next_deadline(&self) -> Option { - self.next_deadline_inner() - } - - pub fn take_next_event(&mut self) -> Option { - self.state.pending_events.pop_front() - } - - pub fn has_pending_events(&self) -> bool { - !self.state.pending_events.is_empty() - } - - pub fn drain_events(&mut self) -> std::collections::vec_deque::Drain<'_, EngineEvent> { - self.state.pending_events.drain(..) - } - - pub fn abort(&mut self, now: Instant, error: QlError) { - self.state.now = now; - self.abort_inner(error); - } -} diff --git a/ql-engine/src/engine/replay_cache.rs b/ql-engine/src/engine/replay_cache.rs deleted file mode 100644 index 8b5d5dc3..00000000 --- a/ql-engine/src/engine/replay_cache.rs +++ /dev/null @@ -1,178 +0,0 @@ -use std::{ - cmp::Reverse, - collections::{binary_heap::PeekMut, BinaryHeap, HashSet}, - time::{SystemTime, UNIX_EPOCH}, -}; - -use bc_components::XID; - -use crate::PacketId; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct ReplayKey { - /// unfortunately we need this in the key, to avoid replay attacks of pair/unpair. - pub peer: XID, - pub packet_id: PacketId, -} - -impl ReplayKey { - pub const fn new(peer: XID, packet_id: PacketId) -> Self { - Self { peer, packet_id } - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -struct ExpiryEntry { - expires_at: u64, - key: ReplayKey, -} - -impl Ord for ExpiryEntry { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.expires_at - .cmp(&other.expires_at) - .then_with(|| self.key.cmp(&other.key)) - } -} - -impl PartialOrd for ExpiryEntry { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -#[derive(Debug, Default)] -pub struct ReplayCache { - entries: HashSet, - expirations: BinaryHeap>, -} - -impl ReplayCache { - pub fn new() -> Self { - Self { - entries: HashSet::new(), - expirations: BinaryHeap::new(), - } - } - - pub fn len(&self) -> usize { - self.entries.len() - } - - pub fn is_empty(&self) -> bool { - self.entries.is_empty() - } - - pub fn add(&mut self, key: ReplayKey, expires_at: u64) { - if self.entries.insert(key) { - self.expirations - .push(Reverse(ExpiryEntry { expires_at, key })); - } - } - - pub fn check_and_store(&mut self, key: ReplayKey, expires_at: u64) -> bool { - let now_secs = now_secs(); - self.check_and_store_at(key, expires_at, now_secs) - } - - pub fn check_and_store_valid_until(&mut self, key: ReplayKey, valid_until: u64) -> bool { - let now_secs = now_secs(); - self.check_and_store_at(key, valid_until, now_secs) - } - - pub fn purge_expired(&mut self) { - let now_secs = now_secs(); - self.purge_expired_at(now_secs); - } - - pub fn clear_peer(&mut self, peer: XID) { - self.entries.retain(|entry| entry.peer != peer); - self.expirations.retain(|entry| entry.0.key.peer != peer); - } - - fn check_and_store_at(&mut self, key: ReplayKey, expires_at: u64, now_secs: u64) -> bool { - self.purge_expired_at(now_secs); - if self.entries.contains(&key) { - return true; - } - self.entries.insert(key); - self.expirations - .push(Reverse(ExpiryEntry { expires_at, key })); - false - } - - fn purge_expired_at(&mut self, now_secs: u64) { - while let Some(entry) = self.expirations.peek_mut() { - if entry.0.expires_at > now_secs { - break; - } - let entry = PeekMut::pop(entry).0; - self.entries.remove(&entry.key); - } - } -} - -fn now_secs() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|duration| duration.as_secs()) - .unwrap_or(0) -} - -#[cfg(test)] -mod tests { - use super::*; - - fn peer_with_byte(byte: u8) -> XID { - XID::from_data([byte; XID::XID_SIZE]) - } - - #[test] - fn check_and_store_detects_replay() { - let mut cache = ReplayCache::new(); - let peer = peer_with_byte(1); - let key = ReplayKey::new(peer, PacketId(1)); - let now_secs = 100; - let expires_at = 110; - - assert!(!cache.check_and_store_at(key, expires_at, now_secs)); - assert!(cache.check_and_store_at(key, expires_at, now_secs)); - } - - #[test] - fn purge_expired_removes_old_entries() { - let mut cache = ReplayCache::new(); - let now_secs = 100; - let expired_at = 99; - let future_at = 110; - - let key_old = ReplayKey::new(peer_with_byte(2), PacketId(2)); - let key_new = ReplayKey::new(peer_with_byte(3), PacketId(3)); - - cache.add(key_old, expired_at); - cache.add(key_new, future_at); - - cache.purge_expired_at(now_secs); - assert_eq!(cache.len(), 1); - assert!(!cache.check_and_store_at(key_old, future_at, now_secs)); - } - - #[test] - fn clear_peer_removes_peer_entries() { - let mut cache = ReplayCache::new(); - let now_secs = 100; - let expires_at = 110; - - let peer_a = peer_with_byte(4); - let peer_b = peer_with_byte(5); - let key_a = ReplayKey::new(peer_a, PacketId(4)); - let key_b = ReplayKey::new(peer_b, PacketId(5)); - - cache.add(key_a, expires_at); - cache.add(key_b, expires_at); - - cache.clear_peer(peer_a); - assert_eq!(cache.len(), 1); - assert!(!cache.check_and_store_at(key_a, expires_at, now_secs)); - } -} diff --git a/ql-engine/src/engine/state.rs b/ql-engine/src/engine/state.rs deleted file mode 100644 index 481d4f06..00000000 --- a/ql-engine/src/engine/state.rs +++ /dev/null @@ -1,270 +0,0 @@ -use std::{ - cell::Cell, - cmp::Reverse, - collections::{BinaryHeap, VecDeque}, - time::Instant, -}; - -use bc_components::{MLDSAPublicKey, MLKEMPublicKey, SymmetricKey, XID}; - -use super::{replay_cache::ReplayCache, EngineConfig, EngineEvent}; -use crate::{ - arena::{ArenaKey, GenerationalArena}, - identity::QlIdentity, - stream::{self, StreamFsm}, - wire::handshake::{Confirm, Hello, HelloReply, Ready, ResponderSecrets}, - PacketId, Peer, -}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Token(pub u64); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct WriteId(pub(crate) ArenaKey); - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum OutboundWriteKind { - Control, - Stream(stream::OutboundCompletion), -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct OutboundWrite { - pub id: WriteId, - pub bytes: Vec, -} - -#[derive(Debug)] -pub struct ControlWrite { - pub token: Token, - pub bytes: Vec, -} - -#[derive(Debug, Clone, Copy)] -pub struct ActiveWrite { - pub token: Option, - pub kind: OutboundWriteKind, -} - -#[derive(Debug, Clone)] -pub struct KeepAliveState { - pub pending: bool, - pub last_activity: Option, -} - -impl Default for KeepAliveState { - fn default() -> Self { - Self { - pending: false, - last_activity: None, - } - } -} - -#[derive(Debug, Clone, PartialEq)] -pub enum HandshakeInitiator { - WaitingHelloReply { - retry_count: u8, - retry_at: Option, - }, - WaitingReady { - reply: HelloReply, - confirm: Confirm, - retry_count: u8, - retry_at: Option, - }, -} - -#[derive(Debug, Clone)] -pub enum HandshakeResponder { - WaitingConfirm { - secrets: ResponderSecrets, - retry_count: u8, - retry_at: Option, - }, - SendingReady { - session_key: SymmetricKey, - ready: Ready, - }, -} - -#[derive(Debug, Clone)] -pub struct RecentReady { - pub hello: Hello, - pub reply: HelloReply, - pub ready: Ready, - pub expires_at: Instant, -} - -#[derive(Debug, Clone)] -pub enum PeerSession { - Disconnected, - Initiator { - handshake_token: Token, - hello: Hello, - session_key: SymmetricKey, - deadline: Instant, - stage: HandshakeInitiator, - }, - Responder { - handshake_token: Token, - hello: Hello, - reply: HelloReply, - deadline: Instant, - stage: HandshakeResponder, - }, - Connected { - session_key: SymmetricKey, - keepalive: KeepAliveState, - recent_ready: Option, - }, -} - -impl PeerSession { - pub fn is_connected(&self) -> bool { - matches!(self, Self::Connected { .. }) - } - - pub fn session_key(&self) -> Option<&SymmetricKey> { - match self { - Self::Connected { session_key, .. } => Some(session_key), - _ => None, - } - } -} - -#[derive(Debug, Clone)] -pub struct PeerRecord { - pub peer: XID, - pub signing_key: MLDSAPublicKey, - pub encapsulation_key: MLKEMPublicKey, - pub session: PeerSession, -} - -impl PeerRecord { - pub fn new(peer: XID, signing_key: MLDSAPublicKey, encapsulation_key: MLKEMPublicKey) -> Self { - Self { - peer, - signing_key, - encapsulation_key, - session: PeerSession::Disconnected, - } - } - - pub fn snapshot(&self) -> Peer { - Peer { - peer: self.peer, - signing_key: self.signing_key.clone(), - encapsulation_key: self.encapsulation_key.clone(), - } - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum TimeoutKind { - Outbound { token: Token }, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct TimeoutEntry { - pub at: Instant, - pub kind: TimeoutKind, -} - -impl Ord for TimeoutEntry { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.at.cmp(&other.at) - } -} - -impl PartialOrd for TimeoutEntry { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -pub struct Engine { - pub config: EngineConfig, - pub identity: QlIdentity, - pub peer: Option, - pub state: EngineState, - pub streams: StreamFsm, -} - -pub struct EngineState { - pub replay_cache: ReplayCache, - - pub next_token: Cell, - pub next_packet_id: Cell, - pub pending_events: VecDeque, - pub control_outbound: VecDeque, - pub active_writes: GenerationalArena, - pub timeouts: BinaryHeap>, - pub now: Instant, -} - -impl EngineState { - pub fn new() -> Self { - Self { - replay_cache: ReplayCache::new(), - next_token: Cell::new(1), - next_packet_id: Cell::new(1), - pending_events: VecDeque::new(), - control_outbound: VecDeque::new(), - active_writes: GenerationalArena::new(), - timeouts: BinaryHeap::new(), - now: Instant::now(), - } - } - - pub fn next_deadline(&self) -> Option { - self.timeouts.peek().map(|entry| entry.0.at) - } - - pub fn next_token(&self) -> Token { - let token = self.next_token.get(); - self.next_token.set(token.wrapping_add(1)); - Token(token) - } - - pub fn next_packet_id(&self) -> PacketId { - let id = self.next_packet_id.get(); - self.next_packet_id.set(id.wrapping_add(1)); - PacketId(id) - } - - pub fn enqueue_handshake_message( - &mut self, - _config: &EngineConfig, - token: Token, - deadline: Instant, - bytes: Vec, - ) { - self.control_outbound - .push_back(ControlWrite { token, bytes }); - self.timeouts.push(Reverse(TimeoutEntry { - at: deadline, - kind: TimeoutKind::Outbound { token }, - })); - } - - pub fn enqueue_control( - &mut self, - config: &EngineConfig, - priority: bool, - bytes: Vec, - ) -> Token { - let token = self.next_token(); - let message = ControlWrite { token, bytes }; - if priority { - self.control_outbound.push_front(message); - } else { - self.control_outbound.push_back(message); - } - self.timeouts.push(Reverse(TimeoutEntry { - at: self.now + config.packet_expiration, - kind: TimeoutKind::Outbound { token }, - })); - token - } -} diff --git a/ql-engine/src/engine/tests/handshake.rs b/ql-engine/src/engine/tests/handshake.rs deleted file mode 100644 index 399c5b37..00000000 --- a/ql-engine/src/engine/tests/handshake.rs +++ /dev/null @@ -1,827 +0,0 @@ -use super::*; - -fn handshake_bytes( - sender: XID, - recipient: XID, - record: wire::handshake::HandshakeRecord, -) -> Vec { - wire::encode_record(&QlRecord { - header: QlHeader { sender, recipient }, - payload: QlPayload::Handshake(record), - }) -} - -fn build_reply( - initiator_identity: &QlIdentity, - responder_identity: &QlIdentity, - responder_crypto: &TestCrypto, - hello: &wire::handshake::Hello, - packet_id: u32, -) -> wire::handshake::HelloReply { - let hello_bytes = wire::encode_value(hello); - let hello_view = wire::access_value::(&hello_bytes).unwrap(); - let (reply, _secrets) = wire::handshake::respond_hello( - responder_identity, - responder_crypto, - initiator_identity.xid, - &initiator_identity.signing_public_key, - &initiator_identity.encapsulation_public_key, - hello_view, - wire::ControlMeta { - packet_id: PacketId(packet_id), - valid_until: wire::now_secs().saturating_add(60), - }, - ) - .unwrap(); - reply -} - -fn build_confirm( - initiator_identity: &QlIdentity, - responder_identity: &QlIdentity, - hello: &wire::handshake::Hello, - reply: &wire::handshake::HelloReply, - initiator_secret: &SymmetricKey, - packet_id: u32, -) -> wire::handshake::Confirm { - let reply_bytes = wire::encode_value(reply); - let reply_view = - wire::access_value::(&reply_bytes).unwrap(); - let (confirm, _session_key) = wire::handshake::build_confirm( - initiator_identity, - responder_identity.xid, - &responder_identity.signing_public_key, - hello, - reply_view, - initiator_secret, - wire::ControlMeta { - packet_id: PacketId(packet_id), - valid_until: wire::now_secs().saturating_add(60), - }, - ) - .unwrap(); - confirm -} - -fn pump_between(a: &mut EngineWrapper, b: &mut EngineWrapper, now: Instant) { - loop { - let mut progressed = false; - - while let Some(write) = a.take_next_write() { - let bytes = write.bytes.clone(); - let _ = a.complete_write_collect(write.id, Ok(())); - let _ = b.run_tick_collect(now, EngineInput::Incoming(bytes)); - progressed = true; - } - - while let Some(write) = b.take_next_write() { - let bytes = write.bytes.clone(); - let _ = b.complete_write_collect(write.id, Ok(())); - let _ = a.run_tick_collect(now, EngineInput::Incoming(bytes)); - progressed = true; - } - - if !progressed { - break; - } - } -} - -#[test] -fn handshake_deadline_is_derived_from_peer_state() { - let mut config = EngineConfig::default(); - config.handshake_timeout = Duration::from_secs(5); - config.handshake_retry_interval = Duration::ZERO; - config.max_handshake_retries = 0; - - let identity = test_identity(); - let peer_identity = test_identity(); - let mut engine = EngineWrapper::new( - Engine::new( - config, - identity.clone(), - Some(peer_from_identity(&peer_identity)), - ), - TestCrypto::new(103), - ); - let now = Instant::now(); - - let _outputs = engine.run_tick_collect(now, EngineInput::Connect); - assert_eq!(engine.next_deadline(), Some(now + Duration::from_secs(5))); - - let write = engine.take_next_write().unwrap(); - let _outputs = engine.complete_write_collect(write.id, Ok(())); - assert_eq!(engine.next_deadline(), Some(now + Duration::from_secs(5))); - - let outputs = engine.run_tick_collect(now + Duration::from_secs(4), EngineInput::TimerExpired); - assert!(!outputs.iter().any(|output| { - matches!( - output, - EngineOutput::PeerStatusChanged { - session: PeerSession::Disconnected, - .. - } - ) - })); - assert_eq!(engine.next_deadline(), Some(now + Duration::from_secs(5))); - - let outputs = engine.run_tick_collect(now + Duration::from_secs(5), EngineInput::TimerExpired); - assert!(outputs.iter().any(|output| { - matches!( - output, - EngineOutput::PeerStatusChanged { - session: PeerSession::Disconnected, - .. - } - ) - })); -} - -#[test] -fn initiator_retries_hello_after_retry_interval() { - let mut config = EngineConfig::default(); - config.handshake_timeout = Duration::from_secs(5); - config.handshake_retry_interval = Duration::from_millis(250); - config.max_handshake_retries = 2; - - let identity = test_identity(); - let peer_identity = test_identity(); - let mut engine = EngineWrapper::new( - Engine::new(config, identity, Some(peer_from_identity(&peer_identity))), - TestCrypto::new(111), - ); - let now = Instant::now(); - - let _ = engine.run_tick_collect(now, EngineInput::Connect); - let hello_write = engine.take_next_write().unwrap(); - let hello_bytes = hello_write.bytes.clone(); - let _ = engine.complete_write_collect(hello_write.id, Ok(())); - - let _ = engine.run_tick_collect(now + Duration::from_millis(250), EngineInput::TimerExpired); - let retry_write = engine.take_next_write().unwrap(); - assert_eq!(retry_write.bytes, hello_bytes); - assert!(matches!( - engine.peer.as_ref().map(|peer| &peer.session), - Some(PeerSession::Initiator { - stage: HandshakeInitiator::WaitingHelloReply { retry_count: 1, .. }, - .. - }) - )); -} - -#[test] -fn responder_retries_hello_reply_after_retry_interval() { - let mut config = EngineConfig::default(); - config.handshake_timeout = Duration::from_secs(5); - config.handshake_retry_interval = Duration::from_millis(250); - config.max_handshake_retries = 2; - - let responder_identity = test_identity(); - let initiator_identity = test_identity(); - let initiator_crypto = TestCrypto::new(112); - let responder_crypto = TestCrypto::new(113); - let mut engine = EngineWrapper::new( - Engine::new( - config, - responder_identity.clone(), - Some(peer_from_identity(&initiator_identity)), - ), - responder_crypto, - ); - let now = Instant::now(); - - let (hello, _secret) = wire::handshake::build_hello( - &initiator_identity, - &initiator_crypto, - responder_identity.xid, - &responder_identity.encapsulation_public_key, - wire::ControlMeta { - packet_id: PacketId(81), - valid_until: wire::now_secs().saturating_add(60), - }, - ) - .unwrap(); - - let _ = engine.run_tick_collect( - now, - EngineInput::Incoming(handshake_bytes( - initiator_identity.xid, - responder_identity.xid, - wire::handshake::HandshakeRecord::Hello(hello), - )), - ); - let reply_write = engine.take_next_write().unwrap(); - let reply_bytes = reply_write.bytes.clone(); - let _ = engine.complete_write_collect(reply_write.id, Ok(())); - - let _ = engine.run_tick_collect(now + Duration::from_millis(250), EngineInput::TimerExpired); - let retry_write = engine.take_next_write().unwrap(); - assert_eq!(retry_write.bytes, reply_bytes); - assert!(matches!( - engine.peer.as_ref().map(|peer| &peer.session), - Some(PeerSession::Responder { - stage: HandshakeResponder::WaitingConfirm { retry_count: 1, .. }, - .. - }) - )); -} - -#[test] -fn initiator_retries_confirm_after_retry_interval() { - let mut config = EngineConfig::default(); - config.handshake_timeout = Duration::from_secs(5); - config.handshake_retry_interval = Duration::from_millis(250); - config.max_handshake_retries = 2; - - let identity = test_identity(); - let peer_identity = test_identity(); - let responder_crypto = TestCrypto::new(114); - let mut engine = EngineWrapper::new( - Engine::new( - config, - identity.clone(), - Some(peer_from_identity(&peer_identity)), - ), - TestCrypto::new(115), - ); - let now = Instant::now(); - - let _ = engine.run_tick_collect(now, EngineInput::Connect); - let hello_write = engine.take_next_write().unwrap(); - let hello_record = wire::decode_record(&hello_write.bytes).unwrap(); - let QlPayload::Handshake(wire::handshake::HandshakeRecord::Hello(hello)) = hello_record.payload - else { - panic!("expected hello record"); - }; - let _ = engine.complete_write_collect(hello_write.id, Ok(())); - - let reply = build_reply(&identity, &peer_identity, &responder_crypto, &hello, 82); - let _ = engine.run_tick_collect( - now, - EngineInput::Incoming(handshake_bytes( - peer_identity.xid, - identity.xid, - wire::handshake::HandshakeRecord::HelloReply(reply), - )), - ); - let confirm_write = engine.take_next_write().unwrap(); - let confirm_bytes = confirm_write.bytes.clone(); - let _ = engine.complete_write_collect(confirm_write.id, Ok(())); - - let _ = engine.run_tick_collect(now + Duration::from_millis(250), EngineInput::TimerExpired); - let retry_write = engine.take_next_write().unwrap(); - assert_eq!(retry_write.bytes, confirm_bytes); - assert!(matches!( - engine.peer.as_ref().map(|peer| &peer.session), - Some(PeerSession::Initiator { - stage: HandshakeInitiator::WaitingReady { retry_count: 1, .. }, - .. - }) - )); -} - -#[test] -fn duplicate_hello_resends_hello_reply() { - let responder_identity = test_identity(); - let initiator_identity = test_identity(); - let initiator_crypto = TestCrypto::new(116); - let responder_crypto = TestCrypto::new(117); - let mut engine = EngineWrapper::new( - Engine::new( - EngineConfig::default(), - responder_identity.clone(), - Some(peer_from_identity(&initiator_identity)), - ), - responder_crypto, - ); - let now = Instant::now(); - - let (hello, _secret) = wire::handshake::build_hello( - &initiator_identity, - &initiator_crypto, - responder_identity.xid, - &responder_identity.encapsulation_public_key, - wire::ControlMeta { - packet_id: PacketId(83), - valid_until: wire::now_secs().saturating_add(60), - }, - ) - .unwrap(); - let hello_bytes = handshake_bytes( - initiator_identity.xid, - responder_identity.xid, - wire::handshake::HandshakeRecord::Hello(hello), - ); - - let _ = engine.run_tick_collect(now, EngineInput::Incoming(hello_bytes.clone())); - let reply_write = engine.take_next_write().unwrap(); - let reply_bytes = reply_write.bytes.clone(); - let _ = engine.complete_write_collect(reply_write.id, Ok(())); - - let _ = engine.run_tick_collect(now, EngineInput::Incoming(hello_bytes)); - let resent_reply = engine.take_next_write().unwrap(); - assert_eq!(resent_reply.bytes, reply_bytes); -} - -#[test] -fn duplicate_hello_reply_resends_confirm() { - let identity = test_identity(); - let peer_identity = test_identity(); - let responder_crypto = TestCrypto::new(118); - let mut engine = EngineWrapper::new( - Engine::new( - EngineConfig::default(), - identity.clone(), - Some(peer_from_identity(&peer_identity)), - ), - TestCrypto::new(119), - ); - let now = Instant::now(); - - let _ = engine.run_tick_collect(now, EngineInput::Connect); - let hello_write = engine.take_next_write().unwrap(); - let hello_record = wire::decode_record(&hello_write.bytes).unwrap(); - let QlPayload::Handshake(wire::handshake::HandshakeRecord::Hello(hello)) = hello_record.payload - else { - panic!("expected hello record"); - }; - let _ = engine.complete_write_collect(hello_write.id, Ok(())); - - let reply = build_reply(&identity, &peer_identity, &responder_crypto, &hello, 84); - let reply_bytes = handshake_bytes( - peer_identity.xid, - identity.xid, - wire::handshake::HandshakeRecord::HelloReply(reply.clone()), - ); - - let _ = engine.run_tick_collect(now, EngineInput::Incoming(reply_bytes.clone())); - let confirm_write = engine.take_next_write().unwrap(); - let confirm_bytes = confirm_write.bytes.clone(); - let _ = engine.complete_write_collect(confirm_write.id, Ok(())); - - let _ = engine.run_tick_collect(now, EngineInput::Incoming(reply_bytes)); - let resent_confirm = engine.take_next_write().unwrap(); - assert_eq!(resent_confirm.bytes, confirm_bytes); -} - -#[test] -fn responder_resends_ready_for_duplicate_confirm_after_connecting() { - let responder_identity = test_identity(); - let initiator_identity = test_identity(); - let initiator_crypto = TestCrypto::new(120); - let responder_crypto = TestCrypto::new(121); - let mut engine = EngineWrapper::new( - Engine::new( - EngineConfig::default(), - responder_identity.clone(), - Some(peer_from_identity(&initiator_identity)), - ), - responder_crypto, - ); - let now = Instant::now(); - - let (hello, initiator_secret) = wire::handshake::build_hello( - &initiator_identity, - &initiator_crypto, - responder_identity.xid, - &responder_identity.encapsulation_public_key, - wire::ControlMeta { - packet_id: PacketId(85), - valid_until: wire::now_secs().saturating_add(60), - }, - ) - .unwrap(); - let _ = engine.run_tick_collect( - now, - EngineInput::Incoming(handshake_bytes( - initiator_identity.xid, - responder_identity.xid, - wire::handshake::HandshakeRecord::Hello(hello.clone()), - )), - ); - - let reply_write = engine.take_next_write().unwrap(); - let reply_record = wire::decode_record(&reply_write.bytes).unwrap(); - let QlPayload::Handshake(wire::handshake::HandshakeRecord::HelloReply(reply)) = - reply_record.payload - else { - panic!("expected hello reply"); - }; - let _ = engine.complete_write_collect(reply_write.id, Ok(())); - - let confirm = build_confirm( - &initiator_identity, - &responder_identity, - &hello, - &reply, - &initiator_secret, - 86, - ); - let confirm_bytes = handshake_bytes( - initiator_identity.xid, - responder_identity.xid, - wire::handshake::HandshakeRecord::Confirm(confirm.clone()), - ); - - let _ = engine.run_tick_collect(now, EngineInput::Incoming(confirm_bytes.clone())); - let ready_write = engine.take_next_write().unwrap(); - let ready_bytes = ready_write.bytes.clone(); - let _ = engine.complete_write_collect(ready_write.id, Ok(())); - - assert!(matches!( - engine.peer.as_ref().map(|peer| &peer.session), - Some(PeerSession::Connected { - recent_ready: Some(_), - .. - }) - )); - - let _ = engine.run_tick_collect(now, EngineInput::Incoming(confirm_bytes)); - let resent_ready = engine.take_next_write().unwrap(); - assert_eq!(resent_ready.bytes, ready_bytes); -} - -#[test] -fn stale_hello_reply_does_not_abort_fresh_handshake() { - let identity = test_identity(); - let peer_identity = test_identity(); - let responder_crypto = TestCrypto::new(122); - let stale_initiator_crypto = TestCrypto::new(123); - let mut engine = EngineWrapper::new( - Engine::new( - EngineConfig::default(), - identity.clone(), - Some(peer_from_identity(&peer_identity)), - ), - TestCrypto::new(124), - ); - let now = Instant::now(); - - let (stale_hello, _stale_secret) = wire::handshake::build_hello( - &identity, - &stale_initiator_crypto, - peer_identity.xid, - &peer_identity.encapsulation_public_key, - wire::ControlMeta { - packet_id: PacketId(87), - valid_until: wire::now_secs().saturating_add(60), - }, - ) - .unwrap(); - let stale_reply = build_reply( - &identity, - &peer_identity, - &responder_crypto, - &stale_hello, - 88, - ); - - let _ = engine.run_tick_collect(now, EngineInput::Connect); - let hello_write = engine.take_next_write().unwrap(); - let hello_record = wire::decode_record(&hello_write.bytes).unwrap(); - let QlPayload::Handshake(wire::handshake::HandshakeRecord::Hello(current_hello)) = - hello_record.payload - else { - panic!("expected hello record"); - }; - let _ = engine.complete_write_collect(hello_write.id, Ok(())); - - let outputs = engine.run_tick_collect( - now, - EngineInput::Incoming(handshake_bytes( - peer_identity.xid, - identity.xid, - wire::handshake::HandshakeRecord::HelloReply(stale_reply), - )), - ); - assert!(!outputs.iter().any(|output| matches!( - output, - EngineOutput::PeerStatusChanged { - session: PeerSession::Disconnected, - .. - } - ))); - assert!(matches!( - engine.peer.as_ref().map(|peer| &peer.session), - Some(PeerSession::Initiator { - stage: HandshakeInitiator::WaitingHelloReply { .. }, - .. - }) - )); - - let current_reply = build_reply( - &identity, - &peer_identity, - &responder_crypto, - ¤t_hello, - 89, - ); - let _ = engine.run_tick_collect( - now, - EngineInput::Incoming(handshake_bytes( - peer_identity.xid, - identity.xid, - wire::handshake::HandshakeRecord::HelloReply(current_reply), - )), - ); - assert!(matches!( - engine.peer.as_ref().map(|peer| &peer.session), - Some(PeerSession::Initiator { - stage: HandshakeInitiator::WaitingReady { .. }, - .. - }) - )); - assert!(engine.take_next_write().is_some()); -} - -#[test] -fn stale_confirm_does_not_abort_fresh_handshake() { - let responder_identity = test_identity(); - let initiator_identity = test_identity(); - let responder_crypto = TestCrypto::new(125); - let initiator_crypto = TestCrypto::new(126); - let stale_initiator_crypto = TestCrypto::new(127); - let mut engine = EngineWrapper::new( - Engine::new( - EngineConfig::default(), - responder_identity.clone(), - Some(peer_from_identity(&initiator_identity)), - ), - responder_crypto, - ); - let now = Instant::now(); - - let (stale_hello, stale_secret) = wire::handshake::build_hello( - &initiator_identity, - &stale_initiator_crypto, - responder_identity.xid, - &responder_identity.encapsulation_public_key, - wire::ControlMeta { - packet_id: PacketId(90), - valid_until: wire::now_secs().saturating_add(60), - }, - ) - .unwrap(); - let stale_reply = build_reply( - &initiator_identity, - &responder_identity, - &TestCrypto::new(128), - &stale_hello, - 91, - ); - let stale_confirm = build_confirm( - &initiator_identity, - &responder_identity, - &stale_hello, - &stale_reply, - &stale_secret, - 92, - ); - - let (hello, initiator_secret) = wire::handshake::build_hello( - &initiator_identity, - &initiator_crypto, - responder_identity.xid, - &responder_identity.encapsulation_public_key, - wire::ControlMeta { - packet_id: PacketId(93), - valid_until: wire::now_secs().saturating_add(60), - }, - ) - .unwrap(); - let _ = engine.run_tick_collect( - now, - EngineInput::Incoming(handshake_bytes( - initiator_identity.xid, - responder_identity.xid, - wire::handshake::HandshakeRecord::Hello(hello.clone()), - )), - ); - - let reply_write = engine.take_next_write().unwrap(); - let reply_record = wire::decode_record(&reply_write.bytes).unwrap(); - let QlPayload::Handshake(wire::handshake::HandshakeRecord::HelloReply(reply)) = - reply_record.payload - else { - panic!("expected hello reply"); - }; - let _ = engine.complete_write_collect(reply_write.id, Ok(())); - - let outputs = engine.run_tick_collect( - now, - EngineInput::Incoming(handshake_bytes( - initiator_identity.xid, - responder_identity.xid, - wire::handshake::HandshakeRecord::Confirm(stale_confirm), - )), - ); - assert!(!outputs.iter().any(|output| matches!( - output, - EngineOutput::PeerStatusChanged { - session: PeerSession::Disconnected, - .. - } - ))); - assert!(matches!( - engine.peer.as_ref().map(|peer| &peer.session), - Some(PeerSession::Responder { - stage: HandshakeResponder::WaitingConfirm { .. }, - .. - }) - )); - - let confirm = build_confirm( - &initiator_identity, - &responder_identity, - &hello, - &reply, - &initiator_secret, - 94, - ); - let _ = engine.run_tick_collect( - now, - EngineInput::Incoming(handshake_bytes( - initiator_identity.xid, - responder_identity.xid, - wire::handshake::HandshakeRecord::Confirm(confirm), - )), - ); - assert!(engine.take_next_write().is_some()); -} - -#[test] -fn initiator_waits_for_ready_before_connecting() { - let config = EngineConfig::default(); - let identity = test_identity(); - let peer_identity = test_identity(); - let responder_crypto = TestCrypto::new(129); - let mut engine = EngineWrapper::new( - Engine::new( - config, - identity.clone(), - Some(peer_from_identity(&peer_identity)), - ), - TestCrypto::new(130), - ); - let now = Instant::now(); - - let _outputs = engine.run_tick_collect(now, EngineInput::Connect); - - let hello_write = engine.take_next_write().unwrap(); - let hello_record = wire::decode_record(&hello_write.bytes).unwrap(); - let QlPayload::Handshake(wire::handshake::HandshakeRecord::Hello(hello)) = hello_record.payload - else { - panic!("expected hello record"); - }; - let _outputs = engine.complete_write_collect(hello_write.id, Ok(())); - - let reply = build_reply(&identity, &peer_identity, &responder_crypto, &hello, 95); - let _outputs = engine.run_tick_collect( - now, - EngineInput::Incoming(handshake_bytes( - peer_identity.xid, - identity.xid, - wire::handshake::HandshakeRecord::HelloReply(reply), - )), - ); - - let confirm_write = engine.take_next_write().unwrap(); - let _outputs = engine.complete_write_collect(confirm_write.id, Ok(())); - - assert!(matches!( - engine.peer.as_ref().map(|peer| &peer.session), - Some(PeerSession::Initiator { - stage: HandshakeInitiator::WaitingReady { .. }, - .. - }) - )); - assert!(matches!( - engine.open_stream(now, Vec::new(), None, StreamConfig::default()), - Err(QlError::MissingSession) - )); - - let pending_session_key = match engine.peer.as_ref().map(|peer| &peer.session) { - Some(PeerSession::Initiator { session_key, .. }) => session_key.clone(), - other => panic!("expected pending initiator session, got {other:?}"), - }; - let outputs = engine.run_tick_collect( - now, - EngineInput::Incoming(handshake_bytes( - peer_identity.xid, - identity.xid, - wire::handshake::HandshakeRecord::Ready(wire::handshake::build_ready( - QlHeader { - sender: peer_identity.xid, - recipient: identity.xid, - }, - &pending_session_key, - wire::ControlMeta { - packet_id: PacketId(96), - valid_until: wire::now_secs().saturating_add(60), - }, - [9; wire::encrypted_message::NONCE_SIZE], - )), - )), - ); - - assert!(matches!( - engine.peer.as_ref().map(|peer| &peer.session), - Some(PeerSession::Connected { .. }) - )); - assert!(outputs.iter().any(|output| matches!( - output, - EngineOutput::PeerStatusChanged { - session: PeerSession::Connected { .. }, - .. - } - ))); -} - -#[test] -fn handshake_retry_limit_disconnects_initiator() { - let mut config = EngineConfig::default(); - config.handshake_timeout = Duration::from_secs(5); - config.handshake_retry_interval = Duration::from_millis(250); - config.max_handshake_retries = 1; - - let identity = test_identity(); - let peer_identity = test_identity(); - let mut engine = EngineWrapper::new( - Engine::new(config, identity, Some(peer_from_identity(&peer_identity))), - TestCrypto::new(131), - ); - let now = Instant::now(); - - let _ = engine.run_tick_collect(now, EngineInput::Connect); - let hello_write = engine.take_next_write().unwrap(); - let hello_bytes = hello_write.bytes.clone(); - let _ = engine.complete_write_collect(hello_write.id, Ok(())); - - let _ = engine.run_tick_collect(now + Duration::from_millis(250), EngineInput::TimerExpired); - let retry_write = engine.take_next_write().unwrap(); - assert_eq!(retry_write.bytes, hello_bytes); - let _ = engine.complete_write_collect(retry_write.id, Ok(())); - - let outputs = - engine.run_tick_collect(now + Duration::from_millis(500), EngineInput::TimerExpired); - assert!(outputs.iter().any(|output| matches!( - output, - EngineOutput::PeerStatusChanged { - session: PeerSession::Disconnected, - .. - } - ))); - assert!(matches!( - engine.peer.as_ref().map(|peer| &peer.session), - Some(PeerSession::Disconnected) - )); -} - -#[test] -fn simultaneous_connect_converges_to_connected_peers() { - let config = EngineConfig::default(); - let identity_a = test_identity(); - let identity_b = test_identity(); - let mut a = EngineWrapper::new( - Engine::new( - config, - identity_a.clone(), - Some(peer_from_identity(&identity_b)), - ), - TestCrypto::new(132), - ); - let mut b = EngineWrapper::new( - Engine::new( - config, - identity_b.clone(), - Some(peer_from_identity(&identity_a)), - ), - TestCrypto::new(133), - ); - let now = Instant::now(); - - let _ = a.run_tick_collect(now, EngineInput::Connect); - let _ = b.run_tick_collect(now, EngineInput::Connect); - - let hello_a = a.take_next_write().unwrap(); - let hello_a_bytes = hello_a.bytes.clone(); - let _ = a.complete_write_collect(hello_a.id, Ok(())); - - let hello_b = b.take_next_write().unwrap(); - let hello_b_bytes = hello_b.bytes.clone(); - let _ = b.complete_write_collect(hello_b.id, Ok(())); - - let _ = a.run_tick_collect(now, EngineInput::Incoming(hello_b_bytes)); - let _ = b.run_tick_collect(now, EngineInput::Incoming(hello_a_bytes)); - - pump_between(&mut a, &mut b, now); - - assert!(matches!( - a.peer.as_ref().map(|peer| &peer.session), - Some(PeerSession::Connected { .. }) - )); - assert!(matches!( - b.peer.as_ref().map(|peer| &peer.session), - Some(PeerSession::Connected { .. }) - )); -} diff --git a/ql-engine/src/engine/tests/liveness.rs b/ql-engine/src/engine/tests/liveness.rs deleted file mode 100644 index 2c3ac9d4..00000000 --- a/ql-engine/src/engine/tests/liveness.rs +++ /dev/null @@ -1,87 +0,0 @@ -use super::*; - -#[test] -fn replayed_heartbeat_is_ignored() { - let SingleEngineHarness { - now, - mut engine, - peer, - session_key, - } = SingleEngineHarness::connected(EngineConfig::default(), 101, 4); - let heartbeat = wire::heartbeat::encrypt_heartbeat( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - wire::heartbeat::HeartbeatBody { - meta: wire::ControlMeta { - packet_id: PacketId(7), - valid_until: wire::now_secs().saturating_add(60), - }, - }, - [3; wire::encrypted_message::NONCE_SIZE], - ); - let bytes = wire::encode_record(&heartbeat); - - let _first = engine.run_tick_collect(now, EngineInput::Incoming(bytes.clone())); - let first_write = engine.take_next_write().unwrap(); - let first_record = wire::decode_record(&first_write.bytes).unwrap(); - assert!(matches!(first_record.payload, QlPayload::Heartbeat(_))); - let _ = engine.complete_write_collect(first_write.id, Ok(())); - - let _second = engine.run_tick_collect(now, EngineInput::Incoming(bytes)); - assert!(engine.take_next_write().is_none()); -} - -#[test] -fn keepalive_deadline_is_derived_from_peer_state() { - let mut config = EngineConfig::default(); - config.keep_alive = Some(KeepAliveConfig { - interval: Duration::from_secs(5), - timeout: Duration::from_secs(7), - }); - let SingleEngineHarness { - now, - mut engine, - peer, - session_key, - } = SingleEngineHarness::connected(config, 103, 6); - - let heartbeat = encrypt_heartbeat_record( - peer.xid, - engine.engine.identity.xid, - &session_key, - 1, - [7; wire::encrypted_message::NONCE_SIZE], - ); - let outputs = - engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&heartbeat))); - let _ = outputs; - assert_eq!(engine.next_deadline(), Some(now + Duration::from_secs(5))); - - let write = engine.take_next_write().unwrap(); - let record = wire::decode_record(&write.bytes).unwrap(); - assert!(matches!(record.payload, QlPayload::Heartbeat(_))); - let _ = engine.complete_write_collect(write.id, Ok(())); - - let outputs = engine.run_tick_collect(now + Duration::from_secs(5), EngineInput::TimerExpired); - let _ = outputs; - assert_eq!(engine.next_deadline(), Some(now + Duration::from_secs(12))); - - let write = engine.take_next_write().unwrap(); - let record = wire::decode_record(&write.bytes).unwrap(); - assert!(matches!(record.payload, QlPayload::Heartbeat(_))); - let _ = engine.complete_write_collect(write.id, Ok(())); - - let outputs = engine.run_tick_collect(now + Duration::from_secs(12), EngineInput::TimerExpired); - assert!(outputs.iter().any(|output| { - matches!( - output, - EngineOutput::PeerStatusChanged { - session: PeerSession::Disconnected, - .. - } - ) - })); -} diff --git a/ql-engine/src/engine/tests/mod.rs b/ql-engine/src/engine/tests/mod.rs deleted file mode 100644 index 2ab01583..00000000 --- a/ql-engine/src/engine/tests/mod.rs +++ /dev/null @@ -1,538 +0,0 @@ -mod handshake; -mod liveness; -mod peer; -mod stream; - -use std::{ - cell::Cell, - mem, - ops::{Deref, DerefMut}, - time::{Duration, Instant}, -}; - -use bc_components::{SymmetricKey, MLDSA, MLKEM, XID}; - -use crate::{ - engine::*, - identity::QlIdentity, - stream::{state::*, StreamNamespace}, - wire::{self, stream::*, QlHeader, QlPayload, QlRecord, StreamSeq}, - PacketId, Peer, -}; - -#[derive(Debug)] -pub enum EngineOutput { - PeerStatusChanged { - peer: XID, - session: PeerSession, - }, - PersistPeer(Peer), - ClearPeer, - - InboundStreamOpened { - stream_id: StreamId, - request_head: Vec, - request_prefix: Option, - }, - InboundData { - stream_id: StreamId, - bytes: Vec, - }, - InboundFinished { - stream_id: StreamId, - }, - InboundFailed { - stream_id: StreamId, - error: QlError, - }, - - OutboundClosed { - stream_id: StreamId, - }, - OutboundFailed { - stream_id: StreamId, - error: QlError, - }, - - StreamReaped { - stream_id: StreamId, - }, -} - -impl EngineEventSink for Vec { - fn peer_status_changed(&mut self, peer: XID, session: PeerSession) { - self.push(EngineOutput::PeerStatusChanged { peer, session }); - } - - fn persist_peer(&mut self, peer: Peer) { - self.push(EngineOutput::PersistPeer(peer)); - } - - fn clear_peer(&mut self) { - self.push(EngineOutput::ClearPeer); - } - - fn inbound_stream_opened( - &mut self, - stream_id: StreamId, - request_head: Vec, - request_prefix: Option, - ) { - self.push(EngineOutput::InboundStreamOpened { - stream_id, - request_head, - request_prefix, - }); - } - - fn inbound_data(&mut self, stream_id: StreamId, bytes: Vec) { - self.push(EngineOutput::InboundData { stream_id, bytes }); - } - - fn inbound_finished(&mut self, stream_id: StreamId) { - self.push(EngineOutput::InboundFinished { stream_id }); - } - - fn inbound_failed(&mut self, stream_id: StreamId, error: QlError) { - self.push(EngineOutput::InboundFailed { stream_id, error }); - } - - fn outbound_closed(&mut self, stream_id: StreamId) { - self.push(EngineOutput::OutboundClosed { stream_id }); - } - - fn outbound_failed(&mut self, stream_id: StreamId, error: QlError) { - self.push(EngineOutput::OutboundFailed { stream_id, error }); - } - - fn stream_reaped(&mut self, stream_id: StreamId) { - self.push(EngineOutput::StreamReaped { stream_id }); - } -} - -#[derive(Clone)] -struct TestCrypto { - nonce_seed: u8, - nonce_counter: Cell, -} - -impl TestCrypto { - fn new(seed: u8) -> Self { - Self { - nonce_seed: seed, - nonce_counter: Cell::new(0), - } - } -} - -impl QlCrypto for TestCrypto { - fn fill_random_bytes(&self, data: &mut [u8]) { - let value = self.nonce_seed.wrapping_add(self.nonce_counter.get()); - self.nonce_counter - .set(self.nonce_counter.get().wrapping_add(1)); - data.fill(value); - } -} - -#[derive(Clone, Copy)] -enum Side { - A, - B, -} - -impl Side { - fn other(self) -> Self { - match self { - Side::A => Side::B, - Side::B => Side::A, - } - } -} - -#[allow(dead_code)] -enum EngineInput { - BindPeer(Peer), - Pair, - Connect, - Unpair, - CloseStream { - stream_id: StreamId, - target: CloseTarget, - code: CloseCode, - payload: Vec, - }, - OutboundData { - stream_id: StreamId, - bytes: Vec, - }, - OutboundFinished { - stream_id: StreamId, - }, - Incoming(Vec), - TimerExpired, -} - -struct Harness { - now: Instant, - a: EngineWrapper, - b: EngineWrapper, -} - -struct SingleEngineHarness { - now: Instant, - engine: EngineWrapper, - peer: QlIdentity, - session_key: SymmetricKey, -} - -impl SingleEngineHarness { - fn connected(config: EngineConfig, nonce_seed: u8, session_fill: u8) -> Self { - let local_identity = test_identity(); - let peer = test_identity(); - let session_key = SymmetricKey::from_data([session_fill; SymmetricKey::SYMMETRIC_KEY_SIZE]); - let mut engine = Engine::new( - config, - local_identity.clone(), - Some(peer_from_identity(&peer)), - ); - engine.peer.as_mut().unwrap().session = PeerSession::Connected { - session_key: session_key.clone(), - keepalive: KeepAliveState::default(), - recent_ready: None, - }; - Self { - now: Instant::now(), - engine: EngineWrapper::new(engine, TestCrypto::new(nonce_seed)), - peer, - session_key, - } - } -} - -impl Harness { - fn connected(config: EngineConfig) -> Self { - let identity_a = test_identity(); - let identity_b = test_identity(); - let peer_a = peer_from_identity(&identity_a); - let peer_b = peer_from_identity(&identity_b); - let crypto_a = TestCrypto::new(1); - let crypto_b = TestCrypto::new(2); - let session_key = SymmetricKey::from_data([7; SymmetricKey::SYMMETRIC_KEY_SIZE]); - let mut a = Engine::new(config, identity_a.clone(), Some(peer_b)); - let mut b = Engine::new(config, identity_b.clone(), Some(peer_a)); - a.peer.as_mut().unwrap().session = PeerSession::Connected { - session_key: session_key.clone(), - keepalive: KeepAliveState::default(), - recent_ready: None, - }; - b.peer.as_mut().unwrap().session = PeerSession::Connected { - session_key, - keepalive: KeepAliveState::default(), - recent_ready: None, - }; - Self { - now: Instant::now(), - a: EngineWrapper::new(a, crypto_a), - b: EngineWrapper::new(b, crypto_b), - } - } - - fn run_side(&mut self, side: Side, input: EngineInput) { - match side { - Side::A => self.a.run_tick(self.now, input), - Side::B => self.b.run_tick(self.now, input), - } - - while let Some(write) = match side { - Side::A => self.a.take_next_write(), - Side::B => self.b.take_next_write(), - } { - let bytes = write.bytes.clone(); - self.complete_side_write(side, write.id, Ok(())); - self.run_side(side.other(), EngineInput::Incoming(bytes)); - } - } - - fn complete_side_write(&mut self, side: Side, write_id: WriteId, result: Result<(), QlError>) { - match side { - Side::A => self.a.complete_write(write_id, result), - Side::B => self.b.complete_write(write_id, result), - } - } -} - -struct EngineWrapper { - engine: Engine, - crypto: TestCrypto, - outputs: Vec, -} - -impl Deref for EngineWrapper { - type Target = Engine; - - fn deref(&self) -> &Self::Target { - &self.engine - } -} - -impl DerefMut for EngineWrapper { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.engine - } -} - -impl EngineWrapper { - fn new(engine: Engine, crypto: TestCrypto) -> Self { - Self { - engine, - crypto, - outputs: Vec::new(), - } - } - - fn run_tick(&mut self, now: Instant, input: EngineInput) { - match input { - EngineInput::BindPeer(peer) => self.engine.bind_peer(now, peer, &mut self.outputs), - EngineInput::Pair => self.engine.pair(now, &self.crypto), - EngineInput::Connect => self.engine.connect(now, &self.crypto, &mut self.outputs), - EngineInput::Unpair => self.engine.unpair(now, &mut self.outputs), - EngineInput::CloseStream { - stream_id, - target, - code, - payload, - } => { - let _ = self - .engine - .close_stream(now, stream_id, target, code, payload); - } - EngineInput::OutboundData { stream_id, bytes } => { - let _ = self.engine.write_stream(now, stream_id, bytes); - } - EngineInput::OutboundFinished { stream_id } => { - let _ = self.engine.finish_stream(now, stream_id); - } - EngineInput::Incoming(bytes) => { - self.engine - .receive(now, bytes, &self.crypto, &mut self.outputs); - } - EngineInput::TimerExpired => { - self.engine.on_timer(now, &self.crypto, &mut self.outputs); - } - } - } - - fn run_tick_collect(&mut self, now: Instant, input: EngineInput) -> Vec { - self.run_tick(now, input); - self.drain_outputs() - } - - fn complete_write(&mut self, write_id: WriteId, result: Result<(), QlError>) { - self.engine - .complete_write(self.engine.state.now, write_id, result, &mut self.outputs); - } - - fn take_next_write(&mut self) -> Option { - self.engine - .take_next_write(self.engine.state.now, &self.crypto) - } - - fn complete_write_collect( - &mut self, - write_id: WriteId, - result: Result<(), QlError>, - ) -> Vec { - self.complete_write(write_id, result); - self.drain_outputs() - } - - fn open_stream( - &mut self, - now: Instant, - request_head: Vec, - request_prefix: Option, - config: StreamConfig, - ) -> Result { - self.engine - .open_stream(now, request_head, request_prefix, config) - } - - fn drain_outputs(&mut self) -> Vec { - mem::take(&mut self.outputs) - } -} - -fn test_identity() -> QlIdentity { - let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); - let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); - QlIdentity::from_keys( - signing_private, - signing_public, - encapsulation_private, - encapsulation_public, - ) -} - -fn peer_from_identity(identity: &QlIdentity) -> Peer { - Peer { - peer: identity.xid, - signing_key: identity.signing_public_key.clone(), - encapsulation_key: identity.encapsulation_public_key.clone(), - } -} - -fn decode_stream_body(bytes: &[u8], session_key: &SymmetricKey) -> (QlHeader, StreamBody) { - let record = wire::decode_record(bytes).unwrap(); - let aad = record.header.aad(); - let QlPayload::Stream(encrypted) = record.payload else { - panic!("expected stream payload"); - }; - let plaintext = encrypted.decrypt(session_key, &aad).unwrap(); - let body = wire::access_value::(&plaintext) - .and_then(wire::deserialize_value) - .unwrap(); - (record.header, body) -} - -fn encrypt_heartbeat_record( - sender: XID, - recipient: XID, - session_key: &SymmetricKey, - packet_id: u32, - nonce: [u8; wire::encrypted_message::NONCE_SIZE], -) -> QlRecord { - wire::heartbeat::encrypt_heartbeat( - QlHeader { sender, recipient }, - session_key, - wire::heartbeat::HeartbeatBody { - meta: crate::wire::ControlMeta { - packet_id: PacketId(packet_id), - valid_until: wire::now_secs().saturating_add(60), - }, - }, - nonce, - ) -} - -fn insert_inflight_gap_stream(engine: &mut EngineWrapper, stream_id: StreamId, now: Instant) { - let retry_at = now + Duration::from_secs(60); - let mut stream = StreamState { - control: StreamControl::default(), - role: StreamRole::Initiator(InitiatorStream { - request: OutboundPhase::from_prefix(false), - response: InboundState::new(), - }), - }; - let control = &mut stream.control; - control.next_tx_seq = StreamSeq(6); - control.insert_in_flight(InFlightFrame { - tx_seq: StreamSeq::START, - frame: StreamFrame::Open(StreamFrameOpen { - stream_id, - request_head: b"open".to_vec(), - request_prefix: None, - }), - attempt: 0, - write_state: InFlightWriteState::WaitingRetry { retry_at }, - }); - for (tx_seq, byte) in [(2, b'a'), (3, b'b'), (4, b'c'), (5, b'd')] { - control.insert_in_flight(InFlightFrame { - tx_seq: StreamSeq(tx_seq), - frame: StreamFrame::Data(StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: vec![byte], - fin: false, - }, - }), - attempt: 0, - write_state: InFlightWriteState::WaitingRetry { retry_at }, - }); - } - engine.streams.streams.insert(stream_id, stream); -} - -fn insert_inflight_stream_with_data( - engine: &mut EngineWrapper, - stream_id: StreamId, - now: Instant, - data_seqs: &[u32], -) { - let retry_at = now + Duration::from_secs(60); - let mut stream = StreamState { - control: StreamControl::default(), - role: StreamRole::Initiator(InitiatorStream { - request: OutboundPhase::from_prefix(false), - response: InboundState::new(), - }), - }; - let control = &mut stream.control; - control.next_tx_seq = StreamSeq(data_seqs.iter().copied().max().unwrap_or(1) + 1); - control.insert_in_flight(InFlightFrame { - tx_seq: StreamSeq::START, - frame: StreamFrame::Open(StreamFrameOpen { - stream_id, - request_head: b"open".to_vec(), - request_prefix: None, - }), - attempt: 0, - write_state: InFlightWriteState::WaitingRetry { retry_at }, - }); - for &tx_seq in data_seqs { - control.insert_in_flight(InFlightFrame { - tx_seq: StreamSeq(tx_seq), - frame: StreamFrame::Data(StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: vec![b'a' + (tx_seq as u8)], - fin: false, - }, - }), - attempt: 0, - write_state: InFlightWriteState::WaitingRetry { retry_at }, - }); - } - engine.streams.streams.insert(stream_id, stream); -} - -fn insert_unwritten_inflight_stream_with_data( - engine: &mut EngineWrapper, - stream_id: StreamId, - _now: Instant, - data_seqs: &[u32], -) { - let mut stream = StreamState { - control: StreamControl::default(), - role: StreamRole::Initiator(InitiatorStream { - request: OutboundPhase::from_prefix(false), - response: InboundState::new(), - }), - }; - let control = &mut stream.control; - control.next_tx_seq = StreamSeq(data_seqs.iter().copied().max().unwrap_or(1) + 1); - control.insert_in_flight(InFlightFrame { - tx_seq: StreamSeq::START, - frame: StreamFrame::Open(StreamFrameOpen { - stream_id, - request_head: b"open".to_vec(), - request_prefix: None, - }), - attempt: 0, - write_state: InFlightWriteState::Ready, - }); - for &tx_seq in data_seqs { - control.insert_in_flight(InFlightFrame { - tx_seq: StreamSeq(tx_seq), - frame: StreamFrame::Data(StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: vec![b'a' + (tx_seq as u8)], - fin: false, - }, - }), - attempt: 0, - write_state: InFlightWriteState::Ready, - }); - } - engine.streams.streams.insert(stream_id, stream); -} diff --git a/ql-engine/src/engine/tests/peer.rs b/ql-engine/src/engine/tests/peer.rs deleted file mode 100644 index 11ea08aa..00000000 --- a/ql-engine/src/engine/tests/peer.rs +++ /dev/null @@ -1,42 +0,0 @@ -use super::*; - -#[test] -fn replayed_unpair_is_ignored_after_rebind() { - let config = EngineConfig::default(); - let SingleEngineHarness { - now, - mut engine, - peer, - session_key: _session_key, - } = SingleEngineHarness::connected(config, 111, 5); - let peer_b = peer_from_identity(&peer); - let bytes = wire::encode_record(&wire::unpair::build_unpair_record( - &peer, - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - wire::ControlMeta { - packet_id: PacketId(9), - valid_until: wire::now_secs().saturating_add(60), - }, - )); - - let first = engine.run_tick_collect(now, EngineInput::Incoming(bytes.clone())); - assert!(first - .iter() - .any(|output| matches!(output, EngineOutput::ClearPeer))); - assert!(engine.peer.is_none()); - - let _ = engine.run_tick_collect(now, EngineInput::BindPeer(peer_b.clone())); - assert!(engine.peer.is_some()); - - let second = engine.run_tick_collect(now, EngineInput::Incoming(bytes)); - assert!(!second - .iter() - .any(|output| matches!(output, EngineOutput::ClearPeer))); - assert_eq!( - engine.peer.as_ref().map(|peer| peer.peer), - Some(peer_b.peer) - ); -} diff --git a/ql-engine/src/engine/tests/stream.rs b/ql-engine/src/engine/tests/stream.rs deleted file mode 100644 index 739f1736..00000000 --- a/ql-engine/src/engine/tests/stream.rs +++ /dev/null @@ -1,1554 +0,0 @@ -#![allow(clippy::too_many_lines)] - -use super::*; - -#[test] -fn simultaneous_opens_use_disjoint_stream_id_namespaces() { - let config = EngineConfig::default(); - let mut harness = Harness::connected(config); - let now = harness.now; - - let stream_id_a = harness - .a - .open_stream(now, b"a-open".to_vec(), None, StreamConfig::default()) - .unwrap(); - let stream_id_b = harness - .b - .open_stream(now, b"b-open".to_vec(), None, StreamConfig::default()) - .unwrap(); - - assert_ne!(stream_id_a, stream_id_b); - assert!(StreamNamespace::for_local( - harness.a.engine.identity.xid, - harness.b.engine.identity.xid - ) - .matches(stream_id_a)); - assert!(StreamNamespace::for_local( - harness.b.engine.identity.xid, - harness.a.engine.identity.xid - ) - .matches(stream_id_b)); - - let write_a = harness.a.take_next_write().unwrap(); - let write_b = harness.b.take_next_write().unwrap(); - - let _ = harness.a.complete_write_collect(write_a.id, Ok(())); - let _ = harness.b.complete_write_collect(write_b.id, Ok(())); - - let outputs_a_incoming = harness - .a - .run_tick_collect(now, EngineInput::Incoming(write_b.bytes)); - let outputs_b_incoming = harness - .b - .run_tick_collect(now, EngineInput::Incoming(write_a.bytes)); - - assert!(outputs_a_incoming.iter().any(|output| matches!( - output, - EngineOutput::InboundStreamOpened { - stream_id, - request_head, - .. - } if *stream_id == stream_id_b && request_head == b"b-open" - ))); - assert!(outputs_b_incoming.iter().any(|output| matches!( - output, - EngineOutput::InboundStreamOpened { - stream_id, - request_head, - .. - } if *stream_id == stream_id_a && request_head == b"a-open" - ))); - assert_eq!(harness.a.streams.streams.len(), 2); - assert_eq!(harness.b.streams.streams.len(), 2); -} - -#[test] -fn invalid_future_frame_does_not_ack_outstanding_open() { - let config = EngineConfig::default(); - let SingleEngineHarness { - now, - mut engine, - peer, - session_key, - } = SingleEngineHarness::connected(config, 31, 5); - let stream_id = engine - .open_stream(now, b"open".to_vec(), None, StreamConfig::default()) - .unwrap(); - - let message = StreamMessage { - tx_seq: StreamSeq(2), - ack: crate::wire::stream::StreamAck { - base: StreamSeq(0), - bitmap: 0, - }, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: b"resp".to_vec(), - fin: false, - }, - }), - }; - - let body = StreamBody::Message(message); - let record = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &body, - [9; wire::encrypted_message::NONCE_SIZE], - ); - - let outputs_incoming = - engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record))); - - assert!(!outputs_incoming - .iter() - .any(|output| matches!(output, EngineOutput::InboundData { .. }))); - - let stream = engine.streams.streams.get(&stream_id).unwrap(); - assert!(stream.control.in_flight.contains_key(&StreamSeq::START)); -} - -#[test] -fn ack_for_issued_open_is_applied_before_write_completion() { - let config = EngineConfig::default(); - let SingleEngineHarness { - now, - mut engine, - peer, - session_key, - } = SingleEngineHarness::connected(config, 33, 7); - let stream_id = engine - .open_stream(now, b"open".to_vec(), None, StreamConfig::default()) - .unwrap(); - - let _open_write = engine.take_next_write().unwrap(); - - let message = StreamMessage { - tx_seq: StreamSeq::START, - ack: StreamAck { - base: StreamSeq::START, - bitmap: 0, - }, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: b"resp".to_vec(), - fin: false, - }, - }), - }; - let record = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &StreamBody::Message(message), - [10; wire::encrypted_message::NONCE_SIZE], - ); - - let outputs_incoming = - engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record))); - - assert!(outputs_incoming.iter().any(|output| matches!( - output, - EngineOutput::InboundData { - stream_id: id, - bytes, - } if *id == stream_id && bytes == b"resp" - ))); - let stream = engine.streams.streams.get(&stream_id).unwrap(); - assert!(!stream.control.in_flight.contains_key(&StreamSeq::START)); -} - -#[test] -fn ack_does_not_retire_ready_data() { - let config = EngineConfig::default(); - let SingleEngineHarness { - now, - mut engine, - peer, - session_key, - } = SingleEngineHarness::connected(config, 35, 8); - let stream_id = engine - .open_stream(now, b"open".to_vec(), None, StreamConfig::default()) - .unwrap(); - - let _open_write = engine.take_next_write().unwrap(); - let _ = engine.run_tick_collect( - now, - EngineInput::OutboundData { - stream_id, - bytes: b"body".to_vec(), - }, - ); - - let message = StreamMessage { - tx_seq: StreamSeq::START, - ack: StreamAck { - base: StreamSeq(2), - bitmap: 0, - }, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: b"resp".to_vec(), - fin: false, - }, - }), - }; - let record = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &StreamBody::Message(message), - [11; wire::encrypted_message::NONCE_SIZE], - ); - - let outputs_incoming = - engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record))); - - assert!(outputs_incoming.iter().any(|output| matches!( - output, - EngineOutput::InboundData { - stream_id: id, - bytes, - } if *id == stream_id && bytes == b"resp" - ))); - - let stream = engine.streams.streams.get(&stream_id).unwrap(); - assert!(!stream.control.in_flight.contains_key(&StreamSeq::START)); - assert!(stream.control.in_flight.contains_key(&StreamSeq(2))); - - let write = engine.take_next_write().unwrap(); - let (_, body) = decode_stream_body(&write.bytes, &session_key); - assert!(matches!( - body, - StreamBody::Message(StreamMessage { - tx_seq: StreamSeq(2), - frame: StreamFrame::Data(StreamFrameData { - stream_id: id, - chunk: BodyChunk { bytes, fin: false }, - }), - .. - }) if id == stream_id && bytes == b"body" - )); -} - -#[test] -fn late_failed_write_after_remote_close_ack_is_ignored() { - let config = EngineConfig::default(); - let SingleEngineHarness { - now, - mut engine, - peer, - session_key, - } = SingleEngineHarness::connected(config, 37, 9); - let stream_id = engine - .open_stream(now, b"open".to_vec(), None, StreamConfig::default()) - .unwrap(); - - let open_write = engine.take_next_write().unwrap(); - - let record = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &StreamBody::Message(StreamMessage { - tx_seq: StreamSeq::START, - ack: StreamAck { - base: StreamSeq::START, - bitmap: 0, - }, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Close(StreamFrameClose { - stream_id, - target: CloseTarget::Both, - code: CloseCode::PROTOCOL, - payload: Vec::new(), - }), - }), - [12; wire::encrypted_message::NONCE_SIZE], - ); - - let outputs_close = - engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record))); - - assert!(outputs_close.iter().any(|output| matches!( - output, - EngineOutput::OutboundFailed { - stream_id: id, - error: QlError::StreamClosed { - target: CloseTarget::Both, - code: CloseCode::PROTOCOL, - payload, - }, - } if *id == stream_id - && payload.is_empty() - ))); - assert!(outputs_close.iter().any(|output| matches!( - output, - EngineOutput::InboundFailed { - stream_id: id, - error: QlError::StreamClosed { - target: CloseTarget::Both, - code: CloseCode::PROTOCOL, - payload, - }, - } if *id == stream_id - && payload.is_empty() - ))); - let stream = engine.streams.streams.get(&stream_id).unwrap(); - assert!(!stream.control.in_flight.contains_key(&StreamSeq::START)); - - let outputs_late = engine.complete_write_collect(open_write.id, Err(QlError::SendFailed)); - assert!(outputs_late.is_empty()); - assert!(engine.streams.streams.contains_key(&stream_id)); -} - -#[test] -fn local_close_both_is_idempotent() { - let SingleEngineHarness { - now, - mut engine, - session_key, - .. - } = SingleEngineHarness::connected(EngineConfig::default(), 39, 10); - let stream_id = engine - .open_stream(now, b"open".to_vec(), None, StreamConfig::default()) - .unwrap(); - - let open_write = engine.take_next_write().unwrap(); - let _ = engine.complete_write_collect(open_write.id, Ok(())); - - let _ = engine.run_tick_collect( - now, - EngineInput::CloseStream { - stream_id, - target: CloseTarget::Request, - code: CloseCode::CANCELLED, - payload: Vec::new(), - }, - ); - let request_close = engine.take_next_write().unwrap(); - let (_, request_close_body) = decode_stream_body(&request_close.bytes, &session_key); - assert!(matches!( - request_close_body, - StreamBody::Message(StreamMessage { - frame: StreamFrame::Close(StreamFrameClose { - stream_id: id, - target: CloseTarget::Request, - .. - }), - .. - }) if id == stream_id - )); - let _ = engine.complete_write_collect(request_close.id, Ok(())); - - let _ = engine.run_tick_collect( - now, - EngineInput::CloseStream { - stream_id, - target: CloseTarget::Both, - code: CloseCode::CANCELLED, - payload: Vec::new(), - }, - ); - let both_close = engine.take_next_write().unwrap(); - let (_, both_close_body) = decode_stream_body(&both_close.bytes, &session_key); - assert!(matches!( - both_close_body, - StreamBody::Message(StreamMessage { - frame: StreamFrame::Close(StreamFrameClose { - stream_id: id, - target: CloseTarget::Both, - .. - }), - .. - }) if id == stream_id - )); - let _ = engine.complete_write_collect(both_close.id, Ok(())); - - let _ = engine.run_tick_collect( - now, - EngineInput::CloseStream { - stream_id, - target: CloseTarget::Both, - code: CloseCode::CANCELLED, - payload: Vec::new(), - }, - ); - assert!(engine.take_next_write().is_none()); -} - -#[test] -fn out_of_order_remote_stream_buffers_until_open_arrives() { - let config = EngineConfig::default(); - let SingleEngineHarness { - now, - mut engine, - peer, - session_key, - } = SingleEngineHarness::connected(config, 41, 6); - let stream_id = - StreamId(StreamNamespace::for_local(peer.xid, engine.engine.identity.xid).bit() | 1); - - let data_message = StreamMessage { - tx_seq: StreamSeq(2), - ack: StreamAck::EMPTY, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(crate::wire::stream::StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: b"hello".to_vec(), - fin: false, - }, - }), - }; - let data_body = StreamBody::Message(data_message); - let data_record = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &data_body, - [11; wire::encrypted_message::NONCE_SIZE], - ); - - let outputs_data = engine.run_tick_collect( - now, - EngineInput::Incoming(wire::encode_record(&data_record)), - ); - - assert!(!outputs_data - .iter() - .any(|output| matches!(output, EngineOutput::InboundStreamOpened { .. }))); - assert!(!outputs_data - .iter() - .any(|output| matches!(output, EngineOutput::InboundData { .. }))); - assert!(engine.take_next_write().is_some()); - assert!(engine - .streams - .streams - .get(&stream_id) - .is_some_and(StreamState::awaiting_open)); - - let open_message = StreamMessage { - tx_seq: StreamSeq(1), - ack: StreamAck::EMPTY, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Open(crate::wire::stream::StreamFrameOpen { - stream_id, - request_head: b"late-open".to_vec(), - request_prefix: None, - }), - }; - let open_body = StreamBody::Message(open_message); - let open_record = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &open_body, - [12; wire::encrypted_message::NONCE_SIZE], - ); - - let outputs_open = engine.run_tick_collect( - now, - EngineInput::Incoming(wire::encode_record(&open_record)), - ); - - assert!(outputs_open.iter().any(|output| matches!( - output, - EngineOutput::InboundStreamOpened { - stream_id: id, - request_head, - request_prefix: None, - } if *id == stream_id && request_head == b"late-open" - ))); - assert!(outputs_open.iter().any(|output| matches!( - output, - EngineOutput::InboundData { - stream_id: id, - bytes, - } if *id == stream_id && bytes == b"hello" - ))); -} - -#[test] -fn delayed_ack_only_does_not_consume_sequence_space() { - let mut harness = Harness::connected(EngineConfig::default()); - let stream_id = harness - .a - .open_stream( - harness.now, - b"open-head".to_vec(), - None, - StreamConfig::default(), - ) - .unwrap(); - let open_write = harness.a.take_next_write().unwrap(); - harness.complete_side_write(Side::A, open_write.id, Ok(())); - harness.run_side(Side::B, EngineInput::Incoming(open_write.bytes)); - - harness.now += EngineConfig::default().stream_ack_delay; - harness.run_side(Side::B, EngineInput::TimerExpired); - - let _outputs_b = harness.b.drain_outputs(); - - let stream = harness.b.streams.streams.get(&stream_id).unwrap(); - assert!(stream.control.in_flight.is_empty()); - assert_eq!(stream.control.next_tx_seq, StreamSeq::START); -} - -#[test] -fn half_window_progress_flushes_ack_before_timer() { - let config = EngineConfig::default(); - let SingleEngineHarness { - now, - mut engine, - peer, - session_key, - } = SingleEngineHarness::connected(config, 61, 8); - let stream_id = - StreamId(StreamNamespace::for_local(peer.xid, engine.engine.identity.xid).bit() | 1); - let messages = [ - StreamMessage { - tx_seq: StreamSeq(1), - ack: StreamAck::EMPTY, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Open(crate::wire::stream::StreamFrameOpen { - stream_id, - request_head: b"open".to_vec(), - request_prefix: None, - }), - }, - StreamMessage { - tx_seq: StreamSeq(2), - ack: StreamAck::EMPTY, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(crate::wire::stream::StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: b"a".to_vec(), - fin: false, - }, - }), - }, - StreamMessage { - tx_seq: StreamSeq(3), - ack: StreamAck::EMPTY, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(crate::wire::stream::StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: b"b".to_vec(), - fin: false, - }, - }), - }, - StreamMessage { - tx_seq: StreamSeq(4), - ack: StreamAck::EMPTY, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(crate::wire::stream::StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: b"c".to_vec(), - fin: false, - }, - }), - }, - ]; - - for message in messages.iter().take(3) { - let body = StreamBody::Message(message.clone()); - let record = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &body, - [message.tx_seq.0 as u8; wire::encrypted_message::NONCE_SIZE], - ); - let _outputs = - engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record))); - assert!(engine.take_next_write().is_none()); - } - - let body = StreamBody::Message(messages[3].clone()); - let record = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &body, - [4; wire::encrypted_message::NONCE_SIZE], - ); - let _outputs = - engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record))); - - let ack_write = engine.take_next_write().unwrap(); - let (_, ack_body) = decode_stream_body(&ack_write.bytes, &session_key); - assert!(matches!(ack_body, StreamBody::Ack(_))); -} - -#[test] -fn out_of_order_loss_reports_selective_ack_bitmap() { - let SingleEngineHarness { - now, - mut engine, - peer, - session_key, - } = SingleEngineHarness::connected(EngineConfig::default(), 71, 3); - let stream_id = - StreamId(StreamNamespace::for_local(peer.xid, engine.engine.identity.xid).bit() | 1); - let messages = [ - StreamMessage { - tx_seq: StreamSeq(1), - ack: StreamAck::EMPTY, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Open(StreamFrameOpen { - stream_id, - request_head: b"open".to_vec(), - request_prefix: None, - }), - }, - StreamMessage { - tx_seq: StreamSeq(2), - ack: StreamAck::EMPTY, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: b"a".to_vec(), - fin: false, - }, - }), - }, - StreamMessage { - tx_seq: StreamSeq(4), - ack: StreamAck::EMPTY, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: b"c".to_vec(), - fin: false, - }, - }), - }, - StreamMessage { - tx_seq: StreamSeq(5), - ack: StreamAck::EMPTY, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: b"d".to_vec(), - fin: false, - }, - }), - }, - ]; - - for message in &messages[..2] { - let record = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &StreamBody::Message(message.clone()), - [message.tx_seq.0 as u8; wire::encrypted_message::NONCE_SIZE], - ); - let _outputs = - engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record))); - assert!(engine.take_next_write().is_none()); - } - - let record4 = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &StreamBody::Message(messages[2].clone()), - [4; wire::encrypted_message::NONCE_SIZE], - ); - let outputs4 = - engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record4))); - let ack_write4 = engine.take_next_write().unwrap(); - let (_, ack_body4) = decode_stream_body(&ack_write4.bytes, &session_key); - assert!(matches!( - ack_body4, - StreamBody::Ack(StreamAckBody { - stream_id: id, - ack: StreamAck { - base: StreamSeq(2), - bitmap: 0b0000_0010, - }, - .. - }) if id == stream_id - )); - assert!(!outputs4 - .iter() - .any(|output| matches!(output, EngineOutput::InboundData { .. }))); - let _ = engine.complete_write_collect(ack_write4.id, Ok(())); - - let record5 = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &StreamBody::Message(messages[3].clone()), - [5; wire::encrypted_message::NONCE_SIZE], - ); - let outputs5 = - engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record5))); - let ack_write5 = engine.take_next_write().unwrap(); - let (_, ack_body5) = decode_stream_body(&ack_write5.bytes, &session_key); - assert!(matches!( - ack_body5, - StreamBody::Ack(StreamAckBody { - stream_id: id, - ack: StreamAck { - base: StreamSeq(2), - bitmap: 0b0000_0110, - }, - .. - }) if id == stream_id - )); - assert!(!outputs5 - .iter() - .any(|output| matches!(output, EngineOutput::InboundData { .. }))); -} - -#[test] -fn selective_ack_only_body_retires_acked_gap_tail() { - let SingleEngineHarness { - now, - mut engine, - peer, - session_key, - } = SingleEngineHarness::connected(EngineConfig::default(), 81, 2); - let stream_id = engine.streams.next_stream_id(); - insert_inflight_gap_stream(&mut engine, stream_id, now); - - let ack_record = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &StreamBody::Ack(StreamAckBody { - stream_id, - ack: StreamAck { - base: StreamSeq(2), - bitmap: 0b0000_0110, - }, - valid_until: wire::now_secs().saturating_add(60), - }), - [9; wire::encrypted_message::NONCE_SIZE], - ); - - let outputs = - engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&ack_record))); - - assert!(!outputs - .iter() - .any(|output| matches!(output, EngineOutput::OutboundFailed { .. }))); - let stream = engine.streams.streams.get(&stream_id).unwrap(); - let remaining: Vec<_> = stream - .control - .in_flight - .iter() - .map(|(seq, _)| seq) - .collect(); - assert_eq!(remaining, vec![StreamSeq(3)]); - assert_eq!(stream.control.next_tx_seq, StreamSeq(6)); -} - -#[test] -fn fast_retransmit_resends_oldest_gap_when_threshold_met() { - let mut config = EngineConfig::default(); - config.stream_fast_retransmit_threshold = 2; - let SingleEngineHarness { - now, - mut engine, - peer, - session_key, - } = SingleEngineHarness::connected(config, 83, 9); - let stream_id = engine.streams.next_stream_id(); - insert_inflight_gap_stream(&mut engine, stream_id, now); - - let ack_record = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &StreamBody::Ack(StreamAckBody { - stream_id, - ack: StreamAck { - base: StreamSeq(2), - bitmap: 0b0000_0110, - }, - valid_until: wire::now_secs().saturating_add(60), - }), - [10; wire::encrypted_message::NONCE_SIZE], - ); - - let _outputs = - engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&ack_record))); - - let write = engine.take_next_write().unwrap(); - let (_, body) = decode_stream_body(&write.bytes, &session_key); - assert!(matches!( - body, - StreamBody::Message(StreamMessage { - tx_seq: StreamSeq(3), - frame: StreamFrame::Data(StreamFrameData { .. }), - .. - }) - )); - - let stream = engine.streams.streams.get(&stream_id).unwrap(); - let remaining: Vec<_> = stream - .control - .in_flight - .iter() - .map(|(seq, _)| seq) - .collect(); - assert_eq!(remaining, vec![StreamSeq(3)]); - let frame = stream.control.in_flight.get(&StreamSeq(3)).unwrap(); - assert_eq!(frame.attempt, 1); - assert!(matches!( - frame.write_state, - InFlightWriteState::Issued { .. } - )); -} - -#[test] -fn fast_retransmit_respects_configured_threshold() { - let mut config = EngineConfig::default(); - config.stream_fast_retransmit_threshold = 3; - let SingleEngineHarness { - now, - mut engine, - peer, - session_key, - } = SingleEngineHarness::connected(config, 85, 10); - let stream_id = engine.streams.next_stream_id(); - insert_inflight_gap_stream(&mut engine, stream_id, now); - - let ack_record = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &StreamBody::Ack(StreamAckBody { - stream_id, - ack: StreamAck { - base: StreamSeq(2), - bitmap: 0b0000_0110, - }, - valid_until: wire::now_secs().saturating_add(60), - }), - [11; wire::encrypted_message::NONCE_SIZE], - ); - - let _outputs = - engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&ack_record))); - - if let Some(write) = engine.take_next_write() { - let (_, body) = decode_stream_body(&write.bytes, &session_key); - assert!(matches!(body, StreamBody::Ack(_))); - } - - let stream = engine.streams.streams.get(&stream_id).unwrap(); - let remaining: Vec<_> = stream - .control - .in_flight - .iter() - .map(|(seq, _)| seq) - .collect(); - assert_eq!(remaining, vec![StreamSeq(3)]); - let frame = stream.control.in_flight.get(&StreamSeq(3)).unwrap(); - assert_eq!(frame.attempt, 0); - assert!(matches!( - frame.write_state, - InFlightWriteState::WaitingRetry { .. } - )); -} - -#[test] -fn timeout_retransmit_reuses_original_tx_seq_and_slot() { - let config = EngineConfig::default(); - let SingleEngineHarness { - now, - mut engine, - peer: _, - session_key, - } = SingleEngineHarness::connected(config, 91, 1); - let tracked_stream_id = engine - .open_stream(now, b"open".to_vec(), None, StreamConfig::default()) - .unwrap(); - let write = engine.take_next_write().unwrap(); - let (_, initial_body) = decode_stream_body(&write.bytes, &session_key); - assert!(matches!( - &initial_body, - StreamBody::Message(StreamMessage { - tx_seq: StreamSeq(1), - frame: StreamFrame::Open(_), - .. - }) - )); - let _outputs_written = engine.complete_write_collect(write.id, Ok(())); - - let stream = engine.streams.streams.get(&tracked_stream_id).unwrap(); - assert_eq!(stream.control.in_flight.len(), 1); - assert!(stream.control.in_flight.contains_key(&StreamSeq::START)); - assert_eq!(stream.control.next_tx_seq, StreamSeq(2)); - - let _outputs_timeout = - engine.run_tick_collect(now + config.stream_ack_timeout, EngineInput::TimerExpired); - let retransmit_write = engine.take_next_write().unwrap(); - let (_, retransmit_body) = decode_stream_body(&retransmit_write.bytes, &session_key); - assert!(matches!( - retransmit_body, - StreamBody::Message(StreamMessage { - tx_seq: StreamSeq(1), - frame: StreamFrame::Open(StreamFrameOpen { stream_id, .. }), - .. - }) if stream_id == tracked_stream_id - )); - - let stream = engine.streams.streams.get(&tracked_stream_id).unwrap(); - assert_eq!(stream.control.in_flight.len(), 1); - assert!(stream.control.in_flight.contains_key(&StreamSeq::START)); - assert_eq!(stream.control.next_tx_seq, StreamSeq(2)); - assert_eq!( - stream - .control - .in_flight - .get(&StreamSeq::START) - .unwrap() - .attempt, - 1 - ); -} - -#[test] -fn take_next_write_drains_multiple_stream_frames_before_completion() { - let SingleEngineHarness { - now, - mut engine, - peer: _, - session_key, - } = SingleEngineHarness::connected(EngineConfig::default(), 93, 12); - let stream_id = engine.streams.next_stream_id(); - insert_unwritten_inflight_stream_with_data(&mut engine, stream_id, now, &[2, 3]); - - let writes = { - let mut writes = Vec::new(); - while let Some(write) = engine.take_next_write() { - writes.push(write); - } - writes - }; - assert_eq!(writes.len(), 3); - - let tx_seqs: Vec<_> = writes - .iter() - .map( - |write| match decode_stream_body(&write.bytes, &session_key).1 { - StreamBody::Message(message) => message.tx_seq, - other => panic!("expected stream message, got {other:?}"), - }, - ) - .collect(); - assert_eq!(tx_seqs, vec![StreamSeq::START, StreamSeq(2), StreamSeq(3)]); - - let unique_ids: std::collections::HashSet<_> = writes.iter().map(|write| write.id).collect(); - assert_eq!(unique_ids.len(), writes.len()); - assert_eq!(engine.state.active_writes.len(), writes.len()); - assert!(engine.take_next_write().is_none()); - - let stream = engine.streams.streams.get(&stream_id).unwrap(); - assert!(stream - .control - .in_flight - .iter() - .all(|(_, in_flight)| matches!(in_flight.write_state, InFlightWriteState::Issued { .. }))); -} - -#[test] -fn take_next_write_does_not_reissue_outstanding_frame() { - let SingleEngineHarness { - now, - mut engine, - peer: _, - session_key: _session_key, - } = SingleEngineHarness::connected(EngineConfig::default(), 95, 13); - let stream_id = engine.streams.next_stream_id(); - insert_unwritten_inflight_stream_with_data(&mut engine, stream_id, now, &[]); - - let write = engine.take_next_write().unwrap(); - assert!(engine.take_next_write().is_none()); - assert!(engine.state.active_writes.contains(write.id.0)); -} - -#[test] -fn take_next_write_round_robins_across_ready_streams() { - let SingleEngineHarness { - now, - mut engine, - peer: _, - session_key, - } = SingleEngineHarness::connected(EngineConfig::default(), 97, 14); - let stream_id1 = engine.streams.next_stream_id(); - let stream_id2 = engine.streams.next_stream_id(); - insert_unwritten_inflight_stream_with_data(&mut engine, stream_id1, now, &[2]); - insert_unwritten_inflight_stream_with_data(&mut engine, stream_id2, now, &[2]); - - let scheduled: Vec<_> = { - let mut writes = Vec::new(); - while let Some(write) = engine.take_next_write() { - writes.push(write); - } - writes - } - .into_iter() - .map( - |write| match decode_stream_body(&write.bytes, &session_key).1 { - StreamBody::Message(message) => (message.frame.stream_id(), message.tx_seq), - other => panic!("expected stream message, got {other:?}"), - }, - ) - .collect(); - - assert_eq!( - scheduled, - vec![ - (stream_id1, StreamSeq::START), - (stream_id2, StreamSeq::START), - (stream_id1, StreamSeq(2)), - (stream_id2, StreamSeq(2)), - ] - ); -} - -#[test] -fn stale_ack_delay_timer_after_piggyback_does_not_emit_extra_ack_only() { - let mut harness = Harness::connected(EngineConfig::default()); - let stream_id = harness - .a - .open_stream( - harness.now, - b"open-head".to_vec(), - None, - StreamConfig::default(), - ) - .unwrap(); - let open_write = harness.a.take_next_write().unwrap(); - harness.complete_side_write(Side::A, open_write.id, Ok(())); - harness.run_side(Side::B, EngineInput::Incoming(open_write.bytes)); - let _ = harness.a.drain_outputs(); - let _ = harness.b.drain_outputs(); - - harness.run_side( - Side::B, - EngineInput::OutboundData { - stream_id, - bytes: b"resp".to_vec(), - }, - ); - let _ = harness.a.drain_outputs(); - let _ = harness.b.drain_outputs(); - - harness.now += EngineConfig::default().stream_ack_delay; - harness.run_side(Side::B, EngineInput::TimerExpired); - let _outputs_b_timer = harness.b.drain_outputs(); - - assert!(harness.b.take_next_write().is_none()); -} - -#[test] -fn late_opened_stream_ignores_unrelated_timer_tick() { - let config = EngineConfig::default(); - let SingleEngineHarness { - now, - mut engine, - peer, - session_key, - } = SingleEngineHarness::connected(config, 63, 11); - let stream_id = - StreamId(StreamNamespace::for_local(peer.xid, engine.engine.identity.xid).bit() | 1); - - let early_record = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &StreamBody::Message(StreamMessage { - tx_seq: StreamSeq(2), - ack: StreamAck::EMPTY, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: b"hello".to_vec(), - fin: false, - }, - }), - }), - [31; wire::encrypted_message::NONCE_SIZE], - ); - let _ = engine.run_tick_collect( - now, - EngineInput::Incoming(wire::encode_record(&early_record)), - ); - - let open_record = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &StreamBody::Message(StreamMessage { - tx_seq: StreamSeq::START, - ack: StreamAck::EMPTY, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Open(StreamFrameOpen { - stream_id, - request_head: b"late-open".to_vec(), - request_prefix: None, - }), - }), - [32; wire::encrypted_message::NONCE_SIZE], - ); - let outputs_open = engine.run_tick_collect( - now, - EngineInput::Incoming(wire::encode_record(&open_record)), - ); - assert!(outputs_open.iter().any(|output| matches!( - output, - EngineOutput::InboundStreamOpened { stream_id: id, .. } if *id == stream_id - ))); - - let _outputs_timeout = - engine.run_tick_collect(now + config.packet_expiration, EngineInput::TimerExpired); - - assert!(matches!( - engine - .streams - .streams - .get(&stream_id) - .map(|stream| &stream.role), - Some(StreamRole::Responder(_)) - )); - if let Some(write) = engine.take_next_write() { - let (_, body) = decode_stream_body(&write.bytes, &session_key); - assert!(!matches!( - body, - StreamBody::Message(StreamMessage { - frame: StreamFrame::Close(_), - .. - }) - )); - } -} - -#[test] -fn ack_only_write_failure_immediately_requeues_ack_without_spending_extra_seq() { - let config = EngineConfig::default(); - let SingleEngineHarness { - now, - mut engine, - peer, - session_key, - } = SingleEngineHarness::connected(config, 65, 12); - let stream_id = - StreamId(StreamNamespace::for_local(peer.xid, engine.engine.identity.xid).bit() | 1); - let open_record = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &StreamBody::Message(StreamMessage { - tx_seq: StreamSeq::START, - ack: StreamAck::EMPTY, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Open(StreamFrameOpen { - stream_id, - request_head: b"open".to_vec(), - request_prefix: None, - }), - }), - [33; wire::encrypted_message::NONCE_SIZE], - ); - let outputs_open = engine.run_tick_collect( - now, - EngineInput::Incoming(wire::encode_record(&open_record)), - ); - assert!(outputs_open.iter().any(|output| matches!( - output, - EngineOutput::InboundStreamOpened { stream_id: id, .. } if *id == stream_id - ))); - - let _outputs_ack = - engine.run_tick_collect(now + config.stream_ack_delay, EngineInput::TimerExpired); - let ack_write = engine.take_next_write().unwrap(); - let (_, ack_body) = decode_stream_body(&ack_write.bytes, &session_key); - assert!(matches!( - ack_body, - StreamBody::Ack(StreamAckBody { - stream_id: id, - ack: StreamAck { - base: StreamSeq::START, - bitmap: 0, - }, - .. - }) if id == stream_id - )); - - let outputs_failed = engine.complete_write_collect(ack_write.id, Err(QlError::SendFailed)); - assert!(!outputs_failed - .iter() - .any(|output| matches!(output, EngineOutput::StreamReaped { .. }))); - let retry_write = engine.take_next_write().unwrap(); - let (_, retry_body) = decode_stream_body(&retry_write.bytes, &session_key); - assert!(matches!( - retry_body, - StreamBody::Ack(StreamAckBody { - stream_id: id, - ack: StreamAck { - base: StreamSeq::START, - bitmap: 0, - }, - .. - }) if id == stream_id - )); - - let _ = engine.complete_write_collect(retry_write.id, Ok(())); - - let _outputs_data = engine.run_tick_collect( - now + config.stream_ack_delay, - EngineInput::OutboundData { - stream_id, - bytes: b"resp".to_vec(), - }, - ); - let response_write = engine.take_next_write().unwrap(); - let (_, body) = decode_stream_body(&response_write.bytes, &session_key); - assert!(matches!( - body, - StreamBody::Message(StreamMessage { - tx_seq: StreamSeq::START, - frame: StreamFrame::Data(StreamFrameData { - stream_id: id, - chunk: BodyChunk { bytes, fin: false }, - }), - .. - }) if id == stream_id && bytes == b"resp" - )); - let stream = engine.streams.streams.get(&stream_id).unwrap(); - assert_eq!(stream.control.next_tx_seq, StreamSeq(2)); -} - -#[test] -fn duplicate_committed_data_is_acked_without_redelivery() { - let config = EngineConfig::default(); - let SingleEngineHarness { - now, - mut engine, - peer, - session_key, - } = SingleEngineHarness::connected(config, 67, 13); - let stream_id = - StreamId(StreamNamespace::for_local(peer.xid, engine.engine.identity.xid).bit() | 1); - - for (nonce, body) in [ - ( - 34u8, - StreamBody::Message(StreamMessage { - tx_seq: StreamSeq::START, - ack: StreamAck::EMPTY, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Open(StreamFrameOpen { - stream_id, - request_head: b"open".to_vec(), - request_prefix: None, - }), - }), - ), - ( - 35u8, - StreamBody::Message(StreamMessage { - tx_seq: StreamSeq(2), - ack: StreamAck::EMPTY, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: b"hello".to_vec(), - fin: false, - }, - }), - }), - ), - ] { - let record = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &body, - [nonce; wire::encrypted_message::NONCE_SIZE], - ); - let _ = engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&record))); - } - - let duplicate_record = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &StreamBody::Message(StreamMessage { - tx_seq: StreamSeq(2), - ack: StreamAck::EMPTY, - valid_until: wire::now_secs().saturating_add(60), - frame: StreamFrame::Data(StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: b"hello".to_vec(), - fin: false, - }, - }), - }), - [36; wire::encrypted_message::NONCE_SIZE], - ); - let outputs_dup = engine.run_tick_collect( - now, - EngineInput::Incoming(wire::encode_record(&duplicate_record)), - ); - - assert!(!outputs_dup - .iter() - .any(|output| matches!(output, EngineOutput::InboundData { .. }))); - let ack_write = engine.take_next_write().unwrap(); - let (_, body) = decode_stream_body(&ack_write.bytes, &session_key); - assert!(matches!( - body, - StreamBody::Ack(StreamAckBody { - stream_id: id, - ack: StreamAck { - base: StreamSeq(2), - bitmap: 0, - }, - .. - }) if id == stream_id - )); -} - -#[test] -fn repeated_identical_gap_ack_only_fast_retransmits_once() { - let mut config = EngineConfig::default(); - config.stream_fast_retransmit_threshold = 2; - let SingleEngineHarness { - now, - mut engine, - peer, - session_key, - } = SingleEngineHarness::connected(config, 69, 14); - let stream_id = engine.streams.next_stream_id(); - insert_inflight_gap_stream(&mut engine, stream_id, now); - - let local_xid = engine.engine.identity.xid; - let remote_xid = peer.xid; - let ack_record = |nonce: u8| { - wire::stream::encrypt_stream( - QlHeader { - sender: remote_xid, - recipient: local_xid, - }, - &session_key, - &StreamBody::Ack(StreamAckBody { - stream_id, - ack: StreamAck { - base: StreamSeq(2), - bitmap: 0b0000_0110, - }, - valid_until: wire::now_secs().saturating_add(60), - }), - [nonce; wire::encrypted_message::NONCE_SIZE], - ) - }; - - let _outputs_first = engine.run_tick_collect( - now, - EngineInput::Incoming(wire::encode_record(&ack_record(37))), - ); - let write = engine.take_next_write().unwrap(); - let (_, body) = decode_stream_body(&write.bytes, &session_key); - assert!(matches!( - body, - StreamBody::Message(StreamMessage { - tx_seq: StreamSeq(3), - .. - }) - )); - - let _ = engine.complete_write_collect(write.id, Ok(())); - - let _outputs_second = engine.run_tick_collect( - now, - EngineInput::Incoming(wire::encode_record(&ack_record(38))), - ); - assert!(engine.take_next_write().is_none()); -} - -#[test] -fn fast_recovery_clears_after_gap_is_acked_and_allows_next_gap() { - let mut config = EngineConfig::default(); - config.stream_fast_retransmit_threshold = 1; - let SingleEngineHarness { - now, - mut engine, - peer, - session_key, - } = SingleEngineHarness::connected(config, 73, 15); - let stream_id = engine.streams.next_stream_id(); - insert_inflight_stream_with_data(&mut engine, stream_id, now, &[2, 3, 4, 5, 6]); - - let first_ack = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &StreamBody::Ack(StreamAckBody { - stream_id, - ack: StreamAck { - base: StreamSeq(2), - bitmap: 0b0000_0010, - }, - valid_until: wire::now_secs().saturating_add(60), - }), - [39; wire::encrypted_message::NONCE_SIZE], - ); - let _outputs_first = - engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&first_ack))); - let write_first = engine.take_next_write().unwrap(); - let (_, first_body) = decode_stream_body(&write_first.bytes, &session_key); - assert!(matches!( - first_body, - StreamBody::Message(StreamMessage { - tx_seq: StreamSeq(3), - .. - }) - )); - - let _ = engine.complete_write_collect(write_first.id, Ok(())); - - let second_ack = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &StreamBody::Ack(StreamAckBody { - stream_id, - ack: StreamAck { - base: StreamSeq(4), - bitmap: 0b0000_0010, - }, - valid_until: wire::now_secs().saturating_add(60), - }), - [40; wire::encrypted_message::NONCE_SIZE], - ); - let _outputs_second = - engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&second_ack))); - let write_second = engine.take_next_write().unwrap(); - let (_, second_body) = decode_stream_body(&write_second.bytes, &session_key); - assert!(matches!( - second_body, - StreamBody::Message(StreamMessage { - tx_seq: StreamSeq(5), - .. - }) - )); -} - -#[test] -fn fast_retransmit_and_retry_deadline_same_tick_only_send_once() { - let mut config = EngineConfig::default(); - config.stream_fast_retransmit_threshold = 2; - let SingleEngineHarness { - now, - mut engine, - peer, - session_key, - } = SingleEngineHarness::connected(config, 75, 16); - let stream_id = engine.streams.next_stream_id(); - insert_inflight_gap_stream(&mut engine, stream_id, now); - - { - let in_flight = engine - .streams - .streams - .get_mut(&stream_id) - .unwrap() - .control - .in_flight - .get_mut(&StreamSeq(3)) - .unwrap(); - in_flight.write_state = InFlightWriteState::WaitingRetry { retry_at: now }; - } - - let ack_record = wire::stream::encrypt_stream( - QlHeader { - sender: peer.xid, - recipient: engine.engine.identity.xid, - }, - &session_key, - &StreamBody::Ack(StreamAckBody { - stream_id, - ack: StreamAck { - base: StreamSeq(2), - bitmap: 0b0000_0110, - }, - valid_until: wire::now_secs().saturating_add(60), - }), - [41; wire::encrypted_message::NONCE_SIZE], - ); - let _outputs_ack = - engine.run_tick_collect(now, EngineInput::Incoming(wire::encode_record(&ack_record))); - let _write = engine.take_next_write().unwrap(); - assert!(engine.take_next_write().is_none()); - - let _outputs_timeout = engine.run_tick_collect(now, EngineInput::TimerExpired); - assert!(engine.take_next_write().is_none()); -} diff --git a/ql-engine/src/identity.rs b/ql-engine/src/identity.rs deleted file mode 100644 index b4e12886..00000000 --- a/ql-engine/src/identity.rs +++ /dev/null @@ -1,29 +0,0 @@ -use bc_components::{ - MLDSAPrivateKey, MLDSAPublicKey, MLKEMPrivateKey, MLKEMPublicKey, SigningPublicKey, XID, -}; - -#[derive(Debug, Clone)] -pub struct QlIdentity { - pub xid: XID, - pub signing_private_key: MLDSAPrivateKey, - pub signing_public_key: MLDSAPublicKey, - pub encapsulation_private_key: MLKEMPrivateKey, - pub encapsulation_public_key: MLKEMPublicKey, -} - -impl QlIdentity { - pub fn from_keys( - signing_private_key: MLDSAPrivateKey, - signing_public_key: MLDSAPublicKey, - encapsulation_private_key: MLKEMPrivateKey, - encapsulation_public_key: MLKEMPublicKey, - ) -> Self { - Self { - xid: XID::new(SigningPublicKey::MLDSA(signing_public_key.clone())), - signing_private_key, - signing_public_key, - encapsulation_private_key, - encapsulation_public_key, - } - } -} diff --git a/ql-engine/src/lib.rs b/ql-engine/src/lib.rs deleted file mode 100644 index db9c878d..00000000 --- a/ql-engine/src/lib.rs +++ /dev/null @@ -1,40 +0,0 @@ -pub(crate) mod arena; -pub mod engine; -pub mod identity; -pub mod stream; -pub mod wire; - -pub use wire::{PacketId, StreamId}; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Peer { - pub peer: bc_components::XID, - pub signing_key: bc_components::MLDSAPublicKey, - pub encapsulation_key: bc_components::MLKEMPublicKey, -} - -#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] -pub enum QlError { - #[error("invalid payload")] - InvalidPayload, - #[error("invalid signature")] - InvalidSignature, - #[error("missing session")] - MissingSession, - #[error("no peer bound")] - NoPeerBound, - #[error("timeout")] - Timeout, - #[error("send failed")] - SendFailed, - #[error("stream closed {code:?}")] - StreamClosed { - target: wire::stream::CloseTarget, - code: wire::stream::CloseCode, - payload: Vec, - }, - #[error("stream protocol error")] - StreamProtocol, - #[error("cancelled")] - Cancelled, -} diff --git a/ql-engine/src/stream/internal.rs b/ql-engine/src/stream/internal.rs deleted file mode 100644 index 358f0465..00000000 --- a/ql-engine/src/stream/internal.rs +++ /dev/null @@ -1,842 +0,0 @@ -use std::{collections::VecDeque, time::Instant}; - -use super::{state::*, *}; -use crate::{ - wire::{ - stream::{ - BodyChunk, CloseCode, CloseTarget, StreamAck, StreamAckBody, StreamBody, StreamFrame, - StreamFrameClose, StreamFrameData, StreamFrameOpen, StreamMessage, - }, - StreamSeq, - }, - StreamId, -}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum OutboundSelection { - Ack, - InitialFrame { tx_seq: StreamSeq }, - RetryFrame { tx_seq: StreamSeq }, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum StreamDisposition { - Keep, - Reap, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum TimerAction { - None, - Fail, -} - -impl StreamFsm { - pub fn open_stream_inner( - &mut self, - request_head: Vec, - request_prefix: Option, - ) -> StreamId { - let stream_id = self.next_stream_id(); - let request_prefix_fin = request_prefix.as_ref().is_some_and(|chunk| chunk.fin); - let mut stream = StreamState { - control: StreamControl { - pending: VecDeque::from([StreamFrame::Open(StreamFrameOpen { - stream_id, - request_head, - request_prefix, - })]), - ..Default::default() - }, - role: StreamRole::Initiator(InitiatorStream { - request: OutboundPhase::from_prefix(request_prefix_fin), - response: InboundState::new(), - }), - }; - Self::drive_stream(&mut stream, stream_id); - self.streams.insert(stream_id, stream); - stream_id - } - - pub fn write_stream_inner( - &mut self, - stream_id: StreamId, - bytes: Vec, - ) -> Result<(), StreamError> { - if bytes.is_empty() { - return Ok(()); - } - - let Some(stream) = self.streams.get_mut(&stream_id) else { - return Err(StreamError::MissingStream); - }; - let Some(side) = stream.outbound_side() else { - return Err(StreamError::NotWritable); - }; - if let StreamRole::Responder(state) = &mut stream.role { - if side == StreamSide::Response { - state.response_started = true; - } - } - let Some(outbound) = stream.outbound_mut(side) else { - return Err(StreamError::NotWritable); - }; - if !outbound.can_queue_data() { - return Err(StreamError::NotWritable); - } - - stream - .control - .pending - .push_back(StreamFrame::Data(StreamFrameData { - stream_id, - chunk: BodyChunk { bytes, fin: false }, - })); - Self::drive_stream(stream, stream_id); - Ok(()) - } - - pub fn finish_stream_inner(&mut self, stream_id: StreamId) -> Result<(), StreamError> { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return Err(StreamError::MissingStream); - }; - let Some(side) = stream.outbound_side() else { - return Err(StreamError::NotWritable); - }; - if let StreamRole::Responder(state) = &mut stream.role { - if side == StreamSide::Response { - state.response_started = true; - } - } - let Some(outbound) = stream.outbound_mut(side) else { - return Err(StreamError::NotWritable); - }; - outbound.finish(); - Self::drive_stream(stream, stream_id); - Ok(()) - } - - pub fn close_stream_inner( - &mut self, - stream_id: StreamId, - target: CloseTarget, - code: CloseCode, - payload: Vec, - ) -> Result<(), StreamError> { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return Err(StreamError::MissingStream); - }; - - let mut dirty = false; - if matches!(target, CloseTarget::Request | CloseTarget::Both) { - if let Some(inbound) = stream.inbound_mut(StreamSide::Request) { - dirty |= inbound.close(); - } - if let Some(outbound) = stream.outbound_mut(StreamSide::Request) { - dirty |= outbound.close(); - } - } - if matches!(target, CloseTarget::Response | CloseTarget::Both) { - if let Some(inbound) = stream.inbound_mut(StreamSide::Response) { - dirty |= inbound.close(); - } - if let Some(outbound) = stream.outbound_mut(StreamSide::Response) { - dirty |= outbound.close(); - } - } - - if dirty { - stream - .control - .pending - .push_front(close_frame(stream_id, target, code, payload)); - Self::drive_stream(stream, stream_id); - } - - Ok(()) - } - - pub fn receive_inner( - &mut self, - now: Instant, - body: StreamBody, - events: &mut impl StreamEventSink, - ) { - match body { - StreamBody::Ack(StreamAckBody { stream_id, ack, .. }) => { - self.process_ack(now, stream_id, ack, events) - } - StreamBody::Message(StreamMessage { - tx_seq, ack, frame, .. - }) => { - let stream_id = frame.stream_id(); - self.process_ack(now, stream_id, ack, events); - - if !self.streams.contains_key(&stream_id) { - if !self.config.local_namespace.remote().matches(stream_id) { - return; - } - self.streams.insert( - stream_id, - StreamState { - control: StreamControl::default(), - role: StreamRole::Responder(ResponderStream { - opened: false, - request: InboundState::new(), - response: OutboundPhase::Ready, - response_started: false, - }), - }, - ); - } - - let disposition = { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - - match stream.control.buffer_incoming(tx_seq, frame) { - BufferIncomingResult::OutOfWindow => { - Self::queue_protocol_close(stream_id, stream, events); - StreamDisposition::Keep - } - BufferIncomingResult::Duplicate | BufferIncomingResult::AlreadyBuffered => { - stream.control.note_ack(now, self.config.ack_delay, true); - StreamDisposition::Keep - } - BufferIncomingResult::Buffered { out_of_order } => { - stream - .control - .note_ack(now, self.config.ack_delay, out_of_order); - Self::drain_committed_frames(stream_id, stream, events) - } - } - }; - - match disposition { - StreamDisposition::Keep => {} - StreamDisposition::Reap => { - self.streams.remove(&stream_id); - events.reaped(stream_id); - } - } - } - } - } - - pub fn next_outbound_inner(&mut self, now: Instant, valid_until: u64) -> Option { - for offset in 0..self.streams.len() { - let stream_id = self.streams.id_at_offset(offset)?; - let selection = { - let stream = self.streams.get(&stream_id)?; - self.select_outbound(stream, now) - }; - let Some(selection) = selection else { - continue; - }; - - let outbound = match selection { - OutboundSelection::Ack => { - let stream = self.streams.get_mut(&stream_id)?; - let ack = stream.control.current_ack(); - stream.control.clear_ack_schedule(); - stream.control.note_ack_sent(ack); - Outbound { - body: StreamBody::Ack(StreamAckBody { - stream_id, - ack, - valid_until, - }), - completion: OutboundCompletion::Ack { stream_id }, - } - } - OutboundSelection::InitialFrame { tx_seq } - | OutboundSelection::RetryFrame { tx_seq } => { - let issue_id = self.next_issue_id(); - let stream = self.streams.get_mut(&stream_id)?; - let inbound_alive = match stream.role { - StreamRole::Initiator(state) => !state.response.closed, - StreamRole::Responder(state) => !state.request.closed, - }; - let ack = stream.control.take_piggyback_ack(inbound_alive); - let frame = stream.control.mark_write_issued(tx_seq, issue_id)?; - Outbound { - body: StreamBody::Message(StreamMessage { - tx_seq, - ack, - valid_until, - frame, - }), - completion: OutboundCompletion::Frame { - stream_id, - tx_seq, - issue_id, - }, - } - } - }; - - self.streams.advance_cursor_after(stream_id); - return Some(outbound); - } - - None - } - - pub fn complete_outbound_inner( - &mut self, - now: Instant, - completion: OutboundCompletion, - result: Result<(), WriteError>, - events: &mut impl StreamEventSink, - ) { - match completion { - OutboundCompletion::Ack { stream_id } => { - if let Some(stream) = self.streams.get_mut(&stream_id) { - if result.is_err() { - stream.control.note_ack(now, self.config.ack_delay, true); - } - if stream.can_reap() { - self.streams.remove(&stream_id); - events.reaped(stream_id); - } - } - } - OutboundCompletion::Frame { - stream_id, - tx_seq, - issue_id, - } => match result { - Ok(()) => { - if let Some(stream) = self.streams.get_mut(&stream_id) { - let _ = stream.control.complete_write( - tx_seq, - issue_id, - now + self.config.ack_timeout, - ); - } - } - Err(WriteError::SendFailed) => { - let should_fail = self.streams.get(&stream_id).is_some_and(|stream| { - stream.control.frame_write_is_issued(tx_seq, issue_id) - }); - if should_fail { - self.fail_stream_by_id(stream_id, StreamError::SendFailed, events); - } - } - }, - } - } - - pub fn on_timer_inner(&mut self, now: Instant, events: &mut impl StreamEventSink) { - let mut index = 0; - while let Some(stream_id) = self.streams.ordered_id(index) { - let action = { - let stream = self - .streams - .get(&stream_id) - .expect("ordered stream id should exist"); - if stream.control.in_flight.iter().any(|(_, in_flight)| { - matches!( - in_flight.write_state, - InFlightWriteState::WaitingRetry { retry_at } - if retry_at <= now && in_flight.attempt >= self.config.retry_limit - ) - }) { - TimerAction::Fail - } else { - TimerAction::None - } - }; - - match action { - TimerAction::Fail => { - self.fail_stream_by_id(stream_id, StreamError::Timeout, events); - } - TimerAction::None => { - if let Some(stream) = self.streams.get_mut(&stream_id) { - if stream - .control - .ack_deadline() - .is_some_and(|due_at| due_at <= now) - { - stream.control.ack_state = AckState::Immediate; - } - } - index += 1; - } - } - } - } - - pub fn next_deadline_inner(&self) -> Option { - let mut next = None; - for stream in self.streams.values() { - if let Some(deadline) = stream.control.ack_deadline() { - next = min_deadline(next, deadline); - } - for (_, in_flight) in stream.control.in_flight.iter() { - if let InFlightWriteState::WaitingRetry { retry_at } = in_flight.write_state { - next = min_deadline(next, retry_at); - } - } - } - next - } - - pub fn abort_inner(&mut self, error: StreamError, events: &mut impl StreamEventSink) { - while let Some(stream_id) = self.streams.first_id() { - self.fail_stream_by_id(stream_id, error.clone(), events); - } - } -} - -impl StreamFsm { - pub(crate) fn next_stream_id(&mut self) -> StreamId { - let seq = self.next_stream_id; - self.next_stream_id = seq.wrapping_add(1); - StreamId((seq & !StreamNamespace::BIT) | self.config.local_namespace.bit()) - } - - fn next_issue_id(&mut self) -> u64 { - let id = self.next_issue_id; - self.next_issue_id = id.wrapping_add(1); - id - } - - fn select_outbound(&self, stream: &StreamState, now: Instant) -> Option { - if let Some(tx_seq) = stream - .control - .in_flight - .iter() - .find_map(|(tx_seq, in_flight)| { - matches!( - in_flight.write_state, - InFlightWriteState::WaitingRetry { retry_at } - if retry_at <= now && in_flight.attempt < self.config.retry_limit - ) - .then_some(tx_seq) - }) - { - return Some(OutboundSelection::RetryFrame { tx_seq }); - } - if let Some(tx_seq) = stream - .control - .in_flight - .iter() - .find_map(|(tx_seq, in_flight)| { - matches!(in_flight.write_state, InFlightWriteState::Ready).then_some(tx_seq) - }) - { - return Some(OutboundSelection::InitialFrame { tx_seq }); - } - - matches!(stream.control.ack_state, AckState::Immediate).then_some(OutboundSelection::Ack) - } - - fn process_ack( - &mut self, - now: Instant, - stream_id: StreamId, - ack: StreamAck, - events: &mut impl StreamEventSink, - ) { - if ack == StreamAck::EMPTY { - return; - } - - let should_reap = { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - stream.control.clear_fast_recovery(ack.base); - let fast_retransmit = stream - .control - .fast_retransmit_candidate(ack, self.config.fast_retransmit_threshold); - - loop { - let acked_tx_seq = - stream - .control - .in_flight - .iter() - .find_map(|(tx_seq, in_flight)| match in_flight.write_state { - InFlightWriteState::Ready => None, - InFlightWriteState::Issued { .. } - | InFlightWriteState::WaitingRetry { .. } => { - StreamControl::ack_covers(ack, tx_seq).then_some(tx_seq) - } - }); - let Some(tx_seq) = acked_tx_seq else { - break; - }; - let Some(in_flight) = stream.control.remove_in_flight(tx_seq) else { - continue; - }; - - match in_flight.frame { - StreamFrame::Open(StreamFrameOpen { request_prefix, .. }) => { - if let StreamRole::Initiator(state) = &mut stream.role { - if request_prefix.as_ref().is_some_and(|chunk| chunk.fin) - && state.request.close() - { - events.outbound_closed(stream_id); - } - } - } - StreamFrame::Data(StreamFrameData { - chunk: BodyChunk { fin: true, .. }, - .. - }) => { - if let Some(side) = stream.outbound_side() { - if let Some(outbound) = stream.outbound_mut(side) { - if outbound.close() { - events.outbound_closed(stream_id); - } - } - } - } - StreamFrame::Close(frame) => { - let mut changed = false; - for side in [StreamSide::Request, StreamSide::Response] { - let affects_outbound = matches!( - (frame.target, side), - (CloseTarget::Request, StreamSide::Request) - | (CloseTarget::Response, StreamSide::Response) - | (CloseTarget::Both, _) - ); - if affects_outbound { - if let Some(outbound) = stream.outbound_mut(side) { - if outbound.close() { - changed = true; - } - } - } - } - if changed { - events.close(StreamCloseEvent { - kind: StreamCloseKind::Acked, - role: stream.local_role(), - frame, - }); - } - } - StreamFrame::Data(_) => {} - } - } - - if let Some(tx_seq) = fast_retransmit { - stream.control.schedule_fast_retransmit(tx_seq, now); - } - Self::drive_stream(stream, stream_id); - stream.can_reap() - }; - - if should_reap { - self.streams.remove(&stream_id); - events.reaped(stream_id); - } - } - - fn drain_committed_frames( - stream_id: StreamId, - stream: &mut StreamState, - events: &mut impl StreamEventSink, - ) -> StreamDisposition { - loop { - let Some((tx_seq, frame)) = stream.control.pop_next_committable() else { - break; - }; - - if stream.awaiting_open() - && (tx_seq != StreamSeq::START || !matches!(frame, StreamFrame::Open(_))) - { - Self::queue_protocol_close(stream_id, stream, events); - return StreamDisposition::Keep; - } - - match frame { - StreamFrame::Open(frame) => { - Self::handle_stream_open(stream_id, stream, frame, events) - } - StreamFrame::Close(frame) => { - Self::handle_stream_close_from_peer(stream_id, stream, frame, events) - } - StreamFrame::Data(frame) => { - Self::handle_stream_data(stream_id, stream, frame, events) - } - } - } - - stream.control.maybe_force_ack_for_progress(); - if stream.can_reap() { - StreamDisposition::Reap - } else { - StreamDisposition::Keep - } - } - - fn handle_stream_open( - stream_id: StreamId, - stream: &mut StreamState, - frame: StreamFrameOpen, - events: &mut impl StreamEventSink, - ) { - let StreamFrameOpen { - request_head, - request_prefix, - .. - } = frame; - - let StreamRole::Responder(state) = &mut stream.role else { - Self::queue_protocol_close(stream_id, stream, events); - return; - }; - if state.opened { - Self::queue_protocol_close(stream_id, stream, events); - return; - } - - let request_fin = request_prefix.as_ref().is_some_and(|chunk| chunk.fin); - state.opened = true; - if request_fin { - let _ = stream - .inbound_mut(StreamSide::Request) - .expect("responder request side should exist") - .close(); - } - events.opened(stream_id, request_head, request_prefix); - } - - fn handle_stream_close_from_peer( - stream_id: StreamId, - stream: &mut StreamState, - frame: StreamFrameClose, - events: &mut impl StreamEventSink, - ) { - let StreamFrameClose { - target, - code, - payload, - .. - } = frame; - Self::apply_remote_close(stream_id, stream, target, code, payload, events); - } - - fn handle_stream_data( - stream_id: StreamId, - stream: &mut StreamState, - frame: StreamFrameData, - events: &mut impl StreamEventSink, - ) { - let Some(side) = stream.inbound_side() else { - Self::queue_protocol_close(stream_id, stream, events); - return; - }; - let Some(inbound) = stream.inbound_mut(side) else { - Self::queue_protocol_close(stream_id, stream, events); - return; - }; - if inbound.closed { - Self::queue_protocol_close(stream_id, stream, events); - return; - } - - let BodyChunk { bytes, fin } = frame.chunk; - if !bytes.is_empty() { - events.inbound_data(stream_id, bytes); - } - if fin && inbound.close() { - events.inbound_finished(stream_id); - } - } - - fn drive_stream(stream: &mut StreamState, stream_id: StreamId) { - match &mut stream.role { - StreamRole::Initiator(state) => Self::drive_stream_outbound( - stream_id, - &mut stream.control, - Some(&mut state.request), - ), - StreamRole::Responder(state) => Self::drive_stream_outbound( - stream_id, - &mut stream.control, - Some(&mut state.response), - ), - } - } - - fn drive_stream_outbound( - stream_id: StreamId, - control: &mut StreamControl, - mut outbound: Option<&mut OutboundPhase>, - ) { - loop { - if control.send_window_has_space() { - if let Some(frame) = control.pending.pop_front() { - let tx_seq = control.take_tx_seq(); - control.insert_in_flight(InFlightFrame { - tx_seq, - frame, - attempt: 0, - write_state: InFlightWriteState::Ready, - }); - continue; - } - } - if !control.send_window_has_space() { - return; - } - let Some(outbound) = outbound.as_deref_mut() else { - return; - }; - if outbound.queue_fin() { - let tx_seq = control.take_tx_seq(); - control.insert_in_flight(InFlightFrame { - tx_seq, - frame: StreamFrame::Data(StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: Vec::new(), - fin: true, - }, - }), - attempt: 0, - write_state: InFlightWriteState::Ready, - }); - continue; - } - return; - } - } - - fn queue_protocol_close( - stream_id: StreamId, - stream: &mut StreamState, - events: &mut impl StreamEventSink, - ) { - let opened = !stream.awaiting_open(); - stream.control.clear_transient_buffers(); - stream.control.pending.push_front(close_frame( - stream_id, - CloseTarget::Both, - CloseCode::PROTOCOL, - Vec::new(), - )); - for side in [StreamSide::Request, StreamSide::Response] { - if let Some(outbound) = stream.outbound_mut(side) { - if outbound.close() { - if opened { - events.outbound_failed(stream_id, StreamError::StreamProtocol); - } - } - } - if let Some(inbound) = stream.inbound_mut(side) { - if inbound.close() { - if opened { - events.inbound_failed(stream_id, StreamError::StreamProtocol); - } - } - } - } - Self::drive_stream(stream, stream_id); - } - - fn apply_remote_close( - stream_id: StreamId, - stream: &mut StreamState, - target: CloseTarget, - code: CloseCode, - payload: Vec, - events: &mut impl StreamEventSink, - ) { - let frame = StreamFrameClose { - stream_id, - target, - code, - payload, - }; - let mut changed = false; - if matches!(target, CloseTarget::Request | CloseTarget::Both) { - if let Some(inbound) = stream.inbound_mut(StreamSide::Request) { - if inbound.close() { - changed = true; - } - } - if let Some(outbound) = stream.outbound_mut(StreamSide::Request) { - if outbound.close() { - changed = true; - } - } - } - if matches!(target, CloseTarget::Response | CloseTarget::Both) { - if let Some(inbound) = stream.inbound_mut(StreamSide::Response) { - if inbound.close() { - changed = true; - } - } - if let Some(outbound) = stream.outbound_mut(StreamSide::Response) { - if outbound.close() { - changed = true; - } - } - } - if changed { - events.close(StreamCloseEvent { - kind: StreamCloseKind::Remote, - role: stream.local_role(), - frame, - }); - } - } - - fn fail_stream_by_id( - &mut self, - stream_id: StreamId, - error: StreamError, - events: &mut impl StreamEventSink, - ) { - let Some(stream) = self.streams.remove(&stream_id) else { - return; - }; - - match stream.role { - StreamRole::Initiator(_) => { - events.outbound_failed(stream_id, error.clone()); - events.inbound_failed(stream_id, error); - } - StreamRole::Responder(stream) => { - if !stream.opened { - events.reaped(stream_id); - return; - } - events.inbound_failed(stream_id, error.clone()); - if stream.response_started || stream.response.is_closed() { - events.outbound_failed(stream_id, error); - } - } - } - events.reaped(stream_id); - } -} - -fn min_deadline(current: Option, candidate: Instant) -> Option { - Some(match current { - Some(current) => current.min(candidate), - None => candidate, - }) -} - -fn close_frame( - stream_id: StreamId, - target: CloseTarget, - code: CloseCode, - payload: Vec, -) -> StreamFrame { - StreamFrame::Close(StreamFrameClose { - stream_id, - target, - code, - payload, - }) -} diff --git a/ql-engine/src/stream/mod.rs b/ql-engine/src/stream/mod.rs deleted file mode 100644 index 481cbc8c..00000000 --- a/ql-engine/src/stream/mod.rs +++ /dev/null @@ -1,270 +0,0 @@ -pub(crate) mod internal; -pub(crate) mod ring; -pub(crate) mod state; - -#[cfg(test)] -mod tests; - -use std::time::{Duration, Instant}; - -use thiserror::Error; - -use crate::{ - wire::{ - stream::{BodyChunk, CloseCode, CloseTarget, StreamBody, StreamFrameClose}, - StreamSeq, - }, - StreamId, -}; - -pub const STREAM_WINDOW_CAPACITY: usize = 8; -pub const STREAM_WINDOW_SIZE: u32 = STREAM_WINDOW_CAPACITY as u32; -pub const STREAM_ACK_EAGER_THRESHOLD: u32 = STREAM_WINDOW_SIZE / 2; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum StreamNamespace { - Low, - High, -} - -impl StreamNamespace { - const BIT: u32 = 1 << 31; - - pub fn for_local(local: bc_components::XID, peer: bc_components::XID) -> Self { - match local.data().cmp(peer.data()) { - std::cmp::Ordering::Less | std::cmp::Ordering::Equal => Self::Low, - std::cmp::Ordering::Greater => Self::High, - } - } - - pub fn bit(self) -> u32 { - match self { - Self::Low => 0, - Self::High => Self::BIT, - } - } - - pub fn matches(self, stream_id: StreamId) -> bool { - (stream_id.0 & Self::BIT) == self.bit() - } - - pub fn remote(self) -> Self { - match self { - Self::Low => Self::High, - Self::High => Self::Low, - } - } -} - -#[derive(Debug, Clone, Copy)] -pub struct StreamFsmConfig { - pub local_namespace: StreamNamespace, - pub ack_delay: Duration, - pub ack_timeout: Duration, - pub fast_retransmit_threshold: u8, - pub retry_limit: u8, -} - -impl Default for StreamFsmConfig { - fn default() -> Self { - Self { - local_namespace: StreamNamespace::Low, - ack_delay: Duration::from_millis(5), - ack_timeout: Duration::from_millis(150), - fast_retransmit_threshold: 2, - retry_limit: 5, - } - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum OutboundCompletion { - Ack { - stream_id: StreamId, - }, - Frame { - stream_id: StreamId, - tx_seq: StreamSeq, - issue_id: u64, - }, -} - -impl OutboundCompletion { - pub fn stream_id(self) -> StreamId { - match self { - Self::Ack { stream_id } | Self::Frame { stream_id, .. } => stream_id, - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Outbound { - pub body: StreamBody, - pub completion: OutboundCompletion, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum StreamCloseKind { - Acked, - Remote, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum StreamLocalRole { - Initiator, - Responder, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct StreamCloseEvent { - pub kind: StreamCloseKind, - pub role: StreamLocalRole, - pub frame: StreamFrameClose, -} - -pub trait StreamEventSink { - fn opened( - &mut self, - stream_id: StreamId, - request_head: Vec, - request_prefix: Option, - ); - - fn inbound_data(&mut self, stream_id: StreamId, bytes: Vec); - - fn inbound_finished(&mut self, stream_id: StreamId); - - fn inbound_failed(&mut self, stream_id: StreamId, error: StreamError); - - fn close(&mut self, event: StreamCloseEvent); - - fn outbound_closed(&mut self, stream_id: StreamId); - - fn outbound_failed(&mut self, stream_id: StreamId, error: StreamError); - - fn reaped(&mut self, stream_id: StreamId); -} - -impl StreamEventSink for () { - fn opened( - &mut self, - _stream_id: StreamId, - _request_head: Vec, - _request_prefix: Option, - ) { - } - - fn inbound_data(&mut self, _stream_id: StreamId, _bytes: Vec) {} - - fn inbound_finished(&mut self, _stream_id: StreamId) {} - - fn inbound_failed(&mut self, _stream_id: StreamId, _error: StreamError) {} - - fn close(&mut self, _event: StreamCloseEvent) {} - - fn outbound_closed(&mut self, _stream_id: StreamId) {} - - fn outbound_failed(&mut self, _stream_id: StreamId, _error: StreamError) {} - - fn reaped(&mut self, _stream_id: StreamId) {} -} - -#[derive(Debug, Clone, PartialEq, Eq, Error)] -pub enum StreamError { - #[error("missing stream")] - MissingStream, - #[error("stream is not writable")] - NotWritable, - #[error("send failed")] - SendFailed, - #[error("timeout")] - Timeout, - #[error("cancelled")] - Cancelled, - #[error("stream protocol error")] - StreamProtocol, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)] -pub enum WriteError { - #[error("send failed")] - SendFailed, -} - -pub struct StreamFsm { - config: StreamFsmConfig, - pub(crate) streams: state::StreamStore, - next_stream_id: u32, - next_issue_id: u64, -} - -impl StreamFsm { - pub fn new(config: StreamFsmConfig) -> Self { - Self { - config, - streams: state::StreamStore::default(), - next_stream_id: 1, - next_issue_id: 1, - } - } - - pub fn open_stream( - &mut self, - request_head: Vec, - request_prefix: Option, - ) -> StreamId { - self.open_stream_inner(request_head, request_prefix) - } - - pub fn write_stream(&mut self, stream_id: StreamId, bytes: Vec) -> Result<(), StreamError> { - self.write_stream_inner(stream_id, bytes) - } - - pub fn finish_stream(&mut self, stream_id: StreamId) -> Result<(), StreamError> { - self.finish_stream_inner(stream_id) - } - - pub fn close_stream( - &mut self, - stream_id: StreamId, - target: CloseTarget, - code: CloseCode, - payload: Vec, - ) -> Result<(), StreamError> { - self.close_stream_inner(stream_id, target, code, payload) - } - - pub fn receive(&mut self, now: Instant, body: StreamBody, events: &mut impl StreamEventSink) { - self.receive_inner(now, body, events) - } - - pub fn next_outbound(&mut self, now: Instant, valid_until: u64) -> Option { - self.next_outbound_inner(now, valid_until) - } - - pub fn complete_outbound( - &mut self, - now: Instant, - completion: OutboundCompletion, - result: Result<(), WriteError>, - events: &mut impl StreamEventSink, - ) { - self.complete_outbound_inner(now, completion, result, events) - } - - pub fn on_timer(&mut self, now: Instant, events: &mut impl StreamEventSink) { - self.on_timer_inner(now, events) - } - - pub fn next_deadline(&self) -> Option { - self.next_deadline_inner() - } - - pub fn abort(&mut self, error: StreamError, events: &mut impl StreamEventSink) { - self.abort_inner(error, events); - } - - pub fn set_local_namespace(&mut self, local_namespace: StreamNamespace) { - self.config.local_namespace = local_namespace; - } -} diff --git a/ql-engine/src/stream/ring.rs b/ql-engine/src/stream/ring.rs deleted file mode 100644 index ccefb1ce..00000000 --- a/ql-engine/src/stream/ring.rs +++ /dev/null @@ -1,194 +0,0 @@ -use std::array; - -use crate::wire::StreamSeq; - -#[derive(Debug)] -pub enum SeqRingInsertError { - OutOfWindow, - Occupied, -} - -#[derive(Debug)] -pub struct SeqRing { - base_seq: StreamSeq, - head: usize, - len: usize, - slots: [Option; N], -} - -impl SeqRing { - pub fn new(base_seq: StreamSeq) -> Self { - Self { - base_seq, - head: 0, - len: 0, - slots: array::from_fn(|_| None), - } - } - - pub fn base_seq(&self) -> StreamSeq { - self.base_seq - } - - pub fn is_empty(&self) -> bool { - self.len == 0 - } - - pub fn len(&self) -> usize { - self.len - } - - pub fn clear_with_base(&mut self, base_seq: StreamSeq) { - for slot in &mut self.slots { - let _ = slot.take(); - } - self.base_seq = base_seq; - self.head = 0; - self.len = 0; - } - - pub fn contains_key(&self, seq: &StreamSeq) -> bool { - self.get(seq).is_some() - } - - pub fn accepts_seq(&self, seq: StreamSeq) -> bool { - self.offset_for(seq).is_some() - } - - pub fn get(&self, seq: &StreamSeq) -> Option<&T> { - let index = self.index_for(*seq)?; - self.slots[index].as_ref() - } - - pub fn get_mut(&mut self, seq: &StreamSeq) -> Option<&mut T> { - let index = self.index_for(*seq)?; - self.slots[index].as_mut() - } - - pub fn insert(&mut self, seq: StreamSeq, value: T) -> Result<(), SeqRingInsertError> { - let index = self.index_for(seq).ok_or(SeqRingInsertError::OutOfWindow)?; - if self.slots[index].is_some() { - return Err(SeqRingInsertError::Occupied); - } - self.slots[index] = Some(value); - self.len += 1; - Ok(()) - } - - pub fn set(&mut self, seq: StreamSeq, value: T) -> Result, SeqRingInsertError> { - let index = self.index_for(seq).ok_or(SeqRingInsertError::OutOfWindow)?; - let previous = self.slots[index].replace(value); - if previous.is_none() { - self.len += 1; - } - Ok(previous) - } - - pub fn remove(&mut self, seq: &StreamSeq) -> Option { - let index = self.index_for(*seq)?; - let value = self.slots[index].take(); - if value.is_some() { - self.len -= 1; - } - value - } - - pub fn take_front(&mut self) -> Option<(StreamSeq, T)> { - let value = self.slots[self.head].take()?; - let seq = self.base_seq; - self.len -= 1; - self.head = self.next_index(self.head); - self.base_seq = self.base_seq.next(); - Some((seq, value)) - } - - pub fn advance_empty_front_until(&mut self, limit_seq: StreamSeq) { - while self.base_seq.serial_lt(limit_seq) && self.slots[self.head].is_none() { - self.head = self.next_index(self.head); - self.base_seq = self.base_seq.next(); - } - } - - pub fn iter(&self) -> SeqRingIter<'_, N, T> { - SeqRingIter { - ring: self, - offset: 0, - } - } - - pub fn bitmap(&self) -> u8 { - debug_assert!(N <= 8); - let mut bitmap = 0u8; - for offset in 0..N { - let index = self.index_for_offset(offset); - if self.slots[index].is_some() { - bitmap |= 1u8 << offset; - } - } - bitmap - } - - fn index_for(&self, seq: StreamSeq) -> Option { - let offset = self.offset_for(seq)?; - Some(self.index_for_offset(offset)) - } - - fn offset_for(&self, seq: StreamSeq) -> Option { - let offset = self.base_seq.forward_distance_to(seq)? as usize; - (offset < N).then_some(offset) - } - - fn index_for_offset(&self, offset: usize) -> usize { - (self.head + offset) % N - } - - fn next_index(&self, index: usize) -> usize { - (index + 1) % N - } -} - -pub struct SeqRingIter<'a, const N: usize, T> { - ring: &'a SeqRing, - offset: usize, -} - -impl<'a, const N: usize, T> Iterator for SeqRingIter<'a, N, T> { - type Item = (StreamSeq, &'a T); - - fn next(&mut self) -> Option { - while self.offset < N { - let offset = self.offset; - self.offset += 1; - let index = self.ring.index_for_offset(offset); - if let Some(value) = self.ring.slots[index].as_ref() { - return Some((self.ring.base_seq.add(offset as u32), value)); - } - } - None - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn seq_ring_wraps_and_reuses_slots() { - let mut ring = SeqRing::<4, u64>::new(StreamSeq(1)); - ring.insert(StreamSeq(1), 1).unwrap(); - ring.insert(StreamSeq(2), 2).unwrap(); - ring.insert(StreamSeq(3), 3).unwrap(); - - assert_eq!(ring.take_front(), Some((StreamSeq(1), 1))); - assert_eq!(ring.take_front(), Some((StreamSeq(2), 2))); - - ring.insert(StreamSeq(4), 4).unwrap(); - ring.insert(StreamSeq(5), 5).unwrap(); - - let remaining: Vec<_> = ring.iter().map(|(seq, value)| (seq, *value)).collect(); - assert_eq!( - remaining, - vec![(StreamSeq(3), 3), (StreamSeq(4), 4), (StreamSeq(5), 5)] - ); - } -} diff --git a/ql-engine/src/stream/state.rs b/ql-engine/src/stream/state.rs deleted file mode 100644 index 402b7686..00000000 --- a/ql-engine/src/stream/state.rs +++ /dev/null @@ -1,532 +0,0 @@ -use std::{ - collections::{HashMap, VecDeque}, - time::{Duration, Instant}, -}; - -use super::{ - ring::SeqRing, StreamLocalRole, STREAM_ACK_EAGER_THRESHOLD, STREAM_WINDOW_CAPACITY, - STREAM_WINDOW_SIZE, -}; -use crate::{ - wire::{ - stream::{StreamAck, StreamFrame}, - StreamSeq, - }, - StreamId, -}; - -#[derive(Debug, Default)] -pub struct StreamStore { - streams: HashMap, - order: Vec, - cursor: usize, -} - -impl StreamStore { - pub fn contains_key(&self, stream_id: &StreamId) -> bool { - self.streams.contains_key(stream_id) - } - - pub fn insert(&mut self, stream_id: StreamId, stream: StreamState) -> Option { - if !self.streams.contains_key(&stream_id) { - self.order.push(stream_id); - } - self.streams.insert(stream_id, stream) - } - - pub fn get(&self, stream_id: &StreamId) -> Option<&StreamState> { - self.streams.get(stream_id) - } - - pub fn get_mut(&mut self, stream_id: &StreamId) -> Option<&mut StreamState> { - self.streams.get_mut(stream_id) - } - - pub fn remove(&mut self, stream_id: &StreamId) -> Option { - let removed = self.streams.remove(stream_id); - if removed.is_some() { - if let Some(index) = self.order.iter().position(|id| id == stream_id) { - self.order.remove(index); - if self.order.is_empty() { - self.cursor = 0; - } else if index < self.cursor { - self.cursor -= 1; - } else if self.cursor >= self.order.len() { - self.cursor = 0; - } - } - } - removed - } - - pub fn values(&self) -> impl Iterator { - self.streams.values() - } - - pub fn len(&self) -> usize { - self.order.len() - } - - pub fn id_at_offset(&self, offset: usize) -> Option { - let len = self.order.len(); - if len == 0 || offset >= len { - return None; - } - Some(self.order[(self.cursor + offset) % len]) - } - - pub fn ordered_id(&self, index: usize) -> Option { - self.order.get(index).copied() - } - - pub fn first_id(&self) -> Option { - self.order.first().copied() - } - - pub fn advance_cursor_after(&mut self, stream_id: StreamId) { - if let Some(index) = self.order.iter().position(|id| *id == stream_id) { - self.cursor = if self.order.is_empty() { - 0 - } else { - (index + 1) % self.order.len() - }; - } - } -} - -#[derive(Debug)] -pub struct StreamState { - pub control: StreamControl, - pub role: StreamRole, -} - -impl StreamState { - pub fn outbound_mut(&mut self, side: StreamSide) -> Option<&mut OutboundPhase> { - match &mut self.role { - StreamRole::Initiator(state) if side == StreamSide::Request => Some(&mut state.request), - StreamRole::Responder(state) if side == StreamSide::Response => { - Some(&mut state.response) - } - StreamRole::Initiator(_) | StreamRole::Responder(_) => None, - } - } - - pub fn inbound_mut(&mut self, side: StreamSide) -> Option<&mut InboundState> { - match &mut self.role { - StreamRole::Initiator(state) if side == StreamSide::Response => { - Some(&mut state.response) - } - StreamRole::Responder(state) if side == StreamSide::Request => Some(&mut state.request), - StreamRole::Initiator(_) | StreamRole::Responder(_) => None, - } - } - - pub fn outbound_side(&self) -> Option { - match self.role { - StreamRole::Initiator(_) => Some(StreamSide::Request), - StreamRole::Responder(_) => Some(StreamSide::Response), - } - } - - pub fn inbound_side(&self) -> Option { - match self.role { - StreamRole::Initiator(_) => Some(StreamSide::Response), - StreamRole::Responder(_) => Some(StreamSide::Request), - } - } - - pub fn awaiting_open(&self) -> bool { - matches!( - self.role, - StreamRole::Responder(ResponderStream { opened: false, .. }) - ) - } - - pub fn can_reap(&self) -> bool { - if !self.control.pending.is_empty() - || !self.control.in_flight.is_empty() - || !self.control.recv_buffer.is_empty() - || !matches!(self.control.ack_state, AckState::Idle) - { - return false; - } - - match self.role { - StreamRole::Initiator(state) => state.request.is_closed() && state.response.closed, - StreamRole::Responder(state) => state.request.closed && state.response.is_closed(), - } - } - - pub fn local_role(&self) -> StreamLocalRole { - match self.role { - StreamRole::Initiator(_) => StreamLocalRole::Initiator, - StreamRole::Responder(_) => StreamLocalRole::Responder, - } - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum StreamSide { - Request, - Response, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum OutboundPhase { - Ready, - FinPending, - FinQueued, - Closed, -} - -impl OutboundPhase { - pub fn from_prefix(fin: bool) -> Self { - if fin { - Self::FinQueued - } else { - Self::Ready - } - } - - pub fn is_closed(self) -> bool { - self == Self::Closed - } - - pub fn can_queue_data(self) -> bool { - self == Self::Ready - } - - pub fn finish(&mut self) { - *self = match *self { - Self::Ready | Self::FinPending => Self::FinPending, - Self::FinQueued => Self::FinQueued, - Self::Closed => Self::Closed, - }; - } - - pub fn queue_fin(&mut self) -> bool { - if *self != Self::FinPending { - return false; - } - *self = Self::FinQueued; - true - } - - pub fn close(&mut self) -> bool { - if *self == Self::Closed { - return false; - } - *self = Self::Closed; - true - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct InboundState { - pub closed: bool, -} - -impl InboundState { - pub fn new() -> Self { - Self { closed: false } - } - - pub fn close(&mut self) -> bool { - if self.closed { - return false; - } - self.closed = true; - true - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum InFlightWriteState { - Ready, - Issued { issue_id: u64 }, - WaitingRetry { retry_at: Instant }, -} - -#[derive(Debug)] -pub struct InFlightFrame { - pub tx_seq: StreamSeq, - pub frame: StreamFrame, - pub attempt: u8, - pub write_state: InFlightWriteState, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum BufferIncomingResult { - Duplicate, - AlreadyBuffered, - Buffered { out_of_order: bool }, - OutOfWindow, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum AckState { - Idle, - Delayed { due_at: Instant }, - Immediate, -} - -#[derive(Debug)] -pub struct StreamControl { - pub pending: VecDeque, - pub in_flight: SeqRing, - pub next_tx_seq: StreamSeq, - pub recv_buffer: SeqRing, - pub ack_state: AckState, - pub last_sent_ack_base: StreamSeq, - pub fast_recovery: Option, -} - -impl Default for StreamControl { - fn default() -> Self { - Self { - pending: VecDeque::new(), - in_flight: SeqRing::new(StreamSeq::START), - next_tx_seq: StreamSeq::START, - recv_buffer: SeqRing::new(StreamSeq::START), - ack_state: AckState::Idle, - last_sent_ack_base: StreamSeq(0), - fast_recovery: None, - } - } -} - -impl StreamControl { - pub fn take_tx_seq(&mut self) -> StreamSeq { - let tx_seq = self.next_tx_seq; - self.next_tx_seq = self.next_tx_seq.next(); - tx_seq - } - - pub fn send_window_has_space(&self) -> bool { - self.in_flight.accepts_seq(self.next_tx_seq) - } - - pub fn committed_rx_seq(&self) -> StreamSeq { - self.recv_buffer.base_seq().prev() - } - - pub fn note_ack(&mut self, now: Instant, ack_delay: Duration, immediate: bool) { - self.ack_state = match self.ack_state { - AckState::Immediate => AckState::Immediate, - AckState::Delayed { due_at } if !immediate && !ack_delay.is_zero() => { - AckState::Delayed { due_at } - } - _ if immediate || ack_delay.is_zero() => AckState::Immediate, - _ => AckState::Delayed { - due_at: now + ack_delay, - }, - }; - } - - pub fn clear_ack_schedule(&mut self) { - self.ack_state = AckState::Idle; - } - - pub fn maybe_force_ack_for_progress(&mut self) { - if matches!(self.ack_state, AckState::Idle) { - return; - } - let committed = self.committed_rx_seq(); - let progressed = self - .last_sent_ack_base - .forward_distance_to(committed) - .unwrap_or(0); - if progressed >= STREAM_ACK_EAGER_THRESHOLD { - self.ack_state = AckState::Immediate; - } - } - - pub fn note_ack_sent(&mut self, ack: StreamAck) { - if ack.base.serial_gt(self.last_sent_ack_base) { - self.last_sent_ack_base = ack.base; - } - } - - pub fn current_ack(&self) -> StreamAck { - StreamAck { - base: self.committed_rx_seq(), - bitmap: self.recv_buffer.bitmap(), - } - } - - pub fn take_piggyback_ack(&mut self, inbound_alive: bool) -> StreamAck { - if !inbound_alive || matches!(self.ack_state, AckState::Idle) { - return StreamAck::EMPTY; - } - let ack = self.current_ack(); - self.clear_ack_schedule(); - self.note_ack_sent(ack); - ack - } - - pub fn ack_deadline(&self) -> Option { - match self.ack_state { - AckState::Delayed { due_at } => Some(due_at), - AckState::Idle | AckState::Immediate => None, - } - } - - pub fn buffer_incoming( - &mut self, - tx_seq: StreamSeq, - frame: StreamFrame, - ) -> BufferIncomingResult { - if tx_seq.serial_lt(self.recv_buffer.base_seq()) { - return BufferIncomingResult::Duplicate; - } - if !self.recv_buffer.accepts_seq(tx_seq) { - return BufferIncomingResult::OutOfWindow; - } - if self.recv_buffer.contains_key(&tx_seq) { - return BufferIncomingResult::AlreadyBuffered; - } - - let out_of_order = tx_seq != self.recv_buffer.base_seq(); - let _ = self.recv_buffer.insert(tx_seq, frame); - BufferIncomingResult::Buffered { out_of_order } - } - - pub fn pop_next_committable(&mut self) -> Option<(StreamSeq, StreamFrame)> { - self.recv_buffer.take_front() - } - - pub fn insert_in_flight(&mut self, frame: InFlightFrame) { - let _ = self.in_flight.set(frame.tx_seq, frame); - } - - pub fn fast_retransmit_candidate(&self, ack: StreamAck, threshold: u8) -> Option { - if threshold == 0 { - return None; - } - - let hole = self - .in_flight - .iter() - .map(|(tx_seq, _)| tx_seq) - .find(|tx_seq| !Self::ack_covers(ack, *tx_seq))?; - - if self.fast_recovery == Some(hole) { - return None; - } - - let later_acked = self - .in_flight - .iter() - .map(|(tx_seq, _)| tx_seq) - .filter(|tx_seq| tx_seq.serial_gt(hole) && Self::ack_covers(ack, *tx_seq)) - .count(); - - (later_acked >= threshold as usize).then_some(hole) - } - - pub fn schedule_fast_retransmit(&mut self, tx_seq: StreamSeq, now: Instant) { - if let Some(in_flight) = self.in_flight.get_mut(&tx_seq) { - in_flight.write_state = InFlightWriteState::WaitingRetry { retry_at: now }; - self.fast_recovery = Some(tx_seq); - } - } - - pub fn mark_write_issued(&mut self, tx_seq: StreamSeq, issue_id: u64) -> Option { - let in_flight = self.in_flight.get_mut(&tx_seq)?; - match in_flight.write_state { - InFlightWriteState::Issued { .. } => return None, - InFlightWriteState::WaitingRetry { .. } => { - in_flight.attempt = in_flight.attempt.saturating_add(1); - } - InFlightWriteState::Ready => {} - } - in_flight.write_state = InFlightWriteState::Issued { issue_id }; - Some(in_flight.frame.clone()) - } - - pub fn frame_write_is_issued(&self, tx_seq: StreamSeq, issue_id: u64) -> bool { - matches!( - self.in_flight.get(&tx_seq).map(|in_flight| in_flight.write_state), - Some(InFlightWriteState::Issued { - issue_id: current_issue_id, - }) if current_issue_id == issue_id - ) - } - - pub fn complete_write(&mut self, tx_seq: StreamSeq, issue_id: u64, retry_at: Instant) -> bool { - let Some(in_flight) = self.in_flight.get_mut(&tx_seq) else { - return false; - }; - match in_flight.write_state { - InFlightWriteState::Issued { - issue_id: current_issue_id, - } if current_issue_id == issue_id => { - in_flight.write_state = InFlightWriteState::WaitingRetry { retry_at }; - true - } - InFlightWriteState::Ready - | InFlightWriteState::WaitingRetry { .. } - | InFlightWriteState::Issued { .. } => false, - } - } - - pub fn clear_fast_recovery(&mut self, ack_base: StreamSeq) { - let should_clear = self.fast_recovery.is_some_and(|tx_seq| { - tx_seq.serial_lte(ack_base) || !self.in_flight.contains_key(&tx_seq) - }); - if should_clear { - self.fast_recovery = None; - } - } - - pub fn remove_in_flight(&mut self, tx_seq: StreamSeq) -> Option { - let removed = self.in_flight.remove(&tx_seq); - self.in_flight.advance_empty_front_until(self.next_tx_seq); - if self.fast_recovery == Some(tx_seq) { - self.fast_recovery = None; - } - removed - } - - pub fn clear_transient_buffers(&mut self) { - self.pending.clear(); - self.in_flight.clear_with_base(self.next_tx_seq); - self.recv_buffer - .clear_with_base(self.committed_rx_seq().next()); - self.clear_ack_schedule(); - self.fast_recovery = None; - } - - pub fn ack_covers(ack: StreamAck, tx_seq: StreamSeq) -> bool { - if tx_seq.serial_lte(ack.base) { - return true; - } - let Some(delta) = ack.base.forward_distance_to(tx_seq) else { - return false; - }; - if !(1..=STREAM_WINDOW_SIZE).contains(&delta) { - return false; - } - (ack.bitmap & (1u8 << (delta - 1))) != 0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum StreamRole { - Initiator(InitiatorStream), - Responder(ResponderStream), -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct InitiatorStream { - pub request: OutboundPhase, - pub response: InboundState, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct ResponderStream { - pub opened: bool, - pub request: InboundState, - pub response: OutboundPhase, - pub response_started: bool, -} diff --git a/ql-engine/src/stream/tests.rs b/ql-engine/src/stream/tests.rs deleted file mode 100644 index 31a1f5de..00000000 --- a/ql-engine/src/stream/tests.rs +++ /dev/null @@ -1,334 +0,0 @@ -use std::time::Instant; - -use super::{ - Outbound, StreamCloseEvent, StreamCloseKind, StreamError, StreamEventSink, StreamFsm, - StreamFsmConfig, StreamLocalRole, StreamNamespace, WriteError, -}; -use crate::{ - wire::stream::{ - BodyChunk, CloseCode, CloseTarget, StreamAck, StreamAckBody, StreamBody, StreamFrame, - StreamFrameClose, StreamFrameData, StreamFrameOpen, StreamMessage, - }, - StreamId, -}; - -#[derive(Debug, Clone, PartialEq, Eq)] -struct OpenedStream { - stream_id: StreamId, - request_head: Vec, - request_prefix: Option, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -struct InboundChunk { - stream_id: StreamId, - bytes: Vec, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -struct StreamFailure { - stream_id: StreamId, - error: StreamError, -} - -#[derive(Debug, Default, Clone, PartialEq, Eq)] -struct Recorder { - opened: Vec, - closes: Vec, - inbound_data: Vec, - inbound_finished: Vec, - inbound_failed: Vec, - outbound_closed: Vec, - outbound_failed: Vec, - reaped: Vec, -} - -impl StreamEventSink for Recorder { - fn opened( - &mut self, - stream_id: StreamId, - request_head: Vec, - request_prefix: Option, - ) { - self.opened.push(OpenedStream { - stream_id, - request_head, - request_prefix, - }); - } - - fn inbound_data(&mut self, stream_id: StreamId, bytes: Vec) { - self.inbound_data.push(InboundChunk { stream_id, bytes }); - } - - fn inbound_finished(&mut self, stream_id: StreamId) { - self.inbound_finished.push(stream_id); - } - - fn inbound_failed(&mut self, stream_id: StreamId, error: StreamError) { - self.inbound_failed.push(StreamFailure { stream_id, error }); - } - - fn close(&mut self, event: StreamCloseEvent) { - self.closes.push(event); - } - - fn outbound_closed(&mut self, stream_id: StreamId) { - self.outbound_closed.push(stream_id); - } - - fn outbound_failed(&mut self, stream_id: StreamId, error: StreamError) { - self.outbound_failed - .push(StreamFailure { stream_id, error }); - } - - fn reaped(&mut self, stream_id: StreamId) { - self.reaped.push(stream_id); - } -} - -fn data_packet(stream_id: StreamId, tx_seq: u32, byte: u8) -> StreamBody { - StreamBody::Message(StreamMessage { - tx_seq: crate::wire::StreamSeq(tx_seq), - ack: StreamAck::EMPTY, - valid_until: 0, - frame: StreamFrame::Data(StreamFrameData { - stream_id, - chunk: BodyChunk { - bytes: vec![byte], - fin: false, - }, - }), - }) -} - -#[test] -fn open_stream_enqueues_open_packet() { - let now = Instant::now(); - let mut stream = StreamFsm::new(StreamFsmConfig::default()); - let stream_id = stream.open_stream(b"open".to_vec(), None); - - let outbound = stream.next_outbound(now, 7).unwrap(); - assert_open(outbound, stream_id, b"open", 7); -} - -#[test] -fn out_of_order_remote_stream_buffers_until_open_arrives() { - let now = Instant::now(); - let mut stream = StreamFsm::new(StreamFsmConfig { - local_namespace: StreamNamespace::Low, - ..Default::default() - }); - let stream_id = StreamId(StreamNamespace::High.bit() | 1); - - let mut events = Recorder::default(); - stream.receive(now, data_packet(stream_id, 2, b'h'), &mut events); - assert!(events.opened.is_empty()); - assert!(events.inbound_data.is_empty()); - - stream.receive( - now, - StreamBody::Message(StreamMessage { - tx_seq: crate::wire::StreamSeq::START, - ack: StreamAck::EMPTY, - valid_until: 0, - frame: StreamFrame::Open(StreamFrameOpen { - stream_id, - request_head: b"late-open".to_vec(), - request_prefix: None, - }), - }), - &mut events, - ); - - assert_eq!( - events.opened, - vec![OpenedStream { - stream_id, - request_head: b"late-open".to_vec(), - request_prefix: None, - }] - ); - assert_eq!( - events.inbound_data, - vec![InboundChunk { - stream_id, - bytes: vec![b'h'], - }] - ); -} - -#[test] -fn ack_only_write_failure_requeues_without_spending_sequence_space() { - let now = Instant::now(); - let config = StreamFsmConfig::default(); - let mut stream = StreamFsm::new(config); - let stream_id = StreamId(StreamNamespace::High.bit() | 1); - - let mut events = Recorder::default(); - stream.receive( - now, - StreamBody::Message(StreamMessage { - tx_seq: crate::wire::StreamSeq::START, - ack: StreamAck::EMPTY, - valid_until: 0, - frame: StreamFrame::Open(StreamFrameOpen { - stream_id, - request_head: b"open".to_vec(), - request_prefix: None, - }), - }), - &mut events, - ); - assert_eq!(events.opened.len(), 1); - - stream.on_timer(now + config.ack_delay, &mut ()); - let ack_write = stream.next_outbound(now + config.ack_delay, 11).unwrap(); - assert!(matches!( - ack_write.body, - StreamBody::Ack(StreamAckBody { - stream_id: id, - ack: StreamAck { - base: crate::wire::StreamSeq::START, - bitmap: 0, - }, - valid_until: 11, - }) if id == stream_id - )); - - stream.complete_outbound( - now + config.ack_delay, - ack_write.completion, - Err(WriteError::SendFailed), - &mut (), - ); - let retry = stream.next_outbound(now + config.ack_delay, 12).unwrap(); - assert!(matches!(retry.body, StreamBody::Ack(_))); - - stream.complete_outbound(now + config.ack_delay, retry.completion, Ok(()), &mut ()); - stream.write_stream(stream_id, b"resp".to_vec()).unwrap(); - let response = stream.next_outbound(now, 13).unwrap(); - assert!(matches!( - response.body, - StreamBody::Message(StreamMessage { - tx_seq: crate::wire::StreamSeq::START, - valid_until: 13, - frame: StreamFrame::Data(StreamFrameData { - stream_id: id, - chunk: BodyChunk { bytes, fin: false }, - }), - .. - }) if id == stream_id && bytes == b"resp" - )); -} - -#[test] -fn fast_retransmit_resends_oldest_gap_when_threshold_met() { - let now = Instant::now(); - let mut stream = StreamFsm::new(StreamFsmConfig { - fast_retransmit_threshold: 2, - ..Default::default() - }); - let stream_id = stream.open_stream(b"open".to_vec(), None); - let open = stream.next_outbound(now, 1).unwrap(); - stream.complete_outbound(now, open.completion, Ok(()), &mut ()); - stream.write_stream(stream_id, b"a".to_vec()).unwrap(); - stream.write_stream(stream_id, b"b".to_vec()).unwrap(); - stream.write_stream(stream_id, b"c".to_vec()).unwrap(); - stream.write_stream(stream_id, b"d".to_vec()).unwrap(); - let first = stream.next_outbound(now, 2).unwrap(); - let second = stream.next_outbound(now, 3).unwrap(); - let third = stream.next_outbound(now, 4).unwrap(); - let fourth = stream.next_outbound(now, 5).unwrap(); - stream.complete_outbound(now, first.completion, Ok(()), &mut ()); - stream.complete_outbound(now, second.completion, Ok(()), &mut ()); - stream.complete_outbound(now, third.completion, Ok(()), &mut ()); - stream.complete_outbound(now, fourth.completion, Ok(()), &mut ()); - - stream.receive( - now, - StreamBody::Ack(StreamAckBody { - stream_id, - ack: StreamAck { - base: crate::wire::StreamSeq(2), - bitmap: 0b0000_0110, - }, - valid_until: 0, - }), - &mut (), - ); - - let retransmit = stream.next_outbound(now, 6).unwrap(); - assert!(matches!( - retransmit.body, - StreamBody::Message(StreamMessage { - tx_seq: crate::wire::StreamSeq(3), - frame: StreamFrame::Data(_), - .. - }) - )); -} - -#[test] -fn late_failed_write_after_remote_close_ack_is_ignored() { - let now = Instant::now(); - let mut stream = StreamFsm::new(StreamFsmConfig::default()); - let stream_id = stream.open_stream(b"open".to_vec(), None); - let open = stream.next_outbound(now, 1).unwrap(); - - let mut events = Recorder::default(); - stream.receive( - now, - StreamBody::Message(StreamMessage { - tx_seq: crate::wire::StreamSeq::START, - ack: StreamAck { - base: crate::wire::StreamSeq::START, - bitmap: 0, - }, - valid_until: 0, - frame: StreamFrame::Close(StreamFrameClose { - stream_id, - target: CloseTarget::Both, - code: CloseCode::PROTOCOL, - payload: Vec::new(), - }), - }), - &mut events, - ); - assert_eq!( - events.closes, - vec![StreamCloseEvent { - kind: StreamCloseKind::Remote, - role: StreamLocalRole::Initiator, - frame: StreamFrameClose { - stream_id, - target: CloseTarget::Both, - code: CloseCode::PROTOCOL, - payload: Vec::new(), - }, - }] - ); - assert!(events.outbound_failed.is_empty()); - assert!(events.inbound_failed.is_empty()); - - let mut late = Recorder::default(); - stream.complete_outbound(now, open.completion, Err(WriteError::SendFailed), &mut late); - assert!(late.outbound_failed.is_empty()); - assert!(late.inbound_failed.is_empty()); -} - -fn assert_open(outbound: Outbound, stream_id: StreamId, request_head: &[u8], valid_until: u64) { - assert!(matches!( - outbound.body, - StreamBody::Message(StreamMessage { - tx_seq: crate::wire::StreamSeq::START, - ack: StreamAck::EMPTY, - valid_until: expires_at, - frame: StreamFrame::Open(StreamFrameOpen { - stream_id: id, - request_head: actual_head, - request_prefix: None, - }), - }) if id == stream_id && actual_head == request_head && expires_at == valid_until - )); -} diff --git a/ql-engine/src/wire/codec.rs b/ql-engine/src/wire/codec.rs deleted file mode 100644 index c9edf009..00000000 --- a/ql-engine/src/wire/codec.rs +++ /dev/null @@ -1,332 +0,0 @@ -use bc_components::{ - MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, MLKEMPublicKey, Nonce, MLDSA, MLKEM, XID, -}; -use rkyv::{ - rancor::{Fallible, Source}, - with::{ArchiveWith, DeserializeWith, SerializeWith}, - Archive, Archived, Deserialize, Place, Resolver, Serialize, -}; - -use crate::QlError; - -macro_rules! impl_wire_wrapper { - ($marker:ident, $external:ty, $wire:ty) => { - pub(crate) struct $marker; - - impl ArchiveWith<$external> for $marker { - type Archived = Archived<$wire>; - type Resolver = Resolver<$wire>; - - fn resolve_with( - field: &$external, - resolver: Self::Resolver, - out: Place, - ) { - <$wire>::from(field).resolve(resolver, out); - } - } - - impl SerializeWith<$external, S> for $marker - where - S: Fallible + ?Sized, - $wire: Serialize, - { - fn serialize_with( - field: &$external, - serializer: &mut S, - ) -> Result { - <$wire>::from(field).serialize(serializer) - } - } - - impl DeserializeWith, $external, D> for $marker - where - D: Fallible + ?Sized, - D::Error: Source, - Archived<$wire>: Deserialize<$wire, D>, - $wire: TryInto<$external, Error = QlError>, - { - fn deserialize_with( - field: &Archived<$wire>, - deserializer: &mut D, - ) -> Result<$external, D::Error> { - field - .deserialize(deserializer)? - .try_into() - .map_err(D::Error::new) - } - } - }; -} - -#[derive( - Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, -)] -pub(crate) struct WireXid(pub(crate) [u8; XID::XID_SIZE]); - -impl From<&XID> for WireXid { - fn from(value: &XID) -> Self { - Self(*value.data()) - } -} - -impl TryFrom for XID { - type Error = QlError; - - fn try_from(value: WireXid) -> Result { - Ok(XID::from_data(value.0)) - } -} - -impl From<&ArchivedWireXid> for XID { - fn from(value: &ArchivedWireXid) -> Self { - XID::from_data(value.0) - } -} - -impl_wire_wrapper!(AsWireXid, XID, WireXid); - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] -pub(crate) struct WireNonce(pub(crate) [u8; Nonce::NONCE_SIZE]); - -impl From<&Nonce> for WireNonce { - fn from(value: &Nonce) -> Self { - Self(*value.data()) - } -} - -impl TryFrom for Nonce { - type Error = QlError; - - fn try_from(value: WireNonce) -> Result { - Ok(Nonce::from_data(value.0)) - } -} - -impl From<&ArchivedWireNonce> for Nonce { - fn from(value: &ArchivedWireNonce) -> Self { - Nonce::from_data(value.0) - } -} - -impl_wire_wrapper!(AsWireNonce, Nonce, WireNonce); - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] -#[repr(u8)] -pub(crate) enum WireMlDsaLevel { - MlDsa44 = 2, - MlDsa65 = 3, - MlDsa87 = 5, -} - -impl TryFrom for MLDSA { - type Error = QlError; - - fn try_from(value: WireMlDsaLevel) -> Result { - Ok(match value { - WireMlDsaLevel::MlDsa44 => MLDSA::MLDSA44, - WireMlDsaLevel::MlDsa65 => MLDSA::MLDSA65, - WireMlDsaLevel::MlDsa87 => MLDSA::MLDSA87, - }) - } -} - -impl From for WireMlDsaLevel { - fn from(value: MLDSA) -> Self { - match value { - MLDSA::MLDSA44 => Self::MlDsa44, - MLDSA::MLDSA65 => Self::MlDsa65, - MLDSA::MLDSA87 => Self::MlDsa87, - } - } -} - -impl From<&ArchivedWireMlDsaLevel> for MLDSA { - fn from(value: &ArchivedWireMlDsaLevel) -> Self { - match value { - ArchivedWireMlDsaLevel::MlDsa44 => MLDSA::MLDSA44, - ArchivedWireMlDsaLevel::MlDsa65 => MLDSA::MLDSA65, - ArchivedWireMlDsaLevel::MlDsa87 => MLDSA::MLDSA87, - } - } -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] -#[repr(u8)] -pub(crate) enum WireMlKemLevel { - MlKem512 = 1, - MlKem768 = 2, - MlKem1024 = 3, -} - -impl TryFrom for MLKEM { - type Error = QlError; - - fn try_from(value: WireMlKemLevel) -> Result { - Ok(match value { - WireMlKemLevel::MlKem512 => MLKEM::MLKEM512, - WireMlKemLevel::MlKem768 => MLKEM::MLKEM768, - WireMlKemLevel::MlKem1024 => MLKEM::MLKEM1024, - }) - } -} - -impl From for WireMlKemLevel { - fn from(value: MLKEM) -> Self { - match value { - MLKEM::MLKEM512 => Self::MlKem512, - MLKEM::MLKEM768 => Self::MlKem768, - MLKEM::MLKEM1024 => Self::MlKem1024, - } - } -} - -impl From<&ArchivedWireMlKemLevel> for MLKEM { - fn from(value: &ArchivedWireMlKemLevel) -> Self { - match value { - ArchivedWireMlKemLevel::MlKem512 => MLKEM::MLKEM512, - ArchivedWireMlKemLevel::MlKem768 => MLKEM::MLKEM768, - ArchivedWireMlKemLevel::MlKem1024 => MLKEM::MLKEM1024, - } - } -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub(crate) struct WireMlDsaPublicKey { - pub(crate) level: WireMlDsaLevel, - pub(crate) bytes: Vec, -} - -impl TryFrom for MLDSAPublicKey { - type Error = QlError; - - fn try_from(value: WireMlDsaPublicKey) -> Result { - MLDSAPublicKey::from_bytes(value.level.try_into()?, &value.bytes) - .map_err(|_| QlError::InvalidPayload) - } -} - -impl From<&MLDSAPublicKey> for WireMlDsaPublicKey { - fn from(value: &MLDSAPublicKey) -> Self { - Self { - level: value.level().into(), - bytes: value.as_bytes().to_vec(), - } - } -} - -impl TryFrom<&ArchivedWireMlDsaPublicKey> for MLDSAPublicKey { - type Error = QlError; - - fn try_from(value: &ArchivedWireMlDsaPublicKey) -> Result { - MLDSAPublicKey::from_bytes((&value.level).into(), value.bytes.as_slice()) - .map_err(|_| QlError::InvalidPayload) - } -} - -impl_wire_wrapper!(AsWireMlDsaPublicKey, MLDSAPublicKey, WireMlDsaPublicKey); - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub(crate) struct WireMlDsaSignature { - pub(crate) level: WireMlDsaLevel, - pub(crate) bytes: Vec, -} - -impl TryFrom for MLDSASignature { - type Error = QlError; - - fn try_from(value: WireMlDsaSignature) -> Result { - MLDSASignature::from_bytes(value.level.try_into()?, &value.bytes) - .map_err(|_| QlError::InvalidPayload) - } -} - -impl From<&MLDSASignature> for WireMlDsaSignature { - fn from(value: &MLDSASignature) -> Self { - Self { - level: value.level().into(), - bytes: value.as_bytes().to_vec(), - } - } -} - -impl TryFrom<&ArchivedWireMlDsaSignature> for MLDSASignature { - type Error = QlError; - - fn try_from(value: &ArchivedWireMlDsaSignature) -> Result { - MLDSASignature::from_bytes((&value.level).into(), value.bytes.as_slice()) - .map_err(|_| QlError::InvalidPayload) - } -} - -impl_wire_wrapper!(AsWireMlDsaSignature, MLDSASignature, WireMlDsaSignature); - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub(crate) struct WireMlKemPublicKey { - pub(crate) level: WireMlKemLevel, - pub(crate) bytes: Vec, -} - -impl TryFrom for MLKEMPublicKey { - type Error = QlError; - - fn try_from(value: WireMlKemPublicKey) -> Result { - MLKEMPublicKey::from_bytes(value.level.try_into()?, &value.bytes) - .map_err(|_| QlError::InvalidPayload) - } -} - -impl From<&MLKEMPublicKey> for WireMlKemPublicKey { - fn from(value: &MLKEMPublicKey) -> Self { - Self { - level: value.level().into(), - bytes: value.as_bytes().to_vec(), - } - } -} - -impl TryFrom<&ArchivedWireMlKemPublicKey> for MLKEMPublicKey { - type Error = QlError; - - fn try_from(value: &ArchivedWireMlKemPublicKey) -> Result { - MLKEMPublicKey::from_bytes((&value.level).into(), value.bytes.as_slice()) - .map_err(|_| QlError::InvalidPayload) - } -} - -impl_wire_wrapper!(AsWireMlKemPublicKey, MLKEMPublicKey, WireMlKemPublicKey); - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub(crate) struct WireMlKemCiphertext { - pub(crate) level: WireMlKemLevel, - pub(crate) bytes: Vec, -} - -impl TryFrom for MLKEMCiphertext { - type Error = QlError; - - fn try_from(value: WireMlKemCiphertext) -> Result { - MLKEMCiphertext::from_bytes(value.level.try_into()?, &value.bytes) - .map_err(|_| QlError::InvalidPayload) - } -} - -impl From<&MLKEMCiphertext> for WireMlKemCiphertext { - fn from(value: &MLKEMCiphertext) -> Self { - Self { - level: value.level().into(), - bytes: value.as_bytes().to_vec(), - } - } -} - -impl TryFrom<&ArchivedWireMlKemCiphertext> for MLKEMCiphertext { - type Error = QlError; - - fn try_from(value: &ArchivedWireMlKemCiphertext) -> Result { - MLKEMCiphertext::from_bytes((&value.level).into(), value.bytes.as_slice()) - .map_err(|_| QlError::InvalidPayload) - } -} - -impl_wire_wrapper!(AsWireMlKemCiphertext, MLKEMCiphertext, WireMlKemCiphertext); diff --git a/ql-engine/src/wire/encrypted_message.rs b/ql-engine/src/wire/encrypted_message.rs deleted file mode 100644 index f79e7a8d..00000000 --- a/ql-engine/src/wire/encrypted_message.rs +++ /dev/null @@ -1,63 +0,0 @@ -use bc_components::SymmetricKey; -use chacha20poly1305::{AeadInPlace, ChaCha20Poly1305, KeyInit}; -use rkyv::{seal::Seal, vec::ArchivedVec, Archive, Deserialize, Serialize}; - -use crate::QlError; - -pub const NONCE_SIZE: usize = 12; -pub const AUTH_SIZE: usize = 16; - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub struct EncryptedMessage { - ciphertext: Vec, - nonce: [u8; NONCE_SIZE], - auth: [u8; AUTH_SIZE], -} - -impl EncryptedMessage { - pub fn encrypt( - key: &SymmetricKey, - mut plaintext: Vec, - aad: &[u8], - nonce: [u8; NONCE_SIZE], - ) -> Self { - let cipher = ChaCha20Poly1305::new(key.data().into()); - let auth = cipher - .encrypt_in_place_detached((&nonce).into(), aad, &mut plaintext) - .expect("chacha20poly1305 encryption should succeed"); - Self { - ciphertext: plaintext, - nonce, - auth: auth.into(), - } - } - - pub fn decrypt(&self, key: &SymmetricKey, aad: &[u8]) -> Result, QlError> { - let cipher = ChaCha20Poly1305::new(key.data().into()); - let mut plaintext = self.ciphertext.clone(); - cipher - .decrypt_in_place_detached( - (&self.nonce).into(), - aad, - &mut plaintext, - (&self.auth).into(), - ) - .map_err(|_| QlError::InvalidPayload)?; - Ok(plaintext) - } -} - -impl ArchivedEncryptedMessage { - pub fn decrypt(&mut self, key: &SymmetricKey, aad: &[u8]) -> Result<&[u8], QlError> { - let cipher = ChaCha20Poly1305::new(key.data().into()); - let nonce = self.nonce; - let auth = self.auth; - let ciphertext = ArchivedVec::as_slice_seal(Seal::new(&mut self.ciphertext)); - // SAFETY: decryption only overwrites initialized u8 bytes in place. - let ciphertext = unsafe { ciphertext.unseal_unchecked() }; - cipher - .decrypt_in_place_detached((&nonce).into(), aad, ciphertext, (&auth).into()) - .map_err(|_| QlError::InvalidPayload)?; - Ok(ciphertext) - } -} diff --git a/ql-engine/src/wire/handshake/crypto.rs b/ql-engine/src/wire/handshake/crypto.rs deleted file mode 100644 index cc79f7be..00000000 --- a/ql-engine/src/wire/handshake/crypto.rs +++ /dev/null @@ -1,344 +0,0 @@ -use bc_components::{ - Digest, MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, MLKEMPublicKey, Nonce, SymmetricKey, - XID, -}; -use rkyv::{Archive, Serialize}; - -use super::{ - verify_signature, ArchivedConfirm, ArchivedHello, ArchivedHelloReply, ArchivedReady, Confirm, - Hello, HelloReply, Ready, ReadyBody, -}; -use crate::{ - engine::QlCrypto, - identity::QlIdentity, - wire::{ - access_value, deserialize_value, encode_value, - encrypted_message::{EncryptedMessage, NONCE_SIZE}, - ensure_not_expired, AsWireMlKemCiphertext, AsWireNonce, AsWireXid, ControlMeta, QlHeader, - }, - QlError, -}; - -#[derive(Archive, Serialize)] -struct HelloProofData { - #[rkyv(with = AsWireXid)] - initiator: XID, - #[rkyv(with = AsWireXid)] - responder: XID, - meta: ControlMeta, - #[rkyv(with = AsWireNonce)] - nonce: Nonce, - #[rkyv(with = AsWireMlKemCiphertext)] - kem_ct: bc_components::MLKEMCiphertext, -} - -#[derive(Archive, Serialize)] -struct HandshakeTranscript { - #[rkyv(with = AsWireXid)] - initiator: XID, - #[rkyv(with = AsWireXid)] - responder: XID, - hello_meta: ControlMeta, - #[rkyv(with = AsWireNonce)] - initiator_nonce: Nonce, - #[rkyv(with = AsWireNonce)] - responder_nonce: Nonce, - reply_meta: ControlMeta, - #[rkyv(with = AsWireMlKemCiphertext)] - initiator_kem_ct: bc_components::MLKEMCiphertext, - #[rkyv(with = AsWireMlKemCiphertext)] - responder_kem_ct: bc_components::MLKEMCiphertext, -} - -#[derive(Archive, Serialize)] -struct ConfirmProofData { - meta: ControlMeta, - transcript: Vec, -} - -#[derive(Archive, Serialize)] -struct SessionKeyMaterial { - initiator_secret: Vec, - responder_secret: Vec, - transcript: Vec, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ResponderSecrets { - pub initiator_secret: SymmetricKey, - pub responder_secret: SymmetricKey, -} - -pub fn build_hello( - identity: &QlIdentity, - crypto: &impl QlCrypto, - recipient: XID, - recipient_encapsulation_key: &MLKEMPublicKey, - meta: ControlMeta, -) -> Result<(Hello, SymmetricKey), QlError> { - let nonce = next_nonce(crypto); - let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(); - let signature = identity.signing_private_key.sign(hello_proof_data( - identity.xid, - recipient, - &meta, - &nonce, - &kem_ct, - )); - Ok(( - Hello { - meta, - nonce, - kem_ct, - signature, - }, - session_key, - )) -} - -pub fn verify_hello( - initiator: XID, - responder: XID, - initiator_signing_key: &MLDSAPublicKey, - hello: &ArchivedHello, -) -> Result<(), QlError> { - let meta: ControlMeta = (&hello.meta).into(); - ensure_not_expired(meta.valid_until)?; - let signature = MLDSASignature::try_from(&hello.signature)?; - let nonce: Nonce = (&hello.nonce).into(); - let kem_ct = MLKEMCiphertext::try_from(&hello.kem_ct)?; - let proof_data = hello_proof_data(initiator, responder, &meta, &nonce, &kem_ct); - verify_signature(initiator_signing_key, &signature, &proof_data) -} - -pub fn respond_hello( - identity: &QlIdentity, - crypto: &impl QlCrypto, - initiator: XID, - initiator_signing_key: &MLDSAPublicKey, - initiator_encapsulation_key: &MLKEMPublicKey, - hello: &ArchivedHello, - meta: ControlMeta, -) -> Result<(HelloReply, ResponderSecrets), QlError> { - verify_hello(initiator, identity.xid, initiator_signing_key, hello)?; - let hello_meta: ControlMeta = (&hello.meta).into(); - let initiator_nonce: Nonce = (&hello.nonce).into(); - let initiator_kem_ct = MLKEMCiphertext::try_from(&hello.kem_ct)?; - let initiator_secret = identity - .encapsulation_private_key - .decapsulate_shared_secret(&initiator_kem_ct) - .map_err(|_| QlError::InvalidPayload)?; - let nonce = next_nonce(crypto); - let (responder_secret, kem_ct) = initiator_encapsulation_key.encapsulate_new_shared_secret(); - let transcript = handshake_transcript( - initiator, - identity.xid, - &hello_meta, - &initiator_nonce, - &initiator_kem_ct, - &meta, - &nonce, - &kem_ct, - ); - let signature = identity.signing_private_key.sign(&transcript); - let reply = HelloReply { - meta, - nonce, - kem_ct, - signature, - }; - Ok(( - reply, - ResponderSecrets { - initiator_secret, - responder_secret, - }, - )) -} - -pub fn build_confirm( - identity: &QlIdentity, - responder: XID, - responder_signing_key: &MLDSAPublicKey, - hello: &Hello, - reply: &ArchivedHelloReply, - initiator_secret: &SymmetricKey, - meta: ControlMeta, -) -> Result<(Confirm, SymmetricKey), QlError> { - let reply_meta: ControlMeta = (&reply.meta).into(); - ensure_not_expired(reply_meta.valid_until)?; - let reply_nonce: Nonce = (&reply.nonce).into(); - let reply_kem_ct = MLKEMCiphertext::try_from(&reply.kem_ct)?; - let reply_signature = MLDSASignature::try_from(&reply.signature)?; - let transcript = handshake_transcript( - identity.xid, - responder, - &hello.meta, - &hello.nonce, - &hello.kem_ct, - &reply_meta, - &reply_nonce, - &reply_kem_ct, - ); - verify_signature(responder_signing_key, &reply_signature, &transcript)?; - let responder_secret = identity - .encapsulation_private_key - .decapsulate_shared_secret(&reply_kem_ct) - .map_err(|_| QlError::InvalidPayload)?; - let signature = identity - .signing_private_key - .sign(confirm_proof_data(&meta, &transcript)); - let confirm = Confirm { meta, signature }; - let session_key = derive_session_key(initiator_secret, &responder_secret, &transcript); - Ok((confirm, session_key)) -} - -pub fn finalize_confirm( - initiator: XID, - responder: XID, - initiator_signing_key: &MLDSAPublicKey, - hello: &Hello, - reply: &HelloReply, - confirm: &ArchivedConfirm, - secrets: &ResponderSecrets, -) -> Result { - verify_confirm( - initiator, - responder, - initiator_signing_key, - hello, - reply, - confirm, - )?; - Ok(derive_session_key( - &secrets.initiator_secret, - &secrets.responder_secret, - &handshake_transcript( - initiator, - responder, - &hello.meta, - &hello.nonce, - &hello.kem_ct, - &reply.meta, - &reply.nonce, - &reply.kem_ct, - ), - )) -} - -pub fn verify_confirm( - initiator: XID, - responder: XID, - initiator_signing_key: &MLDSAPublicKey, - hello: &Hello, - reply: &HelloReply, - confirm: &ArchivedConfirm, -) -> Result<(), QlError> { - let confirm_meta: ControlMeta = (&confirm.meta).into(); - ensure_not_expired(confirm_meta.valid_until)?; - let confirm_signature = MLDSASignature::try_from(&confirm.signature)?; - let transcript = handshake_transcript( - initiator, - responder, - &hello.meta, - &hello.nonce, - &hello.kem_ct, - &reply.meta, - &reply.nonce, - &reply.kem_ct, - ); - let proof_data = confirm_proof_data(&confirm_meta, &transcript); - verify_signature(initiator_signing_key, &confirm_signature, &proof_data)?; - Ok(()) -} - -pub fn build_ready( - header: QlHeader, - session_key: &SymmetricKey, - meta: ControlMeta, - nonce: [u8; NONCE_SIZE], -) -> Ready { - let aad = header.aad(); - let body_bytes = encode_value(&ReadyBody { meta }); - Ready { - encrypted: EncryptedMessage::encrypt(session_key, body_bytes, &aad, nonce), - } -} - -pub fn decrypt_ready( - header: &QlHeader, - ready: &mut ArchivedReady, - session_key: &SymmetricKey, -) -> Result { - let aad = header.aad(); - let plaintext = ready.encrypted.decrypt(session_key, &aad)?; - let body = access_value::(plaintext)?; - let body = deserialize_value(body)?; - ensure_not_expired(body.meta.valid_until)?; - Ok(body) -} - -fn handshake_transcript( - initiator: XID, - responder: XID, - hello_meta: &ControlMeta, - initiator_nonce: &Nonce, - initiator_kem_ct: &bc_components::MLKEMCiphertext, - reply_meta: &ControlMeta, - responder_nonce: &Nonce, - responder_kem_ct: &bc_components::MLKEMCiphertext, -) -> Vec { - encode_value(&HandshakeTranscript { - initiator, - responder, - hello_meta: *hello_meta, - initiator_nonce: initiator_nonce.clone(), - responder_nonce: responder_nonce.clone(), - reply_meta: *reply_meta, - initiator_kem_ct: initiator_kem_ct.clone(), - responder_kem_ct: responder_kem_ct.clone(), - }) -} - -fn hello_proof_data( - initiator: XID, - responder: XID, - meta: &ControlMeta, - nonce: &Nonce, - kem_ct: &bc_components::MLKEMCiphertext, -) -> Vec { - encode_value(&HelloProofData { - initiator, - responder, - meta: *meta, - nonce: nonce.clone(), - kem_ct: kem_ct.clone(), - }) -} - -fn confirm_proof_data(meta: &ControlMeta, transcript: &[u8]) -> Vec { - encode_value(&ConfirmProofData { - meta: *meta, - transcript: transcript.to_vec(), - }) -} - -fn next_nonce(platform: &impl QlCrypto) -> Nonce { - let mut data = [0u8; Nonce::NONCE_SIZE]; - platform.fill_random_bytes(&mut data); - Nonce::from_data(data) -} - -fn derive_session_key( - initiator_secret: &SymmetricKey, - responder_secret: &SymmetricKey, - transcript: &[u8], -) -> SymmetricKey { - let payload = encode_value(&SessionKeyMaterial { - initiator_secret: initiator_secret.as_bytes().to_vec(), - responder_secret: responder_secret.as_bytes().to_vec(), - transcript: transcript.to_vec(), - }); - let digest = Digest::from_image(payload); - SymmetricKey::from_data(*digest.data()) -} diff --git a/ql-engine/src/wire/handshake/mod.rs b/ql-engine/src/wire/handshake/mod.rs deleted file mode 100644 index a049af62..00000000 --- a/ql-engine/src/wire/handshake/mod.rs +++ /dev/null @@ -1,69 +0,0 @@ -use bc_components::{MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, Nonce}; -use rkyv::{Archive, Deserialize, Serialize}; - -use super::{ - encrypted_message::EncryptedMessage, AsWireMlDsaSignature, AsWireMlKemCiphertext, AsWireNonce, - ControlMeta, -}; -use crate::QlError; - -mod crypto; -pub use crypto::*; - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub enum HandshakeRecord { - Hello(Hello), - HelloReply(HelloReply), - Confirm(Confirm), - Ready(Ready), -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Hello { - pub meta: ControlMeta, - #[rkyv(with = AsWireNonce)] - pub nonce: Nonce, - #[rkyv(with = AsWireMlKemCiphertext)] - pub kem_ct: MLKEMCiphertext, - #[rkyv(with = AsWireMlDsaSignature)] - pub signature: MLDSASignature, -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct HelloReply { - pub meta: ControlMeta, - #[rkyv(with = AsWireNonce)] - pub nonce: Nonce, - #[rkyv(with = AsWireMlKemCiphertext)] - pub kem_ct: MLKEMCiphertext, - #[rkyv(with = AsWireMlDsaSignature)] - pub signature: MLDSASignature, -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Confirm { - pub meta: ControlMeta, - #[rkyv(with = AsWireMlDsaSignature)] - pub signature: MLDSASignature, -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Ready { - pub encrypted: EncryptedMessage, -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct ReadyBody { - pub meta: ControlMeta, -} - -pub fn verify_signature( - signing_key: &MLDSAPublicKey, - signature: &MLDSASignature, - proof_data: &[u8], -) -> Result<(), QlError> { - match signing_key.verify(signature, proof_data) { - Ok(true) => Ok(()), - _ => Err(QlError::InvalidSignature), - } -} diff --git a/ql-engine/src/wire/heartbeat/crypto.rs b/ql-engine/src/wire/heartbeat/crypto.rs deleted file mode 100644 index 0002e979..00000000 --- a/ql-engine/src/wire/heartbeat/crypto.rs +++ /dev/null @@ -1,39 +0,0 @@ -use bc_components::SymmetricKey; - -use super::HeartbeatBody; -use crate::{ - wire::{ - access_value, deserialize_value, encode_value, - encrypted_message::{ArchivedEncryptedMessage, EncryptedMessage, NONCE_SIZE}, - ensure_not_expired, QlHeader, QlPayload, QlRecord, - }, - QlError, -}; - -pub fn encrypt_heartbeat( - header: QlHeader, - session_key: &SymmetricKey, - body: HeartbeatBody, - nonce: [u8; NONCE_SIZE], -) -> QlRecord { - let aad = header.aad(); - let body_bytes = encode_value(&body); - let encrypted = EncryptedMessage::encrypt(session_key, body_bytes, &aad, nonce); - QlRecord { - header, - payload: QlPayload::Heartbeat(encrypted), - } -} - -pub(crate) fn decrypt_heartbeat( - header: &QlHeader, - encrypted: &mut ArchivedEncryptedMessage, - session_key: &SymmetricKey, -) -> Result { - let aad = header.aad(); - let plaintext = encrypted.decrypt(session_key, &aad)?; - let body = access_value::(plaintext)?; - let body = deserialize_value(body)?; - ensure_not_expired(body.meta.valid_until)?; - Ok(body) -} diff --git a/ql-engine/src/wire/heartbeat/mod.rs b/ql-engine/src/wire/heartbeat/mod.rs deleted file mode 100644 index 8a7810f6..00000000 --- a/ql-engine/src/wire/heartbeat/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -use rkyv::{Archive, Deserialize, Serialize}; - -use super::ControlMeta; - -mod crypto; -pub use crypto::*; - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct HeartbeatBody { - pub meta: ControlMeta, -} diff --git a/ql-engine/src/wire/id.rs b/ql-engine/src/wire/id.rs deleted file mode 100644 index 1c32f62c..00000000 --- a/ql-engine/src/wire/id.rs +++ /dev/null @@ -1,44 +0,0 @@ -use std::fmt; - -use rkyv::{Archive, Deserialize, Serialize}; - -macro_rules! define_id { - ($name:ident, $ty:ty) => { - #[derive( - Archive, - Serialize, - Deserialize, - Debug, - Clone, - Copy, - PartialEq, - Eq, - Hash, - PartialOrd, - Ord, - )] - #[repr(transparent)] - pub struct $name(pub $ty); - - impl fmt::Display for $name { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } - } - }; -} - -define_id!(PacketId, u32); -define_id!(StreamId, u32); - -impl From<&ArchivedPacketId> for PacketId { - fn from(value: &ArchivedPacketId) -> Self { - Self(value.0.to_native()) - } -} - -impl From<&ArchivedStreamId> for StreamId { - fn from(value: &ArchivedStreamId) -> Self { - Self(value.0.to_native()) - } -} diff --git a/ql-engine/src/wire/mod.rs b/ql-engine/src/wire/mod.rs deleted file mode 100644 index bc889e58..00000000 --- a/ql-engine/src/wire/mod.rs +++ /dev/null @@ -1,473 +0,0 @@ -//! quantum link protocol wire format -//! -//! naming conventions: -//! - *Record - unencrypted messages -//! - *Body - message content after decrypting -//! - -use bc_components::XID; -use rkyv::{ - api::{ - high::{to_bytes_in, HighSerializer, HighValidator}, - low::{self, LowDeserializer}, - }, - bytecheck::CheckBytes, - ser::allocator::ArenaHandle, - Archive, Deserialize, Portable, Serialize, -}; - -pub mod encrypted_message; -pub mod handshake; -pub mod heartbeat; -mod id; -pub mod pair; -pub mod seq; -pub mod stream; -pub mod unpair; - -pub use id::*; -pub use seq::StreamSeq; - -mod codec; - -pub(crate) use codec::*; - -use self::{ - encrypted_message::EncryptedMessage, handshake::HandshakeRecord, pair::PairRequestRecord, - unpair::UnpairRecord, -}; -use crate::QlError; - -pub(crate) type WireArchiveError = rkyv::rancor::Error; - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct QlRecord { - pub header: QlHeader, - pub payload: QlPayload, -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct QlHeader { - #[rkyv(with = AsWireXid)] - pub sender: XID, - #[rkyv(with = AsWireXid)] - pub recipient: XID, -} - -impl QlHeader { - pub fn aad(&self) -> Vec { - encode_value(self) - } -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] -pub struct ControlMeta { - pub packet_id: PacketId, - pub valid_until: u64, -} - -impl From<&ArchivedControlMeta> for ControlMeta { - fn from(value: &ArchivedControlMeta) -> Self { - Self { - packet_id: (&value.packet_id).into(), - valid_until: value.valid_until.to_native(), - } - } -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub enum QlPayload { - Handshake(HandshakeRecord), - Pair(PairRequestRecord), - Unpair(UnpairRecord), - Heartbeat(EncryptedMessage), - Stream(EncryptedMessage), -} - -pub fn encode_record(record: &QlRecord) -> Vec { - encode_value(record) -} - -pub fn access_record(bytes: &[u8]) -> Result<&ArchivedQlRecord, QlError> { - access_value(bytes) -} - -pub fn decode_record(bytes: &[u8]) -> Result { - deserialize_value(access_record(bytes)?) -} - -pub(crate) fn encode_value( - value: &impl for<'a> Serialize, ArenaHandle<'a>, WireArchiveError>>, -) -> Vec { - to_bytes_in::<_, WireArchiveError>(value, Vec::new()) - .expect("wire serialization should not fail") -} - -pub(crate) fn access_value(bytes: &[u8]) -> Result<&T, QlError> -where - T: Portable + for<'a> CheckBytes>, -{ - rkyv::access::(bytes).map_err(|_| QlError::InvalidPayload) -} - -pub(crate) fn deserialize_value( - value: &impl rkyv::Deserialize>, -) -> Result { - low::deserialize::(value).map_err(|_| QlError::InvalidPayload) -} - -pub(crate) fn ensure_not_expired(valid_until: u64) -> Result<(), QlError> { - if now_secs() > valid_until { - Err(QlError::Timeout) - } else { - Ok(()) - } -} - -pub(crate) fn now_secs() -> u64 { - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|duration| duration.as_secs()) - .unwrap_or(0) -} - -#[test] -fn ql_record_round_trip() { - let record = QlRecord { - header: QlHeader { - sender: XID::from_data([1; XID::XID_SIZE]), - recipient: XID::from_data([2; XID::XID_SIZE]), - }, - payload: QlPayload::Heartbeat(encrypted_message::EncryptedMessage::encrypt( - &bc_components::SymmetricKey::from_data( - [7; bc_components::SymmetricKey::SYMMETRIC_KEY_SIZE], - ), - vec![3u8, 4, 5], - b"roundtrip", - [8; encrypted_message::NONCE_SIZE], - )), - }; - - let bytes = encode_record(&record); - let decoded = decode_record(&bytes).unwrap(); - - assert_eq!(decoded, record); -} - -#[cfg(test)] -mod test { - use super::*; - use crate::{engine::QlCrypto, identity::QlIdentity}; - - struct SizeTestCrypto(std::sync::atomic::AtomicU8); - - impl SizeTestCrypto { - fn new(seed: u8) -> Self { - Self(std::sync::atomic::AtomicU8::new(seed)) - } - } - - impl QlCrypto for SizeTestCrypto { - fn fill_random_bytes(&self, data: &mut [u8]) { - let seed = self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - for (index, byte) in data.iter_mut().enumerate() { - *byte = seed.wrapping_add(index as u8); - } - } - } - - fn size_test_identity() -> QlIdentity { - use bc_components::{MLDSA, MLKEM}; - - let (signing_private_key, signing_public_key) = MLDSA::MLDSA44.keypair(); - let (encapsulation_private_key, encapsulation_public_key) = MLKEM::MLKEM512.keypair(); - QlIdentity::from_keys( - signing_private_key, - signing_public_key, - encapsulation_private_key, - encapsulation_public_key, - ) - } - - fn size_test_meta(packet_id: u32) -> ControlMeta { - ControlMeta { - packet_id: PacketId(packet_id), - valid_until: now_secs().saturating_add(60), - } - } - - #[test] - fn protocol_record_size_breakdown() { - use crate::{ - wire::{handshake::HandshakeRecord, heartbeat::HeartbeatBody}, - StreamId, - }; - - let identity_a = size_test_identity(); - let identity_b = size_test_identity(); - let crypto_a = SizeTestCrypto::new(1); - let crypto_b = SizeTestCrypto::new(2); - - let initiator = identity_a.xid; - let responder = identity_b.xid; - - let (hello, initiator_secret) = handshake::build_hello( - &identity_a, - &crypto_a, - responder, - &identity_b.encapsulation_public_key, - size_test_meta(1), - ) - .unwrap(); - let hello_record = QlRecord { - header: QlHeader { - sender: initiator, - recipient: responder, - }, - payload: QlPayload::Handshake(HandshakeRecord::Hello(hello.clone())), - }; - let hello_size = encode_record(&hello_record).len(); - let hello_bytes = encode_value(&hello); - let hello_view = access_value::(&hello_bytes).unwrap(); - - let (hello_reply, responder_secrets) = handshake::respond_hello( - &identity_b, - &crypto_b, - initiator, - &identity_a.signing_public_key, - &identity_a.encapsulation_public_key, - hello_view, - size_test_meta(2), - ) - .unwrap(); - let reply_record = QlRecord { - header: QlHeader { - sender: responder, - recipient: initiator, - }, - payload: QlPayload::Handshake(HandshakeRecord::HelloReply(hello_reply.clone())), - }; - let reply_size = encode_record(&reply_record).len(); - let reply_bytes = encode_value(&hello_reply); - let reply_view = access_value::(&reply_bytes).unwrap(); - - let (confirm, session_key) = handshake::build_confirm( - &identity_a, - responder, - &identity_b.signing_public_key, - &hello, - reply_view, - &initiator_secret, - size_test_meta(3), - ) - .unwrap(); - let confirm_size = encode_record(&QlRecord { - header: QlHeader { - sender: initiator, - recipient: responder, - }, - payload: QlPayload::Handshake(HandshakeRecord::Confirm(confirm.clone())), - }) - .len(); - - let confirm_bytes = encode_value(&confirm); - let confirm_view = access_value::(&confirm_bytes).unwrap(); - let _session_key_b = handshake::finalize_confirm( - initiator, - responder, - &identity_a.signing_public_key, - &hello, - &hello_reply, - confirm_view, - &responder_secrets, - ) - .unwrap(); - - let pair_size = encode_record( - &pair::build_pair_request( - &identity_a, - &crypto_a, - responder, - &identity_b.encapsulation_public_key, - size_test_meta(11), - ) - .unwrap(), - ) - .len(); - - let heartbeat_size = encode_record(&heartbeat::encrypt_heartbeat( - QlHeader { - sender: initiator, - recipient: responder, - }, - &session_key, - HeartbeatBody { - meta: size_test_meta(12), - }, - [12; encrypted_message::NONCE_SIZE], - )) - .len(); - - let unpair_size = encode_record(&unpair::build_unpair_record( - &identity_a, - QlHeader { - sender: initiator, - recipient: responder, - }, - size_test_meta(13), - )) - .len(); - - let stream_header = QlHeader { - sender: initiator, - recipient: responder, - }; - let stream_record_size = |body: &stream::StreamBody, nonce: u8| { - encode_record(&stream::encrypt_stream( - stream_header.clone(), - &session_key, - body, - [nonce; encrypted_message::NONCE_SIZE], - )) - .len() - }; - - let stream_ack_body = stream::StreamBody::Ack(stream::StreamAckBody { - stream_id: StreamId(2), - ack: stream::StreamAck { - base: StreamSeq(19), - bitmap: 0b0000_0110, - }, - valid_until: now_secs().saturating_add(60), - }); - let stream_ack_record = stream::encrypt_stream( - stream_header.clone(), - &session_key, - &stream_ack_body, - [20; encrypted_message::NONCE_SIZE], - ); - let stream_ack_encrypted = match &stream_ack_record.payload { - QlPayload::Stream(encrypted) => encrypted, - _ => unreachable!(), - }; - let stream_ack_header_size = encode_value(&stream_header).len(); - let stream_ack_body_size = encode_value(&stream_ack_body).len(); - let stream_ack_envelope_size = encode_value(stream_ack_encrypted).len(); - let stream_ack_payload_size = encode_value(&stream_ack_record.payload).len(); - - let stream_open_body = stream::StreamBody::Message(stream::StreamMessage { - tx_seq: StreamSeq(21), - ack: stream::StreamAck::EMPTY, - valid_until: now_secs().saturating_add(60), - frame: stream::StreamFrame::Open(stream::StreamFrameOpen { - stream_id: StreamId(2), - request_head: vec![1, 2, 3], - request_prefix: Some(stream::BodyChunk { - bytes: vec![9, 9, 9], - fin: false, - }), - }), - }); - let stream_open_body_size = encode_value(&stream_open_body).len(); - - let stream_message_no_ack = stream::StreamBody::Message(stream::StreamMessage { - tx_seq: StreamSeq(20), - ack: stream::StreamAck::EMPTY, - valid_until: now_secs().saturating_add(60), - frame: stream::StreamFrame::Data(stream::StreamFrameData { - stream_id: StreamId(2), - chunk: stream::BodyChunk { - bytes: vec![7, 8, 9, 10], - fin: false, - }, - }), - }); - let stream_message_with_ack = stream::StreamBody::Message(stream::StreamMessage { - tx_seq: StreamSeq(20), - ack: stream::StreamAck { - base: StreamSeq(19), - bitmap: 0b0000_0110, - }, - valid_until: now_secs().saturating_add(60), - frame: stream::StreamFrame::Data(stream::StreamFrameData { - stream_id: StreamId(2), - chunk: stream::BodyChunk { - bytes: vec![7, 8, 9, 10], - fin: false, - }, - }), - }); - - let stream_ack_size = stream_record_size(&stream_ack_body, 20); - let stream_open_size = stream_record_size(&stream_open_body, 21); - let stream_data_no_ack_size = stream_record_size(&stream_message_no_ack, 24); - let stream_data_with_ack_size = stream_record_size(&stream_message_with_ack, 25); - let stream_fin_size = stream_record_size( - &stream::StreamBody::Message(stream::StreamMessage { - tx_seq: StreamSeq(26), - ack: stream::StreamAck::EMPTY, - valid_until: now_secs().saturating_add(60), - frame: stream::StreamFrame::Data(stream::StreamFrameData { - stream_id: StreamId(2), - chunk: stream::BodyChunk { - bytes: Vec::new(), - fin: true, - }, - }), - }), - 26, - ); - let stream_reset_size = stream_record_size( - &stream::StreamBody::Message(stream::StreamMessage { - tx_seq: StreamSeq(27), - ack: stream::StreamAck::EMPTY, - valid_until: now_secs().saturating_add(60), - frame: stream::StreamFrame::Close(stream::StreamFrameClose { - stream_id: StreamId(2), - target: stream::CloseTarget::Both, - code: stream::CloseCode::PROTOCOL, - payload: Vec::new(), - }), - }), - 27, - ); - - let print_size = |label: &str, size: usize| { - println!("{label:<28}: {size} bytes"); - }; - - print_size("ql2 size hello", hello_size); - print_size("ql2 size hello_reply", reply_size); - print_size("ql2 size confirm", confirm_size); - print_size("ql2 size pair", pair_size); - print_size("ql2 size heartbeat", heartbeat_size); - print_size("ql2 size unpair", unpair_size); - print_size("ql2 size stream ack-only", stream_ack_size); - print_size("ql2 size stream open", stream_open_size); - print_size("ql2 size stream data no ack", stream_data_no_ack_size); - print_size("ql2 size stream data w/ ack", stream_data_with_ack_size); - print_size("ql2 size stream fin", stream_fin_size); - print_size("ql2 size stream reset", stream_reset_size); - println!( - "ql2 stream ack breakdown : header={} aad={} plaintext={} envelope={} payload={} full={}", - stream_ack_header_size, - stream_header.aad().len(), - stream_ack_body_size, - stream_ack_envelope_size, - stream_ack_payload_size, - stream_ack_size, - ); - println!( - "ql2 stream open delta : open_body={} ack_body={} (+{} bytes)", - stream_open_body_size, - stream_ack_body_size, - stream_open_body_size.saturating_sub(stream_ack_body_size), - ); - println!( - "ql2 stream data ack delta : no_ack={} with_ack={} (+{} bytes)", - stream_data_no_ack_size, - stream_data_with_ack_size, - stream_data_with_ack_size.saturating_sub(stream_data_no_ack_size), - ); - } -} diff --git a/ql-engine/src/wire/pair/crypto.rs b/ql-engine/src/wire/pair/crypto.rs deleted file mode 100644 index ab7e4033..00000000 --- a/ql-engine/src/wire/pair/crypto.rs +++ /dev/null @@ -1,139 +0,0 @@ -use bc_components::{ - MLDSAPublicKey, MLKEMCiphertext, MLKEMPublicKey, SigningPublicKey, SymmetricKey, XID, -}; -use rkyv::{Archive, Serialize}; - -use super::{PairRequestBody, PairRequestRecord}; -use crate::{ - engine::QlCrypto, - identity::QlIdentity, - wire::{ - access_value, deserialize_value, encode_value, - encrypted_message::{ArchivedEncryptedMessage, EncryptedMessage, NONCE_SIZE}, - ensure_not_expired, AsWireMlDsaPublicKey, AsWireMlKemCiphertext, AsWireMlKemPublicKey, - ControlMeta, QlHeader, QlPayload, QlRecord, - }, - QlError, -}; - -#[derive(Archive, Serialize)] -struct PairingAad { - header: QlHeader, - #[rkyv(with = AsWireMlKemCiphertext)] - kem_ct: MLKEMCiphertext, -} - -#[derive(Archive, Serialize)] -struct PairingProofData { - aad: Vec, - meta: ControlMeta, - #[rkyv(with = AsWireMlDsaPublicKey)] - signing_pub_key: MLDSAPublicKey, - #[rkyv(with = AsWireMlKemPublicKey)] - encapsulation_pub_key: MLKEMPublicKey, -} - -pub fn build_pair_request( - identity: &QlIdentity, - crypto: &impl QlCrypto, - recipient: XID, - recipient_encapsulation_key: &MLKEMPublicKey, - meta: ControlMeta, -) -> Result { - let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(); - let header = QlHeader { - sender: identity.xid, - recipient, - }; - let signing_pub_key = identity.signing_public_key.clone(); - let sender_encapsulation_key = identity.encapsulation_public_key.clone(); - let proof_data = pairing_proof_data( - &header, - &kem_ct, - &meta, - &signing_pub_key, - &sender_encapsulation_key, - ); - let proof = identity.signing_private_key.sign(&proof_data); - let body = PairRequestBody { - meta, - signing_pub_key, - encapsulation_pub_key: sender_encapsulation_key, - proof, - }; - let body_bytes = encode_value(&body); - let aad = pairing_aad(&header, &kem_ct); - let mut nonce = [0u8; NONCE_SIZE]; - crypto.fill_random_bytes(&mut nonce); - let encrypted = EncryptedMessage::encrypt(&session_key, body_bytes, &aad, nonce); - Ok(QlRecord { - header, - payload: QlPayload::Pair(PairRequestRecord { kem_ct, encrypted }), - }) -} - -pub fn decrypt_pair_request( - identity: &QlIdentity, - header: &QlHeader, - request: &mut super::ArchivedPairRequestRecord, -) -> Result { - let kem_ct = MLKEMCiphertext::try_from(&request.kem_ct)?; - let aad = pairing_aad(header, &kem_ct); - let session_key = identity - .encapsulation_private_key - .decapsulate_shared_secret(&kem_ct) - .map_err(|_| QlError::InvalidPayload)?; - let decrypted = decrypt_body(&session_key, &mut request.encrypted, &aad)?; - ensure_not_expired(decrypted.meta.valid_until)?; - if XID::new(SigningPublicKey::MLDSA(decrypted.signing_pub_key.clone())) != header.sender { - return Err(QlError::InvalidPayload); - } - let proof_data = pairing_proof_data( - header, - &kem_ct, - &decrypted.meta, - &decrypted.signing_pub_key, - &decrypted.encapsulation_pub_key, - ); - if decrypted - .signing_pub_key - .verify(&decrypted.proof, &proof_data) - .unwrap_or(false) - { - Ok(decrypted) - } else { - Err(QlError::InvalidSignature) - } -} - -fn pairing_proof_data( - header: &QlHeader, - kem_ct: &MLKEMCiphertext, - meta: &ControlMeta, - signing_pub_key: &MLDSAPublicKey, - encapsulation_pub_key: &MLKEMPublicKey, -) -> Vec { - encode_value(&PairingProofData { - aad: pairing_aad(header, kem_ct), - meta: *meta, - signing_pub_key: signing_pub_key.clone(), - encapsulation_pub_key: encapsulation_pub_key.clone(), - }) -} - -fn decrypt_body( - key: &SymmetricKey, - encrypted: &mut ArchivedEncryptedMessage, - aad: &[u8], -) -> Result { - let plaintext = encrypted.decrypt(key, aad)?; - let body = access_value::(plaintext)?; - deserialize_value(body) -} - -pub(crate) fn pairing_aad(header: &QlHeader, kem_ct: &MLKEMCiphertext) -> Vec { - encode_value(&PairingAad { - header: header.clone(), - kem_ct: kem_ct.clone(), - }) -} diff --git a/ql-engine/src/wire/pair/mod.rs b/ql-engine/src/wire/pair/mod.rs deleted file mode 100644 index 7bb5f488..00000000 --- a/ql-engine/src/wire/pair/mod.rs +++ /dev/null @@ -1,28 +0,0 @@ -use bc_components::{MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, MLKEMPublicKey}; -use rkyv::{Archive, Deserialize, Serialize}; - -use super::{ - encrypted_message::EncryptedMessage, AsWireMlDsaPublicKey, AsWireMlDsaSignature, - AsWireMlKemCiphertext, AsWireMlKemPublicKey, ControlMeta, -}; - -mod crypto; -pub use crypto::*; - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct PairRequestRecord { - #[rkyv(with = AsWireMlKemCiphertext)] - pub kem_ct: MLKEMCiphertext, - pub encrypted: EncryptedMessage, -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct PairRequestBody { - pub meta: ControlMeta, - #[rkyv(with = AsWireMlDsaPublicKey)] - pub signing_pub_key: MLDSAPublicKey, - #[rkyv(with = AsWireMlKemPublicKey)] - pub encapsulation_pub_key: MLKEMPublicKey, - #[rkyv(with = AsWireMlDsaSignature)] - pub proof: MLDSASignature, -} diff --git a/ql-engine/src/wire/seq.rs b/ql-engine/src/wire/seq.rs deleted file mode 100644 index c3cc1dd9..00000000 --- a/ql-engine/src/wire/seq.rs +++ /dev/null @@ -1,97 +0,0 @@ -use std::{cmp::Ordering, fmt}; - -use rkyv::{Archive, Deserialize, Serialize}; - -#[derive( - Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, -)] -#[repr(transparent)] -pub struct StreamSeq(pub u32); - -impl fmt::Display for StreamSeq { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From<&ArchivedStreamSeq> for StreamSeq { - fn from(value: &ArchivedStreamSeq) -> Self { - Self(value.0.to_native()) - } -} - -impl StreamSeq { - const HALF_RANGE: u32 = 1 << 31; - pub const START: Self = Self(1); - - pub fn next(self) -> Self { - Self(self.0.wrapping_add(1)) - } - - pub fn prev(self) -> Self { - Self(self.0.wrapping_sub(1)) - } - - pub fn add(self, delta: u32) -> Self { - Self(self.0.wrapping_add(delta)) - } - - pub fn serial_cmp(self, other: Self) -> Ordering { - if self == other { - return Ordering::Equal; - } - - let delta = self.0.wrapping_sub(other.0); - if delta < Self::HALF_RANGE { - Ordering::Greater - } else { - Ordering::Less - } - } - - pub fn serial_lt(self, other: Self) -> bool { - self.serial_cmp(other) == Ordering::Less - } - - pub fn serial_lte(self, other: Self) -> bool { - !self.serial_gt(other) - } - - pub fn serial_gt(self, other: Self) -> bool { - self.serial_cmp(other) == Ordering::Greater - } - - pub fn forward_distance_to(self, other: Self) -> Option { - match other.serial_cmp(self) { - Ordering::Less => None, - Ordering::Equal => Some(0), - Ordering::Greater => Some(other.0.wrapping_sub(self.0)), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn stream_seq_serial_order_wraps() { - assert!(StreamSeq(0).serial_gt(StreamSeq(u32::MAX))); - assert!(StreamSeq(1).serial_gt(StreamSeq(u32::MAX))); - assert!(StreamSeq(u32::MAX).serial_lt(StreamSeq(0))); - assert!(StreamSeq(u32::MAX - 1).serial_lt(StreamSeq(1))); - } - - #[test] - fn stream_seq_forward_distance_wraps() { - assert_eq!( - StreamSeq(u32::MAX - 1).forward_distance_to(StreamSeq(1)), - Some(3) - ); - assert_eq!( - StreamSeq(u32::MAX).forward_distance_to(StreamSeq(2)), - Some(3) - ); - assert_eq!(StreamSeq(1).forward_distance_to(StreamSeq(u32::MAX)), None); - } -} diff --git a/ql-engine/src/wire/stream/crypto.rs b/ql-engine/src/wire/stream/crypto.rs deleted file mode 100644 index 48ea3522..00000000 --- a/ql-engine/src/wire/stream/crypto.rs +++ /dev/null @@ -1,39 +0,0 @@ -use bc_components::SymmetricKey; - -use super::StreamBody; -use crate::{ - wire::{ - access_value, deserialize_value, encode_value, - encrypted_message::{ArchivedEncryptedMessage, EncryptedMessage, NONCE_SIZE}, - ensure_not_expired, QlHeader, QlPayload, QlRecord, - }, - QlError, -}; - -pub fn encrypt_stream( - header: QlHeader, - session_key: &SymmetricKey, - body: &StreamBody, - nonce: [u8; NONCE_SIZE], -) -> QlRecord { - let aad = header.aad(); - let body_bytes = encode_value(body); - let encrypted = EncryptedMessage::encrypt(session_key, body_bytes, &aad, nonce); - QlRecord { - header, - payload: QlPayload::Stream(encrypted), - } -} - -pub(crate) fn decrypt_stream( - header: &QlHeader, - encrypted: &mut ArchivedEncryptedMessage, - session_key: &SymmetricKey, -) -> Result { - let aad = header.aad(); - let plaintext = encrypted.decrypt(session_key, &aad)?; - let body = access_value::(plaintext)?; - let body = deserialize_value(body)?; - ensure_not_expired(body.valid_until())?; - Ok(body) -} diff --git a/ql-engine/src/wire/stream/mod.rs b/ql-engine/src/wire/stream/mod.rs deleted file mode 100644 index b6bfeab5..00000000 --- a/ql-engine/src/wire/stream/mod.rs +++ /dev/null @@ -1,199 +0,0 @@ -use rkyv::{Archive, Deserialize, Serialize}; - -use crate::{wire::StreamSeq, StreamId}; - -mod crypto; -pub use crypto::*; - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub enum StreamBody { - Ack(StreamAckBody), - Message(StreamMessage), -} - -impl StreamBody { - pub fn stream_id(&self) -> StreamId { - match self { - Self::Ack(StreamAckBody { stream_id, .. }) => *stream_id, - Self::Message(message) => message.frame.stream_id(), - } - } - - pub fn valid_until(&self) -> u64 { - match self { - Self::Ack(body) => body.valid_until, - Self::Message(message) => message.valid_until, - } - } -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] -pub struct StreamAckBody { - pub stream_id: StreamId, - pub ack: StreamAck, - pub valid_until: u64, -} - -impl From<&ArchivedStreamAckBody> for StreamAckBody { - fn from(value: &ArchivedStreamAckBody) -> Self { - Self { - stream_id: (&value.stream_id).into(), - ack: (&value.ack).into(), - valid_until: value.valid_until.to_native(), - } - } -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub struct StreamMessage { - pub tx_seq: StreamSeq, - pub ack: StreamAck, - pub valid_until: u64, - pub frame: StreamFrame, -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] -pub struct StreamAck { - pub base: StreamSeq, - pub bitmap: u8, -} - -impl From<&ArchivedStreamAck> for StreamAck { - fn from(value: &ArchivedStreamAck) -> Self { - Self { - base: (&value.base).into(), - bitmap: value.bitmap, - } - } -} - -impl StreamAck { - pub const EMPTY: Self = Self { - base: StreamSeq(0), - bitmap: 0, - }; -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub enum StreamFrame { - Open(StreamFrameOpen), - Data(StreamFrameData), - Close(StreamFrameClose), -} - -impl StreamFrame { - pub fn stream_id(&self) -> StreamId { - match self { - StreamFrame::Open(StreamFrameOpen { stream_id, .. }) - | StreamFrame::Data(StreamFrameData { stream_id, .. }) - | StreamFrame::Close(StreamFrameClose { stream_id, .. }) => *stream_id, - } - } -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub struct BodyChunk { - pub bytes: Vec, - pub fin: bool, -} - -impl From<&ArchivedBodyChunk> for BodyChunk { - fn from(value: &ArchivedBodyChunk) -> Self { - Self { - bytes: value.bytes.as_slice().to_vec(), - fin: value.fin, - } - } -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub struct StreamFrameOpen { - pub stream_id: StreamId, - pub request_head: Vec, - pub request_prefix: Option, -} - -impl From<&ArchivedStreamFrameOpen> for StreamFrameOpen { - fn from(value: &ArchivedStreamFrameOpen) -> Self { - Self { - stream_id: (&value.stream_id).into(), - request_head: value.request_head.as_slice().to_vec(), - request_prefix: value.request_prefix.as_ref().map(Into::into), - } - } -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub struct StreamFrameData { - pub stream_id: StreamId, - pub chunk: BodyChunk, -} - -impl From<&ArchivedStreamFrameData> for StreamFrameData { - fn from(value: &ArchivedStreamFrameData) -> Self { - Self { - stream_id: (&value.stream_id).into(), - chunk: (&value.chunk).into(), - } - } -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub struct StreamFrameClose { - pub stream_id: StreamId, - pub target: CloseTarget, - pub code: CloseCode, - pub payload: Vec, -} - -impl From<&ArchivedStreamFrameClose> for StreamFrameClose { - fn from(value: &ArchivedStreamFrameClose) -> Self { - Self { - stream_id: (&value.stream_id).into(), - target: (&value.target).into(), - code: (&value.code).into(), - payload: value.payload.as_slice().to_vec(), - } - } -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] -#[repr(u8)] -pub enum CloseTarget { - Request = 1, - Response = 2, - Both = 3, -} - -impl From<&ArchivedCloseTarget> for CloseTarget { - fn from(value: &ArchivedCloseTarget) -> Self { - match value { - ArchivedCloseTarget::Request => Self::Request, - ArchivedCloseTarget::Response => Self::Response, - ArchivedCloseTarget::Both => Self::Both, - } - } -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] -#[repr(transparent)] -pub struct CloseCode(pub u16); - -impl CloseCode { - pub const CANCELLED: Self = Self(0); - pub const PROTOCOL: Self = Self(1); - pub const INVALID_DATA: Self = Self(2); - pub const TIMEOUT: Self = Self(3); - - pub const UNKNOWN: Self = Self(16); - pub const UNKNOWN_ROUTE: Self = Self(17); - pub const INVALID_HEAD: Self = Self(18); - pub const BUSY: Self = Self(19); - pub const UNHANDLED: Self = Self(20); -} - -impl From<&ArchivedCloseCode> for CloseCode { - fn from(value: &ArchivedCloseCode) -> Self { - Self(value.0.to_native()) - } -} diff --git a/ql-engine/src/wire/unpair/crypto.rs b/ql-engine/src/wire/unpair/crypto.rs deleted file mode 100644 index 05df157d..00000000 --- a/ql-engine/src/wire/unpair/crypto.rs +++ /dev/null @@ -1,50 +0,0 @@ -use bc_components::{MLDSAPublicKey, MLDSASignature}; -use rkyv::{Archive, Serialize}; - -use super::UnpairRecord; -use crate::{ - identity::QlIdentity, - wire::{encode_value, ensure_not_expired, ControlMeta, QlHeader, QlPayload, QlRecord}, - QlError, -}; - -#[derive(Archive, Serialize)] -struct UnpairProofData { - domain: Vec, - header: QlHeader, - meta: ControlMeta, -} - -pub fn build_unpair_record(identity: &QlIdentity, header: QlHeader, meta: ControlMeta) -> QlRecord { - let signature = identity - .signing_private_key - .sign(unpair_proof_data(&header, &meta)); - QlRecord { - header, - payload: QlPayload::Unpair(UnpairRecord { meta, signature }), - } -} - -pub fn verify_unpair_record( - header: &QlHeader, - record: &super::ArchivedUnpairRecord, - signing_key: &MLDSAPublicKey, -) -> Result<(), QlError> { - let meta: ControlMeta = (&record.meta).into(); - let signature = MLDSASignature::try_from(&record.signature)?; - ensure_not_expired(meta.valid_until)?; - let proof_data = unpair_proof_data(header, &meta); - if signing_key.verify(&signature, &proof_data).unwrap_or(false) { - Ok(()) - } else { - Err(QlError::InvalidSignature) - } -} - -fn unpair_proof_data(header: &QlHeader, meta: &ControlMeta) -> Vec { - encode_value(&UnpairProofData { - domain: b"ql-unpair-v1".to_vec(), - header: header.clone(), - meta: *meta, - }) -} diff --git a/ql-engine/src/wire/unpair/mod.rs b/ql-engine/src/wire/unpair/mod.rs deleted file mode 100644 index 62781e8f..00000000 --- a/ql-engine/src/wire/unpair/mod.rs +++ /dev/null @@ -1,14 +0,0 @@ -use bc_components::MLDSASignature; -use rkyv::{Archive, Deserialize, Serialize}; - -use super::{AsWireMlDsaSignature, ControlMeta}; - -mod crypto; -pub use crypto::*; - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct UnpairRecord { - pub meta: ControlMeta, - #[rkyv(with = AsWireMlDsaSignature)] - pub signature: MLDSASignature, -} diff --git a/ql-fsm/Cargo.toml b/ql-fsm/Cargo.toml index 98b0abed..d3dba528 100644 --- a/ql-fsm/Cargo.toml +++ b/ql-fsm/Cargo.toml @@ -6,15 +6,10 @@ description = "Quantum Link synchronous finite state machine" license = "Proprietary" [dependencies] -bc-components = { version = "0.28.0", default-features = false, features = [ - "pqcrypto", -] } +indexmap = "2" ql-wire = { path = "../ql-wire" } -rkyv = { version = "0.8", default-features = false, features = [ - "std", - "bytecheck", - "little_endian", - "unaligned", - "pointer_width_32", -] } thiserror = { version = "2" } + +[dev-dependencies] +libcrux-aesgcm = "0.0.7" +sha2 = "0.10" diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs new file mode 100644 index 00000000..9a25a943 --- /dev/null +++ b/ql-fsm/src/error.rs @@ -0,0 +1,51 @@ +use ql_wire::WireError; +use thiserror::Error; + +use crate::session::StreamError; + +#[derive(Debug, Clone, PartialEq, Eq, Error)] +pub enum QlFsmError { + #[error("invalid payload")] + InvalidPayload, + #[error("invalid signature")] + InvalidSignature, + #[error("expired")] + Expired, + #[error("signing failed")] + SigningFailed, + #[error("encryption failed")] + EncryptFailed, + #[error("decryption failed")] + DecryptFailed, + #[error("missing stream")] + MissingStream, + #[error("stream is not writable")] + NotWritable, + #[error("session is closed")] + SessionClosed, + #[error("no peer bound")] + NoPeerBound, +} + +impl From for QlFsmError { + fn from(value: WireError) -> Self { + match value { + WireError::InvalidPayload => Self::InvalidPayload, + WireError::InvalidSignature => Self::InvalidSignature, + WireError::Expired => Self::Expired, + WireError::SigningFailed => Self::SigningFailed, + WireError::EncryptFailed => Self::EncryptFailed, + WireError::DecryptFailed => Self::DecryptFailed, + } + } +} + +impl From for QlFsmError { + fn from(value: StreamError) -> Self { + match value { + StreamError::MissingStream => Self::MissingStream, + StreamError::NotWritable => Self::NotWritable, + StreamError::SessionClosed => Self::SessionClosed, + } + } +} diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs new file mode 100644 index 00000000..403ab79d --- /dev/null +++ b/ql-fsm/src/implementation/fsm.rs @@ -0,0 +1,218 @@ +use std::time::Instant; + +use ql_wire::{self as wire, CloseCode, CloseTarget, Nonce, QlCrypto, QlPayloadRef, StreamId}; + +use crate::{OutboundWrite, QlFsm, QlFsmError, QlFsmEvent, QlSessionEvent, SessionWriteId}; + +pub fn receive( + fsm: &mut QlFsm, + mut bytes: Vec, + crypto: &impl QlCrypto, +) -> Result<(), QlFsmError> { + let wire::QlRecordRef { header, payload } = wire::QlRecord::parse_mut(&mut bytes)?; + + if header.recipient != fsm.identity.xid { + return Ok(()); + } + if !matches!(&payload, QlPayloadRef::PairRequest(_)) { + let Some(peer) = fsm.peer.as_ref().map(|entry| entry.peer.xid) else { + return Ok(()); + }; + if header.sender != peer { + return Ok(()); + } + } + + match payload { + QlPayloadRef::PairRequest(mut request) => { + super::handle_pair(fsm, crypto, &header, &mut request)?; + } + QlPayloadRef::Hello(hello) => { + super::handle_hello(fsm, crypto, &header, &hello)?; + } + QlPayloadRef::HelloReply(reply) => { + super::handle_hello_reply(fsm, crypto, &header, &reply)?; + } + QlPayloadRef::Confirm(confirm) => { + super::handle_confirm(fsm, crypto, &header, &confirm)?; + } + QlPayloadRef::Ready(mut ready) => { + super::handle_ready(fsm, crypto, &header, &mut ready)?; + } + QlPayloadRef::Session(mut encrypted) => { + let Some((_, session_key)) = super::peer_session(fsm) else { + return Ok(()); + }; + let envelope = match wire::decrypt_record(crypto, &header, &mut encrypted, &session_key) + .and_then(|envelope| envelope.to_session_envelope()) + { + Ok(envelope) => envelope, + Err(_) => return Ok(()), + }; + fsm.session.receive(fsm.state.now.instant, envelope); + super::drain_session_events(fsm); + } + } + + Ok(()) +} + +pub fn on_timer(fsm: &mut QlFsm) { + super::handle_timer(fsm); + if super::peer_session(fsm).is_some() { + fsm.session.on_timer(fsm.state.now.instant); + super::drain_session_events(fsm); + } +} + +pub fn next_deadline(fsm: &QlFsm) -> Option { + [ + super::next_handshake_deadline(fsm), + super::peer_session(fsm).and_then(|_| fsm.session.next_deadline()), + ] + .into_iter() + .flatten() + .min() +} + +pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option { + if let Some(record) = fsm.state.outbound.pop_front() { + return Some(OutboundWrite { + record, + session_write_id: None, + }); + } + + if matches!( + fsm.peer.as_ref().map(|entry| &entry.session), + Some(crate::state::ConnectionState::Disconnected) + ) && fsm.session.has_pending_stream_work() + { + let _ = super::handle_connect(fsm, crypto); + if let Some(record) = fsm.state.outbound.pop_front() { + return Some(OutboundWrite { + record, + session_write_id: None, + }); + } + } + + let (recipient, session_key) = super::peer_session(fsm)?; + let envelope = fsm.session.take_next_write(fsm.state.now.instant)?; + let mut nonce = [0u8; Nonce::SIZE]; + crypto.fill_random_bytes(&mut nonce); + Some(OutboundWrite { + record: wire::encrypt_record( + crypto, + wire::QlHeader { + sender: fsm.identity.xid, + recipient, + }, + &session_key, + &envelope, + Nonce(nonce), + ) + .ok()?, + session_write_id: Some(SessionWriteId(envelope.seq)), + }) +} + +pub fn confirm_session_write(fsm: &mut QlFsm, write_id: SessionWriteId) { + fsm.session.confirm_write(fsm.state.now.instant, write_id.0); +} + +pub fn return_session_write(fsm: &mut QlFsm, write_id: SessionWriteId) { + fsm.session.return_write(write_id.0); +} + +pub fn kill_session(fsm: &mut QlFsm, code: CloseCode) { + let Some(entry) = fsm.peer.as_mut() else { + return; + }; + if !matches!( + entry.session, + crate::state::ConnectionState::Connected { .. } + ) { + return; + } + + entry.session = crate::state::ConnectionState::Disconnected; + super::emit_peer_status(fsm); + super::reset_session(fsm); + fsm.state + .session_events + .push_back(QlSessionEvent::SessionClosed(ql_wire::SessionCloseBody { + code, + })); +} + +pub fn take_next_event(fsm: &mut QlFsm) -> Option { + fsm.state.events.pop_front() +} + +pub fn take_next_session_event(fsm: &mut QlFsm) -> Option { + fsm.state.session_events.pop_front() +} + +pub fn open_stream(fsm: &mut QlFsm) -> Result { + ensure_peer_bound(fsm)?; + fsm.session.open_stream().map_err(Into::into) +} + +pub fn write_stream( + fsm: &mut QlFsm, + stream_id: StreamId, + bytes: Vec, +) -> Result<(), QlFsmError> { + ensure_peer_bound(fsm)?; + fsm.session + .write_stream(stream_id, bytes) + .map_err(Into::into) +} + +pub fn finish_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result<(), QlFsmError> { + ensure_peer_bound(fsm)?; + fsm.session.finish_stream(stream_id).map_err(Into::into) +} + +pub fn close_stream( + fsm: &mut QlFsm, + stream_id: StreamId, + target: CloseTarget, + code: CloseCode, + payload: Vec, +) -> Result<(), QlFsmError> { + ensure_peer_bound(fsm)?; + fsm.session + .close_stream(stream_id, target, code, payload) + .map_err(Into::into) +} + +pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), QlFsmError> { + ensure_session_open(fsm)?; + fsm.session.queue_ping().map_err(Into::into) +} + +pub fn queue_unpair(fsm: &mut QlFsm) -> Result<(), QlFsmError> { + ensure_session_open(fsm)?; + // TODO: keep local peer/session state alive until this queued unpair is acked or times out, + // then clear it locally. Right now this only requests remote unpair. + fsm.session.queue_unpair().map_err(Into::into) +} + +fn ensure_peer_bound(fsm: &QlFsm) -> Result<(), QlFsmError> { + fsm.peer.as_ref().map(|_| ()).ok_or(QlFsmError::NoPeerBound) +} + +fn ensure_session_open(fsm: &QlFsm) -> Result<(), QlFsmError> { + ensure_peer_bound(fsm)?; + if fsm + .peer + .as_ref() + .and_then(|entry| entry.session.session_key()) + .is_none() + { + return Err(QlFsmError::SessionClosed); + } + Ok(()) +} diff --git a/ql-fsm/src/implementation/handshake.rs b/ql-fsm/src/implementation/handshake.rs index 5820a170..35fd2763 100644 --- a/ql-fsm/src/implementation/handshake.rs +++ b/ql-fsm/src/implementation/handshake.rs @@ -1,15 +1,17 @@ use std::{cmp::Ordering, time::Instant}; -use bc_components::{MLDSAPublicKey, SymmetricKey}; use ql_wire::{ - self as wire, - handshake::{Confirm, Hello, HelloReply, Ready}, - ControlMeta, QlCrypto, QlHeader, XID, + self as wire, Confirm, Hello, HelloReply, MlDsaPublicKey, Nonce, QlCrypto, QlHeader, QlPayload, + Ready, ReadyRef, SessionKey, XID, }; -use rkyv::api::low; +use super::{ + emit_peer_status, enqueue_handshake, fail_pending_connect_session, is_replayed_control, + next_control_meta, +}; use crate::{ - HandshakeInitiator, HandshakeResponder, Peer, PeerSession, QlFsm, QlFsmError, RecentReady, + state::{ConnectionState, HandshakeInitiator, HandshakeResponder, RecentReady}, + Peer, QlFsm, QlFsmError, }; #[derive(Debug)] @@ -23,8 +25,8 @@ enum HelloAction { enum HelloReplyAction { Advance { hello: Hello, - initiator_secret: SymmetricKey, - responder_signing_key: MLDSAPublicKey, + initiator_secret: SessionKey, + responder_signing_key: MlDsaPublicKey, }, ResendConfirm { confirm: Confirm, @@ -44,20 +46,21 @@ pub fn handle_connect(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlF pub fn handle_hello( fsm: &mut QlFsm, - header: &QlHeader, - archived_hello: &wire::handshake::ArchivedHello, crypto: &impl QlCrypto, + header: &QlHeader, + hello: &Hello, ) -> Result<(), QlFsmError> { - let hello: Hello = deserialize_archived(archived_hello)?; let action = { let Some(entry) = fsm.peer.as_ref() else { return Ok(()); }; - if wire::handshake::verify_hello( + if wire::verify_hello( + crypto, header.sender, fsm.identity.xid, &entry.peer.signing_key, - archived_hello, + hello, + fsm.state.now.unix_secs, ) .is_err() { @@ -65,7 +68,7 @@ pub fn handle_hello( } match &entry.session { - PeerSession::Initiator { + ConnectionState::Initiator { hello: local_hello, .. } => { if peer_hello_wins(local_hello, fsm.identity.xid, &hello, header.sender) { @@ -74,7 +77,7 @@ pub fn handle_hello( HelloAction::Ignore } } - PeerSession::Responder { + ConnectionState::Responder { hello: stored, reply, stage: HandshakeResponder::WaitingConfirm { .. }, @@ -88,7 +91,7 @@ pub fn handle_hello( HelloAction::StartResponder } } - PeerSession::Disconnected | PeerSession::Connected { .. } => { + ConnectionState::Disconnected | ConnectionState::Connected { .. } => { HelloAction::StartResponder } } @@ -97,35 +100,33 @@ pub fn handle_hello( match action { HelloAction::Ignore => {} HelloAction::ResendReply { reply } => { - fsm.enqueue_handshake( - header.sender, - wire::handshake::HandshakeRecord::HelloReply(reply), - ); + enqueue_handshake(fsm, header.sender, QlPayload::HelloReply(reply)); } HelloAction::StartResponder => { - if fsm.is_replayed_control(header.sender, hello.meta) { + if is_replayed_control(fsm, header.sender, hello.meta) { return Ok(()); } let peer = fsm.peer.as_ref().map(|entry| entry.peer.clone()).unwrap(); - let reply_meta = fsm.next_control_meta(fsm.config.handshake_timeout); - let responder = wire::handshake::respond_hello( - &fsm.identity, + let reply_meta = next_control_meta(fsm, fsm.config.handshake_timeout); + let responder = wire::respond_hello( crypto, + &fsm.identity, peer.xid, &peer.signing_key, &peer.encapsulation_key, - archived_hello, + hello, reply_meta, + fsm.state.now.unix_secs, ); let (reply, secrets) = match responder { Ok(result) => result, Err(_) => { if let Some(entry) = fsm.peer.as_mut() { - entry.session = PeerSession::Disconnected; + entry.session = ConnectionState::Disconnected; } - fsm.emit_peer_status(); + emit_peer_status(fsm); return Ok(()); } }; @@ -133,7 +134,7 @@ pub fn handle_hello( let deadline = fsm.state.now.instant + fsm.config.handshake_timeout; let retry_at = Some(fsm.state.now.instant + fsm.config.handshake_retry_interval); if let Some(entry) = fsm.peer.as_mut() { - entry.session = PeerSession::Responder { + entry.session = ConnectionState::Responder { hello: hello.clone(), reply: reply.clone(), deadline, @@ -144,11 +145,8 @@ pub fn handle_hello( }, }; } - fsm.enqueue_handshake( - header.sender, - wire::handshake::HandshakeRecord::HelloReply(reply), - ); - fsm.emit_peer_status(); + enqueue_handshake(fsm, header.sender, QlPayload::HelloReply(reply)); + emit_peer_status(fsm); } } @@ -157,16 +155,16 @@ pub fn handle_hello( pub fn handle_hello_reply( fsm: &mut QlFsm, + crypto: &impl QlCrypto, header: &QlHeader, - archived_reply: &wire::handshake::ArchivedHelloReply, + reply: &HelloReply, ) -> Result<(), QlFsmError> { - let reply: HelloReply = deserialize_archived(archived_reply)?; let action = { let Some(entry) = fsm.peer.as_ref() else { return Ok(()); }; match &entry.session { - PeerSession::Initiator { + ConnectionState::Initiator { hello, stage: HandshakeInitiator::WaitingHelloReply { @@ -178,7 +176,7 @@ pub fn handle_hello_reply( initiator_secret: initiator_secret.clone(), responder_signing_key: entry.peer.signing_key.clone(), }, - PeerSession::Initiator { + ConnectionState::Initiator { stage: HandshakeInitiator::WaitingReady { reply: stored, @@ -195,38 +193,37 @@ pub fn handle_hello_reply( match action { HelloReplyAction::ResendConfirm { confirm } => { - fsm.enqueue_handshake( - header.sender, - wire::handshake::HandshakeRecord::Confirm(confirm), - ); + enqueue_handshake(fsm, header.sender, QlPayload::Confirm(confirm)); } HelloReplyAction::Advance { hello, initiator_secret, responder_signing_key, } => { - let confirm_meta = fsm.next_control_meta(fsm.config.handshake_timeout); - let (confirm, session_key) = match wire::handshake::build_confirm( + let confirm_meta = next_control_meta(fsm, fsm.config.handshake_timeout); + let (confirm, session_key) = match wire::build_confirm( + crypto, &fsm.identity, header.sender, &responder_signing_key, &hello, - archived_reply, + reply, &initiator_secret, confirm_meta, + fsm.state.now.unix_secs, ) { Ok(result) => result, Err(_) => return Ok(()), }; - if fsm.is_replayed_control(header.sender, reply.meta) { + if is_replayed_control(fsm, header.sender, reply.meta) { return Ok(()); } let deadline = fsm.state.now.instant + fsm.config.handshake_timeout; let retry_at = Some(fsm.state.now.instant + fsm.config.handshake_retry_interval); if let Some(entry) = fsm.peer.as_mut() { - entry.session = PeerSession::Initiator { + entry.session = ConnectionState::Initiator { hello, deadline, stage: HandshakeInitiator::WaitingReady { @@ -238,33 +235,21 @@ pub fn handle_hello_reply( }, }; } - fsm.enqueue_handshake( - header.sender, - wire::handshake::HandshakeRecord::Confirm(confirm), - ); + enqueue_handshake(fsm, header.sender, QlPayload::Confirm(confirm)); } } Ok(()) } -fn deserialize_archived( - value: &impl rkyv::Deserialize>, -) -> Result { - low::deserialize::(value).map_err(|_| QlFsmError::InvalidPayload) -} - pub fn handle_confirm( fsm: &mut QlFsm, - header: &QlHeader, - confirm: &wire::handshake::ArchivedConfirm, crypto: &impl QlCrypto, + header: &QlHeader, + confirm: &Confirm, ) -> Result<(), QlFsmError> { - if let Some(ready) = recent_ready_resend(fsm, header.sender, confirm) { - fsm.enqueue_handshake( - header.sender, - wire::handshake::HandshakeRecord::Ready(ready), - ); + if let Some(ready) = recent_ready_resend(fsm, crypto, header.sender, confirm) { + enqueue_handshake(fsm, header.sender, QlPayload::Ready(ready)); return Ok(()); } @@ -272,7 +257,7 @@ pub fn handle_confirm( let Some(entry) = fsm.peer.as_ref() else { return Ok(()); }; - let PeerSession::Responder { + let ConnectionState::Responder { hello, reply, deadline, @@ -282,7 +267,8 @@ pub fn handle_confirm( return Ok(()); }; - wire::handshake::finalize_confirm( + wire::finalize_confirm( + crypto, header.sender, fsm.identity.xid, &entry.peer.signing_key, @@ -290,6 +276,7 @@ pub fn handle_confirm( reply, confirm, secrets, + fsm.state.now.unix_secs, ) .map(|session_key| (hello.clone(), reply.clone(), *deadline, session_key)) }; @@ -299,23 +286,23 @@ pub fn handle_confirm( Err(_) => return Ok(()), }; - let meta: ControlMeta = (&confirm.meta).into(); - if fsm.is_replayed_control(header.sender, meta) { + if is_replayed_control(fsm, header.sender, confirm.meta) { return Ok(()); } - let ready = wire::handshake::build_ready( + let ready = wire::build_ready( + crypto, QlHeader { sender: fsm.identity.xid, recipient: header.sender, }, &session_key, - fsm.next_control_meta(fsm.config.handshake_timeout), + next_control_meta(fsm, fsm.config.handshake_timeout), next_encrypted_nonce(crypto), - ); + )?; if let Some(entry) = fsm.peer.as_mut() { - entry.session = PeerSession::Connected { + entry.session = ConnectionState::Connected { session_key, recent_ready: Some(RecentReady { hello, @@ -326,25 +313,23 @@ pub fn handle_confirm( }; } - fsm.enqueue_handshake( - header.sender, - wire::handshake::HandshakeRecord::Ready(ready), - ); - fsm.emit_peer_status(); + enqueue_handshake(fsm, header.sender, QlPayload::Ready(ready)); + emit_peer_status(fsm); Ok(()) } pub fn handle_ready( fsm: &mut QlFsm, + crypto: &impl QlCrypto, header: &QlHeader, - ready: &mut wire::handshake::ArchivedReady, + ready: &mut ReadyRef<&mut [u8]>, ) -> Result<(), QlFsmError> { let session_key = { let Some(entry) = fsm.peer.as_ref() else { return Ok(()); }; match &entry.session { - PeerSession::Initiator { + ConnectionState::Initiator { stage: HandshakeInitiator::WaitingReady { session_key, .. }, .. } => session_key.clone(), @@ -352,34 +337,35 @@ pub fn handle_ready( } }; - let body = match wire::handshake::decrypt_ready(header, ready, &session_key) { - Ok(body) => body, - Err(_) => return Ok(()), - }; - if fsm.is_replayed_control(header.sender, body.meta) { + let body = + match wire::decrypt_ready(crypto, header, ready, &session_key, fsm.state.now.unix_secs) { + Ok(body) => body, + Err(_) => return Ok(()), + }; + if is_replayed_control(fsm, header.sender, body.meta) { return Ok(()); } if let Some(entry) = fsm.peer.as_mut() { - entry.session = PeerSession::Connected { + entry.session = ConnectionState::Connected { session_key, recent_ready: None, }; } - fsm.emit_peer_status(); + emit_peer_status(fsm); Ok(()) } pub fn handle_timer(fsm: &mut QlFsm) { let now = fsm.state.now.instant; - if let Some(PeerSession::Connected { + if let Some(ConnectionState::Connected { recent_ready: Some(recent_ready), .. }) = fsm.peer.as_mut().map(|entry| &mut entry.session) { if recent_ready.expires_at <= now { if let Some(entry) = fsm.peer.as_mut() { - if let PeerSession::Connected { recent_ready, .. } = &mut entry.session { + if let ConnectionState::Connected { recent_ready, .. } = &mut entry.session { *recent_ready = None; } } @@ -391,13 +377,13 @@ pub fn handle_timer(fsm: &mut QlFsm) { if let Some(entry) = fsm.peer.as_mut() { match &mut entry.session { - PeerSession::Initiator { + ConnectionState::Initiator { hello, deadline, stage, } => { if *deadline <= now { - entry.session = PeerSession::Disconnected; + entry.session = ConnectionState::Disconnected; disconnected = true; } else { retry_action = handle_initiator_retry( @@ -409,19 +395,19 @@ pub fn handle_timer(fsm: &mut QlFsm) { fsm.config.max_handshake_retries, ); if retry_action.is_none() && initiator_retries_exhausted(stage) { - entry.session = PeerSession::Disconnected; + entry.session = ConnectionState::Disconnected; disconnected = true; } } } - PeerSession::Responder { + ConnectionState::Responder { reply, deadline, stage, .. } => { if *deadline <= now { - entry.session = PeerSession::Disconnected; + entry.session = ConnectionState::Disconnected; disconnected = true; } else { retry_action = handle_responder_retry( @@ -433,39 +419,40 @@ pub fn handle_timer(fsm: &mut QlFsm) { fsm.config.max_handshake_retries, ); if retry_action.is_none() && responder_retries_exhausted(stage) { - entry.session = PeerSession::Disconnected; + entry.session = ConnectionState::Disconnected; disconnected = true; } } } - PeerSession::Disconnected | PeerSession::Connected { .. } => {} + ConnectionState::Disconnected | ConnectionState::Connected { .. } => {} } } if disconnected { - fsm.emit_peer_status(); + fail_pending_connect_session(fsm, ql_wire::CloseCode::TIMEOUT); + emit_peer_status(fsm); } if let Some(action) = retry_action { match action { RetryAction::Hello { peer, hello } => { - fsm.enqueue_handshake(peer, wire::handshake::HandshakeRecord::Hello(hello)); + enqueue_handshake(fsm, peer, QlPayload::Hello(hello)); } RetryAction::HelloReply { peer, reply } => { - fsm.enqueue_handshake(peer, wire::handshake::HandshakeRecord::HelloReply(reply)); + enqueue_handshake(fsm, peer, QlPayload::HelloReply(reply)); } RetryAction::Confirm { peer, confirm } => { - fsm.enqueue_handshake(peer, wire::handshake::HandshakeRecord::Confirm(confirm)); + enqueue_handshake(fsm, peer, QlPayload::Confirm(confirm)); } } } } -pub fn next_deadline(fsm: &QlFsm) -> Option { +pub fn next_handshake_deadline(fsm: &QlFsm) -> Option { let mut deadline = None; if let Some(entry) = fsm.peer.as_ref() { match &entry.session { - PeerSession::Initiator { + ConnectionState::Initiator { deadline: session_deadline, stage, .. @@ -473,7 +460,7 @@ pub fn next_deadline(fsm: &QlFsm) -> Option { deadline = Some(*session_deadline); deadline = min_optional(deadline, initiator_retry_at(stage)); } - PeerSession::Responder { + ConnectionState::Responder { deadline: session_deadline, stage, .. @@ -481,13 +468,13 @@ pub fn next_deadline(fsm: &QlFsm) -> Option { deadline = Some(*session_deadline); deadline = min_optional(deadline, responder_retry_at(stage)); } - PeerSession::Connected { + ConnectionState::Connected { recent_ready: Some(recent_ready), .. } => { deadline = Some(recent_ready.expires_at); } - PeerSession::Disconnected | PeerSession::Connected { .. } => {} + ConnectionState::Disconnected | ConnectionState::Connected { .. } => {} } } deadline @@ -497,15 +484,15 @@ fn start_initiator_handshake(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result< let Some(entry) = fsm.peer.as_ref() else { return Err(QlFsmError::NoPeerBound); }; - if !matches!(entry.session, PeerSession::Disconnected) { + if !matches!(entry.session, ConnectionState::Disconnected) { return Ok(()); } let peer = entry.peer.clone(); - let meta = fsm.next_control_meta(fsm.config.handshake_timeout); - let (hello, initiator_secret) = wire::handshake::build_hello( - &fsm.identity, + let meta = next_control_meta(fsm, fsm.config.handshake_timeout); + let (hello, initiator_secret) = wire::build_hello( crypto, + &fsm.identity, peer.xid, &peer.encapsulation_key, meta, @@ -514,7 +501,7 @@ fn start_initiator_handshake(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result< let retry_at = Some(fsm.state.now.instant + fsm.config.handshake_retry_interval); if let Some(entry) = fsm.peer.as_mut() { - entry.session = PeerSession::Initiator { + entry.session = ConnectionState::Initiator { hello: hello.clone(), deadline, stage: HandshakeInitiator::WaitingHelloReply { @@ -525,18 +512,19 @@ fn start_initiator_handshake(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result< }; } - fsm.enqueue_handshake(peer.xid, wire::handshake::HandshakeRecord::Hello(hello)); - fsm.emit_peer_status(); + enqueue_handshake(fsm, peer.xid, QlPayload::Hello(hello)); + emit_peer_status(fsm); Ok(()) } fn recent_ready_resend( fsm: &QlFsm, + crypto: &impl QlCrypto, peer: XID, - confirm: &wire::handshake::ArchivedConfirm, + confirm: &Confirm, ) -> Option { let entry = fsm.peer.as_ref()?; - let PeerSession::Connected { + let ConnectionState::Connected { recent_ready: Some(recent_ready), .. } = &entry.session @@ -546,13 +534,15 @@ fn recent_ready_resend( if recent_ready.expires_at <= fsm.state.now.instant { return None; } - wire::handshake::verify_confirm( + wire::verify_confirm( + crypto, peer, fsm.identity.xid, &entry.peer.signing_key, &recent_ready.hello, &recent_ready.reply, confirm, + fsm.state.now.unix_secs, ) .ok()?; Some(recent_ready.ready.clone()) @@ -693,10 +683,10 @@ fn peer_hello_wins( } } -fn next_encrypted_nonce(crypto: &impl QlCrypto) -> wire::Nonce { - let mut bytes = [0u8; wire::Nonce::NONCE_SIZE]; +fn next_encrypted_nonce(crypto: &impl QlCrypto) -> Nonce { + let mut bytes = [0u8; Nonce::SIZE]; crypto.fill_random_bytes(&mut bytes); - wire::Nonce(bytes) + Nonce(bytes) } fn retry_due(retry_at: Option, now: Instant) -> bool { diff --git a/ql-fsm/src/implementation/mod.rs b/ql-fsm/src/implementation/mod.rs index af67a96d..e1114d98 100644 --- a/ql-fsm/src/implementation/mod.rs +++ b/ql-fsm/src/implementation/mod.rs @@ -1,120 +1,145 @@ -pub mod handshake; -pub mod peer; +mod fsm; +mod handshake; +mod peer; use std::time::Duration; -use ql_wire::{ - self as wire, handshake::ArchivedHandshakeRecord, ArchivedQlPayload, ControlId, ControlMeta, - QlCrypto, QlHeader, QlPayload, QlRecord, XID, -}; -use rkyv::api::low; +pub use fsm::*; +pub use handshake::*; +pub use peer::*; +use ql_wire::{ControlId, ControlMeta, QlHeader, QlPayload, QlRecord, SessionKey, XID}; -use crate::{Peer, QlFsm, QlFsmError, QlFsmEvent}; +use crate::{ + session::{SessionEvent, SessionFsmConfig, StreamIncoming, StreamNamespace}, + QlFsm, QlFsmEvent, QlSessionEvent, +}; -impl QlFsm { - pub fn bind_peer_inner(&mut self, peer: Peer) { - peer::handle_bind_peer(self, peer); +fn emit_peer_status(fsm: &mut QlFsm) { + if let Some(entry) = fsm.peer.as_ref() { + fsm.state.events.push_back(QlFsmEvent::PeerStatusChanged { + peer: entry.peer.xid, + status: entry.session.status(), + }); } +} - pub fn pair_inner(&mut self, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { - peer::handle_pair_local(self, crypto) +fn next_control_meta(fsm: &mut QlFsm, lifetime: Duration) -> ControlMeta { + let control_id = ControlId(fsm.state.next_control_id); + fsm.state.next_control_id = fsm.state.next_control_id.wrapping_add(1); + ControlMeta { + control_id, + valid_until: deadline_after_secs(fsm.state.now.unix_secs, lifetime), } +} - pub fn connect_inner(&mut self, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { - handshake::handle_connect(self, crypto) - } +fn enqueue_handshake(fsm: &mut QlFsm, peer: XID, payload: QlPayload) { + fsm.state.outbound.push_back(QlRecord { + header: QlHeader { + sender: fsm.identity.xid, + recipient: peer, + }, + payload, + }); +} - pub fn receive_inner( - &mut self, - mut bytes: Vec, - crypto: &impl QlCrypto, - ) -> Result<(), QlFsmError> { - let archived = wire::access_record_mut(&mut bytes)?; - let archived = unsafe { archived.unseal_unchecked() }; - let header: QlHeader = deserialize_archived(&archived.header)?; +fn is_replayed_control(fsm: &mut QlFsm, peer: XID, meta: ControlMeta) -> bool { + fsm.state + .replay_cache + .check_and_store_valid_until(peer, meta, fsm.state.now.unix_secs) +} - if header.recipient != self.identity.xid { - return Ok(()); - } - if !matches!(&archived.payload, ArchivedQlPayload::Pair(_)) { - let Some(peer) = self.peer.as_ref().map(|entry| entry.peer.xid) else { - return Ok(()); - }; - if header.sender != peer { - return Ok(()); - } - } +fn peer_session(fsm: &QlFsm) -> Option<(XID, SessionKey)> { + let entry = fsm.peer.as_ref()?; + let session_key = entry.session.session_key()?.clone(); + Some((entry.peer.xid, session_key)) +} + +fn reset_session(fsm: &mut QlFsm) { + let local_namespace = fsm + .peer + .as_ref() + .map(|peer| StreamNamespace::for_local(fsm.identity.xid, peer.peer.xid)) + .unwrap_or(StreamNamespace::Low); + fsm.session = crate::session::SessionFsm::new( + SessionFsmConfig { + local_namespace, + ack_delay: fsm.config.session_ack_delay, + retransmit_timeout: fsm.config.session_retransmit_timeout, + keepalive_interval: fsm.config.session_keepalive_interval, + peer_timeout: fsm.config.session_peer_timeout, + }, + fsm.state.now.instant, + ); +} + +fn fail_pending_connect_session(fsm: &mut QlFsm, code: ql_wire::CloseCode) { + if !fsm.session.has_pending_stream_work() { + return; + } + reset_session(fsm); + fsm.state + .session_events + .push_back(QlSessionEvent::SessionClosed(ql_wire::SessionCloseBody { + code, + })); +} - match &mut archived.payload { - ArchivedQlPayload::Pair(request) => { - peer::handle_pair(self, &header, request, crypto)?; +fn drain_session_events(fsm: &mut QlFsm) { + while let Some(event) = fsm.session.take_next_event() { + match event { + SessionEvent::Opened(stream_id) => { + fsm.state + .session_events + .push_back(QlSessionEvent::Opened(stream_id)); } - ArchivedQlPayload::Handshake(ArchivedHandshakeRecord::Hello(archived_hello)) => { - handshake::handle_hello(self, &header, archived_hello, crypto)?; + SessionEvent::Readable(stream_id) => { + while let Some(incoming) = fsm.session.take_next_inbound(stream_id) { + match incoming { + StreamIncoming::Data(bytes) => { + fsm.state + .session_events + .push_back(QlSessionEvent::Data { stream_id, bytes }); + } + StreamIncoming::Finished => { + fsm.state + .session_events + .push_back(QlSessionEvent::Finished(stream_id)); + } + StreamIncoming::Closed(frame) => { + fsm.state + .session_events + .push_back(QlSessionEvent::Closed(frame)); + } + } + } } - ArchivedQlPayload::Handshake(ArchivedHandshakeRecord::HelloReply(archived_reply)) => { - handshake::handle_hello_reply(self, &header, archived_reply)?; + SessionEvent::WritableClosed(stream_id) => { + fsm.state + .session_events + .push_back(QlSessionEvent::WritableClosed(stream_id)); } - ArchivedQlPayload::Handshake(ArchivedHandshakeRecord::Confirm(archived_confirm)) => { - handshake::handle_confirm(self, &header, archived_confirm, crypto)?; + SessionEvent::Unpaired => { + fsm.state.session_events.push_back(QlSessionEvent::Unpaired); + fsm.peer = None; + reset_session(fsm); + fsm.state.events.push_back(QlFsmEvent::ClearPeer); } - ArchivedQlPayload::Handshake(ArchivedHandshakeRecord::Ready(archived_ready)) => { - handshake::handle_ready(self, &header, archived_ready)?; + SessionEvent::SessionClosed(close) => { + fsm.state + .session_events + .push_back(QlSessionEvent::SessionClosed(close.clone())); + if let Some(entry) = fsm.peer.as_mut() { + if matches!( + entry.session, + crate::state::ConnectionState::Connected { .. } + ) { + entry.session = crate::state::ConnectionState::Disconnected; + emit_peer_status(fsm); + } + } + reset_session(fsm); } - ArchivedQlPayload::Encrypted(_) => {} } - - Ok(()) - } - - pub fn on_timer_inner(&mut self) { - handshake::handle_timer(self); - } - - pub fn next_deadline_inner(&self) -> Option { - handshake::next_deadline(self) - } - - pub fn take_next_outbound_inner(&mut self) -> Option { - self.state.outbound.pop_front() - } - - pub fn take_next_event_inner(&mut self) -> Option { - self.state.events.pop_front() - } - - fn emit_peer_status(&mut self) { - if let Some(entry) = self.peer.as_ref() { - self.state.events.push_back(QlFsmEvent::PeerStatusChanged { - peer: entry.peer.xid, - status: entry.session.status(), - }); - } - } - - fn next_control_meta(&mut self, lifetime: Duration) -> ControlMeta { - let control_id = ControlId(self.state.next_control_id); - self.state.next_control_id = self.state.next_control_id.wrapping_add(1); - ControlMeta { - control_id, - valid_until: deadline_after_secs(self.state.now.unix_secs, lifetime), - } - } - - fn enqueue_handshake(&mut self, peer: XID, record: wire::handshake::HandshakeRecord) { - self.state.outbound.push_back(QlRecord { - header: QlHeader { - sender: self.identity.xid, - recipient: peer, - }, - payload: QlPayload::Handshake(record), - }); - } - - fn is_replayed_control(&mut self, peer: XID, meta: ControlMeta) -> bool { - self.state - .replay_cache - .check_and_store_valid_until(peer, meta, self.state.now.unix_secs) } } @@ -127,9 +152,3 @@ fn duration_to_secs(duration: Duration) -> u64 { .as_secs() .saturating_add(u64::from(duration.subsec_nanos() > 0)) } - -fn deserialize_archived( - value: &impl rkyv::Deserialize>, -) -> Result { - low::deserialize::(value).map_err(|_| QlFsmError::InvalidPayload) -} diff --git a/ql-fsm/src/implementation/peer.rs b/ql-fsm/src/implementation/peer.rs index ec5a75c7..cf0632b8 100644 --- a/ql-fsm/src/implementation/peer.rs +++ b/ql-fsm/src/implementation/peer.rs @@ -1,18 +1,18 @@ -use ql_wire::{self as wire, pair::ArchivedPairRequestRecord, QlCrypto, QlHeader}; +use ql_wire::{self as wire, PairRequestRecordRef, QlCrypto, QlHeader}; -use super::handshake; -use crate::{Peer, PeerRecord, QlFsm, QlFsmError, QlFsmEvent}; +use super::{emit_peer_status, handshake, is_replayed_control, next_control_meta, reset_session}; +use crate::{state::PeerRecord, Peer, QlFsm, QlFsmError, QlFsmEvent}; pub fn handle_bind_peer(fsm: &mut QlFsm, peer: Peer) { bind_peer_record(fsm, peer); } pub fn handle_pair_local(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { - let meta = fsm.next_control_meta(fsm.config.control_expiration); + let meta = next_control_meta(fsm, fsm.config.control_expiration); let peer = fsm.peer.as_ref().ok_or(QlFsmError::NoPeerBound)?; - let record = wire::pair::build_pair_request( - &fsm.identity, + let record = wire::build_pair_request( crypto, + &fsm.identity, peer.peer.xid, &peer.peer.encapsulation_key, meta, @@ -23,20 +23,26 @@ pub fn handle_pair_local(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), pub fn handle_pair( fsm: &mut QlFsm, - header: &QlHeader, - request: &mut ArchivedPairRequestRecord, crypto: &impl QlCrypto, + header: &QlHeader, + request: &mut PairRequestRecordRef<&mut [u8]>, ) -> Result<(), QlFsmError> { - let payload = match wire::pair::decrypt_pair_request(&fsm.identity, header, request) { + let payload = match wire::decrypt_pair_request( + crypto, + &fsm.identity, + header, + request, + fsm.state.now.unix_secs, + ) { Ok(payload) => payload, Err(_) => return Ok(()), }; let peer = Peer { - xid: ql_wire::XID::from_signing_public_key(&payload.signing_pub_key), + xid: payload.xid, signing_key: payload.signing_pub_key, encapsulation_key: payload.encapsulation_pub_key, }; - if fsm.is_replayed_control(peer.xid, payload.meta) { + if is_replayed_control(fsm, peer.xid, payload.meta) { return Ok(()); } @@ -51,6 +57,7 @@ pub fn handle_pair( fn bind_peer_record(fsm: &mut QlFsm, peer: Peer) { fsm.peer = Some(PeerRecord::new(peer.clone())); - fsm.state.events.push_back(QlFsmEvent::PersistPeer(peer)); - fsm.emit_peer_status(); + reset_session(fsm); + fsm.state.events.push_back(QlFsmEvent::NewPeer(peer)); + emit_peer_status(fsm); } diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index de3d552d..0ab099b9 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -1,17 +1,24 @@ +mod error; pub(crate) mod implementation; pub(crate) mod replay_cache; -pub mod session; +mod session; pub(crate) mod state; +#[cfg(test)] +mod tests; -use std::time::Instant; +use std::time::{Duration, Instant}; -use ql_wire::{QlCrypto, QlIdentity, QlRecord}; -use state::{ - HandshakeInitiator, HandshakeResponder, Peer, PeerRecord, PeerSession, QlFsm, QlFsmConfig, - QlFsmError, QlFsmEvent, RecentReady, +pub use error::QlFsmError; +use ql_wire::{ + CloseCode, CloseTarget, MlDsaPublicKey, MlKemPublicKey, QlCrypto, QlIdentity, QlRecord, + SessionCloseBody, SessionSeq, StreamClose, StreamId, XID, }; -use crate::{replay_cache::ReplayCache, state::QlFsmState}; +use crate::{ + replay_cache::ReplayCache, + session::SessionFsm, + state::{PeerRecord, QlFsmState}, +}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct FsmTime { @@ -19,40 +26,122 @@ pub struct FsmTime { pub unix_secs: u64, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Peer { + pub xid: XID, + pub signing_key: MlDsaPublicKey, + pub encapsulation_key: MlKemPublicKey, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PeerStatus { + Disconnected, + Initiator, + Responder, + Connected, +} + +#[derive(Debug, Clone)] +pub enum QlFsmEvent { + NewPeer(Peer), + ClearPeer, + PeerStatusChanged { peer: XID, status: PeerStatus }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum QlSessionEvent { + Opened(StreamId), + Data { stream_id: StreamId, bytes: Vec }, + Finished(StreamId), + Closed(StreamClose), + WritableClosed(StreamId), + Unpaired, + SessionClosed(SessionCloseBody), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct SessionWriteId(pub SessionSeq); + +#[derive(Debug, Clone, PartialEq)] +pub struct OutboundWrite { + pub record: QlRecord, + pub session_write_id: Option, +} + +#[derive(Debug, Clone, Copy)] +pub struct QlFsmConfig { + pub handshake_timeout: Duration, + pub handshake_retry_interval: Duration, + pub max_handshake_retries: u8, + pub control_expiration: Duration, + pub session_ack_delay: Duration, + pub session_retransmit_timeout: Duration, + pub session_keepalive_interval: Duration, + pub session_peer_timeout: Duration, +} + +impl Default for QlFsmConfig { + fn default() -> Self { + Self { + handshake_timeout: Duration::from_secs(5), + handshake_retry_interval: Duration::from_millis(750), + max_handshake_retries: 3, + control_expiration: Duration::from_secs(30), + session_ack_delay: Duration::from_millis(5), + session_retransmit_timeout: Duration::from_millis(150), + session_keepalive_interval: Duration::from_secs(10), + session_peer_timeout: Duration::from_secs(30), + } + } +} + +pub struct QlFsm { + pub config: QlFsmConfig, + pub identity: QlIdentity, + pub(crate) peer: Option, + pub(crate) session: SessionFsm, + pub(crate) state: QlFsmState, +} + impl QlFsm { - pub fn new( - config: QlFsmConfig, - identity: QlIdentity, - peer: Option, - now: FsmTime, - ) -> Self { - let peer = peer.map(PeerRecord::new); + pub fn new(config: QlFsmConfig, identity: QlIdentity, now: FsmTime) -> Self { Self { config, identity, - peer, + peer: None, + session: session::SessionFsm::new( + session::SessionFsmConfig { + local_namespace: session::StreamNamespace::Low, + ack_delay: config.session_ack_delay, + retransmit_timeout: config.session_retransmit_timeout, + keepalive_interval: config.session_keepalive_interval, + peer_timeout: config.session_peer_timeout, + }, + now.instant, + ), state: QlFsmState { replay_cache: ReplayCache::default(), next_control_id: 1, outbound: Default::default(), events: Default::default(), + session_events: Default::default(), now, }, } } pub fn bind_peer(&mut self, peer: Peer) { - self.bind_peer_inner(peer); + implementation::handle_bind_peer(self, peer); } pub fn pair(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { self.state.now = now; - self.pair_inner(crypto) + implementation::handle_pair_local(self, crypto) } pub fn connect(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { self.state.now = now; - self.connect_inner(crypto) + implementation::handle_connect(self, crypto) } pub fn receive( @@ -62,23 +151,90 @@ impl QlFsm { crypto: &impl QlCrypto, ) -> Result<(), QlFsmError> { self.state.now = now; - self.receive_inner(bytes, crypto) + implementation::receive(self, bytes, crypto) } pub fn on_timer(&mut self, now: FsmTime) { self.state.now = now; - self.on_timer_inner(); + implementation::on_timer(self); } pub fn next_deadline(&self) -> Option { - self.next_deadline_inner() + implementation::next_deadline(self) + } + + /// Returns the next outbound record. + /// + /// If `session_write_id` is `Some`, it must be followed by exactly one of + /// [`Self::confirm_session_write`] or [`Self::return_session_write`]. + /// + /// If `session_write_id` is `None`, the record is fire-and-forget. + pub fn take_next_write( + &mut self, + now: FsmTime, + crypto: &impl QlCrypto, + ) -> Option { + self.state.now = now; + implementation::take_next_write(self, crypto) + } + + /// Marks a previously issued session write as successfully handed to the transport. + /// + /// This must be called at most once for a `SessionWriteId` returned by + /// [`Self::take_next_write`] whose `session_write_id` was `Some`. + pub fn confirm_session_write(&mut self, now: FsmTime, write_id: SessionWriteId) { + self.state.now = now; + implementation::confirm_session_write(self, write_id); } - pub fn take_next_outbound(&mut self) -> Option { - self.take_next_outbound_inner() + /// Reports that a previously issued session write was not accepted by the transport. + /// + /// This must be called at most once for a `SessionWriteId` returned by + /// [`Self::take_next_write`] whose `session_write_id` was `Some`. + pub fn reject_session_write(&mut self, write_id: SessionWriteId) { + implementation::return_session_write(self, write_id); + } + + /// Aborts the current encrypted session locally. + pub fn kill_session(&mut self, code: CloseCode) { + implementation::kill_session(self, code); } pub fn take_next_event(&mut self) -> Option { - self.take_next_event_inner() + implementation::take_next_event(self) + } + + pub fn open_stream(&mut self) -> Result { + implementation::open_stream(self) + } + + pub fn write_stream(&mut self, stream_id: StreamId, bytes: Vec) -> Result<(), QlFsmError> { + implementation::write_stream(self, stream_id, bytes) + } + + pub fn finish_stream(&mut self, stream_id: StreamId) -> Result<(), QlFsmError> { + implementation::finish_stream(self, stream_id) + } + + pub fn close_stream( + &mut self, + stream_id: StreamId, + target: CloseTarget, + code: CloseCode, + payload: Vec, + ) -> Result<(), QlFsmError> { + implementation::close_stream(self, stream_id, target, code, payload) + } + + pub fn queue_ping(&mut self) -> Result<(), QlFsmError> { + implementation::queue_ping(self) + } + + pub fn queue_unpair(&mut self) -> Result<(), QlFsmError> { + implementation::queue_unpair(self) + } + + pub fn take_next_session_event(&mut self) -> Option { + implementation::take_next_session_event(self) } } diff --git a/ql-fsm/src/session/internal.rs b/ql-fsm/src/session/internal.rs deleted file mode 100644 index 6a0c931b..00000000 --- a/ql-fsm/src/session/internal.rs +++ /dev/null @@ -1,628 +0,0 @@ -use std::time::Instant; - -use ql_wire::{ - encrypted::{heartbeat::HeartbeatBody, unpair::UnpairBody}, - CloseCode, CloseTarget, SessionBody, SessionCloseBody, SessionEnvelope, SessionSeq, - StreamCloseFrame, StreamFrame, StreamId, -}; - -use super::{ - ring::SeqRingInsertError, - state::{ - AckState, PendingChunk, PendingSessionBody, PendingStreamBody, SessionFsmState, StreamRole, - StreamState, TxEntry, - }, - SessionEvent, SessionFsm, SessionFsmConfig, SessionState, StreamError, StreamIncoming, -}; - -impl SessionFsm { - pub fn new_inner(config: SessionFsmConfig) -> Self { - let now = Instant::now(); - Self { - config, - state: SessionFsmState { - now, - session_state: SessionState::Open, - next_stream_ordinal: 1, - next_seq: SessionSeq(1), - tx_ring: super::ring::SeqRing::new(SessionSeq(1)), - rx_ring: super::ring::SeqRing::new(SessionSeq(1)), - ack_state: AckState::Idle, - pending_control: Default::default(), - streams: Default::default(), - ready_streams: Default::default(), - events: Default::default(), - }, - } - } - - pub fn open_stream_inner(&mut self) -> Result { - self.ensure_session_open()?; - let stream_id = - StreamId(self.config.local_namespace.bit() | self.state.next_stream_ordinal); - self.state.next_stream_ordinal = self.state.next_stream_ordinal.saturating_add(1); - self.state - .streams - .insert(stream_id, StreamState::new(StreamRole::Initiator)); - Ok(stream_id) - } - - pub fn write_stream_inner( - &mut self, - stream_id: StreamId, - bytes: Vec, - ) -> Result<(), StreamError> { - self.ensure_session_open()?; - if bytes.is_empty() { - return Ok(()); - } - - let stream = self - .state - .streams - .get_mut(&stream_id) - .ok_or(StreamError::MissingStream)?; - if !stream.is_writable() { - return Err(StreamError::NotWritable); - } - - let frame = StreamFrame { - stream_id, - offset: stream.next_send_offset, - bytes, - fin: false, - }; - stream.next_send_offset += frame.bytes.len() as u64; - stream - .send_queue - .push_back(PendingStreamBody::Stream(frame)); - Self::mark_stream_ready(&mut self.state, stream_id); - Ok(()) - } - - pub fn finish_stream_inner(&mut self, stream_id: StreamId) -> Result<(), StreamError> { - self.ensure_session_open()?; - let stream = self - .state - .streams - .get_mut(&stream_id) - .ok_or(StreamError::MissingStream)?; - if !stream.is_writable() { - return Err(StreamError::NotWritable); - } - - stream.outbound_finished = true; - stream - .send_queue - .push_back(PendingStreamBody::Stream(StreamFrame { - stream_id, - offset: stream.next_send_offset, - bytes: Vec::new(), - fin: true, - })); - Self::mark_stream_ready(&mut self.state, stream_id); - Ok(()) - } - - pub fn close_stream_inner( - &mut self, - stream_id: StreamId, - target: CloseTarget, - code: CloseCode, - payload: Vec, - ) -> Result<(), StreamError> { - self.ensure_session_open()?; - let stream = self - .state - .streams - .get_mut(&stream_id) - .ok_or(StreamError::MissingStream)?; - - Self::apply_close_to_stream(stream, target); - stream - .send_queue - .push_back(PendingStreamBody::StreamClose(StreamCloseFrame { - stream_id, - target, - code, - payload, - })); - Self::mark_stream_ready(&mut self.state, stream_id); - Ok(()) - } - - pub fn queue_heartbeat_inner(&mut self) -> Result<(), StreamError> { - self.ensure_session_open()?; - self.state.pending_control.heartbeat = true; - Ok(()) - } - - pub fn queue_unpair_inner(&mut self) -> Result<(), StreamError> { - self.ensure_session_open()?; - self.state.pending_control.unpair = true; - Ok(()) - } - - pub fn close_session_inner(&mut self, code: CloseCode) { - self.fail_session(SessionCloseBody { code }); - } - - pub fn receive_inner(&mut self, envelope: SessionEnvelope) { - self.collect_timeouts(); - self.process_ack(envelope.ack); - - if self.state.session_state == SessionState::Closed { - return; - } - - let seq = envelope.seq; - if seq.0 < self.state.rx_ring.base_seq().0 || self.state.rx_ring.contains_key(&seq) { - self.schedule_ack(true); - return; - } - match self.state.rx_ring.insert(seq, ()) { - Ok(()) => { - let out_of_order = seq != self.state.rx_ring.base_seq(); - self.state.rx_ring.advance_occupied_front(); - self.schedule_ack(out_of_order); - } - Err(SeqRingInsertError::OutOfWindow) => { - self.fail_session(SessionCloseBody { - code: CloseCode::PROTOCOL, - }); - return; - } - Err(SeqRingInsertError::Occupied) => { - self.schedule_ack(true); - return; - } - } - - match envelope.body { - SessionBody::Heartbeat(_) => {} - SessionBody::Unpair(_) => { - self.state.session_state = SessionState::Closed; - self.clear_streams(); - self.state.events.push_back(SessionEvent::Unpaired); - } - SessionBody::Close(close) => { - self.state.session_state = SessionState::Closed; - self.clear_streams(); - self.state - .events - .push_back(SessionEvent::SessionClosed(close)); - } - SessionBody::Stream(frame) => self.handle_stream_frame(frame), - SessionBody::StreamClose(frame) => self.handle_stream_close(frame), - } - } - - pub fn next_outbound_inner(&mut self) -> Option { - self.collect_timeouts(); - let pending = self.next_pending_body()?; - if !self.state.tx_ring.accepts_seq(self.state.next_seq) { - if pending.priority { - self.requeue_pending_front(pending); - } - return None; - } - - let seq = self.state.next_seq; - self.state.next_seq = SessionSeq(seq.0 + 1); - let ack = self.state.current_ack(); - self.state.clear_ack_schedule(); - let envelope = SessionEnvelope { - seq, - ack, - body: pending.body.clone(), - }; - let entry = TxEntry { - pending, - sent_at: self.state.now, - }; - let _ = self.state.tx_ring.insert(seq, entry); - Some(envelope) - } - - pub fn on_timer_inner(&mut self) { - self.collect_timeouts(); - if let AckState::Delayed { due_at } = self.state.ack_state { - if due_at <= self.state.now { - self.state.ack_state = AckState::Immediate; - } - } - } - - pub fn next_deadline_inner(&self) -> Option { - let ack_deadline = match self.state.ack_state { - AckState::Idle => None, - AckState::Immediate => Some(self.state.now), - AckState::Delayed { due_at } => Some(due_at), - }; - let retransmit_deadline = self - .state - .tx_ring - .iter() - .map(|(_, entry)| entry.sent_at + self.config.retransmit_timeout) - .min(); - [ack_deadline, retransmit_deadline] - .into_iter() - .flatten() - .min() - } - - pub fn take_next_event_inner(&mut self) -> Option { - self.state.events.pop_front() - } - - pub fn take_next_inbound_inner(&mut self, stream_id: StreamId) -> Option { - self.state - .streams - .get_mut(&stream_id) - .and_then(|stream| stream.inbound_queue.pop_front()) - } - - pub fn session_state_inner(&self) -> SessionState { - self.state.session_state - } - - fn next_pending_body(&mut self) -> Option { - if let Some(close) = self.state.pending_control.close.take() { - return Some(PendingSessionBody { - body: SessionBody::Close(close), - retransmit: true, - priority: true, - }); - } - if self.state.pending_control.unpair { - self.state.pending_control.unpair = false; - return Some(PendingSessionBody { - body: SessionBody::Unpair(UnpairBody), - retransmit: true, - priority: true, - }); - } - if self.state.pending_control.heartbeat { - self.state.pending_control.heartbeat = false; - return Some(PendingSessionBody { - body: SessionBody::Heartbeat(HeartbeatBody), - retransmit: false, - priority: true, - }); - } - - while let Some(stream_id) = self.state.ready_streams.pop_front() { - let Some(stream) = self.state.streams.get_mut(&stream_id) else { - continue; - }; - stream.ready_enqueued = false; - let Some(item) = stream.send_queue.pop_front() else { - continue; - }; - if !stream.send_queue.is_empty() { - Self::mark_stream_ready(&mut self.state, stream_id); - } - return Some(PendingSessionBody { - body: item.to_session_body(), - retransmit: true, - priority: true, - }); - } - - let ack_due = match self.state.ack_state { - AckState::Immediate => true, - AckState::Delayed { due_at } => due_at <= self.state.now, - AckState::Idle => false, - }; - ack_due.then_some(PendingSessionBody { - body: SessionBody::Heartbeat(HeartbeatBody), - retransmit: false, - priority: false, - }) - } - - fn ensure_session_open(&self) -> Result<(), StreamError> { - if self.state.session_state == SessionState::Closed { - Err(StreamError::SessionClosed) - } else { - Ok(()) - } - } - - fn process_ack(&mut self, ack: ql_wire::SessionAck) { - let acked: Vec<_> = self - .state - .tx_ring - .iter() - .filter_map(|(seq, _)| Self::ack_covers(ack, seq).then_some(seq)) - .collect(); - for seq in acked { - let _ = self.state.tx_ring.remove(&seq); - } - self.state - .tx_ring - .advance_empty_front_until(self.state.next_seq); - } - - fn ack_covers(ack: ql_wire::SessionAck, seq: SessionSeq) -> bool { - if seq.0 <= ack.base.0 { - return true; - } - let delta = seq.0 - ack.base.0; - if delta == 0 || delta > 64 { - return false; - } - (ack.bitmap & (1u64 << (delta - 1))) != 0 - } - - fn schedule_ack(&mut self, immediate: bool) { - self.state.ack_state = match self.state.ack_state { - AckState::Immediate => AckState::Immediate, - _ if immediate || self.config.ack_delay.is_zero() => AckState::Immediate, - AckState::Delayed { due_at } => AckState::Delayed { due_at }, - AckState::Idle => AckState::Delayed { - due_at: self.state.now + self.config.ack_delay, - }, - }; - } - - fn collect_timeouts(&mut self) { - let expired: Vec<_> = self - .state - .tx_ring - .iter() - .filter_map(|(seq, entry)| { - (entry.sent_at + self.config.retransmit_timeout <= self.state.now).then_some(seq) - }) - .collect(); - - for seq in expired { - if let Some(entry) = self.state.tx_ring.remove(&seq) { - if entry.pending.retransmit { - self.requeue_pending_front(entry.pending); - } - } - } - - self.state - .tx_ring - .advance_empty_front_until(self.state.next_seq); - } - - fn requeue_pending_front(&mut self, pending: PendingSessionBody) { - match pending.body { - SessionBody::Stream(frame) => { - if let Some(stream) = self.state.streams.get_mut(&frame.stream_id) { - let stream_id = frame.stream_id; - stream - .send_queue - .push_front(PendingStreamBody::Stream(frame)); - Self::mark_stream_ready_front(&mut self.state, stream_id); - } - } - SessionBody::StreamClose(frame) => { - if let Some(stream) = self.state.streams.get_mut(&frame.stream_id) { - let stream_id = frame.stream_id; - stream - .send_queue - .push_front(PendingStreamBody::StreamClose(frame)); - Self::mark_stream_ready_front(&mut self.state, stream_id); - } - } - body => match body { - SessionBody::Heartbeat(_) => self.state.pending_control.heartbeat = true, - SessionBody::Unpair(_) => self.state.pending_control.unpair = true, - SessionBody::Close(close) => self.state.pending_control.close = Some(close), - SessionBody::Stream(_) | SessionBody::StreamClose(_) => unreachable!(), - }, - } - } - - fn mark_stream_ready(state: &mut SessionFsmState, stream_id: StreamId) { - let Some(stream) = state.streams.get_mut(&stream_id) else { - return; - }; - if stream.ready_enqueued { - return; - } - stream.ready_enqueued = true; - state.ready_streams.push_back(stream_id); - } - - fn mark_stream_ready_front(state: &mut SessionFsmState, stream_id: StreamId) { - let Some(stream) = state.streams.get_mut(&stream_id) else { - return; - }; - if stream.ready_enqueued { - return; - } - stream.ready_enqueued = true; - state.ready_streams.push_front(stream_id); - } - - fn handle_stream_frame(&mut self, frame: StreamFrame) { - let stream_id = frame.stream_id; - let remote_namespace = self.config.local_namespace.remote(); - if !self.state.streams.contains_key(&stream_id) { - if !remote_namespace.matches(stream_id) || frame.offset != 0 { - self.fail_session(SessionCloseBody { - code: CloseCode::PROTOCOL, - }); - return; - } - self.state - .streams - .insert(stream_id, StreamState::new(StreamRole::Responder)); - self.state.events.push_back(SessionEvent::Opened(stream_id)); - } - - let Some(stream) = self.state.streams.get_mut(&stream_id) else { - return; - }; - if stream.inbound_discarding { - return; - } - if stream.inbound_closed || stream.inbound_finished { - if frame.offset + frame.bytes.len() as u64 <= stream.next_recv_offset { - return; - } - self.fail_session(SessionCloseBody { - code: CloseCode::PROTOCOL, - }); - return; - } - - if frame.offset < stream.next_recv_offset { - let frame_end = frame.offset + frame.bytes.len() as u64; - if frame_end <= stream.next_recv_offset { - return; - } - self.fail_session(SessionCloseBody { - code: CloseCode::PROTOCOL, - }); - return; - } - - if frame.offset == stream.next_recv_offset { - Self::commit_inbound_frame(stream, frame); - Self::drain_pending_recv(stream); - self.state - .events - .push_back(SessionEvent::Readable(stream_id)); - return; - } - - if Self::insert_pending_chunk( - stream, - frame.offset, - PendingChunk { - bytes: frame.bytes, - fin: frame.fin, - }, - ) - .is_err() - { - self.fail_session(SessionCloseBody { - code: CloseCode::PROTOCOL, - }); - } - } - - fn handle_stream_close(&mut self, frame: StreamCloseFrame) { - let Some(stream) = self.state.streams.get_mut(&frame.stream_id) else { - self.fail_session(SessionCloseBody { - code: CloseCode::PROTOCOL, - }); - return; - }; - - if Self::target_affects_inbound(stream.role, frame.target) { - stream.inbound_closed = true; - stream.inbound_discarding = false; - stream.pending_recv.clear(); - stream - .inbound_queue - .push_back(StreamIncoming::Closed(frame.clone())); - self.state - .events - .push_back(SessionEvent::Readable(frame.stream_id)); - } - if Self::target_affects_outbound(stream.role, frame.target) { - stream.outbound_closed = true; - stream.send_queue.clear(); - self.state - .events - .push_back(SessionEvent::WritableClosed(frame.stream_id)); - } - } - - fn apply_close_to_stream(stream: &mut StreamState, target: CloseTarget) { - if Self::target_affects_inbound(stream.role, target) { - stream.inbound_discarding = true; - stream.pending_recv.clear(); - } - if Self::target_affects_outbound(stream.role, target) { - stream.outbound_closed = true; - stream.outbound_finished = true; - stream.send_queue.clear(); - } - } - - fn target_affects_inbound(role: StreamRole, target: CloseTarget) -> bool { - matches!(target, CloseTarget::Both) || role.inbound_target() == target - } - - fn target_affects_outbound(role: StreamRole, target: CloseTarget) -> bool { - matches!(target, CloseTarget::Both) || role.outbound_target() == target - } - - fn commit_inbound_frame(stream: &mut StreamState, frame: StreamFrame) { - Self::commit_inbound_chunk(stream, frame.bytes, frame.fin); - } - - fn commit_inbound_chunk(stream: &mut StreamState, bytes: Vec, fin: bool) { - stream.next_recv_offset += bytes.len() as u64; - if !bytes.is_empty() { - stream.inbound_queue.push_back(StreamIncoming::Data(bytes)); - } - if fin { - stream.inbound_finished = true; - stream.inbound_queue.push_back(StreamIncoming::Finished); - } - } - - fn drain_pending_recv(stream: &mut StreamState) { - while let Some(chunk) = stream.pending_recv.remove(&stream.next_recv_offset) { - Self::commit_inbound_chunk(stream, chunk.bytes, chunk.fin); - if stream.inbound_finished { - break; - } - } - } - - fn insert_pending_chunk( - stream: &mut StreamState, - offset: u64, - chunk: PendingChunk, - ) -> Result<(), ()> { - let end = chunk.end_offset(offset); - - if let Some((&prev_offset, prev)) = stream.pending_recv.range(..=offset).next_back() { - let prev_end = prev.end_offset(prev_offset); - if prev_end > offset { - if prev_offset == offset && prev.bytes == chunk.bytes && prev.fin == chunk.fin { - return Ok(()); - } - return Err(()); - } - } - - if let Some((&next_offset, _)) = stream.pending_recv.range(offset..).next() { - if end > next_offset { - return Err(()); - } - } - - stream.pending_recv.insert(offset, chunk); - Ok(()) - } - - fn fail_session(&mut self, close: SessionCloseBody) { - if self.state.session_state == SessionState::Closed { - return; - } - - self.state.session_state = SessionState::Closed; - self.clear_streams(); - self.state.pending_control = Default::default(); - self.state.pending_control.close = Some(close.clone()); - self.state - .events - .push_back(SessionEvent::SessionClosed(close)); - } - - fn clear_streams(&mut self) { - self.state.ready_streams.clear(); - self.state.streams.clear(); - } -} diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 570527be..e8328c29 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -1,4 +1,3 @@ -pub(crate) mod internal; pub(crate) mod ring; pub(crate) mod state; @@ -8,10 +7,17 @@ mod tests; use std::time::{Duration, Instant}; use ql_wire::{ - CloseCode, CloseTarget, SessionCloseBody, SessionEnvelope, StreamCloseFrame, StreamId, XID, + CloseCode, CloseTarget, PingBody, SessionBody, SessionCloseBody, SessionEnvelope, SessionSeq, + StreamChunk, StreamClose, StreamId, UnpairBody, XID, }; -use self::state::SessionFsmState; +use self::{ + ring::SeqRingInsertError, + state::{ + AckState, PendingChunk, PendingSessionBody, PendingStreamBody, SessionFsmState, StreamRole, + StreamState, TxEntry, TxState, + }, +}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum StreamNamespace { @@ -53,6 +59,8 @@ pub struct SessionFsmConfig { pub local_namespace: StreamNamespace, pub ack_delay: Duration, pub retransmit_timeout: Duration, + pub keepalive_interval: Duration, + pub peer_timeout: Duration, } impl Default for SessionFsmConfig { @@ -61,6 +69,8 @@ impl Default for SessionFsmConfig { local_namespace: StreamNamespace::Low, ack_delay: Duration::from_millis(5), retransmit_timeout: Duration::from_millis(150), + keepalive_interval: Duration::from_secs(10), + peer_timeout: Duration::from_secs(30), } } } @@ -78,7 +88,7 @@ pub enum SessionEvent { pub enum StreamIncoming { Data(Vec), Finished, - Closed(StreamCloseFrame), + Closed(StreamClose), } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -103,20 +113,85 @@ pub struct SessionFsm { } impl SessionFsm { - pub fn new(config: SessionFsmConfig) -> Self { - Self::new_inner(config) + pub fn new(config: SessionFsmConfig, now: Instant) -> Self { + Self { + config, + state: SessionFsmState { + now, + last_activity_at: now, + last_inbound_at: now, + session_state: SessionState::Open, + next_stream_ordinal: 1, + next_seq: SessionSeq(1), + tx_ring: ring::SeqRing::new(SessionSeq(1)), + rx_ring: ring::SeqRing::new(SessionSeq(1)), + ack_state: AckState::Idle, + pending_control: Default::default(), + streams: Default::default(), + next_stream_index: 0, + events: Default::default(), + }, + } } pub fn open_stream(&mut self) -> Result { - self.open_stream_inner() + self.ensure_session_open()?; + let stream_id = + StreamId(self.config.local_namespace.bit() | self.state.next_stream_ordinal); + self.state.next_stream_ordinal = self.state.next_stream_ordinal.saturating_add(1); + self.state + .streams + .insert(stream_id, StreamState::new(StreamRole::Initiator)); + Ok(stream_id) } pub fn write_stream(&mut self, stream_id: StreamId, bytes: Vec) -> Result<(), StreamError> { - self.write_stream_inner(stream_id, bytes) + self.ensure_session_open()?; + if bytes.is_empty() { + return Ok(()); + } + + let stream = self + .state + .streams + .get_mut(&stream_id) + .ok_or(StreamError::MissingStream)?; + if !stream.is_writable() { + return Err(StreamError::NotWritable); + } + + let frame = StreamChunk { + stream_id, + offset: stream.next_send_offset, + bytes, + fin: false, + }; + stream.next_send_offset += frame.bytes.len() as u64; + stream.send_queue.push_back(PendingStreamBody::Chunk(frame)); + Ok(()) } pub fn finish_stream(&mut self, stream_id: StreamId) -> Result<(), StreamError> { - self.finish_stream_inner(stream_id) + self.ensure_session_open()?; + let stream = self + .state + .streams + .get_mut(&stream_id) + .ok_or(StreamError::MissingStream)?; + if !stream.is_writable() { + return Err(StreamError::NotWritable); + } + + stream.outbound_finished = true; + stream + .send_queue + .push_back(PendingStreamBody::Chunk(StreamChunk { + stream_id, + offset: stream.next_send_offset, + bytes: Vec::new(), + fin: true, + })); + Ok(()) } pub fn close_stream( @@ -126,49 +201,632 @@ impl SessionFsm { code: CloseCode, payload: Vec, ) -> Result<(), StreamError> { - self.close_stream_inner(stream_id, target, code, payload) + self.ensure_session_open()?; + let stream = self + .state + .streams + .get_mut(&stream_id) + .ok_or(StreamError::MissingStream)?; + + Self::apply_close_to_stream(stream, target); + stream + .send_queue + .push_back(PendingStreamBody::Close(StreamClose { + stream_id, + target, + code, + payload, + })); + Ok(()) } - pub fn queue_heartbeat(&mut self) -> Result<(), StreamError> { - self.queue_heartbeat_inner() + pub fn queue_ping(&mut self) -> Result<(), StreamError> { + self.ensure_session_open()?; + self.state.pending_control.ping = true; + Ok(()) } pub fn queue_unpair(&mut self) -> Result<(), StreamError> { - self.queue_unpair_inner() + self.ensure_session_open()?; + self.state.pending_control.unpair = true; + Ok(()) } - pub fn close_session(&mut self, code: CloseCode) { - self.close_session_inner(code); + pub fn receive(&mut self, now: Instant, envelope: SessionEnvelope) { + self.state.now = now; + self.collect_timeouts(); + self.process_ack(envelope.ack); + + if self.state.session_state == SessionState::Closed { + return; + } + + self.state.last_activity_at = self.state.now; + self.state.last_inbound_at = self.state.now; + + let seq = envelope.seq; + if seq.0 < self.state.rx_ring.base_seq().0 || self.state.rx_ring.contains_key(&seq) { + if !matches!(envelope.body, SessionBody::Ack) { + self.schedule_ack(true); + } + return; + } + match self.state.rx_ring.insert(seq, ()) { + Ok(()) => { + let out_of_order = seq != self.state.rx_ring.base_seq(); + self.state.rx_ring.advance_occupied_front(); + if !matches!(envelope.body, SessionBody::Ack) { + self.schedule_ack(out_of_order); + } + } + Err(SeqRingInsertError::OutOfWindow) => { + self.fail_session(SessionCloseBody { + code: CloseCode::PROTOCOL, + }); + return; + } + Err(SeqRingInsertError::Occupied) => { + if !matches!(envelope.body, SessionBody::Ack) { + self.schedule_ack(true); + } + return; + } + } + + match envelope.body { + SessionBody::Ack => {} + SessionBody::Ping(_) => {} + SessionBody::Unpair(_) => { + self.state.session_state = SessionState::Closed; + self.clear_streams(); + self.state.events.push_back(SessionEvent::Unpaired); + } + SessionBody::Close(close) => { + self.state.session_state = SessionState::Closed; + self.clear_streams(); + self.state + .events + .push_back(SessionEvent::SessionClosed(close)); + } + SessionBody::Stream(frame) => self.handle_stream_frame(frame), + SessionBody::StreamClose(frame) => self.handle_stream_close(frame), + } } - pub fn receive(&mut self, now: Instant, envelope: SessionEnvelope) { + pub fn take_next_write(&mut self, now: Instant) -> Option { self.state.now = now; - self.receive_inner(envelope); + self.collect_timeouts(); + let ack = self.state.current_ack(); + if let Some(seq) = self + .state + .tx_ring + .iter() + .find_map(|(seq, entry)| matches!(entry.state, TxState::Pending).then_some(seq)) + { + let Some(entry) = self.state.tx_ring.get_mut(&seq) else { + return None; + }; + entry.state = TxState::Issued; + return Some(SessionEnvelope { + seq, + ack, + body: entry.pending.body.clone(), + }); + } + + if !self.state.tx_ring.accepts_seq(self.state.next_seq) { + return None; + } + + let pending = self.next_pending_body()?; + let seq = self.state.next_seq; + self.state.next_seq = SessionSeq(seq.0 + 1); + let body = pending.body.clone(); + self.state + .tx_ring + .insert( + seq, + TxEntry { + pending, + state: TxState::Issued, + }, + ) + .unwrap(); + + Some(SessionEnvelope { seq, ack, body }) } - pub fn next_outbound(&mut self, now: Instant) -> Option { + pub fn confirm_write(&mut self, now: Instant, seq: SessionSeq) { self.state.now = now; - self.next_outbound_inner() + let Some((retransmit, should_clear_ack)) = self.state.tx_ring.get(&seq).map(|entry| { + ( + entry.pending.retransmit, + matches!(entry.pending.body, SessionBody::Ack), + ) + }) else { + return; + }; + debug_assert!(matches!( + self.state.tx_ring.get(&seq).map(|entry| entry.state), + Some(TxState::Issued) + )); + if !matches!( + self.state.tx_ring.get(&seq).map(|entry| entry.state), + Some(TxState::Issued) + ) { + return; + } + + self.state.last_activity_at = self.state.now; + if retransmit { + if let Some(entry) = self.state.tx_ring.get_mut(&seq) { + entry.state = TxState::Sent { + sent_at: self.state.now, + }; + } + } else { + let _ = self.state.tx_ring.remove(&seq); + self.state + .tx_ring + .advance_empty_front_until(self.state.next_seq); + if should_clear_ack { + self.state.clear_ack_schedule(); + } + } + } + + pub fn return_write(&mut self, seq: SessionSeq) { + debug_assert!(matches!( + self.state.tx_ring.get(&seq).map(|entry| entry.state), + Some(TxState::Issued) + )); + let Some(entry) = self.state.tx_ring.get_mut(&seq) else { + return; + }; + if !matches!(entry.state, TxState::Issued) { + return; + } + entry.state = TxState::Pending; + } + + #[cfg(test)] + pub fn next_outbound(&mut self, now: Instant) -> Option { + let envelope = self.take_next_write(now)?; + self.confirm_write(now, envelope.seq); + Some(envelope) } pub fn on_timer(&mut self, now: Instant) { self.state.now = now; - self.on_timer_inner(); + self.collect_timeouts(); + if self.state.session_state == SessionState::Closed { + return; + } + if let AckState::Delayed { due_at } = self.state.ack_state { + if due_at <= self.state.now { + self.state.ack_state = AckState::Immediate; + } + } + if !self.config.peer_timeout.is_zero() + && self.state.last_inbound_at + self.config.peer_timeout <= self.state.now + { + self.fail_session(SessionCloseBody { + code: CloseCode::TIMEOUT, + }); + return; + } + if !self.config.keepalive_interval.is_zero() + && self.state.last_activity_at + self.config.keepalive_interval <= self.state.now + { + self.state.pending_control.ping = true; + } } pub fn next_deadline(&self) -> Option { - self.next_deadline_inner() + let ack_deadline = match self.state.ack_state { + AckState::Idle => None, + AckState::Immediate => Some(self.state.now), + AckState::Delayed { due_at } => Some(due_at), + }; + let retransmit_deadline = self + .state + .tx_ring + .iter() + .filter_map(|(_, entry)| match entry.state { + TxState::Sent { sent_at } => Some(sent_at + self.config.retransmit_timeout), + TxState::Pending | TxState::Issued => None, + }) + .min(); + let keepalive_deadline = (self.state.session_state == SessionState::Open + && !self.config.keepalive_interval.is_zero() + && !self.state.pending_control.ping) + .then_some(self.state.last_activity_at + self.config.keepalive_interval); + let peer_timeout_deadline = (self.state.session_state == SessionState::Open + && !self.config.peer_timeout.is_zero()) + .then_some(self.state.last_inbound_at + self.config.peer_timeout); + [ + ack_deadline, + retransmit_deadline, + keepalive_deadline, + peer_timeout_deadline, + ] + .into_iter() + .flatten() + .min() } pub fn take_next_event(&mut self) -> Option { - self.take_next_event_inner() + self.state.events.pop_front() } pub fn take_next_inbound(&mut self, stream_id: StreamId) -> Option { - self.take_next_inbound_inner(stream_id) + self.state + .streams + .get_mut(&stream_id) + .and_then(|stream| stream.inbound_queue.pop_front()) } + #[cfg(test)] pub fn session_state(&self) -> SessionState { - self.session_state_inner() + self.state.session_state + } + + pub fn has_pending_stream_work(&self) -> bool { + self.state + .streams + .values() + .any(|stream| !stream.send_queue.is_empty()) + } + + fn next_pending_body(&mut self) -> Option { + if let Some(close) = self.state.pending_control.close.take() { + return Some(PendingSessionBody { + body: SessionBody::Close(close), + retransmit: true, + }); + } + if self.state.pending_control.unpair { + self.state.pending_control.unpair = false; + return Some(PendingSessionBody { + body: SessionBody::Unpair(UnpairBody), + retransmit: true, + }); + } + if self.state.pending_control.ping { + self.state.pending_control.ping = false; + return Some(PendingSessionBody { + body: SessionBody::Ping(PingBody), + retransmit: false, + }); + } + + let len = self.state.streams.len(); + if len > 0 { + let start = self.state.next_stream_index % len; + for offset in 0..len { + let index = (start + offset) % len; + let has_pending = self + .state + .streams + .get_index(index) + .is_some_and(|(_, stream)| !stream.send_queue.is_empty()); + if !has_pending { + continue; + } + + let item = { + let Some((_, stream)) = self.state.streams.get_index_mut(index) else { + continue; + }; + let Some(item) = stream.send_queue.pop_front() else { + continue; + }; + item + }; + self.state.next_stream_index = (index + 1) % len; + return Some(PendingSessionBody { + body: item.to_session_body(), + retransmit: true, + }); + } + } + + let ack_due = match self.state.ack_state { + AckState::Immediate => true, + AckState::Delayed { due_at } => due_at <= self.state.now, + AckState::Idle => false, + }; + ack_due.then_some(PendingSessionBody { + body: SessionBody::Ack, + retransmit: false, + }) + } + + fn ensure_session_open(&self) -> Result<(), StreamError> { + if self.state.session_state == SessionState::Closed { + Err(StreamError::SessionClosed) + } else { + Ok(()) + } + } + + fn process_ack(&mut self, ack: ql_wire::SessionAck) { + let acked: Vec<_> = self + .state + .tx_ring + .iter() + .filter_map(|(seq, entry)| { + (matches!(entry.state, TxState::Sent { .. }) && Self::ack_covers(ack, seq)) + .then_some(seq) + }) + .collect(); + for seq in acked { + let _ = self.state.tx_ring.remove(&seq); + } + self.state + .tx_ring + .advance_empty_front_until(self.state.next_seq); + } + + fn ack_covers(ack: ql_wire::SessionAck, seq: SessionSeq) -> bool { + if seq.0 <= ack.base.0 { + return true; + } + let delta = seq.0 - ack.base.0; + if delta == 0 || delta > 64 { + return false; + } + (ack.bitmap & (1u64 << (delta - 1))) != 0 + } + + fn schedule_ack(&mut self, immediate: bool) { + self.state.ack_state = match self.state.ack_state { + AckState::Immediate => AckState::Immediate, + _ if immediate || self.config.ack_delay.is_zero() => AckState::Immediate, + AckState::Delayed { due_at } => AckState::Delayed { due_at }, + AckState::Idle => AckState::Delayed { + due_at: self.state.now + self.config.ack_delay, + }, + }; + } + + fn collect_timeouts(&mut self) { + let expired: Vec<_> = self + .state + .tx_ring + .iter() + .filter_map(|(seq, entry)| match entry.state { + TxState::Sent { sent_at } + if sent_at + self.config.retransmit_timeout <= self.state.now => + { + Some(seq) + } + TxState::Pending | TxState::Issued | TxState::Sent { .. } => None, + }) + .collect(); + + for seq in expired { + if let Some(entry) = self.state.tx_ring.remove(&seq) { + if entry.pending.retransmit { + self.requeue_pending_front(entry.pending); + } + } + } + + self.state + .tx_ring + .advance_empty_front_until(self.state.next_seq); + } + + fn requeue_pending_front(&mut self, pending: PendingSessionBody) { + match pending.body { + SessionBody::Stream(frame) => { + if let Some(stream) = self.state.streams.get_mut(&frame.stream_id) { + stream + .send_queue + .push_front(PendingStreamBody::Chunk(frame)); + } + } + SessionBody::StreamClose(frame) => { + if let Some(stream) = self.state.streams.get_mut(&frame.stream_id) { + stream + .send_queue + .push_front(PendingStreamBody::Close(frame)); + } + } + body => match body { + SessionBody::Ack => {} + SessionBody::Ping(_) => self.state.pending_control.ping = true, + SessionBody::Unpair(_) => self.state.pending_control.unpair = true, + SessionBody::Close(close) => self.state.pending_control.close = Some(close), + SessionBody::Stream(_) | SessionBody::StreamClose(_) => unreachable!(), + }, + } + } + + fn handle_stream_frame(&mut self, frame: StreamChunk) { + let stream_id = frame.stream_id; + let remote_namespace = self.config.local_namespace.remote(); + if !self.state.streams.contains_key(&stream_id) { + if !remote_namespace.matches(stream_id) { + self.fail_session(SessionCloseBody { + code: CloseCode::PROTOCOL, + }); + return; + } + self.state + .streams + .insert(stream_id, StreamState::new(StreamRole::Responder)); + self.state.events.push_back(SessionEvent::Opened(stream_id)); + } + + let Some(stream) = self.state.streams.get_mut(&stream_id) else { + return; + }; + if stream.inbound_discarding { + return; + } + if stream.inbound_closed || stream.inbound_finished { + if frame.offset + frame.bytes.len() as u64 <= stream.next_recv_offset { + return; + } + self.fail_session(SessionCloseBody { + code: CloseCode::PROTOCOL, + }); + return; + } + + if frame.offset < stream.next_recv_offset { + let frame_end = frame.offset + frame.bytes.len() as u64; + if frame_end <= stream.next_recv_offset { + return; + } + self.fail_session(SessionCloseBody { + code: CloseCode::PROTOCOL, + }); + return; + } + + if frame.offset == stream.next_recv_offset { + Self::commit_inbound_frame(stream, frame); + Self::drain_pending_recv(stream); + self.state + .events + .push_back(SessionEvent::Readable(stream_id)); + return; + } + + if Self::insert_pending_chunk( + stream, + frame.offset, + PendingChunk { + bytes: frame.bytes, + fin: frame.fin, + }, + ) + .is_err() + { + self.fail_session(SessionCloseBody { + code: CloseCode::PROTOCOL, + }); + } + } + + fn handle_stream_close(&mut self, frame: StreamClose) { + let Some(stream) = self.state.streams.get_mut(&frame.stream_id) else { + self.fail_session(SessionCloseBody { + code: CloseCode::PROTOCOL, + }); + return; + }; + + if Self::target_affects_inbound(stream.role, frame.target) && !stream.inbound_closed { + stream.inbound_closed = true; + stream.inbound_discarding = false; + stream.pending_recv.clear(); + stream + .inbound_queue + .push_back(StreamIncoming::Closed(frame.clone())); + self.state + .events + .push_back(SessionEvent::Readable(frame.stream_id)); + } + if Self::target_affects_outbound(stream.role, frame.target) && !stream.outbound_closed { + stream.outbound_closed = true; + stream.send_queue.clear(); + self.state + .events + .push_back(SessionEvent::WritableClosed(frame.stream_id)); + } + } + + fn apply_close_to_stream(stream: &mut StreamState, target: CloseTarget) { + if Self::target_affects_inbound(stream.role, target) { + stream.inbound_discarding = true; + stream.pending_recv.clear(); + } + if Self::target_affects_outbound(stream.role, target) { + stream.outbound_closed = true; + stream.outbound_finished = true; + stream.send_queue.clear(); + } + } + + fn target_affects_inbound(role: StreamRole, target: CloseTarget) -> bool { + matches!(target, CloseTarget::Both) || role.inbound_target() == target + } + + fn target_affects_outbound(role: StreamRole, target: CloseTarget) -> bool { + matches!(target, CloseTarget::Both) || role.outbound_target() == target + } + + fn commit_inbound_frame(stream: &mut StreamState, frame: StreamChunk) { + Self::commit_inbound_chunk(stream, frame.bytes, frame.fin); + } + + fn commit_inbound_chunk(stream: &mut StreamState, bytes: Vec, fin: bool) { + stream.next_recv_offset += bytes.len() as u64; + if !bytes.is_empty() { + stream.inbound_queue.push_back(StreamIncoming::Data(bytes)); + } + if fin { + stream.inbound_finished = true; + stream.inbound_queue.push_back(StreamIncoming::Finished); + } + } + + fn drain_pending_recv(stream: &mut StreamState) { + while let Some(chunk) = stream.pending_recv.remove(&stream.next_recv_offset) { + Self::commit_inbound_chunk(stream, chunk.bytes, chunk.fin); + if stream.inbound_finished { + break; + } + } + } + + fn insert_pending_chunk( + stream: &mut StreamState, + offset: u64, + chunk: PendingChunk, + ) -> Result<(), ()> { + let end = chunk.end_offset(offset); + + if let Some((&prev_offset, prev)) = stream.pending_recv.range(..=offset).next_back() { + let prev_end = prev.end_offset(prev_offset); + if prev_end > offset { + if prev_offset == offset && prev.bytes == chunk.bytes && prev.fin == chunk.fin { + return Ok(()); + } + return Err(()); + } + } + + if let Some((&next_offset, _)) = stream.pending_recv.range(offset..).next() { + if end > next_offset { + return Err(()); + } + } + + stream.pending_recv.insert(offset, chunk); + Ok(()) + } + + fn fail_session(&mut self, close: SessionCloseBody) { + if self.state.session_state == SessionState::Closed { + return; + } + + self.state.session_state = SessionState::Closed; + self.clear_streams(); + self.state.pending_control = Default::default(); + self.state.pending_control.close = Some(close.clone()); + self.state + .events + .push_back(SessionEvent::SessionClosed(close)); + } + + fn clear_streams(&mut self) { + self.state.next_stream_index = 0; + self.state.streams.clear(); } } diff --git a/ql-fsm/src/session/ring.rs b/ql-fsm/src/session/ring.rs index 872c92a5..b6aad72b 100644 --- a/ql-fsm/src/session/ring.rs +++ b/ql-fsm/src/session/ring.rs @@ -43,6 +43,11 @@ impl SeqRing { self.slots[index].as_ref() } + pub fn get_mut(&mut self, seq: &SessionSeq) -> Option<&mut T> { + let index = self.index_for(*seq)?; + self.slots[index].as_mut() + } + pub fn insert(&mut self, seq: SessionSeq, value: T) -> Result<(), SeqRingInsertError> { let index = self.index_for(seq).ok_or(SeqRingInsertError::OutOfWindow)?; if self.slots[index].is_some() { @@ -139,3 +144,54 @@ impl<'a, const N: usize, T> Iterator for SeqRingIter<'a, N, T> { None } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn insert_iter_and_bitmap() { + let mut ring = SeqRing::<4, u8>::new(SessionSeq(10)); + + ring.insert(SessionSeq(10), 1).unwrap(); + ring.insert(SessionSeq(12), 3).unwrap(); + + assert_eq!(ring.bitmap(), 0b0101); + assert_eq!( + ring.iter() + .map(|(seq, value)| (seq, *value)) + .collect::>(), + vec![(SessionSeq(10), 1), (SessionSeq(12), 3)] + ); + } + + #[test] + fn advance_fronts() { + let mut ring = SeqRing::<4, u8>::new(SessionSeq(10)); + + ring.insert(SessionSeq(11), 2).unwrap(); + ring.advance_empty_front_until(SessionSeq(11)); + assert_eq!(ring.base_seq(), SessionSeq(11)); + assert_eq!(ring.get(&SessionSeq(11)), Some(&2)); + + ring.advance_occupied_front(); + assert_eq!(ring.base_seq(), SessionSeq(12)); + assert!(ring.get(&SessionSeq(11)).is_none()); + } + + #[test] + fn insert_errors() { + let mut ring = SeqRing::<2, u8>::new(SessionSeq(5)); + + ring.insert(SessionSeq(5), 1).unwrap(); + + assert_eq!( + ring.insert(SessionSeq(5), 2), + Err(SeqRingInsertError::Occupied) + ); + assert_eq!( + ring.insert(SessionSeq(7), 3), + Err(SeqRingInsertError::OutOfWindow) + ); + } +} diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index f922a8b0..d08e7758 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -1,15 +1,15 @@ use std::{ - collections::{BTreeMap, HashMap, VecDeque}, + collections::{BTreeMap, VecDeque}, time::Instant, }; +use indexmap::IndexMap; use ql_wire::{ - CloseTarget, SessionAck, SessionBody, SessionCloseBody, SessionSeq, StreamCloseFrame, - StreamFrame, StreamId, + CloseTarget, SessionAck, SessionBody, SessionCloseBody, SessionSeq, StreamChunk, StreamClose, + StreamId, }; -use super::ring::SeqRing; -use super::{SessionEvent, SessionState, StreamIncoming}; +use super::{ring::SeqRing, SessionEvent, SessionState, StreamIncoming}; pub const SESSION_WINDOW_CAPACITY: usize = 64; @@ -49,15 +49,15 @@ impl PendingChunk { #[derive(Debug, Clone)] pub enum PendingStreamBody { - Stream(StreamFrame), - StreamClose(StreamCloseFrame), + Chunk(StreamChunk), + Close(StreamClose), } impl PendingStreamBody { pub fn to_session_body(&self) -> SessionBody { match self { - Self::Stream(frame) => SessionBody::Stream(frame.clone()), - Self::StreamClose(frame) => SessionBody::StreamClose(frame.clone()), + Self::Chunk(frame) => SessionBody::Stream(frame.clone()), + Self::Close(frame) => SessionBody::StreamClose(frame.clone()), } } } @@ -75,7 +75,6 @@ pub struct StreamState { pub inbound_finished: bool, pub inbound_closed: bool, pub inbound_discarding: bool, - pub ready_enqueued: bool, } impl StreamState { @@ -92,7 +91,6 @@ impl StreamState { inbound_finished: false, inbound_closed: false, inbound_discarding: false, - ready_enqueued: false, } } @@ -104,13 +102,13 @@ impl StreamState { #[derive(Debug, Clone)] pub struct PendingSessionBody { pub body: SessionBody, + /// whether the body should be retransmitted after a confirmed send times out without ack pub retransmit: bool, - pub priority: bool, } #[derive(Debug, Clone, Default)] pub struct PendingSessionControl { - pub heartbeat: bool, + pub ping: bool, pub unpair: bool, pub close: Option, } @@ -118,7 +116,14 @@ pub struct PendingSessionControl { #[derive(Debug, Clone)] pub struct TxEntry { pub pending: PendingSessionBody, - pub sent_at: Instant, + pub state: TxState, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TxState { + Pending, + Issued, + Sent { sent_at: Instant }, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -130,6 +135,8 @@ pub enum AckState { pub struct SessionFsmState { pub now: Instant, + pub last_activity_at: Instant, + pub last_inbound_at: Instant, pub session_state: SessionState, pub next_stream_ordinal: u32, pub next_seq: SessionSeq, @@ -137,8 +144,10 @@ pub struct SessionFsmState { pub rx_ring: SeqRing, pub ack_state: AckState, pub pending_control: PendingSessionControl, - pub streams: HashMap, - pub ready_streams: VecDeque, + /// `IndexMap` has stable (and fast) iteration order for round-robin + /// scheduling, so we do not need a separate ready queue + pub streams: IndexMap, + pub next_stream_index: usize, pub events: VecDeque, } diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index e1dbe3cb..a6a050c0 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -1,24 +1,32 @@ use std::time::{Duration, Instant}; use ql_wire::{ - encrypted::heartbeat::HeartbeatBody, CloseCode, CloseTarget, SessionAck, SessionBody, - SessionEnvelope, SessionSeq, StreamFrame, + CloseCode, CloseTarget, PingBody, SessionAck, SessionBody, SessionEnvelope, SessionSeq, + StreamChunk, StreamClose, }; use super::{SessionFsm, SessionFsmConfig, SessionState}; -fn heartbeat(seq: u64, ack: SessionAck) -> SessionEnvelope { +fn ack(seq: u64, ack: SessionAck) -> SessionEnvelope { SessionEnvelope { seq: SessionSeq(seq), ack, - body: SessionBody::Heartbeat(HeartbeatBody), + body: SessionBody::Ack, + } +} + +fn ping(seq: u64, ack: SessionAck) -> SessionEnvelope { + SessionEnvelope { + seq: SessionSeq(seq), + ack, + body: SessionBody::Ping(PingBody), } } #[test] fn outbound_session_seq_increments_monotonically() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default()); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = fsm.open_stream().unwrap(); fsm.write_stream(stream_id, b"one".to_vec()).unwrap(); @@ -34,7 +42,7 @@ fn outbound_session_seq_increments_monotonically() { #[test] fn inbound_ack_removes_acked_tx_entries() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default()); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = fsm.open_stream().unwrap(); fsm.write_stream(stream_id, b"one".to_vec()).unwrap(); @@ -44,7 +52,7 @@ fn inbound_ack_removes_acked_tx_entries() { fsm.receive( now + Duration::from_millis(1), - heartbeat( + ack( 1, SessionAck { base: SessionSeq(1), @@ -59,9 +67,23 @@ fn inbound_ack_removes_acked_tx_entries() { #[test] fn out_of_order_receive_produces_bitmap_ack_then_advances_base() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default()); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let stream_id_a = ql_wire::StreamId(super::StreamNamespace::High.bit() | 1); + let stream_id_b = ql_wire::StreamId(super::StreamNamespace::High.bit() | 2); - fsm.receive(now, heartbeat(2, SessionAck::EMPTY)); + fsm.receive( + now, + SessionEnvelope { + seq: SessionSeq(2), + ack: SessionAck::EMPTY, + body: SessionBody::Stream(StreamChunk { + stream_id: stream_id_a, + offset: 0, + bytes: b"a".to_vec(), + fin: false, + }), + }, + ); let gap_ack = fsm.next_outbound(now).unwrap(); assert_eq!(gap_ack.seq, SessionSeq(1)); assert_eq!( @@ -74,7 +96,16 @@ fn out_of_order_receive_produces_bitmap_ack_then_advances_base() { fsm.receive( now + Duration::from_millis(1), - heartbeat(1, SessionAck::EMPTY), + SessionEnvelope { + seq: SessionSeq(1), + ack: SessionAck::EMPTY, + body: SessionBody::Stream(StreamChunk { + stream_id: stream_id_b, + offset: 0, + bytes: b"b".to_vec(), + fin: false, + }), + }, ); let contiguous_ack = fsm.next_outbound(now + Duration::from_millis(10)).unwrap(); assert_eq!(contiguous_ack.seq, SessionSeq(2)); @@ -90,7 +121,7 @@ fn out_of_order_receive_produces_bitmap_ack_then_advances_base() { #[test] fn retransmit_requeues_body_with_new_session_seq() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default()); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = fsm.open_stream().unwrap(); fsm.write_stream(stream_id, b"retry-me".to_vec()).unwrap(); @@ -107,10 +138,10 @@ fn retransmit_requeues_body_with_new_session_seq() { #[test] fn repeated_outbound_messages_keep_reporting_latest_receive_ack() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default()); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = fsm.open_stream().unwrap(); - fsm.receive(now, heartbeat(1, SessionAck::EMPTY)); + fsm.receive(now, ack(1, SessionAck::EMPTY)); fsm.write_stream(stream_id, b"one".to_vec()).unwrap(); let first = fsm.next_outbound(now).unwrap(); @@ -127,7 +158,7 @@ fn repeated_outbound_messages_keep_reporting_latest_receive_ack() { #[test] fn local_inbound_close_ignores_late_remote_bytes() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default()); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = fsm.open_stream().unwrap(); fsm.close_stream( @@ -143,7 +174,7 @@ fn local_inbound_close_ignores_late_remote_bytes() { SessionEnvelope { seq: SessionSeq(1), ack: SessionAck::EMPTY, - body: SessionBody::Stream(StreamFrame { + body: SessionBody::Stream(StreamChunk { stream_id, offset: 0, bytes: b"late".to_vec(), @@ -156,3 +187,298 @@ fn local_inbound_close_ignores_late_remote_bytes() { assert!(fsm.take_next_inbound(stream_id).is_none()); assert!(fsm.take_next_event().is_none()); } + +#[test] +fn out_of_order_remote_stream_buffers_until_initial_bytes_arrive() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 7); + + fsm.receive( + now, + SessionEnvelope { + seq: SessionSeq(2), + ack: SessionAck::EMPTY, + body: SessionBody::Stream(StreamChunk { + stream_id, + offset: 1, + bytes: b"b".to_vec(), + fin: false, + }), + }, + ); + + assert_eq!(fsm.session_state(), SessionState::Open); + assert_eq!( + fsm.take_next_event(), + Some(super::SessionEvent::Opened(stream_id)) + ); + assert!(fsm.take_next_inbound(stream_id).is_none()); + + fsm.receive( + now + Duration::from_millis(1), + SessionEnvelope { + seq: SessionSeq(1), + ack: SessionAck::EMPTY, + body: SessionBody::Stream(StreamChunk { + stream_id, + offset: 0, + bytes: b"a".to_vec(), + fin: false, + }), + }, + ); + + assert_eq!( + fsm.take_next_event(), + Some(super::SessionEvent::Readable(stream_id)) + ); + assert_eq!( + fsm.take_next_inbound(stream_id), + Some(super::StreamIncoming::Data(b"a".to_vec())) + ); + assert_eq!( + fsm.take_next_inbound(stream_id), + Some(super::StreamIncoming::Data(b"b".to_vec())) + ); +} + +#[test] +fn duplicate_committed_data_is_not_redelivered() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 9); + let body = SessionBody::Stream(StreamChunk { + stream_id, + offset: 0, + bytes: b"dup".to_vec(), + fin: false, + }); + + fsm.receive( + now, + SessionEnvelope { + seq: SessionSeq(1), + ack: SessionAck::EMPTY, + body: body.clone(), + }, + ); + let _ = fsm.take_next_event(); + let _ = fsm.take_next_event(); + let _ = fsm.take_next_inbound(stream_id); + + fsm.receive( + now + Duration::from_millis(1), + SessionEnvelope { + seq: SessionSeq(2), + ack: SessionAck::EMPTY, + body, + }, + ); + + assert!(fsm.take_next_event().is_none()); + assert!(fsm.take_next_inbound(stream_id).is_none()); +} + +#[test] +fn next_outbound_round_robins_across_ready_streams() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let stream_id_a = fsm.open_stream().unwrap(); + let stream_id_b = fsm.open_stream().unwrap(); + + fsm.write_stream(stream_id_a, b"a-1".to_vec()).unwrap(); + fsm.write_stream(stream_id_b, b"b-1".to_vec()).unwrap(); + fsm.write_stream(stream_id_a, b"a-2".to_vec()).unwrap(); + fsm.write_stream(stream_id_b, b"b-2".to_vec()).unwrap(); + + let scheduled: Vec<_> = (0..4) + .map(|_| match fsm.next_outbound(now).unwrap().body { + SessionBody::Stream(frame) => frame.stream_id, + other => panic!("expected stream frame, got {other:?}"), + }) + .collect(); + + assert_eq!( + scheduled, + vec![stream_id_a, stream_id_b, stream_id_a, stream_id_b] + ); +} + +#[test] +fn idle_session_sends_ping_after_keepalive_interval() { + let now = Instant::now(); + let mut fsm = SessionFsm::new( + SessionFsmConfig { + keepalive_interval: Duration::from_millis(50), + ..SessionFsmConfig::default() + }, + now, + ); + + assert_eq!(fsm.next_deadline(), Some(now + Duration::from_millis(50))); + assert!(fsm.next_outbound(now + Duration::from_millis(49)).is_none()); + fsm.on_timer(now + Duration::from_millis(50)); + + let envelope = fsm.next_outbound(now + Duration::from_millis(50)).unwrap(); + assert!(matches!(envelope.body, SessionBody::Ping(PingBody))); +} + +#[test] +fn receive_ping_schedules_ack_without_ping_pong() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + + fsm.receive(now, ping(1, SessionAck::EMPTY)); + + let ack_envelope = fsm.next_outbound(now + Duration::from_millis(10)).unwrap(); + assert_eq!(ack_envelope.body, SessionBody::Ack); + + fsm.receive(now + Duration::from_millis(20), ack(2, SessionAck::EMPTY)); + assert!(fsm.next_outbound(now + Duration::from_millis(30)).is_none()); +} + +#[test] +fn tx_selective_ack_keeps_front_gap_pinned() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let stream_id = fsm.open_stream().unwrap(); + + for byte in 0..64u8 { + fsm.write_stream(stream_id, vec![byte]).unwrap(); + let _ = fsm + .next_outbound(now + Duration::from_millis(byte as u64)) + .unwrap(); + } + + fsm.receive( + now + Duration::from_millis(100), + ack( + 1, + SessionAck { + base: SessionSeq(0), + bitmap: u64::MAX ^ 1, + }, + ), + ); + + assert!(fsm.state.tx_ring.contains_key(&SessionSeq(1))); + assert!(!fsm.state.tx_ring.contains_key(&SessionSeq(2))); + + fsm.write_stream(stream_id, b"x".to_vec()).unwrap(); + assert!(fsm + .next_outbound(now + Duration::from_millis(101)) + .is_none()); + + fsm.receive( + now + Duration::from_millis(102), + ack( + 2, + SessionAck { + base: SessionSeq(1), + bitmap: 0, + }, + ), + ); + + assert_eq!( + fsm.next_outbound(now + Duration::from_millis(103)) + .unwrap() + .seq, + SessionSeq(65) + ); +} + +#[test] +fn rx_seq_past_window_closes_protocol() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + + fsm.receive(now, ping(65, SessionAck::EMPTY)); + + assert_eq!(fsm.session_state(), SessionState::Closed); + assert!(matches!( + fsm.take_next_event(), + Some(super::SessionEvent::SessionClosed(close)) if close.code == CloseCode::PROTOCOL + )); +} + +#[test] +fn duplicate_old_packet_seq_is_ignored() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 11); + let body = SessionBody::Stream(StreamChunk { + stream_id, + offset: 0, + bytes: b"x".to_vec(), + fin: false, + }); + + fsm.receive( + now, + SessionEnvelope { + seq: SessionSeq(1), + ack: SessionAck::EMPTY, + body: body.clone(), + }, + ); + let _ = fsm.take_next_event(); + let _ = fsm.take_next_event(); + let _ = fsm.take_next_inbound(stream_id); + + fsm.receive( + now + Duration::from_millis(1), + SessionEnvelope { + seq: SessionSeq(1), + ack: SessionAck::EMPTY, + body, + }, + ); + + assert!(fsm.take_next_event().is_none()); + assert!(fsm.take_next_inbound(stream_id).is_none()); +} + +#[test] +fn retransmitted_stream_close_is_idempotent() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let stream_id = fsm.open_stream().unwrap(); + let frame = StreamClose { + stream_id, + target: CloseTarget::Response, + code: CloseCode::CANCELLED, + payload: Vec::new(), + }; + + fsm.receive( + now, + SessionEnvelope { + seq: SessionSeq(1), + ack: SessionAck::EMPTY, + body: SessionBody::StreamClose(frame.clone()), + }, + ); + + assert_eq!( + fsm.take_next_event(), + Some(super::SessionEvent::Readable(stream_id)) + ); + assert_eq!( + fsm.take_next_inbound(stream_id), + Some(super::StreamIncoming::Closed(frame.clone())) + ); + + fsm.receive( + now + Duration::from_millis(1), + SessionEnvelope { + seq: SessionSeq(2), + ack: SessionAck::EMPTY, + body: SessionBody::StreamClose(frame), + }, + ); + + assert!(fsm.take_next_event().is_none()); + assert!(fsm.take_next_inbound(stream_id).is_none()); +} diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index 5e9b108d..e09d821c 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -1,69 +1,20 @@ -use std::{ - collections::VecDeque, - time::{Duration, Instant}, -}; +use std::{collections::VecDeque, time::Instant}; -use bc_components::{MLDSAPublicKey, MLKEMPublicKey, SymmetricKey}; -use ql_wire::{ - handshake::{Confirm, Hello, HelloReply, Ready, ResponderSecrets}, - QlIdentity, QlRecord, WireError, XID, -}; -use thiserror::Error; +use ql_wire::{Confirm, Hello, HelloReply, QlRecord, Ready, ResponderSecrets, SessionKey}; -use crate::{replay_cache::ReplayCache, FsmTime}; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Peer { - pub xid: XID, - pub signing_key: MLDSAPublicKey, - pub encapsulation_key: MLKEMPublicKey, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum PeerStatus { - Disconnected, - Initiator, - Responder, - Connected, -} - -#[derive(Debug, Clone)] -pub enum QlFsmEvent { - PersistPeer(Peer), - ClearPeer, - PeerStatusChanged { peer: XID, status: PeerStatus }, -} - -#[derive(Debug, Clone, Copy)] -pub struct QlFsmConfig { - pub handshake_timeout: Duration, - pub handshake_retry_interval: Duration, - pub max_handshake_retries: u8, - pub control_expiration: Duration, -} - -impl Default for QlFsmConfig { - fn default() -> Self { - Self { - handshake_timeout: Duration::from_secs(5), - handshake_retry_interval: Duration::from_millis(750), - max_handshake_retries: 3, - control_expiration: Duration::from_secs(30), - } - } -} +use crate::{replay_cache::ReplayCache, FsmTime, Peer, PeerStatus, QlFsmEvent, QlSessionEvent}; #[derive(Debug, Clone)] pub enum HandshakeInitiator { WaitingHelloReply { - initiator_secret: SymmetricKey, + initiator_secret: SessionKey, retry_count: u8, retry_at: Option, }, WaitingReady { reply: HelloReply, confirm: Confirm, - session_key: SymmetricKey, + session_key: SessionKey, retry_count: u8, retry_at: Option, }, @@ -87,7 +38,7 @@ pub struct RecentReady { } #[derive(Debug, Clone)] -pub enum PeerSession { +pub enum ConnectionState { Disconnected, Initiator { hello: Hello, @@ -101,12 +52,12 @@ pub enum PeerSession { stage: HandshakeResponder, }, Connected { - session_key: SymmetricKey, + session_key: SessionKey, recent_ready: Option, }, } -impl PeerSession { +impl ConnectionState { pub fn status(&self) -> PeerStatus { match self { Self::Disconnected => PeerStatus::Disconnected, @@ -116,7 +67,7 @@ impl PeerSession { } } - pub fn session_key(&self) -> Option<&SymmetricKey> { + pub fn session_key(&self) -> Option<&SessionKey> { match self { Self::Connected { session_key, .. } => Some(session_key), _ => None, @@ -127,51 +78,23 @@ impl PeerSession { #[derive(Debug, Clone)] pub struct PeerRecord { pub peer: Peer, - pub session: PeerSession, + pub session: ConnectionState, } impl PeerRecord { pub fn new(peer: Peer) -> Self { Self { peer, - session: PeerSession::Disconnected, + session: ConnectionState::Disconnected, } } } -#[derive(Debug, Clone, PartialEq, Eq, Error)] -pub enum QlFsmError { - #[error("invalid payload")] - InvalidPayload, - #[error("invalid signature")] - InvalidSignature, - #[error("expired")] - Expired, - #[error("no peer bound")] - NoPeerBound, -} - -impl From for QlFsmError { - fn from(value: WireError) -> Self { - match value { - WireError::InvalidPayload => Self::InvalidPayload, - WireError::InvalidSignature => Self::InvalidSignature, - WireError::Expired => Self::Expired, - } - } -} - -pub struct QlFsm { - pub config: QlFsmConfig, - pub identity: QlIdentity, - pub peer: Option, - pub state: QlFsmState, -} - pub struct QlFsmState { pub replay_cache: ReplayCache, pub next_control_id: u32, pub outbound: VecDeque, pub events: VecDeque, + pub session_events: VecDeque, pub now: FsmTime, } diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs new file mode 100644 index 00000000..2e573572 --- /dev/null +++ b/ql-fsm/src/tests/handshake.rs @@ -0,0 +1,315 @@ +use std::time::Duration; + +use ql_wire::QlPayload; + +use super::*; +use crate::state::{ConnectionState, HandshakeInitiator, HandshakeResponder}; + +#[test] +fn handshake_deadline_is_derived_from_peer_state() { + let config = QlFsmConfig { + handshake_timeout: Duration::from_secs(5), + handshake_retry_interval: Duration::from_secs(10), + max_handshake_retries: 0, + session_keepalive_interval: Duration::from_millis(1), + session_peer_timeout: Duration::from_millis(2), + ..QlFsmConfig::default() + }; + let mut harness = Harness::paired(config); + + harness + .a + .fsm + .connect(harness.time(), &harness.a.crypto) + .unwrap(); + assert_eq!( + harness.a.fsm.next_deadline(), + Some(harness.now + config.handshake_timeout) + ); + + let _hello = harness.next_outbound_a().unwrap(); + harness.advance(Duration::from_secs(4)); + harness.a.fsm.on_timer(harness.time()); + assert!(matches!( + harness.a.fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Initiator { .. }) + )); + + harness.advance(Duration::from_secs(1)); + harness.a.fsm.on_timer(harness.time()); + assert!(matches!( + harness.a.fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Disconnected) + )); +} + +#[test] +fn initiator_retries_hello_after_retry_interval() { + let config = QlFsmConfig { + handshake_retry_interval: Duration::from_millis(250), + max_handshake_retries: 2, + ..QlFsmConfig::default() + }; + let mut harness = Harness::paired(config); + + harness + .a + .fsm + .connect(harness.time(), &harness.a.crypto) + .unwrap(); + let hello = harness.next_outbound_a().unwrap(); + + harness.advance(config.handshake_retry_interval); + harness.a.fsm.on_timer(harness.time()); + + assert_eq!(harness.next_outbound_a(), Some(hello)); + assert!(matches!( + harness.a.fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Initiator { + stage: HandshakeInitiator::WaitingHelloReply { retry_count: 1, .. }, + .. + }) + )); +} + +#[test] +fn responder_retries_hello_reply_after_retry_interval() { + let config = QlFsmConfig { + handshake_retry_interval: Duration::from_millis(250), + max_handshake_retries: 2, + ..QlFsmConfig::default() + }; + let mut harness = Harness::paired(config); + + harness + .a + .fsm + .connect(harness.time(), &harness.a.crypto) + .unwrap(); + let hello = harness.next_outbound_a().unwrap(); + harness.deliver_to_b(hello); + let reply = harness.next_outbound_b().unwrap(); + + harness.advance(config.handshake_retry_interval); + harness.b.fsm.on_timer(harness.time()); + + assert_eq!(harness.next_outbound_b(), Some(reply)); + assert!(matches!( + harness.b.fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Responder { + stage: HandshakeResponder::WaitingConfirm { retry_count: 1, .. }, + .. + }) + )); +} + +#[test] +fn initiator_retries_confirm_after_retry_interval() { + let config = QlFsmConfig { + handshake_retry_interval: Duration::from_millis(250), + max_handshake_retries: 2, + ..QlFsmConfig::default() + }; + let mut harness = Harness::paired(config); + + harness + .a + .fsm + .connect(harness.time(), &harness.a.crypto) + .unwrap(); + let hello = harness.next_outbound_a().unwrap(); + harness.deliver_to_b(hello); + let reply = harness.next_outbound_b().unwrap(); + harness.deliver_to_a(reply); + let confirm = harness.next_outbound_a().unwrap(); + + harness.advance(config.handshake_retry_interval); + harness.a.fsm.on_timer(harness.time()); + + assert_eq!(harness.next_outbound_a(), Some(confirm)); + assert!(matches!( + harness.a.fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Initiator { + stage: HandshakeInitiator::WaitingReady { retry_count: 1, .. }, + .. + }) + )); +} + +#[test] +fn duplicate_hello_resends_hello_reply() { + let mut harness = Harness::paired(QlFsmConfig::default()); + + harness + .a + .fsm + .connect(harness.time(), &harness.a.crypto) + .unwrap(); + let hello = harness.next_outbound_a().unwrap(); + + harness.deliver_to_b(hello.clone()); + let reply = harness.next_outbound_b().unwrap(); + + harness.deliver_to_b(hello); + assert_eq!(harness.next_outbound_b(), Some(reply)); +} + +#[test] +fn duplicate_hello_reply_resends_confirm() { + let mut harness = Harness::paired(QlFsmConfig::default()); + + harness + .a + .fsm + .connect(harness.time(), &harness.a.crypto) + .unwrap(); + let hello = harness.next_outbound_a().unwrap(); + harness.deliver_to_b(hello); + let reply = harness.next_outbound_b().unwrap(); + + harness.deliver_to_a(reply.clone()); + let confirm = harness.next_outbound_a().unwrap(); + + harness.deliver_to_a(reply); + assert_eq!(harness.next_outbound_a(), Some(confirm)); +} + +#[test] +fn responder_resends_ready_for_duplicate_confirm_after_connecting() { + let mut harness = Harness::paired(QlFsmConfig::default()); + + harness + .a + .fsm + .connect(harness.time(), &harness.a.crypto) + .unwrap(); + let hello = harness.next_outbound_a().unwrap(); + harness.deliver_to_b(hello); + let reply = harness.next_outbound_b().unwrap(); + harness.deliver_to_a(reply); + let confirm = harness.next_outbound_a().unwrap(); + + harness.deliver_to_b(confirm.clone()); + let ready = harness.next_outbound_b().unwrap(); + + assert!(matches!( + harness.b.fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Connected { + recent_ready: Some(_), + .. + }) + )); + + harness.deliver_to_b(confirm); + assert_eq!(harness.next_outbound_b(), Some(ready)); +} + +#[test] +fn initiator_waits_for_ready_before_connecting() { + let mut harness = Harness::paired(QlFsmConfig::default()); + + harness + .a + .fsm + .connect(harness.time(), &harness.a.crypto) + .unwrap(); + let hello = harness.next_outbound_a().unwrap(); + harness.deliver_to_b(hello); + let reply = harness.next_outbound_b().unwrap(); + harness.deliver_to_a(reply); + + assert!(matches!( + harness.a.fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Initiator { + stage: HandshakeInitiator::WaitingReady { .. }, + .. + }) + )); + let stream_id = harness.a.fsm.open_stream().unwrap(); + harness + .a + .fsm + .write_stream(stream_id, b"queued".to_vec()) + .unwrap(); + + let confirm = harness.next_outbound_a().unwrap(); + assert!(matches!(confirm.payload, QlPayload::Confirm(_))); + harness.deliver_to_b(confirm); + let ready = harness.next_outbound_b().unwrap(); + + assert!(matches!( + harness.a.fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Initiator { + stage: HandshakeInitiator::WaitingReady { .. }, + .. + }) + )); + + harness.deliver_to_a(ready); + assert!(matches!( + harness.a.fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Connected { .. }) + )); + let record = harness.next_outbound_a().unwrap(); + assert!(matches!(record.payload, QlPayload::Session(_))); +} + +#[test] +fn handshake_retry_limit_disconnects_initiator() { + let config = QlFsmConfig { + handshake_retry_interval: Duration::from_millis(250), + max_handshake_retries: 1, + ..QlFsmConfig::default() + }; + let mut harness = Harness::paired(config); + + harness + .a + .fsm + .connect(harness.time(), &harness.a.crypto) + .unwrap(); + let hello = harness.next_outbound_a().unwrap(); + + harness.advance(config.handshake_retry_interval); + harness.a.fsm.on_timer(harness.time()); + assert_eq!(harness.next_outbound_a(), Some(hello)); + + harness.advance(config.handshake_retry_interval); + harness.a.fsm.on_timer(harness.time()); + assert!(matches!( + harness.a.fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Disconnected) + )); +} + +#[test] +fn simultaneous_connect_converges_to_connected_peers() { + let mut harness = Harness::paired(QlFsmConfig::default()); + + harness + .a + .fsm + .connect(harness.time(), &harness.a.crypto) + .unwrap(); + harness + .b + .fsm + .connect(harness.time(), &harness.b.crypto) + .unwrap(); + + let hello_a = harness.next_outbound_a().unwrap(); + let hello_b = harness.next_outbound_b().unwrap(); + + harness.deliver_to_a(hello_b); + harness.deliver_to_b(hello_a); + harness.pump(); + + assert!(matches!( + harness.a.fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Connected { .. }) + )); + assert!(matches!( + harness.b.fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Connected { .. }) + )); +} diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs new file mode 100644 index 00000000..28a94bdd --- /dev/null +++ b/ql-fsm/src/tests/mod.rs @@ -0,0 +1,283 @@ +mod handshake; +mod session; + +use std::{ + cell::Cell, + time::{Duration, Instant}, +}; + +use libcrux_aesgcm::AesGcm256Key; +use ql_wire::{ + self, generate_ml_dsa_keypair, generate_ml_kem_keypair, EncryptedMessage, QlCrypto, QlIdentity, + QlPayload, QlRecord, SessionEnvelope, SessionKey, XID, +}; +use sha2::{Digest, Sha256}; + +use crate::{ + session::{SessionFsm, SessionFsmConfig, StreamNamespace}, + state::ConnectionState, + FsmTime, OutboundWrite, Peer, QlFsm, QlFsmConfig, SessionWriteId, +}; + +#[derive(Clone)] +struct TestCrypto { + seed: u8, + counter: Cell, +} + +impl TestCrypto { + fn new(seed: u8) -> Self { + Self { + seed, + counter: Cell::new(0), + } + } +} + +impl QlCrypto for TestCrypto { + fn fill_random_bytes(&self, data: &mut [u8]) { + let value = self.seed.wrapping_add(self.counter.get()); + self.counter.set(self.counter.get().wrapping_add(1)); + data.fill(value); + } + + fn hash(&self, parts: &[&[u8]]) -> [u8; 32] { + let mut hasher = Sha256::new(); + for part in parts { + hasher.update(part); + } + hasher.finalize().into() + } + + fn encrypt_with_aead( + &self, + key: &SessionKey, + nonce: &ql_wire::Nonce, + aad: &[u8], + buffer: &mut [u8], + ) -> Option<[u8; EncryptedMessage::AUTH_SIZE]> { + let key: AesGcm256Key = (*key.data()).into(); + let plaintext = buffer.to_vec(); + let mut auth = [0u8; EncryptedMessage::AUTH_SIZE]; + key.encrypt( + buffer, + (&mut auth).into(), + (&nonce.0).into(), + aad, + &plaintext, + ) + .ok()?; + Some(auth) + } + + fn decrypt_with_aead( + &self, + key: &SessionKey, + nonce: &ql_wire::Nonce, + aad: &[u8], + buffer: &mut [u8], + auth_tag: &[u8; EncryptedMessage::AUTH_SIZE], + ) -> bool { + let key: AesGcm256Key = (*key.data()).into(); + let ciphertext = buffer.to_vec(); + key.decrypt(buffer, (&nonce.0).into(), aad, &ciphertext, auth_tag.into()) + .is_ok() + } +} + +struct Node { + fsm: QlFsm, + crypto: TestCrypto, +} + +struct Harness { + now: Instant, + unix_secs: u64, + a: Node, + b: Node, +} + +impl Harness { + fn paired(config: QlFsmConfig) -> Self { + let identity_a = test_identity(11); + let identity_b = test_identity(73); + let peer_a = peer_from_identity(&identity_b); + let peer_b = peer_from_identity(&identity_a); + let now = Instant::now(); + let time = FsmTime { + instant: now, + unix_secs: 1_700_000_000, + }; + + let mut harness = Self { + now, + unix_secs: time.unix_secs, + a: Node { + fsm: QlFsm::new(config, identity_a, time), + crypto: TestCrypto::new(1), + }, + b: Node { + fsm: QlFsm::new(config, identity_b, time), + crypto: TestCrypto::new(2), + }, + }; + + harness.a.fsm.bind_peer(peer_a); + harness.b.fsm.bind_peer(peer_b); + while harness.a.fsm.take_next_event().is_some() {} + while harness.b.fsm.take_next_event().is_some() {} + + harness + } + + fn connected(config: QlFsmConfig) -> Self { + let mut harness = Self::paired(config); + let session_key = SessionKey::from_data([7; SessionKey::SIZE]); + + harness.a.fsm.peer.as_mut().unwrap().session = ConnectionState::Connected { + session_key: session_key.clone(), + recent_ready: None, + }; + harness.b.fsm.peer.as_mut().unwrap().session = ConnectionState::Connected { + session_key, + recent_ready: None, + }; + harness.a.fsm.session = SessionFsm::new( + SessionFsmConfig { + local_namespace: StreamNamespace::for_local( + harness.a.fsm.identity.xid, + harness.a.fsm.peer.as_ref().unwrap().peer.xid, + ), + ack_delay: config.session_ack_delay, + retransmit_timeout: config.session_retransmit_timeout, + keepalive_interval: config.session_keepalive_interval, + peer_timeout: config.session_peer_timeout, + }, + harness.now, + ); + harness.b.fsm.session = SessionFsm::new( + SessionFsmConfig { + local_namespace: StreamNamespace::for_local( + harness.b.fsm.identity.xid, + harness.b.fsm.peer.as_ref().unwrap().peer.xid, + ), + ack_delay: config.session_ack_delay, + retransmit_timeout: config.session_retransmit_timeout, + keepalive_interval: config.session_keepalive_interval, + peer_timeout: config.session_peer_timeout, + }, + harness.now, + ); + harness + } + + fn time(&self) -> FsmTime { + FsmTime { + instant: self.now, + unix_secs: self.unix_secs, + } + } + + fn advance(&mut self, duration: Duration) { + self.now += duration; + self.unix_secs = self.unix_secs.saturating_add(duration.as_secs()); + } + + fn next_outbound_a(&mut self) -> Option { + let write = self.a.fsm.take_next_write(self.time(), &self.a.crypto)?; + if let Some(id) = write.session_write_id { + self.a.fsm.confirm_session_write(self.time(), id); + } + Some(write.record) + } + + fn next_outbound_b(&mut self) -> Option { + let write = self.b.fsm.take_next_write(self.time(), &self.b.crypto)?; + if let Some(id) = write.session_write_id { + self.b.fsm.confirm_session_write(self.time(), id); + } + Some(write.record) + } + + fn next_write_a(&mut self) -> Option { + self.a.fsm.take_next_write(self.time(), &self.a.crypto) + } + + fn deliver_to_a(&mut self, record: QlRecord) { + self.a + .fsm + .receive(self.time(), record.encode(), &self.a.crypto) + .unwrap(); + } + + fn deliver_to_b(&mut self, record: QlRecord) { + self.b + .fsm + .receive(self.time(), record.encode(), &self.b.crypto) + .unwrap(); + } + + fn confirm_write_a(&mut self, write_id: SessionWriteId) { + self.a.fsm.confirm_session_write(self.time(), write_id); + } + + fn return_write_a(&mut self, write_id: SessionWriteId) { + self.a.fsm.reject_session_write(write_id); + } + + fn pump(&mut self) { + for _ in 0..128 { + let mut progressed = false; + + while let Some(record) = self.next_outbound_a() { + progressed = true; + self.deliver_to_b(record); + } + + while let Some(record) = self.next_outbound_b() { + progressed = true; + self.deliver_to_a(record); + } + + if !progressed { + return; + } + } + + panic!("pump did not quiesce"); + } +} + +fn test_identity(seed: u8) -> QlIdentity { + let crypto = TestCrypto::new(seed); + let (signing_private, signing_public) = generate_ml_dsa_keypair(&crypto); + let (encapsulation_private, encapsulation_public) = generate_ml_kem_keypair(&crypto); + QlIdentity::new( + XID([seed; XID::SIZE]), + signing_private, + signing_public, + encapsulation_private, + encapsulation_public, + ) +} + +fn peer_from_identity(identity: &QlIdentity) -> Peer { + Peer { + xid: identity.xid, + signing_key: identity.signing_public_key.clone(), + encapsulation_key: identity.encapsulation_public_key.clone(), + } +} + +fn decrypt_envelope( + crypto: &impl QlCrypto, + record: &QlRecord, + session_key: &SessionKey, +) -> ql_wire::SessionEnvelope { + let aad = record.header.aad(); + let QlPayload::Session(encrypted) = &record.payload else { + panic!("expected encrypted payload"); + }; + let plaintext = encrypted.decrypt(crypto, session_key, &aad).unwrap(); + SessionEnvelope::decode(&plaintext).unwrap() +} diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs new file mode 100644 index 00000000..ced78785 --- /dev/null +++ b/ql-fsm/src/tests/session.rs @@ -0,0 +1,352 @@ +use std::time::Duration; + +use ql_wire::SessionCloseBody; + +use super::*; +use crate::{session::StreamNamespace, QlFsmEvent, QlSessionEvent}; + +#[test] +fn connected_fsms_deliver_stream_data() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + let stream_id = harness.a.fsm.open_stream().unwrap(); + harness + .a + .fsm + .write_stream(stream_id, b"hello".to_vec()) + .unwrap(); + harness.a.fsm.finish_stream(stream_id).unwrap(); + + harness.pump(); + + assert_eq!( + harness.b.fsm.take_next_session_event(), + Some(QlSessionEvent::Opened(stream_id)) + ); + assert_eq!( + harness.b.fsm.take_next_session_event(), + Some(QlSessionEvent::Data { + stream_id, + bytes: b"hello".to_vec(), + }) + ); + assert_eq!( + harness.b.fsm.take_next_session_event(), + Some(QlSessionEvent::Finished(stream_id)) + ); +} + +#[test] +fn lost_encrypted_record_is_retried_and_acked() { + let config = QlFsmConfig::default(); + let mut harness = Harness::connected(config); + + let stream_id = harness.a.fsm.open_stream().unwrap(); + harness + .a + .fsm + .write_stream(stream_id, b"retry".to_vec()) + .unwrap(); + + let first = harness.next_outbound_a().unwrap(); + let session_key = harness + .b + .fsm + .peer + .as_ref() + .unwrap() + .session + .session_key() + .unwrap() + .clone(); + let first_body = decrypt_envelope(&harness.b.crypto, &first, &session_key); + + harness.advance(config.session_retransmit_timeout + Duration::from_millis(1)); + + let retried = harness.next_outbound_a().unwrap(); + let retried_body = decrypt_envelope(&harness.b.crypto, &retried, &session_key); + + assert_ne!(first_body.seq, retried_body.seq); + assert_eq!(first_body.body, retried_body.body); + + harness.deliver_to_b(retried); + harness.pump(); + + assert_eq!( + harness.b.fsm.take_next_session_event(), + Some(QlSessionEvent::Opened(stream_id)) + ); + assert_eq!( + harness.b.fsm.take_next_session_event(), + Some(QlSessionEvent::Data { + stream_id, + bytes: b"retry".to_vec(), + }) + ); + + harness.advance(config.session_retransmit_timeout + Duration::from_millis(1)); + assert!(harness.next_outbound_a().is_none()); +} + +#[test] +fn remote_unpair_clears_peer() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + harness.a.fsm.queue_unpair().unwrap(); + harness.pump(); + + assert_eq!( + harness.b.fsm.take_next_session_event(), + Some(QlSessionEvent::Unpaired) + ); + assert!(harness.b.fsm.peer.is_none()); + assert!(matches!( + harness.b.fsm.take_next_event(), + Some(QlFsmEvent::ClearPeer) + )); + assert!(harness.a.fsm.peer.is_some()); +} + +#[test] +fn simultaneous_opens_use_disjoint_stream_id_namespaces() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + let stream_id_a = harness.a.fsm.open_stream().unwrap(); + let stream_id_b = harness.b.fsm.open_stream().unwrap(); + + assert_ne!(stream_id_a, stream_id_b); + assert!( + StreamNamespace::for_local(harness.a.fsm.identity.xid, harness.b.fsm.identity.xid) + .matches(stream_id_a) + ); + assert!( + StreamNamespace::for_local(harness.b.fsm.identity.xid, harness.a.fsm.identity.xid) + .matches(stream_id_b) + ); + + harness + .a + .fsm + .write_stream(stream_id_a, b"from-a".to_vec()) + .unwrap(); + harness + .b + .fsm + .write_stream(stream_id_b, b"from-b".to_vec()) + .unwrap(); + + harness.pump(); + + assert_eq!( + harness.a.fsm.take_next_session_event(), + Some(QlSessionEvent::Opened(stream_id_b)) + ); + assert_eq!( + harness.a.fsm.take_next_session_event(), + Some(QlSessionEvent::Data { + stream_id: stream_id_b, + bytes: b"from-b".to_vec(), + }) + ); + assert_eq!( + harness.b.fsm.take_next_session_event(), + Some(QlSessionEvent::Opened(stream_id_a)) + ); + assert_eq!( + harness.b.fsm.take_next_session_event(), + Some(QlSessionEvent::Data { + stream_id: stream_id_a, + bytes: b"from-a".to_vec(), + }) + ); +} + +#[test] +fn queued_stream_work_auto_connects_and_drains_after_handshake() { + let mut harness = Harness::paired(QlFsmConfig::default()); + + let stream_id = harness.a.fsm.open_stream().unwrap(); + harness + .a + .fsm + .write_stream(stream_id, b"queued".to_vec()) + .unwrap(); + harness.a.fsm.finish_stream(stream_id).unwrap(); + + harness.pump(); + + assert!(matches!( + harness.a.fsm.peer.as_ref().map(|entry| &entry.session), + Some(crate::state::ConnectionState::Connected { .. }) + )); + assert!(matches!( + harness.b.fsm.peer.as_ref().map(|entry| &entry.session), + Some(crate::state::ConnectionState::Connected { .. }) + )); + assert_eq!( + harness.b.fsm.take_next_session_event(), + Some(QlSessionEvent::Opened(stream_id)) + ); + assert_eq!( + harness.b.fsm.take_next_session_event(), + Some(QlSessionEvent::Data { + stream_id, + bytes: b"queued".to_vec(), + }) + ); + assert_eq!( + harness.b.fsm.take_next_session_event(), + Some(QlSessionEvent::Finished(stream_id)) + ); +} + +#[test] +fn queued_stream_work_is_failed_when_handshake_times_out() { + let config = QlFsmConfig { + handshake_retry_interval: Duration::from_millis(50), + max_handshake_retries: 0, + ..QlFsmConfig::default() + }; + let mut harness = Harness::paired(config); + + let stream_id = harness.a.fsm.open_stream().unwrap(); + harness + .a + .fsm + .write_stream(stream_id, b"queued".to_vec()) + .unwrap(); + + let _hello = harness.next_outbound_a().unwrap(); + + harness.advance(config.handshake_retry_interval); + harness.a.fsm.on_timer(harness.time()); + + assert!(matches!( + harness.a.fsm.peer.as_ref().map(|entry| &entry.session), + Some(crate::state::ConnectionState::Disconnected) + )); + assert_eq!( + harness.a.fsm.take_next_session_event(), + Some(QlSessionEvent::SessionClosed(SessionCloseBody { + code: ql_wire::CloseCode::TIMEOUT + })) + ); + assert!(harness.next_outbound_a().is_none()); +} + +#[test] +fn returned_session_write_is_reissued_with_same_seq() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + let stream_id = harness.a.fsm.open_stream().unwrap(); + harness + .a + .fsm + .write_stream(stream_id, b"retry".to_vec()) + .unwrap(); + + let write = harness.next_write_a().unwrap(); + let id = write.session_write_id.expect("expected session write"); + let record = write.record; + let session_key = harness + .b + .fsm + .peer + .as_ref() + .unwrap() + .session + .session_key() + .unwrap() + .clone(); + let first = decrypt_envelope(&harness.b.crypto, &record, &session_key); + + harness.return_write_a(id); + + let write = harness.next_write_a().unwrap(); + let reissued_id = write + .session_write_id + .expect("expected reissued session write"); + let record = write.record; + let reissued = decrypt_envelope(&harness.b.crypto, &record, &session_key); + + assert_eq!(reissued_id, id); + assert_eq!(reissued.seq, first.seq); + assert_eq!(reissued.body, first.body); + + harness.confirm_write_a(reissued_id); + harness.deliver_to_b(record); + harness.pump(); + + assert_eq!( + harness.b.fsm.take_next_session_event(), + Some(QlSessionEvent::Opened(stream_id)) + ); + assert_eq!( + harness.b.fsm.take_next_session_event(), + Some(QlSessionEvent::Data { + stream_id, + bytes: b"retry".to_vec(), + }) + ); +} + +#[test] +fn unconfirmed_session_write_does_not_start_retransmit_timer() { + let config = QlFsmConfig::default(); + let mut harness = Harness::connected(config); + + let stream_id = harness.a.fsm.open_stream().unwrap(); + harness + .a + .fsm + .write_stream(stream_id, b"retry".to_vec()) + .unwrap(); + + let write = harness.next_write_a().unwrap(); + let id = write.session_write_id.expect("expected session write"); + let record = write.record; + let session_key = harness + .b + .fsm + .peer + .as_ref() + .unwrap() + .session + .session_key() + .unwrap() + .clone(); + let first = decrypt_envelope(&harness.b.crypto, &record, &session_key); + + harness.advance(config.session_retransmit_timeout + Duration::from_millis(1)); + harness.a.fsm.on_timer(harness.time()); + assert!(harness.next_write_a().is_none()); + + harness.confirm_write_a(id); + harness.advance(config.session_retransmit_timeout + Duration::from_millis(1)); + + let write = harness.next_write_a().unwrap(); + assert!(write.session_write_id.is_some(), "expected retransmit"); + let record = write.record; + let retried = decrypt_envelope(&harness.b.crypto, &record, &session_key); + + assert_ne!(retried.seq, first.seq); + assert_eq!(retried.body, first.body); +} + +#[test] +fn kill_session_disconnects_locally() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + harness.a.fsm.kill_session(ql_wire::CloseCode::CANCELLED); + + assert!(matches!( + harness.a.fsm.peer.as_ref().map(|entry| &entry.session), + Some(crate::state::ConnectionState::Disconnected) + )); + assert_eq!( + harness.a.fsm.take_next_session_event(), + Some(QlSessionEvent::SessionClosed(SessionCloseBody { + code: ql_wire::CloseCode::CANCELLED + })) + ); +} diff --git a/ql-runtime/Cargo.toml b/ql-runtime/Cargo.toml index f7eda226..03988f66 100644 --- a/ql-runtime/Cargo.toml +++ b/ql-runtime/Cargo.toml @@ -7,13 +7,14 @@ license = "Proprietary" [dependencies] async-channel = { version = "2.5" } -bc-components = { version = "0.28.0", default-features = false, features = [ - "pqcrypto", -] } futures-lite = { version = "2.5" } oneshot = { version = "0.1.11" } piper = { version = "0.2.4" } -ql-engine = { path = "../ql-engine" } +ql-fsm = { path = "../ql-fsm" } +ql-wire = { path = "../ql-wire" } +thiserror = { version = "2" } [dev-dependencies] +libcrux-aesgcm = "0.0.7" +sha2 = "0.10" tokio = { version = "1.44", features = ["macros", "rt", "time", "sync"] } diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index 6c5a5844..9686e85a 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -1,7 +1,6 @@ use crate::{ - OpenedStreamDelivery, StreamConfig, - wire::stream::{CloseCode, CloseTarget}, - Peer, QlError, StreamId, + wire::{CloseCode, CloseTarget}, + OpenedStreamDelivery, Peer, QlError, StreamId, }; pub(crate) enum RuntimeCommand { @@ -12,10 +11,8 @@ pub(crate) enum RuntimeCommand { Connect, Unpair, OpenStream { - request_head: Vec, request_reader: piper::Reader, start: oneshot::Sender>, - config: StreamConfig, }, PollStream { stream_id: StreamId, diff --git a/ql-runtime/src/driver.rs b/ql-runtime/src/driver.rs index 40238f27..62345e0c 100644 --- a/ql-runtime/src/driver.rs +++ b/ql-runtime/src/driver.rs @@ -1,50 +1,34 @@ use std::{ - collections::{HashMap, VecDeque}, + collections::HashMap, future::Future, task::Poll, - time::Instant, + time::{Duration, Instant, SystemTime, UNIX_EPOCH}, }; use futures_lite::future::poll_fn; +use ql_fsm::{FsmTime, QlFsm, QlFsmEvent, QlSessionEvent, SessionWriteId}; use crate::{ command::RuntimeCommand, - engine::{Engine, EngineEventSink, WriteId}, - handle::{InboundByteStream, InboundStream, OutboundByteStream}, + handle::{ByteReader, ByteWriter, InboundStream}, platform::{PlatformFuture, QlPlatform}, - wire::stream::{BodyChunk, CloseCode, CloseTarget}, - HandlerEvent, InboundEvent, OpenedStreamDelivery, Peer, QlError, Runtime, StreamId, + CloseCode, CloseTarget, HandlerEvent, InboundEvent, OpenedStreamDelivery, QlError, Runtime, + StreamId, }; struct InFlightWrite<'a> { - id: WriteId, + session_write_id: Option, future: PlatformFuture<'a, Result<(), QlError>>, } -enum PendingAction { - CloseStream { - stream_id: StreamId, - target: CloseTarget, - code: CloseCode, - payload: Vec, - }, - OutboundData { - stream_id: StreamId, - bytes: Vec, - }, - OutboundFinished { - stream_id: StreamId, - }, -} - enum DriverEvent { Command(RuntimeCommand), WriteCompleted { - write_id: WriteId, + index: usize, result: Result<(), QlError>, }, TimerExpired, - Closed, + CommandsClosed, } enum OutboundIo { @@ -67,29 +51,33 @@ impl OutboundIo { *self = Self::Closed; } - fn poll_pending(&mut self, stream_id: StreamId, pending_inputs: &mut VecDeque) { + fn take_pending(&mut self) -> (Option>, bool) { let Self::Open { reader, finish_queued, } = self else { - return; + return (None, false); }; + let mut drained = None; let available = reader.len(); if available > 0 { let mut bytes = vec![0; available]; let read = reader.try_drain(&mut bytes); if read > 0 { bytes.truncate(read); - pending_inputs.push_back(PendingAction::OutboundData { stream_id, bytes }); + drained = Some(bytes); } } + let mut finished = false; if reader.is_closed() && !*finish_queued { *finish_queued = true; - pending_inputs.push_back(PendingAction::OutboundFinished { stream_id }); + finished = true; } + + (drained, finished) } } @@ -103,31 +91,18 @@ impl InboundIo { Self::Open(tx) } - fn write_or_close( - &mut self, - stream_id: StreamId, - target: CloseTarget, - bytes: Vec, - ) -> Option { + fn write_or_close(&mut self, bytes: Vec) -> bool { let Self::Open(tx) = self else { - return Some(PendingAction::CloseStream { - stream_id, - target, - code: CloseCode::CANCELLED, - payload: Vec::new(), - }); + return true; }; + if tx.try_send(InboundEvent::Data(bytes)).is_err() { tx.close(); *self = Self::Closed; - return Some(PendingAction::CloseStream { - stream_id, - target, - code: CloseCode::CANCELLED, - payload: Vec::new(), - }); + return true; } - None + + false } fn finish(&mut self) { @@ -145,14 +120,6 @@ impl InboundIo { } *self = Self::Closed; } - - fn close(&mut self) { - if let Self::Open(tx) = self { - let _ = tx.try_send(InboundEvent::Failed(QlError::Cancelled)); - tx.close(); - } - *self = Self::Closed; - } } enum DriverStreamIo { @@ -167,13 +134,6 @@ enum DriverStreamIo { } impl DriverStreamIo { - fn poll_pending(&mut self, stream_id: StreamId, pending_inputs: &mut VecDeque) { - match self { - Self::Initiator { request, .. } => request.poll_pending(stream_id, pending_inputs), - Self::Responder { response, .. } => response.poll_pending(stream_id, pending_inputs), - } - } - fn outbound_mut(&mut self) -> &mut OutboundIo { match self { Self::Initiator { request, .. } => request, @@ -195,161 +155,29 @@ impl DriverStreamIo { } } - fn close_all(&mut self) { + fn outbound_target(&self) -> CloseTarget { + match self { + Self::Initiator { .. } => CloseTarget::Request, + Self::Responder { .. } => CloseTarget::Response, + } + } + + fn fail_all(&mut self, error: QlError) { match self { Self::Initiator { request, response } => { request.close(); - response.close(); + response.fail(error); } Self::Responder { request, response } => { - request.close(); + request.fail(error); response.close(); } } } } -struct DriverEventSink<'a, P> { - platform: &'a P, - runtime_tx: &'a async_channel::Sender, - stream_send_buffer_bytes: usize, - pending_inputs: &'a mut VecDeque, - streams: &'a mut HashMap, -} - -impl<'a, P> DriverEventSink<'a, P> { - fn new( - platform: &'a P, - runtime_tx: &'a async_channel::Sender, - stream_send_buffer_bytes: usize, - pending_inputs: &'a mut VecDeque, - streams: &'a mut HashMap, - ) -> Self { - Self { - platform, - runtime_tx, - stream_send_buffer_bytes, - pending_inputs, - streams, - } - } -} - -impl EngineEventSink for DriverEventSink<'_, P> { - fn peer_status_changed( - &mut self, - peer: bc_components::XID, - session: crate::engine::PeerSession, - ) { - self.platform.handle_peer_status(peer, &session); - } - - fn persist_peer(&mut self, peer: Peer) { - self.platform.persist_peer(peer); - } - - fn clear_peer(&mut self) { - self.platform.clear_peer(); - } - - fn inbound_stream_opened( - &mut self, - stream_id: StreamId, - request_head: Vec, - request_prefix: Option, - ) { - let (request_tx, request_rx) = async_channel::unbounded(); - let mut request = InboundIo::new(request_tx); - if let Some(prefix) = request_prefix.as_ref() { - if !prefix.bytes.is_empty() { - let InboundIo::Open(tx) = &request else { - unreachable!("fresh inbound stream must be open"); - }; - tx.try_send(InboundEvent::Data(prefix.bytes.clone())) - .expect("new inbound stream prefix send should succeed"); - } - if prefix.fin { - request.finish(); - } - } - - let (response_reader, response_writer) = piper::pipe(self.stream_send_buffer_bytes); - self.streams.insert( - stream_id, - DriverStreamIo::Responder { - request, - response: OutboundIo::new(response_reader), - }, - ); - - self.platform - .handle_inbound(HandlerEvent::Stream(InboundStream { - stream_id, - request_head, - request: InboundByteStream::new( - stream_id, - CloseTarget::Request, - request_rx, - self.runtime_tx.clone(), - ), - response: OutboundByteStream::new( - stream_id, - CloseTarget::Response, - response_writer, - self.runtime_tx.clone(), - ), - })); - } - - fn inbound_data(&mut self, stream_id: StreamId, bytes: Vec) { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - let target = stream.inbound_target(); - let inbound = stream.inbound_mut(); - if let Some(input) = inbound.write_or_close(stream_id, target, bytes) { - self.pending_inputs.push_back(input); - } - } - - fn inbound_finished(&mut self, stream_id: StreamId) { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - stream.inbound_mut().finish(); - } - - fn inbound_failed(&mut self, stream_id: StreamId, error: QlError) { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - stream.inbound_mut().fail(error); - } - - fn outbound_closed(&mut self, stream_id: StreamId) { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - stream.outbound_mut().close(); - } - - fn outbound_failed(&mut self, stream_id: StreamId, _error: QlError) { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - stream.outbound_mut().close(); - } - - fn stream_reaped(&mut self, stream_id: StreamId) { - if let Some(mut stream) = self.streams.remove(&stream_id) { - stream.close_all(); - } - } -} - struct DriverState { - engine: Engine, - pending_inputs: VecDeque, + fsm: QlFsm, streams: HashMap, runtime_tx: async_channel::Sender, stream_send_buffer_bytes: usize, @@ -365,91 +193,52 @@ impl DriverState { ) { match command { RuntimeCommand::BindPeer { peer } => { - let now = Instant::now(); - let mut events = DriverEventSink::new( - platform, - &self.runtime_tx, - self.stream_send_buffer_bytes, - &mut self.pending_inputs, - &mut self.streams, - ); - self.engine.bind_peer(now, peer, &mut events); + self.fsm.bind_peer(peer); self.finish_step(platform, in_flight); } RuntimeCommand::Pair => { - self.engine.pair(Instant::now(), platform); + let _ = self.fsm.pair(now(), platform); self.finish_step(platform, in_flight); } RuntimeCommand::Connect => { - let now = Instant::now(); - let mut events = DriverEventSink::new( - platform, - &self.runtime_tx, - self.stream_send_buffer_bytes, - &mut self.pending_inputs, - &mut self.streams, - ); - self.engine.connect(now, platform, &mut events); + let _ = self.fsm.connect(now(), platform); self.finish_step(platform, in_flight); } RuntimeCommand::Unpair => { - let now = Instant::now(); - let mut events = DriverEventSink::new( - platform, - &self.runtime_tx, - self.stream_send_buffer_bytes, - &mut self.pending_inputs, - &mut self.streams, - ); - self.engine.unpair(now, &mut events); + let _ = self.fsm.queue_unpair(); self.finish_step(platform, in_flight); } RuntimeCommand::Incoming(bytes) => { - let now = Instant::now(); - let mut events = DriverEventSink::new( - platform, - &self.runtime_tx, - self.stream_send_buffer_bytes, - &mut self.pending_inputs, - &mut self.streams, - ); - self.engine.receive(now, bytes, platform, &mut events); + let _ = self.fsm.receive(now(), bytes, platform); self.finish_step(platform, in_flight); } RuntimeCommand::OpenStream { - request_head, request_reader, start, - config, - } => { - match self - .engine - .open_stream(Instant::now(), request_head, None, config) - { - Ok(stream_id) => { - let (response_tx, response_rx) = async_channel::unbounded(); - self.streams.insert( - stream_id, - DriverStreamIo::Initiator { - request: OutboundIo::new(request_reader), - response: InboundIo::new(response_tx), - }, - ); - let _ = start.send(Ok(OpenedStreamDelivery { - stream_id, - response: response_rx, - })); - self.poll_stream(stream_id); - self.drive_pending(platform, in_flight); - } - Err(error) => { - let _ = start.send(Err(error)); - } + } => match self.fsm.open_stream().map_err(QlError::from) { + Ok(stream_id) => { + let (response_tx, response_rx) = async_channel::unbounded(); + self.streams.insert( + stream_id, + DriverStreamIo::Initiator { + request: OutboundIo::new(request_reader), + response: InboundIo::new(response_tx), + }, + ); + let _ = start.send(Ok(OpenedStreamDelivery { + stream_id, + response: response_rx, + })); + self.poll_stream(stream_id); + self.finish_step(platform, in_flight); } - } + Err(error) => { + let _ = start.send(Err(error)); + } + }, RuntimeCommand::PollStream { stream_id } => { self.poll_stream(stream_id); - self.drive_pending(platform, in_flight); + self.finish_step(platform, in_flight); } RuntimeCommand::CloseStream { stream_id, @@ -457,9 +246,7 @@ impl DriverState { code, payload, } => { - let _ = self - .engine - .close_stream(Instant::now(), stream_id, target, code, payload); + let _ = self.fsm.close_stream(stream_id, target, code, payload); self.finish_step(platform, in_flight); } } @@ -467,104 +254,184 @@ impl DriverState { fn drive_write_completed<'a, P: QlPlatform>( &mut self, - write_id: WriteId, + session_write_id: Option, result: Result<(), QlError>, platform: &'a P, in_flight: &mut Vec>, ) { - { - let now = self.engine.state.now; - let mut events = DriverEventSink::new( - platform, - &self.runtime_tx, - self.stream_send_buffer_bytes, - &mut self.pending_inputs, - &mut self.streams, - ); - self.engine - .complete_write(now, write_id, result, &mut events); + if let Some(write_id) = session_write_id { + match result { + Ok(()) => self.fsm.confirm_session_write(now(), write_id), + Err(_) => self.fsm.reject_session_write(write_id), + } } self.finish_step(platform, in_flight); } - fn drive_pending<'a, P: QlPlatform>( + fn finish_step<'a, P: QlPlatform>( &mut self, platform: &'a P, in_flight: &mut Vec>, ) { - while let Some(input) = self.pending_inputs.pop_front() { - let now = Instant::now(); - match input { - PendingAction::CloseStream { - stream_id, - target, - code, - payload, - } => { - let _ = self - .engine - .close_stream(now, stream_id, target, code, payload); - } - PendingAction::OutboundData { stream_id, bytes } => { - let _ = self.engine.write_stream(now, stream_id, bytes); + loop { + let mut progressed = false; + + progressed |= self.drain_fsm(platform); + progressed |= self.fill_write_slots(platform, in_flight); + + if !progressed { + break; + } + } + } + + fn drain_fsm(&mut self, platform: &P) -> bool { + let mut progressed = false; + + while let Some(event) = self.fsm.take_next_event() { + progressed = true; + match event { + QlFsmEvent::NewPeer(peer) => platform.persist_peer(peer), + QlFsmEvent::ClearPeer => platform.clear_peer(), + QlFsmEvent::PeerStatusChanged { peer, status } => { + platform.handle_peer_status(peer, status) } - PendingAction::OutboundFinished { stream_id } => { - let _ = self.engine.finish_stream(now, stream_id); + } + } + + while let Some(event) = self.fsm.take_next_session_event() { + progressed = true; + match event { + QlSessionEvent::Opened(stream_id) => self.handle_opened_stream(platform, stream_id), + QlSessionEvent::Data { stream_id, bytes } => { + self.handle_inbound_data(stream_id, bytes) } + QlSessionEvent::Finished(stream_id) => self.handle_inbound_finished(stream_id), + QlSessionEvent::Closed(frame) => self.handle_closed_stream(frame), + QlSessionEvent::WritableClosed(stream_id) => self.handle_writable_closed(stream_id), + QlSessionEvent::Unpaired => self.fail_all_streams(QlError::Cancelled), + QlSessionEvent::SessionClosed(_) => self.fail_all_streams(QlError::SessionClosed), } - self.fill_write_slots(platform, in_flight); } - self.fill_write_slots(platform, in_flight); + progressed } - fn drive_timer<'a, P: QlPlatform>( - &mut self, - platform: &'a P, - in_flight: &mut Vec>, - ) { - let now = Instant::now(); - let mut events = DriverEventSink::new( - platform, - &self.runtime_tx, - self.stream_send_buffer_bytes, - &mut self.pending_inputs, - &mut self.streams, + fn handle_opened_stream(&mut self, platform: &P, stream_id: StreamId) { + let (request_tx, request_rx) = async_channel::unbounded(); + let (response_reader, response_writer) = piper::pipe(self.stream_send_buffer_bytes); + + self.streams.insert( + stream_id, + DriverStreamIo::Responder { + request: InboundIo::new(request_tx), + response: OutboundIo::new(response_reader), + }, ); - self.engine.on_timer(now, platform, &mut events); - self.finish_step(platform, in_flight); + + platform.handle_inbound(HandlerEvent::Stream(InboundStream { + stream_id, + request: ByteReader::new( + stream_id, + CloseTarget::Request, + request_rx, + self.runtime_tx.clone(), + ), + response: ByteWriter::new( + stream_id, + CloseTarget::Response, + response_writer, + self.runtime_tx.clone(), + ), + })); } - fn finish_step<'a, P: QlPlatform>( - &mut self, - platform: &'a P, - in_flight: &mut Vec>, - ) { - self.fill_write_slots(platform, in_flight); - self.drive_pending(platform, in_flight); + fn handle_inbound_data(&mut self, stream_id: StreamId, bytes: Vec) { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + + let target = stream.inbound_target(); + let should_close = stream.inbound_mut().write_or_close(bytes); + if should_close { + let _ = self + .fsm + .close_stream(stream_id, target, CloseCode::CANCELLED, Vec::new()); + } + } + + fn handle_inbound_finished(&mut self, stream_id: StreamId) { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + stream.inbound_mut().finish(); + } + + fn handle_closed_stream(&mut self, frame: ql_wire::StreamClose) { + let Some(stream) = self.streams.get_mut(&frame.stream_id) else { + return; + }; + + let error = QlError::StreamClosed { + target: frame.target, + code: frame.code, + payload: frame.payload.clone(), + }; + + if frame.target == CloseTarget::Both || frame.target == stream.inbound_target() { + stream.inbound_mut().fail(error); + } + if frame.target == CloseTarget::Both || frame.target == stream.outbound_target() { + stream.outbound_mut().close(); + } + } + + fn handle_writable_closed(&mut self, stream_id: StreamId) { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + stream.outbound_mut().close(); + } + + fn fail_all_streams(&mut self, error: QlError) { + for stream in self.streams.values_mut() { + stream.fail_all(error.clone()); + } + self.streams.clear(); } fn fill_write_slots<'a, P: QlPlatform>( &mut self, platform: &'a P, in_flight: &mut Vec>, - ) { + ) -> bool { + let mut progressed = false; + while in_flight.len() < self.max_concurrent_message_writes { - let Some(write) = self.engine.take_next_write(self.engine.state.now, platform) else { + let Some(write) = self.fsm.take_next_write(now(), platform) else { break; }; + progressed = true; in_flight.push(InFlightWrite { - id: write.id, - future: platform.write_message(write.bytes), + session_write_id: write.session_write_id, + future: platform.write_message(write.record.encode()), }); } + + progressed } fn poll_stream(&mut self, stream_id: StreamId) { let Some(stream) = self.streams.get_mut(&stream_id) else { return; }; - stream.poll_pending(stream_id, &mut self.pending_inputs); + let (bytes, finished) = stream.outbound_mut().take_pending(); + if let Some(bytes) = bytes { + let _ = self.fsm.write_stream(stream_id, bytes); + } + if finished { + let _ = self.fsm.finish_stream(stream_id); + } } } @@ -572,23 +439,18 @@ async fn next_driver_event( rx: &async_channel::Receiver, platform: &P, next_timer: Option, - in_flight: &mut Vec>, + in_flight: &mut [InFlightWrite<'_>], ) -> DriverEvent { - let recv_future = rx.recv(); - futures_lite::pin!(recv_future); - + let mut recv_future = (!rx.is_closed()).then(|| Box::pin(rx.recv())); let mut sleep_future = next_timer.map(|deadline| { let timeout = deadline.saturating_duration_since(Instant::now()); platform.sleep(timeout) }); poll_fn(|cx| { - for write in in_flight.iter_mut() { + for (index, write) in in_flight.iter_mut().enumerate() { if let Poll::Ready(result) = write.future.as_mut().poll(cx) { - return Poll::Ready(DriverEvent::WriteCompleted { - write_id: write.id, - result, - }); + return Poll::Ready(DriverEvent::WriteCompleted { index, result }); } } @@ -598,10 +460,16 @@ async fn next_driver_event( } } - recv_future.as_mut().poll(cx).map(|res| match res { - Ok(command) => DriverEvent::Command(command), - Err(_) => DriverEvent::Closed, - }) + if let Some(future) = recv_future.as_mut() { + if let Poll::Ready(res) = future.as_mut().poll(cx) { + return Poll::Ready(match res { + Ok(command) => DriverEvent::Command(command), + Err(_) => DriverEvent::CommandsClosed, + }); + } + } + + Poll::Pending }) .await } @@ -615,11 +483,15 @@ impl Runtime

{ rx, tx, } = self; - let peer = platform.load_peer().await; + let runtime_tx = tx.upgrade().expect("runtime tx"); + let mut fsm = QlFsm::new(config.fsm, identity, now()); + if let Some(peer) = platform.load_peer().await { + fsm.bind_peer(peer); + } + let mut state = DriverState { - engine: Engine::new(config.engine, identity, peer), - pending_inputs: VecDeque::new(), + fsm, streams: HashMap::new(), runtime_tx, stream_send_buffer_bytes: config.stream_send_buffer_bytes, @@ -628,29 +500,46 @@ impl Runtime

{ let mut in_flight = Vec::new(); loop { - state.drive_pending(&platform, &mut in_flight); + state.finish_step(&platform, &mut in_flight); - if rx.is_closed() && state.pending_inputs.is_empty() && in_flight.is_empty() { + if rx.is_closed() && in_flight.is_empty() { break; } - match next_driver_event(&rx, &platform, state.engine.next_deadline(), &mut in_flight) - .await + match next_driver_event(&rx, &platform, state.fsm.next_deadline(), &mut in_flight).await { DriverEvent::Command(command) => { - state.drive_command(command, &platform, &mut in_flight); + state.drive_command(command, &platform, &mut in_flight) } - DriverEvent::WriteCompleted { write_id, result } => { - if let Some(index) = in_flight.iter().position(|write| write.id == write_id) { - in_flight.swap_remove(index); - } - state.drive_write_completed(write_id, result, &platform, &mut in_flight); + DriverEvent::WriteCompleted { index, result } => { + let write = in_flight.swap_remove(index); + state.drive_write_completed( + write.session_write_id, + result, + &platform, + &mut in_flight, + ); } DriverEvent::TimerExpired => { - state.drive_timer(&platform, &mut in_flight); + state.fsm.on_timer(now()); + state.finish_step(&platform, &mut in_flight); } - DriverEvent::Closed => break, + DriverEvent::CommandsClosed => {} } } } } + +fn now() -> FsmTime { + FsmTime { + instant: Instant::now(), + unix_secs: unix_now_secs(), + } +} + +fn unix_now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or(Duration::ZERO) + .as_secs() +} diff --git a/ql-runtime/src/handle.rs b/ql-runtime/src/handle.rs index 1a2618a9..10b69a22 100644 --- a/ql-runtime/src/handle.rs +++ b/ql-runtime/src/handle.rs @@ -2,9 +2,8 @@ use async_channel::{Receiver, Sender}; use futures_lite::future::poll_fn; use crate::{ - command::RuntimeCommand, InboundEvent, OpenedStreamDelivery, StreamConfig, - wire::stream::{CloseCode, CloseTarget}, - Peer, QlError, StreamId, + command::RuntimeCommand, CloseCode, CloseTarget, InboundEvent, OpenedStreamDelivery, Peer, + QlError, StreamId, }; #[derive(Clone)] @@ -13,21 +12,21 @@ pub struct RuntimeHandle { pub(crate) stream_send_buffer_bytes: usize, } -pub struct DuplexStream { +#[derive(Debug)] +pub struct OutboundStream { pub stream_id: StreamId, - pub request: OutboundByteStream, - pub response: InboundByteStream, + pub request: ByteWriter, + pub response: ByteReader, } #[derive(Debug)] pub struct InboundStream { pub stream_id: StreamId, - pub request_head: Vec, - pub request: InboundByteStream, - pub response: OutboundByteStream, + pub request: ByteReader, + pub response: ByteWriter, } -pub struct InboundByteStream { +pub struct ByteReader { stream_id: StreamId, target: CloseTarget, rx: Receiver, @@ -35,7 +34,7 @@ pub struct InboundByteStream { finished: bool, } -impl std::fmt::Debug for InboundByteStream { +impl std::fmt::Debug for ByteReader { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("InboundByteStream") .field("stream_id", &self.stream_id) @@ -45,14 +44,14 @@ impl std::fmt::Debug for InboundByteStream { } } -pub struct OutboundByteStream { +pub struct ByteWriter { stream_id: StreamId, target: CloseTarget, writer: Option, tx: Sender, } -impl std::fmt::Debug for OutboundByteStream { +impl std::fmt::Debug for ByteWriter { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("OutboundByteStream") .field("stream_id", &self.stream_id) @@ -62,7 +61,7 @@ impl std::fmt::Debug for OutboundByteStream { } } -impl InboundByteStream { +impl ByteReader { pub(crate) fn new( stream_id: StreamId, target: CloseTarget, @@ -116,7 +115,7 @@ impl InboundByteStream { } } -impl Drop for InboundByteStream { +impl Drop for ByteReader { fn drop(&mut self) { if self.finished { return; @@ -130,7 +129,7 @@ impl Drop for InboundByteStream { } } -impl OutboundByteStream { +impl ByteWriter { pub(crate) fn new( stream_id: StreamId, target: CloseTarget, @@ -202,7 +201,7 @@ impl OutboundByteStream { } } -impl Drop for OutboundByteStream { +impl Drop for ByteWriter { fn drop(&mut self) { if self.writer.take().is_none() { return; @@ -243,20 +242,14 @@ impl RuntimeHandle { self.send(RuntimeCommand::Incoming(bytes)) } - pub async fn open_stream( - &self, - request_head: Vec, - config: StreamConfig, - ) -> Result { + pub async fn open_stream(&self) -> Result { let (request_reader, request_writer) = piper::pipe(self.stream_send_buffer_bytes); let (start_tx, start_rx) = oneshot::channel(); self.tx .send(RuntimeCommand::OpenStream { - request_head, request_reader, start: start_tx, - config, }) .await .map_err(|_| QlError::Cancelled)?; @@ -266,20 +259,15 @@ impl RuntimeHandle { response, } = start_rx.await.unwrap_or(Err(QlError::Cancelled))?; - Ok(DuplexStream { + Ok(OutboundStream { stream_id, - request: OutboundByteStream::new( + request: ByteWriter::new( stream_id, CloseTarget::Request, request_writer, self.tx.clone(), ), - response: InboundByteStream::new( - stream_id, - CloseTarget::Response, - response, - self.tx.clone(), - ), + response: ByteReader::new(stream_id, CloseTarget::Response, response, self.tx.clone()), }) } } diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index 14ac5f98..590d16b1 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -1,11 +1,6 @@ -pub use handle::{ - DuplexStream, InboundByteStream, InboundStream, OutboundByteStream, RuntimeHandle, -}; -pub use ql_engine::{engine, identity, wire, PacketId, Peer, QlError, StreamId}; - -pub use crate::engine::{ - EngineConfig, HandshakeInitiator, KeepAliveConfig, PeerSession, StreamConfig, -}; +pub use handle::{ByteReader, ByteWriter, InboundStream, OutboundStream, RuntimeHandle}; +pub use ql_fsm::{Peer, PeerStatus, QlFsmConfig, QlFsmError, SessionWriteId}; +pub use ql_wire::{self as wire, CloseCode, CloseTarget, QlIdentity, StreamId, XID}; pub(crate) mod command; pub(crate) mod driver; @@ -15,12 +10,64 @@ pub mod platform; #[cfg(test)] mod tests; +use thiserror::Error; + use self::platform::QlPlatform; -use crate::identity::QlIdentity; + +#[derive(Debug, Clone, PartialEq, Eq, Error)] +pub enum QlError { + #[error("invalid payload")] + InvalidPayload, + #[error("invalid signature")] + InvalidSignature, + #[error("expired")] + Expired, + #[error("signing failed")] + SigningFailed, + #[error("encryption failed")] + EncryptFailed, + #[error("decryption failed")] + DecryptFailed, + #[error("missing stream")] + MissingStream, + #[error("stream is not writable")] + NotWritable, + #[error("session is closed")] + SessionClosed, + #[error("no peer bound")] + NoPeerBound, + #[error("send failed")] + SendFailed, + #[error("stream closed {code:?}")] + StreamClosed { + target: CloseTarget, + code: CloseCode, + payload: Vec, + }, + #[error("cancelled")] + Cancelled, +} + +impl From for QlError { + fn from(value: QlFsmError) -> Self { + match value { + QlFsmError::InvalidPayload => Self::InvalidPayload, + QlFsmError::InvalidSignature => Self::InvalidSignature, + QlFsmError::Expired => Self::Expired, + QlFsmError::SigningFailed => Self::SigningFailed, + QlFsmError::EncryptFailed => Self::EncryptFailed, + QlFsmError::DecryptFailed => Self::DecryptFailed, + QlFsmError::MissingStream => Self::MissingStream, + QlFsmError::NotWritable => Self::NotWritable, + QlFsmError::SessionClosed => Self::SessionClosed, + QlFsmError::NoPeerBound => Self::NoPeerBound, + } + } +} #[derive(Debug, Clone, Copy)] pub struct RuntimeConfig { - pub engine: EngineConfig, + pub fsm: QlFsmConfig, pub stream_send_buffer_bytes: usize, pub max_concurrent_message_writes: usize, } @@ -28,7 +75,7 @@ pub struct RuntimeConfig { impl Default for RuntimeConfig { fn default() -> Self { Self { - engine: EngineConfig::default(), + fsm: QlFsmConfig::default(), stream_send_buffer_bytes: 64 * 1024, max_concurrent_message_writes: 4, } diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs index 8a6ce873..08660674 100644 --- a/ql-runtime/src/platform.rs +++ b/ql-runtime/src/platform.rs @@ -1,11 +1,8 @@ use std::{future::Future, pin::Pin, time::Duration}; -use bc_components::XID; +use ql_wire::QlCrypto; -use crate::{ - engine::{PeerSession, QlCrypto}, - Peer, QlError, -}; +use crate::{Peer, PeerStatus, QlError, XID}; pub type PlatformFuture<'a, T> = Pin + 'a>>; @@ -17,6 +14,6 @@ pub trait QlPlatform: QlCrypto { fn persist_peer(&self, peer: Peer); fn clear_peer(&self); - fn handle_peer_status(&self, peer: XID, session: &PeerSession); + fn handle_peer_status(&self, peer: XID, status: PeerStatus); fn handle_inbound(&self, event: super::HandlerEvent); } diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index 096f38e6..c95db6e5 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -8,8 +8,8 @@ async fn connect_round_trip_changes_peer_status() { let config = default_runtime_config(); let (platform_a, outbound_a, status_a) = TestPlatform::new(1); let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let identity_a = new_identity(); - let identity_b = new_identity(); + let identity_a = new_identity(11); + let identity_b = new_identity(73); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); @@ -29,20 +29,68 @@ async fn connect_round_trip_changes_peer_status() { .await; } +#[tokio::test(flavor = "current_thread")] +async fn opening_stream_auto_connects() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); + let identity_a = new_identity(11); + let identity_b = new_identity(73); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + + let responder = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { + HandlerEvent::Stream(stream) => stream, + }; + let request = read_all(stream.request).await.unwrap(); + stream.response.finish().await.unwrap(); + request + }); + + let mut stream = handle_a.open_stream().await.unwrap(); + stream.request.write_all(b"auto-connect").await.unwrap(); + stream.request.finish().await.unwrap(); + assert_eq!(stream.response.next_chunk().await.unwrap(), None); + + assert_eq!( + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(), + b"auto-connect".to_vec() + ); + + await_status(&status_a, identity_b.xid, PeerStage::Connected).await; + await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + }) + .await; +} + #[tokio::test(flavor = "current_thread")] async fn handshake_timeout_disconnects() { run_local_test(async { let config = RuntimeConfig { - engine: crate::engine::EngineConfig { + fsm: QlFsmConfig { handshake_timeout: Duration::from_millis(60), - ..default_runtime_config().engine + ..default_runtime_config().fsm }, ..default_runtime_config() }; let (platform_a, _outbound_a, status_a) = TestPlatform::new(1); let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); - let identity_a = new_identity(); - let identity_b = new_identity(); + let identity_a = new_identity(11); + let identity_b = new_identity(73); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); @@ -59,13 +107,13 @@ async fn handshake_timeout_disconnects() { } #[tokio::test(flavor = "current_thread")] -async fn confirm_write_failure_disconnects_initiator() { +async fn rejected_session_write_is_reissued() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new_with_stream_write_failure(1, 1); + let (platform_a, outbound_a, status_a) = TestPlatform::new_with_session_write_failure(1, 1); let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(); - let identity_b = new_identity(); + let identity_a = new_identity(11); + let identity_b = new_identity(73); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); @@ -82,23 +130,27 @@ async fn confirm_write_failure_disconnects_initiator() { await_status(&status_a, identity_b.xid, PeerStage::Connected).await; await_status(&status_b, identity_a.xid, PeerStage::Connected).await; - let responder_task = tokio::task::spawn_local(async move { - let second = match inbound_b.recv().await.unwrap() { + let responder = tokio::task::spawn_local(async move { + let stream = match inbound_b.recv().await.unwrap() { HandlerEvent::Stream(stream) => stream, }; - let mut second_request = second.request; - let mut second_response = second.response; - assert_eq!(second_request.next_chunk().await.unwrap(), None); - second_response.write_all(b"ok").await.unwrap(); - second_response.finish().await.unwrap(); + let request = read_all(stream.request).await.unwrap(); + stream.response.finish().await.unwrap(); + request }); - let mut first = handle_a - .open_stream(Vec::new(), crate::StreamConfig::default()) - .await - .unwrap(); - let _ = first.request.finish().await; - let _ = first.response.next_chunk().await; + let mut stream = handle_a.open_stream().await.unwrap(); + stream.request.write_all(b"retry").await.unwrap(); + stream.request.finish().await.unwrap(); + assert_eq!(stream.response.next_chunk().await.unwrap(), None); + + assert_eq!( + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(), + b"retry".to_vec() + ); assert_no_status_for( &status_a, @@ -107,19 +159,6 @@ async fn confirm_write_failure_disconnects_initiator() { Duration::from_millis(150), ) .await; - - let mut second = handle_a - .open_stream(Vec::new(), crate::StreamConfig::default()) - .await - .unwrap(); - second.request.finish().await.unwrap(); - assert_eq!(second.response.next_chunk().await.unwrap(), Some(b"ok".to_vec())); - assert_eq!(second.response.next_chunk().await.unwrap(), None); - - tokio::time::timeout(Duration::from_secs(2), responder_task) - .await - .unwrap() - .unwrap(); }) .await; } diff --git a/ql-runtime/src/tests/heartbeat.rs b/ql-runtime/src/tests/heartbeat.rs index 57ff7e53..2d1ba718 100644 --- a/ql-runtime/src/tests/heartbeat.rs +++ b/ql-runtime/src/tests/heartbeat.rs @@ -9,170 +9,21 @@ use std::{ use super::*; #[tokio::test(flavor = "current_thread")] -async fn keepalive_disabled_no_heartbeat() { +async fn session_timeout_disconnects_and_fails_pending_open() { run_local_test(async { - let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let identity_a = new_identity(); - let identity_b = new_identity(); - - let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); - let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect().unwrap(); - - await_status(&status_a, identity_b.xid, PeerStage::Connected).await; - await_status(&status_b, identity_a.xid, PeerStage::Connected).await; - - let result = tokio::time::timeout(Duration::from_millis(120), heartbeat_rx.recv()).await; - assert!(result.is_err(), "unexpected heartbeat while disabled"); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn heartbeat_sent_after_idle() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(30), - timeout: Duration::from_millis(80), - }; - let config_a = RuntimeConfig { - engine: crate::engine::EngineConfig { - keep_alive: Some(keep_alive), - ..default_runtime_config().engine - }, - ..default_runtime_config() - }; - let config_b = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let identity_a = new_identity(); - let identity_b = new_identity(); - - let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config_b); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect().unwrap(); - - await_status(&status_a, identity_b.xid, PeerStage::Connected).await; - await_status(&status_b, identity_a.xid, PeerStage::Connected).await; - - tokio::time::timeout(Duration::from_millis(200), heartbeat_rx.recv()) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn stream_activity_prevents_keepalive_timeout() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(120), - timeout: Duration::from_millis(40), - }; - let config_a = RuntimeConfig { - engine: crate::engine::EngineConfig { - keep_alive: Some(keep_alive), - ..default_runtime_config().engine - }, - ..default_runtime_config() - }; - let config_b = default_runtime_config(); - let (platform_a, outbound_a, status_a, inbound_a) = TestPlatform::new_with_inbound(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let identity_a = new_identity(); - let identity_b = new_identity(); - - let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config_b); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded(); - spawn_heartbeat_tap_forwarder(outbound_a, handle_b.clone(), heartbeat_tx); - spawn_drop_heartbeat_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect().unwrap(); - - await_status(&status_a, identity_b.xid, PeerStage::Connected).await; - await_status(&status_b, identity_a.xid, PeerStage::Connected).await; - - tokio::time::timeout(Duration::from_millis(200), heartbeat_rx.recv()) - .await - .unwrap() - .unwrap(); - - let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_a.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; - let response = stream.response; - response.finish().await.unwrap(); - }); - - let stream = handle_b.open_stream(Vec::new(), crate::StreamConfig::default()).await; - let mut stream = stream.unwrap(); - stream.request.finish().await.unwrap(); - assert_eq!(stream.response.next_chunk().await.unwrap(), None); - - let disconnect = tokio::time::timeout(keep_alive.timeout + Duration::from_millis(20), async { - loop { - if let Ok(event) = status_a.recv().await { - if event.peer == identity_b.xid && event.stage == PeerStage::Disconnected { - return; - } - } - } - }) - .await; - assert!(disconnect.is_err(), "unexpected disconnect"); - - let _ = responder_task.await; - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn heartbeat_timeout_disconnects_and_fails_pending_open() { - run_local_test(async { - let keep_alive = KeepAliveConfig { - interval: Duration::from_millis(80), - timeout: Duration::from_millis(60), - }; let config_a = RuntimeConfig { - engine: crate::engine::EngineConfig { - keep_alive: Some(keep_alive), - ..default_runtime_config().engine + fsm: QlFsmConfig { + session_keepalive_interval: Duration::from_millis(40), + session_peer_timeout: Duration::from_millis(60), + ..default_runtime_config().fsm }, ..default_runtime_config() }; let config_b = default_runtime_config(); let (platform_a, outbound_a, status_a) = TestPlatform::new(2); let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(1); - let identity_a = new_identity(); - let identity_b = new_identity(); + let identity_a = new_identity(11); + let identity_b = new_identity(73); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config_a); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config_b); @@ -194,22 +45,26 @@ async fn heartbeat_timeout_disconnects_and_fails_pending_open() { let stream = match inbound_b.recv().await.unwrap() { HandlerEvent::Stream(stream) => stream, }; + let _ = read_all(stream.request).await; let response = stream.response; - response.finish().await.unwrap(); + let _ = response.finish().await; }); drop_flag.store(true, Ordering::Relaxed); - let mut pending = handle_a - .open_stream(Vec::new(), crate::StreamConfig::default()) - .await - .unwrap(); + let mut pending = handle_a.open_stream().await.unwrap(); + pending.request.finish().await.unwrap(); await_status(&status_a, identity_b.xid, PeerStage::Disconnected).await; - let result = tokio::time::timeout(Duration::from_millis(300), pending.response.next_chunk()) - .await; - assert!(result.is_ok(), "pending stream never resolved after disconnect"); + let result = + tokio::time::timeout(Duration::from_millis(300), pending.response.next_chunk()) + .await + .unwrap(); + assert!(matches!( + result, + Err(QlError::SessionClosed) | Err(QlError::Cancelled) + )); responder_task.abort(); }) diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 30aa2ff8..d37d1802 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -1,4 +1,5 @@ use std::{ + cell::Cell, future::Future, sync::{ atomic::{AtomicU8, AtomicUsize, Ordering}, @@ -8,16 +9,17 @@ use std::{ }; use async_channel::{Receiver, Sender}; -use bc_components::{MLDSA, MLKEM}; +use libcrux_aesgcm::AesGcm256Key; +use ql_wire::{ + generate_ml_dsa_keypair, generate_ml_kem_keypair, EncryptedMessage, Nonce, QlCrypto, + QlIdentity, QlPayload, QlRecord, SessionKey, XID, +}; +use sha2::{Digest, Sha256}; use tokio::task::LocalSet; use crate::{ - engine::QlCrypto, - identity::QlIdentity, - new_runtime, - platform::PlatformFuture, - wire::{self, QlPayload}, - HandlerEvent, KeepAliveConfig, Peer, PeerSession, QlError, RuntimeConfig, RuntimeHandle, + new_runtime, platform::PlatformFuture, HandlerEvent, Peer, PeerStatus, QlError, QlFsmConfig, + RuntimeConfig, RuntimeHandle, }; mod handshake; @@ -35,7 +37,7 @@ enum PeerStage { #[derive(Debug, Clone, Copy, PartialEq, Eq)] struct StatusEvent { - peer: bc_components::XID, + peer: XID, stage: PeerStage, } @@ -58,14 +60,79 @@ impl WriteStats { } } +struct DeterministicCrypto { + seed: u8, + counter: Cell, +} + +impl DeterministicCrypto { + fn new(seed: u8) -> Self { + Self { + seed, + counter: Cell::new(0), + } + } +} + +impl QlCrypto for DeterministicCrypto { + fn fill_random_bytes(&self, data: &mut [u8]) { + let value = self.seed.wrapping_add(self.counter.get()); + self.counter.set(self.counter.get().wrapping_add(1)); + data.fill(value); + } + + fn hash(&self, parts: &[&[u8]]) -> [u8; 32] { + let mut hasher = Sha256::new(); + for part in parts { + hasher.update(part); + } + hasher.finalize().into() + } + + fn encrypt_with_aead( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + ) -> Option<[u8; EncryptedMessage::AUTH_SIZE]> { + let key: AesGcm256Key = (*key.data()).into(); + let plaintext = buffer.to_vec(); + let mut auth = [0u8; EncryptedMessage::AUTH_SIZE]; + key.encrypt( + buffer, + (&mut auth).into(), + (&nonce.0).into(), + aad, + &plaintext, + ) + .ok()?; + Some(auth) + } + + fn decrypt_with_aead( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + auth_tag: &[u8; EncryptedMessage::AUTH_SIZE], + ) -> bool { + let key: AesGcm256Key = (*key.data()).into(); + let ciphertext = buffer.to_vec(); + key.decrypt(buffer, (&nonce.0).into(), aad, &ciphertext, auth_tag.into()) + .is_ok() + } +} + struct TestPlatform { outbound: Sender>, status: Sender, inbound: Option>, nonce_seed: u8, nonce_counter: AtomicU8, - stream_write_counter: AtomicUsize, - fail_stream_write_at: Option, + encrypted_write_counter: AtomicUsize, + fail_encrypted_write_at: Option, write_delay: Duration, write_stats: Option, } @@ -89,11 +156,17 @@ impl TestPlatform { (platform, outbound_rx, status_rx, inbound_rx) } - fn new_with_stream_write_failure( + fn new_with_session_write_failure( seed: u8, - fail_stream_write_at: usize, + fail_encrypted_write_at: usize, ) -> (Self, Receiver>, Receiver) { - Self::new_inner(seed, None, Some(fail_stream_write_at), Duration::ZERO, None) + Self::new_inner( + seed, + None, + Some(fail_encrypted_write_at), + Duration::ZERO, + None, + ) } fn new_with_delayed_writes( @@ -107,7 +180,7 @@ impl TestPlatform { fn new_inner( seed: u8, inbound: Option>, - fail_stream_write_at: Option, + fail_encrypted_write_at: Option, write_delay: Duration, write_stats: Option, ) -> (Self, Receiver>, Receiver) { @@ -120,8 +193,8 @@ impl TestPlatform { inbound, nonce_seed: seed, nonce_counter: AtomicU8::new(0), - stream_write_counter: AtomicUsize::new(0), - fail_stream_write_at, + encrypted_write_counter: AtomicUsize::new(0), + fail_encrypted_write_at, write_delay, write_stats, }, @@ -138,13 +211,56 @@ impl QlCrypto for TestPlatform { .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); data.fill(value); } + + fn hash(&self, parts: &[&[u8]]) -> [u8; 32] { + let mut hasher = Sha256::new(); + for part in parts { + hasher.update(part); + } + hasher.finalize().into() + } + + fn encrypt_with_aead( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + ) -> Option<[u8; EncryptedMessage::AUTH_SIZE]> { + let key: AesGcm256Key = (*key.data()).into(); + let plaintext = buffer.to_vec(); + let mut auth = [0u8; EncryptedMessage::AUTH_SIZE]; + key.encrypt( + buffer, + (&mut auth).into(), + (&nonce.0).into(), + aad, + &plaintext, + ) + .ok()?; + Some(auth) + } + + fn decrypt_with_aead( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + auth_tag: &[u8; EncryptedMessage::AUTH_SIZE], + ) -> bool { + let key: AesGcm256Key = (*key.data()).into(); + let ciphertext = buffer.to_vec(); + key.decrypt(buffer, (&nonce.0).into(), aad, &ciphertext, auth_tag.into()) + .is_ok() + } } impl crate::platform::QlPlatform for TestPlatform { fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { let outbound = self.outbound.clone(); let write_delay = self.write_delay; - let fail_stream_write_at = self.fail_stream_write_at; + let fail_encrypted_write_at = self.fail_encrypted_write_at; let write_stats = self.write_stats.clone(); Box::pin(async move { @@ -158,9 +274,9 @@ impl crate::platform::QlPlatform for TestPlatform { } let mut should_fail = false; - if is_stream_payload(&message) { - let count = self.stream_write_counter.fetch_add(1, Ordering::Relaxed) + 1; - should_fail = fail_stream_write_at == Some(count); + if is_encrypted_payload(&message) { + let count = self.encrypted_write_counter.fetch_add(1, Ordering::Relaxed) + 1; + should_fail = fail_encrypted_write_at == Some(count); } let result = if should_fail { @@ -192,12 +308,12 @@ impl crate::platform::QlPlatform for TestPlatform { fn clear_peer(&self) {} - fn handle_peer_status(&self, peer: bc_components::XID, session: &PeerSession) { - let stage = match session { - PeerSession::Disconnected => PeerStage::Disconnected, - PeerSession::Initiator { .. } => PeerStage::Initiator, - PeerSession::Responder { .. } => PeerStage::Responder, - PeerSession::Connected { .. } => PeerStage::Connected, + fn handle_peer_status(&self, peer: XID, status: PeerStatus) { + let stage = match status { + PeerStatus::Disconnected => PeerStage::Disconnected, + PeerStatus::Initiator => PeerStage::Initiator, + PeerStatus::Responder => PeerStage::Responder, + PeerStatus::Connected => PeerStage::Connected, }; let _ = self.status.try_send(StatusEvent { peer, stage }); } @@ -209,16 +325,18 @@ impl crate::platform::QlPlatform for TestPlatform { } } -fn is_stream_payload(bytes: &[u8]) -> bool { - wire::decode_record(bytes) +fn is_encrypted_payload(bytes: &[u8]) -> bool { + QlRecord::decode(bytes) .ok() - .is_some_and(|record| matches!(record.payload, QlPayload::Stream(_))) + .is_some_and(|record| matches!(record.payload, QlPayload::Session(_))) } -fn new_identity() -> QlIdentity { - let (signing_private, signing_public) = MLDSA::MLDSA44.keypair(); - let (encapsulation_private, encapsulation_public) = MLKEM::MLKEM512.keypair(); - QlIdentity::from_keys( +fn new_identity(seed: u8) -> QlIdentity { + let crypto = DeterministicCrypto::new(seed); + let (signing_private, signing_public) = generate_ml_dsa_keypair(&crypto); + let (encapsulation_private, encapsulation_public) = generate_ml_kem_keypair(&crypto); + QlIdentity::new( + XID([seed; XID::SIZE]), signing_private, signing_public, encapsulation_private, @@ -228,7 +346,7 @@ fn new_identity() -> QlIdentity { fn peer_from_identity(identity: &QlIdentity) -> Peer { Peer { - peer: identity.xid, + xid: identity.xid, signing_key: identity.signing_public_key.clone(), encapsulation_key: identity.encapsulation_public_key.clone(), } @@ -252,17 +370,17 @@ fn spawn_forwarder(outbound: Receiver>, handle: RuntimeHandle) { }); } -fn spawn_drop_every_nth_stream_forwarder( +fn spawn_drop_every_nth_encrypted_forwarder( outbound: Receiver>, handle: RuntimeHandle, nth: usize, ) { tokio::task::spawn_local(async move { - let mut stream_count = 0usize; + let mut encrypted_count = 0usize; while let Ok(bytes) = outbound.recv().await { - if nth > 0 && is_stream_payload(&bytes) { - stream_count = stream_count.saturating_add(1); - if stream_count % nth == 0 { + if nth > 0 && is_encrypted_payload(&bytes) { + encrypted_count = encrypted_count.saturating_add(1); + if encrypted_count % nth == 0 { continue; } } @@ -271,38 +389,6 @@ fn spawn_drop_every_nth_stream_forwarder( }); } -fn is_heartbeat(bytes: &[u8]) -> bool { - wire::decode_record(bytes) - .ok() - .is_some_and(|record| matches!(record.payload, QlPayload::Heartbeat(_))) -} - -fn spawn_heartbeat_tap_forwarder( - outbound: Receiver>, - handle: RuntimeHandle, - heartbeat_tx: Sender<()>, -) { - tokio::task::spawn_local(async move { - while let Ok(bytes) = outbound.recv().await { - if is_heartbeat(&bytes) { - let _ = heartbeat_tx.send(()).await; - } - handle.send_incoming(bytes); - } - }); -} - -fn spawn_drop_heartbeat_forwarder(outbound: Receiver>, handle: RuntimeHandle) { - tokio::task::spawn_local(async move { - while let Ok(bytes) = outbound.recv().await { - if is_heartbeat(&bytes) { - continue; - } - handle.send_incoming(bytes); - } - }); -} - fn spawn_gated_forwarder( outbound: Receiver>, handle: RuntimeHandle, @@ -326,11 +412,7 @@ where local.run_until(future).await; } -async fn await_status( - receiver: &Receiver, - peer: bc_components::XID, - stage: PeerStage, -) { +async fn await_status(receiver: &Receiver, peer: XID, stage: PeerStage) { tokio::time::timeout(Duration::from_secs(2), async { loop { if let Ok(event) = receiver.recv().await { @@ -346,7 +428,7 @@ async fn await_status( async fn assert_no_status_for( receiver: &Receiver, - peer: bc_components::XID, + peer: XID, stage: PeerStage, window: Duration, ) { @@ -362,7 +444,7 @@ async fn assert_no_status_for( assert!(res.is_err(), "unexpected status event: {stage:?}"); } -async fn read_all(mut stream: crate::InboundByteStream) -> Result, QlError> { +async fn read_all(mut stream: crate::ByteReader) -> Result, QlError> { let mut data = Vec::new(); while let Some(chunk) = stream.next_chunk().await? { data.extend_from_slice(&chunk); @@ -372,10 +454,11 @@ async fn read_all(mut stream: crate::InboundByteStream) -> Result, QlErr fn default_runtime_config() -> RuntimeConfig { RuntimeConfig { - engine: crate::engine::EngineConfig { + fsm: QlFsmConfig { handshake_timeout: Duration::from_millis(300), - stream_ack_timeout: Duration::from_millis(30), - stream_retry_limit: 8, + session_retransmit_timeout: Duration::from_millis(30), + session_keepalive_interval: Duration::ZERO, + session_peer_timeout: Duration::ZERO, ..Default::default() }, ..Default::default() diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index 476999ef..2a092dc3 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -1,10 +1,7 @@ use std::time::Duration; use super::*; -use crate::{ - StreamConfig, - wire::stream::{CloseCode, CloseTarget}, -}; +use crate::{CloseCode, CloseTarget}; #[tokio::test(flavor = "current_thread")] async fn open_stream_duplex_happy_path() { @@ -12,8 +9,8 @@ async fn open_stream_duplex_happy_path() { let config = default_runtime_config(); let (platform_a, outbound_a, status_a) = TestPlatform::new(1); let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(); - let identity_b = new_identity(); + let identity_a = new_identity(11); + let identity_b = new_identity(73); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); @@ -34,7 +31,6 @@ async fn open_stream_duplex_happy_path() { let inbound = match inbound_b.recv().await.unwrap() { HandlerEvent::Stream(stream) => stream, }; - assert_eq!(inbound.request_head, b"req-head".to_vec()); let mut request = inbound.request; let mut response = inbound.response; @@ -47,15 +43,15 @@ async fn open_stream_duplex_happy_path() { response.finish().await.unwrap(); }); - let mut stream = handle_a - .open_stream(b"req-head".to_vec(), StreamConfig::default()) - .await - .unwrap(); + let mut stream = handle_a.open_stream().await.unwrap(); stream.request.write_all(&[1, 2]).await.unwrap(); assert_eq!(stream.response.next_chunk().await.unwrap(), Some(vec![9])); stream.request.write_all(&[3, 4]).await.unwrap(); stream.request.finish().await.unwrap(); - assert_eq!(stream.response.next_chunk().await.unwrap(), Some(vec![8, 7])); + assert_eq!( + stream.response.next_chunk().await.unwrap(), + Some(vec![8, 7]) + ); assert_eq!(stream.response.next_chunk().await.unwrap(), None); tokio::time::timeout(Duration::from_secs(2), responder) @@ -77,8 +73,8 @@ async fn stream_backpressure_with_small_runtime_buffer() { let (platform_a, outbound_a, status_a) = TestPlatform::new(1); let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(); - let identity_b = new_identity(); + let identity_a = new_identity(11); + let identity_b = new_identity(73); let (done_tx, done_rx) = async_channel::bounded(1); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); @@ -105,10 +101,7 @@ async fn stream_backpressure_with_small_runtime_buffer() { done_tx.send(request_data).await.unwrap(); }); - let mut stream = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); + let mut stream = handle_a.open_stream().await.unwrap(); stream.request.write_all(&payload).await.unwrap(); stream.request.finish().await.unwrap(); assert_eq!(stream.response.next_chunk().await.unwrap(), None); @@ -128,13 +121,13 @@ async fn stream_backpressure_with_small_runtime_buffer() { } #[tokio::test(flavor = "current_thread")] -async fn dropping_responder_rejects_as_unhandled() { +async fn dropping_responder_closes_initiator_response() { run_local_test(async { let config = default_runtime_config(); let (platform_a, outbound_a, status_a) = TestPlatform::new(1); let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(); - let identity_b = new_identity(); + let identity_a = new_identity(11); + let identity_b = new_identity(73); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); @@ -158,10 +151,7 @@ async fn dropping_responder_rejects_as_unhandled() { drop(stream.response); }); - let mut stream = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); + let mut stream = handle_a.open_stream().await.unwrap(); stream.request.finish().await.unwrap(); let err = stream.response.next_chunk().await.unwrap_err(); @@ -191,8 +181,8 @@ async fn dropping_inbound_reader_cancels_remote_writer() { }; let (platform_a, outbound_a, status_a) = TestPlatform::new(1); let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(); - let identity_b = new_identity(); + let identity_a = new_identity(11); + let identity_b = new_identity(73); let (go_tx, go_rx) = async_channel::bounded(1); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); @@ -223,12 +213,12 @@ async fn dropping_inbound_reader_cancels_remote_writer() { assert!(matches!(err, QlError::Cancelled)); }); - let mut stream = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); + let mut stream = handle_a.open_stream().await.unwrap(); stream.request.finish().await.unwrap(); - assert_eq!(stream.response.next_chunk().await.unwrap(), Some(vec![1, 2, 3, 4])); + assert_eq!( + stream.response.next_chunk().await.unwrap(), + Some(vec![1, 2, 3, 4]) + ); drop(stream.response); go_tx.send(()).await.unwrap(); @@ -251,8 +241,8 @@ async fn max_concurrent_message_writes_is_respected() { let (platform_a, outbound_a, status_a) = TestPlatform::new_with_delayed_writes(1, Duration::from_millis(40), stats.clone()); let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(); - let identity_b = new_identity(); + let identity_a = new_identity(11); + let identity_b = new_identity(73); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); @@ -274,10 +264,8 @@ async fn max_concurrent_message_writes_is_respected() { let stream = match inbound_b.recv().await.unwrap() { HandlerEvent::Stream(stream) => stream, }; - let request = stream.request; - let response = stream.response; - let _ = read_all(request).await; - let _ = response.finish().await; + let _ = read_all(stream.request).await; + let _ = stream.response.finish().await; } }); @@ -285,10 +273,7 @@ async fn max_concurrent_message_writes_is_respected() { for i in 0..4u8 { let handle = handle_a.clone(); tasks.push(tokio::task::spawn_local(async move { - let mut stream = handle - .open_stream(vec![i], StreamConfig::default()) - .await - .unwrap(); + let mut stream = handle.open_stream().await.unwrap(); stream.request.write_all(&[i; 8]).await.unwrap(); stream.request.finish().await.unwrap(); assert_eq!(stream.response.next_chunk().await.unwrap(), None); @@ -317,21 +302,20 @@ async fn max_concurrent_message_writes_is_respected() { } #[tokio::test(flavor = "current_thread")] -async fn stream_round_trip_survives_packet_drops() { +async fn stream_round_trip_survives_encrypted_packet_drops() { run_local_test(async { let config = RuntimeConfig { - engine: crate::engine::EngineConfig { - stream_retry_limit: 12, - stream_ack_timeout: Duration::from_millis(20), - ..default_runtime_config().engine + fsm: QlFsmConfig { + session_retransmit_timeout: Duration::from_millis(20), + ..default_runtime_config().fsm }, stream_send_buffer_bytes: 4, ..default_runtime_config() }; let (platform_a, outbound_a, status_a) = TestPlatform::new(1); let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(); - let identity_b = new_identity(); + let identity_a = new_identity(11); + let identity_b = new_identity(73); let request_payload: Vec = (0..32).collect(); let response_payload: Vec = (100..132).collect(); @@ -343,8 +327,8 @@ async fn stream_round_trip_survives_packet_drops() { tokio::task::spawn_local(async move { runtime_a.run().await }); tokio::task::spawn_local(async move { runtime_b.run().await }); - spawn_drop_every_nth_stream_forwarder(outbound_a, handle_b.clone(), 3); - spawn_drop_every_nth_stream_forwarder(outbound_b, handle_a.clone(), 3); + spawn_drop_every_nth_encrypted_forwarder(outbound_a, handle_b.clone(), 3); + spawn_drop_every_nth_encrypted_forwarder(outbound_b, handle_a.clone(), 3); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect().unwrap(); @@ -363,10 +347,7 @@ async fn stream_round_trip_survives_packet_drops() { received_request }); - let mut stream = handle_a - .open_stream(Vec::new(), StreamConfig::default()) - .await - .unwrap(); + let mut stream = handle_a.open_stream().await.unwrap(); stream.request.write_all(&request_payload).await.unwrap(); stream.request.finish().await.unwrap(); diff --git a/ql-runtime/src/tests/unpair.rs b/ql-runtime/src/tests/unpair.rs index e73be578..24791416 100644 --- a/ql-runtime/src/tests/unpair.rs +++ b/ql-runtime/src/tests/unpair.rs @@ -1,13 +1,13 @@ use super::*; #[tokio::test(flavor = "current_thread")] -async fn unpair_aborts_active_stream_and_clears_peer() { +async fn unpair_clears_remote_peer_and_aborts_active_stream() { run_local_test(async { let config = default_runtime_config(); let (platform_a, outbound_a, status_a) = TestPlatform::new(1); let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(); - let identity_b = new_identity(); + let identity_a = new_identity(11); + let identity_b = new_identity(73); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); @@ -29,48 +29,31 @@ async fn unpair_aborts_active_stream_and_clears_peer() { HandlerEvent::Stream(stream) => stream, }; let mut request = stream.request; - let _response = stream.response; - let first = request.next_chunk().await; - assert!(matches!(first, Ok(Some(_)) | Ok(None) | Err(_))); + let _ = request.next_chunk().await; let second = request.next_chunk().await; - assert!(matches!( - second, - Ok(None) - | Err(QlError::Cancelled) - | Err(QlError::SendFailed) - | Err(QlError::StreamClosed { .. }) - | Err(QlError::StreamProtocol) - )); + assert!(matches!(second, Ok(None) | Err(QlError::Cancelled))); }); - let mut stream = handle_a - .open_stream(Vec::new(), crate::StreamConfig::default()) - .await - .unwrap(); + let mut stream = handle_a.open_stream().await.unwrap(); stream.request.write_all(&[1, 2, 3, 4]).await.unwrap(); handle_a.unpair().unwrap(); - await_status(&status_a, identity_b.xid, PeerStage::Disconnected).await; - await_status(&status_b, identity_a.xid, PeerStage::Disconnected).await; - - let write_err = stream.request.write_all(&[5, 6, 7, 8]).await.unwrap_err(); - assert!(matches!(write_err, QlError::Cancelled)); - - let open_err_a = handle_a - .open_stream(Vec::new(), crate::StreamConfig::default()) - .await; - let open_err_b = handle_b - .open_stream(Vec::new(), crate::StreamConfig::default()) - .await; - - assert!(matches!(open_err_a, Err(QlError::NoPeerBound))); - assert!(matches!(open_err_b, Err(QlError::NoPeerBound))); - tokio::time::timeout(std::time::Duration::from_secs(2), responder) .await .unwrap() .unwrap(); + + let open_err_b = tokio::time::timeout(std::time::Duration::from_secs(2), async { + loop { + match handle_b.open_stream().await { + Err(QlError::NoPeerBound) => return, + _ => tokio::time::sleep(std::time::Duration::from_millis(10)).await, + } + } + }) + .await; + assert!(open_err_b.is_ok(), "remote peer was not cleared"); }) .await; } diff --git a/ql-wire/Cargo.toml b/ql-wire/Cargo.toml index eb1a76a7..cd579826 100644 --- a/ql-wire/Cargo.toml +++ b/ql-wire/Cargo.toml @@ -6,15 +6,11 @@ description = "Quantum Link wire format types and crypto helpers" license = "Proprietary" [dependencies] -bc-components = { version = "0.28.0", default-features = false, features = [ - "pqcrypto", -] } -chacha20poly1305 = { version = "0.10.1" } -rkyv = { version = "0.8", default-features = false, features = [ - "std", - "bytecheck", - "little_endian", - "unaligned", - "pointer_width_32", -] } +libcrux-ml-dsa = { version = "0.0.7", default-features = false, features = ["std", "mldsa87"] } +libcrux-ml-kem = { version = "0.0.7", default-features = false, features = ["std", "mlkem1024"] } thiserror = { version = "2" } +zerocopy = { version = "0.8", features = ["derive"] } + +[dev-dependencies] +libcrux-aesgcm = "0.0.7" +sha2 = "0.10" diff --git a/ql-wire/src/codec.rs b/ql-wire/src/codec.rs index ea1d76f0..63eb81d9 100644 --- a/ql-wire/src/codec.rs +++ b/ql-wire/src/codec.rs @@ -1,280 +1,75 @@ -use bc_components::{ - MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, MLKEMPublicKey, MLDSA, MLKEM, +use zerocopy::{ + byte_slice::{ByteSlice, SplitByteSlice}, + byteorder::little_endian, + FromBytes, Immutable, IntoBytes, KnownLayout, Ref, TryFromBytes, }; -use rkyv::{ - rancor::{Fallible, Source}, - with::{ArchiveWith, DeserializeWith, SerializeWith}, - Archive, Archived, Deserialize, Place, Resolver, Serialize, -}; - -use crate::WireError; - -macro_rules! impl_wire_wrapper { - ($marker:ident, $external:ty, $wire:ty) => { - pub(crate) struct $marker; - - impl ArchiveWith<$external> for $marker { - type Archived = Archived<$wire>; - type Resolver = Resolver<$wire>; - - fn resolve_with( - field: &$external, - resolver: Self::Resolver, - out: Place, - ) { - <$wire>::from(field).resolve(resolver, out); - } - } - - impl SerializeWith<$external, S> for $marker - where - S: Fallible + ?Sized, - $wire: Serialize, - { - fn serialize_with( - field: &$external, - serializer: &mut S, - ) -> Result { - <$wire>::from(field).serialize(serializer) - } - } - - impl DeserializeWith, $external, D> for $marker - where - D: Fallible + ?Sized, - D::Error: Source, - Archived<$wire>: Deserialize<$wire, D>, - $wire: TryInto<$external, Error = WireError>, - { - fn deserialize_with( - field: &Archived<$wire>, - deserializer: &mut D, - ) -> Result<$external, D::Error> { - field - .deserialize(deserializer)? - .try_into() - .map_err(D::Error::new) - } - } - }; -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] -#[repr(u8)] -pub(crate) enum WireMlDsaLevel { - MlDsa44 = 2, - MlDsa65 = 3, - MlDsa87 = 5, -} - -impl TryFrom for MLDSA { - type Error = WireError; - - fn try_from(value: WireMlDsaLevel) -> Result { - Ok(match value { - WireMlDsaLevel::MlDsa44 => MLDSA::MLDSA44, - WireMlDsaLevel::MlDsa65 => MLDSA::MLDSA65, - WireMlDsaLevel::MlDsa87 => MLDSA::MLDSA87, - }) - } -} - -impl From for WireMlDsaLevel { - fn from(value: MLDSA) -> Self { - match value { - MLDSA::MLDSA44 => Self::MlDsa44, - MLDSA::MLDSA65 => Self::MlDsa65, - MLDSA::MLDSA87 => Self::MlDsa87, - } - } -} -impl From<&ArchivedWireMlDsaLevel> for MLDSA { - fn from(value: &ArchivedWireMlDsaLevel) -> Self { - match value { - ArchivedWireMlDsaLevel::MlDsa44 => MLDSA::MLDSA44, - ArchivedWireMlDsaLevel::MlDsa65 => MLDSA::MLDSA65, - ArchivedWireMlDsaLevel::MlDsa87 => MLDSA::MLDSA87, - } - } -} +use crate::{QlHeader, WireError}; -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] -#[repr(u8)] -pub(crate) enum WireMlKemLevel { - MlKem512 = 1, - MlKem768 = 2, - MlKem1024 = 3, -} +pub type U16Le = little_endian::U16; +pub type U32Le = little_endian::U32; +pub type U64Le = little_endian::U64; -impl TryFrom for MLKEM { - type Error = WireError; - - fn try_from(value: WireMlKemLevel) -> Result { - Ok(match value { - WireMlKemLevel::MlKem512 => MLKEM::MLKEM512, - WireMlKemLevel::MlKem768 => MLKEM::MLKEM768, - WireMlKemLevel::MlKem1024 => MLKEM::MLKEM1024, - }) - } +pub fn push_value(out: &mut Vec, value: &T) +where + T: IntoBytes + Immutable + ?Sized, +{ + out.extend_from_slice(value.as_bytes()); } -impl From for WireMlKemLevel { - fn from(value: MLKEM) -> Self { - match value { - MLKEM::MLKEM512 => Self::MlKem512, - MLKEM::MLKEM768 => Self::MlKem768, - MLKEM::MLKEM1024 => Self::MlKem1024, - } - } +pub fn read_exact(bytes: &[u8]) -> Result +where + T: FromBytes + Copy, +{ + T::read_from_bytes(bytes).map_err(|_| WireError::InvalidPayload) } -impl From<&ArchivedWireMlKemLevel> for MLKEM { - fn from(value: &ArchivedWireMlKemLevel) -> Self { - match value { - ArchivedWireMlKemLevel::MlKem512 => MLKEM::MLKEM512, - ArchivedWireMlKemLevel::MlKem768 => MLKEM::MLKEM768, - ArchivedWireMlKemLevel::MlKem1024 => MLKEM::MLKEM1024, - } - } +pub fn read_byte(byte: u8) -> Result +where + T: TryFromBytes + Copy, +{ + T::try_read_from_bytes(core::slice::from_ref(&byte)).map_err(|_| WireError::InvalidPayload) } -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub(crate) struct WireMlDsaPublicKey { - pub(crate) level: WireMlDsaLevel, - pub(crate) bytes: Vec, +pub fn read_prefix(bytes: B) -> Result<(T, B), WireError> +where + B: SplitByteSlice, + T: FromBytes + KnownLayout + Immutable + Copy, +{ + let (value, rest) = Ref::<_, T>::from_prefix(bytes).map_err(|_| WireError::InvalidPayload)?; + Ok((*value, rest)) } -impl TryFrom for MLDSAPublicKey { - type Error = WireError; - - fn try_from(value: WireMlDsaPublicKey) -> Result { - MLDSAPublicKey::from_bytes(value.level.try_into()?, &value.bytes) - .map_err(|_| WireError::InvalidPayload) - } +pub fn parse(bytes: B) -> Result, WireError> +where + B: ByteSlice, + T: KnownLayout + Immutable + ?Sized, +{ + Ref::<_, T>::from_bytes(bytes).map_err(|_| WireError::InvalidPayload) } -impl From<&MLDSAPublicKey> for WireMlDsaPublicKey { - fn from(value: &MLDSAPublicKey) -> Self { - Self { - level: value.level().into(), - bytes: value.as_bytes().to_vec(), - } +pub fn ensure_empty(bytes: &[u8]) -> Result<(), WireError> { + if bytes.is_empty() { + Ok(()) + } else { + Err(WireError::InvalidPayload) } } -impl TryFrom<&ArchivedWireMlDsaPublicKey> for MLDSAPublicKey { - type Error = WireError; - - fn try_from(value: &ArchivedWireMlDsaPublicKey) -> Result { - MLDSAPublicKey::from_bytes((&value.level).into(), value.bytes.as_slice()) - .map_err(|_| WireError::InvalidPayload) - } -} - -impl_wire_wrapper!(AsWireMlDsaPublicKey, MLDSAPublicKey, WireMlDsaPublicKey); - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub(crate) struct WireMlDsaSignature { - pub(crate) level: WireMlDsaLevel, - pub(crate) bytes: Vec, +pub fn append_field(out: &mut Vec, label: &[u8], value: &[u8]) { + append_framed_bytes(out, label); + append_framed_bytes(out, value); } -impl TryFrom for MLDSASignature { - type Error = WireError; - - fn try_from(value: WireMlDsaSignature) -> Result { - MLDSASignature::from_bytes(value.level.try_into()?, &value.bytes) - .map_err(|_| WireError::InvalidPayload) - } +pub fn append_framed_bytes(out: &mut Vec, value: &[u8]) { + out.extend_from_slice(&u64::try_from(value.len()).unwrap().to_le_bytes()); + out.extend_from_slice(value); } -impl From<&MLDSASignature> for WireMlDsaSignature { - fn from(value: &MLDSASignature) -> Self { - Self { - level: value.level().into(), - bytes: value.as_bytes().to_vec(), - } - } +pub fn header_aad(header: &QlHeader) -> Vec { + let mut aad = Vec::new(); + append_field(&mut aad, b"domain", b"ql-wire:header-aad:v1"); + append_field(&mut aad, b"sender", &header.sender.0); + append_field(&mut aad, b"recipient", &header.recipient.0); + aad } - -impl TryFrom<&ArchivedWireMlDsaSignature> for MLDSASignature { - type Error = WireError; - - fn try_from(value: &ArchivedWireMlDsaSignature) -> Result { - MLDSASignature::from_bytes((&value.level).into(), value.bytes.as_slice()) - .map_err(|_| WireError::InvalidPayload) - } -} - -impl_wire_wrapper!(AsWireMlDsaSignature, MLDSASignature, WireMlDsaSignature); - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub(crate) struct WireMlKemPublicKey { - pub(crate) level: WireMlKemLevel, - pub(crate) bytes: Vec, -} - -impl TryFrom for MLKEMPublicKey { - type Error = WireError; - - fn try_from(value: WireMlKemPublicKey) -> Result { - MLKEMPublicKey::from_bytes(value.level.try_into()?, &value.bytes) - .map_err(|_| WireError::InvalidPayload) - } -} - -impl From<&MLKEMPublicKey> for WireMlKemPublicKey { - fn from(value: &MLKEMPublicKey) -> Self { - Self { - level: value.level().into(), - bytes: value.as_bytes().to_vec(), - } - } -} - -impl TryFrom<&ArchivedWireMlKemPublicKey> for MLKEMPublicKey { - type Error = WireError; - - fn try_from(value: &ArchivedWireMlKemPublicKey) -> Result { - MLKEMPublicKey::from_bytes((&value.level).into(), value.bytes.as_slice()) - .map_err(|_| WireError::InvalidPayload) - } -} - -impl_wire_wrapper!(AsWireMlKemPublicKey, MLKEMPublicKey, WireMlKemPublicKey); - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub(crate) struct WireMlKemCiphertext { - pub(crate) level: WireMlKemLevel, - pub(crate) bytes: Vec, -} - -impl TryFrom for MLKEMCiphertext { - type Error = WireError; - - fn try_from(value: WireMlKemCiphertext) -> Result { - MLKEMCiphertext::from_bytes(value.level.try_into()?, &value.bytes) - .map_err(|_| WireError::InvalidPayload) - } -} - -impl From<&MLKEMCiphertext> for WireMlKemCiphertext { - fn from(value: &MLKEMCiphertext) -> Self { - Self { - level: value.level().into(), - bytes: value.as_bytes().to_vec(), - } - } -} - -impl TryFrom<&ArchivedWireMlKemCiphertext> for MLKEMCiphertext { - type Error = WireError; - - fn try_from(value: &ArchivedWireMlKemCiphertext) -> Result { - MLKEMCiphertext::from_bytes((&value.level).into(), value.bytes.as_slice()) - .map_err(|_| WireError::InvalidPayload) - } -} - -impl_wire_wrapper!(AsWireMlKemCiphertext, MLKEMCiphertext, WireMlKemCiphertext); diff --git a/ql-wire/src/control.rs b/ql-wire/src/control.rs new file mode 100644 index 00000000..b0519c9c --- /dev/null +++ b/ql-wire/src/control.rs @@ -0,0 +1,47 @@ +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned}; + +use crate::{ + codec::{U32Le, U64Le}, + WireError, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct ControlId(pub u32); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ControlMeta { + pub control_id: ControlId, + pub valid_until: u64, +} + +impl ControlMeta { + pub fn ensure_not_expired(&self, now_seconds: u64) -> Result<(), WireError> { + if now_seconds > self.valid_until { + Err(WireError::Expired) + } else { + Ok(()) + } + } +} + +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] +#[repr(C)] +pub(crate) struct ControlMetaWire { + pub(crate) control_id: U32Le, + pub(crate) valid_until: U64Le, +} + +pub(crate) fn control_meta_to_wire(meta: &ControlMeta) -> ControlMetaWire { + ControlMetaWire { + control_id: U32Le::new(meta.control_id.0), + valid_until: U64Le::new(meta.valid_until), + } +} + +pub(crate) fn control_meta_from_wire(meta: ControlMetaWire) -> ControlMeta { + ControlMeta { + control_id: ControlId(meta.control_id.get()), + valid_until: meta.valid_until.get(), + } +} diff --git a/ql-wire/src/encrypted/close.rs b/ql-wire/src/encrypted/close.rs new file mode 100644 index 00000000..83dc9ef5 --- /dev/null +++ b/ql-wire/src/encrypted/close.rs @@ -0,0 +1,34 @@ +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned}; + +use super::CloseCode; +use crate::{ + codec::{push_value, read_exact, U16Le}, + WireError, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionCloseBody { + pub code: CloseCode, +} + +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] +#[repr(C)] +struct SessionCloseBodyWire { + code: U16Le, +} + +impl SessionCloseBody { + pub(crate) fn encode_into(&self, out: &mut Vec) { + let wire = SessionCloseBodyWire { + code: U16Le::new(self.code.0), + }; + push_value(out, &wire); + } + + pub(crate) fn decode(bytes: &[u8]) -> Result { + let wire: SessionCloseBodyWire = read_exact(bytes)?; + Ok(Self { + code: CloseCode(wire.code.get()), + }) + } +} diff --git a/ql-wire/src/encrypted/close/mod.rs b/ql-wire/src/encrypted/close/mod.rs deleted file mode 100644 index 02c75aaa..00000000 --- a/ql-wire/src/encrypted/close/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -use rkyv::{Archive, Deserialize, Serialize}; - -use crate::encrypted::stream::CloseCode; - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub struct SessionCloseBody { - pub code: CloseCode, -} diff --git a/ql-wire/src/encrypted/heartbeat/mod.rs b/ql-wire/src/encrypted/heartbeat/mod.rs deleted file mode 100644 index 2bdd15f9..00000000 --- a/ql-wire/src/encrypted/heartbeat/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -use rkyv::{Archive, Deserialize, Serialize}; - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Default)] -pub struct HeartbeatBody; diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 45680ed7..4d7846b2 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -1,25 +1,42 @@ -use bc_components::SymmetricKey; -use rkyv::{Archive, Deserialize, Serialize}; +use zerocopy::{ + byte_slice::{ByteSlice, ByteSliceMut}, + FromBytes, Immutable, IntoBytes, KnownLayout, Ref, TryFromBytes, Unaligned, +}; use crate::{ - access_value, deserialize_value, encode_value, - encrypted_message::{ArchivedEncryptedMessage, EncryptedMessage}, - Nonce, QlHeader, QlPayload, QlRecord, SessionSeq, WireError, + codec::{parse, push_value, U64Le}, + encrypted_message::{EncryptedMessage, EncryptedMessageRef}, + Nonce, QlCrypto, QlHeader, QlPayload, QlRecord, SessionKey, WireError, }; -pub mod close; -pub mod heartbeat; -pub mod stream; -pub mod unpair; +mod close; +mod ping; +mod stream_chunk; +mod stream_close; +mod unpair; + +pub use close::*; +pub use ping::*; +pub use stream_chunk::*; +pub use stream_close::*; +pub use unpair::*; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct SessionSeq(pub u64); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct StreamId(pub u32); -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct SessionEnvelope { pub seq: SessionSeq, pub ack: SessionAck, pub body: SessionBody, } -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct SessionAck { pub base: SessionSeq, pub bitmap: u64, @@ -32,37 +49,172 @@ impl SessionAck { }; } -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum SessionBody { - Heartbeat(heartbeat::HeartbeatBody), + Ack, + Ping(ping::PingBody), Unpair(unpair::UnpairBody), - Stream(stream::StreamFrame), - StreamClose(stream::StreamCloseFrame), + Stream(StreamChunk), + StreamClose(StreamClose), Close(close::SessionCloseBody), } +pub enum SessionBodyRef { + Ack, + Ping, + Unpair, + Stream(StreamChunkRef), + StreamClose(StreamCloseRef), + Close(close::SessionCloseBody), +} + +#[derive( + Debug, Clone, Copy, PartialEq, Eq, TryFromBytes, KnownLayout, Immutable, IntoBytes, Unaligned, +)] +#[repr(u8)] +enum SessionBodyKind { + Ack = 1, + Ping = 2, + Unpair = 3, + Stream = 4, + StreamClose = 5, + Close = 6, +} + +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[repr(C, packed)] +pub struct SessionEnvelopeWire { + pub seq: U64Le, + pub ack_base: U64Le, + pub ack_bitmap: U64Le, + pub kind: u8, + pub body: [u8], +} + +pub type SessionEnvelopeRef = Ref; + +impl SessionEnvelopeWire { + pub fn parse(bytes: B) -> Result, WireError> { + parse(bytes) + } + + fn body_kind(&self) -> Result { + crate::codec::read_byte(self.kind) + } + + pub fn to_session_envelope(&self) -> Result { + let body = match parse_session_body(self.body_kind()?, &self.body)? { + SessionBodyRef::Ack => SessionBody::Ack, + SessionBodyRef::Ping => SessionBody::Ping(ping::PingBody), + SessionBodyRef::Unpair => SessionBody::Unpair(unpair::UnpairBody), + SessionBodyRef::Stream(frame) => SessionBody::Stream(frame.to_stream_chunk()?), + SessionBodyRef::StreamClose(frame) => { + SessionBody::StreamClose(frame.to_stream_close()?) + } + SessionBodyRef::Close(body) => SessionBody::Close(body), + }; + Ok(SessionEnvelope { + seq: SessionSeq(self.seq.get()), + ack: SessionAck { + base: SessionSeq(self.ack_base.get()), + bitmap: self.ack_bitmap.get(), + }, + body, + }) + } +} + +fn parse_session_body( + kind: SessionBodyKind, + body: B, +) -> Result, WireError> { + match kind { + SessionBodyKind::Ack => { + crate::codec::ensure_empty(&body)?; + Ok(SessionBodyRef::Ack) + } + SessionBodyKind::Ping => { + crate::codec::ensure_empty(&body)?; + Ok(SessionBodyRef::Ping) + } + SessionBodyKind::Unpair => { + crate::codec::ensure_empty(&body)?; + Ok(SessionBodyRef::Unpair) + } + SessionBodyKind::Stream => Ok(SessionBodyRef::Stream(StreamChunkWire::parse(body)?)), + SessionBodyKind::StreamClose => { + Ok(SessionBodyRef::StreamClose(StreamCloseWire::parse(body)?)) + } + SessionBodyKind::Close => Ok(SessionBodyRef::Close(close::SessionCloseBody::decode( + &body, + )?)), + } +} + +impl SessionEnvelope { + pub fn encode(&self) -> Vec { + let mut out = Vec::new(); + let kind = match &self.body { + SessionBody::Ack => SessionBodyKind::Ack, + SessionBody::Ping(_) => SessionBodyKind::Ping, + SessionBody::Unpair(_) => SessionBodyKind::Unpair, + SessionBody::Stream(_) => SessionBodyKind::Stream, + SessionBody::StreamClose(_) => SessionBodyKind::StreamClose, + SessionBody::Close(_) => SessionBodyKind::Close, + }; + let header = SessionEnvelopeHeaderWire { + seq: U64Le::new(self.seq.0), + ack_base: U64Le::new(self.ack.base.0), + ack_bitmap: U64Le::new(self.ack.bitmap), + kind: kind as u8, + }; + push_value(&mut out, &header); + match &self.body { + SessionBody::Ack | SessionBody::Ping(_) | SessionBody::Unpair(_) => {} + SessionBody::Stream(frame) => frame.encode_into(&mut out), + SessionBody::StreamClose(frame) => frame.encode_into(&mut out), + SessionBody::Close(body) => body.encode_into(&mut out), + } + out + } + + pub fn decode(bytes: &[u8]) -> Result { + SessionEnvelopeWire::parse(bytes)?.to_session_envelope() + } +} + pub fn encrypt_record( + crypto: &impl QlCrypto, header: QlHeader, - session_key: &SymmetricKey, + session_key: &SessionKey, body: &SessionEnvelope, nonce: Nonce, -) -> QlRecord { +) -> Result { let aad = header.aad(); - let body_bytes = encode_value(body); - let encrypted = EncryptedMessage::encrypt(session_key, body_bytes, &aad, nonce); - QlRecord { + let body_bytes = body.encode(); + let encrypted = EncryptedMessage::encrypt(crypto, session_key, body_bytes, &aad, nonce)?; + Ok(QlRecord { header, - payload: QlPayload::Encrypted(encrypted), - } + payload: QlPayload::Session(encrypted), + }) } -pub fn decrypt_record( +pub fn decrypt_record<'a, B: ByteSliceMut>( + crypto: &impl QlCrypto, header: &QlHeader, - encrypted: &mut ArchivedEncryptedMessage, - session_key: &SymmetricKey, -) -> Result { + encrypted: &'a mut EncryptedMessageRef, + session_key: &SessionKey, +) -> Result, WireError> { let aad = header.aad(); - let plaintext = encrypted.decrypt(session_key, &aad)?; - let body = access_value::(plaintext)?; - deserialize_value(body) + let plaintext = encrypted.decrypt(crypto, session_key, &aad)?; + SessionEnvelopeWire::parse(plaintext) +} + +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] +#[repr(C)] +struct SessionEnvelopeHeaderWire { + seq: U64Le, + ack_base: U64Le, + ack_bitmap: U64Le, + kind: u8, } diff --git a/ql-wire/src/encrypted/ping.rs b/ql-wire/src/encrypted/ping.rs new file mode 100644 index 00000000..e0dd3fd2 --- /dev/null +++ b/ql-wire/src/encrypted/ping.rs @@ -0,0 +1,2 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct PingBody; diff --git a/ql-wire/src/encrypted/stream/mod.rs b/ql-wire/src/encrypted/stream/mod.rs deleted file mode 100644 index a21712ac..00000000 --- a/ql-wire/src/encrypted/stream/mod.rs +++ /dev/null @@ -1,60 +0,0 @@ -use rkyv::{Archive, Deserialize, Serialize}; - -use crate::StreamId; - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub struct StreamFrame { - pub stream_id: StreamId, - pub offset: u64, - pub bytes: Vec, - pub fin: bool, -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub struct StreamCloseFrame { - pub stream_id: StreamId, - pub target: CloseTarget, - pub code: CloseCode, - pub payload: Vec, -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] -#[repr(u8)] -pub enum CloseTarget { - Request = 1, - Response = 2, - Both = 3, -} - -impl From<&ArchivedCloseTarget> for CloseTarget { - fn from(value: &ArchivedCloseTarget) -> Self { - match value { - ArchivedCloseTarget::Request => Self::Request, - ArchivedCloseTarget::Response => Self::Response, - ArchivedCloseTarget::Both => Self::Both, - } - } -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] -#[repr(transparent)] -pub struct CloseCode(pub u16); - -impl CloseCode { - pub const CANCELLED: Self = Self(0); - pub const PROTOCOL: Self = Self(1); - pub const INVALID_DATA: Self = Self(2); - pub const TIMEOUT: Self = Self(3); - - pub const UNKNOWN: Self = Self(16); - pub const UNKNOWN_ROUTE: Self = Self(17); - pub const INVALID_HEAD: Self = Self(18); - pub const BUSY: Self = Self(19); - pub const UNHANDLED: Self = Self(20); -} - -impl From<&ArchivedCloseCode> for CloseCode { - fn from(value: &ArchivedCloseCode) -> Self { - Self(value.0.to_native()) - } -} diff --git a/ql-wire/src/encrypted/stream_chunk.rs b/ql-wire/src/encrypted/stream_chunk.rs new file mode 100644 index 00000000..60c98415 --- /dev/null +++ b/ql-wire/src/encrypted/stream_chunk.rs @@ -0,0 +1,63 @@ +use zerocopy::{ + byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, +}; + +use super::StreamId; +use crate::{ + codec::{parse, push_value, U32Le, U64Le}, + WireError, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamChunk { + pub stream_id: StreamId, + pub offset: u64, + pub fin: bool, + pub bytes: Vec, +} + +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[repr(C, packed)] +pub struct StreamChunkWire { + pub stream_id: U32Le, + pub offset: U64Le, + pub fin: u8, + pub bytes: [u8], +} + +pub type StreamChunkRef = Ref; + +impl StreamChunkWire { + pub fn parse(bytes: B) -> Result, WireError> { + parse(bytes) + } + + pub fn to_stream_chunk(&self) -> Result { + Ok(StreamChunk { + stream_id: StreamId(self.stream_id.get()), + offset: self.offset.get(), + bytes: self.bytes.to_vec(), + fin: crate::codec::read_byte(self.fin)?, + }) + } +} + +impl StreamChunk { + pub(crate) fn encode_into(&self, out: &mut Vec) { + let header = StreamChunkHeaderWire { + stream_id: U32Le::new(self.stream_id.0), + offset: U64Le::new(self.offset), + fin: u8::from(self.fin), + }; + push_value(out, &header); + out.extend_from_slice(&self.bytes); + } +} + +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] +#[repr(C)] +struct StreamChunkHeaderWire { + stream_id: U32Le, + offset: U64Le, + fin: u8, +} diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs new file mode 100644 index 00000000..dd314f47 --- /dev/null +++ b/ql-wire/src/encrypted/stream_close.rs @@ -0,0 +1,97 @@ +use zerocopy::{ + byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, TryFromBytes, + Unaligned, +}; + +use super::StreamId; +use crate::{ + codec::{parse, push_value, U16Le, U32Le}, + WireError, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamClose { + pub stream_id: StreamId, + pub target: CloseTarget, + pub code: CloseCode, + pub payload: Vec, +} + +#[derive( + Debug, Clone, Copy, PartialEq, Eq, TryFromBytes, KnownLayout, Immutable, IntoBytes, Unaligned, +)] +#[repr(u8)] +pub enum CloseTarget { + Request = 1, + Response = 2, + Both = 3, +} + +impl CloseTarget { + pub(crate) const fn to_wire(self) -> u8 { + self as u8 + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct CloseCode(pub u16); + +impl CloseCode { + pub const CANCELLED: Self = Self(0); + pub const PROTOCOL: Self = Self(1); + pub const INVALID_DATA: Self = Self(2); + pub const TIMEOUT: Self = Self(3); + + pub const UNKNOWN: Self = Self(16); + pub const UNKNOWN_ROUTE: Self = Self(17); + pub const INVALID_HEAD: Self = Self(18); + pub const BUSY: Self = Self(19); + pub const UNHANDLED: Self = Self(20); +} + +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[repr(C, packed)] +pub struct StreamCloseWire { + pub stream_id: U32Le, + pub target: u8, + pub code: U16Le, + pub payload: [u8], +} + +pub type StreamCloseRef = Ref; + +impl StreamCloseWire { + pub fn parse(bytes: B) -> Result, WireError> { + parse(bytes) + } + + pub fn to_stream_close(&self) -> Result { + Ok(StreamClose { + stream_id: StreamId(self.stream_id.get()), + target: crate::codec::read_byte(self.target)?, + code: CloseCode(self.code.get()), + payload: self.payload.to_vec(), + }) + } +} + +impl StreamClose { + pub(crate) fn encode_into(&self, out: &mut Vec) { + let header = StreamCloseHeaderWire { + stream_id: U32Le::new(self.stream_id.0), + target: self.target.to_wire(), + code: U16Le::new(self.code.0), + }; + push_value(out, &header); + out.extend_from_slice(&self.payload); + } +} + +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] +#[repr(C)] +struct StreamCloseHeaderWire { + stream_id: U32Le, + target: u8, + code: U16Le, +} diff --git a/ql-wire/src/encrypted/unpair.rs b/ql-wire/src/encrypted/unpair.rs new file mode 100644 index 00000000..a638b045 --- /dev/null +++ b/ql-wire/src/encrypted/unpair.rs @@ -0,0 +1,2 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct UnpairBody; diff --git a/ql-wire/src/encrypted/unpair/mod.rs b/ql-wire/src/encrypted/unpair/mod.rs deleted file mode 100644 index 70a65e63..00000000 --- a/ql-wire/src/encrypted/unpair/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -use rkyv::{Archive, Deserialize, Serialize}; - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Default)] -pub struct UnpairBody; diff --git a/ql-wire/src/encrypted_message.rs b/ql-wire/src/encrypted_message.rs index bd29c3e2..1a08bcab 100644 --- a/ql-wire/src/encrypted_message.rs +++ b/ql-wire/src/encrypted_message.rs @@ -1,63 +1,115 @@ -use bc_components::SymmetricKey; -use chacha20poly1305::{AeadInPlace, ChaCha20Poly1305, KeyInit}; -use rkyv::{seal::Seal, vec::ArchivedVec, Archive, Deserialize, Serialize}; +use zerocopy::{ + byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, +}; -use crate::WireError; +use crate::{ + codec::{parse, push_value}, + Nonce, QlCrypto, SessionKey, WireError, +}; -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct Nonce(pub [u8; Self::NONCE_SIZE]); - -impl Nonce { - pub const NONCE_SIZE: usize = 12; +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[repr(C, packed)] +pub struct EncryptedMessageWire { + pub nonce: [u8; Nonce::SIZE], + pub auth: [u8; EncryptedMessage::AUTH_SIZE], + pub ciphertext: [u8], } -pub const AUTH_SIZE: usize = 16; +pub type EncryptedMessageRef = Ref; -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct EncryptedMessage { - ciphertext: Vec, - nonce: Nonce, - auth: [u8; AUTH_SIZE], + pub nonce: Nonce, + pub auth: [u8; Self::AUTH_SIZE], + pub ciphertext: Vec, +} + +impl EncryptedMessageWire { + pub fn parse(bytes: B) -> Result, WireError> { + parse(bytes) + } + + pub fn to_encrypted_message(&self) -> EncryptedMessage { + EncryptedMessage { + nonce: Nonce(self.nonce), + auth: self.auth, + ciphertext: self.ciphertext.to_vec(), + } + } + + pub fn decrypt<'a>( + &'a mut self, + crypto: &impl QlCrypto, + key: &SessionKey, + aad: &[u8], + ) -> Result<&'a mut [u8], WireError> { + let nonce = Nonce(self.nonce); + let auth = self.auth; + if !crypto.decrypt_with_aead(key, &nonce, aad, &mut self.ciphertext, &auth) { + return Err(WireError::DecryptFailed); + } + Ok(&mut self.ciphertext) + } } impl EncryptedMessage { - pub fn encrypt(key: &SymmetricKey, mut plaintext: Vec, aad: &[u8], nonce: Nonce) -> Self { - let cipher = ChaCha20Poly1305::new(key.data().into()); - let auth = cipher - .encrypt_in_place_detached((&nonce.0).into(), aad, &mut plaintext) - .expect("chacha20poly1305 encryption should succeed"); - Self { - ciphertext: plaintext, + pub const AUTH_SIZE: usize = 16; + + pub fn encode(&self) -> Vec { + let mut out = Vec::with_capacity(Nonce::SIZE + Self::AUTH_SIZE + self.ciphertext.len()); + self.encode_into(&mut out); + out + } + + pub fn decode(bytes: &[u8]) -> Result { + Ok(EncryptedMessageWire::parse(bytes)?.to_encrypted_message()) + } + + pub fn encode_into(&self, out: &mut Vec) { + push_value( + out, + &EncryptedMessageHeaderWire { + nonce: self.nonce.0, + auth: self.auth, + }, + ); + out.extend_from_slice(&self.ciphertext); + } + + pub fn encrypt( + crypto: &impl QlCrypto, + key: &SessionKey, + mut plaintext: Vec, + aad: &[u8], + nonce: Nonce, + ) -> Result { + let auth = crypto + .encrypt_with_aead(key, &nonce, aad, &mut plaintext) + .ok_or(WireError::EncryptFailed)?; + Ok(Self { nonce, - auth: auth.into(), - } + auth, + ciphertext: plaintext, + }) } - pub fn decrypt(&self, key: &SymmetricKey, aad: &[u8]) -> Result, WireError> { - let cipher = ChaCha20Poly1305::new(key.data().into()); + pub fn decrypt( + &self, + crypto: &impl QlCrypto, + key: &SessionKey, + aad: &[u8], + ) -> Result, WireError> { let mut plaintext = self.ciphertext.clone(); - cipher - .decrypt_in_place_detached( - (&self.nonce.0).into(), - aad, - &mut plaintext, - (&self.auth).into(), - ) - .map_err(|_| WireError::InvalidPayload)?; + if !crypto.decrypt_with_aead(key, &self.nonce, aad, &mut plaintext, &self.auth) { + return Err(WireError::DecryptFailed); + } Ok(plaintext) } } -impl ArchivedEncryptedMessage { - pub fn decrypt(&mut self, key: &SymmetricKey, aad: &[u8]) -> Result<&[u8], WireError> { - let cipher = ChaCha20Poly1305::new(key.data().into()); - let nonce = &self.nonce; - let auth = self.auth; - let ciphertext = ArchivedVec::as_slice_seal(Seal::new(&mut self.ciphertext)); - let ciphertext = unsafe { ciphertext.unseal_unchecked() }; - cipher - .decrypt_in_place_detached((&nonce.0).into(), aad, ciphertext, (&auth).into()) - .map_err(|_| WireError::InvalidPayload)?; - Ok(ciphertext) - } +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] +#[repr(C)] +struct EncryptedMessageHeaderWire { + nonce: [u8; Nonce::SIZE], + auth: [u8; EncryptedMessage::AUTH_SIZE], } diff --git a/ql-wire/src/error.rs b/ql-wire/src/error.rs new file mode 100644 index 00000000..2a84ce38 --- /dev/null +++ b/ql-wire/src/error.rs @@ -0,0 +1,17 @@ +use thiserror::Error; + +#[derive(Debug, Clone, PartialEq, Eq, Error)] +pub enum WireError { + #[error("invalid payload")] + InvalidPayload, + #[error("invalid signature")] + InvalidSignature, + #[error("expired")] + Expired, + #[error("signing failed")] + SigningFailed, + #[error("encryption failed")] + EncryptFailed, + #[error("decryption failed")] + DecryptFailed, +} diff --git a/ql-wire/src/handshake/crypto.rs b/ql-wire/src/handshake/crypto.rs index 46bc850a..f8fbc4bb 100644 --- a/ql-wire/src/handshake/crypto.rs +++ b/ql-wire/src/handshake/crypto.rs @@ -1,77 +1,29 @@ -use bc_components::{ - Digest, MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, MLKEMPublicKey, SymmetricKey, -}; -use rkyv::{Archive, Serialize}; +use zerocopy::byte_slice::ByteSliceMut; -use super::{ - verify_signature, ArchivedConfirm, ArchivedHello, ArchivedHelloReply, ArchivedReady, Confirm, - Hello, HelloReply, Ready, ReadyBody, -}; +use super::{verify_signature, Confirm, Hello, HelloReply, Ready, ReadyBody, ReadyRef}; use crate::{ - access_value, deserialize_value, encode_value, encrypted_message::EncryptedMessage, - ensure_not_expired, AsWireMlKemCiphertext, ControlMeta, Nonce, QlCrypto, QlHeader, QlIdentity, - WireError, XID, + pq::ML_KEM_SUITE_TAG, ControlMeta, MlDsaPublicKey, MlKemCiphertext, MlKemPublicKey, Nonce, + QlCrypto, QlHeader, QlIdentity, SessionKey, WireError, XID, }; -#[derive(Archive, Serialize)] -struct HelloProofData { - initiator: XID, - responder: XID, - meta: ControlMeta, - nonce: Nonce, - #[rkyv(with = AsWireMlKemCiphertext)] - kem_ct: bc_components::MLKEMCiphertext, -} - -#[derive(Archive, Serialize)] -struct HandshakeTranscript { - initiator: XID, - responder: XID, - hello_meta: ControlMeta, - initiator_nonce: Nonce, - responder_nonce: Nonce, - reply_meta: ControlMeta, - #[rkyv(with = AsWireMlKemCiphertext)] - initiator_kem_ct: bc_components::MLKEMCiphertext, - #[rkyv(with = AsWireMlKemCiphertext)] - responder_kem_ct: bc_components::MLKEMCiphertext, -} - -#[derive(Archive, Serialize)] -struct ConfirmProofData { - meta: ControlMeta, - transcript: Vec, -} - -#[derive(Archive, Serialize)] -struct SessionKeyMaterial { - initiator_secret: Vec, - responder_secret: Vec, - transcript: Vec, -} - #[derive(Debug, Clone, PartialEq, Eq)] pub struct ResponderSecrets { - pub initiator_secret: SymmetricKey, - pub responder_secret: SymmetricKey, + pub initiator_secret: SessionKey, + pub responder_secret: SessionKey, } pub fn build_hello( - identity: &QlIdentity, crypto: &impl QlCrypto, + identity: &QlIdentity, recipient: XID, - recipient_encapsulation_key: &MLKEMPublicKey, + recipient_encapsulation_key: &MlKemPublicKey, meta: ControlMeta, -) -> Result<(Hello, SymmetricKey), WireError> { +) -> Result<(Hello, SessionKey), WireError> { let nonce = next_nonce(crypto); - let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(); - let signature = identity.signing_private_key.sign(hello_proof_data( - identity.xid, - recipient, - &meta, - &nonce, - &kem_ct, - )); + let (session_key, kem_ct) = + recipient_encapsulation_key.encapsulate_new_shared_secret(crypto)?; + let proof_data = hash_hello_proof_data(crypto, identity.xid, recipient, &meta, &nonce, &kem_ct); + let signature = identity.signing_private_key.sign(crypto, &proof_data)?; Ok(( Hello { meta, @@ -84,58 +36,68 @@ pub fn build_hello( } pub fn verify_hello( + crypto: &impl QlCrypto, initiator: XID, responder: XID, - initiator_signing_key: &MLDSAPublicKey, - hello: &ArchivedHello, + initiator_signing_key: &MlDsaPublicKey, + hello: &Hello, + now_seconds: u64, ) -> Result<(), WireError> { - let meta: ControlMeta = (&hello.meta).into(); - ensure_not_expired(meta.valid_until)?; - let signature = MLDSASignature::try_from(&hello.signature)?; - let nonce: Nonce = deserialize_value(&hello.nonce)?; - let kem_ct = MLKEMCiphertext::try_from(&hello.kem_ct)?; - let proof_data = hello_proof_data(initiator, responder, &meta, &nonce, &kem_ct); - verify_signature(initiator_signing_key, &signature, &proof_data) + hello.meta.ensure_not_expired(now_seconds)?; + let proof_data = hash_hello_proof_data( + crypto, + initiator, + responder, + &hello.meta, + &hello.nonce, + &hello.kem_ct, + ); + verify_signature(initiator_signing_key, &hello.signature, &proof_data) } pub fn respond_hello( - identity: &QlIdentity, crypto: &impl QlCrypto, + identity: &QlIdentity, initiator: XID, - initiator_signing_key: &MLDSAPublicKey, - initiator_encapsulation_key: &MLKEMPublicKey, - hello: &ArchivedHello, + initiator_signing_key: &MlDsaPublicKey, + initiator_encapsulation_key: &MlKemPublicKey, + hello: &Hello, meta: ControlMeta, + now_seconds: u64, ) -> Result<(HelloReply, ResponderSecrets), WireError> { - verify_hello(initiator, identity.xid, initiator_signing_key, hello)?; - let hello_meta: ControlMeta = (&hello.meta).into(); - let initiator_nonce: Nonce = deserialize_value(&hello.nonce)?; - let initiator_kem_ct = MLKEMCiphertext::try_from(&hello.kem_ct)?; + verify_hello( + crypto, + initiator, + identity.xid, + initiator_signing_key, + hello, + now_seconds, + )?; let initiator_secret = identity .encapsulation_private_key - .decapsulate_shared_secret(&initiator_kem_ct) - .map_err(|_| WireError::InvalidPayload)?; + .decapsulate_shared_secret(&hello.kem_ct)?; let nonce = next_nonce(crypto); - let (responder_secret, kem_ct) = initiator_encapsulation_key.encapsulate_new_shared_secret(); - let transcript = handshake_transcript( + let (responder_secret, kem_ct) = + initiator_encapsulation_key.encapsulate_new_shared_secret(crypto)?; + let transcript = hash_handshake_transcript( + crypto, initiator, identity.xid, - &hello_meta, - &initiator_nonce, - &initiator_kem_ct, + &hello.meta, + &hello.nonce, + &hello.kem_ct, &meta, &nonce, &kem_ct, ); - let signature = identity.signing_private_key.sign(&transcript); - let reply = HelloReply { - meta, - nonce, - kem_ct, - signature, - }; + let signature = identity.signing_private_key.sign(crypto, &transcript)?; Ok(( - reply, + HelloReply { + meta, + nonce, + kem_ct, + signature, + }, ResponderSecrets { initiator_secret, responder_secret, @@ -144,87 +106,111 @@ pub fn respond_hello( } pub fn build_confirm( + crypto: &impl QlCrypto, identity: &QlIdentity, responder: XID, - responder_signing_key: &MLDSAPublicKey, + responder_signing_key: &MlDsaPublicKey, hello: &Hello, - reply: &ArchivedHelloReply, - initiator_secret: &SymmetricKey, + reply: &HelloReply, + initiator_secret: &SessionKey, meta: ControlMeta, -) -> Result<(Confirm, SymmetricKey), WireError> { - let reply_meta: ControlMeta = (&reply.meta).into(); - ensure_not_expired(reply_meta.valid_until)?; - let reply_nonce: Nonce = deserialize_value(&reply.nonce)?; - let reply_kem_ct = MLKEMCiphertext::try_from(&reply.kem_ct)?; - let reply_signature = MLDSASignature::try_from(&reply.signature)?; - let transcript = handshake_transcript( + now_seconds: u64, +) -> Result<(Confirm, SessionKey), WireError> { + reply.meta.ensure_not_expired(now_seconds)?; + let transcript = hash_handshake_transcript( + crypto, identity.xid, responder, &hello.meta, &hello.nonce, &hello.kem_ct, - &reply_meta, - &reply_nonce, - &reply_kem_ct, + &reply.meta, + &reply.nonce, + &reply.kem_ct, ); - verify_signature(responder_signing_key, &reply_signature, &transcript)?; + verify_signature(responder_signing_key, &reply.signature, &transcript)?; let responder_secret = identity .encapsulation_private_key - .decapsulate_shared_secret(&reply_kem_ct) - .map_err(|_| WireError::InvalidPayload)?; - let signature = identity - .signing_private_key - .sign(confirm_proof_data(&meta, &transcript)); - let confirm = Confirm { meta, signature }; - let session_key = derive_session_key(initiator_secret, &responder_secret, &transcript); - Ok((confirm, session_key)) + .decapsulate_shared_secret(&reply.kem_ct)?; + let proof_data = hash_confirm_proof_data( + crypto, + &meta, + identity.xid, + responder, + &hello.meta, + &hello.nonce, + &hello.kem_ct, + &reply.meta, + &reply.nonce, + &reply.kem_ct, + ); + let signature = identity.signing_private_key.sign(crypto, &proof_data)?; + let session_key = derive_session_key( + crypto, + initiator_secret, + &responder_secret, + identity.xid, + responder, + &hello.meta, + &hello.nonce, + &hello.kem_ct, + &reply.meta, + &reply.nonce, + &reply.kem_ct, + ); + Ok((Confirm { meta, signature }, session_key)) } pub fn finalize_confirm( + crypto: &impl QlCrypto, initiator: XID, responder: XID, - initiator_signing_key: &MLDSAPublicKey, + initiator_signing_key: &MlDsaPublicKey, hello: &Hello, reply: &HelloReply, - confirm: &ArchivedConfirm, + confirm: &Confirm, secrets: &ResponderSecrets, -) -> Result { + now_seconds: u64, +) -> Result { verify_confirm( + crypto, initiator, responder, initiator_signing_key, hello, reply, confirm, + now_seconds, )?; Ok(derive_session_key( + crypto, &secrets.initiator_secret, &secrets.responder_secret, - &handshake_transcript( - initiator, - responder, - &hello.meta, - &hello.nonce, - &hello.kem_ct, - &reply.meta, - &reply.nonce, - &reply.kem_ct, - ), + initiator, + responder, + &hello.meta, + &hello.nonce, + &hello.kem_ct, + &reply.meta, + &reply.nonce, + &reply.kem_ct, )) } pub fn verify_confirm( + crypto: &impl QlCrypto, initiator: XID, responder: XID, - initiator_signing_key: &MLDSAPublicKey, + initiator_signing_key: &MlDsaPublicKey, hello: &Hello, reply: &HelloReply, - confirm: &ArchivedConfirm, + confirm: &Confirm, + now_seconds: u64, ) -> Result<(), WireError> { - let confirm_meta: ControlMeta = (&confirm.meta).into(); - ensure_not_expired(confirm_meta.valid_until)?; - let confirm_signature = MLDSASignature::try_from(&confirm.signature)?; - let transcript = handshake_transcript( + confirm.meta.ensure_not_expired(now_seconds)?; + let proof_data = hash_confirm_proof_data( + crypto, + &confirm.meta, initiator, responder, &hello.meta, @@ -234,98 +220,219 @@ pub fn verify_confirm( &reply.nonce, &reply.kem_ct, ); - let proof_data = confirm_proof_data(&confirm_meta, &transcript); - verify_signature(initiator_signing_key, &confirm_signature, &proof_data)?; - Ok(()) + verify_signature(initiator_signing_key, &confirm.signature, &proof_data) } pub fn build_ready( + crypto: &impl QlCrypto, header: QlHeader, - session_key: &SymmetricKey, + session_key: &SessionKey, meta: ControlMeta, nonce: Nonce, -) -> Ready { +) -> Result { let aad = header.aad(); - let body_bytes = encode_value(&ReadyBody { meta }); - Ready { - encrypted: EncryptedMessage::encrypt(session_key, body_bytes, &aad, nonce), - } + let body_bytes = ReadyBody { meta }.encode(); + Ok(Ready { + encrypted: crate::encrypted_message::EncryptedMessage::encrypt( + crypto, + session_key, + body_bytes, + &aad, + nonce, + )?, + }) } -pub fn decrypt_ready( +pub fn decrypt_ready( + crypto: &impl QlCrypto, header: &QlHeader, - ready: &mut ArchivedReady, - session_key: &SymmetricKey, + ready: &mut ReadyRef, + session_key: &SessionKey, + now_seconds: u64, ) -> Result { let aad = header.aad(); - let plaintext = ready.encrypted.decrypt(session_key, &aad)?; - let body = access_value::(plaintext)?; - let body = deserialize_value(body)?; - ensure_not_expired(body.meta.valid_until)?; + let plaintext = ready.decrypt(crypto, session_key, &aad)?; + let body = ReadyBody::decode(plaintext)?; + body.meta.ensure_not_expired(now_seconds)?; Ok(body) } -fn handshake_transcript( +fn hash_hello_proof_data( + crypto: &impl QlCrypto, + initiator: XID, + responder: XID, + meta: &ControlMeta, + nonce: &Nonce, + kem_ct: &MlKemCiphertext, +) -> [u8; 32] { + let control_id = meta.control_id.0.to_le_bytes(); + let valid_until = meta.valid_until.to_le_bytes(); + crypto.hash(&[ + b"ql-wire:hello-proof:v1", + b"initiator", + &initiator.0, + b"responder", + &responder.0, + b"control-id", + &control_id, + b"valid-until", + &valid_until, + b"nonce", + &nonce.0, + b"kem-suite", + ML_KEM_SUITE_TAG, + b"kem-ct", + kem_ct.as_bytes(), + ]) +} + +fn hash_handshake_transcript( + crypto: &impl QlCrypto, initiator: XID, responder: XID, hello_meta: &ControlMeta, initiator_nonce: &Nonce, - initiator_kem_ct: &bc_components::MLKEMCiphertext, + initiator_kem_ct: &MlKemCiphertext, reply_meta: &ControlMeta, responder_nonce: &Nonce, - responder_kem_ct: &bc_components::MLKEMCiphertext, -) -> Vec { - encode_value(&HandshakeTranscript { - initiator, - responder, - hello_meta: *hello_meta, - initiator_nonce: initiator_nonce.clone(), - responder_nonce: responder_nonce.clone(), - reply_meta: *reply_meta, - initiator_kem_ct: initiator_kem_ct.clone(), - responder_kem_ct: responder_kem_ct.clone(), - }) + responder_kem_ct: &MlKemCiphertext, +) -> [u8; 32] { + let hello_control_id = hello_meta.control_id.0.to_le_bytes(); + let hello_valid_until = hello_meta.valid_until.to_le_bytes(); + let reply_control_id = reply_meta.control_id.0.to_le_bytes(); + let reply_valid_until = reply_meta.valid_until.to_le_bytes(); + crypto.hash(&[ + b"ql-wire:handshake-transcript:v1", + b"initiator", + &initiator.0, + b"responder", + &responder.0, + b"hello-control-id", + &hello_control_id, + b"hello-valid-until", + &hello_valid_until, + b"initiator-nonce", + &initiator_nonce.0, + b"initiator-kem-suite", + ML_KEM_SUITE_TAG, + b"initiator-kem-ct", + initiator_kem_ct.as_bytes(), + b"reply-control-id", + &reply_control_id, + b"reply-valid-until", + &reply_valid_until, + b"responder-nonce", + &responder_nonce.0, + b"responder-kem-suite", + ML_KEM_SUITE_TAG, + b"responder-kem-ct", + responder_kem_ct.as_bytes(), + ]) } -fn hello_proof_data( +fn hash_confirm_proof_data( + crypto: &impl QlCrypto, + confirm_meta: &ControlMeta, initiator: XID, responder: XID, - meta: &ControlMeta, - nonce: &Nonce, - kem_ct: &bc_components::MLKEMCiphertext, -) -> Vec { - encode_value(&HelloProofData { - initiator, - responder, - meta: *meta, - nonce: nonce.clone(), - kem_ct: kem_ct.clone(), - }) -} - -fn confirm_proof_data(meta: &ControlMeta, transcript: &[u8]) -> Vec { - encode_value(&ConfirmProofData { - meta: *meta, - transcript: transcript.to_vec(), - }) + hello_meta: &ControlMeta, + initiator_nonce: &Nonce, + initiator_kem_ct: &MlKemCiphertext, + reply_meta: &ControlMeta, + responder_nonce: &Nonce, + responder_kem_ct: &MlKemCiphertext, +) -> [u8; 32] { + let confirm_control_id = confirm_meta.control_id.0.to_le_bytes(); + let confirm_valid_until = confirm_meta.valid_until.to_le_bytes(); + let hello_control_id = hello_meta.control_id.0.to_le_bytes(); + let hello_valid_until = hello_meta.valid_until.to_le_bytes(); + let reply_control_id = reply_meta.control_id.0.to_le_bytes(); + let reply_valid_until = reply_meta.valid_until.to_le_bytes(); + crypto.hash(&[ + b"ql-wire:confirm-proof:v1", + b"confirm-control-id", + &confirm_control_id, + b"confirm-valid-until", + &confirm_valid_until, + b"initiator", + &initiator.0, + b"responder", + &responder.0, + b"hello-control-id", + &hello_control_id, + b"hello-valid-until", + &hello_valid_until, + b"initiator-nonce", + &initiator_nonce.0, + b"initiator-kem-suite", + ML_KEM_SUITE_TAG, + b"initiator-kem-ct", + initiator_kem_ct.as_bytes(), + b"reply-control-id", + &reply_control_id, + b"reply-valid-until", + &reply_valid_until, + b"responder-nonce", + &responder_nonce.0, + b"responder-kem-suite", + ML_KEM_SUITE_TAG, + b"responder-kem-ct", + responder_kem_ct.as_bytes(), + ]) } -fn next_nonce(platform: &impl QlCrypto) -> Nonce { - let mut data = [0u8; Nonce::NONCE_SIZE]; - platform.fill_random_bytes(&mut data); +fn next_nonce(crypto: &impl QlCrypto) -> Nonce { + let mut data = [0u8; Nonce::SIZE]; + crypto.fill_random_bytes(&mut data); Nonce(data) } fn derive_session_key( - initiator_secret: &SymmetricKey, - responder_secret: &SymmetricKey, - transcript: &[u8], -) -> SymmetricKey { - let payload = encode_value(&SessionKeyMaterial { - initiator_secret: initiator_secret.as_bytes().to_vec(), - responder_secret: responder_secret.as_bytes().to_vec(), - transcript: transcript.to_vec(), - }); - let digest = Digest::from_image(payload); - SymmetricKey::from_data(*digest.data()) + crypto: &impl QlCrypto, + initiator_secret: &SessionKey, + responder_secret: &SessionKey, + initiator: XID, + responder: XID, + hello_meta: &ControlMeta, + initiator_nonce: &Nonce, + initiator_kem_ct: &MlKemCiphertext, + reply_meta: &ControlMeta, + responder_nonce: &Nonce, + responder_kem_ct: &MlKemCiphertext, +) -> SessionKey { + let hello_control_id = hello_meta.control_id.0.to_le_bytes(); + let hello_valid_until = hello_meta.valid_until.to_le_bytes(); + let reply_control_id = reply_meta.control_id.0.to_le_bytes(); + let reply_valid_until = reply_meta.valid_until.to_le_bytes(); + SessionKey::from_data(crypto.hash(&[ + b"ql-wire:session-key:v1", + b"initiator-secret", + initiator_secret.as_bytes(), + b"responder-secret", + responder_secret.as_bytes(), + b"initiator", + &initiator.0, + b"responder", + &responder.0, + b"hello-control-id", + &hello_control_id, + b"hello-valid-until", + &hello_valid_until, + b"initiator-nonce", + &initiator_nonce.0, + b"initiator-kem-suite", + ML_KEM_SUITE_TAG, + b"initiator-kem-ct", + initiator_kem_ct.as_bytes(), + b"reply-control-id", + &reply_control_id, + b"reply-valid-until", + &reply_valid_until, + b"responder-nonce", + &responder_nonce.0, + b"responder-kem-suite", + ML_KEM_SUITE_TAG, + b"responder-kem-ct", + responder_kem_ct.as_bytes(), + ])) } diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index c40378ae..33adbae7 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -1,66 +1,155 @@ -use bc_components::{MLDSAPublicKey, MLDSASignature, MLKEMCiphertext}; -use rkyv::{Archive, Deserialize, Serialize}; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned}; use crate::{ - encrypted_message::EncryptedMessage, AsWireMlDsaSignature, AsWireMlKemCiphertext, ControlMeta, - Nonce, WireError, + codec::{push_value, read_exact}, + control::{control_meta_from_wire, control_meta_to_wire, ControlMetaWire}, + encrypted_message::{EncryptedMessage, EncryptedMessageRef}, + ControlMeta, MlDsaPublicKey, MlDsaSignature, MlKemCiphertext, Nonce, WireError, }; mod crypto; pub use crypto::*; -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub enum HandshakeRecord { - Hello(Hello), - HelloReply(HelloReply), - Confirm(Confirm), - Ready(Ready), -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct Hello { pub meta: ControlMeta, pub nonce: Nonce, - #[rkyv(with = AsWireMlKemCiphertext)] - pub kem_ct: MLKEMCiphertext, - #[rkyv(with = AsWireMlDsaSignature)] - pub signature: MLDSASignature, + pub kem_ct: MlKemCiphertext, + pub signature: MlDsaSignature, } -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct HelloReply { pub meta: ControlMeta, pub nonce: Nonce, - #[rkyv(with = AsWireMlKemCiphertext)] - pub kem_ct: MLKEMCiphertext, - #[rkyv(with = AsWireMlDsaSignature)] - pub signature: MLDSASignature, + pub kem_ct: MlKemCiphertext, + pub signature: MlDsaSignature, } -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct Confirm { pub meta: ControlMeta, - #[rkyv(with = AsWireMlDsaSignature)] - pub signature: MLDSASignature, + pub signature: MlDsaSignature, } -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct Ready { pub encrypted: EncryptedMessage, } -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct ReadyBody { pub meta: ControlMeta, } +pub type ReadyRef = EncryptedMessageRef; + +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] +#[repr(C)] +struct HelloWire { + meta: ControlMetaWire, + nonce: [u8; Nonce::SIZE], + kem_ct: [u8; MlKemCiphertext::SIZE], + signature: [u8; MlDsaSignature::SIZE], +} + +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] +#[repr(C)] +struct ConfirmWire { + meta: ControlMetaWire, + signature: [u8; MlDsaSignature::SIZE], +} + +impl Hello { + pub(crate) fn encode_into(&self, out: &mut Vec) { + let wire = HelloWire { + meta: control_meta_to_wire(&self.meta), + nonce: self.nonce.0, + kem_ct: *self.kem_ct.as_bytes(), + signature: *self.signature.as_bytes(), + }; + push_value(out, &wire); + } + + pub(crate) fn decode(bytes: &[u8]) -> Result { + let wire: HelloWire = read_exact(bytes)?; + Ok(Self { + meta: control_meta_from_wire(wire.meta), + nonce: Nonce(wire.nonce), + kem_ct: MlKemCiphertext::from_data(wire.kem_ct), + signature: MlDsaSignature::from_data(wire.signature), + }) + } +} + +impl HelloReply { + pub(crate) fn encode_into(&self, out: &mut Vec) { + Hello { + meta: self.meta, + nonce: self.nonce, + kem_ct: self.kem_ct, + signature: self.signature, + } + .encode_into(out); + } + + pub(crate) fn decode(bytes: &[u8]) -> Result { + let hello = Hello::decode(bytes)?; + Ok(Self { + meta: hello.meta, + nonce: hello.nonce, + kem_ct: hello.kem_ct, + signature: hello.signature, + }) + } +} + +impl Confirm { + pub(crate) fn encode_into(&self, out: &mut Vec) { + let wire = ConfirmWire { + meta: control_meta_to_wire(&self.meta), + signature: *self.signature.as_bytes(), + }; + push_value(out, &wire); + } + + pub(crate) fn decode(bytes: &[u8]) -> Result { + let wire: ConfirmWire = read_exact(bytes)?; + Ok(Self { + meta: control_meta_from_wire(wire.meta), + signature: MlDsaSignature::from_data(wire.signature), + }) + } +} + +impl Ready { + pub(crate) fn encode_into(&self, out: &mut Vec) { + self.encrypted.encode_into(out); + } +} + +impl ReadyBody { + pub(crate) fn encode(&self) -> Vec { + let wire = control_meta_to_wire(&self.meta); + wire.as_bytes().to_vec() + } + + pub(crate) fn decode(bytes: &[u8]) -> Result { + let wire: ControlMetaWire = read_exact(bytes)?; + Ok(Self { + meta: control_meta_from_wire(wire), + }) + } +} + pub fn verify_signature( - signing_key: &MLDSAPublicKey, - signature: &MLDSASignature, + signing_key: &MlDsaPublicKey, + signature: &MlDsaSignature, proof_data: &[u8], ) -> Result<(), WireError> { - match signing_key.verify(signature, proof_data) { - Ok(true) => Ok(()), - _ => Err(WireError::InvalidSignature), + if signing_key.verify(signature, proof_data) { + Ok(()) + } else { + Err(WireError::InvalidSignature) } } diff --git a/ql-wire/src/header.rs b/ql-wire/src/header.rs new file mode 100644 index 00000000..7fe1813a --- /dev/null +++ b/ql-wire/src/header.rs @@ -0,0 +1,55 @@ +use zerocopy::{ + byte_slice::SplitByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned, +}; + +use crate::{codec, record::RecordKind, WireError, XID}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct QlHeader { + pub sender: XID, + pub recipient: XID, +} + +impl QlHeader { + pub fn aad(&self) -> Vec { + codec::header_aad(self) + } +} + +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] +#[repr(C)] +pub(crate) struct QlRecordHeaderWire { + pub(crate) kind: u8, + pub(crate) sender: [u8; XID::SIZE], + pub(crate) recipient: [u8; XID::SIZE], +} + +#[derive(Debug, Clone, Copy)] +pub(crate) struct DecodedRecordHeader { + pub(crate) kind: RecordKind, + pub(crate) header: QlHeader, +} + +pub(crate) fn encode_record_header(header: &QlHeader, kind: RecordKind) -> QlRecordHeaderWire { + QlRecordHeaderWire { + kind: kind as u8, + sender: header.sender.0, + recipient: header.recipient.0, + } +} + +pub(crate) fn decode_record_header( + bytes: B, +) -> Result<(DecodedRecordHeader, B), WireError> { + let (wire, payload_bytes) = codec::read_prefix::(bytes)?; + Ok(( + DecodedRecordHeader { + kind: codec::read_byte(wire.kind)?, + header: QlHeader { + sender: XID(wire.sender), + recipient: XID(wire.recipient), + }, + }, + payload_bytes, + )) +} diff --git a/ql-wire/src/id.rs b/ql-wire/src/id.rs deleted file mode 100644 index f236514d..00000000 --- a/ql-wire/src/id.rs +++ /dev/null @@ -1,51 +0,0 @@ -use std::fmt; - -use rkyv::{Archive, Deserialize, Serialize}; - -macro_rules! define_id { - ($name:ident, $ty:ty) => { - #[derive( - Archive, - Serialize, - Deserialize, - Debug, - Clone, - Copy, - PartialEq, - Eq, - Hash, - PartialOrd, - Ord, - )] - #[repr(transparent)] - pub struct $name(pub $ty); - - impl fmt::Display for $name { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } - } - }; -} - -define_id!(ControlId, u32); -define_id!(SessionSeq, u64); -define_id!(StreamId, u32); - -impl From<&ArchivedControlId> for ControlId { - fn from(value: &ArchivedControlId) -> Self { - Self(value.0.to_native()) - } -} - -impl From<&ArchivedSessionSeq> for SessionSeq { - fn from(value: &ArchivedSessionSeq) -> Self { - Self(value.0.to_native()) - } -} - -impl From<&ArchivedStreamId> for StreamId { - fn from(value: &ArchivedStreamId) -> Self { - Self(value.0.to_native()) - } -} diff --git a/ql-wire/src/identity.rs b/ql-wire/src/identity.rs new file mode 100644 index 00000000..574031cd --- /dev/null +++ b/ql-wire/src/identity.rs @@ -0,0 +1,28 @@ +use crate::{MlDsaPrivateKey, MlDsaPublicKey, MlKemPrivateKey, MlKemPublicKey, XID}; + +#[derive(Debug, Clone)] +pub struct QlIdentity { + pub xid: XID, + pub signing_private_key: MlDsaPrivateKey, + pub signing_public_key: MlDsaPublicKey, + pub encapsulation_private_key: MlKemPrivateKey, + pub encapsulation_public_key: MlKemPublicKey, +} + +impl QlIdentity { + pub fn new( + xid: XID, + signing_private_key: MlDsaPrivateKey, + signing_public_key: MlDsaPublicKey, + encapsulation_private_key: MlKemPrivateKey, + encapsulation_public_key: MlKemPublicKey, + ) -> Self { + Self { + xid, + signing_private_key, + signing_public_key, + encapsulation_private_key, + encapsulation_public_key, + } + } +} diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs index ac238055..034cf80d 100644 --- a/ql-wire/src/lib.rs +++ b/ql-wire/src/lib.rs @@ -1,228 +1,56 @@ //! quantum link protocol wire format -//! -//! naming conventions: -//! - *Record - unencrypted messages -//! - *Body - message content after decrypting - -use bc_components::{MLDSAPrivateKey, MLDSAPublicKey, MLKEMPrivateKey, MLKEMPublicKey}; -use rkyv::{ - api::{ - high::{to_bytes_in, HighSerializer, HighValidator}, - low::{self, LowDeserializer}, - }, - bytecheck::CheckBytes, - ser::allocator::ArenaHandle, - Archive, Deserialize, Portable, Serialize, -}; -use thiserror::Error; mod codec; -pub mod encrypted; -pub mod encrypted_message; -pub mod handshake; -mod id; -pub mod pair; +mod control; +mod encrypted; +mod encrypted_message; +mod error; +mod handshake; +mod header; +mod identity; +mod nonce; +mod pair; +mod pq; +mod record; mod xid; -pub(crate) use codec::*; -pub use encrypted::{ - close::SessionCloseBody, - stream::{CloseCode, CloseTarget, StreamCloseFrame, StreamFrame}, - SessionAck, SessionBody, SessionEnvelope, -}; -pub use encrypted_message::Nonce; -pub use id::{ControlId, SessionSeq, StreamId}; -pub use xid::XID; - -pub(crate) type WireArchiveError = rkyv::rancor::Error; - -#[derive(Debug, Clone)] -pub struct QlIdentity { - pub xid: XID, - pub signing_private_key: MLDSAPrivateKey, - pub signing_public_key: MLDSAPublicKey, - pub encapsulation_private_key: MLKEMPrivateKey, - pub encapsulation_public_key: MLKEMPublicKey, -} - -impl QlIdentity { - pub fn from_keys( - signing_private_key: MLDSAPrivateKey, - signing_public_key: MLDSAPublicKey, - encapsulation_private_key: MLKEMPrivateKey, - encapsulation_public_key: MLKEMPublicKey, - ) -> Self { - Self { - xid: XID::from_signing_public_key(&signing_public_key), - signing_private_key, - signing_public_key, - encapsulation_private_key, - encapsulation_public_key, - } - } -} +pub use control::*; +pub use encrypted::*; +pub use encrypted_message::*; +pub use error::*; +pub use handshake::*; +pub use header::*; +pub use identity::*; +pub use nonce::*; +pub use pair::*; +pub use pq::*; +pub use record::*; +pub use xid::*; + +pub const QL_WIRE_VERSION: u8 = 1; pub trait QlCrypto { fn fill_random_bytes(&self, data: &mut [u8]); -} - -#[derive(Debug, Clone, PartialEq, Eq, Error)] -pub enum WireError { - #[error("invalid payload")] - InvalidPayload, - #[error("invalid signature")] - InvalidSignature, - #[error("expired")] - Expired, -} - -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct QlRecord { - pub header: QlHeader, - pub payload: QlPayload, -} -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct QlHeader { - pub sender: XID, - pub recipient: XID, -} - -impl QlHeader { - pub fn aad(&self) -> Vec { - encode_value(self) - } -} + fn hash(&self, parts: &[&[u8]]) -> [u8; 32]; -#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] -pub struct ControlMeta { - pub control_id: ControlId, - pub valid_until: u64, -} - -impl From<&ArchivedControlMeta> for ControlMeta { - fn from(value: &ArchivedControlMeta) -> Self { - Self { - control_id: (&value.control_id).into(), - valid_until: value.valid_until.to_native(), - } - } -} + fn encrypt_with_aead( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + ) -> Option<[u8; EncryptedMessage::AUTH_SIZE]>; -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] -pub enum QlPayload { - Handshake(handshake::HandshakeRecord), - Pair(pair::PairRequestRecord), - Encrypted(encrypted_message::EncryptedMessage), -} - -pub fn encode_record(record: &QlRecord) -> Vec { - encode_value(record) -} - -pub fn access_record(bytes: &[u8]) -> Result<&ArchivedQlRecord, WireError> { - access_value(bytes) -} - -pub fn access_record_mut( - bytes: &mut [u8], -) -> Result, WireError> { - rkyv::access_mut::(bytes) - .map_err(|_| WireError::InvalidPayload) -} - -pub fn decode_record(bytes: &[u8]) -> Result { - deserialize_value(access_record(bytes)?) -} - -pub(crate) fn encode_value( - value: &impl for<'a> Serialize, ArenaHandle<'a>, WireArchiveError>>, -) -> Vec { - to_bytes_in::<_, WireArchiveError>(value, Vec::new()) - .expect("wire serialization should not fail") -} - -pub(crate) fn access_value(bytes: &[u8]) -> Result<&T, WireError> -where - T: Portable + for<'a> CheckBytes>, -{ - rkyv::access::(bytes).map_err(|_| WireError::InvalidPayload) -} - -pub(crate) fn deserialize_value( - value: &impl rkyv::Deserialize>, -) -> Result { - low::deserialize::(value).map_err(|_| WireError::InvalidPayload) -} - -pub(crate) fn ensure_not_expired(valid_until: u64) -> Result<(), WireError> { - if now_secs() > valid_until { - Err(WireError::Expired) - } else { - Ok(()) - } -} - -pub fn now_secs() -> u64 { - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|duration| duration.as_secs()) - .unwrap_or(0) + fn decrypt_with_aead( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + auth_tag: &[u8; EncryptedMessage::AUTH_SIZE], + ) -> bool; } #[cfg(test)] -mod tests { - use bc_components::SymmetricKey; - - use super::*; - - struct TestCrypto(std::sync::atomic::AtomicU8); - - impl TestCrypto { - fn new(seed: u8) -> Self { - Self(std::sync::atomic::AtomicU8::new(seed)) - } - } - - impl QlCrypto for TestCrypto { - fn fill_random_bytes(&self, data: &mut [u8]) { - let seed = self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - for (index, byte) in data.iter_mut().enumerate() { - *byte = seed.wrapping_add(index as u8); - } - } - } - - #[test] - fn ql_record_round_trip() { - let header = QlHeader { - sender: XID([1; XID::XID_SIZE]), - recipient: XID([2; XID::XID_SIZE]), - }; - let body = SessionEnvelope { - seq: SessionSeq(7), - ack: SessionAck { - base: SessionSeq(3), - bitmap: 0b101, - }, - body: SessionBody::Heartbeat(encrypted::heartbeat::HeartbeatBody), - }; - let record = encrypted::encrypt_record( - header.clone(), - &SymmetricKey::from_data([7; SymmetricKey::SYMMETRIC_KEY_SIZE]), - &body, - Nonce([8; Nonce::NONCE_SIZE]), - ); - - let bytes = encode_record(&record); - let decoded = decode_record(&bytes).unwrap(); - assert_eq!(decoded.header, header); - assert!(matches!(decoded.payload, QlPayload::Encrypted(_))); - } - - #[test] - fn now_secs_advances() { - let _ = TestCrypto::new(1); - assert!(now_secs() > 0); - } -} +mod tests; diff --git a/ql-wire/src/nonce.rs b/ql-wire/src/nonce.rs new file mode 100644 index 00000000..fd913400 --- /dev/null +++ b/ql-wire/src/nonce.rs @@ -0,0 +1,7 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct Nonce(pub [u8; Self::SIZE]); + +impl Nonce { + pub const SIZE: usize = 12; +} diff --git a/ql-wire/src/pair/crypto.rs b/ql-wire/src/pair/crypto.rs index 7c9c1659..a10d52d1 100644 --- a/ql-wire/src/pair/crypto.rs +++ b/ql-wire/src/pair/crypto.rs @@ -1,97 +1,92 @@ -use bc_components::{MLDSAPublicKey, MLKEMCiphertext, MLKEMPublicKey, SymmetricKey}; -use rkyv::{Archive, Serialize}; +use zerocopy::byte_slice::ByteSliceMut; -use super::{PairRequestBody, PairRequestRecord}; +use super::{PairRequestBody, PairRequestRecordRef}; use crate::{ - access_value, deserialize_value, encode_value, - encrypted_message::{ArchivedEncryptedMessage, EncryptedMessage}, - ensure_not_expired, AsWireMlDsaPublicKey, AsWireMlKemCiphertext, AsWireMlKemPublicKey, - ControlMeta, Nonce, QlCrypto, QlHeader, QlIdentity, QlPayload, QlRecord, WireError, XID, + pq::ML_KEM_SUITE_TAG, ControlMeta, MlDsaPublicKey, MlKemCiphertext, MlKemPublicKey, QlCrypto, + QlHeader, QlIdentity, QlPayload, QlRecord, WireError, XID, }; -#[derive(Archive, Serialize)] -struct PairingAad { - header: QlHeader, - #[rkyv(with = AsWireMlKemCiphertext)] - kem_ct: MLKEMCiphertext, -} - -#[derive(Archive, Serialize)] -struct PairingProofData { - aad: Vec, - meta: ControlMeta, - #[rkyv(with = AsWireMlDsaPublicKey)] - signing_pub_key: MLDSAPublicKey, - #[rkyv(with = AsWireMlKemPublicKey)] - encapsulation_pub_key: MLKEMPublicKey, -} - pub fn build_pair_request( - identity: &QlIdentity, crypto: &impl QlCrypto, + identity: &QlIdentity, recipient: XID, - recipient_encapsulation_key: &MLKEMPublicKey, + recipient_encapsulation_key: &MlKemPublicKey, meta: ControlMeta, ) -> Result { - let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(); + let (session_key, kem_ct) = + recipient_encapsulation_key.encapsulate_new_shared_secret(crypto)?; let header = QlHeader { sender: identity.xid, recipient, }; - let signing_pub_key = identity.signing_public_key.clone(); - let sender_encapsulation_key = identity.encapsulation_public_key.clone(); - let proof_data = pairing_proof_data( + let signing_pub_key = identity.signing_public_key; + let sender_encapsulation_key = identity.encapsulation_public_key; + let proof_data = hash_pairing_proof_data( + crypto, &header, &kem_ct, &meta, + identity.xid, &signing_pub_key, &sender_encapsulation_key, ); - let proof = identity.signing_private_key.sign(&proof_data); + let proof = identity.signing_private_key.sign(crypto, &proof_data)?; let body = PairRequestBody { meta, + xid: identity.xid, signing_pub_key, encapsulation_pub_key: sender_encapsulation_key, proof, }; - let body_bytes = encode_value(&body); + let body_bytes = body.encode(); let aad = pairing_aad(&header, &kem_ct); - let mut nonce_bytes = [0u8; Nonce::NONCE_SIZE]; - crypto.fill_random_bytes(&mut nonce_bytes); - let encrypted = EncryptedMessage::encrypt(&session_key, body_bytes, &aad, Nonce(nonce_bytes)); + let mut nonce = [0u8; crate::Nonce::SIZE]; + crypto.fill_random_bytes(&mut nonce); + let encrypted = crate::encrypted_message::EncryptedMessage::encrypt( + crypto, + &session_key, + body_bytes, + &aad, + crate::Nonce(nonce), + )?; Ok(QlRecord { header, - payload: QlPayload::Pair(PairRequestRecord { kem_ct, encrypted }), + payload: QlPayload::PairRequest(super::PairRequestRecord { kem_ct, encrypted }), }) } -pub fn decrypt_pair_request( +pub fn decrypt_pair_request( + crypto: &impl QlCrypto, identity: &QlIdentity, header: &QlHeader, - request: &mut super::ArchivedPairRequestRecord, + request: &mut PairRequestRecordRef, + now_seconds: u64, ) -> Result { - let kem_ct = MLKEMCiphertext::try_from(&request.kem_ct)?; + let kem_ct = MlKemCiphertext::from_data(request.kem_ct); let aad = pairing_aad(header, &kem_ct); let session_key = identity .encapsulation_private_key - .decapsulate_shared_secret(&kem_ct) - .map_err(|_| WireError::InvalidPayload)?; - let decrypted = decrypt_body(&session_key, &mut request.encrypted, &aad)?; - ensure_not_expired(decrypted.meta.valid_until)?; - if XID::from_signing_public_key(&decrypted.signing_pub_key) != header.sender { + .decapsulate_shared_secret(&kem_ct)?; + let mut encrypted = + crate::encrypted_message::EncryptedMessageWire::parse(&mut request.encrypted)?; + let plaintext = encrypted.decrypt(crypto, &session_key, &aad)?; + let decrypted = PairRequestBody::decode(plaintext)?; + decrypted.meta.ensure_not_expired(now_seconds)?; + if decrypted.xid != header.sender { return Err(WireError::InvalidPayload); } - let proof_data = pairing_proof_data( + let proof_data = hash_pairing_proof_data( + crypto, header, &kem_ct, &decrypted.meta, + decrypted.xid, &decrypted.signing_pub_key, &decrypted.encapsulation_pub_key, ); if decrypted .signing_pub_key .verify(&decrypted.proof, &proof_data) - .unwrap_or(false) { Ok(decrypted) } else { @@ -99,34 +94,43 @@ pub fn decrypt_pair_request( } } -fn pairing_proof_data( +fn hash_pairing_proof_data( + crypto: &impl QlCrypto, header: &QlHeader, - kem_ct: &MLKEMCiphertext, + kem_ct: &MlKemCiphertext, meta: &ControlMeta, - signing_pub_key: &MLDSAPublicKey, - encapsulation_pub_key: &MLKEMPublicKey, -) -> Vec { - encode_value(&PairingProofData { - aad: pairing_aad(header, kem_ct), - meta: *meta, - signing_pub_key: signing_pub_key.clone(), - encapsulation_pub_key: encapsulation_pub_key.clone(), - }) + xid: XID, + signing_pub_key: &MlDsaPublicKey, + encapsulation_pub_key: &MlKemPublicKey, +) -> [u8; 32] { + let aad = pairing_aad(header, kem_ct); + let control_id = meta.control_id.0.to_le_bytes(); + let valid_until = meta.valid_until.to_le_bytes(); + crypto.hash(&[ + b"ql-wire:pair-proof:v1", + b"aad", + &aad, + b"control-id", + &control_id, + b"valid-until", + &valid_until, + b"xid", + &xid.0, + b"signing-pub-key", + signing_pub_key.as_bytes(), + b"encapsulation-pub-key-suite", + ML_KEM_SUITE_TAG, + b"encapsulation-pub-key", + encapsulation_pub_key.as_bytes(), + ]) } -fn decrypt_body( - key: &SymmetricKey, - encrypted: &mut ArchivedEncryptedMessage, - aad: &[u8], -) -> Result { - let plaintext = encrypted.decrypt(key, aad)?; - let body = access_value::(plaintext)?; - deserialize_value(body) -} - -pub(crate) fn pairing_aad(header: &QlHeader, kem_ct: &MLKEMCiphertext) -> Vec { - encode_value(&PairingAad { - header: header.clone(), - kem_ct: kem_ct.clone(), - }) +pub(crate) fn pairing_aad(header: &QlHeader, kem_ct: &MlKemCiphertext) -> Vec { + let mut aad = Vec::new(); + crate::codec::append_field(&mut aad, b"domain", b"ql-wire:pair-aad:v1"); + crate::codec::append_field(&mut aad, b"sender", &header.sender.0); + crate::codec::append_field(&mut aad, b"recipient", &header.recipient.0); + crate::codec::append_field(&mut aad, b"kem-suite", ML_KEM_SUITE_TAG); + crate::codec::append_field(&mut aad, b"kem-ct", kem_ct.as_bytes()); + aad } diff --git a/ql-wire/src/pair/mod.rs b/ql-wire/src/pair/mod.rs index c2c06dd8..5973b22e 100644 --- a/ql-wire/src/pair/mod.rs +++ b/ql-wire/src/pair/mod.rs @@ -1,28 +1,106 @@ -use bc_components::{MLDSAPublicKey, MLDSASignature, MLKEMCiphertext, MLKEMPublicKey}; -use rkyv::{Archive, Deserialize, Serialize}; +use zerocopy::{ + byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, +}; use crate::{ - encrypted_message::EncryptedMessage, AsWireMlDsaPublicKey, AsWireMlDsaSignature, - AsWireMlKemCiphertext, AsWireMlKemPublicKey, ControlMeta, + codec::{parse, push_value, read_exact}, + control::{control_meta_from_wire, control_meta_to_wire, ControlMetaWire}, + encrypted_message::{EncryptedMessage, EncryptedMessageWire}, + ControlMeta, MlDsaPublicKey, MlDsaSignature, MlKemCiphertext, MlKemPublicKey, WireError, XID, }; mod crypto; pub use crypto::*; -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct PairRequestRecord { - #[rkyv(with = AsWireMlKemCiphertext)] - pub kem_ct: MLKEMCiphertext, + pub kem_ct: MlKemCiphertext, pub encrypted: EncryptedMessage, } -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct PairRequestBody { pub meta: ControlMeta, - #[rkyv(with = AsWireMlDsaPublicKey)] - pub signing_pub_key: MLDSAPublicKey, - #[rkyv(with = AsWireMlKemPublicKey)] - pub encapsulation_pub_key: MLKEMPublicKey, - #[rkyv(with = AsWireMlDsaSignature)] - pub proof: MLDSASignature, + pub xid: XID, + pub signing_pub_key: MlDsaPublicKey, + pub encapsulation_pub_key: MlKemPublicKey, + pub proof: MlDsaSignature, +} + +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[repr(C, packed)] +pub struct PairRequestRecordWire { + pub kem_ct: [u8; MlKemCiphertext::SIZE], + pub encrypted: [u8], +} + +pub type PairRequestRecordRef = Ref; + +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] +#[repr(C)] +struct PairRequestBodyWire { + meta: ControlMetaWire, + xid: [u8; XID::SIZE], + signing_pub_key: [u8; MlDsaPublicKey::SIZE], + encapsulation_pub_key: [u8; MlKemPublicKey::SIZE], + proof: [u8; MlDsaSignature::SIZE], +} + +impl PairRequestRecordWire { + pub fn parse(bytes: B) -> Result, WireError> { + let record: PairRequestRecordRef = parse(bytes)?; + let _ = EncryptedMessageWire::parse(&record.encrypted)?; + Ok(record) + } + + pub fn to_pair_request_record(&self) -> PairRequestRecord { + PairRequestRecord { + kem_ct: MlKemCiphertext::from_data(self.kem_ct), + encrypted: EncryptedMessageWire::parse(&self.encrypted) + .expect("validated pair request") + .to_encrypted_message(), + } + } +} + +impl PairRequestRecord { + pub(crate) fn encode_into(&self, out: &mut Vec) { + push_value( + out, + &PairRequestHeaderWire { + kem_ct: *self.kem_ct.as_bytes(), + }, + ); + out.extend_from_slice(&self.encrypted.encode()); + } +} + +impl PairRequestBody { + pub(crate) fn encode(&self) -> Vec { + let wire = PairRequestBodyWire { + meta: control_meta_to_wire(&self.meta), + xid: self.xid.0, + signing_pub_key: *self.signing_pub_key.as_bytes(), + encapsulation_pub_key: *self.encapsulation_pub_key.as_bytes(), + proof: *self.proof.as_bytes(), + }; + wire.as_bytes().to_vec() + } + + pub(crate) fn decode(bytes: &[u8]) -> Result { + let wire: PairRequestBodyWire = read_exact(bytes)?; + Ok(Self { + meta: control_meta_from_wire(wire.meta), + xid: XID(wire.xid), + signing_pub_key: MlDsaPublicKey::from_data(wire.signing_pub_key), + encapsulation_pub_key: MlKemPublicKey::from_data(wire.encapsulation_pub_key), + proof: MlDsaSignature::from_data(wire.proof), + }) + } +} + +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] +#[repr(C)] +struct PairRequestHeaderWire { + kem_ct: [u8; MlKemCiphertext::SIZE], } diff --git a/ql-wire/src/pq.rs b/ql-wire/src/pq.rs new file mode 100644 index 00000000..61e9cca5 --- /dev/null +++ b/ql-wire/src/pq.rs @@ -0,0 +1,185 @@ +use libcrux_ml_dsa::{ml_dsa_87, KEY_GENERATION_RANDOMNESS_SIZE, SIGNING_RANDOMNESS_SIZE}; +use libcrux_ml_kem::{mlkem1024, KEY_GENERATION_SEED_SIZE, SHARED_SECRET_SIZE}; + +use crate::{QlCrypto, WireError}; + +pub(crate) const ML_KEM_SUITE_TAG: &[u8] = b"ml-kem-1024"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct SessionKey([u8; Self::SIZE]); + +impl SessionKey { + pub const SIZE: usize = SHARED_SECRET_SIZE; + + pub const fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(data) + } + + pub const fn data(&self) -> &[u8; Self::SIZE] { + &self.0 + } + + pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { + &self.0 + } +} + +impl AsRef<[u8]> for SessionKey { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MlDsaPrivateKey([u8; MlDsaPrivateKey::SIZE]); + +impl MlDsaPrivateKey { + pub const SIZE: usize = ml_dsa_87::MLDSA87SigningKey::len(); + + pub const fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(data) + } + + pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { + &self.0 + } + + pub fn sign( + &self, + crypto: &impl QlCrypto, + message: &[u8], + ) -> Result { + let mut randomness = [0u8; SIGNING_RANDOMNESS_SIZE]; + crypto.fill_random_bytes(&mut randomness); + let signing_key = ml_dsa_87::MLDSA87SigningKey::new(self.0); + let signature = ml_dsa_87::sign(&signing_key, message, b"", randomness) + .map_err(|_| WireError::SigningFailed)?; + Ok(MlDsaSignature::from_data(*signature.as_ref())) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct MlDsaPublicKey([u8; MlDsaPublicKey::SIZE]); + +impl MlDsaPublicKey { + pub const SIZE: usize = ml_dsa_87::MLDSA87VerificationKey::len(); + + pub const fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(data) + } + + pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { + &self.0 + } + + pub fn verify(&self, signature: &MlDsaSignature, message: &[u8]) -> bool { + let verification_key = ml_dsa_87::MLDSA87VerificationKey::new(self.0); + let signature = ml_dsa_87::MLDSA87Signature::new(*signature.as_bytes()); + ml_dsa_87::verify(&verification_key, message, b"", &signature).is_ok() + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct MlDsaSignature([u8; MlDsaSignature::SIZE]); + +impl MlDsaSignature { + pub const SIZE: usize = ml_dsa_87::MLDSA87Signature::len(); + + pub const fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(data) + } + + pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { + &self.0 + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct MlKemPublicKey([u8; MlKemPublicKey::SIZE]); + +impl MlKemPublicKey { + pub const SIZE: usize = mlkem1024::MlKem1024PublicKey::len(); + + pub const fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(data) + } + + pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { + &self.0 + } + + pub fn encapsulate_new_shared_secret( + &self, + crypto: &impl QlCrypto, + ) -> Result<(SessionKey, MlKemCiphertext), WireError> { + let mut randomness = [0u8; SHARED_SECRET_SIZE]; + crypto.fill_random_bytes(&mut randomness); + let public_key = mlkem1024::MlKem1024PublicKey::from(self.as_bytes()); + let (ciphertext, shared_secret) = mlkem1024::encapsulate(&public_key, randomness); + Ok(( + SessionKey::from_data(shared_secret), + MlKemCiphertext::from_data(*ciphertext.as_slice()), + )) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MlKemPrivateKey([u8; MlKemPrivateKey::SIZE]); + +impl MlKemPrivateKey { + pub const SIZE: usize = mlkem1024::MlKem1024PrivateKey::len(); + + pub const fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(data) + } + + pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { + &self.0 + } + + pub fn decapsulate_shared_secret( + &self, + ciphertext: &MlKemCiphertext, + ) -> Result { + let private_key = mlkem1024::MlKem1024PrivateKey::from(self.as_bytes()); + let ciphertext = mlkem1024::MlKem1024Ciphertext::from(ciphertext.as_bytes()); + let shared_secret = mlkem1024::decapsulate(&private_key, &ciphertext); + Ok(SessionKey::from_data(shared_secret)) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct MlKemCiphertext([u8; MlKemCiphertext::SIZE]); + +impl MlKemCiphertext { + pub const SIZE: usize = mlkem1024::MlKem1024Ciphertext::len(); + + pub const fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(data) + } + + pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { + &self.0 + } +} + +pub fn generate_ml_dsa_keypair(crypto: &impl QlCrypto) -> (MlDsaPrivateKey, MlDsaPublicKey) { + let mut randomness = [0u8; KEY_GENERATION_RANDOMNESS_SIZE]; + crypto.fill_random_bytes(&mut randomness); + let key_pair = ml_dsa_87::generate_key_pair(randomness); + ( + MlDsaPrivateKey::from_data(*key_pair.signing_key.as_ref()), + MlDsaPublicKey::from_data(*key_pair.verification_key.as_ref()), + ) +} + +pub fn generate_ml_kem_keypair(crypto: &impl QlCrypto) -> (MlKemPrivateKey, MlKemPublicKey) { + let mut randomness = [0u8; KEY_GENERATION_SEED_SIZE]; + crypto.fill_random_bytes(&mut randomness); + let key_pair = mlkem1024::generate_key_pair(randomness); + let (private_key, public_key) = key_pair.into_parts(); + ( + MlKemPrivateKey::from_data(*private_key.as_slice()), + MlKemPublicKey::from_data(*public_key.as_slice()), + ) +} diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs new file mode 100644 index 00000000..79e67c58 --- /dev/null +++ b/ql-wire/src/record.rs @@ -0,0 +1,152 @@ +use zerocopy::{ + byte_slice::{ByteSlice, SplitByteSlice}, + Immutable, IntoBytes, KnownLayout, TryFromBytes, Unaligned, +}; + +use crate::{ + codec, + encrypted_message::{EncryptedMessage, EncryptedMessageRef, EncryptedMessageWire}, + handshake, + header::{decode_record_header, encode_record_header, QlHeader}, + pair, WireError, QL_WIRE_VERSION, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QlRecord { + pub header: QlHeader, + pub payload: QlPayload, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum QlPayload { + PairRequest(pair::PairRequestRecord), + Hello(handshake::Hello), + HelloReply(handshake::HelloReply), + Confirm(handshake::Confirm), + Ready(handshake::Ready), + Session(EncryptedMessage), +} + +pub struct QlRecordRef { + pub header: QlHeader, + pub payload: QlPayloadRef, +} + +pub enum QlPayloadRef { + PairRequest(pair::PairRequestRecordRef), + Hello(handshake::Hello), + HelloReply(handshake::HelloReply), + Confirm(handshake::Confirm), + Ready(handshake::ReadyRef), + Session(EncryptedMessageRef), +} + +#[derive( + Debug, Clone, Copy, PartialEq, Eq, TryFromBytes, KnownLayout, Immutable, IntoBytes, Unaligned, +)] +#[repr(u8)] +pub(crate) enum RecordKind { + PairRequest = 1, + Hello = 2, + HelloReply = 3, + Confirm = 4, + Ready = 5, + Session = 6, +} + +impl RecordKind { + fn for_payload(payload: &QlPayload) -> Self { + match payload { + QlPayload::PairRequest(_) => Self::PairRequest, + QlPayload::Hello(_) => Self::Hello, + QlPayload::HelloReply(_) => Self::HelloReply, + QlPayload::Confirm(_) => Self::Confirm, + QlPayload::Ready(_) => Self::Ready, + QlPayload::Session(_) => Self::Session, + } + } +} + +impl QlRecord { + pub fn encode(&self) -> Vec { + let mut out = Vec::new(); + out.push(QL_WIRE_VERSION); + let header = encode_record_header(&self.header, RecordKind::for_payload(&self.payload)); + codec::push_value(&mut out, &header); + match &self.payload { + QlPayload::PairRequest(request) => request.encode_into(&mut out), + QlPayload::Hello(hello) => hello.encode_into(&mut out), + QlPayload::HelloReply(reply) => reply.encode_into(&mut out), + QlPayload::Confirm(confirm) => confirm.encode_into(&mut out), + QlPayload::Ready(ready) => ready.encode_into(&mut out), + QlPayload::Session(encrypted) => encrypted.encode_into(&mut out), + } + out + } + + pub fn decode(bytes: &[u8]) -> Result { + Ok(Self::parse(bytes)?.to_owned()) + } + + pub fn parse(bytes: &[u8]) -> Result, WireError> { + QlRecordRef::parse(bytes) + } + + pub fn parse_mut(bytes: &mut [u8]) -> Result, WireError> { + QlRecordRef::parse(bytes) + } +} + +impl QlRecordRef { + pub fn parse(bytes: B) -> Result { + let (version, payload_bytes) = codec::read_prefix::(bytes)?; + if version != QL_WIRE_VERSION { + return Err(WireError::InvalidPayload); + } + let (header, payload_bytes) = decode_record_header(payload_bytes)?; + let payload = parse_payload(header.kind, payload_bytes)?; + Ok(Self { + header: header.header, + payload, + }) + } +} + +impl QlRecordRef { + pub fn to_owned(&self) -> QlRecord { + QlRecord { + header: self.header, + payload: self.payload.to_owned(), + } + } +} + +impl QlPayloadRef { + pub fn to_owned(&self) -> QlPayload { + match self { + Self::PairRequest(request) => QlPayload::PairRequest(request.to_pair_request_record()), + Self::Hello(hello) => QlPayload::Hello(hello.clone()), + Self::HelloReply(reply) => QlPayload::HelloReply(reply.clone()), + Self::Confirm(confirm) => QlPayload::Confirm(confirm.clone()), + Self::Ready(ready) => QlPayload::Ready(handshake::Ready { + encrypted: ready.to_encrypted_message(), + }), + Self::Session(encrypted) => QlPayload::Session(encrypted.to_encrypted_message()), + } + } +} + +fn parse_payload(kind: RecordKind, payload: B) -> Result, WireError> { + match kind { + RecordKind::PairRequest => Ok(QlPayloadRef::PairRequest( + pair::PairRequestRecordWire::parse(payload)?, + )), + RecordKind::Hello => Ok(QlPayloadRef::Hello(handshake::Hello::decode(&payload)?)), + RecordKind::HelloReply => Ok(QlPayloadRef::HelloReply(handshake::HelloReply::decode( + &payload, + )?)), + RecordKind::Confirm => Ok(QlPayloadRef::Confirm(handshake::Confirm::decode(&payload)?)), + RecordKind::Ready => Ok(QlPayloadRef::Ready(EncryptedMessageWire::parse(payload)?)), + RecordKind::Session => Ok(QlPayloadRef::Session(EncryptedMessageWire::parse(payload)?)), + } +} diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs new file mode 100644 index 00000000..f0f574bf --- /dev/null +++ b/ql-wire/src/tests.rs @@ -0,0 +1,378 @@ +use std::sync::atomic::{AtomicU8, Ordering}; + +use libcrux_aesgcm::AesGcm256Key; +use sha2::{Digest, Sha256}; + +use super::*; + +struct TestCrypto(AtomicU8); + +impl TestCrypto { + fn new(seed: u8) -> Self { + Self(AtomicU8::new(seed)) + } +} + +impl QlCrypto for TestCrypto { + fn fill_random_bytes(&self, data: &mut [u8]) { + let seed = self.0.fetch_add(1, Ordering::Relaxed); + for (index, byte) in data.iter_mut().enumerate() { + *byte = seed.wrapping_add(index as u8); + } + } + + fn hash(&self, parts: &[&[u8]]) -> [u8; 32] { + let mut hasher = Sha256::new(); + for part in parts { + hasher.update(part); + } + hasher.finalize().into() + } + + fn encrypt_with_aead( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + ) -> Option<[u8; EncryptedMessage::AUTH_SIZE]> { + let key: AesGcm256Key = (*key.data()).into(); + let plaintext = buffer.to_vec(); + let mut auth = [0u8; EncryptedMessage::AUTH_SIZE]; + key.encrypt( + buffer, + (&mut auth).into(), + (&nonce.0).into(), + aad, + &plaintext, + ) + .ok()?; + Some(auth) + } + + fn decrypt_with_aead( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + auth_tag: &[u8; EncryptedMessage::AUTH_SIZE], + ) -> bool { + let key: AesGcm256Key = (*key.data()).into(); + let ciphertext = buffer.to_vec(); + key.decrypt(buffer, (&nonce.0).into(), aad, &ciphertext, auth_tag.into()) + .is_ok() + } +} + +#[test] +fn encrypted_session_record_round_trip_and_decrypt() { + let crypto = TestCrypto::new(1); + let header = QlHeader { + sender: XID([1; XID::SIZE]), + recipient: XID([2; XID::SIZE]), + }; + let body = SessionEnvelope { + seq: SessionSeq(7), + ack: SessionAck { + base: SessionSeq(3), + bitmap: 0b101, + }, + body: SessionBody::Stream(StreamChunk { + stream_id: StreamId(9), + offset: 11, + bytes: b"hello".to_vec(), + fin: true, + }), + }; + let session_key = SessionKey::from_data([7; SessionKey::SIZE]); + let record = encrypted::encrypt_record( + &crypto, + header, + &session_key, + &body, + Nonce([8; Nonce::SIZE]), + ) + .unwrap(); + + let bytes = record.encode(); + let decoded = QlRecord::decode(&bytes).unwrap(); + assert_eq!(decoded.header, header); + assert!(matches!(decoded.payload, QlPayload::Session(_))); + + let parsed = QlRecord::parse(&bytes).unwrap(); + assert_eq!(parsed.to_owned(), record); + + let mut bytes = bytes; + let QlRecordRef { header, payload } = QlRecord::parse_mut(&mut bytes).unwrap(); + let QlPayloadRef::Session(mut encrypted) = payload else { + panic!("expected session payload"); + }; + let decrypted = + encrypted::decrypt_record(&crypto, &header, &mut encrypted, &session_key).unwrap(); + assert_eq!(decrypted.to_session_envelope().unwrap(), body); +} + +#[test] +fn pair_request_round_trip_and_decrypt() { + let crypto = TestCrypto::new(9); + let sender_signing = generate_ml_dsa_keypair(&crypto); + let sender_kem = generate_ml_kem_keypair(&crypto); + let recipient_signing = generate_ml_dsa_keypair(&crypto); + let recipient_kem = generate_ml_kem_keypair(&crypto); + + let sender = QlIdentity::new( + XID([3; XID::SIZE]), + sender_signing.0, + sender_signing.1, + sender_kem.0, + sender_kem.1, + ); + let recipient = QlIdentity::new( + XID([4; XID::SIZE]), + recipient_signing.0, + recipient_signing.1, + recipient_kem.0, + recipient_kem.1, + ); + let meta = ControlMeta { + control_id: ControlId(55), + valid_until: 999, + }; + let record = pair::build_pair_request( + &crypto, + &sender, + recipient.xid, + &recipient.encapsulation_public_key, + meta, + ) + .unwrap(); + + let mut bytes = record.encode(); + let QlRecordRef { header, payload } = QlRecord::parse_mut(&mut bytes).unwrap(); + let QlPayloadRef::PairRequest(mut request) = payload else { + panic!("expected pair request"); + }; + let body = pair::decrypt_pair_request(&crypto, &recipient, &header, &mut request, 100).unwrap(); + assert_eq!(body.meta, meta); + assert_eq!(body.xid, sender.xid); + assert_eq!(body.signing_pub_key, sender.signing_public_key); + assert_eq!(body.encapsulation_pub_key, sender.encapsulation_public_key); +} + +#[test] +fn ready_round_trip_and_decrypt() { + let crypto = TestCrypto::new(30); + let header = QlHeader { + sender: XID([5; XID::SIZE]), + recipient: XID([6; XID::SIZE]), + }; + let session_key = SessionKey::from_data([11; SessionKey::SIZE]); + let meta = ControlMeta { + control_id: ControlId(77), + valid_until: 500, + }; + let ready = handshake::build_ready( + &crypto, + header, + &session_key, + meta, + Nonce([12; Nonce::SIZE]), + ) + .unwrap(); + let record = QlRecord { + header, + payload: QlPayload::Ready(ready), + }; + + let mut bytes = record.encode(); + let parsed = QlRecord::decode(&bytes).unwrap(); + assert_eq!(parsed, record); + + let QlRecordRef { header, payload } = QlRecord::parse_mut(&mut bytes).unwrap(); + let QlPayloadRef::Ready(mut ready) = payload else { + panic!("expected ready payload"); + }; + let body = handshake::decrypt_ready(&crypto, &header, &mut ready, &session_key, 100).unwrap(); + assert_eq!(body.meta, meta); +} + +#[test] +fn protocol_record_size_breakdown() { + fn meta(id: u32) -> ControlMeta { + ControlMeta { + control_id: ControlId(id), + valid_until: 1_000, + } + } + + fn header() -> QlHeader { + QlHeader { + sender: XID([1; XID::SIZE]), + recipient: XID([2; XID::SIZE]), + } + } + + fn encrypted(tag: u8, ciphertext_len: usize) -> EncryptedMessage { + EncryptedMessage { + nonce: Nonce([tag; Nonce::SIZE]), + auth: [tag; EncryptedMessage::AUTH_SIZE], + ciphertext: vec![tag; ciphertext_len], + } + } + + fn session_record(header: QlHeader, tag: u8, body: SessionEnvelope) -> QlRecord { + let ciphertext_len = body.encode().len(); + QlRecord { + header, + payload: QlPayload::Session(encrypted(tag, ciphertext_len)), + } + } + + let header = header(); + let hello = QlRecord { + header, + payload: QlPayload::Hello(handshake::Hello { + meta: meta(1), + nonce: Nonce([3; Nonce::SIZE]), + kem_ct: MlKemCiphertext::from_data([4; MlKemCiphertext::SIZE]), + signature: MlDsaSignature::from_data([5; MlDsaSignature::SIZE]), + }), + }; + let hello_reply = QlRecord { + header, + payload: QlPayload::HelloReply(handshake::HelloReply { + meta: meta(2), + nonce: Nonce([6; Nonce::SIZE]), + kem_ct: MlKemCiphertext::from_data([7; MlKemCiphertext::SIZE]), + signature: MlDsaSignature::from_data([8; MlDsaSignature::SIZE]), + }), + }; + let confirm = QlRecord { + header, + payload: QlPayload::Confirm(handshake::Confirm { + meta: meta(3), + signature: MlDsaSignature::from_data([9; MlDsaSignature::SIZE]), + }), + }; + let pair_request = QlRecord { + header, + payload: QlPayload::PairRequest(pair::PairRequestRecord { + kem_ct: MlKemCiphertext::from_data([10; MlKemCiphertext::SIZE]), + encrypted: encrypted(11, 0), + }), + }; + let ready = QlRecord { + header, + payload: QlPayload::Ready(handshake::Ready { + encrypted: encrypted(12, 0), + }), + }; + + let session_ack = session_record( + header, + 13, + SessionEnvelope { + seq: SessionSeq(1), + ack: SessionAck::EMPTY, + body: SessionBody::Ack, + }, + ); + let session_ping = session_record( + header, + 14, + SessionEnvelope { + seq: SessionSeq(2), + ack: SessionAck::EMPTY, + body: SessionBody::Ping(PingBody), + }, + ); + let session_unpair = session_record( + header, + 15, + SessionEnvelope { + seq: SessionSeq(3), + ack: SessionAck::EMPTY, + body: SessionBody::Unpair(UnpairBody), + }, + ); + let session_stream_empty = session_record( + header, + 16, + SessionEnvelope { + seq: SessionSeq(4), + ack: SessionAck::EMPTY, + body: SessionBody::Stream(StreamChunk { + stream_id: StreamId(1), + offset: 0, + fin: false, + bytes: Vec::new(), + }), + }, + ); + let session_stream_fin = session_record( + header, + 17, + SessionEnvelope { + seq: SessionSeq(5), + ack: SessionAck::EMPTY, + body: SessionBody::Stream(StreamChunk { + stream_id: StreamId(1), + offset: 0, + fin: true, + bytes: Vec::new(), + }), + }, + ); + let session_stream_close = session_record( + header, + 18, + SessionEnvelope { + seq: SessionSeq(6), + ack: SessionAck::EMPTY, + body: SessionBody::StreamClose(StreamClose { + stream_id: StreamId(1), + target: CloseTarget::Both, + code: CloseCode::PROTOCOL, + payload: Vec::new(), + }), + }, + ); + let session_close = session_record( + header, + 19, + SessionEnvelope { + seq: SessionSeq(7), + ack: SessionAck::EMPTY, + body: SessionBody::Close(SessionCloseBody { + code: CloseCode::PROTOCOL, + }), + }, + ); + + let print_size = |label: &str, size: usize| { + println!("{label:<32}: {size} bytes"); + }; + + print_size("ql-wire hello", hello.encode().len()); + print_size("ql-wire hello_reply", hello_reply.encode().len()); + print_size("ql-wire confirm", confirm.encode().len()); + print_size("ql-wire pair_request empty", pair_request.encode().len()); + print_size("ql-wire ready empty", ready.encode().len()); + print_size("ql-wire session ack", session_ack.encode().len()); + print_size("ql-wire session ping", session_ping.encode().len()); + print_size("ql-wire session unpair", session_unpair.encode().len()); + print_size( + "ql-wire session stream empty", + session_stream_empty.encode().len(), + ); + print_size( + "ql-wire session stream fin", + session_stream_fin.encode().len(), + ); + print_size( + "ql-wire session stream close", + session_stream_close.encode().len(), + ); + print_size("ql-wire session close", session_close.encode().len()); +} diff --git a/ql-wire/src/xid.rs b/ql-wire/src/xid.rs index 548da558..164b32e0 100644 --- a/ql-wire/src/xid.rs +++ b/ql-wire/src/xid.rs @@ -1,16 +1,7 @@ -use bc_components::{MLDSAPublicKey, SigningPublicKey}; -use rkyv::{Archive, Deserialize, Serialize}; - -#[derive( - Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, -)] -pub struct XID(pub [u8; Self::XID_SIZE]); +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct XID(pub [u8; Self::SIZE]); impl XID { - pub const XID_SIZE: usize = 32; - - pub fn from_signing_public_key(signing_public_key: &MLDSAPublicKey) -> Self { - let xid = bc_components::XID::new(SigningPublicKey::MLDSA(signing_public_key.clone())); - Self(*xid.data()) - } + pub const SIZE: usize = 32; } From 79b4480cedbc925fd3e76ad93012117c0311f42e Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 18 Mar 2026 01:16:43 -0400 Subject: [PATCH 010/304] ql: stabilize post-engine runtime and stream lifecycle behavior --- ql-fsm/src/implementation/fsm.rs | 32 ++- ql-fsm/src/implementation/mod.rs | 33 +-- ql-fsm/src/lib.rs | 19 +- ql-fsm/src/session/mod.rs | 476 +++++++++++++++++++++++-------- ql-fsm/src/session/state.rs | 70 +++-- ql-fsm/src/session/tests.rs | 290 ++++++++++++++++--- ql-fsm/src/tests/mod.rs | 2 + ql-fsm/src/tests/session.rs | 64 +++-- ql-runtime/src/driver.rs | 247 +++++++++++++++- ql-runtime/src/rpc/client.rs | 70 ----- ql-runtime/src/rpc/mod.rs | 153 ---------- ql-runtime/src/rpc/modality.rs | 35 --- ql-runtime/src/rpc/server.rs | 1 - ql-runtime/src/tests/mod.rs | 2 +- 14 files changed, 969 insertions(+), 525 deletions(-) delete mode 100644 ql-runtime/src/rpc/client.rs delete mode 100644 ql-runtime/src/rpc/mod.rs delete mode 100644 ql-runtime/src/rpc/modality.rs delete mode 100644 ql-runtime/src/rpc/server.rs diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index 403ab79d..12293775 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -121,8 +121,8 @@ pub fn confirm_session_write(fsm: &mut QlFsm, write_id: SessionWriteId) { fsm.session.confirm_write(fsm.state.now.instant, write_id.0); } -pub fn return_session_write(fsm: &mut QlFsm, write_id: SessionWriteId) { - fsm.session.return_write(write_id.0); +pub fn reject_session_write(fsm: &mut QlFsm, write_id: SessionWriteId) { + fsm.session.reject_write(write_id.0); } pub fn kill_session(fsm: &mut QlFsm, code: CloseCode) { @@ -156,7 +156,7 @@ pub fn take_next_session_event(fsm: &mut QlFsm) -> Option { pub fn open_stream(fsm: &mut QlFsm) -> Result { ensure_peer_bound(fsm)?; - fsm.session.open_stream().map_err(Into::into) + Ok(fsm.session.open_stream()?) } pub fn write_stream( @@ -165,14 +165,24 @@ pub fn write_stream( bytes: Vec, ) -> Result<(), QlFsmError> { ensure_peer_bound(fsm)?; - fsm.session - .write_stream(stream_id, bytes) - .map_err(Into::into) + Ok(fsm.session.write_stream(stream_id, bytes)?) +} + +pub fn read_stream( + fsm: &mut QlFsm, + stream_id: StreamId, + out: &mut [u8], +) -> Result { + Ok(fsm.session.read_stream(stream_id, out)?) +} + +pub fn stream_available_bytes(fsm: &QlFsm, stream_id: StreamId) -> Result { + Ok(fsm.session.stream_available_bytes(stream_id)?) } pub fn finish_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result<(), QlFsmError> { ensure_peer_bound(fsm)?; - fsm.session.finish_stream(stream_id).map_err(Into::into) + Ok(fsm.session.finish_stream(stream_id)?) } pub fn close_stream( @@ -183,21 +193,19 @@ pub fn close_stream( payload: Vec, ) -> Result<(), QlFsmError> { ensure_peer_bound(fsm)?; - fsm.session - .close_stream(stream_id, target, code, payload) - .map_err(Into::into) + Ok(fsm.session.close_stream(stream_id, target, code, payload)?) } pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), QlFsmError> { ensure_session_open(fsm)?; - fsm.session.queue_ping().map_err(Into::into) + Ok(fsm.session.queue_ping()?) } pub fn queue_unpair(fsm: &mut QlFsm) -> Result<(), QlFsmError> { ensure_session_open(fsm)?; // TODO: keep local peer/session state alive until this queued unpair is acked or times out, // then clear it locally. Right now this only requests remote unpair. - fsm.session.queue_unpair().map_err(Into::into) + Ok(fsm.session.queue_unpair()?) } fn ensure_peer_bound(fsm: &QlFsm) -> Result<(), QlFsmError> { diff --git a/ql-fsm/src/implementation/mod.rs b/ql-fsm/src/implementation/mod.rs index e1114d98..69e2bf9c 100644 --- a/ql-fsm/src/implementation/mod.rs +++ b/ql-fsm/src/implementation/mod.rs @@ -10,7 +10,7 @@ pub use peer::*; use ql_wire::{ControlId, ControlMeta, QlHeader, QlPayload, QlRecord, SessionKey, XID}; use crate::{ - session::{SessionEvent, SessionFsmConfig, StreamIncoming, StreamNamespace}, + session::{SessionEvent, SessionFsmConfig, StreamNamespace}, QlFsm, QlFsmEvent, QlSessionEvent, }; @@ -63,6 +63,7 @@ fn reset_session(fsm: &mut QlFsm) { fsm.session = crate::session::SessionFsm::new( SessionFsmConfig { local_namespace, + stream_chunk_size: fsm.config.session_stream_chunk_size, ack_delay: fsm.config.session_ack_delay, retransmit_timeout: fsm.config.session_retransmit_timeout, keepalive_interval: fsm.config.session_keepalive_interval, @@ -93,26 +94,18 @@ fn drain_session_events(fsm: &mut QlFsm) { .push_back(QlSessionEvent::Opened(stream_id)); } SessionEvent::Readable(stream_id) => { - while let Some(incoming) = fsm.session.take_next_inbound(stream_id) { - match incoming { - StreamIncoming::Data(bytes) => { - fsm.state - .session_events - .push_back(QlSessionEvent::Data { stream_id, bytes }); - } - StreamIncoming::Finished => { - fsm.state - .session_events - .push_back(QlSessionEvent::Finished(stream_id)); - } - StreamIncoming::Closed(frame) => { - fsm.state - .session_events - .push_back(QlSessionEvent::Closed(frame)); - } - } - } + fsm.state + .session_events + .push_back(QlSessionEvent::Readable(stream_id)); } + SessionEvent::Finished(stream_id) => fsm + .state + .session_events + .push_back(QlSessionEvent::Finished(stream_id)), + SessionEvent::Closed(frame) => fsm + .state + .session_events + .push_back(QlSessionEvent::Closed(frame)), SessionEvent::WritableClosed(stream_id) => { fsm.state .session_events diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 0ab099b9..1f6fc695 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -51,7 +51,7 @@ pub enum QlFsmEvent { #[derive(Debug, Clone, PartialEq, Eq)] pub enum QlSessionEvent { Opened(StreamId), - Data { stream_id: StreamId, bytes: Vec }, + Readable(StreamId), Finished(StreamId), Closed(StreamClose), WritableClosed(StreamId), @@ -78,6 +78,7 @@ pub struct QlFsmConfig { pub session_retransmit_timeout: Duration, pub session_keepalive_interval: Duration, pub session_peer_timeout: Duration, + pub session_stream_chunk_size: usize, } impl Default for QlFsmConfig { @@ -91,6 +92,7 @@ impl Default for QlFsmConfig { session_retransmit_timeout: Duration::from_millis(150), session_keepalive_interval: Duration::from_secs(10), session_peer_timeout: Duration::from_secs(30), + session_stream_chunk_size: 16 * 1024, } } } @@ -112,6 +114,7 @@ impl QlFsm { session: session::SessionFsm::new( session::SessionFsmConfig { local_namespace: session::StreamNamespace::Low, + stream_chunk_size: config.session_stream_chunk_size, ack_delay: config.session_ack_delay, retransmit_timeout: config.session_retransmit_timeout, keepalive_interval: config.session_keepalive_interval, @@ -192,7 +195,7 @@ impl QlFsm { /// This must be called at most once for a `SessionWriteId` returned by /// [`Self::take_next_write`] whose `session_write_id` was `Some`. pub fn reject_session_write(&mut self, write_id: SessionWriteId) { - implementation::return_session_write(self, write_id); + implementation::reject_session_write(self, write_id); } /// Aborts the current encrypted session locally. @@ -212,6 +215,18 @@ impl QlFsm { implementation::write_stream(self, stream_id, bytes) } + pub fn read_stream( + &mut self, + stream_id: StreamId, + out: &mut [u8], + ) -> Result { + implementation::read_stream(self, stream_id, out) + } + + pub fn stream_available_bytes(&self, stream_id: StreamId) -> Result { + implementation::stream_available_bytes(self, stream_id) + } + pub fn finish_stream(&mut self, stream_id: StreamId) -> Result<(), QlFsmError> { implementation::finish_stream(self, stream_id) } diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index e8328c29..38d69e3f 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -14,8 +14,8 @@ use ql_wire::{ use self::{ ring::SeqRingInsertError, state::{ - AckState, PendingChunk, PendingSessionBody, PendingStreamBody, SessionFsmState, StreamRole, - StreamState, TxEntry, TxState, + AckState, InboundState, OutboundState, PendingRxChunk, PendingSessionBody, SessionFsmState, + StreamOpenState, StreamRole, StreamState, TxEntry, TxState, }, }; @@ -57,6 +57,7 @@ impl StreamNamespace { #[derive(Debug, Clone, Copy)] pub struct SessionFsmConfig { pub local_namespace: StreamNamespace, + pub stream_chunk_size: usize, pub ack_delay: Duration, pub retransmit_timeout: Duration, pub keepalive_interval: Duration, @@ -67,6 +68,7 @@ impl Default for SessionFsmConfig { fn default() -> Self { Self { local_namespace: StreamNamespace::Low, + stream_chunk_size: 16 * 1024, ack_delay: Duration::from_millis(5), retransmit_timeout: Duration::from_millis(150), keepalive_interval: Duration::from_secs(10), @@ -79,18 +81,13 @@ impl Default for SessionFsmConfig { pub enum SessionEvent { Opened(StreamId), Readable(StreamId), + Finished(StreamId), + Closed(StreamClose), WritableClosed(StreamId), Unpaired, SessionClosed(SessionCloseBody), } -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum StreamIncoming { - Data(Vec), - Finished, - Closed(StreamClose), -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum SessionState { Open, @@ -113,7 +110,8 @@ pub struct SessionFsm { } impl SessionFsm { - pub fn new(config: SessionFsmConfig, now: Instant) -> Self { + pub fn new(mut config: SessionFsmConfig, now: Instant) -> Self { + config.stream_chunk_size = config.stream_chunk_size.max(1); Self { config, state: SessionFsmState { @@ -160,14 +158,7 @@ impl SessionFsm { return Err(StreamError::NotWritable); } - let frame = StreamChunk { - stream_id, - offset: stream.next_send_offset, - bytes, - fin: false, - }; - stream.next_send_offset += frame.bytes.len() as u64; - stream.send_queue.push_back(PendingStreamBody::Chunk(frame)); + stream.send_buf.extend(bytes); Ok(()) } @@ -182,15 +173,7 @@ impl SessionFsm { return Err(StreamError::NotWritable); } - stream.outbound_finished = true; - stream - .send_queue - .push_back(PendingStreamBody::Chunk(StreamChunk { - stream_id, - offset: stream.next_send_offset, - bytes: Vec::new(), - fin: true, - })); + stream.outbound_state = OutboundState::FinQueued; Ok(()) } @@ -202,24 +185,68 @@ impl SessionFsm { payload: Vec, ) -> Result<(), StreamError> { self.ensure_session_open()?; - let stream = self - .state - .streams - .get_mut(&stream_id) - .ok_or(StreamError::MissingStream)?; + { + let stream = self + .state + .streams + .get_mut(&stream_id) + .ok_or(StreamError::MissingStream)?; - Self::apply_close_to_stream(stream, target); - stream - .send_queue - .push_back(PendingStreamBody::Close(StreamClose { + Self::apply_close_to_stream(stream, target); + stream.pending_close = Some(StreamClose { stream_id, target, code, payload, - })); + }); + } + self.try_reap_stream(stream_id); Ok(()) } + pub fn read_stream( + &mut self, + stream_id: StreamId, + out: &mut [u8], + ) -> Result { + let written = { + let stream = self + .state + .streams + .get_mut(&stream_id) + .ok_or(StreamError::MissingStream)?; + if out.is_empty() || stream.recv_buf.is_empty() { + return Ok(0); + } + + let (front, back) = stream.recv_buf.as_slices(); + let front_len = front.len().min(out.len()); + out[..front_len].copy_from_slice(&front[..front_len]); + + let mut written = front_len; + let remaining = out.len() - front_len; + if remaining > 0 { + let back_len = back.len().min(remaining); + out[written..written + back_len].copy_from_slice(&back[..back_len]); + written += back_len; + } + + stream.recv_buf.drain(..written); + written + }; + self.try_reap_stream(stream_id); + Ok(written) + } + + pub fn stream_available_bytes(&self, stream_id: StreamId) -> Result { + let stream = self + .state + .streams + .get(&stream_id) + .ok_or(StreamError::MissingStream)?; + Ok(stream.recv_buf.len()) + } + pub fn queue_ping(&mut self) -> Result<(), StreamError> { self.ensure_session_open()?; self.state.pending_control.ping = true; @@ -297,21 +324,35 @@ impl SessionFsm { self.state.now = now; self.collect_timeouts(); let ack = self.state.current_ack(); - if let Some(seq) = self - .state - .tx_ring - .iter() - .find_map(|(seq, entry)| matches!(entry.state, TxState::Pending).then_some(seq)) - { + loop { + let Some(seq) = + self.state.tx_ring.iter().find_map(|(seq, entry)| { + matches!(entry.state, TxState::Pending).then_some(seq) + }) + else { + break; + }; + let Some(body) = self + .state + .tx_ring + .get(&seq) + .map(|entry| entry.pending.body.clone()) + else { + return None; + }; + if !self.should_retry_body(&body) { + let _ = self.state.tx_ring.remove(&seq); + self.state + .tx_ring + .advance_empty_front_until(self.state.next_seq); + continue; + } + let Some(entry) = self.state.tx_ring.get_mut(&seq) else { return None; }; entry.state = TxState::Issued; - return Some(SessionEnvelope { - seq, - ack, - body: entry.pending.body.clone(), - }); + return Some(SessionEnvelope { seq, ack, body }); } if !self.state.tx_ring.accepts_seq(self.state.next_seq) { @@ -375,7 +416,7 @@ impl SessionFsm { } } - pub fn return_write(&mut self, seq: SessionSeq) { + pub fn reject_write(&mut self, seq: SessionSeq) { debug_assert!(matches!( self.state.tx_ring.get(&seq).map(|entry| entry.state), Some(TxState::Issued) @@ -459,23 +500,17 @@ impl SessionFsm { self.state.events.pop_front() } - pub fn take_next_inbound(&mut self, stream_id: StreamId) -> Option { - self.state - .streams - .get_mut(&stream_id) - .and_then(|stream| stream.inbound_queue.pop_front()) - } - #[cfg(test)] pub fn session_state(&self) -> SessionState { self.state.session_state } pub fn has_pending_stream_work(&self) -> bool { - self.state - .streams - .values() - .any(|stream| !stream.send_queue.is_empty()) + self.state.streams.values().any(|stream| { + stream.pending_close.is_some() + || !stream.send_buf.is_empty() + || matches!(stream.outbound_state, OutboundState::FinQueued) + }) } fn next_pending_body(&mut self) -> Option { @@ -509,23 +544,53 @@ impl SessionFsm { .state .streams .get_index(index) - .is_some_and(|(_, stream)| !stream.send_queue.is_empty()); + .is_some_and(|(_, stream)| { + stream.pending_close.is_some() + || !stream.send_buf.is_empty() + || matches!(stream.outbound_state, OutboundState::FinQueued) + }); if !has_pending { continue; } - let item = { - let Some((_, stream)) = self.state.streams.get_index_mut(index) else { + let body = { + let Some((&stream_id, stream)) = self.state.streams.get_index_mut(index) else { continue; }; - let Some(item) = stream.send_queue.pop_front() else { - continue; - }; - item + match stream.open_state { + StreamOpenState::PendingSend => { + let body = Self::take_stream_frame( + stream, + stream_id, + self.config.stream_chunk_size, + ) + .map(SessionBody::Stream); + if body.is_some() { + stream.open_state = StreamOpenState::WaitingForAck; + } + body + } + StreamOpenState::WaitingForAck => None, + StreamOpenState::Opened => { + if let Some(close) = stream.pending_close.take() { + Some(SessionBody::StreamClose(close)) + } else { + Self::take_stream_frame( + stream, + stream_id, + self.config.stream_chunk_size, + ) + .map(SessionBody::Stream) + } + } + } + }; + let Some(body) = body else { + continue; }; self.state.next_stream_index = (index + 1) % len; return Some(PendingSessionBody { - body: item.to_session_body(), + body, retransmit: true, }); } @@ -551,17 +616,36 @@ impl SessionFsm { } fn process_ack(&mut self, ack: ql_wire::SessionAck) { - let acked: Vec<_> = self - .state - .tx_ring - .iter() - .filter_map(|(seq, entry)| { - (matches!(entry.state, TxState::Sent { .. }) && Self::ack_covers(ack, seq)) - .then_some(seq) - }) - .collect(); - for seq in acked { + loop { + let Some((seq, stream_id, opens_stream)) = + self.state.tx_ring.iter().find_map(|(seq, entry)| { + if !matches!(entry.state, TxState::Sent { .. }) || !Self::ack_covers(ack, seq) { + return None; + } + + let (stream_id, opens_stream) = match &entry.pending.body { + SessionBody::Stream(frame) => (Some(frame.stream_id), frame.offset == 0), + SessionBody::StreamClose(frame) => (Some(frame.stream_id), false), + _ => (None, false), + }; + + Some((seq, stream_id, opens_stream)) + }) + else { + break; + }; + let _ = self.state.tx_ring.remove(&seq); + if let Some(stream_id) = stream_id { + if opens_stream { + if let Some(stream) = self.state.streams.get_mut(&stream_id) { + if matches!(stream.open_state, StreamOpenState::WaitingForAck) { + stream.open_state = StreamOpenState::Opened; + } + } + } + self.try_reap_stream(stream_id); + } } self.state .tx_ring @@ -606,9 +690,22 @@ impl SessionFsm { .collect(); for seq in expired { - if let Some(entry) = self.state.tx_ring.remove(&seq) { - if entry.pending.retransmit { - self.requeue_pending_front(entry.pending); + let Some((retransmit, body)) = self + .state + .tx_ring + .get(&seq) + .map(|entry| (entry.pending.retransmit, entry.pending.body.clone())) + else { + continue; + }; + if retransmit && self.should_retry_body(&body) { + if let Some(entry) = self.state.tx_ring.get_mut(&seq) { + entry.state = TxState::Pending; + } + } else { + let _ = self.state.tx_ring.remove(&seq); + if matches!(body, SessionBody::Ack) { + self.state.clear_ack_schedule(); } } } @@ -618,29 +715,29 @@ impl SessionFsm { .advance_empty_front_until(self.state.next_seq); } - fn requeue_pending_front(&mut self, pending: PendingSessionBody) { - match pending.body { + fn should_retry_body(&self, body: &SessionBody) -> bool { + match body { + SessionBody::Ack => true, + SessionBody::Ping(_) | SessionBody::Unpair(_) => { + self.state.session_state == SessionState::Open + } + SessionBody::Close(_) => true, SessionBody::Stream(frame) => { - if let Some(stream) = self.state.streams.get_mut(&frame.stream_id) { - stream - .send_queue - .push_front(PendingStreamBody::Chunk(frame)); - } + self.state.session_state == SessionState::Open + && self + .state + .streams + .get(&frame.stream_id) + .is_some_and(|stream| { + !matches!(stream.outbound_state, OutboundState::Closed) + || (matches!(stream.open_state, StreamOpenState::WaitingForAck) + && frame.offset == 0) + }) } SessionBody::StreamClose(frame) => { - if let Some(stream) = self.state.streams.get_mut(&frame.stream_id) { - stream - .send_queue - .push_front(PendingStreamBody::Close(frame)); - } + self.state.session_state == SessionState::Open + && self.state.streams.contains_key(&frame.stream_id) } - body => match body { - SessionBody::Ack => {} - SessionBody::Ping(_) => self.state.pending_control.ping = true, - SessionBody::Unpair(_) => self.state.pending_control.unpair = true, - SessionBody::Close(close) => self.state.pending_control.close = Some(close), - SessionBody::Stream(_) | SessionBody::StreamClose(_) => unreachable!(), - }, } } @@ -654,6 +751,9 @@ impl SessionFsm { }); return; } + if frame.offset != 0 { + return; + } self.state .streams .insert(stream_id, StreamState::new(StreamRole::Responder)); @@ -663,10 +763,13 @@ impl SessionFsm { let Some(stream) = self.state.streams.get_mut(&stream_id) else { return; }; - if stream.inbound_discarding { + if matches!(stream.inbound_state, InboundState::Discarding) { return; } - if stream.inbound_closed || stream.inbound_finished { + if matches!( + stream.inbound_state, + InboundState::Closed(_) | InboundState::Finished + ) { if frame.offset + frame.bytes.len() as u64 <= stream.next_recv_offset { return; } @@ -688,18 +791,31 @@ impl SessionFsm { } if frame.offset == stream.next_recv_offset { + let was_readable = !stream.recv_buf.is_empty(); + let was_finished = matches!(stream.inbound_state, InboundState::Finished); Self::commit_inbound_frame(stream, frame); Self::drain_pending_recv(stream); - self.state - .events - .push_back(SessionEvent::Readable(stream_id)); + let became_readable = !was_readable && !stream.recv_buf.is_empty(); + let became_finished = + !was_finished && matches!(stream.inbound_state, InboundState::Finished); + if became_readable { + self.state + .events + .push_back(SessionEvent::Readable(stream_id)); + } + if became_finished { + self.state + .events + .push_back(SessionEvent::Finished(stream_id)); + } + self.try_reap_stream(stream_id); return; } if Self::insert_pending_chunk( stream, frame.offset, - PendingChunk { + PendingRxChunk { bytes: frame.bytes, fin: frame.fin, }, @@ -720,35 +836,41 @@ impl SessionFsm { return; }; - if Self::target_affects_inbound(stream.role, frame.target) && !stream.inbound_closed { - stream.inbound_closed = true; - stream.inbound_discarding = false; + if Self::target_affects_inbound(stream.role, frame.target) + && !matches!( + stream.inbound_state, + InboundState::Closed(_) | InboundState::Discarding + ) + { + stream.inbound_state = InboundState::Closed(frame.clone()); + stream.recv_buf.clear(); stream.pending_recv.clear(); - stream - .inbound_queue - .push_back(StreamIncoming::Closed(frame.clone())); self.state .events - .push_back(SessionEvent::Readable(frame.stream_id)); + .push_back(SessionEvent::Closed(frame.clone())); } - if Self::target_affects_outbound(stream.role, frame.target) && !stream.outbound_closed { - stream.outbound_closed = true; - stream.send_queue.clear(); + if Self::target_affects_outbound(stream.role, frame.target) + && !matches!(stream.outbound_state, OutboundState::Closed) + { + stream.outbound_state = OutboundState::Closed; + stream.send_buf.clear(); + stream.pending_close = None; self.state .events .push_back(SessionEvent::WritableClosed(frame.stream_id)); } + self.try_reap_stream(frame.stream_id); } fn apply_close_to_stream(stream: &mut StreamState, target: CloseTarget) { if Self::target_affects_inbound(stream.role, target) { - stream.inbound_discarding = true; + stream.inbound_state = InboundState::Discarding; + stream.recv_buf.clear(); stream.pending_recv.clear(); } if Self::target_affects_outbound(stream.role, target) { - stream.outbound_closed = true; - stream.outbound_finished = true; - stream.send_queue.clear(); + stream.outbound_state = OutboundState::Closed; + stream.send_buf.clear(); } } @@ -767,18 +889,17 @@ impl SessionFsm { fn commit_inbound_chunk(stream: &mut StreamState, bytes: Vec, fin: bool) { stream.next_recv_offset += bytes.len() as u64; if !bytes.is_empty() { - stream.inbound_queue.push_back(StreamIncoming::Data(bytes)); + stream.recv_buf.extend(bytes); } if fin { - stream.inbound_finished = true; - stream.inbound_queue.push_back(StreamIncoming::Finished); + stream.inbound_state = InboundState::Finished; } } fn drain_pending_recv(stream: &mut StreamState) { while let Some(chunk) = stream.pending_recv.remove(&stream.next_recv_offset) { Self::commit_inbound_chunk(stream, chunk.bytes, chunk.fin); - if stream.inbound_finished { + if matches!(stream.inbound_state, InboundState::Finished) { break; } } @@ -787,7 +908,7 @@ impl SessionFsm { fn insert_pending_chunk( stream: &mut StreamState, offset: u64, - chunk: PendingChunk, + chunk: PendingRxChunk, ) -> Result<(), ()> { let end = chunk.end_offset(offset); @@ -811,6 +932,111 @@ impl SessionFsm { Ok(()) } + fn take_stream_frame( + stream: &mut StreamState, + stream_id: StreamId, + chunk_size: usize, + ) -> Option { + if !stream.send_buf.is_empty() { + let len = stream.send_buf.len().min(chunk_size); + let bytes: Vec<_> = stream.send_buf.drain(..len).collect(); + let fin = if stream.send_buf.is_empty() + && matches!(stream.outbound_state, OutboundState::FinQueued) + { + stream.outbound_state = OutboundState::Finished; + true + } else { + false + }; + let frame = StreamChunk { + stream_id, + offset: stream.next_send_offset, + bytes, + fin, + }; + stream.next_send_offset += frame.bytes.len() as u64; + return Some(frame); + } + + if matches!(stream.outbound_state, OutboundState::FinQueued) { + stream.outbound_state = OutboundState::Finished; + return Some(StreamChunk { + stream_id, + offset: stream.next_send_offset, + bytes: Vec::new(), + fin: true, + }); + } + + None + } + + fn stream_is_reapable(&self, stream_id: StreamId, stream: &StreamState) -> bool { + let tx_ring_references_stream = + self.state + .tx_ring + .iter() + .any(|(_, entry)| match &entry.pending.body { + SessionBody::Stream(frame) => frame.stream_id == stream_id, + SessionBody::StreamClose(frame) => frame.stream_id == stream_id, + _ => false, + }); + + if tx_ring_references_stream { + return false; + } + + if !stream.send_buf.is_empty() + || !stream.recv_buf.is_empty() + || !stream.pending_recv.is_empty() + { + return false; + } + + match stream.open_state { + StreamOpenState::WaitingForAck => false, + StreamOpenState::PendingSend => matches!(stream.outbound_state, OutboundState::Closed), + StreamOpenState::Opened => { + stream.pending_close.is_none() + && matches!( + stream.inbound_state, + InboundState::Finished | InboundState::Closed(_) | InboundState::Discarding + ) + && matches!( + stream.outbound_state, + OutboundState::Finished | OutboundState::Closed + ) + } + } + } + + fn try_reap_stream(&mut self, stream_id: StreamId) { + let should_reap = self + .state + .streams + .get(&stream_id) + .is_some_and(|stream| self.stream_is_reapable(stream_id, stream)); + if !should_reap { + return; + } + + let Some(index) = self.state.streams.get_index_of(&stream_id) else { + return; + }; + self.state.streams.shift_remove(&stream_id); + + if self.state.streams.is_empty() { + self.state.next_stream_index = 0; + return; + } + if index < self.state.next_stream_index { + self.state.next_stream_index -= 1; + } + if self.state.next_stream_index >= self.state.streams.len() { + self.state.next_stream_index %= self.state.streams.len(); + } + } + fn fail_session(&mut self, close: SessionCloseBody) { if self.state.session_state == SessionState::Closed { return; diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index d08e7758..209a8700 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -5,11 +5,10 @@ use std::{ use indexmap::IndexMap; use ql_wire::{ - CloseTarget, SessionAck, SessionBody, SessionCloseBody, SessionSeq, StreamChunk, StreamClose, - StreamId, + CloseTarget, SessionAck, SessionBody, SessionCloseBody, SessionSeq, StreamClose, StreamId, }; -use super::{ring::SeqRing, SessionEvent, SessionState, StreamIncoming}; +use super::{ring::SeqRing, SessionEvent, SessionState}; pub const SESSION_WINDOW_CAPACITY: usize = 64; @@ -36,66 +35,75 @@ impl StreamRole { } #[derive(Debug, Clone)] -pub struct PendingChunk { +pub struct PendingRxChunk { pub bytes: Vec, pub fin: bool, } -impl PendingChunk { +impl PendingRxChunk { pub fn end_offset(&self, offset: u64) -> u64 { offset + self.bytes.len() as u64 } } #[derive(Debug, Clone)] -pub enum PendingStreamBody { - Chunk(StreamChunk), - Close(StreamClose), +pub enum OutboundState { + Open, + FinQueued, + Finished, + Closed, } -impl PendingStreamBody { - pub fn to_session_body(&self) -> SessionBody { - match self { - Self::Chunk(frame) => SessionBody::Stream(frame.clone()), - Self::Close(frame) => SessionBody::StreamClose(frame.clone()), - } - } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamOpenState { + PendingSend, + WaitingForAck, + Opened, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum InboundState { + Open, + Finished, + Closed(StreamClose), + Discarding, } #[derive(Debug)] pub struct StreamState { pub role: StreamRole, - pub send_queue: VecDeque, - pub inbound_queue: VecDeque, - pub pending_recv: BTreeMap, + pub open_state: StreamOpenState, + pub send_buf: VecDeque, + pub pending_close: Option, + pub recv_buf: VecDeque, + pub pending_recv: BTreeMap, pub next_send_offset: u64, pub next_recv_offset: u64, - pub outbound_finished: bool, - pub outbound_closed: bool, - pub inbound_finished: bool, - pub inbound_closed: bool, - pub inbound_discarding: bool, + pub outbound_state: OutboundState, + pub inbound_state: InboundState, } impl StreamState { pub fn new(role: StreamRole) -> Self { Self { role, - send_queue: VecDeque::new(), - inbound_queue: VecDeque::new(), + open_state: match role { + StreamRole::Initiator => StreamOpenState::PendingSend, + StreamRole::Responder => StreamOpenState::Opened, + }, + send_buf: VecDeque::new(), + pending_close: None, + recv_buf: VecDeque::new(), pending_recv: BTreeMap::new(), next_send_offset: 0, next_recv_offset: 0, - outbound_finished: false, - outbound_closed: false, - inbound_finished: false, - inbound_closed: false, - inbound_discarding: false, + outbound_state: OutboundState::Open, + inbound_state: InboundState::Open, } } pub fn is_writable(&self) -> bool { - !self.outbound_finished && !self.outbound_closed + matches!(self.outbound_state, OutboundState::Open) } } diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index a6a050c0..d41fa7bc 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -7,6 +7,19 @@ use ql_wire::{ use super::{SessionFsm, SessionFsmConfig, SessionState}; +fn read_stream_all(fsm: &mut SessionFsm, stream_id: ql_wire::StreamId) -> Vec { + let mut out = Vec::new(); + let mut buf = [0u8; 64]; + loop { + let read = fsm.read_stream(stream_id, &mut buf).unwrap(); + if read == 0 { + break; + } + out.extend_from_slice(&buf[..read]); + } + out +} + fn ack(seq: u64, ack: SessionAck) -> SessionEnvelope { SessionEnvelope { seq: SessionSeq(seq), @@ -32,8 +45,19 @@ fn outbound_session_seq_increments_monotonically() { fsm.write_stream(stream_id, b"one".to_vec()).unwrap(); let first = fsm.next_outbound(now).unwrap(); + fsm.receive( + now + Duration::from_millis(1), + ack( + 1, + SessionAck { + base: SessionSeq(1), + bitmap: 0, + }, + ), + ); + fsm.write_stream(stream_id, b"two".to_vec()).unwrap(); - let second = fsm.next_outbound(now + Duration::from_millis(1)).unwrap(); + let second = fsm.next_outbound(now + Duration::from_millis(2)).unwrap(); assert_eq!(first.seq, SessionSeq(1)); assert_eq!(second.seq, SessionSeq(2)); @@ -119,7 +143,7 @@ fn out_of_order_receive_produces_bitmap_ack_then_advances_base() { } #[test] -fn retransmit_requeues_body_with_new_session_seq() { +fn retransmit_reuses_session_seq() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = fsm.open_stream().unwrap(); @@ -131,7 +155,7 @@ fn retransmit_requeues_body_with_new_session_seq() { let retried = fsm.next_outbound(retransmit_at).unwrap(); assert_eq!(first.seq, SessionSeq(1)); - assert_eq!(retried.seq, SessionSeq(2)); + assert_eq!(retried.seq, SessionSeq(1)); assert_eq!(retried.body, first.body); } @@ -139,14 +163,15 @@ fn retransmit_requeues_body_with_new_session_seq() { fn repeated_outbound_messages_keep_reporting_latest_receive_ack() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_id = fsm.open_stream().unwrap(); + let stream_id_a = fsm.open_stream().unwrap(); + let stream_id_b = fsm.open_stream().unwrap(); fsm.receive(now, ack(1, SessionAck::EMPTY)); - fsm.write_stream(stream_id, b"one".to_vec()).unwrap(); + fsm.write_stream(stream_id_a, b"one".to_vec()).unwrap(); let first = fsm.next_outbound(now).unwrap(); - fsm.write_stream(stream_id, b"two".to_vec()).unwrap(); + fsm.write_stream(stream_id_b, b"two".to_vec()).unwrap(); let second = fsm.next_outbound(now + Duration::from_millis(1)).unwrap(); assert_eq!(first.ack.base, SessionSeq(1)); @@ -184,12 +209,12 @@ fn local_inbound_close_ignores_late_remote_bytes() { ); assert_eq!(fsm.session_state(), SessionState::Open); - assert!(fsm.take_next_inbound(stream_id).is_none()); + assert_eq!(read_stream_all(&mut fsm, stream_id), Vec::::new()); assert!(fsm.take_next_event().is_none()); } #[test] -fn out_of_order_remote_stream_buffers_until_initial_bytes_arrive() { +fn missing_stream_nonzero_offset_is_ignored_until_offset_zero_arrives() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 7); @@ -197,7 +222,7 @@ fn out_of_order_remote_stream_buffers_until_initial_bytes_arrive() { fsm.receive( now, SessionEnvelope { - seq: SessionSeq(2), + seq: SessionSeq(1), ack: SessionAck::EMPTY, body: SessionBody::Stream(StreamChunk { stream_id, @@ -209,38 +234,207 @@ fn out_of_order_remote_stream_buffers_until_initial_bytes_arrive() { ); assert_eq!(fsm.session_state(), SessionState::Open); + assert!(fsm.take_next_event().is_none()); + assert!(!fsm.state.streams.contains_key(&stream_id)); + + fsm.receive( + now + Duration::from_millis(1), + SessionEnvelope { + seq: SessionSeq(2), + ack: SessionAck::EMPTY, + body: SessionBody::Stream(StreamChunk { + stream_id, + offset: 0, + bytes: b"a".to_vec(), + fin: false, + }), + }, + ); + assert_eq!( fsm.take_next_event(), Some(super::SessionEvent::Opened(stream_id)) ); - assert!(fsm.take_next_inbound(stream_id).is_none()); + assert_eq!( + fsm.take_next_event(), + Some(super::SessionEvent::Readable(stream_id)) + ); + assert_eq!(read_stream_all(&mut fsm, stream_id), b"a".to_vec()); +} + +#[test] +fn local_stream_waits_for_open_frame_ack_before_sending_follow_up_data() { + let now = Instant::now(); + let mut fsm = SessionFsm::new( + SessionFsmConfig { + stream_chunk_size: 2, + ..SessionFsmConfig::default() + }, + now, + ); + let stream_id = fsm.open_stream().unwrap(); + + fsm.write_stream(stream_id, b"hello".to_vec()).unwrap(); + + let first = fsm.next_outbound(now).unwrap(); + assert_eq!( + first.body, + SessionBody::Stream(StreamChunk { + stream_id, + offset: 0, + bytes: b"he".to_vec(), + fin: false, + }) + ); + assert!(fsm.next_outbound(now + Duration::from_millis(1)).is_none()); fsm.receive( - now + Duration::from_millis(1), + now + Duration::from_millis(2), + ack( + 1, + SessionAck { + base: SessionSeq(1), + bitmap: 0, + }, + ), + ); + + let second = fsm.next_outbound(now + Duration::from_millis(3)).unwrap(); + assert_eq!( + second.body, + SessionBody::Stream(StreamChunk { + stream_id, + offset: 2, + bytes: b"ll".to_vec(), + fin: false, + }) + ); +} + +#[test] +fn stream_is_reaped_after_terminal_state_and_last_stream_ack() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 13); + + fsm.receive( + now, SessionEnvelope { seq: SessionSeq(1), ack: SessionAck::EMPTY, body: SessionBody::Stream(StreamChunk { stream_id, offset: 0, - bytes: b"a".to_vec(), - fin: false, + bytes: b"hi".to_vec(), + fin: true, }), }, ); + assert_eq!( + fsm.take_next_event(), + Some(super::SessionEvent::Opened(stream_id)) + ); assert_eq!( fsm.take_next_event(), Some(super::SessionEvent::Readable(stream_id)) ); + assert_eq!(read_stream_all(&mut fsm, stream_id), b"hi".to_vec()); assert_eq!( - fsm.take_next_inbound(stream_id), - Some(super::StreamIncoming::Data(b"a".to_vec())) + fsm.take_next_event(), + Some(super::SessionEvent::Finished(stream_id)) ); + assert!(fsm.state.streams.contains_key(&stream_id)); + + fsm.finish_stream(stream_id).unwrap(); + let fin = fsm.next_outbound(now + Duration::from_millis(1)).unwrap(); assert_eq!( - fsm.take_next_inbound(stream_id), - Some(super::StreamIncoming::Data(b"b".to_vec())) + fin.body, + SessionBody::Stream(StreamChunk { + stream_id, + offset: 0, + bytes: Vec::new(), + fin: true, + }) ); + assert!(fsm.state.streams.contains_key(&stream_id)); + + fsm.receive( + now + Duration::from_millis(2), + ack( + 2, + SessionAck { + base: SessionSeq(2), + bitmap: 0, + }, + ), + ); + + assert!(!fsm.state.streams.contains_key(&stream_id)); +} + +#[test] +fn replayed_remote_open_does_not_recreate_reaped_stream() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 17); + let opener = SessionEnvelope { + seq: SessionSeq(1), + ack: SessionAck::EMPTY, + body: SessionBody::Stream(StreamChunk { + stream_id, + offset: 0, + bytes: b"hi".to_vec(), + fin: true, + }), + }; + + fsm.receive(now, opener.clone()); + + assert_eq!( + fsm.take_next_event(), + Some(super::SessionEvent::Opened(stream_id)) + ); + assert_eq!( + fsm.take_next_event(), + Some(super::SessionEvent::Readable(stream_id)) + ); + assert_eq!(read_stream_all(&mut fsm, stream_id), b"hi".to_vec()); + assert_eq!( + fsm.take_next_event(), + Some(super::SessionEvent::Finished(stream_id)) + ); + + fsm.finish_stream(stream_id).unwrap(); + let fin = fsm.next_outbound(now + Duration::from_millis(1)).unwrap(); + assert_eq!( + fin.body, + SessionBody::Stream(StreamChunk { + stream_id, + offset: 0, + bytes: Vec::new(), + fin: true, + }) + ); + + fsm.receive( + now + Duration::from_millis(2), + ack( + 2, + SessionAck { + base: SessionSeq(1), + bitmap: 0, + }, + ), + ); + + assert!(!fsm.state.streams.contains_key(&stream_id)); + + fsm.receive(now + Duration::from_millis(3), opener); + + assert_eq!(fsm.session_state(), SessionState::Open); + assert!(!fsm.state.streams.contains_key(&stream_id)); + assert!(fsm.take_next_event().is_none()); } #[test] @@ -265,7 +459,7 @@ fn duplicate_committed_data_is_not_redelivered() { ); let _ = fsm.take_next_event(); let _ = fsm.take_next_event(); - let _ = fsm.take_next_inbound(stream_id); + let _ = read_stream_all(&mut fsm, stream_id); fsm.receive( now + Duration::from_millis(1), @@ -277,13 +471,19 @@ fn duplicate_committed_data_is_not_redelivered() { ); assert!(fsm.take_next_event().is_none()); - assert!(fsm.take_next_inbound(stream_id).is_none()); + assert_eq!(read_stream_all(&mut fsm, stream_id), Vec::::new()); } #[test] fn next_outbound_round_robins_across_ready_streams() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let mut fsm = SessionFsm::new( + SessionFsmConfig { + stream_chunk_size: 3, + ..SessionFsmConfig::default() + }, + now, + ); let stream_id_a = fsm.open_stream().unwrap(); let stream_id_b = fsm.open_stream().unwrap(); @@ -292,17 +492,39 @@ fn next_outbound_round_robins_across_ready_streams() { fsm.write_stream(stream_id_a, b"a-2".to_vec()).unwrap(); fsm.write_stream(stream_id_b, b"b-2".to_vec()).unwrap(); - let scheduled: Vec<_> = (0..4) + let first_round: Vec<_> = (0..2) .map(|_| match fsm.next_outbound(now).unwrap().body { SessionBody::Stream(frame) => frame.stream_id, other => panic!("expected stream frame, got {other:?}"), }) .collect(); - assert_eq!( - scheduled, - vec![stream_id_a, stream_id_b, stream_id_a, stream_id_b] + fsm.receive( + now + Duration::from_millis(1), + ack( + 1, + SessionAck { + base: SessionSeq(2), + bitmap: 0, + }, + ), ); + + let second_round: Vec<_> = (0..2) + .map(|_| { + match fsm + .next_outbound(now + Duration::from_millis(2)) + .unwrap() + .body + { + SessionBody::Stream(frame) => frame.stream_id, + other => panic!("expected stream frame, got {other:?}"), + } + }) + .collect(); + + assert_eq!(first_round, vec![stream_id_a, stream_id_b]); + assert_eq!(second_round, vec![stream_id_a, stream_id_b]); } #[test] @@ -342,9 +564,9 @@ fn receive_ping_schedules_ack_without_ping_pong() { fn tx_selective_ack_keeps_front_gap_pinned() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_id = fsm.open_stream().unwrap(); + let stream_ids: Vec<_> = (0..64).map(|_| fsm.open_stream().unwrap()).collect(); - for byte in 0..64u8 { + for (byte, stream_id) in (0..64u8).zip(stream_ids.iter().copied()) { fsm.write_stream(stream_id, vec![byte]).unwrap(); let _ = fsm .next_outbound(now + Duration::from_millis(byte as u64)) @@ -365,7 +587,8 @@ fn tx_selective_ack_keeps_front_gap_pinned() { assert!(fsm.state.tx_ring.contains_key(&SessionSeq(1))); assert!(!fsm.state.tx_ring.contains_key(&SessionSeq(2))); - fsm.write_stream(stream_id, b"x".to_vec()).unwrap(); + let extra_stream = fsm.open_stream().unwrap(); + fsm.write_stream(extra_stream, b"x".to_vec()).unwrap(); assert!(fsm .next_outbound(now + Duration::from_millis(101)) .is_none()); @@ -425,7 +648,7 @@ fn duplicate_old_packet_seq_is_ignored() { ); let _ = fsm.take_next_event(); let _ = fsm.take_next_event(); - let _ = fsm.take_next_inbound(stream_id); + let _ = read_stream_all(&mut fsm, stream_id); fsm.receive( now + Duration::from_millis(1), @@ -437,7 +660,7 @@ fn duplicate_old_packet_seq_is_ignored() { ); assert!(fsm.take_next_event().is_none()); - assert!(fsm.take_next_inbound(stream_id).is_none()); + assert_eq!(read_stream_all(&mut fsm, stream_id), Vec::::new()); } #[test] @@ -463,12 +686,9 @@ fn retransmitted_stream_close_is_idempotent() { assert_eq!( fsm.take_next_event(), - Some(super::SessionEvent::Readable(stream_id)) - ); - assert_eq!( - fsm.take_next_inbound(stream_id), - Some(super::StreamIncoming::Closed(frame.clone())) + Some(super::SessionEvent::Closed(frame.clone())) ); + assert_eq!(read_stream_all(&mut fsm, stream_id), Vec::::new()); fsm.receive( now + Duration::from_millis(1), @@ -480,5 +700,5 @@ fn retransmitted_stream_close_is_idempotent() { ); assert!(fsm.take_next_event().is_none()); - assert!(fsm.take_next_inbound(stream_id).is_none()); + assert_eq!(read_stream_all(&mut fsm, stream_id), Vec::::new()); } diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 28a94bdd..ac253221 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -148,6 +148,7 @@ impl Harness { harness.a.fsm.identity.xid, harness.a.fsm.peer.as_ref().unwrap().peer.xid, ), + stream_chunk_size: config.session_stream_chunk_size, ack_delay: config.session_ack_delay, retransmit_timeout: config.session_retransmit_timeout, keepalive_interval: config.session_keepalive_interval, @@ -161,6 +162,7 @@ impl Harness { harness.b.fsm.identity.xid, harness.b.fsm.peer.as_ref().unwrap().peer.xid, ), + stream_chunk_size: config.session_stream_chunk_size, ack_delay: config.session_ack_delay, retransmit_timeout: config.session_retransmit_timeout, keepalive_interval: config.session_keepalive_interval, diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index ced78785..44cc8b15 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -1,10 +1,23 @@ use std::time::Duration; -use ql_wire::SessionCloseBody; +use ql_wire::{SessionCloseBody, StreamId}; use super::*; use crate::{session::StreamNamespace, QlFsmEvent, QlSessionEvent}; +fn read_stream_all(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { + let mut out = Vec::new(); + let mut buf = [0u8; 64]; + loop { + let read = fsm.read_stream(stream_id, &mut buf).unwrap(); + if read == 0 { + break; + } + out.extend_from_slice(&buf[..read]); + } + out +} + #[test] fn connected_fsms_deliver_stream_data() { let mut harness = Harness::connected(QlFsmConfig::default()); @@ -25,11 +38,9 @@ fn connected_fsms_deliver_stream_data() { ); assert_eq!( harness.b.fsm.take_next_session_event(), - Some(QlSessionEvent::Data { - stream_id, - bytes: b"hello".to_vec(), - }) + Some(QlSessionEvent::Readable(stream_id)) ); + assert_eq!(read_stream_all(&mut harness.b.fsm, stream_id), b"hello".to_vec()); assert_eq!( harness.b.fsm.take_next_session_event(), Some(QlSessionEvent::Finished(stream_id)) @@ -66,10 +77,13 @@ fn lost_encrypted_record_is_retried_and_acked() { let retried = harness.next_outbound_a().unwrap(); let retried_body = decrypt_envelope(&harness.b.crypto, &retried, &session_key); - assert_ne!(first_body.seq, retried_body.seq); + assert_eq!(first_body.seq, retried_body.seq); assert_eq!(first_body.body, retried_body.body); harness.deliver_to_b(retried); + harness.advance(config.session_ack_delay); + harness.a.fsm.on_timer(harness.time()); + harness.b.fsm.on_timer(harness.time()); harness.pump(); assert_eq!( @@ -78,11 +92,9 @@ fn lost_encrypted_record_is_retried_and_acked() { ); assert_eq!( harness.b.fsm.take_next_session_event(), - Some(QlSessionEvent::Data { - stream_id, - bytes: b"retry".to_vec(), - }) + Some(QlSessionEvent::Readable(stream_id)) ); + assert_eq!(read_stream_all(&mut harness.b.fsm, stream_id), b"retry".to_vec()); harness.advance(config.session_retransmit_timeout + Duration::from_millis(1)); assert!(harness.next_outbound_a().is_none()); @@ -143,10 +155,11 @@ fn simultaneous_opens_use_disjoint_stream_id_namespaces() { ); assert_eq!( harness.a.fsm.take_next_session_event(), - Some(QlSessionEvent::Data { - stream_id: stream_id_b, - bytes: b"from-b".to_vec(), - }) + Some(QlSessionEvent::Readable(stream_id_b)) + ); + assert_eq!( + read_stream_all(&mut harness.a.fsm, stream_id_b), + b"from-b".to_vec() ); assert_eq!( harness.b.fsm.take_next_session_event(), @@ -154,10 +167,11 @@ fn simultaneous_opens_use_disjoint_stream_id_namespaces() { ); assert_eq!( harness.b.fsm.take_next_session_event(), - Some(QlSessionEvent::Data { - stream_id: stream_id_a, - bytes: b"from-a".to_vec(), - }) + Some(QlSessionEvent::Readable(stream_id_a)) + ); + assert_eq!( + read_stream_all(&mut harness.b.fsm, stream_id_a), + b"from-a".to_vec() ); } @@ -189,11 +203,9 @@ fn queued_stream_work_auto_connects_and_drains_after_handshake() { ); assert_eq!( harness.b.fsm.take_next_session_event(), - Some(QlSessionEvent::Data { - stream_id, - bytes: b"queued".to_vec(), - }) + Some(QlSessionEvent::Readable(stream_id)) ); + assert_eq!(read_stream_all(&mut harness.b.fsm, stream_id), b"queued".to_vec()); assert_eq!( harness.b.fsm.take_next_session_event(), Some(QlSessionEvent::Finished(stream_id)) @@ -283,11 +295,9 @@ fn returned_session_write_is_reissued_with_same_seq() { ); assert_eq!( harness.b.fsm.take_next_session_event(), - Some(QlSessionEvent::Data { - stream_id, - bytes: b"retry".to_vec(), - }) + Some(QlSessionEvent::Readable(stream_id)) ); + assert_eq!(read_stream_all(&mut harness.b.fsm, stream_id), b"retry".to_vec()); } #[test] @@ -329,7 +339,7 @@ fn unconfirmed_session_write_does_not_start_retransmit_timer() { let record = write.record; let retried = decrypt_envelope(&harness.b.crypto, &record, &session_key); - assert_ne!(retried.seq, first.seq); + assert_eq!(retried.seq, first.seq); assert_eq!(retried.body, first.body); } diff --git a/ql-runtime/src/driver.rs b/ql-runtime/src/driver.rs index 62345e0c..9eb73f53 100644 --- a/ql-runtime/src/driver.rs +++ b/ql-runtime/src/driver.rs @@ -91,6 +91,13 @@ impl InboundIo { Self::Open(tx) } + fn close(&mut self) { + if let Self::Open(tx) = self { + tx.close(); + } + *self = Self::Closed; + } + fn write_or_close(&mut self, bytes: Vec) -> bool { let Self::Open(tx) = self else { return true; @@ -246,7 +253,16 @@ impl DriverState { code, payload, } => { + if let Some(stream) = self.streams.get_mut(&stream_id) { + if target == CloseTarget::Both || target == stream.inbound_target() { + stream.inbound_mut().close(); + } + if target == CloseTarget::Both || target == stream.outbound_target() { + stream.outbound_mut().close(); + } + } let _ = self.fsm.close_stream(stream_id, target, code, payload); + self.try_reap_stream(stream_id); self.finish_step(platform, in_flight); } } @@ -303,9 +319,7 @@ impl DriverState { progressed = true; match event { QlSessionEvent::Opened(stream_id) => self.handle_opened_stream(platform, stream_id), - QlSessionEvent::Data { stream_id, bytes } => { - self.handle_inbound_data(stream_id, bytes) - } + QlSessionEvent::Readable(stream_id) => self.handle_inbound_readable(stream_id), QlSessionEvent::Finished(stream_id) => self.handle_inbound_finished(stream_id), QlSessionEvent::Closed(frame) => self.handle_closed_stream(frame), QlSessionEvent::WritableClosed(stream_id) => self.handle_writable_closed(stream_id), @@ -346,17 +360,36 @@ impl DriverState { })); } - fn handle_inbound_data(&mut self, stream_id: StreamId, bytes: Vec) { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; + fn handle_inbound_readable(&mut self, stream_id: StreamId) { + loop { + let max_len = self.fsm.config.session_stream_chunk_size.max(1); + let available = match self.fsm.stream_available_bytes(stream_id) { + Ok(available) => available, + Err(_) => return, + }; + if available == 0 { + break; + } + + let mut bytes = vec![0; available.min(max_len)]; + let read = match self.fsm.read_stream(stream_id, &mut bytes) { + Ok(read) => read, + Err(_) => return, + }; + bytes.truncate(read); - let target = stream.inbound_target(); - let should_close = stream.inbound_mut().write_or_close(bytes); - if should_close { - let _ = self - .fsm - .close_stream(stream_id, target, CloseCode::CANCELLED, Vec::new()); + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + let target = stream.inbound_target(); + let should_close = stream.inbound_mut().write_or_close(bytes); + if should_close { + let _ = self + .fsm + .close_stream(stream_id, target, CloseCode::CANCELLED, Vec::new()); + self.try_reap_stream(stream_id); + break; + } } } @@ -365,6 +398,7 @@ impl DriverState { return; }; stream.inbound_mut().finish(); + self.try_reap_stream(stream_id); } fn handle_closed_stream(&mut self, frame: ql_wire::StreamClose) { @@ -384,6 +418,7 @@ impl DriverState { if frame.target == CloseTarget::Both || frame.target == stream.outbound_target() { stream.outbound_mut().close(); } + self.try_reap_stream(frame.stream_id); } fn handle_writable_closed(&mut self, stream_id: StreamId) { @@ -391,6 +426,7 @@ impl DriverState { return; }; stream.outbound_mut().close(); + self.try_reap_stream(stream_id); } fn fail_all_streams(&mut self, error: QlError) { @@ -431,6 +467,24 @@ impl DriverState { } if finished { let _ = self.fsm.finish_stream(stream_id); + if let Some(stream) = self.streams.get_mut(&stream_id) { + stream.outbound_mut().close(); + } + self.try_reap_stream(stream_id); + } + } + + fn try_reap_stream(&mut self, stream_id: StreamId) { + let should_reap = self.streams.get(&stream_id).is_some_and(|stream| match stream { + DriverStreamIo::Initiator { request, response } => { + matches!(request, OutboundIo::Closed) && matches!(response, InboundIo::Closed) + } + DriverStreamIo::Responder { request, response } => { + matches!(request, InboundIo::Closed) && matches!(response, OutboundIo::Closed) + } + }); + if should_reap { + self.streams.remove(&stream_id); } } } @@ -543,3 +597,170 @@ fn unix_now_secs() -> u64 { .unwrap_or(Duration::ZERO) .as_secs() } + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::new_identity; + use ql_wire::{CloseCode, StreamClose, XID}; + use ql_fsm::Peer; + + struct NoopPlatform; + + impl ql_wire::QlCrypto for NoopPlatform { + fn fill_random_bytes(&self, data: &mut [u8]) { + data.fill(0); + } + + fn hash(&self, _parts: &[&[u8]]) -> [u8; 32] { + [0; 32] + } + + fn encrypt_with_aead( + &self, + _key: &ql_wire::SessionKey, + _nonce: &ql_wire::Nonce, + _aad: &[u8], + _buffer: &mut [u8], + ) -> Option<[u8; ql_wire::EncryptedMessage::AUTH_SIZE]> { + None + } + + fn decrypt_with_aead( + &self, + _key: &ql_wire::SessionKey, + _nonce: &ql_wire::Nonce, + _aad: &[u8], + _buffer: &mut [u8], + _auth_tag: &[u8; ql_wire::EncryptedMessage::AUTH_SIZE], + ) -> bool { + false + } + } + + impl QlPlatform for NoopPlatform { + fn write_message(&self, _message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { + Box::pin(async { Ok(()) }) + } + + fn sleep(&self, _duration: Duration) -> PlatformFuture<'_, ()> { + Box::pin(async {}) + } + + fn load_peer(&self) -> PlatformFuture<'_, Option> { + Box::pin(async { None }) + } + + fn persist_peer(&self, _peer: Peer) {} + + fn clear_peer(&self) {} + + fn handle_peer_status(&self, _peer: XID, _status: ql_fsm::PeerStatus) {} + + fn handle_inbound(&self, _event: HandlerEvent) {} + } + + fn new_driver_state() -> DriverState { + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + DriverState { + fsm: QlFsm::new(ql_fsm::QlFsmConfig::default(), new_identity(7), now()), + streams: HashMap::new(), + runtime_tx, + stream_send_buffer_bytes: 16, + max_concurrent_message_writes: 1, + } + } + + #[test] + fn handle_inbound_finished_reaps_closed_initiator_stream() { + let mut state = new_driver_state(); + let stream_id = StreamId(1); + let (response_tx, _response_rx) = async_channel::unbounded(); + + state.streams.insert( + stream_id, + DriverStreamIo::Initiator { + request: OutboundIo::Closed, + response: InboundIo::new(response_tx), + }, + ); + + state.handle_inbound_finished(stream_id); + + assert!(!state.streams.contains_key(&stream_id)); + } + + #[test] + fn handle_closed_stream_reaps_when_both_halves_close() { + let mut state = new_driver_state(); + let stream_id = StreamId(2); + let (request_tx, _request_rx) = async_channel::unbounded(); + let (response_reader, _response_writer) = piper::pipe(1); + + state.streams.insert( + stream_id, + DriverStreamIo::Responder { + request: InboundIo::new(request_tx), + response: OutboundIo::new(response_reader), + }, + ); + + state.handle_closed_stream(StreamClose { + stream_id, + target: CloseTarget::Both, + code: CloseCode::CANCELLED, + payload: Vec::new(), + }); + + assert!(!state.streams.contains_key(&stream_id)); + } + + #[test] + fn poll_stream_reaps_after_local_finish_when_inbound_is_closed() { + let mut state = new_driver_state(); + let stream_id = StreamId(3); + let (request_reader, request_writer) = piper::pipe(1); + + drop(request_writer); + state.streams.insert( + stream_id, + DriverStreamIo::Initiator { + request: OutboundIo::new(request_reader), + response: InboundIo::Closed, + }, + ); + + state.poll_stream(stream_id); + + assert!(!state.streams.contains_key(&stream_id)); + } + + #[test] + fn local_close_command_reaps_when_other_half_is_already_closed() { + let mut state = new_driver_state(); + let stream_id = StreamId(4); + let (request_reader, _request_writer) = piper::pipe(1); + let mut in_flight = Vec::new(); + + state.streams.insert( + stream_id, + DriverStreamIo::Initiator { + request: OutboundIo::new(request_reader), + response: InboundIo::Closed, + }, + ); + + state.drive_command( + RuntimeCommand::CloseStream { + stream_id, + target: CloseTarget::Request, + code: CloseCode::CANCELLED, + payload: Vec::new(), + }, + &NoopPlatform, + &mut in_flight, + ); + + assert!(!state.streams.contains_key(&stream_id)); + } +} diff --git a/ql-runtime/src/rpc/client.rs b/ql-runtime/src/rpc/client.rs deleted file mode 100644 index 0ff532ce..00000000 --- a/ql-runtime/src/rpc/client.rs +++ /dev/null @@ -1,70 +0,0 @@ -use dcbor::CBOR; - -use super::{modality::RequestResponse, RpcError, RpcRequestHead, RpcResponseHead}; -use crate::runtime::{RuntimeHandle, StreamConfig}; - -#[derive(Clone)] -pub struct RpcHandle { - inner: RuntimeHandle, -} - -impl RpcHandle { - pub fn new(inner: RuntimeHandle) -> Self { - Self { inner } - } - - pub fn runtime(&self) -> &RuntimeHandle { - &self.inner - } - - pub async fn request( - &self, - request: M, - config: StreamConfig, - ) -> Result { - let request_body = Into::::into(request).to_cbor_data(); - let request_head = CBOR::from(RpcRequestHead::new( - M::METHOD, - Some(request_body.len() as u64), - )) - .to_cbor_data(); - - let crate::runtime::PendingStream { - mut request, - accepted, - } = self.inner.open_stream(request_head, config).await?; - let accepted = accepted.await?; - request.write_all(&request_body).await?; - request.finish().await?; - - let response_head = - RpcResponseHead::try_from(CBOR::try_from_data(&accepted.response_head)?)?; - if response_head.version != super::RPC_VERSION { - return Err(RpcError::BadVersion(response_head.version)); - } - - let response_body = - read_stream_to_end(accepted.response, response_head.content_length).await?; - Ok(CBOR::try_from_data(&response_body)?.try_into()?) - } -} - -async fn read_stream_to_end( - mut stream: crate::runtime::InboundByteStream, - content_length: Option, -) -> Result, RpcError> { - let mut body = match content_length.and_then(|length| usize::try_from(length).ok()) { - Some(length) => Vec::with_capacity(length), - None => Vec::new(), - }; - while let Some(chunk) = stream.next_chunk().await? { - body.extend_from_slice(&chunk); - } - if let Some(expected) = content_length { - let actual = body.len() as u64; - if actual != expected { - return Err(RpcError::ContentLengthMismatch { expected, actual }); - } - } - Ok(body) -} diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs deleted file mode 100644 index f19b95b4..00000000 --- a/ql-runtime/src/rpc/mod.rs +++ /dev/null @@ -1,153 +0,0 @@ -mod server; - -pub mod client; -pub mod modality; - -pub use client::RpcHandle; -use dcbor::CBOR; -pub use modality::{MethodId, QlCodec, RequestResponse}; - -use crate::QlError; - -pub(crate) const RPC_VERSION: u16 = 1; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct RpcRequestHead { - pub version: u16, - pub method: MethodId, - pub content_length: Option, -} - -impl RpcRequestHead { - pub fn new(method: MethodId, content_length: Option) -> Self { - Self { - version: RPC_VERSION, - method, - content_length, - } - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct RpcResponseHead { - pub version: u16, - pub content_length: Option, -} - -impl RpcResponseHead { - pub fn new(content_length: Option) -> Self { - Self { - version: RPC_VERSION, - content_length, - } - } -} - -impl Default for RpcResponseHead { - fn default() -> Self { - Self::new(None) - } -} - -#[derive(Debug, thiserror::Error)] -pub enum RpcError { - #[error(transparent)] - Transport(#[from] QlError), - #[error(transparent)] - Decode(#[from] dcbor::Error), - #[error("unsupported rpc version {0}")] - BadVersion(u16), - #[error("rpc content length mismatch: expected {expected}, got {actual}")] - ContentLengthMismatch { expected: u64, actual: u64 }, -} - -impl From for CBOR { - fn from(value: RpcRequestHead) -> Self { - CBOR::from(vec![ - CBOR::from(value.version), - CBOR::from(value.method), - value - .content_length - .map(CBOR::from) - .unwrap_or_else(CBOR::null), - ]) - } -} - -impl TryFrom for RpcRequestHead { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let [version, method, content_length] = take_fields(value.try_into_array()?.into_iter())?; - Ok(Self { - version: version.try_into()?, - method: method.try_into()?, - content_length: if content_length.is_null() { - None - } else { - Some(content_length.try_into()?) - }, - }) - } -} - -impl From for CBOR { - fn from(value: RpcResponseHead) -> Self { - CBOR::from(vec![ - CBOR::from(value.version), - value - .content_length - .map(CBOR::from) - .unwrap_or_else(CBOR::null), - ]) - } -} - -impl TryFrom for RpcResponseHead { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - let [version, content_length] = take_fields(value.try_into_array()?.into_iter())?; - Ok(Self { - version: version.try_into()?, - content_length: if content_length.is_null() { - None - } else { - Some(content_length.try_into()?) - }, - }) - } -} - -fn take_fields( - mut iter: impl Iterator, -) -> Result<[CBOR; N], dcbor::Error> { - use std::mem::MaybeUninit; - - let mut fields: [MaybeUninit; N] = [const { MaybeUninit::uninit() }; N]; - for (index, slot) in fields.iter_mut().enumerate() { - let Some(value) = iter.next() else { - for init in &mut fields[..index] { - unsafe { init.assume_init_drop() }; - } - return Err(dcbor::Error::msg("array too short")); - }; - slot.write(value); - } - let result = unsafe { std::ptr::read(&fields as *const _ as *const [CBOR; N]) }; - if iter.next().is_some() { - return Err(dcbor::Error::msg("array too long")); - } - Ok(result) -} - -#[test] -fn take_fields_reads_exact_count() { - let values = vec![CBOR::from(1u8), CBOR::from(2u8), CBOR::from(3u8)]; - let mut iter = values.into_iter(); - let [first, second, third] = take_fields(&mut iter).unwrap(); - assert_eq!(u8::try_from(first).unwrap(), 1); - assert_eq!(u8::try_from(second).unwrap(), 2); - assert_eq!(u8::try_from(third).unwrap(), 3); - assert!(iter.next().is_none()); -} diff --git a/ql-runtime/src/rpc/modality.rs b/ql-runtime/src/rpc/modality.rs deleted file mode 100644 index 533ece93..00000000 --- a/ql-runtime/src/rpc/modality.rs +++ /dev/null @@ -1,35 +0,0 @@ -use std::fmt; - -use dcbor::CBOR; - -pub trait QlCodec: Into + TryFrom {} - -impl QlCodec for T where T: Into + TryFrom {} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct MethodId(pub u64); - -impl fmt::Display for MethodId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From for CBOR { - fn from(value: MethodId) -> Self { - CBOR::from(value.0) - } -} - -impl TryFrom for MethodId { - type Error = dcbor::Error; - - fn try_from(value: CBOR) -> Result { - Ok(Self(u64::try_from(value)?)) - } -} - -pub trait RequestResponse: QlCodec { - const METHOD: MethodId; - type Response: QlCodec; -} diff --git a/ql-runtime/src/rpc/server.rs b/ql-runtime/src/rpc/server.rs deleted file mode 100644 index 8b137891..00000000 --- a/ql-runtime/src/rpc/server.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index d37d1802..b789b41a 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -331,7 +331,7 @@ fn is_encrypted_payload(bytes: &[u8]) -> bool { .is_some_and(|record| matches!(record.payload, QlPayload::Session(_))) } -fn new_identity(seed: u8) -> QlIdentity { +pub(crate) fn new_identity(seed: u8) -> QlIdentity { let crypto = DeterministicCrypto::new(seed); let (signing_private, signing_public) = generate_ml_dsa_keypair(&crypto); let (encapsulation_private, encapsulation_public) = generate_ml_kem_keypair(&crypto); From 3f545ad567cfe592fb87ee6a7eb7f380b1bc390e Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 18 Mar 2026 01:29:12 -0400 Subject: [PATCH 011/304] ql: fmt --- ql-fsm/src/tests/session.rs | 20 ++++++++++++++++---- ql-runtime/src/driver.rs | 24 ++++++++++++++---------- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 44cc8b15..1c9159de 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -40,7 +40,10 @@ fn connected_fsms_deliver_stream_data() { harness.b.fsm.take_next_session_event(), Some(QlSessionEvent::Readable(stream_id)) ); - assert_eq!(read_stream_all(&mut harness.b.fsm, stream_id), b"hello".to_vec()); + assert_eq!( + read_stream_all(&mut harness.b.fsm, stream_id), + b"hello".to_vec() + ); assert_eq!( harness.b.fsm.take_next_session_event(), Some(QlSessionEvent::Finished(stream_id)) @@ -94,7 +97,10 @@ fn lost_encrypted_record_is_retried_and_acked() { harness.b.fsm.take_next_session_event(), Some(QlSessionEvent::Readable(stream_id)) ); - assert_eq!(read_stream_all(&mut harness.b.fsm, stream_id), b"retry".to_vec()); + assert_eq!( + read_stream_all(&mut harness.b.fsm, stream_id), + b"retry".to_vec() + ); harness.advance(config.session_retransmit_timeout + Duration::from_millis(1)); assert!(harness.next_outbound_a().is_none()); @@ -205,7 +211,10 @@ fn queued_stream_work_auto_connects_and_drains_after_handshake() { harness.b.fsm.take_next_session_event(), Some(QlSessionEvent::Readable(stream_id)) ); - assert_eq!(read_stream_all(&mut harness.b.fsm, stream_id), b"queued".to_vec()); + assert_eq!( + read_stream_all(&mut harness.b.fsm, stream_id), + b"queued".to_vec() + ); assert_eq!( harness.b.fsm.take_next_session_event(), Some(QlSessionEvent::Finished(stream_id)) @@ -297,7 +306,10 @@ fn returned_session_write_is_reissued_with_same_seq() { harness.b.fsm.take_next_session_event(), Some(QlSessionEvent::Readable(stream_id)) ); - assert_eq!(read_stream_all(&mut harness.b.fsm, stream_id), b"retry".to_vec()); + assert_eq!( + read_stream_all(&mut harness.b.fsm, stream_id), + b"retry".to_vec() + ); } #[test] diff --git a/ql-runtime/src/driver.rs b/ql-runtime/src/driver.rs index 9eb73f53..f5467f03 100644 --- a/ql-runtime/src/driver.rs +++ b/ql-runtime/src/driver.rs @@ -475,14 +475,17 @@ impl DriverState { } fn try_reap_stream(&mut self, stream_id: StreamId) { - let should_reap = self.streams.get(&stream_id).is_some_and(|stream| match stream { - DriverStreamIo::Initiator { request, response } => { - matches!(request, OutboundIo::Closed) && matches!(response, InboundIo::Closed) - } - DriverStreamIo::Responder { request, response } => { - matches!(request, InboundIo::Closed) && matches!(response, OutboundIo::Closed) - } - }); + let should_reap = self + .streams + .get(&stream_id) + .is_some_and(|stream| match stream { + DriverStreamIo::Initiator { request, response } => { + matches!(request, OutboundIo::Closed) && matches!(response, InboundIo::Closed) + } + DriverStreamIo::Responder { request, response } => { + matches!(request, InboundIo::Closed) && matches!(response, OutboundIo::Closed) + } + }); if should_reap { self.streams.remove(&stream_id); } @@ -600,10 +603,11 @@ fn unix_now_secs() -> u64 { #[cfg(test)] mod tests { + use ql_fsm::Peer; + use ql_wire::{CloseCode, StreamClose, XID}; + use super::*; use crate::tests::new_identity; - use ql_wire::{CloseCode, StreamClose, XID}; - use ql_fsm::Peer; struct NoopPlatform; From 3443f9e39f0c576b91e9966743cdbb737cd26048 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 18 Mar 2026 08:05:38 -0400 Subject: [PATCH 012/304] ql: clippy wip --- ql-fsm/src/implementation/mod.rs | 2 +- ql-fsm/src/replay_cache.rs | 13 +++++++------ ql-fsm/src/session/mod.rs | 11 +++-------- ql-fsm/src/tests/mod.rs | 6 +++--- ql-fsm/src/tests/session.rs | 12 ++++++------ ql-wire/src/lib.rs | 4 ++++ ql-wire/src/pair/crypto.rs | 2 +- 7 files changed, 25 insertions(+), 25 deletions(-) diff --git a/ql-fsm/src/implementation/mod.rs b/ql-fsm/src/implementation/mod.rs index 69e2bf9c..fbb002d8 100644 --- a/ql-fsm/src/implementation/mod.rs +++ b/ql-fsm/src/implementation/mod.rs @@ -50,7 +50,7 @@ fn is_replayed_control(fsm: &mut QlFsm, peer: XID, meta: ControlMeta) -> bool { fn peer_session(fsm: &QlFsm) -> Option<(XID, SessionKey)> { let entry = fsm.peer.as_ref()?; - let session_key = entry.session.session_key()?.clone(); + let session_key = *entry.session.session_key()?; Some((entry.peer.xid, session_key)) } diff --git a/ql-fsm/src/replay_cache.rs b/ql-fsm/src/replay_cache.rs index 4843c517..5ffbb03d 100644 --- a/ql-fsm/src/replay_cache.rs +++ b/ql-fsm/src/replay_cache.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{hash_map::Entry, HashMap}; use ql_wire::{ControlId, ControlMeta, XID}; @@ -28,11 +28,12 @@ impl ReplayCache { control_id: meta.control_id, }; - if self.valid_until_by_key.contains_key(&key) { - true - } else { - self.valid_until_by_key.insert(key, meta.valid_until); - false + match self.valid_until_by_key.entry(key) { + Entry::Occupied(_) => true, + Entry::Vacant(entry) => { + entry.insert(meta.valid_until); + false + } } } } diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 38d69e3f..ab78203f 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -332,14 +332,11 @@ impl SessionFsm { else { break; }; - let Some(body) = self + let body = self .state .tx_ring .get(&seq) - .map(|entry| entry.pending.body.clone()) - else { - return None; - }; + .map(|entry| entry.pending.body.clone())?; if !self.should_retry_body(&body) { let _ = self.state.tx_ring.remove(&seq); self.state @@ -348,9 +345,7 @@ impl SessionFsm { continue; } - let Some(entry) = self.state.tx_ring.get_mut(&seq) else { - return None; - }; + let entry = self.state.tx_ring.get_mut(&seq)?; entry.state = TxState::Issued; return Some(SessionEnvelope { seq, ack, body }); } diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index ac253221..32d4e619 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -135,7 +135,7 @@ impl Harness { let session_key = SessionKey::from_data([7; SessionKey::SIZE]); harness.a.fsm.peer.as_mut().unwrap().session = ConnectionState::Connected { - session_key: session_key.clone(), + session_key, recent_ready: None, }; harness.b.fsm.peer.as_mut().unwrap().session = ConnectionState::Connected { @@ -266,8 +266,8 @@ fn test_identity(seed: u8) -> QlIdentity { fn peer_from_identity(identity: &QlIdentity) -> Peer { Peer { xid: identity.xid, - signing_key: identity.signing_public_key.clone(), - encapsulation_key: identity.encapsulation_public_key.clone(), + signing_key: identity.signing_public_key, + encapsulation_key: identity.encapsulation_public_key, } } diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 1c9159de..58f05752 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -71,8 +71,8 @@ fn lost_encrypted_record_is_retried_and_acked() { .unwrap() .session .session_key() - .unwrap() - .clone(); + .unwrap(); + let session_key = *session_key; let first_body = decrypt_envelope(&harness.b.crypto, &first, &session_key); harness.advance(config.session_retransmit_timeout + Duration::from_millis(1)); @@ -277,8 +277,8 @@ fn returned_session_write_is_reissued_with_same_seq() { .unwrap() .session .session_key() - .unwrap() - .clone(); + .unwrap(); + let session_key = *session_key; let first = decrypt_envelope(&harness.b.crypto, &record, &session_key); harness.return_write_a(id); @@ -335,8 +335,8 @@ fn unconfirmed_session_write_does_not_start_retransmit_timer() { .unwrap() .session .session_key() - .unwrap() - .clone(); + .unwrap(); + let session_key = *session_key; let first = decrypt_envelope(&harness.b.crypto, &record, &session_key); harness.advance(config.session_retransmit_timeout + Duration::from_millis(1)); diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs index 034cf80d..36d8ff7c 100644 --- a/ql-wire/src/lib.rs +++ b/ql-wire/src/lib.rs @@ -1,4 +1,8 @@ +//! //! quantum link protocol wire format +//! + +#![allow(clippy::too_many_arguments)] mod codec; mod control; diff --git a/ql-wire/src/pair/crypto.rs b/ql-wire/src/pair/crypto.rs index a10d52d1..690b26ca 100644 --- a/ql-wire/src/pair/crypto.rs +++ b/ql-wire/src/pair/crypto.rs @@ -125,7 +125,7 @@ fn hash_pairing_proof_data( ]) } -pub(crate) fn pairing_aad(header: &QlHeader, kem_ct: &MlKemCiphertext) -> Vec { +fn pairing_aad(header: &QlHeader, kem_ct: &MlKemCiphertext) -> Vec { let mut aad = Vec::new(); crate::codec::append_field(&mut aad, b"domain", b"ql-wire:pair-aad:v1"); crate::codec::append_field(&mut aad, b"sender", &header.sender.0); From db5dfed1b0cedf991e26cdceaffdea164432a935 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 18 Mar 2026 09:21:44 -0400 Subject: [PATCH 013/304] ql-wire: box pq types --- ql-fsm/src/implementation/handshake.rs | 10 ++-- ql-fsm/src/tests/mod.rs | 4 +- ql-wire/src/handshake/mod.rs | 4 +- ql-wire/src/pair/crypto.rs | 4 +- ql-wire/src/pq.rs | 72 +++++++++++++------------- 5 files changed, 47 insertions(+), 47 deletions(-) diff --git a/ql-fsm/src/implementation/handshake.rs b/ql-fsm/src/implementation/handshake.rs index 35fd2763..e452d13c 100644 --- a/ql-fsm/src/implementation/handshake.rs +++ b/ql-fsm/src/implementation/handshake.rs @@ -71,7 +71,7 @@ pub fn handle_hello( ConnectionState::Initiator { hello: local_hello, .. } => { - if peer_hello_wins(local_hello, fsm.identity.xid, &hello, header.sender) { + if peer_hello_wins(local_hello, fsm.identity.xid, hello, header.sender) { HelloAction::StartResponder } else { HelloAction::Ignore @@ -83,7 +83,7 @@ pub fn handle_hello( stage: HandshakeResponder::WaitingConfirm { .. }, .. } => { - if same_hello(stored, &hello) { + if same_hello(stored, hello) { HelloAction::ResendReply { reply: reply.clone(), } @@ -173,7 +173,7 @@ pub fn handle_hello_reply( .. } => HelloReplyAction::Advance { hello: hello.clone(), - initiator_secret: initiator_secret.clone(), + initiator_secret: *initiator_secret, responder_signing_key: entry.peer.signing_key.clone(), }, ConnectionState::Initiator { @@ -184,7 +184,7 @@ pub fn handle_hello_reply( .. }, .. - } if same_reply(stored, &reply) => HelloReplyAction::ResendConfirm { + } if same_reply(stored, reply) => HelloReplyAction::ResendConfirm { confirm: confirm.clone(), }, _ => return Ok(()), @@ -332,7 +332,7 @@ pub fn handle_ready( ConnectionState::Initiator { stage: HandshakeInitiator::WaitingReady { session_key, .. }, .. - } => session_key.clone(), + } => *session_key, _ => return Ok(()), } }; diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 32d4e619..14b2a0f2 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -266,8 +266,8 @@ fn test_identity(seed: u8) -> QlIdentity { fn peer_from_identity(identity: &QlIdentity) -> Peer { Peer { xid: identity.xid, - signing_key: identity.signing_public_key, - encapsulation_key: identity.encapsulation_public_key, + signing_key: identity.signing_public_key.clone(), + encapsulation_key: identity.encapsulation_public_key.clone(), } } diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 33adbae7..c98b5f33 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -87,8 +87,8 @@ impl HelloReply { Hello { meta: self.meta, nonce: self.nonce, - kem_ct: self.kem_ct, - signature: self.signature, + kem_ct: self.kem_ct.clone(), + signature: self.signature.clone(), } .encode_into(out); } diff --git a/ql-wire/src/pair/crypto.rs b/ql-wire/src/pair/crypto.rs index 690b26ca..466e935a 100644 --- a/ql-wire/src/pair/crypto.rs +++ b/ql-wire/src/pair/crypto.rs @@ -19,8 +19,8 @@ pub fn build_pair_request( sender: identity.xid, recipient, }; - let signing_pub_key = identity.signing_public_key; - let sender_encapsulation_key = identity.encapsulation_public_key; + let signing_pub_key = identity.signing_public_key.clone(); + let sender_encapsulation_key = identity.encapsulation_public_key.clone(); let proof_data = hash_pairing_proof_data( crypto, &header, diff --git a/ql-wire/src/pq.rs b/ql-wire/src/pq.rs index 61e9cca5..9f831668 100644 --- a/ql-wire/src/pq.rs +++ b/ql-wire/src/pq.rs @@ -31,17 +31,17 @@ impl AsRef<[u8]> for SessionKey { } #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct MlDsaPrivateKey([u8; MlDsaPrivateKey::SIZE]); +pub struct MlDsaPrivateKey(Box<[u8; MlDsaPrivateKey::SIZE]>); impl MlDsaPrivateKey { pub const SIZE: usize = ml_dsa_87::MLDSA87SigningKey::len(); - pub const fn from_data(data: [u8; Self::SIZE]) -> Self { - Self(data) + pub fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(Box::new(data)) } - pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { - &self.0 + pub fn as_bytes(&self) -> &[u8; Self::SIZE] { + self.0.as_ref() } pub fn sign( @@ -51,61 +51,61 @@ impl MlDsaPrivateKey { ) -> Result { let mut randomness = [0u8; SIGNING_RANDOMNESS_SIZE]; crypto.fill_random_bytes(&mut randomness); - let signing_key = ml_dsa_87::MLDSA87SigningKey::new(self.0); + let signing_key = ml_dsa_87::MLDSA87SigningKey::new(*self.as_bytes()); let signature = ml_dsa_87::sign(&signing_key, message, b"", randomness) .map_err(|_| WireError::SigningFailed)?; Ok(MlDsaSignature::from_data(*signature.as_ref())) } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct MlDsaPublicKey([u8; MlDsaPublicKey::SIZE]); +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MlDsaPublicKey(Box<[u8; MlDsaPublicKey::SIZE]>); impl MlDsaPublicKey { pub const SIZE: usize = ml_dsa_87::MLDSA87VerificationKey::len(); - pub const fn from_data(data: [u8; Self::SIZE]) -> Self { - Self(data) + pub fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(Box::new(data)) } - pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { - &self.0 + pub fn as_bytes(&self) -> &[u8; Self::SIZE] { + self.0.as_ref() } pub fn verify(&self, signature: &MlDsaSignature, message: &[u8]) -> bool { - let verification_key = ml_dsa_87::MLDSA87VerificationKey::new(self.0); + let verification_key = ml_dsa_87::MLDSA87VerificationKey::new(*self.as_bytes()); let signature = ml_dsa_87::MLDSA87Signature::new(*signature.as_bytes()); ml_dsa_87::verify(&verification_key, message, b"", &signature).is_ok() } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct MlDsaSignature([u8; MlDsaSignature::SIZE]); +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MlDsaSignature(Box<[u8; MlDsaSignature::SIZE]>); impl MlDsaSignature { pub const SIZE: usize = ml_dsa_87::MLDSA87Signature::len(); - pub const fn from_data(data: [u8; Self::SIZE]) -> Self { - Self(data) + pub fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(Box::new(data)) } - pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { - &self.0 + pub fn as_bytes(&self) -> &[u8; Self::SIZE] { + self.0.as_ref() } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct MlKemPublicKey([u8; MlKemPublicKey::SIZE]); +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MlKemPublicKey(Box<[u8; MlKemPublicKey::SIZE]>); impl MlKemPublicKey { pub const SIZE: usize = mlkem1024::MlKem1024PublicKey::len(); - pub const fn from_data(data: [u8; Self::SIZE]) -> Self { - Self(data) + pub fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(Box::new(data)) } - pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { - &self.0 + pub fn as_bytes(&self) -> &[u8; Self::SIZE] { + self.0.as_ref() } pub fn encapsulate_new_shared_secret( @@ -124,17 +124,17 @@ impl MlKemPublicKey { } #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct MlKemPrivateKey([u8; MlKemPrivateKey::SIZE]); +pub struct MlKemPrivateKey(Box<[u8; MlKemPrivateKey::SIZE]>); impl MlKemPrivateKey { pub const SIZE: usize = mlkem1024::MlKem1024PrivateKey::len(); - pub const fn from_data(data: [u8; Self::SIZE]) -> Self { - Self(data) + pub fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(Box::new(data)) } - pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { - &self.0 + pub fn as_bytes(&self) -> &[u8; Self::SIZE] { + self.0.as_ref() } pub fn decapsulate_shared_secret( @@ -148,18 +148,18 @@ impl MlKemPrivateKey { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct MlKemCiphertext([u8; MlKemCiphertext::SIZE]); +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MlKemCiphertext(Box<[u8; MlKemCiphertext::SIZE]>); impl MlKemCiphertext { pub const SIZE: usize = mlkem1024::MlKem1024Ciphertext::len(); - pub const fn from_data(data: [u8; Self::SIZE]) -> Self { - Self(data) + pub fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(Box::new(data)) } - pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { - &self.0 + pub fn as_bytes(&self) -> &[u8; Self::SIZE] { + self.0.as_ref() } } From 96a838c26b2747b091a4ef911b328d35ca409f15 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 18 Mar 2026 09:53:49 -0400 Subject: [PATCH 014/304] ql-wire: consistent wire types --- ql-fsm/src/implementation/fsm.rs | 2 +- ql-fsm/src/implementation/handshake.rs | 49 +++--- ql-fsm/src/implementation/peer.rs | 4 +- ql-wire/src/control.rs | 32 ++-- ql-wire/src/encrypted/close.rs | 27 +-- ql-wire/src/encrypted/mod.rs | 110 ++++++------ ql-wire/src/encrypted/stream_chunk.rs | 28 ++- ql-wire/src/encrypted/stream_close.rs | 30 ++-- ql-wire/src/encrypted_message.rs | 61 ++++--- ql-wire/src/handshake/crypto.rs | 185 +++++++++++--------- ql-wire/src/handshake/mod.rs | 229 +++++++++++++++---------- ql-wire/src/header.rs | 8 +- ql-wire/src/lib.rs | 3 + ql-wire/src/pair/crypto.rs | 16 +- ql-wire/src/pair/mod.rs | 78 +++++---- ql-wire/src/pq.rs | 15 +- ql-wire/src/record.rs | 57 +++--- ql-wire/src/tests.rs | 2 +- 18 files changed, 512 insertions(+), 424 deletions(-) diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index 12293775..ac388313 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -44,7 +44,7 @@ pub fn receive( return Ok(()); }; let envelope = match wire::decrypt_record(crypto, &header, &mut encrypted, &session_key) - .and_then(|envelope| envelope.to_session_envelope()) + .and_then(|envelope| wire::SessionEnvelope::from_wire(&envelope)) { Ok(envelope) => envelope, Err(_) => return Ok(()), diff --git a/ql-fsm/src/implementation/handshake.rs b/ql-fsm/src/implementation/handshake.rs index e452d13c..bdbc2d3e 100644 --- a/ql-fsm/src/implementation/handshake.rs +++ b/ql-fsm/src/implementation/handshake.rs @@ -1,8 +1,9 @@ use std::{cmp::Ordering, time::Instant}; use ql_wire::{ - self as wire, Confirm, Hello, HelloReply, MlDsaPublicKey, Nonce, QlCrypto, QlHeader, QlPayload, - Ready, ReadyRef, SessionKey, XID, + self as wire, Confirm, ConfirmWire, EncryptedMessageWire, Hello, HelloReply, HelloReplyWire, + HelloWire, MlDsaPublicKey, Nonce, QlCrypto, QlHeader, QlPayload, Ready, RefMut, SessionKey, + XID, }; use super::{ @@ -48,7 +49,7 @@ pub fn handle_hello( fsm: &mut QlFsm, crypto: &impl QlCrypto, header: &QlHeader, - hello: &Hello, + hello: &RefMut<'_, HelloWire>, ) -> Result<(), QlFsmError> { let action = { let Some(entry) = fsm.peer.as_ref() else { @@ -71,7 +72,7 @@ pub fn handle_hello( ConnectionState::Initiator { hello: local_hello, .. } => { - if peer_hello_wins(local_hello, fsm.identity.xid, hello, header.sender) { + if peer_hello_wins_ref(local_hello, fsm.identity.xid, hello, header.sender) { HelloAction::StartResponder } else { HelloAction::Ignore @@ -83,7 +84,7 @@ pub fn handle_hello( stage: HandshakeResponder::WaitingConfirm { .. }, .. } => { - if same_hello(stored, hello) { + if same_hello_ref(stored, hello) { HelloAction::ResendReply { reply: reply.clone(), } @@ -103,7 +104,7 @@ pub fn handle_hello( enqueue_handshake(fsm, header.sender, QlPayload::HelloReply(reply)); } HelloAction::StartResponder => { - if is_replayed_control(fsm, header.sender, hello.meta) { + if is_replayed_control(fsm, header.sender, wire::ControlMeta::from_wire(hello.meta)) { return Ok(()); } @@ -135,7 +136,7 @@ pub fn handle_hello( let retry_at = Some(fsm.state.now.instant + fsm.config.handshake_retry_interval); if let Some(entry) = fsm.peer.as_mut() { entry.session = ConnectionState::Responder { - hello: hello.clone(), + hello: wire::Hello::from_wire(hello), reply: reply.clone(), deadline, stage: HandshakeResponder::WaitingConfirm { @@ -157,7 +158,7 @@ pub fn handle_hello_reply( fsm: &mut QlFsm, crypto: &impl QlCrypto, header: &QlHeader, - reply: &HelloReply, + reply: &RefMut<'_, HelloReplyWire>, ) -> Result<(), QlFsmError> { let action = { let Some(entry) = fsm.peer.as_ref() else { @@ -184,7 +185,7 @@ pub fn handle_hello_reply( .. }, .. - } if same_reply(stored, reply) => HelloReplyAction::ResendConfirm { + } if same_reply_ref(stored, reply) => HelloReplyAction::ResendConfirm { confirm: confirm.clone(), }, _ => return Ok(()), @@ -216,7 +217,7 @@ pub fn handle_hello_reply( Err(_) => return Ok(()), }; - if is_replayed_control(fsm, header.sender, reply.meta) { + if is_replayed_control(fsm, header.sender, wire::ControlMeta::from_wire(reply.meta)) { return Ok(()); } @@ -227,7 +228,7 @@ pub fn handle_hello_reply( hello, deadline, stage: HandshakeInitiator::WaitingReady { - reply: reply.clone(), + reply: wire::HelloReply::from_wire(reply), confirm: confirm.clone(), session_key, retry_count: 0, @@ -246,7 +247,7 @@ pub fn handle_confirm( fsm: &mut QlFsm, crypto: &impl QlCrypto, header: &QlHeader, - confirm: &Confirm, + confirm: &RefMut<'_, ConfirmWire>, ) -> Result<(), QlFsmError> { if let Some(ready) = recent_ready_resend(fsm, crypto, header.sender, confirm) { enqueue_handshake(fsm, header.sender, QlPayload::Ready(ready)); @@ -286,7 +287,11 @@ pub fn handle_confirm( Err(_) => return Ok(()), }; - if is_replayed_control(fsm, header.sender, confirm.meta) { + if is_replayed_control( + fsm, + header.sender, + wire::ControlMeta::from_wire(confirm.meta), + ) { return Ok(()); } @@ -322,7 +327,7 @@ pub fn handle_ready( fsm: &mut QlFsm, crypto: &impl QlCrypto, header: &QlHeader, - ready: &mut ReadyRef<&mut [u8]>, + ready: &mut RefMut<'_, EncryptedMessageWire>, ) -> Result<(), QlFsmError> { let session_key = { let Some(entry) = fsm.peer.as_ref() else { @@ -521,7 +526,7 @@ fn recent_ready_resend( fsm: &QlFsm, crypto: &impl QlCrypto, peer: XID, - confirm: &Confirm, + confirm: &RefMut<'_, ConfirmWire>, ) -> Option { let entry = fsm.peer.as_ref()?; let ConnectionState::Connected { @@ -662,21 +667,21 @@ fn responder_retry_at(stage: &HandshakeResponder) -> Option { } } -fn same_hello(stored: &Hello, incoming: &Hello) -> bool { - stored.meta.control_id == incoming.meta.control_id && stored.nonce == incoming.nonce +fn same_hello_ref(stored: &Hello, incoming: &RefMut<'_, HelloWire>) -> bool { + stored.meta.control_id.0 == incoming.meta.control_id.get() && stored.nonce.0 == incoming.nonce } -fn same_reply(stored: &HelloReply, incoming: &HelloReply) -> bool { - stored.meta.control_id == incoming.meta.control_id && stored.nonce == incoming.nonce +fn same_reply_ref(stored: &HelloReply, incoming: &RefMut<'_, HelloReplyWire>) -> bool { + stored.meta.control_id.0 == incoming.meta.control_id.get() && stored.nonce.0 == incoming.nonce } -fn peer_hello_wins( +fn peer_hello_wins_ref( local_hello: &Hello, local_sender: XID, - peer_hello: &Hello, + peer_hello: &RefMut<'_, HelloWire>, peer_sender: XID, ) -> bool { - match peer_hello.nonce.0.cmp(&local_hello.nonce.0) { + match peer_hello.nonce.cmp(&local_hello.nonce.0) { Ordering::Less => true, Ordering::Greater => false, Ordering::Equal => peer_sender.0.cmp(&local_sender.0) == Ordering::Less, diff --git a/ql-fsm/src/implementation/peer.rs b/ql-fsm/src/implementation/peer.rs index cf0632b8..76605a20 100644 --- a/ql-fsm/src/implementation/peer.rs +++ b/ql-fsm/src/implementation/peer.rs @@ -1,4 +1,4 @@ -use ql_wire::{self as wire, PairRequestRecordRef, QlCrypto, QlHeader}; +use ql_wire::{self as wire, PairRequestRecordWire, QlCrypto, QlHeader, RefMut}; use super::{emit_peer_status, handshake, is_replayed_control, next_control_meta, reset_session}; use crate::{state::PeerRecord, Peer, QlFsm, QlFsmError, QlFsmEvent}; @@ -25,7 +25,7 @@ pub fn handle_pair( fsm: &mut QlFsm, crypto: &impl QlCrypto, header: &QlHeader, - request: &mut PairRequestRecordRef<&mut [u8]>, + request: &mut RefMut<'_, PairRequestRecordWire>, ) -> Result<(), QlFsmError> { let payload = match wire::decrypt_pair_request( crypto, diff --git a/ql-wire/src/control.rs b/ql-wire/src/control.rs index b0519c9c..47f1c74a 100644 --- a/ql-wire/src/control.rs +++ b/ql-wire/src/control.rs @@ -23,25 +23,25 @@ impl ControlMeta { Ok(()) } } -} -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] -#[repr(C)] -pub(crate) struct ControlMetaWire { - pub(crate) control_id: U32Le, - pub(crate) valid_until: U64Le, -} + pub fn to_wire(&self) -> ControlMetaWire { + ControlMetaWire { + control_id: U32Le::new(self.control_id.0), + valid_until: U64Le::new(self.valid_until), + } + } -pub(crate) fn control_meta_to_wire(meta: &ControlMeta) -> ControlMetaWire { - ControlMetaWire { - control_id: U32Le::new(meta.control_id.0), - valid_until: U64Le::new(meta.valid_until), + pub fn from_wire(meta: ControlMetaWire) -> Self { + Self { + control_id: ControlId(meta.control_id.get()), + valid_until: meta.valid_until.get(), + } } } -pub(crate) fn control_meta_from_wire(meta: ControlMetaWire) -> ControlMeta { - ControlMeta { - control_id: ControlId(meta.control_id.get()), - valid_until: meta.valid_until.get(), - } +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] +#[repr(C)] +pub struct ControlMetaWire { + pub control_id: U32Le, + pub valid_until: U64Le, } diff --git a/ql-wire/src/encrypted/close.rs b/ql-wire/src/encrypted/close.rs index 83dc9ef5..d0b95953 100644 --- a/ql-wire/src/encrypted/close.rs +++ b/ql-wire/src/encrypted/close.rs @@ -13,22 +13,29 @@ pub struct SessionCloseBody { #[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] #[repr(C)] -struct SessionCloseBodyWire { - code: U16Le, +pub struct SessionCloseBodyWire { + pub code: U16Le, } impl SessionCloseBody { - pub(crate) fn encode_into(&self, out: &mut Vec) { - let wire = SessionCloseBodyWire { + pub fn from_wire(wire: SessionCloseBodyWire) -> Self { + Self { + code: CloseCode(wire.code.get()), + } + } + + pub fn to_wire(&self) -> SessionCloseBodyWire { + SessionCloseBodyWire { code: U16Le::new(self.code.0), - }; - push_value(out, &wire); + } } - pub(crate) fn decode(bytes: &[u8]) -> Result { + pub fn encode_into(&self, out: &mut Vec) { + push_value(out, &self.to_wire()); + } + + pub fn decode(bytes: &[u8]) -> Result { let wire: SessionCloseBodyWire = read_exact(bytes)?; - Ok(Self { - code: CloseCode(wire.code.get()), - }) + Ok(Self::from_wire(wire)) } } diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 4d7846b2..07a75d56 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -5,7 +5,7 @@ use zerocopy::{ use crate::{ codec::{parse, push_value, U64Le}, - encrypted_message::{EncryptedMessage, EncryptedMessageRef}, + encrypted_message::{EncryptedMessage, EncryptedMessageWire}, Nonce, QlCrypto, QlHeader, QlPayload, QlRecord, SessionKey, WireError, }; @@ -63,8 +63,8 @@ pub enum SessionBodyRef { Ack, Ping, Unpair, - Stream(StreamChunkRef), - StreamClose(StreamCloseRef), + Stream(Ref), + StreamClose(Ref), Close(close::SessionCloseBody), } @@ -91,67 +91,32 @@ pub struct SessionEnvelopeWire { pub body: [u8], } -pub type SessionEnvelopeRef = Ref; - -impl SessionEnvelopeWire { - pub fn parse(bytes: B) -> Result, WireError> { +impl SessionEnvelope { + pub fn parse(bytes: B) -> Result, WireError> { parse(bytes) } - fn body_kind(&self) -> Result { - crate::codec::read_byte(self.kind) - } - - pub fn to_session_envelope(&self) -> Result { - let body = match parse_session_body(self.body_kind()?, &self.body)? { + pub fn from_wire(wire: &SessionEnvelopeWire) -> Result { + let body = match parse_session_body(session_body_kind(wire)?, &wire.body)? { SessionBodyRef::Ack => SessionBody::Ack, SessionBodyRef::Ping => SessionBody::Ping(ping::PingBody), SessionBodyRef::Unpair => SessionBody::Unpair(unpair::UnpairBody), - SessionBodyRef::Stream(frame) => SessionBody::Stream(frame.to_stream_chunk()?), + SessionBodyRef::Stream(frame) => SessionBody::Stream(StreamChunk::from_wire(&frame)?), SessionBodyRef::StreamClose(frame) => { - SessionBody::StreamClose(frame.to_stream_close()?) + SessionBody::StreamClose(StreamClose::from_wire(&frame)?) } SessionBodyRef::Close(body) => SessionBody::Close(body), }; - Ok(SessionEnvelope { - seq: SessionSeq(self.seq.get()), + Ok(Self { + seq: SessionSeq(wire.seq.get()), ack: SessionAck { - base: SessionSeq(self.ack_base.get()), - bitmap: self.ack_bitmap.get(), + base: SessionSeq(wire.ack_base.get()), + bitmap: wire.ack_bitmap.get(), }, body, }) } -} - -fn parse_session_body( - kind: SessionBodyKind, - body: B, -) -> Result, WireError> { - match kind { - SessionBodyKind::Ack => { - crate::codec::ensure_empty(&body)?; - Ok(SessionBodyRef::Ack) - } - SessionBodyKind::Ping => { - crate::codec::ensure_empty(&body)?; - Ok(SessionBodyRef::Ping) - } - SessionBodyKind::Unpair => { - crate::codec::ensure_empty(&body)?; - Ok(SessionBodyRef::Unpair) - } - SessionBodyKind::Stream => Ok(SessionBodyRef::Stream(StreamChunkWire::parse(body)?)), - SessionBodyKind::StreamClose => { - Ok(SessionBodyRef::StreamClose(StreamCloseWire::parse(body)?)) - } - SessionBodyKind::Close => Ok(SessionBodyRef::Close(close::SessionCloseBody::decode( - &body, - )?)), - } -} -impl SessionEnvelope { pub fn encode(&self) -> Vec { let mut out = Vec::new(); let kind = match &self.body { @@ -179,7 +144,7 @@ impl SessionEnvelope { } pub fn decode(bytes: &[u8]) -> Result { - SessionEnvelopeWire::parse(bytes)?.to_session_envelope() + Self::from_wire(&Self::parse(bytes)?) } } @@ -202,19 +167,48 @@ pub fn encrypt_record( pub fn decrypt_record<'a, B: ByteSliceMut>( crypto: &impl QlCrypto, header: &QlHeader, - encrypted: &'a mut EncryptedMessageRef, + encrypted: &'a mut Ref, session_key: &SessionKey, -) -> Result, WireError> { +) -> Result, WireError> { let aad = header.aad(); - let plaintext = encrypted.decrypt(crypto, session_key, &aad)?; - SessionEnvelopeWire::parse(plaintext) + let plaintext = EncryptedMessage::decrypt_in_place(encrypted, crypto, session_key, &aad)?; + SessionEnvelope::parse(plaintext) } #[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] #[repr(C)] -struct SessionEnvelopeHeaderWire { - seq: U64Le, - ack_base: U64Le, - ack_bitmap: U64Le, - kind: u8, +pub struct SessionEnvelopeHeaderWire { + pub seq: U64Le, + pub ack_base: U64Le, + pub ack_bitmap: U64Le, + pub kind: u8, +} + +fn session_body_kind(wire: &SessionEnvelopeWire) -> Result { + crate::codec::read_byte(wire.kind) +} + +fn parse_session_body( + kind: SessionBodyKind, + body: B, +) -> Result, WireError> { + match kind { + SessionBodyKind::Ack => { + crate::codec::ensure_empty(&body)?; + Ok(SessionBodyRef::Ack) + } + SessionBodyKind::Ping => { + crate::codec::ensure_empty(&body)?; + Ok(SessionBodyRef::Ping) + } + SessionBodyKind::Unpair => { + crate::codec::ensure_empty(&body)?; + Ok(SessionBodyRef::Unpair) + } + SessionBodyKind::Stream => Ok(SessionBodyRef::Stream(StreamChunk::parse(body)?)), + SessionBodyKind::StreamClose => Ok(SessionBodyRef::StreamClose(StreamClose::parse(body)?)), + SessionBodyKind::Close => Ok(SessionBodyRef::Close(close::SessionCloseBody::decode( + &body, + )?)), + } } diff --git a/ql-wire/src/encrypted/stream_chunk.rs b/ql-wire/src/encrypted/stream_chunk.rs index 60c98415..37fd6872 100644 --- a/ql-wire/src/encrypted/stream_chunk.rs +++ b/ql-wire/src/encrypted/stream_chunk.rs @@ -25,25 +25,21 @@ pub struct StreamChunkWire { pub bytes: [u8], } -pub type StreamChunkRef = Ref; - -impl StreamChunkWire { - pub fn parse(bytes: B) -> Result, WireError> { +impl StreamChunk { + pub fn parse(bytes: B) -> Result, WireError> { parse(bytes) } - pub fn to_stream_chunk(&self) -> Result { + pub fn from_wire(wire: &StreamChunkWire) -> Result { Ok(StreamChunk { - stream_id: StreamId(self.stream_id.get()), - offset: self.offset.get(), - bytes: self.bytes.to_vec(), - fin: crate::codec::read_byte(self.fin)?, + stream_id: StreamId(wire.stream_id.get()), + offset: wire.offset.get(), + bytes: wire.bytes.to_vec(), + fin: crate::codec::read_byte(wire.fin)?, }) } -} -impl StreamChunk { - pub(crate) fn encode_into(&self, out: &mut Vec) { + pub fn encode_into(&self, out: &mut Vec) { let header = StreamChunkHeaderWire { stream_id: U32Le::new(self.stream_id.0), offset: U64Le::new(self.offset), @@ -56,8 +52,8 @@ impl StreamChunk { #[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] #[repr(C)] -struct StreamChunkHeaderWire { - stream_id: U32Le, - offset: U64Le, - fin: u8, +pub struct StreamChunkHeaderWire { + pub stream_id: U32Le, + pub offset: U64Le, + pub fin: u8, } diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index dd314f47..b76a7d68 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -28,7 +28,7 @@ pub enum CloseTarget { } impl CloseTarget { - pub(crate) const fn to_wire(self) -> u8 { + pub const fn to_wire(self) -> u8 { self as u8 } } @@ -59,25 +59,21 @@ pub struct StreamCloseWire { pub payload: [u8], } -pub type StreamCloseRef = Ref; - -impl StreamCloseWire { - pub fn parse(bytes: B) -> Result, WireError> { +impl StreamClose { + pub fn parse(bytes: B) -> Result, WireError> { parse(bytes) } - pub fn to_stream_close(&self) -> Result { + pub fn from_wire(wire: &StreamCloseWire) -> Result { Ok(StreamClose { - stream_id: StreamId(self.stream_id.get()), - target: crate::codec::read_byte(self.target)?, - code: CloseCode(self.code.get()), - payload: self.payload.to_vec(), + stream_id: StreamId(wire.stream_id.get()), + target: crate::codec::read_byte(wire.target)?, + code: CloseCode(wire.code.get()), + payload: wire.payload.to_vec(), }) } -} -impl StreamClose { - pub(crate) fn encode_into(&self, out: &mut Vec) { + pub fn encode_into(&self, out: &mut Vec) { let header = StreamCloseHeaderWire { stream_id: U32Le::new(self.stream_id.0), target: self.target.to_wire(), @@ -90,8 +86,8 @@ impl StreamClose { #[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] #[repr(C)] -struct StreamCloseHeaderWire { - stream_id: U32Le, - target: u8, - code: U16Le, +pub struct StreamCloseHeaderWire { + pub stream_id: U32Le, + pub target: u8, + pub code: U16Le, } diff --git a/ql-wire/src/encrypted_message.rs b/ql-wire/src/encrypted_message.rs index 1a08bcab..1c80b1bd 100644 --- a/ql-wire/src/encrypted_message.rs +++ b/ql-wire/src/encrypted_message.rs @@ -1,5 +1,6 @@ use zerocopy::{ - byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, + byte_slice::{ByteSlice, ByteSliceMut}, + FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, }; use crate::{ @@ -15,8 +16,6 @@ pub struct EncryptedMessageWire { pub ciphertext: [u8], } -pub type EncryptedMessageRef = Ref; - #[derive(Debug, Clone, PartialEq, Eq)] pub struct EncryptedMessage { pub nonce: Nonce, @@ -24,37 +23,21 @@ pub struct EncryptedMessage { pub ciphertext: Vec, } -impl EncryptedMessageWire { - pub fn parse(bytes: B) -> Result, WireError> { +impl EncryptedMessage { + pub const AUTH_SIZE: usize = 16; + + pub fn parse(bytes: B) -> Result, WireError> { parse(bytes) } - pub fn to_encrypted_message(&self) -> EncryptedMessage { - EncryptedMessage { - nonce: Nonce(self.nonce), - auth: self.auth, - ciphertext: self.ciphertext.to_vec(), + pub fn from_wire(wire: &EncryptedMessageWire) -> Self { + Self { + nonce: Nonce(wire.nonce), + auth: wire.auth, + ciphertext: wire.ciphertext.to_vec(), } } - pub fn decrypt<'a>( - &'a mut self, - crypto: &impl QlCrypto, - key: &SessionKey, - aad: &[u8], - ) -> Result<&'a mut [u8], WireError> { - let nonce = Nonce(self.nonce); - let auth = self.auth; - if !crypto.decrypt_with_aead(key, &nonce, aad, &mut self.ciphertext, &auth) { - return Err(WireError::DecryptFailed); - } - Ok(&mut self.ciphertext) - } -} - -impl EncryptedMessage { - pub const AUTH_SIZE: usize = 16; - pub fn encode(&self) -> Vec { let mut out = Vec::with_capacity(Nonce::SIZE + Self::AUTH_SIZE + self.ciphertext.len()); self.encode_into(&mut out); @@ -62,7 +45,7 @@ impl EncryptedMessage { } pub fn decode(bytes: &[u8]) -> Result { - Ok(EncryptedMessageWire::parse(bytes)?.to_encrypted_message()) + Ok(Self::from_wire(&Self::parse(bytes)?)) } pub fn encode_into(&self, out: &mut Vec) { @@ -105,11 +88,25 @@ impl EncryptedMessage { } Ok(plaintext) } + + pub fn decrypt_in_place<'a, B: ByteSliceMut>( + wire: &'a mut Ref, + crypto: &impl QlCrypto, + key: &SessionKey, + aad: &[u8], + ) -> Result<&'a mut [u8], WireError> { + let nonce = Nonce(wire.nonce); + let auth = wire.auth; + if !crypto.decrypt_with_aead(key, &nonce, aad, &mut wire.ciphertext, &auth) { + return Err(WireError::DecryptFailed); + } + Ok(&mut wire.ciphertext) + } } #[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] #[repr(C)] -struct EncryptedMessageHeaderWire { - nonce: [u8; Nonce::SIZE], - auth: [u8; EncryptedMessage::AUTH_SIZE], +pub struct EncryptedMessageHeaderWire { + pub nonce: [u8; Nonce::SIZE], + pub auth: [u8; EncryptedMessage::AUTH_SIZE], } diff --git a/ql-wire/src/handshake/crypto.rs b/ql-wire/src/handshake/crypto.rs index f8fbc4bb..74e7a68d 100644 --- a/ql-wire/src/handshake/crypto.rs +++ b/ql-wire/src/handshake/crypto.rs @@ -1,9 +1,13 @@ -use zerocopy::byte_slice::ByteSliceMut; +use zerocopy::{ + byte_slice::{ByteSlice, ByteSliceMut}, + Ref, +}; -use super::{verify_signature, Confirm, Hello, HelloReply, Ready, ReadyBody, ReadyRef}; +use super::{Confirm, ConfirmWire, Hello, HelloReply, HelloReplyWire, HelloWire, Ready, ReadyBody}; use crate::{ - pq::ML_KEM_SUITE_TAG, ControlMeta, MlDsaPublicKey, MlKemCiphertext, MlKemPublicKey, Nonce, - QlCrypto, QlHeader, QlIdentity, SessionKey, WireError, XID, + pq::ML_KEM_SUITE_TAG, ControlMeta, EncryptedMessage, EncryptedMessageWire, MlDsaPublicKey, + MlDsaSignature, MlKemCiphertext, MlKemPublicKey, Nonce, QlCrypto, QlHeader, QlIdentity, + SessionKey, WireError, XID, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -22,7 +26,14 @@ pub fn build_hello( let nonce = next_nonce(crypto); let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(crypto)?; - let proof_data = hash_hello_proof_data(crypto, identity.xid, recipient, &meta, &nonce, &kem_ct); + let proof_data = hash_hello_proof_data( + crypto, + identity.xid, + recipient, + &meta, + &nonce.0, + kem_ct.as_bytes(), + ); let signature = identity.signing_private_key.sign(crypto, &proof_data)?; Ok(( Hello { @@ -35,33 +46,34 @@ pub fn build_hello( )) } -pub fn verify_hello( +pub fn verify_hello( crypto: &impl QlCrypto, initiator: XID, responder: XID, initiator_signing_key: &MlDsaPublicKey, - hello: &Hello, + hello: &Ref, now_seconds: u64, ) -> Result<(), WireError> { - hello.meta.ensure_not_expired(now_seconds)?; + let meta = ControlMeta::from_wire(hello.meta); + meta.ensure_not_expired(now_seconds)?; let proof_data = hash_hello_proof_data( crypto, initiator, responder, - &hello.meta, + &meta, &hello.nonce, &hello.kem_ct, ); - verify_signature(initiator_signing_key, &hello.signature, &proof_data) + verify_signature_bytes(initiator_signing_key, &hello.signature, &proof_data) } -pub fn respond_hello( +pub fn respond_hello( crypto: &impl QlCrypto, identity: &QlIdentity, initiator: XID, initiator_signing_key: &MlDsaPublicKey, initiator_encapsulation_key: &MlKemPublicKey, - hello: &Hello, + hello: &Ref, meta: ControlMeta, now_seconds: u64, ) -> Result<(HelloReply, ResponderSecrets), WireError> { @@ -75,7 +87,8 @@ pub fn respond_hello( )?; let initiator_secret = identity .encapsulation_private_key - .decapsulate_shared_secret(&hello.kem_ct)?; + .decapsulate_shared_secret_bytes(&hello.kem_ct)?; + let hello_meta = ControlMeta::from_wire(hello.meta); let nonce = next_nonce(crypto); let (responder_secret, kem_ct) = initiator_encapsulation_key.encapsulate_new_shared_secret(crypto)?; @@ -83,12 +96,12 @@ pub fn respond_hello( crypto, initiator, identity.xid, - &hello.meta, + &hello_meta, &hello.nonce, &hello.kem_ct, &meta, - &nonce, - &kem_ct, + &nonce.0, + kem_ct.as_bytes(), ); let signature = identity.signing_private_key.sign(crypto, &transcript)?; Ok(( @@ -105,42 +118,43 @@ pub fn respond_hello( )) } -pub fn build_confirm( +pub fn build_confirm( crypto: &impl QlCrypto, identity: &QlIdentity, responder: XID, responder_signing_key: &MlDsaPublicKey, hello: &Hello, - reply: &HelloReply, + reply: &Ref, initiator_secret: &SessionKey, meta: ControlMeta, now_seconds: u64, ) -> Result<(Confirm, SessionKey), WireError> { - reply.meta.ensure_not_expired(now_seconds)?; + let reply_meta = ControlMeta::from_wire(reply.meta); + reply_meta.ensure_not_expired(now_seconds)?; let transcript = hash_handshake_transcript( crypto, identity.xid, responder, &hello.meta, - &hello.nonce, - &hello.kem_ct, - &reply.meta, + &hello.nonce.0, + hello.kem_ct.as_bytes(), + &reply_meta, &reply.nonce, &reply.kem_ct, ); - verify_signature(responder_signing_key, &reply.signature, &transcript)?; + verify_signature_bytes(responder_signing_key, &reply.signature, &transcript)?; let responder_secret = identity .encapsulation_private_key - .decapsulate_shared_secret(&reply.kem_ct)?; + .decapsulate_shared_secret_bytes(&reply.kem_ct)?; let proof_data = hash_confirm_proof_data( crypto, &meta, identity.xid, responder, &hello.meta, - &hello.nonce, - &hello.kem_ct, - &reply.meta, + &hello.nonce.0, + hello.kem_ct.as_bytes(), + &reply_meta, &reply.nonce, &reply.kem_ct, ); @@ -152,23 +166,23 @@ pub fn build_confirm( identity.xid, responder, &hello.meta, - &hello.nonce, - &hello.kem_ct, - &reply.meta, + &hello.nonce.0, + hello.kem_ct.as_bytes(), + &reply_meta, &reply.nonce, &reply.kem_ct, ); Ok((Confirm { meta, signature }, session_key)) } -pub fn finalize_confirm( +pub fn finalize_confirm( crypto: &impl QlCrypto, initiator: XID, responder: XID, initiator_signing_key: &MlDsaPublicKey, hello: &Hello, reply: &HelloReply, - confirm: &Confirm, + confirm: &Ref, secrets: &ResponderSecrets, now_seconds: u64, ) -> Result { @@ -189,38 +203,39 @@ pub fn finalize_confirm( initiator, responder, &hello.meta, - &hello.nonce, - &hello.kem_ct, + &hello.nonce.0, + hello.kem_ct.as_bytes(), &reply.meta, - &reply.nonce, - &reply.kem_ct, + &reply.nonce.0, + reply.kem_ct.as_bytes(), )) } -pub fn verify_confirm( +pub fn verify_confirm( crypto: &impl QlCrypto, initiator: XID, responder: XID, initiator_signing_key: &MlDsaPublicKey, hello: &Hello, reply: &HelloReply, - confirm: &Confirm, + confirm: &Ref, now_seconds: u64, ) -> Result<(), WireError> { - confirm.meta.ensure_not_expired(now_seconds)?; + let confirm_meta = ControlMeta::from_wire(confirm.meta); + confirm_meta.ensure_not_expired(now_seconds)?; let proof_data = hash_confirm_proof_data( crypto, - &confirm.meta, + &confirm_meta, initiator, responder, &hello.meta, - &hello.nonce, - &hello.kem_ct, + &hello.nonce.0, + hello.kem_ct.as_bytes(), &reply.meta, - &reply.nonce, - &reply.kem_ct, + &reply.nonce.0, + reply.kem_ct.as_bytes(), ); - verify_signature(initiator_signing_key, &confirm.signature, &proof_data) + verify_signature_bytes(initiator_signing_key, &confirm.signature, &proof_data) } pub fn build_ready( @@ -233,25 +248,19 @@ pub fn build_ready( let aad = header.aad(); let body_bytes = ReadyBody { meta }.encode(); Ok(Ready { - encrypted: crate::encrypted_message::EncryptedMessage::encrypt( - crypto, - session_key, - body_bytes, - &aad, - nonce, - )?, + encrypted: EncryptedMessage::encrypt(crypto, session_key, body_bytes, &aad, nonce)?, }) } pub fn decrypt_ready( crypto: &impl QlCrypto, header: &QlHeader, - ready: &mut ReadyRef, + ready: &mut Ref, session_key: &SessionKey, now_seconds: u64, ) -> Result { let aad = header.aad(); - let plaintext = ready.decrypt(crypto, session_key, &aad)?; + let plaintext = EncryptedMessage::decrypt_in_place(ready, crypto, session_key, &aad)?; let body = ReadyBody::decode(plaintext)?; body.meta.ensure_not_expired(now_seconds)?; Ok(body) @@ -262,8 +271,8 @@ fn hash_hello_proof_data( initiator: XID, responder: XID, meta: &ControlMeta, - nonce: &Nonce, - kem_ct: &MlKemCiphertext, + nonce: &[u8; Nonce::SIZE], + kem_ct: &[u8; MlKemCiphertext::SIZE], ) -> [u8; 32] { let control_id = meta.control_id.0.to_le_bytes(); let valid_until = meta.valid_until.to_le_bytes(); @@ -278,11 +287,11 @@ fn hash_hello_proof_data( b"valid-until", &valid_until, b"nonce", - &nonce.0, + nonce, b"kem-suite", ML_KEM_SUITE_TAG, b"kem-ct", - kem_ct.as_bytes(), + kem_ct, ]) } @@ -291,11 +300,11 @@ fn hash_handshake_transcript( initiator: XID, responder: XID, hello_meta: &ControlMeta, - initiator_nonce: &Nonce, - initiator_kem_ct: &MlKemCiphertext, + initiator_nonce: &[u8; Nonce::SIZE], + initiator_kem_ct: &[u8; MlKemCiphertext::SIZE], reply_meta: &ControlMeta, - responder_nonce: &Nonce, - responder_kem_ct: &MlKemCiphertext, + responder_nonce: &[u8; Nonce::SIZE], + responder_kem_ct: &[u8; MlKemCiphertext::SIZE], ) -> [u8; 32] { let hello_control_id = hello_meta.control_id.0.to_le_bytes(); let hello_valid_until = hello_meta.valid_until.to_le_bytes(); @@ -312,21 +321,21 @@ fn hash_handshake_transcript( b"hello-valid-until", &hello_valid_until, b"initiator-nonce", - &initiator_nonce.0, + initiator_nonce, b"initiator-kem-suite", ML_KEM_SUITE_TAG, b"initiator-kem-ct", - initiator_kem_ct.as_bytes(), + initiator_kem_ct, b"reply-control-id", &reply_control_id, b"reply-valid-until", &reply_valid_until, b"responder-nonce", - &responder_nonce.0, + responder_nonce, b"responder-kem-suite", ML_KEM_SUITE_TAG, b"responder-kem-ct", - responder_kem_ct.as_bytes(), + responder_kem_ct, ]) } @@ -336,11 +345,11 @@ fn hash_confirm_proof_data( initiator: XID, responder: XID, hello_meta: &ControlMeta, - initiator_nonce: &Nonce, - initiator_kem_ct: &MlKemCiphertext, + initiator_nonce: &[u8; Nonce::SIZE], + initiator_kem_ct: &[u8; MlKemCiphertext::SIZE], reply_meta: &ControlMeta, - responder_nonce: &Nonce, - responder_kem_ct: &MlKemCiphertext, + responder_nonce: &[u8; Nonce::SIZE], + responder_kem_ct: &[u8; MlKemCiphertext::SIZE], ) -> [u8; 32] { let confirm_control_id = confirm_meta.control_id.0.to_le_bytes(); let confirm_valid_until = confirm_meta.valid_until.to_le_bytes(); @@ -363,21 +372,21 @@ fn hash_confirm_proof_data( b"hello-valid-until", &hello_valid_until, b"initiator-nonce", - &initiator_nonce.0, + initiator_nonce, b"initiator-kem-suite", ML_KEM_SUITE_TAG, b"initiator-kem-ct", - initiator_kem_ct.as_bytes(), + initiator_kem_ct, b"reply-control-id", &reply_control_id, b"reply-valid-until", &reply_valid_until, b"responder-nonce", - &responder_nonce.0, + responder_nonce, b"responder-kem-suite", ML_KEM_SUITE_TAG, b"responder-kem-ct", - responder_kem_ct.as_bytes(), + responder_kem_ct, ]) } @@ -394,11 +403,11 @@ fn derive_session_key( initiator: XID, responder: XID, hello_meta: &ControlMeta, - initiator_nonce: &Nonce, - initiator_kem_ct: &MlKemCiphertext, + initiator_nonce: &[u8; Nonce::SIZE], + initiator_kem_ct: &[u8; MlKemCiphertext::SIZE], reply_meta: &ControlMeta, - responder_nonce: &Nonce, - responder_kem_ct: &MlKemCiphertext, + responder_nonce: &[u8; Nonce::SIZE], + responder_kem_ct: &[u8; MlKemCiphertext::SIZE], ) -> SessionKey { let hello_control_id = hello_meta.control_id.0.to_le_bytes(); let hello_valid_until = hello_meta.valid_until.to_le_bytes(); @@ -419,20 +428,32 @@ fn derive_session_key( b"hello-valid-until", &hello_valid_until, b"initiator-nonce", - &initiator_nonce.0, + initiator_nonce, b"initiator-kem-suite", ML_KEM_SUITE_TAG, b"initiator-kem-ct", - initiator_kem_ct.as_bytes(), + initiator_kem_ct, b"reply-control-id", &reply_control_id, b"reply-valid-until", &reply_valid_until, b"responder-nonce", - &responder_nonce.0, + responder_nonce, b"responder-kem-suite", ML_KEM_SUITE_TAG, b"responder-kem-ct", - responder_kem_ct.as_bytes(), + responder_kem_ct, ])) } + +fn verify_signature_bytes( + signing_key: &MlDsaPublicKey, + signature: &[u8; MlDsaSignature::SIZE], + proof_data: &[u8], +) -> Result<(), WireError> { + if signing_key.verify_bytes(signature, proof_data) { + Ok(()) + } else { + Err(WireError::InvalidSignature) + } +} diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index c98b5f33..e503b63e 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -1,10 +1,12 @@ -use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned}; +use zerocopy::{ + byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, +}; use crate::{ - codec::{push_value, read_exact}, - control::{control_meta_from_wire, control_meta_to_wire, ControlMetaWire}, - encrypted_message::{EncryptedMessage, EncryptedMessageRef}, - ControlMeta, MlDsaPublicKey, MlDsaSignature, MlKemCiphertext, Nonce, WireError, + codec::{parse, push_value, read_exact}, + control::ControlMetaWire, + encrypted_message::{EncryptedMessage, EncryptedMessageWire}, + ControlMeta, MlDsaSignature, MlKemCiphertext, Nonce, WireError, }; mod crypto; @@ -18,6 +20,48 @@ pub struct Hello { pub signature: MlDsaSignature, } +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] +#[repr(C)] +pub struct HelloWire { + pub meta: ControlMetaWire, + pub nonce: [u8; Nonce::SIZE], + pub kem_ct: [u8; MlKemCiphertext::SIZE], + pub signature: [u8; MlDsaSignature::SIZE], +} + +impl Hello { + pub fn parse(bytes: B) -> Result, WireError> { + parse(bytes) + } + + pub fn from_wire(wire: &HelloWire) -> Self { + Self { + meta: ControlMeta::from_wire(wire.meta), + nonce: Nonce(wire.nonce), + kem_ct: MlKemCiphertext::from_data(wire.kem_ct), + signature: MlDsaSignature::from_data(wire.signature), + } + } + + pub fn decode(bytes: &[u8]) -> Result { + let wire = Self::parse(bytes)?; + Ok(Self::from_wire(&wire)) + } + + pub fn to_wire(&self) -> HelloWire { + HelloWire { + meta: self.meta.to_wire(), + nonce: self.nonce.0, + kem_ct: *self.kem_ct.as_bytes(), + signature: *self.signature.as_bytes(), + } + } + + pub fn encode_into(&self, out: &mut Vec) { + push_value(out, &self.to_wire()); + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct HelloReply { pub meta: ControlMeta, @@ -26,130 +70,131 @@ pub struct HelloReply { pub signature: MlDsaSignature, } -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Confirm { - pub meta: ControlMeta, - pub signature: MlDsaSignature, +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] +#[repr(C)] +pub struct HelloReplyWire { + pub meta: ControlMetaWire, + pub nonce: [u8; Nonce::SIZE], + pub kem_ct: [u8; MlKemCiphertext::SIZE], + pub signature: [u8; MlDsaSignature::SIZE], } -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Ready { - pub encrypted: EncryptedMessage, +impl HelloReply { + pub fn parse(bytes: B) -> Result, WireError> { + parse(bytes) + } + + pub fn from_wire(wire: &HelloReplyWire) -> Self { + Self { + meta: ControlMeta::from_wire(wire.meta), + nonce: Nonce(wire.nonce), + kem_ct: MlKemCiphertext::from_data(wire.kem_ct), + signature: MlDsaSignature::from_data(wire.signature), + } + } + + pub fn decode(bytes: &[u8]) -> Result { + let wire = Self::parse(bytes)?; + Ok(Self::from_wire(&wire)) + } + + pub fn to_wire(&self) -> HelloReplyWire { + HelloReplyWire { + meta: self.meta.to_wire(), + nonce: self.nonce.0, + kem_ct: *self.kem_ct.as_bytes(), + signature: *self.signature.as_bytes(), + } + } + + pub fn encode_into(&self, out: &mut Vec) { + push_value(out, &self.to_wire()); + } } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct ReadyBody { +pub struct Confirm { pub meta: ControlMeta, -} - -pub type ReadyRef = EncryptedMessageRef; - -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] -#[repr(C)] -struct HelloWire { - meta: ControlMetaWire, - nonce: [u8; Nonce::SIZE], - kem_ct: [u8; MlKemCiphertext::SIZE], - signature: [u8; MlDsaSignature::SIZE], + pub signature: MlDsaSignature, } #[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] #[repr(C)] -struct ConfirmWire { - meta: ControlMetaWire, - signature: [u8; MlDsaSignature::SIZE], +pub struct ConfirmWire { + pub meta: ControlMetaWire, + pub signature: [u8; MlDsaSignature::SIZE], } -impl Hello { - pub(crate) fn encode_into(&self, out: &mut Vec) { - let wire = HelloWire { - meta: control_meta_to_wire(&self.meta), - nonce: self.nonce.0, - kem_ct: *self.kem_ct.as_bytes(), - signature: *self.signature.as_bytes(), - }; - push_value(out, &wire); +impl Confirm { + pub fn parse(bytes: B) -> Result, WireError> { + parse(bytes) } - pub(crate) fn decode(bytes: &[u8]) -> Result { - let wire: HelloWire = read_exact(bytes)?; - Ok(Self { - meta: control_meta_from_wire(wire.meta), - nonce: Nonce(wire.nonce), - kem_ct: MlKemCiphertext::from_data(wire.kem_ct), + pub fn from_wire(wire: &ConfirmWire) -> Self { + Self { + meta: ControlMeta::from_wire(wire.meta), signature: MlDsaSignature::from_data(wire.signature), - }) - } -} - -impl HelloReply { - pub(crate) fn encode_into(&self, out: &mut Vec) { - Hello { - meta: self.meta, - nonce: self.nonce, - kem_ct: self.kem_ct.clone(), - signature: self.signature.clone(), } - .encode_into(out); } - pub(crate) fn decode(bytes: &[u8]) -> Result { - let hello = Hello::decode(bytes)?; - Ok(Self { - meta: hello.meta, - nonce: hello.nonce, - kem_ct: hello.kem_ct, - signature: hello.signature, - }) + pub fn decode(bytes: &[u8]) -> Result { + let wire = Self::parse(bytes)?; + Ok(Self::from_wire(&wire)) } -} -impl Confirm { - pub(crate) fn encode_into(&self, out: &mut Vec) { - let wire = ConfirmWire { - meta: control_meta_to_wire(&self.meta), + pub fn to_wire(&self) -> ConfirmWire { + ConfirmWire { + meta: self.meta.to_wire(), signature: *self.signature.as_bytes(), - }; - push_value(out, &wire); + } } - pub(crate) fn decode(bytes: &[u8]) -> Result { - let wire: ConfirmWire = read_exact(bytes)?; - Ok(Self { - meta: control_meta_from_wire(wire.meta), - signature: MlDsaSignature::from_data(wire.signature), - }) + pub fn encode_into(&self, out: &mut Vec) { + push_value(out, &self.to_wire()); } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Ready { + pub encrypted: EncryptedMessage, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ReadyBody { + pub meta: ControlMeta, +} + impl Ready { - pub(crate) fn encode_into(&self, out: &mut Vec) { + pub fn parse(bytes: B) -> Result, WireError> { + EncryptedMessage::parse(bytes) + } + + pub fn from_wire(wire: &EncryptedMessageWire) -> Self { + Self { + encrypted: EncryptedMessage::from_wire(wire), + } + } + + pub fn decode(bytes: &[u8]) -> Result { + let wire = Self::parse(bytes)?; + Ok(Self::from_wire(&wire)) + } + + pub fn encode_into(&self, out: &mut Vec) { self.encrypted.encode_into(out); } } impl ReadyBody { - pub(crate) fn encode(&self) -> Vec { - let wire = control_meta_to_wire(&self.meta); + pub fn encode(&self) -> Vec { + let wire = self.meta.to_wire(); wire.as_bytes().to_vec() } - pub(crate) fn decode(bytes: &[u8]) -> Result { + pub fn decode(bytes: &[u8]) -> Result { let wire: ControlMetaWire = read_exact(bytes)?; Ok(Self { - meta: control_meta_from_wire(wire), + meta: ControlMeta::from_wire(wire), }) } } - -pub fn verify_signature( - signing_key: &MlDsaPublicKey, - signature: &MlDsaSignature, - proof_data: &[u8], -) -> Result<(), WireError> { - if signing_key.verify(signature, proof_data) { - Ok(()) - } else { - Err(WireError::InvalidSignature) - } -} diff --git a/ql-wire/src/header.rs b/ql-wire/src/header.rs index 7fe1813a..d4a460a1 100644 --- a/ql-wire/src/header.rs +++ b/ql-wire/src/header.rs @@ -18,10 +18,10 @@ impl QlHeader { #[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] #[repr(C)] -pub(crate) struct QlRecordHeaderWire { - pub(crate) kind: u8, - pub(crate) sender: [u8; XID::SIZE], - pub(crate) recipient: [u8; XID::SIZE], +pub struct QlRecordHeaderWire { + pub kind: u8, + pub sender: [u8; XID::SIZE], + pub recipient: [u8; XID::SIZE], } #[derive(Debug, Clone, Copy)] diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs index 36d8ff7c..117c0ab7 100644 --- a/ql-wire/src/lib.rs +++ b/ql-wire/src/lib.rs @@ -4,6 +4,9 @@ #![allow(clippy::too_many_arguments)] +pub type Ref<'a, T> = zerocopy::Ref<&'a [u8], T>; +pub type RefMut<'a, T> = zerocopy::Ref<&'a mut [u8], T>; + mod codec; mod control; mod encrypted; diff --git a/ql-wire/src/pair/crypto.rs b/ql-wire/src/pair/crypto.rs index 466e935a..045e7b78 100644 --- a/ql-wire/src/pair/crypto.rs +++ b/ql-wire/src/pair/crypto.rs @@ -1,6 +1,6 @@ -use zerocopy::byte_slice::ByteSliceMut; +use zerocopy::{byte_slice::ByteSliceMut, Ref}; -use super::{PairRequestBody, PairRequestRecordRef}; +use super::{PairRequestBody, PairRequestRecordWire}; use crate::{ pq::ML_KEM_SUITE_TAG, ControlMeta, MlDsaPublicKey, MlKemCiphertext, MlKemPublicKey, QlCrypto, QlHeader, QlIdentity, QlPayload, QlRecord, WireError, XID, @@ -59,7 +59,7 @@ pub fn decrypt_pair_request( crypto: &impl QlCrypto, identity: &QlIdentity, header: &QlHeader, - request: &mut PairRequestRecordRef, + request: &mut Ref, now_seconds: u64, ) -> Result { let kem_ct = MlKemCiphertext::from_data(request.kem_ct); @@ -67,9 +67,13 @@ pub fn decrypt_pair_request( let session_key = identity .encapsulation_private_key .decapsulate_shared_secret(&kem_ct)?; - let mut encrypted = - crate::encrypted_message::EncryptedMessageWire::parse(&mut request.encrypted)?; - let plaintext = encrypted.decrypt(crypto, &session_key, &aad)?; + let mut encrypted = crate::encrypted_message::EncryptedMessage::parse(&mut request.encrypted)?; + let plaintext = crate::encrypted_message::EncryptedMessage::decrypt_in_place( + &mut encrypted, + crypto, + &session_key, + &aad, + )?; let decrypted = PairRequestBody::decode(plaintext)?; decrypted.meta.ensure_not_expired(now_seconds)?; if decrypted.xid != header.sender { diff --git a/ql-wire/src/pair/mod.rs b/ql-wire/src/pair/mod.rs index 5973b22e..53999b43 100644 --- a/ql-wire/src/pair/mod.rs +++ b/ql-wire/src/pair/mod.rs @@ -4,8 +4,8 @@ use zerocopy::{ use crate::{ codec::{parse, push_value, read_exact}, - control::{control_meta_from_wire, control_meta_to_wire, ControlMetaWire}, - encrypted_message::{EncryptedMessage, EncryptedMessageWire}, + control::ControlMetaWire, + encrypted_message::EncryptedMessage, ControlMeta, MlDsaPublicKey, MlDsaSignature, MlKemCiphertext, MlKemPublicKey, WireError, XID, }; @@ -34,37 +34,33 @@ pub struct PairRequestRecordWire { pub encrypted: [u8], } -pub type PairRequestRecordRef = Ref; - #[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] #[repr(C)] -struct PairRequestBodyWire { - meta: ControlMetaWire, - xid: [u8; XID::SIZE], - signing_pub_key: [u8; MlDsaPublicKey::SIZE], - encapsulation_pub_key: [u8; MlKemPublicKey::SIZE], - proof: [u8; MlDsaSignature::SIZE], +pub struct PairRequestBodyWire { + pub meta: ControlMetaWire, + pub xid: [u8; XID::SIZE], + pub signing_pub_key: [u8; MlDsaPublicKey::SIZE], + pub encapsulation_pub_key: [u8; MlKemPublicKey::SIZE], + pub proof: [u8; MlDsaSignature::SIZE], } -impl PairRequestRecordWire { - pub fn parse(bytes: B) -> Result, WireError> { - let record: PairRequestRecordRef = parse(bytes)?; - let _ = EncryptedMessageWire::parse(&record.encrypted)?; +impl PairRequestRecord { + pub fn parse(bytes: B) -> Result, WireError> { + let record: Ref = parse(bytes)?; + let _ = EncryptedMessage::parse(&record.encrypted)?; Ok(record) } - pub fn to_pair_request_record(&self) -> PairRequestRecord { - PairRequestRecord { - kem_ct: MlKemCiphertext::from_data(self.kem_ct), - encrypted: EncryptedMessageWire::parse(&self.encrypted) - .expect("validated pair request") - .to_encrypted_message(), + pub fn from_wire(wire: &PairRequestRecordWire) -> Self { + let encrypted = + EncryptedMessage::parse(&wire.encrypted).expect("validated pair request record"); + Self { + kem_ct: MlKemCiphertext::from_data(wire.kem_ct), + encrypted: EncryptedMessage::from_wire(&encrypted), } } -} -impl PairRequestRecord { - pub(crate) fn encode_into(&self, out: &mut Vec) { + pub fn encode_into(&self, out: &mut Vec) { push_value( out, &PairRequestHeaderWire { @@ -76,31 +72,39 @@ impl PairRequestRecord { } impl PairRequestBody { - pub(crate) fn encode(&self) -> Vec { - let wire = PairRequestBodyWire { - meta: control_meta_to_wire(&self.meta), + pub fn from_wire(wire: PairRequestBodyWire) -> Self { + Self { + meta: ControlMeta::from_wire(wire.meta), + xid: XID(wire.xid), + signing_pub_key: MlDsaPublicKey::from_data(wire.signing_pub_key), + encapsulation_pub_key: MlKemPublicKey::from_data(wire.encapsulation_pub_key), + proof: MlDsaSignature::from_data(wire.proof), + } + } + + pub fn to_wire(&self) -> PairRequestBodyWire { + PairRequestBodyWire { + meta: self.meta.to_wire(), xid: self.xid.0, signing_pub_key: *self.signing_pub_key.as_bytes(), encapsulation_pub_key: *self.encapsulation_pub_key.as_bytes(), proof: *self.proof.as_bytes(), - }; + } + } + + pub fn encode(&self) -> Vec { + let wire = self.to_wire(); wire.as_bytes().to_vec() } - pub(crate) fn decode(bytes: &[u8]) -> Result { + pub fn decode(bytes: &[u8]) -> Result { let wire: PairRequestBodyWire = read_exact(bytes)?; - Ok(Self { - meta: control_meta_from_wire(wire.meta), - xid: XID(wire.xid), - signing_pub_key: MlDsaPublicKey::from_data(wire.signing_pub_key), - encapsulation_pub_key: MlKemPublicKey::from_data(wire.encapsulation_pub_key), - proof: MlDsaSignature::from_data(wire.proof), - }) + Ok(Self::from_wire(wire)) } } #[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] #[repr(C)] -struct PairRequestHeaderWire { - kem_ct: [u8; MlKemCiphertext::SIZE], +pub struct PairRequestHeaderWire { + pub kem_ct: [u8; MlKemCiphertext::SIZE], } diff --git a/ql-wire/src/pq.rs b/ql-wire/src/pq.rs index 9f831668..76b594e1 100644 --- a/ql-wire/src/pq.rs +++ b/ql-wire/src/pq.rs @@ -77,6 +77,12 @@ impl MlDsaPublicKey { let signature = ml_dsa_87::MLDSA87Signature::new(*signature.as_bytes()); ml_dsa_87::verify(&verification_key, message, b"", &signature).is_ok() } + + pub fn verify_bytes(&self, signature: &[u8; MlDsaSignature::SIZE], message: &[u8]) -> bool { + let verification_key = ml_dsa_87::MLDSA87VerificationKey::new(*self.as_bytes()); + let signature = ml_dsa_87::MLDSA87Signature::new(*signature); + ml_dsa_87::verify(&verification_key, message, b"", &signature).is_ok() + } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -140,9 +146,16 @@ impl MlKemPrivateKey { pub fn decapsulate_shared_secret( &self, ciphertext: &MlKemCiphertext, + ) -> Result { + self.decapsulate_shared_secret_bytes(ciphertext.as_bytes()) + } + + pub fn decapsulate_shared_secret_bytes( + &self, + ciphertext: &[u8; MlKemCiphertext::SIZE], ) -> Result { let private_key = mlkem1024::MlKem1024PrivateKey::from(self.as_bytes()); - let ciphertext = mlkem1024::MlKem1024Ciphertext::from(ciphertext.as_bytes()); + let ciphertext = mlkem1024::MlKem1024Ciphertext::from(ciphertext); let shared_secret = mlkem1024::decapsulate(&private_key, &ciphertext); Ok(SessionKey::from_data(shared_secret)) } diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index 79e67c58..666a93b1 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -1,14 +1,15 @@ use zerocopy::{ byte_slice::{ByteSlice, SplitByteSlice}, - Immutable, IntoBytes, KnownLayout, TryFromBytes, Unaligned, + Immutable, IntoBytes, KnownLayout, Ref, TryFromBytes, Unaligned, }; use crate::{ codec, - encrypted_message::{EncryptedMessage, EncryptedMessageRef, EncryptedMessageWire}, - handshake, + encrypted_message::{EncryptedMessage, EncryptedMessageWire}, + handshake::{self, ConfirmWire, HelloReplyWire, HelloWire}, header::{decode_record_header, encode_record_header, QlHeader}, - pair, WireError, QL_WIRE_VERSION, + pair::{self, PairRequestRecordWire}, + WireError, QL_WIRE_VERSION, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -33,12 +34,12 @@ pub struct QlRecordRef { } pub enum QlPayloadRef { - PairRequest(pair::PairRequestRecordRef), - Hello(handshake::Hello), - HelloReply(handshake::HelloReply), - Confirm(handshake::Confirm), - Ready(handshake::ReadyRef), - Session(EncryptedMessageRef), + PairRequest(Ref), + Hello(Ref), + HelloReply(Ref), + Confirm(Ref), + Ready(Ref), + Session(Ref), } #[derive( @@ -124,29 +125,31 @@ impl QlRecordRef { impl QlPayloadRef { pub fn to_owned(&self) -> QlPayload { match self { - Self::PairRequest(request) => QlPayload::PairRequest(request.to_pair_request_record()), - Self::Hello(hello) => QlPayload::Hello(hello.clone()), - Self::HelloReply(reply) => QlPayload::HelloReply(reply.clone()), - Self::Confirm(confirm) => QlPayload::Confirm(confirm.clone()), - Self::Ready(ready) => QlPayload::Ready(handshake::Ready { - encrypted: ready.to_encrypted_message(), - }), - Self::Session(encrypted) => QlPayload::Session(encrypted.to_encrypted_message()), + Self::PairRequest(request) => { + QlPayload::PairRequest(pair::PairRequestRecord::from_wire(request)) + } + Self::Hello(hello) => QlPayload::Hello(handshake::Hello::from_wire(hello)), + Self::HelloReply(reply) => { + QlPayload::HelloReply(handshake::HelloReply::from_wire(reply)) + } + Self::Confirm(confirm) => QlPayload::Confirm(handshake::Confirm::from_wire(confirm)), + Self::Ready(ready) => QlPayload::Ready(handshake::Ready::from_wire(ready)), + Self::Session(encrypted) => QlPayload::Session(EncryptedMessage::from_wire(encrypted)), } } } fn parse_payload(kind: RecordKind, payload: B) -> Result, WireError> { match kind { - RecordKind::PairRequest => Ok(QlPayloadRef::PairRequest( - pair::PairRequestRecordWire::parse(payload)?, - )), - RecordKind::Hello => Ok(QlPayloadRef::Hello(handshake::Hello::decode(&payload)?)), - RecordKind::HelloReply => Ok(QlPayloadRef::HelloReply(handshake::HelloReply::decode( - &payload, + RecordKind::PairRequest => Ok(QlPayloadRef::PairRequest(pair::PairRequestRecord::parse( + payload, + )?)), + RecordKind::Hello => Ok(QlPayloadRef::Hello(handshake::Hello::parse(payload)?)), + RecordKind::HelloReply => Ok(QlPayloadRef::HelloReply(handshake::HelloReply::parse( + payload, )?)), - RecordKind::Confirm => Ok(QlPayloadRef::Confirm(handshake::Confirm::decode(&payload)?)), - RecordKind::Ready => Ok(QlPayloadRef::Ready(EncryptedMessageWire::parse(payload)?)), - RecordKind::Session => Ok(QlPayloadRef::Session(EncryptedMessageWire::parse(payload)?)), + RecordKind::Confirm => Ok(QlPayloadRef::Confirm(handshake::Confirm::parse(payload)?)), + RecordKind::Ready => Ok(QlPayloadRef::Ready(handshake::Ready::parse(payload)?)), + RecordKind::Session => Ok(QlPayloadRef::Session(EncryptedMessage::parse(payload)?)), } } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index f0f574bf..929c3773 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -110,7 +110,7 @@ fn encrypted_session_record_round_trip_and_decrypt() { }; let decrypted = encrypted::decrypt_record(&crypto, &header, &mut encrypted, &session_key).unwrap(); - assert_eq!(decrypted.to_session_envelope().unwrap(), body); + assert_eq!(SessionEnvelope::from_wire(&decrypted).unwrap(), body); } #[test] From 75ff4bd33cef3c420d4f9f8c092d37f3d75aad48 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 18 Mar 2026 11:02:02 -0400 Subject: [PATCH 015/304] ql-runtime: get rid of useless enum --- ql-runtime/src/driver.rs | 9 ++++----- ql-runtime/src/lib.rs | 5 ----- ql-runtime/src/platform.rs | 2 +- ql-runtime/src/tests/handshake.rs | 8 ++------ ql-runtime/src/tests/heartbeat.rs | 4 +--- ql-runtime/src/tests/mod.rs | 10 +++++----- ql-runtime/src/tests/stream.rs | 24 ++++++------------------ ql-runtime/src/tests/unpair.rs | 4 +--- 8 files changed, 20 insertions(+), 46 deletions(-) diff --git a/ql-runtime/src/driver.rs b/ql-runtime/src/driver.rs index f5467f03..483ddbf2 100644 --- a/ql-runtime/src/driver.rs +++ b/ql-runtime/src/driver.rs @@ -12,8 +12,7 @@ use crate::{ command::RuntimeCommand, handle::{ByteReader, ByteWriter, InboundStream}, platform::{PlatformFuture, QlPlatform}, - CloseCode, CloseTarget, HandlerEvent, InboundEvent, OpenedStreamDelivery, QlError, Runtime, - StreamId, + CloseCode, CloseTarget, InboundEvent, OpenedStreamDelivery, QlError, Runtime, StreamId, }; struct InFlightWrite<'a> { @@ -343,7 +342,7 @@ impl DriverState { }, ); - platform.handle_inbound(HandlerEvent::Stream(InboundStream { + platform.handle_inbound(InboundStream { stream_id, request: ByteReader::new( stream_id, @@ -357,7 +356,7 @@ impl DriverState { response_writer, self.runtime_tx.clone(), ), - })); + }); } fn handle_inbound_readable(&mut self, stream_id: StreamId) { @@ -661,7 +660,7 @@ mod tests { fn handle_peer_status(&self, _peer: XID, _status: ql_fsm::PeerStatus) {} - fn handle_inbound(&self, _event: HandlerEvent) {} + fn handle_inbound(&self, _event: InboundStream) {} } fn new_driver_state() -> DriverState { diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index 590d16b1..b51093c0 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -90,11 +90,6 @@ impl RuntimeConfig { } } -#[derive(Debug)] -pub enum HandlerEvent { - Stream(InboundStream), -} - #[derive(Debug)] pub(crate) enum InboundEvent { Data(Vec), diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs index 08660674..bb690242 100644 --- a/ql-runtime/src/platform.rs +++ b/ql-runtime/src/platform.rs @@ -15,5 +15,5 @@ pub trait QlPlatform: QlCrypto { fn clear_peer(&self); fn handle_peer_status(&self, peer: XID, status: PeerStatus); - fn handle_inbound(&self, event: super::HandlerEvent); + fn handle_inbound(&self, event: super::InboundStream); } diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index c95db6e5..7db4cef6 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -50,9 +50,7 @@ async fn opening_stream_auto_connects() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); let responder = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; + let stream = inbound_b.recv().await.unwrap(); let request = read_all(stream.request).await.unwrap(); stream.response.finish().await.unwrap(); request @@ -131,9 +129,7 @@ async fn rejected_session_write_is_reissued() { await_status(&status_b, identity_a.xid, PeerStage::Connected).await; let responder = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; + let stream = inbound_b.recv().await.unwrap(); let request = read_all(stream.request).await.unwrap(); stream.response.finish().await.unwrap(); request diff --git a/ql-runtime/src/tests/heartbeat.rs b/ql-runtime/src/tests/heartbeat.rs index 2d1ba718..12231080 100644 --- a/ql-runtime/src/tests/heartbeat.rs +++ b/ql-runtime/src/tests/heartbeat.rs @@ -42,9 +42,7 @@ async fn session_timeout_disconnects_and_fails_pending_open() { await_status(&status_b, identity_a.xid, PeerStage::Connected).await; let responder_task = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; + let stream = inbound_b.recv().await.unwrap(); let _ = read_all(stream.request).await; let response = stream.response; let _ = response.finish().await; diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index b789b41a..d7391b5f 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -18,7 +18,7 @@ use sha2::{Digest, Sha256}; use tokio::task::LocalSet; use crate::{ - new_runtime, platform::PlatformFuture, HandlerEvent, Peer, PeerStatus, QlError, QlFsmConfig, + new_runtime, platform::PlatformFuture, InboundStream, Peer, PeerStatus, QlError, QlFsmConfig, RuntimeConfig, RuntimeHandle, }; @@ -128,7 +128,7 @@ impl QlCrypto for DeterministicCrypto { struct TestPlatform { outbound: Sender>, status: Sender, - inbound: Option>, + inbound: Option>, nonce_seed: u8, nonce_counter: AtomicU8, encrypted_write_counter: AtomicUsize, @@ -148,7 +148,7 @@ impl TestPlatform { Self, Receiver>, Receiver, - Receiver, + Receiver, ) { let (inbound_tx, inbound_rx) = async_channel::unbounded(); let (platform, outbound_rx, status_rx) = @@ -179,7 +179,7 @@ impl TestPlatform { fn new_inner( seed: u8, - inbound: Option>, + inbound: Option>, fail_encrypted_write_at: Option, write_delay: Duration, write_stats: Option, @@ -318,7 +318,7 @@ impl crate::platform::QlPlatform for TestPlatform { let _ = self.status.try_send(StatusEvent { peer, stage }); } - fn handle_inbound(&self, event: HandlerEvent) { + fn handle_inbound(&self, event: InboundStream) { if let Some(tx) = &self.inbound { let _ = tx.try_send(event); } diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index 2a092dc3..f66096b1 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -28,9 +28,7 @@ async fn open_stream_duplex_happy_path() { await_status(&status_b, identity_a.xid, PeerStage::Connected).await; let responder = tokio::task::spawn_local(async move { - let inbound = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; + let inbound = inbound_b.recv().await.unwrap(); let mut request = inbound.request; let mut response = inbound.response; @@ -93,9 +91,7 @@ async fn stream_backpressure_with_small_runtime_buffer() { await_status(&status_b, identity_a.xid, PeerStage::Connected).await; let responder = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; + let stream = inbound_b.recv().await.unwrap(); let request_data = read_all(stream.request).await.unwrap(); stream.response.finish().await.unwrap(); done_tx.send(request_data).await.unwrap(); @@ -145,9 +141,7 @@ async fn dropping_responder_closes_initiator_response() { await_status(&status_b, identity_a.xid, PeerStage::Connected).await; let responder = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; + let stream = inbound_b.recv().await.unwrap(); drop(stream.response); }); @@ -201,9 +195,7 @@ async fn dropping_inbound_reader_cancels_remote_writer() { await_status(&status_b, identity_a.xid, PeerStage::Connected).await; let responder = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; + let stream = inbound_b.recv().await.unwrap(); let mut request = stream.request; let mut response = stream.response; assert_eq!(request.next_chunk().await.unwrap(), None); @@ -261,9 +253,7 @@ async fn max_concurrent_message_writes_is_respected() { let responder = tokio::task::spawn_local(async move { for _ in 0..4 { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; + let stream = inbound_b.recv().await.unwrap(); let _ = read_all(stream.request).await; let _ = stream.response.finish().await; } @@ -337,9 +327,7 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { await_status(&status_b, identity_a.xid, PeerStage::Connected).await; let responder = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; + let stream = inbound_b.recv().await.unwrap(); let received_request = read_all(stream.request).await.unwrap(); let mut response = stream.response; response.write_all(&response_payload).await.unwrap(); diff --git a/ql-runtime/src/tests/unpair.rs b/ql-runtime/src/tests/unpair.rs index 24791416..74898fae 100644 --- a/ql-runtime/src/tests/unpair.rs +++ b/ql-runtime/src/tests/unpair.rs @@ -25,9 +25,7 @@ async fn unpair_clears_remote_peer_and_aborts_active_stream() { await_status(&status_b, identity_a.xid, PeerStage::Connected).await; let responder = tokio::task::spawn_local(async move { - let stream = match inbound_b.recv().await.unwrap() { - HandlerEvent::Stream(stream) => stream, - }; + let stream = inbound_b.recv().await.unwrap(); let mut request = stream.request; let _ = request.next_chunk().await; let second = request.next_chunk().await; From d7a3747213cc5dab2ebecebaed2c349919aa305d Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 18 Mar 2026 11:37:53 -0400 Subject: [PATCH 016/304] ql-fsm: crate docs --- ql-fsm/src/lib.rs | 109 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 96 insertions(+), 13 deletions(-) diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 1f6fc695..6eb4c7ac 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -1,3 +1,22 @@ +//! sync finite state machine for quantum link protocol +//! +//! a caller drives `QlFsm` inside its own event loop +//! +//! inputs to that loop usually include +//! - app actions like `bind_peer`, `pair`, `connect`, `open_stream`, or `write_stream` +//! - inbound transport bytes passed to `receive` +//! - a deadline expiring, handled by calling `on_timer` +//! - transport write results passed to `confirm_session_write` or `reject_session_write` +//! +//! outputs from `QlFsm` are +//! - outbound records from `take_next_write` +//! - peer events from `take_next_event` +//! - session events from `take_next_session_event` +//! +//! call `next_deadline` after handling current inputs and draining current outputs +//! use it to decide how long the outer loop can wait before `on_timer` must run +//! another input may arrive before that deadline, which is fine + mod error; pub(crate) mod implementation; pub(crate) mod replay_cache; @@ -20,64 +39,110 @@ use crate::{ state::{PeerRecord, QlFsmState}, }; +/// time input for `QlFsm` #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct FsmTime { + /// monotonic time used for local deadlines pub instant: Instant, + /// wall-clock unix time used for expiration checks pub unix_secs: u64, } +/// bound remote peer identity and public keys #[derive(Debug, Clone, PartialEq, Eq)] pub struct Peer { + /// peer xid pub xid: XID, + /// peer signing public key pub signing_key: MlDsaPublicKey, + /// peer encapsulation public key pub encapsulation_key: MlKemPublicKey, } +/// connection state for the bound peer #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PeerStatus { + /// no active encrypted session Disconnected, + /// we are driving the handshake Initiator, + /// the peer is driving the handshake Responder, + /// the encrypted session is up Connected, } +/// peer-level events emitted by `QlFsm` #[derive(Debug, Clone)] pub enum QlFsmEvent { + /// a peer was bound or replaced NewPeer(Peer), + /// the bound peer was cleared ClearPeer, - PeerStatusChanged { peer: XID, status: PeerStatus }, + /// the peer changed connection state + PeerStatusChanged { + /// peer that changed state + peer: XID, + /// new connection state + status: PeerStatus, + }, } +/// session and stream events emitted by `QlFsm` #[derive(Debug, Clone, PartialEq, Eq)] pub enum QlSessionEvent { + /// a stream was opened Opened(StreamId), + /// a stream has bytes ready to read Readable(StreamId), + /// the peer finished writing this stream Finished(StreamId), + /// a stream was closed Closed(StreamClose), + /// local writes on this stream are closed WritableClosed(StreamId), + /// the peer requested unpairing Unpaired, + /// the encrypted session was closed SessionClosed(SessionCloseBody), } +/// handle for a session write returned by `QlFsm::take_next_write` #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct SessionWriteId(pub SessionSeq); +pub struct SessionWriteId( + /// session sequence number for this write + pub SessionSeq, +); +/// outbound record produced by `QlFsm` #[derive(Debug, Clone, PartialEq)] pub struct OutboundWrite { + /// record to hand to the transport pub record: QlRecord, + /// write handle that must be confirmed or rejected pub session_write_id: Option, } +/// timing and buffering knobs for `QlFsm` #[derive(Debug, Clone, Copy)] pub struct QlFsmConfig { + /// overall time limit for one handshake attempt pub handshake_timeout: Duration, + /// delay before retrying the current handshake message pub handshake_retry_interval: Duration, + /// maximum retries for each handshake step pub max_handshake_retries: u8, + /// how far into the future control messages remain valid pub control_expiration: Duration, + /// delay before sending a pure ack pub session_ack_delay: Duration, + /// how long to wait before resending unacked session data pub session_retransmit_timeout: Duration, + /// idle delay before sending a keepalive ping pub session_keepalive_interval: Duration, + /// how long to wait before declaring the peer dead pub session_peer_timeout: Duration, + /// maximum bytes per outbound stream chunk pub session_stream_chunk_size: usize, } @@ -97,8 +162,11 @@ impl Default for QlFsmConfig { } } +/// synchronous driver for pairing, handshake, and encrypted streams pub struct QlFsm { + /// active configuration pub config: QlFsmConfig, + /// local identity and private keys pub identity: QlIdentity, pub(crate) peer: Option, pub(crate) session: SessionFsm, @@ -106,6 +174,7 @@ pub struct QlFsm { } impl QlFsm { + /// creates a new `QlFsm` pub fn new(config: QlFsmConfig, identity: QlIdentity, now: FsmTime) -> Self { Self { config, @@ -133,20 +202,24 @@ impl QlFsm { } } + /// binds or replaces the remote peer pub fn bind_peer(&mut self, peer: Peer) { implementation::handle_bind_peer(self, peer); } + /// queues a pair request for the bound peer pub fn pair(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { self.state.now = now; implementation::handle_pair_local(self, crypto) } + /// starts or resumes the encrypted session handshake pub fn connect(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { self.state.now = now; implementation::handle_connect(self, crypto) } + /// handles one inbound wire message pub fn receive( &mut self, now: FsmTime, @@ -157,21 +230,23 @@ impl QlFsm { implementation::receive(self, bytes, crypto) } + /// advances time-based state pub fn on_timer(&mut self, now: FsmTime) { self.state.now = now; implementation::on_timer(self); } + /// returns the next timer deadline, if any pub fn next_deadline(&self) -> Option { implementation::next_deadline(self) } - /// Returns the next outbound record. + /// returns the next outbound record /// - /// If `session_write_id` is `Some`, it must be followed by exactly one of - /// [`Self::confirm_session_write`] or [`Self::return_session_write`]. + /// if `session_write_id` is `Some`, call exactly one of + /// `confirm_session_write` or `reject_session_write` /// - /// If `session_write_id` is `None`, the record is fire-and-forget. + /// if it is `None`, the record is fire-and-forget pub fn take_next_write( &mut self, now: FsmTime, @@ -181,40 +256,42 @@ impl QlFsm { implementation::take_next_write(self, crypto) } - /// Marks a previously issued session write as successfully handed to the transport. + /// marks a `SessionWriteId` from `take_next_write` as handed to the transport /// - /// This must be called at most once for a `SessionWriteId` returned by - /// [`Self::take_next_write`] whose `session_write_id` was `Some`. + /// call this at most once for each returned `SessionWriteId` pub fn confirm_session_write(&mut self, now: FsmTime, write_id: SessionWriteId) { self.state.now = now; implementation::confirm_session_write(self, write_id); } - /// Reports that a previously issued session write was not accepted by the transport. + /// reports that a `SessionWriteId` from `take_next_write` was not accepted /// - /// This must be called at most once for a `SessionWriteId` returned by - /// [`Self::take_next_write`] whose `session_write_id` was `Some`. + /// call this at most once for each returned `SessionWriteId` pub fn reject_session_write(&mut self, write_id: SessionWriteId) { implementation::reject_session_write(self, write_id); } - /// Aborts the current encrypted session locally. + /// closes the current encrypted session locally pub fn kill_session(&mut self, code: CloseCode) { implementation::kill_session(self, code); } + /// returns the next peer-level event pub fn take_next_event(&mut self) -> Option { implementation::take_next_event(self) } + /// opens a new outgoing stream pub fn open_stream(&mut self) -> Result { implementation::open_stream(self) } + /// queues bytes for an open stream pub fn write_stream(&mut self, stream_id: StreamId, bytes: Vec) -> Result<(), QlFsmError> { implementation::write_stream(self, stream_id, bytes) } + /// reads queued bytes from a stream into `out` pub fn read_stream( &mut self, stream_id: StreamId, @@ -223,14 +300,17 @@ impl QlFsm { implementation::read_stream(self, stream_id, out) } + /// returns how many bytes can be read from a stream pub fn stream_available_bytes(&self, stream_id: StreamId) -> Result { implementation::stream_available_bytes(self, stream_id) } + /// marks the local write side as finished pub fn finish_stream(&mut self, stream_id: StreamId) -> Result<(), QlFsmError> { implementation::finish_stream(self, stream_id) } + /// closes part or all of a stream pub fn close_stream( &mut self, stream_id: StreamId, @@ -241,14 +321,17 @@ impl QlFsm { implementation::close_stream(self, stream_id, target, code, payload) } + /// queues a ping on the active session pub fn queue_ping(&mut self) -> Result<(), QlFsmError> { implementation::queue_ping(self) } + /// queues an unpair request on the active session pub fn queue_unpair(&mut self) -> Result<(), QlFsmError> { implementation::queue_unpair(self) } + /// returns the next session or stream event pub fn take_next_session_event(&mut self) -> Option { implementation::take_next_session_event(self) } From edc12bc554270063ccf4b298421e7c451040a946 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 18 Mar 2026 13:11:11 -0400 Subject: [PATCH 017/304] ql-rpc --- Cargo.lock | 11 + Cargo.toml | 2 + ql-rpc/Cargo.toml | 11 + ql-rpc/src/codec.rs | 103 +++++++ ql-rpc/src/error.rs | 78 ++++++ ql-rpc/src/header.rs | 41 +++ ql-rpc/src/lib.rs | 35 +++ ql-rpc/src/rpc/mod.rs | 9 + ql-rpc/src/rpc/notification.rs | 61 ++++ ql-rpc/src/rpc/request.rs | 86 ++++++ ql-rpc/src/rpc/request_with_progress.rs | 295 ++++++++++++++++++++ ql-rpc/src/rpc/subscription.rs | 212 ++++++++++++++ ql-runtime/Cargo.toml | 6 + ql-runtime/src/handle.rs | 43 ++- ql-runtime/src/lib.rs | 2 + ql-runtime/src/rpc/error.rs | 55 ++++ ql-runtime/src/rpc/mod.rs | 153 ++++++++++ ql-runtime/src/rpc/request_with_progress.rs | 134 +++++++++ ql-runtime/src/rpc/subscription.rs | 73 +++++ ql-runtime/src/tests/mod.rs | 2 + ql-runtime/src/tests/rpc.rs | 248 ++++++++++++++++ 21 files changed, 1650 insertions(+), 10 deletions(-) create mode 100644 ql-rpc/Cargo.toml create mode 100644 ql-rpc/src/codec.rs create mode 100644 ql-rpc/src/error.rs create mode 100644 ql-rpc/src/header.rs create mode 100644 ql-rpc/src/lib.rs create mode 100644 ql-rpc/src/rpc/mod.rs create mode 100644 ql-rpc/src/rpc/notification.rs create mode 100644 ql-rpc/src/rpc/request.rs create mode 100644 ql-rpc/src/rpc/request_with_progress.rs create mode 100644 ql-rpc/src/rpc/subscription.rs create mode 100644 ql-runtime/src/rpc/error.rs create mode 100644 ql-runtime/src/rpc/mod.rs create mode 100644 ql-runtime/src/rpc/request_with_progress.rs create mode 100644 ql-runtime/src/rpc/subscription.rs create mode 100644 ql-runtime/src/tests/rpc.rs diff --git a/Cargo.lock b/Cargo.lock index c0e954a0..c89a427b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2158,16 +2158,27 @@ dependencies = [ "thiserror", ] +[[package]] +name = "ql-rpc" +version = "0.1.0" +dependencies = [ + "bytes", + "ql-wire", + "thiserror", +] + [[package]] name = "ql-runtime" version = "0.1.0" dependencies = [ "async-channel", + "bytes", "futures-lite", "libcrux-aesgcm", "oneshot", "piper", "ql-fsm", + "ql-rpc", "ql-wire", "sha2", "thiserror", diff --git a/Cargo.toml b/Cargo.toml index d56111c1..a8a932ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "backup-shard", "btp", "ql-fsm", + "ql-rpc", "ql-runtime", "ql-wire", "quantum-link-macros", @@ -34,6 +35,7 @@ foundation-api = { path = "api" } quantum-link-macros = { path = "quantum-link-macros" } ql-protocol = { path = "ql-protocol" } ql-fsm = { path = "ql-fsm" } +ql-rpc = { path = "ql-rpc" } ql-wire = { path = "ql-wire" } [patch.crates-io] diff --git a/ql-rpc/Cargo.toml b/ql-rpc/Cargo.toml new file mode 100644 index 00000000..a7df20c9 --- /dev/null +++ b/ql-rpc/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "ql-rpc" +version = "0.1.0" +edition = "2021" +description = "Quantum Link RPC protocol traits and framing" +license = "Proprietary" + +[dependencies] +bytes = { version = "1" } +ql-wire = { path = "../ql-wire" } +thiserror = { version = "2" } diff --git a/ql-rpc/src/codec.rs b/ql-rpc/src/codec.rs new file mode 100644 index 00000000..b6786a84 --- /dev/null +++ b/ql-rpc/src/codec.rs @@ -0,0 +1,103 @@ +use std::collections::VecDeque; + +use bytes::Buf; + +use crate::{RpcCodec, RpcError}; + +const LENGTH_SIZE: usize = 8; + +pub fn encode_value_part(value: &T, out: &mut Vec) -> Result<(), T::Error> { + let mut payload = Vec::new(); + value.encode_value(&mut payload)?; + push_length(out, payload.len()); + out.extend_from_slice(&payload); + Ok(()) +} + +pub fn try_measure_next_part(mut bytes: B) -> Result, RpcError> { + if bytes.remaining() < LENGTH_SIZE { + return Ok(None); + } + + let len = bytes.get_u64_le(); + let len: usize = len.try_into().map_err(|_| RpcError::LengthOverflow)?; + let consumed = LENGTH_SIZE + .checked_add(len) + .ok_or(RpcError::LengthOverflow)?; + if bytes.remaining() < len { + return Ok(None); + } + + Ok(Some((consumed, len))) +} + +pub fn try_measure_next_tagged_part( + mut bytes: B, +) -> Result, RpcError> { + if !bytes.has_remaining() { + return Ok(None); + } + + let kind = bytes.get_u8(); + let Some((consumed, len)) = try_measure_next_part(bytes)? else { + return Ok(None); + }; + + Ok(Some((kind, 1 + consumed, len))) +} + +pub struct DrainBuf<'a> { + bytes: &'a mut VecDeque, + offset: usize, + len: usize, +} + +impl<'a> DrainBuf<'a> { + pub fn new(bytes: &'a mut VecDeque, len: usize) -> Self { + debug_assert!(bytes.len() >= len); + Self { + bytes, + offset: 0, + len, + } + } +} + +impl Buf for DrainBuf<'_> { + fn remaining(&self) -> usize { + self.len - self.offset + } + + fn chunk(&self) -> &[u8] { + if self.remaining() == 0 { + return &[]; + } + + let (first, second) = self.bytes.as_slices(); + if self.offset < first.len() { + let start = self.offset; + let end = (start + self.remaining()).min(first.len()); + &first[start..end] + } else { + let start = self.offset - first.len(); + let end = (start + self.remaining()).min(second.len()); + &second[start..end] + } + } + + fn advance(&mut self, cnt: usize) { + assert!(cnt <= self.remaining(), "advanced past payload boundary"); + self.offset += cnt; + } +} + +impl Drop for DrainBuf<'_> { + fn drop(&mut self) { + self.bytes.drain(..self.len); + } +} + +pub fn push_length(out: &mut Vec, len: usize) { + let len = u64::try_from(len).expect("rpc payload exceeds u64 length framing"); + out.extend_from_slice(&len.to_le_bytes()); +} diff --git a/ql-rpc/src/error.rs b/ql-rpc/src/error.rs new file mode 100644 index 00000000..95b4fff7 --- /dev/null +++ b/ql-rpc/src/error.rs @@ -0,0 +1,78 @@ +use ql_wire::CloseCode; + +use crate::MethodId; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +pub enum RpcError { + #[error("truncated rpc payload")] + Truncated, + #[error("rpc payload length overflow")] + LengthOverflow, + #[error("invalid rpc version {0}")] + InvalidVersion(u8), + #[error("unexpected rpc method {actual:?}, expected {expected:?}")] + UnexpectedMethod { + expected: MethodId, + actual: MethodId, + }, + #[error("unexpected rpc frame kind {0}")] + UnexpectedFrameKind(u8), + #[error("missing terminal rpc response")] + MissingResponse, + #[error("trailing rpc bytes")] + TrailingBytes, +} + +impl RpcError { + pub const fn close_code(self) -> CloseCode { + match self { + Self::UnexpectedMethod { .. } => CloseCode::UNKNOWN_ROUTE, + Self::Truncated + | Self::LengthOverflow + | Self::InvalidVersion(_) + | Self::UnexpectedFrameKind(_) + | Self::MissingResponse + | Self::TrailingBytes => CloseCode::INVALID_HEAD, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum RpcCodecError { + Rpc(RpcError), + Codec(E), +} + +impl std::error::Error for RpcCodecError +where + E: std::error::Error + 'static, +{ + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + RpcCodecError::Rpc(e) => Some(e), + RpcCodecError::Codec(e) => Some(e), + } + } + + fn cause(&self) -> Option<&dyn std::error::Error> { + self.source() + } +} + +impl std::fmt::Display for RpcCodecError +where + E: std::fmt::Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RpcCodecError::Rpc(e) => write!(f, "{e}"), + RpcCodecError::Codec(e) => write!(f, "{e}"), + } + } +} + +impl From for RpcCodecError { + fn from(error: RpcError) -> Self { + Self::Rpc(error) + } +} diff --git a/ql-rpc/src/header.rs b/ql-rpc/src/header.rs new file mode 100644 index 00000000..cd7fa83f --- /dev/null +++ b/ql-rpc/src/header.rs @@ -0,0 +1,41 @@ +use crate::{MethodId, RpcError, RPC_VERSION}; + +const HEADER_SIZE: usize = 1 + 8; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RpcHeader { + pub version: u8, + pub method: MethodId, +} + +impl RpcHeader { + pub const WIRE_SIZE: usize = HEADER_SIZE; + + pub const fn new(method: MethodId) -> Self { + Self { + version: RPC_VERSION, + method, + } + } + + pub fn encode_into(&self, out: &mut Vec) { + out.push(self.version); + out.extend_from_slice(&self.method.0.to_le_bytes()); + } + + pub fn decode(bytes: &[u8]) -> Result<(Self, &[u8]), RpcError> { + if bytes.len() < Self::WIRE_SIZE { + return Err(RpcError::Truncated); + } + + let version = bytes[0]; + if version != RPC_VERSION { + return Err(RpcError::InvalidVersion(version)); + } + + let method = MethodId(u64::from_le_bytes( + bytes[1..Self::WIRE_SIZE].try_into().unwrap(), + )); + Ok((Self { version, method }, &bytes[Self::WIRE_SIZE..])) + } +} diff --git a/ql-rpc/src/lib.rs b/ql-rpc/src/lib.rs new file mode 100644 index 00000000..e8252136 --- /dev/null +++ b/ql-rpc/src/lib.rs @@ -0,0 +1,35 @@ +//! quantum link rpc protocol traits and framing helpers. + +use bytes::Buf; + +pub(crate) mod codec; +mod error; +pub mod header; +pub mod rpc; + +pub use error::*; +pub use rpc::*; + +pub const RPC_VERSION: u8 = 1; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct MethodId(pub u64); + +pub trait RpcCodec: Sized { + type Error; + + fn encode_value(&self, out: &mut Vec) -> Result<(), Self::Error>; + fn decode_value(bytes: &mut B) -> Result; +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Inbound<'a> { + pub header: header::RpcHeader, + pub body: &'a [u8], +} + +pub fn parse_inbound(bytes: &[u8]) -> Result, RpcError> { + let (header, body) = header::RpcHeader::decode(bytes)?; + Ok(Inbound { header, body }) +} diff --git a/ql-rpc/src/rpc/mod.rs b/ql-rpc/src/rpc/mod.rs new file mode 100644 index 00000000..d61a88a6 --- /dev/null +++ b/ql-rpc/src/rpc/mod.rs @@ -0,0 +1,9 @@ +pub mod notification; +pub mod request; +pub mod request_with_progress; +pub mod subscription; + +pub use notification::Notification; +pub use request::Request; +pub use request_with_progress::RequestWithProgress; +pub use subscription::Subscription; diff --git a/ql-rpc/src/rpc/notification.rs b/ql-rpc/src/rpc/notification.rs new file mode 100644 index 00000000..e288e043 --- /dev/null +++ b/ql-rpc/src/rpc/notification.rs @@ -0,0 +1,61 @@ +use crate::{MethodId, RpcCodec}; + +pub trait Notification { + const METHOD: MethodId; + type Error; + type Event: RpcCodec; +} + +pub fn encode_event(event: &M::Event, out: &mut Vec) -> Result<(), M::Error> { + crate::header::RpcHeader::new(M::METHOD).encode_into(out); + event.encode_value(out) +} + +pub fn decode_event(mut body: &[u8]) -> Result { + M::Event::decode_value(&mut body) +} + +#[cfg(test)] +mod tests { + use bytes::Buf; + + use super::{decode_event, encode_event, Notification}; + use crate::{parse_inbound, MethodId, RpcCodec}; + + #[derive(Debug, Clone, PartialEq, Eq)] + struct BytesValue(Vec); + + impl RpcCodec for BytesValue { + type Error = core::convert::Infallible; + + fn encode_value(&self, out: &mut Vec) -> Result<(), Self::Error> { + out.extend_from_slice(&self.0); + Ok(()) + } + + fn decode_value(bytes: &mut B) -> Result { + Ok(Self(bytes.copy_to_bytes(bytes.remaining()).to_vec())) + } + } + + struct Notify; + + impl Notification for Notify { + const METHOD: MethodId = MethodId(13); + type Error = core::convert::Infallible; + type Event = BytesValue; + } + + #[test] + fn event_round_trip_preserves_header_and_payload() { + let mut encoded = Vec::new(); + encode_event::(&BytesValue(b"hello".to_vec()), &mut encoded).unwrap(); + + let inbound = parse_inbound(&encoded).unwrap(); + assert_eq!(inbound.header.method, Notify::METHOD); + assert_eq!( + decode_event::(inbound.body).unwrap(), + BytesValue(b"hello".to_vec()) + ); + } +} diff --git a/ql-rpc/src/rpc/request.rs b/ql-rpc/src/rpc/request.rs new file mode 100644 index 00000000..e7f4f5fb --- /dev/null +++ b/ql-rpc/src/rpc/request.rs @@ -0,0 +1,86 @@ +use crate::{MethodId, RpcCodec}; + +pub trait Request { + const METHOD: MethodId; + type Error; + type Request: RpcCodec; + type Response: RpcCodec; +} + +pub fn encode_request(request: &M::Request, out: &mut Vec) -> Result<(), M::Error> { + crate::header::RpcHeader::new(M::METHOD).encode_into(out); + request.encode_value(out) +} + +pub fn decode_request(body: &[u8]) -> Result { + let mut body = body; + M::Request::decode_value(&mut body) +} + +pub fn encode_response( + response: &M::Response, + out: &mut Vec, +) -> Result<(), M::Error> { + response.encode_value(out) +} + +pub fn decode_response(bytes: &[u8]) -> Result { + let mut bytes = bytes; + M::Response::decode_value(&mut bytes) +} + +#[cfg(test)] +mod tests { + use bytes::Buf; + + use super::*; + use crate::{parse_inbound, MethodId, RpcCodec}; + + #[derive(Debug, Clone, PartialEq, Eq)] + struct BytesValue(Vec); + + impl RpcCodec for BytesValue { + type Error = core::convert::Infallible; + + fn encode_value(&self, out: &mut Vec) -> Result<(), Self::Error> { + out.extend_from_slice(&self.0); + Ok(()) + } + + fn decode_value(bytes: &mut B) -> Result { + Ok(Self(bytes.copy_to_bytes(bytes.remaining()).to_vec())) + } + } + + struct Echo; + + impl Request for Echo { + const METHOD: MethodId = MethodId(7); + type Error = core::convert::Infallible; + type Request = BytesValue; + type Response = BytesValue; + } + + #[test] + fn request_round_trip_preserves_header_and_payload() { + let mut encoded = Vec::new(); + encode_request::(&BytesValue(b"hello".to_vec()), &mut encoded).unwrap(); + + let inbound = parse_inbound(&encoded).unwrap(); + assert_eq!(inbound.header.method, Echo::METHOD); + assert_eq!( + decode_request::(inbound.body).unwrap(), + BytesValue(b"hello".to_vec()) + ); + } + + #[test] + fn response_round_trip_preserves_payload() { + let mut encoded = Vec::new(); + encode_response::(&BytesValue(b"done".to_vec()), &mut encoded).unwrap(); + assert_eq!( + decode_response::(&encoded).unwrap(), + BytesValue(b"done".to_vec()) + ); + } +} diff --git a/ql-rpc/src/rpc/request_with_progress.rs b/ql-rpc/src/rpc/request_with_progress.rs new file mode 100644 index 00000000..159c78db --- /dev/null +++ b/ql-rpc/src/rpc/request_with_progress.rs @@ -0,0 +1,295 @@ +use std::{collections::VecDeque, marker::PhantomData}; + +use bytes::Buf; + +use crate::{codec, MethodId, RpcCodec, RpcCodecError, RpcError}; + +const FRAME_HEADER_SIZE: usize = 1 + core::mem::size_of::(); + +pub trait RequestWithProgress { + const METHOD: MethodId; + type Error; + type Request: RpcCodec; + type Progress: RpcCodec; + type Response: RpcCodec; +} + +pub enum ReadStep { + NeedMore(ResponseReader), + Progress { + value: M::Progress, + next: ResponseReader, + }, + Response(M::Response), +} + +pub struct ResponseReader { + bytes: VecDeque, + marker: PhantomData M>, +} + +impl Default for ResponseReader { + fn default() -> Self { + Self::new() + } +} + +impl ResponseReader { + pub fn new() -> Self { + Self { + bytes: VecDeque::new(), + marker: PhantomData, + } + } + + pub fn push(mut self, chunk: &[u8]) -> Self { + self.bytes.extend(chunk); + self + } + + pub fn advance(self) -> Result, RpcCodecError> { + let mut this = self; + + let (first, second) = this.bytes.as_slices(); + let Some((kind, consumed, payload_len)) = + codec::try_measure_next_tagged_part(first.chain(second)).map_err(RpcCodecError::Rpc)? + else { + return Ok(ReadStep::NeedMore(this)); + }; + + match kind { + x if x == FrameKind::Progress as u8 => { + this.bytes.drain(..FRAME_HEADER_SIZE); + let value = { + let mut body = codec::DrainBuf::new(&mut this.bytes, payload_len); + M::Progress::decode_value(&mut body).map_err(RpcCodecError::Codec)? + }; + Ok(ReadStep::Progress { value, next: this }) + } + x if x == FrameKind::Response as u8 => { + let has_trailing = this.bytes.len() > consumed; + this.bytes.drain(..FRAME_HEADER_SIZE); + let mut body = codec::DrainBuf::new(&mut this.bytes, payload_len); + let response = + M::Response::decode_value(&mut body).map_err(RpcCodecError::Codec)?; + if has_trailing { + Err(RpcCodecError::Rpc(RpcError::TrailingBytes)) + } else { + Ok(ReadStep::Response(response)) + } + } + other => Err(RpcCodecError::Rpc(RpcError::UnexpectedFrameKind(other))), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +enum FrameKind { + Progress = 1, + Response = 2, +} + +pub fn encode_request( + request: &M::Request, + out: &mut Vec, +) -> Result<(), M::Error> { + crate::header::RpcHeader::new(M::METHOD).encode_into(out); + request.encode_value(out) +} + +pub fn decode_request(mut body: &[u8]) -> Result { + M::Request::decode_value(&mut body) +} + +pub fn encode_progress( + progress: &M::Progress, + out: &mut Vec, +) -> Result<(), M::Error> { + encode_tagged_value_part(FrameKind::Progress, progress, out) +} + +pub fn encode_response( + response: &M::Response, + out: &mut Vec, +) -> Result<(), M::Error> { + encode_tagged_value_part(FrameKind::Response, response, out) +} + +fn encode_tagged_value_part( + kind: FrameKind, + value: &T, + out: &mut Vec, +) -> Result<(), T::Error> { + let mut payload = Vec::new(); + value.encode_value(&mut payload)?; + out.push(kind as u8); + codec::push_length(out, payload.len()); + out.extend_from_slice(&payload); + Ok(()) +} + +#[cfg(test)] +mod tests { + use bytes::Buf; + + use super::{ + decode_request, encode_progress, encode_request, encode_response, ReadStep, + RequestWithProgress, ResponseReader, + }; + use crate::{parse_inbound, MethodId, RpcCodec, RpcCodecError, RpcError}; + + #[derive(Debug, Clone, PartialEq, Eq)] + struct BytesValue(Vec); + + impl RpcCodec for BytesValue { + type Error = core::convert::Infallible; + + fn encode_value(&self, out: &mut Vec) -> Result<(), Self::Error> { + out.extend_from_slice(&self.0); + Ok(()) + } + + fn decode_value(bytes: &mut B) -> Result { + Ok(Self(bytes.copy_to_bytes(bytes.remaining()).to_vec())) + } + } + + struct Watch; + + impl RequestWithProgress for Watch { + const METHOD: MethodId = MethodId(11); + type Error = core::convert::Infallible; + type Request = BytesValue; + type Progress = BytesValue; + type Response = BytesValue; + } + + #[test] + fn request_round_trip_preserves_header_and_payload() { + let mut encoded = Vec::new(); + encode_request::(&BytesValue(b"watch".to_vec()), &mut encoded).unwrap(); + + let inbound = parse_inbound(&encoded).unwrap(); + assert_eq!(inbound.header.method, Watch::METHOD); + assert_eq!( + decode_request::(inbound.body).unwrap(), + BytesValue(b"watch".to_vec()) + ); + } + + #[test] + fn response_with_progress_requires_terminal_response() { + let mut encoded = Vec::new(); + encode_progress::(&BytesValue(b"10%".to_vec()), &mut encoded).unwrap(); + + let reader = match ResponseReader::::new() + .push(&encoded) + .advance() + .unwrap() + { + ReadStep::Progress { value, next } => { + assert_eq!(value, BytesValue(b"10%".to_vec())); + next + } + _ => unreachable!(), + }; + let reader = match reader.advance().unwrap() { + ReadStep::NeedMore(next) => next, + _ => unreachable!(), + }; + let _ = reader; + } + + #[test] + fn response_with_progress_rejects_bytes_after_response() { + let mut encoded = Vec::new(); + encode_progress::(&BytesValue(b"10%".to_vec()), &mut encoded).unwrap(); + encode_response::(&BytesValue(b"done".to_vec()), &mut encoded).unwrap(); + encode_progress::(&BytesValue(b"late".to_vec()), &mut encoded).unwrap(); + + let reader = match ResponseReader::::new() + .push(&encoded) + .advance() + .unwrap() + { + ReadStep::Progress { next, .. } => next, + _ => unreachable!(), + }; + match reader.advance() { + Err(RpcCodecError::Rpc(RpcError::TrailingBytes)) => {} + _ => unreachable!(), + } + } + + #[test] + fn response_reader_emits_typed_events() { + let mut encoded = Vec::new(); + encode_progress::(&BytesValue(b"10%".to_vec()), &mut encoded).unwrap(); + encode_response::(&BytesValue(b"done".to_vec()), &mut encoded).unwrap(); + + let reader = ResponseReader::::new().push(&encoded[..4]); + let reader = match reader.advance().unwrap() { + ReadStep::NeedMore(next) => next, + _ => unreachable!(), + }; + let reader = reader.push(&encoded[4..encoded.len() - 2]); + let reader = match reader.advance().unwrap() { + ReadStep::Progress { + value: BytesValue(bytes), + next, + } => { + assert_eq!(bytes, b"10%".to_vec()); + next + } + _ => unreachable!(), + }; + let reader = match reader.advance().unwrap() { + ReadStep::NeedMore(next) => next, + _ => unreachable!(), + }; + let reader = reader.push(&encoded[encoded.len() - 2..]); + match reader.advance().unwrap() { + ReadStep::Response(value) => assert_eq!(value, BytesValue(b"done".to_vec())), + _ => unreachable!(), + } + } + + #[test] + fn response_progress_then_response_round_trips() { + let mut encoded = Vec::new(); + encode_progress::(&BytesValue(b"10%".to_vec()), &mut encoded).unwrap(); + encode_response::(&BytesValue(b"done".to_vec()), &mut encoded).unwrap(); + + let reader = match ResponseReader::::new() + .push(&encoded) + .advance() + .unwrap() + { + ReadStep::Progress { value, next } => { + assert_eq!(value, BytesValue(b"10%".to_vec())); + next + } + _ => unreachable!(), + }; + match reader.advance().unwrap() { + ReadStep::Response(value) => assert_eq!(value, BytesValue(b"done".to_vec())), + _ => unreachable!(), + } + } + + #[test] + fn response_can_be_encoded_without_progress() { + let mut encoded = Vec::new(); + encode_response::(&BytesValue(b"done".to_vec()), &mut encoded).unwrap(); + + match ResponseReader::::new() + .push(&encoded) + .advance() + .unwrap() + { + ReadStep::Response(value) => assert_eq!(value, BytesValue(b"done".to_vec())), + _ => unreachable!(), + } + } +} diff --git a/ql-rpc/src/rpc/subscription.rs b/ql-rpc/src/rpc/subscription.rs new file mode 100644 index 00000000..442ef156 --- /dev/null +++ b/ql-rpc/src/rpc/subscription.rs @@ -0,0 +1,212 @@ +use std::{collections::VecDeque, marker::PhantomData}; + +use bytes::Buf; + +use crate::{codec, MethodId, RpcCodec, RpcCodecError, RpcError}; + +const ITEM_HEADER_SIZE: usize = core::mem::size_of::(); + +pub trait Subscription { + const METHOD: MethodId; + type Error; + type Request: RpcCodec; + type Event: RpcCodec; +} + +pub enum ReadStep { + NeedMore(ResponseReader), + Item { + value: M::Event, + next: ResponseReader, + }, + End, +} + +pub struct ResponseReader { + bytes: VecDeque, + marker: PhantomData M>, +} + +impl Default for ResponseReader { + fn default() -> Self { + Self::new() + } +} + +impl ResponseReader { + pub fn new() -> Self { + Self { + bytes: VecDeque::new(), + marker: PhantomData, + } + } + + pub fn push(mut self, chunk: &[u8]) -> Self { + self.bytes.extend(chunk); + self + } + + pub fn advance(self) -> Result, RpcCodecError> { + let mut this = self; + let (first, second) = this.bytes.as_slices(); + let Some((consumed, payload_len)) = + codec::try_measure_next_part(first.chain(second)).map_err(RpcCodecError::Rpc)? + else { + return Ok(ReadStep::NeedMore(this)); + }; + + if payload_len == 0 { + if this.bytes.len() == consumed { + return Ok(ReadStep::End); + } + return Err(RpcCodecError::Rpc(RpcError::TrailingBytes)); + } + + this.bytes.drain(..ITEM_HEADER_SIZE); + let item = { + let mut body = codec::DrainBuf::new(&mut this.bytes, payload_len); + M::Event::decode_value(&mut body).map_err(RpcCodecError::Codec)? + }; + Ok(ReadStep::Item { + value: item, + next: this, + }) + } +} + +pub fn encode_request( + request: &M::Request, + out: &mut Vec, +) -> Result<(), M::Error> { + crate::header::RpcHeader::new(M::METHOD).encode_into(out); + request.encode_value(out) +} + +pub fn decode_request(mut body: &[u8]) -> Result { + M::Request::decode_value(&mut body) +} + +pub fn encode_item( + item: &M::Event, + out: &mut Vec, +) -> Result<(), ::Error> { + codec::encode_value_part(item, out) +} + +pub fn encode_end(out: &mut Vec) { + codec::push_length(out, 0); +} + +#[cfg(test)] +mod tests { + use bytes::Buf; + + use super::{ + decode_request, encode_end, encode_item, encode_request, ReadStep, ResponseReader, + Subscription, + }; + use crate::{parse_inbound, MethodId, RpcCodec}; + + #[derive(Debug, Clone, PartialEq, Eq)] + struct BytesValue(Vec); + + impl RpcCodec for BytesValue { + type Error = core::convert::Infallible; + + fn encode_value(&self, out: &mut Vec) -> Result<(), Self::Error> { + out.extend_from_slice(&self.0); + Ok(()) + } + + fn decode_value(bytes: &mut B) -> Result { + Ok(Self(bytes.copy_to_bytes(bytes.remaining()).to_vec())) + } + } + + struct Feed; + + impl Subscription for Feed { + const METHOD: MethodId = MethodId(17); + type Error = core::convert::Infallible; + type Request = BytesValue; + type Event = BytesValue; + } + + #[test] + fn request_round_trip_preserves_header_and_payload() { + let mut encoded = Vec::new(); + encode_request::(&BytesValue(b"watch".to_vec()), &mut encoded).unwrap(); + + let inbound = parse_inbound(&encoded).unwrap(); + assert_eq!(inbound.header.method, Feed::METHOD); + assert_eq!( + decode_request::(inbound.body).unwrap(), + BytesValue(b"watch".to_vec()) + ); + } + + #[test] + fn decode_item_stream_reads_all_items() { + let mut encoded = Vec::new(); + encode_item::(&BytesValue(b"one".to_vec()), &mut encoded).unwrap(); + encode_item::(&BytesValue(b"two".to_vec()), &mut encoded).unwrap(); + encode_end(&mut encoded); + + let reader = match ResponseReader::::new() + .push(&encoded) + .advance() + .unwrap() + { + ReadStep::Item { value, next } => { + assert_eq!(value, BytesValue(b"one".to_vec())); + next + } + _ => unreachable!(), + }; + + let reader = match reader.advance().unwrap() { + ReadStep::Item { value, next } => { + assert_eq!(value, BytesValue(b"two".to_vec())); + next + } + _ => unreachable!(), + }; + + assert!(matches!(reader.advance().unwrap(), ReadStep::End)); + } + + #[test] + fn response_reader_emits_items_as_chunks_arrive() { + let mut encoded = Vec::new(); + encode_item::(&BytesValue(b"one".to_vec()), &mut encoded).unwrap(); + encode_item::(&BytesValue(b"two".to_vec()), &mut encoded).unwrap(); + encode_end(&mut encoded); + + let reader = match ResponseReader::::new() + .push(&encoded[..5]) + .advance() + .unwrap() + { + ReadStep::NeedMore(next) => next, + _ => unreachable!(), + }; + + let reader = match reader.push(&encoded[5..]).advance().unwrap() { + ReadStep::Item { value, next } => { + assert_eq!(value, BytesValue(b"one".to_vec())); + next + } + _ => unreachable!(), + }; + + let reader = match reader.advance().unwrap() { + ReadStep::Item { value, next } => { + assert_eq!(value, BytesValue(b"two".to_vec())); + next + } + _ => unreachable!(), + }; + + assert!(matches!(reader.advance().unwrap(), ReadStep::End)); + } +} diff --git a/ql-runtime/Cargo.toml b/ql-runtime/Cargo.toml index 03988f66..be9e018e 100644 --- a/ql-runtime/Cargo.toml +++ b/ql-runtime/Cargo.toml @@ -5,16 +5,22 @@ edition = "2021" description = "Quantum Link runtime" license = "Proprietary" +[features] +default = [] +rpc = ["dep:ql-rpc"] + [dependencies] async-channel = { version = "2.5" } futures-lite = { version = "2.5" } oneshot = { version = "0.1.11" } piper = { version = "0.2.4" } ql-fsm = { path = "../ql-fsm" } +ql-rpc = { path = "../ql-rpc", optional = true } ql-wire = { path = "../ql-wire" } thiserror = { version = "2" } [dev-dependencies] +bytes = "1" libcrux-aesgcm = "0.0.7" sha2 = "0.10" tokio = { version = "1.44", features = ["macros", "rt", "time", "sync"] } diff --git a/ql-runtime/src/handle.rs b/ql-runtime/src/handle.rs index 10b69a22..f013b78c 100644 --- a/ql-runtime/src/handle.rs +++ b/ql-runtime/src/handle.rs @@ -1,5 +1,7 @@ +use std::{pin::Pin, task::Poll}; + use async_channel::{Receiver, Sender}; -use futures_lite::future::poll_fn; +use futures_lite::{future::poll_fn, Stream}; use crate::{ command::RuntimeCommand, CloseCode, CloseTarget, InboundEvent, OpenedStreamDelivery, Peer, @@ -78,22 +80,36 @@ impl ByteReader { } pub async fn next_chunk(&mut self) -> Result>, QlError> { + poll_fn(|cx| self.poll_next_chunk(cx)).await + } + + pub fn poll_next_chunk( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>, QlError>> { if self.finished { - return Ok(None); + return Poll::Ready(Ok(None)); } - match self.rx.recv().await { - Ok(InboundEvent::Data(bytes)) => Ok(Some(bytes)), - Ok(InboundEvent::Finished) => { + + // `async_channel::Receiver` implements `Stream` and stores its listener state + // internally, so poll it directly rather than recreating a `recv()` future. + // SAFETY: `self.rx` is pinned for the duration of this call and is not moved + // before `poll_next` returns. + let mut rx = unsafe { Pin::new_unchecked(&mut self.rx) }; + match Stream::poll_next(rx.as_mut(), cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(InboundEvent::Data(bytes))) => Poll::Ready(Ok(Some(bytes))), + Poll::Ready(Some(InboundEvent::Finished)) => { self.finished = true; - Ok(None) + Poll::Ready(Ok(None)) } - Ok(InboundEvent::Failed(error)) => { + Poll::Ready(Some(InboundEvent::Failed(error))) => { self.finished = true; - Err(error) + Poll::Ready(Err(error)) } - Err(_) => { + Poll::Ready(None) => { self.finished = true; - Err(QlError::Cancelled) + Poll::Ready(Err(QlError::Cancelled)) } } } @@ -270,6 +286,13 @@ impl RuntimeHandle { response: ByteReader::new(stream_id, CloseTarget::Response, response, self.tx.clone()), }) } + + #[cfg(feature = "rpc")] + pub fn rpc(&self) -> crate::rpc::RpcHandle { + crate::rpc::RpcHandle { + inner: self.clone(), + } + } } impl RuntimeHandle { diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index b51093c0..f732b695 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -6,6 +6,8 @@ pub(crate) mod command; pub(crate) mod driver; pub mod handle; pub mod platform; +#[cfg(feature = "rpc")] +pub mod rpc; #[cfg(test)] mod tests; diff --git a/ql-runtime/src/rpc/error.rs b/ql-runtime/src/rpc/error.rs new file mode 100644 index 00000000..f81ad652 --- /dev/null +++ b/ql-runtime/src/rpc/error.rs @@ -0,0 +1,55 @@ +use crate::QlError; + +#[derive(Debug)] +pub enum RpcCallError { + Runtime(QlError), + Rpc(ql_rpc::RpcError), + Codec(E), +} + +impl From for RpcCallError { + fn from(error: QlError) -> Self { + Self::Runtime(error) + } +} + +impl From for RpcCallError { + fn from(error: ql_rpc::RpcError) -> Self { + Self::Rpc(error) + } +} + +impl From> for RpcCallError { + fn from(error: ql_rpc::RpcCodecError) -> Self { + match error { + ql_rpc::RpcCodecError::Rpc(error) => Self::Rpc(error), + ql_rpc::RpcCodecError::Codec(error) => Self::Codec(error), + } + } +} + +impl std::fmt::Display for RpcCallError +where + E: std::fmt::Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Runtime(error) => write!(f, "{error}"), + Self::Rpc(error) => write!(f, "{error}"), + Self::Codec(error) => write!(f, "{error}"), + } + } +} + +impl std::error::Error for RpcCallError +where + E: std::error::Error + 'static, +{ + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Runtime(error) => Some(error), + Self::Rpc(error) => Some(error), + Self::Codec(error) => Some(error), + } + } +} diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs new file mode 100644 index 00000000..d7d72b56 --- /dev/null +++ b/ql-runtime/src/rpc/mod.rs @@ -0,0 +1,153 @@ +use std::task::{Context, Poll}; + +mod error; +mod request_with_progress; +mod subscription; + +pub use error::*; +use ql_rpc::{ + notification::{self, Notification}, + request::{self, Request as RequestRpc}, + request_with_progress::{self as rpc_request_with_progress, RequestWithProgress}, + subscription::{self as rpc_subscription, Subscription as SubscriptionRpc}, + RpcError, +}; +pub use request_with_progress::*; +pub use subscription::*; + +use crate::{ByteReader, OutboundStream, QlError, RuntimeHandle}; + +#[derive(Clone)] +pub struct RpcHandle { + pub(crate) inner: RuntimeHandle, +} + +pub(super) enum ChunkState { + Open(ByteReader), + Closed, +} + +impl RpcHandle { + pub async fn event(&self, event: &M::Event) -> Result<(), RpcCallError> + where + M: Notification, + { + let mut payload = Vec::new(); + notification::encode_event::(event, &mut payload).map_err(RpcCallError::Codec)?; + + let response = self + .start_request(payload) + .await + .map_err(RpcCallError::Runtime)?; + let response = read_all(response).await.map_err(RpcCallError::Runtime)?; + if response.is_empty() { + Ok(()) + } else { + Err(RpcCallError::Rpc(RpcError::TrailingBytes)) + } + } + + pub async fn request( + &self, + request: &M::Request, + ) -> Result> + where + M: RequestRpc, + { + let mut payload = Vec::new(); + request::encode_request::(request, &mut payload).map_err(RpcCallError::Codec)?; + let response = self + .start_request(payload) + .await + .map_err(RpcCallError::Runtime)?; + let response = read_all(response).await.map_err(RpcCallError::Runtime)?; + request::decode_response::(&response).map_err(RpcCallError::Codec) + } + + pub async fn subscribe( + &self, + request: &M::Request, + ) -> Result, RpcCallError> + where + M: SubscriptionRpc, + { + let mut payload = Vec::new(); + rpc_subscription::encode_request::(request, &mut payload) + .map_err(RpcCallError::Codec)?; + + let response = self + .start_request(payload) + .await + .map_err(RpcCallError::Runtime)?; + Ok(Subscription { + chunks: ChunkState::new(response), + reader: Some(rpc_subscription::ResponseReader::new()), + }) + } + + pub async fn request_with_progress( + &self, + request: &M::Request, + ) -> Result, RpcCallError> + where + M: RequestWithProgress, + { + let mut payload = Vec::new(); + rpc_request_with_progress::encode_request::(request, &mut payload) + .map_err(RpcCallError::Codec)?; + + let response = self + .start_request(payload) + .await + .map_err(RpcCallError::Runtime)?; + Ok(ProgressCall { + chunks: ChunkState::new(response), + reader: Some(rpc_request_with_progress::ResponseReader::new()), + terminal: None, + }) + } + + async fn start_request(&self, payload: Vec) -> Result { + let OutboundStream { + mut request, + response, + .. + } = self.inner.open_stream().await?; + + request.write_all(&payload).await?; + request.finish().await?; + Ok(response) + } +} + +impl ChunkState { + fn new(reader: ByteReader) -> Self { + Self::Open(reader) + } + + fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll>, QlError>> { + match self { + Self::Open(reader) => match reader.poll_next_chunk(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(Some(bytes))) => Poll::Ready(Ok(Some(bytes))), + Poll::Ready(Ok(None)) => { + *self = Self::Closed; + Poll::Ready(Ok(None)) + } + Poll::Ready(Err(error)) => { + *self = Self::Closed; + Poll::Ready(Err(error)) + } + }, + Self::Closed => Poll::Ready(Ok(None)), + } + } +} + +async fn read_all(mut reader: ByteReader) -> Result, QlError> { + let mut bytes = Vec::new(); + while let Some(chunk) = reader.next_chunk().await? { + bytes.extend_from_slice(&chunk); + } + Ok(bytes) +} diff --git a/ql-runtime/src/rpc/request_with_progress.rs b/ql-runtime/src/rpc/request_with_progress.rs new file mode 100644 index 00000000..89021929 --- /dev/null +++ b/ql-runtime/src/rpc/request_with_progress.rs @@ -0,0 +1,134 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_lite::{future::poll_fn, Stream}; +use ql_rpc::{ + request_with_progress::{ReadStep, RequestWithProgress}, + RpcError, +}; + +use super::{ChunkState, RpcCallError}; + +pub struct ProgressCall { + pub(super) chunks: ChunkState, + pub(super) reader: Option>, + pub(super) terminal: Option>>, +} + +impl Unpin for ProgressCall where M: RequestWithProgress {} + +impl ProgressCall +where + M: RequestWithProgress, +{ + pub async fn progress(&mut self) -> Option { + poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await + } +} + +impl Stream for ProgressCall +where + M: RequestWithProgress, +{ + type Item = M::Progress; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + if this.terminal.is_some() || this.reader.is_none() { + return Poll::Ready(None); + } + + loop { + let reader = this.reader.take().expect("progress reader is present"); + match reader.advance() { + Ok(ReadStep::Progress { value, next }) => { + this.reader = Some(next); + return Poll::Ready(Some(value)); + } + Ok(ReadStep::Response(response)) => { + this.terminal = Some(Ok(response)); + return Poll::Ready(None); + } + Ok(ReadStep::NeedMore(next)) => { + this.reader = Some(next); + } + Err(error) => { + this.terminal = Some(Err(error.into())); + return Poll::Ready(None); + } + } + + match this.chunks.poll_next(cx) { + Poll::Ready(Ok(Some(chunk))) => { + let reader = this.reader.take().expect("progress reader is present"); + this.reader = Some(reader.push(&chunk)); + } + Poll::Ready(Ok(None)) => { + this.reader = None; + this.terminal = Some(Err(RpcError::MissingResponse.into())); + return Poll::Ready(None); + } + Poll::Ready(Err(error)) => { + this.reader = None; + this.terminal = Some(Err(RpcCallError::Runtime(error))); + return Poll::Ready(None); + } + Poll::Pending => return Poll::Pending, + } + } + } +} + +impl Future for ProgressCall +where + M: RequestWithProgress, +{ + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + if let Some(result) = this.terminal.take() { + return Poll::Ready(result); + } + + loop { + let Some(reader) = this.reader.take() else { + panic!("progress call polled after completion"); + }; + + match reader.advance() { + Ok(ReadStep::Progress { next, .. }) => { + this.reader = Some(next); + } + Ok(ReadStep::Response(response)) => { + return Poll::Ready(Ok(response)); + } + Ok(ReadStep::NeedMore(next)) => { + this.reader = Some(next); + } + Err(error) => return Poll::Ready(Err(error.into())), + } + + match this.chunks.poll_next(cx) { + Poll::Ready(Ok(Some(chunk))) => { + let reader = this.reader.take().expect("progress reader is present"); + this.reader = Some(reader.push(&chunk)); + } + Poll::Ready(Ok(None)) => { + this.reader = None; + return Poll::Ready(Err(RpcError::MissingResponse.into())); + } + Poll::Ready(Err(error)) => { + this.reader = None; + return Poll::Ready(Err(RpcCallError::Runtime(error))); + } + Poll::Pending => return Poll::Pending, + } + } + } +} diff --git a/ql-runtime/src/rpc/subscription.rs b/ql-runtime/src/rpc/subscription.rs new file mode 100644 index 00000000..2b2e51fa --- /dev/null +++ b/ql-runtime/src/rpc/subscription.rs @@ -0,0 +1,73 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use futures_lite::{future::poll_fn, Stream}; +use ql_rpc::{ + subscription::{ReadStep, Subscription as SubscriptionRpc}, + RpcError, +}; + +use super::{ChunkState, RpcCallError}; + +pub struct Subscription { + pub(super) chunks: ChunkState, + pub(super) reader: Option>, +} + +impl Unpin for Subscription where M: SubscriptionRpc {} + +impl Subscription +where + M: SubscriptionRpc, +{ + pub async fn next_event(&mut self) -> Option>> { + poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await + } +} + +impl Stream for Subscription +where + M: SubscriptionRpc, +{ + type Item = Result>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + loop { + let Some(reader) = this.reader.take() else { + return Poll::Ready(None); + }; + + match reader.advance() { + Ok(ReadStep::Item { value, next }) => { + this.reader = Some(next); + return Poll::Ready(Some(Ok(value))); + } + Ok(ReadStep::End) => return Poll::Ready(None), + Ok(ReadStep::NeedMore(next)) => { + this.reader = Some(next); + } + Err(error) => return Poll::Ready(Some(Err(error.into()))), + } + + match this.chunks.poll_next(cx) { + Poll::Ready(Ok(Some(chunk))) => { + let reader = this.reader.take().expect("subscription reader is present"); + this.reader = Some(reader.push(&chunk)); + } + Poll::Ready(Ok(None)) => { + this.reader = None; + return Poll::Ready(Some(Err(RpcError::Truncated.into()))); + } + Poll::Ready(Err(error)) => { + this.reader = None; + return Poll::Ready(Some(Err(RpcCallError::Runtime(error)))); + } + Poll::Pending => return Poll::Pending, + } + } + } +} diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index d7391b5f..0cce74d9 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -24,6 +24,8 @@ use crate::{ mod handshake; mod heartbeat; +#[cfg(feature = "rpc")] +mod rpc; mod stream; mod unpair; diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs new file mode 100644 index 00000000..0fd48c1a --- /dev/null +++ b/ql-runtime/src/tests/rpc.rs @@ -0,0 +1,248 @@ +use std::time::Duration; + +use bytes::Buf; +use futures_lite::StreamExt; + +use super::*; +#[derive(Debug, Clone, PartialEq, Eq)] +struct BytesValue(Vec); + +impl ql_rpc::RpcCodec for BytesValue { + type Error = core::convert::Infallible; + + fn encode_value(&self, out: &mut Vec) -> Result<(), Self::Error> { + out.extend_from_slice(&self.0); + Ok(()) + } + + fn decode_value(bytes: &mut B) -> Result { + Ok(Self(bytes.copy_to_bytes(bytes.remaining()).to_vec())) + } +} + +struct Echo; + +impl ql_rpc::request::Request for Echo { + const METHOD: ql_rpc::MethodId = ql_rpc::MethodId(51); + type Error = core::convert::Infallible; + type Request = BytesValue; + type Response = BytesValue; +} + +struct Feed; + +impl ql_rpc::subscription::Subscription for Feed { + const METHOD: ql_rpc::MethodId = ql_rpc::MethodId(52); + type Error = core::convert::Infallible; + type Request = BytesValue; + type Event = BytesValue; +} + +struct Download; + +impl ql_rpc::request_with_progress::RequestWithProgress for Download { + const METHOD: ql_rpc::MethodId = ql_rpc::MethodId(53); + type Error = core::convert::Infallible; + type Request = BytesValue; + type Progress = BytesValue; + type Response = BytesValue; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_request_round_trips() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); + let identity_a = new_identity(11); + let identity_b = new_identity(73); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect().unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Connected).await; + await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + let request = read_all(inbound.request).await.unwrap(); + let rpc_inbound = ql_rpc::parse_inbound(&request).unwrap(); + assert_eq!( + ql_rpc::request::decode_request::(rpc_inbound.body).unwrap(), + BytesValue(b"hello".to_vec()) + ); + + let mut encoded = Vec::new(); + ql_rpc::request::encode_response::(&BytesValue(b"world".to_vec()), &mut encoded) + .unwrap(); + let mut response = inbound.response; + response.write_all(&encoded).await.unwrap(); + response.finish().await.unwrap(); + }); + + let rpc = handle_a.rpc(); + let response = rpc + .request::(&BytesValue(b"hello".to_vec())) + .await + .unwrap(); + assert_eq!(response, BytesValue(b"world".to_vec())); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_subscription_streams_events() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); + let identity_a = new_identity(11); + let identity_b = new_identity(73); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect().unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Connected).await; + await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + let request = read_all(inbound.request).await.unwrap(); + let rpc_inbound = ql_rpc::parse_inbound(&request).unwrap(); + assert_eq!( + ql_rpc::subscription::decode_request::(rpc_inbound.body).unwrap(), + BytesValue(b"watch".to_vec()) + ); + + let mut encoded = Vec::new(); + ql_rpc::subscription::encode_item::(&BytesValue(b"one".to_vec()), &mut encoded) + .unwrap(); + ql_rpc::subscription::encode_item::(&BytesValue(b"two".to_vec()), &mut encoded) + .unwrap(); + ql_rpc::subscription::encode_end(&mut encoded); + + let mut response = inbound.response; + response.write_all(&encoded).await.unwrap(); + response.finish().await.unwrap(); + }); + + let rpc = handle_a.rpc(); + let mut subscription = rpc + .subscribe::(&BytesValue(b"watch".to_vec())) + .await + .unwrap(); + assert_eq!( + subscription.next().await.unwrap().unwrap(), + BytesValue(b"one".to_vec()) + ); + assert_eq!( + subscription.next().await.unwrap().unwrap(), + BytesValue(b"two".to_vec()) + ); + assert!(subscription.next().await.is_none()); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_request_with_progress_supports_progress_then_await() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); + let identity_a = new_identity(11); + let identity_b = new_identity(73); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect().unwrap(); + + await_status(&status_a, identity_b.xid, PeerStage::Connected).await; + await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + let request = read_all(inbound.request).await.unwrap(); + let rpc_inbound = ql_rpc::parse_inbound(&request).unwrap(); + assert_eq!( + ql_rpc::request_with_progress::decode_request::(rpc_inbound.body) + .unwrap(), + BytesValue(b"logo".to_vec()) + ); + + let mut encoded = Vec::new(); + ql_rpc::request_with_progress::encode_progress::( + &BytesValue(b"10".to_vec()), + &mut encoded, + ) + .unwrap(); + ql_rpc::request_with_progress::encode_progress::( + &BytesValue(b"90".to_vec()), + &mut encoded, + ) + .unwrap(); + ql_rpc::request_with_progress::encode_response::( + &BytesValue(b"done".to_vec()), + &mut encoded, + ) + .unwrap(); + + let mut response = inbound.response; + response.write_all(&encoded).await.unwrap(); + response.finish().await.unwrap(); + }); + + let rpc = handle_a.rpc(); + let mut download = rpc + .request_with_progress::(&BytesValue(b"logo".to_vec())) + .await + .unwrap(); + + assert_eq!(download.progress().await, Some(BytesValue(b"10".to_vec()))); + assert_eq!(download.progress().await, Some(BytesValue(b"90".to_vec()))); + assert_eq!(download.progress().await, None); + assert_eq!(download.await.unwrap(), BytesValue(b"done".to_vec())); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} From 31ed49de623569b0e293843dc0df5dcc16e50f6d Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 23 Mar 2026 10:10:13 -0400 Subject: [PATCH 018/304] design doc --- QL_V2.md | 68 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 QL_V2.md diff --git a/QL_V2.md b/QL_V2.md new file mode 100644 index 00000000..30012271 --- /dev/null +++ b/QL_V2.md @@ -0,0 +1,68 @@ +# QuantumLink V2 Design Document + + QuantumLink V2 is a peer-to-peer protocol for authenticated, encrypted sessions carrying multiplexed byte streams. + + It replaces QLv1's one-message-at-a-time model with explicit pairing, handshake, session, and stream state. + + QLv2 operates on complete QL records and leaves transport-specific framing, fragmentation, reassembly, and delivery behavior to platform adapters. + +## Design goals +1. [use ephemeral peer sessions for record encryption](#1-explicit-peer-sessions) +2. [include a minimal unencrypted but authenticated header](#2-minimal-authenticated-header) +3. [keep the record layer transport-agnostic](#3-transport-agnostic-record-layer) +4. [add QL-level reliability above the transport](#4-ql-level-reliability) +5. [use duplex byte streams as the application primitive](#5-duplex-byte-streams) +6. [efficient protocol wire format](#6-efficient-wire-format) +7. [provide a single shared protocol state machine across platforms](#7-shared-core-state-machine) +8. [support hardware-backed cryptography](#8-hardware-backed-cryptography) + +### 1. Explicit peer sessions +QLv2 replaces per-exchange sealing with explicit pairing, handshake, session, and stream state. This keeps peer state durable across many records, amortizes large post-quantum signatures and expensive key exchange, and keeps steady-state traffic smaller and cheaper. + +### 2. Minimal authenticated header +QLv2 keeps a small header visible on the wire while still authenticating it. This lets a host route a record to the correct local or third-party application before decryption without exposing more metadata than necessary. + +### 3. Transport-agnostic record layer +The core protocol only consumes and produces complete QL records. Framing, batching, fragmentation, and reassembly stay in the transport adapter so the same protocol can run over transports such as TCP, BLE, or L2CAP without rewriting core logic. + +### 4. QL-level reliability +QLv2 includes QL-level sequence numbers and acknowledgments above the transport. A transport can usually only tell us that bytes were accepted for transmission. A QL acknowledgment tells us something stronger: the peer received, decrypted, and authenticated the record with the current session key. + +This is deliberate redundancy, not a replacement for transport reliability. It is not sufficient for a fully unreliable transport like raw UDP, but it does make QLv2 more robust on transports that should be reliable in theory yet have shown implementation-level flakiness in practice, such as Passport Prime's embedded BLE. + +### 5. Duplex byte streams +QLv2 treats duplex byte streams as the application primitive rather than building in a separate model for each interaction style. Request/response, subscriptions, progress updates, and bulk transfer can all be adapted to the same abstraction, which also gives useful behavior such as finish semantics, cancellation, and backpressure without separate protocol features. + +### 6. Efficient wire format +The wire format should stay compact, cheap to process, and independent of any one implementation language. QLv2 uses an efficient binary encoding with explicit endianness and fixed layouts, so records can be parsed consistently across platforms and can support zero-copy or near-zero-copy implementations where appropriate. + +The record sizes shows the protocol's intended split between setup and steady-state traffic. Setup records are relatively large because they carry post-quantum material, while steady-state session records are much smaller. + +| Record type | Encoded size | +| --- | ---: | +| `hello` | 6253 bytes | +| `hello_reply` | 6253 bytes | +| `confirm` | 4673 bytes | +| `pair_request empty` | 1630 bytes | +| `ready empty` | 62 bytes | +| `session ack` | 87 bytes | +| `session ping` | 87 bytes | +| `session unpair` | 87 bytes | +| `session stream empty` | 100 bytes | +| `session stream fin` | 100 bytes | +| `session stream close` | 94 bytes | +| `session close` | 89 bytes | + +### 7. Shared core state machine +QLv2 should have one core implementation of pairing, handshake, session, retransmission, and stream behavior. Platforms should integrate that shared state machine instead of rebuilding subtle protocol logic independently. + +### 8. Hardware-backed cryptography +QLv2 separates parts of its cryptographic implementation through the `QlCrypto` trait. Each platform can provide its own source of randomness, hashing, and AEAD encryption and decryption, choosing software or hardware-backed implementations as appropriate. + +## Non-design goals +- not a replacement for TCP, QUIC, BLE, or any other transport +- not a universal reliability layer for arbitrary raw packets +- not responsible for framing, batching, fragmentation, or reassembly on a given platform +- not responsible for how QL records map onto TCP reads/writes, BLE packets, or similar transport units +- not a general-purpose message bus above the stream layer +- not an attempt to preserve QLv1's sealed-message model in the core protocol From b7401eb119a63a8d03bc75154de9960b034f9324 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 23 Mar 2026 10:17:07 -0400 Subject: [PATCH 019/304] ql: smaller xid size --- ql-wire/src/xid.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ql-wire/src/xid.rs b/ql-wire/src/xid.rs index 164b32e0..040b3127 100644 --- a/ql-wire/src/xid.rs +++ b/ql-wire/src/xid.rs @@ -3,5 +3,5 @@ pub struct XID(pub [u8; Self::SIZE]); impl XID { - pub const SIZE: usize = 32; + pub const SIZE: usize = 16; } From 0e66e7e3715b180c58c8a552f225889acc365843 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 23 Mar 2026 10:56:49 -0400 Subject: [PATCH 020/304] ql: infallible encryption --- ql-fsm/src/error.rs | 6 ------ ql-fsm/src/implementation/fsm.rs | 3 +-- ql-fsm/src/implementation/handshake.rs | 4 ++-- ql-fsm/src/implementation/peer.rs | 2 +- ql-fsm/src/tests/mod.rs | 6 +++--- ql-runtime/src/driver.rs | 4 ++-- ql-runtime/src/lib.rs | 6 ------ ql-runtime/src/tests/mod.rs | 12 +++++------ ql-wire/src/encrypted/mod.rs | 8 +++---- ql-wire/src/encrypted_message.rs | 10 ++++----- ql-wire/src/error.rs | 4 ---- ql-wire/src/handshake/crypto.rs | 29 +++++++++++++------------- ql-wire/src/lib.rs | 2 +- ql-wire/src/pair/crypto.rs | 15 +++++++------ ql-wire/src/pq.rs | 29 +++++++++++--------------- ql-wire/src/tests.rs | 15 ++++++------- 16 files changed, 63 insertions(+), 92 deletions(-) diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs index 9a25a943..454b1554 100644 --- a/ql-fsm/src/error.rs +++ b/ql-fsm/src/error.rs @@ -11,10 +11,6 @@ pub enum QlFsmError { InvalidSignature, #[error("expired")] Expired, - #[error("signing failed")] - SigningFailed, - #[error("encryption failed")] - EncryptFailed, #[error("decryption failed")] DecryptFailed, #[error("missing stream")] @@ -33,8 +29,6 @@ impl From for QlFsmError { WireError::InvalidPayload => Self::InvalidPayload, WireError::InvalidSignature => Self::InvalidSignature, WireError::Expired => Self::Expired, - WireError::SigningFailed => Self::SigningFailed, - WireError::EncryptFailed => Self::EncryptFailed, WireError::DecryptFailed => Self::DecryptFailed, } } diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index ac388313..b7e4c8c6 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -111,8 +111,7 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option Result< peer.xid, &peer.encapsulation_key, meta, - )?; + ); let deadline = fsm.state.now.instant + fsm.config.handshake_timeout; let retry_at = Some(fsm.state.now.instant + fsm.config.handshake_retry_interval); diff --git a/ql-fsm/src/implementation/peer.rs b/ql-fsm/src/implementation/peer.rs index 76605a20..001f238c 100644 --- a/ql-fsm/src/implementation/peer.rs +++ b/ql-fsm/src/implementation/peer.rs @@ -16,7 +16,7 @@ pub fn handle_pair_local(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), peer.peer.xid, &peer.peer.encapsulation_key, meta, - )?; + ); fsm.state.outbound.push_back(record); Ok(()) } diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 14b2a0f2..f5055354 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -55,7 +55,7 @@ impl QlCrypto for TestCrypto { nonce: &ql_wire::Nonce, aad: &[u8], buffer: &mut [u8], - ) -> Option<[u8; EncryptedMessage::AUTH_SIZE]> { + ) -> [u8; EncryptedMessage::AUTH_SIZE] { let key: AesGcm256Key = (*key.data()).into(); let plaintext = buffer.to_vec(); let mut auth = [0u8; EncryptedMessage::AUTH_SIZE]; @@ -66,8 +66,8 @@ impl QlCrypto for TestCrypto { aad, &plaintext, ) - .ok()?; - Some(auth) + .unwrap(); + auth } fn decrypt_with_aead( diff --git a/ql-runtime/src/driver.rs b/ql-runtime/src/driver.rs index 483ddbf2..79edcdd3 100644 --- a/ql-runtime/src/driver.rs +++ b/ql-runtime/src/driver.rs @@ -625,8 +625,8 @@ mod tests { _nonce: &ql_wire::Nonce, _aad: &[u8], _buffer: &mut [u8], - ) -> Option<[u8; ql_wire::EncryptedMessage::AUTH_SIZE]> { - None + ) -> [u8; ql_wire::EncryptedMessage::AUTH_SIZE] { + [0; ql_wire::EncryptedMessage::AUTH_SIZE] } fn decrypt_with_aead( diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index f732b695..d1a58564 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -24,10 +24,6 @@ pub enum QlError { InvalidSignature, #[error("expired")] Expired, - #[error("signing failed")] - SigningFailed, - #[error("encryption failed")] - EncryptFailed, #[error("decryption failed")] DecryptFailed, #[error("missing stream")] @@ -56,8 +52,6 @@ impl From for QlError { QlFsmError::InvalidPayload => Self::InvalidPayload, QlFsmError::InvalidSignature => Self::InvalidSignature, QlFsmError::Expired => Self::Expired, - QlFsmError::SigningFailed => Self::SigningFailed, - QlFsmError::EncryptFailed => Self::EncryptFailed, QlFsmError::DecryptFailed => Self::DecryptFailed, QlFsmError::MissingStream => Self::MissingStream, QlFsmError::NotWritable => Self::NotWritable, diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 0cce74d9..dfcda3a3 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -97,7 +97,7 @@ impl QlCrypto for DeterministicCrypto { nonce: &Nonce, aad: &[u8], buffer: &mut [u8], - ) -> Option<[u8; EncryptedMessage::AUTH_SIZE]> { + ) -> [u8; EncryptedMessage::AUTH_SIZE] { let key: AesGcm256Key = (*key.data()).into(); let plaintext = buffer.to_vec(); let mut auth = [0u8; EncryptedMessage::AUTH_SIZE]; @@ -108,8 +108,8 @@ impl QlCrypto for DeterministicCrypto { aad, &plaintext, ) - .ok()?; - Some(auth) + .unwrap(); + auth } fn decrypt_with_aead( @@ -228,7 +228,7 @@ impl QlCrypto for TestPlatform { nonce: &Nonce, aad: &[u8], buffer: &mut [u8], - ) -> Option<[u8; EncryptedMessage::AUTH_SIZE]> { + ) -> [u8; EncryptedMessage::AUTH_SIZE] { let key: AesGcm256Key = (*key.data()).into(); let plaintext = buffer.to_vec(); let mut auth = [0u8; EncryptedMessage::AUTH_SIZE]; @@ -239,8 +239,8 @@ impl QlCrypto for TestPlatform { aad, &plaintext, ) - .ok()?; - Some(auth) + .unwrap(); + auth } fn decrypt_with_aead( diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 07a75d56..99aee244 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -154,14 +154,14 @@ pub fn encrypt_record( session_key: &SessionKey, body: &SessionEnvelope, nonce: Nonce, -) -> Result { +) -> QlRecord { let aad = header.aad(); let body_bytes = body.encode(); - let encrypted = EncryptedMessage::encrypt(crypto, session_key, body_bytes, &aad, nonce)?; - Ok(QlRecord { + let encrypted = EncryptedMessage::encrypt(crypto, session_key, body_bytes, &aad, nonce); + QlRecord { header, payload: QlPayload::Session(encrypted), - }) + } } pub fn decrypt_record<'a, B: ByteSliceMut>( diff --git a/ql-wire/src/encrypted_message.rs b/ql-wire/src/encrypted_message.rs index 1c80b1bd..07f20aa1 100644 --- a/ql-wire/src/encrypted_message.rs +++ b/ql-wire/src/encrypted_message.rs @@ -65,15 +65,13 @@ impl EncryptedMessage { mut plaintext: Vec, aad: &[u8], nonce: Nonce, - ) -> Result { - let auth = crypto - .encrypt_with_aead(key, &nonce, aad, &mut plaintext) - .ok_or(WireError::EncryptFailed)?; - Ok(Self { + ) -> Self { + let auth = crypto.encrypt_with_aead(key, &nonce, aad, &mut plaintext); + Self { nonce, auth, ciphertext: plaintext, - }) + } } pub fn decrypt( diff --git a/ql-wire/src/error.rs b/ql-wire/src/error.rs index 2a84ce38..a1861866 100644 --- a/ql-wire/src/error.rs +++ b/ql-wire/src/error.rs @@ -8,10 +8,6 @@ pub enum WireError { InvalidSignature, #[error("expired")] Expired, - #[error("signing failed")] - SigningFailed, - #[error("encryption failed")] - EncryptFailed, #[error("decryption failed")] DecryptFailed, } diff --git a/ql-wire/src/handshake/crypto.rs b/ql-wire/src/handshake/crypto.rs index 74e7a68d..0ca2f636 100644 --- a/ql-wire/src/handshake/crypto.rs +++ b/ql-wire/src/handshake/crypto.rs @@ -22,10 +22,9 @@ pub fn build_hello( recipient: XID, recipient_encapsulation_key: &MlKemPublicKey, meta: ControlMeta, -) -> Result<(Hello, SessionKey), WireError> { +) -> (Hello, SessionKey) { let nonce = next_nonce(crypto); - let (session_key, kem_ct) = - recipient_encapsulation_key.encapsulate_new_shared_secret(crypto)?; + let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(crypto); let proof_data = hash_hello_proof_data( crypto, identity.xid, @@ -34,8 +33,8 @@ pub fn build_hello( &nonce.0, kem_ct.as_bytes(), ); - let signature = identity.signing_private_key.sign(crypto, &proof_data)?; - Ok(( + let signature = identity.signing_private_key.sign(crypto, &proof_data); + ( Hello { meta, nonce, @@ -43,7 +42,7 @@ pub fn build_hello( signature, }, session_key, - )) + ) } pub fn verify_hello( @@ -87,11 +86,11 @@ pub fn respond_hello( )?; let initiator_secret = identity .encapsulation_private_key - .decapsulate_shared_secret_bytes(&hello.kem_ct)?; + .decapsulate_shared_secret_bytes(&hello.kem_ct); let hello_meta = ControlMeta::from_wire(hello.meta); let nonce = next_nonce(crypto); let (responder_secret, kem_ct) = - initiator_encapsulation_key.encapsulate_new_shared_secret(crypto)?; + initiator_encapsulation_key.encapsulate_new_shared_secret(crypto); let transcript = hash_handshake_transcript( crypto, initiator, @@ -103,7 +102,7 @@ pub fn respond_hello( &nonce.0, kem_ct.as_bytes(), ); - let signature = identity.signing_private_key.sign(crypto, &transcript)?; + let signature = identity.signing_private_key.sign(crypto, &transcript); Ok(( HelloReply { meta, @@ -145,7 +144,7 @@ pub fn build_confirm( verify_signature_bytes(responder_signing_key, &reply.signature, &transcript)?; let responder_secret = identity .encapsulation_private_key - .decapsulate_shared_secret_bytes(&reply.kem_ct)?; + .decapsulate_shared_secret_bytes(&reply.kem_ct); let proof_data = hash_confirm_proof_data( crypto, &meta, @@ -158,7 +157,7 @@ pub fn build_confirm( &reply.nonce, &reply.kem_ct, ); - let signature = identity.signing_private_key.sign(crypto, &proof_data)?; + let signature = identity.signing_private_key.sign(crypto, &proof_data); let session_key = derive_session_key( crypto, initiator_secret, @@ -244,12 +243,12 @@ pub fn build_ready( session_key: &SessionKey, meta: ControlMeta, nonce: Nonce, -) -> Result { +) -> Ready { let aad = header.aad(); let body_bytes = ReadyBody { meta }.encode(); - Ok(Ready { - encrypted: EncryptedMessage::encrypt(crypto, session_key, body_bytes, &aad, nonce)?, - }) + Ready { + encrypted: EncryptedMessage::encrypt(crypto, session_key, body_bytes, &aad, nonce), + } } pub fn decrypt_ready( diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs index 117c0ab7..74e59212 100644 --- a/ql-wire/src/lib.rs +++ b/ql-wire/src/lib.rs @@ -47,7 +47,7 @@ pub trait QlCrypto { nonce: &Nonce, aad: &[u8], buffer: &mut [u8], - ) -> Option<[u8; EncryptedMessage::AUTH_SIZE]>; + ) -> [u8; EncryptedMessage::AUTH_SIZE]; fn decrypt_with_aead( &self, diff --git a/ql-wire/src/pair/crypto.rs b/ql-wire/src/pair/crypto.rs index 045e7b78..69e4d9c7 100644 --- a/ql-wire/src/pair/crypto.rs +++ b/ql-wire/src/pair/crypto.rs @@ -12,9 +12,8 @@ pub fn build_pair_request( recipient: XID, recipient_encapsulation_key: &MlKemPublicKey, meta: ControlMeta, -) -> Result { - let (session_key, kem_ct) = - recipient_encapsulation_key.encapsulate_new_shared_secret(crypto)?; +) -> QlRecord { + let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(crypto); let header = QlHeader { sender: identity.xid, recipient, @@ -30,7 +29,7 @@ pub fn build_pair_request( &signing_pub_key, &sender_encapsulation_key, ); - let proof = identity.signing_private_key.sign(crypto, &proof_data)?; + let proof = identity.signing_private_key.sign(crypto, &proof_data); let body = PairRequestBody { meta, xid: identity.xid, @@ -48,11 +47,11 @@ pub fn build_pair_request( body_bytes, &aad, crate::Nonce(nonce), - )?; - Ok(QlRecord { + ); + QlRecord { header, payload: QlPayload::PairRequest(super::PairRequestRecord { kem_ct, encrypted }), - }) + } } pub fn decrypt_pair_request( @@ -66,7 +65,7 @@ pub fn decrypt_pair_request( let aad = pairing_aad(header, &kem_ct); let session_key = identity .encapsulation_private_key - .decapsulate_shared_secret(&kem_ct)?; + .decapsulate_shared_secret(&kem_ct); let mut encrypted = crate::encrypted_message::EncryptedMessage::parse(&mut request.encrypted)?; let plaintext = crate::encrypted_message::EncryptedMessage::decrypt_in_place( &mut encrypted, diff --git a/ql-wire/src/pq.rs b/ql-wire/src/pq.rs index 76b594e1..0f2db1bb 100644 --- a/ql-wire/src/pq.rs +++ b/ql-wire/src/pq.rs @@ -1,7 +1,7 @@ use libcrux_ml_dsa::{ml_dsa_87, KEY_GENERATION_RANDOMNESS_SIZE, SIGNING_RANDOMNESS_SIZE}; use libcrux_ml_kem::{mlkem1024, KEY_GENERATION_SEED_SIZE, SHARED_SECRET_SIZE}; -use crate::{QlCrypto, WireError}; +use crate::QlCrypto; pub(crate) const ML_KEM_SUITE_TAG: &[u8] = b"ml-kem-1024"; @@ -44,17 +44,15 @@ impl MlDsaPrivateKey { self.0.as_ref() } - pub fn sign( - &self, - crypto: &impl QlCrypto, - message: &[u8], - ) -> Result { + pub fn sign(&self, crypto: &impl QlCrypto, message: &[u8]) -> MlDsaSignature { let mut randomness = [0u8; SIGNING_RANDOMNESS_SIZE]; crypto.fill_random_bytes(&mut randomness); let signing_key = ml_dsa_87::MLDSA87SigningKey::new(*self.as_bytes()); + // Safe: we always sign with the empty context, so the only remaining + // error is libcrux's negligible-probability rejection-sampling failure. let signature = ml_dsa_87::sign(&signing_key, message, b"", randomness) - .map_err(|_| WireError::SigningFailed)?; - Ok(MlDsaSignature::from_data(*signature.as_ref())) + .expect("ML-DSA signing should not fail"); + MlDsaSignature::from_data(*signature.as_ref()) } } @@ -117,15 +115,15 @@ impl MlKemPublicKey { pub fn encapsulate_new_shared_secret( &self, crypto: &impl QlCrypto, - ) -> Result<(SessionKey, MlKemCiphertext), WireError> { + ) -> (SessionKey, MlKemCiphertext) { let mut randomness = [0u8; SHARED_SECRET_SIZE]; crypto.fill_random_bytes(&mut randomness); let public_key = mlkem1024::MlKem1024PublicKey::from(self.as_bytes()); let (ciphertext, shared_secret) = mlkem1024::encapsulate(&public_key, randomness); - Ok(( + ( SessionKey::from_data(shared_secret), MlKemCiphertext::from_data(*ciphertext.as_slice()), - )) + ) } } @@ -143,21 +141,18 @@ impl MlKemPrivateKey { self.0.as_ref() } - pub fn decapsulate_shared_secret( - &self, - ciphertext: &MlKemCiphertext, - ) -> Result { + pub fn decapsulate_shared_secret(&self, ciphertext: &MlKemCiphertext) -> SessionKey { self.decapsulate_shared_secret_bytes(ciphertext.as_bytes()) } pub fn decapsulate_shared_secret_bytes( &self, ciphertext: &[u8; MlKemCiphertext::SIZE], - ) -> Result { + ) -> SessionKey { let private_key = mlkem1024::MlKem1024PrivateKey::from(self.as_bytes()); let ciphertext = mlkem1024::MlKem1024Ciphertext::from(ciphertext); let shared_secret = mlkem1024::decapsulate(&private_key, &ciphertext); - Ok(SessionKey::from_data(shared_secret)) + SessionKey::from_data(shared_secret) } } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 929c3773..8a536224 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -35,7 +35,7 @@ impl QlCrypto for TestCrypto { nonce: &Nonce, aad: &[u8], buffer: &mut [u8], - ) -> Option<[u8; EncryptedMessage::AUTH_SIZE]> { + ) -> [u8; EncryptedMessage::AUTH_SIZE] { let key: AesGcm256Key = (*key.data()).into(); let plaintext = buffer.to_vec(); let mut auth = [0u8; EncryptedMessage::AUTH_SIZE]; @@ -46,8 +46,8 @@ impl QlCrypto for TestCrypto { aad, &plaintext, ) - .ok()?; - Some(auth) + .unwrap(); + auth } fn decrypt_with_aead( @@ -92,8 +92,7 @@ fn encrypted_session_record_round_trip_and_decrypt() { &session_key, &body, Nonce([8; Nonce::SIZE]), - ) - .unwrap(); + ); let bytes = record.encode(); let decoded = QlRecord::decode(&bytes).unwrap(); @@ -145,8 +144,7 @@ fn pair_request_round_trip_and_decrypt() { recipient.xid, &recipient.encapsulation_public_key, meta, - ) - .unwrap(); + ); let mut bytes = record.encode(); let QlRecordRef { header, payload } = QlRecord::parse_mut(&mut bytes).unwrap(); @@ -178,8 +176,7 @@ fn ready_round_trip_and_decrypt() { &session_key, meta, Nonce([12; Nonce::SIZE]), - ) - .unwrap(); + ); let record = QlRecord { header, payload: QlPayload::Ready(ready), From a4cfe3133e90ff15647e6ebfa5a7ad7eabc6b1cf Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 23 Mar 2026 11:52:37 -0400 Subject: [PATCH 021/304] ql: surface inbound message errors --- ql-fsm/src/error.rs | 4 ++ ql-fsm/src/implementation/fsm.rs | 16 +++----- ql-fsm/src/implementation/handshake.rs | 56 +++++++------------------- ql-fsm/src/implementation/peer.rs | 9 ++--- ql-fsm/src/tests/handshake.rs | 45 ++++++++++++++++++++- ql-runtime/src/lib.rs | 6 +++ 6 files changed, 78 insertions(+), 58 deletions(-) diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs index 454b1554..5a436812 100644 --- a/ql-fsm/src/error.rs +++ b/ql-fsm/src/error.rs @@ -13,6 +13,8 @@ pub enum QlFsmError { Expired, #[error("decryption failed")] DecryptFailed, + #[error("invalid xid")] + InvalidXid, #[error("missing stream")] MissingStream, #[error("stream is not writable")] @@ -21,6 +23,8 @@ pub enum QlFsmError { SessionClosed, #[error("no peer bound")] NoPeerBound, + #[error("no active session")] + NoSession, } impl From for QlFsmError { diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index b7e4c8c6..c6c4fe4a 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -12,14 +12,14 @@ pub fn receive( let wire::QlRecordRef { header, payload } = wire::QlRecord::parse_mut(&mut bytes)?; if header.recipient != fsm.identity.xid { - return Ok(()); + return Err(QlFsmError::InvalidXid); } if !matches!(&payload, QlPayloadRef::PairRequest(_)) { let Some(peer) = fsm.peer.as_ref().map(|entry| entry.peer.xid) else { - return Ok(()); + return Err(QlFsmError::NoPeerBound); }; if header.sender != peer { - return Ok(()); + return Err(QlFsmError::InvalidXid); } } @@ -41,14 +41,10 @@ pub fn receive( } QlPayloadRef::Session(mut encrypted) => { let Some((_, session_key)) = super::peer_session(fsm) else { - return Ok(()); - }; - let envelope = match wire::decrypt_record(crypto, &header, &mut encrypted, &session_key) - .and_then(|envelope| wire::SessionEnvelope::from_wire(&envelope)) - { - Ok(envelope) => envelope, - Err(_) => return Ok(()), + return Err(QlFsmError::NoSession); }; + let envelope = wire::decrypt_record(crypto, &header, &mut encrypted, &session_key)?; + let envelope = wire::SessionEnvelope::from_wire(&envelope)?; fsm.session.receive(fsm.state.now.instant, envelope); super::drain_session_events(fsm); } diff --git a/ql-fsm/src/implementation/handshake.rs b/ql-fsm/src/implementation/handshake.rs index ebd5415f..d7f79ba0 100644 --- a/ql-fsm/src/implementation/handshake.rs +++ b/ql-fsm/src/implementation/handshake.rs @@ -55,19 +55,6 @@ pub fn handle_hello( let Some(entry) = fsm.peer.as_ref() else { return Ok(()); }; - if wire::verify_hello( - crypto, - header.sender, - fsm.identity.xid, - &entry.peer.signing_key, - hello, - fsm.state.now.unix_secs, - ) - .is_err() - { - return Ok(()); - } - match &entry.session { ConnectionState::Initiator { hello: local_hello, .. @@ -97,6 +84,15 @@ pub fn handle_hello( } } }; + let peer = fsm.peer.as_ref().map(|entry| entry.peer.clone()).unwrap(); + wire::verify_hello( + crypto, + header.sender, + fsm.identity.xid, + &peer.signing_key, + hello, + fsm.state.now.unix_secs, + )?; match action { HelloAction::Ignore => {} @@ -108,9 +104,8 @@ pub fn handle_hello( return Ok(()); } - let peer = fsm.peer.as_ref().map(|entry| entry.peer.clone()).unwrap(); let reply_meta = next_control_meta(fsm, fsm.config.handshake_timeout); - let responder = wire::respond_hello( + let (reply, secrets) = wire::respond_hello( crypto, &fsm.identity, peer.xid, @@ -119,18 +114,7 @@ pub fn handle_hello( hello, reply_meta, fsm.state.now.unix_secs, - ); - - let (reply, secrets) = match responder { - Ok(result) => result, - Err(_) => { - if let Some(entry) = fsm.peer.as_mut() { - entry.session = ConnectionState::Disconnected; - } - emit_peer_status(fsm); - return Ok(()); - } - }; + )?; let deadline = fsm.state.now.instant + fsm.config.handshake_timeout; let retry_at = Some(fsm.state.now.instant + fsm.config.handshake_retry_interval); @@ -202,7 +186,7 @@ pub fn handle_hello_reply( responder_signing_key, } => { let confirm_meta = next_control_meta(fsm, fsm.config.handshake_timeout); - let (confirm, session_key) = match wire::build_confirm( + let (confirm, session_key) = wire::build_confirm( crypto, &fsm.identity, header.sender, @@ -212,10 +196,7 @@ pub fn handle_hello_reply( &initiator_secret, confirm_meta, fsm.state.now.unix_secs, - ) { - Ok(result) => result, - Err(_) => return Ok(()), - }; + )?; if is_replayed_control(fsm, header.sender, wire::ControlMeta::from_wire(reply.meta)) { return Ok(()); @@ -282,10 +263,7 @@ pub fn handle_confirm( .map(|session_key| (hello.clone(), reply.clone(), *deadline, session_key)) }; - let (hello, reply, deadline, session_key) = match outcome { - Ok(result) => result, - Err(_) => return Ok(()), - }; + let (hello, reply, deadline, session_key) = outcome?; if is_replayed_control( fsm, @@ -342,11 +320,7 @@ pub fn handle_ready( } }; - let body = - match wire::decrypt_ready(crypto, header, ready, &session_key, fsm.state.now.unix_secs) { - Ok(body) => body, - Err(_) => return Ok(()), - }; + let body = wire::decrypt_ready(crypto, header, ready, &session_key, fsm.state.now.unix_secs)?; if is_replayed_control(fsm, header.sender, body.meta) { return Ok(()); } diff --git a/ql-fsm/src/implementation/peer.rs b/ql-fsm/src/implementation/peer.rs index 001f238c..bccf112c 100644 --- a/ql-fsm/src/implementation/peer.rs +++ b/ql-fsm/src/implementation/peer.rs @@ -27,16 +27,13 @@ pub fn handle_pair( header: &QlHeader, request: &mut RefMut<'_, PairRequestRecordWire>, ) -> Result<(), QlFsmError> { - let payload = match wire::decrypt_pair_request( + let payload = wire::decrypt_pair_request( crypto, &fsm.identity, header, request, fsm.state.now.unix_secs, - ) { - Ok(payload) => payload, - Err(_) => return Ok(()), - }; + )?; let peer = Peer { xid: payload.xid, signing_key: payload.signing_pub_key, @@ -47,7 +44,7 @@ pub fn handle_pair( } match fsm.peer.as_ref() { - Some(existing) if existing.peer != peer => return Ok(()), + Some(existing) if existing.peer != peer => return Err(QlFsmError::InvalidXid), Some(_) => {} None => bind_peer_record(fsm, peer.clone()), } diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 2e573572..02ab2e58 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -1,6 +1,6 @@ use std::time::Duration; -use ql_wire::QlPayload; +use ql_wire::{QlPayload, XID}; use super::*; use crate::state::{ConnectionState, HandshakeInitiator, HandshakeResponder}; @@ -313,3 +313,46 @@ fn simultaneous_connect_converges_to_connected_peers() { Some(ConnectionState::Connected { .. }) )); } + +#[test] +fn receive_surfaces_invalid_xid_for_wrong_recipient() { + let mut harness = Harness::paired(QlFsmConfig::default()); + + harness + .a + .fsm + .connect(harness.time(), &harness.a.crypto) + .unwrap(); + let mut hello = harness.next_outbound_a().unwrap(); + hello.header.recipient = XID([0xAA; XID::SIZE]); + + assert_eq!( + harness + .b + .fsm + .receive(harness.time(), hello.encode(), &harness.b.crypto), + Err(crate::QlFsmError::InvalidXid) + ); +} + +#[test] +fn receive_surfaces_invalid_signature_for_tampered_hello() { + let mut harness = Harness::paired(QlFsmConfig::default()); + + harness + .a + .fsm + .connect(harness.time(), &harness.a.crypto) + .unwrap(); + let hello = harness.next_outbound_a().unwrap(); + let mut bytes = hello.encode(); + *bytes.last_mut().unwrap() ^= 0x01; + + assert_eq!( + harness + .b + .fsm + .receive(harness.time(), bytes, &harness.b.crypto), + Err(crate::QlFsmError::InvalidSignature) + ); +} diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index d1a58564..fcc93cc3 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -26,6 +26,8 @@ pub enum QlError { Expired, #[error("decryption failed")] DecryptFailed, + #[error("invalid xid")] + InvalidXid, #[error("missing stream")] MissingStream, #[error("stream is not writable")] @@ -34,6 +36,8 @@ pub enum QlError { SessionClosed, #[error("no peer bound")] NoPeerBound, + #[error("no active session")] + NoSession, #[error("send failed")] SendFailed, #[error("stream closed {code:?}")] @@ -53,10 +57,12 @@ impl From for QlError { QlFsmError::InvalidSignature => Self::InvalidSignature, QlFsmError::Expired => Self::Expired, QlFsmError::DecryptFailed => Self::DecryptFailed, + QlFsmError::InvalidXid => Self::InvalidXid, QlFsmError::MissingStream => Self::MissingStream, QlFsmError::NotWritable => Self::NotWritable, QlFsmError::SessionClosed => Self::SessionClosed, QlFsmError::NoPeerBound => Self::NoPeerBound, + QlFsmError::NoSession => Self::NoSession, } } } From e39f3e5a3d72e21bc1cb2616876ebe131f7419c7 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 23 Mar 2026 12:10:50 -0400 Subject: [PATCH 022/304] ql: clean up handle stream frame --- ql-fsm/src/session/mod.rs | 67 +++++++++++++++------------------------ 1 file changed, 26 insertions(+), 41 deletions(-) diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index ab78203f..e22b28a6 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -737,8 +737,15 @@ impl SessionFsm { } fn handle_stream_frame(&mut self, frame: StreamChunk) { - let stream_id = frame.stream_id; + let StreamChunk { + stream_id, + offset, + bytes, + fin, + } = frame; + let frame_end = offset + bytes.len() as u64; let remote_namespace = self.config.local_namespace.remote(); + if !self.state.streams.contains_key(&stream_id) { if !remote_namespace.matches(stream_id) { self.fail_session(SessionCloseBody { @@ -746,7 +753,7 @@ impl SessionFsm { }); return; } - if frame.offset != 0 { + if offset != 0 { return; } self.state @@ -758,24 +765,21 @@ impl SessionFsm { let Some(stream) = self.state.streams.get_mut(&stream_id) else { return; }; - if matches!(stream.inbound_state, InboundState::Discarding) { - return; - } - if matches!( - stream.inbound_state, - InboundState::Closed(_) | InboundState::Finished - ) { - if frame.offset + frame.bytes.len() as u64 <= stream.next_recv_offset { + match stream.inbound_state { + InboundState::Open => (), + InboundState::Finished | InboundState::Closed(_) => { + if frame_end <= stream.next_recv_offset { + return; + } + self.fail_session(SessionCloseBody { + code: CloseCode::PROTOCOL, + }); return; } - self.fail_session(SessionCloseBody { - code: CloseCode::PROTOCOL, - }); - return; + InboundState::Discarding => return, } - if frame.offset < stream.next_recv_offset { - let frame_end = frame.offset + frame.bytes.len() as u64; + if offset < stream.next_recv_offset { if frame_end <= stream.next_recv_offset { return; } @@ -785,20 +789,16 @@ impl SessionFsm { return; } - if frame.offset == stream.next_recv_offset { + if offset == stream.next_recv_offset { let was_readable = !stream.recv_buf.is_empty(); - let was_finished = matches!(stream.inbound_state, InboundState::Finished); - Self::commit_inbound_frame(stream, frame); + Self::commit_inbound_chunk(stream, bytes, fin); Self::drain_pending_recv(stream); - let became_readable = !was_readable && !stream.recv_buf.is_empty(); - let became_finished = - !was_finished && matches!(stream.inbound_state, InboundState::Finished); - if became_readable { + if !was_readable && !stream.recv_buf.is_empty() { self.state .events .push_back(SessionEvent::Readable(stream_id)); } - if became_finished { + if matches!(stream.inbound_state, InboundState::Finished) { self.state .events .push_back(SessionEvent::Finished(stream_id)); @@ -807,16 +807,7 @@ impl SessionFsm { return; } - if Self::insert_pending_chunk( - stream, - frame.offset, - PendingRxChunk { - bytes: frame.bytes, - fin: frame.fin, - }, - ) - .is_err() - { + if Self::insert_pending_chunk(stream, offset, PendingRxChunk { bytes, fin }).is_err() { self.fail_session(SessionCloseBody { code: CloseCode::PROTOCOL, }); @@ -877,15 +868,9 @@ impl SessionFsm { matches!(target, CloseTarget::Both) || role.outbound_target() == target } - fn commit_inbound_frame(stream: &mut StreamState, frame: StreamChunk) { - Self::commit_inbound_chunk(stream, frame.bytes, frame.fin); - } - fn commit_inbound_chunk(stream: &mut StreamState, bytes: Vec, fin: bool) { stream.next_recv_offset += bytes.len() as u64; - if !bytes.is_empty() { - stream.recv_buf.extend(bytes); - } + stream.recv_buf.extend(bytes); if fin { stream.inbound_state = InboundState::Finished; } From e70f829ab0cb2a9063f164168986d10279762cac Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 23 Mar 2026 12:15:39 -0400 Subject: [PATCH 023/304] ql: clean up test only methods --- ql-fsm/src/session/mod.rs | 12 ------- ql-fsm/src/session/tests.rs | 64 +++++++++++++++++++------------------ 2 files changed, 33 insertions(+), 43 deletions(-) diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index e22b28a6..bde271a1 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -425,13 +425,6 @@ impl SessionFsm { entry.state = TxState::Pending; } - #[cfg(test)] - pub fn next_outbound(&mut self, now: Instant) -> Option { - let envelope = self.take_next_write(now)?; - self.confirm_write(now, envelope.seq); - Some(envelope) - } - pub fn on_timer(&mut self, now: Instant) { self.state.now = now; self.collect_timeouts(); @@ -495,11 +488,6 @@ impl SessionFsm { self.state.events.pop_front() } - #[cfg(test)] - pub fn session_state(&self) -> SessionState { - self.state.session_state - } - pub fn has_pending_stream_work(&self) -> bool { self.state.streams.values().any(|stream| { stream.pending_close.is_some() diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index d41fa7bc..653d9b12 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -36,6 +36,12 @@ fn ping(seq: u64, ack: SessionAck) -> SessionEnvelope { } } +fn next_outbound(fsm: &mut SessionFsm, now: Instant) -> Option { + let envelope = fsm.take_next_write(now)?; + fsm.confirm_write(now, envelope.seq); + Some(envelope) +} + #[test] fn outbound_session_seq_increments_monotonically() { let now = Instant::now(); @@ -43,7 +49,7 @@ fn outbound_session_seq_increments_monotonically() { let stream_id = fsm.open_stream().unwrap(); fsm.write_stream(stream_id, b"one".to_vec()).unwrap(); - let first = fsm.next_outbound(now).unwrap(); + let first = next_outbound(&mut fsm, now).unwrap(); fsm.receive( now + Duration::from_millis(1), @@ -57,7 +63,7 @@ fn outbound_session_seq_increments_monotonically() { ); fsm.write_stream(stream_id, b"two".to_vec()).unwrap(); - let second = fsm.next_outbound(now + Duration::from_millis(2)).unwrap(); + let second = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); assert_eq!(first.seq, SessionSeq(1)); assert_eq!(second.seq, SessionSeq(2)); @@ -70,7 +76,7 @@ fn inbound_ack_removes_acked_tx_entries() { let stream_id = fsm.open_stream().unwrap(); fsm.write_stream(stream_id, b"one".to_vec()).unwrap(); - let first = fsm.next_outbound(now).unwrap(); + let first = next_outbound(&mut fsm, now).unwrap(); assert_eq!(first.seq, SessionSeq(1)); assert!(fsm.state.tx_ring.contains_key(&SessionSeq(1))); @@ -108,7 +114,7 @@ fn out_of_order_receive_produces_bitmap_ack_then_advances_base() { }), }, ); - let gap_ack = fsm.next_outbound(now).unwrap(); + let gap_ack = next_outbound(&mut fsm, now).unwrap(); assert_eq!(gap_ack.seq, SessionSeq(1)); assert_eq!( gap_ack.ack, @@ -131,7 +137,7 @@ fn out_of_order_receive_produces_bitmap_ack_then_advances_base() { }), }, ); - let contiguous_ack = fsm.next_outbound(now + Duration::from_millis(10)).unwrap(); + let contiguous_ack = next_outbound(&mut fsm, now + Duration::from_millis(10)).unwrap(); assert_eq!(contiguous_ack.seq, SessionSeq(2)); assert_eq!( contiguous_ack.ack, @@ -149,10 +155,10 @@ fn retransmit_reuses_session_seq() { let stream_id = fsm.open_stream().unwrap(); fsm.write_stream(stream_id, b"retry-me".to_vec()).unwrap(); - let first = fsm.next_outbound(now).unwrap(); + let first = next_outbound(&mut fsm, now).unwrap(); let retransmit_at = now + Duration::from_millis(200); - let retried = fsm.next_outbound(retransmit_at).unwrap(); + let retried = next_outbound(&mut fsm, retransmit_at).unwrap(); assert_eq!(first.seq, SessionSeq(1)); assert_eq!(retried.seq, SessionSeq(1)); @@ -169,10 +175,10 @@ fn repeated_outbound_messages_keep_reporting_latest_receive_ack() { fsm.receive(now, ack(1, SessionAck::EMPTY)); fsm.write_stream(stream_id_a, b"one".to_vec()).unwrap(); - let first = fsm.next_outbound(now).unwrap(); + let first = next_outbound(&mut fsm, now).unwrap(); fsm.write_stream(stream_id_b, b"two".to_vec()).unwrap(); - let second = fsm.next_outbound(now + Duration::from_millis(1)).unwrap(); + let second = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); assert_eq!(first.ack.base, SessionSeq(1)); assert_eq!(second.ack.base, SessionSeq(1)); @@ -208,7 +214,7 @@ fn local_inbound_close_ignores_late_remote_bytes() { }, ); - assert_eq!(fsm.session_state(), SessionState::Open); + assert_eq!(fsm.state.session_state, SessionState::Open); assert_eq!(read_stream_all(&mut fsm, stream_id), Vec::::new()); assert!(fsm.take_next_event().is_none()); } @@ -233,7 +239,7 @@ fn missing_stream_nonzero_offset_is_ignored_until_offset_zero_arrives() { }, ); - assert_eq!(fsm.session_state(), SessionState::Open); + assert_eq!(fsm.state.session_state, SessionState::Open); assert!(fsm.take_next_event().is_none()); assert!(!fsm.state.streams.contains_key(&stream_id)); @@ -276,7 +282,7 @@ fn local_stream_waits_for_open_frame_ack_before_sending_follow_up_data() { fsm.write_stream(stream_id, b"hello".to_vec()).unwrap(); - let first = fsm.next_outbound(now).unwrap(); + let first = next_outbound(&mut fsm, now).unwrap(); assert_eq!( first.body, SessionBody::Stream(StreamChunk { @@ -286,7 +292,7 @@ fn local_stream_waits_for_open_frame_ack_before_sending_follow_up_data() { fin: false, }) ); - assert!(fsm.next_outbound(now + Duration::from_millis(1)).is_none()); + assert!(next_outbound(&mut fsm, now + Duration::from_millis(1)).is_none()); fsm.receive( now + Duration::from_millis(2), @@ -299,7 +305,7 @@ fn local_stream_waits_for_open_frame_ack_before_sending_follow_up_data() { ), ); - let second = fsm.next_outbound(now + Duration::from_millis(3)).unwrap(); + let second = next_outbound(&mut fsm, now + Duration::from_millis(3)).unwrap(); assert_eq!( second.body, SessionBody::Stream(StreamChunk { @@ -347,7 +353,7 @@ fn stream_is_reaped_after_terminal_state_and_last_stream_ack() { assert!(fsm.state.streams.contains_key(&stream_id)); fsm.finish_stream(stream_id).unwrap(); - let fin = fsm.next_outbound(now + Duration::from_millis(1)).unwrap(); + let fin = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); assert_eq!( fin.body, SessionBody::Stream(StreamChunk { @@ -406,7 +412,7 @@ fn replayed_remote_open_does_not_recreate_reaped_stream() { ); fsm.finish_stream(stream_id).unwrap(); - let fin = fsm.next_outbound(now + Duration::from_millis(1)).unwrap(); + let fin = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); assert_eq!( fin.body, SessionBody::Stream(StreamChunk { @@ -432,7 +438,7 @@ fn replayed_remote_open_does_not_recreate_reaped_stream() { fsm.receive(now + Duration::from_millis(3), opener); - assert_eq!(fsm.session_state(), SessionState::Open); + assert_eq!(fsm.state.session_state, SessionState::Open); assert!(!fsm.state.streams.contains_key(&stream_id)); assert!(fsm.take_next_event().is_none()); } @@ -493,7 +499,7 @@ fn next_outbound_round_robins_across_ready_streams() { fsm.write_stream(stream_id_b, b"b-2".to_vec()).unwrap(); let first_round: Vec<_> = (0..2) - .map(|_| match fsm.next_outbound(now).unwrap().body { + .map(|_| match next_outbound(&mut fsm, now).unwrap().body { SessionBody::Stream(frame) => frame.stream_id, other => panic!("expected stream frame, got {other:?}"), }) @@ -512,8 +518,7 @@ fn next_outbound_round_robins_across_ready_streams() { let second_round: Vec<_> = (0..2) .map(|_| { - match fsm - .next_outbound(now + Duration::from_millis(2)) + match next_outbound(&mut fsm, now + Duration::from_millis(2)) .unwrap() .body { @@ -539,10 +544,10 @@ fn idle_session_sends_ping_after_keepalive_interval() { ); assert_eq!(fsm.next_deadline(), Some(now + Duration::from_millis(50))); - assert!(fsm.next_outbound(now + Duration::from_millis(49)).is_none()); + assert!(next_outbound(&mut fsm, now + Duration::from_millis(49)).is_none()); fsm.on_timer(now + Duration::from_millis(50)); - let envelope = fsm.next_outbound(now + Duration::from_millis(50)).unwrap(); + let envelope = next_outbound(&mut fsm, now + Duration::from_millis(50)).unwrap(); assert!(matches!(envelope.body, SessionBody::Ping(PingBody))); } @@ -553,11 +558,11 @@ fn receive_ping_schedules_ack_without_ping_pong() { fsm.receive(now, ping(1, SessionAck::EMPTY)); - let ack_envelope = fsm.next_outbound(now + Duration::from_millis(10)).unwrap(); + let ack_envelope = next_outbound(&mut fsm, now + Duration::from_millis(10)).unwrap(); assert_eq!(ack_envelope.body, SessionBody::Ack); fsm.receive(now + Duration::from_millis(20), ack(2, SessionAck::EMPTY)); - assert!(fsm.next_outbound(now + Duration::from_millis(30)).is_none()); + assert!(next_outbound(&mut fsm, now + Duration::from_millis(30)).is_none()); } #[test] @@ -568,8 +573,7 @@ fn tx_selective_ack_keeps_front_gap_pinned() { for (byte, stream_id) in (0..64u8).zip(stream_ids.iter().copied()) { fsm.write_stream(stream_id, vec![byte]).unwrap(); - let _ = fsm - .next_outbound(now + Duration::from_millis(byte as u64)) + let _ = next_outbound(&mut fsm, now + Duration::from_millis(byte as u64)) .unwrap(); } @@ -589,9 +593,7 @@ fn tx_selective_ack_keeps_front_gap_pinned() { let extra_stream = fsm.open_stream().unwrap(); fsm.write_stream(extra_stream, b"x".to_vec()).unwrap(); - assert!(fsm - .next_outbound(now + Duration::from_millis(101)) - .is_none()); + assert!(next_outbound(&mut fsm, now + Duration::from_millis(101)).is_none()); fsm.receive( now + Duration::from_millis(102), @@ -605,7 +607,7 @@ fn tx_selective_ack_keeps_front_gap_pinned() { ); assert_eq!( - fsm.next_outbound(now + Duration::from_millis(103)) + next_outbound(&mut fsm, now + Duration::from_millis(103)) .unwrap() .seq, SessionSeq(65) @@ -619,7 +621,7 @@ fn rx_seq_past_window_closes_protocol() { fsm.receive(now, ping(65, SessionAck::EMPTY)); - assert_eq!(fsm.session_state(), SessionState::Closed); + assert_eq!(fsm.state.session_state, SessionState::Closed); assert!(matches!( fsm.take_next_event(), Some(super::SessionEvent::SessionClosed(close)) if close.code == CloseCode::PROTOCOL From 9f05937d2dbde2c325cc4a43fcbe635d0d187008 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 24 Mar 2026 09:11:35 -0400 Subject: [PATCH 024/304] ql: add todos --- ql-fsm/src/implementation/fsm.rs | 1 + ql-runtime/src/driver.rs | 1 + 2 files changed, 2 insertions(+) diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index c6c4fe4a..a8ba04e9 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -44,6 +44,7 @@ pub fn receive( return Err(QlFsmError::NoSession); }; let envelope = wire::decrypt_record(crypto, &header, &mut encrypted, &session_key)?; + // TODO: this seems unnecessary to me? let envelope = wire::SessionEnvelope::from_wire(&envelope)?; fsm.session.receive(fsm.state.now.instant, envelope); super::drain_session_events(fsm); diff --git a/ql-runtime/src/driver.rs b/ql-runtime/src/driver.rs index 79edcdd3..4aeae179 100644 --- a/ql-runtime/src/driver.rs +++ b/ql-runtime/src/driver.rs @@ -215,6 +215,7 @@ impl DriverState { self.finish_step(platform, in_flight); } RuntimeCommand::Incoming(bytes) => { + // TODO: surface these errors somehow? let _ = self.fsm.receive(now(), bytes, platform); self.finish_step(platform, in_flight); } From 51e2315963d604dfdff87e9a77cce6755d5b1ffc Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 24 Mar 2026 10:23:06 -0400 Subject: [PATCH 025/304] ql: stream numbered chunks instead of byte offset --- ql-fsm/src/session/mod.rs | 221 +++++++++++--------------- ql-fsm/src/session/state.rs | 34 ++-- ql-fsm/src/session/stream_window.rs | 71 +++++++++ ql-fsm/src/session/tests.rs | 147 +++++++++++++++-- ql-wire/src/encrypted/stream_chunk.rs | 10 +- ql-wire/src/tests.rs | 6 +- 6 files changed, 317 insertions(+), 172 deletions(-) create mode 100644 ql-fsm/src/session/stream_window.rs diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index bde271a1..739b02de 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -1,24 +1,28 @@ pub(crate) mod ring; pub(crate) mod state; +pub(crate) mod stream_window; #[cfg(test)] mod tests; use std::time::{Duration, Instant}; +use indexmap::map::Entry; use ql_wire::{ CloseCode, CloseTarget, PingBody, SessionBody, SessionCloseBody, SessionEnvelope, SessionSeq, StreamChunk, StreamClose, StreamId, UnpairBody, XID, }; use self::{ - ring::SeqRingInsertError, state::{ - AckState, InboundState, OutboundState, PendingRxChunk, PendingSessionBody, SessionFsmState, + AckState, InboundState, OutboundState, PendingSessionBody, SessionFsmState, StreamOpenState, StreamRole, StreamState, TxEntry, TxState, }, + stream_window::{RecvInsertOutcome, RxChunk}, }; +struct RejectNoAck; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum StreamNamespace { Low, @@ -278,35 +282,22 @@ impl SessionFsm { } return; } - match self.state.rx_ring.insert(seq, ()) { - Ok(()) => { - let out_of_order = seq != self.state.rx_ring.base_seq(); - self.state.rx_ring.advance_occupied_front(); - if !matches!(envelope.body, SessionBody::Ack) { - self.schedule_ack(out_of_order); - } - } - Err(SeqRingInsertError::OutOfWindow) => { - self.fail_session(SessionCloseBody { - code: CloseCode::PROTOCOL, - }); - return; - } - Err(SeqRingInsertError::Occupied) => { - if !matches!(envelope.body, SessionBody::Ack) { - self.schedule_ack(true); - } - return; - } + if !self.state.rx_ring.accepts_seq(seq) { + self.fail_session(SessionCloseBody { + code: CloseCode::PROTOCOL, + }); + return; } - match envelope.body { - SessionBody::Ack => {} - SessionBody::Ping(_) => {} + let out_of_order = seq != self.state.rx_ring.base_seq(); + let body_kind_is_ack = matches!(envelope.body, SessionBody::Ack); + let apply_inbound_body = match envelope.body { + SessionBody::Ack | SessionBody::Ping(_) => Ok(()), SessionBody::Unpair(_) => { self.state.session_state = SessionState::Closed; self.clear_streams(); self.state.events.push_back(SessionEvent::Unpaired); + Ok(()) } SessionBody::Close(close) => { self.state.session_state = SessionState::Closed; @@ -314,9 +305,28 @@ impl SessionFsm { self.state .events .push_back(SessionEvent::SessionClosed(close)); + Ok(()) } SessionBody::Stream(frame) => self.handle_stream_frame(frame), - SessionBody::StreamClose(frame) => self.handle_stream_close(frame), + SessionBody::StreamClose(frame) => { + self.handle_stream_close(frame); + Ok(()) + } + }; + if apply_inbound_body.is_err() { + return; + } + + match self.state.rx_ring.insert(seq, ()) { + Ok(()) => { + self.state.rx_ring.advance_occupied_front(); + if !body_kind_is_ack { + self.schedule_ack(out_of_order); + } + } + Err(e) => { + unreachable!("seq window was pre-validated before body handling {e:?}"); + } } } @@ -607,7 +617,7 @@ impl SessionFsm { } let (stream_id, opens_stream) = match &entry.pending.body { - SessionBody::Stream(frame) => (Some(frame.stream_id), frame.offset == 0), + SessionBody::Stream(frame) => (Some(frame.stream_id), frame.chunk_seq == 0), SessionBody::StreamClose(frame) => (Some(frame.stream_id), false), _ => (None, false), }; @@ -714,7 +724,7 @@ impl SessionFsm { .is_some_and(|stream| { !matches!(stream.outbound_state, OutboundState::Closed) || (matches!(stream.open_state, StreamOpenState::WaitingForAck) - && frame.offset == 0) + && frame.chunk_seq == 0) }) } SessionBody::StreamClose(frame) => { @@ -724,81 +734,71 @@ impl SessionFsm { } } - fn handle_stream_frame(&mut self, frame: StreamChunk) { + fn handle_stream_frame(&mut self, frame: StreamChunk) -> Result<(), RejectNoAck> { let StreamChunk { stream_id, - offset, + chunk_seq, bytes, fin, } = frame; - let frame_end = offset + bytes.len() as u64; let remote_namespace = self.config.local_namespace.remote(); - - if !self.state.streams.contains_key(&stream_id) { - if !remote_namespace.matches(stream_id) { - self.fail_session(SessionCloseBody { - code: CloseCode::PROTOCOL, - }); - return; - } - if offset != 0 { - return; + let stream = match self.state.streams.entry(stream_id) { + Entry::Occupied(entry) => entry.into_mut(), + Entry::Vacant(entry) => { + if !remote_namespace.matches(stream_id) { + self.fail_session(SessionCloseBody { + code: CloseCode::PROTOCOL, + }); + return Ok(()); + } + if chunk_seq != 0 { + return Err(RejectNoAck); + } + self.state.events.push_back(SessionEvent::Opened(stream_id)); + entry.insert(StreamState::new(StreamRole::Responder)) } - self.state - .streams - .insert(stream_id, StreamState::new(StreamRole::Responder)); - self.state.events.push_back(SessionEvent::Opened(stream_id)); - } - - let Some(stream) = self.state.streams.get_mut(&stream_id) else { - return; }; match stream.inbound_state { InboundState::Open => (), InboundState::Finished | InboundState::Closed(_) => { - if frame_end <= stream.next_recv_offset { - return; + if chunk_seq < stream.recv_window.next_chunk_seq() { + return Ok(()); } self.fail_session(SessionCloseBody { code: CloseCode::PROTOCOL, }); - return; + return Ok(()); } - InboundState::Discarding => return, + InboundState::Discarding => return Ok(()), } - if offset < stream.next_recv_offset { - if frame_end <= stream.next_recv_offset { - return; - } - self.fail_session(SessionCloseBody { - code: CloseCode::PROTOCOL, - }); - return; - } + let was_readable = !stream.recv_buf.is_empty(); + let outcome = stream.recv_window.insert(chunk_seq, RxChunk { bytes, fin }); - if offset == stream.next_recv_offset { - let was_readable = !stream.recv_buf.is_empty(); - Self::commit_inbound_chunk(stream, bytes, fin); - Self::drain_pending_recv(stream); - if !was_readable && !stream.recv_buf.is_empty() { - self.state - .events - .push_back(SessionEvent::Readable(stream_id)); + match outcome { + RecvInsertOutcome::Inserted => { + Self::drain_recv_window(stream); + if !was_readable && !stream.recv_buf.is_empty() { + self.state + .events + .push_back(SessionEvent::Readable(stream_id)); + } + if matches!(stream.inbound_state, InboundState::Finished) { + self.state + .events + .push_back(SessionEvent::Finished(stream_id)); + } + self.try_reap_stream(stream_id); + Ok(()) } - if matches!(stream.inbound_state, InboundState::Finished) { - self.state - .events - .push_back(SessionEvent::Finished(stream_id)); + RecvInsertOutcome::Duplicate => Ok(()), + RecvInsertOutcome::RejectNoAck => Err(RejectNoAck), + RecvInsertOutcome::Conflict => { + self.fail_session(SessionCloseBody { + code: CloseCode::PROTOCOL, + }); + Ok(()) } - self.try_reap_stream(stream_id); - return; - } - - if Self::insert_pending_chunk(stream, offset, PendingRxChunk { bytes, fin }).is_err() { - self.fail_session(SessionCloseBody { - code: CloseCode::PROTOCOL, - }); } } @@ -818,7 +818,7 @@ impl SessionFsm { { stream.inbound_state = InboundState::Closed(frame.clone()); stream.recv_buf.clear(); - stream.pending_recv.clear(); + stream.recv_window.clear(); self.state .events .push_back(SessionEvent::Closed(frame.clone())); @@ -840,7 +840,7 @@ impl SessionFsm { if Self::target_affects_inbound(stream.role, target) { stream.inbound_state = InboundState::Discarding; stream.recv_buf.clear(); - stream.pending_recv.clear(); + stream.recv_window.clear(); } if Self::target_affects_outbound(stream.role, target) { stream.outbound_state = OutboundState::Closed; @@ -856,50 +856,17 @@ impl SessionFsm { matches!(target, CloseTarget::Both) || role.outbound_target() == target } - fn commit_inbound_chunk(stream: &mut StreamState, bytes: Vec, fin: bool) { - stream.next_recv_offset += bytes.len() as u64; - stream.recv_buf.extend(bytes); - if fin { - stream.inbound_state = InboundState::Finished; - } - } - - fn drain_pending_recv(stream: &mut StreamState) { - while let Some(chunk) = stream.pending_recv.remove(&stream.next_recv_offset) { - Self::commit_inbound_chunk(stream, chunk.bytes, chunk.fin); - if matches!(stream.inbound_state, InboundState::Finished) { + fn drain_recv_window(stream: &mut StreamState) { + while let Some(chunk) = stream.recv_window.pop_contiguous() { + let RxChunk { bytes, fin } = chunk; + stream.recv_buf.extend(bytes); + if fin { + stream.inbound_state = InboundState::Finished; break; } } } - fn insert_pending_chunk( - stream: &mut StreamState, - offset: u64, - chunk: PendingRxChunk, - ) -> Result<(), ()> { - let end = chunk.end_offset(offset); - - if let Some((&prev_offset, prev)) = stream.pending_recv.range(..=offset).next_back() { - let prev_end = prev.end_offset(prev_offset); - if prev_end > offset { - if prev_offset == offset && prev.bytes == chunk.bytes && prev.fin == chunk.fin { - return Ok(()); - } - return Err(()); - } - } - - if let Some((&next_offset, _)) = stream.pending_recv.range(offset..).next() { - if end > next_offset { - return Err(()); - } - } - - stream.pending_recv.insert(offset, chunk); - Ok(()) - } - fn take_stream_frame( stream: &mut StreamState, stream_id: StreamId, @@ -918,11 +885,11 @@ impl SessionFsm { }; let frame = StreamChunk { stream_id, - offset: stream.next_send_offset, + chunk_seq: stream.next_send_chunk_seq, bytes, fin, }; - stream.next_send_offset += frame.bytes.len() as u64; + stream.next_send_chunk_seq += 1; return Some(frame); } @@ -930,7 +897,7 @@ impl SessionFsm { stream.outbound_state = OutboundState::Finished; return Some(StreamChunk { stream_id, - offset: stream.next_send_offset, + chunk_seq: stream.next_send_chunk_seq, bytes: Vec::new(), fin: true, }); @@ -956,7 +923,7 @@ impl SessionFsm { if !stream.send_buf.is_empty() || !stream.recv_buf.is_empty() - || !stream.pending_recv.is_empty() + || !stream.recv_window.is_empty() { return false; } diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index 209a8700..557305d4 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -1,14 +1,16 @@ -use std::{ - collections::{BTreeMap, VecDeque}, - time::Instant, -}; +use std::{collections::VecDeque, time::Instant}; use indexmap::IndexMap; use ql_wire::{ CloseTarget, SessionAck, SessionBody, SessionCloseBody, SessionSeq, StreamClose, StreamId, }; -use super::{ring::SeqRing, SessionEvent, SessionState}; +use super::{ + ring::SeqRing, + stream_window::StreamRecvWindow, + SessionEvent, + SessionState, +}; pub const SESSION_WINDOW_CAPACITY: usize = 64; @@ -34,18 +36,6 @@ impl StreamRole { } } -#[derive(Debug, Clone)] -pub struct PendingRxChunk { - pub bytes: Vec, - pub fin: bool, -} - -impl PendingRxChunk { - pub fn end_offset(&self, offset: u64) -> u64 { - offset + self.bytes.len() as u64 - } -} - #[derive(Debug, Clone)] pub enum OutboundState { Open, @@ -76,9 +66,8 @@ pub struct StreamState { pub send_buf: VecDeque, pub pending_close: Option, pub recv_buf: VecDeque, - pub pending_recv: BTreeMap, - pub next_send_offset: u64, - pub next_recv_offset: u64, + pub recv_window: StreamRecvWindow, + pub next_send_chunk_seq: u64, pub outbound_state: OutboundState, pub inbound_state: InboundState, } @@ -94,9 +83,8 @@ impl StreamState { send_buf: VecDeque::new(), pending_close: None, recv_buf: VecDeque::new(), - pending_recv: BTreeMap::new(), - next_send_offset: 0, - next_recv_offset: 0, + recv_window: StreamRecvWindow::new(), + next_send_chunk_seq: 0, outbound_state: OutboundState::Open, inbound_state: InboundState::Open, } diff --git a/ql-fsm/src/session/stream_window.rs b/ql-fsm/src/session/stream_window.rs new file mode 100644 index 00000000..e25b5e37 --- /dev/null +++ b/ql-fsm/src/session/stream_window.rs @@ -0,0 +1,71 @@ +use std::array; + +pub const STREAM_RECV_WINDOW_CAPACITY: usize = 8; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RxChunk { + pub bytes: Vec, + pub fin: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RecvInsertOutcome { + Inserted, + Duplicate, + RejectNoAck, + Conflict, +} + +#[derive(Debug)] +pub struct StreamRecvWindow { + next_chunk_seq: u64, + slots: [Option; STREAM_RECV_WINDOW_CAPACITY], +} + +impl StreamRecvWindow { + pub fn new() -> Self { + Self { + next_chunk_seq: 0, + slots: array::from_fn(|_| None), + } + } + + pub fn clear(&mut self) { + self.slots.fill(None); + } + + pub fn is_empty(&self) -> bool { + self.slots.iter().all(Option::is_none) + } + + pub fn next_chunk_seq(&self) -> u64 { + self.next_chunk_seq + } + + pub fn insert(&mut self, chunk_seq: u64, chunk: RxChunk) -> RecvInsertOutcome { + let Some(delta) = chunk_seq.checked_sub(self.next_chunk_seq) else { + return RecvInsertOutcome::Duplicate; + }; + if delta >= self.slots.len() as u64 { + return RecvInsertOutcome::RejectNoAck; + } + + let slot = &mut self.slots[delta as usize]; + match slot { + Some(existing) if *existing == chunk => RecvInsertOutcome::Duplicate, + Some(_) => RecvInsertOutcome::Conflict, + None => { + *slot = Some(chunk); + RecvInsertOutcome::Inserted + } + } + } + + pub fn pop_contiguous(&mut self) -> Option { + let chunk = self.slots[0].take()?; + self.next_chunk_seq += 1; + self.slots.rotate_left(1); + self.slots[self.slots.len() - 1] = None; + Some(chunk) + } +} diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 653d9b12..c399f00f 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -108,7 +108,7 @@ fn out_of_order_receive_produces_bitmap_ack_then_advances_base() { ack: SessionAck::EMPTY, body: SessionBody::Stream(StreamChunk { stream_id: stream_id_a, - offset: 0, + chunk_seq: 0, bytes: b"a".to_vec(), fin: false, }), @@ -131,7 +131,7 @@ fn out_of_order_receive_produces_bitmap_ack_then_advances_base() { ack: SessionAck::EMPTY, body: SessionBody::Stream(StreamChunk { stream_id: stream_id_b, - offset: 0, + chunk_seq: 0, bytes: b"b".to_vec(), fin: false, }), @@ -207,7 +207,7 @@ fn local_inbound_close_ignores_late_remote_bytes() { ack: SessionAck::EMPTY, body: SessionBody::Stream(StreamChunk { stream_id, - offset: 0, + chunk_seq: 0, bytes: b"late".to_vec(), fin: false, }), @@ -220,7 +220,7 @@ fn local_inbound_close_ignores_late_remote_bytes() { } #[test] -fn missing_stream_nonzero_offset_is_ignored_until_offset_zero_arrives() { +fn missing_stream_nonzero_chunk_is_ignored_until_chunk_zero_arrives() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 7); @@ -232,7 +232,7 @@ fn missing_stream_nonzero_offset_is_ignored_until_offset_zero_arrives() { ack: SessionAck::EMPTY, body: SessionBody::Stream(StreamChunk { stream_id, - offset: 1, + chunk_seq: 1, bytes: b"b".to_vec(), fin: false, }), @@ -250,7 +250,7 @@ fn missing_stream_nonzero_offset_is_ignored_until_offset_zero_arrives() { ack: SessionAck::EMPTY, body: SessionBody::Stream(StreamChunk { stream_id, - offset: 0, + chunk_seq: 0, bytes: b"a".to_vec(), fin: false, }), @@ -268,6 +268,125 @@ fn missing_stream_nonzero_offset_is_ignored_until_offset_zero_arrives() { assert_eq!(read_stream_all(&mut fsm, stream_id), b"a".to_vec()); } +#[test] +fn out_of_order_chunks_within_recv_window_are_buffered_and_drained() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 8); + + fsm.receive( + now, + SessionEnvelope { + seq: SessionSeq(1), + ack: SessionAck::EMPTY, + body: SessionBody::Stream(StreamChunk { + stream_id, + chunk_seq: 0, + bytes: b"a".to_vec(), + fin: false, + }), + }, + ); + fsm.receive( + now + Duration::from_millis(1), + SessionEnvelope { + seq: SessionSeq(2), + ack: SessionAck::EMPTY, + body: SessionBody::Stream(StreamChunk { + stream_id, + chunk_seq: 2, + bytes: b"c".to_vec(), + fin: false, + }), + }, + ); + fsm.receive( + now + Duration::from_millis(2), + SessionEnvelope { + seq: SessionSeq(3), + ack: SessionAck::EMPTY, + body: SessionBody::Stream(StreamChunk { + stream_id, + chunk_seq: 1, + bytes: b"b".to_vec(), + fin: false, + }), + }, + ); + + assert_eq!( + fsm.take_next_event(), + Some(super::SessionEvent::Opened(stream_id)) + ); + assert_eq!( + fsm.take_next_event(), + Some(super::SessionEvent::Readable(stream_id)) + ); + assert_eq!(read_stream_all(&mut fsm, stream_id), b"abc".to_vec()); + assert!(fsm.take_next_event().is_none()); +} + +#[test] +fn chunk_past_recv_window_is_dropped_without_session_ack() { + let now = Instant::now(); + let mut fsm = SessionFsm::new( + SessionFsmConfig { + ack_delay: Duration::ZERO, + ..SessionFsmConfig::default() + }, + now, + ); + let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 10); + + fsm.receive( + now, + SessionEnvelope { + seq: SessionSeq(1), + ack: SessionAck::EMPTY, + body: SessionBody::Stream(StreamChunk { + stream_id, + chunk_seq: 0, + bytes: b"a".to_vec(), + fin: false, + }), + }, + ); + + let ack = next_outbound(&mut fsm, now).unwrap(); + assert_eq!( + ack.ack, + SessionAck { + base: SessionSeq(1), + bitmap: 0, + } + ); + + fsm.receive( + now + Duration::from_millis(1), + SessionEnvelope { + seq: SessionSeq(2), + ack: SessionAck::EMPTY, + body: SessionBody::Stream(StreamChunk { + stream_id, + chunk_seq: 9, + bytes: b"z".to_vec(), + fin: false, + }), + }, + ); + + assert_eq!(fsm.state.rx_ring.base_seq(), SessionSeq(2)); + assert!(!fsm.state.rx_ring.contains_key(&SessionSeq(2))); + assert_eq!( + fsm.state.current_ack(), + SessionAck { + base: SessionSeq(1), + bitmap: 0, + } + ); + assert!(next_outbound(&mut fsm, now + Duration::from_millis(2)).is_none()); +} + #[test] fn local_stream_waits_for_open_frame_ack_before_sending_follow_up_data() { let now = Instant::now(); @@ -287,7 +406,7 @@ fn local_stream_waits_for_open_frame_ack_before_sending_follow_up_data() { first.body, SessionBody::Stream(StreamChunk { stream_id, - offset: 0, + chunk_seq: 0, bytes: b"he".to_vec(), fin: false, }) @@ -310,7 +429,7 @@ fn local_stream_waits_for_open_frame_ack_before_sending_follow_up_data() { second.body, SessionBody::Stream(StreamChunk { stream_id, - offset: 2, + chunk_seq: 1, bytes: b"ll".to_vec(), fin: false, }) @@ -330,7 +449,7 @@ fn stream_is_reaped_after_terminal_state_and_last_stream_ack() { ack: SessionAck::EMPTY, body: SessionBody::Stream(StreamChunk { stream_id, - offset: 0, + chunk_seq: 0, bytes: b"hi".to_vec(), fin: true, }), @@ -358,7 +477,7 @@ fn stream_is_reaped_after_terminal_state_and_last_stream_ack() { fin.body, SessionBody::Stream(StreamChunk { stream_id, - offset: 0, + chunk_seq: 0, bytes: Vec::new(), fin: true, }) @@ -389,7 +508,7 @@ fn replayed_remote_open_does_not_recreate_reaped_stream() { ack: SessionAck::EMPTY, body: SessionBody::Stream(StreamChunk { stream_id, - offset: 0, + chunk_seq: 0, bytes: b"hi".to_vec(), fin: true, }), @@ -417,7 +536,7 @@ fn replayed_remote_open_does_not_recreate_reaped_stream() { fin.body, SessionBody::Stream(StreamChunk { stream_id, - offset: 0, + chunk_seq: 0, bytes: Vec::new(), fin: true, }) @@ -450,7 +569,7 @@ fn duplicate_committed_data_is_not_redelivered() { let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 9); let body = SessionBody::Stream(StreamChunk { stream_id, - offset: 0, + chunk_seq: 0, bytes: b"dup".to_vec(), fin: false, }); @@ -635,7 +754,7 @@ fn duplicate_old_packet_seq_is_ignored() { let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 11); let body = SessionBody::Stream(StreamChunk { stream_id, - offset: 0, + chunk_seq: 0, bytes: b"x".to_vec(), fin: false, }); diff --git a/ql-wire/src/encrypted/stream_chunk.rs b/ql-wire/src/encrypted/stream_chunk.rs index 37fd6872..2c41b416 100644 --- a/ql-wire/src/encrypted/stream_chunk.rs +++ b/ql-wire/src/encrypted/stream_chunk.rs @@ -11,7 +11,7 @@ use crate::{ #[derive(Debug, Clone, PartialEq, Eq)] pub struct StreamChunk { pub stream_id: StreamId, - pub offset: u64, + pub chunk_seq: u64, pub fin: bool, pub bytes: Vec, } @@ -20,7 +20,7 @@ pub struct StreamChunk { #[repr(C, packed)] pub struct StreamChunkWire { pub stream_id: U32Le, - pub offset: U64Le, + pub chunk_seq: U64Le, pub fin: u8, pub bytes: [u8], } @@ -33,7 +33,7 @@ impl StreamChunk { pub fn from_wire(wire: &StreamChunkWire) -> Result { Ok(StreamChunk { stream_id: StreamId(wire.stream_id.get()), - offset: wire.offset.get(), + chunk_seq: wire.chunk_seq.get(), bytes: wire.bytes.to_vec(), fin: crate::codec::read_byte(wire.fin)?, }) @@ -42,7 +42,7 @@ impl StreamChunk { pub fn encode_into(&self, out: &mut Vec) { let header = StreamChunkHeaderWire { stream_id: U32Le::new(self.stream_id.0), - offset: U64Le::new(self.offset), + chunk_seq: U64Le::new(self.chunk_seq), fin: u8::from(self.fin), }; push_value(out, &header); @@ -54,6 +54,6 @@ impl StreamChunk { #[repr(C)] pub struct StreamChunkHeaderWire { pub stream_id: U32Le, - pub offset: U64Le, + pub chunk_seq: U64Le, pub fin: u8, } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 8a536224..7283dfb9 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -80,7 +80,7 @@ fn encrypted_session_record_round_trip_and_decrypt() { }, body: SessionBody::Stream(StreamChunk { stream_id: StreamId(9), - offset: 11, + chunk_seq: 11, bytes: b"hello".to_vec(), fin: true, }), @@ -301,7 +301,7 @@ fn protocol_record_size_breakdown() { ack: SessionAck::EMPTY, body: SessionBody::Stream(StreamChunk { stream_id: StreamId(1), - offset: 0, + chunk_seq: 0, fin: false, bytes: Vec::new(), }), @@ -315,7 +315,7 @@ fn protocol_record_size_breakdown() { ack: SessionAck::EMPTY, body: SessionBody::Stream(StreamChunk { stream_id: StreamId(1), - offset: 0, + chunk_seq: 0, fin: true, bytes: Vec::new(), }), From c074c3840d83859d6477da9daa6225e45d75c252 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 24 Mar 2026 16:26:47 -0400 Subject: [PATCH 026/304] ql: design doc --- QL_V2.md | 124 ++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 118 insertions(+), 6 deletions(-) diff --git a/QL_V2.md b/QL_V2.md index 30012271..fbe297fe 100644 --- a/QL_V2.md +++ b/QL_V2.md @@ -1,10 +1,19 @@ # QuantumLink V2 Design Document - QuantumLink V2 is a peer-to-peer protocol for authenticated, encrypted sessions carrying multiplexed byte streams. +QuantumLink V2 is a peer-to-peer protocol for authenticated, encrypted sessions carrying multiplexed byte streams. - It replaces QLv1's one-message-at-a-time model with explicit pairing, handshake, session, and stream state. +It replaces QLv1's one-message-at-a-time model with explicit pairing, handshake, session, and stream state. - QLv2 operates on complete QL records and leaves transport-specific framing, fragmentation, reassembly, and delivery behavior to platform adapters. +QLv2 operates on complete QL records and leaves transport-specific framing, fragmentation, reassembly, and delivery behavior to platform adapters. + +## Table of contents +- [Design goals](#design-goals) +- [Non-design goals](#non-design-goals) +- [Protocol model](#protocol-model) +- [Session handshake](#handshake) +- [Session sequencing and reliability](#session-sequencing-and-reliability) +- [Keepalive and liveness](#keepalive-and-liveness) +- [Stream model](#stream-model) ## Design goals 1. [use ephemeral peer sessions for record encryption](#1-explicit-peer-sessions) @@ -22,11 +31,20 @@ QLv2 replaces per-exchange sealing with explicit pairing, handshake, session, an ### 2. Minimal authenticated header QLv2 keeps a small header visible on the wire while still authenticating it. This lets a host route a record to the correct local or third-party application before decryption without exposing more metadata than necessary. +The visible record header currently includes: + +- protocol version +- record kind +- sender XID +- recipient XID + +This header is intentionally narrow, and can be extended in the future if needed. + ### 3. Transport-agnostic record layer The core protocol only consumes and produces complete QL records. Framing, batching, fragmentation, and reassembly stay in the transport adapter so the same protocol can run over transports such as TCP, BLE, or L2CAP without rewriting core logic. ### 4. QL-level reliability -QLv2 includes QL-level sequence numbers and acknowledgments above the transport. A transport can usually only tell us that bytes were accepted for transmission. A QL acknowledgment tells us something stronger: the peer received, decrypted, and authenticated the record with the current session key. +QLv2 includes QL-level sequence numbers and acknowledgments above the transport. A transport can usually only tell us that bytes were accepted for transmission. A QL acknowledgment tells us something stronger: the peer received and decrypted the message with the session key. This is deliberate redundancy, not a replacement for transport reliability. It is not sufficient for a fully unreliable transport like raw UDP, but it does make QLv2 more robust on transports that should be reliable in theory yet have shown implementation-level flakiness in practice, such as Passport Prime's embedded BLE. @@ -34,9 +52,9 @@ This is deliberate redundancy, not a replacement for transport reliability. It i QLv2 treats duplex byte streams as the application primitive rather than building in a separate model for each interaction style. Request/response, subscriptions, progress updates, and bulk transfer can all be adapted to the same abstraction, which also gives useful behavior such as finish semantics, cancellation, and backpressure without separate protocol features. ### 6. Efficient wire format -The wire format should stay compact, cheap to process, and independent of any one implementation language. QLv2 uses an efficient binary encoding with explicit endianness and fixed layouts, so records can be parsed consistently across platforms and can support zero-copy or near-zero-copy implementations where appropriate. +The wire format should stay compact, cheap to process, and independent of any one implementation language. QLv2 uses an efficient binary encoding with explicit endianness and fixed layouts, so records can be parsed consistently across platforms. -The record sizes shows the protocol's intended split between setup and steady-state traffic. Setup records are relatively large because they carry post-quantum material, while steady-state session records are much smaller. +The record sizes shows the protocol's intended split between setup and steady-state traffic. Setup records are relatively large because they carry post-quantum cryptography material, while steady-state session records are much smaller. | Record type | Encoded size | | --- | ---: | @@ -59,6 +77,15 @@ QLv2 should have one core implementation of pairing, handshake, session, retrans ### 8. Hardware-backed cryptography QLv2 separates parts of its cryptographic implementation through the `QlCrypto` trait. Each platform can provide its own source of randomness, hashing, and AEAD encryption and decryption, choosing software or hardware-backed implementations as appropriate. +```rust +pub trait QlCrypto { + fn fill_random_bytes(&self, data: &mut [u8]); + fn hash(&self, parts: &[&[u8]]) -> [u8; 32]; + fn encrypt_with_aead(&self, /*...*/) -> [u8; EncryptedMessage::AUTH_SIZE]; + fn decrypt_with_aead(&self, /*...*/) -> bool; +} +``` + ## Non-design goals - not a replacement for TCP, QUIC, BLE, or any other transport - not a universal reliability layer for arbitrary raw packets @@ -66,3 +93,88 @@ QLv2 separates parts of its cryptographic implementation through the `QlCrypto` - not responsible for how QL records map onto TCP reads/writes, BLE packets, or similar transport units - not a general-purpose message bus above the stream layer - not an attempt to preserve QLv1's sealed-message model in the core protocol + +## Protocol model +QLv2 has four layers of state: + +- `Pairing` establish a durable peer relationship +- `Handshake` establish a fresh encrypted session between paired peers +- `Session` carries authenticated encrypted traffic with QL-level acknowledgment and retransmission +- `Stream` multiplex many concurrent duplex byte streams inside one session + +This structure gives QLv2 a few important properties: + +- one peer relationship can span many sessions over time +- one session can carry many streams at once +- stream data from different streams can be interwoven on the same session +- ordering is preserved within a stream, not across all streams +- one blocked stream does not block unrelated streams + +## Handshake +The handshake authenticates both peers, derives a fresh session key, and confirms that both sides can use it. + +| Message | Sender | Est. size | Purpose | +| --- | --- | ---: | --- | +| `hello` | initiator | ~6253 bytes | start the handshake, contribute fresh key material, prove initiator identity | +| `hello_reply` | responder | ~6253 bytes | contribute fresh key material, prove responder identity, bind to `hello` | +| `confirm` | initiator | ~4673 bytes | prove the initiator saw `hello_reply` and derived the same session | +| `ready` | responder | ~62 bytes | prove the responder derived the session key by encrypting under it | + +Both peers contribute fresh key material during the handshake. The signatures bind the exchange to the two peers and to the full handshake transcript rather than to isolated messages. The session key is derived from the combined exchange. `ready` is the final key confirmation step because it is encrypted under that new session key. + +The handshake also follows a few simple rules: + +- each handshake message has a bounded lifetime +- duplicate handshake messages can trigger resend of the matching response +- simultaneous `hello` messages are resolved deterministically so only one side continues as the initiator + +## Session sequencing and reliability +This layer gives the session record-level acknowledgment and retransmission, independent of any one stream. + +| Term | Meaning | +| --- | --- | +| `seq` | session-wide sequence number for one encrypted record | +| `ack.base` | all sequence numbers up to this point are acknowledged | +| `ack.bitmap` | selective acknowledgment for the next 64 sequence numbers after `ack.base` | + +- every encrypted session record gets a `seq` +- the sequence space is shared by all streams on the session +- receivers can acknowledge out-of-order records within the session receive window +- retransmission resends the same logical session record with the same `seq` +- a QL acknowledgment tells us that the peer received the record, decrypted it successfully under the current session key, verified it, and accepted its session sequence number + +### Keepalive and liveness +- when a session is idle, a peer may send a `ping` to show that the session is still alive +- the peer does not answer with another `ping`; it simply acknowledges the record at the normal session layer +- if inbound traffic stays silent for too long, the session is treated as dead and closed + +Multiple streams can be interwoven in the same session. A missing session record can stall byte delivery on its own stream, but it does not block unrelated streams. + +## Stream model +QLv2 uses duplex byte streams as the application primitive. + +- each stream has independent inbound and outbound directions +- either peer can open a stream at any time +- many streams can be active on the same session +- bytes are delivered in order within a stream +- each stream chunk may carry bytes and may also mark that direction as complete +- this supports both bounded exchanges and long-lived streams + +Normal completion means one side is done sending bytes on that direction while the other direction may continue. Explicit close is different. It terminates one side or both sides of the stream early and carries a close code. + +By convention, higher-level protocols can treat one direction as a request and the other as a response. + +### Example: RPC over streams + +#### Unary request/response + +- the caller opens a stream +- the caller writes the request bytes and marks the request direction complete +- the responder reads the request, writes the response bytes, and marks the response direction complete + +#### Subscription + +- the caller opens a stream and writes a request body (any subscription parameters) +- the caller marks the request direction complete once the request is sent +- the responder keeps writing response updates on the response direction until the subscription ends or the job completes +- either side can explicitly close the stream early to cancel From 2a5c5817530d01bf31c47d6687badf6dc09119fd Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 24 Mar 2026 17:35:44 -0400 Subject: [PATCH 027/304] ql: unpair outside of session --- QL_V2.md | 4 +- ql-fsm/src/implementation/fsm.rs | 33 ++++++++++----- ql-fsm/src/implementation/mod.rs | 16 ++++--- ql-fsm/src/implementation/peer.rs | 45 +++++++++++++++++++- ql-fsm/src/lib.rs | 12 +++--- ql-fsm/src/replay_cache.rs | 2 +- ql-fsm/src/session/mod.rs | 26 +----------- ql-fsm/src/session/state.rs | 1 - ql-fsm/src/tests/session.rs | 42 +++++++++++++++++-- ql-runtime/src/driver.rs | 7 +++- ql-runtime/src/tests/unpair.rs | 4 ++ ql-wire/src/encrypted/mod.rs | 13 +----- ql-wire/src/encrypted/unpair.rs | 2 - ql-wire/src/lib.rs | 2 + ql-wire/src/record.rs | 8 ++++ ql-wire/src/tests.rs | 69 +++++++++++++++++++++++-------- ql-wire/src/unpair/crypto.rs | 64 ++++++++++++++++++++++++++++ ql-wire/src/unpair/mod.rs | 48 +++++++++++++++++++++ 18 files changed, 312 insertions(+), 86 deletions(-) delete mode 100644 ql-wire/src/encrypted/unpair.rs create mode 100644 ql-wire/src/unpair/crypto.rs create mode 100644 ql-wire/src/unpair/mod.rs diff --git a/QL_V2.md b/QL_V2.md index fbe297fe..6e868c92 100644 --- a/QL_V2.md +++ b/QL_V2.md @@ -62,10 +62,10 @@ The record sizes shows the protocol's intended split between setup and steady-st | `hello_reply` | 6253 bytes | | `confirm` | 4673 bytes | | `pair_request empty` | 1630 bytes | +| `unpair` | 4673 bytes | | `ready empty` | 62 bytes | | `session ack` | 87 bytes | | `session ping` | 87 bytes | -| `session unpair` | 87 bytes | | `session stream empty` | 100 bytes | | `session stream fin` | 100 bytes | | `session stream close` | 94 bytes | @@ -102,6 +102,8 @@ QLv2 has four layers of state: - `Session` carries authenticated encrypted traffic with QL-level acknowledgment and retransmission - `Stream` multiplex many concurrent duplex byte streams inside one session +`Unpair` is a peer-level signed control record outside the session. It tears down the pairing relationship on a best-effort basis and does not depend on session ordering or session establishment. + This structure gives QLv2 a few important properties: - one peer relationship can span many sessions over time diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index a8ba04e9..2b436723 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -14,12 +14,23 @@ pub fn receive( if header.recipient != fsm.identity.xid { return Err(QlFsmError::InvalidXid); } - if !matches!(&payload, QlPayloadRef::PairRequest(_)) { - let Some(peer) = fsm.peer.as_ref().map(|entry| entry.peer.xid) else { - return Err(QlFsmError::NoPeerBound); - }; - if header.sender != peer { - return Err(QlFsmError::InvalidXid); + match &payload { + QlPayloadRef::PairRequest(_) => {} + QlPayloadRef::Unpair(_) => { + let Some(peer) = fsm.peer.as_ref().map(|entry| entry.peer.xid) else { + return Ok(()); + }; + if header.sender != peer { + return Err(QlFsmError::InvalidXid); + } + } + _ => { + let Some(peer) = fsm.peer.as_ref().map(|entry| entry.peer.xid) else { + return Err(QlFsmError::NoPeerBound); + }; + if header.sender != peer { + return Err(QlFsmError::InvalidXid); + } } } @@ -27,6 +38,9 @@ pub fn receive( QlPayloadRef::PairRequest(mut request) => { super::handle_pair(fsm, crypto, &header, &mut request)?; } + QlPayloadRef::Unpair(unpair) => { + super::handle_unpair(fsm, crypto, &header, &unpair)?; + } QlPayloadRef::Hello(hello) => { super::handle_hello(fsm, crypto, &header, &hello)?; } @@ -197,11 +211,8 @@ pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), QlFsmError> { Ok(fsm.session.queue_ping()?) } -pub fn queue_unpair(fsm: &mut QlFsm) -> Result<(), QlFsmError> { - ensure_session_open(fsm)?; - // TODO: keep local peer/session state alive until this queued unpair is acked or times out, - // then clear it locally. Right now this only requests remote unpair. - Ok(fsm.session.queue_unpair()?) +pub fn unpair(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option { + super::handle_unpair_local(fsm, crypto) } fn ensure_peer_bound(fsm: &QlFsm) -> Result<(), QlFsmError> { diff --git a/ql-fsm/src/implementation/mod.rs b/ql-fsm/src/implementation/mod.rs index fbb002d8..002e016a 100644 --- a/ql-fsm/src/implementation/mod.rs +++ b/ql-fsm/src/implementation/mod.rs @@ -73,6 +73,16 @@ fn reset_session(fsm: &mut QlFsm) { ); } +fn clear_bound_peer(fsm: &mut QlFsm) { + if fsm.peer.take().is_none() { + return; + } + fsm.state.outbound.clear(); + reset_session(fsm); + fsm.state.session_events.push_back(QlSessionEvent::Unpaired); + fsm.state.events.push_back(QlFsmEvent::ClearPeer); +} + fn fail_pending_connect_session(fsm: &mut QlFsm, code: ql_wire::CloseCode) { if !fsm.session.has_pending_stream_work() { return; @@ -111,12 +121,6 @@ fn drain_session_events(fsm: &mut QlFsm) { .session_events .push_back(QlSessionEvent::WritableClosed(stream_id)); } - SessionEvent::Unpaired => { - fsm.state.session_events.push_back(QlSessionEvent::Unpaired); - fsm.peer = None; - reset_session(fsm); - fsm.state.events.push_back(QlFsmEvent::ClearPeer); - } SessionEvent::SessionClosed(close) => { fsm.state .session_events diff --git a/ql-fsm/src/implementation/peer.rs b/ql-fsm/src/implementation/peer.rs index bccf112c..2fc87145 100644 --- a/ql-fsm/src/implementation/peer.rs +++ b/ql-fsm/src/implementation/peer.rs @@ -1,6 +1,9 @@ -use ql_wire::{self as wire, PairRequestRecordWire, QlCrypto, QlHeader, RefMut}; +use ql_wire::{self as wire, PairRequestRecordWire, QlCrypto, QlHeader, RefMut, UnpairWire}; -use super::{emit_peer_status, handshake, is_replayed_control, next_control_meta, reset_session}; +use super::{ + clear_bound_peer, emit_peer_status, handshake, is_replayed_control, next_control_meta, + reset_session, +}; use crate::{state::PeerRecord, Peer, QlFsm, QlFsmError, QlFsmEvent}; pub fn handle_bind_peer(fsm: &mut QlFsm, peer: Peer) { @@ -21,6 +24,19 @@ pub fn handle_pair_local(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), Ok(()) } +pub fn handle_unpair_local(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option { + let peer = fsm.peer.as_ref()?.peer.clone(); + let meta = next_control_meta(fsm, fsm.config.control_expiration); + let record = wire::build_unpair( + crypto, + &fsm.identity, + peer.xid, + meta, + ); + clear_bound_peer(fsm); + Some(record) +} + pub fn handle_pair( fsm: &mut QlFsm, crypto: &impl QlCrypto, @@ -52,6 +68,31 @@ pub fn handle_pair( handshake::handle_connect(fsm, crypto) } +pub fn handle_unpair( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + header: &QlHeader, + unpair: &RefMut<'_, UnpairWire>, +) -> Result<(), QlFsmError> { + let Some(entry) = fsm.peer.as_ref() else { + return Ok(()); + }; + + wire::verify_unpair( + crypto, + header, + &entry.peer.signing_key, + unpair, + fsm.state.now.unix_secs, + )?; + if is_replayed_control(fsm, header.sender, wire::ControlMeta::from_wire(unpair.meta)) { + return Ok(()); + } + + clear_bound_peer(fsm); + Ok(()) +} + fn bind_peer_record(fsm: &mut QlFsm, peer: Peer) { fsm.peer = Some(PeerRecord::new(peer.clone())); reset_session(fsm); diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 6eb4c7ac..ac6b77b7 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -3,13 +3,14 @@ //! a caller drives `QlFsm` inside its own event loop //! //! inputs to that loop usually include -//! - app actions like `bind_peer`, `pair`, `connect`, `open_stream`, or `write_stream` +//! - app actions like `bind_peer`, `pair`, `connect`, `unpair`, `open_stream`, or `write_stream` //! - inbound transport bytes passed to `receive` //! - a deadline expiring, handled by calling `on_timer` //! - transport write results passed to `confirm_session_write` or `reject_session_write` //! //! outputs from `QlFsm` are -//! - outbound records from `take_next_write` +//! - outbound session and handshake records from `take_next_write` +//! - a best-effort peer unpair record returned directly from `unpair` //! - peer events from `take_next_event` //! - session events from `take_next_session_event` //! @@ -326,9 +327,10 @@ impl QlFsm { implementation::queue_ping(self) } - /// queues an unpair request on the active session - pub fn queue_unpair(&mut self) -> Result<(), QlFsmError> { - implementation::queue_unpair(self) + /// clears the bound peer locally and returns a best-effort unpair record + pub fn unpair(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Option { + self.state.now = now; + implementation::unpair(self, crypto) } /// returns the next session or stream event diff --git a/ql-fsm/src/replay_cache.rs b/ql-fsm/src/replay_cache.rs index 5ffbb03d..470d4100 100644 --- a/ql-fsm/src/replay_cache.rs +++ b/ql-fsm/src/replay_cache.rs @@ -21,7 +21,7 @@ impl ReplayCache { now_secs: u64, ) -> bool { self.valid_until_by_key - .retain(|_, valid_until| *valid_until > now_secs); + .retain(|_, stored_valid_until| *stored_valid_until > now_secs); let key = ReplayKey { peer, diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 739b02de..f78bcf96 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -10,7 +10,7 @@ use std::time::{Duration, Instant}; use indexmap::map::Entry; use ql_wire::{ CloseCode, CloseTarget, PingBody, SessionBody, SessionCloseBody, SessionEnvelope, SessionSeq, - StreamChunk, StreamClose, StreamId, UnpairBody, XID, + StreamChunk, StreamClose, StreamId, XID, }; use self::{ @@ -88,7 +88,6 @@ pub enum SessionEvent { Finished(StreamId), Closed(StreamClose), WritableClosed(StreamId), - Unpaired, SessionClosed(SessionCloseBody), } @@ -257,12 +256,6 @@ impl SessionFsm { Ok(()) } - pub fn queue_unpair(&mut self) -> Result<(), StreamError> { - self.ensure_session_open()?; - self.state.pending_control.unpair = true; - Ok(()) - } - pub fn receive(&mut self, now: Instant, envelope: SessionEnvelope) { self.state.now = now; self.collect_timeouts(); @@ -293,12 +286,6 @@ impl SessionFsm { let body_kind_is_ack = matches!(envelope.body, SessionBody::Ack); let apply_inbound_body = match envelope.body { SessionBody::Ack | SessionBody::Ping(_) => Ok(()), - SessionBody::Unpair(_) => { - self.state.session_state = SessionState::Closed; - self.clear_streams(); - self.state.events.push_back(SessionEvent::Unpaired); - Ok(()) - } SessionBody::Close(close) => { self.state.session_state = SessionState::Closed; self.clear_streams(); @@ -513,13 +500,6 @@ impl SessionFsm { retransmit: true, }); } - if self.state.pending_control.unpair { - self.state.pending_control.unpair = false; - return Some(PendingSessionBody { - body: SessionBody::Unpair(UnpairBody), - retransmit: true, - }); - } if self.state.pending_control.ping { self.state.pending_control.ping = false; return Some(PendingSessionBody { @@ -711,9 +691,7 @@ impl SessionFsm { fn should_retry_body(&self, body: &SessionBody) -> bool { match body { SessionBody::Ack => true, - SessionBody::Ping(_) | SessionBody::Unpair(_) => { - self.state.session_state == SessionState::Open - } + SessionBody::Ping(_) => self.state.session_state == SessionState::Open, SessionBody::Close(_) => true, SessionBody::Stream(frame) => { self.state.session_state == SessionState::Open diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index 557305d4..bec8b129 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -105,7 +105,6 @@ pub struct PendingSessionBody { #[derive(Debug, Clone, Default)] pub struct PendingSessionControl { pub ping: bool, - pub unpair: bool, pub close: Option, } diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 58f05752..640fcec7 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -110,8 +110,23 @@ fn lost_encrypted_record_is_retried_and_acked() { fn remote_unpair_clears_peer() { let mut harness = Harness::connected(QlFsmConfig::default()); - harness.a.fsm.queue_unpair().unwrap(); - harness.pump(); + let record = harness + .a + .fsm + .unpair(harness.time(), &harness.a.crypto) + .unwrap(); + + assert_eq!( + harness.a.fsm.take_next_session_event(), + Some(QlSessionEvent::Unpaired) + ); + assert!(harness.a.fsm.peer.is_none()); + assert!(matches!( + harness.a.fsm.take_next_event(), + Some(QlFsmEvent::ClearPeer) + )); + + harness.deliver_to_b(record); assert_eq!( harness.b.fsm.take_next_session_event(), @@ -122,7 +137,28 @@ fn remote_unpair_clears_peer() { harness.b.fsm.take_next_event(), Some(QlFsmEvent::ClearPeer) )); - assert!(harness.a.fsm.peer.is_some()); +} + +#[test] +fn unpair_returns_record_without_active_session() { + let mut harness = Harness::paired(QlFsmConfig::default()); + + let record = harness + .a + .fsm + .unpair(harness.time(), &harness.a.crypto) + .unwrap(); + + assert!(matches!(record.payload, QlPayload::Unpair(_))); + assert!(harness.a.fsm.peer.is_none()); + assert_eq!( + harness.a.fsm.take_next_session_event(), + Some(QlSessionEvent::Unpaired) + ); + assert!(matches!( + harness.a.fsm.take_next_event(), + Some(QlFsmEvent::ClearPeer) + )); } #[test] diff --git a/ql-runtime/src/driver.rs b/ql-runtime/src/driver.rs index 4aeae179..137ec69e 100644 --- a/ql-runtime/src/driver.rs +++ b/ql-runtime/src/driver.rs @@ -211,7 +211,12 @@ impl DriverState { self.finish_step(platform, in_flight); } RuntimeCommand::Unpair => { - let _ = self.fsm.queue_unpair(); + if let Some(record) = self.fsm.unpair(now(), platform) { + in_flight.push(InFlightWrite { + session_write_id: None, + future: platform.write_message(record.encode()), + }); + } self.finish_step(platform, in_flight); } RuntimeCommand::Incoming(bytes) => { diff --git a/ql-runtime/src/tests/unpair.rs b/ql-runtime/src/tests/unpair.rs index 74898fae..600ee3cd 100644 --- a/ql-runtime/src/tests/unpair.rs +++ b/ql-runtime/src/tests/unpair.rs @@ -36,6 +36,10 @@ async fn unpair_clears_remote_peer_and_aborts_active_stream() { stream.request.write_all(&[1, 2, 3, 4]).await.unwrap(); handle_a.unpair().unwrap(); + assert!(matches!( + handle_a.open_stream().await, + Err(QlError::NoPeerBound) + )); tokio::time::timeout(std::time::Duration::from_secs(2), responder) .await diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 99aee244..db1bf09a 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -13,13 +13,11 @@ mod close; mod ping; mod stream_chunk; mod stream_close; -mod unpair; pub use close::*; pub use ping::*; pub use stream_chunk::*; pub use stream_close::*; -pub use unpair::*; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] @@ -53,7 +51,6 @@ impl SessionAck { pub enum SessionBody { Ack, Ping(ping::PingBody), - Unpair(unpair::UnpairBody), Stream(StreamChunk), StreamClose(StreamClose), Close(close::SessionCloseBody), @@ -62,7 +59,6 @@ pub enum SessionBody { pub enum SessionBodyRef { Ack, Ping, - Unpair, Stream(Ref), StreamClose(Ref), Close(close::SessionCloseBody), @@ -75,7 +71,6 @@ pub enum SessionBodyRef { enum SessionBodyKind { Ack = 1, Ping = 2, - Unpair = 3, Stream = 4, StreamClose = 5, Close = 6, @@ -100,7 +95,6 @@ impl SessionEnvelope { let body = match parse_session_body(session_body_kind(wire)?, &wire.body)? { SessionBodyRef::Ack => SessionBody::Ack, SessionBodyRef::Ping => SessionBody::Ping(ping::PingBody), - SessionBodyRef::Unpair => SessionBody::Unpair(unpair::UnpairBody), SessionBodyRef::Stream(frame) => SessionBody::Stream(StreamChunk::from_wire(&frame)?), SessionBodyRef::StreamClose(frame) => { SessionBody::StreamClose(StreamClose::from_wire(&frame)?) @@ -122,7 +116,6 @@ impl SessionEnvelope { let kind = match &self.body { SessionBody::Ack => SessionBodyKind::Ack, SessionBody::Ping(_) => SessionBodyKind::Ping, - SessionBody::Unpair(_) => SessionBodyKind::Unpair, SessionBody::Stream(_) => SessionBodyKind::Stream, SessionBody::StreamClose(_) => SessionBodyKind::StreamClose, SessionBody::Close(_) => SessionBodyKind::Close, @@ -135,7 +128,7 @@ impl SessionEnvelope { }; push_value(&mut out, &header); match &self.body { - SessionBody::Ack | SessionBody::Ping(_) | SessionBody::Unpair(_) => {} + SessionBody::Ack | SessionBody::Ping(_) => {} SessionBody::Stream(frame) => frame.encode_into(&mut out), SessionBody::StreamClose(frame) => frame.encode_into(&mut out), SessionBody::Close(body) => body.encode_into(&mut out), @@ -201,10 +194,6 @@ fn parse_session_body( crate::codec::ensure_empty(&body)?; Ok(SessionBodyRef::Ping) } - SessionBodyKind::Unpair => { - crate::codec::ensure_empty(&body)?; - Ok(SessionBodyRef::Unpair) - } SessionBodyKind::Stream => Ok(SessionBodyRef::Stream(StreamChunk::parse(body)?)), SessionBodyKind::StreamClose => Ok(SessionBodyRef::StreamClose(StreamClose::parse(body)?)), SessionBodyKind::Close => Ok(SessionBodyRef::Close(close::SessionCloseBody::decode( diff --git a/ql-wire/src/encrypted/unpair.rs b/ql-wire/src/encrypted/unpair.rs deleted file mode 100644 index a638b045..00000000 --- a/ql-wire/src/encrypted/unpair.rs +++ /dev/null @@ -1,2 +0,0 @@ -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub struct UnpairBody; diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs index 74e59212..78bd38ff 100644 --- a/ql-wire/src/lib.rs +++ b/ql-wire/src/lib.rs @@ -19,6 +19,7 @@ mod nonce; mod pair; mod pq; mod record; +mod unpair; mod xid; pub use control::*; @@ -32,6 +33,7 @@ pub use nonce::*; pub use pair::*; pub use pq::*; pub use record::*; +pub use unpair::*; pub use xid::*; pub const QL_WIRE_VERSION: u8 = 1; diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index 666a93b1..aad98e7b 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -9,6 +9,7 @@ use crate::{ handshake::{self, ConfirmWire, HelloReplyWire, HelloWire}, header::{decode_record_header, encode_record_header, QlHeader}, pair::{self, PairRequestRecordWire}, + unpair::{self, UnpairWire}, WireError, QL_WIRE_VERSION, }; @@ -21,6 +22,7 @@ pub struct QlRecord { #[derive(Debug, Clone, PartialEq, Eq)] pub enum QlPayload { PairRequest(pair::PairRequestRecord), + Unpair(unpair::Unpair), Hello(handshake::Hello), HelloReply(handshake::HelloReply), Confirm(handshake::Confirm), @@ -35,6 +37,7 @@ pub struct QlRecordRef { pub enum QlPayloadRef { PairRequest(Ref), + Unpair(Ref), Hello(Ref), HelloReply(Ref), Confirm(Ref), @@ -53,12 +56,14 @@ pub(crate) enum RecordKind { Confirm = 4, Ready = 5, Session = 6, + Unpair = 7, } impl RecordKind { fn for_payload(payload: &QlPayload) -> Self { match payload { QlPayload::PairRequest(_) => Self::PairRequest, + QlPayload::Unpair(_) => Self::Unpair, QlPayload::Hello(_) => Self::Hello, QlPayload::HelloReply(_) => Self::HelloReply, QlPayload::Confirm(_) => Self::Confirm, @@ -76,6 +81,7 @@ impl QlRecord { codec::push_value(&mut out, &header); match &self.payload { QlPayload::PairRequest(request) => request.encode_into(&mut out), + QlPayload::Unpair(unpair) => unpair.encode_into(&mut out), QlPayload::Hello(hello) => hello.encode_into(&mut out), QlPayload::HelloReply(reply) => reply.encode_into(&mut out), QlPayload::Confirm(confirm) => confirm.encode_into(&mut out), @@ -128,6 +134,7 @@ impl QlPayloadRef { Self::PairRequest(request) => { QlPayload::PairRequest(pair::PairRequestRecord::from_wire(request)) } + Self::Unpair(unpair) => QlPayload::Unpair(unpair::Unpair::from_wire(unpair)), Self::Hello(hello) => QlPayload::Hello(handshake::Hello::from_wire(hello)), Self::HelloReply(reply) => { QlPayload::HelloReply(handshake::HelloReply::from_wire(reply)) @@ -144,6 +151,7 @@ fn parse_payload(kind: RecordKind, payload: B) -> Result Ok(QlPayloadRef::PairRequest(pair::PairRequestRecord::parse( payload, )?)), + RecordKind::Unpair => Ok(QlPayloadRef::Unpair(unpair::Unpair::parse(payload)?)), RecordKind::Hello => Ok(QlPayloadRef::Hello(handshake::Hello::parse(payload)?)), RecordKind::HelloReply => Ok(QlPayloadRef::HelloReply(handshake::HelloReply::parse( payload, diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 7283dfb9..37dbde73 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -194,6 +194,43 @@ fn ready_round_trip_and_decrypt() { assert_eq!(body.meta, meta); } +#[test] +fn unpair_round_trip_and_verify() { + let crypto = TestCrypto::new(40); + let (sender_signing_private, sender_signing_public) = generate_ml_dsa_keypair(&crypto); + let sender_kem = generate_ml_kem_keypair(&crypto); + let identity = QlIdentity::new( + XID([7; XID::SIZE]), + sender_signing_private, + sender_signing_public.clone(), + sender_kem.0, + sender_kem.1, + ); + let recipient = XID([8; XID::SIZE]); + let meta = ControlMeta { + control_id: ControlId(88), + valid_until: 600, + }; + let record = unpair::build_unpair(&crypto, &identity, recipient, meta); + + let mut bytes = record.encode(); + let parsed = QlRecord::decode(&bytes).unwrap(); + assert_eq!(parsed, record); + + let QlRecordRef { header, payload } = QlRecord::parse_mut(&mut bytes).unwrap(); + let QlPayloadRef::Unpair(unpair) = payload else { + panic!("expected unpair payload"); + }; + unpair::verify_unpair( + &crypto, + &header, + &sender_signing_public, + &unpair, + 100, + ) + .unwrap(); +} + #[test] fn protocol_record_size_breakdown() { fn meta(id: u32) -> ControlMeta { @@ -259,16 +296,23 @@ fn protocol_record_size_breakdown() { encrypted: encrypted(11, 0), }), }; + let unpair = QlRecord { + header, + payload: QlPayload::Unpair(unpair::Unpair { + meta: meta(4), + signature: MlDsaSignature::from_data([12; MlDsaSignature::SIZE]), + }), + }; let ready = QlRecord { header, payload: QlPayload::Ready(handshake::Ready { - encrypted: encrypted(12, 0), + encrypted: encrypted(13, 0), }), }; let session_ack = session_record( header, - 13, + 14, SessionEnvelope { seq: SessionSeq(1), ack: SessionAck::EMPTY, @@ -277,27 +321,18 @@ fn protocol_record_size_breakdown() { ); let session_ping = session_record( header, - 14, + 15, SessionEnvelope { seq: SessionSeq(2), ack: SessionAck::EMPTY, body: SessionBody::Ping(PingBody), }, ); - let session_unpair = session_record( - header, - 15, - SessionEnvelope { - seq: SessionSeq(3), - ack: SessionAck::EMPTY, - body: SessionBody::Unpair(UnpairBody), - }, - ); let session_stream_empty = session_record( header, 16, SessionEnvelope { - seq: SessionSeq(4), + seq: SessionSeq(3), ack: SessionAck::EMPTY, body: SessionBody::Stream(StreamChunk { stream_id: StreamId(1), @@ -311,7 +346,7 @@ fn protocol_record_size_breakdown() { header, 17, SessionEnvelope { - seq: SessionSeq(5), + seq: SessionSeq(4), ack: SessionAck::EMPTY, body: SessionBody::Stream(StreamChunk { stream_id: StreamId(1), @@ -325,7 +360,7 @@ fn protocol_record_size_breakdown() { header, 18, SessionEnvelope { - seq: SessionSeq(6), + seq: SessionSeq(5), ack: SessionAck::EMPTY, body: SessionBody::StreamClose(StreamClose { stream_id: StreamId(1), @@ -339,7 +374,7 @@ fn protocol_record_size_breakdown() { header, 19, SessionEnvelope { - seq: SessionSeq(7), + seq: SessionSeq(6), ack: SessionAck::EMPTY, body: SessionBody::Close(SessionCloseBody { code: CloseCode::PROTOCOL, @@ -355,10 +390,10 @@ fn protocol_record_size_breakdown() { print_size("ql-wire hello_reply", hello_reply.encode().len()); print_size("ql-wire confirm", confirm.encode().len()); print_size("ql-wire pair_request empty", pair_request.encode().len()); + print_size("ql-wire unpair", unpair.encode().len()); print_size("ql-wire ready empty", ready.encode().len()); print_size("ql-wire session ack", session_ack.encode().len()); print_size("ql-wire session ping", session_ping.encode().len()); - print_size("ql-wire session unpair", session_unpair.encode().len()); print_size( "ql-wire session stream empty", session_stream_empty.encode().len(), diff --git a/ql-wire/src/unpair/crypto.rs b/ql-wire/src/unpair/crypto.rs new file mode 100644 index 00000000..fa5ea03a --- /dev/null +++ b/ql-wire/src/unpair/crypto.rs @@ -0,0 +1,64 @@ +use zerocopy::{byte_slice::ByteSlice, Ref}; + +use super::UnpairWire; +use crate::{ + ControlMeta, MlDsaPublicKey, QlCrypto, QlHeader, QlIdentity, QlPayload, QlRecord, WireError, + XID, +}; + +pub fn build_unpair( + crypto: &impl QlCrypto, + identity: &QlIdentity, + recipient: XID, + meta: ControlMeta, +) -> QlRecord { + let header = QlHeader { + sender: identity.xid, + recipient, + }; + let signature = identity + .signing_private_key + .sign(crypto, &hash_unpair_signature_data(crypto, &header, &meta)); + QlRecord { + header, + payload: QlPayload::Unpair(super::Unpair { meta, signature }), + } +} + +pub fn verify_unpair( + crypto: &impl QlCrypto, + header: &QlHeader, + signer: &MlDsaPublicKey, + unpair: &Ref, + now_seconds: u64, +) -> Result<(), WireError> { + let meta = ControlMeta::from_wire(unpair.meta); + meta.ensure_not_expired(now_seconds)?; + if signer.verify_bytes( + &unpair.signature, + &hash_unpair_signature_data(crypto, header, &meta), + ) { + Ok(()) + } else { + Err(WireError::InvalidSignature) + } +} + +fn hash_unpair_signature_data( + crypto: &impl QlCrypto, + header: &QlHeader, + meta: &ControlMeta, +) -> [u8; 32] { + let aad = header.aad(); + let control_id = meta.control_id.0.to_le_bytes(); + let valid_until = meta.valid_until.to_le_bytes(); + crypto.hash(&[ + b"ql-wire:unpair:v1", + b"aad", + &aad, + b"control-id", + &control_id, + b"valid-until", + &valid_until, + ]) +} diff --git a/ql-wire/src/unpair/mod.rs b/ql-wire/src/unpair/mod.rs new file mode 100644 index 00000000..12b0b612 --- /dev/null +++ b/ql-wire/src/unpair/mod.rs @@ -0,0 +1,48 @@ +use zerocopy::{ + byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, +}; + +use crate::{ + codec::{parse, push_value}, + control::ControlMetaWire, + ControlMeta, MlDsaSignature, WireError, +}; + +mod crypto; +pub use crypto::*; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Unpair { + pub meta: ControlMeta, + pub signature: MlDsaSignature, +} + +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] +#[repr(C)] +pub struct UnpairWire { + pub meta: ControlMetaWire, + pub signature: [u8; MlDsaSignature::SIZE], +} + +impl Unpair { + pub fn parse(bytes: B) -> Result, WireError> { + parse(bytes) + } + + pub fn from_wire(wire: &UnpairWire) -> Self { + Self { + meta: ControlMeta::from_wire(wire.meta), + signature: MlDsaSignature::from_data(wire.signature), + } + } + + pub fn encode_into(&self, out: &mut Vec) { + push_value( + out, + &UnpairWire { + meta: self.meta.to_wire(), + signature: *self.signature.as_bytes(), + }, + ); + } +} From c0633cd8bd0c15d82b3247b697e3280e9b72741a Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 24 Mar 2026 18:21:02 -0400 Subject: [PATCH 028/304] ql: wire/non-wire support for handshake --- ql-wire/src/handshake/crypto.rs | 121 +++++++++++---------- ql-wire/src/handshake/mod.rs | 184 +++++++++++++++++++++++--------- 2 files changed, 200 insertions(+), 105 deletions(-) diff --git a/ql-wire/src/handshake/crypto.rs b/ql-wire/src/handshake/crypto.rs index 0ca2f636..15bdeeee 100644 --- a/ql-wire/src/handshake/crypto.rs +++ b/ql-wire/src/handshake/crypto.rs @@ -1,9 +1,11 @@ use zerocopy::{ - byte_slice::{ByteSlice, ByteSliceMut}, + byte_slice::ByteSliceMut, Ref, }; -use super::{Confirm, ConfirmWire, Hello, HelloReply, HelloReplyWire, HelloWire, Ready, ReadyBody}; +use super::{ + Confirm, ConfirmView, Hello, HelloReply, HelloReplyView, HelloView, Ready, ReadyBody, +}; use crate::{ pq::ML_KEM_SUITE_TAG, ControlMeta, EncryptedMessage, EncryptedMessageWire, MlDsaPublicKey, MlDsaSignature, MlKemCiphertext, MlKemPublicKey, Nonce, QlCrypto, QlHeader, QlIdentity, @@ -45,34 +47,34 @@ pub fn build_hello( ) } -pub fn verify_hello( +pub fn verify_hello( crypto: &impl QlCrypto, initiator: XID, responder: XID, initiator_signing_key: &MlDsaPublicKey, - hello: &Ref, + hello: &impl HelloView, now_seconds: u64, ) -> Result<(), WireError> { - let meta = ControlMeta::from_wire(hello.meta); + let meta = hello.meta(); meta.ensure_not_expired(now_seconds)?; let proof_data = hash_hello_proof_data( crypto, initiator, responder, &meta, - &hello.nonce, - &hello.kem_ct, + hello.nonce(), + hello.kem_ct(), ); - verify_signature_bytes(initiator_signing_key, &hello.signature, &proof_data) + verify_signature_bytes(initiator_signing_key, hello.signature(), &proof_data) } -pub fn respond_hello( +pub fn respond_hello( crypto: &impl QlCrypto, identity: &QlIdentity, initiator: XID, initiator_signing_key: &MlDsaPublicKey, initiator_encapsulation_key: &MlKemPublicKey, - hello: &Ref, + hello: &impl HelloView, meta: ControlMeta, now_seconds: u64, ) -> Result<(HelloReply, ResponderSecrets), WireError> { @@ -86,8 +88,8 @@ pub fn respond_hello( )?; let initiator_secret = identity .encapsulation_private_key - .decapsulate_shared_secret_bytes(&hello.kem_ct); - let hello_meta = ControlMeta::from_wire(hello.meta); + .decapsulate_shared_secret_bytes(hello.kem_ct()); + let hello_meta = hello.meta(); let nonce = next_nonce(crypto); let (responder_secret, kem_ct) = initiator_encapsulation_key.encapsulate_new_shared_secret(crypto); @@ -96,8 +98,8 @@ pub fn respond_hello( initiator, identity.xid, &hello_meta, - &hello.nonce, - &hello.kem_ct, + hello.nonce(), + hello.kem_ct(), &meta, &nonce.0, kem_ct.as_bytes(), @@ -117,45 +119,46 @@ pub fn respond_hello( )) } -pub fn build_confirm( +pub fn build_confirm( crypto: &impl QlCrypto, identity: &QlIdentity, responder: XID, responder_signing_key: &MlDsaPublicKey, - hello: &Hello, - reply: &Ref, + hello: &impl HelloView, + reply: &impl HelloReplyView, initiator_secret: &SessionKey, meta: ControlMeta, now_seconds: u64, ) -> Result<(Confirm, SessionKey), WireError> { - let reply_meta = ControlMeta::from_wire(reply.meta); + let hello_meta = hello.meta(); + let reply_meta = reply.meta(); reply_meta.ensure_not_expired(now_seconds)?; let transcript = hash_handshake_transcript( crypto, identity.xid, responder, - &hello.meta, - &hello.nonce.0, - hello.kem_ct.as_bytes(), + &hello_meta, + hello.nonce(), + hello.kem_ct(), &reply_meta, - &reply.nonce, - &reply.kem_ct, + reply.nonce(), + reply.kem_ct(), ); - verify_signature_bytes(responder_signing_key, &reply.signature, &transcript)?; + verify_signature_bytes(responder_signing_key, reply.signature(), &transcript)?; let responder_secret = identity .encapsulation_private_key - .decapsulate_shared_secret_bytes(&reply.kem_ct); + .decapsulate_shared_secret_bytes(reply.kem_ct()); let proof_data = hash_confirm_proof_data( crypto, &meta, identity.xid, responder, - &hello.meta, - &hello.nonce.0, - hello.kem_ct.as_bytes(), + &hello_meta, + hello.nonce(), + hello.kem_ct(), &reply_meta, - &reply.nonce, - &reply.kem_ct, + reply.nonce(), + reply.kem_ct(), ); let signature = identity.signing_private_key.sign(crypto, &proof_data); let session_key = derive_session_key( @@ -164,27 +167,29 @@ pub fn build_confirm( &responder_secret, identity.xid, responder, - &hello.meta, - &hello.nonce.0, - hello.kem_ct.as_bytes(), + &hello_meta, + hello.nonce(), + hello.kem_ct(), &reply_meta, - &reply.nonce, - &reply.kem_ct, + reply.nonce(), + reply.kem_ct(), ); Ok((Confirm { meta, signature }, session_key)) } -pub fn finalize_confirm( +pub fn finalize_confirm( crypto: &impl QlCrypto, initiator: XID, responder: XID, initiator_signing_key: &MlDsaPublicKey, - hello: &Hello, - reply: &HelloReply, - confirm: &Ref, + hello: &impl HelloView, + reply: &impl HelloReplyView, + confirm: &impl ConfirmView, secrets: &ResponderSecrets, now_seconds: u64, ) -> Result { + let hello_meta = hello.meta(); + let reply_meta = reply.meta(); verify_confirm( crypto, initiator, @@ -201,40 +206,42 @@ pub fn finalize_confirm( &secrets.responder_secret, initiator, responder, - &hello.meta, - &hello.nonce.0, - hello.kem_ct.as_bytes(), - &reply.meta, - &reply.nonce.0, - reply.kem_ct.as_bytes(), + &hello_meta, + hello.nonce(), + hello.kem_ct(), + &reply_meta, + reply.nonce(), + reply.kem_ct(), )) } -pub fn verify_confirm( +pub fn verify_confirm( crypto: &impl QlCrypto, initiator: XID, responder: XID, initiator_signing_key: &MlDsaPublicKey, - hello: &Hello, - reply: &HelloReply, - confirm: &Ref, + hello: &impl HelloView, + reply: &impl HelloReplyView, + confirm: &impl ConfirmView, now_seconds: u64, ) -> Result<(), WireError> { - let confirm_meta = ControlMeta::from_wire(confirm.meta); + let hello_meta = hello.meta(); + let reply_meta = reply.meta(); + let confirm_meta = confirm.meta(); confirm_meta.ensure_not_expired(now_seconds)?; let proof_data = hash_confirm_proof_data( crypto, &confirm_meta, initiator, responder, - &hello.meta, - &hello.nonce.0, - hello.kem_ct.as_bytes(), - &reply.meta, - &reply.nonce.0, - reply.kem_ct.as_bytes(), + &hello_meta, + hello.nonce(), + hello.kem_ct(), + &reply_meta, + reply.nonce(), + reply.kem_ct(), ); - verify_signature_bytes(initiator_signing_key, &confirm.signature, &proof_data) + verify_signature_bytes(initiator_signing_key, confirm.signature(), &proof_data) } pub fn build_ready( diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index e503b63e..e2fbb9fe 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -29,6 +29,49 @@ pub struct HelloWire { pub signature: [u8; MlDsaSignature::SIZE], } +pub trait HelloView { + fn meta(&self) -> ControlMeta; + fn nonce(&self) -> &[u8; Nonce::SIZE]; + fn kem_ct(&self) -> &[u8; MlKemCiphertext::SIZE]; + fn signature(&self) -> &[u8; MlDsaSignature::SIZE]; +} + +impl HelloView for Hello { + fn meta(&self) -> ControlMeta { + self.meta + } + + fn nonce(&self) -> &[u8; Nonce::SIZE] { + &self.nonce.0 + } + + fn kem_ct(&self) -> &[u8; MlKemCiphertext::SIZE] { + self.kem_ct.as_bytes() + } + + fn signature(&self) -> &[u8; MlDsaSignature::SIZE] { + self.signature.as_bytes() + } +} + +impl HelloView for Ref { + fn meta(&self) -> ControlMeta { + ControlMeta::from_wire(self.meta) + } + + fn nonce(&self) -> &[u8; Nonce::SIZE] { + &self.nonce + } + + fn kem_ct(&self) -> &[u8; MlKemCiphertext::SIZE] { + &self.kem_ct + } + + fn signature(&self) -> &[u8; MlDsaSignature::SIZE] { + &self.signature + } +} + impl Hello { pub fn parse(bytes: B) -> Result, WireError> { parse(bytes) @@ -43,22 +86,16 @@ impl Hello { } } - pub fn decode(bytes: &[u8]) -> Result { - let wire = Self::parse(bytes)?; - Ok(Self::from_wire(&wire)) - } - - pub fn to_wire(&self) -> HelloWire { - HelloWire { - meta: self.meta.to_wire(), - nonce: self.nonce.0, - kem_ct: *self.kem_ct.as_bytes(), - signature: *self.signature.as_bytes(), - } - } - pub fn encode_into(&self, out: &mut Vec) { - push_value(out, &self.to_wire()); + push_value( + out, + &HelloWire { + meta: self.meta.to_wire(), + nonce: self.nonce.0, + kem_ct: *self.kem_ct.as_bytes(), + signature: *self.signature.as_bytes(), + }, + ); } } @@ -79,6 +116,49 @@ pub struct HelloReplyWire { pub signature: [u8; MlDsaSignature::SIZE], } +pub trait HelloReplyView { + fn meta(&self) -> ControlMeta; + fn nonce(&self) -> &[u8; Nonce::SIZE]; + fn kem_ct(&self) -> &[u8; MlKemCiphertext::SIZE]; + fn signature(&self) -> &[u8; MlDsaSignature::SIZE]; +} + +impl HelloReplyView for HelloReply { + fn meta(&self) -> ControlMeta { + self.meta + } + + fn nonce(&self) -> &[u8; Nonce::SIZE] { + &self.nonce.0 + } + + fn kem_ct(&self) -> &[u8; MlKemCiphertext::SIZE] { + self.kem_ct.as_bytes() + } + + fn signature(&self) -> &[u8; MlDsaSignature::SIZE] { + self.signature.as_bytes() + } +} + +impl HelloReplyView for Ref { + fn meta(&self) -> ControlMeta { + ControlMeta::from_wire(self.meta) + } + + fn nonce(&self) -> &[u8; Nonce::SIZE] { + &self.nonce + } + + fn kem_ct(&self) -> &[u8; MlKemCiphertext::SIZE] { + &self.kem_ct + } + + fn signature(&self) -> &[u8; MlDsaSignature::SIZE] { + &self.signature + } +} + impl HelloReply { pub fn parse(bytes: B) -> Result, WireError> { parse(bytes) @@ -93,22 +173,16 @@ impl HelloReply { } } - pub fn decode(bytes: &[u8]) -> Result { - let wire = Self::parse(bytes)?; - Ok(Self::from_wire(&wire)) - } - - pub fn to_wire(&self) -> HelloReplyWire { - HelloReplyWire { - meta: self.meta.to_wire(), - nonce: self.nonce.0, - kem_ct: *self.kem_ct.as_bytes(), - signature: *self.signature.as_bytes(), - } - } - pub fn encode_into(&self, out: &mut Vec) { - push_value(out, &self.to_wire()); + push_value( + out, + &HelloReplyWire { + meta: self.meta.to_wire(), + nonce: self.nonce.0, + kem_ct: *self.kem_ct.as_bytes(), + signature: *self.signature.as_bytes(), + }, + ); } } @@ -125,6 +199,31 @@ pub struct ConfirmWire { pub signature: [u8; MlDsaSignature::SIZE], } +pub trait ConfirmView { + fn meta(&self) -> ControlMeta; + fn signature(&self) -> &[u8; MlDsaSignature::SIZE]; +} + +impl ConfirmView for Confirm { + fn meta(&self) -> ControlMeta { + self.meta + } + + fn signature(&self) -> &[u8; MlDsaSignature::SIZE] { + self.signature.as_bytes() + } +} + +impl ConfirmView for Ref { + fn meta(&self) -> ControlMeta { + ControlMeta::from_wire(self.meta) + } + + fn signature(&self) -> &[u8; MlDsaSignature::SIZE] { + &self.signature + } +} + impl Confirm { pub fn parse(bytes: B) -> Result, WireError> { parse(bytes) @@ -137,20 +236,14 @@ impl Confirm { } } - pub fn decode(bytes: &[u8]) -> Result { - let wire = Self::parse(bytes)?; - Ok(Self::from_wire(&wire)) - } - - pub fn to_wire(&self) -> ConfirmWire { - ConfirmWire { - meta: self.meta.to_wire(), - signature: *self.signature.as_bytes(), - } - } - pub fn encode_into(&self, out: &mut Vec) { - push_value(out, &self.to_wire()); + push_value( + out, + &ConfirmWire { + meta: self.meta.to_wire(), + signature: *self.signature.as_bytes(), + }, + ); } } @@ -175,11 +268,6 @@ impl Ready { } } - pub fn decode(bytes: &[u8]) -> Result { - let wire = Self::parse(bytes)?; - Ok(Self::from_wire(&wire)) - } - pub fn encode_into(&self, out: &mut Vec) { self.encrypted.encode_into(out); } From ca708b26f484e0f0ce4e673d3031c112b395fd88 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 24 Mar 2026 18:21:40 -0400 Subject: [PATCH 029/304] ql: avoid copying pq types repeatedly --- ql-wire/src/pq.rs | 72 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 50 insertions(+), 22 deletions(-) diff --git a/ql-wire/src/pq.rs b/ql-wire/src/pq.rs index 0f2db1bb..c6783c1d 100644 --- a/ql-wire/src/pq.rs +++ b/ql-wire/src/pq.rs @@ -30,71 +30,99 @@ impl AsRef<[u8]> for SessionKey { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct MlDsaPrivateKey(Box<[u8; MlDsaPrivateKey::SIZE]>); +macro_rules! impl_byte_traits { + ($name:ident) => { + impl std::fmt::Debug for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple(stringify!($name)) + .field(&self.as_bytes()) + .finish() + } + } + + impl PartialEq for $name { + fn eq(&self, other: &Self) -> bool { + self.as_bytes() == other.as_bytes() + } + } + + impl Eq for $name {} + + impl std::hash::Hash for $name { + fn hash(&self, state: &mut H) { + self.as_bytes().hash(state); + } + } + }; +} + +#[derive(Clone)] +pub struct MlDsaPrivateKey(Box); + +impl_byte_traits!(MlDsaPrivateKey); impl MlDsaPrivateKey { pub const SIZE: usize = ml_dsa_87::MLDSA87SigningKey::len(); pub fn from_data(data: [u8; Self::SIZE]) -> Self { - Self(Box::new(data)) + Self(Box::new(ml_dsa_87::MLDSA87SigningKey::new(data))) } pub fn as_bytes(&self) -> &[u8; Self::SIZE] { - self.0.as_ref() + self.0.as_ref().as_ref() } pub fn sign(&self, crypto: &impl QlCrypto, message: &[u8]) -> MlDsaSignature { let mut randomness = [0u8; SIGNING_RANDOMNESS_SIZE]; crypto.fill_random_bytes(&mut randomness); - let signing_key = ml_dsa_87::MLDSA87SigningKey::new(*self.as_bytes()); // Safe: we always sign with the empty context, so the only remaining // error is libcrux's negligible-probability rejection-sampling failure. - let signature = ml_dsa_87::sign(&signing_key, message, b"", randomness) + let signature = ml_dsa_87::sign(self.0.as_ref(), message, b"", randomness) .expect("ML-DSA signing should not fail"); - MlDsaSignature::from_data(*signature.as_ref()) + MlDsaSignature(Box::new(signature)) } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct MlDsaPublicKey(Box<[u8; MlDsaPublicKey::SIZE]>); +#[derive(Clone)] +pub struct MlDsaPublicKey(Box); + +impl_byte_traits!(MlDsaPublicKey); impl MlDsaPublicKey { pub const SIZE: usize = ml_dsa_87::MLDSA87VerificationKey::len(); pub fn from_data(data: [u8; Self::SIZE]) -> Self { - Self(Box::new(data)) + Self(Box::new(ml_dsa_87::MLDSA87VerificationKey::new(data))) } pub fn as_bytes(&self) -> &[u8; Self::SIZE] { - self.0.as_ref() + self.0.as_ref().as_ref() } pub fn verify(&self, signature: &MlDsaSignature, message: &[u8]) -> bool { - let verification_key = ml_dsa_87::MLDSA87VerificationKey::new(*self.as_bytes()); - let signature = ml_dsa_87::MLDSA87Signature::new(*signature.as_bytes()); - ml_dsa_87::verify(&verification_key, message, b"", &signature).is_ok() + ml_dsa_87::verify(self.0.as_ref(), message, b"", signature.0.as_ref()).is_ok() } pub fn verify_bytes(&self, signature: &[u8; MlDsaSignature::SIZE], message: &[u8]) -> bool { - let verification_key = ml_dsa_87::MLDSA87VerificationKey::new(*self.as_bytes()); let signature = ml_dsa_87::MLDSA87Signature::new(*signature); - ml_dsa_87::verify(&verification_key, message, b"", &signature).is_ok() + ml_dsa_87::verify(self.0.as_ref(), message, b"", &signature).is_ok() } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct MlDsaSignature(Box<[u8; MlDsaSignature::SIZE]>); +#[derive(Clone)] +pub struct MlDsaSignature(Box); + +impl_byte_traits!(MlDsaSignature); impl MlDsaSignature { pub const SIZE: usize = ml_dsa_87::MLDSA87Signature::len(); pub fn from_data(data: [u8; Self::SIZE]) -> Self { - Self(Box::new(data)) + Self(Box::new(ml_dsa_87::MLDSA87Signature::new(data))) } pub fn as_bytes(&self) -> &[u8; Self::SIZE] { - self.0.as_ref() + ml_dsa_87::MLDSA87Signature::as_ref(self.0.as_ref()) } } @@ -176,8 +204,8 @@ pub fn generate_ml_dsa_keypair(crypto: &impl QlCrypto) -> (MlDsaPrivateKey, MlDs crypto.fill_random_bytes(&mut randomness); let key_pair = ml_dsa_87::generate_key_pair(randomness); ( - MlDsaPrivateKey::from_data(*key_pair.signing_key.as_ref()), - MlDsaPublicKey::from_data(*key_pair.verification_key.as_ref()), + MlDsaPrivateKey(Box::new(key_pair.signing_key)), + MlDsaPublicKey(Box::new(key_pair.verification_key)), ) } From fe93f76f12920f2347b1bec06f3af48232a11655 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 24 Mar 2026 18:43:57 -0400 Subject: [PATCH 030/304] ql: fix ci --- api/src/api/quantum_link.rs | 2 ++ ql-fsm/src/implementation/peer.rs | 13 ++++++------- ql-fsm/src/session/state.rs | 7 +------ ql-fsm/src/session/tests.rs | 3 +-- ql-wire/src/handshake/crypto.rs | 9 ++------- ql-wire/src/tests.rs | 9 +-------- 6 files changed, 13 insertions(+), 30 deletions(-) diff --git a/api/src/api/quantum_link.rs b/api/src/api/quantum_link.rs index 2b67c10e..0d956897 100644 --- a/api/src/api/quantum_link.rs +++ b/api/src/api/quantum_link.rs @@ -325,6 +325,7 @@ mod tests { let message = EnvoyMessage { message: QuantumLinkMessage::ExchangeRate(fx_rate), timestamp: 123456, + protocol_version: None, }; let envelope = QuantumLink::seal( @@ -346,6 +347,7 @@ mod tests { let message = EnvoyMessage { message: QuantumLinkMessage::Heartbeat(Heartbeat {}), timestamp: 123456, + protocol_version: None, }; let envelope = QuantumLink::seal( diff --git a/ql-fsm/src/implementation/peer.rs b/ql-fsm/src/implementation/peer.rs index 2fc87145..bd2d0785 100644 --- a/ql-fsm/src/implementation/peer.rs +++ b/ql-fsm/src/implementation/peer.rs @@ -27,12 +27,7 @@ pub fn handle_pair_local(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), pub fn handle_unpair_local(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option { let peer = fsm.peer.as_ref()?.peer.clone(); let meta = next_control_meta(fsm, fsm.config.control_expiration); - let record = wire::build_unpair( - crypto, - &fsm.identity, - peer.xid, - meta, - ); + let record = wire::build_unpair(crypto, &fsm.identity, peer.xid, meta); clear_bound_peer(fsm); Some(record) } @@ -85,7 +80,11 @@ pub fn handle_unpair( unpair, fsm.state.now.unix_secs, )?; - if is_replayed_control(fsm, header.sender, wire::ControlMeta::from_wire(unpair.meta)) { + if is_replayed_control( + fsm, + header.sender, + wire::ControlMeta::from_wire(unpair.meta), + ) { return Ok(()); } diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index bec8b129..f9708806 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -5,12 +5,7 @@ use ql_wire::{ CloseTarget, SessionAck, SessionBody, SessionCloseBody, SessionSeq, StreamClose, StreamId, }; -use super::{ - ring::SeqRing, - stream_window::StreamRecvWindow, - SessionEvent, - SessionState, -}; +use super::{ring::SeqRing, stream_window::StreamRecvWindow, SessionEvent, SessionState}; pub const SESSION_WINDOW_CAPACITY: usize = 64; diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index c399f00f..340babf6 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -692,8 +692,7 @@ fn tx_selective_ack_keeps_front_gap_pinned() { for (byte, stream_id) in (0..64u8).zip(stream_ids.iter().copied()) { fsm.write_stream(stream_id, vec![byte]).unwrap(); - let _ = next_outbound(&mut fsm, now + Duration::from_millis(byte as u64)) - .unwrap(); + let _ = next_outbound(&mut fsm, now + Duration::from_millis(byte as u64)).unwrap(); } fsm.receive( diff --git a/ql-wire/src/handshake/crypto.rs b/ql-wire/src/handshake/crypto.rs index 15bdeeee..b2d00756 100644 --- a/ql-wire/src/handshake/crypto.rs +++ b/ql-wire/src/handshake/crypto.rs @@ -1,11 +1,6 @@ -use zerocopy::{ - byte_slice::ByteSliceMut, - Ref, -}; +use zerocopy::{byte_slice::ByteSliceMut, Ref}; -use super::{ - Confirm, ConfirmView, Hello, HelloReply, HelloReplyView, HelloView, Ready, ReadyBody, -}; +use super::{Confirm, ConfirmView, Hello, HelloReply, HelloReplyView, HelloView, Ready, ReadyBody}; use crate::{ pq::ML_KEM_SUITE_TAG, ControlMeta, EncryptedMessage, EncryptedMessageWire, MlDsaPublicKey, MlDsaSignature, MlKemCiphertext, MlKemPublicKey, Nonce, QlCrypto, QlHeader, QlIdentity, diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 37dbde73..de80c488 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -221,14 +221,7 @@ fn unpair_round_trip_and_verify() { let QlPayloadRef::Unpair(unpair) = payload else { panic!("expected unpair payload"); }; - unpair::verify_unpair( - &crypto, - &header, - &sender_signing_public, - &unpair, - 100, - ) - .unwrap(); + unpair::verify_unpair(&crypto, &header, &sender_signing_public, &unpair, 100).unwrap(); } #[test] From f3e2bd3375419411fd1933e4295770a4c9af0e62 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 25 Mar 2026 09:27:52 -0400 Subject: [PATCH 031/304] ql: get rid of duplicate queues --- ql-fsm/src/implementation/fsm.rs | 24 +++- ql-fsm/src/implementation/mod.rs | 84 ++++++------- ql-fsm/src/session/mod.rs | 108 +++++++++-------- ql-fsm/src/session/state.rs | 3 +- ql-fsm/src/session/tests.rs | 195 ++++++++++++++++++------------- 5 files changed, 233 insertions(+), 181 deletions(-) diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index 2b436723..94d226cd 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -60,8 +60,16 @@ pub fn receive( let envelope = wire::decrypt_record(crypto, &header, &mut encrypted, &session_key)?; // TODO: this seems unnecessary to me? let envelope = wire::SessionEnvelope::from_wire(&envelope)?; - fsm.session.receive(fsm.state.now.instant, envelope); - super::drain_session_events(fsm); + let mut session_closed = false; + fsm.session.receive(fsm.state.now.instant, envelope, { + let session_events = &mut fsm.state.session_events; + |event| { + session_closed |= super::forward_session_event(session_events, event); + } + }); + if session_closed { + super::apply_session_closed(fsm); + } } } @@ -71,8 +79,16 @@ pub fn receive( pub fn on_timer(fsm: &mut QlFsm) { super::handle_timer(fsm); if super::peer_session(fsm).is_some() { - fsm.session.on_timer(fsm.state.now.instant); - super::drain_session_events(fsm); + let mut session_closed = false; + fsm.session.on_timer(fsm.state.now.instant, { + let session_events = &mut fsm.state.session_events; + |event| { + session_closed |= super::forward_session_event(session_events, event); + } + }); + if session_closed { + super::apply_session_closed(fsm); + } } } diff --git a/ql-fsm/src/implementation/mod.rs b/ql-fsm/src/implementation/mod.rs index 002e016a..54bd727a 100644 --- a/ql-fsm/src/implementation/mod.rs +++ b/ql-fsm/src/implementation/mod.rs @@ -2,7 +2,7 @@ mod fsm; mod handshake; mod peer; -use std::time::Duration; +use std::{collections::VecDeque, time::Duration}; pub use fsm::*; pub use handshake::*; @@ -95,51 +95,51 @@ fn fail_pending_connect_session(fsm: &mut QlFsm, code: ql_wire::CloseCode) { })); } -fn drain_session_events(fsm: &mut QlFsm) { - while let Some(event) = fsm.session.take_next_event() { - match event { - SessionEvent::Opened(stream_id) => { - fsm.state - .session_events - .push_back(QlSessionEvent::Opened(stream_id)); - } - SessionEvent::Readable(stream_id) => { - fsm.state - .session_events - .push_back(QlSessionEvent::Readable(stream_id)); - } - SessionEvent::Finished(stream_id) => fsm - .state - .session_events - .push_back(QlSessionEvent::Finished(stream_id)), - SessionEvent::Closed(frame) => fsm - .state - .session_events - .push_back(QlSessionEvent::Closed(frame)), - SessionEvent::WritableClosed(stream_id) => { - fsm.state - .session_events - .push_back(QlSessionEvent::WritableClosed(stream_id)); - } - SessionEvent::SessionClosed(close) => { - fsm.state - .session_events - .push_back(QlSessionEvent::SessionClosed(close.clone())); - if let Some(entry) = fsm.peer.as_mut() { - if matches!( - entry.session, - crate::state::ConnectionState::Connected { .. } - ) { - entry.session = crate::state::ConnectionState::Disconnected; - emit_peer_status(fsm); - } - } - reset_session(fsm); - } +fn forward_session_event( + session_events: &mut VecDeque, + event: SessionEvent, +) -> bool { + match event { + SessionEvent::Opened(stream_id) => { + session_events.push_back(QlSessionEvent::Opened(stream_id)); + false + } + SessionEvent::Readable(stream_id) => { + session_events.push_back(QlSessionEvent::Readable(stream_id)); + false + } + SessionEvent::Finished(stream_id) => { + session_events.push_back(QlSessionEvent::Finished(stream_id)); + false + } + SessionEvent::Closed(frame) => { + session_events.push_back(QlSessionEvent::Closed(frame)); + false + } + SessionEvent::WritableClosed(stream_id) => { + session_events.push_back(QlSessionEvent::WritableClosed(stream_id)); + false + } + SessionEvent::SessionClosed(close) => { + session_events.push_back(QlSessionEvent::SessionClosed(close)); + true } } } +fn apply_session_closed(fsm: &mut QlFsm) { + if let Some(entry) = fsm.peer.as_mut() { + if matches!( + entry.session, + crate::state::ConnectionState::Connected { .. } + ) { + entry.session = crate::state::ConnectionState::Disconnected; + emit_peer_status(fsm); + } + } + reset_session(fsm); +} + fn deadline_after_secs(now_secs: u64, duration: Duration) -> u64 { now_secs.saturating_add(duration_to_secs(duration)) } diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index f78bcf96..7b07f0de 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -130,7 +130,6 @@ impl SessionFsm { pending_control: Default::default(), streams: Default::default(), next_stream_index: 0, - events: Default::default(), }, } } @@ -256,7 +255,12 @@ impl SessionFsm { Ok(()) } - pub fn receive(&mut self, now: Instant, envelope: SessionEnvelope) { + pub fn receive( + &mut self, + now: Instant, + envelope: SessionEnvelope, + mut emit: impl FnMut(SessionEvent), + ) { self.state.now = now; self.collect_timeouts(); self.process_ack(envelope.ack); @@ -276,9 +280,12 @@ impl SessionFsm { return; } if !self.state.rx_ring.accepts_seq(seq) { - self.fail_session(SessionCloseBody { - code: CloseCode::PROTOCOL, - }); + self.fail_session( + SessionCloseBody { + code: CloseCode::PROTOCOL, + }, + &mut emit, + ); return; } @@ -289,14 +296,12 @@ impl SessionFsm { SessionBody::Close(close) => { self.state.session_state = SessionState::Closed; self.clear_streams(); - self.state - .events - .push_back(SessionEvent::SessionClosed(close)); + emit(SessionEvent::SessionClosed(close)); Ok(()) } - SessionBody::Stream(frame) => self.handle_stream_frame(frame), + SessionBody::Stream(frame) => self.handle_stream_frame(frame, &mut emit), SessionBody::StreamClose(frame) => { - self.handle_stream_close(frame); + self.handle_stream_close(frame, &mut emit); Ok(()) } }; @@ -422,7 +427,7 @@ impl SessionFsm { entry.state = TxState::Pending; } - pub fn on_timer(&mut self, now: Instant) { + pub fn on_timer(&mut self, now: Instant, mut emit: impl FnMut(SessionEvent)) { self.state.now = now; self.collect_timeouts(); if self.state.session_state == SessionState::Closed { @@ -436,9 +441,12 @@ impl SessionFsm { if !self.config.peer_timeout.is_zero() && self.state.last_inbound_at + self.config.peer_timeout <= self.state.now { - self.fail_session(SessionCloseBody { - code: CloseCode::TIMEOUT, - }); + self.fail_session( + SessionCloseBody { + code: CloseCode::TIMEOUT, + }, + &mut emit, + ); return; } if !self.config.keepalive_interval.is_zero() @@ -481,10 +489,6 @@ impl SessionFsm { .min() } - pub fn take_next_event(&mut self) -> Option { - self.state.events.pop_front() - } - pub fn has_pending_stream_work(&self) -> bool { self.state.streams.values().any(|stream| { stream.pending_close.is_some() @@ -712,7 +716,11 @@ impl SessionFsm { } } - fn handle_stream_frame(&mut self, frame: StreamChunk) -> Result<(), RejectNoAck> { + fn handle_stream_frame( + &mut self, + frame: StreamChunk, + emit: &mut impl FnMut(SessionEvent), + ) -> Result<(), RejectNoAck> { let StreamChunk { stream_id, chunk_seq, @@ -724,15 +732,18 @@ impl SessionFsm { Entry::Occupied(entry) => entry.into_mut(), Entry::Vacant(entry) => { if !remote_namespace.matches(stream_id) { - self.fail_session(SessionCloseBody { - code: CloseCode::PROTOCOL, - }); + self.fail_session( + SessionCloseBody { + code: CloseCode::PROTOCOL, + }, + emit, + ); return Ok(()); } if chunk_seq != 0 { return Err(RejectNoAck); } - self.state.events.push_back(SessionEvent::Opened(stream_id)); + emit(SessionEvent::Opened(stream_id)); entry.insert(StreamState::new(StreamRole::Responder)) } }; @@ -742,9 +753,12 @@ impl SessionFsm { if chunk_seq < stream.recv_window.next_chunk_seq() { return Ok(()); } - self.fail_session(SessionCloseBody { - code: CloseCode::PROTOCOL, - }); + self.fail_session( + SessionCloseBody { + code: CloseCode::PROTOCOL, + }, + emit, + ); return Ok(()); } InboundState::Discarding => return Ok(()), @@ -757,14 +771,10 @@ impl SessionFsm { RecvInsertOutcome::Inserted => { Self::drain_recv_window(stream); if !was_readable && !stream.recv_buf.is_empty() { - self.state - .events - .push_back(SessionEvent::Readable(stream_id)); + emit(SessionEvent::Readable(stream_id)); } if matches!(stream.inbound_state, InboundState::Finished) { - self.state - .events - .push_back(SessionEvent::Finished(stream_id)); + emit(SessionEvent::Finished(stream_id)); } self.try_reap_stream(stream_id); Ok(()) @@ -772,19 +782,25 @@ impl SessionFsm { RecvInsertOutcome::Duplicate => Ok(()), RecvInsertOutcome::RejectNoAck => Err(RejectNoAck), RecvInsertOutcome::Conflict => { - self.fail_session(SessionCloseBody { - code: CloseCode::PROTOCOL, - }); + self.fail_session( + SessionCloseBody { + code: CloseCode::PROTOCOL, + }, + emit, + ); Ok(()) } } } - fn handle_stream_close(&mut self, frame: StreamClose) { + fn handle_stream_close(&mut self, frame: StreamClose, emit: &mut impl FnMut(SessionEvent)) { let Some(stream) = self.state.streams.get_mut(&frame.stream_id) else { - self.fail_session(SessionCloseBody { - code: CloseCode::PROTOCOL, - }); + self.fail_session( + SessionCloseBody { + code: CloseCode::PROTOCOL, + }, + emit, + ); return; }; @@ -797,9 +813,7 @@ impl SessionFsm { stream.inbound_state = InboundState::Closed(frame.clone()); stream.recv_buf.clear(); stream.recv_window.clear(); - self.state - .events - .push_back(SessionEvent::Closed(frame.clone())); + emit(SessionEvent::Closed(frame.clone())); } if Self::target_affects_outbound(stream.role, frame.target) && !matches!(stream.outbound_state, OutboundState::Closed) @@ -807,9 +821,7 @@ impl SessionFsm { stream.outbound_state = OutboundState::Closed; stream.send_buf.clear(); stream.pending_close = None; - self.state - .events - .push_back(SessionEvent::WritableClosed(frame.stream_id)); + emit(SessionEvent::WritableClosed(frame.stream_id)); } self.try_reap_stream(frame.stream_id); } @@ -950,7 +962,7 @@ impl SessionFsm { } } - fn fail_session(&mut self, close: SessionCloseBody) { + fn fail_session(&mut self, close: SessionCloseBody, emit: &mut impl FnMut(SessionEvent)) { if self.state.session_state == SessionState::Closed { return; } @@ -959,9 +971,7 @@ impl SessionFsm { self.clear_streams(); self.state.pending_control = Default::default(); self.state.pending_control.close = Some(close.clone()); - self.state - .events - .push_back(SessionEvent::SessionClosed(close)); + emit(SessionEvent::SessionClosed(close)); } fn clear_streams(&mut self) { diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index f9708806..99ebb3ee 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -5,7 +5,7 @@ use ql_wire::{ CloseTarget, SessionAck, SessionBody, SessionCloseBody, SessionSeq, StreamClose, StreamId, }; -use super::{ring::SeqRing, stream_window::StreamRecvWindow, SessionEvent, SessionState}; +use super::{ring::SeqRing, stream_window::StreamRecvWindow, SessionState}; pub const SESSION_WINDOW_CAPACITY: usize = 64; @@ -138,7 +138,6 @@ pub struct SessionFsmState { /// scheduling, so we do not need a separate ready queue pub streams: IndexMap, pub next_stream_index: usize, - pub events: VecDeque, } impl SessionFsmState { diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 340babf6..bd7423b3 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -5,7 +5,7 @@ use ql_wire::{ StreamChunk, StreamClose, }; -use super::{SessionFsm, SessionFsmConfig, SessionState}; +use super::{SessionEvent, SessionFsm, SessionFsmConfig, SessionState}; fn read_stream_all(fsm: &mut SessionFsm, stream_id: ql_wire::StreamId) -> Vec { let mut out = Vec::new(); @@ -42,6 +42,22 @@ fn next_outbound(fsm: &mut SessionFsm, now: Instant) -> Option Some(envelope) } +fn receive_events( + fsm: &mut SessionFsm, + now: Instant, + envelope: SessionEnvelope, +) -> Vec { + let mut events = Vec::new(); + fsm.receive(now, envelope, |event| events.push(event)); + events +} + +fn on_timer_events(fsm: &mut SessionFsm, now: Instant) -> Vec { + let mut events = Vec::new(); + fsm.on_timer(now, |event| events.push(event)); + events +} + #[test] fn outbound_session_seq_increments_monotonically() { let now = Instant::now(); @@ -51,7 +67,8 @@ fn outbound_session_seq_increments_monotonically() { fsm.write_stream(stream_id, b"one".to_vec()).unwrap(); let first = next_outbound(&mut fsm, now).unwrap(); - fsm.receive( + let _ = receive_events( + &mut fsm, now + Duration::from_millis(1), ack( 1, @@ -80,7 +97,8 @@ fn inbound_ack_removes_acked_tx_entries() { assert_eq!(first.seq, SessionSeq(1)); assert!(fsm.state.tx_ring.contains_key(&SessionSeq(1))); - fsm.receive( + let _ = receive_events( + &mut fsm, now + Duration::from_millis(1), ack( 1, @@ -101,7 +119,8 @@ fn out_of_order_receive_produces_bitmap_ack_then_advances_base() { let stream_id_a = ql_wire::StreamId(super::StreamNamespace::High.bit() | 1); let stream_id_b = ql_wire::StreamId(super::StreamNamespace::High.bit() | 2); - fsm.receive( + let _ = receive_events( + &mut fsm, now, SessionEnvelope { seq: SessionSeq(2), @@ -124,7 +143,8 @@ fn out_of_order_receive_produces_bitmap_ack_then_advances_base() { } ); - fsm.receive( + let _ = receive_events( + &mut fsm, now + Duration::from_millis(1), SessionEnvelope { seq: SessionSeq(1), @@ -172,7 +192,7 @@ fn repeated_outbound_messages_keep_reporting_latest_receive_ack() { let stream_id_a = fsm.open_stream().unwrap(); let stream_id_b = fsm.open_stream().unwrap(); - fsm.receive(now, ack(1, SessionAck::EMPTY)); + let _ = receive_events(&mut fsm, now, ack(1, SessionAck::EMPTY)); fsm.write_stream(stream_id_a, b"one".to_vec()).unwrap(); let first = next_outbound(&mut fsm, now).unwrap(); @@ -200,7 +220,8 @@ fn local_inbound_close_ignores_late_remote_bytes() { ) .unwrap(); - fsm.receive( + let events = receive_events( + &mut fsm, now, SessionEnvelope { seq: SessionSeq(1), @@ -216,7 +237,7 @@ fn local_inbound_close_ignores_late_remote_bytes() { assert_eq!(fsm.state.session_state, SessionState::Open); assert_eq!(read_stream_all(&mut fsm, stream_id), Vec::::new()); - assert!(fsm.take_next_event().is_none()); + assert!(events.is_empty()); } #[test] @@ -225,7 +246,8 @@ fn missing_stream_nonzero_chunk_is_ignored_until_chunk_zero_arrives() { let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 7); - fsm.receive( + let events = receive_events( + &mut fsm, now, SessionEnvelope { seq: SessionSeq(1), @@ -240,10 +262,11 @@ fn missing_stream_nonzero_chunk_is_ignored_until_chunk_zero_arrives() { ); assert_eq!(fsm.state.session_state, SessionState::Open); - assert!(fsm.take_next_event().is_none()); + assert!(events.is_empty()); assert!(!fsm.state.streams.contains_key(&stream_id)); - fsm.receive( + let events = receive_events( + &mut fsm, now + Duration::from_millis(1), SessionEnvelope { seq: SessionSeq(2), @@ -258,12 +281,11 @@ fn missing_stream_nonzero_chunk_is_ignored_until_chunk_zero_arrives() { ); assert_eq!( - fsm.take_next_event(), - Some(super::SessionEvent::Opened(stream_id)) - ); - assert_eq!( - fsm.take_next_event(), - Some(super::SessionEvent::Readable(stream_id)) + events, + vec![ + SessionEvent::Opened(stream_id), + SessionEvent::Readable(stream_id) + ] ); assert_eq!(read_stream_all(&mut fsm, stream_id), b"a".to_vec()); } @@ -274,7 +296,8 @@ fn out_of_order_chunks_within_recv_window_are_buffered_and_drained() { let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 8); - fsm.receive( + let mut events = receive_events( + &mut fsm, now, SessionEnvelope { seq: SessionSeq(1), @@ -287,7 +310,8 @@ fn out_of_order_chunks_within_recv_window_are_buffered_and_drained() { }), }, ); - fsm.receive( + events.extend(receive_events( + &mut fsm, now + Duration::from_millis(1), SessionEnvelope { seq: SessionSeq(2), @@ -299,8 +323,9 @@ fn out_of_order_chunks_within_recv_window_are_buffered_and_drained() { fin: false, }), }, - ); - fsm.receive( + )); + events.extend(receive_events( + &mut fsm, now + Duration::from_millis(2), SessionEnvelope { seq: SessionSeq(3), @@ -312,18 +337,16 @@ fn out_of_order_chunks_within_recv_window_are_buffered_and_drained() { fin: false, }), }, - ); + )); assert_eq!( - fsm.take_next_event(), - Some(super::SessionEvent::Opened(stream_id)) - ); - assert_eq!( - fsm.take_next_event(), - Some(super::SessionEvent::Readable(stream_id)) + events, + vec![ + SessionEvent::Opened(stream_id), + SessionEvent::Readable(stream_id) + ] ); assert_eq!(read_stream_all(&mut fsm, stream_id), b"abc".to_vec()); - assert!(fsm.take_next_event().is_none()); } #[test] @@ -338,7 +361,8 @@ fn chunk_past_recv_window_is_dropped_without_session_ack() { ); let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 10); - fsm.receive( + let _ = receive_events( + &mut fsm, now, SessionEnvelope { seq: SessionSeq(1), @@ -361,7 +385,8 @@ fn chunk_past_recv_window_is_dropped_without_session_ack() { } ); - fsm.receive( + let _ = receive_events( + &mut fsm, now + Duration::from_millis(1), SessionEnvelope { seq: SessionSeq(2), @@ -413,7 +438,8 @@ fn local_stream_waits_for_open_frame_ack_before_sending_follow_up_data() { ); assert!(next_outbound(&mut fsm, now + Duration::from_millis(1)).is_none()); - fsm.receive( + let _ = receive_events( + &mut fsm, now + Duration::from_millis(2), ack( 1, @@ -442,7 +468,8 @@ fn stream_is_reaped_after_terminal_state_and_last_stream_ack() { let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 13); - fsm.receive( + let events = receive_events( + &mut fsm, now, SessionEnvelope { seq: SessionSeq(1), @@ -457,18 +484,14 @@ fn stream_is_reaped_after_terminal_state_and_last_stream_ack() { ); assert_eq!( - fsm.take_next_event(), - Some(super::SessionEvent::Opened(stream_id)) - ); - assert_eq!( - fsm.take_next_event(), - Some(super::SessionEvent::Readable(stream_id)) + events, + vec![ + SessionEvent::Opened(stream_id), + SessionEvent::Readable(stream_id), + SessionEvent::Finished(stream_id), + ] ); assert_eq!(read_stream_all(&mut fsm, stream_id), b"hi".to_vec()); - assert_eq!( - fsm.take_next_event(), - Some(super::SessionEvent::Finished(stream_id)) - ); assert!(fsm.state.streams.contains_key(&stream_id)); fsm.finish_stream(stream_id).unwrap(); @@ -484,7 +507,8 @@ fn stream_is_reaped_after_terminal_state_and_last_stream_ack() { ); assert!(fsm.state.streams.contains_key(&stream_id)); - fsm.receive( + let _ = receive_events( + &mut fsm, now + Duration::from_millis(2), ack( 2, @@ -514,21 +538,17 @@ fn replayed_remote_open_does_not_recreate_reaped_stream() { }), }; - fsm.receive(now, opener.clone()); + let events = receive_events(&mut fsm, now, opener.clone()); assert_eq!( - fsm.take_next_event(), - Some(super::SessionEvent::Opened(stream_id)) - ); - assert_eq!( - fsm.take_next_event(), - Some(super::SessionEvent::Readable(stream_id)) + events, + vec![ + SessionEvent::Opened(stream_id), + SessionEvent::Readable(stream_id), + SessionEvent::Finished(stream_id), + ] ); assert_eq!(read_stream_all(&mut fsm, stream_id), b"hi".to_vec()); - assert_eq!( - fsm.take_next_event(), - Some(super::SessionEvent::Finished(stream_id)) - ); fsm.finish_stream(stream_id).unwrap(); let fin = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); @@ -542,7 +562,8 @@ fn replayed_remote_open_does_not_recreate_reaped_stream() { }) ); - fsm.receive( + let _ = receive_events( + &mut fsm, now + Duration::from_millis(2), ack( 2, @@ -555,11 +576,11 @@ fn replayed_remote_open_does_not_recreate_reaped_stream() { assert!(!fsm.state.streams.contains_key(&stream_id)); - fsm.receive(now + Duration::from_millis(3), opener); + let events = receive_events(&mut fsm, now + Duration::from_millis(3), opener); assert_eq!(fsm.state.session_state, SessionState::Open); assert!(!fsm.state.streams.contains_key(&stream_id)); - assert!(fsm.take_next_event().is_none()); + assert!(events.is_empty()); } #[test] @@ -574,7 +595,8 @@ fn duplicate_committed_data_is_not_redelivered() { fin: false, }); - fsm.receive( + let _ = receive_events( + &mut fsm, now, SessionEnvelope { seq: SessionSeq(1), @@ -582,11 +604,10 @@ fn duplicate_committed_data_is_not_redelivered() { body: body.clone(), }, ); - let _ = fsm.take_next_event(); - let _ = fsm.take_next_event(); let _ = read_stream_all(&mut fsm, stream_id); - fsm.receive( + let events = receive_events( + &mut fsm, now + Duration::from_millis(1), SessionEnvelope { seq: SessionSeq(2), @@ -595,7 +616,7 @@ fn duplicate_committed_data_is_not_redelivered() { }, ); - assert!(fsm.take_next_event().is_none()); + assert!(events.is_empty()); assert_eq!(read_stream_all(&mut fsm, stream_id), Vec::::new()); } @@ -624,7 +645,8 @@ fn next_outbound_round_robins_across_ready_streams() { }) .collect(); - fsm.receive( + let _ = receive_events( + &mut fsm, now + Duration::from_millis(1), ack( 1, @@ -664,7 +686,7 @@ fn idle_session_sends_ping_after_keepalive_interval() { assert_eq!(fsm.next_deadline(), Some(now + Duration::from_millis(50))); assert!(next_outbound(&mut fsm, now + Duration::from_millis(49)).is_none()); - fsm.on_timer(now + Duration::from_millis(50)); + assert!(on_timer_events(&mut fsm, now + Duration::from_millis(50)).is_empty()); let envelope = next_outbound(&mut fsm, now + Duration::from_millis(50)).unwrap(); assert!(matches!(envelope.body, SessionBody::Ping(PingBody))); @@ -675,12 +697,16 @@ fn receive_ping_schedules_ack_without_ping_pong() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - fsm.receive(now, ping(1, SessionAck::EMPTY)); + let _ = receive_events(&mut fsm, now, ping(1, SessionAck::EMPTY)); let ack_envelope = next_outbound(&mut fsm, now + Duration::from_millis(10)).unwrap(); assert_eq!(ack_envelope.body, SessionBody::Ack); - fsm.receive(now + Duration::from_millis(20), ack(2, SessionAck::EMPTY)); + let _ = receive_events( + &mut fsm, + now + Duration::from_millis(20), + ack(2, SessionAck::EMPTY), + ); assert!(next_outbound(&mut fsm, now + Duration::from_millis(30)).is_none()); } @@ -695,7 +721,8 @@ fn tx_selective_ack_keeps_front_gap_pinned() { let _ = next_outbound(&mut fsm, now + Duration::from_millis(byte as u64)).unwrap(); } - fsm.receive( + let _ = receive_events( + &mut fsm, now + Duration::from_millis(100), ack( 1, @@ -713,7 +740,8 @@ fn tx_selective_ack_keeps_front_gap_pinned() { fsm.write_stream(extra_stream, b"x".to_vec()).unwrap(); assert!(next_outbound(&mut fsm, now + Duration::from_millis(101)).is_none()); - fsm.receive( + let _ = receive_events( + &mut fsm, now + Duration::from_millis(102), ack( 2, @@ -737,12 +765,12 @@ fn rx_seq_past_window_closes_protocol() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - fsm.receive(now, ping(65, SessionAck::EMPTY)); + let events = receive_events(&mut fsm, now, ping(65, SessionAck::EMPTY)); assert_eq!(fsm.state.session_state, SessionState::Closed); assert!(matches!( - fsm.take_next_event(), - Some(super::SessionEvent::SessionClosed(close)) if close.code == CloseCode::PROTOCOL + events.as_slice(), + [SessionEvent::SessionClosed(close)] if close.code == CloseCode::PROTOCOL )); } @@ -758,7 +786,8 @@ fn duplicate_old_packet_seq_is_ignored() { fin: false, }); - fsm.receive( + let _ = receive_events( + &mut fsm, now, SessionEnvelope { seq: SessionSeq(1), @@ -766,11 +795,10 @@ fn duplicate_old_packet_seq_is_ignored() { body: body.clone(), }, ); - let _ = fsm.take_next_event(); - let _ = fsm.take_next_event(); let _ = read_stream_all(&mut fsm, stream_id); - fsm.receive( + let events = receive_events( + &mut fsm, now + Duration::from_millis(1), SessionEnvelope { seq: SessionSeq(1), @@ -779,7 +807,7 @@ fn duplicate_old_packet_seq_is_ignored() { }, ); - assert!(fsm.take_next_event().is_none()); + assert!(events.is_empty()); assert_eq!(read_stream_all(&mut fsm, stream_id), Vec::::new()); } @@ -795,7 +823,8 @@ fn retransmitted_stream_close_is_idempotent() { payload: Vec::new(), }; - fsm.receive( + let events = receive_events( + &mut fsm, now, SessionEnvelope { seq: SessionSeq(1), @@ -804,13 +833,11 @@ fn retransmitted_stream_close_is_idempotent() { }, ); - assert_eq!( - fsm.take_next_event(), - Some(super::SessionEvent::Closed(frame.clone())) - ); + assert_eq!(events, vec![SessionEvent::Closed(frame.clone())]); assert_eq!(read_stream_all(&mut fsm, stream_id), Vec::::new()); - fsm.receive( + let events = receive_events( + &mut fsm, now + Duration::from_millis(1), SessionEnvelope { seq: SessionSeq(2), @@ -819,6 +846,6 @@ fn retransmitted_stream_close_is_idempotent() { }, ); - assert!(fsm.take_next_event().is_none()); + assert!(events.is_empty()); assert_eq!(read_stream_all(&mut fsm, stream_id), Vec::::new()); } From 3a0202f869ddfc13dc2aa30c3c124bd98bb5275f Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 25 Mar 2026 12:44:01 -0400 Subject: [PATCH 032/304] ql: more efficient take_next_write --- ql-fsm/src/session/mod.rs | 110 ++++++++++++++++++++++---------------- 1 file changed, 63 insertions(+), 47 deletions(-) diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 7b07f0de..7b664f16 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -9,8 +9,8 @@ use std::time::{Duration, Instant}; use indexmap::map::Entry; use ql_wire::{ - CloseCode, CloseTarget, PingBody, SessionBody, SessionCloseBody, SessionEnvelope, SessionSeq, - StreamChunk, StreamClose, StreamId, XID, + CloseCode, CloseTarget, PingBody, SessionAck, SessionBody, SessionCloseBody, SessionEnvelope, + SessionSeq, StreamChunk, StreamClose, StreamId, XID, }; use self::{ @@ -326,52 +326,9 @@ impl SessionFsm { self.state.now = now; self.collect_timeouts(); let ack = self.state.current_ack(); - loop { - let Some(seq) = - self.state.tx_ring.iter().find_map(|(seq, entry)| { - matches!(entry.state, TxState::Pending).then_some(seq) - }) - else { - break; - }; - let body = self - .state - .tx_ring - .get(&seq) - .map(|entry| entry.pending.body.clone())?; - if !self.should_retry_body(&body) { - let _ = self.state.tx_ring.remove(&seq); - self.state - .tx_ring - .advance_empty_front_until(self.state.next_seq); - continue; - } - - let entry = self.state.tx_ring.get_mut(&seq)?; - entry.state = TxState::Issued; - return Some(SessionEnvelope { seq, ack, body }); - } - - if !self.state.tx_ring.accepts_seq(self.state.next_seq) { - return None; - } - - let pending = self.next_pending_body()?; - let seq = self.state.next_seq; - self.state.next_seq = SessionSeq(seq.0 + 1); - let body = pending.body.clone(); - self.state - .tx_ring - .insert( - seq, - TxEntry { - pending, - state: TxState::Issued, - }, - ) - .unwrap(); - Some(SessionEnvelope { seq, ack, body }) + self.take_pending_retransmit(ack) + .or_else(|| self.take_fresh_write(ack)) } pub fn confirm_write(&mut self, now: Instant, seq: SessionSeq) { @@ -497,6 +454,65 @@ impl SessionFsm { }) } + fn take_pending_retransmit(&mut self, ack: SessionAck) -> Option { + let base_seq = self.state.tx_ring.base_seq().0; + let next_seq = self.state.next_seq.0; + + for seq in (base_seq..next_seq).map(SessionSeq) { + let should_retry = match self.state.tx_ring.get(&seq) { + Some(entry) if matches!(entry.state, TxState::Pending) => { + self.should_retry_body(&entry.pending.body) + } + _ => continue, + }; + + if !should_retry { + let _ = self.state.tx_ring.remove(&seq); + continue; + } + + self.state + .tx_ring + .advance_empty_front_until(self.state.next_seq); + let entry = self.state.tx_ring.get_mut(&seq).unwrap(); + entry.state = TxState::Issued; + return Some(SessionEnvelope { + seq, + ack, + body: entry.pending.body.clone(), + }); + } + + self.state + .tx_ring + .advance_empty_front_until(self.state.next_seq); + + None + } + + fn take_fresh_write(&mut self, ack: SessionAck) -> Option { + if !self.state.tx_ring.accepts_seq(self.state.next_seq) { + return None; + } + + let pending = self.next_pending_body()?; + let seq = self.state.next_seq; + self.state.next_seq = SessionSeq(seq.0 + 1); + let body = pending.body.clone(); + self.state + .tx_ring + .insert( + seq, + TxEntry { + pending, + state: TxState::Issued, + }, + ) + .unwrap(); + + Some(SessionEnvelope { seq, ack, body }) + } + fn next_pending_body(&mut self) -> Option { if let Some(close) = self.state.pending_control.close.take() { return Some(PendingSessionBody { From bd82782b4643fdde25d751c13e23aa9d28bc8c9f Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 25 Mar 2026 13:26:01 -0400 Subject: [PATCH 033/304] ql: avoid clone for take_next_write --- ql-fsm/src/implementation/fsm.rs | 16 ++--- ql-fsm/src/session/mod.rs | 37 +++++----- ql-fsm/src/session/ring.rs | 2 +- ql-fsm/src/session/tests.rs | 14 +++- ql-wire/src/encrypted/mod.rs | 117 +++++++++++++++++++++++-------- 5 files changed, 125 insertions(+), 61 deletions(-) diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index 94d226cd..3c25e442 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -124,22 +124,22 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option Option { - self.state.now = now; - self.collect_timeouts(); - let ack = self.state.current_ack(); - - self.take_pending_retransmit(ack) - .or_else(|| self.take_fresh_write(ack)) - } - pub fn confirm_write(&mut self, now: Instant, seq: SessionSeq) { self.state.now = now; let Some((retransmit, should_clear_ack)) = self.state.tx_ring.get(&seq).map(|entry| { @@ -454,7 +445,21 @@ impl SessionFsm { }) } - fn take_pending_retransmit(&mut self, ack: SessionAck) -> Option { + pub fn take_next_write( + &mut self, + now: Instant, + ) -> Option<(SessionSeq, SessionAck, &SessionBody)> { + self.state.now = now; + self.collect_timeouts(); + let ack = self.state.current_ack(); + let seq = self + .take_pending_retransmit() + .or_else(|| self.take_fresh_write())?; + let entry = self.state.tx_ring.get(&seq).unwrap(); + Some((seq, ack, &entry.pending.body)) + } + + fn take_pending_retransmit(&mut self) -> Option { let base_seq = self.state.tx_ring.base_seq().0; let next_seq = self.state.next_seq.0; @@ -476,11 +481,7 @@ impl SessionFsm { .advance_empty_front_until(self.state.next_seq); let entry = self.state.tx_ring.get_mut(&seq).unwrap(); entry.state = TxState::Issued; - return Some(SessionEnvelope { - seq, - ack, - body: entry.pending.body.clone(), - }); + return Some(seq); } self.state @@ -490,7 +491,7 @@ impl SessionFsm { None } - fn take_fresh_write(&mut self, ack: SessionAck) -> Option { + fn take_fresh_write(&mut self) -> Option { if !self.state.tx_ring.accepts_seq(self.state.next_seq) { return None; } @@ -498,7 +499,6 @@ impl SessionFsm { let pending = self.next_pending_body()?; let seq = self.state.next_seq; self.state.next_seq = SessionSeq(seq.0 + 1); - let body = pending.body.clone(); self.state .tx_ring .insert( @@ -509,8 +509,7 @@ impl SessionFsm { }, ) .unwrap(); - - Some(SessionEnvelope { seq, ack, body }) + Some(seq) } fn next_pending_body(&mut self) -> Option { diff --git a/ql-fsm/src/session/ring.rs b/ql-fsm/src/session/ring.rs index b6aad72b..4fa6427e 100644 --- a/ql-fsm/src/session/ring.rs +++ b/ql-fsm/src/session/ring.rs @@ -18,6 +18,7 @@ pub struct SeqRing { impl SeqRing { pub fn new(base_seq: SessionSeq) -> Self { + debug_assert!(N <= 64); Self { base_seq, head: 0, @@ -91,7 +92,6 @@ impl SeqRing { } pub fn bitmap(&self) -> u64 { - debug_assert!(N <= 64); let mut bitmap = 0u64; for offset in 0..N { let index = self.index_for_offset(offset); diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index bd7423b3..00008fc5 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -37,8 +37,18 @@ fn ping(seq: u64, ack: SessionAck) -> SessionEnvelope { } fn next_outbound(fsm: &mut SessionFsm, now: Instant) -> Option { - let envelope = fsm.take_next_write(now)?; - fsm.confirm_write(now, envelope.seq); + let (seq, envelope) = { + let (seq, ack, body) = fsm.take_next_write(now)?; + ( + seq, + SessionEnvelope { + seq, + ack, + body: body.clone(), + }, + ) + }; + fsm.confirm_write(now, seq); Some(envelope) } diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index db1bf09a..acaf02e3 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -1,3 +1,5 @@ +use std::mem::size_of; + use zerocopy::{ byte_slice::{ByteSlice, ByteSliceMut}, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, TryFromBytes, Unaligned, @@ -27,13 +29,6 @@ pub struct SessionSeq(pub u64); #[repr(transparent)] pub struct StreamId(pub u32); -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct SessionEnvelope { - pub seq: SessionSeq, - pub ack: SessionAck, - pub body: SessionBody, -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct SessionAck { pub base: SessionSeq, @@ -47,6 +42,13 @@ impl SessionAck { }; } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionEnvelope { + pub seq: SessionSeq, + pub ack: SessionAck, + pub body: SessionBody, +} + #[derive(Debug, Clone, PartialEq, Eq)] pub enum SessionBody { Ack, @@ -112,28 +114,7 @@ impl SessionEnvelope { } pub fn encode(&self) -> Vec { - let mut out = Vec::new(); - let kind = match &self.body { - SessionBody::Ack => SessionBodyKind::Ack, - SessionBody::Ping(_) => SessionBodyKind::Ping, - SessionBody::Stream(_) => SessionBodyKind::Stream, - SessionBody::StreamClose(_) => SessionBodyKind::StreamClose, - SessionBody::Close(_) => SessionBodyKind::Close, - }; - let header = SessionEnvelopeHeaderWire { - seq: U64Le::new(self.seq.0), - ack_base: U64Le::new(self.ack.base.0), - ack_bitmap: U64Le::new(self.ack.bitmap), - kind: kind as u8, - }; - push_value(&mut out, &header); - match &self.body { - SessionBody::Ack | SessionBody::Ping(_) => {} - SessionBody::Stream(frame) => frame.encode_into(&mut out), - SessionBody::StreamClose(frame) => frame.encode_into(&mut out), - SessionBody::Close(body) => body.encode_into(&mut out), - } - out + encode_session_envelope(self.seq, self.ack, &self.body) } pub fn decode(bytes: &[u8]) -> Result { @@ -147,10 +128,30 @@ pub fn encrypt_record( session_key: &SessionKey, body: &SessionEnvelope, nonce: Nonce, +) -> QlRecord { + encrypt_record_parts( + crypto, + header, + session_key, + body.seq, + body.ack, + &body.body, + nonce, + ) +} + +pub fn encrypt_record_parts( + crypto: &impl QlCrypto, + header: QlHeader, + session_key: &SessionKey, + seq: SessionSeq, + ack: SessionAck, + body: &SessionBody, + nonce: Nonce, ) -> QlRecord { let aad = header.aad(); - let body_bytes = body.encode(); - let encrypted = EncryptedMessage::encrypt(crypto, session_key, body_bytes, &aad, nonce); + let body = encode_session_envelope(seq, ack, body); + let encrypted = EncryptedMessage::encrypt(crypto, session_key, body, &aad, nonce); QlRecord { header, payload: QlPayload::Session(encrypted), @@ -177,6 +178,60 @@ pub struct SessionEnvelopeHeaderWire { pub kind: u8, } +fn encode_session_envelope(seq: SessionSeq, ack: SessionAck, body: &SessionBody) -> Vec { + let expected_len = size_of::() + session_body_encoded_len(body); + let mut out = Vec::with_capacity(expected_len); + let initial_capacity = out.capacity(); + encode_session_envelope_into(seq, ack, body, &mut out); + debug_assert_eq!(out.len(), expected_len); + debug_assert_eq!(out.capacity(), initial_capacity); + out +} + +fn encode_session_envelope_into( + seq: SessionSeq, + ack: SessionAck, + body: &SessionBody, + out: &mut Vec, +) { + let header = SessionEnvelopeHeaderWire { + seq: U64Le::new(seq.0), + ack_base: U64Le::new(ack.base.0), + ack_bitmap: U64Le::new(ack.bitmap), + kind: session_body_kind_for(body) as u8, + }; + push_value(out, &header); + encode_session_body_into(body, out); +} + +fn session_body_kind_for(body: &SessionBody) -> SessionBodyKind { + match body { + SessionBody::Ack => SessionBodyKind::Ack, + SessionBody::Ping(_) => SessionBodyKind::Ping, + SessionBody::Stream(_) => SessionBodyKind::Stream, + SessionBody::StreamClose(_) => SessionBodyKind::StreamClose, + SessionBody::Close(_) => SessionBodyKind::Close, + } +} + +fn session_body_encoded_len(body: &SessionBody) -> usize { + match body { + SessionBody::Ack | SessionBody::Ping(_) => 0, + SessionBody::Stream(frame) => size_of::() + frame.bytes.len(), + SessionBody::StreamClose(frame) => size_of::() + frame.payload.len(), + SessionBody::Close(_) => size_of::(), + } +} + +fn encode_session_body_into(body: &SessionBody, out: &mut Vec) { + match body { + SessionBody::Ack | SessionBody::Ping(_) => {} + SessionBody::Stream(frame) => frame.encode_into(out), + SessionBody::StreamClose(frame) => frame.encode_into(out), + SessionBody::Close(body) => body.encode_into(out), + } +} + fn session_body_kind(wire: &SessionEnvelopeWire) -> Result { crate::codec::read_byte(wire.kind) } From f5ffcc50f17f0ce29a75478f36ccbc1ab38e11d3 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 26 Mar 2026 11:18:26 -0400 Subject: [PATCH 034/304] ql-design: add encrypted record size --- QL_V2.md | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/QL_V2.md b/QL_V2.md index 6e868c92..f3e2f7a5 100644 --- a/QL_V2.md +++ b/QL_V2.md @@ -71,6 +71,29 @@ The record sizes shows the protocol's intended split between setup and steady-st | `session stream close` | 94 bytes | | `session close` | 89 bytes | +Any encrypted record has the same outer wire shape: + +| Component | Size | +| --- | ---: | +| protocol version | 1 byte | +| record kind | 1 byte | +| sender XID | 16 bytes | +| recipient XID | 16 bytes | +| AEAD nonce | 12 bytes | +| AEAD auth tag | 16 bytes | +| ciphertext | N bytes | + +That gives a 62-byte minimum for any encrypted record before counting the encrypted plaintext. The AEAD keeps the ciphertext the same length as the plaintext, so after that fixed 62-byte overhead, each additional plaintext byte becomes one additional ciphertext byte. + +For session records, the encrypted plaintext always starts with a 25-byte session envelope: + +| Session envelope field | Size | +| --- | ---: | +| `seq` | 8 bytes | +| `ack.base` | 8 bytes | +| `ack.bitmap` | 8 bytes | +| session body kind discriminator | 1 byte | + ### 7. Shared core state machine QLv2 should have one core implementation of pairing, handshake, session, retransmission, and stream behavior. Platforms should integrate that shared state machine instead of rebuilding subtle protocol logic independently. From c69790eb92c24d72db778140685306b2db3621bb Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 27 Mar 2026 18:43:01 -0400 Subject: [PATCH 035/304] ql-wire: stream acks and multiple frames per record --- ql-wire/src/codec.rs | 8 - ql-wire/src/encrypted/close.rs | 22 +- ql-wire/src/encrypted/mod.rs | 354 +++++++++++++------------ ql-wire/src/encrypted/ping.rs | 2 - ql-wire/src/encrypted/stream_ack.rs | 136 ++++++++++ ql-wire/src/encrypted/stream_chunk.rs | 59 ----- ql-wire/src/encrypted/stream_close.rs | 45 ++-- ql-wire/src/encrypted/stream_data.rs | 82 ++++++ ql-wire/src/encrypted/stream_window.rs | 50 ++++ ql-wire/src/tests.rs | 318 ++++++++++++++++++---- 10 files changed, 758 insertions(+), 318 deletions(-) delete mode 100644 ql-wire/src/encrypted/ping.rs create mode 100644 ql-wire/src/encrypted/stream_ack.rs delete mode 100644 ql-wire/src/encrypted/stream_chunk.rs create mode 100644 ql-wire/src/encrypted/stream_data.rs create mode 100644 ql-wire/src/encrypted/stream_window.rs diff --git a/ql-wire/src/codec.rs b/ql-wire/src/codec.rs index 63eb81d9..5e3e16fd 100644 --- a/ql-wire/src/codec.rs +++ b/ql-wire/src/codec.rs @@ -48,14 +48,6 @@ where Ref::<_, T>::from_bytes(bytes).map_err(|_| WireError::InvalidPayload) } -pub fn ensure_empty(bytes: &[u8]) -> Result<(), WireError> { - if bytes.is_empty() { - Ok(()) - } else { - Err(WireError::InvalidPayload) - } -} - pub fn append_field(out: &mut Vec, label: &[u8], value: &[u8]) { append_framed_bytes(out, label); append_framed_bytes(out, value); diff --git a/ql-wire/src/encrypted/close.rs b/ql-wire/src/encrypted/close.rs index d0b95953..a4a1048f 100644 --- a/ql-wire/src/encrypted/close.rs +++ b/ql-wire/src/encrypted/close.rs @@ -1,11 +1,16 @@ -use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned}; +use std::mem::size_of; + +use zerocopy::{ + byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, +}; use super::CloseCode; use crate::{ - codec::{push_value, read_exact, U16Le}, + codec::{parse, push_value, read_exact, U16Le}, WireError, }; +/// closes the whole session immediately with a close code. #[derive(Debug, Clone, PartialEq, Eq)] pub struct SessionCloseBody { pub code: CloseCode, @@ -18,7 +23,16 @@ pub struct SessionCloseBodyWire { } impl SessionCloseBody { - pub fn from_wire(wire: SessionCloseBodyWire) -> Self { + pub const WIRE_SIZE: usize = size_of::(); + + pub fn parse(bytes: B) -> Result, WireError> { + if bytes.len() != Self::WIRE_SIZE { + return Err(WireError::InvalidPayload); + } + parse(bytes) + } + + pub fn from_wire(wire: &SessionCloseBodyWire) -> Self { Self { code: CloseCode(wire.code.get()), } @@ -36,6 +50,6 @@ impl SessionCloseBody { pub fn decode(bytes: &[u8]) -> Result { let wire: SessionCloseBodyWire = read_exact(bytes)?; - Ok(Self::from_wire(wire)) + Ok(Self::from_wire(&wire)) } } diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index acaf02e3..2ee317cb 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -6,115 +6,97 @@ use zerocopy::{ }; use crate::{ - codec::{parse, push_value, U64Le}, + codec::{parse, read_byte}, encrypted_message::{EncryptedMessage, EncryptedMessageWire}, Nonce, QlCrypto, QlHeader, QlPayload, QlRecord, SessionKey, WireError, }; mod close; -mod ping; -mod stream_chunk; +mod stream_ack; mod stream_close; +mod stream_data; +mod stream_window; pub use close::*; -pub use ping::*; -pub use stream_chunk::*; +pub use stream_ack::*; pub use stream_close::*; +pub use stream_data::*; +pub use stream_window::*; -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -#[repr(transparent)] -pub struct SessionSeq(pub u64); - +// todo: should use even/odd based on xid ordering #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] pub struct StreamId(pub u32); -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct SessionAck { - pub base: SessionSeq, - pub bitmap: u64, -} - -impl SessionAck { - pub const EMPTY: Self = Self { - base: SessionSeq(0), - bitmap: 0, - }; -} - #[derive(Debug, Clone, PartialEq, Eq)] -pub struct SessionEnvelope { - pub seq: SessionSeq, - pub ack: SessionAck, - pub body: SessionBody, +pub struct SessionRecord { + pub frames: Vec, } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum SessionBody { - Ack, - Ping(ping::PingBody), - Stream(StreamChunk), +pub enum SessionFrame { + Ping, + Pong, + StreamData(StreamData), + StreamAck(StreamAck), + StreamWindow(StreamWindow), StreamClose(StreamClose), - Close(close::SessionCloseBody), + Close(SessionCloseBody), } -pub enum SessionBodyRef { - Ack, +pub enum SessionFrameRef<'a> { Ping, - Stream(Ref), - StreamClose(Ref), - Close(close::SessionCloseBody), + Pong, + StreamData(Ref<&'a [u8], StreamDataWire>), + StreamAck(Ref<&'a [u8], StreamAckWire>), + StreamWindow(Ref<&'a [u8], StreamWindowWire>), + StreamClose(Ref<&'a [u8], StreamCloseWire>), + Close(Ref<&'a [u8], SessionCloseBodyWire>), } #[derive( Debug, Clone, Copy, PartialEq, Eq, TryFromBytes, KnownLayout, Immutable, IntoBytes, Unaligned, )] #[repr(u8)] -enum SessionBodyKind { - Ack = 1, - Ping = 2, - Stream = 4, - StreamClose = 5, - Close = 6, +pub(crate) enum SessionFrameKind { + Ping = 1, + Pong = 2, + StreamData = 3, + StreamAck = 4, + StreamWindow = 5, + StreamClose = 6, + Close = 7, } -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[derive(FromBytes, KnownLayout, Immutable, Unaligned)] #[repr(C, packed)] -pub struct SessionEnvelopeWire { - pub seq: U64Le, - pub ack_base: U64Le, - pub ack_bitmap: U64Le, - pub kind: u8, - pub body: [u8], +pub struct SessionRecordWire { + pub frames: [u8], } -impl SessionEnvelope { - pub fn parse(bytes: B) -> Result, WireError> { +pub struct SessionFrameIter<'a> { + remaining: &'a [u8], +} + +impl SessionRecord { + pub fn parse(bytes: B) -> Result, WireError> { parse(bytes) } - pub fn from_wire(wire: &SessionEnvelopeWire) -> Result { - let body = match parse_session_body(session_body_kind(wire)?, &wire.body)? { - SessionBodyRef::Ack => SessionBody::Ack, - SessionBodyRef::Ping => SessionBody::Ping(ping::PingBody), - SessionBodyRef::Stream(frame) => SessionBody::Stream(StreamChunk::from_wire(&frame)?), - SessionBodyRef::StreamClose(frame) => { - SessionBody::StreamClose(StreamClose::from_wire(&frame)?) - } - SessionBodyRef::Close(body) => SessionBody::Close(body), - }; - Ok(Self { - seq: SessionSeq(wire.seq.get()), - ack: SessionAck { - base: SessionSeq(wire.ack_base.get()), - bitmap: wire.ack_bitmap.get(), - }, - body, - }) + pub fn from_wire(wire: &SessionRecordWire) -> Result { + let frames = wire + .frames() + .map(|frame| frame?.to_owned()) + .collect::, _>>()?; + Ok(Self { frames }) } pub fn encode(&self) -> Vec { - encode_session_envelope(self.seq, self.ack, &self.body) + let mut out = Vec::new(); + for frame in &self.frames { + frame.encode_into(&mut out); + } + out } pub fn decode(bytes: &[u8]) -> Result { @@ -122,35 +104,91 @@ impl SessionEnvelope { } } -pub fn encrypt_record( - crypto: &impl QlCrypto, - header: QlHeader, - session_key: &SessionKey, - body: &SessionEnvelope, - nonce: Nonce, -) -> QlRecord { - encrypt_record_parts( - crypto, - header, - session_key, - body.seq, - body.ack, - &body.body, - nonce, - ) +impl SessionRecordWire { + pub fn frames(&self) -> SessionFrameIter<'_> { + SessionFrameIter { + remaining: &self.frames, + } + } } -pub fn encrypt_record_parts( +impl SessionFrame { + pub fn encode_into(&self, out: &mut Vec) { + match self { + Self::Ping => out.push(SessionFrameKind::Ping as u8), + Self::Pong => out.push(SessionFrameKind::Pong as u8), + Self::StreamData(frame) => { + out.push(SessionFrameKind::StreamData as u8); + push_variable_len(out, frame.encoded_len()); + frame.encode_into(out); + } + Self::StreamAck(frame) => { + out.push(SessionFrameKind::StreamAck as u8); + push_variable_len(out, frame.encoded_len()); + frame.encode_into(out); + } + Self::StreamWindow(frame) => { + out.push(SessionFrameKind::StreamWindow as u8); + frame.encode_into(out); + } + Self::StreamClose(frame) => { + out.push(SessionFrameKind::StreamClose as u8); + push_variable_len(out, frame.encoded_len()); + frame.encode_into(out); + } + Self::Close(body) => { + out.push(SessionFrameKind::Close as u8); + body.encode_into(out); + } + } + } +} + +impl SessionFrameRef<'_> { + pub fn to_owned(&self) -> Result { + Ok(match self { + Self::Ping => SessionFrame::Ping, + Self::Pong => SessionFrame::Pong, + Self::StreamData(frame) => SessionFrame::StreamData(StreamData::from_wire(frame)?), + Self::StreamAck(frame) => SessionFrame::StreamAck(StreamAck::from_wire(frame)?), + Self::StreamWindow(frame) => SessionFrame::StreamWindow(StreamWindow::from_wire(frame)), + Self::StreamClose(frame) => SessionFrame::StreamClose(StreamClose::from_wire(frame)?), + Self::Close(frame) => SessionFrame::Close(SessionCloseBody::from_wire(frame)), + }) + } +} + +impl<'a> Iterator for SessionFrameIter<'a> { + type Item = Result, WireError>; + + fn next(&mut self) -> Option { + if self.remaining.is_empty() { + return None; + } + + let parsed = parse_next_frame(self.remaining); + match parsed { + Ok((frame, rest)) => { + self.remaining = rest; + Some(Ok(frame)) + } + Err(error) => { + self.remaining = &[]; + Some(Err(error)) + } + } + } +} + +pub fn encrypt_record( crypto: &impl QlCrypto, header: QlHeader, session_key: &SessionKey, - seq: SessionSeq, - ack: SessionAck, - body: &SessionBody, + body: &SessionRecord, nonce: Nonce, ) -> QlRecord { let aad = header.aad(); - let body = encode_session_envelope(seq, ack, body); + let body = body.encode(); let encrypted = EncryptedMessage::encrypt(crypto, session_key, body, &aad, nonce); QlRecord { header, @@ -163,96 +201,68 @@ pub fn decrypt_record<'a, B: ByteSliceMut>( header: &QlHeader, encrypted: &'a mut Ref, session_key: &SessionKey, -) -> Result, WireError> { +) -> Result, WireError> { let aad = header.aad(); let plaintext = EncryptedMessage::decrypt_in_place(encrypted, crypto, session_key, &aad)?; - SessionEnvelope::parse(plaintext) -} - -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] -#[repr(C)] -pub struct SessionEnvelopeHeaderWire { - pub seq: U64Le, - pub ack_base: U64Le, - pub ack_bitmap: U64Le, - pub kind: u8, -} - -fn encode_session_envelope(seq: SessionSeq, ack: SessionAck, body: &SessionBody) -> Vec { - let expected_len = size_of::() + session_body_encoded_len(body); - let mut out = Vec::with_capacity(expected_len); - let initial_capacity = out.capacity(); - encode_session_envelope_into(seq, ack, body, &mut out); - debug_assert_eq!(out.len(), expected_len); - debug_assert_eq!(out.capacity(), initial_capacity); - out -} - -fn encode_session_envelope_into( - seq: SessionSeq, - ack: SessionAck, - body: &SessionBody, - out: &mut Vec, -) { - let header = SessionEnvelopeHeaderWire { - seq: U64Le::new(seq.0), - ack_base: U64Le::new(ack.base.0), - ack_bitmap: U64Le::new(ack.bitmap), - kind: session_body_kind_for(body) as u8, - }; - push_value(out, &header); - encode_session_body_into(body, out); + SessionRecord::parse(plaintext) } -fn session_body_kind_for(body: &SessionBody) -> SessionBodyKind { - match body { - SessionBody::Ack => SessionBodyKind::Ack, - SessionBody::Ping(_) => SessionBodyKind::Ping, - SessionBody::Stream(_) => SessionBodyKind::Stream, - SessionBody::StreamClose(_) => SessionBodyKind::StreamClose, - SessionBody::Close(_) => SessionBodyKind::Close, - } -} - -fn session_body_encoded_len(body: &SessionBody) -> usize { - match body { - SessionBody::Ack | SessionBody::Ping(_) => 0, - SessionBody::Stream(frame) => size_of::() + frame.bytes.len(), - SessionBody::StreamClose(frame) => size_of::() + frame.payload.len(), - SessionBody::Close(_) => size_of::(), +fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrameRef<'_>, &[u8]), WireError> { + let (&kind, rest) = bytes.split_first().ok_or(WireError::InvalidPayload)?; + let kind: SessionFrameKind = read_byte(kind)?; + match kind { + SessionFrameKind::Ping => Ok((SessionFrameRef::Ping, rest)), + SessionFrameKind::Pong => Ok((SessionFrameRef::Pong, rest)), + SessionFrameKind::StreamData => { + let (frame, rest) = split_variable_frame(rest)?; + Ok((SessionFrameRef::StreamData(StreamData::parse(frame)?), rest)) + } + SessionFrameKind::StreamAck => { + let (frame, rest) = split_variable_frame(rest)?; + Ok((SessionFrameRef::StreamAck(StreamAck::parse(frame)?), rest)) + } + SessionFrameKind::StreamWindow => { + let wire_size = StreamWindow::WIRE_SIZE; + if rest.len() < wire_size { + return Err(WireError::InvalidPayload); + } + let (frame, rest) = rest.split_at(wire_size); + Ok(( + SessionFrameRef::StreamWindow(StreamWindow::parse(frame)?), + rest, + )) + } + SessionFrameKind::StreamClose => { + let (frame, rest) = split_variable_frame(rest)?; + Ok(( + SessionFrameRef::StreamClose(StreamClose::parse(frame)?), + rest, + )) + } + SessionFrameKind::Close => { + let wire_size = SessionCloseBody::WIRE_SIZE; + if rest.len() < wire_size { + return Err(WireError::InvalidPayload); + } + let (frame, rest) = rest.split_at(wire_size); + let frame = SessionCloseBody::parse(frame)?; + Ok((SessionFrameRef::Close(frame), rest)) + } } } -fn encode_session_body_into(body: &SessionBody, out: &mut Vec) { - match body { - SessionBody::Ack | SessionBody::Ping(_) => {} - SessionBody::Stream(frame) => frame.encode_into(out), - SessionBody::StreamClose(frame) => frame.encode_into(out), - SessionBody::Close(body) => body.encode_into(out), - } +fn push_variable_len(out: &mut Vec, len: usize) { + let len = u16::try_from(len).expect("session frame exceeds u16"); + out.extend_from_slice(&len.to_le_bytes()); } -fn session_body_kind(wire: &SessionEnvelopeWire) -> Result { - crate::codec::read_byte(wire.kind) -} +fn split_variable_frame(bytes: &[u8]) -> Result<(&[u8], &[u8]), WireError> { + const LEN_SIZE: usize = size_of::(); -fn parse_session_body( - kind: SessionBodyKind, - body: B, -) -> Result, WireError> { - match kind { - SessionBodyKind::Ack => { - crate::codec::ensure_empty(&body)?; - Ok(SessionBodyRef::Ack) - } - SessionBodyKind::Ping => { - crate::codec::ensure_empty(&body)?; - Ok(SessionBodyRef::Ping) - } - SessionBodyKind::Stream => Ok(SessionBodyRef::Stream(StreamChunk::parse(body)?)), - SessionBodyKind::StreamClose => Ok(SessionBodyRef::StreamClose(StreamClose::parse(body)?)), - SessionBodyKind::Close => Ok(SessionBodyRef::Close(close::SessionCloseBody::decode( - &body, - )?)), + if bytes.len() < LEN_SIZE { + return Err(WireError::InvalidPayload); } + let len = u16::from_le_bytes([bytes[0], bytes[1]]) as usize; + let bytes = &bytes[LEN_SIZE..]; + bytes.split_at_checked(len).ok_or(WireError::InvalidPayload) } diff --git a/ql-wire/src/encrypted/ping.rs b/ql-wire/src/encrypted/ping.rs deleted file mode 100644 index e0dd3fd2..00000000 --- a/ql-wire/src/encrypted/ping.rs +++ /dev/null @@ -1,2 +0,0 @@ -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub struct PingBody; diff --git a/ql-wire/src/encrypted/stream_ack.rs b/ql-wire/src/encrypted/stream_ack.rs new file mode 100644 index 00000000..b49b5dbd --- /dev/null +++ b/ql-wire/src/encrypted/stream_ack.rs @@ -0,0 +1,136 @@ +use std::mem::size_of; + +use zerocopy::{ + byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, +}; + +use super::StreamId; +use crate::{ + codec::{parse, push_value, read_exact, U32Le, U64Le}, + WireError, +}; + +/// acknowledges a contiguous prefix plus optional selective ranges. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamAck { + pub stream_id: StreamId, + pub acked_prefix: u64, + pub ranges: Vec, +} + +/// one acknowledged range after the acked prefix. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct StreamAckRange { + pub start_offset: u64, + pub end_offset: u64, +} + +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] +#[repr(C)] +pub struct StreamAckRangeWire { + pub start_offset: U64Le, + pub end_offset: U64Le, +} + +#[derive(FromBytes, KnownLayout, Immutable, Unaligned)] +#[repr(C, packed)] +pub struct StreamAckWire { + pub stream_id: U32Le, + pub acked_prefix: U64Le, + pub ranges: [u8], +} + +pub struct StreamAckRangeIter<'a> { + remaining: &'a [u8], +} + +impl StreamAck { + pub const MIN_WIRE_SIZE: usize = size_of::() + size_of::(); + + pub fn parse(bytes: B) -> Result, WireError> { + let wire = parse(bytes)?; + validate_ack_frame(&wire)?; + Ok(wire) + } + + pub fn encoded_len(&self) -> usize { + Self::MIN_WIRE_SIZE + self.ranges.len() * size_of::() + } + + pub fn from_wire(wire: &StreamAckWire) -> Result { + validate_ack_frame(wire)?; + Ok(Self { + stream_id: wire.stream_id(), + acked_prefix: wire.acked_prefix(), + ranges: wire.ranges().collect(), + }) + } + + pub fn encode_into(&self, out: &mut Vec) { + out.extend_from_slice(&self.stream_id.0.to_le_bytes()); + out.extend_from_slice(&self.acked_prefix.to_le_bytes()); + for range in &self.ranges { + push_value( + out, + &StreamAckRangeWire { + start_offset: U64Le::new(range.start_offset), + end_offset: U64Le::new(range.end_offset), + }, + ); + } + } +} + +impl StreamAckWire { + pub fn stream_id(&self) -> StreamId { + StreamId(self.stream_id.get()) + } + + pub fn acked_prefix(&self) -> u64 { + self.acked_prefix.get() + } + + pub fn ranges(&self) -> StreamAckRangeIter<'_> { + StreamAckRangeIter { + remaining: &self.ranges, + } + } +} + +impl Iterator for StreamAckRangeIter<'_> { + type Item = StreamAckRange; + + fn next(&mut self) -> Option { + if self.remaining.is_empty() { + return None; + } + let (head, tail) = self.remaining.split_at(size_of::()); + self.remaining = tail; + let wire: StreamAckRangeWire = + read_exact(head).expect("stream ack ranges are validated before iteration"); + Some(StreamAckRange { + start_offset: wire.start_offset.get(), + end_offset: wire.end_offset.get(), + }) + } +} + +fn validate_ack_frame(wire: &StreamAckWire) -> Result<(), WireError> { + if wire.ranges.len() % size_of::() != 0 { + return Err(WireError::InvalidPayload); + } + + let acked_prefix = wire.acked_prefix(); + let mut previous_end = acked_prefix; + for range in wire.ranges() { + if range.start_offset < acked_prefix + || range.start_offset >= range.end_offset + || range.start_offset < previous_end + { + return Err(WireError::InvalidPayload); + } + previous_end = range.end_offset; + } + + Ok(()) +} diff --git a/ql-wire/src/encrypted/stream_chunk.rs b/ql-wire/src/encrypted/stream_chunk.rs deleted file mode 100644 index 2c41b416..00000000 --- a/ql-wire/src/encrypted/stream_chunk.rs +++ /dev/null @@ -1,59 +0,0 @@ -use zerocopy::{ - byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, -}; - -use super::StreamId; -use crate::{ - codec::{parse, push_value, U32Le, U64Le}, - WireError, -}; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct StreamChunk { - pub stream_id: StreamId, - pub chunk_seq: u64, - pub fin: bool, - pub bytes: Vec, -} - -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] -#[repr(C, packed)] -pub struct StreamChunkWire { - pub stream_id: U32Le, - pub chunk_seq: U64Le, - pub fin: u8, - pub bytes: [u8], -} - -impl StreamChunk { - pub fn parse(bytes: B) -> Result, WireError> { - parse(bytes) - } - - pub fn from_wire(wire: &StreamChunkWire) -> Result { - Ok(StreamChunk { - stream_id: StreamId(wire.stream_id.get()), - chunk_seq: wire.chunk_seq.get(), - bytes: wire.bytes.to_vec(), - fin: crate::codec::read_byte(wire.fin)?, - }) - } - - pub fn encode_into(&self, out: &mut Vec) { - let header = StreamChunkHeaderWire { - stream_id: U32Le::new(self.stream_id.0), - chunk_seq: U64Le::new(self.chunk_seq), - fin: u8::from(self.fin), - }; - push_value(out, &header); - out.extend_from_slice(&self.bytes); - } -} - -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] -#[repr(C)] -pub struct StreamChunkHeaderWire { - pub stream_id: U32Le, - pub chunk_seq: U64Le, - pub fin: u8, -} diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index b76a7d68..665ef02a 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -1,14 +1,14 @@ +use std::mem::size_of; + use zerocopy::{ byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, TryFromBytes, Unaligned, }; use super::StreamId; -use crate::{ - codec::{parse, push_value, U16Le, U32Le}, - WireError, -}; +use crate::{codec::{parse, read_byte, U16Le, U32Le}, WireError}; +/// aborts one or both directions of a stream with a close code. #[derive(Debug, Clone, PartialEq, Eq)] pub struct StreamClose { pub stream_id: StreamId, @@ -50,7 +50,7 @@ impl CloseCode { pub const UNHANDLED: Self = Self(20); } -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[derive(FromBytes, KnownLayout, Immutable, Unaligned)] #[repr(C, packed)] pub struct StreamCloseWire { pub stream_id: U32Le, @@ -60,34 +60,35 @@ pub struct StreamCloseWire { } impl StreamClose { + pub const MIN_WIRE_SIZE: usize = + size_of::() + size_of::() + size_of::(); + pub fn parse(bytes: B) -> Result, WireError> { - parse(bytes) + if bytes.len() < Self::MIN_WIRE_SIZE { + return Err(WireError::InvalidPayload); + } + let wire: Ref = parse(bytes)?; + let _ = read_byte::(wire.target)?; + Ok(wire) + } + + pub fn encoded_len(&self) -> usize { + Self::MIN_WIRE_SIZE + self.payload.len() } pub fn from_wire(wire: &StreamCloseWire) -> Result { - Ok(StreamClose { + Ok(Self { stream_id: StreamId(wire.stream_id.get()), - target: crate::codec::read_byte(wire.target)?, + target: read_byte(wire.target)?, code: CloseCode(wire.code.get()), payload: wire.payload.to_vec(), }) } pub fn encode_into(&self, out: &mut Vec) { - let header = StreamCloseHeaderWire { - stream_id: U32Le::new(self.stream_id.0), - target: self.target.to_wire(), - code: U16Le::new(self.code.0), - }; - push_value(out, &header); + out.extend_from_slice(&self.stream_id.0.to_le_bytes()); + out.push(self.target.to_wire()); + out.extend_from_slice(&self.code.0.to_le_bytes()); out.extend_from_slice(&self.payload); } } - -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] -#[repr(C)] -pub struct StreamCloseHeaderWire { - pub stream_id: U32Le, - pub target: u8, - pub code: U16Le, -} diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs new file mode 100644 index 00000000..ac69d012 --- /dev/null +++ b/ql-wire/src/encrypted/stream_data.rs @@ -0,0 +1,82 @@ +use std::mem::size_of; + +use zerocopy::{ + byte_slice::ByteSlice, FromBytes, Immutable, KnownLayout, Ref, Unaligned, +}; + +use super::StreamId; +use crate::{codec::{parse, U32Le, U64Le}, WireError}; + +/// carries bytes for a stream and may finish that sending direction. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamData { + pub stream_id: StreamId, + pub offset: u64, + pub fin: bool, + pub bytes: Vec, +} + +#[derive(FromBytes, KnownLayout, Immutable, Unaligned)] +#[repr(C, packed)] +pub struct StreamDataWire { + pub stream_id: U32Le, + pub offset: U64Le, + pub fin: u8, + pub bytes: [u8], +} + +impl StreamData { + pub const MIN_WIRE_SIZE: usize = + size_of::() + size_of::() + size_of::(); + + pub fn parse(bytes: B) -> Result, WireError> { + if bytes.len() < Self::MIN_WIRE_SIZE { + return Err(WireError::InvalidPayload); + } + let wire: Ref = parse(bytes)?; + let _ = wire.fin()?; + Ok(wire) + } + + pub fn encoded_len(&self) -> usize { + Self::MIN_WIRE_SIZE + self.bytes.len() + } + + pub fn from_wire(wire: &StreamDataWire) -> Result { + Ok(Self { + stream_id: wire.stream_id(), + offset: wire.offset(), + fin: wire.fin()?, + bytes: wire.bytes().to_vec(), + }) + } + + pub fn encode_into(&self, out: &mut Vec) { + out.extend_from_slice(&self.stream_id.0.to_le_bytes()); + out.extend_from_slice(&self.offset.to_le_bytes()); + out.push(u8::from(self.fin)); + out.extend_from_slice(&self.bytes); + } +} + +impl StreamDataWire { + pub fn stream_id(&self) -> StreamId { + StreamId(self.stream_id.get()) + } + + pub fn offset(&self) -> u64 { + self.offset.get() + } + + pub fn fin(&self) -> Result { + match self.fin { + 0 => Ok(false), + 1 => Ok(true), + _ => Err(WireError::InvalidPayload), + } + } + + pub fn bytes(&self) -> &[u8] { + &self.bytes + } +} diff --git a/ql-wire/src/encrypted/stream_window.rs b/ql-wire/src/encrypted/stream_window.rs new file mode 100644 index 00000000..2f6e2c4a --- /dev/null +++ b/ql-wire/src/encrypted/stream_window.rs @@ -0,0 +1,50 @@ +use std::mem::size_of; + +use zerocopy::{ + byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, +}; + +use super::StreamId; +use crate::{codec::{parse, push_value, U32Le, U64Le}, WireError}; + +/// advertises the highest byte offset the peer may send on a stream. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamWindow { + pub stream_id: StreamId, + pub maximum_offset: u64, +} + +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] +#[repr(C)] +pub struct StreamWindowWire { + pub stream_id: U32Le, + pub maximum_offset: U64Le, +} + +impl StreamWindow { + pub const WIRE_SIZE: usize = size_of::(); + + pub fn parse(bytes: B) -> Result, WireError> { + if bytes.len() != Self::WIRE_SIZE { + return Err(WireError::InvalidPayload); + } + parse(bytes) + } + + pub fn from_wire(wire: &StreamWindowWire) -> Self { + Self { + stream_id: StreamId(wire.stream_id.get()), + maximum_offset: wire.maximum_offset.get(), + } + } + + pub fn encode_into(&self, out: &mut Vec) { + push_value( + out, + &StreamWindowWire { + stream_id: U32Le::new(self.stream_id.0), + maximum_offset: U64Le::new(self.maximum_offset), + }, + ); + } +} diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index de80c488..f71f09d2 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -72,18 +72,44 @@ fn encrypted_session_record_round_trip_and_decrypt() { sender: XID([1; XID::SIZE]), recipient: XID([2; XID::SIZE]), }; - let body = SessionEnvelope { - seq: SessionSeq(7), - ack: SessionAck { - base: SessionSeq(3), - bitmap: 0b101, - }, - body: SessionBody::Stream(StreamChunk { - stream_id: StreamId(9), - chunk_seq: 11, - bytes: b"hello".to_vec(), - fin: true, - }), + let body = SessionRecord { + frames: vec![ + SessionFrame::Ping, + SessionFrame::Pong, + SessionFrame::StreamAck(StreamAck { + stream_id: StreamId(7), + acked_prefix: 12, + ranges: vec![ + StreamAckRange { + start_offset: 20, + end_offset: 24, + }, + StreamAckRange { + start_offset: 30, + end_offset: 33, + }, + ], + }), + SessionFrame::StreamWindow(StreamWindow { + stream_id: StreamId(9), + maximum_offset: 65_536, + }), + SessionFrame::StreamData(StreamData { + stream_id: StreamId(9), + offset: 1024, + bytes: b"hello".to_vec(), + fin: true, + }), + SessionFrame::StreamClose(StreamClose { + stream_id: StreamId(9), + target: CloseTarget::Both, + code: CloseCode::PROTOCOL, + payload: b"bye".to_vec(), + }), + SessionFrame::Close(SessionCloseBody { + code: CloseCode::TIMEOUT, + }), + ], }; let session_key = SessionKey::from_data([7; SessionKey::SIZE]); let record = encrypted::encrypt_record( @@ -109,7 +135,92 @@ fn encrypted_session_record_round_trip_and_decrypt() { }; let decrypted = encrypted::decrypt_record(&crypto, &header, &mut encrypted, &session_key).unwrap(); - assert_eq!(SessionEnvelope::from_wire(&decrypted).unwrap(), body); + assert_eq!(SessionRecord::from_wire(&decrypted).unwrap(), body); +} + +#[test] +fn decrypted_session_record_iterates_zero_copy_frames() { + let crypto = TestCrypto::new(2); + let header = QlHeader { + sender: XID([9; XID::SIZE]), + recipient: XID([10; XID::SIZE]), + }; + let body = SessionRecord { + frames: vec![ + SessionFrame::StreamData(StreamData { + stream_id: StreamId(1), + offset: 5, + fin: false, + bytes: b"abc".to_vec(), + }), + SessionFrame::StreamAck(StreamAck { + stream_id: StreamId(1), + acked_prefix: 3, + ranges: vec![StreamAckRange { + start_offset: 5, + end_offset: 8, + }], + }), + SessionFrame::StreamClose(StreamClose { + stream_id: StreamId(1), + target: CloseTarget::Response, + code: CloseCode::CANCELLED, + payload: b"later".to_vec(), + }), + ], + }; + let session_key = SessionKey::from_data([3; SessionKey::SIZE]); + let record = encrypted::encrypt_record( + &crypto, + header, + &session_key, + &body, + Nonce([4; Nonce::SIZE]), + ); + + let mut bytes = record.encode(); + let QlRecordRef { header, payload } = QlRecord::parse_mut(&mut bytes).unwrap(); + let QlPayloadRef::Session(mut encrypted) = payload else { + panic!("expected session payload"); + }; + let decrypted = + encrypted::decrypt_record(&crypto, &header, &mut encrypted, &session_key).unwrap(); + + let mut frames = decrypted.frames(); + match frames.next().unwrap().unwrap() { + SessionFrameRef::StreamData(frame) => { + assert_eq!(frame.stream_id(), StreamId(1)); + assert_eq!(frame.offset(), 5); + assert!(!frame.fin().unwrap()); + assert_eq!(frame.bytes(), b"abc"); + } + other => panic!("expected stream data, got {}", frame_name(&other)), + } + match frames.next().unwrap().unwrap() { + SessionFrameRef::StreamAck(frame) => { + assert_eq!(frame.stream_id(), StreamId(1)); + assert_eq!(frame.acked_prefix(), 3); + let ranges: Vec<_> = frame.ranges().collect(); + assert_eq!( + ranges, + vec![StreamAckRange { + start_offset: 5, + end_offset: 8, + }] + ); + } + other => panic!("expected stream ack, got {}", frame_name(&other)), + } + match frames.next().unwrap().unwrap() { + SessionFrameRef::StreamClose(frame) => { + let owned = StreamClose::from_wire(&frame).unwrap(); + assert_eq!(owned.stream_id, StreamId(1)); + assert_eq!(owned.target, CloseTarget::Response); + assert_eq!(owned.payload, b"later".to_vec()); + } + other => panic!("expected stream close, got {}", frame_name(&other)), + } + assert!(frames.next().is_none()); } #[test] @@ -224,6 +335,85 @@ fn unpair_round_trip_and_verify() { unpair::verify_unpair(&crypto, &header, &sender_signing_public, &unpair, 100).unwrap(); } +#[test] +fn session_record_rejects_malformed_frames() { + let invalid_cases = [ + vec![0xff], + { + let mut bytes = vec![SessionFrameKind::StreamData as u8]; + bytes.push(1); + bytes + }, + { + let mut bytes = vec![SessionFrameKind::StreamData as u8]; + bytes.extend_from_slice(&13u16.to_le_bytes()); + bytes.extend_from_slice(&1u32.to_le_bytes()); + bytes.extend_from_slice(&4u64.to_le_bytes()); + bytes.push(0); + bytes.extend_from_slice(b"abc"); + bytes + }, + { + let mut bytes = vec![SessionFrameKind::StreamAck as u8]; + bytes.extend_from_slice(&20u16.to_le_bytes()); + bytes.extend_from_slice(&1u32.to_le_bytes()); + bytes.extend_from_slice(&3u64.to_le_bytes()); + bytes.extend_from_slice(&5u64.to_le_bytes()); + bytes + }, + { + let mut bytes = vec![SessionFrameKind::StreamAck as u8]; + bytes.extend_from_slice(&28u16.to_le_bytes()); + bytes.extend_from_slice(&1u32.to_le_bytes()); + bytes.extend_from_slice(&6u64.to_le_bytes()); + bytes.extend_from_slice(&4u64.to_le_bytes()); + bytes.extend_from_slice(&8u64.to_le_bytes()); + bytes + }, + { + let mut bytes = vec![SessionFrameKind::StreamClose as u8]; + bytes.extend_from_slice(&9u16.to_le_bytes()); + bytes.extend_from_slice(&1u32.to_le_bytes()); + bytes.push(CloseTarget::Both as u8); + bytes.extend_from_slice(&CloseCode::PROTOCOL.0.to_le_bytes()); + bytes.extend_from_slice(b"abc"); + bytes + }, + { + let mut bytes = vec![SessionFrameKind::Close as u8]; + bytes.push(0); + bytes + }, + ]; + + for bytes in invalid_cases { + assert_eq!(SessionRecord::decode(&bytes), Err(WireError::InvalidPayload)); + } +} + +#[test] +fn session_record_supports_empty_fin_stream_data_and_empty_ping_pong() { + let record = SessionRecord { + frames: vec![ + SessionFrame::Ping, + SessionFrame::Pong, + SessionFrame::StreamData(StreamData { + stream_id: StreamId(42), + offset: 999, + fin: true, + bytes: Vec::new(), + }), + ], + }; + + let encoded = record.encode(); + assert_eq!(encoded[0], SessionFrameKind::Ping as u8); + assert_eq!(encoded[1], SessionFrameKind::Pong as u8); + + let decoded = SessionRecord::decode(&encoded).unwrap(); + assert_eq!(decoded, record); +} + #[test] fn protocol_record_size_breakdown() { fn meta(id: u32) -> ControlMeta { @@ -248,7 +438,7 @@ fn protocol_record_size_breakdown() { } } - fn session_record(header: QlHeader, tag: u8, body: SessionEnvelope) -> QlRecord { + fn session_record(header: QlHeader, tag: u8, body: SessionRecord) -> QlRecord { let ciphertext_len = body.encode().len(); QlRecord { header, @@ -303,75 +493,87 @@ fn protocol_record_size_breakdown() { }), }; - let session_ack = session_record( + let session_ping = session_record( header, 14, - SessionEnvelope { - seq: SessionSeq(1), - ack: SessionAck::EMPTY, - body: SessionBody::Ack, + SessionRecord { + frames: vec![SessionFrame::Ping], }, ); - let session_ping = session_record( + let session_pong = session_record( header, 15, - SessionEnvelope { - seq: SessionSeq(2), - ack: SessionAck::EMPTY, - body: SessionBody::Ping(PingBody), + SessionRecord { + frames: vec![SessionFrame::Pong], }, ); - let session_stream_empty = session_record( + let session_stream_window = session_record( header, 16, - SessionEnvelope { - seq: SessionSeq(3), - ack: SessionAck::EMPTY, - body: SessionBody::Stream(StreamChunk { + SessionRecord { + frames: vec![SessionFrame::StreamWindow(StreamWindow { stream_id: StreamId(1), - chunk_seq: 0, + maximum_offset: 65_536, + })], + }, + ); + let session_stream_ack = session_record( + header, + 17, + SessionRecord { + frames: vec![SessionFrame::StreamAck(StreamAck { + stream_id: StreamId(1), + acked_prefix: 4, + ranges: vec![StreamAckRange { + start_offset: 8, + end_offset: 12, + }], + })], + }, + ); + let session_stream_empty = session_record( + header, + 18, + SessionRecord { + frames: vec![SessionFrame::StreamData(StreamData { + stream_id: StreamId(1), + offset: 0, fin: false, bytes: Vec::new(), - }), + })], }, ); let session_stream_fin = session_record( header, - 17, - SessionEnvelope { - seq: SessionSeq(4), - ack: SessionAck::EMPTY, - body: SessionBody::Stream(StreamChunk { + 19, + SessionRecord { + frames: vec![SessionFrame::StreamData(StreamData { stream_id: StreamId(1), - chunk_seq: 0, + offset: 0, fin: true, bytes: Vec::new(), - }), + })], }, ); let session_stream_close = session_record( header, - 18, - SessionEnvelope { - seq: SessionSeq(5), - ack: SessionAck::EMPTY, - body: SessionBody::StreamClose(StreamClose { + 20, + SessionRecord { + frames: vec![SessionFrame::StreamClose(StreamClose { stream_id: StreamId(1), target: CloseTarget::Both, code: CloseCode::PROTOCOL, payload: Vec::new(), - }), + })], }, ); let session_close = session_record( header, - 19, - SessionEnvelope { - seq: SessionSeq(6), - ack: SessionAck::EMPTY, - body: SessionBody::Close(SessionCloseBody { + 21, + SessionRecord { + frames: vec![SessionFrame::Close(SessionCloseBody { code: CloseCode::PROTOCOL, - }), + })], }, ); @@ -385,8 +587,10 @@ fn protocol_record_size_breakdown() { print_size("ql-wire pair_request empty", pair_request.encode().len()); print_size("ql-wire unpair", unpair.encode().len()); print_size("ql-wire ready empty", ready.encode().len()); - print_size("ql-wire session ack", session_ack.encode().len()); print_size("ql-wire session ping", session_ping.encode().len()); + print_size("ql-wire session pong", session_pong.encode().len()); + print_size("ql-wire session stream window", session_stream_window.encode().len()); + print_size("ql-wire session stream ack", session_stream_ack.encode().len()); print_size( "ql-wire session stream empty", session_stream_empty.encode().len(), @@ -401,3 +605,15 @@ fn protocol_record_size_breakdown() { ); print_size("ql-wire session close", session_close.encode().len()); } + +fn frame_name(frame: &SessionFrameRef<'_>) -> &'static str { + match frame { + SessionFrameRef::Ping => "ping", + SessionFrameRef::Pong => "pong", + SessionFrameRef::StreamData(_) => "stream_data", + SessionFrameRef::StreamAck(_) => "stream_ack", + SessionFrameRef::StreamWindow(_) => "stream_window", + SessionFrameRef::StreamClose(_) => "stream_close", + SessionFrameRef::Close(_) => "close", + } +} From ca7fa31ffe42d86cdfc32ab2e1c3b0a08c97ae9c Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 27 Mar 2026 23:40:31 -0400 Subject: [PATCH 036/304] ql-wire: byte reassembly --- ql-wire/src/encrypted/byte_reassembly.rs | 627 +++++++++++++++++++++++ ql-wire/src/encrypted/mod.rs | 2 + 2 files changed, 629 insertions(+) create mode 100644 ql-wire/src/encrypted/byte_reassembly.rs diff --git a/ql-wire/src/encrypted/byte_reassembly.rs b/ql-wire/src/encrypted/byte_reassembly.rs new file mode 100644 index 00000000..c6cb306b --- /dev/null +++ b/ql-wire/src/encrypted/byte_reassembly.rs @@ -0,0 +1,627 @@ +use std::collections::VecDeque; + +/// reassembles one stream direction from out-of-order byte ranges. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ByteReassembly { + start_offset: u64, + bytes: VecDeque, + missing: MissingRanges, + final_offset: Option, + max_buffered: usize, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct MissingRange { + pub start: u64, + pub end: u64, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct InsertOutcome { + pub newly_readable_bytes: usize, + pub became_complete: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ByteReassemblyError { + OffsetOverflow, + OutOfWindow, + InconsistentFinalOffset, + FinalOffsetBeforeBufferedData, + BeyondFinalOffset, + ConflictingOverlap, + ConsumeBeyondReadable, + TooManyMissingRanges, +} + +#[derive(Debug, Clone, Copy)] +pub struct BytesIter<'a> { + front: Option<&'a [u8]>, + back: Option<&'a [u8]>, +} + +impl ByteReassembly { + pub fn new(max_buffered: usize) -> Self { + Self::with_start_offset(0, max_buffered) + } + + pub fn with_start_offset(start_offset: u64, max_buffered: usize) -> Self { + Self { + start_offset, + bytes: VecDeque::new(), + missing: MissingRanges::new(), + final_offset: None, + max_buffered, + } + } + + pub fn start_offset(&self) -> u64 { + self.start_offset + } + + pub fn buffered_end_offset(&self) -> u64 { + self.start_offset + self.bytes.len() as u64 + } + + pub fn final_offset(&self) -> Option { + self.final_offset + } + + pub fn max_buffered(&self) -> usize { + self.max_buffered + } + + pub fn missing_ranges(&self) -> &[MissingRange] { + self.missing.as_slice() + } + + pub fn readable_len(&self) -> usize { + if self.bytes.is_empty() { + return 0; + } + + match self.missing.first() { + Some(range) if range.start <= self.start_offset => 0, + Some(range) => usize::try_from(range.start - self.start_offset) + .expect("readable prefix exceeds usize"), + None => self.bytes.len(), + } + } + + pub fn bytes(&self) -> BytesIter<'_> { + let readable = self.readable_len(); + if readable == 0 { + return BytesIter { + front: None, + back: None, + }; + } + + let (front, back) = self.bytes.as_slices(); + if readable <= front.len() { + BytesIter { + front: Some(&front[..readable]), + back: None, + } + } else { + BytesIter { + front: Some(front), + back: Some(&back[..readable - front.len()]), + } + } + } + + pub fn copy_readable(&self) -> Vec { + let readable = self.readable_len(); + let mut out = Vec::with_capacity(readable); + for chunk in self.bytes() { + out.extend_from_slice(chunk); + } + out + } + + pub fn is_complete(&self) -> bool { + matches!(self.final_offset, Some(final_offset) if final_offset == self.buffered_end_offset()) + && self.missing.is_empty() + } + + pub fn insert( + &mut self, + offset: u64, + fin: bool, + bytes: &[u8], + ) -> Result { + let end = offset + .checked_add(bytes.len() as u64) + .ok_or(ByteReassemblyError::OffsetOverflow)?; + + let was_complete = self.is_complete(); + let old_readable = self.readable_len(); + + if fin { + self.set_or_validate_final_offset(end)?; + } + if let Some(final_offset) = self.final_offset { + if end > final_offset { + return Err(ByteReassemblyError::BeyondFinalOffset); + } + } + + if bytes.is_empty() || end <= self.start_offset { + return Ok(self.insert_outcome(was_complete, old_readable)); + } + + let effective_offset = offset.max(self.start_offset); + let trim_front = + usize::try_from(effective_offset - offset).expect("front trim exceeds usize"); + let effective_bytes = &bytes[trim_front..]; + if effective_bytes.is_empty() { + return Ok(self.insert_outcome(was_complete, old_readable)); + } + + self.ensure_within_window(end)?; + self.ensure_buffered(end)?; + self.validate_overlap(effective_offset, effective_bytes)?; + self.write_bytes(effective_offset, effective_bytes); + self.subtract_missing_range(effective_offset, end)?; + + Ok(self.insert_outcome(was_complete, old_readable)) + } + + pub fn consume(&mut self, len: usize) -> Result<(), ByteReassemblyError> { + let readable = self.readable_len(); + if len > readable { + return Err(ByteReassemblyError::ConsumeBeyondReadable); + } + + self.bytes.drain(..len); + self.start_offset = self.start_offset.saturating_add(len as u64); + Ok(()) + } + + fn insert_outcome(&self, was_complete: bool, old_readable: usize) -> InsertOutcome { + InsertOutcome { + newly_readable_bytes: self.readable_len().saturating_sub(old_readable), + became_complete: !was_complete && self.is_complete(), + } + } + + fn set_or_validate_final_offset( + &mut self, + final_offset: u64, + ) -> Result<(), ByteReassemblyError> { + if let Some(existing) = self.final_offset { + return if existing == final_offset { + Ok(()) + } else { + Err(ByteReassemblyError::InconsistentFinalOffset) + }; + } + + let buffered_end = self.buffered_end_offset(); + if final_offset < buffered_end { + return Err(ByteReassemblyError::FinalOffsetBeforeBufferedData); + } + + self.final_offset = Some(final_offset); + Ok(()) + } + + fn ensure_within_window(&self, end: u64) -> Result<(), ByteReassemblyError> { + let attempted = end.saturating_sub(self.start_offset); + if attempted > self.max_buffered as u64 { + return Err(ByteReassemblyError::OutOfWindow); + } + Ok(()) + } + + fn ensure_buffered(&mut self, end: u64) -> Result<(), ByteReassemblyError> { + let buffered_end = self.buffered_end_offset(); + if end <= buffered_end { + return Ok(()); + } + + let additional = usize::try_from(end - buffered_end).expect("buffer growth exceeds usize"); + self.bytes.resize(self.bytes.len() + additional, 0); + self.push_missing_range(MissingRange { + start: buffered_end, + end, + }) + } + + fn push_missing_range( + &mut self, + range: MissingRange, + ) -> Result<(), ByteReassemblyError> { + if range.start >= range.end { + return Ok(()); + } + + if let Some(last) = self.missing.last_mut() { + if last.end >= range.start { + last.end = last.end.max(range.end); + return Ok(()); + } + } + + self.missing.push(range) + } + + fn validate_overlap( + &self, + offset: u64, + bytes: &[u8], + ) -> Result<(), ByteReassemblyError> { + let mut gap_index = self.first_gap_index_after(offset); + + for (index, byte) in bytes.iter().copied().enumerate() { + let absolute = offset + index as u64; + + while gap_index < self.missing.len() && self.missing[gap_index].end <= absolute { + gap_index += 1; + } + + let is_missing = gap_index < self.missing.len() + && self.missing[gap_index].start <= absolute + && absolute < self.missing[gap_index].end; + if is_missing { + continue; + } + + if self.byte_at(absolute) != byte { + return Err(ByteReassemblyError::ConflictingOverlap); + } + } + + Ok(()) + } + + fn write_bytes(&mut self, offset: u64, bytes: &[u8]) { + let start_index = + usize::try_from(offset - self.start_offset).expect("write index exceeds usize"); + for (index, byte) in bytes.iter().copied().enumerate() { + self.bytes[start_index + index] = byte; + } + } + + fn subtract_missing_range( + &mut self, + start: u64, + end: u64, + ) -> Result<(), ByteReassemblyError> { + let first = self.first_gap_index_after(start); + if first == self.missing.len() || self.missing[first].start >= end { + return Ok(()); + } + + let mut last_exclusive = first; + while last_exclusive < self.missing.len() && self.missing[last_exclusive].start < end { + last_exclusive += 1; + } + + let last = last_exclusive - 1; + let keep_left = self.missing[first].start < start; + let keep_right = self.missing[last].end > end; + + if first == last { + let original = self.missing[first]; + match (keep_left, keep_right) { + (true, true) => { + self.missing[first].end = start; + self.missing.insert( + first + 1, + MissingRange { + start: end, + end: original.end, + }, + )?; + } + (true, false) => { + self.missing[first].end = start; + } + (false, true) => { + self.missing[first].start = end; + } + (false, false) => { + self.missing.remove(first); + } + } + return Ok(()); + } + + match (keep_left, keep_right) { + (true, true) => { + self.missing[first].end = start; + self.missing[last].start = end; + self.missing.drain(first + 1..last); + } + (true, false) => { + self.missing[first].end = start; + self.missing.drain(first + 1..last_exclusive); + } + (false, true) => { + self.missing[last].start = end; + self.missing.drain(first..last); + } + (false, false) => { + self.missing.drain(first..last_exclusive); + } + } + + Ok(()) + } + + fn first_gap_index_after(&self, offset: u64) -> usize { + self.missing.as_slice().partition_point(|range| range.end <= offset) + } + + fn byte_at(&self, offset: u64) -> u8 { + let index = usize::try_from(offset - self.start_offset).expect("read index exceeds usize"); + self.bytes[index] + } +} + +impl<'a> Iterator for BytesIter<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option { + if let Some(front) = self.front.take() { + if !front.is_empty() { + return Some(front); + } + } + + if let Some(back) = self.back.take() { + if !back.is_empty() { + return Some(back); + } + } + + None + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct MissingRanges { + ranges: [MissingRange; N], + len: usize, +} + +impl MissingRanges { + fn new() -> Self { + Self { + ranges: [MissingRange { start: 0, end: 0 }; N], + len: 0, + } + } + + fn as_slice(&self) -> &[MissingRange] { + &self.ranges[..self.len] + } + + fn is_empty(&self) -> bool { + self.len == 0 + } + + fn len(&self) -> usize { + self.len + } + + fn first(&self) -> Option<&MissingRange> { + self.as_slice().first() + } + + fn last_mut(&mut self) -> Option<&mut MissingRange> { + if self.len == 0 { + None + } else { + Some(&mut self.ranges[self.len - 1]) + } + } + + fn push(&mut self, range: MissingRange) -> Result<(), ByteReassemblyError> { + if self.len == N { + return Err(ByteReassemblyError::TooManyMissingRanges); + } + self.ranges[self.len] = range; + self.len += 1; + Ok(()) + } + + fn insert(&mut self, index: usize, range: MissingRange) -> Result<(), ByteReassemblyError> { + if self.len == N { + return Err(ByteReassemblyError::TooManyMissingRanges); + } + for i in (index..self.len).rev() { + self.ranges[i + 1] = self.ranges[i]; + } + self.ranges[index] = range; + self.len += 1; + Ok(()) + } + + fn remove(&mut self, index: usize) -> MissingRange { + let removed = self.ranges[index]; + for i in index + 1..self.len { + self.ranges[i - 1] = self.ranges[i]; + } + self.len -= 1; + self.ranges[self.len] = MissingRange { start: 0, end: 0 }; + removed + } + + fn drain(&mut self, range: std::ops::Range) { + let count = range.end - range.start; + if count == 0 { + return; + } + + for i in range.end..self.len { + self.ranges[i - count] = self.ranges[i]; + } + for i in self.len - count..self.len { + self.ranges[i] = MissingRange { start: 0, end: 0 }; + } + self.len -= count; + } +} + +impl std::ops::Index for MissingRanges { + type Output = MissingRange; + + fn index(&self, index: usize) -> &Self::Output { + &self.as_slice()[index] + } +} + +impl std::ops::IndexMut for MissingRanges { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.ranges[index] + } +} + +#[cfg(test)] +mod tests { + use super::{ByteReassembly, ByteReassemblyError, InsertOutcome, MissingRange}; + + #[test] + fn contiguous_insert_becomes_readable_and_complete() { + let mut assembler = ByteReassembly::<8>::new(64); + + let outcome = assembler.insert(0, true, b"hello").unwrap(); + + assert_eq!( + outcome, + InsertOutcome { + newly_readable_bytes: 5, + became_complete: true, + } + ); + assert_eq!(assembler.readable_len(), 5); + assert_eq!(assembler.copy_readable(), b"hello"); + assert_eq!(assembler.final_offset(), Some(5)); + assert!(assembler.is_complete()); + assert!(assembler.missing_ranges().is_empty()); + } + + #[test] + fn out_of_order_insert_tracks_missing_ranges_until_gap_is_filled() { + let mut assembler = ByteReassembly::<8>::new(64); + + let first = assembler.insert(5, true, b" world").unwrap(); + assert_eq!( + first, + InsertOutcome { + newly_readable_bytes: 0, + became_complete: false, + } + ); + assert_eq!( + assembler.missing_ranges(), + &[MissingRange { start: 0, end: 5 }] + ); + assert_eq!(assembler.readable_len(), 0); + + let second = assembler.insert(0, false, b"hello").unwrap(); + assert_eq!( + second, + InsertOutcome { + newly_readable_bytes: 11, + became_complete: true, + } + ); + assert_eq!(assembler.copy_readable(), b"hello world"); + assert!(assembler.missing_ranges().is_empty()); + assert!(assembler.is_complete()); + } + + #[test] + fn duplicate_insert_is_ignored_if_bytes_match() { + let mut assembler = ByteReassembly::<8>::new(64); + + assembler.insert(0, false, b"hello").unwrap(); + let duplicate = assembler.insert(0, false, b"hello").unwrap(); + + assert_eq!( + duplicate, + InsertOutcome { + newly_readable_bytes: 0, + became_complete: false, + } + ); + assert_eq!(assembler.copy_readable(), b"hello"); + } + + #[test] + fn conflicting_overlap_is_rejected() { + let mut assembler = ByteReassembly::<8>::new(64); + + assembler.insert(0, false, b"abcdef").unwrap(); + let error = assembler.insert(3, false, b"xyz").unwrap_err(); + + assert_eq!(error, ByteReassemblyError::ConflictingOverlap); + } + + #[test] + fn consume_advances_start_offset_and_trims_old_prefix() { + let mut assembler = ByteReassembly::<8>::new(64); + + assembler.insert(0, false, b"abcd").unwrap(); + assembler.consume(2).unwrap(); + assert_eq!(assembler.start_offset(), 2); + assert_eq!(assembler.copy_readable(), b"cd"); + + let outcome = assembler.insert(1, true, b"bcde").unwrap(); + assert_eq!( + outcome, + InsertOutcome { + newly_readable_bytes: 1, + became_complete: true, + } + ); + assert_eq!(assembler.copy_readable(), b"cde"); + assert_eq!(assembler.final_offset(), Some(5)); + assert!(assembler.is_complete()); + } + + #[test] + fn insert_rejects_when_missing_range_budget_is_exhausted() { + let mut assembler = ByteReassembly::<2>::new(64); + + assembler.insert(1, false, b"a").unwrap(); + assembler.insert(3, false, b"b").unwrap(); + let error = assembler.insert(5, false, b"c").unwrap_err(); + + assert_eq!(error, ByteReassemblyError::TooManyMissingRanges); + } + + #[test] + fn insert_can_fill_multiple_gaps_without_rebuilding_state() { + let mut assembler = ByteReassembly::<8>::new(64); + + assembler.insert(0, false, b"ab").unwrap(); + assembler.insert(4, false, b"ef").unwrap(); + assembler.insert(8, true, b"ij").unwrap(); + + assert_eq!( + assembler.missing_ranges(), + &[ + MissingRange { start: 2, end: 4 }, + MissingRange { start: 6, end: 8 }, + ] + ); + + let outcome = assembler.insert(2, false, b"cdefgh").unwrap(); + + assert_eq!( + outcome, + InsertOutcome { + newly_readable_bytes: 8, + became_complete: true, + } + ); + assert!(assembler.missing_ranges().is_empty()); + assert_eq!(assembler.copy_readable(), b"abcdefghij"); + assert!(assembler.is_complete()); + } +} diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 2ee317cb..c4cb348a 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -11,12 +11,14 @@ use crate::{ Nonce, QlCrypto, QlHeader, QlPayload, QlRecord, SessionKey, WireError, }; +mod byte_reassembly; mod close; mod stream_ack; mod stream_close; mod stream_data; mod stream_window; +pub use byte_reassembly::*; pub use close::*; pub use stream_ack::*; pub use stream_close::*; From 070264f319b054479034e00302cf2d16bf592d96 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 28 Mar 2026 00:01:04 -0400 Subject: [PATCH 037/304] ql-wire: record ack range --- ql-wire/src/encrypted/ack.rs | 117 ++++++++++++++++++++++++ ql-wire/src/encrypted/mod.rs | 56 +++++++----- ql-wire/src/encrypted/stream_ack.rs | 136 --------------------------- ql-wire/src/tests.rs | 137 ++++++++++++++-------------- 4 files changed, 217 insertions(+), 229 deletions(-) create mode 100644 ql-wire/src/encrypted/ack.rs delete mode 100644 ql-wire/src/encrypted/stream_ack.rs diff --git a/ql-wire/src/encrypted/ack.rs b/ql-wire/src/encrypted/ack.rs new file mode 100644 index 00000000..1a3f03dd --- /dev/null +++ b/ql-wire/src/encrypted/ack.rs @@ -0,0 +1,117 @@ +use std::mem::size_of; + +use zerocopy::{ + byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, +}; + +use crate::{ + codec::{parse, push_value, read_exact, U64Le}, + WireError, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RecordAck { + pub ranges: Vec, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RecordAckRange { + pub start: u64, + pub end: u64, +} + +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] +#[repr(C)] +pub struct RecordAckRangeWire { + pub start: U64Le, + pub end: U64Le, +} + +#[derive(FromBytes, KnownLayout, Immutable, Unaligned)] +#[repr(C, packed)] +pub struct RecordAckQire { + pub ranges: [u8], +} + +pub struct RecordAckRangeIter<'a> { + remaining: &'a [u8], +} + +impl RecordAck { + pub fn parse(bytes: B) -> Result, WireError> { + let wire = parse(bytes)?; + validate_ack_frame(&wire)?; + Ok(wire) + } + + pub fn encoded_len(&self) -> usize { + self.ranges.len() * size_of::() + } + + pub fn from_wire(wire: &RecordAckQire) -> Result { + validate_ack_frame(wire)?; + Ok(Self { + ranges: wire.ranges().collect(), + }) + } + + pub fn encode_into(&self, out: &mut Vec) { + for range in &self.ranges { + push_value( + out, + &RecordAckRangeWire { + start: U64Le::new(range.start), + end: U64Le::new(range.end), + }, + ); + } + } +} + +impl RecordAckQire { + pub fn ranges(&self) -> RecordAckRangeIter<'_> { + RecordAckRangeIter { + remaining: &self.ranges, + } + } +} + +impl Iterator for RecordAckRangeIter<'_> { + type Item = RecordAckRange; + + fn next(&mut self) -> Option { + if self.remaining.is_empty() { + return None; + } + + let (head, tail) = self.remaining.split_at(size_of::()); + self.remaining = tail; + let wire: RecordAckRangeWire = + read_exact(head).expect("ack ranges are validated before iteration"); + Some(RecordAckRange { + start: wire.start.get(), + end: wire.end.get(), + }) + } +} + +fn validate_ack_frame(wire: &RecordAckQire) -> Result<(), WireError> { + if wire.ranges.is_empty() || wire.ranges.len() % size_of::() != 0 { + return Err(WireError::InvalidPayload); + } + + let mut previous_end = 0; + let mut first = true; + for range in wire.ranges() { + if range.start >= range.end { + return Err(WireError::InvalidPayload); + } + if !first && range.start < previous_end { + return Err(WireError::InvalidPayload); + } + first = false; + previous_end = range.end; + } + + Ok(()) +} diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index c4cb348a..fc1d7a27 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -11,16 +11,16 @@ use crate::{ Nonce, QlCrypto, QlHeader, QlPayload, QlRecord, SessionKey, WireError, }; +mod ack; mod byte_reassembly; mod close; -mod stream_ack; mod stream_close; mod stream_data; mod stream_window; +pub use ack::*; pub use byte_reassembly::*; pub use close::*; -pub use stream_ack::*; pub use stream_close::*; pub use stream_data::*; pub use stream_window::*; @@ -30,17 +30,21 @@ pub use stream_window::*; #[repr(transparent)] pub struct StreamId(pub u32); +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct RecordSeq(pub u64); + #[derive(Debug, Clone, PartialEq, Eq)] pub struct SessionRecord { + pub seq: RecordSeq, pub frames: Vec, } #[derive(Debug, Clone, PartialEq, Eq)] pub enum SessionFrame { Ping, - Pong, + Ack(RecordAck), StreamData(StreamData), - StreamAck(StreamAck), StreamWindow(StreamWindow), StreamClose(StreamClose), Close(SessionCloseBody), @@ -48,9 +52,8 @@ pub enum SessionFrame { pub enum SessionFrameRef<'a> { Ping, - Pong, + Ack(Ref<&'a [u8], RecordAckQire>), StreamData(Ref<&'a [u8], StreamDataWire>), - StreamAck(Ref<&'a [u8], StreamAckWire>), StreamWindow(Ref<&'a [u8], StreamWindowWire>), StreamClose(Ref<&'a [u8], StreamCloseWire>), Close(Ref<&'a [u8], SessionCloseBodyWire>), @@ -62,17 +65,17 @@ pub enum SessionFrameRef<'a> { #[repr(u8)] pub(crate) enum SessionFrameKind { Ping = 1, - Pong = 2, + Ack = 2, StreamData = 3, - StreamAck = 4, - StreamWindow = 5, - StreamClose = 6, - Close = 7, + StreamWindow = 4, + StreamClose = 5, + Close = 6, } #[derive(FromBytes, KnownLayout, Immutable, Unaligned)] #[repr(C, packed)] pub struct SessionRecordWire { + pub seq: crate::codec::U64Le, pub frames: [u8], } @@ -90,11 +93,15 @@ impl SessionRecord { .frames() .map(|frame| frame?.to_owned()) .collect::, _>>()?; - Ok(Self { frames }) + Ok(Self { + seq: wire.seq(), + frames, + }) } pub fn encode(&self) -> Vec { let mut out = Vec::new(); + out.extend_from_slice(&self.seq.0.to_le_bytes()); for frame in &self.frames { frame.encode_into(&mut out); } @@ -107,6 +114,10 @@ impl SessionRecord { } impl SessionRecordWire { + pub fn seq(&self) -> RecordSeq { + RecordSeq(self.seq.get()) + } + pub fn frames(&self) -> SessionFrameIter<'_> { SessionFrameIter { remaining: &self.frames, @@ -118,14 +129,13 @@ impl SessionFrame { pub fn encode_into(&self, out: &mut Vec) { match self { Self::Ping => out.push(SessionFrameKind::Ping as u8), - Self::Pong => out.push(SessionFrameKind::Pong as u8), - Self::StreamData(frame) => { - out.push(SessionFrameKind::StreamData as u8); + Self::Ack(frame) => { + out.push(SessionFrameKind::Ack as u8); push_variable_len(out, frame.encoded_len()); frame.encode_into(out); } - Self::StreamAck(frame) => { - out.push(SessionFrameKind::StreamAck as u8); + Self::StreamData(frame) => { + out.push(SessionFrameKind::StreamData as u8); push_variable_len(out, frame.encoded_len()); frame.encode_into(out); } @@ -150,9 +160,8 @@ impl SessionFrameRef<'_> { pub fn to_owned(&self) -> Result { Ok(match self { Self::Ping => SessionFrame::Ping, - Self::Pong => SessionFrame::Pong, + Self::Ack(frame) => SessionFrame::Ack(RecordAck::from_wire(frame)?), Self::StreamData(frame) => SessionFrame::StreamData(StreamData::from_wire(frame)?), - Self::StreamAck(frame) => SessionFrame::StreamAck(StreamAck::from_wire(frame)?), Self::StreamWindow(frame) => SessionFrame::StreamWindow(StreamWindow::from_wire(frame)), Self::StreamClose(frame) => SessionFrame::StreamClose(StreamClose::from_wire(frame)?), Self::Close(frame) => SessionFrame::Close(SessionCloseBody::from_wire(frame)), @@ -214,14 +223,13 @@ fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrameRef<'_>, &[u8]), WireEr let kind: SessionFrameKind = read_byte(kind)?; match kind { SessionFrameKind::Ping => Ok((SessionFrameRef::Ping, rest)), - SessionFrameKind::Pong => Ok((SessionFrameRef::Pong, rest)), - SessionFrameKind::StreamData => { + SessionFrameKind::Ack => { let (frame, rest) = split_variable_frame(rest)?; - Ok((SessionFrameRef::StreamData(StreamData::parse(frame)?), rest)) + Ok((SessionFrameRef::Ack(RecordAck::parse(frame)?), rest)) } - SessionFrameKind::StreamAck => { + SessionFrameKind::StreamData => { let (frame, rest) = split_variable_frame(rest)?; - Ok((SessionFrameRef::StreamAck(StreamAck::parse(frame)?), rest)) + Ok((SessionFrameRef::StreamData(StreamData::parse(frame)?), rest)) } SessionFrameKind::StreamWindow => { let wire_size = StreamWindow::WIRE_SIZE; diff --git a/ql-wire/src/encrypted/stream_ack.rs b/ql-wire/src/encrypted/stream_ack.rs deleted file mode 100644 index b49b5dbd..00000000 --- a/ql-wire/src/encrypted/stream_ack.rs +++ /dev/null @@ -1,136 +0,0 @@ -use std::mem::size_of; - -use zerocopy::{ - byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, -}; - -use super::StreamId; -use crate::{ - codec::{parse, push_value, read_exact, U32Le, U64Le}, - WireError, -}; - -/// acknowledges a contiguous prefix plus optional selective ranges. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct StreamAck { - pub stream_id: StreamId, - pub acked_prefix: u64, - pub ranges: Vec, -} - -/// one acknowledged range after the acked prefix. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct StreamAckRange { - pub start_offset: u64, - pub end_offset: u64, -} - -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] -#[repr(C)] -pub struct StreamAckRangeWire { - pub start_offset: U64Le, - pub end_offset: U64Le, -} - -#[derive(FromBytes, KnownLayout, Immutable, Unaligned)] -#[repr(C, packed)] -pub struct StreamAckWire { - pub stream_id: U32Le, - pub acked_prefix: U64Le, - pub ranges: [u8], -} - -pub struct StreamAckRangeIter<'a> { - remaining: &'a [u8], -} - -impl StreamAck { - pub const MIN_WIRE_SIZE: usize = size_of::() + size_of::(); - - pub fn parse(bytes: B) -> Result, WireError> { - let wire = parse(bytes)?; - validate_ack_frame(&wire)?; - Ok(wire) - } - - pub fn encoded_len(&self) -> usize { - Self::MIN_WIRE_SIZE + self.ranges.len() * size_of::() - } - - pub fn from_wire(wire: &StreamAckWire) -> Result { - validate_ack_frame(wire)?; - Ok(Self { - stream_id: wire.stream_id(), - acked_prefix: wire.acked_prefix(), - ranges: wire.ranges().collect(), - }) - } - - pub fn encode_into(&self, out: &mut Vec) { - out.extend_from_slice(&self.stream_id.0.to_le_bytes()); - out.extend_from_slice(&self.acked_prefix.to_le_bytes()); - for range in &self.ranges { - push_value( - out, - &StreamAckRangeWire { - start_offset: U64Le::new(range.start_offset), - end_offset: U64Le::new(range.end_offset), - }, - ); - } - } -} - -impl StreamAckWire { - pub fn stream_id(&self) -> StreamId { - StreamId(self.stream_id.get()) - } - - pub fn acked_prefix(&self) -> u64 { - self.acked_prefix.get() - } - - pub fn ranges(&self) -> StreamAckRangeIter<'_> { - StreamAckRangeIter { - remaining: &self.ranges, - } - } -} - -impl Iterator for StreamAckRangeIter<'_> { - type Item = StreamAckRange; - - fn next(&mut self) -> Option { - if self.remaining.is_empty() { - return None; - } - let (head, tail) = self.remaining.split_at(size_of::()); - self.remaining = tail; - let wire: StreamAckRangeWire = - read_exact(head).expect("stream ack ranges are validated before iteration"); - Some(StreamAckRange { - start_offset: wire.start_offset.get(), - end_offset: wire.end_offset.get(), - }) - } -} - -fn validate_ack_frame(wire: &StreamAckWire) -> Result<(), WireError> { - if wire.ranges.len() % size_of::() != 0 { - return Err(WireError::InvalidPayload); - } - - let acked_prefix = wire.acked_prefix(); - let mut previous_end = acked_prefix; - for range in wire.ranges() { - if range.start_offset < acked_prefix - || range.start_offset >= range.end_offset - || range.start_offset < previous_end - { - return Err(WireError::InvalidPayload); - } - previous_end = range.end_offset; - } - - Ok(()) -} diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index f71f09d2..6ba1360c 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -73,21 +73,13 @@ fn encrypted_session_record_round_trip_and_decrypt() { recipient: XID([2; XID::SIZE]), }; let body = SessionRecord { + seq: RecordSeq(11), frames: vec![ SessionFrame::Ping, - SessionFrame::Pong, - SessionFrame::StreamAck(StreamAck { - stream_id: StreamId(7), - acked_prefix: 12, + SessionFrame::Ack(RecordAck { ranges: vec![ - StreamAckRange { - start_offset: 20, - end_offset: 24, - }, - StreamAckRange { - start_offset: 30, - end_offset: 33, - }, + RecordAckRange { start: 12, end: 14 }, + RecordAckRange { start: 20, end: 24 }, ], }), SessionFrame::StreamWindow(StreamWindow { @@ -146,6 +138,7 @@ fn decrypted_session_record_iterates_zero_copy_frames() { recipient: XID([10; XID::SIZE]), }; let body = SessionRecord { + seq: RecordSeq(7), frames: vec![ SessionFrame::StreamData(StreamData { stream_id: StreamId(1), @@ -153,13 +146,8 @@ fn decrypted_session_record_iterates_zero_copy_frames() { fin: false, bytes: b"abc".to_vec(), }), - SessionFrame::StreamAck(StreamAck { - stream_id: StreamId(1), - acked_prefix: 3, - ranges: vec![StreamAckRange { - start_offset: 5, - end_offset: 8, - }], + SessionFrame::Ack(RecordAck { + ranges: vec![RecordAckRange { start: 3, end: 8 }], }), SessionFrame::StreamClose(StreamClose { stream_id: StreamId(1), @@ -186,6 +174,7 @@ fn decrypted_session_record_iterates_zero_copy_frames() { let decrypted = encrypted::decrypt_record(&crypto, &header, &mut encrypted, &session_key).unwrap(); + assert_eq!(decrypted.seq(), RecordSeq(7)); let mut frames = decrypted.frames(); match frames.next().unwrap().unwrap() { SessionFrameRef::StreamData(frame) => { @@ -197,19 +186,11 @@ fn decrypted_session_record_iterates_zero_copy_frames() { other => panic!("expected stream data, got {}", frame_name(&other)), } match frames.next().unwrap().unwrap() { - SessionFrameRef::StreamAck(frame) => { - assert_eq!(frame.stream_id(), StreamId(1)); - assert_eq!(frame.acked_prefix(), 3); + SessionFrameRef::Ack(frame) => { let ranges: Vec<_> = frame.ranges().collect(); - assert_eq!( - ranges, - vec![StreamAckRange { - start_offset: 5, - end_offset: 8, - }] - ); + assert_eq!(ranges, vec![RecordAckRange { start: 3, end: 8 }]); } - other => panic!("expected stream ack, got {}", frame_name(&other)), + other => panic!("expected ack, got {}", frame_name(&other)), } match frames.next().unwrap().unwrap() { SessionFrameRef::StreamClose(frame) => { @@ -338,14 +319,25 @@ fn unpair_round_trip_and_verify() { #[test] fn session_record_rejects_malformed_frames() { let invalid_cases = [ - vec![0xff], { - let mut bytes = vec![SessionFrameKind::StreamData as u8]; + let mut bytes = Vec::new(); + bytes.extend_from_slice(&1u32.to_le_bytes()); + bytes + }, + { + let mut bytes = 1u64.to_le_bytes().to_vec(); + bytes.push(0xff); + bytes + }, + { + let mut bytes = 1u64.to_le_bytes().to_vec(); + bytes.push(SessionFrameKind::StreamData as u8); bytes.push(1); bytes }, { - let mut bytes = vec![SessionFrameKind::StreamData as u8]; + let mut bytes = 1u64.to_le_bytes().to_vec(); + bytes.push(SessionFrameKind::StreamData as u8); bytes.extend_from_slice(&13u16.to_le_bytes()); bytes.extend_from_slice(&1u32.to_le_bytes()); bytes.extend_from_slice(&4u64.to_le_bytes()); @@ -354,24 +346,31 @@ fn session_record_rejects_malformed_frames() { bytes }, { - let mut bytes = vec![SessionFrameKind::StreamAck as u8]; - bytes.extend_from_slice(&20u16.to_le_bytes()); - bytes.extend_from_slice(&1u32.to_le_bytes()); - bytes.extend_from_slice(&3u64.to_le_bytes()); + let mut bytes = 1u64.to_le_bytes().to_vec(); + bytes.push(SessionFrameKind::Ack as u8); + bytes.extend_from_slice(&0u16.to_le_bytes()); + bytes + }, + { + let mut bytes = 1u64.to_le_bytes().to_vec(); + bytes.push(SessionFrameKind::Ack as u8); + bytes.extend_from_slice(&8u16.to_le_bytes()); bytes.extend_from_slice(&5u64.to_le_bytes()); bytes }, { - let mut bytes = vec![SessionFrameKind::StreamAck as u8]; - bytes.extend_from_slice(&28u16.to_le_bytes()); - bytes.extend_from_slice(&1u32.to_le_bytes()); + let mut bytes = 1u64.to_le_bytes().to_vec(); + bytes.push(SessionFrameKind::Ack as u8); + bytes.extend_from_slice(&32u16.to_le_bytes()); bytes.extend_from_slice(&6u64.to_le_bytes()); - bytes.extend_from_slice(&4u64.to_le_bytes()); bytes.extend_from_slice(&8u64.to_le_bytes()); + bytes.extend_from_slice(&7u64.to_le_bytes()); + bytes.extend_from_slice(&9u64.to_le_bytes()); bytes }, { - let mut bytes = vec![SessionFrameKind::StreamClose as u8]; + let mut bytes = 1u64.to_le_bytes().to_vec(); + bytes.push(SessionFrameKind::StreamClose as u8); bytes.extend_from_slice(&9u16.to_le_bytes()); bytes.extend_from_slice(&1u32.to_le_bytes()); bytes.push(CloseTarget::Both as u8); @@ -380,23 +379,27 @@ fn session_record_rejects_malformed_frames() { bytes }, { - let mut bytes = vec![SessionFrameKind::Close as u8]; + let mut bytes = 1u64.to_le_bytes().to_vec(); + bytes.push(SessionFrameKind::Close as u8); bytes.push(0); bytes }, ]; for bytes in invalid_cases { - assert_eq!(SessionRecord::decode(&bytes), Err(WireError::InvalidPayload)); + assert_eq!( + SessionRecord::decode(bytes.as_slice()), + Err(WireError::InvalidPayload) + ); } } #[test] -fn session_record_supports_empty_fin_stream_data_and_empty_ping_pong() { +fn session_record_supports_empty_fin_stream_data_and_empty_ping() { let record = SessionRecord { + seq: RecordSeq(99), frames: vec![ SessionFrame::Ping, - SessionFrame::Pong, SessionFrame::StreamData(StreamData { stream_id: StreamId(42), offset: 999, @@ -407,8 +410,8 @@ fn session_record_supports_empty_fin_stream_data_and_empty_ping_pong() { }; let encoded = record.encode(); - assert_eq!(encoded[0], SessionFrameKind::Ping as u8); - assert_eq!(encoded[1], SessionFrameKind::Pong as u8); + assert_eq!(&encoded[..8], &99u64.to_le_bytes()); + assert_eq!(encoded[8], SessionFrameKind::Ping as u8); let decoded = SessionRecord::decode(&encoded).unwrap(); assert_eq!(decoded, record); @@ -497,44 +500,36 @@ fn protocol_record_size_breakdown() { header, 14, SessionRecord { + seq: RecordSeq(1), frames: vec![SessionFrame::Ping], }, ); - let session_pong = session_record( + let session_ack = session_record( header, 15, SessionRecord { - frames: vec![SessionFrame::Pong], + seq: RecordSeq(2), + frames: vec![SessionFrame::Ack(RecordAck { + ranges: vec![RecordAckRange { start: 1, end: 3 }], + })], }, ); let session_stream_window = session_record( header, 16, SessionRecord { + seq: RecordSeq(3), frames: vec![SessionFrame::StreamWindow(StreamWindow { stream_id: StreamId(1), maximum_offset: 65_536, })], }, ); - let session_stream_ack = session_record( - header, - 17, - SessionRecord { - frames: vec![SessionFrame::StreamAck(StreamAck { - stream_id: StreamId(1), - acked_prefix: 4, - ranges: vec![StreamAckRange { - start_offset: 8, - end_offset: 12, - }], - })], - }, - ); let session_stream_empty = session_record( header, 18, SessionRecord { + seq: RecordSeq(4), frames: vec![SessionFrame::StreamData(StreamData { stream_id: StreamId(1), offset: 0, @@ -547,6 +542,7 @@ fn protocol_record_size_breakdown() { header, 19, SessionRecord { + seq: RecordSeq(5), frames: vec![SessionFrame::StreamData(StreamData { stream_id: StreamId(1), offset: 0, @@ -559,6 +555,7 @@ fn protocol_record_size_breakdown() { header, 20, SessionRecord { + seq: RecordSeq(6), frames: vec![SessionFrame::StreamClose(StreamClose { stream_id: StreamId(1), target: CloseTarget::Both, @@ -571,6 +568,7 @@ fn protocol_record_size_breakdown() { header, 21, SessionRecord { + seq: RecordSeq(7), frames: vec![SessionFrame::Close(SessionCloseBody { code: CloseCode::PROTOCOL, })], @@ -588,9 +586,11 @@ fn protocol_record_size_breakdown() { print_size("ql-wire unpair", unpair.encode().len()); print_size("ql-wire ready empty", ready.encode().len()); print_size("ql-wire session ping", session_ping.encode().len()); - print_size("ql-wire session pong", session_pong.encode().len()); - print_size("ql-wire session stream window", session_stream_window.encode().len()); - print_size("ql-wire session stream ack", session_stream_ack.encode().len()); + print_size("ql-wire session ack", session_ack.encode().len()); + print_size( + "ql-wire session stream window", + session_stream_window.encode().len(), + ); print_size( "ql-wire session stream empty", session_stream_empty.encode().len(), @@ -609,9 +609,8 @@ fn protocol_record_size_breakdown() { fn frame_name(frame: &SessionFrameRef<'_>) -> &'static str { match frame { SessionFrameRef::Ping => "ping", - SessionFrameRef::Pong => "pong", + SessionFrameRef::Ack(_) => "ack", SessionFrameRef::StreamData(_) => "stream_data", - SessionFrameRef::StreamAck(_) => "stream_ack", SessionFrameRef::StreamWindow(_) => "stream_window", SessionFrameRef::StreamClose(_) => "stream_close", SessionFrameRef::Close(_) => "close", From 6b552e5eb8a9c01821b642cf5713214420ae790a Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 28 Mar 2026 00:42:20 -0400 Subject: [PATCH 038/304] ql-fsm: quic lite wip --- ql-fsm/src/error.rs | 3 + ql-fsm/src/implementation/fsm.rs | 33 +- ql-fsm/src/implementation/mod.rs | 22 +- ql-fsm/src/lib.rs | 66 +- ql-fsm/src/session/mod.rs | 1264 ++++++++++++++++----------- ql-fsm/src/session/ring.rs | 197 ----- ql-fsm/src/session/state.rs | 261 ++++-- ql-fsm/src/session/stream_window.rs | 71 -- ql-fsm/src/session/tests.rs | 879 +++---------------- ql-fsm/src/tests/handshake.rs | 2 +- ql-fsm/src/tests/mod.rs | 30 +- ql-fsm/src/tests/session.rs | 237 +++-- 12 files changed, 1289 insertions(+), 1776 deletions(-) delete mode 100644 ql-fsm/src/session/ring.rs delete mode 100644 ql-fsm/src/session/stream_window.rs diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs index 5a436812..1642222c 100644 --- a/ql-fsm/src/error.rs +++ b/ql-fsm/src/error.rs @@ -19,6 +19,8 @@ pub enum QlFsmError { MissingStream, #[error("stream is not writable")] NotWritable, + #[error("invalid read commit")] + InvalidRead, #[error("session is closed")] SessionClosed, #[error("no peer bound")] @@ -43,6 +45,7 @@ impl From for QlFsmError { match value { StreamError::MissingStream => Self::MissingStream, StreamError::NotWritable => Self::NotWritable, + StreamError::InvalidRead => Self::InvalidRead, StreamError::SessionClosed => Self::SessionClosed, } } diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index 3c25e442..428bb255 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -57,11 +57,10 @@ pub fn receive( let Some((_, session_key)) = super::peer_session(fsm) else { return Err(QlFsmError::NoSession); }; - let envelope = wire::decrypt_record(crypto, &header, &mut encrypted, &session_key)?; - // TODO: this seems unnecessary to me? - let envelope = wire::SessionEnvelope::from_wire(&envelope)?; + let record = wire::decrypt_record(crypto, &header, &mut encrypted, &session_key)?; + let record = wire::SessionRecord::from_wire(&record)?; let mut session_closed = false; - fsm.session.receive(fsm.state.now.instant, envelope, { + fsm.session.receive(fsm.state.now.instant, record, { let session_events = &mut fsm.state.session_events; |event| { session_closed |= super::forward_session_event(session_events, event); @@ -126,20 +125,18 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option Result { pub fn write_stream( fsm: &mut QlFsm, stream_id: StreamId, - bytes: Vec, -) -> Result<(), QlFsmError> { + bytes: &[u8], +) -> Result { ensure_peer_bound(fsm)?; Ok(fsm.session.write_stream(stream_id, bytes)?) } -pub fn read_stream( +pub fn peek_stream( fsm: &mut QlFsm, stream_id: StreamId, out: &mut [u8], ) -> Result { - Ok(fsm.session.read_stream(stream_id, out)?) + Ok(fsm.session.peek_stream(stream_id, out)?) +} + +pub fn commit_stream_read( + fsm: &mut QlFsm, + stream_id: StreamId, + len: usize, +) -> Result<(), QlFsmError> { + Ok(fsm.session.commit_stream_read(stream_id, len)?) } pub fn stream_available_bytes(fsm: &QlFsm, stream_id: StreamId) -> Result { diff --git a/ql-fsm/src/implementation/mod.rs b/ql-fsm/src/implementation/mod.rs index 54bd727a..3d8b9261 100644 --- a/ql-fsm/src/implementation/mod.rs +++ b/ql-fsm/src/implementation/mod.rs @@ -10,7 +10,7 @@ pub use peer::*; use ql_wire::{ControlId, ControlMeta, QlHeader, QlPayload, QlRecord, SessionKey, XID}; use crate::{ - session::{SessionEvent, SessionFsmConfig, StreamNamespace}, + session::{state::StreamParity, SessionEvent, SessionFsmConfig}, QlFsm, QlFsmEvent, QlSessionEvent, }; @@ -55,19 +55,21 @@ fn peer_session(fsm: &QlFsm) -> Option<(XID, SessionKey)> { } fn reset_session(fsm: &mut QlFsm) { - let local_namespace = fsm + let local_parity = fsm .peer .as_ref() - .map(|peer| StreamNamespace::for_local(fsm.identity.xid, peer.peer.xid)) - .unwrap_or(StreamNamespace::Low); + .map(|peer| StreamParity::for_local(fsm.identity.xid, peer.peer.xid)) + .unwrap_or(StreamParity::Even); fsm.session = crate::session::SessionFsm::new( SessionFsmConfig { - local_namespace, - stream_chunk_size: fsm.config.session_stream_chunk_size, - ack_delay: fsm.config.session_ack_delay, - retransmit_timeout: fsm.config.session_retransmit_timeout, + local_parity, + record_size: fsm.config.session_record_size, + ack_delay: fsm.config.session_record_ack_delay, + retransmit_timeout: fsm.config.session_record_retransmit_timeout, keepalive_interval: fsm.config.session_keepalive_interval, peer_timeout: fsm.config.session_peer_timeout, + stream_send_buffer_size: fsm.config.session_stream_send_buffer_size, + stream_receive_buffer_size: fsm.config.session_stream_receive_buffer_size, }, fsm.state.now.instant, ); @@ -108,6 +110,10 @@ fn forward_session_event( session_events.push_back(QlSessionEvent::Readable(stream_id)); false } + SessionEvent::Writable(stream_id) => { + session_events.push_back(QlSessionEvent::Writable(stream_id)); + false + } SessionEvent::Finished(stream_id) => { session_events.push_back(QlSessionEvent::Finished(stream_id)); false diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index ac6b77b7..c7bacb69 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -31,7 +31,7 @@ use std::time::{Duration, Instant}; pub use error::QlFsmError; use ql_wire::{ CloseCode, CloseTarget, MlDsaPublicKey, MlKemPublicKey, QlCrypto, QlIdentity, QlRecord, - SessionCloseBody, SessionSeq, StreamClose, StreamId, XID, + SessionCloseBody, StreamClose, StreamId, XID, }; use crate::{ @@ -96,6 +96,8 @@ pub enum QlSessionEvent { Opened(StreamId), /// a stream has bytes ready to read Readable(StreamId), + /// a stream has room for more local writes + Writable(StreamId), /// the peer finished writing this stream Finished(StreamId), /// a stream was closed @@ -110,10 +112,7 @@ pub enum QlSessionEvent { /// handle for a session write returned by `QlFsm::take_next_write` #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct SessionWriteId( - /// session sequence number for this write - pub SessionSeq, -); +pub struct SessionWriteId(pub(crate) u64); /// outbound record produced by `QlFsm` #[derive(Debug, Clone, PartialEq)] @@ -135,16 +134,20 @@ pub struct QlFsmConfig { pub max_handshake_retries: u8, /// how far into the future control messages remain valid pub control_expiration: Duration, - /// delay before sending a pure ack - pub session_ack_delay: Duration, - /// how long to wait before resending unacked session data - pub session_retransmit_timeout: Duration, + /// delay before sending a pure record ack + pub session_record_ack_delay: Duration, + /// how long to wait before resending unacked session records + pub session_record_retransmit_timeout: Duration, /// idle delay before sending a keepalive ping pub session_keepalive_interval: Duration, /// how long to wait before declaring the peer dead pub session_peer_timeout: Duration, - /// maximum bytes per outbound stream chunk - pub session_stream_chunk_size: usize, + /// target plaintext size for one session record + pub session_record_size: usize, + /// maximum bytes buffered locally for one stream send side + pub session_stream_send_buffer_size: usize, + /// maximum bytes buffered locally for one stream receive side + pub session_stream_receive_buffer_size: usize, } impl Default for QlFsmConfig { @@ -154,11 +157,13 @@ impl Default for QlFsmConfig { handshake_retry_interval: Duration::from_millis(750), max_handshake_retries: 3, control_expiration: Duration::from_secs(30), - session_ack_delay: Duration::from_millis(5), - session_retransmit_timeout: Duration::from_millis(150), + session_record_ack_delay: Duration::from_millis(5), + session_record_retransmit_timeout: Duration::from_millis(150), session_keepalive_interval: Duration::from_secs(10), session_peer_timeout: Duration::from_secs(30), - session_stream_chunk_size: 16 * 1024, + session_record_size: 16 * 1024, + session_stream_send_buffer_size: 64 * 1024, + session_stream_receive_buffer_size: 64 * 1024, } } } @@ -183,12 +188,14 @@ impl QlFsm { peer: None, session: session::SessionFsm::new( session::SessionFsmConfig { - local_namespace: session::StreamNamespace::Low, - stream_chunk_size: config.session_stream_chunk_size, - ack_delay: config.session_ack_delay, - retransmit_timeout: config.session_retransmit_timeout, + local_parity: session::state::StreamParity::Even, + record_size: config.session_record_size, + ack_delay: config.session_record_ack_delay, + retransmit_timeout: config.session_record_retransmit_timeout, keepalive_interval: config.session_keepalive_interval, peer_timeout: config.session_peer_timeout, + stream_send_buffer_size: config.session_stream_send_buffer_size, + stream_receive_buffer_size: config.session_stream_receive_buffer_size, }, now.instant, ), @@ -287,18 +294,31 @@ impl QlFsm { implementation::open_stream(self) } - /// queues bytes for an open stream - pub fn write_stream(&mut self, stream_id: StreamId, bytes: Vec) -> Result<(), QlFsmError> { + /// queues bytes for an open stream and returns the accepted count + pub fn write_stream( + &mut self, + stream_id: StreamId, + bytes: &[u8], + ) -> Result { implementation::write_stream(self, stream_id, bytes) } - /// reads queued bytes from a stream into `out` - pub fn read_stream( + /// copies readable bytes from a stream into `out` without consuming them + pub fn peek_stream( &mut self, stream_id: StreamId, out: &mut [u8], ) -> Result { - implementation::read_stream(self, stream_id, out) + implementation::peek_stream(self, stream_id, out) + } + + /// marks previously peeked bytes as consumed + pub fn commit_stream_read( + &mut self, + stream_id: StreamId, + len: usize, + ) -> Result<(), QlFsmError> { + implementation::commit_stream_read(self, stream_id, len) } /// returns how many bytes can be read from a stream diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index c073f6e1..0fbeecb0 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -1,6 +1,4 @@ -pub(crate) mod ring; pub(crate) mod state; -pub(crate) mod stream_window; #[cfg(test)] mod tests; @@ -9,74 +7,40 @@ use std::time::{Duration, Instant}; use indexmap::map::Entry; use ql_wire::{ - CloseCode, CloseTarget, PingBody, SessionAck, SessionBody, SessionCloseBody, SessionEnvelope, - SessionSeq, StreamChunk, StreamClose, StreamId, XID, + ByteReassemblyError, CloseCode, CloseTarget, RecordAck, RecordSeq, SessionCloseBody, + SessionFrame, SessionRecord, StreamClose, StreamData, StreamId, StreamWindow, }; -use self::{ - state::{ - AckState, InboundState, OutboundState, PendingSessionBody, SessionFsmState, - StreamOpenState, StreamRole, StreamState, TxEntry, TxState, - }, - stream_window::{RecvInsertOutcome, RxChunk}, +use self::state::{ + AckState, InboundState, OutboundState, PendingRecord, ReceivedRecords, ReceiveInsertOutcome, + ReliableFrame, SentRecord, SessionFsmState, StreamParity, StreamRole, StreamState, }; -struct RejectNoAck; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum StreamNamespace { - Low, - High, -} - -impl StreamNamespace { - const BIT: u32 = 1 << 31; - - pub fn for_local(local: XID, peer: XID) -> Self { - match local.0.cmp(&peer.0) { - std::cmp::Ordering::Less | std::cmp::Ordering::Equal => Self::Low, - std::cmp::Ordering::Greater => Self::High, - } - } - - pub fn bit(self) -> u32 { - match self { - Self::Low => 0, - Self::High => Self::BIT, - } - } - - pub fn matches(self, stream_id: StreamId) -> bool { - (stream_id.0 & Self::BIT) == self.bit() - } - - pub fn remote(self) -> Self { - match self { - Self::Low => Self::High, - Self::High => Self::Low, - } - } -} +pub(crate) const SESSION_RECORD_TRACKED_WINDOW: u64 = 256; #[derive(Debug, Clone, Copy)] pub struct SessionFsmConfig { - pub local_namespace: StreamNamespace, - pub stream_chunk_size: usize, + pub local_parity: StreamParity, + pub record_size: usize, pub ack_delay: Duration, pub retransmit_timeout: Duration, pub keepalive_interval: Duration, pub peer_timeout: Duration, + pub stream_send_buffer_size: usize, + pub stream_receive_buffer_size: usize, } impl Default for SessionFsmConfig { fn default() -> Self { Self { - local_namespace: StreamNamespace::Low, - stream_chunk_size: 16 * 1024, + local_parity: StreamParity::Even, + record_size: 16 * 1024, ack_delay: Duration::from_millis(5), retransmit_timeout: Duration::from_millis(150), keepalive_interval: Duration::from_secs(10), peer_timeout: Duration::from_secs(30), + stream_send_buffer_size: 64 * 1024, + stream_receive_buffer_size: 64 * 1024, } } } @@ -85,6 +49,7 @@ impl Default for SessionFsmConfig { pub enum SessionEvent { Opened(StreamId), Readable(StreamId), + Writable(StreamId), Finished(StreamId), Closed(StreamClose), WritableClosed(StreamId), @@ -103,6 +68,8 @@ pub enum StreamError { MissingStream, #[error("stream is not writable")] NotWritable, + #[error("invalid read commit")] + InvalidRead, #[error("session is closed")] SessionClosed, } @@ -114,7 +81,9 @@ pub struct SessionFsm { impl SessionFsm { pub fn new(mut config: SessionFsmConfig, now: Instant) -> Self { - config.stream_chunk_size = config.stream_chunk_size.max(1); + config.record_size = config.record_size.max(64); + config.stream_send_buffer_size = config.stream_send_buffer_size.max(1); + config.stream_receive_buffer_size = config.stream_receive_buffer_size.max(1); Self { config, state: SessionFsmState { @@ -122,10 +91,12 @@ impl SessionFsm { last_activity_at: now, last_inbound_at: now, session_state: SessionState::Open, - next_stream_ordinal: 1, - next_seq: SessionSeq(1), - tx_ring: ring::SeqRing::new(SessionSeq(1)), - rx_ring: ring::SeqRing::new(SessionSeq(1)), + next_stream_ordinal: 0, + next_record_seq: RecordSeq(0), + next_write_id: 0, + issued_records: Default::default(), + sent_records: Default::default(), + received_records: ReceivedRecords::default(), ack_state: AckState::Idle, pending_control: Default::default(), streams: Default::default(), @@ -136,21 +107,28 @@ impl SessionFsm { pub fn open_stream(&mut self) -> Result { self.ensure_session_open()?; - let stream_id = - StreamId(self.config.local_namespace.bit() | self.state.next_stream_ordinal); + let stream_id = self + .config + .local_parity + .make_stream_id(self.state.next_stream_ordinal); self.state.next_stream_ordinal = self.state.next_stream_ordinal.saturating_add(1); - self.state - .streams - .insert(stream_id, StreamState::new(StreamRole::Initiator)); + self.state.streams.insert( + stream_id, + StreamState::new( + StreamRole::Initiator, + self.config.stream_send_buffer_size, + self.config.stream_receive_buffer_size, + ), + ); Ok(stream_id) } - pub fn write_stream(&mut self, stream_id: StreamId, bytes: Vec) -> Result<(), StreamError> { + pub fn write_stream( + &mut self, + stream_id: StreamId, + bytes: &[u8], + ) -> Result { self.ensure_session_open()?; - if bytes.is_empty() { - return Ok(()); - } - let stream = self .state .streams @@ -160,8 +138,11 @@ impl SessionFsm { return Err(StreamError::NotWritable); } - stream.send_buf.extend(bytes); - Ok(()) + let accepted = bytes + .len() + .min(stream.send_capacity(self.config.stream_send_buffer_size)); + stream.send_buf.extend(bytes[..accepted].iter().copied()); + Ok(accepted) } pub fn finish_stream(&mut self, stream_id: StreamId) -> Result<(), StreamError> { @@ -174,7 +155,6 @@ impl SessionFsm { if !stream.is_writable() { return Err(StreamError::NotWritable); } - stream.outbound_state = OutboundState::FinQueued; Ok(()) } @@ -193,8 +173,7 @@ impl SessionFsm { .streams .get_mut(&stream_id) .ok_or(StreamError::MissingStream)?; - - Self::apply_close_to_stream(stream, target); + Self::apply_local_close_to_stream(stream, target); stream.pending_close = Some(StreamClose { stream_id, target, @@ -206,47 +185,65 @@ impl SessionFsm { Ok(()) } - pub fn read_stream( + pub fn peek_stream( &mut self, stream_id: StreamId, out: &mut [u8], ) -> Result { - let written = { - let stream = self - .state - .streams - .get_mut(&stream_id) - .ok_or(StreamError::MissingStream)?; - if out.is_empty() || stream.recv_buf.is_empty() { - return Ok(0); - } - - let (front, back) = stream.recv_buf.as_slices(); - let front_len = front.len().min(out.len()); - out[..front_len].copy_from_slice(&front[..front_len]); + let stream = self + .state + .streams + .get(&stream_id) + .ok_or(StreamError::MissingStream)?; + if out.is_empty() { + return Ok(0); + } - let mut written = front_len; - let remaining = out.len() - front_len; - if remaining > 0 { - let back_len = back.len().min(remaining); - out[written..written + back_len].copy_from_slice(&back[..back_len]); - written += back_len; + let mut written = 0; + for chunk in stream.recv.bytes() { + let remaining = out.len().saturating_sub(written); + if remaining == 0 { + break; + } + let len = remaining.min(chunk.len()); + out[written..written + len].copy_from_slice(&chunk[..len]); + written += len; + if len < chunk.len() { + break; } + } - stream.recv_buf.drain(..written); - written - }; - self.try_reap_stream(stream_id); Ok(written) } + pub fn commit_stream_read( + &mut self, + stream_id: StreamId, + len: usize, + ) -> Result<(), StreamError> { + let stream = self + .state + .streams + .get_mut(&stream_id) + .ok_or(StreamError::MissingStream)?; + if len > stream.readable_bytes() { + return Err(StreamError::InvalidRead); + } + stream.recv.consume(len).map_err(|_| StreamError::InvalidRead)?; + if stream.recv_limit() > stream.advertised_max_offset { + stream.pending_window = true; + } + self.try_reap_stream(stream_id); + Ok(()) + } + pub fn stream_available_bytes(&self, stream_id: StreamId) -> Result { let stream = self .state .streams .get(&stream_id) .ok_or(StreamError::MissingStream)?; - Ok(stream.recv_buf.len()) + Ok(stream.readable_bytes()) } pub fn queue_ping(&mut self) -> Result<(), StreamError> { @@ -258,129 +255,89 @@ impl SessionFsm { pub fn receive( &mut self, now: Instant, - envelope: SessionEnvelope, + record: SessionRecord, mut emit: impl FnMut(SessionEvent), ) { self.state.now = now; self.collect_timeouts(); - self.process_ack(envelope.ack); - - if self.state.session_state == SessionState::Closed { - return; - } + let ack_eliciting = Self::record_is_ack_eliciting(&record); self.state.last_activity_at = self.state.now; self.state.last_inbound_at = self.state.now; - let seq = envelope.seq; - if seq.0 < self.state.rx_ring.base_seq().0 || self.state.rx_ring.contains_key(&seq) { - if !matches!(envelope.body, SessionBody::Ack) { - self.schedule_ack(true); + let out_of_order = match self.state.received_records.insert(record.seq) { + ReceiveInsertOutcome::Duplicate => { + if ack_eliciting { + self.schedule_ack(true); + } + return; } - return; - } - if !self.state.rx_ring.accepts_seq(seq) { - self.fail_session( - SessionCloseBody { - code: CloseCode::PROTOCOL, - }, - &mut emit, - ); - return; - } + ReceiveInsertOutcome::New { out_of_order } => out_of_order, + }; - let out_of_order = seq != self.state.rx_ring.base_seq(); - let body_kind_is_ack = matches!(envelope.body, SessionBody::Ack); - let apply_inbound_body = match envelope.body { - SessionBody::Ack | SessionBody::Ping(_) => Ok(()), - SessionBody::Close(close) => { - self.state.session_state = SessionState::Closed; - self.clear_streams(); - emit(SessionEvent::SessionClosed(close)); - Ok(()) - } - SessionBody::Stream(frame) => self.handle_stream_frame(frame, &mut emit), - SessionBody::StreamClose(frame) => { - self.handle_stream_close(frame, &mut emit); - Ok(()) + if self.state.session_state == SessionState::Closed { + if ack_eliciting { + self.schedule_ack(true); } - }; - if apply_inbound_body.is_err() { return; } - match self.state.rx_ring.insert(seq, ()) { - Ok(()) => { - self.state.rx_ring.advance_occupied_front(); - if !body_kind_is_ack { - self.schedule_ack(out_of_order); + for frame in record.frames { + match frame { + SessionFrame::Ping => {} + SessionFrame::Ack(ack) => self.process_record_ack(ack, &mut emit), + SessionFrame::StreamData(frame) => { + if self.handle_stream_data(frame, &mut emit).is_err() { + return; + } + } + SessionFrame::StreamWindow(frame) => { + if self.handle_stream_window(frame, &mut emit).is_err() { + return; + } + } + SessionFrame::StreamClose(frame) => { + if self.handle_stream_close(frame, &mut emit).is_err() { + return; + } + } + SessionFrame::Close(close) => { + self.handle_session_close(close, &mut emit); + return; } } - Err(e) => { - unreachable!("seq window was pre-validated before body handling {e:?}"); - } + } + + if ack_eliciting { + self.schedule_ack(out_of_order); } } - pub fn confirm_write(&mut self, now: Instant, seq: SessionSeq) { + pub fn confirm_write(&mut self, now: Instant, write_id: u64) { self.state.now = now; - let Some((retransmit, should_clear_ack)) = self.state.tx_ring.get(&seq).map(|entry| { - ( - entry.pending.retransmit, - matches!(entry.pending.body, SessionBody::Ack), - ) - }) else { + let Some(pending) = self.state.issued_records.shift_remove(&write_id) else { return; }; - debug_assert!(matches!( - self.state.tx_ring.get(&seq).map(|entry| entry.state), - Some(TxState::Issued) - )); - if !matches!( - self.state.tx_ring.get(&seq).map(|entry| entry.state), - Some(TxState::Issued) - ) { - return; - } - - self.state.last_activity_at = self.state.now; - if retransmit { - if let Some(entry) = self.state.tx_ring.get_mut(&seq) { - entry.state = TxState::Sent { - sent_at: self.state.now, - }; - } - } else { - let _ = self.state.tx_ring.remove(&seq); - self.state - .tx_ring - .advance_empty_front_until(self.state.next_seq); - if should_clear_ack { - self.state.clear_ack_schedule(); - } - } + self.state.last_activity_at = now; + self.state.sent_records.insert( + pending.seq.0, + SentRecord { + pending, + sent_at: now, + }, + ); } - pub fn reject_write(&mut self, seq: SessionSeq) { - debug_assert!(matches!( - self.state.tx_ring.get(&seq).map(|entry| entry.state), - Some(TxState::Issued) - )); - let Some(entry) = self.state.tx_ring.get_mut(&seq) else { + pub fn reject_write(&mut self, write_id: u64) { + let Some(pending) = self.state.issued_records.shift_remove(&write_id) else { return; }; - if !matches!(entry.state, TxState::Issued) { - return; - } - entry.state = TxState::Pending; + self.restore_pending_record(pending); } pub fn on_timer(&mut self, now: Instant, mut emit: impl FnMut(SessionEvent)) { self.state.now = now; self.collect_timeouts(); - if self.state.session_state == SessionState::Closed { - return; - } if let AckState::Delayed { due_at } = self.state.ack_state { if due_at <= self.state.now { self.state.ack_state = AckState::Immediate; @@ -397,7 +354,8 @@ impl SessionFsm { ); return; } - if !self.config.keepalive_interval.is_zero() + if self.state.session_state == SessionState::Open + && !self.config.keepalive_interval.is_zero() && self.state.last_activity_at + self.config.keepalive_interval <= self.state.now { self.state.pending_control.ping = true; @@ -412,12 +370,9 @@ impl SessionFsm { }; let retransmit_deadline = self .state - .tx_ring - .iter() - .filter_map(|(_, entry)| match entry.state { - TxState::Sent { sent_at } => Some(sent_at + self.config.retransmit_timeout), - TxState::Pending | TxState::Issued => None, - }) + .sent_records + .values() + .map(|record| record.sent_at + self.config.retransmit_timeout) .min(); let keepalive_deadline = (self.state.session_state == SessionState::Open && !self.config.keepalive_interval.is_zero() @@ -440,163 +395,337 @@ impl SessionFsm { pub fn has_pending_stream_work(&self) -> bool { self.state.streams.values().any(|stream| { stream.pending_close.is_some() + || !stream.retransmit.is_empty() || !stream.send_buf.is_empty() + || stream.pending_window || matches!(stream.outbound_state, OutboundState::FinQueued) }) } - pub fn take_next_write( - &mut self, - now: Instant, - ) -> Option<(SessionSeq, SessionAck, &SessionBody)> { + pub fn take_next_write(&mut self, now: Instant) -> Option<(u64, SessionRecord)> { self.state.now = now; self.collect_timeouts(); - let ack = self.state.current_ack(); - let seq = self - .take_pending_retransmit() - .or_else(|| self.take_fresh_write())?; - let entry = self.state.tx_ring.get(&seq).unwrap(); - Some((seq, ack, &entry.pending.body)) - } - - fn take_pending_retransmit(&mut self) -> Option { - let base_seq = self.state.tx_ring.base_seq().0; - let next_seq = self.state.next_seq.0; - - for seq in (base_seq..next_seq).map(SessionSeq) { - let should_retry = match self.state.tx_ring.get(&seq) { - Some(entry) if matches!(entry.state, TxState::Pending) => { - self.should_retry_body(&entry.pending.body) + + let built = self.build_next_record()?; + let write_id = self.state.next_write_id; + self.state.next_write_id = self.state.next_write_id.wrapping_add(1); + self.state.issued_records.insert(write_id, built.pending); + Some((write_id, built.record)) + } + + fn build_next_record(&mut self) -> Option { + let seq = self.state.next_record_seq; + let mut record = SessionRecord { + seq, + frames: Vec::new(), + }; + let mut pending = PendingRecord { + seq, + reliable: Vec::new(), + ack_included: false, + ping_included: false, + window_updates: Vec::new(), + }; + let mut remaining = self.config.record_size.saturating_sub(8); + + if self.should_send_ack() { + if let Some(ack) = self.state.received_records.ack() { + let frame = SessionFrame::Ack(ack); + if self.push_frame(&mut record, &mut remaining, frame, true) { + pending.ack_included = true; + self.state.ack_state = AckState::Idle; } - _ => continue, - }; + } + } - if !should_retry { - let _ = self.state.tx_ring.remove(&seq); - continue; + while let Some(close) = self.take_pending_session_close(remaining, record.frames.is_empty()) { + let frame = SessionFrame::Close(close.clone()); + if !self.push_frame(&mut record, &mut remaining, frame, true) { + self.state.pending_control.close = Some(close); + break; } + pending.reliable.push(ReliableFrame::Close(close)); + } - self.state - .tx_ring - .advance_empty_front_until(self.state.next_seq); - let entry = self.state.tx_ring.get_mut(&seq).unwrap(); - entry.state = TxState::Issued; - return Some(seq); + while let Some(close) = self.take_next_pending_stream_close(remaining, record.frames.is_empty()) { + let frame = SessionFrame::StreamClose(close.clone()); + if !self.push_frame(&mut record, &mut remaining, frame, true) { + self.restore_stream_close(close); + break; + } + pending.reliable.push(ReliableFrame::StreamClose(close)); } - self.state - .tx_ring - .advance_empty_front_until(self.state.next_seq); + if let Some(ping) = self.take_pending_ping(remaining, record.frames.is_empty()) { + if self.push_frame(&mut record, &mut remaining, ping, true) { + pending.ping_included = true; + } else { + self.state.pending_control.ping = true; + } + } - None + while let Some(window) = + self.take_next_pending_stream_window(remaining, record.frames.is_empty()) + { + let maximum_offset = window.maximum_offset; + let stream_id = window.stream_id; + if !self.push_frame( + &mut record, + &mut remaining, + SessionFrame::StreamWindow(window), + true, + ) { + if let Some(stream) = self.state.streams.get_mut(&stream_id) { + stream.pending_window = true; + } + break; + } + pending.window_updates.push((stream_id, maximum_offset)); + } + + while let Some(frame) = + self.take_next_retransmit_stream_data(remaining, record.frames.is_empty()) + { + if !self.push_frame( + &mut record, + &mut remaining, + SessionFrame::StreamData(frame.clone()), + true, + ) { + self.restore_stream_data(frame); + break; + } + pending.reliable.push(ReliableFrame::StreamData(frame)); + } + + while let Some(frame) = self.take_next_fresh_stream_data(remaining, record.frames.is_empty()) { + if !self.push_frame( + &mut record, + &mut remaining, + SessionFrame::StreamData(frame.clone()), + true, + ) { + self.restore_stream_data(frame); + break; + } + pending.reliable.push(ReliableFrame::StreamData(frame)); + } + + if record.frames.is_empty() { + return None; + } + + self.state.next_record_seq = RecordSeq(self.state.next_record_seq.0.saturating_add(1)); + Some(BuiltRecord { record, pending }) } - fn take_fresh_write(&mut self) -> Option { - if !self.state.tx_ring.accepts_seq(self.state.next_seq) { + fn take_pending_session_close( + &mut self, + remaining: usize, + record_empty: bool, + ) -> Option { + let close = self.state.pending_control.close.clone()?; + let frame = SessionFrame::Close(close.clone()); + if !self.frame_fits(remaining, record_empty, &frame) { return None; } + self.state.pending_control.close.take() + } - let pending = self.next_pending_body()?; - let seq = self.state.next_seq; - self.state.next_seq = SessionSeq(seq.0 + 1); - self.state - .tx_ring - .insert( - seq, - TxEntry { - pending, - state: TxState::Issued, - }, - ) - .unwrap(); - Some(seq) + fn take_pending_ping( + &mut self, + remaining: usize, + record_empty: bool, + ) -> Option { + if !self.state.pending_control.ping { + return None; + } + let frame = SessionFrame::Ping; + if !self.frame_fits(remaining, record_empty, &frame) { + return None; + } + self.state.pending_control.ping = false; + Some(frame) } - fn next_pending_body(&mut self) -> Option { - if let Some(close) = self.state.pending_control.close.take() { - return Some(PendingSessionBody { - body: SessionBody::Close(close), - retransmit: true, - }); + fn take_next_pending_stream_close( + &mut self, + remaining: usize, + record_empty: bool, + ) -> Option { + let len = self.state.streams.len(); + if len == 0 { + return None; } - if self.state.pending_control.ping { - self.state.pending_control.ping = false; - return Some(PendingSessionBody { - body: SessionBody::Ping(PingBody), - retransmit: false, - }); + + let start = self.state.next_stream_index % len; + for offset in 0..len { + let index = (start + offset) % len; + let Some((_, stream)) = self.state.streams.get_index(index) else { + continue; + }; + let Some(close) = stream.pending_close.clone() else { + continue; + }; + let frame = SessionFrame::StreamClose(close.clone()); + if !self.frame_fits(remaining, record_empty, &frame) { + continue; + } + + let stream = self.state.streams.get_index_mut(index).unwrap().1; + self.state.next_stream_index = (index + 1) % len; + return stream.pending_close.take().or(Some(close)); + } + + None + } + + fn take_next_pending_stream_window( + &mut self, + remaining: usize, + record_empty: bool, + ) -> Option { + let len = self.state.streams.len(); + if len == 0 { + return None; + } + + let start = self.state.next_stream_index % len; + for offset in 0..len { + let index = (start + offset) % len; + let Some((&stream_id, stream)) = self.state.streams.get_index(index) else { + continue; + }; + if !stream.pending_window { + continue; + } + let frame = StreamWindow { + stream_id, + maximum_offset: stream.recv_limit(), + }; + if !self.frame_fits( + remaining, + record_empty, + &SessionFrame::StreamWindow(frame.clone()), + ) { + continue; + } + + let (_, stream) = self.state.streams.get_index_mut(index).unwrap(); + stream.pending_window = false; + stream.advertised_max_offset = frame.maximum_offset; + self.state.next_stream_index = (index + 1) % len; + return Some(frame); } + None + } + + fn take_next_retransmit_stream_data( + &mut self, + remaining: usize, + record_empty: bool, + ) -> Option { + let max_payload = self.max_stream_data_payload(remaining, record_empty)?; let len = self.state.streams.len(); - if len > 0 { - let start = self.state.next_stream_index % len; - for offset in 0..len { - let index = (start + offset) % len; - let has_pending = self - .state - .streams - .get_index(index) - .is_some_and(|(_, stream)| { - stream.pending_close.is_some() - || !stream.send_buf.is_empty() - || matches!(stream.outbound_state, OutboundState::FinQueued) - }); - if !has_pending { - continue; + if len == 0 { + return None; + } + + let start = self.state.next_stream_index % len; + for offset in 0..len { + let index = (start + offset) % len; + let Some((_, stream)) = self.state.streams.get_index(index) else { + continue; + }; + + if matches!(stream.outbound_state, OutboundState::Closed) { + let (_, stream) = self.state.streams.get_index_mut(index).unwrap(); + while let Some(frame) = stream.retransmit.pop_front() { + stream.inflight_bytes = stream.inflight_bytes.saturating_sub(frame.bytes.len()); } + continue; + } - let body = { - let Some((&stream_id, stream)) = self.state.streams.get_index_mut(index) else { - continue; - }; - match stream.open_state { - StreamOpenState::PendingSend => { - let body = Self::take_stream_frame( - stream, - stream_id, - self.config.stream_chunk_size, - ) - .map(SessionBody::Stream); - if body.is_some() { - stream.open_state = StreamOpenState::WaitingForAck; - } - body - } - StreamOpenState::WaitingForAck => None, - StreamOpenState::Opened => { - if let Some(close) = stream.pending_close.take() { - Some(SessionBody::StreamClose(close)) - } else { - Self::take_stream_frame( - stream, - stream_id, - self.config.stream_chunk_size, - ) - .map(SessionBody::Stream) - } - } - } - }; - let Some(body) = body else { - continue; - }; - self.state.next_stream_index = (index + 1) % len; - return Some(PendingSessionBody { - body, - retransmit: true, - }); + let Some(_) = stream.retransmit.front() else { + continue; + }; + let (_, stream) = self.state.streams.get_index_mut(index).unwrap(); + let frame = stream.retransmit.pop_front().unwrap(); + let (head, tail) = Self::split_stream_data(frame, max_payload); + if let Some(tail) = tail { + stream.retransmit.push_front(tail); } + self.state.next_stream_index = (index + 1) % len; + return Some(head); } - let ack_due = match self.state.ack_state { - AckState::Immediate => true, - AckState::Delayed { due_at } => due_at <= self.state.now, - AckState::Idle => false, - }; - ack_due.then_some(PendingSessionBody { - body: SessionBody::Ack, - retransmit: false, - }) + None + } + + fn take_next_fresh_stream_data( + &mut self, + remaining: usize, + record_empty: bool, + ) -> Option { + let max_payload = self.max_stream_data_payload(remaining, record_empty)?; + let len = self.state.streams.len(); + if len == 0 { + return None; + } + + let start = self.state.next_stream_index % len; + for offset in 0..len { + let index = (start + offset) % len; + let Some((&stream_id, stream)) = self.state.streams.get_index(index) else { + continue; + }; + if matches!(stream.outbound_state, OutboundState::Closed) { + continue; + } + + let credit_remaining = stream + .peer_max_offset + .saturating_sub(stream.next_send_offset) as usize; + let has_empty_fin = matches!(stream.outbound_state, OutboundState::FinQueued) + && stream.send_buf.is_empty() + && stream.next_send_offset <= stream.peer_max_offset; + if stream.send_buf.is_empty() && !has_empty_fin { + continue; + } + + if credit_remaining == 0 && !has_empty_fin { + continue; + } + + let (_, stream) = self.state.streams.get_index_mut(index).unwrap(); + let payload_len = stream + .send_buf + .len() + .min(max_payload) + .min(credit_remaining); + let bytes: Vec = stream.send_buf.drain(..payload_len).collect(); + let fin = matches!(stream.outbound_state, OutboundState::FinQueued) + && stream.send_buf.is_empty() + && stream.next_send_offset + bytes.len() as u64 <= stream.peer_max_offset; + let frame = StreamData { + stream_id, + offset: stream.next_send_offset, + fin, + bytes, + }; + stream.next_send_offset = stream + .next_send_offset + .saturating_add(frame.bytes.len() as u64); + stream.inflight_bytes = stream.inflight_bytes.saturating_add(frame.bytes.len()); + if fin { + stream.outbound_state = OutboundState::Finished; + } + self.state.next_stream_index = (index + 1) % len; + return Some(frame); + } + + None } fn ensure_session_open(&self) -> Result<(), StreamError> { @@ -607,52 +736,29 @@ impl SessionFsm { } } - fn process_ack(&mut self, ack: ql_wire::SessionAck) { - loop { - let Some((seq, stream_id, opens_stream)) = - self.state.tx_ring.iter().find_map(|(seq, entry)| { - if !matches!(entry.state, TxState::Sent { .. }) || !Self::ack_covers(ack, seq) { - return None; - } - - let (stream_id, opens_stream) = match &entry.pending.body { - SessionBody::Stream(frame) => (Some(frame.stream_id), frame.chunk_seq == 0), - SessionBody::StreamClose(frame) => (Some(frame.stream_id), false), - _ => (None, false), - }; + fn process_record_ack(&mut self, ack: RecordAck, emit: &mut impl FnMut(SessionEvent)) { + let acked: Vec = self + .state + .sent_records + .keys() + .copied() + .filter(|seq| Self::ack_covers(&ack, RecordSeq(*seq))) + .collect(); - Some((seq, stream_id, opens_stream)) - }) - else { - break; + for seq in acked { + let Some(sent) = self.state.sent_records.shift_remove(&seq) else { + continue; }; - - let _ = self.state.tx_ring.remove(&seq); - if let Some(stream_id) = stream_id { - if opens_stream { - if let Some(stream) = self.state.streams.get_mut(&stream_id) { - if matches!(stream.open_state, StreamOpenState::WaitingForAck) { - stream.open_state = StreamOpenState::Opened; - } - } - } - self.try_reap_stream(stream_id); + for frame in sent.pending.reliable { + self.acknowledge_reliable_frame(frame, emit); } } - self.state - .tx_ring - .advance_empty_front_until(self.state.next_seq); } - fn ack_covers(ack: ql_wire::SessionAck, seq: SessionSeq) -> bool { - if seq.0 <= ack.base.0 { - return true; - } - let delta = seq.0 - ack.base.0; - if delta == 0 || delta > 64 { - return false; - } - (ack.bitmap & (1u64 << (delta - 1))) != 0 + fn ack_covers(ack: &RecordAck, seq: RecordSeq) -> bool { + ack.ranges + .iter() + .any(|range| range.start <= seq.0 && seq.0 < range.end) } fn schedule_ack(&mut self, immediate: bool) { @@ -666,106 +772,120 @@ impl SessionFsm { }; } + fn should_send_ack(&self) -> bool { + if self.state.received_records.ack().is_none() { + return false; + } + match self.state.ack_state { + AckState::Immediate => true, + AckState::Delayed { due_at } => due_at <= self.state.now, + AckState::Idle => false, + } + } + fn collect_timeouts(&mut self) { - let expired: Vec<_> = self + let expired: Vec = self .state - .tx_ring + .sent_records .iter() - .filter_map(|(seq, entry)| match entry.state { - TxState::Sent { sent_at } - if sent_at + self.config.retransmit_timeout <= self.state.now => - { - Some(seq) - } - TxState::Pending | TxState::Issued | TxState::Sent { .. } => None, + .filter_map(|(seq, record)| { + (record.sent_at + self.config.retransmit_timeout <= self.state.now).then_some(*seq) }) .collect(); for seq in expired { - let Some((retransmit, body)) = self - .state - .tx_ring - .get(&seq) - .map(|entry| (entry.pending.retransmit, entry.pending.body.clone())) - else { + let Some(sent) = self.state.sent_records.shift_remove(&seq) else { continue; }; - if retransmit && self.should_retry_body(&body) { - if let Some(entry) = self.state.tx_ring.get_mut(&seq) { - entry.state = TxState::Pending; - } - } else { - let _ = self.state.tx_ring.remove(&seq); - if matches!(body, SessionBody::Ack) { - self.state.clear_ack_schedule(); + self.restore_pending_record(sent.pending); + } + } + + fn restore_pending_record(&mut self, pending: PendingRecord) { + if pending.ack_included { + self.schedule_ack(true); + } + if pending.ping_included { + self.state.pending_control.ping = true; + } + for (stream_id, maximum_offset) in pending.window_updates { + if let Some(stream) = self.state.streams.get_mut(&stream_id) { + if stream.recv_limit() >= maximum_offset { + stream.pending_window = true; } } } + for frame in pending.reliable { + self.requeue_reliable_frame(frame); + } + } - self.state - .tx_ring - .advance_empty_front_until(self.state.next_seq); + fn requeue_reliable_frame(&mut self, frame: ReliableFrame) { + match frame { + ReliableFrame::Close(close) => { + self.state.pending_control.close = Some(close); + } + ReliableFrame::StreamClose(close) => self.restore_stream_close(close), + ReliableFrame::StreamData(frame) => self.restore_stream_data(frame), + } } - fn should_retry_body(&self, body: &SessionBody) -> bool { - match body { - SessionBody::Ack => true, - SessionBody::Ping(_) => self.state.session_state == SessionState::Open, - SessionBody::Close(_) => true, - SessionBody::Stream(frame) => { - self.state.session_state == SessionState::Open - && self - .state - .streams - .get(&frame.stream_id) - .is_some_and(|stream| { - !matches!(stream.outbound_state, OutboundState::Closed) - || (matches!(stream.open_state, StreamOpenState::WaitingForAck) - && frame.chunk_seq == 0) - }) + fn acknowledge_reliable_frame( + &mut self, + frame: ReliableFrame, + emit: &mut impl FnMut(SessionEvent), + ) { + match frame { + ReliableFrame::Close(_) => {} + ReliableFrame::StreamClose(frame) => { + self.try_reap_stream(frame.stream_id); } - SessionBody::StreamClose(frame) => { - self.state.session_state == SessionState::Open - && self.state.streams.contains_key(&frame.stream_id) + ReliableFrame::StreamData(frame) => { + let stream_id = frame.stream_id; + if let Some(stream) = self.state.streams.get_mut(&stream_id) { + let was_full = stream.send_capacity(self.config.stream_send_buffer_size) == 0; + stream.inflight_bytes = stream.inflight_bytes.saturating_sub(frame.bytes.len()); + if was_full && stream.send_capacity(self.config.stream_send_buffer_size) > 0 { + emit(SessionEvent::Writable(stream_id)); + } + } + self.try_reap_stream(stream_id); } } } - fn handle_stream_frame( + fn handle_stream_data( &mut self, - frame: StreamChunk, + frame: StreamData, emit: &mut impl FnMut(SessionEvent), - ) -> Result<(), RejectNoAck> { - let StreamChunk { - stream_id, - chunk_seq, - bytes, - fin, - } = frame; - let remote_namespace = self.config.local_namespace.remote(); + ) -> Result<(), ()> { + let stream_id = frame.stream_id; let stream = match self.state.streams.entry(stream_id) { Entry::Occupied(entry) => entry.into_mut(), Entry::Vacant(entry) => { - if !remote_namespace.matches(stream_id) { + if !self.config.local_parity.remote().matches(stream_id) { self.fail_session( SessionCloseBody { code: CloseCode::PROTOCOL, }, emit, ); - return Ok(()); - } - if chunk_seq != 0 { - return Err(RejectNoAck); + return Err(()); } emit(SessionEvent::Opened(stream_id)); - entry.insert(StreamState::new(StreamRole::Responder)) + entry.insert(StreamState::new( + StreamRole::Responder, + self.config.stream_send_buffer_size, + self.config.stream_receive_buffer_size, + )) } }; + match stream.inbound_state { - InboundState::Open => (), + InboundState::Open => {} + InboundState::Discarding => return Ok(()), InboundState::Finished | InboundState::Closed(_) => { - if chunk_seq < stream.recv_window.next_chunk_seq() { + if frame.offset + frame.bytes.len() as u64 <= stream.recv.start_offset() { return Ok(()); } self.fail_session( @@ -774,41 +894,48 @@ impl SessionFsm { }, emit, ); - return Ok(()); + return Err(()); } - InboundState::Discarding => return Ok(()), } - let was_readable = !stream.recv_buf.is_empty(); - let outcome = stream.recv_window.insert(chunk_seq, RxChunk { bytes, fin }); - - match outcome { - RecvInsertOutcome::Inserted => { - Self::drain_recv_window(stream); - if !was_readable && !stream.recv_buf.is_empty() { + let was_readable = stream.readable_bytes() > 0; + let insert = stream.recv.insert(frame.offset, frame.fin, &frame.bytes); + match insert { + Ok(outcome) => { + if !was_readable && outcome.newly_readable_bytes > 0 { emit(SessionEvent::Readable(stream_id)); } - if matches!(stream.inbound_state, InboundState::Finished) { + if outcome.became_complete { + stream.inbound_state = InboundState::Finished; emit(SessionEvent::Finished(stream_id)); } self.try_reap_stream(stream_id); Ok(()) } - RecvInsertOutcome::Duplicate => Ok(()), - RecvInsertOutcome::RejectNoAck => Err(RejectNoAck), - RecvInsertOutcome::Conflict => { + Err(ByteReassemblyError::ConflictingOverlap) + | Err(ByteReassemblyError::OutOfWindow) + | Err(ByteReassemblyError::InconsistentFinalOffset) + | Err(ByteReassemblyError::FinalOffsetBeforeBufferedData) + | Err(ByteReassemblyError::BeyondFinalOffset) + | Err(ByteReassemblyError::TooManyMissingRanges) + | Err(ByteReassemblyError::OffsetOverflow) => { self.fail_session( SessionCloseBody { code: CloseCode::PROTOCOL, }, emit, ); - Ok(()) + Err(()) } + Err(ByteReassemblyError::ConsumeBeyondReadable) => unreachable!(), } } - fn handle_stream_close(&mut self, frame: StreamClose, emit: &mut impl FnMut(SessionEvent)) { + fn handle_stream_window( + &mut self, + frame: StreamWindow, + emit: &mut impl FnMut(SessionEvent), + ) -> Result<(), ()> { let Some(stream) = self.state.streams.get_mut(&frame.stream_id) else { self.fail_session( SessionCloseBody { @@ -816,9 +943,50 @@ impl SessionFsm { }, emit, ); - return; + return Err(()); }; + let was_full = stream.send_capacity(self.config.stream_send_buffer_size) == 0; + if frame.maximum_offset > stream.peer_max_offset { + stream.peer_max_offset = frame.maximum_offset; + } + if was_full && stream.send_capacity(self.config.stream_send_buffer_size) > 0 { + emit(SessionEvent::Writable(frame.stream_id)); + } + Ok(()) + } + + fn handle_stream_close( + &mut self, + frame: StreamClose, + emit: &mut impl FnMut(SessionEvent), + ) -> Result<(), ()> { + let created = match self.state.streams.entry(frame.stream_id) { + Entry::Occupied(_) => false, + Entry::Vacant(entry) => { + if !self.config.local_parity.remote().matches(frame.stream_id) { + self.fail_session( + SessionCloseBody { + code: CloseCode::PROTOCOL, + }, + emit, + ); + return Err(()); + } + entry.insert(StreamState::new( + StreamRole::Responder, + self.config.stream_send_buffer_size, + self.config.stream_receive_buffer_size, + )); + true + } + }; + + let stream = self.state.streams.get_mut(&frame.stream_id).unwrap(); + if created { + emit(SessionEvent::Opened(frame.stream_id)); + } + if Self::target_affects_inbound(stream.role, frame.target) && !matches!( stream.inbound_state, @@ -826,8 +994,7 @@ impl SessionFsm { ) { stream.inbound_state = InboundState::Closed(frame.clone()); - stream.recv_buf.clear(); - stream.recv_window.clear(); + stream.reset_recv(); emit(SessionEvent::Closed(frame.clone())); } if Self::target_affects_outbound(stream.role, frame.target) @@ -835,21 +1002,41 @@ impl SessionFsm { { stream.outbound_state = OutboundState::Closed; stream.send_buf.clear(); + stream.retransmit.clear(); stream.pending_close = None; + stream.inflight_bytes = 0; emit(SessionEvent::WritableClosed(frame.stream_id)); } self.try_reap_stream(frame.stream_id); + Ok(()) } - fn apply_close_to_stream(stream: &mut StreamState, target: CloseTarget) { + fn handle_session_close( + &mut self, + close: SessionCloseBody, + emit: &mut impl FnMut(SessionEvent), + ) { + if self.state.session_state == SessionState::Closed { + return; + } + + self.state.session_state = SessionState::Closed; + self.state.issued_records.clear(); + self.state.sent_records.clear(); + self.clear_streams(); + self.state.pending_control = Default::default(); + emit(SessionEvent::SessionClosed(close)); + } + + fn apply_local_close_to_stream(stream: &mut StreamState, target: CloseTarget) { if Self::target_affects_inbound(stream.role, target) { stream.inbound_state = InboundState::Discarding; - stream.recv_buf.clear(); - stream.recv_window.clear(); + stream.reset_recv(); } if Self::target_affects_outbound(stream.role, target) { stream.outbound_state = OutboundState::Closed; stream.send_buf.clear(); + stream.retransmit.clear(); } } @@ -861,93 +1048,137 @@ impl SessionFsm { matches!(target, CloseTarget::Both) || role.outbound_target() == target } - fn drain_recv_window(stream: &mut StreamState) { - while let Some(chunk) = stream.recv_window.pop_contiguous() { - let RxChunk { bytes, fin } = chunk; - stream.recv_buf.extend(bytes); - if fin { - stream.inbound_state = InboundState::Finished; - break; + fn restore_stream_close(&mut self, close: StreamClose) { + if let Some(stream) = self.state.streams.get_mut(&close.stream_id) { + stream.pending_close = Some(close); + } + } + + fn restore_stream_data(&mut self, frame: StreamData) { + if let Some(stream) = self.state.streams.get_mut(&frame.stream_id) { + if matches!(stream.outbound_state, OutboundState::Closed) { + stream.inflight_bytes = stream.inflight_bytes.saturating_sub(frame.bytes.len()); + return; } + stream.retransmit.push_front(frame); } } - fn take_stream_frame( - stream: &mut StreamState, - stream_id: StreamId, - chunk_size: usize, - ) -> Option { - if !stream.send_buf.is_empty() { - let len = stream.send_buf.len().min(chunk_size); - let bytes: Vec<_> = stream.send_buf.drain(..len).collect(); - let fin = if stream.send_buf.is_empty() - && matches!(stream.outbound_state, OutboundState::FinQueued) - { - stream.outbound_state = OutboundState::Finished; - true - } else { - false - }; - let frame = StreamChunk { - stream_id, - chunk_seq: stream.next_send_chunk_seq, - bytes, - fin, - }; - stream.next_send_chunk_seq += 1; - return Some(frame); + fn split_stream_data(frame: StreamData, max_payload: usize) -> (StreamData, Option) { + if frame.bytes.len() <= max_payload || frame.bytes.is_empty() { + return (frame, None); } - if matches!(stream.outbound_state, OutboundState::FinQueued) { - stream.outbound_state = OutboundState::Finished; - return Some(StreamChunk { - stream_id, - chunk_seq: stream.next_send_chunk_seq, - bytes: Vec::new(), - fin: true, - }); + let split = max_payload.max(1).min(frame.bytes.len()); + let mut head = frame.clone(); + head.bytes.truncate(split); + head.fin = false; + + let tail = StreamData { + stream_id: frame.stream_id, + offset: frame.offset + split as u64, + fin: frame.fin, + bytes: frame.bytes[split..].to_vec(), + }; + (head, Some(tail)) + } + + fn max_stream_data_payload(&self, remaining: usize, record_empty: bool) -> Option { + let overhead = self.frame_len(&SessionFrame::StreamData(StreamData { + stream_id: StreamId(0), + offset: 0, + fin: false, + bytes: Vec::new(), + })); + if remaining > overhead { + Some(remaining - overhead) + } else if record_empty { + Some(self.config.record_size) + } else { + None } + } - None + fn frame_fits(&self, remaining: usize, record_empty: bool, frame: &SessionFrame) -> bool { + let len = self.frame_len(frame); + len <= remaining || record_empty + } + + fn push_frame( + &self, + record: &mut SessionRecord, + remaining: &mut usize, + frame: SessionFrame, + force_if_empty: bool, + ) -> bool { + let len = self.frame_len(&frame); + if len > *remaining && !(force_if_empty && record.frames.is_empty()) { + return false; + } + record.frames.push(frame); + *remaining = remaining.saturating_sub(len); + true + } + + fn frame_len(&self, frame: &SessionFrame) -> usize { + let mut bytes = Vec::new(); + frame.encode_into(&mut bytes); + bytes.len() + } + + fn record_is_ack_eliciting(record: &SessionRecord) -> bool { + record + .frames + .iter() + .any(|frame| !matches!(frame, SessionFrame::Ack(_))) } fn stream_is_reapable(&self, stream_id: StreamId, stream: &StreamState) -> bool { - let tx_ring_references_stream = - self.state - .tx_ring - .iter() - .any(|(_, entry)| match &entry.pending.body { - SessionBody::Stream(frame) => frame.stream_id == stream_id, - SessionBody::StreamClose(frame) => frame.stream_id == stream_id, - _ => false, - }); + let issued_refs_stream = self.state.issued_records.values().any(|record| { + record.window_updates.iter().any(|(id, _)| *id == stream_id) + || record.reliable.iter().any(|frame| match frame { + ReliableFrame::StreamData(frame) => frame.stream_id == stream_id, + ReliableFrame::StreamClose(frame) => frame.stream_id == stream_id, + ReliableFrame::Close(_) => false, + }) + }); + if issued_refs_stream { + return false; + } - if tx_ring_references_stream { + let sent_refs_stream = self.state.sent_records.values().any(|record| { + record + .pending + .window_updates + .iter() + .any(|(id, _)| *id == stream_id) + || record.pending.reliable.iter().any(|frame| match frame { + ReliableFrame::StreamData(frame) => frame.stream_id == stream_id, + ReliableFrame::StreamClose(frame) => frame.stream_id == stream_id, + ReliableFrame::Close(_) => false, + }) + }); + if sent_refs_stream { return false; } if !stream.send_buf.is_empty() - || !stream.recv_buf.is_empty() - || !stream.recv_window.is_empty() + || !stream.retransmit.is_empty() + || stream.pending_close.is_some() + || stream.inflight_bytes > 0 + || stream.readable_bytes() > 0 + || stream.recv.buffered_end_offset() > stream.recv.start_offset() { return false; } - match stream.open_state { - StreamOpenState::WaitingForAck => false, - StreamOpenState::PendingSend => matches!(stream.outbound_state, OutboundState::Closed), - StreamOpenState::Opened => { - stream.pending_close.is_none() - && matches!( - stream.inbound_state, - InboundState::Finished | InboundState::Closed(_) | InboundState::Discarding - ) - && matches!( - stream.outbound_state, - OutboundState::Finished | OutboundState::Closed - ) - } - } + matches!( + stream.inbound_state, + InboundState::Finished | InboundState::Closed(_) | InboundState::Discarding + ) && matches!( + stream.outbound_state, + OutboundState::Finished | OutboundState::Closed + ) } fn try_reap_stream(&mut self, stream_id: StreamId) { @@ -983,9 +1214,11 @@ impl SessionFsm { } self.state.session_state = SessionState::Closed; - self.clear_streams(); + self.state.issued_records.clear(); + self.state.sent_records.clear(); self.state.pending_control = Default::default(); self.state.pending_control.close = Some(close.clone()); + self.clear_streams(); emit(SessionEvent::SessionClosed(close)); } @@ -994,3 +1227,8 @@ impl SessionFsm { self.state.streams.clear(); } } + +struct BuiltRecord { + record: SessionRecord, + pending: PendingRecord, +} diff --git a/ql-fsm/src/session/ring.rs b/ql-fsm/src/session/ring.rs deleted file mode 100644 index 4fa6427e..00000000 --- a/ql-fsm/src/session/ring.rs +++ /dev/null @@ -1,197 +0,0 @@ -use std::array; - -use ql_wire::SessionSeq; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum SeqRingInsertError { - OutOfWindow, - Occupied, -} - -#[derive(Debug)] -pub struct SeqRing { - base_seq: SessionSeq, - head: usize, - len: usize, - slots: [Option; N], -} - -impl SeqRing { - pub fn new(base_seq: SessionSeq) -> Self { - debug_assert!(N <= 64); - Self { - base_seq, - head: 0, - len: 0, - slots: array::from_fn(|_| None), - } - } - - pub fn base_seq(&self) -> SessionSeq { - self.base_seq - } - - pub fn accepts_seq(&self, seq: SessionSeq) -> bool { - self.offset_for(seq).is_some() - } - - pub fn contains_key(&self, seq: &SessionSeq) -> bool { - self.get(seq).is_some() - } - - pub fn get(&self, seq: &SessionSeq) -> Option<&T> { - let index = self.index_for(*seq)?; - self.slots[index].as_ref() - } - - pub fn get_mut(&mut self, seq: &SessionSeq) -> Option<&mut T> { - let index = self.index_for(*seq)?; - self.slots[index].as_mut() - } - - pub fn insert(&mut self, seq: SessionSeq, value: T) -> Result<(), SeqRingInsertError> { - let index = self.index_for(seq).ok_or(SeqRingInsertError::OutOfWindow)?; - if self.slots[index].is_some() { - return Err(SeqRingInsertError::Occupied); - } - self.slots[index] = Some(value); - self.len += 1; - Ok(()) - } - - pub fn remove(&mut self, seq: &SessionSeq) -> Option { - let index = self.index_for(*seq)?; - let value = self.slots[index].take(); - if value.is_some() { - self.len -= 1; - } - value - } - - pub fn advance_empty_front_until(&mut self, limit_seq: SessionSeq) { - while self.base_seq.0 < limit_seq.0 && self.slots[self.head].is_none() { - self.head = self.next_index(self.head); - self.base_seq = SessionSeq(self.base_seq.0 + 1); - } - } - - pub fn advance_occupied_front(&mut self) { - while self.slots[self.head].is_some() { - let _ = self.slots[self.head].take(); - self.len -= 1; - self.head = self.next_index(self.head); - self.base_seq = SessionSeq(self.base_seq.0 + 1); - } - } - - pub fn iter(&self) -> SeqRingIter<'_, N, T> { - SeqRingIter { - ring: self, - offset: 0, - } - } - - pub fn bitmap(&self) -> u64 { - let mut bitmap = 0u64; - for offset in 0..N { - let index = self.index_for_offset(offset); - if self.slots[index].is_some() { - bitmap |= 1u64 << offset; - } - } - bitmap - } - - fn index_for(&self, seq: SessionSeq) -> Option { - let offset = self.offset_for(seq)?; - Some(self.index_for_offset(offset)) - } - - fn offset_for(&self, seq: SessionSeq) -> Option { - if seq.0 < self.base_seq.0 { - return None; - } - let offset = (seq.0 - self.base_seq.0) as usize; - (offset < N).then_some(offset) - } - - fn index_for_offset(&self, offset: usize) -> usize { - (self.head + offset) % N - } - - fn next_index(&self, index: usize) -> usize { - (index + 1) % N - } -} - -pub struct SeqRingIter<'a, const N: usize, T> { - ring: &'a SeqRing, - offset: usize, -} - -impl<'a, const N: usize, T> Iterator for SeqRingIter<'a, N, T> { - type Item = (SessionSeq, &'a T); - - fn next(&mut self) -> Option { - while self.offset < N { - let offset = self.offset; - self.offset += 1; - let index = self.ring.index_for_offset(offset); - if let Some(value) = self.ring.slots[index].as_ref() { - return Some((SessionSeq(self.ring.base_seq.0 + offset as u64), value)); - } - } - None - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn insert_iter_and_bitmap() { - let mut ring = SeqRing::<4, u8>::new(SessionSeq(10)); - - ring.insert(SessionSeq(10), 1).unwrap(); - ring.insert(SessionSeq(12), 3).unwrap(); - - assert_eq!(ring.bitmap(), 0b0101); - assert_eq!( - ring.iter() - .map(|(seq, value)| (seq, *value)) - .collect::>(), - vec![(SessionSeq(10), 1), (SessionSeq(12), 3)] - ); - } - - #[test] - fn advance_fronts() { - let mut ring = SeqRing::<4, u8>::new(SessionSeq(10)); - - ring.insert(SessionSeq(11), 2).unwrap(); - ring.advance_empty_front_until(SessionSeq(11)); - assert_eq!(ring.base_seq(), SessionSeq(11)); - assert_eq!(ring.get(&SessionSeq(11)), Some(&2)); - - ring.advance_occupied_front(); - assert_eq!(ring.base_seq(), SessionSeq(12)); - assert!(ring.get(&SessionSeq(11)).is_none()); - } - - #[test] - fn insert_errors() { - let mut ring = SeqRing::<2, u8>::new(SessionSeq(5)); - - ring.insert(SessionSeq(5), 1).unwrap(); - - assert_eq!( - ring.insert(SessionSeq(5), 2), - Err(SeqRingInsertError::Occupied) - ); - assert_eq!( - ring.insert(SessionSeq(7), 3), - Err(SeqRingInsertError::OutOfWindow) - ); - } -} diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index 99ebb3ee..dcc7bf6f 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -1,13 +1,58 @@ -use std::{collections::VecDeque, time::Instant}; +use std::{ + collections::{BTreeSet, VecDeque}, + time::Instant, +}; use indexmap::IndexMap; use ql_wire::{ - CloseTarget, SessionAck, SessionBody, SessionCloseBody, SessionSeq, StreamClose, StreamId, + ByteReassembly, CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionCloseBody, + StreamClose, StreamData, StreamId, XID, }; -use super::{ring::SeqRing, stream_window::StreamRecvWindow, SessionState}; +use super::{SessionState, SESSION_RECORD_TRACKED_WINDOW}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamParity { + Even, + Odd, +} + +impl StreamParity { + pub fn for_local(local: XID, peer: XID) -> Self { + match local.0.cmp(&peer.0) { + std::cmp::Ordering::Less | std::cmp::Ordering::Equal => Self::Even, + std::cmp::Ordering::Greater => Self::Odd, + } + } -pub const SESSION_WINDOW_CAPACITY: usize = 64; + pub const fn first_stream_id(self) -> u32 { + match self { + Self::Even => 0, + Self::Odd => 1, + } + } + + pub const fn matches(self, stream_id: StreamId) -> bool { + match self { + Self::Even => stream_id.0 % 2 == 0, + Self::Odd => stream_id.0 % 2 == 1, + } + } + + pub const fn remote(self) -> Self { + match self { + Self::Even => Self::Odd, + Self::Odd => Self::Even, + } + } + + pub fn make_stream_id(self, ordinal: u32) -> StreamId { + StreamId( + self.first_stream_id() + .saturating_add(ordinal.saturating_mul(2)), + ) + } +} #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum StreamRole { @@ -31,7 +76,7 @@ impl StreamRole { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum OutboundState { Open, FinQueued, @@ -39,13 +84,6 @@ pub enum OutboundState { Closed, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum StreamOpenState { - PendingSend, - WaitingForAck, - Opened, -} - #[derive(Debug, Clone, PartialEq, Eq)] pub enum InboundState { Open, @@ -57,63 +95,114 @@ pub enum InboundState { #[derive(Debug)] pub struct StreamState { pub role: StreamRole, - pub open_state: StreamOpenState, pub send_buf: VecDeque, + // TODO: this is a stopgap shape and should be replaced. + // + // right now we keep sent-but-not-yet-acked stream bytes here as `ql_wire::StreamData` + // segments so the session scheduler can re-pack them into later records after loss. + // that works mechanically, but it is the wrong abstraction boundary: the fsm is caching + // wire-shaped frames instead of owning transport-neutral outbound stream state. + // + // the cleaner model is: + // - keep one authoritative outbound byte buffer per stream + // - track offsets/cursors into that buffer: + // - oldest buffered offset + // - next unsent offset + // - final offset, if known + // - keep lightweight sent-range/manifests that reference byte ranges in that buffer + // instead of cloning/storing `StreamData` + // - build `ql_wire::StreamData` only at pack time + // - free buffered prefix bytes only once no in-flight record manifest still references them + // + // that would let `send_buf` remain the source of truth while record manifests explain + // which accepted byte ranges were carried by which record. + pub retransmit: VecDeque, pub pending_close: Option, - pub recv_buf: VecDeque, - pub recv_window: StreamRecvWindow, - pub next_send_chunk_seq: u64, + pub inflight_bytes: usize, + pub next_send_offset: u64, + pub peer_max_offset: u64, pub outbound_state: OutboundState, pub inbound_state: InboundState, + pub recv: ByteReassembly, + pub advertised_max_offset: u64, + pub pending_window: bool, } impl StreamState { - pub fn new(role: StreamRole) -> Self { + pub fn new( + role: StreamRole, + _send_buffer_size: usize, + receive_buffer_size: usize, + ) -> Self { Self { role, - open_state: match role { - StreamRole::Initiator => StreamOpenState::PendingSend, - StreamRole::Responder => StreamOpenState::Opened, - }, send_buf: VecDeque::new(), + retransmit: VecDeque::new(), pending_close: None, - recv_buf: VecDeque::new(), - recv_window: StreamRecvWindow::new(), - next_send_chunk_seq: 0, + inflight_bytes: 0, + next_send_offset: 0, + peer_max_offset: receive_buffer_size as u64, outbound_state: OutboundState::Open, inbound_state: InboundState::Open, + recv: ByteReassembly::new(receive_buffer_size), + advertised_max_offset: receive_buffer_size as u64, + pending_window: false, } } pub fn is_writable(&self) -> bool { matches!(self.outbound_state, OutboundState::Open) } + + pub fn buffered_send_bytes(&self) -> usize { + self.send_buf.len().saturating_add(self.inflight_bytes) + } + + pub fn send_capacity(&self, send_buffer_size: usize) -> usize { + send_buffer_size.saturating_sub(self.buffered_send_bytes()) + } + + pub fn readable_bytes(&self) -> usize { + self.recv.readable_len() + } + + pub fn recv_limit(&self) -> u64 { + self.recv + .start_offset() + .saturating_add(self.recv.max_buffered() as u64) + } + + pub fn reset_recv(&mut self) { + self.recv = ByteReassembly::with_start_offset(self.recv.start_offset(), self.recv.max_buffered()); + } } #[derive(Debug, Clone)] -pub struct PendingSessionBody { - pub body: SessionBody, - /// whether the body should be retransmitted after a confirmed send times out without ack - pub retransmit: bool, +pub enum ReliableFrame { + StreamData(StreamData), + StreamClose(StreamClose), + Close(SessionCloseBody), } -#[derive(Debug, Clone, Default)] -pub struct PendingSessionControl { - pub ping: bool, - pub close: Option, +#[derive(Debug, Clone)] +pub struct PendingRecord { + pub seq: RecordSeq, + pub reliable: Vec, + pub ack_included: bool, + pub ping_included: bool, + pub window_updates: Vec<(StreamId, u64)>, } #[derive(Debug, Clone)] -pub struct TxEntry { - pub pending: PendingSessionBody, - pub state: TxState, +pub struct SentRecord { + pub pending: PendingRecord, + pub sent_at: Instant, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum TxState { - Pending, - Issued, - Sent { sent_at: Instant }, +#[derive(Debug, Clone, Default)] +pub struct PendingSessionControl { + pub ping: bool, + pub close: Option, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -123,32 +212,88 @@ pub enum AckState { Immediate, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReceiveInsertOutcome { + New { out_of_order: bool }, + Duplicate, +} + +#[derive(Debug, Default)] +pub struct ReceivedRecords { + seen: BTreeSet, + largest: Option, +} + +impl ReceivedRecords { + pub fn insert(&mut self, seq: RecordSeq) -> ReceiveInsertOutcome { + if self.seen.contains(&seq.0) { + return ReceiveInsertOutcome::Duplicate; + } + + if self + .largest + .is_some_and(|largest| largest.saturating_sub(seq.0) > SESSION_RECORD_TRACKED_WINDOW) + { + return ReceiveInsertOutcome::Duplicate; + } + + let out_of_order = self + .largest + .is_some_and(|largest| seq.0 != largest.saturating_add(1)); + self.seen.insert(seq.0); + self.largest = Some(self.largest.map_or(seq.0, |largest| largest.max(seq.0))); + self.prune(); + ReceiveInsertOutcome::New { out_of_order } + } + + pub fn ack(&self) -> Option { + if self.seen.is_empty() { + return None; + } + + let mut ranges = Vec::new(); + let mut iter = self.seen.iter().copied(); + let first = iter.next()?; + let mut start = first; + let mut end = first.saturating_add(1); + + for seq in iter { + if seq == end { + end = end.saturating_add(1); + continue; + } + + ranges.push(RecordAckRange { start, end }); + start = seq; + end = seq.saturating_add(1); + } + + ranges.push(RecordAckRange { start, end }); + Some(RecordAck { ranges }) + } + + fn prune(&mut self) { + let Some(largest) = self.largest else { + return; + }; + let keep_from = largest.saturating_sub(SESSION_RECORD_TRACKED_WINDOW); + self.seen.retain(|seq| *seq >= keep_from); + } +} + pub struct SessionFsmState { pub now: Instant, pub last_activity_at: Instant, pub last_inbound_at: Instant, pub session_state: SessionState, pub next_stream_ordinal: u32, - pub next_seq: SessionSeq, - pub tx_ring: SeqRing, - pub rx_ring: SeqRing, + pub next_record_seq: RecordSeq, + pub next_write_id: u64, + pub issued_records: IndexMap, + pub sent_records: IndexMap, + pub received_records: ReceivedRecords, pub ack_state: AckState, pub pending_control: PendingSessionControl, - /// `IndexMap` has stable (and fast) iteration order for round-robin - /// scheduling, so we do not need a separate ready queue pub streams: IndexMap, pub next_stream_index: usize, } - -impl SessionFsmState { - pub fn current_ack(&self) -> SessionAck { - SessionAck { - base: SessionSeq(self.rx_ring.base_seq().0.saturating_sub(1)), - bitmap: self.rx_ring.bitmap(), - } - } - - pub fn clear_ack_schedule(&mut self) { - self.ack_state = AckState::Idle; - } -} diff --git a/ql-fsm/src/session/stream_window.rs b/ql-fsm/src/session/stream_window.rs deleted file mode 100644 index e25b5e37..00000000 --- a/ql-fsm/src/session/stream_window.rs +++ /dev/null @@ -1,71 +0,0 @@ -use std::array; - -pub const STREAM_RECV_WINDOW_CAPACITY: usize = 8; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct RxChunk { - pub bytes: Vec, - pub fin: bool, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum RecvInsertOutcome { - Inserted, - Duplicate, - RejectNoAck, - Conflict, -} - -#[derive(Debug)] -pub struct StreamRecvWindow { - next_chunk_seq: u64, - slots: [Option; STREAM_RECV_WINDOW_CAPACITY], -} - -impl StreamRecvWindow { - pub fn new() -> Self { - Self { - next_chunk_seq: 0, - slots: array::from_fn(|_| None), - } - } - - pub fn clear(&mut self) { - self.slots.fill(None); - } - - pub fn is_empty(&self) -> bool { - self.slots.iter().all(Option::is_none) - } - - pub fn next_chunk_seq(&self) -> u64 { - self.next_chunk_seq - } - - pub fn insert(&mut self, chunk_seq: u64, chunk: RxChunk) -> RecvInsertOutcome { - let Some(delta) = chunk_seq.checked_sub(self.next_chunk_seq) else { - return RecvInsertOutcome::Duplicate; - }; - if delta >= self.slots.len() as u64 { - return RecvInsertOutcome::RejectNoAck; - } - - let slot = &mut self.slots[delta as usize]; - match slot { - Some(existing) if *existing == chunk => RecvInsertOutcome::Duplicate, - Some(_) => RecvInsertOutcome::Conflict, - None => { - *slot = Some(chunk); - RecvInsertOutcome::Inserted - } - } - } - - pub fn pop_contiguous(&mut self) -> Option { - let chunk = self.slots[0].take()?; - self.next_chunk_seq += 1; - self.slots.rotate_left(1); - self.slots[self.slots.len() - 1] = None; - Some(chunk) - } -} diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 00008fc5..3422039c 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -1,861 +1,258 @@ use std::time::{Duration, Instant}; use ql_wire::{ - CloseCode, CloseTarget, PingBody, SessionAck, SessionBody, SessionEnvelope, SessionSeq, - StreamChunk, StreamClose, + CloseCode, CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionFrame, SessionRecord, + StreamClose, StreamData, StreamId, XID, }; -use super::{SessionEvent, SessionFsm, SessionFsmConfig, SessionState}; +use super::{state::StreamParity, SessionEvent, SessionFsm, SessionFsmConfig}; -fn read_stream_all(fsm: &mut SessionFsm, stream_id: ql_wire::StreamId) -> Vec { +fn read_stream_all(fsm: &mut SessionFsm, stream_id: StreamId) -> Vec { let mut out = Vec::new(); let mut buf = [0u8; 64]; loop { - let read = fsm.read_stream(stream_id, &mut buf).unwrap(); + let read = fsm.peek_stream(stream_id, &mut buf).unwrap(); if read == 0 { break; } out.extend_from_slice(&buf[..read]); + fsm.commit_stream_read(stream_id, read).unwrap(); } out } -fn ack(seq: u64, ack: SessionAck) -> SessionEnvelope { - SessionEnvelope { - seq: SessionSeq(seq), - ack, - body: SessionBody::Ack, - } -} - -fn ping(seq: u64, ack: SessionAck) -> SessionEnvelope { - SessionEnvelope { - seq: SessionSeq(seq), - ack, - body: SessionBody::Ping(PingBody), - } -} - -fn next_outbound(fsm: &mut SessionFsm, now: Instant) -> Option { - let (seq, envelope) = { - let (seq, ack, body) = fsm.take_next_write(now)?; - ( - seq, - SessionEnvelope { - seq, - ack, - body: body.clone(), - }, - ) - }; - fsm.confirm_write(now, seq); - Some(envelope) -} - -fn receive_events( - fsm: &mut SessionFsm, - now: Instant, - envelope: SessionEnvelope, -) -> Vec { - let mut events = Vec::new(); - fsm.receive(now, envelope, |event| events.push(event)); - events +fn next_outbound(fsm: &mut SessionFsm, now: Instant) -> Option { + let (write_id, record) = fsm.take_next_write(now)?; + fsm.confirm_write(now, write_id); + Some(record) } -fn on_timer_events(fsm: &mut SessionFsm, now: Instant) -> Vec { +fn receive_events(fsm: &mut SessionFsm, now: Instant, record: SessionRecord) -> Vec { let mut events = Vec::new(); - fsm.on_timer(now, |event| events.push(event)); + fsm.receive(now, record, |event| events.push(event)); events } #[test] -fn outbound_session_seq_increments_monotonically() { +fn outbound_record_seq_increments_monotonically() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = fsm.open_stream().unwrap(); - fsm.write_stream(stream_id, b"one".to_vec()).unwrap(); + assert_eq!(fsm.write_stream(stream_id, b"one").unwrap(), 3); let first = next_outbound(&mut fsm, now).unwrap(); - let _ = receive_events( - &mut fsm, - now + Duration::from_millis(1), - ack( - 1, - SessionAck { - base: SessionSeq(1), - bitmap: 0, - }, - ), - ); - - fsm.write_stream(stream_id, b"two".to_vec()).unwrap(); - let second = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); + assert_eq!(fsm.write_stream(stream_id, b"two").unwrap(), 3); + let second = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); - assert_eq!(first.seq, SessionSeq(1)); - assert_eq!(second.seq, SessionSeq(2)); + assert_eq!(first.seq, RecordSeq(0)); + assert_eq!(second.seq, RecordSeq(1)); } #[test] -fn inbound_ack_removes_acked_tx_entries() { +fn retransmit_uses_new_record_seq() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = fsm.open_stream().unwrap(); - fsm.write_stream(stream_id, b"one".to_vec()).unwrap(); + assert_eq!(fsm.write_stream(stream_id, b"retry").unwrap(), 5); let first = next_outbound(&mut fsm, now).unwrap(); - assert_eq!(first.seq, SessionSeq(1)); - assert!(fsm.state.tx_ring.contains_key(&SessionSeq(1))); - let _ = receive_events( - &mut fsm, - now + Duration::from_millis(1), - ack( - 1, - SessionAck { - base: SessionSeq(1), - bitmap: 0, - }, - ), - ); + fsm.on_timer(now + Duration::from_millis(200), |_| {}); + let retried = next_outbound(&mut fsm, now + Duration::from_millis(200)).unwrap(); - assert!(!fsm.state.tx_ring.contains_key(&SessionSeq(1))); + assert_ne!(first.seq, retried.seq); + assert_eq!(first.frames, retried.frames); } #[test] -fn out_of_order_receive_produces_bitmap_ack_then_advances_base() { +fn lost_record_on_one_stream_does_not_block_another_stream() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_id_a = ql_wire::StreamId(super::StreamNamespace::High.bit() | 1); - let stream_id_b = ql_wire::StreamId(super::StreamNamespace::High.bit() | 2); - - let _ = receive_events( - &mut fsm, - now, - SessionEnvelope { - seq: SessionSeq(2), - ack: SessionAck::EMPTY, - body: SessionBody::Stream(StreamChunk { - stream_id: stream_id_a, - chunk_seq: 0, - bytes: b"a".to_vec(), - fin: false, - }), - }, - ); - let gap_ack = next_outbound(&mut fsm, now).unwrap(); - assert_eq!(gap_ack.seq, SessionSeq(1)); - assert_eq!( - gap_ack.ack, - SessionAck { - base: SessionSeq(0), - bitmap: 0b10, - } - ); - - let _ = receive_events( - &mut fsm, - now + Duration::from_millis(1), - SessionEnvelope { - seq: SessionSeq(1), - ack: SessionAck::EMPTY, - body: SessionBody::Stream(StreamChunk { - stream_id: stream_id_b, - chunk_seq: 0, - bytes: b"b".to_vec(), - fin: false, - }), + let mut fsm = SessionFsm::new( + SessionFsmConfig { + record_size: 80, + ..SessionFsmConfig::default() }, + now, ); - let contiguous_ack = next_outbound(&mut fsm, now + Duration::from_millis(10)).unwrap(); - assert_eq!(contiguous_ack.seq, SessionSeq(2)); - assert_eq!( - contiguous_ack.ack, - SessionAck { - base: SessionSeq(2), - bitmap: 0, - } - ); -} - -#[test] -fn retransmit_reuses_session_seq() { - let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_id = fsm.open_stream().unwrap(); - - fsm.write_stream(stream_id, b"retry-me".to_vec()).unwrap(); - let first = next_outbound(&mut fsm, now).unwrap(); - - let retransmit_at = now + Duration::from_millis(200); - let retried = next_outbound(&mut fsm, retransmit_at).unwrap(); - - assert_eq!(first.seq, SessionSeq(1)); - assert_eq!(retried.seq, SessionSeq(1)); - assert_eq!(retried.body, first.body); -} - -#[test] -fn repeated_outbound_messages_keep_reporting_latest_receive_ack() { - let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id_a = fsm.open_stream().unwrap(); let stream_id_b = fsm.open_stream().unwrap(); + let payload_a = vec![b'a'; 40]; + let payload_b = vec![b'b'; 40]; - let _ = receive_events(&mut fsm, now, ack(1, SessionAck::EMPTY)); + assert_eq!(fsm.write_stream(stream_id_a, &payload_a).unwrap(), 40); + assert_eq!(fsm.write_stream(stream_id_b, &payload_b).unwrap(), 40); - fsm.write_stream(stream_id_a, b"one".to_vec()).unwrap(); let first = next_outbound(&mut fsm, now).unwrap(); - - fsm.write_stream(stream_id_b, b"two".to_vec()).unwrap(); let second = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); + assert_ne!(first.seq, second.seq); - assert_eq!(first.ack.base, SessionSeq(1)); - assert_eq!(second.ack.base, SessionSeq(1)); - assert_eq!(first.ack.bitmap, 0); - assert_eq!(second.ack.bitmap, 0); -} + assert_eq!(fsm.write_stream(stream_id_b, b"b-2").unwrap(), 3); + let third = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); -#[test] -fn local_inbound_close_ignores_late_remote_bytes() { - let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_id = fsm.open_stream().unwrap(); - - fsm.close_stream( - stream_id, - CloseTarget::Response, - CloseCode::CANCELLED, - Vec::new(), - ) - .unwrap(); - - let events = receive_events( - &mut fsm, - now, - SessionEnvelope { - seq: SessionSeq(1), - ack: SessionAck::EMPTY, - body: SessionBody::Stream(StreamChunk { - stream_id, - chunk_seq: 0, - bytes: b"late".to_vec(), - fin: false, - }), - }, - ); - - assert_eq!(fsm.state.session_state, SessionState::Open); - assert_eq!(read_stream_all(&mut fsm, stream_id), Vec::::new()); - assert!(events.is_empty()); -} - -#[test] -fn missing_stream_nonzero_chunk_is_ignored_until_chunk_zero_arrives() { - let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 7); - - let events = receive_events( - &mut fsm, - now, - SessionEnvelope { - seq: SessionSeq(1), - ack: SessionAck::EMPTY, - body: SessionBody::Stream(StreamChunk { - stream_id, - chunk_seq: 1, - bytes: b"b".to_vec(), - fin: false, - }), - }, - ); - - assert_eq!(fsm.state.session_state, SessionState::Open); - assert!(events.is_empty()); - assert!(!fsm.state.streams.contains_key(&stream_id)); - - let events = receive_events( - &mut fsm, - now + Duration::from_millis(1), - SessionEnvelope { - seq: SessionSeq(2), - ack: SessionAck::EMPTY, - body: SessionBody::Stream(StreamChunk { - stream_id, - chunk_seq: 0, - bytes: b"a".to_vec(), - fin: false, - }), - }, - ); - - assert_eq!( - events, - vec![ - SessionEvent::Opened(stream_id), - SessionEvent::Readable(stream_id) - ] - ); - assert_eq!(read_stream_all(&mut fsm, stream_id), b"a".to_vec()); -} - -#[test] -fn out_of_order_chunks_within_recv_window_are_buffered_and_drained() { - let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 8); - - let mut events = receive_events( - &mut fsm, - now, - SessionEnvelope { - seq: SessionSeq(1), - ack: SessionAck::EMPTY, - body: SessionBody::Stream(StreamChunk { - stream_id, - chunk_seq: 0, - bytes: b"a".to_vec(), - fin: false, - }), - }, - ); - events.extend(receive_events( - &mut fsm, - now + Duration::from_millis(1), - SessionEnvelope { - seq: SessionSeq(2), - ack: SessionAck::EMPTY, - body: SessionBody::Stream(StreamChunk { - stream_id, - chunk_seq: 2, - bytes: b"c".to_vec(), - fin: false, - }), - }, - )); - events.extend(receive_events( - &mut fsm, - now + Duration::from_millis(2), - SessionEnvelope { - seq: SessionSeq(3), - ack: SessionAck::EMPTY, - body: SessionBody::Stream(StreamChunk { - stream_id, - chunk_seq: 1, - bytes: b"b".to_vec(), - fin: false, - }), - }, - )); - - assert_eq!( - events, - vec![ - SessionEvent::Opened(stream_id), - SessionEvent::Readable(stream_id) - ] - ); - assert_eq!(read_stream_all(&mut fsm, stream_id), b"abc".to_vec()); + let stream_ids: Vec<_> = third + .frames + .iter() + .filter_map(|frame| match frame { + SessionFrame::StreamData(frame) => Some(frame.stream_id), + _ => None, + }) + .collect(); + assert_eq!(stream_ids, vec![stream_id_b]); } #[test] -fn chunk_past_recv_window_is_dropped_without_session_ack() { +fn write_stream_is_partial_and_ack_emits_writable() { let now = Instant::now(); let mut fsm = SessionFsm::new( SessionFsmConfig { - ack_delay: Duration::ZERO, + stream_send_buffer_size: 4, ..SessionFsmConfig::default() }, now, ); - let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 10); - - let _ = receive_events( - &mut fsm, - now, - SessionEnvelope { - seq: SessionSeq(1), - ack: SessionAck::EMPTY, - body: SessionBody::Stream(StreamChunk { - stream_id, - chunk_seq: 0, - bytes: b"a".to_vec(), - fin: false, - }), - }, - ); - - let ack = next_outbound(&mut fsm, now).unwrap(); - assert_eq!( - ack.ack, - SessionAck { - base: SessionSeq(1), - bitmap: 0, - } - ); - - let _ = receive_events( - &mut fsm, - now + Duration::from_millis(1), - SessionEnvelope { - seq: SessionSeq(2), - ack: SessionAck::EMPTY, - body: SessionBody::Stream(StreamChunk { - stream_id, - chunk_seq: 9, - bytes: b"z".to_vec(), - fin: false, - }), - }, - ); + let stream_id = fsm.open_stream().unwrap(); - assert_eq!(fsm.state.rx_ring.base_seq(), SessionSeq(2)); - assert!(!fsm.state.rx_ring.contains_key(&SessionSeq(2))); - assert_eq!( - fsm.state.current_ack(), - SessionAck { - base: SessionSeq(1), - bitmap: 0, - } - ); - assert!(next_outbound(&mut fsm, now + Duration::from_millis(2)).is_none()); + assert_eq!(fsm.write_stream(stream_id, b"abcd").unwrap(), 4); + assert_eq!(fsm.write_stream(stream_id, b"z").unwrap(), 0); + + let sent = next_outbound(&mut fsm, now).unwrap(); + let ack = SessionRecord { + seq: RecordSeq(99), + frames: vec![SessionFrame::Ack(RecordAck { + ranges: vec![RecordAckRange { + start: sent.seq.0, + end: sent.seq.0 + 1, + }], + })], + }; + let events = receive_events(&mut fsm, now + Duration::from_millis(1), ack); + assert_eq!(events, vec![SessionEvent::Writable(stream_id)]); + assert_eq!(fsm.write_stream(stream_id, b"z").unwrap(), 1); } #[test] -fn local_stream_waits_for_open_frame_ack_before_sending_follow_up_data() { +fn commit_stream_read_is_what_advances_stream_window() { let now = Instant::now(); let mut fsm = SessionFsm::new( SessionFsmConfig { - stream_chunk_size: 2, + local_parity: StreamParity::Even, + ack_delay: Duration::ZERO, ..SessionFsmConfig::default() }, now, ); - let stream_id = fsm.open_stream().unwrap(); - - fsm.write_stream(stream_id, b"hello".to_vec()).unwrap(); - - let first = next_outbound(&mut fsm, now).unwrap(); - assert_eq!( - first.body, - SessionBody::Stream(StreamChunk { - stream_id, - chunk_seq: 0, - bytes: b"he".to_vec(), - fin: false, - }) - ); - assert!(next_outbound(&mut fsm, now + Duration::from_millis(1)).is_none()); - - let _ = receive_events( - &mut fsm, - now + Duration::from_millis(2), - ack( - 1, - SessionAck { - base: SessionSeq(1), - bitmap: 0, - }, - ), - ); - - let second = next_outbound(&mut fsm, now + Duration::from_millis(3)).unwrap(); - assert_eq!( - second.body, - SessionBody::Stream(StreamChunk { + let stream_id = StreamId(1); + let data = SessionRecord { + seq: RecordSeq(7), + frames: vec![SessionFrame::StreamData(StreamData { stream_id, - chunk_seq: 1, - bytes: b"ll".to_vec(), + offset: 0, fin: false, - }) - ); -} - -#[test] -fn stream_is_reaped_after_terminal_state_and_last_stream_ack() { - let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 13); - - let events = receive_events( - &mut fsm, - now, - SessionEnvelope { - seq: SessionSeq(1), - ack: SessionAck::EMPTY, - body: SessionBody::Stream(StreamChunk { - stream_id, - chunk_seq: 0, - bytes: b"hi".to_vec(), - fin: true, - }), - }, - ); - - assert_eq!( - events, - vec![ - SessionEvent::Opened(stream_id), - SessionEvent::Readable(stream_id), - SessionEvent::Finished(stream_id), - ] - ); - assert_eq!(read_stream_all(&mut fsm, stream_id), b"hi".to_vec()); - assert!(fsm.state.streams.contains_key(&stream_id)); - - fsm.finish_stream(stream_id).unwrap(); - let fin = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); - assert_eq!( - fin.body, - SessionBody::Stream(StreamChunk { - stream_id, - chunk_seq: 0, - bytes: Vec::new(), - fin: true, - }) - ); - assert!(fsm.state.streams.contains_key(&stream_id)); - - let _ = receive_events( - &mut fsm, - now + Duration::from_millis(2), - ack( - 2, - SessionAck { - base: SessionSeq(2), - bitmap: 0, - }, - ), - ); - - assert!(!fsm.state.streams.contains_key(&stream_id)); -} - -#[test] -fn replayed_remote_open_does_not_recreate_reaped_stream() { - let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 17); - let opener = SessionEnvelope { - seq: SessionSeq(1), - ack: SessionAck::EMPTY, - body: SessionBody::Stream(StreamChunk { - stream_id, - chunk_seq: 0, bytes: b"hi".to_vec(), - fin: true, - }), + })], }; - - let events = receive_events(&mut fsm, now, opener.clone()); - + let events = receive_events(&mut fsm, now, data); assert_eq!( events, - vec![ - SessionEvent::Opened(stream_id), - SessionEvent::Readable(stream_id), - SessionEvent::Finished(stream_id), - ] - ); - assert_eq!(read_stream_all(&mut fsm, stream_id), b"hi".to_vec()); - - fsm.finish_stream(stream_id).unwrap(); - let fin = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); - assert_eq!( - fin.body, - SessionBody::Stream(StreamChunk { - stream_id, - chunk_seq: 0, - bytes: Vec::new(), - fin: true, - }) + vec![SessionEvent::Opened(stream_id), SessionEvent::Readable(stream_id)] ); - let _ = receive_events( - &mut fsm, - now + Duration::from_millis(2), - ack( - 2, - SessionAck { - base: SessionSeq(1), - bitmap: 0, - }, - ), - ); + let first = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); + assert!(matches!(first.frames.as_slice(), [SessionFrame::Ack(_)])); - assert!(!fsm.state.streams.contains_key(&stream_id)); + let mut buf = [0u8; 8]; + assert_eq!(fsm.peek_stream(stream_id, &mut buf).unwrap(), 2); - let events = receive_events(&mut fsm, now + Duration::from_millis(3), opener); + assert!(next_outbound(&mut fsm, now + Duration::from_millis(2)).is_none()); - assert_eq!(fsm.state.session_state, SessionState::Open); - assert!(!fsm.state.streams.contains_key(&stream_id)); - assert!(events.is_empty()); + fsm.commit_stream_read(stream_id, 2).unwrap(); + let second = next_outbound(&mut fsm, now + Duration::from_millis(3)).unwrap(); + assert!(matches!( + second.frames.as_slice(), + [SessionFrame::StreamWindow(window)] if window.stream_id == stream_id + )); } #[test] -fn duplicate_committed_data_is_not_redelivered() { +fn remote_stream_close_is_reliable_and_retried() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 9); - let body = SessionBody::Stream(StreamChunk { - stream_id, - chunk_seq: 0, - bytes: b"dup".to_vec(), - fin: false, - }); + let stream_id = fsm.open_stream().unwrap(); - let _ = receive_events( - &mut fsm, - now, - SessionEnvelope { - seq: SessionSeq(1), - ack: SessionAck::EMPTY, - body: body.clone(), - }, - ); - let _ = read_stream_all(&mut fsm, stream_id); + fsm.close_stream( + stream_id, + CloseTarget::Both, + CloseCode::CANCELLED, + b"bye".to_vec(), + ) + .unwrap(); - let events = receive_events( - &mut fsm, - now + Duration::from_millis(1), - SessionEnvelope { - seq: SessionSeq(2), - ack: SessionAck::EMPTY, - body, - }, - ); + let (write_id, first) = fsm.take_next_write(now).unwrap(); + fsm.confirm_write(now, write_id); + assert!(matches!( + first.frames.as_slice(), + [SessionFrame::StreamClose(StreamClose { stream_id: id, .. })] if *id == stream_id + )); - assert!(events.is_empty()); - assert_eq!(read_stream_all(&mut fsm, stream_id), Vec::::new()); + fsm.on_timer(now + Duration::from_millis(200), |_| {}); + let retried = next_outbound(&mut fsm, now + Duration::from_millis(200)).unwrap(); + assert_ne!(first.seq, retried.seq); + assert_eq!(first.frames, retried.frames); } #[test] -fn next_outbound_round_robins_across_ready_streams() { +fn stream_ids_follow_even_odd_xid_ordering() { let now = Instant::now(); - let mut fsm = SessionFsm::new( + let even = StreamParity::for_local(XID([1; XID::SIZE]), XID([2; XID::SIZE])); + let odd = StreamParity::for_local(XID([2; XID::SIZE]), XID([1; XID::SIZE])); + + let even_id = SessionFsm::new( SessionFsmConfig { - stream_chunk_size: 3, + local_parity: even, ..SessionFsmConfig::default() }, now, - ); - let stream_id_a = fsm.open_stream().unwrap(); - let stream_id_b = fsm.open_stream().unwrap(); - - fsm.write_stream(stream_id_a, b"a-1".to_vec()).unwrap(); - fsm.write_stream(stream_id_b, b"b-1".to_vec()).unwrap(); - fsm.write_stream(stream_id_a, b"a-2".to_vec()).unwrap(); - fsm.write_stream(stream_id_b, b"b-2".to_vec()).unwrap(); - - let first_round: Vec<_> = (0..2) - .map(|_| match next_outbound(&mut fsm, now).unwrap().body { - SessionBody::Stream(frame) => frame.stream_id, - other => panic!("expected stream frame, got {other:?}"), - }) - .collect(); - - let _ = receive_events( - &mut fsm, - now + Duration::from_millis(1), - ack( - 1, - SessionAck { - base: SessionSeq(2), - bitmap: 0, - }, - ), - ); - - let second_round: Vec<_> = (0..2) - .map(|_| { - match next_outbound(&mut fsm, now + Duration::from_millis(2)) - .unwrap() - .body - { - SessionBody::Stream(frame) => frame.stream_id, - other => panic!("expected stream frame, got {other:?}"), - } - }) - .collect(); - - assert_eq!(first_round, vec![stream_id_a, stream_id_b]); - assert_eq!(second_round, vec![stream_id_a, stream_id_b]); -} - -#[test] -fn idle_session_sends_ping_after_keepalive_interval() { - let now = Instant::now(); - let mut fsm = SessionFsm::new( + ) + .open_stream() + .unwrap(); + let odd_id = SessionFsm::new( SessionFsmConfig { - keepalive_interval: Duration::from_millis(50), + local_parity: odd, ..SessionFsmConfig::default() }, now, - ); - - assert_eq!(fsm.next_deadline(), Some(now + Duration::from_millis(50))); - assert!(next_outbound(&mut fsm, now + Duration::from_millis(49)).is_none()); - assert!(on_timer_events(&mut fsm, now + Duration::from_millis(50)).is_empty()); - - let envelope = next_outbound(&mut fsm, now + Duration::from_millis(50)).unwrap(); - assert!(matches!(envelope.body, SessionBody::Ping(PingBody))); -} - -#[test] -fn receive_ping_schedules_ack_without_ping_pong() { - let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - - let _ = receive_events(&mut fsm, now, ping(1, SessionAck::EMPTY)); - - let ack_envelope = next_outbound(&mut fsm, now + Duration::from_millis(10)).unwrap(); - assert_eq!(ack_envelope.body, SessionBody::Ack); - - let _ = receive_events( - &mut fsm, - now + Duration::from_millis(20), - ack(2, SessionAck::EMPTY), - ); - assert!(next_outbound(&mut fsm, now + Duration::from_millis(30)).is_none()); -} - -#[test] -fn tx_selective_ack_keeps_front_gap_pinned() { - let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_ids: Vec<_> = (0..64).map(|_| fsm.open_stream().unwrap()).collect(); - - for (byte, stream_id) in (0..64u8).zip(stream_ids.iter().copied()) { - fsm.write_stream(stream_id, vec![byte]).unwrap(); - let _ = next_outbound(&mut fsm, now + Duration::from_millis(byte as u64)).unwrap(); - } - - let _ = receive_events( - &mut fsm, - now + Duration::from_millis(100), - ack( - 1, - SessionAck { - base: SessionSeq(0), - bitmap: u64::MAX ^ 1, - }, - ), - ); - - assert!(fsm.state.tx_ring.contains_key(&SessionSeq(1))); - assert!(!fsm.state.tx_ring.contains_key(&SessionSeq(2))); - - let extra_stream = fsm.open_stream().unwrap(); - fsm.write_stream(extra_stream, b"x".to_vec()).unwrap(); - assert!(next_outbound(&mut fsm, now + Duration::from_millis(101)).is_none()); - - let _ = receive_events( - &mut fsm, - now + Duration::from_millis(102), - ack( - 2, - SessionAck { - base: SessionSeq(1), - bitmap: 0, - }, - ), - ); - - assert_eq!( - next_outbound(&mut fsm, now + Duration::from_millis(103)) - .unwrap() - .seq, - SessionSeq(65) - ); -} - -#[test] -fn rx_seq_past_window_closes_protocol() { - let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - - let events = receive_events(&mut fsm, now, ping(65, SessionAck::EMPTY)); - - assert_eq!(fsm.state.session_state, SessionState::Closed); - assert!(matches!( - events.as_slice(), - [SessionEvent::SessionClosed(close)] if close.code == CloseCode::PROTOCOL - )); -} - -#[test] -fn duplicate_old_packet_seq_is_ignored() { - let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_id = ql_wire::StreamId(super::StreamNamespace::High.bit() | 11); - let body = SessionBody::Stream(StreamChunk { - stream_id, - chunk_seq: 0, - bytes: b"x".to_vec(), - fin: false, - }); - - let _ = receive_events( - &mut fsm, - now, - SessionEnvelope { - seq: SessionSeq(1), - ack: SessionAck::EMPTY, - body: body.clone(), - }, - ); - let _ = read_stream_all(&mut fsm, stream_id); - - let events = receive_events( - &mut fsm, - now + Duration::from_millis(1), - SessionEnvelope { - seq: SessionSeq(1), - ack: SessionAck::EMPTY, - body, - }, - ); + ) + .open_stream() + .unwrap(); - assert!(events.is_empty()); - assert_eq!(read_stream_all(&mut fsm, stream_id), Vec::::new()); + assert_eq!(even_id.0 % 2, 0); + assert_eq!(odd_id.0 % 2, 1); } #[test] -fn retransmitted_stream_close_is_idempotent() { +fn duplicate_stream_data_is_not_redelivered() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_id = fsm.open_stream().unwrap(); - let frame = StreamClose { - stream_id, - target: CloseTarget::Response, - code: CloseCode::CANCELLED, - payload: Vec::new(), + let stream_id = StreamId(1); + let record = SessionRecord { + seq: RecordSeq(1), + frames: vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: 0, + fin: false, + bytes: b"hi".to_vec(), + })], }; - - let events = receive_events( - &mut fsm, - now, - SessionEnvelope { - seq: SessionSeq(1), - ack: SessionAck::EMPTY, - body: SessionBody::StreamClose(frame.clone()), - }, - ); - - assert_eq!(events, vec![SessionEvent::Closed(frame.clone())]); - assert_eq!(read_stream_all(&mut fsm, stream_id), Vec::::new()); - - let events = receive_events( + let _ = receive_events(&mut fsm, now, record.clone()); + let _ = receive_events( &mut fsm, now + Duration::from_millis(1), - SessionEnvelope { - seq: SessionSeq(2), - ack: SessionAck::EMPTY, - body: SessionBody::StreamClose(frame), + SessionRecord { + seq: RecordSeq(2), + ..record }, ); - assert!(events.is_empty()); - assert_eq!(read_stream_all(&mut fsm, stream_id), Vec::::new()); + assert_eq!(read_stream_all(&mut fsm, stream_id), b"hi".to_vec()); } diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 02ab2e58..7716c089 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -229,7 +229,7 @@ fn initiator_waits_for_ready_before_connecting() { harness .a .fsm - .write_stream(stream_id, b"queued".to_vec()) + .write_stream(stream_id, b"queued") .unwrap(); let confirm = harness.next_outbound_a().unwrap(); diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index f5055354..abee0d40 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -9,12 +9,12 @@ use std::{ use libcrux_aesgcm::AesGcm256Key; use ql_wire::{ self, generate_ml_dsa_keypair, generate_ml_kem_keypair, EncryptedMessage, QlCrypto, QlIdentity, - QlPayload, QlRecord, SessionEnvelope, SessionKey, XID, + QlPayload, QlRecord, SessionKey, XID, }; use sha2::{Digest, Sha256}; use crate::{ - session::{SessionFsm, SessionFsmConfig, StreamNamespace}, + session::{state::StreamParity, SessionFsm, SessionFsmConfig}, state::ConnectionState, FsmTime, OutboundWrite, Peer, QlFsm, QlFsmConfig, SessionWriteId, }; @@ -144,29 +144,33 @@ impl Harness { }; harness.a.fsm.session = SessionFsm::new( SessionFsmConfig { - local_namespace: StreamNamespace::for_local( + local_parity: StreamParity::for_local( harness.a.fsm.identity.xid, harness.a.fsm.peer.as_ref().unwrap().peer.xid, ), - stream_chunk_size: config.session_stream_chunk_size, - ack_delay: config.session_ack_delay, - retransmit_timeout: config.session_retransmit_timeout, + record_size: config.session_record_size, + ack_delay: config.session_record_ack_delay, + retransmit_timeout: config.session_record_retransmit_timeout, keepalive_interval: config.session_keepalive_interval, peer_timeout: config.session_peer_timeout, + stream_send_buffer_size: config.session_stream_send_buffer_size, + stream_receive_buffer_size: config.session_stream_receive_buffer_size, }, harness.now, ); harness.b.fsm.session = SessionFsm::new( SessionFsmConfig { - local_namespace: StreamNamespace::for_local( + local_parity: StreamParity::for_local( harness.b.fsm.identity.xid, harness.b.fsm.peer.as_ref().unwrap().peer.xid, ), - stream_chunk_size: config.session_stream_chunk_size, - ack_delay: config.session_ack_delay, - retransmit_timeout: config.session_retransmit_timeout, + record_size: config.session_record_size, + ack_delay: config.session_record_ack_delay, + retransmit_timeout: config.session_record_retransmit_timeout, keepalive_interval: config.session_keepalive_interval, peer_timeout: config.session_peer_timeout, + stream_send_buffer_size: config.session_stream_send_buffer_size, + stream_receive_buffer_size: config.session_stream_receive_buffer_size, }, harness.now, ); @@ -271,15 +275,15 @@ fn peer_from_identity(identity: &QlIdentity) -> Peer { } } -fn decrypt_envelope( +fn decrypt_record( crypto: &impl QlCrypto, record: &QlRecord, session_key: &SessionKey, -) -> ql_wire::SessionEnvelope { +) -> ql_wire::SessionRecord { let aad = record.header.aad(); let QlPayload::Session(encrypted) = &record.payload else { panic!("expected encrypted payload"); }; let plaintext = encrypted.decrypt(crypto, session_key, &aad).unwrap(); - SessionEnvelope::decode(&plaintext).unwrap() + ql_wire::SessionRecord::decode(&plaintext).unwrap() } diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 640fcec7..5a832e28 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -1,19 +1,20 @@ use std::time::Duration; -use ql_wire::{SessionCloseBody, StreamId}; +use ql_wire::{SessionCloseBody, SessionFrame, StreamId}; use super::*; -use crate::{session::StreamNamespace, QlFsmEvent, QlSessionEvent}; +use crate::{session::state::StreamParity, QlFsmEvent, QlSessionEvent}; fn read_stream_all(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { let mut out = Vec::new(); let mut buf = [0u8; 64]; loop { - let read = fsm.read_stream(stream_id, &mut buf).unwrap(); + let read = fsm.peek_stream(stream_id, &mut buf).unwrap(); if read == 0 { break; } out.extend_from_slice(&buf[..read]); + fsm.commit_stream_read(stream_id, read).unwrap(); } out } @@ -23,11 +24,7 @@ fn connected_fsms_deliver_stream_data() { let mut harness = Harness::connected(QlFsmConfig::default()); let stream_id = harness.a.fsm.open_stream().unwrap(); - harness - .a - .fsm - .write_stream(stream_id, b"hello".to_vec()) - .unwrap(); + assert_eq!(harness.a.fsm.write_stream(stream_id, b"hello").unwrap(), 5); harness.a.fsm.finish_stream(stream_id).unwrap(); harness.pump(); @@ -51,19 +48,15 @@ fn connected_fsms_deliver_stream_data() { } #[test] -fn lost_encrypted_record_is_retried_and_acked() { +fn lost_record_is_retried_with_new_record_seq() { let config = QlFsmConfig::default(); let mut harness = Harness::connected(config); let stream_id = harness.a.fsm.open_stream().unwrap(); - harness - .a - .fsm - .write_stream(stream_id, b"retry".to_vec()) - .unwrap(); + assert_eq!(harness.a.fsm.write_stream(stream_id, b"retry").unwrap(), 5); let first = harness.next_outbound_a().unwrap(); - let session_key = harness + let session_key = *harness .b .fsm .peer @@ -72,19 +65,19 @@ fn lost_encrypted_record_is_retried_and_acked() { .session .session_key() .unwrap(); - let session_key = *session_key; - let first_body = decrypt_envelope(&harness.b.crypto, &first, &session_key); + let first_record = decrypt_record(&harness.b.crypto, &first, &session_key); - harness.advance(config.session_retransmit_timeout + Duration::from_millis(1)); + harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); + harness.a.fsm.on_timer(harness.time()); let retried = harness.next_outbound_a().unwrap(); - let retried_body = decrypt_envelope(&harness.b.crypto, &retried, &session_key); + let retried_record = decrypt_record(&harness.b.crypto, &retried, &session_key); - assert_eq!(first_body.seq, retried_body.seq); - assert_eq!(first_body.body, retried_body.body); + assert_ne!(retried_record.seq, first_record.seq); + assert_eq!(retried_record.frames, first_record.frames); harness.deliver_to_b(retried); - harness.advance(config.session_ack_delay); + harness.advance(config.session_record_ack_delay); harness.a.fsm.on_timer(harness.time()); harness.b.fsm.on_timer(harness.time()); harness.pump(); @@ -102,92 +95,26 @@ fn lost_encrypted_record_is_retried_and_acked() { b"retry".to_vec() ); - harness.advance(config.session_retransmit_timeout + Duration::from_millis(1)); + harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); + harness.a.fsm.on_timer(harness.time()); assert!(harness.next_outbound_a().is_none()); } #[test] -fn remote_unpair_clears_peer() { - let mut harness = Harness::connected(QlFsmConfig::default()); - - let record = harness - .a - .fsm - .unpair(harness.time(), &harness.a.crypto) - .unwrap(); - - assert_eq!( - harness.a.fsm.take_next_session_event(), - Some(QlSessionEvent::Unpaired) - ); - assert!(harness.a.fsm.peer.is_none()); - assert!(matches!( - harness.a.fsm.take_next_event(), - Some(QlFsmEvent::ClearPeer) - )); - - harness.deliver_to_b(record); - - assert_eq!( - harness.b.fsm.take_next_session_event(), - Some(QlSessionEvent::Unpaired) - ); - assert!(harness.b.fsm.peer.is_none()); - assert!(matches!( - harness.b.fsm.take_next_event(), - Some(QlFsmEvent::ClearPeer) - )); -} - -#[test] -fn unpair_returns_record_without_active_session() { - let mut harness = Harness::paired(QlFsmConfig::default()); - - let record = harness - .a - .fsm - .unpair(harness.time(), &harness.a.crypto) - .unwrap(); - - assert!(matches!(record.payload, QlPayload::Unpair(_))); - assert!(harness.a.fsm.peer.is_none()); - assert_eq!( - harness.a.fsm.take_next_session_event(), - Some(QlSessionEvent::Unpaired) - ); - assert!(matches!( - harness.a.fsm.take_next_event(), - Some(QlFsmEvent::ClearPeer) - )); -} - -#[test] -fn simultaneous_opens_use_disjoint_stream_id_namespaces() { +fn simultaneous_opens_use_even_and_odd_stream_ids() { let mut harness = Harness::connected(QlFsmConfig::default()); let stream_id_a = harness.a.fsm.open_stream().unwrap(); let stream_id_b = harness.b.fsm.open_stream().unwrap(); assert_ne!(stream_id_a, stream_id_b); - assert!( - StreamNamespace::for_local(harness.a.fsm.identity.xid, harness.b.fsm.identity.xid) - .matches(stream_id_a) - ); - assert!( - StreamNamespace::for_local(harness.b.fsm.identity.xid, harness.a.fsm.identity.xid) - .matches(stream_id_b) - ); + assert!(StreamParity::for_local(harness.a.fsm.identity.xid, harness.b.fsm.identity.xid) + .matches(stream_id_a)); + assert!(StreamParity::for_local(harness.b.fsm.identity.xid, harness.a.fsm.identity.xid) + .matches(stream_id_b)); - harness - .a - .fsm - .write_stream(stream_id_a, b"from-a".to_vec()) - .unwrap(); - harness - .b - .fsm - .write_stream(stream_id_b, b"from-b".to_vec()) - .unwrap(); + assert_eq!(harness.a.fsm.write_stream(stream_id_a, b"from-a").unwrap(), 6); + assert_eq!(harness.b.fsm.write_stream(stream_id_b, b"from-b").unwrap(), 6); harness.pump(); @@ -222,11 +149,7 @@ fn queued_stream_work_auto_connects_and_drains_after_handshake() { let mut harness = Harness::paired(QlFsmConfig::default()); let stream_id = harness.a.fsm.open_stream().unwrap(); - harness - .a - .fsm - .write_stream(stream_id, b"queued".to_vec()) - .unwrap(); + assert_eq!(harness.a.fsm.write_stream(stream_id, b"queued").unwrap(), 6); harness.a.fsm.finish_stream(stream_id).unwrap(); harness.pump(); @@ -267,11 +190,7 @@ fn queued_stream_work_is_failed_when_handshake_times_out() { let mut harness = Harness::paired(config); let stream_id = harness.a.fsm.open_stream().unwrap(); - harness - .a - .fsm - .write_stream(stream_id, b"queued".to_vec()) - .unwrap(); + assert_eq!(harness.a.fsm.write_stream(stream_id, b"queued").unwrap(), 6); let _hello = harness.next_outbound_a().unwrap(); @@ -292,20 +211,16 @@ fn queued_stream_work_is_failed_when_handshake_times_out() { } #[test] -fn returned_session_write_is_reissued_with_same_seq() { +fn returned_session_write_is_reissued_with_new_record_seq() { let mut harness = Harness::connected(QlFsmConfig::default()); let stream_id = harness.a.fsm.open_stream().unwrap(); - harness - .a - .fsm - .write_stream(stream_id, b"retry".to_vec()) - .unwrap(); + assert_eq!(harness.a.fsm.write_stream(stream_id, b"retry").unwrap(), 5); let write = harness.next_write_a().unwrap(); let id = write.session_write_id.expect("expected session write"); let record = write.record; - let session_key = harness + let session_key = *harness .b .fsm .peer @@ -314,21 +229,18 @@ fn returned_session_write_is_reissued_with_same_seq() { .session .session_key() .unwrap(); - let session_key = *session_key; - let first = decrypt_envelope(&harness.b.crypto, &record, &session_key); + let first = decrypt_record(&harness.b.crypto, &record, &session_key); harness.return_write_a(id); let write = harness.next_write_a().unwrap(); - let reissued_id = write - .session_write_id - .expect("expected reissued session write"); + let reissued_id = write.session_write_id.expect("expected reissued write"); let record = write.record; - let reissued = decrypt_envelope(&harness.b.crypto, &record, &session_key); + let reissued = decrypt_record(&harness.b.crypto, &record, &session_key); - assert_eq!(reissued_id, id); - assert_eq!(reissued.seq, first.seq); - assert_eq!(reissued.body, first.body); + assert_ne!(reissued_id, id); + assert_ne!(reissued.seq, first.seq); + assert_eq!(reissued.frames, first.frames); harness.confirm_write_a(reissued_id); harness.deliver_to_b(record); @@ -354,16 +266,12 @@ fn unconfirmed_session_write_does_not_start_retransmit_timer() { let mut harness = Harness::connected(config); let stream_id = harness.a.fsm.open_stream().unwrap(); - harness - .a - .fsm - .write_stream(stream_id, b"retry".to_vec()) - .unwrap(); + assert_eq!(harness.a.fsm.write_stream(stream_id, b"retry").unwrap(), 5); let write = harness.next_write_a().unwrap(); let id = write.session_write_id.expect("expected session write"); let record = write.record; - let session_key = harness + let session_key = *harness .b .fsm .peer @@ -372,23 +280,47 @@ fn unconfirmed_session_write_does_not_start_retransmit_timer() { .session .session_key() .unwrap(); - let session_key = *session_key; - let first = decrypt_envelope(&harness.b.crypto, &record, &session_key); + let first = decrypt_record(&harness.b.crypto, &record, &session_key); - harness.advance(config.session_retransmit_timeout + Duration::from_millis(1)); + harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); harness.a.fsm.on_timer(harness.time()); assert!(harness.next_write_a().is_none()); harness.confirm_write_a(id); - harness.advance(config.session_retransmit_timeout + Duration::from_millis(1)); + harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); + harness.a.fsm.on_timer(harness.time()); let write = harness.next_write_a().unwrap(); - assert!(write.session_write_id.is_some(), "expected retransmit"); let record = write.record; - let retried = decrypt_envelope(&harness.b.crypto, &record, &session_key); + let retried = decrypt_record(&harness.b.crypto, &record, &session_key); - assert_eq!(retried.seq, first.seq); - assert_eq!(retried.body, first.body); + assert_ne!(retried.seq, first.seq); + assert_eq!(retried.frames, first.frames); +} + +#[test] +fn ack_frame_releases_stream_capacity_and_emits_writable() { + let config = QlFsmConfig { + session_stream_send_buffer_size: 4, + ..QlFsmConfig::default() + }; + let mut harness = Harness::connected(config); + + let stream_id = harness.a.fsm.open_stream().unwrap(); + assert_eq!(harness.a.fsm.write_stream(stream_id, b"abcd").unwrap(), 4); + assert_eq!(harness.a.fsm.write_stream(stream_id, b"z").unwrap(), 0); + + let record = harness.next_outbound_a().unwrap(); + harness.deliver_to_b(record); + harness.advance(config.session_record_ack_delay); + harness.a.fsm.on_timer(harness.time()); + harness.b.fsm.on_timer(harness.time()); + harness.pump(); + + assert_eq!( + harness.a.fsm.take_next_session_event(), + Some(QlSessionEvent::Writable(stream_id)) + ); } #[test] @@ -407,4 +339,35 @@ fn kill_session_disconnects_locally() { code: ql_wire::CloseCode::CANCELLED })) ); + assert!(matches!( + harness.a.fsm.take_next_event(), + Some(QlFsmEvent::PeerStatusChanged { .. }) + )); +} + +#[test] +fn session_records_contain_ack_frames_after_delivery() { + let config = QlFsmConfig::default(); + let mut harness = Harness::connected(config); + + let stream_id = harness.a.fsm.open_stream().unwrap(); + assert_eq!(harness.a.fsm.write_stream(stream_id, b"x").unwrap(), 1); + + let data = harness.next_outbound_a().unwrap(); + harness.deliver_to_b(data); + harness.advance(config.session_record_ack_delay); + harness.b.fsm.on_timer(harness.time()); + + let ack = harness.next_outbound_b().unwrap(); + let session_key = *harness + .a + .fsm + .peer + .as_ref() + .unwrap() + .session + .session_key() + .unwrap(); + let ack_record = decrypt_record(&harness.a.crypto, &ack, &session_key); + assert!(matches!(ack_record.frames.as_slice(), [SessionFrame::Ack(_)])); } From 407804551d1ef9b7c3b0cd3f4d5a69f945a3be63 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 28 Mar 2026 01:00:23 -0400 Subject: [PATCH 039/304] ql-fsm: house ByteAssembler and read returns a borrowed view --- ql-fsm/src/implementation/fsm.rs | 16 ++-- ql-fsm/src/lib.rs | 23 ++---- ql-fsm/src/session/mod.rs | 79 ++++++++----------- .../src/session/reassembly.rs | 24 +++--- ql-fsm/src/session/state.rs | 15 ++-- ql-fsm/src/session/tests.rs | 25 ++++-- ql-fsm/src/tests/handshake.rs | 6 +- ql-fsm/src/tests/session.rs | 37 ++++++--- ql-wire/src/encrypted/mod.rs | 2 - ql-wire/src/encrypted/stream_close.rs | 8 +- ql-wire/src/encrypted/stream_data.rs | 12 +-- ql-wire/src/encrypted/stream_window.rs | 5 +- 12 files changed, 122 insertions(+), 130 deletions(-) rename ql-wire/src/encrypted/byte_reassembly.rs => ql-fsm/src/session/reassembly.rs (97%) diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index 428bb255..01d85f82 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -2,7 +2,9 @@ use std::time::Instant; use ql_wire::{self as wire, CloseCode, CloseTarget, Nonce, QlCrypto, QlPayloadRef, StreamId}; -use crate::{OutboundWrite, QlFsm, QlFsmError, QlFsmEvent, QlSessionEvent, SessionWriteId}; +use crate::{ + BytesIter, OutboundWrite, QlFsm, QlFsmError, QlFsmEvent, QlSessionEvent, SessionWriteId, +}; pub fn receive( fsm: &mut QlFsm, @@ -191,20 +193,16 @@ pub fn write_stream( Ok(fsm.session.write_stream(stream_id, bytes)?) } -pub fn peek_stream( - fsm: &mut QlFsm, - stream_id: StreamId, - out: &mut [u8], -) -> Result { - Ok(fsm.session.peek_stream(stream_id, out)?) +pub fn stream_read(fsm: &QlFsm, stream_id: StreamId) -> Result, QlFsmError> { + Ok(fsm.session.stream_read(stream_id)?) } -pub fn commit_stream_read( +pub fn stream_read_commit( fsm: &mut QlFsm, stream_id: StreamId, len: usize, ) -> Result<(), QlFsmError> { - Ok(fsm.session.commit_stream_read(stream_id, len)?) + Ok(fsm.session.stream_read_commit(stream_id, len)?) } pub fn stream_available_bytes(fsm: &QlFsm, stream_id: StreamId) -> Result { diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index c7bacb69..22091a10 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -33,6 +33,7 @@ use ql_wire::{ CloseCode, CloseTarget, MlDsaPublicKey, MlKemPublicKey, QlCrypto, QlIdentity, QlRecord, SessionCloseBody, StreamClose, StreamId, XID, }; +pub use session::reassembly::BytesIter; use crate::{ replay_cache::ReplayCache, @@ -295,30 +296,22 @@ impl QlFsm { } /// queues bytes for an open stream and returns the accepted count - pub fn write_stream( - &mut self, - stream_id: StreamId, - bytes: &[u8], - ) -> Result { + pub fn write_stream(&mut self, stream_id: StreamId, bytes: &[u8]) -> Result { implementation::write_stream(self, stream_id, bytes) } - /// copies readable bytes from a stream into `out` without consuming them - pub fn peek_stream( - &mut self, - stream_id: StreamId, - out: &mut [u8], - ) -> Result { - implementation::peek_stream(self, stream_id, out) + /// returns the readable stream bytes as borrowed chunks without consuming them + pub fn stream_read(&self, stream_id: StreamId) -> Result, QlFsmError> { + implementation::stream_read(self, stream_id) } - /// marks previously peeked bytes as consumed - pub fn commit_stream_read( + /// marks previously read bytes as consumed + pub fn stream_read_commit( &mut self, stream_id: StreamId, len: usize, ) -> Result<(), QlFsmError> { - implementation::commit_stream_read(self, stream_id, len) + implementation::stream_read_commit(self, stream_id, len) } /// returns how many bytes can be read from a stream diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 0fbeecb0..408c8449 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod reassembly; pub(crate) mod state; #[cfg(test)] @@ -7,13 +8,17 @@ use std::time::{Duration, Instant}; use indexmap::map::Entry; use ql_wire::{ - ByteReassemblyError, CloseCode, CloseTarget, RecordAck, RecordSeq, SessionCloseBody, - SessionFrame, SessionRecord, StreamClose, StreamData, StreamId, StreamWindow, + CloseCode, CloseTarget, RecordAck, RecordSeq, SessionCloseBody, SessionFrame, SessionRecord, + StreamClose, StreamData, StreamId, StreamWindow, }; -use self::state::{ - AckState, InboundState, OutboundState, PendingRecord, ReceivedRecords, ReceiveInsertOutcome, - ReliableFrame, SentRecord, SessionFsmState, StreamParity, StreamRole, StreamState, +use self::{ + reassembly::{ByteReassemblyError, BytesIter}, + state::{ + AckState, InboundState, OutboundState, PendingRecord, ReceiveInsertOutcome, + ReceivedRecords, ReliableFrame, SentRecord, SessionFsmState, StreamParity, StreamRole, + StreamState, + }, }; pub(crate) const SESSION_RECORD_TRACKED_WINDOW: u64 = 256; @@ -185,38 +190,16 @@ impl SessionFsm { Ok(()) } - pub fn peek_stream( - &mut self, - stream_id: StreamId, - out: &mut [u8], - ) -> Result { + pub fn stream_read(&self, stream_id: StreamId) -> Result, StreamError> { let stream = self .state .streams .get(&stream_id) .ok_or(StreamError::MissingStream)?; - if out.is_empty() { - return Ok(0); - } - - let mut written = 0; - for chunk in stream.recv.bytes() { - let remaining = out.len().saturating_sub(written); - if remaining == 0 { - break; - } - let len = remaining.min(chunk.len()); - out[written..written + len].copy_from_slice(&chunk[..len]); - written += len; - if len < chunk.len() { - break; - } - } - - Ok(written) + Ok(stream.recv.bytes()) } - pub fn commit_stream_read( + pub fn stream_read_commit( &mut self, stream_id: StreamId, len: usize, @@ -229,7 +212,10 @@ impl SessionFsm { if len > stream.readable_bytes() { return Err(StreamError::InvalidRead); } - stream.recv.consume(len).map_err(|_| StreamError::InvalidRead)?; + stream + .recv + .consume(len) + .map_err(|_| StreamError::InvalidRead)?; if stream.recv_limit() > stream.advertised_max_offset { stream.pending_window = true; } @@ -438,7 +424,8 @@ impl SessionFsm { } } - while let Some(close) = self.take_pending_session_close(remaining, record.frames.is_empty()) { + while let Some(close) = self.take_pending_session_close(remaining, record.frames.is_empty()) + { let frame = SessionFrame::Close(close.clone()); if !self.push_frame(&mut record, &mut remaining, frame, true) { self.state.pending_control.close = Some(close); @@ -447,7 +434,9 @@ impl SessionFsm { pending.reliable.push(ReliableFrame::Close(close)); } - while let Some(close) = self.take_next_pending_stream_close(remaining, record.frames.is_empty()) { + while let Some(close) = + self.take_next_pending_stream_close(remaining, record.frames.is_empty()) + { let frame = SessionFrame::StreamClose(close.clone()); if !self.push_frame(&mut record, &mut remaining, frame, true) { self.restore_stream_close(close); @@ -498,7 +487,9 @@ impl SessionFsm { pending.reliable.push(ReliableFrame::StreamData(frame)); } - while let Some(frame) = self.take_next_fresh_stream_data(remaining, record.frames.is_empty()) { + while let Some(frame) = + self.take_next_fresh_stream_data(remaining, record.frames.is_empty()) + { if !self.push_frame( &mut record, &mut remaining, @@ -532,11 +523,7 @@ impl SessionFsm { self.state.pending_control.close.take() } - fn take_pending_ping( - &mut self, - remaining: usize, - record_empty: bool, - ) -> Option { + fn take_pending_ping(&mut self, remaining: usize, record_empty: bool) -> Option { if !self.state.pending_control.ping { return None; } @@ -686,7 +673,8 @@ impl SessionFsm { let credit_remaining = stream .peer_max_offset - .saturating_sub(stream.next_send_offset) as usize; + .saturating_sub(stream.next_send_offset) + as usize; let has_empty_fin = matches!(stream.outbound_state, OutboundState::FinQueued) && stream.send_buf.is_empty() && stream.next_send_offset <= stream.peer_max_offset; @@ -699,11 +687,7 @@ impl SessionFsm { } let (_, stream) = self.state.streams.get_index_mut(index).unwrap(); - let payload_len = stream - .send_buf - .len() - .min(max_payload) - .min(credit_remaining); + let payload_len = stream.send_buf.len().min(max_payload).min(credit_remaining); let bytes: Vec = stream.send_buf.drain(..payload_len).collect(); let fin = matches!(stream.outbound_state, OutboundState::FinQueued) && stream.send_buf.is_empty() @@ -1064,7 +1048,10 @@ impl SessionFsm { } } - fn split_stream_data(frame: StreamData, max_payload: usize) -> (StreamData, Option) { + fn split_stream_data( + frame: StreamData, + max_payload: usize, + ) -> (StreamData, Option) { if frame.bytes.len() <= max_payload || frame.bytes.is_empty() { return (frame, None); } diff --git a/ql-wire/src/encrypted/byte_reassembly.rs b/ql-fsm/src/session/reassembly.rs similarity index 97% rename from ql-wire/src/encrypted/byte_reassembly.rs rename to ql-fsm/src/session/reassembly.rs index c6cb306b..c1b6730f 100644 --- a/ql-wire/src/encrypted/byte_reassembly.rs +++ b/ql-fsm/src/session/reassembly.rs @@ -63,6 +63,7 @@ impl ByteReassembly { self.start_offset + self.bytes.len() as u64 } + #[cfg(test)] pub fn final_offset(&self) -> Option { self.final_offset } @@ -71,6 +72,7 @@ impl ByteReassembly { self.max_buffered } + #[cfg(test)] pub fn missing_ranges(&self) -> &[MissingRange] { self.missing.as_slice() } @@ -111,6 +113,7 @@ impl ByteReassembly { } } + #[cfg(test)] pub fn copy_readable(&self) -> Vec { let readable = self.readable_len(); let mut out = Vec::with_capacity(readable); @@ -229,10 +232,7 @@ impl ByteReassembly { }) } - fn push_missing_range( - &mut self, - range: MissingRange, - ) -> Result<(), ByteReassemblyError> { + fn push_missing_range(&mut self, range: MissingRange) -> Result<(), ByteReassemblyError> { if range.start >= range.end { return Ok(()); } @@ -247,11 +247,7 @@ impl ByteReassembly { self.missing.push(range) } - fn validate_overlap( - &self, - offset: u64, - bytes: &[u8], - ) -> Result<(), ByteReassemblyError> { + fn validate_overlap(&self, offset: u64, bytes: &[u8]) -> Result<(), ByteReassemblyError> { let mut gap_index = self.first_gap_index_after(offset); for (index, byte) in bytes.iter().copied().enumerate() { @@ -284,11 +280,7 @@ impl ByteReassembly { } } - fn subtract_missing_range( - &mut self, - start: u64, - end: u64, - ) -> Result<(), ByteReassemblyError> { + fn subtract_missing_range(&mut self, start: u64, end: u64) -> Result<(), ByteReassemblyError> { let first = self.first_gap_index_after(start); if first == self.missing.len() || self.missing[first].start >= end { return Ok(()); @@ -352,7 +344,9 @@ impl ByteReassembly { } fn first_gap_index_after(&self, offset: u64) -> usize { - self.missing.as_slice().partition_point(|range| range.end <= offset) + self.missing + .as_slice() + .partition_point(|range| range.end <= offset) } fn byte_at(&self, offset: u64) -> u8 { diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index dcc7bf6f..c013f40d 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -5,11 +5,11 @@ use std::{ use indexmap::IndexMap; use ql_wire::{ - ByteReassembly, CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionCloseBody, - StreamClose, StreamData, StreamId, XID, + CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionCloseBody, StreamClose, StreamData, + StreamId, XID, }; -use super::{SessionState, SESSION_RECORD_TRACKED_WINDOW}; +use super::{reassembly::ByteReassembly, SessionState, SESSION_RECORD_TRACKED_WINDOW}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum StreamParity { @@ -129,11 +129,7 @@ pub struct StreamState { } impl StreamState { - pub fn new( - role: StreamRole, - _send_buffer_size: usize, - receive_buffer_size: usize, - ) -> Self { + pub fn new(role: StreamRole, _send_buffer_size: usize, receive_buffer_size: usize) -> Self { Self { role, send_buf: VecDeque::new(), @@ -173,7 +169,8 @@ impl StreamState { } pub fn reset_recv(&mut self) { - self.recv = ByteReassembly::with_start_offset(self.recv.start_offset(), self.recv.max_buffered()); + self.recv = + ByteReassembly::with_start_offset(self.recv.start_offset(), self.recv.max_buffered()); } } diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 3422039c..4b66c0ed 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -9,14 +9,16 @@ use super::{state::StreamParity, SessionEvent, SessionFsm, SessionFsmConfig}; fn read_stream_all(fsm: &mut SessionFsm, stream_id: StreamId) -> Vec { let mut out = Vec::new(); - let mut buf = [0u8; 64]; loop { - let read = fsm.peek_stream(stream_id, &mut buf).unwrap(); + let mut read = 0; + for chunk in fsm.stream_read(stream_id).unwrap() { + out.extend_from_slice(chunk); + read += chunk.len(); + } if read == 0 { break; } - out.extend_from_slice(&buf[..read]); - fsm.commit_stream_read(stream_id, read).unwrap(); + fsm.stream_read_commit(stream_id, read).unwrap(); } out } @@ -155,18 +157,25 @@ fn commit_stream_read_is_what_advances_stream_window() { let events = receive_events(&mut fsm, now, data); assert_eq!( events, - vec![SessionEvent::Opened(stream_id), SessionEvent::Readable(stream_id)] + vec![ + SessionEvent::Opened(stream_id), + SessionEvent::Readable(stream_id) + ] ); let first = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); assert!(matches!(first.frames.as_slice(), [SessionFrame::Ack(_)])); - let mut buf = [0u8; 8]; - assert_eq!(fsm.peek_stream(stream_id, &mut buf).unwrap(), 2); + let read = fsm + .stream_read(stream_id) + .unwrap() + .map(|chunk| chunk.len()) + .sum::(); + assert_eq!(read, 2); assert!(next_outbound(&mut fsm, now + Duration::from_millis(2)).is_none()); - fsm.commit_stream_read(stream_id, 2).unwrap(); + fsm.stream_read_commit(stream_id, 2).unwrap(); let second = next_outbound(&mut fsm, now + Duration::from_millis(3)).unwrap(); assert!(matches!( second.frames.as_slice(), diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 7716c089..44624b95 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -226,11 +226,7 @@ fn initiator_waits_for_ready_before_connecting() { }) )); let stream_id = harness.a.fsm.open_stream().unwrap(); - harness - .a - .fsm - .write_stream(stream_id, b"queued") - .unwrap(); + harness.a.fsm.write_stream(stream_id, b"queued").unwrap(); let confirm = harness.next_outbound_a().unwrap(); assert!(matches!(confirm.payload, QlPayload::Confirm(_))); diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 5a832e28..e003cfa0 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -7,14 +7,16 @@ use crate::{session::state::StreamParity, QlFsmEvent, QlSessionEvent}; fn read_stream_all(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { let mut out = Vec::new(); - let mut buf = [0u8; 64]; loop { - let read = fsm.peek_stream(stream_id, &mut buf).unwrap(); + let mut read = 0; + for chunk in fsm.stream_read(stream_id).unwrap() { + out.extend_from_slice(chunk); + read += chunk.len(); + } if read == 0 { break; } - out.extend_from_slice(&buf[..read]); - fsm.commit_stream_read(stream_id, read).unwrap(); + fsm.stream_read_commit(stream_id, read).unwrap(); } out } @@ -108,13 +110,23 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { let stream_id_b = harness.b.fsm.open_stream().unwrap(); assert_ne!(stream_id_a, stream_id_b); - assert!(StreamParity::for_local(harness.a.fsm.identity.xid, harness.b.fsm.identity.xid) - .matches(stream_id_a)); - assert!(StreamParity::for_local(harness.b.fsm.identity.xid, harness.a.fsm.identity.xid) - .matches(stream_id_b)); + assert!( + StreamParity::for_local(harness.a.fsm.identity.xid, harness.b.fsm.identity.xid) + .matches(stream_id_a) + ); + assert!( + StreamParity::for_local(harness.b.fsm.identity.xid, harness.a.fsm.identity.xid) + .matches(stream_id_b) + ); - assert_eq!(harness.a.fsm.write_stream(stream_id_a, b"from-a").unwrap(), 6); - assert_eq!(harness.b.fsm.write_stream(stream_id_b, b"from-b").unwrap(), 6); + assert_eq!( + harness.a.fsm.write_stream(stream_id_a, b"from-a").unwrap(), + 6 + ); + assert_eq!( + harness.b.fsm.write_stream(stream_id_b, b"from-b").unwrap(), + 6 + ); harness.pump(); @@ -369,5 +381,8 @@ fn session_records_contain_ack_frames_after_delivery() { .session_key() .unwrap(); let ack_record = decrypt_record(&harness.a.crypto, &ack, &session_key); - assert!(matches!(ack_record.frames.as_slice(), [SessionFrame::Ack(_)])); + assert!(matches!( + ack_record.frames.as_slice(), + [SessionFrame::Ack(_)] + )); } diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index fc1d7a27..1e87bac5 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -12,14 +12,12 @@ use crate::{ }; mod ack; -mod byte_reassembly; mod close; mod stream_close; mod stream_data; mod stream_window; pub use ack::*; -pub use byte_reassembly::*; pub use close::*; pub use stream_close::*; pub use stream_data::*; diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index 665ef02a..5e3bc0ac 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -6,7 +6,10 @@ use zerocopy::{ }; use super::StreamId; -use crate::{codec::{parse, read_byte, U16Le, U32Le}, WireError}; +use crate::{ + codec::{parse, read_byte, U16Le, U32Le}, + WireError, +}; /// aborts one or both directions of a stream with a close code. #[derive(Debug, Clone, PartialEq, Eq)] @@ -60,8 +63,7 @@ pub struct StreamCloseWire { } impl StreamClose { - pub const MIN_WIRE_SIZE: usize = - size_of::() + size_of::() + size_of::(); + pub const MIN_WIRE_SIZE: usize = size_of::() + size_of::() + size_of::(); pub fn parse(bytes: B) -> Result, WireError> { if bytes.len() < Self::MIN_WIRE_SIZE { diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs index ac69d012..5507ab07 100644 --- a/ql-wire/src/encrypted/stream_data.rs +++ b/ql-wire/src/encrypted/stream_data.rs @@ -1,11 +1,12 @@ use std::mem::size_of; -use zerocopy::{ - byte_slice::ByteSlice, FromBytes, Immutable, KnownLayout, Ref, Unaligned, -}; +use zerocopy::{byte_slice::ByteSlice, FromBytes, Immutable, KnownLayout, Ref, Unaligned}; use super::StreamId; -use crate::{codec::{parse, U32Le, U64Le}, WireError}; +use crate::{ + codec::{parse, U32Le, U64Le}, + WireError, +}; /// carries bytes for a stream and may finish that sending direction. #[derive(Debug, Clone, PartialEq, Eq)] @@ -26,8 +27,7 @@ pub struct StreamDataWire { } impl StreamData { - pub const MIN_WIRE_SIZE: usize = - size_of::() + size_of::() + size_of::(); + pub const MIN_WIRE_SIZE: usize = size_of::() + size_of::() + size_of::(); pub fn parse(bytes: B) -> Result, WireError> { if bytes.len() < Self::MIN_WIRE_SIZE { diff --git a/ql-wire/src/encrypted/stream_window.rs b/ql-wire/src/encrypted/stream_window.rs index 2f6e2c4a..764b3ff7 100644 --- a/ql-wire/src/encrypted/stream_window.rs +++ b/ql-wire/src/encrypted/stream_window.rs @@ -5,7 +5,10 @@ use zerocopy::{ }; use super::StreamId; -use crate::{codec::{parse, push_value, U32Le, U64Le}, WireError}; +use crate::{ + codec::{parse, push_value, U32Le, U64Le}, + WireError, +}; /// advertises the highest byte offset the peer may send on a stream. #[derive(Debug, Clone, PartialEq, Eq)] From 482e0ad8968637401cd107493eb56b599988f243 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 28 Mar 2026 01:09:07 -0400 Subject: [PATCH 040/304] ql-fsm: rename to StreamAssembler --- ql-fsm/src/implementation/fsm.rs | 4 +- ql-fsm/src/lib.rs | 4 +- ql-fsm/src/session/mod.rs | 20 ++++----- ql-fsm/src/session/reassembly.rs | 76 ++++++++++++++++---------------- ql-fsm/src/session/state.rs | 8 ++-- 5 files changed, 56 insertions(+), 56 deletions(-) diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index 01d85f82..c16e109a 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -3,7 +3,7 @@ use std::time::Instant; use ql_wire::{self as wire, CloseCode, CloseTarget, Nonce, QlCrypto, QlPayloadRef, StreamId}; use crate::{ - BytesIter, OutboundWrite, QlFsm, QlFsmError, QlFsmEvent, QlSessionEvent, SessionWriteId, + OutboundWrite, QlFsm, QlFsmError, QlFsmEvent, QlSessionEvent, SessionWriteId, StreamReadIter, }; pub fn receive( @@ -193,7 +193,7 @@ pub fn write_stream( Ok(fsm.session.write_stream(stream_id, bytes)?) } -pub fn stream_read(fsm: &QlFsm, stream_id: StreamId) -> Result, QlFsmError> { +pub fn stream_read(fsm: &QlFsm, stream_id: StreamId) -> Result, QlFsmError> { Ok(fsm.session.stream_read(stream_id)?) } diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 22091a10..1860990c 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -33,7 +33,7 @@ use ql_wire::{ CloseCode, CloseTarget, MlDsaPublicKey, MlKemPublicKey, QlCrypto, QlIdentity, QlRecord, SessionCloseBody, StreamClose, StreamId, XID, }; -pub use session::reassembly::BytesIter; +pub use session::reassembly::StreamReadIter; use crate::{ replay_cache::ReplayCache, @@ -301,7 +301,7 @@ impl QlFsm { } /// returns the readable stream bytes as borrowed chunks without consuming them - pub fn stream_read(&self, stream_id: StreamId) -> Result, QlFsmError> { + pub fn stream_read(&self, stream_id: StreamId) -> Result, QlFsmError> { implementation::stream_read(self, stream_id) } diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 408c8449..372d92b7 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -13,7 +13,7 @@ use ql_wire::{ }; use self::{ - reassembly::{ByteReassemblyError, BytesIter}, + reassembly::{StreamAssemblerError, StreamReadIter}, state::{ AckState, InboundState, OutboundState, PendingRecord, ReceiveInsertOutcome, ReceivedRecords, ReliableFrame, SentRecord, SessionFsmState, StreamParity, StreamRole, @@ -190,7 +190,7 @@ impl SessionFsm { Ok(()) } - pub fn stream_read(&self, stream_id: StreamId) -> Result, StreamError> { + pub fn stream_read(&self, stream_id: StreamId) -> Result, StreamError> { let stream = self .state .streams @@ -896,13 +896,13 @@ impl SessionFsm { self.try_reap_stream(stream_id); Ok(()) } - Err(ByteReassemblyError::ConflictingOverlap) - | Err(ByteReassemblyError::OutOfWindow) - | Err(ByteReassemblyError::InconsistentFinalOffset) - | Err(ByteReassemblyError::FinalOffsetBeforeBufferedData) - | Err(ByteReassemblyError::BeyondFinalOffset) - | Err(ByteReassemblyError::TooManyMissingRanges) - | Err(ByteReassemblyError::OffsetOverflow) => { + Err(StreamAssemblerError::ConflictingOverlap) + | Err(StreamAssemblerError::OutOfWindow) + | Err(StreamAssemblerError::InconsistentFinalOffset) + | Err(StreamAssemblerError::FinalOffsetBeforeBufferedData) + | Err(StreamAssemblerError::BeyondFinalOffset) + | Err(StreamAssemblerError::TooManyMissingRanges) + | Err(StreamAssemblerError::OffsetOverflow) => { self.fail_session( SessionCloseBody { code: CloseCode::PROTOCOL, @@ -911,7 +911,7 @@ impl SessionFsm { ); Err(()) } - Err(ByteReassemblyError::ConsumeBeyondReadable) => unreachable!(), + Err(StreamAssemblerError::ConsumeBeyondReadable) => unreachable!(), } } diff --git a/ql-fsm/src/session/reassembly.rs b/ql-fsm/src/session/reassembly.rs index c1b6730f..c3120e72 100644 --- a/ql-fsm/src/session/reassembly.rs +++ b/ql-fsm/src/session/reassembly.rs @@ -2,7 +2,7 @@ use std::collections::VecDeque; /// reassembles one stream direction from out-of-order byte ranges. #[derive(Debug, Clone, PartialEq, Eq)] -pub struct ByteReassembly { +pub struct StreamAssembler { start_offset: u64, bytes: VecDeque, missing: MissingRanges, @@ -23,7 +23,7 @@ pub struct InsertOutcome { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ByteReassemblyError { +pub enum StreamAssemblerError { OffsetOverflow, OutOfWindow, InconsistentFinalOffset, @@ -35,12 +35,12 @@ pub enum ByteReassemblyError { } #[derive(Debug, Clone, Copy)] -pub struct BytesIter<'a> { +pub struct StreamReadIter<'a> { front: Option<&'a [u8]>, back: Option<&'a [u8]>, } -impl ByteReassembly { +impl StreamAssembler { pub fn new(max_buffered: usize) -> Self { Self::with_start_offset(0, max_buffered) } @@ -90,10 +90,10 @@ impl ByteReassembly { } } - pub fn bytes(&self) -> BytesIter<'_> { + pub fn bytes(&self) -> StreamReadIter<'_> { let readable = self.readable_len(); if readable == 0 { - return BytesIter { + return StreamReadIter { front: None, back: None, }; @@ -101,12 +101,12 @@ impl ByteReassembly { let (front, back) = self.bytes.as_slices(); if readable <= front.len() { - BytesIter { + StreamReadIter { front: Some(&front[..readable]), back: None, } } else { - BytesIter { + StreamReadIter { front: Some(front), back: Some(&back[..readable - front.len()]), } @@ -133,10 +133,10 @@ impl ByteReassembly { offset: u64, fin: bool, bytes: &[u8], - ) -> Result { + ) -> Result { let end = offset .checked_add(bytes.len() as u64) - .ok_or(ByteReassemblyError::OffsetOverflow)?; + .ok_or(StreamAssemblerError::OffsetOverflow)?; let was_complete = self.is_complete(); let old_readable = self.readable_len(); @@ -146,7 +146,7 @@ impl ByteReassembly { } if let Some(final_offset) = self.final_offset { if end > final_offset { - return Err(ByteReassemblyError::BeyondFinalOffset); + return Err(StreamAssemblerError::BeyondFinalOffset); } } @@ -171,10 +171,10 @@ impl ByteReassembly { Ok(self.insert_outcome(was_complete, old_readable)) } - pub fn consume(&mut self, len: usize) -> Result<(), ByteReassemblyError> { + pub fn consume(&mut self, len: usize) -> Result<(), StreamAssemblerError> { let readable = self.readable_len(); if len > readable { - return Err(ByteReassemblyError::ConsumeBeyondReadable); + return Err(StreamAssemblerError::ConsumeBeyondReadable); } self.bytes.drain(..len); @@ -192,33 +192,33 @@ impl ByteReassembly { fn set_or_validate_final_offset( &mut self, final_offset: u64, - ) -> Result<(), ByteReassemblyError> { + ) -> Result<(), StreamAssemblerError> { if let Some(existing) = self.final_offset { return if existing == final_offset { Ok(()) } else { - Err(ByteReassemblyError::InconsistentFinalOffset) + Err(StreamAssemblerError::InconsistentFinalOffset) }; } let buffered_end = self.buffered_end_offset(); if final_offset < buffered_end { - return Err(ByteReassemblyError::FinalOffsetBeforeBufferedData); + return Err(StreamAssemblerError::FinalOffsetBeforeBufferedData); } self.final_offset = Some(final_offset); Ok(()) } - fn ensure_within_window(&self, end: u64) -> Result<(), ByteReassemblyError> { + fn ensure_within_window(&self, end: u64) -> Result<(), StreamAssemblerError> { let attempted = end.saturating_sub(self.start_offset); if attempted > self.max_buffered as u64 { - return Err(ByteReassemblyError::OutOfWindow); + return Err(StreamAssemblerError::OutOfWindow); } Ok(()) } - fn ensure_buffered(&mut self, end: u64) -> Result<(), ByteReassemblyError> { + fn ensure_buffered(&mut self, end: u64) -> Result<(), StreamAssemblerError> { let buffered_end = self.buffered_end_offset(); if end <= buffered_end { return Ok(()); @@ -232,7 +232,7 @@ impl ByteReassembly { }) } - fn push_missing_range(&mut self, range: MissingRange) -> Result<(), ByteReassemblyError> { + fn push_missing_range(&mut self, range: MissingRange) -> Result<(), StreamAssemblerError> { if range.start >= range.end { return Ok(()); } @@ -247,7 +247,7 @@ impl ByteReassembly { self.missing.push(range) } - fn validate_overlap(&self, offset: u64, bytes: &[u8]) -> Result<(), ByteReassemblyError> { + fn validate_overlap(&self, offset: u64, bytes: &[u8]) -> Result<(), StreamAssemblerError> { let mut gap_index = self.first_gap_index_after(offset); for (index, byte) in bytes.iter().copied().enumerate() { @@ -265,7 +265,7 @@ impl ByteReassembly { } if self.byte_at(absolute) != byte { - return Err(ByteReassemblyError::ConflictingOverlap); + return Err(StreamAssemblerError::ConflictingOverlap); } } @@ -280,7 +280,7 @@ impl ByteReassembly { } } - fn subtract_missing_range(&mut self, start: u64, end: u64) -> Result<(), ByteReassemblyError> { + fn subtract_missing_range(&mut self, start: u64, end: u64) -> Result<(), StreamAssemblerError> { let first = self.first_gap_index_after(start); if first == self.missing.len() || self.missing[first].start >= end { return Ok(()); @@ -355,7 +355,7 @@ impl ByteReassembly { } } -impl<'a> Iterator for BytesIter<'a> { +impl<'a> Iterator for StreamReadIter<'a> { type Item = &'a [u8]; fn next(&mut self) -> Option { @@ -413,18 +413,18 @@ impl MissingRanges { } } - fn push(&mut self, range: MissingRange) -> Result<(), ByteReassemblyError> { + fn push(&mut self, range: MissingRange) -> Result<(), StreamAssemblerError> { if self.len == N { - return Err(ByteReassemblyError::TooManyMissingRanges); + return Err(StreamAssemblerError::TooManyMissingRanges); } self.ranges[self.len] = range; self.len += 1; Ok(()) } - fn insert(&mut self, index: usize, range: MissingRange) -> Result<(), ByteReassemblyError> { + fn insert(&mut self, index: usize, range: MissingRange) -> Result<(), StreamAssemblerError> { if self.len == N { - return Err(ByteReassemblyError::TooManyMissingRanges); + return Err(StreamAssemblerError::TooManyMissingRanges); } for i in (index..self.len).rev() { self.ranges[i + 1] = self.ranges[i]; @@ -476,11 +476,11 @@ impl std::ops::IndexMut for MissingRanges { #[cfg(test)] mod tests { - use super::{ByteReassembly, ByteReassemblyError, InsertOutcome, MissingRange}; + use super::{InsertOutcome, MissingRange, StreamAssembler, StreamAssemblerError}; #[test] fn contiguous_insert_becomes_readable_and_complete() { - let mut assembler = ByteReassembly::<8>::new(64); + let mut assembler = StreamAssembler::<8>::new(64); let outcome = assembler.insert(0, true, b"hello").unwrap(); @@ -500,7 +500,7 @@ mod tests { #[test] fn out_of_order_insert_tracks_missing_ranges_until_gap_is_filled() { - let mut assembler = ByteReassembly::<8>::new(64); + let mut assembler = StreamAssembler::<8>::new(64); let first = assembler.insert(5, true, b" world").unwrap(); assert_eq!( @@ -531,7 +531,7 @@ mod tests { #[test] fn duplicate_insert_is_ignored_if_bytes_match() { - let mut assembler = ByteReassembly::<8>::new(64); + let mut assembler = StreamAssembler::<8>::new(64); assembler.insert(0, false, b"hello").unwrap(); let duplicate = assembler.insert(0, false, b"hello").unwrap(); @@ -548,17 +548,17 @@ mod tests { #[test] fn conflicting_overlap_is_rejected() { - let mut assembler = ByteReassembly::<8>::new(64); + let mut assembler = StreamAssembler::<8>::new(64); assembler.insert(0, false, b"abcdef").unwrap(); let error = assembler.insert(3, false, b"xyz").unwrap_err(); - assert_eq!(error, ByteReassemblyError::ConflictingOverlap); + assert_eq!(error, StreamAssemblerError::ConflictingOverlap); } #[test] fn consume_advances_start_offset_and_trims_old_prefix() { - let mut assembler = ByteReassembly::<8>::new(64); + let mut assembler = StreamAssembler::<8>::new(64); assembler.insert(0, false, b"abcd").unwrap(); assembler.consume(2).unwrap(); @@ -580,18 +580,18 @@ mod tests { #[test] fn insert_rejects_when_missing_range_budget_is_exhausted() { - let mut assembler = ByteReassembly::<2>::new(64); + let mut assembler = StreamAssembler::<2>::new(64); assembler.insert(1, false, b"a").unwrap(); assembler.insert(3, false, b"b").unwrap(); let error = assembler.insert(5, false, b"c").unwrap_err(); - assert_eq!(error, ByteReassemblyError::TooManyMissingRanges); + assert_eq!(error, StreamAssemblerError::TooManyMissingRanges); } #[test] fn insert_can_fill_multiple_gaps_without_rebuilding_state() { - let mut assembler = ByteReassembly::<8>::new(64); + let mut assembler = StreamAssembler::<8>::new(64); assembler.insert(0, false, b"ab").unwrap(); assembler.insert(4, false, b"ef").unwrap(); diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index c013f40d..c312b9dd 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -9,7 +9,7 @@ use ql_wire::{ StreamId, XID, }; -use super::{reassembly::ByteReassembly, SessionState, SESSION_RECORD_TRACKED_WINDOW}; +use super::{reassembly::StreamAssembler, SessionState, SESSION_RECORD_TRACKED_WINDOW}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum StreamParity { @@ -123,7 +123,7 @@ pub struct StreamState { pub peer_max_offset: u64, pub outbound_state: OutboundState, pub inbound_state: InboundState, - pub recv: ByteReassembly, + pub recv: StreamAssembler, pub advertised_max_offset: u64, pub pending_window: bool, } @@ -140,7 +140,7 @@ impl StreamState { peer_max_offset: receive_buffer_size as u64, outbound_state: OutboundState::Open, inbound_state: InboundState::Open, - recv: ByteReassembly::new(receive_buffer_size), + recv: StreamAssembler::new(receive_buffer_size), advertised_max_offset: receive_buffer_size as u64, pending_window: false, } @@ -170,7 +170,7 @@ impl StreamState { pub fn reset_recv(&mut self) { self.recv = - ByteReassembly::with_start_offset(self.recv.start_offset(), self.recv.max_buffered()); + StreamAssembler::with_start_offset(self.recv.start_offset(), self.recv.max_buffered()); } } From 185a45ed018ea123174d47a60fa4095745ab7734 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 28 Mar 2026 01:18:47 -0400 Subject: [PATCH 041/304] ql-fsm: outbound record --- ql-fsm/src/session/mod.rs | 133 ++++++++++++++++++------------------ ql-fsm/src/session/state.rs | 12 +--- 2 files changed, 68 insertions(+), 77 deletions(-) diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 372d92b7..f157358f 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -15,9 +15,8 @@ use ql_wire::{ use self::{ reassembly::{StreamAssemblerError, StreamReadIter}, state::{ - AckState, InboundState, OutboundState, PendingRecord, ReceiveInsertOutcome, - ReceivedRecords, ReliableFrame, SentRecord, SessionFsmState, StreamParity, StreamRole, - StreamState, + AckState, InboundState, OutboundRecord, OutboundState, ReceiveInsertOutcome, + ReceivedRecords, ReliableFrame, SessionFsmState, StreamParity, StreamRole, StreamState, }, }; @@ -99,8 +98,7 @@ impl SessionFsm { next_stream_ordinal: 0, next_record_seq: RecordSeq(0), next_write_id: 0, - issued_records: Default::default(), - sent_records: Default::default(), + outbound_records: Default::default(), received_records: ReceivedRecords::default(), ack_state: AckState::Idle, pending_control: Default::default(), @@ -301,24 +299,29 @@ impl SessionFsm { pub fn confirm_write(&mut self, now: Instant, write_id: u64) { self.state.now = now; - let Some(pending) = self.state.issued_records.shift_remove(&write_id) else { + let Some(record) = self.state.outbound_records.get_mut(&write_id) else { return; }; + if record.sent_at.is_some() { + return; + } self.state.last_activity_at = now; - self.state.sent_records.insert( - pending.seq.0, - SentRecord { - pending, - sent_at: now, - }, - ); + record.sent_at = Some(now); } pub fn reject_write(&mut self, write_id: u64) { - let Some(pending) = self.state.issued_records.shift_remove(&write_id) else { + if self + .state + .outbound_records + .get(&write_id) + .is_some_and(|record| record.sent_at.is_some()) + { + return; + } + let Some(record) = self.state.outbound_records.shift_remove(&write_id) else { return; }; - self.restore_pending_record(pending); + self.restore_outbound_record(record); } pub fn on_timer(&mut self, now: Instant, mut emit: impl FnMut(SessionEvent)) { @@ -356,9 +359,13 @@ impl SessionFsm { }; let retransmit_deadline = self .state - .sent_records + .outbound_records .values() - .map(|record| record.sent_at + self.config.retransmit_timeout) + .filter_map(|record| { + record + .sent_at + .map(|sent_at| sent_at + self.config.retransmit_timeout) + }) .min(); let keepalive_deadline = (self.state.session_state == SessionState::Open && !self.config.keepalive_interval.is_zero() @@ -395,7 +402,7 @@ impl SessionFsm { let built = self.build_next_record()?; let write_id = self.state.next_write_id; self.state.next_write_id = self.state.next_write_id.wrapping_add(1); - self.state.issued_records.insert(write_id, built.pending); + self.state.outbound_records.insert(write_id, built.outbound); Some((write_id, built.record)) } @@ -405,12 +412,13 @@ impl SessionFsm { seq, frames: Vec::new(), }; - let mut pending = PendingRecord { + let mut outbound = OutboundRecord { seq, reliable: Vec::new(), ack_included: false, ping_included: false, window_updates: Vec::new(), + sent_at: None, }; let mut remaining = self.config.record_size.saturating_sub(8); @@ -418,7 +426,7 @@ impl SessionFsm { if let Some(ack) = self.state.received_records.ack() { let frame = SessionFrame::Ack(ack); if self.push_frame(&mut record, &mut remaining, frame, true) { - pending.ack_included = true; + outbound.ack_included = true; self.state.ack_state = AckState::Idle; } } @@ -431,7 +439,7 @@ impl SessionFsm { self.state.pending_control.close = Some(close); break; } - pending.reliable.push(ReliableFrame::Close(close)); + outbound.reliable.push(ReliableFrame::Close(close)); } while let Some(close) = @@ -442,12 +450,12 @@ impl SessionFsm { self.restore_stream_close(close); break; } - pending.reliable.push(ReliableFrame::StreamClose(close)); + outbound.reliable.push(ReliableFrame::StreamClose(close)); } if let Some(ping) = self.take_pending_ping(remaining, record.frames.is_empty()) { if self.push_frame(&mut record, &mut remaining, ping, true) { - pending.ping_included = true; + outbound.ping_included = true; } else { self.state.pending_control.ping = true; } @@ -469,7 +477,7 @@ impl SessionFsm { } break; } - pending.window_updates.push((stream_id, maximum_offset)); + outbound.window_updates.push((stream_id, maximum_offset)); } while let Some(frame) = @@ -484,7 +492,7 @@ impl SessionFsm { self.restore_stream_data(frame); break; } - pending.reliable.push(ReliableFrame::StreamData(frame)); + outbound.reliable.push(ReliableFrame::StreamData(frame)); } while let Some(frame) = @@ -499,7 +507,7 @@ impl SessionFsm { self.restore_stream_data(frame); break; } - pending.reliable.push(ReliableFrame::StreamData(frame)); + outbound.reliable.push(ReliableFrame::StreamData(frame)); } if record.frames.is_empty() { @@ -507,7 +515,7 @@ impl SessionFsm { } self.state.next_record_seq = RecordSeq(self.state.next_record_seq.0.saturating_add(1)); - Some(BuiltRecord { record, pending }) + Some(BuiltRecord { record, outbound }) } fn take_pending_session_close( @@ -723,17 +731,21 @@ impl SessionFsm { fn process_record_ack(&mut self, ack: RecordAck, emit: &mut impl FnMut(SessionEvent)) { let acked: Vec = self .state - .sent_records - .keys() - .copied() - .filter(|seq| Self::ack_covers(&ack, RecordSeq(*seq))) + .outbound_records + .iter() + .filter_map(|(write_id, record)| { + record + .sent_at + .filter(|_| Self::ack_covers(&ack, record.seq)) + .map(|_| *write_id) + }) .collect(); - for seq in acked { - let Some(sent) = self.state.sent_records.shift_remove(&seq) else { + for write_id in acked { + let Some(record) = self.state.outbound_records.shift_remove(&write_id) else { continue; }; - for frame in sent.pending.reliable { + for frame in record.reliable { self.acknowledge_reliable_frame(frame, emit); } } @@ -770,36 +782,39 @@ impl SessionFsm { fn collect_timeouts(&mut self) { let expired: Vec = self .state - .sent_records + .outbound_records .iter() - .filter_map(|(seq, record)| { - (record.sent_at + self.config.retransmit_timeout <= self.state.now).then_some(*seq) + .filter_map(|(write_id, record)| { + record + .sent_at + .filter(|sent_at| *sent_at + self.config.retransmit_timeout <= self.state.now) + .map(|_| *write_id) }) .collect(); - for seq in expired { - let Some(sent) = self.state.sent_records.shift_remove(&seq) else { + for write_id in expired { + let Some(record) = self.state.outbound_records.shift_remove(&write_id) else { continue; }; - self.restore_pending_record(sent.pending); + self.restore_outbound_record(record); } } - fn restore_pending_record(&mut self, pending: PendingRecord) { - if pending.ack_included { + fn restore_outbound_record(&mut self, record: OutboundRecord) { + if record.ack_included { self.schedule_ack(true); } - if pending.ping_included { + if record.ping_included { self.state.pending_control.ping = true; } - for (stream_id, maximum_offset) in pending.window_updates { + for (stream_id, maximum_offset) in record.window_updates { if let Some(stream) = self.state.streams.get_mut(&stream_id) { if stream.recv_limit() >= maximum_offset { stream.pending_window = true; } } } - for frame in pending.reliable { + for frame in record.reliable { self.requeue_reliable_frame(frame); } } @@ -1005,8 +1020,7 @@ impl SessionFsm { } self.state.session_state = SessionState::Closed; - self.state.issued_records.clear(); - self.state.sent_records.clear(); + self.state.outbound_records.clear(); self.clear_streams(); self.state.pending_control = Default::default(); emit(SessionEvent::SessionClosed(close)); @@ -1121,7 +1135,7 @@ impl SessionFsm { } fn stream_is_reapable(&self, stream_id: StreamId, stream: &StreamState) -> bool { - let issued_refs_stream = self.state.issued_records.values().any(|record| { + let outbound_refs_stream = self.state.outbound_records.values().any(|record| { record.window_updates.iter().any(|(id, _)| *id == stream_id) || record.reliable.iter().any(|frame| match frame { ReliableFrame::StreamData(frame) => frame.stream_id == stream_id, @@ -1129,23 +1143,7 @@ impl SessionFsm { ReliableFrame::Close(_) => false, }) }); - if issued_refs_stream { - return false; - } - - let sent_refs_stream = self.state.sent_records.values().any(|record| { - record - .pending - .window_updates - .iter() - .any(|(id, _)| *id == stream_id) - || record.pending.reliable.iter().any(|frame| match frame { - ReliableFrame::StreamData(frame) => frame.stream_id == stream_id, - ReliableFrame::StreamClose(frame) => frame.stream_id == stream_id, - ReliableFrame::Close(_) => false, - }) - }); - if sent_refs_stream { + if outbound_refs_stream { return false; } @@ -1201,8 +1199,7 @@ impl SessionFsm { } self.state.session_state = SessionState::Closed; - self.state.issued_records.clear(); - self.state.sent_records.clear(); + self.state.outbound_records.clear(); self.state.pending_control = Default::default(); self.state.pending_control.close = Some(close.clone()); self.clear_streams(); @@ -1217,5 +1214,5 @@ impl SessionFsm { struct BuiltRecord { record: SessionRecord, - pending: PendingRecord, + outbound: OutboundRecord, } diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index c312b9dd..21e0892b 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -182,18 +182,13 @@ pub enum ReliableFrame { } #[derive(Debug, Clone)] -pub struct PendingRecord { +pub struct OutboundRecord { pub seq: RecordSeq, pub reliable: Vec, pub ack_included: bool, pub ping_included: bool, pub window_updates: Vec<(StreamId, u64)>, -} - -#[derive(Debug, Clone)] -pub struct SentRecord { - pub pending: PendingRecord, - pub sent_at: Instant, + pub sent_at: Option, } #[derive(Debug, Clone, Default)] @@ -286,8 +281,7 @@ pub struct SessionFsmState { pub next_stream_ordinal: u32, pub next_record_seq: RecordSeq, pub next_write_id: u64, - pub issued_records: IndexMap, - pub sent_records: IndexMap, + pub outbound_records: IndexMap, pub received_records: ReceivedRecords, pub ack_state: AckState, pub pending_control: PendingSessionControl, From 595fc19e49e1b9a4a88d48d463980849dc1ee255 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 28 Mar 2026 10:35:48 -0400 Subject: [PATCH 042/304] ql-wire: bytes and ref --- ql-wire/src/bytes.rs | 112 ++++++++++++++++++++++++++++++ ql-wire/src/lib.rs | 2 + ql-wire/src/ref.rs | 162 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 276 insertions(+) create mode 100644 ql-wire/src/bytes.rs create mode 100644 ql-wire/src/ref.rs diff --git a/ql-wire/src/bytes.rs b/ql-wire/src/bytes.rs new file mode 100644 index 00000000..fdd98008 --- /dev/null +++ b/ql-wire/src/bytes.rs @@ -0,0 +1,112 @@ +use core::ops::{Deref, DerefMut}; + +/// A mutable or immutable reference to bytes. +/// +/// # Safety +/// +/// Implementations must provide stable dereferences. Given some `b: B`, repeated +/// calls to `Deref::deref(&b)` must always produce a byte slice with the same +/// address and length for as long as `b` is alive. If `B: ByteSliceMut`, +/// repeated calls to `DerefMut::deref_mut(&mut b)` must provide the same +/// guarantee. +pub unsafe trait ByteSlice: Deref + Sized {} + +/// A mutable reference to bytes. +pub trait ByteSliceMut: ByteSlice + DerefMut {} + +impl ByteSliceMut for B where B: ByteSlice + DerefMut {} + +/// A [`ByteSlice`] that can be split in two. +/// +/// # Safety +/// +/// Implementations must guarantee that `split_at` and `split_at_unchecked` +/// correctly split the underlying bytes. If `self.deref()` yields a slice with +/// address `addr` and length `len`, then splitting at `mid <= len` must return +/// `(left, right)` such that: +/// - `left` starts at `addr` and has length `mid` +/// - `right` starts at `addr + mid` and has length `len - mid` +pub unsafe trait SplitByteSlice: ByteSlice { + #[inline] + fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { + if mid <= self.len() { + // SAFETY: We just proved that `mid` is in bounds. + unsafe { Ok(self.split_at_unchecked(mid)) } + } else { + Err(self) + } + } + + /// Splits the underlying bytes at `mid`. + /// + /// # Safety + /// + /// `mid` must be less than or equal to the underlying slice length. + unsafe fn split_at_unchecked(self, mid: usize) -> (Self, Self); +} + +/// A shorthand for [`SplitByteSlice`] and [`ByteSliceMut`]. +pub trait SplitByteSliceMut: SplitByteSlice + ByteSliceMut {} + +impl SplitByteSliceMut for B where B: SplitByteSlice + ByteSliceMut {} + +// SAFETY: `&[u8]` dereferences to the same slice for the lifetime of the +// reference. +unsafe impl ByteSlice for &[u8] {} + +// SAFETY: `&mut [u8]` dereferences to the same slice for the lifetime of the +// reference. +unsafe impl ByteSlice for &mut [u8] {} + +// SAFETY: These methods delegate to the standard library slice split methods, +// which return the exact left and right sub-slices at `mid`. +unsafe impl SplitByteSlice for &[u8] { + #[inline] + unsafe fn split_at_unchecked(self, mid: usize) -> (Self, Self) { + <[u8]>::split_at(self, mid) + } +} + +// SAFETY: These methods delegate to the standard library slice split methods, +// which return the exact left and right sub-slices at `mid`. +unsafe impl SplitByteSlice for &mut [u8] { + #[inline] + unsafe fn split_at_unchecked(self, mid: usize) -> (Self, Self) { + <[u8]>::split_at_mut(self, mid) + } +} + +#[cfg(test)] +mod tests { + use super::{SplitByteSlice, SplitByteSliceMut}; + + #[test] + fn shared_slice_split_at() { + let bytes: &[u8] = b"abcdef"; + let (left, right) = SplitByteSlice::split_at(bytes, 2).unwrap(); + assert_eq!(left, b"ab"); + assert_eq!(right, b"cdef"); + } + + #[test] + fn mutable_slice_split_at() { + let mut bytes = *b"abcdef"; + let (left, right) = SplitByteSlice::split_at(&mut bytes[..], 2).unwrap(); + assert_eq!(left, b"ab"); + assert_eq!(right, b"cdef"); + } + + #[test] + fn mutable_split_trait_is_implemented() { + fn assert_split_mut(_value: T) {} + + let mut bytes = [0u8; 4]; + assert_split_mut(&mut bytes[..]); + } + + #[test] + fn split_at_rejects_out_of_bounds_index() { + let bytes: &[u8] = b"abcdef"; + assert!(SplitByteSlice::split_at(bytes, 7).is_err()); + } +} diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs index 78bd38ff..2c7151f2 100644 --- a/ql-wire/src/lib.rs +++ b/ql-wire/src/lib.rs @@ -7,6 +7,7 @@ pub type Ref<'a, T> = zerocopy::Ref<&'a [u8], T>; pub type RefMut<'a, T> = zerocopy::Ref<&'a mut [u8], T>; +mod bytes; mod codec; mod control; mod encrypted; @@ -22,6 +23,7 @@ mod record; mod unpair; mod xid; +pub use bytes::*; pub use control::*; pub use encrypted::*; pub use encrypted_message::*; diff --git a/ql-wire/src/ref.rs b/ql-wire/src/ref.rs new file mode 100644 index 00000000..45377642 --- /dev/null +++ b/ql-wire/src/ref.rs @@ -0,0 +1,162 @@ +use core::{ + fmt, + marker::PhantomData, + ops::{Deref, DerefMut}, +}; + +use crate::{ByteSlice, ByteSliceMut}; + +/// Typed bytes backed by a mutable or immutable byte slice. +/// +/// Unlike `zerocopy::Ref`, this type does not perform any size, alignment, or +/// layout validation for `T`. `T` is only a marker carried alongside the bytes. +pub struct Ref { + bytes: B, + _marker: PhantomData<*const T>, +} + +impl Ref { + pub const fn new(bytes: B) -> Self { + Self { + bytes, + _marker: PhantomData, + } + } + + pub fn into_bytes(self) -> B { + self.bytes + } + + pub fn retag(self) -> Ref { + Ref::new(self.bytes) + } +} + +impl Ref { + pub fn bytes(&self) -> &[u8] { + self.bytes.deref() + } + + pub fn len(&self) -> usize { + self.bytes.len() + } + + pub fn is_empty(&self) -> bool { + self.bytes.is_empty() + } + + pub fn reborrow(&self) -> Ref<&[u8], T> { + Ref::new(self.bytes()) + } +} + +impl Ref { + pub fn bytes_mut(&mut self) -> &mut [u8] { + self.bytes.deref_mut() + } + + pub fn reborrow_mut(&mut self) -> Ref<&mut [u8], T> { + Ref::new(self.bytes_mut()) + } +} + +impl Deref for Ref { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.bytes() + } +} + +impl DerefMut for Ref { + fn deref_mut(&mut self) -> &mut Self::Target { + self.bytes_mut() + } +} + +impl Clone for Ref { + fn clone(&self) -> Self { + Self::new(self.bytes.clone()) + } +} + +impl Copy for Ref {} + +impl AsRef<[u8]> for Ref { + fn as_ref(&self) -> &[u8] { + self.bytes() + } +} + +impl AsMut<[u8]> for Ref { + fn as_mut(&mut self) -> &mut [u8] { + self.bytes_mut() + } +} + +impl fmt::Debug for Ref { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Ref") + .field("type", &core::any::type_name::()) + .field("bytes", &self.bytes()) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::Ref; + + struct Message; + struct OtherMessage; + + #[test] + fn shared_ref_exposes_bytes() { + let bytes = b"hello"; + let reference = Ref::<_, Message>::new(&bytes[..]); + + assert_eq!(reference.bytes(), b"hello"); + assert_eq!(reference.len(), 5); + assert!(!reference.is_empty()); + } + + #[test] + fn mutable_ref_exposes_mutable_bytes() { + let mut bytes = *b"hello"; + let mut reference = Ref::<_, Message>::new(&mut bytes[..]); + + reference.bytes_mut()[0] = b'j'; + assert_eq!(&bytes, b"jello"); + } + + #[test] + fn ref_can_be_retagged() { + let bytes = b"hello"; + let reference = Ref::<_, Message>::new(&bytes[..]); + let other = reference.retag::(); + + assert_eq!(other.bytes(), b"hello"); + } + + #[test] + fn ref_can_be_reborrowed() { + let bytes = b"hello"; + let reference = Ref::<_, Message>::new(&bytes[..]); + let borrowed = reference.reborrow(); + + assert_eq!(borrowed.bytes(), b"hello"); + } + + #[test] + fn mutable_ref_can_be_reborrowed_mutably() { + let mut bytes = *b"hello"; + let mut reference = Ref::<_, Message>::new(&mut bytes[..]); + + { + let mut borrowed = reference.reborrow_mut(); + borrowed[1] = b'a'; + } + + assert_eq!(&bytes, b"hallo"); + } +} From 95291a8ca9609cb3585e00dfe7e605e1ac17308a Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 28 Mar 2026 11:44:53 -0400 Subject: [PATCH 043/304] ql-wire: remove zerocopy encrypted message: decrypt in place consumes message ql-wire: simplify byteslice --- Cargo.lock | 1 - ql-wire/Cargo.toml | 11 +- ql-wire/src/bytes.rs | 83 ++----- ql-wire/src/codec.rs | 129 +++++++---- ql-wire/src/control.rs | 50 +++-- ql-wire/src/encrypted/ack.rs | 121 +++------- ql-wire/src/encrypted/close.rs | 42 +--- ql-wire/src/encrypted/mod.rs | 165 ++++++-------- ql-wire/src/encrypted/stream_close.rs | 104 ++++----- ql-wire/src/encrypted/stream_data.rs | 90 +++----- ql-wire/src/encrypted/stream_window.rs | 49 ++--- ql-wire/src/encrypted_message.rs | 132 ++++++----- ql-wire/src/handshake/crypto.rs | 139 ++++++------ ql-wire/src/handshake/mod.rs | 292 +++++++------------------ ql-wire/src/header.rs | 39 ++-- ql-wire/src/lib.rs | 8 +- ql-wire/src/pair/crypto.rs | 22 +- ql-wire/src/pair/mod.rs | 120 ++++------ ql-wire/src/record.rs | 137 +++++++----- ql-wire/src/ref.rs | 162 -------------- ql-wire/src/tests.rs | 68 +++--- ql-wire/src/unpair/crypto.rs | 18 +- ql-wire/src/unpair/mod.rs | 43 +--- 23 files changed, 763 insertions(+), 1262 deletions(-) delete mode 100644 ql-wire/src/ref.rs diff --git a/Cargo.lock b/Cargo.lock index c89a427b..1206f164 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2194,7 +2194,6 @@ dependencies = [ "libcrux-ml-kem", "sha2", "thiserror", - "zerocopy", ] [[package]] diff --git a/ql-wire/Cargo.toml b/ql-wire/Cargo.toml index cd579826..4147ec23 100644 --- a/ql-wire/Cargo.toml +++ b/ql-wire/Cargo.toml @@ -6,10 +6,15 @@ description = "Quantum Link wire format types and crypto helpers" license = "Proprietary" [dependencies] -libcrux-ml-dsa = { version = "0.0.7", default-features = false, features = ["std", "mldsa87"] } -libcrux-ml-kem = { version = "0.0.7", default-features = false, features = ["std", "mlkem1024"] } +libcrux-ml-dsa = { version = "0.0.7", default-features = false, features = [ + "std", + "mldsa87", +] } +libcrux-ml-kem = { version = "0.0.7", default-features = false, features = [ + "std", + "mlkem1024", +] } thiserror = { version = "2" } -zerocopy = { version = "0.8", features = ["derive"] } [dev-dependencies] libcrux-aesgcm = "0.0.7" diff --git a/ql-wire/src/bytes.rs b/ql-wire/src/bytes.rs index fdd98008..19795fab 100644 --- a/ql-wire/src/bytes.rs +++ b/ql-wire/src/bytes.rs @@ -1,89 +1,48 @@ use core::ops::{Deref, DerefMut}; -/// A mutable or immutable reference to bytes. -/// -/// # Safety -/// -/// Implementations must provide stable dereferences. Given some `b: B`, repeated -/// calls to `Deref::deref(&b)` must always produce a byte slice with the same -/// address and length for as long as `b` is alive. If `B: ByteSliceMut`, -/// repeated calls to `DerefMut::deref_mut(&mut b)` must provide the same -/// guarantee. -pub unsafe trait ByteSlice: Deref + Sized {} +/// A mutable or immutable byte slice owner used by the wire parser. +pub trait ByteSlice: Deref + Sized { + /// Splits the current byte view at `mid`. + /// + /// Returns `Err(self)` when `mid` is out of bounds. + fn split_at(self, mid: usize) -> Result<(Self, Self), Self>; +} /// A mutable reference to bytes. pub trait ByteSliceMut: ByteSlice + DerefMut {} impl ByteSliceMut for B where B: ByteSlice + DerefMut {} -/// A [`ByteSlice`] that can be split in two. -/// -/// # Safety -/// -/// Implementations must guarantee that `split_at` and `split_at_unchecked` -/// correctly split the underlying bytes. If `self.deref()` yields a slice with -/// address `addr` and length `len`, then splitting at `mid <= len` must return -/// `(left, right)` such that: -/// - `left` starts at `addr` and has length `mid` -/// - `right` starts at `addr + mid` and has length `len - mid` -pub unsafe trait SplitByteSlice: ByteSlice { +impl ByteSlice for &[u8] { #[inline] fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { if mid <= self.len() { - // SAFETY: We just proved that `mid` is in bounds. - unsafe { Ok(self.split_at_unchecked(mid)) } + Ok(<[u8]>::split_at(self, mid)) } else { Err(self) } } - - /// Splits the underlying bytes at `mid`. - /// - /// # Safety - /// - /// `mid` must be less than or equal to the underlying slice length. - unsafe fn split_at_unchecked(self, mid: usize) -> (Self, Self); -} - -/// A shorthand for [`SplitByteSlice`] and [`ByteSliceMut`]. -pub trait SplitByteSliceMut: SplitByteSlice + ByteSliceMut {} - -impl SplitByteSliceMut for B where B: SplitByteSlice + ByteSliceMut {} - -// SAFETY: `&[u8]` dereferences to the same slice for the lifetime of the -// reference. -unsafe impl ByteSlice for &[u8] {} - -// SAFETY: `&mut [u8]` dereferences to the same slice for the lifetime of the -// reference. -unsafe impl ByteSlice for &mut [u8] {} - -// SAFETY: These methods delegate to the standard library slice split methods, -// which return the exact left and right sub-slices at `mid`. -unsafe impl SplitByteSlice for &[u8] { - #[inline] - unsafe fn split_at_unchecked(self, mid: usize) -> (Self, Self) { - <[u8]>::split_at(self, mid) - } } -// SAFETY: These methods delegate to the standard library slice split methods, -// which return the exact left and right sub-slices at `mid`. -unsafe impl SplitByteSlice for &mut [u8] { +impl ByteSlice for &mut [u8] { #[inline] - unsafe fn split_at_unchecked(self, mid: usize) -> (Self, Self) { - <[u8]>::split_at_mut(self, mid) + fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { + if mid <= self.len() { + Ok(<[u8]>::split_at_mut(self, mid)) + } else { + Err(self) + } } } #[cfg(test)] mod tests { - use super::{SplitByteSlice, SplitByteSliceMut}; + use super::{ByteSlice, ByteSliceMut}; #[test] fn shared_slice_split_at() { let bytes: &[u8] = b"abcdef"; - let (left, right) = SplitByteSlice::split_at(bytes, 2).unwrap(); + let (left, right) = ByteSlice::split_at(bytes, 2).unwrap(); assert_eq!(left, b"ab"); assert_eq!(right, b"cdef"); } @@ -91,14 +50,14 @@ mod tests { #[test] fn mutable_slice_split_at() { let mut bytes = *b"abcdef"; - let (left, right) = SplitByteSlice::split_at(&mut bytes[..], 2).unwrap(); + let (left, right) = ByteSlice::split_at(&mut bytes[..], 2).unwrap(); assert_eq!(left, b"ab"); assert_eq!(right, b"cdef"); } #[test] fn mutable_split_trait_is_implemented() { - fn assert_split_mut(_value: T) {} + fn assert_split_mut(_value: T) {} let mut bytes = [0u8; 4]; assert_split_mut(&mut bytes[..]); @@ -107,6 +66,6 @@ mod tests { #[test] fn split_at_rejects_out_of_bounds_index() { let bytes: &[u8] = b"abcdef"; - assert!(SplitByteSlice::split_at(bytes, 7).is_err()); + assert!(ByteSlice::split_at(bytes, 7).is_err()); } } diff --git a/ql-wire/src/codec.rs b/ql-wire/src/codec.rs index 5e3e16fd..e5e0ebdd 100644 --- a/ql-wire/src/codec.rs +++ b/ql-wire/src/codec.rs @@ -1,51 +1,100 @@ -use zerocopy::{ - byte_slice::{ByteSlice, SplitByteSlice}, - byteorder::little_endian, - FromBytes, Immutable, IntoBytes, KnownLayout, Ref, TryFromBytes, -}; - -use crate::{QlHeader, WireError}; - -pub type U16Le = little_endian::U16; -pub type U32Le = little_endian::U32; -pub type U64Le = little_endian::U64; - -pub fn push_value(out: &mut Vec, value: &T) -where - T: IntoBytes + Immutable + ?Sized, -{ - out.extend_from_slice(value.as_bytes()); +use crate::{ByteSlice, QlHeader, WireError}; + +pub fn push_u8(out: &mut Vec, value: u8) { + out.push(value); +} + +pub fn push_u16(out: &mut Vec, value: u16) { + out.extend_from_slice(&value.to_le_bytes()); +} + +pub fn push_u32(out: &mut Vec, value: u32) { + out.extend_from_slice(&value.to_le_bytes()); } -pub fn read_exact(bytes: &[u8]) -> Result -where - T: FromBytes + Copy, -{ - T::read_from_bytes(bytes).map_err(|_| WireError::InvalidPayload) +pub fn push_u64(out: &mut Vec, value: u64) { + out.extend_from_slice(&value.to_le_bytes()); } -pub fn read_byte(byte: u8) -> Result -where - T: TryFromBytes + Copy, -{ - T::try_read_from_bytes(core::slice::from_ref(&byte)).map_err(|_| WireError::InvalidPayload) +pub fn push_bytes(out: &mut Vec, bytes: &[u8]) { + out.extend_from_slice(bytes); } -pub fn read_prefix(bytes: B) -> Result<(T, B), WireError> -where - B: SplitByteSlice, - T: FromBytes + KnownLayout + Immutable + Copy, -{ - let (value, rest) = Ref::<_, T>::from_prefix(bytes).map_err(|_| WireError::InvalidPayload)?; - Ok((*value, rest)) +pub struct Reader { + remaining: Option, } -pub fn parse(bytes: B) -> Result, WireError> -where - B: ByteSlice, - T: KnownLayout + Immutable + ?Sized, -{ - Ref::<_, T>::from_bytes(bytes).map_err(|_| WireError::InvalidPayload) +impl Reader { + pub fn new(bytes: B) -> Self { + Self { + remaining: Some(bytes), + } + } + + pub fn is_empty(&self) -> bool { + self.remaining.as_ref().unwrap().is_empty() + } + + pub fn remaining(&self) -> usize { + self.remaining.as_ref().unwrap().len() + } + + pub fn take_bytes(&mut self, len: usize) -> Result { + let remaining = self.remaining.take().unwrap(); + match remaining.split_at(len) { + Ok((head, tail)) => { + self.remaining = Some(tail); + Ok(head) + } + Err(remaining) => { + self.remaining = Some(remaining); + Err(WireError::InvalidPayload) + } + } + } + + pub fn take_rest(mut self) -> B { + self.remaining.take().unwrap() + } + + pub fn take_array(&mut self) -> Result<[u8; N], WireError> { + let bytes = self.take_bytes(N)?; + let mut out = [0u8; N]; + out.copy_from_slice(&bytes); + Ok(out) + } + + pub fn take_u8(&mut self) -> Result { + Ok(self.take_bytes(1)?[0]) + } + + pub fn take_u16(&mut self) -> Result { + Ok(u16::from_le_bytes(self.take_array()?)) + } + + pub fn take_u32(&mut self) -> Result { + Ok(u32::from_le_bytes(self.take_array()?)) + } + + pub fn take_u64(&mut self) -> Result { + Ok(u64::from_le_bytes(self.take_array()?)) + } + + pub fn take_bool(&mut self) -> Result { + match self.take_u8()? { + 0 => Ok(false), + 1 => Ok(true), + _ => Err(WireError::InvalidPayload), + } + } + + pub fn finish(self) -> Result<(), WireError> { + if self.is_empty() { + Ok(()) + } else { + Err(WireError::InvalidPayload) + } + } } pub fn append_field(out: &mut Vec, label: &[u8], value: &[u8]) { diff --git a/ql-wire/src/control.rs b/ql-wire/src/control.rs index 47f1c74a..17dbfd67 100644 --- a/ql-wire/src/control.rs +++ b/ql-wire/src/control.rs @@ -1,9 +1,4 @@ -use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned}; - -use crate::{ - codec::{U32Le, U64Le}, - WireError, -}; +use crate::{codec, WireError}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] @@ -16,6 +11,8 @@ pub struct ControlMeta { } impl ControlMeta { + pub const ENCODED_LEN: usize = core::mem::size_of::() + core::mem::size_of::(); + pub fn ensure_not_expired(&self, now_seconds: u64) -> Result<(), WireError> { if now_seconds > self.valid_until { Err(WireError::Expired) @@ -24,24 +21,33 @@ impl ControlMeta { } } - pub fn to_wire(&self) -> ControlMetaWire { - ControlMetaWire { - control_id: U32Le::new(self.control_id.0), - valid_until: U64Le::new(self.valid_until), - } + pub fn encode_into(&self, out: &mut Vec) { + codec::push_u32(out, self.control_id.0); + codec::push_u64(out, self.valid_until); } - pub fn from_wire(meta: ControlMetaWire) -> Self { - Self { - control_id: ControlId(meta.control_id.get()), - valid_until: meta.valid_until.get(), - } + pub fn encode(&self) -> Vec { + let mut out = Vec::with_capacity(Self::ENCODED_LEN); + self.encode_into(&mut out); + out + } + + pub fn decode(bytes: &[u8]) -> Result { + let mut reader = codec::Reader::new(bytes); + let meta = Self { + control_id: ControlId(reader.take_u32()?), + valid_until: reader.take_u64()?, + }; + reader.finish()?; + Ok(meta) } -} -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] -#[repr(C)] -pub struct ControlMetaWire { - pub control_id: U32Le, - pub valid_until: U64Le, + pub fn decode_from( + reader: &mut codec::Reader, + ) -> Result { + Ok(Self { + control_id: ControlId(reader.take_u32()?), + valid_until: reader.take_u64()?, + }) + } } diff --git a/ql-wire/src/encrypted/ack.rs b/ql-wire/src/encrypted/ack.rs index 1a3f03dd..46d068fd 100644 --- a/ql-wire/src/encrypted/ack.rs +++ b/ql-wire/src/encrypted/ack.rs @@ -1,13 +1,6 @@ use std::mem::size_of; -use zerocopy::{ - byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, -}; - -use crate::{ - codec::{parse, push_value, read_exact, U64Le}, - WireError, -}; +use crate::{codec, WireError}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct RecordAck { @@ -20,98 +13,46 @@ pub struct RecordAckRange { pub end: u64, } -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] -#[repr(C)] -pub struct RecordAckRangeWire { - pub start: U64Le, - pub end: U64Le, -} - -#[derive(FromBytes, KnownLayout, Immutable, Unaligned)] -#[repr(C, packed)] -pub struct RecordAckQire { - pub ranges: [u8], -} - -pub struct RecordAckRangeIter<'a> { - remaining: &'a [u8], -} - impl RecordAck { - pub fn parse(bytes: B) -> Result, WireError> { - let wire = parse(bytes)?; - validate_ack_frame(&wire)?; - Ok(wire) - } + pub const RANGE_ENCODED_LEN: usize = size_of::() + size_of::(); - pub fn encoded_len(&self) -> usize { - self.ranges.len() * size_of::() - } - - pub fn from_wire(wire: &RecordAckQire) -> Result { - validate_ack_frame(wire)?; - Ok(Self { - ranges: wire.ranges().collect(), - }) - } - - pub fn encode_into(&self, out: &mut Vec) { - for range in &self.ranges { - push_value( - out, - &RecordAckRangeWire { - start: U64Le::new(range.start), - end: U64Le::new(range.end), - }, - ); - } - } -} - -impl RecordAckQire { - pub fn ranges(&self) -> RecordAckRangeIter<'_> { - RecordAckRangeIter { - remaining: &self.ranges, + pub fn decode(bytes: &[u8]) -> Result { + if bytes.is_empty() || bytes.len() % Self::RANGE_ENCODED_LEN != 0 { + return Err(WireError::InvalidPayload); } - } -} -impl Iterator for RecordAckRangeIter<'_> { - type Item = RecordAckRange; - - fn next(&mut self) -> Option { - if self.remaining.is_empty() { - return None; + let mut reader = codec::Reader::new(bytes); + let mut ranges = Vec::with_capacity(bytes.len() / Self::RANGE_ENCODED_LEN); + let mut previous_end = 0; + + while !reader.is_empty() { + let range = RecordAckRange { + start: reader.take_u64()?, + end: reader.take_u64()?, + }; + + if range.start >= range.end { + return Err(WireError::InvalidPayload); + } + if !ranges.is_empty() && range.start < previous_end { + return Err(WireError::InvalidPayload); + } + + previous_end = range.end; + ranges.push(range); } - let (head, tail) = self.remaining.split_at(size_of::()); - self.remaining = tail; - let wire: RecordAckRangeWire = - read_exact(head).expect("ack ranges are validated before iteration"); - Some(RecordAckRange { - start: wire.start.get(), - end: wire.end.get(), - }) + Ok(Self { ranges }) } -} -fn validate_ack_frame(wire: &RecordAckQire) -> Result<(), WireError> { - if wire.ranges.is_empty() || wire.ranges.len() % size_of::() != 0 { - return Err(WireError::InvalidPayload); + pub fn encoded_len(&self) -> usize { + self.ranges.len() * Self::RANGE_ENCODED_LEN } - let mut previous_end = 0; - let mut first = true; - for range in wire.ranges() { - if range.start >= range.end { - return Err(WireError::InvalidPayload); - } - if !first && range.start < previous_end { - return Err(WireError::InvalidPayload); + pub fn encode_into(&self, out: &mut Vec) { + for range in &self.ranges { + codec::push_u64(out, range.start); + codec::push_u64(out, range.end); } - first = false; - previous_end = range.end; } - - Ok(()) } diff --git a/ql-wire/src/encrypted/close.rs b/ql-wire/src/encrypted/close.rs index a4a1048f..4702566a 100644 --- a/ql-wire/src/encrypted/close.rs +++ b/ql-wire/src/encrypted/close.rs @@ -1,12 +1,8 @@ use std::mem::size_of; -use zerocopy::{ - byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, -}; - use super::CloseCode; use crate::{ - codec::{parse, push_value, read_exact, U16Le}, + codec::{self, Reader}, WireError, }; @@ -16,40 +12,18 @@ pub struct SessionCloseBody { pub code: CloseCode, } -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] -#[repr(C)] -pub struct SessionCloseBodyWire { - pub code: U16Le, -} - impl SessionCloseBody { - pub const WIRE_SIZE: usize = size_of::(); - - pub fn parse(bytes: B) -> Result, WireError> { - if bytes.len() != Self::WIRE_SIZE { - return Err(WireError::InvalidPayload); - } - parse(bytes) - } - - pub fn from_wire(wire: &SessionCloseBodyWire) -> Self { - Self { - code: CloseCode(wire.code.get()), - } - } - - pub fn to_wire(&self) -> SessionCloseBodyWire { - SessionCloseBodyWire { - code: U16Le::new(self.code.0), - } - } + pub const WIRE_SIZE: usize = size_of::(); pub fn encode_into(&self, out: &mut Vec) { - push_value(out, &self.to_wire()); + codec::push_u16(out, self.code.0); } pub fn decode(bytes: &[u8]) -> Result { - let wire: SessionCloseBodyWire = read_exact(bytes)?; - Ok(Self::from_wire(&wire)) + let mut reader = Reader::new(bytes); + let code = reader.take_u16()?; + Ok(Self { + code: CloseCode(code), + }) } } diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 1e87bac5..38b260cc 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -1,14 +1,9 @@ use std::mem::size_of; -use zerocopy::{ - byte_slice::{ByteSlice, ByteSliceMut}, - FromBytes, Immutable, IntoBytes, KnownLayout, Ref, TryFromBytes, Unaligned, -}; - use crate::{ - codec::{parse, read_byte}, - encrypted_message::{EncryptedMessage, EncryptedMessageWire}, - Nonce, QlCrypto, QlHeader, QlPayload, QlRecord, SessionKey, WireError, + codec, + encrypted_message::EncryptedMessage, + QlCrypto, QlHeader, QlPayload, QlRecord, SessionKey, WireError, }; mod ack; @@ -35,31 +30,24 @@ pub struct RecordSeq(pub u64); #[derive(Debug, Clone, PartialEq, Eq)] pub struct SessionRecord { pub seq: RecordSeq, - pub frames: Vec, + pub frames: Vec, } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum SessionFrame { +pub enum SessionFrame { Ping, Ack(RecordAck), - StreamData(StreamData), + StreamData(StreamData), StreamWindow(StreamWindow), - StreamClose(StreamClose), + StreamClose(StreamClose), Close(SessionCloseBody), } -pub enum SessionFrameRef<'a> { - Ping, - Ack(Ref<&'a [u8], RecordAckQire>), - StreamData(Ref<&'a [u8], StreamDataWire>), - StreamWindow(Ref<&'a [u8], StreamWindowWire>), - StreamClose(Ref<&'a [u8], StreamCloseWire>), - Close(Ref<&'a [u8], SessionCloseBodyWire>), -} +pub type SessionFrameVec = SessionFrame>; +pub type StreamDataVec = StreamData>; +pub type StreamCloseVec = StreamClose>; -#[derive( - Debug, Clone, Copy, PartialEq, Eq, TryFromBytes, KnownLayout, Immutable, IntoBytes, Unaligned, -)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub(crate) enum SessionFrameKind { Ping = 1, @@ -70,60 +58,60 @@ pub(crate) enum SessionFrameKind { Close = 6, } -#[derive(FromBytes, KnownLayout, Immutable, Unaligned)] -#[repr(C, packed)] -pub struct SessionRecordWire { - pub seq: crate::codec::U64Le, - pub frames: [u8], -} - pub struct SessionFrameIter<'a> { remaining: &'a [u8], } +impl TryFrom for SessionFrameKind { + type Error = WireError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::Ping), + 2 => Ok(Self::Ack), + 3 => Ok(Self::StreamData), + 4 => Ok(Self::StreamWindow), + 5 => Ok(Self::StreamClose), + 6 => Ok(Self::Close), + _ => Err(WireError::InvalidPayload), + } + } +} + impl SessionRecord { - pub fn parse(bytes: B) -> Result, WireError> { - parse(bytes) + pub fn parse(bytes: &[u8]) -> Result<(RecordSeq, SessionFrameIter<'_>), WireError> { + let mut reader = codec::Reader::new(bytes); + let seq = RecordSeq(reader.take_u64()?); + Ok(( + seq, + SessionFrameIter { + remaining: reader.take_rest(), + }, + )) } - pub fn from_wire(wire: &SessionRecordWire) -> Result { - let frames = wire - .frames() - .map(|frame| frame?.to_owned()) + pub fn decode(bytes: &[u8]) -> Result { + let (seq, frames) = Self::parse(bytes)?; + let frames = frames + .map(|frame| frame.map(SessionFrame::into_owned)) .collect::, _>>()?; Ok(Self { - seq: wire.seq(), + seq, frames, }) } pub fn encode(&self) -> Vec { let mut out = Vec::new(); - out.extend_from_slice(&self.seq.0.to_le_bytes()); + codec::push_u64(&mut out, self.seq.0); for frame in &self.frames { frame.encode_into(&mut out); } out } - - pub fn decode(bytes: &[u8]) -> Result { - Self::from_wire(&Self::parse(bytes)?) - } -} - -impl SessionRecordWire { - pub fn seq(&self) -> RecordSeq { - RecordSeq(self.seq.get()) - } - - pub fn frames(&self) -> SessionFrameIter<'_> { - SessionFrameIter { - remaining: &self.frames, - } - } } -impl SessionFrame { +impl> SessionFrame { pub fn encode_into(&self, out: &mut Vec) { match self { Self::Ping => out.push(SessionFrameKind::Ping as u8), @@ -152,23 +140,21 @@ impl SessionFrame { } } } -} -impl SessionFrameRef<'_> { - pub fn to_owned(&self) -> Result { - Ok(match self { + pub fn into_owned(self) -> SessionFrameVec { + match self { Self::Ping => SessionFrame::Ping, - Self::Ack(frame) => SessionFrame::Ack(RecordAck::from_wire(frame)?), - Self::StreamData(frame) => SessionFrame::StreamData(StreamData::from_wire(frame)?), - Self::StreamWindow(frame) => SessionFrame::StreamWindow(StreamWindow::from_wire(frame)), - Self::StreamClose(frame) => SessionFrame::StreamClose(StreamClose::from_wire(frame)?), - Self::Close(frame) => SessionFrame::Close(SessionCloseBody::from_wire(frame)), - }) + Self::Ack(frame) => SessionFrame::Ack(frame), + Self::StreamData(frame) => SessionFrame::StreamData(frame.into_owned()), + Self::StreamWindow(frame) => SessionFrame::StreamWindow(frame), + Self::StreamClose(frame) => SessionFrame::StreamClose(frame.into_owned()), + Self::Close(frame) => SessionFrame::Close(frame), + } } } impl<'a> Iterator for SessionFrameIter<'a> { - type Item = Result, WireError>; + type Item = Result, WireError>; fn next(&mut self) -> Option { if self.remaining.is_empty() { @@ -194,7 +180,7 @@ pub fn encrypt_record( header: QlHeader, session_key: &SessionKey, body: &SessionRecord, - nonce: Nonce, + nonce: crate::Nonce, ) -> QlRecord { let aad = header.aad(); let body = body.encode(); @@ -205,63 +191,52 @@ pub fn encrypt_record( } } -pub fn decrypt_record<'a, B: ByteSliceMut>( +pub fn decrypt_record>( crypto: &impl QlCrypto, header: &QlHeader, - encrypted: &'a mut Ref, + encrypted: EncryptedMessage, session_key: &SessionKey, -) -> Result, WireError> { +) -> Result { let aad = header.aad(); - let plaintext = EncryptedMessage::decrypt_in_place(encrypted, crypto, session_key, &aad)?; - SessionRecord::parse(plaintext) + encrypted.decrypt_in_place(crypto, session_key, &aad) } -fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrameRef<'_>, &[u8]), WireError> { +fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrame<&[u8]>, &[u8]), WireError> { let (&kind, rest) = bytes.split_first().ok_or(WireError::InvalidPayload)?; - let kind: SessionFrameKind = read_byte(kind)?; - match kind { - SessionFrameKind::Ping => Ok((SessionFrameRef::Ping, rest)), + match SessionFrameKind::try_from(kind)? { + SessionFrameKind::Ping => Ok((SessionFrame::Ping, rest)), SessionFrameKind::Ack => { let (frame, rest) = split_variable_frame(rest)?; - Ok((SessionFrameRef::Ack(RecordAck::parse(frame)?), rest)) + Ok((SessionFrame::Ack(RecordAck::decode(frame)?), rest)) } SessionFrameKind::StreamData => { let (frame, rest) = split_variable_frame(rest)?; - Ok((SessionFrameRef::StreamData(StreamData::parse(frame)?), rest)) + Ok((SessionFrame::StreamData(StreamData::parse(frame)?), rest)) } SessionFrameKind::StreamWindow => { - let wire_size = StreamWindow::WIRE_SIZE; - if rest.len() < wire_size { + if rest.len() < StreamWindow::WIRE_SIZE { return Err(WireError::InvalidPayload); } - let (frame, rest) = rest.split_at(wire_size); - Ok(( - SessionFrameRef::StreamWindow(StreamWindow::parse(frame)?), - rest, - )) + let (frame, rest) = rest.split_at(StreamWindow::WIRE_SIZE); + Ok((SessionFrame::StreamWindow(StreamWindow::decode(frame)?), rest)) } SessionFrameKind::StreamClose => { let (frame, rest) = split_variable_frame(rest)?; - Ok(( - SessionFrameRef::StreamClose(StreamClose::parse(frame)?), - rest, - )) + Ok((SessionFrame::StreamClose(StreamClose::parse(frame)?), rest)) } SessionFrameKind::Close => { - let wire_size = SessionCloseBody::WIRE_SIZE; - if rest.len() < wire_size { + if rest.len() < SessionCloseBody::WIRE_SIZE { return Err(WireError::InvalidPayload); } - let (frame, rest) = rest.split_at(wire_size); - let frame = SessionCloseBody::parse(frame)?; - Ok((SessionFrameRef::Close(frame), rest)) + let (frame, rest) = rest.split_at(SessionCloseBody::WIRE_SIZE); + Ok((SessionFrame::Close(SessionCloseBody::decode(frame)?), rest)) } } } fn push_variable_len(out: &mut Vec, len: usize) { let len = u16::try_from(len).expect("session frame exceeds u16"); - out.extend_from_slice(&len.to_le_bytes()); + codec::push_u16(out, len); } fn split_variable_frame(bytes: &[u8]) -> Result<(&[u8], &[u8]), WireError> { diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index 5e3bc0ac..3ae539a3 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -1,28 +1,9 @@ use std::mem::size_of; -use zerocopy::{ - byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, TryFromBytes, - Unaligned, -}; - use super::StreamId; -use crate::{ - codec::{parse, read_byte, U16Le, U32Le}, - WireError, -}; - -/// aborts one or both directions of a stream with a close code. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct StreamClose { - pub stream_id: StreamId, - pub target: CloseTarget, - pub code: CloseCode, - pub payload: Vec, -} +use crate::{codec, ByteSlice, WireError}; -#[derive( - Debug, Clone, Copy, PartialEq, Eq, TryFromBytes, KnownLayout, Immutable, IntoBytes, Unaligned, -)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub enum CloseTarget { Request = 1, @@ -36,6 +17,19 @@ impl CloseTarget { } } +impl TryFrom for CloseTarget { + type Error = WireError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::Request), + 2 => Ok(Self::Response), + 3 => Ok(Self::Both), + _ => Err(WireError::InvalidPayload), + } + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[repr(transparent)] pub struct CloseCode(pub u16); @@ -53,44 +47,54 @@ impl CloseCode { pub const UNHANDLED: Self = Self(20); } -#[derive(FromBytes, KnownLayout, Immutable, Unaligned)] -#[repr(C, packed)] -pub struct StreamCloseWire { - pub stream_id: U32Le, - pub target: u8, - pub code: U16Le, - pub payload: [u8], +/// aborts one or both directions of a stream with a close code. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamClose { + pub stream_id: StreamId, + pub target: CloseTarget, + pub code: CloseCode, + pub payload: B, } -impl StreamClose { - pub const MIN_WIRE_SIZE: usize = size_of::() + size_of::() + size_of::(); +impl StreamClose { + pub const MIN_WIRE_SIZE: usize = size_of::() + size_of::() + size_of::(); +} - pub fn parse(bytes: B) -> Result, WireError> { - if bytes.len() < Self::MIN_WIRE_SIZE { - return Err(WireError::InvalidPayload); - } - let wire: Ref = parse(bytes)?; - let _ = read_byte::(wire.target)?; - Ok(wire) +impl StreamClose { + pub fn parse(bytes: B) -> Result { + let mut reader = codec::Reader::new(bytes); + Ok(Self { + stream_id: StreamId(reader.take_u32()?), + target: CloseTarget::try_from(reader.take_u8()?)?, + code: CloseCode(reader.take_u16()?), + payload: reader.take_rest(), + }) } +} - pub fn encoded_len(&self) -> usize { - Self::MIN_WIRE_SIZE + self.payload.len() +impl StreamClose { + pub fn into_owned(self) -> StreamClose> + where + B: AsRef<[u8]>, + { + StreamClose { + stream_id: self.stream_id, + target: self.target, + code: self.code, + payload: self.payload.as_ref().to_vec(), + } } +} - pub fn from_wire(wire: &StreamCloseWire) -> Result { - Ok(Self { - stream_id: StreamId(wire.stream_id.get()), - target: read_byte(wire.target)?, - code: CloseCode(wire.code.get()), - payload: wire.payload.to_vec(), - }) +impl> StreamClose { + pub fn encoded_len(&self) -> usize { + Self::MIN_WIRE_SIZE + self.payload.as_ref().len() } pub fn encode_into(&self, out: &mut Vec) { - out.extend_from_slice(&self.stream_id.0.to_le_bytes()); - out.push(self.target.to_wire()); - out.extend_from_slice(&self.code.0.to_le_bytes()); - out.extend_from_slice(&self.payload); + codec::push_u32(out, self.stream_id.0); + codec::push_u8(out, self.target.to_wire()); + codec::push_u16(out, self.code.0); + codec::push_bytes(out, self.payload.as_ref()); } } diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs index 5507ab07..6e183d65 100644 --- a/ql-wire/src/encrypted/stream_data.rs +++ b/ql-wire/src/encrypted/stream_data.rs @@ -1,82 +1,56 @@ use std::mem::size_of; -use zerocopy::{byte_slice::ByteSlice, FromBytes, Immutable, KnownLayout, Ref, Unaligned}; - use super::StreamId; -use crate::{ - codec::{parse, U32Le, U64Le}, - WireError, -}; +use crate::{codec, ByteSlice, WireError}; /// carries bytes for a stream and may finish that sending direction. #[derive(Debug, Clone, PartialEq, Eq)] -pub struct StreamData { +pub struct StreamData { pub stream_id: StreamId, pub offset: u64, pub fin: bool, - pub bytes: Vec, + pub bytes: B, } -#[derive(FromBytes, KnownLayout, Immutable, Unaligned)] -#[repr(C, packed)] -pub struct StreamDataWire { - pub stream_id: U32Le, - pub offset: U64Le, - pub fin: u8, - pub bytes: [u8], +impl StreamData { + pub const MIN_WIRE_SIZE: usize = size_of::() + size_of::() + size_of::(); } -impl StreamData { - pub const MIN_WIRE_SIZE: usize = size_of::() + size_of::() + size_of::(); - - pub fn parse(bytes: B) -> Result, WireError> { - if bytes.len() < Self::MIN_WIRE_SIZE { - return Err(WireError::InvalidPayload); - } - let wire: Ref = parse(bytes)?; - let _ = wire.fin()?; - Ok(wire) - } - - pub fn encoded_len(&self) -> usize { - Self::MIN_WIRE_SIZE + self.bytes.len() - } - - pub fn from_wire(wire: &StreamDataWire) -> Result { +impl StreamData { + pub fn parse(bytes: B) -> Result { + let mut reader = codec::Reader::new(bytes); Ok(Self { - stream_id: wire.stream_id(), - offset: wire.offset(), - fin: wire.fin()?, - bytes: wire.bytes().to_vec(), + stream_id: StreamId(reader.take_u32()?), + offset: reader.take_u64()?, + fin: reader.take_bool()?, + bytes: reader.take_rest(), }) } - - pub fn encode_into(&self, out: &mut Vec) { - out.extend_from_slice(&self.stream_id.0.to_le_bytes()); - out.extend_from_slice(&self.offset.to_le_bytes()); - out.push(u8::from(self.fin)); - out.extend_from_slice(&self.bytes); - } } -impl StreamDataWire { - pub fn stream_id(&self) -> StreamId { - StreamId(self.stream_id.get()) - } - - pub fn offset(&self) -> u64 { - self.offset.get() +impl StreamData { + pub fn into_owned(self) -> StreamData> + where + B: AsRef<[u8]>, + { + StreamData { + stream_id: self.stream_id, + offset: self.offset, + fin: self.fin, + bytes: self.bytes.as_ref().to_vec(), + } } +} - pub fn fin(&self) -> Result { - match self.fin { - 0 => Ok(false), - 1 => Ok(true), - _ => Err(WireError::InvalidPayload), - } +impl> StreamData { + pub fn encoded_len(&self) -> usize { + Self::MIN_WIRE_SIZE + self.bytes.as_ref().len() } - pub fn bytes(&self) -> &[u8] { - &self.bytes + pub fn encode_into(&self, out: &mut Vec) { + codec::push_u32(out, self.stream_id.0); + codec::push_u64(out, self.offset); + codec::push_u8(out, u8::from(self.fin)); + codec::push_bytes(out, self.bytes.as_ref()); } } diff --git a/ql-wire/src/encrypted/stream_window.rs b/ql-wire/src/encrypted/stream_window.rs index 764b3ff7..d03f0d02 100644 --- a/ql-wire/src/encrypted/stream_window.rs +++ b/ql-wire/src/encrypted/stream_window.rs @@ -1,14 +1,7 @@ use std::mem::size_of; -use zerocopy::{ - byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, -}; - use super::StreamId; -use crate::{ - codec::{parse, push_value, U32Le, U64Le}, - WireError, -}; +use crate::{codec, WireError}; /// advertises the highest byte offset the peer may send on a stream. #[derive(Debug, Clone, PartialEq, Eq)] @@ -17,37 +10,21 @@ pub struct StreamWindow { pub maximum_offset: u64, } -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] -#[repr(C)] -pub struct StreamWindowWire { - pub stream_id: U32Le, - pub maximum_offset: U64Le, -} - impl StreamWindow { - pub const WIRE_SIZE: usize = size_of::(); + pub const WIRE_SIZE: usize = size_of::() + size_of::(); - pub fn parse(bytes: B) -> Result, WireError> { - if bytes.len() != Self::WIRE_SIZE { - return Err(WireError::InvalidPayload); - } - parse(bytes) - } - - pub fn from_wire(wire: &StreamWindowWire) -> Self { - Self { - stream_id: StreamId(wire.stream_id.get()), - maximum_offset: wire.maximum_offset.get(), - } + pub fn encode_into(&self, out: &mut Vec) { + codec::push_u32(out, self.stream_id.0); + codec::push_u64(out, self.maximum_offset); } - pub fn encode_into(&self, out: &mut Vec) { - push_value( - out, - &StreamWindowWire { - stream_id: U32Le::new(self.stream_id.0), - maximum_offset: U64Le::new(self.maximum_offset), - }, - ); + pub fn decode(bytes: &[u8]) -> Result { + let mut reader = codec::Reader::new(bytes); + let window = Self { + stream_id: StreamId(reader.take_u32()?), + maximum_offset: reader.take_u64()?, + }; + reader.finish()?; + Ok(window) } } diff --git a/ql-wire/src/encrypted_message.rs b/ql-wire/src/encrypted_message.rs index 07f20aa1..a2473fd3 100644 --- a/ql-wire/src/encrypted_message.rs +++ b/ql-wire/src/encrypted_message.rs @@ -1,77 +1,52 @@ -use zerocopy::{ - byte_slice::{ByteSlice, ByteSliceMut}, - FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, -}; - use crate::{ - codec::{parse, push_value}, - Nonce, QlCrypto, SessionKey, WireError, + codec, ByteSlice, ENCRYPTED_MESSAGE_AUTH_SIZE, Nonce, QlCrypto, SessionKey, WireError, }; -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] -#[repr(C, packed)] -pub struct EncryptedMessageWire { - pub nonce: [u8; Nonce::SIZE], - pub auth: [u8; EncryptedMessage::AUTH_SIZE], - pub ciphertext: [u8], -} - #[derive(Debug, Clone, PartialEq, Eq)] -pub struct EncryptedMessage { +pub struct EncryptedMessage { pub nonce: Nonce, - pub auth: [u8; Self::AUTH_SIZE], - pub ciphertext: Vec, + pub auth: [u8; ENCRYPTED_MESSAGE_AUTH_SIZE], + pub ciphertext: B, } -impl EncryptedMessage { - pub const AUTH_SIZE: usize = 16; +impl EncryptedMessage { + pub const AUTH_SIZE: usize = ENCRYPTED_MESSAGE_AUTH_SIZE; + pub const HEADER_LEN: usize = Nonce::SIZE + Self::AUTH_SIZE; - pub fn parse(bytes: B) -> Result, WireError> { - parse(bytes) - } - - pub fn from_wire(wire: &EncryptedMessageWire) -> Self { - Self { - nonce: Nonce(wire.nonce), - auth: wire.auth, - ciphertext: wire.ciphertext.to_vec(), + pub fn into_owned(self) -> EncryptedMessage> + where + B: AsRef<[u8]>, + { + EncryptedMessage { + nonce: self.nonce, + auth: self.auth, + ciphertext: self.ciphertext.as_ref().to_vec(), } } +} - pub fn encode(&self) -> Vec { - let mut out = Vec::with_capacity(Nonce::SIZE + Self::AUTH_SIZE + self.ciphertext.len()); - self.encode_into(&mut out); - out - } - - pub fn decode(bytes: &[u8]) -> Result { - Ok(Self::from_wire(&Self::parse(bytes)?)) +impl EncryptedMessage { + pub fn parse(bytes: B) -> Result { + let mut reader = codec::Reader::new(bytes); + Ok(Self { + nonce: Nonce(reader.take_array()?), + auth: reader.take_array()?, + ciphertext: reader.take_rest(), + }) } +} +impl> EncryptedMessage { pub fn encode_into(&self, out: &mut Vec) { - push_value( - out, - &EncryptedMessageHeaderWire { - nonce: self.nonce.0, - auth: self.auth, - }, - ); - out.extend_from_slice(&self.ciphertext); + codec::push_bytes(out, &self.nonce.0); + codec::push_bytes(out, &self.auth); + codec::push_bytes(out, self.ciphertext.as_ref()); } - pub fn encrypt( - crypto: &impl QlCrypto, - key: &SessionKey, - mut plaintext: Vec, - aad: &[u8], - nonce: Nonce, - ) -> Self { - let auth = crypto.encrypt_with_aead(key, &nonce, aad, &mut plaintext); - Self { - nonce, - auth, - ciphertext: plaintext, - } + pub fn encode(&self) -> Vec { + let mut out = Vec::with_capacity(Self::HEADER_LEN + self.ciphertext.as_ref().len()); + self.encode_into(&mut out); + out } pub fn decrypt( @@ -80,31 +55,46 @@ impl EncryptedMessage { key: &SessionKey, aad: &[u8], ) -> Result, WireError> { - let mut plaintext = self.ciphertext.clone(); + let mut plaintext = self.ciphertext.as_ref().to_vec(); if !crypto.decrypt_with_aead(key, &self.nonce, aad, &mut plaintext, &self.auth) { return Err(WireError::DecryptFailed); } Ok(plaintext) } +} - pub fn decrypt_in_place<'a, B: ByteSliceMut>( - wire: &'a mut Ref, +impl> EncryptedMessage { + pub fn decrypt_in_place( + mut self, crypto: &impl QlCrypto, key: &SessionKey, aad: &[u8], - ) -> Result<&'a mut [u8], WireError> { - let nonce = Nonce(wire.nonce); - let auth = wire.auth; - if !crypto.decrypt_with_aead(key, &nonce, aad, &mut wire.ciphertext, &auth) { + ) -> Result { + let ciphertext = self.ciphertext.as_mut(); + if !crypto.decrypt_with_aead(key, &self.nonce, aad, ciphertext, &self.auth) { return Err(WireError::DecryptFailed); } - Ok(&mut wire.ciphertext) + Ok(self.ciphertext) } } -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] -#[repr(C)] -pub struct EncryptedMessageHeaderWire { - pub nonce: [u8; Nonce::SIZE], - pub auth: [u8; EncryptedMessage::AUTH_SIZE], +impl EncryptedMessage> { + pub fn encrypt( + crypto: &impl QlCrypto, + key: &SessionKey, + mut plaintext: Vec, + aad: &[u8], + nonce: Nonce, + ) -> Self { + let auth = crypto.encrypt_with_aead(key, &nonce, aad, &mut plaintext); + Self { + nonce, + auth, + ciphertext: plaintext, + } + } + + pub fn decode(bytes: &[u8]) -> Result { + Ok(EncryptedMessage::parse(bytes)?.into_owned()) + } } diff --git a/ql-wire/src/handshake/crypto.rs b/ql-wire/src/handshake/crypto.rs index b2d00756..80ab76fc 100644 --- a/ql-wire/src/handshake/crypto.rs +++ b/ql-wire/src/handshake/crypto.rs @@ -1,10 +1,8 @@ -use zerocopy::{byte_slice::ByteSliceMut, Ref}; - -use super::{Confirm, ConfirmView, Hello, HelloReply, HelloReplyView, HelloView, Ready, ReadyBody}; +use super::{Confirm, Hello, HelloReply, Ready, ReadyBody}; use crate::{ - pq::ML_KEM_SUITE_TAG, ControlMeta, EncryptedMessage, EncryptedMessageWire, MlDsaPublicKey, - MlDsaSignature, MlKemCiphertext, MlKemPublicKey, Nonce, QlCrypto, QlHeader, QlIdentity, - SessionKey, WireError, XID, + pq::ML_KEM_SUITE_TAG, ControlMeta, EncryptedMessage, MlDsaPublicKey, MlDsaSignature, + MlKemCiphertext, MlKemPublicKey, Nonce, QlCrypto, QlHeader, QlIdentity, SessionKey, + WireError, XID, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -47,20 +45,19 @@ pub fn verify_hello( initiator: XID, responder: XID, initiator_signing_key: &MlDsaPublicKey, - hello: &impl HelloView, + hello: &Hello, now_seconds: u64, ) -> Result<(), WireError> { - let meta = hello.meta(); - meta.ensure_not_expired(now_seconds)?; + hello.meta.ensure_not_expired(now_seconds)?; let proof_data = hash_hello_proof_data( crypto, initiator, responder, - &meta, - hello.nonce(), - hello.kem_ct(), + &hello.meta, + &hello.nonce.0, + hello.kem_ct.as_bytes(), ); - verify_signature_bytes(initiator_signing_key, hello.signature(), &proof_data) + verify_signature_bytes(initiator_signing_key, hello.signature.as_bytes(), &proof_data) } pub fn respond_hello( @@ -69,7 +66,7 @@ pub fn respond_hello( initiator: XID, initiator_signing_key: &MlDsaPublicKey, initiator_encapsulation_key: &MlKemPublicKey, - hello: &impl HelloView, + hello: &Hello, meta: ControlMeta, now_seconds: u64, ) -> Result<(HelloReply, ResponderSecrets), WireError> { @@ -83,8 +80,7 @@ pub fn respond_hello( )?; let initiator_secret = identity .encapsulation_private_key - .decapsulate_shared_secret_bytes(hello.kem_ct()); - let hello_meta = hello.meta(); + .decapsulate_shared_secret_bytes(hello.kem_ct.as_bytes()); let nonce = next_nonce(crypto); let (responder_secret, kem_ct) = initiator_encapsulation_key.encapsulate_new_shared_secret(crypto); @@ -92,9 +88,9 @@ pub fn respond_hello( crypto, initiator, identity.xid, - &hello_meta, - hello.nonce(), - hello.kem_ct(), + &hello.meta, + &hello.nonce.0, + hello.kem_ct.as_bytes(), &meta, &nonce.0, kem_ct.as_bytes(), @@ -119,41 +115,39 @@ pub fn build_confirm( identity: &QlIdentity, responder: XID, responder_signing_key: &MlDsaPublicKey, - hello: &impl HelloView, - reply: &impl HelloReplyView, + hello: &Hello, + reply: &HelloReply, initiator_secret: &SessionKey, meta: ControlMeta, now_seconds: u64, ) -> Result<(Confirm, SessionKey), WireError> { - let hello_meta = hello.meta(); - let reply_meta = reply.meta(); - reply_meta.ensure_not_expired(now_seconds)?; + reply.meta.ensure_not_expired(now_seconds)?; let transcript = hash_handshake_transcript( crypto, identity.xid, responder, - &hello_meta, - hello.nonce(), - hello.kem_ct(), - &reply_meta, - reply.nonce(), - reply.kem_ct(), + &hello.meta, + &hello.nonce.0, + hello.kem_ct.as_bytes(), + &reply.meta, + &reply.nonce.0, + reply.kem_ct.as_bytes(), ); - verify_signature_bytes(responder_signing_key, reply.signature(), &transcript)?; + verify_signature_bytes(responder_signing_key, reply.signature.as_bytes(), &transcript)?; let responder_secret = identity .encapsulation_private_key - .decapsulate_shared_secret_bytes(reply.kem_ct()); + .decapsulate_shared_secret_bytes(reply.kem_ct.as_bytes()); let proof_data = hash_confirm_proof_data( crypto, &meta, identity.xid, responder, - &hello_meta, - hello.nonce(), - hello.kem_ct(), - &reply_meta, - reply.nonce(), - reply.kem_ct(), + &hello.meta, + &hello.nonce.0, + hello.kem_ct.as_bytes(), + &reply.meta, + &reply.nonce.0, + reply.kem_ct.as_bytes(), ); let signature = identity.signing_private_key.sign(crypto, &proof_data); let session_key = derive_session_key( @@ -162,12 +156,12 @@ pub fn build_confirm( &responder_secret, identity.xid, responder, - &hello_meta, - hello.nonce(), - hello.kem_ct(), - &reply_meta, - reply.nonce(), - reply.kem_ct(), + &hello.meta, + &hello.nonce.0, + hello.kem_ct.as_bytes(), + &reply.meta, + &reply.nonce.0, + reply.kem_ct.as_bytes(), ); Ok((Confirm { meta, signature }, session_key)) } @@ -177,14 +171,12 @@ pub fn finalize_confirm( initiator: XID, responder: XID, initiator_signing_key: &MlDsaPublicKey, - hello: &impl HelloView, - reply: &impl HelloReplyView, - confirm: &impl ConfirmView, + hello: &Hello, + reply: &HelloReply, + confirm: &Confirm, secrets: &ResponderSecrets, now_seconds: u64, ) -> Result { - let hello_meta = hello.meta(); - let reply_meta = reply.meta(); verify_confirm( crypto, initiator, @@ -201,12 +193,12 @@ pub fn finalize_confirm( &secrets.responder_secret, initiator, responder, - &hello_meta, - hello.nonce(), - hello.kem_ct(), - &reply_meta, - reply.nonce(), - reply.kem_ct(), + &hello.meta, + &hello.nonce.0, + hello.kem_ct.as_bytes(), + &reply.meta, + &reply.nonce.0, + reply.kem_ct.as_bytes(), )) } @@ -215,28 +207,25 @@ pub fn verify_confirm( initiator: XID, responder: XID, initiator_signing_key: &MlDsaPublicKey, - hello: &impl HelloView, - reply: &impl HelloReplyView, - confirm: &impl ConfirmView, + hello: &Hello, + reply: &HelloReply, + confirm: &Confirm, now_seconds: u64, ) -> Result<(), WireError> { - let hello_meta = hello.meta(); - let reply_meta = reply.meta(); - let confirm_meta = confirm.meta(); - confirm_meta.ensure_not_expired(now_seconds)?; + confirm.meta.ensure_not_expired(now_seconds)?; let proof_data = hash_confirm_proof_data( crypto, - &confirm_meta, + &confirm.meta, initiator, responder, - &hello_meta, - hello.nonce(), - hello.kem_ct(), - &reply_meta, - reply.nonce(), - reply.kem_ct(), + &hello.meta, + &hello.nonce.0, + hello.kem_ct.as_bytes(), + &reply.meta, + &reply.nonce.0, + reply.kem_ct.as_bytes(), ); - verify_signature_bytes(initiator_signing_key, confirm.signature(), &proof_data) + verify_signature_bytes(initiator_signing_key, confirm.signature.as_bytes(), &proof_data) } pub fn build_ready( @@ -245,7 +234,7 @@ pub fn build_ready( session_key: &SessionKey, meta: ControlMeta, nonce: Nonce, -) -> Ready { +) -> Ready> { let aad = header.aad(); let body_bytes = ReadyBody { meta }.encode(); Ready { @@ -253,16 +242,16 @@ pub fn build_ready( } } -pub fn decrypt_ready( +pub fn decrypt_ready>( crypto: &impl QlCrypto, header: &QlHeader, - ready: &mut Ref, + ready: Ready, session_key: &SessionKey, now_seconds: u64, ) -> Result { let aad = header.aad(); - let plaintext = EncryptedMessage::decrypt_in_place(ready, crypto, session_key, &aad)?; - let body = ReadyBody::decode(plaintext)?; + let mut plaintext = ready.encrypted.decrypt_in_place(crypto, session_key, &aad)?; + let body = ReadyBody::decode(plaintext.as_mut())?; body.meta.ensure_not_expired(now_seconds)?; Ok(body) } diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index e2fbb9fe..351a73fc 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -1,12 +1,6 @@ -use zerocopy::{ - byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, -}; - use crate::{ - codec::{parse, push_value, read_exact}, - control::ControlMetaWire, - encrypted_message::{EncryptedMessage, EncryptedMessageWire}, - ControlMeta, MlDsaSignature, MlKemCiphertext, Nonce, WireError, + codec, encrypted_message::EncryptedMessage, ControlMeta, MlDsaSignature, MlKemCiphertext, + ByteSlice, Nonce, WireError, }; mod crypto; @@ -20,82 +14,29 @@ pub struct Hello { pub signature: MlDsaSignature, } -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] -#[repr(C)] -pub struct HelloWire { - pub meta: ControlMetaWire, - pub nonce: [u8; Nonce::SIZE], - pub kem_ct: [u8; MlKemCiphertext::SIZE], - pub signature: [u8; MlDsaSignature::SIZE], -} - -pub trait HelloView { - fn meta(&self) -> ControlMeta; - fn nonce(&self) -> &[u8; Nonce::SIZE]; - fn kem_ct(&self) -> &[u8; MlKemCiphertext::SIZE]; - fn signature(&self) -> &[u8; MlDsaSignature::SIZE]; -} - -impl HelloView for Hello { - fn meta(&self) -> ControlMeta { - self.meta - } - - fn nonce(&self) -> &[u8; Nonce::SIZE] { - &self.nonce.0 - } - - fn kem_ct(&self) -> &[u8; MlKemCiphertext::SIZE] { - self.kem_ct.as_bytes() - } - - fn signature(&self) -> &[u8; MlDsaSignature::SIZE] { - self.signature.as_bytes() - } -} - -impl HelloView for Ref { - fn meta(&self) -> ControlMeta { - ControlMeta::from_wire(self.meta) - } - - fn nonce(&self) -> &[u8; Nonce::SIZE] { - &self.nonce - } - - fn kem_ct(&self) -> &[u8; MlKemCiphertext::SIZE] { - &self.kem_ct - } - - fn signature(&self) -> &[u8; MlDsaSignature::SIZE] { - &self.signature - } -} - impl Hello { - pub fn parse(bytes: B) -> Result, WireError> { - parse(bytes) - } + pub const ENCODED_LEN: usize = ControlMeta::ENCODED_LEN + + Nonce::SIZE + + MlKemCiphertext::SIZE + + MlDsaSignature::SIZE; - pub fn from_wire(wire: &HelloWire) -> Self { - Self { - meta: ControlMeta::from_wire(wire.meta), - nonce: Nonce(wire.nonce), - kem_ct: MlKemCiphertext::from_data(wire.kem_ct), - signature: MlDsaSignature::from_data(wire.signature), - } + pub fn encode_into(&self, out: &mut Vec) { + self.meta.encode_into(out); + codec::push_bytes(out, &self.nonce.0); + codec::push_bytes(out, self.kem_ct.as_bytes()); + codec::push_bytes(out, self.signature.as_bytes()); } - pub fn encode_into(&self, out: &mut Vec) { - push_value( - out, - &HelloWire { - meta: self.meta.to_wire(), - nonce: self.nonce.0, - kem_ct: *self.kem_ct.as_bytes(), - signature: *self.signature.as_bytes(), - }, - ); + pub fn decode(bytes: &[u8]) -> Result { + let mut reader = codec::Reader::new(bytes); + let hello = Self { + meta: ControlMeta::decode_from(&mut reader)?, + nonce: Nonce(reader.take_array()?), + kem_ct: MlKemCiphertext::from_data(reader.take_array()?), + signature: MlDsaSignature::from_data(reader.take_array()?), + }; + reader.finish()?; + Ok(hello) } } @@ -107,82 +48,29 @@ pub struct HelloReply { pub signature: MlDsaSignature, } -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] -#[repr(C)] -pub struct HelloReplyWire { - pub meta: ControlMetaWire, - pub nonce: [u8; Nonce::SIZE], - pub kem_ct: [u8; MlKemCiphertext::SIZE], - pub signature: [u8; MlDsaSignature::SIZE], -} - -pub trait HelloReplyView { - fn meta(&self) -> ControlMeta; - fn nonce(&self) -> &[u8; Nonce::SIZE]; - fn kem_ct(&self) -> &[u8; MlKemCiphertext::SIZE]; - fn signature(&self) -> &[u8; MlDsaSignature::SIZE]; -} - -impl HelloReplyView for HelloReply { - fn meta(&self) -> ControlMeta { - self.meta - } - - fn nonce(&self) -> &[u8; Nonce::SIZE] { - &self.nonce.0 - } - - fn kem_ct(&self) -> &[u8; MlKemCiphertext::SIZE] { - self.kem_ct.as_bytes() - } - - fn signature(&self) -> &[u8; MlDsaSignature::SIZE] { - self.signature.as_bytes() - } -} - -impl HelloReplyView for Ref { - fn meta(&self) -> ControlMeta { - ControlMeta::from_wire(self.meta) - } - - fn nonce(&self) -> &[u8; Nonce::SIZE] { - &self.nonce - } - - fn kem_ct(&self) -> &[u8; MlKemCiphertext::SIZE] { - &self.kem_ct - } - - fn signature(&self) -> &[u8; MlDsaSignature::SIZE] { - &self.signature - } -} - impl HelloReply { - pub fn parse(bytes: B) -> Result, WireError> { - parse(bytes) - } + pub const ENCODED_LEN: usize = ControlMeta::ENCODED_LEN + + Nonce::SIZE + + MlKemCiphertext::SIZE + + MlDsaSignature::SIZE; - pub fn from_wire(wire: &HelloReplyWire) -> Self { - Self { - meta: ControlMeta::from_wire(wire.meta), - nonce: Nonce(wire.nonce), - kem_ct: MlKemCiphertext::from_data(wire.kem_ct), - signature: MlDsaSignature::from_data(wire.signature), - } + pub fn encode_into(&self, out: &mut Vec) { + self.meta.encode_into(out); + codec::push_bytes(out, &self.nonce.0); + codec::push_bytes(out, self.kem_ct.as_bytes()); + codec::push_bytes(out, self.signature.as_bytes()); } - pub fn encode_into(&self, out: &mut Vec) { - push_value( - out, - &HelloReplyWire { - meta: self.meta.to_wire(), - nonce: self.nonce.0, - kem_ct: *self.kem_ct.as_bytes(), - signature: *self.signature.as_bytes(), - }, - ); + pub fn decode(bytes: &[u8]) -> Result { + let mut reader = codec::Reader::new(bytes); + let reply = Self { + meta: ControlMeta::decode_from(&mut reader)?, + nonce: Nonce(reader.take_array()?), + kem_ct: MlKemCiphertext::from_data(reader.take_array()?), + signature: MlDsaSignature::from_data(reader.take_array()?), + }; + reader.finish()?; + Ok(reply) } } @@ -192,64 +80,28 @@ pub struct Confirm { pub signature: MlDsaSignature, } -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] -#[repr(C)] -pub struct ConfirmWire { - pub meta: ControlMetaWire, - pub signature: [u8; MlDsaSignature::SIZE], -} - -pub trait ConfirmView { - fn meta(&self) -> ControlMeta; - fn signature(&self) -> &[u8; MlDsaSignature::SIZE]; -} - -impl ConfirmView for Confirm { - fn meta(&self) -> ControlMeta { - self.meta - } - - fn signature(&self) -> &[u8; MlDsaSignature::SIZE] { - self.signature.as_bytes() - } -} - -impl ConfirmView for Ref { - fn meta(&self) -> ControlMeta { - ControlMeta::from_wire(self.meta) - } - - fn signature(&self) -> &[u8; MlDsaSignature::SIZE] { - &self.signature - } -} - impl Confirm { - pub fn parse(bytes: B) -> Result, WireError> { - parse(bytes) - } + pub const ENCODED_LEN: usize = ControlMeta::ENCODED_LEN + MlDsaSignature::SIZE; - pub fn from_wire(wire: &ConfirmWire) -> Self { - Self { - meta: ControlMeta::from_wire(wire.meta), - signature: MlDsaSignature::from_data(wire.signature), - } + pub fn encode_into(&self, out: &mut Vec) { + self.meta.encode_into(out); + codec::push_bytes(out, self.signature.as_bytes()); } - pub fn encode_into(&self, out: &mut Vec) { - push_value( - out, - &ConfirmWire { - meta: self.meta.to_wire(), - signature: *self.signature.as_bytes(), - }, - ); + pub fn decode(bytes: &[u8]) -> Result { + let mut reader = codec::Reader::new(bytes); + let confirm = Self { + meta: ControlMeta::decode_from(&mut reader)?, + signature: MlDsaSignature::from_data(reader.take_array()?), + }; + reader.finish()?; + Ok(confirm) } } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct Ready { - pub encrypted: EncryptedMessage, +pub struct Ready { + pub encrypted: EncryptedMessage, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -257,32 +109,48 @@ pub struct ReadyBody { pub meta: ControlMeta, } -impl Ready { - pub fn parse(bytes: B) -> Result, WireError> { - EncryptedMessage::parse(bytes) +impl Ready { + pub fn parse(bytes: B) -> Result { + Ok(Self { + encrypted: EncryptedMessage::parse(bytes)?, + }) } +} - pub fn from_wire(wire: &EncryptedMessageWire) -> Self { - Self { - encrypted: EncryptedMessage::from_wire(wire), +impl Ready { + pub fn into_owned(self) -> Ready> + where + B: AsRef<[u8]>, + { + Ready { + encrypted: self.encrypted.into_owned(), } } +} +impl> Ready { pub fn encode_into(&self, out: &mut Vec) { self.encrypted.encode_into(out); } } +impl Ready> { + pub fn decode(bytes: &[u8]) -> Result { + EncryptedMessage::parse(bytes) + .map(|encrypted| Self { + encrypted: encrypted.into_owned(), + }) + } +} + impl ReadyBody { pub fn encode(&self) -> Vec { - let wire = self.meta.to_wire(); - wire.as_bytes().to_vec() + self.meta.encode() } pub fn decode(bytes: &[u8]) -> Result { - let wire: ControlMetaWire = read_exact(bytes)?; Ok(Self { - meta: ControlMeta::from_wire(wire), + meta: ControlMeta::decode(bytes)?, }) } } diff --git a/ql-wire/src/header.rs b/ql-wire/src/header.rs index d4a460a1..bcf69201 100644 --- a/ql-wire/src/header.rs +++ b/ql-wire/src/header.rs @@ -1,8 +1,4 @@ -use zerocopy::{ - byte_slice::SplitByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned, -}; - -use crate::{codec, record::RecordKind, WireError, XID}; +use crate::{codec, record::RecordKind, ByteSlice, WireError, XID}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct QlHeader { @@ -16,40 +12,33 @@ impl QlHeader { } } -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] -#[repr(C)] -pub struct QlRecordHeaderWire { - pub kind: u8, - pub sender: [u8; XID::SIZE], - pub recipient: [u8; XID::SIZE], -} - #[derive(Debug, Clone, Copy)] pub(crate) struct DecodedRecordHeader { pub(crate) kind: RecordKind, pub(crate) header: QlHeader, } -pub(crate) fn encode_record_header(header: &QlHeader, kind: RecordKind) -> QlRecordHeaderWire { - QlRecordHeaderWire { - kind: kind as u8, - sender: header.sender.0, - recipient: header.recipient.0, - } +pub(crate) fn encode_record_header(out: &mut Vec, header: &QlHeader, kind: RecordKind) { + codec::push_u8(out, kind as u8); + codec::push_bytes(out, &header.sender.0); + codec::push_bytes(out, &header.recipient.0); } -pub(crate) fn decode_record_header( +pub(crate) fn decode_record_header( bytes: B, ) -> Result<(DecodedRecordHeader, B), WireError> { - let (wire, payload_bytes) = codec::read_prefix::(bytes)?; + let mut reader = codec::Reader::new(bytes); + let kind = RecordKind::try_from(reader.take_u8()?)?; + let sender = XID(reader.take_array()?); + let recipient = XID(reader.take_array()?); Ok(( DecodedRecordHeader { - kind: codec::read_byte(wire.kind)?, + kind, header: QlHeader { - sender: XID(wire.sender), - recipient: XID(wire.recipient), + sender, + recipient, }, }, - payload_bytes, + reader.take_rest(), )) } diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs index 2c7151f2..abed8470 100644 --- a/ql-wire/src/lib.rs +++ b/ql-wire/src/lib.rs @@ -4,9 +4,6 @@ #![allow(clippy::too_many_arguments)] -pub type Ref<'a, T> = zerocopy::Ref<&'a [u8], T>; -pub type RefMut<'a, T> = zerocopy::Ref<&'a mut [u8], T>; - mod bytes; mod codec; mod control; @@ -39,6 +36,7 @@ pub use unpair::*; pub use xid::*; pub const QL_WIRE_VERSION: u8 = 1; +pub const ENCRYPTED_MESSAGE_AUTH_SIZE: usize = 16; pub trait QlCrypto { fn fill_random_bytes(&self, data: &mut [u8]); @@ -51,7 +49,7 @@ pub trait QlCrypto { nonce: &Nonce, aad: &[u8], buffer: &mut [u8], - ) -> [u8; EncryptedMessage::AUTH_SIZE]; + ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; fn decrypt_with_aead( &self, @@ -59,7 +57,7 @@ pub trait QlCrypto { nonce: &Nonce, aad: &[u8], buffer: &mut [u8], - auth_tag: &[u8; EncryptedMessage::AUTH_SIZE], + auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], ) -> bool; } diff --git a/ql-wire/src/pair/crypto.rs b/ql-wire/src/pair/crypto.rs index 69e4d9c7..b42be525 100644 --- a/ql-wire/src/pair/crypto.rs +++ b/ql-wire/src/pair/crypto.rs @@ -1,6 +1,4 @@ -use zerocopy::{byte_slice::ByteSliceMut, Ref}; - -use super::{PairRequestBody, PairRequestRecordWire}; +use super::{PairRequestBody, PairRequestRecord}; use crate::{ pq::ML_KEM_SUITE_TAG, ControlMeta, MlDsaPublicKey, MlKemCiphertext, MlKemPublicKey, QlCrypto, QlHeader, QlIdentity, QlPayload, QlRecord, WireError, XID, @@ -50,30 +48,24 @@ pub fn build_pair_request( ); QlRecord { header, - payload: QlPayload::PairRequest(super::PairRequestRecord { kem_ct, encrypted }), + payload: QlPayload::PairRequest(PairRequestRecord { kem_ct, encrypted }), } } -pub fn decrypt_pair_request( +pub fn decrypt_pair_request>( crypto: &impl QlCrypto, identity: &QlIdentity, header: &QlHeader, - request: &mut Ref, + request: PairRequestRecord, now_seconds: u64, ) -> Result { - let kem_ct = MlKemCiphertext::from_data(request.kem_ct); + let PairRequestRecord { kem_ct, encrypted } = request; let aad = pairing_aad(header, &kem_ct); let session_key = identity .encapsulation_private_key .decapsulate_shared_secret(&kem_ct); - let mut encrypted = crate::encrypted_message::EncryptedMessage::parse(&mut request.encrypted)?; - let plaintext = crate::encrypted_message::EncryptedMessage::decrypt_in_place( - &mut encrypted, - crypto, - &session_key, - &aad, - )?; - let decrypted = PairRequestBody::decode(plaintext)?; + let mut plaintext = encrypted.decrypt_in_place(crypto, &session_key, &aad)?; + let decrypted = PairRequestBody::decode(plaintext.as_mut())?; decrypted.meta.ensure_not_expired(now_seconds)?; if decrypted.xid != header.sender { return Err(WireError::InvalidPayload); diff --git a/ql-wire/src/pair/mod.rs b/ql-wire/src/pair/mod.rs index 53999b43..d1352f61 100644 --- a/ql-wire/src/pair/mod.rs +++ b/ql-wire/src/pair/mod.rs @@ -1,21 +1,15 @@ -use zerocopy::{ - byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, -}; - use crate::{ - codec::{parse, push_value, read_exact}, - control::ControlMetaWire, - encrypted_message::EncryptedMessage, - ControlMeta, MlDsaPublicKey, MlDsaSignature, MlKemCiphertext, MlKemPublicKey, WireError, XID, + codec, encrypted_message::EncryptedMessage, ControlMeta, MlDsaPublicKey, MlDsaSignature, + ByteSlice, MlKemCiphertext, MlKemPublicKey, WireError, XID, }; mod crypto; pub use crypto::*; #[derive(Debug, Clone, PartialEq, Eq)] -pub struct PairRequestRecord { +pub struct PairRequestRecord { pub kem_ct: MlKemCiphertext, - pub encrypted: EncryptedMessage, + pub encrypted: EncryptedMessage, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -27,84 +21,66 @@ pub struct PairRequestBody { pub proof: MlDsaSignature, } -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] -#[repr(C, packed)] -pub struct PairRequestRecordWire { - pub kem_ct: [u8; MlKemCiphertext::SIZE], - pub encrypted: [u8], -} - -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] -#[repr(C)] -pub struct PairRequestBodyWire { - pub meta: ControlMetaWire, - pub xid: [u8; XID::SIZE], - pub signing_pub_key: [u8; MlDsaPublicKey::SIZE], - pub encapsulation_pub_key: [u8; MlKemPublicKey::SIZE], - pub proof: [u8; MlDsaSignature::SIZE], -} - -impl PairRequestRecord { - pub fn parse(bytes: B) -> Result, WireError> { - let record: Ref = parse(bytes)?; - let _ = EncryptedMessage::parse(&record.encrypted)?; - Ok(record) +impl PairRequestRecord { + pub fn parse(bytes: B) -> Result { + let mut reader = codec::Reader::new(bytes); + Ok(Self { + kem_ct: MlKemCiphertext::from_data(reader.take_array()?), + encrypted: EncryptedMessage::parse(reader.take_rest())?, + }) } +} - pub fn from_wire(wire: &PairRequestRecordWire) -> Self { - let encrypted = - EncryptedMessage::parse(&wire.encrypted).expect("validated pair request record"); - Self { - kem_ct: MlKemCiphertext::from_data(wire.kem_ct), - encrypted: EncryptedMessage::from_wire(&encrypted), +impl PairRequestRecord { + pub fn into_owned(self) -> PairRequestRecord> + where + B: AsRef<[u8]>, + { + PairRequestRecord { + kem_ct: self.kem_ct, + encrypted: self.encrypted.into_owned(), } } +} +impl> PairRequestRecord { pub fn encode_into(&self, out: &mut Vec) { - push_value( - out, - &PairRequestHeaderWire { - kem_ct: *self.kem_ct.as_bytes(), - }, - ); - out.extend_from_slice(&self.encrypted.encode()); + codec::push_bytes(out, self.kem_ct.as_bytes()); + self.encrypted.encode_into(out); } } impl PairRequestBody { - pub fn from_wire(wire: PairRequestBodyWire) -> Self { - Self { - meta: ControlMeta::from_wire(wire.meta), - xid: XID(wire.xid), - signing_pub_key: MlDsaPublicKey::from_data(wire.signing_pub_key), - encapsulation_pub_key: MlKemPublicKey::from_data(wire.encapsulation_pub_key), - proof: MlDsaSignature::from_data(wire.proof), - } - } + pub const ENCODED_LEN: usize = ControlMeta::ENCODED_LEN + + XID::SIZE + + MlDsaPublicKey::SIZE + + MlKemPublicKey::SIZE + + MlDsaSignature::SIZE; - pub fn to_wire(&self) -> PairRequestBodyWire { - PairRequestBodyWire { - meta: self.meta.to_wire(), - xid: self.xid.0, - signing_pub_key: *self.signing_pub_key.as_bytes(), - encapsulation_pub_key: *self.encapsulation_pub_key.as_bytes(), - proof: *self.proof.as_bytes(), - } + pub fn encode_into(&self, out: &mut Vec) { + self.meta.encode_into(out); + codec::push_bytes(out, &self.xid.0); + codec::push_bytes(out, self.signing_pub_key.as_bytes()); + codec::push_bytes(out, self.encapsulation_pub_key.as_bytes()); + codec::push_bytes(out, self.proof.as_bytes()); } pub fn encode(&self) -> Vec { - let wire = self.to_wire(); - wire.as_bytes().to_vec() + let mut out = Vec::with_capacity(Self::ENCODED_LEN); + self.encode_into(&mut out); + out } pub fn decode(bytes: &[u8]) -> Result { - let wire: PairRequestBodyWire = read_exact(bytes)?; - Ok(Self::from_wire(wire)) + let mut reader = codec::Reader::new(bytes); + let body = Self { + meta: ControlMeta::decode_from(&mut reader)?, + xid: XID(reader.take_array()?), + signing_pub_key: MlDsaPublicKey::from_data(reader.take_array()?), + encapsulation_pub_key: MlKemPublicKey::from_data(reader.take_array()?), + proof: MlDsaSignature::from_data(reader.take_array()?), + }; + reader.finish()?; + Ok(body) } } - -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] -#[repr(C)] -pub struct PairRequestHeaderWire { - pub kem_ct: [u8; MlKemCiphertext::SIZE], -} diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index aad98e7b..9a2e362e 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -1,15 +1,10 @@ -use zerocopy::{ - byte_slice::{ByteSlice, SplitByteSlice}, - Immutable, IntoBytes, KnownLayout, Ref, TryFromBytes, Unaligned, -}; - use crate::{ codec, - encrypted_message::{EncryptedMessage, EncryptedMessageWire}, - handshake::{self, ConfirmWire, HelloReplyWire, HelloWire}, + encrypted_message::EncryptedMessage, + handshake::{self, Confirm, Hello, HelloReply, Ready}, header::{decode_record_header, encode_record_header, QlHeader}, - pair::{self, PairRequestRecordWire}, - unpair::{self, UnpairWire}, + pair::PairRequestRecord, + unpair::Unpair, WireError, QL_WIRE_VERSION, }; @@ -21,33 +16,33 @@ pub struct QlRecord { #[derive(Debug, Clone, PartialEq, Eq)] pub enum QlPayload { - PairRequest(pair::PairRequestRecord), - Unpair(unpair::Unpair), - Hello(handshake::Hello), - HelloReply(handshake::HelloReply), - Confirm(handshake::Confirm), - Ready(handshake::Ready), - Session(EncryptedMessage), + PairRequest(PairRequestRecord>), + Unpair(Unpair), + Hello(Hello), + HelloReply(HelloReply), + Confirm(Confirm), + Ready(Ready>), + Session(EncryptedMessage>), } +#[derive(Debug, Clone, PartialEq, Eq)] pub struct QlRecordRef { pub header: QlHeader, pub payload: QlPayloadRef, } +#[derive(Debug, Clone, PartialEq, Eq)] pub enum QlPayloadRef { - PairRequest(Ref), - Unpair(Ref), - Hello(Ref), - HelloReply(Ref), - Confirm(Ref), - Ready(Ref), - Session(Ref), + PairRequest(PairRequestRecord), + Unpair(Unpair), + Hello(Hello), + HelloReply(HelloReply), + Confirm(Confirm), + Ready(Ready), + Session(EncryptedMessage), } -#[derive( - Debug, Clone, Copy, PartialEq, Eq, TryFromBytes, KnownLayout, Immutable, IntoBytes, Unaligned, -)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub(crate) enum RecordKind { PairRequest = 1, @@ -59,6 +54,23 @@ pub(crate) enum RecordKind { Unpair = 7, } +impl TryFrom for RecordKind { + type Error = WireError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::PairRequest), + 2 => Ok(Self::Hello), + 3 => Ok(Self::HelloReply), + 4 => Ok(Self::Confirm), + 5 => Ok(Self::Ready), + 6 => Ok(Self::Session), + 7 => Ok(Self::Unpair), + _ => Err(WireError::InvalidPayload), + } + } +} + impl RecordKind { fn for_payload(payload: &QlPayload) -> Self { match payload { @@ -76,9 +88,8 @@ impl RecordKind { impl QlRecord { pub fn encode(&self) -> Vec { let mut out = Vec::new(); - out.push(QL_WIRE_VERSION); - let header = encode_record_header(&self.header, RecordKind::for_payload(&self.payload)); - codec::push_value(&mut out, &header); + codec::push_u8(&mut out, QL_WIRE_VERSION); + encode_record_header(&mut out, &self.header, RecordKind::for_payload(&self.payload)); match &self.payload { QlPayload::PairRequest(request) => request.encode_into(&mut out), QlPayload::Unpair(unpair) => unpair.encode_into(&mut out), @@ -104,13 +115,13 @@ impl QlRecord { } } -impl QlRecordRef { +impl QlRecordRef { pub fn parse(bytes: B) -> Result { - let (version, payload_bytes) = codec::read_prefix::(bytes)?; - if version != QL_WIRE_VERSION { + let mut reader = codec::Reader::new(bytes); + if reader.take_u8()? != QL_WIRE_VERSION { return Err(WireError::InvalidPayload); } - let (header, payload_bytes) = decode_record_header(payload_bytes)?; + let (header, payload_bytes) = decode_record_header(reader.take_rest())?; let payload = parse_payload(header.kind, payload_bytes)?; Ok(Self { header: header.header, @@ -119,7 +130,7 @@ impl QlRecordRef { } } -impl QlRecordRef { +impl> QlRecordRef { pub fn to_owned(&self) -> QlRecord { QlRecord { header: self.header, @@ -128,36 +139,48 @@ impl QlRecordRef { } } -impl QlPayloadRef { +impl> QlPayloadRef { pub fn to_owned(&self) -> QlPayload { match self { - Self::PairRequest(request) => { - QlPayload::PairRequest(pair::PairRequestRecord::from_wire(request)) - } - Self::Unpair(unpair) => QlPayload::Unpair(unpair::Unpair::from_wire(unpair)), - Self::Hello(hello) => QlPayload::Hello(handshake::Hello::from_wire(hello)), - Self::HelloReply(reply) => { - QlPayload::HelloReply(handshake::HelloReply::from_wire(reply)) - } - Self::Confirm(confirm) => QlPayload::Confirm(handshake::Confirm::from_wire(confirm)), - Self::Ready(ready) => QlPayload::Ready(handshake::Ready::from_wire(ready)), - Self::Session(encrypted) => QlPayload::Session(EncryptedMessage::from_wire(encrypted)), + Self::PairRequest(request) => QlPayload::PairRequest(PairRequestRecord { + kem_ct: request.kem_ct.clone(), + encrypted: EncryptedMessage { + nonce: request.encrypted.nonce, + auth: request.encrypted.auth, + ciphertext: request.encrypted.ciphertext.as_ref().to_vec(), + }, + }), + Self::Unpair(unpair) => QlPayload::Unpair(unpair.clone()), + Self::Hello(hello) => QlPayload::Hello(hello.clone()), + Self::HelloReply(reply) => QlPayload::HelloReply(reply.clone()), + Self::Confirm(confirm) => QlPayload::Confirm(confirm.clone()), + Self::Ready(ready) => QlPayload::Ready(Ready { + encrypted: EncryptedMessage { + nonce: ready.encrypted.nonce, + auth: ready.encrypted.auth, + ciphertext: ready.encrypted.ciphertext.as_ref().to_vec(), + }, + }), + Self::Session(encrypted) => QlPayload::Session(EncryptedMessage { + nonce: encrypted.nonce, + auth: encrypted.auth, + ciphertext: encrypted.ciphertext.as_ref().to_vec(), + }), } } } -fn parse_payload(kind: RecordKind, payload: B) -> Result, WireError> { +fn parse_payload( + kind: RecordKind, + payload: B, +) -> Result, WireError> { match kind { - RecordKind::PairRequest => Ok(QlPayloadRef::PairRequest(pair::PairRequestRecord::parse( - payload, - )?)), - RecordKind::Unpair => Ok(QlPayloadRef::Unpair(unpair::Unpair::parse(payload)?)), - RecordKind::Hello => Ok(QlPayloadRef::Hello(handshake::Hello::parse(payload)?)), - RecordKind::HelloReply => Ok(QlPayloadRef::HelloReply(handshake::HelloReply::parse( - payload, - )?)), - RecordKind::Confirm => Ok(QlPayloadRef::Confirm(handshake::Confirm::parse(payload)?)), - RecordKind::Ready => Ok(QlPayloadRef::Ready(handshake::Ready::parse(payload)?)), + RecordKind::PairRequest => Ok(QlPayloadRef::PairRequest(PairRequestRecord::parse(payload)?)), + RecordKind::Unpair => Ok(QlPayloadRef::Unpair(Unpair::decode(&payload[..])?)), + RecordKind::Hello => Ok(QlPayloadRef::Hello(handshake::Hello::decode(&payload[..])?)), + RecordKind::HelloReply => Ok(QlPayloadRef::HelloReply(HelloReply::decode(&payload[..])?)), + RecordKind::Confirm => Ok(QlPayloadRef::Confirm(Confirm::decode(&payload[..])?)), + RecordKind::Ready => Ok(QlPayloadRef::Ready(Ready::parse(payload)?)), RecordKind::Session => Ok(QlPayloadRef::Session(EncryptedMessage::parse(payload)?)), } } diff --git a/ql-wire/src/ref.rs b/ql-wire/src/ref.rs deleted file mode 100644 index 45377642..00000000 --- a/ql-wire/src/ref.rs +++ /dev/null @@ -1,162 +0,0 @@ -use core::{ - fmt, - marker::PhantomData, - ops::{Deref, DerefMut}, -}; - -use crate::{ByteSlice, ByteSliceMut}; - -/// Typed bytes backed by a mutable or immutable byte slice. -/// -/// Unlike `zerocopy::Ref`, this type does not perform any size, alignment, or -/// layout validation for `T`. `T` is only a marker carried alongside the bytes. -pub struct Ref { - bytes: B, - _marker: PhantomData<*const T>, -} - -impl Ref { - pub const fn new(bytes: B) -> Self { - Self { - bytes, - _marker: PhantomData, - } - } - - pub fn into_bytes(self) -> B { - self.bytes - } - - pub fn retag(self) -> Ref { - Ref::new(self.bytes) - } -} - -impl Ref { - pub fn bytes(&self) -> &[u8] { - self.bytes.deref() - } - - pub fn len(&self) -> usize { - self.bytes.len() - } - - pub fn is_empty(&self) -> bool { - self.bytes.is_empty() - } - - pub fn reborrow(&self) -> Ref<&[u8], T> { - Ref::new(self.bytes()) - } -} - -impl Ref { - pub fn bytes_mut(&mut self) -> &mut [u8] { - self.bytes.deref_mut() - } - - pub fn reborrow_mut(&mut self) -> Ref<&mut [u8], T> { - Ref::new(self.bytes_mut()) - } -} - -impl Deref for Ref { - type Target = [u8]; - - fn deref(&self) -> &Self::Target { - self.bytes() - } -} - -impl DerefMut for Ref { - fn deref_mut(&mut self) -> &mut Self::Target { - self.bytes_mut() - } -} - -impl Clone for Ref { - fn clone(&self) -> Self { - Self::new(self.bytes.clone()) - } -} - -impl Copy for Ref {} - -impl AsRef<[u8]> for Ref { - fn as_ref(&self) -> &[u8] { - self.bytes() - } -} - -impl AsMut<[u8]> for Ref { - fn as_mut(&mut self) -> &mut [u8] { - self.bytes_mut() - } -} - -impl fmt::Debug for Ref { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Ref") - .field("type", &core::any::type_name::()) - .field("bytes", &self.bytes()) - .finish() - } -} - -#[cfg(test)] -mod tests { - use super::Ref; - - struct Message; - struct OtherMessage; - - #[test] - fn shared_ref_exposes_bytes() { - let bytes = b"hello"; - let reference = Ref::<_, Message>::new(&bytes[..]); - - assert_eq!(reference.bytes(), b"hello"); - assert_eq!(reference.len(), 5); - assert!(!reference.is_empty()); - } - - #[test] - fn mutable_ref_exposes_mutable_bytes() { - let mut bytes = *b"hello"; - let mut reference = Ref::<_, Message>::new(&mut bytes[..]); - - reference.bytes_mut()[0] = b'j'; - assert_eq!(&bytes, b"jello"); - } - - #[test] - fn ref_can_be_retagged() { - let bytes = b"hello"; - let reference = Ref::<_, Message>::new(&bytes[..]); - let other = reference.retag::(); - - assert_eq!(other.bytes(), b"hello"); - } - - #[test] - fn ref_can_be_reborrowed() { - let bytes = b"hello"; - let reference = Ref::<_, Message>::new(&bytes[..]); - let borrowed = reference.reborrow(); - - assert_eq!(borrowed.bytes(), b"hello"); - } - - #[test] - fn mutable_ref_can_be_reborrowed_mutably() { - let mut bytes = *b"hello"; - let mut reference = Ref::<_, Message>::new(&mut bytes[..]); - - { - let mut borrowed = reference.reborrow_mut(); - borrowed[1] = b'a'; - } - - assert_eq!(&bytes, b"hallo"); - } -} diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 6ba1360c..e7becfb6 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -35,10 +35,10 @@ impl QlCrypto for TestCrypto { nonce: &Nonce, aad: &[u8], buffer: &mut [u8], - ) -> [u8; EncryptedMessage::AUTH_SIZE] { + ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE] { let key: AesGcm256Key = (*key.data()).into(); let plaintext = buffer.to_vec(); - let mut auth = [0u8; EncryptedMessage::AUTH_SIZE]; + let mut auth = [0u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; key.encrypt( buffer, (&mut auth).into(), @@ -56,7 +56,7 @@ impl QlCrypto for TestCrypto { nonce: &Nonce, aad: &[u8], buffer: &mut [u8], - auth_tag: &[u8; EncryptedMessage::AUTH_SIZE], + auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], ) -> bool { let key: AesGcm256Key = (*key.data()).into(); let ciphertext = buffer.to_vec(); @@ -122,12 +122,11 @@ fn encrypted_session_record_round_trip_and_decrypt() { let mut bytes = bytes; let QlRecordRef { header, payload } = QlRecord::parse_mut(&mut bytes).unwrap(); - let QlPayloadRef::Session(mut encrypted) = payload else { + let QlPayloadRef::Session(encrypted) = payload else { panic!("expected session payload"); }; - let decrypted = - encrypted::decrypt_record(&crypto, &header, &mut encrypted, &session_key).unwrap(); - assert_eq!(SessionRecord::from_wire(&decrypted).unwrap(), body); + let decrypted = encrypted::decrypt_record(&crypto, &header, encrypted, &session_key).unwrap(); + assert_eq!(SessionRecord::decode(decrypted).unwrap(), body); } #[test] @@ -168,33 +167,30 @@ fn decrypted_session_record_iterates_zero_copy_frames() { let mut bytes = record.encode(); let QlRecordRef { header, payload } = QlRecord::parse_mut(&mut bytes).unwrap(); - let QlPayloadRef::Session(mut encrypted) = payload else { + let QlPayloadRef::Session(encrypted) = payload else { panic!("expected session payload"); }; - let decrypted = - encrypted::decrypt_record(&crypto, &header, &mut encrypted, &session_key).unwrap(); - - assert_eq!(decrypted.seq(), RecordSeq(7)); - let mut frames = decrypted.frames(); + let decrypted = encrypted::decrypt_record(&crypto, &header, encrypted, &session_key).unwrap(); + let (seq, mut frames) = SessionRecord::parse(decrypted).unwrap(); + assert_eq!(seq, RecordSeq(7)); match frames.next().unwrap().unwrap() { - SessionFrameRef::StreamData(frame) => { - assert_eq!(frame.stream_id(), StreamId(1)); - assert_eq!(frame.offset(), 5); - assert!(!frame.fin().unwrap()); - assert_eq!(frame.bytes(), b"abc"); + SessionFrame::StreamData(frame) => { + assert_eq!(frame.stream_id, StreamId(1)); + assert_eq!(frame.offset, 5); + assert!(!frame.fin); + assert_eq!(frame.bytes, b"abc"); } other => panic!("expected stream data, got {}", frame_name(&other)), } match frames.next().unwrap().unwrap() { - SessionFrameRef::Ack(frame) => { - let ranges: Vec<_> = frame.ranges().collect(); - assert_eq!(ranges, vec![RecordAckRange { start: 3, end: 8 }]); + SessionFrame::Ack(frame) => { + assert_eq!(frame.ranges, vec![RecordAckRange { start: 3, end: 8 }]); } other => panic!("expected ack, got {}", frame_name(&other)), } match frames.next().unwrap().unwrap() { - SessionFrameRef::StreamClose(frame) => { - let owned = StreamClose::from_wire(&frame).unwrap(); + SessionFrame::StreamClose(frame) => { + let owned = frame.into_owned(); assert_eq!(owned.stream_id, StreamId(1)); assert_eq!(owned.target, CloseTarget::Response); assert_eq!(owned.payload, b"later".to_vec()); @@ -240,10 +236,10 @@ fn pair_request_round_trip_and_decrypt() { let mut bytes = record.encode(); let QlRecordRef { header, payload } = QlRecord::parse_mut(&mut bytes).unwrap(); - let QlPayloadRef::PairRequest(mut request) = payload else { + let QlPayloadRef::PairRequest(request) = payload else { panic!("expected pair request"); }; - let body = pair::decrypt_pair_request(&crypto, &recipient, &header, &mut request, 100).unwrap(); + let body = pair::decrypt_pair_request(&crypto, &recipient, &header, request, 100).unwrap(); assert_eq!(body.meta, meta); assert_eq!(body.xid, sender.xid); assert_eq!(body.signing_pub_key, sender.signing_public_key); @@ -279,10 +275,10 @@ fn ready_round_trip_and_decrypt() { assert_eq!(parsed, record); let QlRecordRef { header, payload } = QlRecord::parse_mut(&mut bytes).unwrap(); - let QlPayloadRef::Ready(mut ready) = payload else { + let QlPayloadRef::Ready(ready) = payload else { panic!("expected ready payload"); }; - let body = handshake::decrypt_ready(&crypto, &header, &mut ready, &session_key, 100).unwrap(); + let body = handshake::decrypt_ready(&crypto, &header, ready, &session_key, 100).unwrap(); assert_eq!(body.meta, meta); } @@ -433,10 +429,10 @@ fn protocol_record_size_breakdown() { } } - fn encrypted(tag: u8, ciphertext_len: usize) -> EncryptedMessage { + fn encrypted(tag: u8, ciphertext_len: usize) -> EncryptedMessage> { EncryptedMessage { nonce: Nonce([tag; Nonce::SIZE]), - auth: [tag; EncryptedMessage::AUTH_SIZE], + auth: [tag; ENCRYPTED_MESSAGE_AUTH_SIZE], ciphertext: vec![tag; ciphertext_len], } } @@ -606,13 +602,13 @@ fn protocol_record_size_breakdown() { print_size("ql-wire session close", session_close.encode().len()); } -fn frame_name(frame: &SessionFrameRef<'_>) -> &'static str { +fn frame_name(frame: &SessionFrame<&[u8]>) -> &'static str { match frame { - SessionFrameRef::Ping => "ping", - SessionFrameRef::Ack(_) => "ack", - SessionFrameRef::StreamData(_) => "stream_data", - SessionFrameRef::StreamWindow(_) => "stream_window", - SessionFrameRef::StreamClose(_) => "stream_close", - SessionFrameRef::Close(_) => "close", + SessionFrame::Ping => "ping", + SessionFrame::Ack(_) => "ack", + SessionFrame::StreamData(_) => "stream_data", + SessionFrame::StreamWindow(_) => "stream_window", + SessionFrame::StreamClose(_) => "stream_close", + SessionFrame::Close(_) => "close", } } diff --git a/ql-wire/src/unpair/crypto.rs b/ql-wire/src/unpair/crypto.rs index fa5ea03a..6062ca2c 100644 --- a/ql-wire/src/unpair/crypto.rs +++ b/ql-wire/src/unpair/crypto.rs @@ -1,11 +1,10 @@ -use zerocopy::{byte_slice::ByteSlice, Ref}; - -use super::UnpairWire; use crate::{ ControlMeta, MlDsaPublicKey, QlCrypto, QlHeader, QlIdentity, QlPayload, QlRecord, WireError, XID, }; +use super::Unpair; + pub fn build_unpair( crypto: &impl QlCrypto, identity: &QlIdentity, @@ -21,22 +20,21 @@ pub fn build_unpair( .sign(crypto, &hash_unpair_signature_data(crypto, &header, &meta)); QlRecord { header, - payload: QlPayload::Unpair(super::Unpair { meta, signature }), + payload: QlPayload::Unpair(Unpair { meta, signature }), } } -pub fn verify_unpair( +pub fn verify_unpair( crypto: &impl QlCrypto, header: &QlHeader, signer: &MlDsaPublicKey, - unpair: &Ref, + unpair: &Unpair, now_seconds: u64, ) -> Result<(), WireError> { - let meta = ControlMeta::from_wire(unpair.meta); - meta.ensure_not_expired(now_seconds)?; + unpair.meta.ensure_not_expired(now_seconds)?; if signer.verify_bytes( - &unpair.signature, - &hash_unpair_signature_data(crypto, header, &meta), + unpair.signature.as_bytes(), + &hash_unpair_signature_data(crypto, header, &unpair.meta), ) { Ok(()) } else { diff --git a/ql-wire/src/unpair/mod.rs b/ql-wire/src/unpair/mod.rs index 12b0b612..593c3b7b 100644 --- a/ql-wire/src/unpair/mod.rs +++ b/ql-wire/src/unpair/mod.rs @@ -1,12 +1,4 @@ -use zerocopy::{ - byte_slice::ByteSlice, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, -}; - -use crate::{ - codec::{parse, push_value}, - control::ControlMetaWire, - ControlMeta, MlDsaSignature, WireError, -}; +use crate::{codec, ControlMeta, MlDsaSignature, WireError}; mod crypto; pub use crypto::*; @@ -17,32 +9,21 @@ pub struct Unpair { pub signature: MlDsaSignature, } -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, Debug, Clone, Copy)] -#[repr(C)] -pub struct UnpairWire { - pub meta: ControlMetaWire, - pub signature: [u8; MlDsaSignature::SIZE], -} - impl Unpair { - pub fn parse(bytes: B) -> Result, WireError> { - parse(bytes) - } + pub const WIRE_SIZE: usize = ControlMeta::ENCODED_LEN + MlDsaSignature::SIZE; - pub fn from_wire(wire: &UnpairWire) -> Self { - Self { - meta: ControlMeta::from_wire(wire.meta), - signature: MlDsaSignature::from_data(wire.signature), - } + pub fn decode(bytes: &[u8]) -> Result { + let mut reader = codec::Reader::new(bytes); + let unpair = Self { + meta: ControlMeta::decode_from(&mut reader)?, + signature: MlDsaSignature::from_data(reader.take_array()?), + }; + reader.finish()?; + Ok(unpair) } pub fn encode_into(&self, out: &mut Vec) { - push_value( - out, - &UnpairWire { - meta: self.meta.to_wire(), - signature: *self.signature.as_bytes(), - }, - ); + self.meta.encode_into(out); + codec::push_bytes(out, self.signature.as_bytes()); } } From 7711012fa9d4c375e01224ff4cd3baf0f311c5c4 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 28 Mar 2026 15:33:45 -0400 Subject: [PATCH 044/304] ql-wire: builder --- ql-wire/src/encrypted/ack.rs | 5 + ql-wire/src/encrypted/builder.rs | 139 +++++++++++++++++++++++++ ql-wire/src/encrypted/close.rs | 1 + ql-wire/src/encrypted/mod.rs | 57 ++++++---- ql-wire/src/encrypted/stream_close.rs | 6 ++ ql-wire/src/encrypted/stream_data.rs | 6 ++ ql-wire/src/encrypted/stream_window.rs | 1 + ql-wire/src/encrypted_message.rs | 2 +- ql-wire/src/handshake/crypto.rs | 26 +++-- ql-wire/src/handshake/mod.rs | 23 ++-- ql-wire/src/header.rs | 5 +- ql-wire/src/pair/mod.rs | 4 +- ql-wire/src/record.rs | 10 +- ql-wire/src/tests.rs | 34 ++++++ ql-wire/src/unpair/crypto.rs | 3 +- 15 files changed, 273 insertions(+), 49 deletions(-) create mode 100644 ql-wire/src/encrypted/builder.rs diff --git a/ql-wire/src/encrypted/ack.rs b/ql-wire/src/encrypted/ack.rs index 46d068fd..7944b22f 100644 --- a/ql-wire/src/encrypted/ack.rs +++ b/ql-wire/src/encrypted/ack.rs @@ -14,6 +14,7 @@ pub struct RecordAckRange { } impl RecordAck { + pub const FRAME_OVERHEAD: usize = std::mem::size_of::() + size_of::(); pub const RANGE_ENCODED_LEN: usize = size_of::() + size_of::(); pub fn decode(bytes: &[u8]) -> Result { @@ -49,6 +50,10 @@ impl RecordAck { self.ranges.len() * Self::RANGE_ENCODED_LEN } + pub fn frame_encoded_len(&self) -> usize { + Self::FRAME_OVERHEAD + self.encoded_len() + } + pub fn encode_into(&self, out: &mut Vec) { for range in &self.ranges { codec::push_u64(out, range.start); diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs new file mode 100644 index 00000000..dffe25db --- /dev/null +++ b/ql-wire/src/encrypted/builder.rs @@ -0,0 +1,139 @@ +use super::{ + push_variable_len, RecordAck, RecordSeq, SessionCloseBody, SessionFrame, SessionFrameKind, + StreamClose, StreamData, StreamWindow, +}; +use crate::{ + codec, encrypted_message::EncryptedMessage, Nonce, QlCrypto, QlHeader, QlPayload, QlRecord, + SessionKey, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionRecordBuilder { + max_capacity: usize, + bytes: Vec, +} + +impl SessionRecordBuilder { + pub const HEADER_LEN: usize = std::mem::size_of::(); + pub const PING_ENCODED_LEN: usize = std::mem::size_of::(); + + pub fn new(seq: RecordSeq, max_capacity: usize) -> Self { + let max_capacity = max_capacity.max(Self::HEADER_LEN); + let mut bytes = Vec::with_capacity(max_capacity); + codec::push_u64(&mut bytes, seq.0); + Self { + max_capacity, + bytes, + } + } + + pub fn max_capacity(&self) -> usize { + self.max_capacity + } + + pub fn len(&self) -> usize { + self.bytes.len() + } + + pub fn is_empty(&self) -> bool { + self.bytes.len() == Self::HEADER_LEN + } + + pub fn remaining_capacity(&self) -> usize { + self.max_capacity.saturating_sub(self.bytes.len()) + } + + pub fn bytes(&self) -> &[u8] { + &self.bytes + } + + pub fn into_plaintext(self) -> Vec { + self.bytes + } + + pub fn can_push_len(&self, len: usize) -> bool { + len <= self.remaining_capacity() || self.is_empty() + } + + pub fn push_ping(&mut self) -> bool { + if !self.can_push_len(Self::PING_ENCODED_LEN) { + return false; + } + self.bytes.push(SessionFrameKind::Ping as u8); + true + } + + pub fn push_ack(&mut self, ack: &RecordAck) -> bool { + if !self.can_push_len(ack.frame_encoded_len()) { + return false; + } + self.bytes.push(SessionFrameKind::Ack as u8); + push_variable_len(&mut self.bytes, ack.encoded_len()); + ack.encode_into(&mut self.bytes); + true + } + + pub fn push_stream_data>(&mut self, frame: &StreamData) -> bool { + if !self.can_push_len(frame.frame_encoded_len()) { + return false; + } + self.bytes.push(SessionFrameKind::StreamData as u8); + push_variable_len(&mut self.bytes, frame.encoded_len()); + frame.encode_into(&mut self.bytes); + true + } + + pub fn push_stream_window(&mut self, frame: &StreamWindow) -> bool { + if !self.can_push_len(StreamWindow::FRAME_ENCODED_LEN) { + return false; + } + self.bytes.push(SessionFrameKind::StreamWindow as u8); + frame.encode_into(&mut self.bytes); + true + } + + pub fn push_stream_close>(&mut self, frame: &StreamClose) -> bool { + if !self.can_push_len(frame.frame_encoded_len()) { + return false; + } + self.bytes.push(SessionFrameKind::StreamClose as u8); + push_variable_len(&mut self.bytes, frame.encoded_len()); + frame.encode_into(&mut self.bytes); + true + } + + pub fn push_close(&mut self, close: &SessionCloseBody) -> bool { + if !self.can_push_len(SessionCloseBody::FRAME_ENCODED_LEN) { + return false; + } + self.bytes.push(SessionFrameKind::Close as u8); + close.encode_into(&mut self.bytes); + true + } + + pub fn push_frame>(&mut self, frame: &SessionFrame) -> bool { + match frame { + SessionFrame::Ping => self.push_ping(), + SessionFrame::Ack(frame) => self.push_ack(frame), + SessionFrame::StreamData(frame) => self.push_stream_data(frame), + SessionFrame::StreamWindow(frame) => self.push_stream_window(frame), + SessionFrame::StreamClose(frame) => self.push_stream_close(frame), + SessionFrame::Close(close) => self.push_close(close), + } + } + + pub fn encrypt( + self, + crypto: &impl QlCrypto, + header: QlHeader, + session_key: &SessionKey, + nonce: Nonce, + ) -> QlRecord { + let aad = header.aad(); + let encrypted = EncryptedMessage::encrypt(crypto, session_key, self.bytes, &aad, nonce); + QlRecord { + header, + payload: QlPayload::Session(encrypted), + } + } +} diff --git a/ql-wire/src/encrypted/close.rs b/ql-wire/src/encrypted/close.rs index 4702566a..ab43e585 100644 --- a/ql-wire/src/encrypted/close.rs +++ b/ql-wire/src/encrypted/close.rs @@ -14,6 +14,7 @@ pub struct SessionCloseBody { impl SessionCloseBody { pub const WIRE_SIZE: usize = size_of::(); + pub const FRAME_ENCODED_LEN: usize = std::mem::size_of::() + Self::WIRE_SIZE; pub fn encode_into(&self, out: &mut Vec) { codec::push_u16(out, self.code.0); diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 38b260cc..cb38e2b5 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -1,18 +1,18 @@ use std::mem::size_of; use crate::{ - codec, - encrypted_message::EncryptedMessage, - QlCrypto, QlHeader, QlPayload, QlRecord, SessionKey, WireError, + codec, encrypted_message::EncryptedMessage, QlCrypto, QlHeader, QlRecord, SessionKey, WireError, }; mod ack; +mod builder; mod close; mod stream_close; mod stream_data; mod stream_window; pub use ack::*; +pub use builder::*; pub use close::*; pub use stream_close::*; pub use stream_data::*; @@ -79,6 +79,8 @@ impl TryFrom for SessionFrameKind { } impl SessionRecord { + pub const HEADER_LEN: usize = size_of::(); + pub fn parse(bytes: &[u8]) -> Result<(RecordSeq, SessionFrameIter<'_>), WireError> { let mut reader = codec::Reader::new(bytes); let seq = RecordSeq(reader.take_u64()?); @@ -95,23 +97,40 @@ impl SessionRecord { let frames = frames .map(|frame| frame.map(SessionFrame::into_owned)) .collect::, _>>()?; - Ok(Self { - seq, - frames, - }) + Ok(Self { seq, frames }) + } + + pub fn encoded_len(&self) -> usize { + Self::HEADER_LEN + + self + .frames + .iter() + .map(SessionFrame::encoded_len) + .sum::() } pub fn encode(&self) -> Vec { - let mut out = Vec::new(); - codec::push_u64(&mut out, self.seq.0); + let mut out = SessionRecordBuilder::new(self.seq, self.encoded_len()); for frame in &self.frames { - frame.encode_into(&mut out); + let pushed = out.push_frame(frame); + debug_assert!(pushed); } - out + out.into_plaintext() } } impl> SessionFrame { + pub fn encoded_len(&self) -> usize { + match self { + Self::Ping => SessionRecordBuilder::PING_ENCODED_LEN, + Self::Ack(frame) => frame.frame_encoded_len(), + Self::StreamData(frame) => frame.frame_encoded_len(), + Self::StreamWindow(_) => StreamWindow::FRAME_ENCODED_LEN, + Self::StreamClose(frame) => frame.frame_encoded_len(), + Self::Close(_) => SessionCloseBody::FRAME_ENCODED_LEN, + } + } + pub fn encode_into(&self, out: &mut Vec) { match self { Self::Ping => out.push(SessionFrameKind::Ping as u8), @@ -182,13 +201,12 @@ pub fn encrypt_record( body: &SessionRecord, nonce: crate::Nonce, ) -> QlRecord { - let aad = header.aad(); - let body = body.encode(); - let encrypted = EncryptedMessage::encrypt(crypto, session_key, body, &aad, nonce); - QlRecord { - header, - payload: QlPayload::Session(encrypted), + let mut builder = SessionRecordBuilder::new(body.seq, body.encoded_len()); + for frame in &body.frames { + let pushed = builder.push_frame(frame); + debug_assert!(pushed); } + builder.encrypt(crypto, header, session_key, nonce) } pub fn decrypt_record>( @@ -218,7 +236,10 @@ fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrame<&[u8]>, &[u8]), WireEr return Err(WireError::InvalidPayload); } let (frame, rest) = rest.split_at(StreamWindow::WIRE_SIZE); - Ok((SessionFrame::StreamWindow(StreamWindow::decode(frame)?), rest)) + Ok(( + SessionFrame::StreamWindow(StreamWindow::decode(frame)?), + rest, + )) } SessionFrameKind::StreamClose => { let (frame, rest) = split_variable_frame(rest)?; diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index 3ae539a3..44843b8c 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -58,6 +58,8 @@ pub struct StreamClose { impl StreamClose { pub const MIN_WIRE_SIZE: usize = size_of::() + size_of::() + size_of::(); + pub const FRAME_OVERHEAD: usize = + std::mem::size_of::() + size_of::() + Self::MIN_WIRE_SIZE; } impl StreamClose { @@ -91,6 +93,10 @@ impl> StreamClose { Self::MIN_WIRE_SIZE + self.payload.as_ref().len() } + pub fn frame_encoded_len(&self) -> usize { + Self::FRAME_OVERHEAD + self.payload.as_ref().len() + } + pub fn encode_into(&self, out: &mut Vec) { codec::push_u32(out, self.stream_id.0); codec::push_u8(out, self.target.to_wire()); diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs index 6e183d65..4b309b7a 100644 --- a/ql-wire/src/encrypted/stream_data.rs +++ b/ql-wire/src/encrypted/stream_data.rs @@ -14,6 +14,8 @@ pub struct StreamData { impl StreamData { pub const MIN_WIRE_SIZE: usize = size_of::() + size_of::() + size_of::(); + pub const FRAME_OVERHEAD: usize = + std::mem::size_of::() + size_of::() + Self::MIN_WIRE_SIZE; } impl StreamData { @@ -47,6 +49,10 @@ impl> StreamData { Self::MIN_WIRE_SIZE + self.bytes.as_ref().len() } + pub fn frame_encoded_len(&self) -> usize { + Self::FRAME_OVERHEAD + self.bytes.as_ref().len() + } + pub fn encode_into(&self, out: &mut Vec) { codec::push_u32(out, self.stream_id.0); codec::push_u64(out, self.offset); diff --git a/ql-wire/src/encrypted/stream_window.rs b/ql-wire/src/encrypted/stream_window.rs index d03f0d02..33d9bcf6 100644 --- a/ql-wire/src/encrypted/stream_window.rs +++ b/ql-wire/src/encrypted/stream_window.rs @@ -12,6 +12,7 @@ pub struct StreamWindow { impl StreamWindow { pub const WIRE_SIZE: usize = size_of::() + size_of::(); + pub const FRAME_ENCODED_LEN: usize = std::mem::size_of::() + Self::WIRE_SIZE; pub fn encode_into(&self, out: &mut Vec) { codec::push_u32(out, self.stream_id.0); diff --git a/ql-wire/src/encrypted_message.rs b/ql-wire/src/encrypted_message.rs index a2473fd3..5b40a0f6 100644 --- a/ql-wire/src/encrypted_message.rs +++ b/ql-wire/src/encrypted_message.rs @@ -1,5 +1,5 @@ use crate::{ - codec, ByteSlice, ENCRYPTED_MESSAGE_AUTH_SIZE, Nonce, QlCrypto, SessionKey, WireError, + codec, ByteSlice, Nonce, QlCrypto, SessionKey, WireError, ENCRYPTED_MESSAGE_AUTH_SIZE, }; #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/ql-wire/src/handshake/crypto.rs b/ql-wire/src/handshake/crypto.rs index 80ab76fc..0033cb19 100644 --- a/ql-wire/src/handshake/crypto.rs +++ b/ql-wire/src/handshake/crypto.rs @@ -1,8 +1,8 @@ use super::{Confirm, Hello, HelloReply, Ready, ReadyBody}; use crate::{ pq::ML_KEM_SUITE_TAG, ControlMeta, EncryptedMessage, MlDsaPublicKey, MlDsaSignature, - MlKemCiphertext, MlKemPublicKey, Nonce, QlCrypto, QlHeader, QlIdentity, SessionKey, - WireError, XID, + MlKemCiphertext, MlKemPublicKey, Nonce, QlCrypto, QlHeader, QlIdentity, SessionKey, WireError, + XID, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -57,7 +57,11 @@ pub fn verify_hello( &hello.nonce.0, hello.kem_ct.as_bytes(), ); - verify_signature_bytes(initiator_signing_key, hello.signature.as_bytes(), &proof_data) + verify_signature_bytes( + initiator_signing_key, + hello.signature.as_bytes(), + &proof_data, + ) } pub fn respond_hello( @@ -133,7 +137,11 @@ pub fn build_confirm( &reply.nonce.0, reply.kem_ct.as_bytes(), ); - verify_signature_bytes(responder_signing_key, reply.signature.as_bytes(), &transcript)?; + verify_signature_bytes( + responder_signing_key, + reply.signature.as_bytes(), + &transcript, + )?; let responder_secret = identity .encapsulation_private_key .decapsulate_shared_secret_bytes(reply.kem_ct.as_bytes()); @@ -225,7 +233,11 @@ pub fn verify_confirm( &reply.nonce.0, reply.kem_ct.as_bytes(), ); - verify_signature_bytes(initiator_signing_key, confirm.signature.as_bytes(), &proof_data) + verify_signature_bytes( + initiator_signing_key, + confirm.signature.as_bytes(), + &proof_data, + ) } pub fn build_ready( @@ -250,7 +262,9 @@ pub fn decrypt_ready>( now_seconds: u64, ) -> Result { let aad = header.aad(); - let mut plaintext = ready.encrypted.decrypt_in_place(crypto, session_key, &aad)?; + let mut plaintext = ready + .encrypted + .decrypt_in_place(crypto, session_key, &aad)?; let body = ReadyBody::decode(plaintext.as_mut())?; body.meta.ensure_not_expired(now_seconds)?; Ok(body) diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 351a73fc..1cd4ba57 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -1,6 +1,6 @@ use crate::{ - codec, encrypted_message::EncryptedMessage, ControlMeta, MlDsaSignature, MlKemCiphertext, - ByteSlice, Nonce, WireError, + codec, encrypted_message::EncryptedMessage, ByteSlice, ControlMeta, MlDsaSignature, + MlKemCiphertext, Nonce, WireError, }; mod crypto; @@ -15,10 +15,8 @@ pub struct Hello { } impl Hello { - pub const ENCODED_LEN: usize = ControlMeta::ENCODED_LEN - + Nonce::SIZE - + MlKemCiphertext::SIZE - + MlDsaSignature::SIZE; + pub const ENCODED_LEN: usize = + ControlMeta::ENCODED_LEN + Nonce::SIZE + MlKemCiphertext::SIZE + MlDsaSignature::SIZE; pub fn encode_into(&self, out: &mut Vec) { self.meta.encode_into(out); @@ -49,10 +47,8 @@ pub struct HelloReply { } impl HelloReply { - pub const ENCODED_LEN: usize = ControlMeta::ENCODED_LEN - + Nonce::SIZE - + MlKemCiphertext::SIZE - + MlDsaSignature::SIZE; + pub const ENCODED_LEN: usize = + ControlMeta::ENCODED_LEN + Nonce::SIZE + MlKemCiphertext::SIZE + MlDsaSignature::SIZE; pub fn encode_into(&self, out: &mut Vec) { self.meta.encode_into(out); @@ -136,10 +132,9 @@ impl> Ready { impl Ready> { pub fn decode(bytes: &[u8]) -> Result { - EncryptedMessage::parse(bytes) - .map(|encrypted| Self { - encrypted: encrypted.into_owned(), - }) + EncryptedMessage::parse(bytes).map(|encrypted| Self { + encrypted: encrypted.into_owned(), + }) } } diff --git a/ql-wire/src/header.rs b/ql-wire/src/header.rs index bcf69201..dda25f3b 100644 --- a/ql-wire/src/header.rs +++ b/ql-wire/src/header.rs @@ -34,10 +34,7 @@ pub(crate) fn decode_record_header( Ok(( DecodedRecordHeader { kind, - header: QlHeader { - sender, - recipient, - }, + header: QlHeader { sender, recipient }, }, reader.take_rest(), )) diff --git a/ql-wire/src/pair/mod.rs b/ql-wire/src/pair/mod.rs index d1352f61..6ba3ed60 100644 --- a/ql-wire/src/pair/mod.rs +++ b/ql-wire/src/pair/mod.rs @@ -1,6 +1,6 @@ use crate::{ - codec, encrypted_message::EncryptedMessage, ControlMeta, MlDsaPublicKey, MlDsaSignature, - ByteSlice, MlKemCiphertext, MlKemPublicKey, WireError, XID, + codec, encrypted_message::EncryptedMessage, ByteSlice, ControlMeta, MlDsaPublicKey, + MlDsaSignature, MlKemCiphertext, MlKemPublicKey, WireError, XID, }; mod crypto; diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index 9a2e362e..693262e6 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -89,7 +89,11 @@ impl QlRecord { pub fn encode(&self) -> Vec { let mut out = Vec::new(); codec::push_u8(&mut out, QL_WIRE_VERSION); - encode_record_header(&mut out, &self.header, RecordKind::for_payload(&self.payload)); + encode_record_header( + &mut out, + &self.header, + RecordKind::for_payload(&self.payload), + ); match &self.payload { QlPayload::PairRequest(request) => request.encode_into(&mut out), QlPayload::Unpair(unpair) => unpair.encode_into(&mut out), @@ -175,7 +179,9 @@ fn parse_payload( payload: B, ) -> Result, WireError> { match kind { - RecordKind::PairRequest => Ok(QlPayloadRef::PairRequest(PairRequestRecord::parse(payload)?)), + RecordKind::PairRequest => Ok(QlPayloadRef::PairRequest(PairRequestRecord::parse( + payload, + )?)), RecordKind::Unpair => Ok(QlPayloadRef::Unpair(Unpair::decode(&payload[..])?)), RecordKind::Hello => Ok(QlPayloadRef::Hello(handshake::Hello::decode(&payload[..])?)), RecordKind::HelloReply => Ok(QlPayloadRef::HelloReply(HelloReply::decode(&payload[..])?)), diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index e7becfb6..cdfe41a4 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -413,6 +413,40 @@ fn session_record_supports_empty_fin_stream_data_and_empty_ping() { assert_eq!(decoded, record); } +#[test] +fn session_record_builder_writes_frames_without_temp_record_allocation() { + let mut builder = SessionRecordBuilder::new(RecordSeq(55), 12); + let stream = StreamData { + stream_id: StreamId(3), + offset: 7, + fin: true, + bytes: b"hello", + }; + assert!(builder.push_stream_data(&stream)); + assert_eq!(builder.remaining_capacity(), 0); + assert!(!builder.push_ping()); + + let close = SessionCloseBody { + code: CloseCode::PROTOCOL, + }; + assert!(!builder.push_close(&close)); + + let encoded = builder.into_plaintext(); + let decoded = SessionRecord::decode(&encoded).unwrap(); + assert_eq!( + decoded, + SessionRecord { + seq: RecordSeq(55), + frames: vec![SessionFrame::StreamData(StreamData { + stream_id: StreamId(3), + offset: 7, + fin: true, + bytes: b"hello".to_vec(), + })], + } + ); +} + #[test] fn protocol_record_size_breakdown() { fn meta(id: u32) -> ControlMeta { diff --git a/ql-wire/src/unpair/crypto.rs b/ql-wire/src/unpair/crypto.rs index 6062ca2c..20451681 100644 --- a/ql-wire/src/unpair/crypto.rs +++ b/ql-wire/src/unpair/crypto.rs @@ -1,10 +1,9 @@ +use super::Unpair; use crate::{ ControlMeta, MlDsaPublicKey, QlCrypto, QlHeader, QlIdentity, QlPayload, QlRecord, WireError, XID, }; -use super::Unpair; - pub fn build_unpair( crypto: &impl QlCrypto, identity: &QlIdentity, From 8dd4eb6edea2d8de36b4ff0df05b90387892d60b Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 28 Mar 2026 15:56:14 -0400 Subject: [PATCH 045/304] too many consts --- ql-wire/src/encrypted/ack.rs | 5 ----- ql-wire/src/encrypted/builder.rs | 15 +++++++-------- ql-wire/src/encrypted/close.rs | 1 - ql-wire/src/encrypted/mod.rs | 22 +++++++++++----------- ql-wire/src/encrypted/stream_close.rs | 6 ------ ql-wire/src/encrypted/stream_data.rs | 6 ------ ql-wire/src/encrypted/stream_window.rs | 1 - 7 files changed, 18 insertions(+), 38 deletions(-) diff --git a/ql-wire/src/encrypted/ack.rs b/ql-wire/src/encrypted/ack.rs index 7944b22f..46d068fd 100644 --- a/ql-wire/src/encrypted/ack.rs +++ b/ql-wire/src/encrypted/ack.rs @@ -14,7 +14,6 @@ pub struct RecordAckRange { } impl RecordAck { - pub const FRAME_OVERHEAD: usize = std::mem::size_of::() + size_of::(); pub const RANGE_ENCODED_LEN: usize = size_of::() + size_of::(); pub fn decode(bytes: &[u8]) -> Result { @@ -50,10 +49,6 @@ impl RecordAck { self.ranges.len() * Self::RANGE_ENCODED_LEN } - pub fn frame_encoded_len(&self) -> usize { - Self::FRAME_OVERHEAD + self.encoded_len() - } - pub fn encode_into(&self, out: &mut Vec) { for range in &self.ranges { codec::push_u64(out, range.start); diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index dffe25db..ba7d6447 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -1,6 +1,6 @@ use super::{ push_variable_len, RecordAck, RecordSeq, SessionCloseBody, SessionFrame, SessionFrameKind, - StreamClose, StreamData, StreamWindow, + StreamClose, StreamData, StreamWindow, SIZE_LEN, }; use crate::{ codec, encrypted_message::EncryptedMessage, Nonce, QlCrypto, QlHeader, QlPayload, QlRecord, @@ -15,7 +15,6 @@ pub struct SessionRecordBuilder { impl SessionRecordBuilder { pub const HEADER_LEN: usize = std::mem::size_of::(); - pub const PING_ENCODED_LEN: usize = std::mem::size_of::(); pub fn new(seq: RecordSeq, max_capacity: usize) -> Self { let max_capacity = max_capacity.max(Self::HEADER_LEN); @@ -56,7 +55,7 @@ impl SessionRecordBuilder { } pub fn push_ping(&mut self) -> bool { - if !self.can_push_len(Self::PING_ENCODED_LEN) { + if !self.can_push_len(1) { return false; } self.bytes.push(SessionFrameKind::Ping as u8); @@ -64,7 +63,7 @@ impl SessionRecordBuilder { } pub fn push_ack(&mut self, ack: &RecordAck) -> bool { - if !self.can_push_len(ack.frame_encoded_len()) { + if !self.can_push_len(1 + SIZE_LEN + ack.encoded_len()) { return false; } self.bytes.push(SessionFrameKind::Ack as u8); @@ -74,7 +73,7 @@ impl SessionRecordBuilder { } pub fn push_stream_data>(&mut self, frame: &StreamData) -> bool { - if !self.can_push_len(frame.frame_encoded_len()) { + if !self.can_push_len(1 + SIZE_LEN + frame.encoded_len()) { return false; } self.bytes.push(SessionFrameKind::StreamData as u8); @@ -84,7 +83,7 @@ impl SessionRecordBuilder { } pub fn push_stream_window(&mut self, frame: &StreamWindow) -> bool { - if !self.can_push_len(StreamWindow::FRAME_ENCODED_LEN) { + if !self.can_push_len(1 + StreamWindow::WIRE_SIZE) { return false; } self.bytes.push(SessionFrameKind::StreamWindow as u8); @@ -93,7 +92,7 @@ impl SessionRecordBuilder { } pub fn push_stream_close>(&mut self, frame: &StreamClose) -> bool { - if !self.can_push_len(frame.frame_encoded_len()) { + if !self.can_push_len(1 + SIZE_LEN + frame.encoded_len()) { return false; } self.bytes.push(SessionFrameKind::StreamClose as u8); @@ -103,7 +102,7 @@ impl SessionRecordBuilder { } pub fn push_close(&mut self, close: &SessionCloseBody) -> bool { - if !self.can_push_len(SessionCloseBody::FRAME_ENCODED_LEN) { + if !self.can_push_len(1 + SessionCloseBody::WIRE_SIZE) { return false; } self.bytes.push(SessionFrameKind::Close as u8); diff --git a/ql-wire/src/encrypted/close.rs b/ql-wire/src/encrypted/close.rs index ab43e585..4702566a 100644 --- a/ql-wire/src/encrypted/close.rs +++ b/ql-wire/src/encrypted/close.rs @@ -14,7 +14,6 @@ pub struct SessionCloseBody { impl SessionCloseBody { pub const WIRE_SIZE: usize = size_of::(); - pub const FRAME_ENCODED_LEN: usize = std::mem::size_of::() + Self::WIRE_SIZE; pub fn encode_into(&self, out: &mut Vec) { codec::push_u16(out, self.code.0); diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index cb38e2b5..3a0d40d5 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -47,6 +47,8 @@ pub type SessionFrameVec = SessionFrame>; pub type StreamDataVec = StreamData>; pub type StreamCloseVec = StreamClose>; +pub(crate) const SIZE_LEN: usize = size_of::(); + #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub(crate) enum SessionFrameKind { @@ -121,13 +123,13 @@ impl SessionRecord { impl> SessionFrame { pub fn encoded_len(&self) -> usize { - match self { - Self::Ping => SessionRecordBuilder::PING_ENCODED_LEN, - Self::Ack(frame) => frame.frame_encoded_len(), - Self::StreamData(frame) => frame.frame_encoded_len(), - Self::StreamWindow(_) => StreamWindow::FRAME_ENCODED_LEN, - Self::StreamClose(frame) => frame.frame_encoded_len(), - Self::Close(_) => SessionCloseBody::FRAME_ENCODED_LEN, + 1 + match self { + Self::Ping => 0, + Self::Ack(frame) => SIZE_LEN + frame.encoded_len(), + Self::StreamData(frame) => SIZE_LEN + frame.encoded_len(), + Self::StreamWindow(_) => StreamWindow::WIRE_SIZE, + Self::StreamClose(frame) => SIZE_LEN + frame.encoded_len(), + Self::Close(_) => SessionCloseBody::WIRE_SIZE, } } @@ -261,12 +263,10 @@ fn push_variable_len(out: &mut Vec, len: usize) { } fn split_variable_frame(bytes: &[u8]) -> Result<(&[u8], &[u8]), WireError> { - const LEN_SIZE: usize = size_of::(); - - if bytes.len() < LEN_SIZE { + if bytes.len() < SIZE_LEN { return Err(WireError::InvalidPayload); } let len = u16::from_le_bytes([bytes[0], bytes[1]]) as usize; - let bytes = &bytes[LEN_SIZE..]; + let bytes = &bytes[SIZE_LEN..]; bytes.split_at_checked(len).ok_or(WireError::InvalidPayload) } diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index 44843b8c..3ae539a3 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -58,8 +58,6 @@ pub struct StreamClose { impl StreamClose { pub const MIN_WIRE_SIZE: usize = size_of::() + size_of::() + size_of::(); - pub const FRAME_OVERHEAD: usize = - std::mem::size_of::() + size_of::() + Self::MIN_WIRE_SIZE; } impl StreamClose { @@ -93,10 +91,6 @@ impl> StreamClose { Self::MIN_WIRE_SIZE + self.payload.as_ref().len() } - pub fn frame_encoded_len(&self) -> usize { - Self::FRAME_OVERHEAD + self.payload.as_ref().len() - } - pub fn encode_into(&self, out: &mut Vec) { codec::push_u32(out, self.stream_id.0); codec::push_u8(out, self.target.to_wire()); diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs index 4b309b7a..6e183d65 100644 --- a/ql-wire/src/encrypted/stream_data.rs +++ b/ql-wire/src/encrypted/stream_data.rs @@ -14,8 +14,6 @@ pub struct StreamData { impl StreamData { pub const MIN_WIRE_SIZE: usize = size_of::() + size_of::() + size_of::(); - pub const FRAME_OVERHEAD: usize = - std::mem::size_of::() + size_of::() + Self::MIN_WIRE_SIZE; } impl StreamData { @@ -49,10 +47,6 @@ impl> StreamData { Self::MIN_WIRE_SIZE + self.bytes.as_ref().len() } - pub fn frame_encoded_len(&self) -> usize { - Self::FRAME_OVERHEAD + self.bytes.as_ref().len() - } - pub fn encode_into(&self, out: &mut Vec) { codec::push_u32(out, self.stream_id.0); codec::push_u64(out, self.offset); diff --git a/ql-wire/src/encrypted/stream_window.rs b/ql-wire/src/encrypted/stream_window.rs index 33d9bcf6..d03f0d02 100644 --- a/ql-wire/src/encrypted/stream_window.rs +++ b/ql-wire/src/encrypted/stream_window.rs @@ -12,7 +12,6 @@ pub struct StreamWindow { impl StreamWindow { pub const WIRE_SIZE: usize = size_of::() + size_of::(); - pub const FRAME_ENCODED_LEN: usize = std::mem::size_of::() + Self::WIRE_SIZE; pub fn encode_into(&self, out: &mut Vec) { codec::push_u32(out, self.stream_id.0); From f03032d4c3536c2c62e03457b0be11a3cb2d2f65 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 28 Mar 2026 19:43:24 -0400 Subject: [PATCH 046/304] use byteslice instead of asref --- ql-wire/src/encrypted/mod.rs | 5 ++++- ql-wire/src/encrypted/stream_close.rs | 4 ++-- ql-wire/src/encrypted/stream_data.rs | 4 ++-- ql-wire/src/encrypted_message.rs | 4 ++-- ql-wire/src/handshake/mod.rs | 2 +- ql-wire/src/pair/mod.rs | 2 +- 6 files changed, 12 insertions(+), 9 deletions(-) diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 3a0d40d5..53c95533 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -1,7 +1,8 @@ use std::mem::size_of; use crate::{ - codec, encrypted_message::EncryptedMessage, QlCrypto, QlHeader, QlRecord, SessionKey, WireError, + codec, encrypted_message::EncryptedMessage, ByteSlice, QlCrypto, QlHeader, QlRecord, + SessionKey, WireError, }; mod ack; @@ -161,7 +162,9 @@ impl> SessionFrame { } } } +} +impl SessionFrame { pub fn into_owned(self) -> SessionFrameVec { match self { Self::Ping => SessionFrame::Ping, diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index 3ae539a3..9e816573 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -75,13 +75,13 @@ impl StreamClose { impl StreamClose { pub fn into_owned(self) -> StreamClose> where - B: AsRef<[u8]>, + B: ByteSlice, { StreamClose { stream_id: self.stream_id, target: self.target, code: self.code, - payload: self.payload.as_ref().to_vec(), + payload: self.payload.to_vec(), } } } diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs index 6e183d65..9dfade23 100644 --- a/ql-wire/src/encrypted/stream_data.rs +++ b/ql-wire/src/encrypted/stream_data.rs @@ -31,13 +31,13 @@ impl StreamData { impl StreamData { pub fn into_owned(self) -> StreamData> where - B: AsRef<[u8]>, + B: ByteSlice, { StreamData { stream_id: self.stream_id, offset: self.offset, fin: self.fin, - bytes: self.bytes.as_ref().to_vec(), + bytes: self.bytes.to_vec(), } } } diff --git a/ql-wire/src/encrypted_message.rs b/ql-wire/src/encrypted_message.rs index 5b40a0f6..e0b821f1 100644 --- a/ql-wire/src/encrypted_message.rs +++ b/ql-wire/src/encrypted_message.rs @@ -15,12 +15,12 @@ impl EncryptedMessage { pub fn into_owned(self) -> EncryptedMessage> where - B: AsRef<[u8]>, + B: ByteSlice, { EncryptedMessage { nonce: self.nonce, auth: self.auth, - ciphertext: self.ciphertext.as_ref().to_vec(), + ciphertext: self.ciphertext.to_vec(), } } } diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 1cd4ba57..5aca8b82 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -116,7 +116,7 @@ impl Ready { impl Ready { pub fn into_owned(self) -> Ready> where - B: AsRef<[u8]>, + B: ByteSlice, { Ready { encrypted: self.encrypted.into_owned(), diff --git a/ql-wire/src/pair/mod.rs b/ql-wire/src/pair/mod.rs index 6ba3ed60..b0d956d0 100644 --- a/ql-wire/src/pair/mod.rs +++ b/ql-wire/src/pair/mod.rs @@ -34,7 +34,7 @@ impl PairRequestRecord { impl PairRequestRecord { pub fn into_owned(self) -> PairRequestRecord> where - B: AsRef<[u8]>, + B: ByteSlice, { PairRequestRecord { kem_ct: self.kem_ct, From f83fd649b8b2a51a825f6ce9a2727f7d831b7148 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 28 Mar 2026 23:12:14 -0400 Subject: [PATCH 047/304] ql-wire: bytechunks --- ql-wire/src/bytes.rs | 133 +++++++++++++++++++++++++- ql-wire/src/encrypted/builder.rs | 10 +- ql-wire/src/encrypted/mod.rs | 6 +- ql-wire/src/encrypted/stream_close.rs | 10 +- ql-wire/src/encrypted/stream_data.rs | 10 +- ql-wire/src/tests.rs | 37 +++++++ 6 files changed, 188 insertions(+), 18 deletions(-) diff --git a/ql-wire/src/bytes.rs b/ql-wire/src/bytes.rs index 19795fab..b6a31dcd 100644 --- a/ql-wire/src/bytes.rs +++ b/ql-wire/src/bytes.rs @@ -1,4 +1,8 @@ -use core::ops::{Deref, DerefMut}; +use core::{ + iter::{once, Chain, Once}, + ops::{Deref, DerefMut}, +}; +use std::collections::VecDeque; /// A mutable or immutable byte slice owner used by the wire parser. pub trait ByteSlice: Deref + Sized { @@ -11,8 +15,110 @@ pub trait ByteSlice: Deref + Sized { /// A mutable reference to bytes. pub trait ByteSliceMut: ByteSlice + DerefMut {} +/// A byte container that can be encoded from one or more chunks. +pub trait ByteChunks { + type Chunks<'a>: Iterator + where + Self: 'a; + + fn len(&self) -> usize; + + fn chunks(&self) -> Self::Chunks<'_>; +} + impl ByteSliceMut for B where B: ByteSlice + DerefMut {} +impl ByteChunks for &T { + type Chunks<'a> + = T::Chunks<'a> + where + Self: 'a; + + fn len(&self) -> usize { + (*self).len() + } + + fn chunks(&self) -> Self::Chunks<'_> { + (*self).chunks() + } +} + +impl ByteChunks for &mut T { + type Chunks<'a> + = T::Chunks<'a> + where + Self: 'a; + + fn len(&self) -> usize { + (**self).len() + } + + fn chunks(&self) -> Self::Chunks<'_> { + (**self).chunks() + } +} + +impl ByteChunks for [u8] { + type Chunks<'a> + = Once<&'a [u8]> + where + Self: 'a; + + fn len(&self) -> usize { + <[u8]>::len(self) + } + + fn chunks(&self) -> Self::Chunks<'_> { + once(self) + } +} + +impl ByteChunks for [u8; N] { + type Chunks<'a> + = Once<&'a [u8]> + where + Self: 'a; + + fn len(&self) -> usize { + N + } + + fn chunks(&self) -> Self::Chunks<'_> { + once(self.as_slice()) + } +} + +impl ByteChunks for Vec { + type Chunks<'a> + = Once<&'a [u8]> + where + Self: 'a; + + fn len(&self) -> usize { + Vec::len(self) + } + + fn chunks(&self) -> Self::Chunks<'_> { + once(self.as_slice()) + } +} + +impl ByteChunks for VecDeque { + type Chunks<'a> + = Chain, Once<&'a [u8]>> + where + Self: 'a; + + fn len(&self) -> usize { + VecDeque::len(self) + } + + fn chunks(&self) -> Self::Chunks<'_> { + let (first, second) = self.as_slices(); + once(first).chain(once(second)) + } +} + impl ByteSlice for &[u8] { #[inline] fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { @@ -37,7 +143,9 @@ impl ByteSlice for &mut [u8] { #[cfg(test)] mod tests { - use super::{ByteSlice, ByteSliceMut}; + use std::collections::VecDeque; + + use super::{ByteChunks, ByteSlice, ByteSliceMut}; #[test] fn shared_slice_split_at() { @@ -68,4 +176,25 @@ mod tests { let bytes: &[u8] = b"abcdef"; assert!(ByteSlice::split_at(bytes, 7).is_err()); } + + #[test] + fn slice_byte_chunks_are_contiguous() { + let bytes: &[u8] = b"abcdef"; + let chunks = ByteChunks::chunks(&bytes).collect::>(); + assert_eq!(bytes.len(), 6); + assert_eq!(chunks, vec![b"abcdef".as_slice()]); + } + + #[test] + fn vec_deque_byte_chunks_preserve_split_storage() { + let mut bytes = VecDeque::with_capacity(8); + bytes.extend(b"abcd".iter().copied()); + bytes.drain(..2); + bytes.extend(b"efgh".iter().copied()); + + let chunks = ByteChunks::chunks(&bytes).collect::>(); + assert_eq!(bytes.len(), 6); + assert_eq!(chunks.concat(), b"cdefgh"); + assert!(chunks.len() >= 1); + } } diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index ba7d6447..df106391 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -3,8 +3,8 @@ use super::{ StreamClose, StreamData, StreamWindow, SIZE_LEN, }; use crate::{ - codec, encrypted_message::EncryptedMessage, Nonce, QlCrypto, QlHeader, QlPayload, QlRecord, - SessionKey, + codec, encrypted_message::EncryptedMessage, ByteChunks, Nonce, QlCrypto, QlHeader, QlPayload, + QlRecord, SessionKey, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -72,7 +72,7 @@ impl SessionRecordBuilder { true } - pub fn push_stream_data>(&mut self, frame: &StreamData) -> bool { + pub fn push_stream_data(&mut self, frame: &StreamData) -> bool { if !self.can_push_len(1 + SIZE_LEN + frame.encoded_len()) { return false; } @@ -91,7 +91,7 @@ impl SessionRecordBuilder { true } - pub fn push_stream_close>(&mut self, frame: &StreamClose) -> bool { + pub fn push_stream_close(&mut self, frame: &StreamClose) -> bool { if !self.can_push_len(1 + SIZE_LEN + frame.encoded_len()) { return false; } @@ -110,7 +110,7 @@ impl SessionRecordBuilder { true } - pub fn push_frame>(&mut self, frame: &SessionFrame) -> bool { + pub fn push_frame(&mut self, frame: &SessionFrame) -> bool { match frame { SessionFrame::Ping => self.push_ping(), SessionFrame::Ack(frame) => self.push_ack(frame), diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 53c95533..362bcced 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -1,8 +1,8 @@ use std::mem::size_of; use crate::{ - codec, encrypted_message::EncryptedMessage, ByteSlice, QlCrypto, QlHeader, QlRecord, - SessionKey, WireError, + codec, encrypted_message::EncryptedMessage, ByteChunks, ByteSlice, QlCrypto, QlHeader, + QlRecord, SessionKey, WireError, }; mod ack; @@ -122,7 +122,7 @@ impl SessionRecord { } } -impl> SessionFrame { +impl SessionFrame { pub fn encoded_len(&self) -> usize { 1 + match self { Self::Ping => 0, diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index 9e816573..396e5fa9 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -1,7 +1,7 @@ use std::mem::size_of; use super::StreamId; -use crate::{codec, ByteSlice, WireError}; +use crate::{codec, ByteChunks, ByteSlice, WireError}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] @@ -86,15 +86,17 @@ impl StreamClose { } } -impl> StreamClose { +impl StreamClose { pub fn encoded_len(&self) -> usize { - Self::MIN_WIRE_SIZE + self.payload.as_ref().len() + Self::MIN_WIRE_SIZE + self.payload.len() } pub fn encode_into(&self, out: &mut Vec) { codec::push_u32(out, self.stream_id.0); codec::push_u8(out, self.target.to_wire()); codec::push_u16(out, self.code.0); - codec::push_bytes(out, self.payload.as_ref()); + for chunk in self.payload.chunks() { + codec::push_bytes(out, chunk); + } } } diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs index 9dfade23..e7253cc3 100644 --- a/ql-wire/src/encrypted/stream_data.rs +++ b/ql-wire/src/encrypted/stream_data.rs @@ -1,7 +1,7 @@ use std::mem::size_of; use super::StreamId; -use crate::{codec, ByteSlice, WireError}; +use crate::{codec, ByteChunks, ByteSlice, WireError}; /// carries bytes for a stream and may finish that sending direction. #[derive(Debug, Clone, PartialEq, Eq)] @@ -42,15 +42,17 @@ impl StreamData { } } -impl> StreamData { +impl StreamData { pub fn encoded_len(&self) -> usize { - Self::MIN_WIRE_SIZE + self.bytes.as_ref().len() + Self::MIN_WIRE_SIZE + self.bytes.len() } pub fn encode_into(&self, out: &mut Vec) { codec::push_u32(out, self.stream_id.0); codec::push_u64(out, self.offset); codec::push_u8(out, u8::from(self.fin)); - codec::push_bytes(out, self.bytes.as_ref()); + for chunk in self.bytes.chunks() { + codec::push_bytes(out, chunk); + } } } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index cdfe41a4..be7c90fc 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -447,6 +447,43 @@ fn session_record_builder_writes_frames_without_temp_record_allocation() { ); } +#[test] +fn session_record_builder_encodes_borrowed_vec_deque_stream_data() { + use std::collections::VecDeque; + + let mut payload = VecDeque::with_capacity(8); + payload.extend(b"abcd".iter().copied()); + payload.drain(..2); + payload.extend(b"efgh".iter().copied()); + + let mut builder = SessionRecordBuilder::new( + RecordSeq(56), + 1 + std::mem::size_of::() + StreamData::<&VecDeque>::MIN_WIRE_SIZE + payload.len(), + ); + let stream = StreamData { + stream_id: StreamId(4), + offset: 9, + fin: false, + bytes: &payload, + }; + assert!(builder.push_stream_data(&stream)); + + let encoded = builder.into_plaintext(); + let decoded = SessionRecord::decode(&encoded).unwrap(); + assert_eq!( + decoded, + SessionRecord { + seq: RecordSeq(56), + frames: vec![SessionFrame::StreamData(StreamData { + stream_id: StreamId(4), + offset: 9, + fin: false, + bytes: b"cdefgh".to_vec(), + })], + } + ); +} + #[test] fn protocol_record_size_breakdown() { fn meta(id: u32) -> ControlMeta { From 5deb2b439ebd094c2972a1f0e160e377aed7fff4 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 28 Mar 2026 23:39:09 -0400 Subject: [PATCH 048/304] ql-wire: capped byte chunks --- ql-wire/src/bytes.rs | 83 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 82 insertions(+), 1 deletion(-) diff --git a/ql-wire/src/bytes.rs b/ql-wire/src/bytes.rs index b6a31dcd..adb4b642 100644 --- a/ql-wire/src/bytes.rs +++ b/ql-wire/src/bytes.rs @@ -141,11 +141,62 @@ impl ByteSlice for &mut [u8] { } } +#[derive(Debug, Clone, Copy)] +pub struct CappedByteChunks { + pub inner: T, + pub limit: usize, +} + +pub struct CappedByteChunksIter { + inner: I, + remaining: usize, +} + +impl<'a, I> Iterator for CappedByteChunksIter +where + I: Iterator, +{ + type Item = &'a [u8]; + + fn next(&mut self) -> Option { + while self.remaining > 0 { + let chunk = self.inner.next()?; + if chunk.is_empty() { + continue; + } + + let len = chunk.len().min(self.remaining); + self.remaining -= len; + return Some(&chunk[..len]); + } + + None + } +} + +impl ByteChunks for CappedByteChunks { + type Chunks<'a> + = CappedByteChunksIter> + where + Self: 'a; + + fn len(&self) -> usize { + self.inner.len().min(self.limit) + } + + fn chunks(&self) -> Self::Chunks<'_> { + CappedByteChunksIter { + inner: self.inner.chunks(), + remaining: self.len(), + } + } +} + #[cfg(test)] mod tests { use std::collections::VecDeque; - use super::{ByteChunks, ByteSlice, ByteSliceMut}; + use super::{ByteChunks, ByteSlice, ByteSliceMut, CappedByteChunks}; #[test] fn shared_slice_split_at() { @@ -197,4 +248,34 @@ mod tests { assert_eq!(chunks.concat(), b"cdefgh"); assert!(chunks.len() >= 1); } + + #[test] + fn capped_byte_chunks_truncate_slice() { + let bytes: &[u8] = b"abcdef"; + let capped = CappedByteChunks { + inner: bytes, + limit: 4, + }; + + let chunks = capped.chunks().collect::>(); + assert_eq!(capped.len(), 4); + assert_eq!(chunks, vec![b"abcd".as_slice()]); + } + + #[test] + fn capped_byte_chunks_truncate_borrowed_vec_deque() { + let mut bytes = VecDeque::with_capacity(8); + bytes.extend(b"abcd".iter().copied()); + bytes.drain(..2); + bytes.extend(b"efgh".iter().copied()); + + let capped = CappedByteChunks { + inner: &bytes, + limit: 4, + }; + + let chunks = capped.chunks().collect::>(); + assert_eq!(capped.len(), 4); + assert_eq!(chunks.concat(), b"cdef"); + } } From 25b53943f2648a5fe499cd02ba6fee52b73f4929 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 29 Mar 2026 00:25:00 -0400 Subject: [PATCH 049/304] ql-fsm: port --- ql-fsm/src/implementation/fsm.rs | 21 +- ql-fsm/src/implementation/handshake.rs | 43 +- ql-fsm/src/implementation/peer.rs | 12 +- ql-fsm/src/lib.rs | 6 +- ql-fsm/src/session/mod.rs | 518 ++++++------------ ql-fsm/src/session/state.rs | 76 +-- .../session/{reassembly.rs => stream_rx.rs} | 154 +++--- ql-fsm/src/session/stream_tx.rs | 337 ++++++++++++ ql-fsm/src/session/tests.rs | 15 +- ql-fsm/src/state.rs | 2 +- ql-fsm/src/tests/mod.rs | 10 +- ql-wire/src/bytes.rs | 95 +++- 12 files changed, 746 insertions(+), 543 deletions(-) rename ql-fsm/src/session/{reassembly.rs => stream_rx.rs} (77%) create mode 100644 ql-fsm/src/session/stream_tx.rs diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index c16e109a..5f2146c0 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -37,8 +37,8 @@ pub fn receive( } match payload { - QlPayloadRef::PairRequest(mut request) => { - super::handle_pair(fsm, crypto, &header, &mut request)?; + QlPayloadRef::PairRequest(request) => { + super::handle_pair(fsm, crypto, &header, request)?; } QlPayloadRef::Unpair(unpair) => { super::handle_unpair(fsm, crypto, &header, &unpair)?; @@ -52,17 +52,17 @@ pub fn receive( QlPayloadRef::Confirm(confirm) => { super::handle_confirm(fsm, crypto, &header, &confirm)?; } - QlPayloadRef::Ready(mut ready) => { - super::handle_ready(fsm, crypto, &header, &mut ready)?; + QlPayloadRef::Ready(ready) => { + super::handle_ready(fsm, crypto, &header, ready)?; } - QlPayloadRef::Session(mut encrypted) => { + QlPayloadRef::Session(encrypted) => { let Some((_, session_key)) = super::peer_session(fsm) else { return Err(QlFsmError::NoSession); }; - let record = wire::decrypt_record(crypto, &header, &mut encrypted, &session_key)?; - let record = wire::SessionRecord::from_wire(&record)?; + let plaintext = wire::decrypt_record(crypto, &header, encrypted, &session_key)?; + let (seq, frames) = wire::SessionRecord::parse(plaintext.as_ref())?; let mut session_closed = false; - fsm.session.receive(fsm.state.now.instant, record, { + fsm.session.receive(fsm.state.now.instant, seq, frames, { let session_events = &mut fsm.state.session_events; |event| { session_closed |= super::forward_session_event(session_events, event); @@ -127,15 +127,14 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option, + hello: &Hello, ) -> Result<(), QlFsmError> { let action = { let Some(entry) = fsm.peer.as_ref() else { @@ -100,7 +99,7 @@ pub fn handle_hello( enqueue_handshake(fsm, header.sender, QlPayload::HelloReply(reply)); } HelloAction::StartResponder => { - if is_replayed_control(fsm, header.sender, wire::ControlMeta::from_wire(hello.meta)) { + if is_replayed_control(fsm, header.sender, hello.meta) { return Ok(()); } @@ -120,7 +119,7 @@ pub fn handle_hello( let retry_at = Some(fsm.state.now.instant + fsm.config.handshake_retry_interval); if let Some(entry) = fsm.peer.as_mut() { entry.session = ConnectionState::Responder { - hello: wire::Hello::from_wire(hello), + hello: hello.clone(), reply: reply.clone(), deadline, stage: HandshakeResponder::WaitingConfirm { @@ -142,7 +141,7 @@ pub fn handle_hello_reply( fsm: &mut QlFsm, crypto: &impl QlCrypto, header: &QlHeader, - reply: &RefMut<'_, HelloReplyWire>, + reply: &HelloReply, ) -> Result<(), QlFsmError> { let action = { let Some(entry) = fsm.peer.as_ref() else { @@ -198,7 +197,7 @@ pub fn handle_hello_reply( fsm.state.now.unix_secs, )?; - if is_replayed_control(fsm, header.sender, wire::ControlMeta::from_wire(reply.meta)) { + if is_replayed_control(fsm, header.sender, reply.meta) { return Ok(()); } @@ -209,7 +208,7 @@ pub fn handle_hello_reply( hello, deadline, stage: HandshakeInitiator::WaitingReady { - reply: wire::HelloReply::from_wire(reply), + reply: reply.clone(), confirm: confirm.clone(), session_key, retry_count: 0, @@ -228,7 +227,7 @@ pub fn handle_confirm( fsm: &mut QlFsm, crypto: &impl QlCrypto, header: &QlHeader, - confirm: &RefMut<'_, ConfirmWire>, + confirm: &Confirm, ) -> Result<(), QlFsmError> { if let Some(ready) = recent_ready_resend(fsm, crypto, header.sender, confirm) { enqueue_handshake(fsm, header.sender, QlPayload::Ready(ready)); @@ -265,11 +264,7 @@ pub fn handle_confirm( let (hello, reply, deadline, session_key) = outcome?; - if is_replayed_control( - fsm, - header.sender, - wire::ControlMeta::from_wire(confirm.meta), - ) { + if is_replayed_control(fsm, header.sender, confirm.meta) { return Ok(()); } @@ -305,7 +300,7 @@ pub fn handle_ready( fsm: &mut QlFsm, crypto: &impl QlCrypto, header: &QlHeader, - ready: &mut RefMut<'_, EncryptedMessageWire>, + ready: Ready<&mut [u8]>, ) -> Result<(), QlFsmError> { let session_key = { let Some(entry) = fsm.peer.as_ref() else { @@ -500,8 +495,8 @@ fn recent_ready_resend( fsm: &QlFsm, crypto: &impl QlCrypto, peer: XID, - confirm: &RefMut<'_, ConfirmWire>, -) -> Option { + confirm: &Confirm, +) -> Option>> { let entry = fsm.peer.as_ref()?; let ConnectionState::Connected { recent_ready: Some(recent_ready), @@ -641,21 +636,21 @@ fn responder_retry_at(stage: &HandshakeResponder) -> Option { } } -fn same_hello_ref(stored: &Hello, incoming: &RefMut<'_, HelloWire>) -> bool { - stored.meta.control_id.0 == incoming.meta.control_id.get() && stored.nonce.0 == incoming.nonce +fn same_hello_ref(stored: &Hello, incoming: &Hello) -> bool { + stored.meta.control_id == incoming.meta.control_id && stored.nonce == incoming.nonce } -fn same_reply_ref(stored: &HelloReply, incoming: &RefMut<'_, HelloReplyWire>) -> bool { - stored.meta.control_id.0 == incoming.meta.control_id.get() && stored.nonce.0 == incoming.nonce +fn same_reply_ref(stored: &HelloReply, incoming: &HelloReply) -> bool { + stored.meta.control_id == incoming.meta.control_id && stored.nonce == incoming.nonce } fn peer_hello_wins_ref( local_hello: &Hello, local_sender: XID, - peer_hello: &RefMut<'_, HelloWire>, + peer_hello: &Hello, peer_sender: XID, ) -> bool { - match peer_hello.nonce.cmp(&local_hello.nonce.0) { + match peer_hello.nonce.0.cmp(&local_hello.nonce.0) { Ordering::Less => true, Ordering::Greater => false, Ordering::Equal => peer_sender.0.cmp(&local_sender.0) == Ordering::Less, diff --git a/ql-fsm/src/implementation/peer.rs b/ql-fsm/src/implementation/peer.rs index bd2d0785..b58659f2 100644 --- a/ql-fsm/src/implementation/peer.rs +++ b/ql-fsm/src/implementation/peer.rs @@ -1,4 +1,4 @@ -use ql_wire::{self as wire, PairRequestRecordWire, QlCrypto, QlHeader, RefMut, UnpairWire}; +use ql_wire::{self as wire, PairRequestRecord, QlCrypto, QlHeader, Unpair}; use super::{ clear_bound_peer, emit_peer_status, handshake, is_replayed_control, next_control_meta, @@ -36,7 +36,7 @@ pub fn handle_pair( fsm: &mut QlFsm, crypto: &impl QlCrypto, header: &QlHeader, - request: &mut RefMut<'_, PairRequestRecordWire>, + request: PairRequestRecord<&mut [u8]>, ) -> Result<(), QlFsmError> { let payload = wire::decrypt_pair_request( crypto, @@ -67,7 +67,7 @@ pub fn handle_unpair( fsm: &mut QlFsm, crypto: &impl QlCrypto, header: &QlHeader, - unpair: &RefMut<'_, UnpairWire>, + unpair: &Unpair, ) -> Result<(), QlFsmError> { let Some(entry) = fsm.peer.as_ref() else { return Ok(()); @@ -80,11 +80,7 @@ pub fn handle_unpair( unpair, fsm.state.now.unix_secs, )?; - if is_replayed_control( - fsm, - header.sender, - wire::ControlMeta::from_wire(unpair.meta), - ) { + if is_replayed_control(fsm, header.sender, unpair.meta) { return Ok(()); } diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 1860990c..e9980a5c 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -31,9 +31,9 @@ use std::time::{Duration, Instant}; pub use error::QlFsmError; use ql_wire::{ CloseCode, CloseTarget, MlDsaPublicKey, MlKemPublicKey, QlCrypto, QlIdentity, QlRecord, - SessionCloseBody, StreamClose, StreamId, XID, + SessionCloseBody, StreamCloseVec, StreamId, XID, }; -pub use session::reassembly::StreamReadIter; +pub use session::stream_rx::StreamReadIter; use crate::{ replay_cache::ReplayCache, @@ -102,7 +102,7 @@ pub enum QlSessionEvent { /// the peer finished writing this stream Finished(StreamId), /// a stream was closed - Closed(StreamClose), + Closed(StreamCloseVec), /// local writes on this stream are closed WritableClosed(StreamId), /// the peer requested unpairing diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index f157358f..bbb32fde 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -1,5 +1,6 @@ -pub(crate) mod reassembly; pub(crate) mod state; +pub(crate) mod stream_rx; +pub(crate) mod stream_tx; #[cfg(test)] mod tests; @@ -8,16 +9,19 @@ use std::time::{Duration, Instant}; use indexmap::map::Entry; use ql_wire::{ - CloseCode, CloseTarget, RecordAck, RecordSeq, SessionCloseBody, SessionFrame, SessionRecord, - StreamClose, StreamData, StreamId, StreamWindow, + CloseCode, CloseTarget, RecordAck, RecordSeq, SessionCloseBody, SessionFrame, + SessionRecordBuilder, StreamClose, StreamCloseVec, StreamData, StreamId, StreamWindow, + WireError, }; use self::{ - reassembly::{StreamAssemblerError, StreamReadIter}, state::{ AckState, InboundState, OutboundRecord, OutboundState, ReceiveInsertOutcome, - ReceivedRecords, ReliableFrame, SessionFsmState, StreamParity, StreamRole, StreamState, + ReceivedRecords, ReliableFrame, SessionFsmState, StreamDataManifest, StreamParity, + StreamRole, StreamState, }, + stream_rx::{StreamReadIter, StreamRxError}, + stream_tx::StreamTxRange, }; pub(crate) const SESSION_RECORD_TRACKED_WINDOW: u64 = 256; @@ -55,7 +59,7 @@ pub enum SessionEvent { Readable(StreamId), Writable(StreamId), Finished(StreamId), - Closed(StreamClose), + Closed(StreamCloseVec), WritableClosed(StreamId), SessionClosed(SessionCloseBody), } @@ -119,7 +123,6 @@ impl SessionFsm { stream_id, StreamState::new( StreamRole::Initiator, - self.config.stream_send_buffer_size, self.config.stream_receive_buffer_size, ), ); @@ -144,7 +147,7 @@ impl SessionFsm { let accepted = bytes .len() .min(stream.send_capacity(self.config.stream_send_buffer_size)); - stream.send_buf.extend(bytes[..accepted].iter().copied()); + stream.tx.append(&bytes[..accepted]); Ok(accepted) } @@ -158,6 +161,7 @@ impl SessionFsm { if !stream.is_writable() { return Err(StreamError::NotWritable); } + stream.tx.queue_fin(); stream.outbound_state = OutboundState::FinQueued; Ok(()) } @@ -194,7 +198,7 @@ impl SessionFsm { .streams .get(&stream_id) .ok_or(StreamError::MissingStream)?; - Ok(stream.recv.bytes()) + Ok(stream.rx.bytes()) } pub fn stream_read_commit( @@ -211,7 +215,7 @@ impl SessionFsm { return Err(StreamError::InvalidRead); } stream - .recv + .rx .consume(len) .map_err(|_| StreamError::InvalidRead)?; if stream.recv_limit() > stream.advertised_max_offset { @@ -236,37 +240,45 @@ impl SessionFsm { Ok(()) } - pub fn receive( + pub fn receive<'a, I>( &mut self, now: Instant, - record: SessionRecord, + seq: RecordSeq, + frames: I, mut emit: impl FnMut(SessionEvent), - ) { + ) where + I: IntoIterator, WireError>>, + { self.state.now = now; self.collect_timeouts(); - - let ack_eliciting = Self::record_is_ack_eliciting(&record); self.state.last_activity_at = self.state.now; self.state.last_inbound_at = self.state.now; - let out_of_order = match self.state.received_records.insert(record.seq) { - ReceiveInsertOutcome::Duplicate => { - if ack_eliciting { - self.schedule_ack(true); - } - return; - } - ReceiveInsertOutcome::New { out_of_order } => out_of_order, + let (duplicate, out_of_order) = match self.state.received_records.insert(seq) { + ReceiveInsertOutcome::Duplicate => (true, false), + ReceiveInsertOutcome::New { out_of_order } => (false, out_of_order), }; - if self.state.session_state == SessionState::Closed { - if ack_eliciting { - self.schedule_ack(true); + let closed = self.state.session_state == SessionState::Closed; + let mut ack_eliciting = false; + for frame in frames { + let frame = match frame { + Ok(frame) => frame, + Err(_) => { + self.fail_session( + SessionCloseBody { + code: CloseCode::PROTOCOL, + }, + &mut emit, + ); + return; + } + }; + ack_eliciting |= !matches!(frame, SessionFrame::Ack(_)); + if duplicate || closed { + continue; } - return; - } - for frame in record.frames { match frame { SessionFrame::Ping => {} SessionFrame::Ack(ack) => self.process_record_ack(ack, &mut emit), @@ -281,7 +293,10 @@ impl SessionFsm { } } SessionFrame::StreamClose(frame) => { - if self.handle_stream_close(frame, &mut emit).is_err() { + if self + .handle_stream_close(frame.into_owned(), &mut emit) + .is_err() + { return; } } @@ -293,7 +308,7 @@ impl SessionFsm { } if ack_eliciting { - self.schedule_ack(out_of_order); + self.schedule_ack(duplicate || closed || out_of_order); } } @@ -387,31 +402,24 @@ impl SessionFsm { pub fn has_pending_stream_work(&self) -> bool { self.state.streams.values().any(|stream| { - stream.pending_close.is_some() - || !stream.retransmit.is_empty() - || !stream.send_buf.is_empty() - || stream.pending_window - || matches!(stream.outbound_state, OutboundState::FinQueued) + stream.pending_close.is_some() || stream.pending_window || stream.tx.has_pending() }) } - pub fn take_next_write(&mut self, now: Instant) -> Option<(u64, SessionRecord)> { + pub fn take_next_write(&mut self, now: Instant) -> Option<(u64, SessionRecordBuilder)> { self.state.now = now; self.collect_timeouts(); - let built = self.build_next_record()?; + let (builder, outbound) = self.build_next_record()?; let write_id = self.state.next_write_id; self.state.next_write_id = self.state.next_write_id.wrapping_add(1); - self.state.outbound_records.insert(write_id, built.outbound); - Some((write_id, built.record)) + self.state.outbound_records.insert(write_id, outbound); + Some((write_id, builder)) } - fn build_next_record(&mut self) -> Option { + fn build_next_record(&mut self) -> Option<(SessionRecordBuilder, OutboundRecord)> { let seq = self.state.next_record_seq; - let mut record = SessionRecord { - seq, - frames: Vec::new(), - }; + let mut builder = SessionRecordBuilder::new(seq, self.config.record_size); let mut outbound = OutboundRecord { seq, reliable: Vec::new(), @@ -420,137 +428,50 @@ impl SessionFsm { window_updates: Vec::new(), sent_at: None, }; - let mut remaining = self.config.record_size.saturating_sub(8); if self.should_send_ack() { if let Some(ack) = self.state.received_records.ack() { - let frame = SessionFrame::Ack(ack); - if self.push_frame(&mut record, &mut remaining, frame, true) { + if builder.push_ack(&ack) { outbound.ack_included = true; self.state.ack_state = AckState::Idle; } } } - while let Some(close) = self.take_pending_session_close(remaining, record.frames.is_empty()) - { - let frame = SessionFrame::Close(close.clone()); - if !self.push_frame(&mut record, &mut remaining, frame, true) { - self.state.pending_control.close = Some(close); - break; + if let Some(close) = self.state.pending_control.close.clone() { + if builder.push_close(&close) { + self.state.pending_control.close = None; + outbound.reliable.push(ReliableFrame::Close(close)); } - outbound.reliable.push(ReliableFrame::Close(close)); } - while let Some(close) = - self.take_next_pending_stream_close(remaining, record.frames.is_empty()) - { - let frame = SessionFrame::StreamClose(close.clone()); - if !self.push_frame(&mut record, &mut remaining, frame, true) { - self.restore_stream_close(close); - break; - } - outbound.reliable.push(ReliableFrame::StreamClose(close)); - } + while self.push_next_pending_stream_close(&mut builder, &mut outbound) {} - if let Some(ping) = self.take_pending_ping(remaining, record.frames.is_empty()) { - if self.push_frame(&mut record, &mut remaining, ping, true) { - outbound.ping_included = true; - } else { - self.state.pending_control.ping = true; - } + if self.state.pending_control.ping && builder.push_ping() { + self.state.pending_control.ping = false; + outbound.ping_included = true; } - while let Some(window) = - self.take_next_pending_stream_window(remaining, record.frames.is_empty()) - { - let maximum_offset = window.maximum_offset; - let stream_id = window.stream_id; - if !self.push_frame( - &mut record, - &mut remaining, - SessionFrame::StreamWindow(window), - true, - ) { - if let Some(stream) = self.state.streams.get_mut(&stream_id) { - stream.pending_window = true; - } - break; - } - outbound.window_updates.push((stream_id, maximum_offset)); - } + while self.push_next_pending_stream_window(&mut builder, &mut outbound) {} - while let Some(frame) = - self.take_next_retransmit_stream_data(remaining, record.frames.is_empty()) - { - if !self.push_frame( - &mut record, - &mut remaining, - SessionFrame::StreamData(frame.clone()), - true, - ) { - self.restore_stream_data(frame); - break; - } - outbound.reliable.push(ReliableFrame::StreamData(frame)); - } + while self.push_next_stream_data(&mut builder, &mut outbound) {} - while let Some(frame) = - self.take_next_fresh_stream_data(remaining, record.frames.is_empty()) - { - if !self.push_frame( - &mut record, - &mut remaining, - SessionFrame::StreamData(frame.clone()), - true, - ) { - self.restore_stream_data(frame); - break; - } - outbound.reliable.push(ReliableFrame::StreamData(frame)); - } - - if record.frames.is_empty() { + if builder.is_empty() { return None; } self.state.next_record_seq = RecordSeq(self.state.next_record_seq.0.saturating_add(1)); - Some(BuiltRecord { record, outbound }) + Some((builder, outbound)) } - fn take_pending_session_close( + fn push_next_pending_stream_close( &mut self, - remaining: usize, - record_empty: bool, - ) -> Option { - let close = self.state.pending_control.close.clone()?; - let frame = SessionFrame::Close(close.clone()); - if !self.frame_fits(remaining, record_empty, &frame) { - return None; - } - self.state.pending_control.close.take() - } - - fn take_pending_ping(&mut self, remaining: usize, record_empty: bool) -> Option { - if !self.state.pending_control.ping { - return None; - } - let frame = SessionFrame::Ping; - if !self.frame_fits(remaining, record_empty, &frame) { - return None; - } - self.state.pending_control.ping = false; - Some(frame) - } - - fn take_next_pending_stream_close( - &mut self, - remaining: usize, - record_empty: bool, - ) -> Option { + builder: &mut SessionRecordBuilder, + outbound: &mut OutboundRecord, + ) -> bool { let len = self.state.streams.len(); if len == 0 { - return None; + return false; } let start = self.state.next_stream_index % len; @@ -559,30 +480,32 @@ impl SessionFsm { let Some((_, stream)) = self.state.streams.get_index(index) else { continue; }; - let Some(close) = stream.pending_close.clone() else { + let Some(close) = stream.pending_close.as_ref() else { continue; }; - let frame = SessionFrame::StreamClose(close.clone()); - if !self.frame_fits(remaining, record_empty, &frame) { + if !builder.push_stream_close(close) { continue; } let stream = self.state.streams.get_index_mut(index).unwrap().1; self.state.next_stream_index = (index + 1) % len; - return stream.pending_close.take().or(Some(close)); + outbound.reliable.push(ReliableFrame::StreamClose( + stream.pending_close.take().unwrap(), + )); + return true; } - None + false } - fn take_next_pending_stream_window( + fn push_next_pending_stream_window( &mut self, - remaining: usize, - record_empty: bool, - ) -> Option { + builder: &mut SessionRecordBuilder, + outbound: &mut OutboundRecord, + ) -> bool { let len = self.state.streams.len(); if len == 0 { - return None; + return false; } let start = self.state.next_stream_index % len; @@ -598,11 +521,7 @@ impl SessionFsm { stream_id, maximum_offset: stream.recv_limit(), }; - if !self.frame_fits( - remaining, - record_empty, - &SessionFrame::StreamWindow(frame.clone()), - ) { + if !builder.push_stream_window(&frame) { continue; } @@ -610,63 +529,26 @@ impl SessionFsm { stream.pending_window = false; stream.advertised_max_offset = frame.maximum_offset; self.state.next_stream_index = (index + 1) % len; - return Some(frame); + outbound + .window_updates + .push((stream_id, frame.maximum_offset)); + return true; } - None + false } - fn take_next_retransmit_stream_data( + fn push_next_stream_data( &mut self, - remaining: usize, - record_empty: bool, - ) -> Option { - let max_payload = self.max_stream_data_payload(remaining, record_empty)?; - let len = self.state.streams.len(); - if len == 0 { - return None; - } - - let start = self.state.next_stream_index % len; - for offset in 0..len { - let index = (start + offset) % len; - let Some((_, stream)) = self.state.streams.get_index(index) else { - continue; - }; - - if matches!(stream.outbound_state, OutboundState::Closed) { - let (_, stream) = self.state.streams.get_index_mut(index).unwrap(); - while let Some(frame) = stream.retransmit.pop_front() { - stream.inflight_bytes = stream.inflight_bytes.saturating_sub(frame.bytes.len()); - } - continue; - } - - let Some(_) = stream.retransmit.front() else { - continue; - }; - let (_, stream) = self.state.streams.get_index_mut(index).unwrap(); - let frame = stream.retransmit.pop_front().unwrap(); - let (head, tail) = Self::split_stream_data(frame, max_payload); - if let Some(tail) = tail { - stream.retransmit.push_front(tail); - } - self.state.next_stream_index = (index + 1) % len; - return Some(head); - } - - None - } - - fn take_next_fresh_stream_data( - &mut self, - remaining: usize, - record_empty: bool, - ) -> Option { - let max_payload = self.max_stream_data_payload(remaining, record_empty)?; + builder: &mut SessionRecordBuilder, + outbound: &mut OutboundRecord, + ) -> bool { + let Some(max_payload) = self.max_stream_data_payload(builder) else { + return false; + }; let len = self.state.streams.len(); if len == 0 { - return None; + return false; } let start = self.state.next_stream_index % len; @@ -679,45 +561,51 @@ impl SessionFsm { continue; } - let credit_remaining = stream - .peer_max_offset - .saturating_sub(stream.next_send_offset) - as usize; - let has_empty_fin = matches!(stream.outbound_state, OutboundState::FinQueued) - && stream.send_buf.is_empty() - && stream.next_send_offset <= stream.peer_max_offset; - if stream.send_buf.is_empty() && !has_empty_fin { - continue; - } - - if credit_remaining == 0 && !has_empty_fin { + let Some(candidate) = stream.tx.next_range(max_payload, stream.peer_max_offset) else { continue; + }; + { + let frame = StreamData { + stream_id, + offset: candidate.offset, + fin: candidate.fin, + bytes: stream.tx.ranged_bytes(candidate), + }; + if !builder.push_stream_data(&frame) { + continue; + } } let (_, stream) = self.state.streams.get_index_mut(index).unwrap(); - let payload_len = stream.send_buf.len().min(max_payload).min(credit_remaining); - let bytes: Vec = stream.send_buf.drain(..payload_len).collect(); - let fin = matches!(stream.outbound_state, OutboundState::FinQueued) - && stream.send_buf.is_empty() - && stream.next_send_offset + bytes.len() as u64 <= stream.peer_max_offset; - let frame = StreamData { - stream_id, - offset: stream.next_send_offset, - fin, - bytes, - }; - stream.next_send_offset = stream - .next_send_offset - .saturating_add(frame.bytes.len() as u64); - stream.inflight_bytes = stream.inflight_bytes.saturating_add(frame.bytes.len()); - if fin { + stream.tx.mark_in_flight(candidate); + if candidate.fin { stream.outbound_state = OutboundState::Finished; } self.state.next_stream_index = (index + 1) % len; - return Some(frame); - } + outbound + .reliable + .push(ReliableFrame::StreamData(StreamDataManifest { + stream_id, + offset: candidate.offset, + len: candidate.len, + fin: candidate.fin, + })); + return true; + } + + false + } - None + fn max_stream_data_payload(&self, builder: &SessionRecordBuilder) -> Option { + let overhead = 1 + std::mem::size_of::() + StreamData::>::MIN_WIRE_SIZE; + let remaining = builder.remaining_capacity(); + if remaining > overhead { + Some(remaining - overhead) + } else if builder.is_empty() { + Some(self.config.record_size) + } else { + None + } } fn ensure_session_open(&self) -> Result<(), StreamError> { @@ -843,7 +731,11 @@ impl SessionFsm { let stream_id = frame.stream_id; if let Some(stream) = self.state.streams.get_mut(&stream_id) { let was_full = stream.send_capacity(self.config.stream_send_buffer_size) == 0; - stream.inflight_bytes = stream.inflight_bytes.saturating_sub(frame.bytes.len()); + stream.tx.mark_acked(StreamTxRange { + offset: frame.offset, + len: frame.len, + fin: frame.fin, + }); if was_full && stream.send_capacity(self.config.stream_send_buffer_size) > 0 { emit(SessionEvent::Writable(stream_id)); } @@ -855,7 +747,7 @@ impl SessionFsm { fn handle_stream_data( &mut self, - frame: StreamData, + frame: StreamData<&[u8]>, emit: &mut impl FnMut(SessionEvent), ) -> Result<(), ()> { let stream_id = frame.stream_id; @@ -874,7 +766,6 @@ impl SessionFsm { emit(SessionEvent::Opened(stream_id)); entry.insert(StreamState::new( StreamRole::Responder, - self.config.stream_send_buffer_size, self.config.stream_receive_buffer_size, )) } @@ -884,7 +775,7 @@ impl SessionFsm { InboundState::Open => {} InboundState::Discarding => return Ok(()), InboundState::Finished | InboundState::Closed(_) => { - if frame.offset + frame.bytes.len() as u64 <= stream.recv.start_offset() { + if frame.offset + frame.bytes.len() as u64 <= stream.rx.start_offset() { return Ok(()); } self.fail_session( @@ -898,7 +789,7 @@ impl SessionFsm { } let was_readable = stream.readable_bytes() > 0; - let insert = stream.recv.insert(frame.offset, frame.fin, &frame.bytes); + let insert = stream.rx.insert(frame.offset, frame.fin, frame.bytes); match insert { Ok(outcome) => { if !was_readable && outcome.newly_readable_bytes > 0 { @@ -911,13 +802,13 @@ impl SessionFsm { self.try_reap_stream(stream_id); Ok(()) } - Err(StreamAssemblerError::ConflictingOverlap) - | Err(StreamAssemblerError::OutOfWindow) - | Err(StreamAssemblerError::InconsistentFinalOffset) - | Err(StreamAssemblerError::FinalOffsetBeforeBufferedData) - | Err(StreamAssemblerError::BeyondFinalOffset) - | Err(StreamAssemblerError::TooManyMissingRanges) - | Err(StreamAssemblerError::OffsetOverflow) => { + Err(StreamRxError::ConflictingOverlap) + | Err(StreamRxError::OutOfWindow) + | Err(StreamRxError::InconsistentFinalOffset) + | Err(StreamRxError::FinalOffsetBeforeBufferedData) + | Err(StreamRxError::BeyondFinalOffset) + | Err(StreamRxError::TooManyMissingRanges) + | Err(StreamRxError::OffsetOverflow) => { self.fail_session( SessionCloseBody { code: CloseCode::PROTOCOL, @@ -926,7 +817,7 @@ impl SessionFsm { ); Err(()) } - Err(StreamAssemblerError::ConsumeBeyondReadable) => unreachable!(), + Err(StreamRxError::ConsumeBeyondReadable) => unreachable!(), } } @@ -957,7 +848,7 @@ impl SessionFsm { fn handle_stream_close( &mut self, - frame: StreamClose, + frame: StreamCloseVec, emit: &mut impl FnMut(SessionEvent), ) -> Result<(), ()> { let created = match self.state.streams.entry(frame.stream_id) { @@ -974,7 +865,6 @@ impl SessionFsm { } entry.insert(StreamState::new( StreamRole::Responder, - self.config.stream_send_buffer_size, self.config.stream_receive_buffer_size, )); true @@ -1000,10 +890,8 @@ impl SessionFsm { && !matches!(stream.outbound_state, OutboundState::Closed) { stream.outbound_state = OutboundState::Closed; - stream.send_buf.clear(); - stream.retransmit.clear(); + stream.tx.clear(); stream.pending_close = None; - stream.inflight_bytes = 0; emit(SessionEvent::WritableClosed(frame.stream_id)); } self.try_reap_stream(frame.stream_id); @@ -1033,8 +921,7 @@ impl SessionFsm { } if Self::target_affects_outbound(stream.role, target) { stream.outbound_state = OutboundState::Closed; - stream.send_buf.clear(); - stream.retransmit.clear(); + stream.tx.clear(); } } @@ -1046,92 +933,28 @@ impl SessionFsm { matches!(target, CloseTarget::Both) || role.outbound_target() == target } - fn restore_stream_close(&mut self, close: StreamClose) { + fn restore_stream_close(&mut self, close: StreamCloseVec) { if let Some(stream) = self.state.streams.get_mut(&close.stream_id) { stream.pending_close = Some(close); } } - fn restore_stream_data(&mut self, frame: StreamData) { + fn restore_stream_data(&mut self, frame: StreamDataManifest) { if let Some(stream) = self.state.streams.get_mut(&frame.stream_id) { if matches!(stream.outbound_state, OutboundState::Closed) { - stream.inflight_bytes = stream.inflight_bytes.saturating_sub(frame.bytes.len()); return; } - stream.retransmit.push_front(frame); - } - } - - fn split_stream_data( - frame: StreamData, - max_payload: usize, - ) -> (StreamData, Option) { - if frame.bytes.len() <= max_payload || frame.bytes.is_empty() { - return (frame, None); - } - - let split = max_payload.max(1).min(frame.bytes.len()); - let mut head = frame.clone(); - head.bytes.truncate(split); - head.fin = false; - - let tail = StreamData { - stream_id: frame.stream_id, - offset: frame.offset + split as u64, - fin: frame.fin, - bytes: frame.bytes[split..].to_vec(), - }; - (head, Some(tail)) - } - - fn max_stream_data_payload(&self, remaining: usize, record_empty: bool) -> Option { - let overhead = self.frame_len(&SessionFrame::StreamData(StreamData { - stream_id: StreamId(0), - offset: 0, - fin: false, - bytes: Vec::new(), - })); - if remaining > overhead { - Some(remaining - overhead) - } else if record_empty { - Some(self.config.record_size) - } else { - None - } - } - - fn frame_fits(&self, remaining: usize, record_empty: bool, frame: &SessionFrame) -> bool { - let len = self.frame_len(frame); - len <= remaining || record_empty - } - - fn push_frame( - &self, - record: &mut SessionRecord, - remaining: &mut usize, - frame: SessionFrame, - force_if_empty: bool, - ) -> bool { - let len = self.frame_len(&frame); - if len > *remaining && !(force_if_empty && record.frames.is_empty()) { - return false; + stream.tx.mark_lost(StreamTxRange { + offset: frame.offset, + len: frame.len, + fin: frame.fin, + }); + if frame.fin { + if matches!(stream.outbound_state, OutboundState::Finished) { + stream.outbound_state = OutboundState::FinQueued; + } + } } - record.frames.push(frame); - *remaining = remaining.saturating_sub(len); - true - } - - fn frame_len(&self, frame: &SessionFrame) -> usize { - let mut bytes = Vec::new(); - frame.encode_into(&mut bytes); - bytes.len() - } - - fn record_is_ack_eliciting(record: &SessionRecord) -> bool { - record - .frames - .iter() - .any(|frame| !matches!(frame, SessionFrame::Ack(_))) } fn stream_is_reapable(&self, stream_id: StreamId, stream: &StreamState) -> bool { @@ -1147,12 +970,10 @@ impl SessionFsm { return false; } - if !stream.send_buf.is_empty() - || !stream.retransmit.is_empty() + if !stream.tx.is_empty() || stream.pending_close.is_some() - || stream.inflight_bytes > 0 || stream.readable_bytes() > 0 - || stream.recv.buffered_end_offset() > stream.recv.start_offset() + || stream.rx.buffered_end_offset() > stream.rx.start_offset() { return false; } @@ -1211,8 +1032,3 @@ impl SessionFsm { self.state.streams.clear(); } } - -struct BuiltRecord { - record: SessionRecord, - outbound: OutboundRecord, -} diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index 21e0892b..c41af21d 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -1,15 +1,14 @@ -use std::{ - collections::{BTreeSet, VecDeque}, - time::Instant, -}; +use std::{collections::BTreeSet, time::Instant}; use indexmap::IndexMap; use ql_wire::{ - CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionCloseBody, StreamClose, StreamData, - StreamId, XID, + CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionCloseBody, StreamCloseVec, StreamId, + XID, }; -use super::{reassembly::StreamAssembler, SessionState, SESSION_RECORD_TRACKED_WINDOW}; +use super::{ + stream_rx::StreamRx, stream_tx::StreamTx, SessionState, SESSION_RECORD_TRACKED_WINDOW, +}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum StreamParity { @@ -88,59 +87,33 @@ pub enum OutboundState { pub enum InboundState { Open, Finished, - Closed(StreamClose), + Closed(StreamCloseVec), Discarding, } #[derive(Debug)] pub struct StreamState { pub role: StreamRole, - pub send_buf: VecDeque, - // TODO: this is a stopgap shape and should be replaced. - // - // right now we keep sent-but-not-yet-acked stream bytes here as `ql_wire::StreamData` - // segments so the session scheduler can re-pack them into later records after loss. - // that works mechanically, but it is the wrong abstraction boundary: the fsm is caching - // wire-shaped frames instead of owning transport-neutral outbound stream state. - // - // the cleaner model is: - // - keep one authoritative outbound byte buffer per stream - // - track offsets/cursors into that buffer: - // - oldest buffered offset - // - next unsent offset - // - final offset, if known - // - keep lightweight sent-range/manifests that reference byte ranges in that buffer - // instead of cloning/storing `StreamData` - // - build `ql_wire::StreamData` only at pack time - // - free buffered prefix bytes only once no in-flight record manifest still references them - // - // that would let `send_buf` remain the source of truth while record manifests explain - // which accepted byte ranges were carried by which record. - pub retransmit: VecDeque, - pub pending_close: Option, - pub inflight_bytes: usize, - pub next_send_offset: u64, + pub rx: StreamRx, + pub tx: StreamTx, + pub pending_close: Option, pub peer_max_offset: u64, pub outbound_state: OutboundState, pub inbound_state: InboundState, - pub recv: StreamAssembler, pub advertised_max_offset: u64, pub pending_window: bool, } impl StreamState { - pub fn new(role: StreamRole, _send_buffer_size: usize, receive_buffer_size: usize) -> Self { + pub fn new(role: StreamRole, receive_buffer_size: usize) -> Self { Self { role, - send_buf: VecDeque::new(), - retransmit: VecDeque::new(), + tx: StreamTx::new(), pending_close: None, - inflight_bytes: 0, - next_send_offset: 0, peer_max_offset: receive_buffer_size as u64, outbound_state: OutboundState::Open, inbound_state: InboundState::Open, - recv: StreamAssembler::new(receive_buffer_size), + rx: StreamRx::new(receive_buffer_size), advertised_max_offset: receive_buffer_size as u64, pending_window: false, } @@ -151,7 +124,7 @@ impl StreamState { } pub fn buffered_send_bytes(&self) -> usize { - self.send_buf.len().saturating_add(self.inflight_bytes) + self.tx.buffered_len() } pub fn send_capacity(&self, send_buffer_size: usize) -> usize { @@ -159,28 +132,35 @@ impl StreamState { } pub fn readable_bytes(&self) -> usize { - self.recv.readable_len() + self.rx.readable_len() } pub fn recv_limit(&self) -> u64 { - self.recv + self.rx .start_offset() - .saturating_add(self.recv.max_buffered() as u64) + .saturating_add(self.rx.max_buffered() as u64) } pub fn reset_recv(&mut self) { - self.recv = - StreamAssembler::with_start_offset(self.recv.start_offset(), self.recv.max_buffered()); + self.rx = StreamRx::with_start_offset(self.rx.start_offset(), self.rx.max_buffered()); } } #[derive(Debug, Clone)] pub enum ReliableFrame { - StreamData(StreamData), - StreamClose(StreamClose), + StreamData(StreamDataManifest), + StreamClose(StreamCloseVec), Close(SessionCloseBody), } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct StreamDataManifest { + pub stream_id: StreamId, + pub offset: u64, + pub len: usize, + pub fin: bool, +} + #[derive(Debug, Clone)] pub struct OutboundRecord { pub seq: RecordSeq, diff --git a/ql-fsm/src/session/reassembly.rs b/ql-fsm/src/session/stream_rx.rs similarity index 77% rename from ql-fsm/src/session/reassembly.rs rename to ql-fsm/src/session/stream_rx.rs index c3120e72..34c3c797 100644 --- a/ql-fsm/src/session/reassembly.rs +++ b/ql-fsm/src/session/stream_rx.rs @@ -2,7 +2,7 @@ use std::collections::VecDeque; /// reassembles one stream direction from out-of-order byte ranges. #[derive(Debug, Clone, PartialEq, Eq)] -pub struct StreamAssembler { +pub struct StreamRx { start_offset: u64, bytes: VecDeque, missing: MissingRanges, @@ -23,7 +23,7 @@ pub struct InsertOutcome { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum StreamAssemblerError { +pub enum StreamRxError { OffsetOverflow, OutOfWindow, InconsistentFinalOffset, @@ -40,7 +40,7 @@ pub struct StreamReadIter<'a> { back: Option<&'a [u8]>, } -impl StreamAssembler { +impl StreamRx { pub fn new(max_buffered: usize) -> Self { Self::with_start_offset(0, max_buffered) } @@ -63,20 +63,10 @@ impl StreamAssembler { self.start_offset + self.bytes.len() as u64 } - #[cfg(test)] - pub fn final_offset(&self) -> Option { - self.final_offset - } - pub fn max_buffered(&self) -> usize { self.max_buffered } - #[cfg(test)] - pub fn missing_ranges(&self) -> &[MissingRange] { - self.missing.as_slice() - } - pub fn readable_len(&self) -> usize { if self.bytes.is_empty() { return 0; @@ -133,10 +123,10 @@ impl StreamAssembler { offset: u64, fin: bool, bytes: &[u8], - ) -> Result { + ) -> Result { let end = offset .checked_add(bytes.len() as u64) - .ok_or(StreamAssemblerError::OffsetOverflow)?; + .ok_or(StreamRxError::OffsetOverflow)?; let was_complete = self.is_complete(); let old_readable = self.readable_len(); @@ -146,7 +136,7 @@ impl StreamAssembler { } if let Some(final_offset) = self.final_offset { if end > final_offset { - return Err(StreamAssemblerError::BeyondFinalOffset); + return Err(StreamRxError::BeyondFinalOffset); } } @@ -171,10 +161,10 @@ impl StreamAssembler { Ok(self.insert_outcome(was_complete, old_readable)) } - pub fn consume(&mut self, len: usize) -> Result<(), StreamAssemblerError> { + pub fn consume(&mut self, len: usize) -> Result<(), StreamRxError> { let readable = self.readable_len(); if len > readable { - return Err(StreamAssemblerError::ConsumeBeyondReadable); + return Err(StreamRxError::ConsumeBeyondReadable); } self.bytes.drain(..len); @@ -189,36 +179,33 @@ impl StreamAssembler { } } - fn set_or_validate_final_offset( - &mut self, - final_offset: u64, - ) -> Result<(), StreamAssemblerError> { + fn set_or_validate_final_offset(&mut self, final_offset: u64) -> Result<(), StreamRxError> { if let Some(existing) = self.final_offset { return if existing == final_offset { Ok(()) } else { - Err(StreamAssemblerError::InconsistentFinalOffset) + Err(StreamRxError::InconsistentFinalOffset) }; } let buffered_end = self.buffered_end_offset(); if final_offset < buffered_end { - return Err(StreamAssemblerError::FinalOffsetBeforeBufferedData); + return Err(StreamRxError::FinalOffsetBeforeBufferedData); } self.final_offset = Some(final_offset); Ok(()) } - fn ensure_within_window(&self, end: u64) -> Result<(), StreamAssemblerError> { + fn ensure_within_window(&self, end: u64) -> Result<(), StreamRxError> { let attempted = end.saturating_sub(self.start_offset); if attempted > self.max_buffered as u64 { - return Err(StreamAssemblerError::OutOfWindow); + return Err(StreamRxError::OutOfWindow); } Ok(()) } - fn ensure_buffered(&mut self, end: u64) -> Result<(), StreamAssemblerError> { + fn ensure_buffered(&mut self, end: u64) -> Result<(), StreamRxError> { let buffered_end = self.buffered_end_offset(); if end <= buffered_end { return Ok(()); @@ -232,7 +219,7 @@ impl StreamAssembler { }) } - fn push_missing_range(&mut self, range: MissingRange) -> Result<(), StreamAssemblerError> { + fn push_missing_range(&mut self, range: MissingRange) -> Result<(), StreamRxError> { if range.start >= range.end { return Ok(()); } @@ -247,7 +234,7 @@ impl StreamAssembler { self.missing.push(range) } - fn validate_overlap(&self, offset: u64, bytes: &[u8]) -> Result<(), StreamAssemblerError> { + fn validate_overlap(&self, offset: u64, bytes: &[u8]) -> Result<(), StreamRxError> { let mut gap_index = self.first_gap_index_after(offset); for (index, byte) in bytes.iter().copied().enumerate() { @@ -265,7 +252,7 @@ impl StreamAssembler { } if self.byte_at(absolute) != byte { - return Err(StreamAssemblerError::ConflictingOverlap); + return Err(StreamRxError::ConflictingOverlap); } } @@ -280,7 +267,7 @@ impl StreamAssembler { } } - fn subtract_missing_range(&mut self, start: u64, end: u64) -> Result<(), StreamAssemblerError> { + fn subtract_missing_range(&mut self, start: u64, end: u64) -> Result<(), StreamRxError> { let first = self.first_gap_index_after(start); if first == self.missing.len() || self.missing[first].start >= end { return Ok(()); @@ -413,18 +400,18 @@ impl MissingRanges { } } - fn push(&mut self, range: MissingRange) -> Result<(), StreamAssemblerError> { + fn push(&mut self, range: MissingRange) -> Result<(), StreamRxError> { if self.len == N { - return Err(StreamAssemblerError::TooManyMissingRanges); + return Err(StreamRxError::TooManyMissingRanges); } self.ranges[self.len] = range; self.len += 1; Ok(()) } - fn insert(&mut self, index: usize, range: MissingRange) -> Result<(), StreamAssemblerError> { + fn insert(&mut self, index: usize, range: MissingRange) -> Result<(), StreamRxError> { if self.len == N { - return Err(StreamAssemblerError::TooManyMissingRanges); + return Err(StreamRxError::TooManyMissingRanges); } for i in (index..self.len).rev() { self.ranges[i + 1] = self.ranges[i]; @@ -476,13 +463,13 @@ impl std::ops::IndexMut for MissingRanges { #[cfg(test)] mod tests { - use super::{InsertOutcome, MissingRange, StreamAssembler, StreamAssemblerError}; + use super::{InsertOutcome, MissingRange, StreamRx, StreamRxError}; #[test] fn contiguous_insert_becomes_readable_and_complete() { - let mut assembler = StreamAssembler::<8>::new(64); + let mut rx = StreamRx::<8>::new(64); - let outcome = assembler.insert(0, true, b"hello").unwrap(); + let outcome = rx.insert(0, true, b"hello").unwrap(); assert_eq!( outcome, @@ -491,18 +478,18 @@ mod tests { became_complete: true, } ); - assert_eq!(assembler.readable_len(), 5); - assert_eq!(assembler.copy_readable(), b"hello"); - assert_eq!(assembler.final_offset(), Some(5)); - assert!(assembler.is_complete()); - assert!(assembler.missing_ranges().is_empty()); + assert_eq!(rx.readable_len(), 5); + assert_eq!(rx.copy_readable(), b"hello"); + assert_eq!(rx.final_offset, Some(5)); + assert!(rx.is_complete()); + assert!(rx.missing.is_empty()); } #[test] fn out_of_order_insert_tracks_missing_ranges_until_gap_is_filled() { - let mut assembler = StreamAssembler::<8>::new(64); + let mut rx = StreamRx::<8>::new(64); - let first = assembler.insert(5, true, b" world").unwrap(); + let first = rx.insert(5, true, b" world").unwrap(); assert_eq!( first, InsertOutcome { @@ -510,13 +497,10 @@ mod tests { became_complete: false, } ); - assert_eq!( - assembler.missing_ranges(), - &[MissingRange { start: 0, end: 5 }] - ); - assert_eq!(assembler.readable_len(), 0); + assert_eq!(rx.missing.as_slice(), &[MissingRange { start: 0, end: 5 }]); + assert_eq!(rx.readable_len(), 0); - let second = assembler.insert(0, false, b"hello").unwrap(); + let second = rx.insert(0, false, b"hello").unwrap(); assert_eq!( second, InsertOutcome { @@ -524,17 +508,17 @@ mod tests { became_complete: true, } ); - assert_eq!(assembler.copy_readable(), b"hello world"); - assert!(assembler.missing_ranges().is_empty()); - assert!(assembler.is_complete()); + assert_eq!(rx.copy_readable(), b"hello world"); + assert!(rx.missing.is_empty()); + assert!(rx.is_complete()); } #[test] fn duplicate_insert_is_ignored_if_bytes_match() { - let mut assembler = StreamAssembler::<8>::new(64); + let mut rx = StreamRx::<8>::new(64); - assembler.insert(0, false, b"hello").unwrap(); - let duplicate = assembler.insert(0, false, b"hello").unwrap(); + rx.insert(0, false, b"hello").unwrap(); + let duplicate = rx.insert(0, false, b"hello").unwrap(); assert_eq!( duplicate, @@ -543,29 +527,29 @@ mod tests { became_complete: false, } ); - assert_eq!(assembler.copy_readable(), b"hello"); + assert_eq!(rx.copy_readable(), b"hello"); } #[test] fn conflicting_overlap_is_rejected() { - let mut assembler = StreamAssembler::<8>::new(64); + let mut rx = StreamRx::<8>::new(64); - assembler.insert(0, false, b"abcdef").unwrap(); - let error = assembler.insert(3, false, b"xyz").unwrap_err(); + rx.insert(0, false, b"abcdef").unwrap(); + let error = rx.insert(3, false, b"xyz").unwrap_err(); - assert_eq!(error, StreamAssemblerError::ConflictingOverlap); + assert_eq!(error, StreamRxError::ConflictingOverlap); } #[test] fn consume_advances_start_offset_and_trims_old_prefix() { - let mut assembler = StreamAssembler::<8>::new(64); + let mut rx = StreamRx::<8>::new(64); - assembler.insert(0, false, b"abcd").unwrap(); - assembler.consume(2).unwrap(); - assert_eq!(assembler.start_offset(), 2); - assert_eq!(assembler.copy_readable(), b"cd"); + rx.insert(0, false, b"abcd").unwrap(); + rx.consume(2).unwrap(); + assert_eq!(rx.start_offset(), 2); + assert_eq!(rx.copy_readable(), b"cd"); - let outcome = assembler.insert(1, true, b"bcde").unwrap(); + let outcome = rx.insert(1, true, b"bcde").unwrap(); assert_eq!( outcome, InsertOutcome { @@ -573,39 +557,39 @@ mod tests { became_complete: true, } ); - assert_eq!(assembler.copy_readable(), b"cde"); - assert_eq!(assembler.final_offset(), Some(5)); - assert!(assembler.is_complete()); + assert_eq!(rx.copy_readable(), b"cde"); + assert_eq!(rx.final_offset, Some(5)); + assert!(rx.is_complete()); } #[test] fn insert_rejects_when_missing_range_budget_is_exhausted() { - let mut assembler = StreamAssembler::<2>::new(64); + let mut rx = StreamRx::<2>::new(64); - assembler.insert(1, false, b"a").unwrap(); - assembler.insert(3, false, b"b").unwrap(); - let error = assembler.insert(5, false, b"c").unwrap_err(); + rx.insert(1, false, b"a").unwrap(); + rx.insert(3, false, b"b").unwrap(); + let error = rx.insert(5, false, b"c").unwrap_err(); - assert_eq!(error, StreamAssemblerError::TooManyMissingRanges); + assert_eq!(error, StreamRxError::TooManyMissingRanges); } #[test] fn insert_can_fill_multiple_gaps_without_rebuilding_state() { - let mut assembler = StreamAssembler::<8>::new(64); + let mut rx = StreamRx::<8>::new(64); - assembler.insert(0, false, b"ab").unwrap(); - assembler.insert(4, false, b"ef").unwrap(); - assembler.insert(8, true, b"ij").unwrap(); + rx.insert(0, false, b"ab").unwrap(); + rx.insert(4, false, b"ef").unwrap(); + rx.insert(8, true, b"ij").unwrap(); assert_eq!( - assembler.missing_ranges(), + rx.missing.as_slice(), &[ MissingRange { start: 2, end: 4 }, MissingRange { start: 6, end: 8 }, ] ); - let outcome = assembler.insert(2, false, b"cdefgh").unwrap(); + let outcome = rx.insert(2, false, b"cdefgh").unwrap(); assert_eq!( outcome, @@ -614,8 +598,8 @@ mod tests { became_complete: true, } ); - assert!(assembler.missing_ranges().is_empty()); - assert_eq!(assembler.copy_readable(), b"abcdefghij"); - assert!(assembler.is_complete()); + assert!(rx.missing.is_empty()); + assert_eq!(rx.copy_readable(), b"abcdefghij"); + assert!(rx.is_complete()); } } diff --git a/ql-fsm/src/session/stream_tx.rs b/ql-fsm/src/session/stream_tx.rs new file mode 100644 index 00000000..7c1a11d4 --- /dev/null +++ b/ql-fsm/src/session/stream_tx.rs @@ -0,0 +1,337 @@ +use std::collections::VecDeque; + +use ql_wire::RangedByteChunks; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum SendState { + Unsent, + InFlight, + Lost, + Acked, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct SendSegment { + offset: u64, + len: usize, + state: SendState, +} + +impl SendSegment { + fn end_offset(&self) -> u64 { + self.offset + self.len as u64 + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct TrackedFinalOffset { + offset: u64, + state: SendState, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct StreamTxRange { + pub offset: u64, + pub len: usize, + pub fin: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamTx { + bytes: VecDeque, + base_offset: u64, + segments: VecDeque, + final_offset: Option, +} + +impl StreamTx { + pub fn new() -> Self { + Self { + bytes: VecDeque::new(), + base_offset: 0, + segments: VecDeque::new(), + final_offset: None, + } + } + + pub fn buffered_len(&self) -> usize { + self.bytes.len() + } + + pub fn end_offset(&self) -> u64 { + self.base_offset + self.bytes.len() as u64 + } + + pub fn has_pending(&self) -> bool { + self.segments + .iter() + .any(|segment| matches!(segment.state, SendState::Unsent | SendState::Lost)) + || self + .final_offset + .is_some_and(|final_offset| matches!(final_offset.state, SendState::Unsent | SendState::Lost)) + } + + pub fn is_empty(&self) -> bool { + self.bytes.is_empty() && self.segments.is_empty() && self.final_offset.is_none() + } + + pub fn append(&mut self, bytes: &[u8]) { + if bytes.is_empty() { + return; + } + + let start = self.end_offset(); + self.bytes.extend(bytes.iter().copied()); + if let Some(last) = self.segments.back_mut() { + if last.state == SendState::Unsent && last.end_offset() == start { + last.len += bytes.len(); + return; + } + } + + self.segments.push_back(SendSegment { + offset: start, + len: bytes.len(), + state: SendState::Unsent, + }); + } + + pub fn queue_fin(&mut self) { + self.final_offset = Some(TrackedFinalOffset { + offset: self.end_offset(), + state: SendState::Unsent, + }); + } + + pub fn next_range(&self, max_payload: usize, peer_max_offset: u64) -> Option { + let mut unsent = None; + for segment in &self.segments { + if !matches!(segment.state, SendState::Lost | SendState::Unsent) { + continue; + } + + let credit_remaining = peer_max_offset.saturating_sub(segment.offset); + let credit_remaining = credit_remaining.min(usize::MAX as u64) as usize; + let len = segment.len.min(max_payload).min(credit_remaining); + if len == 0 { + continue; + } + + let fin = self.final_offset.is_some_and(|final_offset| { + matches!(final_offset.state, SendState::Lost | SendState::Unsent) + && final_offset.offset == segment.offset + len as u64 + }); + let range = StreamTxRange { + offset: segment.offset, + len, + fin, + }; + + if segment.state == SendState::Lost { + return Some(range); + } + unsent = Some(range); + } + + if let Some(range) = unsent { + return Some(range); + } + + let final_offset = self.final_offset?; + if !matches!(final_offset.state, SendState::Lost | SendState::Unsent) { + return None; + } + if final_offset.offset > peer_max_offset { + return None; + } + Some(StreamTxRange { + offset: final_offset.offset, + len: 0, + fin: true, + }) + } + + pub fn ranged_bytes(&self, range: StreamTxRange) -> RangedByteChunks<&VecDeque> { + let offset = usize::try_from(range.offset - self.base_offset).unwrap(); + RangedByteChunks { + inner: &self.bytes, + offset, + len: range.len, + } + } + + pub fn mark_in_flight(&mut self, range: StreamTxRange) { + self.set_segment_state(range.offset, range.len, SendState::InFlight); + if range.fin { + if let Some(final_offset) = self.final_offset.as_mut() { + final_offset.state = SendState::InFlight; + } + } + } + + pub fn mark_lost(&mut self, range: StreamTxRange) { + self.set_segment_state(range.offset, range.len, SendState::Lost); + if range.fin { + if let Some(final_offset) = self.final_offset.as_mut() { + final_offset.state = SendState::Lost; + } + } + } + + pub fn mark_acked(&mut self, range: StreamTxRange) { + self.set_segment_state(range.offset, range.len, SendState::Acked); + if range.fin { + if let Some(final_offset) = self.final_offset.as_mut() { + final_offset.state = SendState::Acked; + } + } + self.trim_acked_prefix(); + } + + pub fn clear(&mut self) { + self.bytes.clear(); + self.segments.clear(); + self.final_offset = None; + } + + fn set_segment_state(&mut self, offset: u64, len: usize, state: SendState) { + if len == 0 { + return; + } + + let Some(index) = self + .segments + .iter() + .position(|segment| segment.offset == offset && segment.len >= len) + else { + return; + }; + + if self.segments[index].len == len { + self.segments[index].state = state; + } else { + let segment = self.segments.remove(index).unwrap(); + self.segments.insert(index, SendSegment { offset, len, state }); + self.segments.insert( + index + 1, + SendSegment { + offset: offset + len as u64, + len: segment.len - len, + state: segment.state, + }, + ); + } + + self.merge_adjacent_segments(); + } + + fn merge_adjacent_segments(&mut self) { + let mut index = 1; + while index < self.segments.len() { + let prev = self.segments[index - 1]; + let next = self.segments[index]; + if prev.state == next.state && prev.end_offset() == next.offset { + self.segments[index - 1].len += next.len; + self.segments.remove(index); + } else { + index += 1; + } + } + } + + fn trim_acked_prefix(&mut self) { + while matches!( + self.segments.front(), + Some(segment) if segment.state == SendState::Acked + ) { + let len = self.segments.pop_front().unwrap().len; + self.bytes.drain(..len); + self.base_offset = self.base_offset.saturating_add(len as u64); + } + + if self.final_offset.is_some_and(|final_offset| { + final_offset.state == SendState::Acked && final_offset.offset == self.base_offset + }) { + self.final_offset = None; + } + } +} + +#[cfg(test)] +mod tests { + use super::{StreamTx, StreamTxRange}; + + #[test] + fn append_tracks_unsent_tail() { + let mut tx = StreamTx::new(); + tx.append(b"abc"); + tx.append(b"de"); + + assert_eq!( + tx.next_range(8, u64::MAX), + Some(StreamTxRange { + offset: 0, + len: 5, + fin: false, + }) + ); + } + + #[test] + fn lost_range_is_selected_before_unsent_tail() { + let mut tx = StreamTx::new(); + tx.append(b"abcdef"); + + let first = tx.next_range(3, u64::MAX).unwrap(); + tx.mark_in_flight(first); + tx.mark_lost(first); + + assert_eq!( + tx.next_range(3, u64::MAX), + Some(StreamTxRange { + offset: 0, + len: 3, + fin: false, + }) + ); + } + + #[test] + fn acked_prefix_is_trimmed() { + let mut tx = StreamTx::new(); + tx.append(b"abcdef"); + + let first = tx.next_range(3, u64::MAX).unwrap(); + tx.mark_in_flight(first); + tx.mark_acked(first); + + assert_eq!( + tx.next_range(3, u64::MAX), + Some(StreamTxRange { + offset: 3, + len: 3, + fin: false, + }) + ); + } + + #[test] + fn empty_fin_is_tracked_separately() { + let mut tx = StreamTx::new(); + tx.queue_fin(); + + let range = tx.next_range(16, u64::MAX).unwrap(); + assert_eq!( + range, + StreamTxRange { + offset: 0, + len: 0, + fin: true, + } + ); + + tx.mark_in_flight(range); + tx.mark_acked(range); + assert!(tx.is_empty()); + } +} diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 4b66c0ed..3e98a102 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -2,7 +2,7 @@ use std::time::{Duration, Instant}; use ql_wire::{ CloseCode, CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionFrame, SessionRecord, - StreamClose, StreamData, StreamId, XID, + StreamCloseVec, StreamData, StreamId, XID, }; use super::{state::StreamParity, SessionEvent, SessionFsm, SessionFsmConfig}; @@ -24,14 +24,16 @@ fn read_stream_all(fsm: &mut SessionFsm, stream_id: StreamId) -> Vec { } fn next_outbound(fsm: &mut SessionFsm, now: Instant) -> Option { - let (write_id, record) = fsm.take_next_write(now)?; + let (write_id, builder) = fsm.take_next_write(now)?; fsm.confirm_write(now, write_id); - Some(record) + Some(SessionRecord::decode(builder.bytes()).unwrap()) } fn receive_events(fsm: &mut SessionFsm, now: Instant, record: SessionRecord) -> Vec { + let bytes = record.encode(); + let (seq, frames) = SessionRecord::parse(&bytes).unwrap(); let mut events = Vec::new(); - fsm.receive(now, record, |event| events.push(event)); + fsm.receive(now, seq, frames, |event| events.push(event)); events } @@ -197,11 +199,12 @@ fn remote_stream_close_is_reliable_and_retried() { ) .unwrap(); - let (write_id, first) = fsm.take_next_write(now).unwrap(); + let (write_id, builder) = fsm.take_next_write(now).unwrap(); fsm.confirm_write(now, write_id); + let first = SessionRecord::decode(builder.bytes()).unwrap(); assert!(matches!( first.frames.as_slice(), - [SessionFrame::StreamClose(StreamClose { stream_id: id, .. })] if *id == stream_id + [SessionFrame::StreamClose(StreamCloseVec { stream_id: id, .. })] if *id == stream_id )); fsm.on_timer(now + Duration::from_millis(200), |_| {}); diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index e09d821c..fcd71243 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -33,7 +33,7 @@ pub enum HandshakeResponder { pub struct RecentReady { pub hello: Hello, pub reply: HelloReply, - pub ready: Ready, + pub ready: Ready>, pub expires_at: Instant, } diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index abee0d40..4c681722 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -8,8 +8,8 @@ use std::{ use libcrux_aesgcm::AesGcm256Key; use ql_wire::{ - self, generate_ml_dsa_keypair, generate_ml_kem_keypair, EncryptedMessage, QlCrypto, QlIdentity, - QlPayload, QlRecord, SessionKey, XID, + self, generate_ml_dsa_keypair, generate_ml_kem_keypair, QlCrypto, QlIdentity, QlPayload, + QlRecord, SessionKey, XID, ENCRYPTED_MESSAGE_AUTH_SIZE, }; use sha2::{Digest, Sha256}; @@ -55,10 +55,10 @@ impl QlCrypto for TestCrypto { nonce: &ql_wire::Nonce, aad: &[u8], buffer: &mut [u8], - ) -> [u8; EncryptedMessage::AUTH_SIZE] { + ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE] { let key: AesGcm256Key = (*key.data()).into(); let plaintext = buffer.to_vec(); - let mut auth = [0u8; EncryptedMessage::AUTH_SIZE]; + let mut auth = [0u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; key.encrypt( buffer, (&mut auth).into(), @@ -76,7 +76,7 @@ impl QlCrypto for TestCrypto { nonce: &ql_wire::Nonce, aad: &[u8], buffer: &mut [u8], - auth_tag: &[u8; EncryptedMessage::AUTH_SIZE], + auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], ) -> bool { let key: AesGcm256Key = (*key.data()).into(); let ciphertext = buffer.to_vec(); diff --git a/ql-wire/src/bytes.rs b/ql-wire/src/bytes.rs index adb4b642..88ed03c7 100644 --- a/ql-wire/src/bytes.rs +++ b/ql-wire/src/bytes.rs @@ -147,11 +147,24 @@ pub struct CappedByteChunks { pub limit: usize, } +#[derive(Debug, Clone, Copy)] +pub struct RangedByteChunks { + pub inner: T, + pub offset: usize, + pub len: usize, +} + pub struct CappedByteChunksIter { inner: I, remaining: usize, } +pub struct RangedByteChunksIter { + inner: I, + skip: usize, + remaining: usize, +} + impl<'a, I> Iterator for CappedByteChunksIter where I: Iterator, @@ -174,6 +187,35 @@ where } } +impl<'a, I> Iterator for RangedByteChunksIter +where + I: Iterator, +{ + type Item = &'a [u8]; + + fn next(&mut self) -> Option { + while self.remaining > 0 { + let chunk = self.inner.next()?; + if self.skip >= chunk.len() { + self.skip -= chunk.len(); + continue; + } + + let chunk = &chunk[self.skip..]; + self.skip = 0; + if chunk.is_empty() { + continue; + } + + let len = chunk.len().min(self.remaining); + self.remaining -= len; + return Some(&chunk[..len]); + } + + None + } +} + impl ByteChunks for CappedByteChunks { type Chunks<'a> = CappedByteChunksIter> @@ -192,11 +234,30 @@ impl ByteChunks for CappedByteChunks { } } +impl ByteChunks for RangedByteChunks { + type Chunks<'a> + = RangedByteChunksIter> + where + Self: 'a; + + fn len(&self) -> usize { + self.inner.len().saturating_sub(self.offset).min(self.len) + } + + fn chunks(&self) -> Self::Chunks<'_> { + RangedByteChunksIter { + inner: self.inner.chunks(), + skip: self.offset, + remaining: self.len(), + } + } +} + #[cfg(test)] mod tests { use std::collections::VecDeque; - use super::{ByteChunks, ByteSlice, ByteSliceMut, CappedByteChunks}; + use super::{ByteChunks, ByteSlice, ByteSliceMut, CappedByteChunks, RangedByteChunks}; #[test] fn shared_slice_split_at() { @@ -278,4 +339,36 @@ mod tests { assert_eq!(capped.len(), 4); assert_eq!(chunks.concat(), b"cdef"); } + + #[test] + fn ranged_byte_chunks_slice_middle() { + let bytes: &[u8] = b"abcdef"; + let ranged = RangedByteChunks { + inner: bytes, + offset: 2, + len: 3, + }; + + let chunks = ranged.chunks().collect::>(); + assert_eq!(ranged.len(), 3); + assert_eq!(chunks, vec![b"cde".as_slice()]); + } + + #[test] + fn ranged_byte_chunks_borrowed_vec_deque_middle() { + let mut bytes = VecDeque::with_capacity(8); + bytes.extend(b"abcd".iter().copied()); + bytes.drain(..2); + bytes.extend(b"efgh".iter().copied()); + + let ranged = RangedByteChunks { + inner: &bytes, + offset: 1, + len: 4, + }; + + let chunks = ranged.chunks().collect::>(); + assert_eq!(ranged.len(), 4); + assert_eq!(chunks.concat(), b"defg"); + } } From c7a98b5e0742a1dafdc170477d677cd0da5fb65b Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 29 Mar 2026 08:27:09 -0400 Subject: [PATCH 050/304] ql: get rid of stream close payload --- ql-fsm/src/implementation/fsm.rs | 3 +- ql-fsm/src/lib.rs | 7 ++--- ql-fsm/src/session/mod.rs | 16 ++++------- ql-fsm/src/session/state.rs | 8 +++--- ql-fsm/src/session/tests.rs | 5 ++-- ql-wire/src/encrypted/builder.rs | 2 +- ql-wire/src/encrypted/mod.rs | 5 ++-- ql-wire/src/encrypted/stream_close.rs | 41 +++++++-------------------- ql-wire/src/tests.rs | 10 ++----- 9 files changed, 31 insertions(+), 66 deletions(-) diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index 5f2146c0..750cacc3 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -218,10 +218,9 @@ pub fn close_stream( stream_id: StreamId, target: CloseTarget, code: CloseCode, - payload: Vec, ) -> Result<(), QlFsmError> { ensure_peer_bound(fsm)?; - Ok(fsm.session.close_stream(stream_id, target, code, payload)?) + Ok(fsm.session.close_stream(stream_id, target, code)?) } pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), QlFsmError> { diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index e9980a5c..3953c153 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -31,7 +31,7 @@ use std::time::{Duration, Instant}; pub use error::QlFsmError; use ql_wire::{ CloseCode, CloseTarget, MlDsaPublicKey, MlKemPublicKey, QlCrypto, QlIdentity, QlRecord, - SessionCloseBody, StreamCloseVec, StreamId, XID, + SessionCloseBody, StreamClose, StreamId, XID, }; pub use session::stream_rx::StreamReadIter; @@ -102,7 +102,7 @@ pub enum QlSessionEvent { /// the peer finished writing this stream Finished(StreamId), /// a stream was closed - Closed(StreamCloseVec), + Closed(StreamClose), /// local writes on this stream are closed WritableClosed(StreamId), /// the peer requested unpairing @@ -330,9 +330,8 @@ impl QlFsm { stream_id: StreamId, target: CloseTarget, code: CloseCode, - payload: Vec, ) -> Result<(), QlFsmError> { - implementation::close_stream(self, stream_id, target, code, payload) + implementation::close_stream(self, stream_id, target, code) } /// queues a ping on the active session diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index bbb32fde..16571f10 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -10,8 +10,7 @@ use std::time::{Duration, Instant}; use indexmap::map::Entry; use ql_wire::{ CloseCode, CloseTarget, RecordAck, RecordSeq, SessionCloseBody, SessionFrame, - SessionRecordBuilder, StreamClose, StreamCloseVec, StreamData, StreamId, StreamWindow, - WireError, + SessionRecordBuilder, StreamClose, StreamData, StreamId, StreamWindow, WireError, }; use self::{ @@ -59,7 +58,7 @@ pub enum SessionEvent { Readable(StreamId), Writable(StreamId), Finished(StreamId), - Closed(StreamCloseVec), + Closed(StreamClose), WritableClosed(StreamId), SessionClosed(SessionCloseBody), } @@ -171,7 +170,6 @@ impl SessionFsm { stream_id: StreamId, target: CloseTarget, code: CloseCode, - payload: Vec, ) -> Result<(), StreamError> { self.ensure_session_open()?; { @@ -185,7 +183,6 @@ impl SessionFsm { stream_id, target, code, - payload, }); } self.try_reap_stream(stream_id); @@ -293,10 +290,7 @@ impl SessionFsm { } } SessionFrame::StreamClose(frame) => { - if self - .handle_stream_close(frame.into_owned(), &mut emit) - .is_err() - { + if self.handle_stream_close(frame, &mut emit).is_err() { return; } } @@ -848,7 +842,7 @@ impl SessionFsm { fn handle_stream_close( &mut self, - frame: StreamCloseVec, + frame: StreamClose, emit: &mut impl FnMut(SessionEvent), ) -> Result<(), ()> { let created = match self.state.streams.entry(frame.stream_id) { @@ -933,7 +927,7 @@ impl SessionFsm { matches!(target, CloseTarget::Both) || role.outbound_target() == target } - fn restore_stream_close(&mut self, close: StreamCloseVec) { + fn restore_stream_close(&mut self, close: StreamClose) { if let Some(stream) = self.state.streams.get_mut(&close.stream_id) { stream.pending_close = Some(close); } diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index c41af21d..e85d23d1 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -2,7 +2,7 @@ use std::{collections::BTreeSet, time::Instant}; use indexmap::IndexMap; use ql_wire::{ - CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionCloseBody, StreamCloseVec, StreamId, + CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionCloseBody, StreamClose, StreamId, XID, }; @@ -87,7 +87,7 @@ pub enum OutboundState { pub enum InboundState { Open, Finished, - Closed(StreamCloseVec), + Closed(StreamClose), Discarding, } @@ -96,7 +96,7 @@ pub struct StreamState { pub role: StreamRole, pub rx: StreamRx, pub tx: StreamTx, - pub pending_close: Option, + pub pending_close: Option, pub peer_max_offset: u64, pub outbound_state: OutboundState, pub inbound_state: InboundState, @@ -149,7 +149,7 @@ impl StreamState { #[derive(Debug, Clone)] pub enum ReliableFrame { StreamData(StreamDataManifest), - StreamClose(StreamCloseVec), + StreamClose(StreamClose), Close(SessionCloseBody), } diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 3e98a102..b4dca534 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -2,7 +2,7 @@ use std::time::{Duration, Instant}; use ql_wire::{ CloseCode, CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionFrame, SessionRecord, - StreamCloseVec, StreamData, StreamId, XID, + StreamClose, StreamData, StreamId, XID, }; use super::{state::StreamParity, SessionEvent, SessionFsm, SessionFsmConfig}; @@ -195,7 +195,6 @@ fn remote_stream_close_is_reliable_and_retried() { stream_id, CloseTarget::Both, CloseCode::CANCELLED, - b"bye".to_vec(), ) .unwrap(); @@ -204,7 +203,7 @@ fn remote_stream_close_is_reliable_and_retried() { let first = SessionRecord::decode(builder.bytes()).unwrap(); assert!(matches!( first.frames.as_slice(), - [SessionFrame::StreamClose(StreamCloseVec { stream_id: id, .. })] if *id == stream_id + [SessionFrame::StreamClose(StreamClose { stream_id: id, .. })] if *id == stream_id )); fsm.on_timer(now + Duration::from_millis(200), |_| {}); diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index df106391..87429be3 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -91,7 +91,7 @@ impl SessionRecordBuilder { true } - pub fn push_stream_close(&mut self, frame: &StreamClose) -> bool { + pub fn push_stream_close(&mut self, frame: &StreamClose) -> bool { if !self.can_push_len(1 + SIZE_LEN + frame.encoded_len()) { return false; } diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 362bcced..245fe2ae 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -40,13 +40,12 @@ pub enum SessionFrame { Ack(RecordAck), StreamData(StreamData), StreamWindow(StreamWindow), - StreamClose(StreamClose), + StreamClose(StreamClose), Close(SessionCloseBody), } pub type SessionFrameVec = SessionFrame>; pub type StreamDataVec = StreamData>; -pub type StreamCloseVec = StreamClose>; pub(crate) const SIZE_LEN: usize = size_of::(); @@ -171,7 +170,7 @@ impl SessionFrame { Self::Ack(frame) => SessionFrame::Ack(frame), Self::StreamData(frame) => SessionFrame::StreamData(frame.into_owned()), Self::StreamWindow(frame) => SessionFrame::StreamWindow(frame), - Self::StreamClose(frame) => SessionFrame::StreamClose(frame.into_owned()), + Self::StreamClose(frame) => SessionFrame::StreamClose(frame), Self::Close(frame) => SessionFrame::Close(frame), } } diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index 396e5fa9..97f962af 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -1,7 +1,7 @@ use std::mem::size_of; use super::StreamId; -use crate::{codec, ByteChunks, ByteSlice, WireError}; +use crate::{codec, ByteSlice, WireError}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] @@ -49,54 +49,33 @@ impl CloseCode { /// aborts one or both directions of a stream with a close code. #[derive(Debug, Clone, PartialEq, Eq)] -pub struct StreamClose { +pub struct StreamClose { pub stream_id: StreamId, pub target: CloseTarget, pub code: CloseCode, - pub payload: B, } -impl StreamClose { - pub const MIN_WIRE_SIZE: usize = size_of::() + size_of::() + size_of::(); -} +impl StreamClose { + pub const WIRE_SIZE: usize = size_of::() + size_of::() + size_of::(); -impl StreamClose { - pub fn parse(bytes: B) -> Result { + pub fn parse(bytes: B) -> Result { let mut reader = codec::Reader::new(bytes); - Ok(Self { + let close = Self { stream_id: StreamId(reader.take_u32()?), target: CloseTarget::try_from(reader.take_u8()?)?, code: CloseCode(reader.take_u16()?), - payload: reader.take_rest(), - }) + }; + reader.finish()?; + Ok(close) } -} -impl StreamClose { - pub fn into_owned(self) -> StreamClose> - where - B: ByteSlice, - { - StreamClose { - stream_id: self.stream_id, - target: self.target, - code: self.code, - payload: self.payload.to_vec(), - } - } -} - -impl StreamClose { pub fn encoded_len(&self) -> usize { - Self::MIN_WIRE_SIZE + self.payload.len() + Self::WIRE_SIZE } pub fn encode_into(&self, out: &mut Vec) { codec::push_u32(out, self.stream_id.0); codec::push_u8(out, self.target.to_wire()); codec::push_u16(out, self.code.0); - for chunk in self.payload.chunks() { - codec::push_bytes(out, chunk); - } } } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index be7c90fc..00898645 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -96,7 +96,6 @@ fn encrypted_session_record_round_trip_and_decrypt() { stream_id: StreamId(9), target: CloseTarget::Both, code: CloseCode::PROTOCOL, - payload: b"bye".to_vec(), }), SessionFrame::Close(SessionCloseBody { code: CloseCode::TIMEOUT, @@ -152,7 +151,6 @@ fn decrypted_session_record_iterates_zero_copy_frames() { stream_id: StreamId(1), target: CloseTarget::Response, code: CloseCode::CANCELLED, - payload: b"later".to_vec(), }), ], }; @@ -190,10 +188,9 @@ fn decrypted_session_record_iterates_zero_copy_frames() { } match frames.next().unwrap().unwrap() { SessionFrame::StreamClose(frame) => { - let owned = frame.into_owned(); - assert_eq!(owned.stream_id, StreamId(1)); - assert_eq!(owned.target, CloseTarget::Response); - assert_eq!(owned.payload, b"later".to_vec()); + assert_eq!(frame.stream_id, StreamId(1)); + assert_eq!(frame.target, CloseTarget::Response); + assert_eq!(frame.code, CloseCode::CANCELLED); } other => panic!("expected stream close, got {}", frame_name(&other)), } @@ -627,7 +624,6 @@ fn protocol_record_size_breakdown() { stream_id: StreamId(1), target: CloseTarget::Both, code: CloseCode::PROTOCOL, - payload: Vec::new(), })], }, ); From 551b3d7e5bc6ee249908c98ebbb5705d8e492775 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 29 Mar 2026 08:44:37 -0400 Subject: [PATCH 051/304] ql-wire: get rid of duplicate types --- ql-fsm/src/implementation/fsm.rs | 24 +++---- ql-fsm/src/implementation/mod.rs | 2 +- ql-fsm/src/implementation/peer.rs | 5 +- ql-fsm/src/lib.rs | 4 +- ql-fsm/src/state.rs | 2 +- ql-fsm/src/tests/mod.rs | 10 +-- ql-wire/src/encrypted/builder.rs | 2 +- ql-wire/src/encrypted/mod.rs | 2 +- ql-wire/src/pair/crypto.rs | 2 +- ql-wire/src/record.rs | 105 +++++++++--------------------- ql-wire/src/tests.rs | 40 ++++++------ ql-wire/src/unpair/crypto.rs | 2 +- 12 files changed, 78 insertions(+), 122 deletions(-) diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index 750cacc3..21cd2ec2 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -1,6 +1,6 @@ use std::time::Instant; -use ql_wire::{self as wire, CloseCode, CloseTarget, Nonce, QlCrypto, QlPayloadRef, StreamId}; +use ql_wire::{self as wire, CloseCode, CloseTarget, Nonce, QlCrypto, QlPayload, StreamId}; use crate::{ OutboundWrite, QlFsm, QlFsmError, QlFsmEvent, QlSessionEvent, SessionWriteId, StreamReadIter, @@ -11,14 +11,14 @@ pub fn receive( mut bytes: Vec, crypto: &impl QlCrypto, ) -> Result<(), QlFsmError> { - let wire::QlRecordRef { header, payload } = wire::QlRecord::parse_mut(&mut bytes)?; + let wire::QlRecord { header, payload } = wire::QlRecord::parse(&mut bytes[..])?; if header.recipient != fsm.identity.xid { return Err(QlFsmError::InvalidXid); } match &payload { - QlPayloadRef::PairRequest(_) => {} - QlPayloadRef::Unpair(_) => { + QlPayload::PairRequest(_) => {} + QlPayload::Unpair(_) => { let Some(peer) = fsm.peer.as_ref().map(|entry| entry.peer.xid) else { return Ok(()); }; @@ -37,25 +37,25 @@ pub fn receive( } match payload { - QlPayloadRef::PairRequest(request) => { + QlPayload::PairRequest(request) => { super::handle_pair(fsm, crypto, &header, request)?; } - QlPayloadRef::Unpair(unpair) => { + QlPayload::Unpair(unpair) => { super::handle_unpair(fsm, crypto, &header, &unpair)?; } - QlPayloadRef::Hello(hello) => { + QlPayload::Hello(hello) => { super::handle_hello(fsm, crypto, &header, &hello)?; } - QlPayloadRef::HelloReply(reply) => { + QlPayload::HelloReply(reply) => { super::handle_hello_reply(fsm, crypto, &header, &reply)?; } - QlPayloadRef::Confirm(confirm) => { + QlPayload::Confirm(confirm) => { super::handle_confirm(fsm, crypto, &header, &confirm)?; } - QlPayloadRef::Ready(ready) => { + QlPayload::Ready(ready) => { super::handle_ready(fsm, crypto, &header, ready)?; } - QlPayloadRef::Session(encrypted) => { + QlPayload::Session(encrypted) => { let Some((_, session_key)) = super::peer_session(fsm) else { return Err(QlFsmError::NoSession); }; @@ -228,7 +228,7 @@ pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), QlFsmError> { Ok(fsm.session.queue_ping()?) } -pub fn unpair(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option { +pub fn unpair(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option>> { super::handle_unpair_local(fsm, crypto) } diff --git a/ql-fsm/src/implementation/mod.rs b/ql-fsm/src/implementation/mod.rs index 3d8b9261..efae6dbe 100644 --- a/ql-fsm/src/implementation/mod.rs +++ b/ql-fsm/src/implementation/mod.rs @@ -32,7 +32,7 @@ fn next_control_meta(fsm: &mut QlFsm, lifetime: Duration) -> ControlMeta { } } -fn enqueue_handshake(fsm: &mut QlFsm, peer: XID, payload: QlPayload) { +fn enqueue_handshake(fsm: &mut QlFsm, peer: XID, payload: QlPayload>) { fsm.state.outbound.push_back(QlRecord { header: QlHeader { sender: fsm.identity.xid, diff --git a/ql-fsm/src/implementation/peer.rs b/ql-fsm/src/implementation/peer.rs index b58659f2..09fcfcbc 100644 --- a/ql-fsm/src/implementation/peer.rs +++ b/ql-fsm/src/implementation/peer.rs @@ -24,7 +24,10 @@ pub fn handle_pair_local(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), Ok(()) } -pub fn handle_unpair_local(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option { +pub fn handle_unpair_local( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, +) -> Option>> { let peer = fsm.peer.as_ref()?.peer.clone(); let meta = next_control_meta(fsm, fsm.config.control_expiration); let record = wire::build_unpair(crypto, &fsm.identity, peer.xid, meta); diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 3953c153..b64bf5fe 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -119,7 +119,7 @@ pub struct SessionWriteId(pub(crate) u64); #[derive(Debug, Clone, PartialEq)] pub struct OutboundWrite { /// record to hand to the transport - pub record: QlRecord, + pub record: QlRecord>, /// write handle that must be confirmed or rejected pub session_write_id: Option, } @@ -340,7 +340,7 @@ impl QlFsm { } /// clears the bound peer locally and returns a best-effort unpair record - pub fn unpair(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Option { + pub fn unpair(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Option>> { self.state.now = now; implementation::unpair(self, crypto) } diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index fcd71243..2c577b47 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -93,7 +93,7 @@ impl PeerRecord { pub struct QlFsmState { pub replay_cache: ReplayCache, pub next_control_id: u32, - pub outbound: VecDeque, + pub outbound: VecDeque>>, pub events: VecDeque, pub session_events: VecDeque, pub now: FsmTime, diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 4c681722..9b475cc9 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -189,7 +189,7 @@ impl Harness { self.unix_secs = self.unix_secs.saturating_add(duration.as_secs()); } - fn next_outbound_a(&mut self) -> Option { + fn next_outbound_a(&mut self) -> Option>> { let write = self.a.fsm.take_next_write(self.time(), &self.a.crypto)?; if let Some(id) = write.session_write_id { self.a.fsm.confirm_session_write(self.time(), id); @@ -197,7 +197,7 @@ impl Harness { Some(write.record) } - fn next_outbound_b(&mut self) -> Option { + fn next_outbound_b(&mut self) -> Option>> { let write = self.b.fsm.take_next_write(self.time(), &self.b.crypto)?; if let Some(id) = write.session_write_id { self.b.fsm.confirm_session_write(self.time(), id); @@ -209,14 +209,14 @@ impl Harness { self.a.fsm.take_next_write(self.time(), &self.a.crypto) } - fn deliver_to_a(&mut self, record: QlRecord) { + fn deliver_to_a(&mut self, record: QlRecord>) { self.a .fsm .receive(self.time(), record.encode(), &self.a.crypto) .unwrap(); } - fn deliver_to_b(&mut self, record: QlRecord) { + fn deliver_to_b(&mut self, record: QlRecord>) { self.b .fsm .receive(self.time(), record.encode(), &self.b.crypto) @@ -277,7 +277,7 @@ fn peer_from_identity(identity: &QlIdentity) -> Peer { fn decrypt_record( crypto: &impl QlCrypto, - record: &QlRecord, + record: &QlRecord>, session_key: &SessionKey, ) -> ql_wire::SessionRecord { let aad = record.header.aad(); diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index 87429be3..949510fc 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -127,7 +127,7 @@ impl SessionRecordBuilder { header: QlHeader, session_key: &SessionKey, nonce: Nonce, - ) -> QlRecord { + ) -> QlRecord> { let aad = header.aad(); let encrypted = EncryptedMessage::encrypt(crypto, session_key, self.bytes, &aad, nonce); QlRecord { diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 245fe2ae..5b8b6252 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -204,7 +204,7 @@ pub fn encrypt_record( session_key: &SessionKey, body: &SessionRecord, nonce: crate::Nonce, -) -> QlRecord { +) -> QlRecord> { let mut builder = SessionRecordBuilder::new(body.seq, body.encoded_len()); for frame in &body.frames { let pushed = builder.push_frame(frame); diff --git a/ql-wire/src/pair/crypto.rs b/ql-wire/src/pair/crypto.rs index b42be525..192ac2de 100644 --- a/ql-wire/src/pair/crypto.rs +++ b/ql-wire/src/pair/crypto.rs @@ -10,7 +10,7 @@ pub fn build_pair_request( recipient: XID, recipient_encapsulation_key: &MlKemPublicKey, meta: ControlMeta, -) -> QlRecord { +) -> QlRecord> { let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(crypto); let header = QlHeader { sender: identity.xid, diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index 693262e6..bd61c3fc 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -5,34 +5,17 @@ use crate::{ header::{decode_record_header, encode_record_header, QlHeader}, pair::PairRequestRecord, unpair::Unpair, - WireError, QL_WIRE_VERSION, + ByteSlice, WireError, QL_WIRE_VERSION, }; #[derive(Debug, Clone, PartialEq, Eq)] -pub struct QlRecord { +pub struct QlRecord { pub header: QlHeader, - pub payload: QlPayload, + pub payload: QlPayload, } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum QlPayload { - PairRequest(PairRequestRecord>), - Unpair(Unpair), - Hello(Hello), - HelloReply(HelloReply), - Confirm(Confirm), - Ready(Ready>), - Session(EncryptedMessage>), -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct QlRecordRef { - pub header: QlHeader, - pub payload: QlPayloadRef, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum QlPayloadRef { +pub enum QlPayload { PairRequest(PairRequestRecord), Unpair(Unpair), Hello(Hello), @@ -72,7 +55,7 @@ impl TryFrom for RecordKind { } impl RecordKind { - fn for_payload(payload: &QlPayload) -> Self { + fn for_payload(payload: &QlPayload) -> Self { match payload { QlPayload::PairRequest(_) => Self::PairRequest, QlPayload::Unpair(_) => Self::Unpair, @@ -85,7 +68,7 @@ impl RecordKind { } } -impl QlRecord { +impl> QlRecord { pub fn encode(&self) -> Vec { let mut out = Vec::new(); codec::push_u8(&mut out, QL_WIRE_VERSION); @@ -105,21 +88,15 @@ impl QlRecord { } out } +} +impl QlRecord> { pub fn decode(bytes: &[u8]) -> Result { - Ok(Self::parse(bytes)?.to_owned()) - } - - pub fn parse(bytes: &[u8]) -> Result, WireError> { - QlRecordRef::parse(bytes) - } - - pub fn parse_mut(bytes: &mut [u8]) -> Result, WireError> { - QlRecordRef::parse(bytes) + Ok(QlRecord::parse(bytes)?.into_owned()) } } -impl QlRecordRef { +impl QlRecord { pub fn parse(bytes: B) -> Result { let mut reader = codec::Reader::new(bytes); if reader.take_u8()? != QL_WIRE_VERSION { @@ -132,61 +109,37 @@ impl QlRecordRef { payload, }) } -} -impl> QlRecordRef { - pub fn to_owned(&self) -> QlRecord { + pub fn into_owned(self) -> QlRecord> { QlRecord { header: self.header, - payload: self.payload.to_owned(), + payload: self.payload.into_owned(), } } } -impl> QlPayloadRef { - pub fn to_owned(&self) -> QlPayload { +impl QlPayload { + pub fn into_owned(self) -> QlPayload> { match self { - Self::PairRequest(request) => QlPayload::PairRequest(PairRequestRecord { - kem_ct: request.kem_ct.clone(), - encrypted: EncryptedMessage { - nonce: request.encrypted.nonce, - auth: request.encrypted.auth, - ciphertext: request.encrypted.ciphertext.as_ref().to_vec(), - }, - }), - Self::Unpair(unpair) => QlPayload::Unpair(unpair.clone()), - Self::Hello(hello) => QlPayload::Hello(hello.clone()), - Self::HelloReply(reply) => QlPayload::HelloReply(reply.clone()), - Self::Confirm(confirm) => QlPayload::Confirm(confirm.clone()), - Self::Ready(ready) => QlPayload::Ready(Ready { - encrypted: EncryptedMessage { - nonce: ready.encrypted.nonce, - auth: ready.encrypted.auth, - ciphertext: ready.encrypted.ciphertext.as_ref().to_vec(), - }, - }), - Self::Session(encrypted) => QlPayload::Session(EncryptedMessage { - nonce: encrypted.nonce, - auth: encrypted.auth, - ciphertext: encrypted.ciphertext.as_ref().to_vec(), - }), + Self::PairRequest(request) => QlPayload::PairRequest(request.into_owned()), + Self::Unpair(unpair) => QlPayload::Unpair(unpair), + Self::Hello(hello) => QlPayload::Hello(hello), + Self::HelloReply(reply) => QlPayload::HelloReply(reply), + Self::Confirm(confirm) => QlPayload::Confirm(confirm), + Self::Ready(ready) => QlPayload::Ready(ready.into_owned()), + Self::Session(encrypted) => QlPayload::Session(encrypted.into_owned()), } } } -fn parse_payload( - kind: RecordKind, - payload: B, -) -> Result, WireError> { +fn parse_payload(kind: RecordKind, payload: B) -> Result, WireError> { match kind { - RecordKind::PairRequest => Ok(QlPayloadRef::PairRequest(PairRequestRecord::parse( - payload, - )?)), - RecordKind::Unpair => Ok(QlPayloadRef::Unpair(Unpair::decode(&payload[..])?)), - RecordKind::Hello => Ok(QlPayloadRef::Hello(handshake::Hello::decode(&payload[..])?)), - RecordKind::HelloReply => Ok(QlPayloadRef::HelloReply(HelloReply::decode(&payload[..])?)), - RecordKind::Confirm => Ok(QlPayloadRef::Confirm(Confirm::decode(&payload[..])?)), - RecordKind::Ready => Ok(QlPayloadRef::Ready(Ready::parse(payload)?)), - RecordKind::Session => Ok(QlPayloadRef::Session(EncryptedMessage::parse(payload)?)), + RecordKind::PairRequest => Ok(QlPayload::PairRequest(PairRequestRecord::parse(payload)?)), + RecordKind::Unpair => Ok(QlPayload::Unpair(Unpair::decode(&payload[..])?)), + RecordKind::Hello => Ok(QlPayload::Hello(handshake::Hello::decode(&payload[..])?)), + RecordKind::HelloReply => Ok(QlPayload::HelloReply(HelloReply::decode(&payload[..])?)), + RecordKind::Confirm => Ok(QlPayload::Confirm(Confirm::decode(&payload[..])?)), + RecordKind::Ready => Ok(QlPayload::Ready(Ready::parse(payload)?)), + RecordKind::Session => Ok(QlPayload::Session(EncryptedMessage::parse(payload)?)), } } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 00898645..ecbac195 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -116,12 +116,12 @@ fn encrypted_session_record_round_trip_and_decrypt() { assert_eq!(decoded.header, header); assert!(matches!(decoded.payload, QlPayload::Session(_))); - let parsed = QlRecord::parse(&bytes).unwrap(); - assert_eq!(parsed.to_owned(), record); + let parsed = QlRecord::parse(bytes.as_slice()).unwrap(); + assert_eq!(parsed.into_owned(), record); let mut bytes = bytes; - let QlRecordRef { header, payload } = QlRecord::parse_mut(&mut bytes).unwrap(); - let QlPayloadRef::Session(encrypted) = payload else { + let QlRecord { header, payload } = QlRecord::parse(&mut bytes[..]).unwrap(); + let QlPayload::Session(encrypted) = payload else { panic!("expected session payload"); }; let decrypted = encrypted::decrypt_record(&crypto, &header, encrypted, &session_key).unwrap(); @@ -164,8 +164,8 @@ fn decrypted_session_record_iterates_zero_copy_frames() { ); let mut bytes = record.encode(); - let QlRecordRef { header, payload } = QlRecord::parse_mut(&mut bytes).unwrap(); - let QlPayloadRef::Session(encrypted) = payload else { + let QlRecord { header, payload } = QlRecord::parse(&mut bytes[..]).unwrap(); + let QlPayload::Session(encrypted) = payload else { panic!("expected session payload"); }; let decrypted = encrypted::decrypt_record(&crypto, &header, encrypted, &session_key).unwrap(); @@ -232,8 +232,8 @@ fn pair_request_round_trip_and_decrypt() { ); let mut bytes = record.encode(); - let QlRecordRef { header, payload } = QlRecord::parse_mut(&mut bytes).unwrap(); - let QlPayloadRef::PairRequest(request) = payload else { + let QlRecord { header, payload } = QlRecord::parse(&mut bytes[..]).unwrap(); + let QlPayload::PairRequest(request) = payload else { panic!("expected pair request"); }; let body = pair::decrypt_pair_request(&crypto, &recipient, &header, request, 100).unwrap(); @@ -262,7 +262,7 @@ fn ready_round_trip_and_decrypt() { meta, Nonce([12; Nonce::SIZE]), ); - let record = QlRecord { + let record: QlRecord> = QlRecord { header, payload: QlPayload::Ready(ready), }; @@ -271,8 +271,8 @@ fn ready_round_trip_and_decrypt() { let parsed = QlRecord::decode(&bytes).unwrap(); assert_eq!(parsed, record); - let QlRecordRef { header, payload } = QlRecord::parse_mut(&mut bytes).unwrap(); - let QlPayloadRef::Ready(ready) = payload else { + let QlRecord { header, payload } = QlRecord::parse(&mut bytes[..]).unwrap(); + let QlPayload::Ready(ready) = payload else { panic!("expected ready payload"); }; let body = handshake::decrypt_ready(&crypto, &header, ready, &session_key, 100).unwrap(); @@ -302,8 +302,8 @@ fn unpair_round_trip_and_verify() { let parsed = QlRecord::decode(&bytes).unwrap(); assert_eq!(parsed, record); - let QlRecordRef { header, payload } = QlRecord::parse_mut(&mut bytes).unwrap(); - let QlPayloadRef::Unpair(unpair) = payload else { + let QlRecord { header, payload } = QlRecord::parse(&mut bytes[..]).unwrap(); + let QlPayload::Unpair(unpair) = payload else { panic!("expected unpair payload"); }; unpair::verify_unpair(&crypto, &header, &sender_signing_public, &unpair, 100).unwrap(); @@ -505,7 +505,7 @@ fn protocol_record_size_breakdown() { } } - fn session_record(header: QlHeader, tag: u8, body: SessionRecord) -> QlRecord { + fn session_record(header: QlHeader, tag: u8, body: SessionRecord) -> QlRecord> { let ciphertext_len = body.encode().len(); QlRecord { header, @@ -514,7 +514,7 @@ fn protocol_record_size_breakdown() { } let header = header(); - let hello = QlRecord { + let hello: QlRecord> = QlRecord { header, payload: QlPayload::Hello(handshake::Hello { meta: meta(1), @@ -523,7 +523,7 @@ fn protocol_record_size_breakdown() { signature: MlDsaSignature::from_data([5; MlDsaSignature::SIZE]), }), }; - let hello_reply = QlRecord { + let hello_reply: QlRecord> = QlRecord { header, payload: QlPayload::HelloReply(handshake::HelloReply { meta: meta(2), @@ -532,28 +532,28 @@ fn protocol_record_size_breakdown() { signature: MlDsaSignature::from_data([8; MlDsaSignature::SIZE]), }), }; - let confirm = QlRecord { + let confirm: QlRecord> = QlRecord { header, payload: QlPayload::Confirm(handshake::Confirm { meta: meta(3), signature: MlDsaSignature::from_data([9; MlDsaSignature::SIZE]), }), }; - let pair_request = QlRecord { + let pair_request: QlRecord> = QlRecord { header, payload: QlPayload::PairRequest(pair::PairRequestRecord { kem_ct: MlKemCiphertext::from_data([10; MlKemCiphertext::SIZE]), encrypted: encrypted(11, 0), }), }; - let unpair = QlRecord { + let unpair: QlRecord> = QlRecord { header, payload: QlPayload::Unpair(unpair::Unpair { meta: meta(4), signature: MlDsaSignature::from_data([12; MlDsaSignature::SIZE]), }), }; - let ready = QlRecord { + let ready: QlRecord> = QlRecord { header, payload: QlPayload::Ready(handshake::Ready { encrypted: encrypted(13, 0), diff --git a/ql-wire/src/unpair/crypto.rs b/ql-wire/src/unpair/crypto.rs index 20451681..f4849387 100644 --- a/ql-wire/src/unpair/crypto.rs +++ b/ql-wire/src/unpair/crypto.rs @@ -9,7 +9,7 @@ pub fn build_unpair( identity: &QlIdentity, recipient: XID, meta: ControlMeta, -) -> QlRecord { +) -> QlRecord> { let header = QlHeader { sender: identity.xid, recipient, From c92b9cb1beb5c4217dc25c332c73ed6a5ef8608f Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 29 Mar 2026 10:47:19 -0400 Subject: [PATCH 052/304] ql-wire: noise handshake --- Cargo.lock | 26 - ql-wire/Cargo.toml | 5 - ql-wire/src/codec.rs | 10 +- ql-wire/src/crypto.rs | 57 +++ ql-wire/src/encrypted/builder.rs | 12 +- ql-wire/src/encrypted/mod.rs | 10 +- ql-wire/src/encrypted_message.rs | 6 +- ql-wire/src/error.rs | 24 +- ql-wire/src/handshake/crypto.rs | 463 ----------------- ql-wire/src/handshake/kk.rs | 312 +++++++++++ ql-wire/src/handshake/mod.rs | 456 +++++++++++++---- ql-wire/src/handshake/xx.rs | 391 ++++++++++++++ ql-wire/src/header.rs | 102 ++-- ql-wire/src/identity.rs | 104 +++- ql-wire/src/lib.rs | 31 +- ql-wire/src/pair/crypto.rs | 131 ----- ql-wire/src/pair/mod.rs | 86 ---- ql-wire/src/pq.rs | 152 +----- ql-wire/src/record.rs | 258 +++++++--- ql-wire/src/tests.rs | 853 ++++++++++--------------------- ql-wire/src/unpair/crypto.rs | 61 --- ql-wire/src/unpair/mod.rs | 29 -- ql-wire/src/x25519.rs | 47 ++ 23 files changed, 1831 insertions(+), 1795 deletions(-) create mode 100644 ql-wire/src/crypto.rs delete mode 100644 ql-wire/src/handshake/crypto.rs create mode 100644 ql-wire/src/handshake/kk.rs create mode 100644 ql-wire/src/handshake/xx.rs delete mode 100644 ql-wire/src/pair/crypto.rs delete mode 100644 ql-wire/src/pair/mod.rs delete mode 100644 ql-wire/src/unpair/crypto.rs delete mode 100644 ql-wire/src/unpair/mod.rs create mode 100644 ql-wire/src/x25519.rs diff --git a/Cargo.lock b/Cargo.lock index 1206f164..90dfe15f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1499,30 +1499,6 @@ dependencies = [ "hax-lib", ] -[[package]] -name = "libcrux-macros" -version = "0.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffd6aa2dcd5be681662001b81d493f1569c6d49a32361f470b0c955465cd0338" -dependencies = [ - "quote", - "syn 2.0.106", -] - -[[package]] -name = "libcrux-ml-dsa" -version = "0.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b34d977eb95b8fe93e6eb87197b55ee21e50e725bc3f206a7cb3a0d7d719c4b" -dependencies = [ - "core-models", - "hax-lib", - "libcrux-intrinsics", - "libcrux-macros", - "libcrux-platform", - "libcrux-sha3", -] - [[package]] name = "libcrux-ml-kem" version = "0.0.7" @@ -2190,10 +2166,8 @@ name = "ql-wire" version = "0.1.0" dependencies = [ "libcrux-aesgcm", - "libcrux-ml-dsa", "libcrux-ml-kem", "sha2", - "thiserror", ] [[package]] diff --git a/ql-wire/Cargo.toml b/ql-wire/Cargo.toml index 4147ec23..3a94f700 100644 --- a/ql-wire/Cargo.toml +++ b/ql-wire/Cargo.toml @@ -6,15 +6,10 @@ description = "Quantum Link wire format types and crypto helpers" license = "Proprietary" [dependencies] -libcrux-ml-dsa = { version = "0.0.7", default-features = false, features = [ - "std", - "mldsa87", -] } libcrux-ml-kem = { version = "0.0.7", default-features = false, features = [ "std", "mlkem1024", ] } -thiserror = { version = "2" } [dev-dependencies] libcrux-aesgcm = "0.0.7" diff --git a/ql-wire/src/codec.rs b/ql-wire/src/codec.rs index e5e0ebdd..45fb9a41 100644 --- a/ql-wire/src/codec.rs +++ b/ql-wire/src/codec.rs @@ -1,4 +1,4 @@ -use crate::{ByteSlice, QlHeader, WireError}; +use crate::{ByteSlice, WireError}; pub fn push_u8(out: &mut Vec, value: u8) { out.push(value); @@ -106,11 +106,3 @@ pub fn append_framed_bytes(out: &mut Vec, value: &[u8]) { out.extend_from_slice(&u64::try_from(value.len()).unwrap().to_le_bytes()); out.extend_from_slice(value); } - -pub fn header_aad(header: &QlHeader) -> Vec { - let mut aad = Vec::new(); - append_field(&mut aad, b"domain", b"ql-wire:header-aad:v1"); - append_field(&mut aad, b"sender", &header.sender.0); - append_field(&mut aad, b"recipient", &header.recipient.0); - aad -} diff --git a/ql-wire/src/crypto.rs b/ql-wire/src/crypto.rs new file mode 100644 index 00000000..503c185c --- /dev/null +++ b/ql-wire/src/crypto.rs @@ -0,0 +1,57 @@ +use crate::{ + MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, SessionKey, + X25519KeyPair, X25519PrivateKey, X25519PublicKey, ENCRYPTED_MESSAGE_AUTH_SIZE, +}; + +pub trait QlRandom { + fn fill_random_bytes(&self, out: &mut [u8]); +} + +pub trait QlHash { + fn sha256(&self, parts: &[&[u8]]) -> [u8; 32]; +} + +pub trait QlAead { + fn aes256_gcm_encrypt( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; + + fn aes256_gcm_decrypt( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], + ) -> bool; +} + +pub trait QlDh { + fn x25519_generate_keypair(&self) -> X25519KeyPair; + + fn x25519_agree( + &self, + private_key: &X25519PrivateKey, + public_key: &X25519PublicKey, + ) -> SessionKey; +} + +pub trait QlKem { + fn mlkem_generate_keypair(&self) -> MlKemKeyPair; + + fn mlkem_encapsulate(&self, public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey); + + fn mlkem_decapsulate( + &self, + private_key: &MlKemPrivateKey, + ciphertext: &MlKemCiphertext, + ) -> SessionKey; +} + +pub trait QlCrypto: QlRandom + QlHash + QlAead + QlDh + QlKem {} + +impl QlCrypto for T where T: QlRandom + QlHash + QlAead + QlDh + QlKem {} diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index 949510fc..e8f01660 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -3,8 +3,8 @@ use super::{ StreamClose, StreamData, StreamWindow, SIZE_LEN, }; use crate::{ - codec, encrypted_message::EncryptedMessage, ByteChunks, Nonce, QlCrypto, QlHeader, QlPayload, - QlRecord, SessionKey, + codec, encrypted_message::EncryptedMessage, ByteChunks, Nonce, QlCrypto, QlSessionRecord, + SessionHeader, SessionKey, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -124,15 +124,15 @@ impl SessionRecordBuilder { pub fn encrypt( self, crypto: &impl QlCrypto, - header: QlHeader, + header: SessionHeader, session_key: &SessionKey, nonce: Nonce, - ) -> QlRecord> { + ) -> QlSessionRecord> { let aad = header.aad(); let encrypted = EncryptedMessage::encrypt(crypto, session_key, self.bytes, &aad, nonce); - QlRecord { + QlSessionRecord { header, - payload: QlPayload::Session(encrypted), + payload: encrypted, } } } diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 5b8b6252..a73ce7d2 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -1,8 +1,8 @@ use std::mem::size_of; use crate::{ - codec, encrypted_message::EncryptedMessage, ByteChunks, ByteSlice, QlCrypto, QlHeader, - QlRecord, SessionKey, WireError, + codec, encrypted_message::EncryptedMessage, ByteChunks, ByteSlice, QlCrypto, QlSessionRecord, + SessionHeader, SessionKey, WireError, }; mod ack; @@ -200,11 +200,11 @@ impl<'a> Iterator for SessionFrameIter<'a> { pub fn encrypt_record( crypto: &impl QlCrypto, - header: QlHeader, + header: SessionHeader, session_key: &SessionKey, body: &SessionRecord, nonce: crate::Nonce, -) -> QlRecord> { +) -> QlSessionRecord> { let mut builder = SessionRecordBuilder::new(body.seq, body.encoded_len()); for frame in &body.frames { let pushed = builder.push_frame(frame); @@ -215,7 +215,7 @@ pub fn encrypt_record( pub fn decrypt_record>( crypto: &impl QlCrypto, - header: &QlHeader, + header: &SessionHeader, encrypted: EncryptedMessage, session_key: &SessionKey, ) -> Result { diff --git a/ql-wire/src/encrypted_message.rs b/ql-wire/src/encrypted_message.rs index e0b821f1..b1a98863 100644 --- a/ql-wire/src/encrypted_message.rs +++ b/ql-wire/src/encrypted_message.rs @@ -56,7 +56,7 @@ impl> EncryptedMessage { aad: &[u8], ) -> Result, WireError> { let mut plaintext = self.ciphertext.as_ref().to_vec(); - if !crypto.decrypt_with_aead(key, &self.nonce, aad, &mut plaintext, &self.auth) { + if !crypto.aes256_gcm_decrypt(key, &self.nonce, aad, &mut plaintext, &self.auth) { return Err(WireError::DecryptFailed); } Ok(plaintext) @@ -71,7 +71,7 @@ impl> EncryptedMessage { aad: &[u8], ) -> Result { let ciphertext = self.ciphertext.as_mut(); - if !crypto.decrypt_with_aead(key, &self.nonce, aad, ciphertext, &self.auth) { + if !crypto.aes256_gcm_decrypt(key, &self.nonce, aad, ciphertext, &self.auth) { return Err(WireError::DecryptFailed); } Ok(self.ciphertext) @@ -86,7 +86,7 @@ impl EncryptedMessage> { aad: &[u8], nonce: Nonce, ) -> Self { - let auth = crypto.encrypt_with_aead(key, &nonce, aad, &mut plaintext); + let auth = crypto.aes256_gcm_encrypt(key, &nonce, aad, &mut plaintext); Self { nonce, auth, diff --git a/ql-wire/src/error.rs b/ql-wire/src/error.rs index a1861866..6f17d648 100644 --- a/ql-wire/src/error.rs +++ b/ql-wire/src/error.rs @@ -1,13 +1,23 @@ -use thiserror::Error; +use core::fmt; -#[derive(Debug, Clone, PartialEq, Eq, Error)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum WireError { - #[error("invalid payload")] InvalidPayload, - #[error("invalid signature")] - InvalidSignature, - #[error("expired")] Expired, - #[error("decryption failed")] DecryptFailed, + InvalidState, } + +impl fmt::Display for WireError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let message = match self { + Self::InvalidPayload => "invalid payload", + Self::Expired => "expired", + Self::DecryptFailed => "decryption failed", + Self::InvalidState => "invalid state", + }; + f.write_str(message) + } +} + +impl std::error::Error for WireError {} diff --git a/ql-wire/src/handshake/crypto.rs b/ql-wire/src/handshake/crypto.rs deleted file mode 100644 index 0033cb19..00000000 --- a/ql-wire/src/handshake/crypto.rs +++ /dev/null @@ -1,463 +0,0 @@ -use super::{Confirm, Hello, HelloReply, Ready, ReadyBody}; -use crate::{ - pq::ML_KEM_SUITE_TAG, ControlMeta, EncryptedMessage, MlDsaPublicKey, MlDsaSignature, - MlKemCiphertext, MlKemPublicKey, Nonce, QlCrypto, QlHeader, QlIdentity, SessionKey, WireError, - XID, -}; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ResponderSecrets { - pub initiator_secret: SessionKey, - pub responder_secret: SessionKey, -} - -pub fn build_hello( - crypto: &impl QlCrypto, - identity: &QlIdentity, - recipient: XID, - recipient_encapsulation_key: &MlKemPublicKey, - meta: ControlMeta, -) -> (Hello, SessionKey) { - let nonce = next_nonce(crypto); - let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(crypto); - let proof_data = hash_hello_proof_data( - crypto, - identity.xid, - recipient, - &meta, - &nonce.0, - kem_ct.as_bytes(), - ); - let signature = identity.signing_private_key.sign(crypto, &proof_data); - ( - Hello { - meta, - nonce, - kem_ct, - signature, - }, - session_key, - ) -} - -pub fn verify_hello( - crypto: &impl QlCrypto, - initiator: XID, - responder: XID, - initiator_signing_key: &MlDsaPublicKey, - hello: &Hello, - now_seconds: u64, -) -> Result<(), WireError> { - hello.meta.ensure_not_expired(now_seconds)?; - let proof_data = hash_hello_proof_data( - crypto, - initiator, - responder, - &hello.meta, - &hello.nonce.0, - hello.kem_ct.as_bytes(), - ); - verify_signature_bytes( - initiator_signing_key, - hello.signature.as_bytes(), - &proof_data, - ) -} - -pub fn respond_hello( - crypto: &impl QlCrypto, - identity: &QlIdentity, - initiator: XID, - initiator_signing_key: &MlDsaPublicKey, - initiator_encapsulation_key: &MlKemPublicKey, - hello: &Hello, - meta: ControlMeta, - now_seconds: u64, -) -> Result<(HelloReply, ResponderSecrets), WireError> { - verify_hello( - crypto, - initiator, - identity.xid, - initiator_signing_key, - hello, - now_seconds, - )?; - let initiator_secret = identity - .encapsulation_private_key - .decapsulate_shared_secret_bytes(hello.kem_ct.as_bytes()); - let nonce = next_nonce(crypto); - let (responder_secret, kem_ct) = - initiator_encapsulation_key.encapsulate_new_shared_secret(crypto); - let transcript = hash_handshake_transcript( - crypto, - initiator, - identity.xid, - &hello.meta, - &hello.nonce.0, - hello.kem_ct.as_bytes(), - &meta, - &nonce.0, - kem_ct.as_bytes(), - ); - let signature = identity.signing_private_key.sign(crypto, &transcript); - Ok(( - HelloReply { - meta, - nonce, - kem_ct, - signature, - }, - ResponderSecrets { - initiator_secret, - responder_secret, - }, - )) -} - -pub fn build_confirm( - crypto: &impl QlCrypto, - identity: &QlIdentity, - responder: XID, - responder_signing_key: &MlDsaPublicKey, - hello: &Hello, - reply: &HelloReply, - initiator_secret: &SessionKey, - meta: ControlMeta, - now_seconds: u64, -) -> Result<(Confirm, SessionKey), WireError> { - reply.meta.ensure_not_expired(now_seconds)?; - let transcript = hash_handshake_transcript( - crypto, - identity.xid, - responder, - &hello.meta, - &hello.nonce.0, - hello.kem_ct.as_bytes(), - &reply.meta, - &reply.nonce.0, - reply.kem_ct.as_bytes(), - ); - verify_signature_bytes( - responder_signing_key, - reply.signature.as_bytes(), - &transcript, - )?; - let responder_secret = identity - .encapsulation_private_key - .decapsulate_shared_secret_bytes(reply.kem_ct.as_bytes()); - let proof_data = hash_confirm_proof_data( - crypto, - &meta, - identity.xid, - responder, - &hello.meta, - &hello.nonce.0, - hello.kem_ct.as_bytes(), - &reply.meta, - &reply.nonce.0, - reply.kem_ct.as_bytes(), - ); - let signature = identity.signing_private_key.sign(crypto, &proof_data); - let session_key = derive_session_key( - crypto, - initiator_secret, - &responder_secret, - identity.xid, - responder, - &hello.meta, - &hello.nonce.0, - hello.kem_ct.as_bytes(), - &reply.meta, - &reply.nonce.0, - reply.kem_ct.as_bytes(), - ); - Ok((Confirm { meta, signature }, session_key)) -} - -pub fn finalize_confirm( - crypto: &impl QlCrypto, - initiator: XID, - responder: XID, - initiator_signing_key: &MlDsaPublicKey, - hello: &Hello, - reply: &HelloReply, - confirm: &Confirm, - secrets: &ResponderSecrets, - now_seconds: u64, -) -> Result { - verify_confirm( - crypto, - initiator, - responder, - initiator_signing_key, - hello, - reply, - confirm, - now_seconds, - )?; - Ok(derive_session_key( - crypto, - &secrets.initiator_secret, - &secrets.responder_secret, - initiator, - responder, - &hello.meta, - &hello.nonce.0, - hello.kem_ct.as_bytes(), - &reply.meta, - &reply.nonce.0, - reply.kem_ct.as_bytes(), - )) -} - -pub fn verify_confirm( - crypto: &impl QlCrypto, - initiator: XID, - responder: XID, - initiator_signing_key: &MlDsaPublicKey, - hello: &Hello, - reply: &HelloReply, - confirm: &Confirm, - now_seconds: u64, -) -> Result<(), WireError> { - confirm.meta.ensure_not_expired(now_seconds)?; - let proof_data = hash_confirm_proof_data( - crypto, - &confirm.meta, - initiator, - responder, - &hello.meta, - &hello.nonce.0, - hello.kem_ct.as_bytes(), - &reply.meta, - &reply.nonce.0, - reply.kem_ct.as_bytes(), - ); - verify_signature_bytes( - initiator_signing_key, - confirm.signature.as_bytes(), - &proof_data, - ) -} - -pub fn build_ready( - crypto: &impl QlCrypto, - header: QlHeader, - session_key: &SessionKey, - meta: ControlMeta, - nonce: Nonce, -) -> Ready> { - let aad = header.aad(); - let body_bytes = ReadyBody { meta }.encode(); - Ready { - encrypted: EncryptedMessage::encrypt(crypto, session_key, body_bytes, &aad, nonce), - } -} - -pub fn decrypt_ready>( - crypto: &impl QlCrypto, - header: &QlHeader, - ready: Ready, - session_key: &SessionKey, - now_seconds: u64, -) -> Result { - let aad = header.aad(); - let mut plaintext = ready - .encrypted - .decrypt_in_place(crypto, session_key, &aad)?; - let body = ReadyBody::decode(plaintext.as_mut())?; - body.meta.ensure_not_expired(now_seconds)?; - Ok(body) -} - -fn hash_hello_proof_data( - crypto: &impl QlCrypto, - initiator: XID, - responder: XID, - meta: &ControlMeta, - nonce: &[u8; Nonce::SIZE], - kem_ct: &[u8; MlKemCiphertext::SIZE], -) -> [u8; 32] { - let control_id = meta.control_id.0.to_le_bytes(); - let valid_until = meta.valid_until.to_le_bytes(); - crypto.hash(&[ - b"ql-wire:hello-proof:v1", - b"initiator", - &initiator.0, - b"responder", - &responder.0, - b"control-id", - &control_id, - b"valid-until", - &valid_until, - b"nonce", - nonce, - b"kem-suite", - ML_KEM_SUITE_TAG, - b"kem-ct", - kem_ct, - ]) -} - -fn hash_handshake_transcript( - crypto: &impl QlCrypto, - initiator: XID, - responder: XID, - hello_meta: &ControlMeta, - initiator_nonce: &[u8; Nonce::SIZE], - initiator_kem_ct: &[u8; MlKemCiphertext::SIZE], - reply_meta: &ControlMeta, - responder_nonce: &[u8; Nonce::SIZE], - responder_kem_ct: &[u8; MlKemCiphertext::SIZE], -) -> [u8; 32] { - let hello_control_id = hello_meta.control_id.0.to_le_bytes(); - let hello_valid_until = hello_meta.valid_until.to_le_bytes(); - let reply_control_id = reply_meta.control_id.0.to_le_bytes(); - let reply_valid_until = reply_meta.valid_until.to_le_bytes(); - crypto.hash(&[ - b"ql-wire:handshake-transcript:v1", - b"initiator", - &initiator.0, - b"responder", - &responder.0, - b"hello-control-id", - &hello_control_id, - b"hello-valid-until", - &hello_valid_until, - b"initiator-nonce", - initiator_nonce, - b"initiator-kem-suite", - ML_KEM_SUITE_TAG, - b"initiator-kem-ct", - initiator_kem_ct, - b"reply-control-id", - &reply_control_id, - b"reply-valid-until", - &reply_valid_until, - b"responder-nonce", - responder_nonce, - b"responder-kem-suite", - ML_KEM_SUITE_TAG, - b"responder-kem-ct", - responder_kem_ct, - ]) -} - -fn hash_confirm_proof_data( - crypto: &impl QlCrypto, - confirm_meta: &ControlMeta, - initiator: XID, - responder: XID, - hello_meta: &ControlMeta, - initiator_nonce: &[u8; Nonce::SIZE], - initiator_kem_ct: &[u8; MlKemCiphertext::SIZE], - reply_meta: &ControlMeta, - responder_nonce: &[u8; Nonce::SIZE], - responder_kem_ct: &[u8; MlKemCiphertext::SIZE], -) -> [u8; 32] { - let confirm_control_id = confirm_meta.control_id.0.to_le_bytes(); - let confirm_valid_until = confirm_meta.valid_until.to_le_bytes(); - let hello_control_id = hello_meta.control_id.0.to_le_bytes(); - let hello_valid_until = hello_meta.valid_until.to_le_bytes(); - let reply_control_id = reply_meta.control_id.0.to_le_bytes(); - let reply_valid_until = reply_meta.valid_until.to_le_bytes(); - crypto.hash(&[ - b"ql-wire:confirm-proof:v1", - b"confirm-control-id", - &confirm_control_id, - b"confirm-valid-until", - &confirm_valid_until, - b"initiator", - &initiator.0, - b"responder", - &responder.0, - b"hello-control-id", - &hello_control_id, - b"hello-valid-until", - &hello_valid_until, - b"initiator-nonce", - initiator_nonce, - b"initiator-kem-suite", - ML_KEM_SUITE_TAG, - b"initiator-kem-ct", - initiator_kem_ct, - b"reply-control-id", - &reply_control_id, - b"reply-valid-until", - &reply_valid_until, - b"responder-nonce", - responder_nonce, - b"responder-kem-suite", - ML_KEM_SUITE_TAG, - b"responder-kem-ct", - responder_kem_ct, - ]) -} - -fn next_nonce(crypto: &impl QlCrypto) -> Nonce { - let mut data = [0u8; Nonce::SIZE]; - crypto.fill_random_bytes(&mut data); - Nonce(data) -} - -fn derive_session_key( - crypto: &impl QlCrypto, - initiator_secret: &SessionKey, - responder_secret: &SessionKey, - initiator: XID, - responder: XID, - hello_meta: &ControlMeta, - initiator_nonce: &[u8; Nonce::SIZE], - initiator_kem_ct: &[u8; MlKemCiphertext::SIZE], - reply_meta: &ControlMeta, - responder_nonce: &[u8; Nonce::SIZE], - responder_kem_ct: &[u8; MlKemCiphertext::SIZE], -) -> SessionKey { - let hello_control_id = hello_meta.control_id.0.to_le_bytes(); - let hello_valid_until = hello_meta.valid_until.to_le_bytes(); - let reply_control_id = reply_meta.control_id.0.to_le_bytes(); - let reply_valid_until = reply_meta.valid_until.to_le_bytes(); - SessionKey::from_data(crypto.hash(&[ - b"ql-wire:session-key:v1", - b"initiator-secret", - initiator_secret.as_bytes(), - b"responder-secret", - responder_secret.as_bytes(), - b"initiator", - &initiator.0, - b"responder", - &responder.0, - b"hello-control-id", - &hello_control_id, - b"hello-valid-until", - &hello_valid_until, - b"initiator-nonce", - initiator_nonce, - b"initiator-kem-suite", - ML_KEM_SUITE_TAG, - b"initiator-kem-ct", - initiator_kem_ct, - b"reply-control-id", - &reply_control_id, - b"reply-valid-until", - &reply_valid_until, - b"responder-nonce", - responder_nonce, - b"responder-kem-suite", - ML_KEM_SUITE_TAG, - b"responder-kem-ct", - responder_kem_ct, - ])) -} - -fn verify_signature_bytes( - signing_key: &MlDsaPublicKey, - signature: &[u8; MlDsaSignature::SIZE], - proof_data: &[u8], -) -> Result<(), WireError> { - if signing_key.verify_bytes(signature, proof_data) { - Ok(()) - } else { - Err(WireError::InvalidSignature) - } -} diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs new file mode 100644 index 00000000..c92ba3a4 --- /dev/null +++ b/ql-wire/src/handshake/kk.rs @@ -0,0 +1,312 @@ +use super::{ + decrypt_mlkem_ciphertext, encrypt_mlkem_ciphertext, finalize_handshake, + generate_ephemeral_keypair, init_kk_symmetric, mix_hash_ephemeral, EncryptedMlKemCiphertext, + FinalizedHandshake, HybridEphemeralKeyPair, HybridEphemeralPublic, Role, SymmetricState, + ENCRYPTED_MLKEM_CIPHERTEXT_LEN, +}; +use crate::{codec, ControlMeta, MlKemCiphertext, PeerBundle, QlCrypto, QlIdentity, WireError}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Kk1 { + pub meta: ControlMeta, + pub skem_ciphertext: MlKemCiphertext, + pub ephemeral: HybridEphemeralPublic, +} + +impl Kk1 { + pub const ENCODED_LEN: usize = + ControlMeta::ENCODED_LEN + MlKemCiphertext::SIZE + HybridEphemeralPublic::ENCODED_LEN; + + pub fn encode_into(&self, out: &mut Vec) { + self.meta.encode_into(out); + codec::push_bytes(out, self.skem_ciphertext.as_bytes()); + self.ephemeral.encode_into(out); + } + + pub fn decode(bytes: &[u8]) -> Result { + let mut reader = codec::Reader::new(bytes); + let meta = ControlMeta::decode_from(&mut reader)?; + let skem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); + let ephemeral = + HybridEphemeralPublic::decode(&reader.take_bytes(HybridEphemeralPublic::ENCODED_LEN)?)?; + reader.finish()?; + Ok(Self { + meta, + skem_ciphertext, + ephemeral, + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Kk2 { + pub meta: ControlMeta, + pub ekem_ciphertext: MlKemCiphertext, + pub skem_ciphertext: EncryptedMlKemCiphertext, + pub ephemeral: HybridEphemeralPublic, +} + +impl Kk2 { + pub const ENCODED_LEN: usize = ControlMeta::ENCODED_LEN + + MlKemCiphertext::SIZE + + ENCRYPTED_MLKEM_CIPHERTEXT_LEN + + HybridEphemeralPublic::ENCODED_LEN; + + pub fn encode_into(&self, out: &mut Vec) { + self.meta.encode_into(out); + codec::push_bytes(out, self.ekem_ciphertext.as_bytes()); + codec::push_bytes(out, self.skem_ciphertext.as_bytes()); + self.ephemeral.encode_into(out); + } + + pub fn decode(bytes: &[u8]) -> Result { + let mut reader = codec::Reader::new(bytes); + let meta = ControlMeta::decode_from(&mut reader)?; + let ekem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); + let skem_ciphertext = EncryptedMlKemCiphertext::from_data(reader.take_array()?); + let ephemeral = + HybridEphemeralPublic::decode(&reader.take_bytes(HybridEphemeralPublic::ENCODED_LEN)?)?; + reader.finish()?; + Ok(Self { + meta, + ekem_ciphertext, + skem_ciphertext, + ephemeral, + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum KkMessage { + Message1(Kk1), + Message2(Kk2), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum KkStep { + Send1, + Recv1, + Send2, + Recv2, + Done, +} + +#[derive(Debug, Clone)] +pub struct KkHandshake { + role: Role, + step: KkStep, + symmetric: SymmetricState, + local: QlIdentity, + remote_bundle: PeerBundle, + local_ephemeral: Option, + remote_ephemeral: Option, +} + +impl KkHandshake { + pub fn new_initiator( + crypto: &impl QlCrypto, + local: QlIdentity, + remote_bundle: PeerBundle, + ) -> Self { + let symmetric = init_kk_symmetric(crypto, &local.bundle(), &remote_bundle); + Self { + role: Role::Initiator, + step: KkStep::Send1, + symmetric, + local, + remote_bundle, + local_ephemeral: None, + remote_ephemeral: None, + } + } + + pub fn new_responder( + crypto: &impl QlCrypto, + local: QlIdentity, + remote_bundle: PeerBundle, + ) -> Self { + let symmetric = init_kk_symmetric(crypto, &remote_bundle, &local.bundle()); + Self { + role: Role::Responder, + step: KkStep::Recv1, + symmetric, + local, + remote_bundle, + local_ephemeral: None, + remote_ephemeral: None, + } + } + + pub fn is_finished(&self) -> bool { + self.step == KkStep::Done + } + + pub fn write_message( + &mut self, + crypto: &impl QlCrypto, + meta: ControlMeta, + ) -> Result { + match self.step { + KkStep::Send1 => { + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&self.remote_bundle.mlkem_public_key); + self.symmetric + .encrypt_and_hash(crypto, skem_ciphertext.as_bytes())?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + let local_ephemeral = generate_ephemeral_keypair(crypto); + let public = local_ephemeral.public(); + mix_hash_ephemeral(&mut self.symmetric, crypto, &public); + + let es = crypto.x25519_agree( + &local_ephemeral.x25519.private, + &self.remote_bundle.x25519_public_key, + ); + self.symmetric.mix_key(crypto, es.as_bytes()); + + let ss = crypto.x25519_agree( + &self.local.x25519_private_key, + &self.remote_bundle.x25519_public_key, + ); + self.symmetric.mix_key(crypto, ss.as_bytes()); + + self.local_ephemeral = Some(local_ephemeral); + self.step = KkStep::Recv2; + Ok(KkMessage::Message1(Kk1 { + meta, + skem_ciphertext, + ephemeral: public, + })) + } + KkStep::Send2 => { + let remote_ephemeral = self + .remote_ephemeral + .clone() + .ok_or(WireError::InvalidState)?; + let (ekem_ciphertext, ekem_secret) = + crypto.mlkem_encapsulate(&remote_ephemeral.mlkem_public_key); + self.symmetric.mix_hash(crypto, ekem_ciphertext.as_bytes()); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&self.remote_bundle.mlkem_public_key); + let skem_ciphertext = + encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + let local_ephemeral = generate_ephemeral_keypair(crypto); + let public = local_ephemeral.public(); + mix_hash_ephemeral(&mut self.symmetric, crypto, &public); + + let ee = crypto.x25519_agree( + &local_ephemeral.x25519.private, + &remote_ephemeral.x25519_public_key, + ); + self.symmetric.mix_key(crypto, ee.as_bytes()); + + let se = crypto.x25519_agree( + &local_ephemeral.x25519.private, + &self.remote_bundle.x25519_public_key, + ); + self.symmetric.mix_key(crypto, se.as_bytes()); + + self.local_ephemeral = Some(local_ephemeral); + self.step = KkStep::Done; + Ok(KkMessage::Message2(Kk2 { + meta, + ekem_ciphertext, + skem_ciphertext, + ephemeral: public, + })) + } + _ => Err(WireError::InvalidState), + } + } + + pub fn read_message( + &mut self, + crypto: &impl QlCrypto, + message: &KkMessage, + ) -> Result<(), WireError> { + match (&self.step, message) { + (KkStep::Recv1, KkMessage::Message1(message)) => { + self.symmetric + .decrypt_and_hash(crypto, message.skem_ciphertext.as_bytes())?; + let skem_secret = crypto + .mlkem_decapsulate(&self.local.mlkem_private_key, &message.skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); + self.remote_ephemeral = Some(message.ephemeral.clone()); + + let es = crypto.x25519_agree( + &self.local.x25519_private_key, + &message.ephemeral.x25519_public_key, + ); + self.symmetric.mix_key(crypto, es.as_bytes()); + + let ss = crypto.x25519_agree( + &self.local.x25519_private_key, + &self.remote_bundle.x25519_public_key, + ); + self.symmetric.mix_key(crypto, ss.as_bytes()); + self.step = KkStep::Send2; + Ok(()) + } + (KkStep::Recv2, KkMessage::Message2(message)) => { + let local_ephemeral = self + .local_ephemeral + .as_ref() + .ok_or(WireError::InvalidState)?; + self.symmetric + .mix_hash(crypto, message.ekem_ciphertext.as_bytes()); + let ekem_secret = crypto + .mlkem_decapsulate(&local_ephemeral.mlkem.private, &message.ekem_ciphertext); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let skem_ciphertext = decrypt_mlkem_ciphertext( + crypto, + &mut self.symmetric, + &message.skem_ciphertext, + )?; + let skem_secret = + crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); + self.remote_ephemeral = Some(message.ephemeral.clone()); + + let ee = crypto.x25519_agree( + &local_ephemeral.x25519.private, + &message.ephemeral.x25519_public_key, + ); + self.symmetric.mix_key(crypto, ee.as_bytes()); + + let se = crypto.x25519_agree( + &self.local.x25519_private_key, + &message.ephemeral.x25519_public_key, + ); + self.symmetric.mix_key(crypto, se.as_bytes()); + self.step = KkStep::Done; + Ok(()) + } + _ => Err(WireError::InvalidState), + } + } + + pub fn finalize(self, crypto: &impl QlCrypto) -> Result { + if !self.is_finished() { + return Err(WireError::InvalidState); + } + Ok(finalize_handshake( + crypto, + self.symmetric, + self.role, + self.remote_bundle, + )) + } +} diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 5aca8b82..a00fbdf4 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -1,151 +1,417 @@ use crate::{ - codec, encrypted_message::EncryptedMessage, ByteSlice, ControlMeta, MlDsaSignature, - MlKemCiphertext, Nonce, WireError, + codec, ConnectionId, MlKemCiphertext, MlKemKeyPair, MlKemPublicKey, Nonce, PeerBundle, + QlCrypto, SessionKey, WireError, X25519KeyPair, X25519PublicKey, ENCRYPTED_MESSAGE_AUTH_SIZE, }; -mod crypto; -pub use crypto::*; +mod kk; +mod xx; + +pub use kk::{Kk1, Kk2, KkHandshake, KkMessage}; +pub use xx::{Xx1, Xx2, Xx3, Xx4, XxHandshake, XxMessage}; + +const SHA256_BLOCK_LEN: usize = 64; +const PROTOCOL_XX: &[u8] = b"ql-wire:hybrid-xx:v1"; +const PROTOCOL_KK: &[u8] = b"ql-wire:hybrid-kk:v1"; +const CONNECTION_ID_DOMAIN: &[u8] = b"ql-wire:conn-id:v1"; + +pub const ENCRYPTED_MLKEM_CIPHERTEXT_LEN: usize = + MlKemCiphertext::SIZE + ENCRYPTED_MESSAGE_AUTH_SIZE; +pub const ENCRYPTED_PEER_BUNDLE_LEN: usize = PeerBundle::ENCODED_LEN + ENCRYPTED_MESSAGE_AUTH_SIZE; #[derive(Debug, Clone, PartialEq, Eq)] -pub struct Hello { - pub meta: ControlMeta, - pub nonce: Nonce, - pub kem_ct: MlKemCiphertext, - pub signature: MlDsaSignature, +pub struct HybridEphemeralPublic { + pub x25519_public_key: X25519PublicKey, + pub mlkem_public_key: MlKemPublicKey, } -impl Hello { - pub const ENCODED_LEN: usize = - ControlMeta::ENCODED_LEN + Nonce::SIZE + MlKemCiphertext::SIZE + MlDsaSignature::SIZE; +impl HybridEphemeralPublic { + pub const ENCODED_LEN: usize = X25519PublicKey::SIZE + MlKemPublicKey::SIZE; pub fn encode_into(&self, out: &mut Vec) { - self.meta.encode_into(out); - codec::push_bytes(out, &self.nonce.0); - codec::push_bytes(out, self.kem_ct.as_bytes()); - codec::push_bytes(out, self.signature.as_bytes()); + codec::push_bytes(out, self.x25519_public_key.as_bytes()); + codec::push_bytes(out, self.mlkem_public_key.as_bytes()); } pub fn decode(bytes: &[u8]) -> Result { let mut reader = codec::Reader::new(bytes); - let hello = Self { - meta: ControlMeta::decode_from(&mut reader)?, - nonce: Nonce(reader.take_array()?), - kem_ct: MlKemCiphertext::from_data(reader.take_array()?), - signature: MlDsaSignature::from_data(reader.take_array()?), + let value = Self { + x25519_public_key: X25519PublicKey::from_data(reader.take_array()?), + mlkem_public_key: MlKemPublicKey::from_data(reader.take_array()?), }; reader.finish()?; - Ok(hello) + Ok(value) } } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct HelloReply { - pub meta: ControlMeta, - pub nonce: Nonce, - pub kem_ct: MlKemCiphertext, - pub signature: MlDsaSignature, +pub struct EncryptedMlKemCiphertext(Box<[u8; ENCRYPTED_MLKEM_CIPHERTEXT_LEN]>); + +impl EncryptedMlKemCiphertext { + pub fn from_data(data: [u8; ENCRYPTED_MLKEM_CIPHERTEXT_LEN]) -> Self { + Self(Box::new(data)) + } + + pub fn as_bytes(&self) -> &[u8; ENCRYPTED_MLKEM_CIPHERTEXT_LEN] { + self.0.as_ref() + } } -impl HelloReply { - pub const ENCODED_LEN: usize = - ControlMeta::ENCODED_LEN + Nonce::SIZE + MlKemCiphertext::SIZE + MlDsaSignature::SIZE; +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EncryptedPeerBundle(Box<[u8; ENCRYPTED_PEER_BUNDLE_LEN]>); - pub fn encode_into(&self, out: &mut Vec) { - self.meta.encode_into(out); - codec::push_bytes(out, &self.nonce.0); - codec::push_bytes(out, self.kem_ct.as_bytes()); - codec::push_bytes(out, self.signature.as_bytes()); +impl EncryptedPeerBundle { + pub fn from_data(data: [u8; ENCRYPTED_PEER_BUNDLE_LEN]) -> Self { + Self(Box::new(data)) } - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - let reply = Self { - meta: ControlMeta::decode_from(&mut reader)?, - nonce: Nonce(reader.take_array()?), - kem_ct: MlKemCiphertext::from_data(reader.take_array()?), - signature: MlDsaSignature::from_data(reader.take_array()?), - }; - reader.finish()?; - Ok(reply) + pub fn as_bytes(&self) -> &[u8; ENCRYPTED_PEER_BUNDLE_LEN] { + self.0.as_ref() } } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct Confirm { - pub meta: ControlMeta, - pub signature: MlDsaSignature, +pub struct FinalizedHandshake { + pub tx_key: SessionKey, + pub rx_key: SessionKey, + pub tx_connection_id: ConnectionId, + pub rx_connection_id: ConnectionId, + pub handshake_hash: [u8; 32], + pub remote_bundle: PeerBundle, } -impl Confirm { - pub const ENCODED_LEN: usize = ControlMeta::ENCODED_LEN + MlDsaSignature::SIZE; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Role { + Initiator, + Responder, +} - pub fn encode_into(&self, out: &mut Vec) { - self.meta.encode_into(out); - codec::push_bytes(out, self.signature.as_bytes()); +#[derive(Debug, Clone)] +struct HybridEphemeralKeyPair { + x25519: X25519KeyPair, + mlkem: MlKemKeyPair, +} + +impl HybridEphemeralKeyPair { + fn public(&self) -> HybridEphemeralPublic { + HybridEphemeralPublic { + x25519_public_key: self.x25519.public, + mlkem_public_key: self.mlkem.public.clone(), + } } +} - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - let confirm = Self { - meta: ControlMeta::decode_from(&mut reader)?, - signature: MlDsaSignature::from_data(reader.take_array()?), - }; - reader.finish()?; - Ok(confirm) +#[derive(Debug, Clone)] +struct CipherState { + key: Option, + nonce: u64, +} + +impl CipherState { + fn new() -> Self { + Self { + key: None, + nonce: 0, + } + } + + fn initialize_key(&mut self, key: SessionKey) { + self.key = Some(key); + self.nonce = 0; + } + + fn has_key(&self) -> bool { + self.key.is_some() + } + + fn encrypt( + &mut self, + crypto: &impl QlCrypto, + aad: &[u8], + plaintext: &[u8], + ) -> Result, WireError> { + let key = self.key.as_ref().ok_or(WireError::InvalidState)?; + let nonce = noise_nonce(self.nonce); + let mut ciphertext = plaintext.to_vec(); + let auth = crypto.aes256_gcm_encrypt(key, &nonce, aad, &mut ciphertext); + self.nonce = self.nonce.wrapping_add(1); + ciphertext.extend_from_slice(&auth); + Ok(ciphertext) + } + + fn decrypt( + &mut self, + crypto: &impl QlCrypto, + aad: &[u8], + ciphertext: &[u8], + ) -> Result, WireError> { + if ciphertext.len() < ENCRYPTED_MESSAGE_AUTH_SIZE { + return Err(WireError::InvalidPayload); + } + let split = ciphertext.len() - ENCRYPTED_MESSAGE_AUTH_SIZE; + let (ciphertext, auth) = ciphertext.split_at(split); + let mut plaintext = ciphertext.to_vec(); + let key = self.key.as_ref().ok_or(WireError::InvalidState)?; + let nonce = noise_nonce(self.nonce); + let mut auth_tag = [0u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; + auth_tag.copy_from_slice(auth); + if !crypto.aes256_gcm_decrypt(key, &nonce, aad, &mut plaintext, &auth_tag) { + return Err(WireError::DecryptFailed); + } + self.nonce = self.nonce.wrapping_add(1); + Ok(plaintext) } } -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Ready { - pub encrypted: EncryptedMessage, +#[derive(Debug, Clone)] +struct SymmetricState { + chaining_key: [u8; 32], + handshake_hash: [u8; 32], + cipher: CipherState, } -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ReadyBody { - pub meta: ControlMeta, +impl SymmetricState { + fn new(crypto: &impl QlCrypto, protocol_name: &[u8]) -> Self { + let h = crypto.sha256(&[protocol_name]); + Self { + chaining_key: h, + handshake_hash: h, + cipher: CipherState::new(), + } + } + + fn mix_hash(&mut self, crypto: &impl QlCrypto, data: &[u8]) { + self.handshake_hash = crypto.sha256(&[&self.handshake_hash, data]); + } + + fn mix_key(&mut self, crypto: &impl QlCrypto, input_key_material: &[u8]) { + let (chaining_key, cipher_key) = hkdf2(crypto, &self.chaining_key, input_key_material); + self.chaining_key = chaining_key; + self.cipher.initialize_key(cipher_key); + } + + fn mix_key_and_hash(&mut self, crypto: &impl QlCrypto, input_key_material: &[u8]) { + let (chaining_key, hash_input, cipher_key) = + hkdf3(crypto, &self.chaining_key, input_key_material); + self.chaining_key = chaining_key; + self.mix_hash(crypto, &hash_input); + self.cipher.initialize_key(cipher_key); + } + + fn encrypt_and_hash( + &mut self, + crypto: &impl QlCrypto, + plaintext: &[u8], + ) -> Result, WireError> { + if self.cipher.has_key() { + let ciphertext = self + .cipher + .encrypt(crypto, &self.handshake_hash, plaintext)?; + self.mix_hash(crypto, &ciphertext); + Ok(ciphertext) + } else { + self.mix_hash(crypto, plaintext); + Ok(plaintext.to_vec()) + } + } + + fn decrypt_and_hash( + &mut self, + crypto: &impl QlCrypto, + ciphertext: &[u8], + ) -> Result, WireError> { + if self.cipher.has_key() { + let plaintext = self + .cipher + .decrypt(crypto, &self.handshake_hash, ciphertext)?; + self.mix_hash(crypto, ciphertext); + Ok(plaintext) + } else { + self.mix_hash(crypto, ciphertext); + Ok(ciphertext.to_vec()) + } + } + + fn split_for_role( + &self, + crypto: &impl QlCrypto, + role: Role, + ) -> (SessionKey, SessionKey) { + let temp_key = hmac_sha256(crypto, &self.chaining_key, &[&[]]); + let k1 = SessionKey::from_data(hmac_sha256(crypto, &temp_key, &[&[1]])); + let k2 = SessionKey::from_data(hmac_sha256(crypto, &temp_key, &[k1.as_bytes(), &[2]])); + match role { + Role::Initiator => (k1, k2), + Role::Responder => (k2, k1), + } + } +} + +fn init_kk_symmetric( + crypto: &impl QlCrypto, + initiator_bundle: &PeerBundle, + responder_bundle: &PeerBundle, +) -> SymmetricState { + let mut symmetric = SymmetricState::new(crypto, PROTOCOL_KK); + symmetric.mix_hash(crypto, &initiator_bundle.encode()); + symmetric.mix_hash(crypto, &responder_bundle.encode()); + symmetric } -impl Ready { - pub fn parse(bytes: B) -> Result { - Ok(Self { - encrypted: EncryptedMessage::parse(bytes)?, - }) +fn generate_ephemeral_keypair(crypto: &impl QlCrypto) -> HybridEphemeralKeyPair { + HybridEphemeralKeyPair { + x25519: crypto.x25519_generate_keypair(), + mlkem: crypto.mlkem_generate_keypair(), } } -impl Ready { - pub fn into_owned(self) -> Ready> - where - B: ByteSlice, - { - Ready { - encrypted: self.encrypted.into_owned(), - } +fn mix_hash_ephemeral( + symmetric: &mut SymmetricState, + crypto: &impl QlCrypto, + public: &HybridEphemeralPublic, +) { + symmetric.mix_hash(crypto, public.x25519_public_key.as_bytes()); + symmetric.mix_hash(crypto, public.mlkem_public_key.as_bytes()); +} + +fn encrypt_peer_bundle( + crypto: &impl QlCrypto, + symmetric: &mut SymmetricState, + bundle: &PeerBundle, +) -> Result { + let ciphertext = symmetric.encrypt_and_hash(crypto, &bundle.encode())?; + if ciphertext.len() != ENCRYPTED_PEER_BUNDLE_LEN { + return Err(WireError::InvalidState); } + let mut out = [0u8; ENCRYPTED_PEER_BUNDLE_LEN]; + out.copy_from_slice(&ciphertext); + Ok(EncryptedPeerBundle::from_data(out)) } -impl> Ready { - pub fn encode_into(&self, out: &mut Vec) { - self.encrypted.encode_into(out); +fn decrypt_peer_bundle( + crypto: &impl QlCrypto, + symmetric: &mut SymmetricState, + bundle: &EncryptedPeerBundle, +) -> Result { + let plaintext = symmetric.decrypt_and_hash(crypto, bundle.as_bytes())?; + PeerBundle::decode(&plaintext) +} + +fn encrypt_mlkem_ciphertext( + crypto: &impl QlCrypto, + symmetric: &mut SymmetricState, + ciphertext: &MlKemCiphertext, +) -> Result { + let encrypted = symmetric.encrypt_and_hash(crypto, ciphertext.as_bytes())?; + if encrypted.len() != ENCRYPTED_MLKEM_CIPHERTEXT_LEN { + return Err(WireError::InvalidState); } + let mut out = [0u8; ENCRYPTED_MLKEM_CIPHERTEXT_LEN]; + out.copy_from_slice(&encrypted); + Ok(EncryptedMlKemCiphertext::from_data(out)) } -impl Ready> { - pub fn decode(bytes: &[u8]) -> Result { - EncryptedMessage::parse(bytes).map(|encrypted| Self { - encrypted: encrypted.into_owned(), - }) +fn decrypt_mlkem_ciphertext( + crypto: &impl QlCrypto, + symmetric: &mut SymmetricState, + ciphertext: &EncryptedMlKemCiphertext, +) -> Result { + let plaintext = symmetric.decrypt_and_hash(crypto, ciphertext.as_bytes())?; + if plaintext.len() != MlKemCiphertext::SIZE { + return Err(WireError::InvalidPayload); } + let mut out = [0u8; MlKemCiphertext::SIZE]; + out.copy_from_slice(&plaintext); + Ok(MlKemCiphertext::from_data(out)) } -impl ReadyBody { - pub fn encode(&self) -> Vec { - self.meta.encode() +fn finalize_handshake( + crypto: &impl QlCrypto, + symmetric: SymmetricState, + role: Role, + remote_bundle: PeerBundle, +) -> FinalizedHandshake { + let handshake_hash = symmetric.handshake_hash; + let (tx_key, rx_key) = symmetric.split_for_role(crypto, role); + let (initiator_rx, responder_rx) = derive_connection_ids(crypto, &handshake_hash); + let (tx_connection_id, rx_connection_id) = match role { + Role::Initiator => (responder_rx, initiator_rx), + Role::Responder => (initiator_rx, responder_rx), + }; + FinalizedHandshake { + tx_key, + rx_key, + tx_connection_id, + rx_connection_id, + handshake_hash, + remote_bundle, } +} - pub fn decode(bytes: &[u8]) -> Result { - Ok(Self { - meta: ControlMeta::decode(bytes)?, - }) +fn derive_connection_ids( + crypto: &impl QlCrypto, + handshake_hash: &[u8; 32], +) -> (ConnectionId, ConnectionId) { + let initiator = crypto.sha256(&[CONNECTION_ID_DOMAIN, handshake_hash, b"initiator-rx"]); + let responder = crypto.sha256(&[CONNECTION_ID_DOMAIN, handshake_hash, b"responder-rx"]); + let mut initiator_rx = [0u8; ConnectionId::SIZE]; + let mut responder_rx = [0u8; ConnectionId::SIZE]; + initiator_rx.copy_from_slice(&initiator[..ConnectionId::SIZE]); + responder_rx.copy_from_slice(&responder[..ConnectionId::SIZE]); + ( + ConnectionId::from_data(initiator_rx), + ConnectionId::from_data(responder_rx), + ) +} + +fn noise_nonce(counter: u64) -> Nonce { + let mut nonce = [0u8; Nonce::SIZE]; + nonce[4..].copy_from_slice(&counter.to_le_bytes()); + Nonce(nonce) +} + +fn hkdf2( + crypto: &impl QlCrypto, + chaining_key: &[u8; 32], + input_key_material: &[u8], +) -> ([u8; 32], SessionKey) { + let temp_key = hmac_sha256(crypto, chaining_key, &[input_key_material]); + let out1 = hmac_sha256(crypto, &temp_key, &[&[1]]); + let out2 = hmac_sha256(crypto, &temp_key, &[&out1, &[2]]); + (out1, SessionKey::from_data(out2)) +} + +fn hkdf3( + crypto: &impl QlCrypto, + chaining_key: &[u8; 32], + input_key_material: &[u8], +) -> ([u8; 32], [u8; 32], SessionKey) { + let temp_key = hmac_sha256(crypto, chaining_key, &[input_key_material]); + let out1 = hmac_sha256(crypto, &temp_key, &[&[1]]); + let out2 = hmac_sha256(crypto, &temp_key, &[&out1, &[2]]); + let out3 = hmac_sha256(crypto, &temp_key, &[&out2, &[3]]); + (out1, out2, SessionKey::from_data(out3)) +} + +fn hmac_sha256(crypto: &impl QlCrypto, key: &[u8], parts: &[&[u8]]) -> [u8; 32] { + let mut key_block = [0u8; SHA256_BLOCK_LEN]; + if key.len() > SHA256_BLOCK_LEN { + key_block[..32].copy_from_slice(&crypto.sha256(&[key])); + } else { + key_block[..key.len()].copy_from_slice(key); } + + let mut ipad = [0x36u8; SHA256_BLOCK_LEN]; + let mut opad = [0x5cu8; SHA256_BLOCK_LEN]; + for (dst, src) in ipad.iter_mut().zip(key_block.iter()) { + *dst ^= *src; + } + for (dst, src) in opad.iter_mut().zip(key_block.iter()) { + *dst ^= *src; + } + + let mut inner_parts: Vec<&[u8]> = Vec::with_capacity(parts.len() + 1); + inner_parts.push(&ipad); + inner_parts.extend_from_slice(parts); + let inner = crypto.sha256(&inner_parts); + crypto.sha256(&[&opad, &inner]) } diff --git a/ql-wire/src/handshake/xx.rs b/ql-wire/src/handshake/xx.rs new file mode 100644 index 00000000..b3428e5b --- /dev/null +++ b/ql-wire/src/handshake/xx.rs @@ -0,0 +1,391 @@ +use super::{ + decrypt_mlkem_ciphertext, decrypt_peer_bundle, encrypt_mlkem_ciphertext, encrypt_peer_bundle, + finalize_handshake, generate_ephemeral_keypair, mix_hash_ephemeral, EncryptedMlKemCiphertext, + EncryptedPeerBundle, FinalizedHandshake, HybridEphemeralKeyPair, HybridEphemeralPublic, Role, + SymmetricState, ENCRYPTED_MLKEM_CIPHERTEXT_LEN, ENCRYPTED_PEER_BUNDLE_LEN, PROTOCOL_XX, +}; +use crate::{codec, ControlMeta, MlKemCiphertext, PeerBundle, QlCrypto, QlIdentity, WireError}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Xx1 { + pub meta: ControlMeta, + pub ephemeral: HybridEphemeralPublic, +} + +impl Xx1 { + pub const ENCODED_LEN: usize = ControlMeta::ENCODED_LEN + HybridEphemeralPublic::ENCODED_LEN; + + pub fn encode_into(&self, out: &mut Vec) { + self.meta.encode_into(out); + self.ephemeral.encode_into(out); + } + + pub fn decode(bytes: &[u8]) -> Result { + let mut reader = codec::Reader::new(bytes); + let meta = ControlMeta::decode_from(&mut reader)?; + let ephemeral = + HybridEphemeralPublic::decode(&reader.take_bytes(HybridEphemeralPublic::ENCODED_LEN)?)?; + reader.finish()?; + Ok(Self { meta, ephemeral }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Xx2 { + pub meta: ControlMeta, + pub ekem_ciphertext: MlKemCiphertext, + pub ephemeral: HybridEphemeralPublic, + pub static_bundle: EncryptedPeerBundle, +} + +impl Xx2 { + pub const ENCODED_LEN: usize = ControlMeta::ENCODED_LEN + + MlKemCiphertext::SIZE + + HybridEphemeralPublic::ENCODED_LEN + + ENCRYPTED_PEER_BUNDLE_LEN; + + pub fn encode_into(&self, out: &mut Vec) { + self.meta.encode_into(out); + codec::push_bytes(out, self.ekem_ciphertext.as_bytes()); + self.ephemeral.encode_into(out); + codec::push_bytes(out, self.static_bundle.as_bytes()); + } + + pub fn decode(bytes: &[u8]) -> Result { + let mut reader = codec::Reader::new(bytes); + let meta = ControlMeta::decode_from(&mut reader)?; + let ekem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); + let ephemeral = + HybridEphemeralPublic::decode(&reader.take_bytes(HybridEphemeralPublic::ENCODED_LEN)?)?; + let static_bundle = EncryptedPeerBundle::from_data(reader.take_array()?); + reader.finish()?; + Ok(Self { + meta, + ekem_ciphertext, + ephemeral, + static_bundle, + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Xx3 { + pub meta: ControlMeta, + pub skem_ciphertext: EncryptedMlKemCiphertext, + pub static_bundle: EncryptedPeerBundle, +} + +impl Xx3 { + pub const ENCODED_LEN: usize = + ControlMeta::ENCODED_LEN + ENCRYPTED_MLKEM_CIPHERTEXT_LEN + ENCRYPTED_PEER_BUNDLE_LEN; + + pub fn encode_into(&self, out: &mut Vec) { + self.meta.encode_into(out); + codec::push_bytes(out, self.skem_ciphertext.as_bytes()); + codec::push_bytes(out, self.static_bundle.as_bytes()); + } + + pub fn decode(bytes: &[u8]) -> Result { + let mut reader = codec::Reader::new(bytes); + let meta = ControlMeta::decode_from(&mut reader)?; + let skem_ciphertext = EncryptedMlKemCiphertext::from_data(reader.take_array()?); + let static_bundle = EncryptedPeerBundle::from_data(reader.take_array()?); + reader.finish()?; + Ok(Self { + meta, + skem_ciphertext, + static_bundle, + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Xx4 { + pub meta: ControlMeta, + pub skem_ciphertext: EncryptedMlKemCiphertext, +} + +impl Xx4 { + pub const ENCODED_LEN: usize = ControlMeta::ENCODED_LEN + ENCRYPTED_MLKEM_CIPHERTEXT_LEN; + + pub fn encode_into(&self, out: &mut Vec) { + self.meta.encode_into(out); + codec::push_bytes(out, self.skem_ciphertext.as_bytes()); + } + + pub fn decode(bytes: &[u8]) -> Result { + let mut reader = codec::Reader::new(bytes); + let meta = ControlMeta::decode_from(&mut reader)?; + let skem_ciphertext = EncryptedMlKemCiphertext::from_data(reader.take_array()?); + reader.finish()?; + Ok(Self { + meta, + skem_ciphertext, + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum XxMessage { + Message1(Xx1), + Message2(Xx2), + Message3(Xx3), + Message4(Xx4), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum XxStep { + Send1, + Recv1, + Send2, + Recv2, + Send3, + Recv3, + Send4, + Recv4, + Done, +} + +#[derive(Debug, Clone)] +pub struct XxHandshake { + role: Role, + step: XxStep, + symmetric: SymmetricState, + local: QlIdentity, + local_ephemeral: Option, + remote_ephemeral: Option, + remote_bundle: Option, +} + +impl XxHandshake { + pub fn new_initiator(crypto: &impl QlCrypto, local: QlIdentity) -> Self { + Self { + role: Role::Initiator, + step: XxStep::Send1, + symmetric: SymmetricState::new(crypto, PROTOCOL_XX), + local, + local_ephemeral: None, + remote_ephemeral: None, + remote_bundle: None, + } + } + + pub fn new_responder(crypto: &impl QlCrypto, local: QlIdentity) -> Self { + Self { + role: Role::Responder, + step: XxStep::Recv1, + symmetric: SymmetricState::new(crypto, PROTOCOL_XX), + local, + local_ephemeral: None, + remote_ephemeral: None, + remote_bundle: None, + } + } + + pub fn is_finished(&self) -> bool { + self.step == XxStep::Done + } + + pub fn write_message( + &mut self, + crypto: &impl QlCrypto, + meta: ControlMeta, + ) -> Result { + match self.step { + XxStep::Send1 => { + let local_ephemeral = generate_ephemeral_keypair(crypto); + let public = local_ephemeral.public(); + mix_hash_ephemeral(&mut self.symmetric, crypto, &public); + self.local_ephemeral = Some(local_ephemeral); + self.step = XxStep::Recv2; + Ok(XxMessage::Message1(Xx1 { + meta, + ephemeral: public, + })) + } + XxStep::Send2 => { + let remote_ephemeral = self + .remote_ephemeral + .clone() + .ok_or(WireError::InvalidState)?; + let (ekem_ciphertext, ekem_secret) = + crypto.mlkem_encapsulate(&remote_ephemeral.mlkem_public_key); + self.symmetric.mix_hash(crypto, ekem_ciphertext.as_bytes()); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let local_ephemeral = generate_ephemeral_keypair(crypto); + let public = local_ephemeral.public(); + mix_hash_ephemeral(&mut self.symmetric, crypto, &public); + let ee = crypto.x25519_agree( + &local_ephemeral.x25519.private, + &remote_ephemeral.x25519_public_key, + ); + self.symmetric.mix_key(crypto, ee.as_bytes()); + + let static_bundle = + encrypt_peer_bundle(crypto, &mut self.symmetric, &self.local.bundle())?; + + let es = crypto.x25519_agree( + &self.local.x25519_private_key, + &remote_ephemeral.x25519_public_key, + ); + self.symmetric.mix_key(crypto, es.as_bytes()); + + self.local_ephemeral = Some(local_ephemeral); + self.step = XxStep::Recv3; + + Ok(XxMessage::Message2(Xx2 { + meta, + ekem_ciphertext, + ephemeral: public, + static_bundle, + })) + } + XxStep::Send3 => { + let remote_bundle = self.remote_bundle.clone().ok_or(WireError::InvalidState)?; + let remote_ephemeral = self + .remote_ephemeral + .as_ref() + .ok_or(WireError::InvalidState)?; + + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); + let skem_ciphertext = + encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + let static_bundle = + encrypt_peer_bundle(crypto, &mut self.symmetric, &self.local.bundle())?; + + let se = crypto.x25519_agree( + &self.local.x25519_private_key, + &remote_ephemeral.x25519_public_key, + ); + self.symmetric.mix_key(crypto, se.as_bytes()); + self.step = XxStep::Recv4; + + Ok(XxMessage::Message3(Xx3 { + meta, + skem_ciphertext, + static_bundle, + })) + } + XxStep::Send4 => { + let remote_bundle = self.remote_bundle.clone().ok_or(WireError::InvalidState)?; + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); + let skem_ciphertext = + encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + self.step = XxStep::Done; + + Ok(XxMessage::Message4(Xx4 { + meta, + skem_ciphertext, + })) + } + _ => Err(WireError::InvalidState), + } + } + + pub fn read_message( + &mut self, + crypto: &impl QlCrypto, + message: &XxMessage, + ) -> Result<(), WireError> { + match (&self.step, message) { + (XxStep::Recv1, XxMessage::Message1(message)) => { + mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); + self.remote_ephemeral = Some(message.ephemeral.clone()); + self.step = XxStep::Send2; + Ok(()) + } + (XxStep::Recv2, XxMessage::Message2(message)) => { + let local_ephemeral = self + .local_ephemeral + .as_ref() + .ok_or(WireError::InvalidState)?; + self.symmetric + .mix_hash(crypto, message.ekem_ciphertext.as_bytes()); + let ekem_secret = crypto + .mlkem_decapsulate(&local_ephemeral.mlkem.private, &message.ekem_ciphertext); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); + self.remote_ephemeral = Some(message.ephemeral.clone()); + + let ee = crypto.x25519_agree( + &local_ephemeral.x25519.private, + &message.ephemeral.x25519_public_key, + ); + self.symmetric.mix_key(crypto, ee.as_bytes()); + + let remote_bundle = + decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; + let es = crypto.x25519_agree( + &local_ephemeral.x25519.private, + &remote_bundle.x25519_public_key, + ); + self.symmetric.mix_key(crypto, es.as_bytes()); + self.remote_bundle = Some(remote_bundle); + self.step = XxStep::Send3; + Ok(()) + } + (XxStep::Recv3, XxMessage::Message3(message)) => { + let skem_ciphertext = decrypt_mlkem_ciphertext( + crypto, + &mut self.symmetric, + &message.skem_ciphertext, + )?; + let skem_secret = + crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + let remote_bundle = + decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; + let local_ephemeral = self + .local_ephemeral + .as_ref() + .ok_or(WireError::InvalidState)?; + let se = crypto.x25519_agree( + &local_ephemeral.x25519.private, + &remote_bundle.x25519_public_key, + ); + self.symmetric.mix_key(crypto, se.as_bytes()); + self.remote_bundle = Some(remote_bundle); + self.step = XxStep::Send4; + Ok(()) + } + (XxStep::Recv4, XxMessage::Message4(message)) => { + let skem_ciphertext = decrypt_mlkem_ciphertext( + crypto, + &mut self.symmetric, + &message.skem_ciphertext, + )?; + let skem_secret = + crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + self.step = XxStep::Done; + Ok(()) + } + _ => Err(WireError::InvalidState), + } + } + + pub fn finalize(self, crypto: &impl QlCrypto) -> Result { + if !self.is_finished() { + return Err(WireError::InvalidState); + } + let remote_bundle = self.remote_bundle.ok_or(WireError::InvalidState)?; + Ok(finalize_handshake( + crypto, + self.symmetric, + self.role, + remote_bundle, + )) + } +} diff --git a/ql-wire/src/header.rs b/ql-wire/src/header.rs index dda25f3b..8f09d47d 100644 --- a/ql-wire/src/header.rs +++ b/ql-wire/src/header.rs @@ -1,41 +1,85 @@ -use crate::{codec, record::RecordKind, ByteSlice, WireError, XID}; +use crate::{codec, QL_WIRE_VERSION, XID}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct ConnectionId(pub [u8; Self::SIZE]); + +impl ConnectionId { + pub const SIZE: usize = 16; + + pub const fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(data) + } + + pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { + &self.0 + } +} #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct QlHeader { +pub struct HandshakeHeader { pub sender: XID, pub recipient: XID, } -impl QlHeader { - pub fn aad(&self) -> Vec { - codec::header_aad(self) - } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct SessionHeader { + pub connection_id: ConnectionId, } -#[derive(Debug, Clone, Copy)] -pub(crate) struct DecodedRecordHeader { - pub(crate) kind: RecordKind, - pub(crate) header: QlHeader, -} +impl HandshakeHeader { + pub const ENCODED_LEN: usize = XID::SIZE * 2; + + pub fn encode_into(&self, out: &mut Vec) { + codec::push_bytes(out, &self.sender.0); + codec::push_bytes(out, &self.recipient.0); + } -pub(crate) fn encode_record_header(out: &mut Vec, header: &QlHeader, kind: RecordKind) { - codec::push_u8(out, kind as u8); - codec::push_bytes(out, &header.sender.0); - codec::push_bytes(out, &header.recipient.0); + pub fn decode(bytes: &[u8]) -> Result { + let mut reader = codec::Reader::new(bytes); + let header = Self::decode_from(&mut reader)?; + reader.finish()?; + Ok(header) + } + + pub fn decode_from( + reader: &mut codec::Reader, + ) -> Result { + Ok(Self { + sender: XID(reader.take_array()?), + recipient: XID(reader.take_array()?), + }) + } } -pub(crate) fn decode_record_header( - bytes: B, -) -> Result<(DecodedRecordHeader, B), WireError> { - let mut reader = codec::Reader::new(bytes); - let kind = RecordKind::try_from(reader.take_u8()?)?; - let sender = XID(reader.take_array()?); - let recipient = XID(reader.take_array()?); - Ok(( - DecodedRecordHeader { - kind, - header: QlHeader { sender, recipient }, - }, - reader.take_rest(), - )) +impl SessionHeader { + pub const ENCODED_LEN: usize = ConnectionId::SIZE; + + pub fn encode_into(&self, out: &mut Vec) { + codec::push_bytes(out, self.connection_id.as_bytes()); + } + + pub fn decode(bytes: &[u8]) -> Result { + let mut reader = codec::Reader::new(bytes); + let header = Self::decode_from(&mut reader)?; + reader.finish()?; + Ok(header) + } + + pub fn decode_from( + reader: &mut codec::Reader, + ) -> Result { + Ok(Self { + connection_id: ConnectionId::from_data(reader.take_array()?), + }) + } + + pub fn aad(&self) -> Vec { + let mut aad = Vec::new(); + codec::append_field(&mut aad, b"domain", b"ql-wire:session-aad:v1"); + codec::append_field(&mut aad, b"wire-version", &[QL_WIRE_VERSION]); + codec::append_field(&mut aad, b"record-kind", b"session"); + codec::append_field(&mut aad, b"connection-id", self.connection_id.as_bytes()); + aad + } } diff --git a/ql-wire/src/identity.rs b/ql-wire/src/identity.rs index 574031cd..60e55f75 100644 --- a/ql-wire/src/identity.rs +++ b/ql-wire/src/identity.rs @@ -1,28 +1,106 @@ -use crate::{MlDsaPrivateKey, MlDsaPublicKey, MlKemPrivateKey, MlKemPublicKey, XID}; +use crate::{ + codec, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, QlCrypto, WireError, X25519KeyPair, + X25519PrivateKey, X25519PublicKey, XID, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PeerBundle { + pub version: u16, + pub capabilities: u32, + pub x25519_public_key: X25519PublicKey, + pub mlkem_public_key: MlKemPublicKey, +} + +impl PeerBundle { + pub const VERSION: u16 = 1; + pub const ENCODED_LEN: usize = core::mem::size_of::() + + core::mem::size_of::() + + X25519PublicKey::SIZE + + MlKemPublicKey::SIZE; + + pub fn encode_into(&self, out: &mut Vec) { + codec::push_u16(out, self.version); + codec::push_u32(out, self.capabilities); + codec::push_bytes(out, self.x25519_public_key.as_bytes()); + codec::push_bytes(out, self.mlkem_public_key.as_bytes()); + } + + pub fn encode(&self) -> Vec { + let mut out = Vec::with_capacity(Self::ENCODED_LEN); + self.encode_into(&mut out); + out + } + + pub fn decode(bytes: &[u8]) -> Result { + let mut reader = codec::Reader::new(bytes); + let bundle = Self { + version: reader.take_u16()?, + capabilities: reader.take_u32()?, + x25519_public_key: X25519PublicKey::from_data(reader.take_array()?), + mlkem_public_key: MlKemPublicKey::from_data(reader.take_array()?), + }; + reader.finish()?; + Ok(bundle) + } +} #[derive(Debug, Clone)] pub struct QlIdentity { pub xid: XID, - pub signing_private_key: MlDsaPrivateKey, - pub signing_public_key: MlDsaPublicKey, - pub encapsulation_private_key: MlKemPrivateKey, - pub encapsulation_public_key: MlKemPublicKey, + pub x25519_private_key: X25519PrivateKey, + pub x25519_public_key: X25519PublicKey, + pub mlkem_private_key: MlKemPrivateKey, + pub mlkem_public_key: MlKemPublicKey, + pub capabilities: u32, } impl QlIdentity { pub fn new( xid: XID, - signing_private_key: MlDsaPrivateKey, - signing_public_key: MlDsaPublicKey, - encapsulation_private_key: MlKemPrivateKey, - encapsulation_public_key: MlKemPublicKey, + x25519_private_key: X25519PrivateKey, + x25519_public_key: X25519PublicKey, + mlkem_private_key: MlKemPrivateKey, + mlkem_public_key: MlKemPublicKey, ) -> Self { Self { xid, - signing_private_key, - signing_public_key, - encapsulation_private_key, - encapsulation_public_key, + x25519_private_key, + x25519_public_key, + mlkem_private_key, + mlkem_public_key, + capabilities: 0, + } + } + + pub fn with_capabilities(mut self, capabilities: u32) -> Self { + self.capabilities = capabilities; + self + } + + pub fn bundle(&self) -> PeerBundle { + PeerBundle { + version: PeerBundle::VERSION, + capabilities: self.capabilities, + x25519_public_key: self.x25519_public_key, + mlkem_public_key: self.mlkem_public_key.clone(), } } } + +pub fn generate_identity(crypto: &impl QlCrypto, xid: XID) -> QlIdentity { + let X25519KeyPair { + private: x25519_private_key, + public: x25519_public_key, + } = crypto.x25519_generate_keypair(); + let MlKemKeyPair { + private: mlkem_private_key, + public: mlkem_public_key, + } = crypto.mlkem_generate_keypair(); + QlIdentity::new( + xid, + x25519_private_key, + x25519_public_key, + mlkem_private_key, + mlkem_public_key, + ) +} diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs index abed8470..4a9a2c27 100644 --- a/ql-wire/src/lib.rs +++ b/ql-wire/src/lib.rs @@ -7,6 +7,7 @@ mod bytes; mod codec; mod control; +mod crypto; mod encrypted; mod encrypted_message; mod error; @@ -14,14 +15,14 @@ mod handshake; mod header; mod identity; mod nonce; -mod pair; mod pq; mod record; -mod unpair; +mod x25519; mod xid; pub use bytes::*; pub use control::*; +pub use crypto::*; pub use encrypted::*; pub use encrypted_message::*; pub use error::*; @@ -29,37 +30,13 @@ pub use handshake::*; pub use header::*; pub use identity::*; pub use nonce::*; -pub use pair::*; pub use pq::*; pub use record::*; -pub use unpair::*; +pub use x25519::*; pub use xid::*; pub const QL_WIRE_VERSION: u8 = 1; pub const ENCRYPTED_MESSAGE_AUTH_SIZE: usize = 16; -pub trait QlCrypto { - fn fill_random_bytes(&self, data: &mut [u8]); - - fn hash(&self, parts: &[&[u8]]) -> [u8; 32]; - - fn encrypt_with_aead( - &self, - key: &SessionKey, - nonce: &Nonce, - aad: &[u8], - buffer: &mut [u8], - ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; - - fn decrypt_with_aead( - &self, - key: &SessionKey, - nonce: &Nonce, - aad: &[u8], - buffer: &mut [u8], - auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], - ) -> bool; -} - #[cfg(test)] mod tests; diff --git a/ql-wire/src/pair/crypto.rs b/ql-wire/src/pair/crypto.rs deleted file mode 100644 index 192ac2de..00000000 --- a/ql-wire/src/pair/crypto.rs +++ /dev/null @@ -1,131 +0,0 @@ -use super::{PairRequestBody, PairRequestRecord}; -use crate::{ - pq::ML_KEM_SUITE_TAG, ControlMeta, MlDsaPublicKey, MlKemCiphertext, MlKemPublicKey, QlCrypto, - QlHeader, QlIdentity, QlPayload, QlRecord, WireError, XID, -}; - -pub fn build_pair_request( - crypto: &impl QlCrypto, - identity: &QlIdentity, - recipient: XID, - recipient_encapsulation_key: &MlKemPublicKey, - meta: ControlMeta, -) -> QlRecord> { - let (session_key, kem_ct) = recipient_encapsulation_key.encapsulate_new_shared_secret(crypto); - let header = QlHeader { - sender: identity.xid, - recipient, - }; - let signing_pub_key = identity.signing_public_key.clone(); - let sender_encapsulation_key = identity.encapsulation_public_key.clone(); - let proof_data = hash_pairing_proof_data( - crypto, - &header, - &kem_ct, - &meta, - identity.xid, - &signing_pub_key, - &sender_encapsulation_key, - ); - let proof = identity.signing_private_key.sign(crypto, &proof_data); - let body = PairRequestBody { - meta, - xid: identity.xid, - signing_pub_key, - encapsulation_pub_key: sender_encapsulation_key, - proof, - }; - let body_bytes = body.encode(); - let aad = pairing_aad(&header, &kem_ct); - let mut nonce = [0u8; crate::Nonce::SIZE]; - crypto.fill_random_bytes(&mut nonce); - let encrypted = crate::encrypted_message::EncryptedMessage::encrypt( - crypto, - &session_key, - body_bytes, - &aad, - crate::Nonce(nonce), - ); - QlRecord { - header, - payload: QlPayload::PairRequest(PairRequestRecord { kem_ct, encrypted }), - } -} - -pub fn decrypt_pair_request>( - crypto: &impl QlCrypto, - identity: &QlIdentity, - header: &QlHeader, - request: PairRequestRecord, - now_seconds: u64, -) -> Result { - let PairRequestRecord { kem_ct, encrypted } = request; - let aad = pairing_aad(header, &kem_ct); - let session_key = identity - .encapsulation_private_key - .decapsulate_shared_secret(&kem_ct); - let mut plaintext = encrypted.decrypt_in_place(crypto, &session_key, &aad)?; - let decrypted = PairRequestBody::decode(plaintext.as_mut())?; - decrypted.meta.ensure_not_expired(now_seconds)?; - if decrypted.xid != header.sender { - return Err(WireError::InvalidPayload); - } - let proof_data = hash_pairing_proof_data( - crypto, - header, - &kem_ct, - &decrypted.meta, - decrypted.xid, - &decrypted.signing_pub_key, - &decrypted.encapsulation_pub_key, - ); - if decrypted - .signing_pub_key - .verify(&decrypted.proof, &proof_data) - { - Ok(decrypted) - } else { - Err(WireError::InvalidSignature) - } -} - -fn hash_pairing_proof_data( - crypto: &impl QlCrypto, - header: &QlHeader, - kem_ct: &MlKemCiphertext, - meta: &ControlMeta, - xid: XID, - signing_pub_key: &MlDsaPublicKey, - encapsulation_pub_key: &MlKemPublicKey, -) -> [u8; 32] { - let aad = pairing_aad(header, kem_ct); - let control_id = meta.control_id.0.to_le_bytes(); - let valid_until = meta.valid_until.to_le_bytes(); - crypto.hash(&[ - b"ql-wire:pair-proof:v1", - b"aad", - &aad, - b"control-id", - &control_id, - b"valid-until", - &valid_until, - b"xid", - &xid.0, - b"signing-pub-key", - signing_pub_key.as_bytes(), - b"encapsulation-pub-key-suite", - ML_KEM_SUITE_TAG, - b"encapsulation-pub-key", - encapsulation_pub_key.as_bytes(), - ]) -} - -fn pairing_aad(header: &QlHeader, kem_ct: &MlKemCiphertext) -> Vec { - let mut aad = Vec::new(); - crate::codec::append_field(&mut aad, b"domain", b"ql-wire:pair-aad:v1"); - crate::codec::append_field(&mut aad, b"sender", &header.sender.0); - crate::codec::append_field(&mut aad, b"recipient", &header.recipient.0); - crate::codec::append_field(&mut aad, b"kem-suite", ML_KEM_SUITE_TAG); - crate::codec::append_field(&mut aad, b"kem-ct", kem_ct.as_bytes()); - aad -} diff --git a/ql-wire/src/pair/mod.rs b/ql-wire/src/pair/mod.rs deleted file mode 100644 index b0d956d0..00000000 --- a/ql-wire/src/pair/mod.rs +++ /dev/null @@ -1,86 +0,0 @@ -use crate::{ - codec, encrypted_message::EncryptedMessage, ByteSlice, ControlMeta, MlDsaPublicKey, - MlDsaSignature, MlKemCiphertext, MlKemPublicKey, WireError, XID, -}; - -mod crypto; -pub use crypto::*; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct PairRequestRecord { - pub kem_ct: MlKemCiphertext, - pub encrypted: EncryptedMessage, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct PairRequestBody { - pub meta: ControlMeta, - pub xid: XID, - pub signing_pub_key: MlDsaPublicKey, - pub encapsulation_pub_key: MlKemPublicKey, - pub proof: MlDsaSignature, -} - -impl PairRequestRecord { - pub fn parse(bytes: B) -> Result { - let mut reader = codec::Reader::new(bytes); - Ok(Self { - kem_ct: MlKemCiphertext::from_data(reader.take_array()?), - encrypted: EncryptedMessage::parse(reader.take_rest())?, - }) - } -} - -impl PairRequestRecord { - pub fn into_owned(self) -> PairRequestRecord> - where - B: ByteSlice, - { - PairRequestRecord { - kem_ct: self.kem_ct, - encrypted: self.encrypted.into_owned(), - } - } -} - -impl> PairRequestRecord { - pub fn encode_into(&self, out: &mut Vec) { - codec::push_bytes(out, self.kem_ct.as_bytes()); - self.encrypted.encode_into(out); - } -} - -impl PairRequestBody { - pub const ENCODED_LEN: usize = ControlMeta::ENCODED_LEN - + XID::SIZE - + MlDsaPublicKey::SIZE - + MlKemPublicKey::SIZE - + MlDsaSignature::SIZE; - - pub fn encode_into(&self, out: &mut Vec) { - self.meta.encode_into(out); - codec::push_bytes(out, &self.xid.0); - codec::push_bytes(out, self.signing_pub_key.as_bytes()); - codec::push_bytes(out, self.encapsulation_pub_key.as_bytes()); - codec::push_bytes(out, self.proof.as_bytes()); - } - - pub fn encode(&self) -> Vec { - let mut out = Vec::with_capacity(Self::ENCODED_LEN); - self.encode_into(&mut out); - out - } - - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - let body = Self { - meta: ControlMeta::decode_from(&mut reader)?, - xid: XID(reader.take_array()?), - signing_pub_key: MlDsaPublicKey::from_data(reader.take_array()?), - encapsulation_pub_key: MlKemPublicKey::from_data(reader.take_array()?), - proof: MlDsaSignature::from_data(reader.take_array()?), - }; - reader.finish()?; - Ok(body) - } -} diff --git a/ql-wire/src/pq.rs b/ql-wire/src/pq.rs index c6783c1d..fa3c133f 100644 --- a/ql-wire/src/pq.rs +++ b/ql-wire/src/pq.rs @@ -1,9 +1,8 @@ -use libcrux_ml_dsa::{ml_dsa_87, KEY_GENERATION_RANDOMNESS_SIZE, SIGNING_RANDOMNESS_SIZE}; -use libcrux_ml_kem::{mlkem1024, KEY_GENERATION_SEED_SIZE, SHARED_SECRET_SIZE}; +use libcrux_ml_kem::{mlkem1024, SHARED_SECRET_SIZE}; use crate::QlCrypto; -pub(crate) const ML_KEM_SUITE_TAG: &[u8] = b"ml-kem-1024"; +pub const ML_KEM_SUITE_TAG: &[u8] = b"ml-kem-1024"; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct SessionKey([u8; Self::SIZE]); @@ -30,102 +29,6 @@ impl AsRef<[u8]> for SessionKey { } } -macro_rules! impl_byte_traits { - ($name:ident) => { - impl std::fmt::Debug for $name { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_tuple(stringify!($name)) - .field(&self.as_bytes()) - .finish() - } - } - - impl PartialEq for $name { - fn eq(&self, other: &Self) -> bool { - self.as_bytes() == other.as_bytes() - } - } - - impl Eq for $name {} - - impl std::hash::Hash for $name { - fn hash(&self, state: &mut H) { - self.as_bytes().hash(state); - } - } - }; -} - -#[derive(Clone)] -pub struct MlDsaPrivateKey(Box); - -impl_byte_traits!(MlDsaPrivateKey); - -impl MlDsaPrivateKey { - pub const SIZE: usize = ml_dsa_87::MLDSA87SigningKey::len(); - - pub fn from_data(data: [u8; Self::SIZE]) -> Self { - Self(Box::new(ml_dsa_87::MLDSA87SigningKey::new(data))) - } - - pub fn as_bytes(&self) -> &[u8; Self::SIZE] { - self.0.as_ref().as_ref() - } - - pub fn sign(&self, crypto: &impl QlCrypto, message: &[u8]) -> MlDsaSignature { - let mut randomness = [0u8; SIGNING_RANDOMNESS_SIZE]; - crypto.fill_random_bytes(&mut randomness); - // Safe: we always sign with the empty context, so the only remaining - // error is libcrux's negligible-probability rejection-sampling failure. - let signature = ml_dsa_87::sign(self.0.as_ref(), message, b"", randomness) - .expect("ML-DSA signing should not fail"); - MlDsaSignature(Box::new(signature)) - } -} - -#[derive(Clone)] -pub struct MlDsaPublicKey(Box); - -impl_byte_traits!(MlDsaPublicKey); - -impl MlDsaPublicKey { - pub const SIZE: usize = ml_dsa_87::MLDSA87VerificationKey::len(); - - pub fn from_data(data: [u8; Self::SIZE]) -> Self { - Self(Box::new(ml_dsa_87::MLDSA87VerificationKey::new(data))) - } - - pub fn as_bytes(&self) -> &[u8; Self::SIZE] { - self.0.as_ref().as_ref() - } - - pub fn verify(&self, signature: &MlDsaSignature, message: &[u8]) -> bool { - ml_dsa_87::verify(self.0.as_ref(), message, b"", signature.0.as_ref()).is_ok() - } - - pub fn verify_bytes(&self, signature: &[u8; MlDsaSignature::SIZE], message: &[u8]) -> bool { - let signature = ml_dsa_87::MLDSA87Signature::new(*signature); - ml_dsa_87::verify(self.0.as_ref(), message, b"", &signature).is_ok() - } -} - -#[derive(Clone)] -pub struct MlDsaSignature(Box); - -impl_byte_traits!(MlDsaSignature); - -impl MlDsaSignature { - pub const SIZE: usize = ml_dsa_87::MLDSA87Signature::len(); - - pub fn from_data(data: [u8; Self::SIZE]) -> Self { - Self(Box::new(ml_dsa_87::MLDSA87Signature::new(data))) - } - - pub fn as_bytes(&self) -> &[u8; Self::SIZE] { - ml_dsa_87::MLDSA87Signature::as_ref(self.0.as_ref()) - } -} - #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct MlKemPublicKey(Box<[u8; MlKemPublicKey::SIZE]>); @@ -139,20 +42,6 @@ impl MlKemPublicKey { pub fn as_bytes(&self) -> &[u8; Self::SIZE] { self.0.as_ref() } - - pub fn encapsulate_new_shared_secret( - &self, - crypto: &impl QlCrypto, - ) -> (SessionKey, MlKemCiphertext) { - let mut randomness = [0u8; SHARED_SECRET_SIZE]; - crypto.fill_random_bytes(&mut randomness); - let public_key = mlkem1024::MlKem1024PublicKey::from(self.as_bytes()); - let (ciphertext, shared_secret) = mlkem1024::encapsulate(&public_key, randomness); - ( - SessionKey::from_data(shared_secret), - MlKemCiphertext::from_data(*ciphertext.as_slice()), - ) - } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -168,20 +57,6 @@ impl MlKemPrivateKey { pub fn as_bytes(&self) -> &[u8; Self::SIZE] { self.0.as_ref() } - - pub fn decapsulate_shared_secret(&self, ciphertext: &MlKemCiphertext) -> SessionKey { - self.decapsulate_shared_secret_bytes(ciphertext.as_bytes()) - } - - pub fn decapsulate_shared_secret_bytes( - &self, - ciphertext: &[u8; MlKemCiphertext::SIZE], - ) -> SessionKey { - let private_key = mlkem1024::MlKem1024PrivateKey::from(self.as_bytes()); - let ciphertext = mlkem1024::MlKem1024Ciphertext::from(ciphertext); - let shared_secret = mlkem1024::decapsulate(&private_key, &ciphertext); - SessionKey::from_data(shared_secret) - } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -199,23 +74,12 @@ impl MlKemCiphertext { } } -pub fn generate_ml_dsa_keypair(crypto: &impl QlCrypto) -> (MlDsaPrivateKey, MlDsaPublicKey) { - let mut randomness = [0u8; KEY_GENERATION_RANDOMNESS_SIZE]; - crypto.fill_random_bytes(&mut randomness); - let key_pair = ml_dsa_87::generate_key_pair(randomness); - ( - MlDsaPrivateKey(Box::new(key_pair.signing_key)), - MlDsaPublicKey(Box::new(key_pair.verification_key)), - ) +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MlKemKeyPair { + pub private: MlKemPrivateKey, + pub public: MlKemPublicKey, } -pub fn generate_ml_kem_keypair(crypto: &impl QlCrypto) -> (MlKemPrivateKey, MlKemPublicKey) { - let mut randomness = [0u8; KEY_GENERATION_SEED_SIZE]; - crypto.fill_random_bytes(&mut randomness); - let key_pair = mlkem1024::generate_key_pair(randomness); - let (private_key, public_key) = key_pair.into_parts(); - ( - MlKemPrivateKey::from_data(*private_key.as_slice()), - MlKemPublicKey::from_data(*public_key.as_slice()), - ) +pub fn generate_ml_kem_keypair(crypto: &impl QlCrypto) -> MlKemKeyPair { + crypto.mlkem_generate_keypair() } diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index bd61c3fc..047ffda2 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -1,145 +1,233 @@ use crate::{ codec, encrypted_message::EncryptedMessage, - handshake::{self, Confirm, Hello, HelloReply, Ready}, - header::{decode_record_header, encode_record_header, QlHeader}, - pair::PairRequestRecord, - unpair::Unpair, - ByteSlice, WireError, QL_WIRE_VERSION, + handshake::{Kk1, Kk2, Xx1, Xx2, Xx3, Xx4}, + ByteSlice, HandshakeHeader, SessionHeader, WireError, QL_WIRE_VERSION, }; #[derive(Debug, Clone, PartialEq, Eq)] -pub struct QlRecord { - pub header: QlHeader, - pub payload: QlPayload, +pub struct QlHandshakeRecord { + pub header: HandshakeHeader, + pub payload: HandshakePayload, } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum QlPayload { - PairRequest(PairRequestRecord), - Unpair(Unpair), - Hello(Hello), - HelloReply(HelloReply), - Confirm(Confirm), - Ready(Ready), - Session(EncryptedMessage), +pub struct QlSessionRecord { + pub header: SessionHeader, + pub payload: EncryptedMessage, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum QlRecord { + Handshake(QlHandshakeRecord), + Session(QlSessionRecord), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum HandshakePayload { + Xx1(Xx1), + Xx2(Xx2), + Xx3(Xx3), + Xx4(Xx4), + Kk1(Kk1), + Kk2(Kk2), } #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] -pub(crate) enum RecordKind { - PairRequest = 1, - Hello = 2, - HelloReply = 3, - Confirm = 4, - Ready = 5, - Session = 6, - Unpair = 7, +pub enum RecordType { + Handshake = 1, + Session = 2, } -impl TryFrom for RecordKind { +impl TryFrom for RecordType { type Error = WireError; fn try_from(value: u8) -> Result { match value { - 1 => Ok(Self::PairRequest), - 2 => Ok(Self::Hello), - 3 => Ok(Self::HelloReply), - 4 => Ok(Self::Confirm), - 5 => Ok(Self::Ready), - 6 => Ok(Self::Session), - 7 => Ok(Self::Unpair), + 1 => Ok(Self::Handshake), + 2 => Ok(Self::Session), _ => Err(WireError::InvalidPayload), } } } -impl RecordKind { - fn for_payload(payload: &QlPayload) -> Self { - match payload { - QlPayload::PairRequest(_) => Self::PairRequest, - QlPayload::Unpair(_) => Self::Unpair, - QlPayload::Hello(_) => Self::Hello, - QlPayload::HelloReply(_) => Self::HelloReply, - QlPayload::Confirm(_) => Self::Confirm, - QlPayload::Ready(_) => Self::Ready, - QlPayload::Session(_) => Self::Session, +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum HandshakeKind { + Xx1 = 1, + Xx2 = 2, + Xx3 = 3, + Xx4 = 4, + Kk1 = 5, + Kk2 = 6, +} + +impl TryFrom for HandshakeKind { + type Error = WireError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::Xx1), + 2 => Ok(Self::Xx2), + 3 => Ok(Self::Xx3), + 4 => Ok(Self::Xx4), + 5 => Ok(Self::Kk1), + 6 => Ok(Self::Kk2), + _ => Err(WireError::InvalidPayload), } } } -impl> QlRecord { +impl HandshakePayload { + pub fn kind(&self) -> HandshakeKind { + match self { + Self::Xx1(_) => HandshakeKind::Xx1, + Self::Xx2(_) => HandshakeKind::Xx2, + Self::Xx3(_) => HandshakeKind::Xx3, + Self::Xx4(_) => HandshakeKind::Xx4, + Self::Kk1(_) => HandshakeKind::Kk1, + Self::Kk2(_) => HandshakeKind::Kk2, + } + } + + fn encode_into(&self, out: &mut Vec) { + match self { + Self::Xx1(message) => message.encode_into(out), + Self::Xx2(message) => message.encode_into(out), + Self::Xx3(message) => message.encode_into(out), + Self::Xx4(message) => message.encode_into(out), + Self::Kk1(message) => message.encode_into(out), + Self::Kk2(message) => message.encode_into(out), + } + } + + fn decode(kind: HandshakeKind, bytes: &[u8]) -> Result { + match kind { + HandshakeKind::Xx1 => Ok(Self::Xx1(Xx1::decode(bytes)?)), + HandshakeKind::Xx2 => Ok(Self::Xx2(Xx2::decode(bytes)?)), + HandshakeKind::Xx3 => Ok(Self::Xx3(Xx3::decode(bytes)?)), + HandshakeKind::Xx4 => Ok(Self::Xx4(Xx4::decode(bytes)?)), + HandshakeKind::Kk1 => Ok(Self::Kk1(Kk1::decode(bytes)?)), + HandshakeKind::Kk2 => Ok(Self::Kk2(Kk2::decode(bytes)?)), + } + } +} + +impl QlHandshakeRecord { pub fn encode(&self) -> Vec { let mut out = Vec::new(); codec::push_u8(&mut out, QL_WIRE_VERSION); - encode_record_header( - &mut out, - &self.header, - RecordKind::for_payload(&self.payload), - ); - match &self.payload { - QlPayload::PairRequest(request) => request.encode_into(&mut out), - QlPayload::Unpair(unpair) => unpair.encode_into(&mut out), - QlPayload::Hello(hello) => hello.encode_into(&mut out), - QlPayload::HelloReply(reply) => reply.encode_into(&mut out), - QlPayload::Confirm(confirm) => confirm.encode_into(&mut out), - QlPayload::Ready(ready) => ready.encode_into(&mut out), - QlPayload::Session(encrypted) => encrypted.encode_into(&mut out), + codec::push_u8(&mut out, RecordType::Handshake as u8); + self.header.encode_into(&mut out); + codec::push_u8(&mut out, self.payload.kind() as u8); + self.payload.encode_into(&mut out); + out + } + + pub fn decode(bytes: &[u8]) -> Result { + Ok(Self::parse(bytes)?) + } + + pub fn parse(bytes: B) -> Result { + let mut reader = codec::Reader::new(bytes); + if reader.take_u8()? != QL_WIRE_VERSION { + return Err(WireError::InvalidPayload); + } + if RecordType::try_from(reader.take_u8()?)? != RecordType::Handshake { + return Err(WireError::InvalidPayload); } + parse_handshake_record(reader.take_rest()) + } +} + +impl> QlSessionRecord { + pub fn encode(&self) -> Vec { + let mut out = Vec::new(); + codec::push_u8(&mut out, QL_WIRE_VERSION); + codec::push_u8(&mut out, RecordType::Session as u8); + self.header.encode_into(&mut out); + self.payload.encode_into(&mut out); out } } -impl QlRecord> { +impl QlSessionRecord> { pub fn decode(bytes: &[u8]) -> Result { - Ok(QlRecord::parse(bytes)?.into_owned()) + Ok(QlSessionRecord::parse(bytes)?.into_owned()) } } -impl QlRecord { +impl QlSessionRecord { pub fn parse(bytes: B) -> Result { let mut reader = codec::Reader::new(bytes); if reader.take_u8()? != QL_WIRE_VERSION { return Err(WireError::InvalidPayload); } - let (header, payload_bytes) = decode_record_header(reader.take_rest())?; - let payload = parse_payload(header.kind, payload_bytes)?; - Ok(Self { - header: header.header, - payload, - }) + if RecordType::try_from(reader.take_u8()?)? != RecordType::Session { + return Err(WireError::InvalidPayload); + } + parse_session_record(reader.take_rest()) } - pub fn into_owned(self) -> QlRecord> { - QlRecord { + pub fn into_owned(self) -> QlSessionRecord> { + QlSessionRecord { header: self.header, payload: self.payload.into_owned(), } } } -impl QlPayload { - pub fn into_owned(self) -> QlPayload> { +impl> QlRecord { + pub fn encode(&self) -> Vec { match self { - Self::PairRequest(request) => QlPayload::PairRequest(request.into_owned()), - Self::Unpair(unpair) => QlPayload::Unpair(unpair), - Self::Hello(hello) => QlPayload::Hello(hello), - Self::HelloReply(reply) => QlPayload::HelloReply(reply), - Self::Confirm(confirm) => QlPayload::Confirm(confirm), - Self::Ready(ready) => QlPayload::Ready(ready.into_owned()), - Self::Session(encrypted) => QlPayload::Session(encrypted.into_owned()), + Self::Handshake(record) => record.encode(), + Self::Session(record) => record.encode(), } } } -fn parse_payload(kind: RecordKind, payload: B) -> Result, WireError> { - match kind { - RecordKind::PairRequest => Ok(QlPayload::PairRequest(PairRequestRecord::parse(payload)?)), - RecordKind::Unpair => Ok(QlPayload::Unpair(Unpair::decode(&payload[..])?)), - RecordKind::Hello => Ok(QlPayload::Hello(handshake::Hello::decode(&payload[..])?)), - RecordKind::HelloReply => Ok(QlPayload::HelloReply(HelloReply::decode(&payload[..])?)), - RecordKind::Confirm => Ok(QlPayload::Confirm(Confirm::decode(&payload[..])?)), - RecordKind::Ready => Ok(QlPayload::Ready(Ready::parse(payload)?)), - RecordKind::Session => Ok(QlPayload::Session(EncryptedMessage::parse(payload)?)), +impl QlRecord> { + pub fn decode(bytes: &[u8]) -> Result { + Ok(QlRecord::parse(bytes)?.into_owned()) + } +} + +impl QlRecord { + pub fn parse(bytes: B) -> Result { + let mut reader = codec::Reader::new(bytes); + if reader.take_u8()? != QL_WIRE_VERSION { + return Err(WireError::InvalidPayload); + } + + let record_type = RecordType::try_from(reader.take_u8()?)?; + let remaining = reader.take_rest(); + match record_type { + RecordType::Handshake => Ok(Self::Handshake(parse_handshake_record(remaining)?)), + RecordType::Session => Ok(Self::Session(parse_session_record(remaining)?)), + } } + + pub fn into_owned(self) -> QlRecord> { + match self { + Self::Handshake(record) => QlRecord::Handshake(record), + Self::Session(record) => QlRecord::Session(record.into_owned()), + } + } +} + +fn parse_handshake_record(bytes: B) -> Result { + let mut reader = codec::Reader::new(bytes); + let header = HandshakeHeader::decode_from(&mut reader)?; + let kind = HandshakeKind::try_from(reader.take_u8()?)?; + let payload = reader.take_rest(); + let payload = HandshakePayload::decode(kind, &payload[..])?; + Ok(QlHandshakeRecord { header, payload }) +} + +fn parse_session_record(bytes: B) -> Result, WireError> { + let mut reader = codec::Reader::new(bytes); + let header = SessionHeader::decode_from(&mut reader)?; + let payload = EncryptedMessage::parse(reader.take_rest())?; + Ok(QlSessionRecord { header, payload }) } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index ecbac195..056ff205 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -1,35 +1,41 @@ -use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::atomic::{AtomicU64, Ordering}; use libcrux_aesgcm::AesGcm256Key; use sha2::{Digest, Sha256}; use super::*; -struct TestCrypto(AtomicU8); +struct TestCrypto { + counter: AtomicU64, +} impl TestCrypto { - fn new(seed: u8) -> Self { - Self(AtomicU8::new(seed)) + fn new(seed: u64) -> Self { + Self { + counter: AtomicU64::new(seed), + } + } + + fn next_block(&self) -> [u8; 32] { + let value = self.counter.fetch_add(1, Ordering::Relaxed).to_le_bytes(); + sha256_parts(&[b"ql-wire:test-rng:v1", &value]) } } -impl QlCrypto for TestCrypto { - fn fill_random_bytes(&self, data: &mut [u8]) { - let seed = self.0.fetch_add(1, Ordering::Relaxed); - for (index, byte) in data.iter_mut().enumerate() { - *byte = seed.wrapping_add(index as u8); - } +impl QlRandom for TestCrypto { + fn fill_random_bytes(&self, out: &mut [u8]) { + fill_expanded(self, &[b"ql-wire:test-fill:v1"], out); } +} - fn hash(&self, parts: &[&[u8]]) -> [u8; 32] { - let mut hasher = Sha256::new(); - for part in parts { - hasher.update(part); - } - hasher.finalize().into() +impl QlHash for TestCrypto { + fn sha256(&self, parts: &[&[u8]]) -> [u8; 32] { + sha256_parts(parts) } +} - fn encrypt_with_aead( +impl QlAead for TestCrypto { + fn aes256_gcm_encrypt( &self, key: &SessionKey, nonce: &Nonce, @@ -50,7 +56,7 @@ impl QlCrypto for TestCrypto { auth } - fn decrypt_with_aead( + fn aes256_gcm_decrypt( &self, key: &SessionKey, nonce: &Nonce, @@ -65,617 +71,322 @@ impl QlCrypto for TestCrypto { } } -#[test] -fn encrypted_session_record_round_trip_and_decrypt() { - let crypto = TestCrypto::new(1); - let header = QlHeader { - sender: XID([1; XID::SIZE]), - recipient: XID([2; XID::SIZE]), - }; - let body = SessionRecord { - seq: RecordSeq(11), - frames: vec![ - SessionFrame::Ping, - SessionFrame::Ack(RecordAck { - ranges: vec![ - RecordAckRange { start: 12, end: 14 }, - RecordAckRange { start: 20, end: 24 }, - ], - }), - SessionFrame::StreamWindow(StreamWindow { - stream_id: StreamId(9), - maximum_offset: 65_536, - }), - SessionFrame::StreamData(StreamData { - stream_id: StreamId(9), - offset: 1024, - bytes: b"hello".to_vec(), - fin: true, - }), - SessionFrame::StreamClose(StreamClose { - stream_id: StreamId(9), - target: CloseTarget::Both, - code: CloseCode::PROTOCOL, - }), - SessionFrame::Close(SessionCloseBody { - code: CloseCode::TIMEOUT, - }), - ], - }; - let session_key = SessionKey::from_data([7; SessionKey::SIZE]); - let record = encrypted::encrypt_record( - &crypto, - header, - &session_key, - &body, - Nonce([8; Nonce::SIZE]), - ); +impl QlDh for TestCrypto { + fn x25519_generate_keypair(&self) -> X25519KeyPair { + let private = self.next_block(); + X25519KeyPair { + private: X25519PrivateKey::from_data(private), + public: X25519PublicKey::from_data(private), + } + } - let bytes = record.encode(); - let decoded = QlRecord::decode(&bytes).unwrap(); - assert_eq!(decoded.header, header); - assert!(matches!(decoded.payload, QlPayload::Session(_))); + fn x25519_agree( + &self, + private_key: &X25519PrivateKey, + public_key: &X25519PublicKey, + ) -> SessionKey { + let left = *private_key.as_bytes(); + let right = *public_key.as_bytes(); + let (first, second) = if left <= right { + (left, right) + } else { + (right, left) + }; + SessionKey::from_data(self.sha256(&[b"ql-wire:test-x25519:v1", &first, &second])) + } +} - let parsed = QlRecord::parse(bytes.as_slice()).unwrap(); - assert_eq!(parsed.into_owned(), record); +impl QlKem for TestCrypto { + fn mlkem_generate_keypair(&self) -> MlKemKeyPair { + let seed = self.next_block(); + let key_id = self.sha256(&[b"ql-wire:test-mlkem:key-id:v1", &seed]); - let mut bytes = bytes; - let QlRecord { header, payload } = QlRecord::parse(&mut bytes[..]).unwrap(); - let QlPayload::Session(encrypted) = payload else { - panic!("expected session payload"); - }; - let decrypted = encrypted::decrypt_record(&crypto, &header, encrypted, &session_key).unwrap(); - assert_eq!(SessionRecord::decode(decrypted).unwrap(), body); -} + let mut public = [0u8; MlKemPublicKey::SIZE]; + fill_expanded(self, &[b"ql-wire:test-mlkem:public:v1", &seed], &mut public); + public[..key_id.len()].copy_from_slice(&key_id); -#[test] -fn decrypted_session_record_iterates_zero_copy_frames() { - let crypto = TestCrypto::new(2); - let header = QlHeader { - sender: XID([9; XID::SIZE]), - recipient: XID([10; XID::SIZE]), - }; - let body = SessionRecord { - seq: RecordSeq(7), - frames: vec![ - SessionFrame::StreamData(StreamData { - stream_id: StreamId(1), - offset: 5, - fin: false, - bytes: b"abc".to_vec(), - }), - SessionFrame::Ack(RecordAck { - ranges: vec![RecordAckRange { start: 3, end: 8 }], - }), - SessionFrame::StreamClose(StreamClose { - stream_id: StreamId(1), - target: CloseTarget::Response, - code: CloseCode::CANCELLED, - }), - ], - }; - let session_key = SessionKey::from_data([3; SessionKey::SIZE]); - let record = encrypted::encrypt_record( - &crypto, - header, - &session_key, - &body, - Nonce([4; Nonce::SIZE]), - ); + let mut private = [0u8; MlKemPrivateKey::SIZE]; + fill_expanded( + self, + &[b"ql-wire:test-mlkem:private:v1", &seed], + &mut private, + ); + private[..key_id.len()].copy_from_slice(&key_id); - let mut bytes = record.encode(); - let QlRecord { header, payload } = QlRecord::parse(&mut bytes[..]).unwrap(); - let QlPayload::Session(encrypted) = payload else { - panic!("expected session payload"); - }; - let decrypted = encrypted::decrypt_record(&crypto, &header, encrypted, &session_key).unwrap(); - let (seq, mut frames) = SessionRecord::parse(decrypted).unwrap(); - assert_eq!(seq, RecordSeq(7)); - match frames.next().unwrap().unwrap() { - SessionFrame::StreamData(frame) => { - assert_eq!(frame.stream_id, StreamId(1)); - assert_eq!(frame.offset, 5); - assert!(!frame.fin); - assert_eq!(frame.bytes, b"abc"); + MlKemKeyPair { + private: MlKemPrivateKey::from_data(private), + public: MlKemPublicKey::from_data(public), } - other => panic!("expected stream data, got {}", frame_name(&other)), } - match frames.next().unwrap().unwrap() { - SessionFrame::Ack(frame) => { - assert_eq!(frame.ranges, vec![RecordAckRange { start: 3, end: 8 }]); - } - other => panic!("expected ack, got {}", frame_name(&other)), + + fn mlkem_encapsulate(&self, public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { + let mut encaps_seed = [0u8; 32]; + self.fill_random_bytes(&mut encaps_seed); + let key_id = &public_key.as_bytes()[..32]; + + let mut ciphertext = [0u8; MlKemCiphertext::SIZE]; + fill_expanded( + self, + &[b"ql-wire:test-mlkem:ciphertext:v1", &encaps_seed], + &mut ciphertext, + ); + ciphertext[..encaps_seed.len()].copy_from_slice(&encaps_seed); + + let shared = self.sha256(&[b"ql-wire:test-mlkem:shared:v1", key_id, &encaps_seed]); + ( + MlKemCiphertext::from_data(ciphertext), + SessionKey::from_data(shared), + ) } - match frames.next().unwrap().unwrap() { - SessionFrame::StreamClose(frame) => { - assert_eq!(frame.stream_id, StreamId(1)); - assert_eq!(frame.target, CloseTarget::Response); - assert_eq!(frame.code, CloseCode::CANCELLED); - } - other => panic!("expected stream close, got {}", frame_name(&other)), + + fn mlkem_decapsulate( + &self, + private_key: &MlKemPrivateKey, + ciphertext: &MlKemCiphertext, + ) -> SessionKey { + let key_id = &private_key.as_bytes()[..32]; + let encaps_seed = &ciphertext.as_bytes()[..32]; + SessionKey::from_data(self.sha256(&[b"ql-wire:test-mlkem:shared:v1", key_id, encaps_seed])) } - assert!(frames.next().is_none()); } -#[test] -fn pair_request_round_trip_and_decrypt() { - let crypto = TestCrypto::new(9); - let sender_signing = generate_ml_dsa_keypair(&crypto); - let sender_kem = generate_ml_kem_keypair(&crypto); - let recipient_signing = generate_ml_dsa_keypair(&crypto); - let recipient_kem = generate_ml_kem_keypair(&crypto); - - let sender = QlIdentity::new( - XID([3; XID::SIZE]), - sender_signing.0, - sender_signing.1, - sender_kem.0, - sender_kem.1, - ); - let recipient = QlIdentity::new( - XID([4; XID::SIZE]), - recipient_signing.0, - recipient_signing.1, - recipient_kem.0, - recipient_kem.1, - ); - let meta = ControlMeta { - control_id: ControlId(55), - valid_until: 999, - }; - let record = pair::build_pair_request( - &crypto, - &sender, - recipient.xid, - &recipient.encapsulation_public_key, - meta, - ); +fn sha256_parts(parts: &[&[u8]]) -> [u8; 32] { + let mut hasher = Sha256::new(); + for part in parts { + hasher.update(part); + } + hasher.finalize().into() +} - let mut bytes = record.encode(); - let QlRecord { header, payload } = QlRecord::parse(&mut bytes[..]).unwrap(); - let QlPayload::PairRequest(request) = payload else { - panic!("expected pair request"); - }; - let body = pair::decrypt_pair_request(&crypto, &recipient, &header, request, 100).unwrap(); - assert_eq!(body.meta, meta); - assert_eq!(body.xid, sender.xid); - assert_eq!(body.signing_pub_key, sender.signing_public_key); - assert_eq!(body.encapsulation_pub_key, sender.encapsulation_public_key); +fn fill_expanded(crypto: &TestCrypto, parts: &[&[u8]], out: &mut [u8]) { + let mut written = 0usize; + let mut counter = 0u64; + while written < out.len() { + let random = crypto.next_block(); + let counter_bytes = counter.to_le_bytes(); + let mut inputs = Vec::with_capacity(parts.len() + 3); + inputs.push(b"ql-wire:test-expand:v1".as_slice()); + inputs.push(&random); + inputs.push(&counter_bytes); + inputs.extend_from_slice(parts); + let block = sha256_parts(&inputs); + let take = (out.len() - written).min(block.len()); + out[written..written + take].copy_from_slice(&block[..take]); + written += take; + counter = counter.wrapping_add(1); + } } -#[test] -fn ready_round_trip_and_decrypt() { - let crypto = TestCrypto::new(30); - let header = QlHeader { - sender: XID([5; XID::SIZE]), - recipient: XID([6; XID::SIZE]), - }; - let session_key = SessionKey::from_data([11; SessionKey::SIZE]); - let meta = ControlMeta { - control_id: ControlId(77), - valid_until: 500, - }; - let ready = handshake::build_ready( - &crypto, - header, - &session_key, - meta, - Nonce([12; Nonce::SIZE]), - ); - let record: QlRecord> = QlRecord { - header, - payload: QlPayload::Ready(ready), - }; +fn xid(byte: u8) -> XID { + XID([byte; XID::SIZE]) +} - let mut bytes = record.encode(); - let parsed = QlRecord::decode(&bytes).unwrap(); - assert_eq!(parsed, record); +fn control_meta(id: u32) -> ControlMeta { + ControlMeta { + control_id: ControlId(id), + valid_until: 10_000 + u64::from(id), + } +} - let QlRecord { header, payload } = QlRecord::parse(&mut bytes[..]).unwrap(); - let QlPayload::Ready(ready) = payload else { - panic!("expected ready payload"); - }; - let body = handshake::decrypt_ready(&crypto, &header, ready, &session_key, 100).unwrap(); - assert_eq!(body.meta, meta); +fn make_identity(crypto: &impl QlCrypto, byte: u8) -> QlIdentity { + generate_identity(crypto, xid(byte)) } #[test] -fn unpair_round_trip_and_verify() { - let crypto = TestCrypto::new(40); - let (sender_signing_private, sender_signing_public) = generate_ml_dsa_keypair(&crypto); - let sender_kem = generate_ml_kem_keypair(&crypto); - let identity = QlIdentity::new( - XID([7; XID::SIZE]), - sender_signing_private, - sender_signing_public.clone(), - sender_kem.0, - sender_kem.1, - ); - let recipient = XID([8; XID::SIZE]); - let meta = ControlMeta { - control_id: ControlId(88), - valid_until: 600, - }; - let record = unpair::build_unpair(&crypto, &identity, recipient, meta); +fn peer_bundle_round_trip() { + let crypto = TestCrypto::new(1); + let identity = make_identity(&crypto, 7).with_capabilities(0x55aa_33cc); + let bundle = identity.bundle(); - let mut bytes = record.encode(); - let parsed = QlRecord::decode(&bytes).unwrap(); - assert_eq!(parsed, record); + let encoded = bundle.encode(); + let decoded = PeerBundle::decode(&encoded).unwrap(); - let QlRecord { header, payload } = QlRecord::parse(&mut bytes[..]).unwrap(); - let QlPayload::Unpair(unpair) = payload else { - panic!("expected unpair payload"); - }; - unpair::verify_unpair(&crypto, &header, &sender_signing_public, &unpair, 100).unwrap(); + assert_eq!(decoded, bundle); } #[test] -fn session_record_rejects_malformed_frames() { - let invalid_cases = [ - { - let mut bytes = Vec::new(); - bytes.extend_from_slice(&1u32.to_le_bytes()); - bytes - }, - { - let mut bytes = 1u64.to_le_bytes().to_vec(); - bytes.push(0xff); - bytes - }, - { - let mut bytes = 1u64.to_le_bytes().to_vec(); - bytes.push(SessionFrameKind::StreamData as u8); - bytes.push(1); - bytes - }, - { - let mut bytes = 1u64.to_le_bytes().to_vec(); - bytes.push(SessionFrameKind::StreamData as u8); - bytes.extend_from_slice(&13u16.to_le_bytes()); - bytes.extend_from_slice(&1u32.to_le_bytes()); - bytes.extend_from_slice(&4u64.to_le_bytes()); - bytes.push(0); - bytes.extend_from_slice(b"abc"); - bytes +fn handshake_record_round_trip_uses_handshake_header() { + let message = Xx1 { + meta: control_meta(1), + ephemeral: HybridEphemeralPublic { + x25519_public_key: X25519PublicKey::from_data([3; X25519PublicKey::SIZE]), + mlkem_public_key: MlKemPublicKey::from_data([9; MlKemPublicKey::SIZE]), }, - { - let mut bytes = 1u64.to_le_bytes().to_vec(); - bytes.push(SessionFrameKind::Ack as u8); - bytes.extend_from_slice(&0u16.to_le_bytes()); - bytes - }, - { - let mut bytes = 1u64.to_le_bytes().to_vec(); - bytes.push(SessionFrameKind::Ack as u8); - bytes.extend_from_slice(&8u16.to_le_bytes()); - bytes.extend_from_slice(&5u64.to_le_bytes()); - bytes - }, - { - let mut bytes = 1u64.to_le_bytes().to_vec(); - bytes.push(SessionFrameKind::Ack as u8); - bytes.extend_from_slice(&32u16.to_le_bytes()); - bytes.extend_from_slice(&6u64.to_le_bytes()); - bytes.extend_from_slice(&8u64.to_le_bytes()); - bytes.extend_from_slice(&7u64.to_le_bytes()); - bytes.extend_from_slice(&9u64.to_le_bytes()); - bytes - }, - { - let mut bytes = 1u64.to_le_bytes().to_vec(); - bytes.push(SessionFrameKind::StreamClose as u8); - bytes.extend_from_slice(&9u16.to_le_bytes()); - bytes.extend_from_slice(&1u32.to_le_bytes()); - bytes.push(CloseTarget::Both as u8); - bytes.extend_from_slice(&CloseCode::PROTOCOL.0.to_le_bytes()); - bytes.extend_from_slice(b"abc"); - bytes - }, - { - let mut bytes = 1u64.to_le_bytes().to_vec(); - bytes.push(SessionFrameKind::Close as u8); - bytes.push(0); - bytes + }; + let record = QlHandshakeRecord { + header: HandshakeHeader { + sender: xid(1), + recipient: xid(2), }, - ]; - - for bytes in invalid_cases { - assert_eq!( - SessionRecord::decode(bytes.as_slice()), - Err(WireError::InvalidPayload) - ); - } -} - -#[test] -fn session_record_supports_empty_fin_stream_data_and_empty_ping() { - let record = SessionRecord { - seq: RecordSeq(99), - frames: vec![ - SessionFrame::Ping, - SessionFrame::StreamData(StreamData { - stream_id: StreamId(42), - offset: 999, - fin: true, - bytes: Vec::new(), - }), - ], + payload: HandshakePayload::Xx1(message), }; let encoded = record.encode(); - assert_eq!(&encoded[..8], &99u64.to_le_bytes()); - assert_eq!(encoded[8], SessionFrameKind::Ping as u8); + let decoded = QlHandshakeRecord::decode(&encoded).unwrap(); - let decoded = SessionRecord::decode(&encoded).unwrap(); assert_eq!(decoded, record); + + let decoded = QlRecord::decode(&encoded).unwrap(); + assert_eq!(decoded, QlRecord::Handshake(record)); } #[test] -fn session_record_builder_writes_frames_without_temp_record_allocation() { - let mut builder = SessionRecordBuilder::new(RecordSeq(55), 12); - let stream = StreamData { - stream_id: StreamId(3), - offset: 7, - fin: true, - bytes: b"hello", - }; - assert!(builder.push_stream_data(&stream)); - assert_eq!(builder.remaining_capacity(), 0); - assert!(!builder.push_ping()); +fn xx_handshake_round_trip_derives_matching_transport() { + let crypto = TestCrypto::new(10); + let initiator = make_identity(&crypto, 1); + let responder = make_identity(&crypto, 2); - let close = SessionCloseBody { - code: CloseCode::PROTOCOL, - }; - assert!(!builder.push_close(&close)); + let mut initiator_state = XxHandshake::new_initiator(&crypto, initiator.clone()); + let mut responder_state = XxHandshake::new_responder(&crypto, responder.clone()); - let encoded = builder.into_plaintext(); - let decoded = SessionRecord::decode(&encoded).unwrap(); - assert_eq!( - decoded, - SessionRecord { - seq: RecordSeq(55), - frames: vec![SessionFrame::StreamData(StreamData { - stream_id: StreamId(3), - offset: 7, - fin: true, - bytes: b"hello".to_vec(), - })], - } - ); -} + let m1 = initiator_state + .write_message(&crypto, control_meta(1)) + .unwrap(); + responder_state.read_message(&crypto, &m1).unwrap(); -#[test] -fn session_record_builder_encodes_borrowed_vec_deque_stream_data() { - use std::collections::VecDeque; + let m2 = responder_state + .write_message(&crypto, control_meta(2)) + .unwrap(); + initiator_state.read_message(&crypto, &m2).unwrap(); - let mut payload = VecDeque::with_capacity(8); - payload.extend(b"abcd".iter().copied()); - payload.drain(..2); - payload.extend(b"efgh".iter().copied()); + let m3 = initiator_state + .write_message(&crypto, control_meta(3)) + .unwrap(); + responder_state.read_message(&crypto, &m3).unwrap(); - let mut builder = SessionRecordBuilder::new( - RecordSeq(56), - 1 + std::mem::size_of::() + StreamData::<&VecDeque>::MIN_WIRE_SIZE + payload.len(), - ); - let stream = StreamData { - stream_id: StreamId(4), - offset: 9, - fin: false, - bytes: &payload, - }; - assert!(builder.push_stream_data(&stream)); + let m4 = responder_state + .write_message(&crypto, control_meta(4)) + .unwrap(); + initiator_state.read_message(&crypto, &m4).unwrap(); + + let initiator_final = initiator_state.finalize(&crypto).unwrap(); + let responder_final = responder_state.finalize(&crypto).unwrap(); - let encoded = builder.into_plaintext(); - let decoded = SessionRecord::decode(&encoded).unwrap(); assert_eq!( - decoded, - SessionRecord { - seq: RecordSeq(56), - frames: vec![SessionFrame::StreamData(StreamData { - stream_id: StreamId(4), - offset: 9, - fin: false, - bytes: b"cdefgh".to_vec(), - })], - } + initiator_final.handshake_hash, + responder_final.handshake_hash + ); + assert_eq!(initiator_final.tx_key, responder_final.rx_key); + assert_eq!(initiator_final.rx_key, responder_final.tx_key); + assert_eq!( + initiator_final.tx_connection_id, + responder_final.rx_connection_id ); + assert_eq!( + initiator_final.rx_connection_id, + responder_final.tx_connection_id + ); + assert_eq!(initiator_final.remote_bundle, responder.bundle()); + assert_eq!(responder_final.remote_bundle, initiator.bundle()); } #[test] -fn protocol_record_size_breakdown() { - fn meta(id: u32) -> ControlMeta { - ControlMeta { - control_id: ControlId(id), - valid_until: 1_000, - } - } - - fn header() -> QlHeader { - QlHeader { - sender: XID([1; XID::SIZE]), - recipient: XID([2; XID::SIZE]), - } - } - - fn encrypted(tag: u8, ciphertext_len: usize) -> EncryptedMessage> { - EncryptedMessage { - nonce: Nonce([tag; Nonce::SIZE]), - auth: [tag; ENCRYPTED_MESSAGE_AUTH_SIZE], - ciphertext: vec![tag; ciphertext_len], - } - } +fn kk_handshake_round_trip_derives_matching_transport() { + let crypto = TestCrypto::new(20); + let initiator = make_identity(&crypto, 3); + let responder = make_identity(&crypto, 4); + + let mut initiator_state = + KkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); + let mut responder_state = + KkHandshake::new_responder(&crypto, responder.clone(), initiator.bundle()); + + let m1 = initiator_state + .write_message(&crypto, control_meta(11)) + .unwrap(); + responder_state.read_message(&crypto, &m1).unwrap(); - fn session_record(header: QlHeader, tag: u8, body: SessionRecord) -> QlRecord> { - let ciphertext_len = body.encode().len(); - QlRecord { - header, - payload: QlPayload::Session(encrypted(tag, ciphertext_len)), - } - } + let m2 = responder_state + .write_message(&crypto, control_meta(12)) + .unwrap(); + initiator_state.read_message(&crypto, &m2).unwrap(); - let header = header(); - let hello: QlRecord> = QlRecord { - header, - payload: QlPayload::Hello(handshake::Hello { - meta: meta(1), - nonce: Nonce([3; Nonce::SIZE]), - kem_ct: MlKemCiphertext::from_data([4; MlKemCiphertext::SIZE]), - signature: MlDsaSignature::from_data([5; MlDsaSignature::SIZE]), - }), - }; - let hello_reply: QlRecord> = QlRecord { - header, - payload: QlPayload::HelloReply(handshake::HelloReply { - meta: meta(2), - nonce: Nonce([6; Nonce::SIZE]), - kem_ct: MlKemCiphertext::from_data([7; MlKemCiphertext::SIZE]), - signature: MlDsaSignature::from_data([8; MlDsaSignature::SIZE]), - }), - }; - let confirm: QlRecord> = QlRecord { - header, - payload: QlPayload::Confirm(handshake::Confirm { - meta: meta(3), - signature: MlDsaSignature::from_data([9; MlDsaSignature::SIZE]), - }), - }; - let pair_request: QlRecord> = QlRecord { - header, - payload: QlPayload::PairRequest(pair::PairRequestRecord { - kem_ct: MlKemCiphertext::from_data([10; MlKemCiphertext::SIZE]), - encrypted: encrypted(11, 0), - }), - }; - let unpair: QlRecord> = QlRecord { - header, - payload: QlPayload::Unpair(unpair::Unpair { - meta: meta(4), - signature: MlDsaSignature::from_data([12; MlDsaSignature::SIZE]), - }), - }; - let ready: QlRecord> = QlRecord { - header, - payload: QlPayload::Ready(handshake::Ready { - encrypted: encrypted(13, 0), - }), - }; + let initiator_final = initiator_state.finalize(&crypto).unwrap(); + let responder_final = responder_state.finalize(&crypto).unwrap(); - let session_ping = session_record( - header, - 14, - SessionRecord { - seq: RecordSeq(1), - frames: vec![SessionFrame::Ping], - }, - ); - let session_ack = session_record( - header, - 15, - SessionRecord { - seq: RecordSeq(2), - frames: vec![SessionFrame::Ack(RecordAck { - ranges: vec![RecordAckRange { start: 1, end: 3 }], - })], - }, + assert_eq!( + initiator_final.handshake_hash, + responder_final.handshake_hash ); - let session_stream_window = session_record( - header, - 16, - SessionRecord { - seq: RecordSeq(3), - frames: vec![SessionFrame::StreamWindow(StreamWindow { - stream_id: StreamId(1), - maximum_offset: 65_536, - })], - }, + assert_eq!(initiator_final.tx_key, responder_final.rx_key); + assert_eq!(initiator_final.rx_key, responder_final.tx_key); + assert_eq!( + initiator_final.tx_connection_id, + responder_final.rx_connection_id ); - let session_stream_empty = session_record( - header, - 18, - SessionRecord { - seq: RecordSeq(4), - frames: vec![SessionFrame::StreamData(StreamData { - stream_id: StreamId(1), - offset: 0, - fin: false, - bytes: Vec::new(), - })], - }, + assert_eq!( + initiator_final.rx_connection_id, + responder_final.tx_connection_id ); - let session_stream_fin = session_record( - header, - 19, - SessionRecord { - seq: RecordSeq(5), - frames: vec![SessionFrame::StreamData(StreamData { - stream_id: StreamId(1), - offset: 0, + assert_eq!(initiator_final.remote_bundle, responder.bundle()); + assert_eq!(responder_final.remote_bundle, initiator.bundle()); +} + +#[test] +fn encrypted_session_record_round_trip_uses_connection_id_header() { + let crypto = TestCrypto::new(30); + let header = SessionHeader { + connection_id: ConnectionId::from_data([0x44; ConnectionId::SIZE]), + }; + let body = SessionRecord { + seq: RecordSeq(11), + frames: vec![ + SessionFrame::Ping, + SessionFrame::Ack(RecordAck { + ranges: vec![ + RecordAckRange { start: 12, end: 14 }, + RecordAckRange { start: 20, end: 24 }, + ], + }), + SessionFrame::StreamWindow(StreamWindow { + stream_id: StreamId(9), + maximum_offset: 65_536, + }), + SessionFrame::StreamData(StreamData { + stream_id: StreamId(9), + offset: 1024, + bytes: b"hello".to_vec(), fin: true, - bytes: Vec::new(), - })], - }, - ); - let session_stream_close = session_record( - header, - 20, - SessionRecord { - seq: RecordSeq(6), - frames: vec![SessionFrame::StreamClose(StreamClose { - stream_id: StreamId(1), + }), + SessionFrame::StreamClose(StreamClose { + stream_id: StreamId(9), target: CloseTarget::Both, code: CloseCode::PROTOCOL, - })], - }, - ); - let session_close = session_record( + }), + SessionFrame::Close(SessionCloseBody { + code: CloseCode::TIMEOUT, + }), + ], + }; + let session_key = SessionKey::from_data([7; SessionKey::SIZE]); + let record = encrypted::encrypt_record( + &crypto, header, - 21, - SessionRecord { - seq: RecordSeq(7), - frames: vec![SessionFrame::Close(SessionCloseBody { - code: CloseCode::PROTOCOL, - })], - }, + &session_key, + &body, + Nonce([8; Nonce::SIZE]), ); - let print_size = |label: &str, size: usize| { - println!("{label:<32}: {size} bytes"); + let bytes = record.encode(); + let decoded = QlRecord::decode(&bytes).unwrap(); + let QlRecord::Session(decoded) = decoded else { + panic!("expected session payload"); }; + assert_eq!(decoded.header, header); + let encrypted = decoded.payload; - print_size("ql-wire hello", hello.encode().len()); - print_size("ql-wire hello_reply", hello_reply.encode().len()); - print_size("ql-wire confirm", confirm.encode().len()); - print_size("ql-wire pair_request empty", pair_request.encode().len()); - print_size("ql-wire unpair", unpair.encode().len()); - print_size("ql-wire ready empty", ready.encode().len()); - print_size("ql-wire session ping", session_ping.encode().len()); - print_size("ql-wire session ack", session_ack.encode().len()); - print_size( - "ql-wire session stream window", - session_stream_window.encode().len(), - ); - print_size( - "ql-wire session stream empty", - session_stream_empty.encode().len(), - ); - print_size( - "ql-wire session stream fin", - session_stream_fin.encode().len(), - ); - print_size( - "ql-wire session stream close", - session_stream_close.encode().len(), - ); - print_size("ql-wire session close", session_close.encode().len()); -} + let decrypted = + encrypted::decrypt_record(&crypto, &header, encrypted.clone(), &session_key).unwrap(); + assert_eq!(SessionRecord::decode(&decrypted).unwrap(), body); -fn frame_name(frame: &SessionFrame<&[u8]>) -> &'static str { - match frame { - SessionFrame::Ping => "ping", - SessionFrame::Ack(_) => "ack", - SessionFrame::StreamData(_) => "stream_data", - SessionFrame::StreamWindow(_) => "stream_window", - SessionFrame::StreamClose(_) => "stream_close", - SessionFrame::Close(_) => "close", - } + let decoded = QlSessionRecord::decode(&bytes).unwrap(); + assert_eq!(decoded.header, header); + + let wrong_header = SessionHeader { + connection_id: ConnectionId::from_data([0x99; ConnectionId::SIZE]), + }; + assert_eq!( + encrypted::decrypt_record(&crypto, &wrong_header, encrypted, &session_key), + Err(WireError::DecryptFailed) + ); } diff --git a/ql-wire/src/unpair/crypto.rs b/ql-wire/src/unpair/crypto.rs deleted file mode 100644 index f4849387..00000000 --- a/ql-wire/src/unpair/crypto.rs +++ /dev/null @@ -1,61 +0,0 @@ -use super::Unpair; -use crate::{ - ControlMeta, MlDsaPublicKey, QlCrypto, QlHeader, QlIdentity, QlPayload, QlRecord, WireError, - XID, -}; - -pub fn build_unpair( - crypto: &impl QlCrypto, - identity: &QlIdentity, - recipient: XID, - meta: ControlMeta, -) -> QlRecord> { - let header = QlHeader { - sender: identity.xid, - recipient, - }; - let signature = identity - .signing_private_key - .sign(crypto, &hash_unpair_signature_data(crypto, &header, &meta)); - QlRecord { - header, - payload: QlPayload::Unpair(Unpair { meta, signature }), - } -} - -pub fn verify_unpair( - crypto: &impl QlCrypto, - header: &QlHeader, - signer: &MlDsaPublicKey, - unpair: &Unpair, - now_seconds: u64, -) -> Result<(), WireError> { - unpair.meta.ensure_not_expired(now_seconds)?; - if signer.verify_bytes( - unpair.signature.as_bytes(), - &hash_unpair_signature_data(crypto, header, &unpair.meta), - ) { - Ok(()) - } else { - Err(WireError::InvalidSignature) - } -} - -fn hash_unpair_signature_data( - crypto: &impl QlCrypto, - header: &QlHeader, - meta: &ControlMeta, -) -> [u8; 32] { - let aad = header.aad(); - let control_id = meta.control_id.0.to_le_bytes(); - let valid_until = meta.valid_until.to_le_bytes(); - crypto.hash(&[ - b"ql-wire:unpair:v1", - b"aad", - &aad, - b"control-id", - &control_id, - b"valid-until", - &valid_until, - ]) -} diff --git a/ql-wire/src/unpair/mod.rs b/ql-wire/src/unpair/mod.rs deleted file mode 100644 index 593c3b7b..00000000 --- a/ql-wire/src/unpair/mod.rs +++ /dev/null @@ -1,29 +0,0 @@ -use crate::{codec, ControlMeta, MlDsaSignature, WireError}; - -mod crypto; -pub use crypto::*; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Unpair { - pub meta: ControlMeta, - pub signature: MlDsaSignature, -} - -impl Unpair { - pub const WIRE_SIZE: usize = ControlMeta::ENCODED_LEN + MlDsaSignature::SIZE; - - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - let unpair = Self { - meta: ControlMeta::decode_from(&mut reader)?, - signature: MlDsaSignature::from_data(reader.take_array()?), - }; - reader.finish()?; - Ok(unpair) - } - - pub fn encode_into(&self, out: &mut Vec) { - self.meta.encode_into(out); - codec::push_bytes(out, self.signature.as_bytes()); - } -} diff --git a/ql-wire/src/x25519.rs b/ql-wire/src/x25519.rs new file mode 100644 index 00000000..647e719e --- /dev/null +++ b/ql-wire/src/x25519.rs @@ -0,0 +1,47 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct X25519PrivateKey([u8; Self::SIZE]); + +impl X25519PrivateKey { + pub const SIZE: usize = 32; + + pub const fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(data) + } + + pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { + &self.0 + } +} + +impl AsRef<[u8]> for X25519PrivateKey { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct X25519PublicKey([u8; Self::SIZE]); + +impl X25519PublicKey { + pub const SIZE: usize = 32; + + pub const fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(data) + } + + pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { + &self.0 + } +} + +impl AsRef<[u8]> for X25519PublicKey { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct X25519KeyPair { + pub private: X25519PrivateKey, + pub public: X25519PublicKey, +} From 1bf27cafae56bfdb6a402f739d551fc29842ca70 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 29 Mar 2026 15:42:35 -0400 Subject: [PATCH 053/304] ql-wire: remove X25519 --- ql-wire/src/crypto.rs | 16 ++------ ql-wire/src/handshake/kk.rs | 79 ++++-------------------------------- ql-wire/src/handshake/mod.rs | 39 +++++++----------- ql-wire/src/handshake/xx.rs | 75 ++++------------------------------ ql-wire/src/identity.rs | 33 ++------------- ql-wire/src/lib.rs | 2 - ql-wire/src/tests.rs | 28 +------------ ql-wire/src/x25519.rs | 47 --------------------- 8 files changed, 38 insertions(+), 281 deletions(-) delete mode 100644 ql-wire/src/x25519.rs diff --git a/ql-wire/src/crypto.rs b/ql-wire/src/crypto.rs index 503c185c..96ace383 100644 --- a/ql-wire/src/crypto.rs +++ b/ql-wire/src/crypto.rs @@ -1,6 +1,6 @@ use crate::{ MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, SessionKey, - X25519KeyPair, X25519PrivateKey, X25519PublicKey, ENCRYPTED_MESSAGE_AUTH_SIZE, + ENCRYPTED_MESSAGE_AUTH_SIZE, }; pub trait QlRandom { @@ -30,16 +30,6 @@ pub trait QlAead { ) -> bool; } -pub trait QlDh { - fn x25519_generate_keypair(&self) -> X25519KeyPair; - - fn x25519_agree( - &self, - private_key: &X25519PrivateKey, - public_key: &X25519PublicKey, - ) -> SessionKey; -} - pub trait QlKem { fn mlkem_generate_keypair(&self) -> MlKemKeyPair; @@ -52,6 +42,6 @@ pub trait QlKem { ) -> SessionKey; } -pub trait QlCrypto: QlRandom + QlHash + QlAead + QlDh + QlKem {} +pub trait QlCrypto: QlRandom + QlHash + QlAead + QlKem {} -impl QlCrypto for T where T: QlRandom + QlHash + QlAead + QlDh + QlKem {} +impl QlCrypto for T where T: QlRandom + QlHash + QlAead + QlKem {} diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index c92ba3a4..5c71b0a4 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -1,7 +1,7 @@ use super::{ decrypt_mlkem_ciphertext, encrypt_mlkem_ciphertext, finalize_handshake, generate_ephemeral_keypair, init_kk_symmetric, mix_hash_ephemeral, EncryptedMlKemCiphertext, - FinalizedHandshake, HybridEphemeralKeyPair, HybridEphemeralPublic, Role, SymmetricState, + EphemeralKeyPair, EphemeralPublicKey, FinalizedHandshake, Role, SymmetricState, ENCRYPTED_MLKEM_CIPHERTEXT_LEN, }; use crate::{codec, ControlMeta, MlKemCiphertext, PeerBundle, QlCrypto, QlIdentity, WireError}; @@ -10,12 +10,12 @@ use crate::{codec, ControlMeta, MlKemCiphertext, PeerBundle, QlCrypto, QlIdentit pub struct Kk1 { pub meta: ControlMeta, pub skem_ciphertext: MlKemCiphertext, - pub ephemeral: HybridEphemeralPublic, + pub ephemeral: EphemeralPublicKey, } impl Kk1 { pub const ENCODED_LEN: usize = - ControlMeta::ENCODED_LEN + MlKemCiphertext::SIZE + HybridEphemeralPublic::ENCODED_LEN; + ControlMeta::ENCODED_LEN + MlKemCiphertext::SIZE + EphemeralPublicKey::ENCODED_LEN; pub fn encode_into(&self, out: &mut Vec) { self.meta.encode_into(out); @@ -28,7 +28,7 @@ impl Kk1 { let meta = ControlMeta::decode_from(&mut reader)?; let skem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); let ephemeral = - HybridEphemeralPublic::decode(&reader.take_bytes(HybridEphemeralPublic::ENCODED_LEN)?)?; + EphemeralPublicKey::decode(&reader.take_bytes(EphemeralPublicKey::ENCODED_LEN)?)?; reader.finish()?; Ok(Self { meta, @@ -43,20 +43,16 @@ pub struct Kk2 { pub meta: ControlMeta, pub ekem_ciphertext: MlKemCiphertext, pub skem_ciphertext: EncryptedMlKemCiphertext, - pub ephemeral: HybridEphemeralPublic, } impl Kk2 { - pub const ENCODED_LEN: usize = ControlMeta::ENCODED_LEN - + MlKemCiphertext::SIZE - + ENCRYPTED_MLKEM_CIPHERTEXT_LEN - + HybridEphemeralPublic::ENCODED_LEN; + pub const ENCODED_LEN: usize = + ControlMeta::ENCODED_LEN + MlKemCiphertext::SIZE + ENCRYPTED_MLKEM_CIPHERTEXT_LEN; pub fn encode_into(&self, out: &mut Vec) { self.meta.encode_into(out); codec::push_bytes(out, self.ekem_ciphertext.as_bytes()); codec::push_bytes(out, self.skem_ciphertext.as_bytes()); - self.ephemeral.encode_into(out); } pub fn decode(bytes: &[u8]) -> Result { @@ -64,14 +60,11 @@ impl Kk2 { let meta = ControlMeta::decode_from(&mut reader)?; let ekem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); let skem_ciphertext = EncryptedMlKemCiphertext::from_data(reader.take_array()?); - let ephemeral = - HybridEphemeralPublic::decode(&reader.take_bytes(HybridEphemeralPublic::ENCODED_LEN)?)?; reader.finish()?; Ok(Self { meta, ekem_ciphertext, skem_ciphertext, - ephemeral, }) } } @@ -98,8 +91,8 @@ pub struct KkHandshake { symmetric: SymmetricState, local: QlIdentity, remote_bundle: PeerBundle, - local_ephemeral: Option, - remote_ephemeral: Option, + local_ephemeral: Option, + remote_ephemeral: Option, } impl KkHandshake { @@ -159,18 +152,6 @@ impl KkHandshake { let public = local_ephemeral.public(); mix_hash_ephemeral(&mut self.symmetric, crypto, &public); - let es = crypto.x25519_agree( - &local_ephemeral.x25519.private, - &self.remote_bundle.x25519_public_key, - ); - self.symmetric.mix_key(crypto, es.as_bytes()); - - let ss = crypto.x25519_agree( - &self.local.x25519_private_key, - &self.remote_bundle.x25519_public_key, - ); - self.symmetric.mix_key(crypto, ss.as_bytes()); - self.local_ephemeral = Some(local_ephemeral); self.step = KkStep::Recv2; Ok(KkMessage::Message1(Kk1 { @@ -196,29 +177,11 @@ impl KkHandshake { self.symmetric .mix_key_and_hash(crypto, skem_secret.as_bytes()); - let local_ephemeral = generate_ephemeral_keypair(crypto); - let public = local_ephemeral.public(); - mix_hash_ephemeral(&mut self.symmetric, crypto, &public); - - let ee = crypto.x25519_agree( - &local_ephemeral.x25519.private, - &remote_ephemeral.x25519_public_key, - ); - self.symmetric.mix_key(crypto, ee.as_bytes()); - - let se = crypto.x25519_agree( - &local_ephemeral.x25519.private, - &self.remote_bundle.x25519_public_key, - ); - self.symmetric.mix_key(crypto, se.as_bytes()); - - self.local_ephemeral = Some(local_ephemeral); self.step = KkStep::Done; Ok(KkMessage::Message2(Kk2 { meta, ekem_ciphertext, skem_ciphertext, - ephemeral: public, })) } _ => Err(WireError::InvalidState), @@ -241,18 +204,6 @@ impl KkHandshake { mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); self.remote_ephemeral = Some(message.ephemeral.clone()); - - let es = crypto.x25519_agree( - &self.local.x25519_private_key, - &message.ephemeral.x25519_public_key, - ); - self.symmetric.mix_key(crypto, es.as_bytes()); - - let ss = crypto.x25519_agree( - &self.local.x25519_private_key, - &self.remote_bundle.x25519_public_key, - ); - self.symmetric.mix_key(crypto, ss.as_bytes()); self.step = KkStep::Send2; Ok(()) } @@ -277,20 +228,6 @@ impl KkHandshake { self.symmetric .mix_key_and_hash(crypto, skem_secret.as_bytes()); - mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); - self.remote_ephemeral = Some(message.ephemeral.clone()); - - let ee = crypto.x25519_agree( - &local_ephemeral.x25519.private, - &message.ephemeral.x25519_public_key, - ); - self.symmetric.mix_key(crypto, ee.as_bytes()); - - let se = crypto.x25519_agree( - &self.local.x25519_private_key, - &message.ephemeral.x25519_public_key, - ); - self.symmetric.mix_key(crypto, se.as_bytes()); self.step = KkStep::Done; Ok(()) } diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index a00fbdf4..7ae69e65 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -1,6 +1,6 @@ use crate::{ codec, ConnectionId, MlKemCiphertext, MlKemKeyPair, MlKemPublicKey, Nonce, PeerBundle, - QlCrypto, SessionKey, WireError, X25519KeyPair, X25519PublicKey, ENCRYPTED_MESSAGE_AUTH_SIZE, + QlCrypto, SessionKey, WireError, ENCRYPTED_MESSAGE_AUTH_SIZE, }; mod kk; @@ -10,8 +10,8 @@ pub use kk::{Kk1, Kk2, KkHandshake, KkMessage}; pub use xx::{Xx1, Xx2, Xx3, Xx4, XxHandshake, XxMessage}; const SHA256_BLOCK_LEN: usize = 64; -const PROTOCOL_XX: &[u8] = b"ql-wire:hybrid-xx:v1"; -const PROTOCOL_KK: &[u8] = b"ql-wire:hybrid-kk:v1"; +const PROTOCOL_XX: &[u8] = b"ql-wire:pq-xx:v1"; +const PROTOCOL_KK: &[u8] = b"ql-wire:pq-kk:v1"; const CONNECTION_ID_DOMAIN: &[u8] = b"ql-wire:conn-id:v1"; pub const ENCRYPTED_MLKEM_CIPHERTEXT_LEN: usize = @@ -19,23 +19,20 @@ pub const ENCRYPTED_MLKEM_CIPHERTEXT_LEN: usize = pub const ENCRYPTED_PEER_BUNDLE_LEN: usize = PeerBundle::ENCODED_LEN + ENCRYPTED_MESSAGE_AUTH_SIZE; #[derive(Debug, Clone, PartialEq, Eq)] -pub struct HybridEphemeralPublic { - pub x25519_public_key: X25519PublicKey, +pub struct EphemeralPublicKey { pub mlkem_public_key: MlKemPublicKey, } -impl HybridEphemeralPublic { - pub const ENCODED_LEN: usize = X25519PublicKey::SIZE + MlKemPublicKey::SIZE; +impl EphemeralPublicKey { + pub const ENCODED_LEN: usize = MlKemPublicKey::SIZE; pub fn encode_into(&self, out: &mut Vec) { - codec::push_bytes(out, self.x25519_public_key.as_bytes()); codec::push_bytes(out, self.mlkem_public_key.as_bytes()); } pub fn decode(bytes: &[u8]) -> Result { let mut reader = codec::Reader::new(bytes); let value = Self { - x25519_public_key: X25519PublicKey::from_data(reader.take_array()?), mlkem_public_key: MlKemPublicKey::from_data(reader.take_array()?), }; reader.finish()?; @@ -86,15 +83,13 @@ enum Role { } #[derive(Debug, Clone)] -struct HybridEphemeralKeyPair { - x25519: X25519KeyPair, +struct EphemeralKeyPair { mlkem: MlKemKeyPair, } -impl HybridEphemeralKeyPair { - fn public(&self) -> HybridEphemeralPublic { - HybridEphemeralPublic { - x25519_public_key: self.x25519.public, +impl EphemeralKeyPair { + fn public(&self) -> EphemeralPublicKey { + EphemeralPublicKey { mlkem_public_key: self.mlkem.public.clone(), } } @@ -231,11 +226,7 @@ impl SymmetricState { } } - fn split_for_role( - &self, - crypto: &impl QlCrypto, - role: Role, - ) -> (SessionKey, SessionKey) { + fn split_for_role(&self, crypto: &impl QlCrypto, role: Role) -> (SessionKey, SessionKey) { let temp_key = hmac_sha256(crypto, &self.chaining_key, &[&[]]); let k1 = SessionKey::from_data(hmac_sha256(crypto, &temp_key, &[&[1]])); let k2 = SessionKey::from_data(hmac_sha256(crypto, &temp_key, &[k1.as_bytes(), &[2]])); @@ -257,9 +248,8 @@ fn init_kk_symmetric( symmetric } -fn generate_ephemeral_keypair(crypto: &impl QlCrypto) -> HybridEphemeralKeyPair { - HybridEphemeralKeyPair { - x25519: crypto.x25519_generate_keypair(), +fn generate_ephemeral_keypair(crypto: &impl QlCrypto) -> EphemeralKeyPair { + EphemeralKeyPair { mlkem: crypto.mlkem_generate_keypair(), } } @@ -267,9 +257,8 @@ fn generate_ephemeral_keypair(crypto: &impl QlCrypto) -> HybridEphemeralKeyPair fn mix_hash_ephemeral( symmetric: &mut SymmetricState, crypto: &impl QlCrypto, - public: &HybridEphemeralPublic, + public: &EphemeralPublicKey, ) { - symmetric.mix_hash(crypto, public.x25519_public_key.as_bytes()); symmetric.mix_hash(crypto, public.mlkem_public_key.as_bytes()); } diff --git a/ql-wire/src/handshake/xx.rs b/ql-wire/src/handshake/xx.rs index b3428e5b..e77c462d 100644 --- a/ql-wire/src/handshake/xx.rs +++ b/ql-wire/src/handshake/xx.rs @@ -1,7 +1,7 @@ use super::{ decrypt_mlkem_ciphertext, decrypt_peer_bundle, encrypt_mlkem_ciphertext, encrypt_peer_bundle, finalize_handshake, generate_ephemeral_keypair, mix_hash_ephemeral, EncryptedMlKemCiphertext, - EncryptedPeerBundle, FinalizedHandshake, HybridEphemeralKeyPair, HybridEphemeralPublic, Role, + EncryptedPeerBundle, EphemeralKeyPair, EphemeralPublicKey, FinalizedHandshake, Role, SymmetricState, ENCRYPTED_MLKEM_CIPHERTEXT_LEN, ENCRYPTED_PEER_BUNDLE_LEN, PROTOCOL_XX, }; use crate::{codec, ControlMeta, MlKemCiphertext, PeerBundle, QlCrypto, QlIdentity, WireError}; @@ -9,11 +9,11 @@ use crate::{codec, ControlMeta, MlKemCiphertext, PeerBundle, QlCrypto, QlIdentit #[derive(Debug, Clone, PartialEq, Eq)] pub struct Xx1 { pub meta: ControlMeta, - pub ephemeral: HybridEphemeralPublic, + pub ephemeral: EphemeralPublicKey, } impl Xx1 { - pub const ENCODED_LEN: usize = ControlMeta::ENCODED_LEN + HybridEphemeralPublic::ENCODED_LEN; + pub const ENCODED_LEN: usize = ControlMeta::ENCODED_LEN + EphemeralPublicKey::ENCODED_LEN; pub fn encode_into(&self, out: &mut Vec) { self.meta.encode_into(out); @@ -24,7 +24,7 @@ impl Xx1 { let mut reader = codec::Reader::new(bytes); let meta = ControlMeta::decode_from(&mut reader)?; let ephemeral = - HybridEphemeralPublic::decode(&reader.take_bytes(HybridEphemeralPublic::ENCODED_LEN)?)?; + EphemeralPublicKey::decode(&reader.take_bytes(EphemeralPublicKey::ENCODED_LEN)?)?; reader.finish()?; Ok(Self { meta, ephemeral }) } @@ -34,20 +34,16 @@ impl Xx1 { pub struct Xx2 { pub meta: ControlMeta, pub ekem_ciphertext: MlKemCiphertext, - pub ephemeral: HybridEphemeralPublic, pub static_bundle: EncryptedPeerBundle, } impl Xx2 { - pub const ENCODED_LEN: usize = ControlMeta::ENCODED_LEN - + MlKemCiphertext::SIZE - + HybridEphemeralPublic::ENCODED_LEN - + ENCRYPTED_PEER_BUNDLE_LEN; + pub const ENCODED_LEN: usize = + ControlMeta::ENCODED_LEN + MlKemCiphertext::SIZE + ENCRYPTED_PEER_BUNDLE_LEN; pub fn encode_into(&self, out: &mut Vec) { self.meta.encode_into(out); codec::push_bytes(out, self.ekem_ciphertext.as_bytes()); - self.ephemeral.encode_into(out); codec::push_bytes(out, self.static_bundle.as_bytes()); } @@ -55,14 +51,11 @@ impl Xx2 { let mut reader = codec::Reader::new(bytes); let meta = ControlMeta::decode_from(&mut reader)?; let ekem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); - let ephemeral = - HybridEphemeralPublic::decode(&reader.take_bytes(HybridEphemeralPublic::ENCODED_LEN)?)?; let static_bundle = EncryptedPeerBundle::from_data(reader.take_array()?); reader.finish()?; Ok(Self { meta, ekem_ciphertext, - ephemeral, static_bundle, }) } @@ -152,8 +145,8 @@ pub struct XxHandshake { step: XxStep, symmetric: SymmetricState, local: QlIdentity, - local_ephemeral: Option, - remote_ephemeral: Option, + local_ephemeral: Option, + remote_ephemeral: Option, remote_bundle: Option, } @@ -213,41 +206,18 @@ impl XxHandshake { self.symmetric.mix_hash(crypto, ekem_ciphertext.as_bytes()); self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); - let local_ephemeral = generate_ephemeral_keypair(crypto); - let public = local_ephemeral.public(); - mix_hash_ephemeral(&mut self.symmetric, crypto, &public); - let ee = crypto.x25519_agree( - &local_ephemeral.x25519.private, - &remote_ephemeral.x25519_public_key, - ); - self.symmetric.mix_key(crypto, ee.as_bytes()); - let static_bundle = encrypt_peer_bundle(crypto, &mut self.symmetric, &self.local.bundle())?; - let es = crypto.x25519_agree( - &self.local.x25519_private_key, - &remote_ephemeral.x25519_public_key, - ); - self.symmetric.mix_key(crypto, es.as_bytes()); - - self.local_ephemeral = Some(local_ephemeral); self.step = XxStep::Recv3; - Ok(XxMessage::Message2(Xx2 { meta, ekem_ciphertext, - ephemeral: public, static_bundle, })) } XxStep::Send3 => { let remote_bundle = self.remote_bundle.clone().ok_or(WireError::InvalidState)?; - let remote_ephemeral = self - .remote_ephemeral - .as_ref() - .ok_or(WireError::InvalidState)?; - let (skem_ciphertext, skem_secret) = crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); let skem_ciphertext = @@ -258,13 +228,7 @@ impl XxHandshake { let static_bundle = encrypt_peer_bundle(crypto, &mut self.symmetric, &self.local.bundle())?; - let se = crypto.x25519_agree( - &self.local.x25519_private_key, - &remote_ephemeral.x25519_public_key, - ); - self.symmetric.mix_key(crypto, se.as_bytes()); self.step = XxStep::Recv4; - Ok(XxMessage::Message3(Xx3 { meta, skem_ciphertext, @@ -313,22 +277,8 @@ impl XxHandshake { .mlkem_decapsulate(&local_ephemeral.mlkem.private, &message.ekem_ciphertext); self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); - mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); - self.remote_ephemeral = Some(message.ephemeral.clone()); - - let ee = crypto.x25519_agree( - &local_ephemeral.x25519.private, - &message.ephemeral.x25519_public_key, - ); - self.symmetric.mix_key(crypto, ee.as_bytes()); - let remote_bundle = decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; - let es = crypto.x25519_agree( - &local_ephemeral.x25519.private, - &remote_bundle.x25519_public_key, - ); - self.symmetric.mix_key(crypto, es.as_bytes()); self.remote_bundle = Some(remote_bundle); self.step = XxStep::Send3; Ok(()) @@ -346,15 +296,6 @@ impl XxHandshake { let remote_bundle = decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; - let local_ephemeral = self - .local_ephemeral - .as_ref() - .ok_or(WireError::InvalidState)?; - let se = crypto.x25519_agree( - &local_ephemeral.x25519.private, - &remote_bundle.x25519_public_key, - ); - self.symmetric.mix_key(crypto, se.as_bytes()); self.remote_bundle = Some(remote_bundle); self.step = XxStep::Send4; Ok(()) diff --git a/ql-wire/src/identity.rs b/ql-wire/src/identity.rs index 60e55f75..b4d72355 100644 --- a/ql-wire/src/identity.rs +++ b/ql-wire/src/identity.rs @@ -1,27 +1,20 @@ -use crate::{ - codec, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, QlCrypto, WireError, X25519KeyPair, - X25519PrivateKey, X25519PublicKey, XID, -}; +use crate::{codec, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, QlCrypto, WireError, XID}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct PeerBundle { pub version: u16, pub capabilities: u32, - pub x25519_public_key: X25519PublicKey, pub mlkem_public_key: MlKemPublicKey, } impl PeerBundle { pub const VERSION: u16 = 1; - pub const ENCODED_LEN: usize = core::mem::size_of::() - + core::mem::size_of::() - + X25519PublicKey::SIZE - + MlKemPublicKey::SIZE; + pub const ENCODED_LEN: usize = + core::mem::size_of::() + core::mem::size_of::() + MlKemPublicKey::SIZE; pub fn encode_into(&self, out: &mut Vec) { codec::push_u16(out, self.version); codec::push_u32(out, self.capabilities); - codec::push_bytes(out, self.x25519_public_key.as_bytes()); codec::push_bytes(out, self.mlkem_public_key.as_bytes()); } @@ -36,7 +29,6 @@ impl PeerBundle { let bundle = Self { version: reader.take_u16()?, capabilities: reader.take_u32()?, - x25519_public_key: X25519PublicKey::from_data(reader.take_array()?), mlkem_public_key: MlKemPublicKey::from_data(reader.take_array()?), }; reader.finish()?; @@ -47,8 +39,6 @@ impl PeerBundle { #[derive(Debug, Clone)] pub struct QlIdentity { pub xid: XID, - pub x25519_private_key: X25519PrivateKey, - pub x25519_public_key: X25519PublicKey, pub mlkem_private_key: MlKemPrivateKey, pub mlkem_public_key: MlKemPublicKey, pub capabilities: u32, @@ -57,15 +47,11 @@ pub struct QlIdentity { impl QlIdentity { pub fn new( xid: XID, - x25519_private_key: X25519PrivateKey, - x25519_public_key: X25519PublicKey, mlkem_private_key: MlKemPrivateKey, mlkem_public_key: MlKemPublicKey, ) -> Self { Self { xid, - x25519_private_key, - x25519_public_key, mlkem_private_key, mlkem_public_key, capabilities: 0, @@ -81,26 +67,15 @@ impl QlIdentity { PeerBundle { version: PeerBundle::VERSION, capabilities: self.capabilities, - x25519_public_key: self.x25519_public_key, mlkem_public_key: self.mlkem_public_key.clone(), } } } pub fn generate_identity(crypto: &impl QlCrypto, xid: XID) -> QlIdentity { - let X25519KeyPair { - private: x25519_private_key, - public: x25519_public_key, - } = crypto.x25519_generate_keypair(); let MlKemKeyPair { private: mlkem_private_key, public: mlkem_public_key, } = crypto.mlkem_generate_keypair(); - QlIdentity::new( - xid, - x25519_private_key, - x25519_public_key, - mlkem_private_key, - mlkem_public_key, - ) + QlIdentity::new(xid, mlkem_private_key, mlkem_public_key) } diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs index 4a9a2c27..3a688850 100644 --- a/ql-wire/src/lib.rs +++ b/ql-wire/src/lib.rs @@ -17,7 +17,6 @@ mod identity; mod nonce; mod pq; mod record; -mod x25519; mod xid; pub use bytes::*; @@ -32,7 +31,6 @@ pub use identity::*; pub use nonce::*; pub use pq::*; pub use record::*; -pub use x25519::*; pub use xid::*; pub const QL_WIRE_VERSION: u8 = 1; diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 056ff205..5ac5c4f6 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -71,31 +71,6 @@ impl QlAead for TestCrypto { } } -impl QlDh for TestCrypto { - fn x25519_generate_keypair(&self) -> X25519KeyPair { - let private = self.next_block(); - X25519KeyPair { - private: X25519PrivateKey::from_data(private), - public: X25519PublicKey::from_data(private), - } - } - - fn x25519_agree( - &self, - private_key: &X25519PrivateKey, - public_key: &X25519PublicKey, - ) -> SessionKey { - let left = *private_key.as_bytes(); - let right = *public_key.as_bytes(); - let (first, second) = if left <= right { - (left, right) - } else { - (right, left) - }; - SessionKey::from_data(self.sha256(&[b"ql-wire:test-x25519:v1", &first, &second])) - } -} - impl QlKem for TestCrypto { fn mlkem_generate_keypair(&self) -> MlKemKeyPair { let seed = self.next_block(); @@ -208,8 +183,7 @@ fn peer_bundle_round_trip() { fn handshake_record_round_trip_uses_handshake_header() { let message = Xx1 { meta: control_meta(1), - ephemeral: HybridEphemeralPublic { - x25519_public_key: X25519PublicKey::from_data([3; X25519PublicKey::SIZE]), + ephemeral: EphemeralPublicKey { mlkem_public_key: MlKemPublicKey::from_data([9; MlKemPublicKey::SIZE]), }, }; diff --git a/ql-wire/src/x25519.rs b/ql-wire/src/x25519.rs deleted file mode 100644 index 647e719e..00000000 --- a/ql-wire/src/x25519.rs +++ /dev/null @@ -1,47 +0,0 @@ -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct X25519PrivateKey([u8; Self::SIZE]); - -impl X25519PrivateKey { - pub const SIZE: usize = 32; - - pub const fn from_data(data: [u8; Self::SIZE]) -> Self { - Self(data) - } - - pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { - &self.0 - } -} - -impl AsRef<[u8]> for X25519PrivateKey { - fn as_ref(&self) -> &[u8] { - &self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct X25519PublicKey([u8; Self::SIZE]); - -impl X25519PublicKey { - pub const SIZE: usize = 32; - - pub const fn from_data(data: [u8; Self::SIZE]) -> Self { - Self(data) - } - - pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { - &self.0 - } -} - -impl AsRef<[u8]> for X25519PublicKey { - fn as_ref(&self) -> &[u8] { - &self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct X25519KeyPair { - pub private: X25519PrivateKey, - pub public: X25519PublicKey, -} From 98c9dfac1df1af251dc0fcc91b0bb7547c537f50 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 29 Mar 2026 15:48:56 -0400 Subject: [PATCH 054/304] ql-wire: add back size tests --- ql-wire/src/tests.rs | 143 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 5ac5c4f6..54ff04cc 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -167,6 +167,24 @@ fn make_identity(crypto: &impl QlCrypto, byte: u8) -> QlIdentity { generate_identity(crypto, xid(byte)) } +fn xx_record(header: HandshakeHeader, message: XxMessage) -> QlHandshakeRecord { + let payload = match message { + XxMessage::Message1(message) => HandshakePayload::Xx1(message), + XxMessage::Message2(message) => HandshakePayload::Xx2(message), + XxMessage::Message3(message) => HandshakePayload::Xx3(message), + XxMessage::Message4(message) => HandshakePayload::Xx4(message), + }; + QlHandshakeRecord { header, payload } +} + +fn kk_record(header: HandshakeHeader, message: KkMessage) -> QlHandshakeRecord { + let payload = match message { + KkMessage::Message1(message) => HandshakePayload::Kk1(message), + KkMessage::Message2(message) => HandshakePayload::Kk2(message), + }; + QlHandshakeRecord { header, payload } +} + #[test] fn peer_bundle_round_trip() { let crypto = TestCrypto::new(1); @@ -364,3 +382,128 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { Err(WireError::DecryptFailed) ); } + +#[test] +fn protocol_record_size_breakdown() { + fn handshake_header(sender: u8, recipient: u8) -> HandshakeHeader { + HandshakeHeader { + sender: xid(sender), + recipient: xid(recipient), + } + } + + fn print_size(label: &str, size: usize) { + println!("{label:<32}: {size} bytes"); + } + + let crypto = TestCrypto::new(40); + let initiator = make_identity(&crypto, 1); + let responder = make_identity(&crypto, 2); + + let mut xx_initiator = XxHandshake::new_initiator(&crypto, initiator.clone()); + let mut xx_responder = XxHandshake::new_responder(&crypto, responder.clone()); + + let xx1 = xx_initiator + .write_message(&crypto, control_meta(101)) + .unwrap(); + xx_responder.read_message(&crypto, &xx1).unwrap(); + + let xx2 = xx_responder + .write_message(&crypto, control_meta(102)) + .unwrap(); + xx_initiator.read_message(&crypto, &xx2).unwrap(); + + let xx3 = xx_initiator + .write_message(&crypto, control_meta(103)) + .unwrap(); + xx_responder.read_message(&crypto, &xx3).unwrap(); + + let xx4 = xx_responder + .write_message(&crypto, control_meta(104)) + .unwrap(); + xx_initiator.read_message(&crypto, &xx4).unwrap(); + + let xx1 = xx_record(handshake_header(1, 2), xx1); + let xx2 = xx_record(handshake_header(2, 1), xx2); + let xx3 = xx_record(handshake_header(1, 2), xx3); + let xx4 = xx_record(handshake_header(2, 1), xx4); + + let mut kk_initiator = + KkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); + let mut kk_responder = + KkHandshake::new_responder(&crypto, responder.clone(), initiator.bundle()); + + let kk1 = kk_initiator + .write_message(&crypto, control_meta(201)) + .unwrap(); + kk_responder.read_message(&crypto, &kk1).unwrap(); + + let kk2 = kk_responder + .write_message(&crypto, control_meta(202)) + .unwrap(); + kk_initiator.read_message(&crypto, &kk2).unwrap(); + + let kk1 = kk_record(handshake_header(1, 2), kk1); + let kk2 = kk_record(handshake_header(2, 1), kk2); + + let session = xx_initiator.finalize(&crypto).unwrap(); + let session_ping = encrypted::encrypt_record( + &crypto, + SessionHeader { + connection_id: session.tx_connection_id, + }, + &session.tx_key, + &SessionRecord { + seq: RecordSeq(1), + frames: vec![SessionFrame::Ping], + }, + Nonce([0x41; Nonce::SIZE]), + ); + let session_stream_empty = encrypted::encrypt_record( + &crypto, + SessionHeader { + connection_id: session.tx_connection_id, + }, + &session.tx_key, + &SessionRecord { + seq: RecordSeq(2), + frames: vec![SessionFrame::StreamData(StreamData { + stream_id: StreamId(1), + offset: 0, + fin: false, + bytes: Vec::new(), + })], + }, + Nonce([0x42; Nonce::SIZE]), + ); + let session_close = encrypted::encrypt_record( + &crypto, + SessionHeader { + connection_id: session.tx_connection_id, + }, + &session.tx_key, + &SessionRecord { + seq: RecordSeq(3), + frames: vec![SessionFrame::Close(SessionCloseBody { + code: CloseCode::PROTOCOL, + })], + }, + Nonce([0x43; Nonce::SIZE]), + ); + + print_size("ql-wire peer bundle", initiator.bundle().encode().len()); + print_size("ql-wire mlkem public key", MlKemPublicKey::SIZE); + print_size("ql-wire mlkem ciphertext", MlKemCiphertext::SIZE); + print_size("ql-wire pq xx1", xx1.encode().len()); + print_size("ql-wire pq xx2", xx2.encode().len()); + print_size("ql-wire pq xx3", xx3.encode().len()); + print_size("ql-wire pq xx4", xx4.encode().len()); + print_size("ql-wire pq kk1", kk1.encode().len()); + print_size("ql-wire pq kk2", kk2.encode().len()); + print_size("ql-wire session ping", session_ping.encode().len()); + print_size( + "ql-wire session stream empty", + session_stream_empty.encode().len(), + ); + print_size("ql-wire session close", session_close.encode().len()); +} From 833d38d1f5ad05d93e92716e1dae396d18daaa31 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 29 Mar 2026 20:02:53 -0400 Subject: [PATCH 055/304] ql-wire: pq no libs + drop impl --- Cargo.lock | 50 ---------------------------------------------- ql-wire/Cargo.toml | 6 ------ ql-wire/src/pq.rs | 50 ++++++++++++++++++++++++++++++++++------------ 3 files changed, 37 insertions(+), 69 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 90dfe15f..b67bce2e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1499,22 +1499,6 @@ dependencies = [ "hax-lib", ] -[[package]] -name = "libcrux-ml-kem" -version = "0.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aca7de713c6dddcf7aaf76e8ef9dc0097c8d7ce23a8eadf04c8761734714e184" -dependencies = [ - "hax-lib", - "libcrux-intrinsics", - "libcrux-platform", - "libcrux-secrets", - "libcrux-sha3", - "libcrux-traits", - "rand 0.9.2", - "tls_codec", -] - [[package]] name = "libcrux-platform" version = "0.0.3" @@ -1533,18 +1517,6 @@ dependencies = [ "hax-lib", ] -[[package]] -name = "libcrux-sha3" -version = "0.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c50f6e04a184511b782c5cc1eb6a227c6d36f2c935e93d698655a93a99696b5" -dependencies = [ - "hax-lib", - "libcrux-intrinsics", - "libcrux-platform", - "libcrux-traits", -] - [[package]] name = "libcrux-traits" version = "0.0.6" @@ -2166,7 +2138,6 @@ name = "ql-wire" version = "0.1.0" dependencies = [ "libcrux-aesgcm", - "libcrux-ml-kem", "sha2", ] @@ -2743,27 +2714,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" -[[package]] -name = "tls_codec" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0de2e01245e2bb89d6f05801c564fa27624dbd7b1846859876c7dad82e90bf6b" -dependencies = [ - "tls_codec_derive", - "zeroize", -] - -[[package]] -name = "tls_codec_derive" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d2e76690929402faae40aebdda620a2c0e25dd6d3b9afe48867dfd95991f4bd" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.106", -] - [[package]] name = "tokio" version = "1.47.1" diff --git a/ql-wire/Cargo.toml b/ql-wire/Cargo.toml index 3a94f700..25bb291d 100644 --- a/ql-wire/Cargo.toml +++ b/ql-wire/Cargo.toml @@ -5,12 +5,6 @@ edition = "2021" description = "Quantum Link wire format types and crypto helpers" license = "Proprietary" -[dependencies] -libcrux-ml-kem = { version = "0.0.7", default-features = false, features = [ - "std", - "mlkem1024", -] } - [dev-dependencies] libcrux-aesgcm = "0.0.7" sha2 = "0.10" diff --git a/ql-wire/src/pq.rs b/ql-wire/src/pq.rs index fa3c133f..ba8753d0 100644 --- a/ql-wire/src/pq.rs +++ b/ql-wire/src/pq.rs @@ -1,14 +1,18 @@ -use libcrux_ml_kem::{mlkem1024, SHARED_SECRET_SIZE}; - -use crate::QlCrypto; - pub const ML_KEM_SUITE_TAG: &[u8] = b"ml-kem-1024"; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +// ql-wire fixes the protocol to ML-KEM-1024 on the wire, but the host +// platform is free to satisfy QlKem with any backend that produces the same +// serialized sizes. +const ML_KEM_1024_SHARED_SECRET_SIZE: usize = 32; +const ML_KEM_1024_PUBLIC_KEY_SIZE: usize = 1568; +const ML_KEM_1024_PRIVATE_KEY_SIZE: usize = 3168; +const ML_KEM_1024_CIPHERTEXT_SIZE: usize = 1568; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct SessionKey([u8; Self::SIZE]); impl SessionKey { - pub const SIZE: usize = SHARED_SECRET_SIZE; + pub const SIZE: usize = ML_KEM_1024_SHARED_SECRET_SIZE; pub const fn from_data(data: [u8; Self::SIZE]) -> Self { Self(data) @@ -29,11 +33,17 @@ impl AsRef<[u8]> for SessionKey { } } +impl Drop for SessionKey { + fn drop(&mut self) { + self.0.fill(0); + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct MlKemPublicKey(Box<[u8; MlKemPublicKey::SIZE]>); impl MlKemPublicKey { - pub const SIZE: usize = mlkem1024::MlKem1024PublicKey::len(); + pub const SIZE: usize = ML_KEM_1024_PUBLIC_KEY_SIZE; pub fn from_data(data: [u8; Self::SIZE]) -> Self { Self(Box::new(data)) @@ -44,11 +54,17 @@ impl MlKemPublicKey { } } +impl Drop for MlKemPublicKey { + fn drop(&mut self) { + self.0.as_mut().fill(0); + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct MlKemPrivateKey(Box<[u8; MlKemPrivateKey::SIZE]>); impl MlKemPrivateKey { - pub const SIZE: usize = mlkem1024::MlKem1024PrivateKey::len(); + pub const SIZE: usize = ML_KEM_1024_PRIVATE_KEY_SIZE; pub fn from_data(data: [u8; Self::SIZE]) -> Self { Self(Box::new(data)) @@ -59,11 +75,17 @@ impl MlKemPrivateKey { } } +impl Drop for MlKemPrivateKey { + fn drop(&mut self) { + self.0.as_mut().fill(0); + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct MlKemCiphertext(Box<[u8; MlKemCiphertext::SIZE]>); impl MlKemCiphertext { - pub const SIZE: usize = mlkem1024::MlKem1024Ciphertext::len(); + pub const SIZE: usize = ML_KEM_1024_CIPHERTEXT_SIZE; pub fn from_data(data: [u8; Self::SIZE]) -> Self { Self(Box::new(data)) @@ -74,12 +96,14 @@ impl MlKemCiphertext { } } +impl Drop for MlKemCiphertext { + fn drop(&mut self) { + self.0.as_mut().fill(0); + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct MlKemKeyPair { pub private: MlKemPrivateKey, pub public: MlKemPublicKey, } - -pub fn generate_ml_kem_keypair(crypto: &impl QlCrypto) -> MlKemKeyPair { - crypto.mlkem_generate_keypair() -} From de80a900adc43d8197e6649fe46c5bf49f1bffa0 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 30 Mar 2026 08:15:56 -0400 Subject: [PATCH 056/304] ql-wire: handshake meta --- ql-wire/src/handshake/kk.rs | 33 ++++++---- ql-wire/src/{control.rs => handshake/meta.rs} | 14 ++--- ql-wire/src/handshake/mod.rs | 39 ++++++++++++ ql-wire/src/handshake/xx.rs | 50 +++++++++++----- ql-wire/src/lib.rs | 2 - ql-wire/src/tests.rs | 60 ++++++++++++++----- 6 files changed, 147 insertions(+), 51 deletions(-) rename ql-wire/src/{control.rs => handshake/meta.rs} (80%) diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index 5c71b0a4..d79823ad 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -1,21 +1,21 @@ use super::{ decrypt_mlkem_ciphertext, encrypt_mlkem_ciphertext, finalize_handshake, - generate_ephemeral_keypair, init_kk_symmetric, mix_hash_ephemeral, EncryptedMlKemCiphertext, - EphemeralKeyPair, EphemeralPublicKey, FinalizedHandshake, Role, SymmetricState, - ENCRYPTED_MLKEM_CIPHERTEXT_LEN, + generate_ephemeral_keypair, init_kk_symmetric, initialize_handshake_meta, mix_hash_ephemeral, + mix_hash_handshake_meta, require_handshake_meta, EncryptedMlKemCiphertext, EphemeralKeyPair, + EphemeralPublicKey, FinalizedHandshake, Role, SymmetricState, ENCRYPTED_MLKEM_CIPHERTEXT_LEN, }; -use crate::{codec, ControlMeta, MlKemCiphertext, PeerBundle, QlCrypto, QlIdentity, WireError}; +use crate::{codec, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, QlIdentity, WireError}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct Kk1 { - pub meta: ControlMeta, + pub meta: HandshakeMeta, pub skem_ciphertext: MlKemCiphertext, pub ephemeral: EphemeralPublicKey, } impl Kk1 { pub const ENCODED_LEN: usize = - ControlMeta::ENCODED_LEN + MlKemCiphertext::SIZE + EphemeralPublicKey::ENCODED_LEN; + HandshakeMeta::ENCODED_LEN + MlKemCiphertext::SIZE + EphemeralPublicKey::ENCODED_LEN; pub fn encode_into(&self, out: &mut Vec) { self.meta.encode_into(out); @@ -25,7 +25,7 @@ impl Kk1 { pub fn decode(bytes: &[u8]) -> Result { let mut reader = codec::Reader::new(bytes); - let meta = ControlMeta::decode_from(&mut reader)?; + let meta = HandshakeMeta::decode_from(&mut reader)?; let skem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); let ephemeral = EphemeralPublicKey::decode(&reader.take_bytes(EphemeralPublicKey::ENCODED_LEN)?)?; @@ -40,14 +40,14 @@ impl Kk1 { #[derive(Debug, Clone, PartialEq, Eq)] pub struct Kk2 { - pub meta: ControlMeta, + pub meta: HandshakeMeta, pub ekem_ciphertext: MlKemCiphertext, pub skem_ciphertext: EncryptedMlKemCiphertext, } impl Kk2 { pub const ENCODED_LEN: usize = - ControlMeta::ENCODED_LEN + MlKemCiphertext::SIZE + ENCRYPTED_MLKEM_CIPHERTEXT_LEN; + HandshakeMeta::ENCODED_LEN + MlKemCiphertext::SIZE + ENCRYPTED_MLKEM_CIPHERTEXT_LEN; pub fn encode_into(&self, out: &mut Vec) { self.meta.encode_into(out); @@ -57,7 +57,7 @@ impl Kk2 { pub fn decode(bytes: &[u8]) -> Result { let mut reader = codec::Reader::new(bytes); - let meta = ControlMeta::decode_from(&mut reader)?; + let meta = HandshakeMeta::decode_from(&mut reader)?; let ekem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); let skem_ciphertext = EncryptedMlKemCiphertext::from_data(reader.take_array()?); reader.finish()?; @@ -93,6 +93,7 @@ pub struct KkHandshake { remote_bundle: PeerBundle, local_ephemeral: Option, remote_ephemeral: Option, + handshake_meta: Option, } impl KkHandshake { @@ -110,6 +111,7 @@ impl KkHandshake { remote_bundle, local_ephemeral: None, remote_ephemeral: None, + handshake_meta: None, } } @@ -127,6 +129,7 @@ impl KkHandshake { remote_bundle, local_ephemeral: None, remote_ephemeral: None, + handshake_meta: None, } } @@ -137,10 +140,12 @@ impl KkHandshake { pub fn write_message( &mut self, crypto: &impl QlCrypto, - meta: ControlMeta, + meta: HandshakeMeta, ) -> Result { match self.step { KkStep::Send1 => { + initialize_handshake_meta(&mut self.handshake_meta, meta)?; + mix_hash_handshake_meta(&mut self.symmetric, crypto, b"kk1", &meta); let (skem_ciphertext, skem_secret) = crypto.mlkem_encapsulate(&self.remote_bundle.mlkem_public_key); self.symmetric @@ -161,6 +166,8 @@ impl KkHandshake { })) } KkStep::Send2 => { + require_handshake_meta(&self.handshake_meta, meta)?; + mix_hash_handshake_meta(&mut self.symmetric, crypto, b"kk2", &meta); let remote_ephemeral = self .remote_ephemeral .clone() @@ -195,6 +202,8 @@ impl KkHandshake { ) -> Result<(), WireError> { match (&self.step, message) { (KkStep::Recv1, KkMessage::Message1(message)) => { + initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; + mix_hash_handshake_meta(&mut self.symmetric, crypto, b"kk1", &message.meta); self.symmetric .decrypt_and_hash(crypto, message.skem_ciphertext.as_bytes())?; let skem_secret = crypto @@ -208,6 +217,8 @@ impl KkHandshake { Ok(()) } (KkStep::Recv2, KkMessage::Message2(message)) => { + require_handshake_meta(&self.handshake_meta, message.meta)?; + mix_hash_handshake_meta(&mut self.symmetric, crypto, b"kk2", &message.meta); let local_ephemeral = self .local_ephemeral .as_ref() diff --git a/ql-wire/src/control.rs b/ql-wire/src/handshake/meta.rs similarity index 80% rename from ql-wire/src/control.rs rename to ql-wire/src/handshake/meta.rs index 17dbfd67..747d7e18 100644 --- a/ql-wire/src/control.rs +++ b/ql-wire/src/handshake/meta.rs @@ -2,15 +2,15 @@ use crate::{codec, WireError}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] -pub struct ControlId(pub u32); +pub struct HandshakeId(pub u32); #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct ControlMeta { - pub control_id: ControlId, +pub struct HandshakeMeta { + pub handshake_id: HandshakeId, pub valid_until: u64, } -impl ControlMeta { +impl HandshakeMeta { pub const ENCODED_LEN: usize = core::mem::size_of::() + core::mem::size_of::(); pub fn ensure_not_expired(&self, now_seconds: u64) -> Result<(), WireError> { @@ -22,7 +22,7 @@ impl ControlMeta { } pub fn encode_into(&self, out: &mut Vec) { - codec::push_u32(out, self.control_id.0); + codec::push_u32(out, self.handshake_id.0); codec::push_u64(out, self.valid_until); } @@ -35,7 +35,7 @@ impl ControlMeta { pub fn decode(bytes: &[u8]) -> Result { let mut reader = codec::Reader::new(bytes); let meta = Self { - control_id: ControlId(reader.take_u32()?), + handshake_id: HandshakeId(reader.take_u32()?), valid_until: reader.take_u64()?, }; reader.finish()?; @@ -46,7 +46,7 @@ impl ControlMeta { reader: &mut codec::Reader, ) -> Result { Ok(Self { - control_id: ControlId(reader.take_u32()?), + handshake_id: HandshakeId(reader.take_u32()?), valid_until: reader.take_u64()?, }) } diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 7ae69e65..f343bfe8 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -3,9 +3,11 @@ use crate::{ QlCrypto, SessionKey, WireError, ENCRYPTED_MESSAGE_AUTH_SIZE, }; +mod meta; mod kk; mod xx; +pub use meta::{HandshakeId, HandshakeMeta}; pub use kk::{Kk1, Kk2, KkHandshake, KkMessage}; pub use xx::{Xx1, Xx2, Xx3, Xx4, XxHandshake, XxMessage}; @@ -13,6 +15,7 @@ const SHA256_BLOCK_LEN: usize = 64; const PROTOCOL_XX: &[u8] = b"ql-wire:pq-xx:v1"; const PROTOCOL_KK: &[u8] = b"ql-wire:pq-kk:v1"; const CONNECTION_ID_DOMAIN: &[u8] = b"ql-wire:conn-id:v1"; +const HANDSHAKE_META_DOMAIN: &[u8] = b"ql-wire:handshake-meta:v1"; pub const ENCRYPTED_MLKEM_CIPHERTEXT_LEN: usize = MlKemCiphertext::SIZE + ENCRYPTED_MESSAGE_AUTH_SIZE; @@ -262,6 +265,42 @@ fn mix_hash_ephemeral( symmetric.mix_hash(crypto, public.mlkem_public_key.as_bytes()); } +fn mix_hash_handshake_meta( + symmetric: &mut SymmetricState, + crypto: &impl QlCrypto, + message_name: &[u8], + meta: &HandshakeMeta, +) { + let encoded = meta.encode(); + symmetric.mix_hash(crypto, HANDSHAKE_META_DOMAIN); + symmetric.mix_hash(crypto, message_name); + symmetric.mix_hash(crypto, &encoded); +} + +fn initialize_handshake_meta( + expected: &mut Option, + meta: HandshakeMeta, +) -> Result<(), WireError> { + match expected { + Some(stored) if *stored != meta => Err(WireError::InvalidPayload), + Some(_) => Ok(()), + None => { + *expected = Some(meta); + Ok(()) + } + } +} + +fn require_handshake_meta( + expected: &Option, + meta: HandshakeMeta, +) -> Result<(), WireError> { + match expected { + Some(stored) if *stored == meta => Ok(()), + _ => Err(WireError::InvalidPayload), + } +} + fn encrypt_peer_bundle( crypto: &impl QlCrypto, symmetric: &mut SymmetricState, diff --git a/ql-wire/src/handshake/xx.rs b/ql-wire/src/handshake/xx.rs index e77c462d..6563274a 100644 --- a/ql-wire/src/handshake/xx.rs +++ b/ql-wire/src/handshake/xx.rs @@ -1,19 +1,20 @@ use super::{ decrypt_mlkem_ciphertext, decrypt_peer_bundle, encrypt_mlkem_ciphertext, encrypt_peer_bundle, - finalize_handshake, generate_ephemeral_keypair, mix_hash_ephemeral, EncryptedMlKemCiphertext, + finalize_handshake, generate_ephemeral_keypair, initialize_handshake_meta, mix_hash_ephemeral, + mix_hash_handshake_meta, require_handshake_meta, EncryptedMlKemCiphertext, EncryptedPeerBundle, EphemeralKeyPair, EphemeralPublicKey, FinalizedHandshake, Role, SymmetricState, ENCRYPTED_MLKEM_CIPHERTEXT_LEN, ENCRYPTED_PEER_BUNDLE_LEN, PROTOCOL_XX, }; -use crate::{codec, ControlMeta, MlKemCiphertext, PeerBundle, QlCrypto, QlIdentity, WireError}; +use crate::{codec, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, QlIdentity, WireError}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct Xx1 { - pub meta: ControlMeta, + pub meta: HandshakeMeta, pub ephemeral: EphemeralPublicKey, } impl Xx1 { - pub const ENCODED_LEN: usize = ControlMeta::ENCODED_LEN + EphemeralPublicKey::ENCODED_LEN; + pub const ENCODED_LEN: usize = HandshakeMeta::ENCODED_LEN + EphemeralPublicKey::ENCODED_LEN; pub fn encode_into(&self, out: &mut Vec) { self.meta.encode_into(out); @@ -22,7 +23,7 @@ impl Xx1 { pub fn decode(bytes: &[u8]) -> Result { let mut reader = codec::Reader::new(bytes); - let meta = ControlMeta::decode_from(&mut reader)?; + let meta = HandshakeMeta::decode_from(&mut reader)?; let ephemeral = EphemeralPublicKey::decode(&reader.take_bytes(EphemeralPublicKey::ENCODED_LEN)?)?; reader.finish()?; @@ -32,14 +33,14 @@ impl Xx1 { #[derive(Debug, Clone, PartialEq, Eq)] pub struct Xx2 { - pub meta: ControlMeta, + pub meta: HandshakeMeta, pub ekem_ciphertext: MlKemCiphertext, pub static_bundle: EncryptedPeerBundle, } impl Xx2 { pub const ENCODED_LEN: usize = - ControlMeta::ENCODED_LEN + MlKemCiphertext::SIZE + ENCRYPTED_PEER_BUNDLE_LEN; + HandshakeMeta::ENCODED_LEN + MlKemCiphertext::SIZE + ENCRYPTED_PEER_BUNDLE_LEN; pub fn encode_into(&self, out: &mut Vec) { self.meta.encode_into(out); @@ -49,7 +50,7 @@ impl Xx2 { pub fn decode(bytes: &[u8]) -> Result { let mut reader = codec::Reader::new(bytes); - let meta = ControlMeta::decode_from(&mut reader)?; + let meta = HandshakeMeta::decode_from(&mut reader)?; let ekem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); let static_bundle = EncryptedPeerBundle::from_data(reader.take_array()?); reader.finish()?; @@ -63,14 +64,14 @@ impl Xx2 { #[derive(Debug, Clone, PartialEq, Eq)] pub struct Xx3 { - pub meta: ControlMeta, + pub meta: HandshakeMeta, pub skem_ciphertext: EncryptedMlKemCiphertext, pub static_bundle: EncryptedPeerBundle, } impl Xx3 { pub const ENCODED_LEN: usize = - ControlMeta::ENCODED_LEN + ENCRYPTED_MLKEM_CIPHERTEXT_LEN + ENCRYPTED_PEER_BUNDLE_LEN; + HandshakeMeta::ENCODED_LEN + ENCRYPTED_MLKEM_CIPHERTEXT_LEN + ENCRYPTED_PEER_BUNDLE_LEN; pub fn encode_into(&self, out: &mut Vec) { self.meta.encode_into(out); @@ -80,7 +81,7 @@ impl Xx3 { pub fn decode(bytes: &[u8]) -> Result { let mut reader = codec::Reader::new(bytes); - let meta = ControlMeta::decode_from(&mut reader)?; + let meta = HandshakeMeta::decode_from(&mut reader)?; let skem_ciphertext = EncryptedMlKemCiphertext::from_data(reader.take_array()?); let static_bundle = EncryptedPeerBundle::from_data(reader.take_array()?); reader.finish()?; @@ -94,12 +95,12 @@ impl Xx3 { #[derive(Debug, Clone, PartialEq, Eq)] pub struct Xx4 { - pub meta: ControlMeta, + pub meta: HandshakeMeta, pub skem_ciphertext: EncryptedMlKemCiphertext, } impl Xx4 { - pub const ENCODED_LEN: usize = ControlMeta::ENCODED_LEN + ENCRYPTED_MLKEM_CIPHERTEXT_LEN; + pub const ENCODED_LEN: usize = HandshakeMeta::ENCODED_LEN + ENCRYPTED_MLKEM_CIPHERTEXT_LEN; pub fn encode_into(&self, out: &mut Vec) { self.meta.encode_into(out); @@ -108,7 +109,7 @@ impl Xx4 { pub fn decode(bytes: &[u8]) -> Result { let mut reader = codec::Reader::new(bytes); - let meta = ControlMeta::decode_from(&mut reader)?; + let meta = HandshakeMeta::decode_from(&mut reader)?; let skem_ciphertext = EncryptedMlKemCiphertext::from_data(reader.take_array()?); reader.finish()?; Ok(Self { @@ -148,6 +149,7 @@ pub struct XxHandshake { local_ephemeral: Option, remote_ephemeral: Option, remote_bundle: Option, + handshake_meta: Option, } impl XxHandshake { @@ -160,6 +162,7 @@ impl XxHandshake { local_ephemeral: None, remote_ephemeral: None, remote_bundle: None, + handshake_meta: None, } } @@ -172,6 +175,7 @@ impl XxHandshake { local_ephemeral: None, remote_ephemeral: None, remote_bundle: None, + handshake_meta: None, } } @@ -182,10 +186,12 @@ impl XxHandshake { pub fn write_message( &mut self, crypto: &impl QlCrypto, - meta: ControlMeta, + meta: HandshakeMeta, ) -> Result { match self.step { XxStep::Send1 => { + initialize_handshake_meta(&mut self.handshake_meta, meta)?; + mix_hash_handshake_meta(&mut self.symmetric, crypto, b"xx1", &meta); let local_ephemeral = generate_ephemeral_keypair(crypto); let public = local_ephemeral.public(); mix_hash_ephemeral(&mut self.symmetric, crypto, &public); @@ -197,6 +203,8 @@ impl XxHandshake { })) } XxStep::Send2 => { + require_handshake_meta(&self.handshake_meta, meta)?; + mix_hash_handshake_meta(&mut self.symmetric, crypto, b"xx2", &meta); let remote_ephemeral = self .remote_ephemeral .clone() @@ -217,6 +225,8 @@ impl XxHandshake { })) } XxStep::Send3 => { + require_handshake_meta(&self.handshake_meta, meta)?; + mix_hash_handshake_meta(&mut self.symmetric, crypto, b"xx3", &meta); let remote_bundle = self.remote_bundle.clone().ok_or(WireError::InvalidState)?; let (skem_ciphertext, skem_secret) = crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); @@ -236,6 +246,8 @@ impl XxHandshake { })) } XxStep::Send4 => { + require_handshake_meta(&self.handshake_meta, meta)?; + mix_hash_handshake_meta(&mut self.symmetric, crypto, b"xx4", &meta); let remote_bundle = self.remote_bundle.clone().ok_or(WireError::InvalidState)?; let (skem_ciphertext, skem_secret) = crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); @@ -261,12 +273,16 @@ impl XxHandshake { ) -> Result<(), WireError> { match (&self.step, message) { (XxStep::Recv1, XxMessage::Message1(message)) => { + initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; + mix_hash_handshake_meta(&mut self.symmetric, crypto, b"xx1", &message.meta); mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); self.remote_ephemeral = Some(message.ephemeral.clone()); self.step = XxStep::Send2; Ok(()) } (XxStep::Recv2, XxMessage::Message2(message)) => { + require_handshake_meta(&self.handshake_meta, message.meta)?; + mix_hash_handshake_meta(&mut self.symmetric, crypto, b"xx2", &message.meta); let local_ephemeral = self .local_ephemeral .as_ref() @@ -284,6 +300,8 @@ impl XxHandshake { Ok(()) } (XxStep::Recv3, XxMessage::Message3(message)) => { + require_handshake_meta(&self.handshake_meta, message.meta)?; + mix_hash_handshake_meta(&mut self.symmetric, crypto, b"xx3", &message.meta); let skem_ciphertext = decrypt_mlkem_ciphertext( crypto, &mut self.symmetric, @@ -301,6 +319,8 @@ impl XxHandshake { Ok(()) } (XxStep::Recv4, XxMessage::Message4(message)) => { + require_handshake_meta(&self.handshake_meta, message.meta)?; + mix_hash_handshake_meta(&mut self.symmetric, crypto, b"xx4", &message.meta); let skem_ciphertext = decrypt_mlkem_ciphertext( crypto, &mut self.symmetric, diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs index 3a688850..ba77f53c 100644 --- a/ql-wire/src/lib.rs +++ b/ql-wire/src/lib.rs @@ -6,7 +6,6 @@ mod bytes; mod codec; -mod control; mod crypto; mod encrypted; mod encrypted_message; @@ -20,7 +19,6 @@ mod record; mod xid; pub use bytes::*; -pub use control::*; pub use crypto::*; pub use encrypted::*; pub use encrypted_message::*; diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 54ff04cc..9fb6ab08 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -156,9 +156,9 @@ fn xid(byte: u8) -> XID { XID([byte; XID::SIZE]) } -fn control_meta(id: u32) -> ControlMeta { - ControlMeta { - control_id: ControlId(id), +fn handshake_meta(id: u32) -> HandshakeMeta { + HandshakeMeta { + handshake_id: HandshakeId(id), valid_until: 10_000 + u64::from(id), } } @@ -200,7 +200,7 @@ fn peer_bundle_round_trip() { #[test] fn handshake_record_round_trip_uses_handshake_header() { let message = Xx1 { - meta: control_meta(1), + meta: handshake_meta(1), ephemeral: EphemeralPublicKey { mlkem_public_key: MlKemPublicKey::from_data([9; MlKemPublicKey::SIZE]), }, @@ -222,6 +222,34 @@ fn handshake_record_round_trip_uses_handshake_header() { assert_eq!(decoded, QlRecord::Handshake(record)); } +#[test] +fn xx_handshake_rejects_tampered_handshake_meta() { + let crypto = TestCrypto::new(9); + let initiator = make_identity(&crypto, 1); + let responder = make_identity(&crypto, 2); + + let mut initiator_state = XxHandshake::new_initiator(&crypto, initiator); + let mut responder_state = XxHandshake::new_responder(&crypto, responder); + + let m1 = initiator_state + .write_message(&crypto, handshake_meta(77)) + .unwrap(); + responder_state.read_message(&crypto, &m1).unwrap(); + + let mut m2 = responder_state + .write_message(&crypto, handshake_meta(77)) + .unwrap(); + let XxMessage::Message2(message) = &mut m2 else { + panic!("expected xx2"); + }; + message.meta.handshake_id = HandshakeId(78); + + assert_eq!( + initiator_state.read_message(&crypto, &m2), + Err(WireError::InvalidPayload) + ); +} + #[test] fn xx_handshake_round_trip_derives_matching_transport() { let crypto = TestCrypto::new(10); @@ -232,22 +260,22 @@ fn xx_handshake_round_trip_derives_matching_transport() { let mut responder_state = XxHandshake::new_responder(&crypto, responder.clone()); let m1 = initiator_state - .write_message(&crypto, control_meta(1)) + .write_message(&crypto, handshake_meta(1)) .unwrap(); responder_state.read_message(&crypto, &m1).unwrap(); let m2 = responder_state - .write_message(&crypto, control_meta(2)) + .write_message(&crypto, handshake_meta(1)) .unwrap(); initiator_state.read_message(&crypto, &m2).unwrap(); let m3 = initiator_state - .write_message(&crypto, control_meta(3)) + .write_message(&crypto, handshake_meta(1)) .unwrap(); responder_state.read_message(&crypto, &m3).unwrap(); let m4 = responder_state - .write_message(&crypto, control_meta(4)) + .write_message(&crypto, handshake_meta(1)) .unwrap(); initiator_state.read_message(&crypto, &m4).unwrap(); @@ -284,12 +312,12 @@ fn kk_handshake_round_trip_derives_matching_transport() { KkHandshake::new_responder(&crypto, responder.clone(), initiator.bundle()); let m1 = initiator_state - .write_message(&crypto, control_meta(11)) + .write_message(&crypto, handshake_meta(11)) .unwrap(); responder_state.read_message(&crypto, &m1).unwrap(); let m2 = responder_state - .write_message(&crypto, control_meta(12)) + .write_message(&crypto, handshake_meta(11)) .unwrap(); initiator_state.read_message(&crypto, &m2).unwrap(); @@ -404,22 +432,22 @@ fn protocol_record_size_breakdown() { let mut xx_responder = XxHandshake::new_responder(&crypto, responder.clone()); let xx1 = xx_initiator - .write_message(&crypto, control_meta(101)) + .write_message(&crypto, handshake_meta(101)) .unwrap(); xx_responder.read_message(&crypto, &xx1).unwrap(); let xx2 = xx_responder - .write_message(&crypto, control_meta(102)) + .write_message(&crypto, handshake_meta(101)) .unwrap(); xx_initiator.read_message(&crypto, &xx2).unwrap(); let xx3 = xx_initiator - .write_message(&crypto, control_meta(103)) + .write_message(&crypto, handshake_meta(101)) .unwrap(); xx_responder.read_message(&crypto, &xx3).unwrap(); let xx4 = xx_responder - .write_message(&crypto, control_meta(104)) + .write_message(&crypto, handshake_meta(101)) .unwrap(); xx_initiator.read_message(&crypto, &xx4).unwrap(); @@ -434,12 +462,12 @@ fn protocol_record_size_breakdown() { KkHandshake::new_responder(&crypto, responder.clone(), initiator.bundle()); let kk1 = kk_initiator - .write_message(&crypto, control_meta(201)) + .write_message(&crypto, handshake_meta(201)) .unwrap(); kk_responder.read_message(&crypto, &kk1).unwrap(); let kk2 = kk_responder - .write_message(&crypto, control_meta(202)) + .write_message(&crypto, handshake_meta(201)) .unwrap(); kk_initiator.read_message(&crypto, &kk2).unwrap(); From 4e99e0e46873388e17456a18f561ae73d664bf3a Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 30 Mar 2026 08:38:55 -0400 Subject: [PATCH 057/304] ql-wire: move record seq to public header --- ql-wire/src/encrypted/builder.rs | 20 ++++++------- ql-wire/src/encrypted/mod.rs | 49 ++++++++++++-------------------- ql-wire/src/encrypted_message.rs | 17 +++++------ ql-wire/src/handshake/mod.rs | 10 ++----- ql-wire/src/header.rs | 32 +++++++++++++-------- ql-wire/src/nonce.rs | 6 ++++ ql-wire/src/tests.rs | 31 ++++++++++---------- 7 files changed, 77 insertions(+), 88 deletions(-) diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index e8f01660..6182aad5 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -1,9 +1,9 @@ use super::{ - push_variable_len, RecordAck, RecordSeq, SessionCloseBody, SessionFrame, SessionFrameKind, - StreamClose, StreamData, StreamWindow, SIZE_LEN, + push_variable_len, RecordAck, SessionCloseBody, SessionFrame, SessionFrameKind, StreamClose, + StreamData, StreamWindow, SIZE_LEN, }; use crate::{ - codec, encrypted_message::EncryptedMessage, ByteChunks, Nonce, QlCrypto, QlSessionRecord, + encrypted_message::EncryptedMessage, ByteChunks, Nonce, QlCrypto, QlSessionRecord, SessionHeader, SessionKey, }; @@ -14,12 +14,8 @@ pub struct SessionRecordBuilder { } impl SessionRecordBuilder { - pub const HEADER_LEN: usize = std::mem::size_of::(); - - pub fn new(seq: RecordSeq, max_capacity: usize) -> Self { - let max_capacity = max_capacity.max(Self::HEADER_LEN); - let mut bytes = Vec::with_capacity(max_capacity); - codec::push_u64(&mut bytes, seq.0); + pub fn new(max_capacity: usize) -> Self { + let bytes = Vec::with_capacity(max_capacity); Self { max_capacity, bytes, @@ -35,7 +31,7 @@ impl SessionRecordBuilder { } pub fn is_empty(&self) -> bool { - self.bytes.len() == Self::HEADER_LEN + self.bytes.is_empty() } pub fn remaining_capacity(&self) -> usize { @@ -126,10 +122,10 @@ impl SessionRecordBuilder { crypto: &impl QlCrypto, header: SessionHeader, session_key: &SessionKey, - nonce: Nonce, ) -> QlSessionRecord> { let aad = header.aad(); - let encrypted = EncryptedMessage::encrypt(crypto, session_key, self.bytes, &aad, nonce); + let nonce = Nonce::from_counter(header.seq.0); + let encrypted = EncryptedMessage::encrypt(crypto, session_key, self.bytes, &nonce, &aad); QlSessionRecord { header, payload: encrypted, diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index a73ce7d2..d423434c 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -1,8 +1,8 @@ use std::mem::size_of; use crate::{ - codec, encrypted_message::EncryptedMessage, ByteChunks, ByteSlice, QlCrypto, QlSessionRecord, - SessionHeader, SessionKey, WireError, + codec, encrypted_message::EncryptedMessage, ByteChunks, ByteSlice, Nonce, QlCrypto, + QlSessionRecord, SessionHeader, SessionKey, WireError, }; mod ack; @@ -24,13 +24,8 @@ pub use stream_window::*; #[repr(transparent)] pub struct StreamId(pub u32); -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -#[repr(transparent)] -pub struct RecordSeq(pub u64); - #[derive(Debug, Clone, PartialEq, Eq)] pub struct SessionRecord { - pub seq: RecordSeq, pub frames: Vec, } @@ -81,38 +76,30 @@ impl TryFrom for SessionFrameKind { } impl SessionRecord { - pub const HEADER_LEN: usize = size_of::(); - - pub fn parse(bytes: &[u8]) -> Result<(RecordSeq, SessionFrameIter<'_>), WireError> { - let mut reader = codec::Reader::new(bytes); - let seq = RecordSeq(reader.take_u64()?); - Ok(( - seq, - SessionFrameIter { - remaining: reader.take_rest(), - }, - )) + pub fn parse(bytes: &[u8]) -> Result, WireError> { + let reader = codec::Reader::new(bytes); + Ok(SessionFrameIter { + remaining: reader.take_rest(), + }) } pub fn decode(bytes: &[u8]) -> Result { - let (seq, frames) = Self::parse(bytes)?; + let frames = Self::parse(bytes)?; let frames = frames .map(|frame| frame.map(SessionFrame::into_owned)) .collect::, _>>()?; - Ok(Self { seq, frames }) + Ok(Self { frames }) } pub fn encoded_len(&self) -> usize { - Self::HEADER_LEN - + self - .frames - .iter() - .map(SessionFrame::encoded_len) - .sum::() + self.frames + .iter() + .map(SessionFrame::encoded_len) + .sum::() } pub fn encode(&self) -> Vec { - let mut out = SessionRecordBuilder::new(self.seq, self.encoded_len()); + let mut out = SessionRecordBuilder::new(self.encoded_len()); for frame in &self.frames { let pushed = out.push_frame(frame); debug_assert!(pushed); @@ -203,14 +190,13 @@ pub fn encrypt_record( header: SessionHeader, session_key: &SessionKey, body: &SessionRecord, - nonce: crate::Nonce, ) -> QlSessionRecord> { - let mut builder = SessionRecordBuilder::new(body.seq, body.encoded_len()); + let mut builder = SessionRecordBuilder::new(body.encoded_len()); for frame in &body.frames { let pushed = builder.push_frame(frame); debug_assert!(pushed); } - builder.encrypt(crypto, header, session_key, nonce) + builder.encrypt(crypto, header, session_key) } pub fn decrypt_record>( @@ -220,7 +206,8 @@ pub fn decrypt_record>( session_key: &SessionKey, ) -> Result { let aad = header.aad(); - encrypted.decrypt_in_place(crypto, session_key, &aad) + let nonce = Nonce::from_counter(header.seq.0); + encrypted.decrypt_in_place(crypto, session_key, &nonce, &aad) } fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrame<&[u8]>, &[u8]), WireError> { diff --git a/ql-wire/src/encrypted_message.rs b/ql-wire/src/encrypted_message.rs index b1a98863..062c4ecd 100644 --- a/ql-wire/src/encrypted_message.rs +++ b/ql-wire/src/encrypted_message.rs @@ -4,21 +4,19 @@ use crate::{ #[derive(Debug, Clone, PartialEq, Eq)] pub struct EncryptedMessage { - pub nonce: Nonce, pub auth: [u8; ENCRYPTED_MESSAGE_AUTH_SIZE], pub ciphertext: B, } impl EncryptedMessage { pub const AUTH_SIZE: usize = ENCRYPTED_MESSAGE_AUTH_SIZE; - pub const HEADER_LEN: usize = Nonce::SIZE + Self::AUTH_SIZE; + pub const HEADER_LEN: usize = Self::AUTH_SIZE; pub fn into_owned(self) -> EncryptedMessage> where B: ByteSlice, { EncryptedMessage { - nonce: self.nonce, auth: self.auth, ciphertext: self.ciphertext.to_vec(), } @@ -29,7 +27,6 @@ impl EncryptedMessage { pub fn parse(bytes: B) -> Result { let mut reader = codec::Reader::new(bytes); Ok(Self { - nonce: Nonce(reader.take_array()?), auth: reader.take_array()?, ciphertext: reader.take_rest(), }) @@ -38,7 +35,6 @@ impl EncryptedMessage { impl> EncryptedMessage { pub fn encode_into(&self, out: &mut Vec) { - codec::push_bytes(out, &self.nonce.0); codec::push_bytes(out, &self.auth); codec::push_bytes(out, self.ciphertext.as_ref()); } @@ -53,10 +49,11 @@ impl> EncryptedMessage { &self, crypto: &impl QlCrypto, key: &SessionKey, + nonce: &Nonce, aad: &[u8], ) -> Result, WireError> { let mut plaintext = self.ciphertext.as_ref().to_vec(); - if !crypto.aes256_gcm_decrypt(key, &self.nonce, aad, &mut plaintext, &self.auth) { + if !crypto.aes256_gcm_decrypt(key, nonce, aad, &mut plaintext, &self.auth) { return Err(WireError::DecryptFailed); } Ok(plaintext) @@ -68,10 +65,11 @@ impl> EncryptedMessage { mut self, crypto: &impl QlCrypto, key: &SessionKey, + nonce: &Nonce, aad: &[u8], ) -> Result { let ciphertext = self.ciphertext.as_mut(); - if !crypto.aes256_gcm_decrypt(key, &self.nonce, aad, ciphertext, &self.auth) { + if !crypto.aes256_gcm_decrypt(key, nonce, aad, ciphertext, &self.auth) { return Err(WireError::DecryptFailed); } Ok(self.ciphertext) @@ -83,12 +81,11 @@ impl EncryptedMessage> { crypto: &impl QlCrypto, key: &SessionKey, mut plaintext: Vec, + nonce: &Nonce, aad: &[u8], - nonce: Nonce, ) -> Self { - let auth = crypto.aes256_gcm_encrypt(key, &nonce, aad, &mut plaintext); + let auth = crypto.aes256_gcm_encrypt(key, nonce, aad, &mut plaintext); Self { - nonce, auth, ciphertext: plaintext, } diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index f343bfe8..309f3b05 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -128,7 +128,7 @@ impl CipherState { plaintext: &[u8], ) -> Result, WireError> { let key = self.key.as_ref().ok_or(WireError::InvalidState)?; - let nonce = noise_nonce(self.nonce); + let nonce = Nonce::from_counter(self.nonce); let mut ciphertext = plaintext.to_vec(); let auth = crypto.aes256_gcm_encrypt(key, &nonce, aad, &mut ciphertext); self.nonce = self.nonce.wrapping_add(1); @@ -149,7 +149,7 @@ impl CipherState { let (ciphertext, auth) = ciphertext.split_at(split); let mut plaintext = ciphertext.to_vec(); let key = self.key.as_ref().ok_or(WireError::InvalidState)?; - let nonce = noise_nonce(self.nonce); + let nonce = Nonce::from_counter(self.nonce); let mut auth_tag = [0u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; auth_tag.copy_from_slice(auth); if !crypto.aes256_gcm_decrypt(key, &nonce, aad, &mut plaintext, &auth_tag) { @@ -391,12 +391,6 @@ fn derive_connection_ids( ) } -fn noise_nonce(counter: u64) -> Nonce { - let mut nonce = [0u8; Nonce::SIZE]; - nonce[4..].copy_from_slice(&counter.to_le_bytes()); - Nonce(nonce) -} - fn hkdf2( crypto: &impl QlCrypto, chaining_key: &[u8; 32], diff --git a/ql-wire/src/header.rs b/ql-wire/src/header.rs index 8f09d47d..845697a3 100644 --- a/ql-wire/src/header.rs +++ b/ql-wire/src/header.rs @@ -1,5 +1,21 @@ use crate::{codec, QL_WIRE_VERSION, XID}; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct HandshakeHeader { + pub sender: XID, + pub recipient: XID, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct SessionHeader { + pub connection_id: ConnectionId, + pub seq: RecordSeq, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct RecordSeq(pub u64); + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[repr(transparent)] pub struct ConnectionId(pub [u8; Self::SIZE]); @@ -16,17 +32,6 @@ impl ConnectionId { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct HandshakeHeader { - pub sender: XID, - pub recipient: XID, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct SessionHeader { - pub connection_id: ConnectionId, -} - impl HandshakeHeader { pub const ENCODED_LEN: usize = XID::SIZE * 2; @@ -53,10 +58,11 @@ impl HandshakeHeader { } impl SessionHeader { - pub const ENCODED_LEN: usize = ConnectionId::SIZE; + pub const ENCODED_LEN: usize = ConnectionId::SIZE + core::mem::size_of::(); pub fn encode_into(&self, out: &mut Vec) { codec::push_bytes(out, self.connection_id.as_bytes()); + codec::push_u64(out, self.seq.0); } pub fn decode(bytes: &[u8]) -> Result { @@ -71,6 +77,7 @@ impl SessionHeader { ) -> Result { Ok(Self { connection_id: ConnectionId::from_data(reader.take_array()?), + seq: RecordSeq(reader.take_u64()?), }) } @@ -80,6 +87,7 @@ impl SessionHeader { codec::append_field(&mut aad, b"wire-version", &[QL_WIRE_VERSION]); codec::append_field(&mut aad, b"record-kind", b"session"); codec::append_field(&mut aad, b"connection-id", self.connection_id.as_bytes()); + codec::append_field(&mut aad, b"record-seq", &self.seq.0.to_le_bytes()); aad } } diff --git a/ql-wire/src/nonce.rs b/ql-wire/src/nonce.rs index fd913400..c7e6d793 100644 --- a/ql-wire/src/nonce.rs +++ b/ql-wire/src/nonce.rs @@ -4,4 +4,10 @@ pub struct Nonce(pub [u8; Self::SIZE]); impl Nonce { pub const SIZE: usize = 12; + + pub fn from_counter(counter: u64) -> Self { + let mut nonce = [0u8; Self::SIZE]; + nonce[4..].copy_from_slice(&counter.to_le_bytes()); + Self(nonce) + } } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 9fb6ab08..c348296a 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -347,9 +347,9 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { let crypto = TestCrypto::new(30); let header = SessionHeader { connection_id: ConnectionId::from_data([0x44; ConnectionId::SIZE]), + seq: RecordSeq(11), }; let body = SessionRecord { - seq: RecordSeq(11), frames: vec![ SessionFrame::Ping, SessionFrame::Ack(RecordAck { @@ -379,13 +379,7 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { ], }; let session_key = SessionKey::from_data([7; SessionKey::SIZE]); - let record = encrypted::encrypt_record( - &crypto, - header, - &session_key, - &body, - Nonce([8; Nonce::SIZE]), - ); + let record = encrypted::encrypt_record(&crypto, header, &session_key, &body); let bytes = record.encode(); let decoded = QlRecord::decode(&bytes).unwrap(); @@ -404,9 +398,19 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { let wrong_header = SessionHeader { connection_id: ConnectionId::from_data([0x99; ConnectionId::SIZE]), + seq: header.seq, }; assert_eq!( - encrypted::decrypt_record(&crypto, &wrong_header, encrypted, &session_key), + encrypted::decrypt_record(&crypto, &wrong_header, encrypted.clone(), &session_key), + Err(WireError::DecryptFailed) + ); + + let wrong_seq_header = SessionHeader { + connection_id: header.connection_id, + seq: RecordSeq(header.seq.0 + 1), + }; + assert_eq!( + encrypted::decrypt_record(&crypto, &wrong_seq_header, encrypted, &session_key), Err(WireError::DecryptFailed) ); } @@ -479,22 +483,21 @@ fn protocol_record_size_breakdown() { &crypto, SessionHeader { connection_id: session.tx_connection_id, + seq: RecordSeq(1), }, &session.tx_key, &SessionRecord { - seq: RecordSeq(1), frames: vec![SessionFrame::Ping], }, - Nonce([0x41; Nonce::SIZE]), ); let session_stream_empty = encrypted::encrypt_record( &crypto, SessionHeader { connection_id: session.tx_connection_id, + seq: RecordSeq(2), }, &session.tx_key, &SessionRecord { - seq: RecordSeq(2), frames: vec![SessionFrame::StreamData(StreamData { stream_id: StreamId(1), offset: 0, @@ -502,21 +505,19 @@ fn protocol_record_size_breakdown() { bytes: Vec::new(), })], }, - Nonce([0x42; Nonce::SIZE]), ); let session_close = encrypted::encrypt_record( &crypto, SessionHeader { connection_id: session.tx_connection_id, + seq: RecordSeq(3), }, &session.tx_key, &SessionRecord { - seq: RecordSeq(3), frames: vec![SessionFrame::Close(SessionCloseBody { code: CloseCode::PROTOCOL, })], }, - Nonce([0x43; Nonce::SIZE]), ); print_size("ql-wire peer bundle", initiator.bundle().encode().len()); From b496f1e5ce729513d5c3fed2d6682fe063be7614 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 30 Mar 2026 09:02:25 -0400 Subject: [PATCH 058/304] ql: verify handshake record contents --- ql-wire/src/handshake/kk.rs | 41 +++++++++-- ql-wire/src/handshake/mod.rs | 23 +++--- ql-wire/src/handshake/xx.rs | 77 ++++++++++++++++---- ql-wire/src/tests.rs | 134 ++++++++++++++++++++++++++--------- 4 files changed, 213 insertions(+), 62 deletions(-) diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index d79823ad..02c254a3 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -1,10 +1,13 @@ use super::{ decrypt_mlkem_ciphertext, encrypt_mlkem_ciphertext, finalize_handshake, generate_ephemeral_keypair, init_kk_symmetric, initialize_handshake_meta, mix_hash_ephemeral, - mix_hash_handshake_meta, require_handshake_meta, EncryptedMlKemCiphertext, EphemeralKeyPair, + mix_hash_handshake, require_handshake_meta, EncryptedMlKemCiphertext, EphemeralKeyPair, EphemeralPublicKey, FinalizedHandshake, Role, SymmetricState, ENCRYPTED_MLKEM_CIPHERTEXT_LEN, }; -use crate::{codec, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, QlIdentity, WireError}; +use crate::{ + codec, HandshakeHeader, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, + QlIdentity, WireError, +}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct Kk1 { @@ -140,12 +143,19 @@ impl KkHandshake { pub fn write_message( &mut self, crypto: &impl QlCrypto, + header: HandshakeHeader, meta: HandshakeMeta, ) -> Result { match self.step { KkStep::Send1 => { initialize_handshake_meta(&mut self.handshake_meta, meta)?; - mix_hash_handshake_meta(&mut self.symmetric, crypto, b"kk1", &meta); + mix_hash_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Kk1, + &meta, + ); let (skem_ciphertext, skem_secret) = crypto.mlkem_encapsulate(&self.remote_bundle.mlkem_public_key); self.symmetric @@ -167,7 +177,13 @@ impl KkHandshake { } KkStep::Send2 => { require_handshake_meta(&self.handshake_meta, meta)?; - mix_hash_handshake_meta(&mut self.symmetric, crypto, b"kk2", &meta); + mix_hash_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Kk2, + &meta, + ); let remote_ephemeral = self .remote_ephemeral .clone() @@ -198,12 +214,19 @@ impl KkHandshake { pub fn read_message( &mut self, crypto: &impl QlCrypto, + header: HandshakeHeader, message: &KkMessage, ) -> Result<(), WireError> { match (&self.step, message) { (KkStep::Recv1, KkMessage::Message1(message)) => { initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; - mix_hash_handshake_meta(&mut self.symmetric, crypto, b"kk1", &message.meta); + mix_hash_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Kk1, + &message.meta, + ); self.symmetric .decrypt_and_hash(crypto, message.skem_ciphertext.as_bytes())?; let skem_secret = crypto @@ -218,7 +241,13 @@ impl KkHandshake { } (KkStep::Recv2, KkMessage::Message2(message)) => { require_handshake_meta(&self.handshake_meta, message.meta)?; - mix_hash_handshake_meta(&mut self.symmetric, crypto, b"kk2", &message.meta); + mix_hash_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Kk2, + &message.meta, + ); let local_ephemeral = self .local_ephemeral .as_ref() diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 309f3b05..ffaf2bac 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -1,21 +1,22 @@ use crate::{ - codec, ConnectionId, MlKemCiphertext, MlKemKeyPair, MlKemPublicKey, Nonce, PeerBundle, - QlCrypto, SessionKey, WireError, ENCRYPTED_MESSAGE_AUTH_SIZE, + codec, ConnectionId, HandshakeHeader, HandshakeKind, MlKemCiphertext, MlKemKeyPair, + MlKemPublicKey, Nonce, PeerBundle, QlCrypto, SessionKey, WireError, + ENCRYPTED_MESSAGE_AUTH_SIZE, }; -mod meta; mod kk; +mod meta; mod xx; -pub use meta::{HandshakeId, HandshakeMeta}; pub use kk::{Kk1, Kk2, KkHandshake, KkMessage}; +pub use meta::{HandshakeId, HandshakeMeta}; pub use xx::{Xx1, Xx2, Xx3, Xx4, XxHandshake, XxMessage}; const SHA256_BLOCK_LEN: usize = 64; const PROTOCOL_XX: &[u8] = b"ql-wire:pq-xx:v1"; const PROTOCOL_KK: &[u8] = b"ql-wire:pq-kk:v1"; const CONNECTION_ID_DOMAIN: &[u8] = b"ql-wire:conn-id:v1"; -const HANDSHAKE_META_DOMAIN: &[u8] = b"ql-wire:handshake-meta:v1"; +const HANDSHAKE_PREAMBLE_DOMAIN: &[u8] = b"ql-wire:handshake-preamble:v1"; pub const ENCRYPTED_MLKEM_CIPHERTEXT_LEN: usize = MlKemCiphertext::SIZE + ENCRYPTED_MESSAGE_AUTH_SIZE; @@ -265,15 +266,19 @@ fn mix_hash_ephemeral( symmetric.mix_hash(crypto, public.mlkem_public_key.as_bytes()); } -fn mix_hash_handshake_meta( +fn mix_hash_handshake( symmetric: &mut SymmetricState, crypto: &impl QlCrypto, - message_name: &[u8], + header: HandshakeHeader, + kind: HandshakeKind, meta: &HandshakeMeta, ) { + let mut encoded_header = Vec::with_capacity(HandshakeHeader::ENCODED_LEN); + header.encode_into(&mut encoded_header); let encoded = meta.encode(); - symmetric.mix_hash(crypto, HANDSHAKE_META_DOMAIN); - symmetric.mix_hash(crypto, message_name); + symmetric.mix_hash(crypto, HANDSHAKE_PREAMBLE_DOMAIN); + symmetric.mix_hash(crypto, &encoded_header); + symmetric.mix_hash(crypto, &[kind as u8]); symmetric.mix_hash(crypto, &encoded); } diff --git a/ql-wire/src/handshake/xx.rs b/ql-wire/src/handshake/xx.rs index 6563274a..b1bdef76 100644 --- a/ql-wire/src/handshake/xx.rs +++ b/ql-wire/src/handshake/xx.rs @@ -1,11 +1,14 @@ use super::{ decrypt_mlkem_ciphertext, decrypt_peer_bundle, encrypt_mlkem_ciphertext, encrypt_peer_bundle, finalize_handshake, generate_ephemeral_keypair, initialize_handshake_meta, mix_hash_ephemeral, - mix_hash_handshake_meta, require_handshake_meta, EncryptedMlKemCiphertext, - EncryptedPeerBundle, EphemeralKeyPair, EphemeralPublicKey, FinalizedHandshake, Role, - SymmetricState, ENCRYPTED_MLKEM_CIPHERTEXT_LEN, ENCRYPTED_PEER_BUNDLE_LEN, PROTOCOL_XX, + mix_hash_handshake, require_handshake_meta, EncryptedMlKemCiphertext, EncryptedPeerBundle, + EphemeralKeyPair, EphemeralPublicKey, FinalizedHandshake, Role, SymmetricState, + ENCRYPTED_MLKEM_CIPHERTEXT_LEN, ENCRYPTED_PEER_BUNDLE_LEN, PROTOCOL_XX, +}; +use crate::{ + codec, HandshakeHeader, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, + QlIdentity, WireError, }; -use crate::{codec, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, QlIdentity, WireError}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct Xx1 { @@ -186,12 +189,19 @@ impl XxHandshake { pub fn write_message( &mut self, crypto: &impl QlCrypto, + header: HandshakeHeader, meta: HandshakeMeta, ) -> Result { match self.step { XxStep::Send1 => { initialize_handshake_meta(&mut self.handshake_meta, meta)?; - mix_hash_handshake_meta(&mut self.symmetric, crypto, b"xx1", &meta); + mix_hash_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx1, + &meta, + ); let local_ephemeral = generate_ephemeral_keypair(crypto); let public = local_ephemeral.public(); mix_hash_ephemeral(&mut self.symmetric, crypto, &public); @@ -204,7 +214,13 @@ impl XxHandshake { } XxStep::Send2 => { require_handshake_meta(&self.handshake_meta, meta)?; - mix_hash_handshake_meta(&mut self.symmetric, crypto, b"xx2", &meta); + mix_hash_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx2, + &meta, + ); let remote_ephemeral = self .remote_ephemeral .clone() @@ -226,7 +242,13 @@ impl XxHandshake { } XxStep::Send3 => { require_handshake_meta(&self.handshake_meta, meta)?; - mix_hash_handshake_meta(&mut self.symmetric, crypto, b"xx3", &meta); + mix_hash_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx3, + &meta, + ); let remote_bundle = self.remote_bundle.clone().ok_or(WireError::InvalidState)?; let (skem_ciphertext, skem_secret) = crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); @@ -247,7 +269,13 @@ impl XxHandshake { } XxStep::Send4 => { require_handshake_meta(&self.handshake_meta, meta)?; - mix_hash_handshake_meta(&mut self.symmetric, crypto, b"xx4", &meta); + mix_hash_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx4, + &meta, + ); let remote_bundle = self.remote_bundle.clone().ok_or(WireError::InvalidState)?; let (skem_ciphertext, skem_secret) = crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); @@ -269,12 +297,19 @@ impl XxHandshake { pub fn read_message( &mut self, crypto: &impl QlCrypto, + header: HandshakeHeader, message: &XxMessage, ) -> Result<(), WireError> { match (&self.step, message) { (XxStep::Recv1, XxMessage::Message1(message)) => { initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; - mix_hash_handshake_meta(&mut self.symmetric, crypto, b"xx1", &message.meta); + mix_hash_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx1, + &message.meta, + ); mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); self.remote_ephemeral = Some(message.ephemeral.clone()); self.step = XxStep::Send2; @@ -282,7 +317,13 @@ impl XxHandshake { } (XxStep::Recv2, XxMessage::Message2(message)) => { require_handshake_meta(&self.handshake_meta, message.meta)?; - mix_hash_handshake_meta(&mut self.symmetric, crypto, b"xx2", &message.meta); + mix_hash_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx2, + &message.meta, + ); let local_ephemeral = self .local_ephemeral .as_ref() @@ -301,7 +342,13 @@ impl XxHandshake { } (XxStep::Recv3, XxMessage::Message3(message)) => { require_handshake_meta(&self.handshake_meta, message.meta)?; - mix_hash_handshake_meta(&mut self.symmetric, crypto, b"xx3", &message.meta); + mix_hash_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx3, + &message.meta, + ); let skem_ciphertext = decrypt_mlkem_ciphertext( crypto, &mut self.symmetric, @@ -320,7 +367,13 @@ impl XxHandshake { } (XxStep::Recv4, XxMessage::Message4(message)) => { require_handshake_meta(&self.handshake_meta, message.meta)?; - mix_hash_handshake_meta(&mut self.symmetric, crypto, b"xx4", &message.meta); + mix_hash_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx4, + &message.meta, + ); let skem_ciphertext = decrypt_mlkem_ciphertext( crypto, &mut self.symmetric, diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index c348296a..07eba3ea 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -167,6 +167,13 @@ fn make_identity(crypto: &impl QlCrypto, byte: u8) -> QlIdentity { generate_identity(crypto, xid(byte)) } +fn handshake_header(sender: u8, recipient: u8) -> HandshakeHeader { + HandshakeHeader { + sender: xid(sender), + recipient: xid(recipient), + } +} + fn xx_record(header: HandshakeHeader, message: XxMessage) -> QlHandshakeRecord { let payload = match message { XxMessage::Message1(message) => HandshakePayload::Xx1(message), @@ -230,14 +237,18 @@ fn xx_handshake_rejects_tampered_handshake_meta() { let mut initiator_state = XxHandshake::new_initiator(&crypto, initiator); let mut responder_state = XxHandshake::new_responder(&crypto, responder); + let initiator_header = handshake_header(1, 2); + let responder_header = handshake_header(2, 1); let m1 = initiator_state - .write_message(&crypto, handshake_meta(77)) + .write_message(&crypto, initiator_header, handshake_meta(77)) + .unwrap(); + responder_state + .read_message(&crypto, initiator_header, &m1) .unwrap(); - responder_state.read_message(&crypto, &m1).unwrap(); let mut m2 = responder_state - .write_message(&crypto, handshake_meta(77)) + .write_message(&crypto, responder_header, handshake_meta(77)) .unwrap(); let XxMessage::Message2(message) = &mut m2 else { panic!("expected xx2"); @@ -245,11 +256,39 @@ fn xx_handshake_rejects_tampered_handshake_meta() { message.meta.handshake_id = HandshakeId(78); assert_eq!( - initiator_state.read_message(&crypto, &m2), + initiator_state.read_message(&crypto, responder_header, &m2), Err(WireError::InvalidPayload) ); } +#[test] +fn xx_handshake_rejects_tampered_handshake_header() { + let crypto = TestCrypto::new(10); + let initiator = make_identity(&crypto, 1); + let responder = make_identity(&crypto, 2); + + let mut initiator_state = XxHandshake::new_initiator(&crypto, initiator); + let mut responder_state = XxHandshake::new_responder(&crypto, responder); + let initiator_header = handshake_header(1, 2); + let responder_header = handshake_header(2, 1); + + let m1 = initiator_state + .write_message(&crypto, initiator_header, handshake_meta(88)) + .unwrap(); + responder_state + .read_message(&crypto, initiator_header, &m1) + .unwrap(); + + let m2 = responder_state + .write_message(&crypto, responder_header, handshake_meta(88)) + .unwrap(); + + assert_eq!( + initiator_state.read_message(&crypto, handshake_header(9, 1), &m2), + Err(WireError::DecryptFailed) + ); +} + #[test] fn xx_handshake_round_trip_derives_matching_transport() { let crypto = TestCrypto::new(10); @@ -258,26 +297,36 @@ fn xx_handshake_round_trip_derives_matching_transport() { let mut initiator_state = XxHandshake::new_initiator(&crypto, initiator.clone()); let mut responder_state = XxHandshake::new_responder(&crypto, responder.clone()); + let initiator_header = handshake_header(1, 2); + let responder_header = handshake_header(2, 1); let m1 = initiator_state - .write_message(&crypto, handshake_meta(1)) + .write_message(&crypto, initiator_header, handshake_meta(1)) + .unwrap(); + responder_state + .read_message(&crypto, initiator_header, &m1) .unwrap(); - responder_state.read_message(&crypto, &m1).unwrap(); let m2 = responder_state - .write_message(&crypto, handshake_meta(1)) + .write_message(&crypto, responder_header, handshake_meta(1)) + .unwrap(); + initiator_state + .read_message(&crypto, responder_header, &m2) .unwrap(); - initiator_state.read_message(&crypto, &m2).unwrap(); let m3 = initiator_state - .write_message(&crypto, handshake_meta(1)) + .write_message(&crypto, initiator_header, handshake_meta(1)) + .unwrap(); + responder_state + .read_message(&crypto, initiator_header, &m3) .unwrap(); - responder_state.read_message(&crypto, &m3).unwrap(); let m4 = responder_state - .write_message(&crypto, handshake_meta(1)) + .write_message(&crypto, responder_header, handshake_meta(1)) + .unwrap(); + initiator_state + .read_message(&crypto, responder_header, &m4) .unwrap(); - initiator_state.read_message(&crypto, &m4).unwrap(); let initiator_final = initiator_state.finalize(&crypto).unwrap(); let responder_final = responder_state.finalize(&crypto).unwrap(); @@ -310,16 +359,22 @@ fn kk_handshake_round_trip_derives_matching_transport() { KkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); let mut responder_state = KkHandshake::new_responder(&crypto, responder.clone(), initiator.bundle()); + let initiator_header = handshake_header(3, 4); + let responder_header = handshake_header(4, 3); let m1 = initiator_state - .write_message(&crypto, handshake_meta(11)) + .write_message(&crypto, initiator_header, handshake_meta(11)) + .unwrap(); + responder_state + .read_message(&crypto, initiator_header, &m1) .unwrap(); - responder_state.read_message(&crypto, &m1).unwrap(); let m2 = responder_state - .write_message(&crypto, handshake_meta(11)) + .write_message(&crypto, responder_header, handshake_meta(11)) + .unwrap(); + initiator_state + .read_message(&crypto, responder_header, &m2) .unwrap(); - initiator_state.read_message(&crypto, &m2).unwrap(); let initiator_final = initiator_state.finalize(&crypto).unwrap(); let responder_final = responder_state.finalize(&crypto).unwrap(); @@ -417,13 +472,6 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { #[test] fn protocol_record_size_breakdown() { - fn handshake_header(sender: u8, recipient: u8) -> HandshakeHeader { - HandshakeHeader { - sender: xid(sender), - recipient: xid(recipient), - } - } - fn print_size(label: &str, size: usize) { println!("{label:<32}: {size} bytes"); } @@ -434,26 +482,36 @@ fn protocol_record_size_breakdown() { let mut xx_initiator = XxHandshake::new_initiator(&crypto, initiator.clone()); let mut xx_responder = XxHandshake::new_responder(&crypto, responder.clone()); + let xx_initiator_header = handshake_header(1, 2); + let xx_responder_header = handshake_header(2, 1); let xx1 = xx_initiator - .write_message(&crypto, handshake_meta(101)) + .write_message(&crypto, xx_initiator_header, handshake_meta(101)) + .unwrap(); + xx_responder + .read_message(&crypto, xx_initiator_header, &xx1) .unwrap(); - xx_responder.read_message(&crypto, &xx1).unwrap(); let xx2 = xx_responder - .write_message(&crypto, handshake_meta(101)) + .write_message(&crypto, xx_responder_header, handshake_meta(101)) + .unwrap(); + xx_initiator + .read_message(&crypto, xx_responder_header, &xx2) .unwrap(); - xx_initiator.read_message(&crypto, &xx2).unwrap(); let xx3 = xx_initiator - .write_message(&crypto, handshake_meta(101)) + .write_message(&crypto, xx_initiator_header, handshake_meta(101)) + .unwrap(); + xx_responder + .read_message(&crypto, xx_initiator_header, &xx3) .unwrap(); - xx_responder.read_message(&crypto, &xx3).unwrap(); let xx4 = xx_responder - .write_message(&crypto, handshake_meta(101)) + .write_message(&crypto, xx_responder_header, handshake_meta(101)) + .unwrap(); + xx_initiator + .read_message(&crypto, xx_responder_header, &xx4) .unwrap(); - xx_initiator.read_message(&crypto, &xx4).unwrap(); let xx1 = xx_record(handshake_header(1, 2), xx1); let xx2 = xx_record(handshake_header(2, 1), xx2); @@ -464,16 +522,22 @@ fn protocol_record_size_breakdown() { KkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); let mut kk_responder = KkHandshake::new_responder(&crypto, responder.clone(), initiator.bundle()); + let kk_initiator_header = handshake_header(1, 2); + let kk_responder_header = handshake_header(2, 1); let kk1 = kk_initiator - .write_message(&crypto, handshake_meta(201)) + .write_message(&crypto, kk_initiator_header, handshake_meta(201)) + .unwrap(); + kk_responder + .read_message(&crypto, kk_initiator_header, &kk1) .unwrap(); - kk_responder.read_message(&crypto, &kk1).unwrap(); let kk2 = kk_responder - .write_message(&crypto, handshake_meta(201)) + .write_message(&crypto, kk_responder_header, handshake_meta(201)) + .unwrap(); + kk_initiator + .read_message(&crypto, kk_responder_header, &kk2) .unwrap(); - kk_initiator.read_message(&crypto, &kk2).unwrap(); let kk1 = kk_record(handshake_header(1, 2), kk1); let kk2 = kk_record(handshake_header(2, 1), kk2); From 27dd813aa6d422c1b3fd24f24bbe041eac5c42d4 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 30 Mar 2026 09:09:38 -0400 Subject: [PATCH 059/304] ql: verify handshake message not expired --- ql-wire/src/handshake/kk.rs | 3 ++ ql-wire/src/handshake/xx.rs | 5 ++++ ql-wire/src/tests.rs | 59 +++++++++++++++++++++++++++---------- 3 files changed, 51 insertions(+), 16 deletions(-) diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index 02c254a3..827ce8f0 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -215,10 +215,12 @@ impl KkHandshake { &mut self, crypto: &impl QlCrypto, header: HandshakeHeader, + now_seconds: u64, message: &KkMessage, ) -> Result<(), WireError> { match (&self.step, message) { (KkStep::Recv1, KkMessage::Message1(message)) => { + message.meta.ensure_not_expired(now_seconds)?; initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; mix_hash_handshake( &mut self.symmetric, @@ -240,6 +242,7 @@ impl KkHandshake { Ok(()) } (KkStep::Recv2, KkMessage::Message2(message)) => { + message.meta.ensure_not_expired(now_seconds)?; require_handshake_meta(&self.handshake_meta, message.meta)?; mix_hash_handshake( &mut self.symmetric, diff --git a/ql-wire/src/handshake/xx.rs b/ql-wire/src/handshake/xx.rs index b1bdef76..02ffddfa 100644 --- a/ql-wire/src/handshake/xx.rs +++ b/ql-wire/src/handshake/xx.rs @@ -298,10 +298,12 @@ impl XxHandshake { &mut self, crypto: &impl QlCrypto, header: HandshakeHeader, + now_seconds: u64, message: &XxMessage, ) -> Result<(), WireError> { match (&self.step, message) { (XxStep::Recv1, XxMessage::Message1(message)) => { + message.meta.ensure_not_expired(now_seconds)?; initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; mix_hash_handshake( &mut self.symmetric, @@ -316,6 +318,7 @@ impl XxHandshake { Ok(()) } (XxStep::Recv2, XxMessage::Message2(message)) => { + message.meta.ensure_not_expired(now_seconds)?; require_handshake_meta(&self.handshake_meta, message.meta)?; mix_hash_handshake( &mut self.symmetric, @@ -341,6 +344,7 @@ impl XxHandshake { Ok(()) } (XxStep::Recv3, XxMessage::Message3(message)) => { + message.meta.ensure_not_expired(now_seconds)?; require_handshake_meta(&self.handshake_meta, message.meta)?; mix_hash_handshake( &mut self.symmetric, @@ -366,6 +370,7 @@ impl XxHandshake { Ok(()) } (XxStep::Recv4, XxMessage::Message4(message)) => { + message.meta.ensure_not_expired(now_seconds)?; require_handshake_meta(&self.handshake_meta, message.meta)?; mix_hash_handshake( &mut self.symmetric, diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 07eba3ea..fddd00ad 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -244,7 +244,7 @@ fn xx_handshake_rejects_tampered_handshake_meta() { .write_message(&crypto, initiator_header, handshake_meta(77)) .unwrap(); responder_state - .read_message(&crypto, initiator_header, &m1) + .read_message(&crypto, initiator_header, 0, &m1) .unwrap(); let mut m2 = responder_state @@ -256,7 +256,7 @@ fn xx_handshake_rejects_tampered_handshake_meta() { message.meta.handshake_id = HandshakeId(78); assert_eq!( - initiator_state.read_message(&crypto, responder_header, &m2), + initiator_state.read_message(&crypto, responder_header, 0, &m2), Err(WireError::InvalidPayload) ); } @@ -276,7 +276,7 @@ fn xx_handshake_rejects_tampered_handshake_header() { .write_message(&crypto, initiator_header, handshake_meta(88)) .unwrap(); responder_state - .read_message(&crypto, initiator_header, &m1) + .read_message(&crypto, initiator_header, 0, &m1) .unwrap(); let m2 = responder_state @@ -284,11 +284,38 @@ fn xx_handshake_rejects_tampered_handshake_header() { .unwrap(); assert_eq!( - initiator_state.read_message(&crypto, handshake_header(9, 1), &m2), + initiator_state.read_message(&crypto, handshake_header(9, 1), 0, &m2), Err(WireError::DecryptFailed) ); } +#[test] +fn xx_handshake_rejects_expired_message() { + let crypto = TestCrypto::new(11); + let initiator = make_identity(&crypto, 1); + let responder = make_identity(&crypto, 2); + + let mut initiator_state = XxHandshake::new_initiator(&crypto, initiator); + let mut responder_state = XxHandshake::new_responder(&crypto, responder); + let initiator_header = handshake_header(1, 2); + + let m1 = initiator_state + .write_message( + &crypto, + initiator_header, + HandshakeMeta { + handshake_id: HandshakeId(90), + valid_until: 5, + }, + ) + .unwrap(); + + assert_eq!( + responder_state.read_message(&crypto, initiator_header, 6, &m1), + Err(WireError::Expired) + ); +} + #[test] fn xx_handshake_round_trip_derives_matching_transport() { let crypto = TestCrypto::new(10); @@ -304,28 +331,28 @@ fn xx_handshake_round_trip_derives_matching_transport() { .write_message(&crypto, initiator_header, handshake_meta(1)) .unwrap(); responder_state - .read_message(&crypto, initiator_header, &m1) + .read_message(&crypto, initiator_header, 0, &m1) .unwrap(); let m2 = responder_state .write_message(&crypto, responder_header, handshake_meta(1)) .unwrap(); initiator_state - .read_message(&crypto, responder_header, &m2) + .read_message(&crypto, responder_header, 0, &m2) .unwrap(); let m3 = initiator_state .write_message(&crypto, initiator_header, handshake_meta(1)) .unwrap(); responder_state - .read_message(&crypto, initiator_header, &m3) + .read_message(&crypto, initiator_header, 0, &m3) .unwrap(); let m4 = responder_state .write_message(&crypto, responder_header, handshake_meta(1)) .unwrap(); initiator_state - .read_message(&crypto, responder_header, &m4) + .read_message(&crypto, responder_header, 0, &m4) .unwrap(); let initiator_final = initiator_state.finalize(&crypto).unwrap(); @@ -366,14 +393,14 @@ fn kk_handshake_round_trip_derives_matching_transport() { .write_message(&crypto, initiator_header, handshake_meta(11)) .unwrap(); responder_state - .read_message(&crypto, initiator_header, &m1) + .read_message(&crypto, initiator_header, 0, &m1) .unwrap(); let m2 = responder_state .write_message(&crypto, responder_header, handshake_meta(11)) .unwrap(); initiator_state - .read_message(&crypto, responder_header, &m2) + .read_message(&crypto, responder_header, 0, &m2) .unwrap(); let initiator_final = initiator_state.finalize(&crypto).unwrap(); @@ -489,28 +516,28 @@ fn protocol_record_size_breakdown() { .write_message(&crypto, xx_initiator_header, handshake_meta(101)) .unwrap(); xx_responder - .read_message(&crypto, xx_initiator_header, &xx1) + .read_message(&crypto, xx_initiator_header, 0, &xx1) .unwrap(); let xx2 = xx_responder .write_message(&crypto, xx_responder_header, handshake_meta(101)) .unwrap(); xx_initiator - .read_message(&crypto, xx_responder_header, &xx2) + .read_message(&crypto, xx_responder_header, 0, &xx2) .unwrap(); let xx3 = xx_initiator .write_message(&crypto, xx_initiator_header, handshake_meta(101)) .unwrap(); xx_responder - .read_message(&crypto, xx_initiator_header, &xx3) + .read_message(&crypto, xx_initiator_header, 0, &xx3) .unwrap(); let xx4 = xx_responder .write_message(&crypto, xx_responder_header, handshake_meta(101)) .unwrap(); xx_initiator - .read_message(&crypto, xx_responder_header, &xx4) + .read_message(&crypto, xx_responder_header, 0, &xx4) .unwrap(); let xx1 = xx_record(handshake_header(1, 2), xx1); @@ -529,14 +556,14 @@ fn protocol_record_size_breakdown() { .write_message(&crypto, kk_initiator_header, handshake_meta(201)) .unwrap(); kk_responder - .read_message(&crypto, kk_initiator_header, &kk1) + .read_message(&crypto, kk_initiator_header, 0, &kk1) .unwrap(); let kk2 = kk_responder .write_message(&crypto, kk_responder_header, handshake_meta(201)) .unwrap(); kk_initiator - .read_message(&crypto, kk_responder_header, &kk2) + .read_message(&crypto, kk_responder_header, 0, &kk2) .unwrap(); let kk1 = kk_record(handshake_header(1, 2), kk1); From 3cff1ec1bf92d59e4dbc36787c3d82b978ec7dfe Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 30 Mar 2026 09:49:15 -0400 Subject: [PATCH 060/304] ql-fsm: noise handshake --- Cargo.lock | 51 ++ ql-fsm/Cargo.toml | 1 + ql-fsm/src/error.rs | 4 +- ql-fsm/src/implementation/fsm.rs | 134 ++-- ql-fsm/src/implementation/handshake.rs | 677 --------------------- ql-fsm/src/implementation/handshake/kk.rs | 130 ++++ ql-fsm/src/implementation/handshake/mod.rs | 207 +++++++ ql-fsm/src/implementation/handshake/xx.rs | 230 +++++++ ql-fsm/src/implementation/mod.rs | 52 +- ql-fsm/src/implementation/peer.rs | 99 --- ql-fsm/src/lib.rs | 42 +- ql-fsm/src/replay_cache.rs | 8 +- ql-fsm/src/session/mod.rs | 29 +- ql-fsm/src/session/state.rs | 6 +- ql-fsm/src/session/tests.rs | 120 ++-- ql-fsm/src/state.rs | 92 ++- ql-fsm/src/tests/handshake.rs | 315 +++------- ql-fsm/src/tests/mod.rs | 284 ++++++--- ql-fsm/src/tests/session.rs | 119 ++-- ql-wire/Cargo.toml | 1 + ql-wire/src/encrypted/builder.rs | 6 +- ql-wire/src/encrypted/close.rs | 4 +- ql-wire/src/encrypted/mod.rs | 10 +- ql-wire/src/encrypted/stream_close.rs | 9 +- ql-wire/src/tests.rs | 53 +- 25 files changed, 1181 insertions(+), 1502 deletions(-) delete mode 100644 ql-fsm/src/implementation/handshake.rs create mode 100644 ql-fsm/src/implementation/handshake/kk.rs create mode 100644 ql-fsm/src/implementation/handshake/mod.rs create mode 100644 ql-fsm/src/implementation/handshake/xx.rs delete mode 100644 ql-fsm/src/implementation/peer.rs diff --git a/Cargo.lock b/Cargo.lock index b67bce2e..6c47fec1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1499,6 +1499,22 @@ dependencies = [ "hax-lib", ] +[[package]] +name = "libcrux-ml-kem" +version = "0.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aca7de713c6dddcf7aaf76e8ef9dc0097c8d7ce23a8eadf04c8761734714e184" +dependencies = [ + "hax-lib", + "libcrux-intrinsics", + "libcrux-platform", + "libcrux-secrets", + "libcrux-sha3", + "libcrux-traits", + "rand 0.9.2", + "tls_codec", +] + [[package]] name = "libcrux-platform" version = "0.0.3" @@ -1517,6 +1533,18 @@ dependencies = [ "hax-lib", ] +[[package]] +name = "libcrux-sha3" +version = "0.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c50f6e04a184511b782c5cc1eb6a227c6d36f2c935e93d698655a93a99696b5" +dependencies = [ + "hax-lib", + "libcrux-intrinsics", + "libcrux-platform", + "libcrux-traits", +] + [[package]] name = "libcrux-traits" version = "0.0.6" @@ -2101,6 +2129,7 @@ version = "0.1.0" dependencies = [ "indexmap", "libcrux-aesgcm", + "libcrux-ml-kem", "ql-wire", "sha2", "thiserror", @@ -2138,6 +2167,7 @@ name = "ql-wire" version = "0.1.0" dependencies = [ "libcrux-aesgcm", + "libcrux-ml-kem", "sha2", ] @@ -2714,6 +2744,27 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tls_codec" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de2e01245e2bb89d6f05801c564fa27624dbd7b1846859876c7dad82e90bf6b" +dependencies = [ + "tls_codec_derive", + "zeroize", +] + +[[package]] +name = "tls_codec_derive" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d2e76690929402faae40aebdda620a2c0e25dd6d3b9afe48867dfd95991f4bd" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "tokio" version = "1.47.1" diff --git a/ql-fsm/Cargo.toml b/ql-fsm/Cargo.toml index d3dba528..89c68339 100644 --- a/ql-fsm/Cargo.toml +++ b/ql-fsm/Cargo.toml @@ -12,4 +12,5 @@ thiserror = { version = "2" } [dev-dependencies] libcrux-aesgcm = "0.0.7" +libcrux-ml-kem = "0.0.7" sha2 = "0.10" diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs index 1642222c..e933a50e 100644 --- a/ql-fsm/src/error.rs +++ b/ql-fsm/src/error.rs @@ -7,8 +7,6 @@ use crate::session::StreamError; pub enum QlFsmError { #[error("invalid payload")] InvalidPayload, - #[error("invalid signature")] - InvalidSignature, #[error("expired")] Expired, #[error("decryption failed")] @@ -33,9 +31,9 @@ impl From for QlFsmError { fn from(value: WireError) -> Self { match value { WireError::InvalidPayload => Self::InvalidPayload, - WireError::InvalidSignature => Self::InvalidSignature, WireError::Expired => Self::Expired, WireError::DecryptFailed => Self::DecryptFailed, + WireError::InvalidState => Self::InvalidPayload, } } } diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index 21cd2ec2..94a38793 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -1,85 +1,44 @@ use std::time::Instant; -use ql_wire::{self as wire, CloseCode, CloseTarget, Nonce, QlCrypto, QlPayload, StreamId}; +use ql_wire::{self as wire, CloseCode, CloseTarget, QlCrypto, SessionHeader, StreamId}; -use crate::{ - OutboundWrite, QlFsm, QlFsmError, QlFsmEvent, QlSessionEvent, SessionWriteId, StreamReadIter, -}; +use crate::{OutboundWrite, QlFsm, QlFsmError, QlSessionEvent, SessionWriteId, StreamReadIter}; pub fn receive( fsm: &mut QlFsm, mut bytes: Vec, crypto: &impl QlCrypto, ) -> Result<(), QlFsmError> { - let wire::QlRecord { header, payload } = wire::QlRecord::parse(&mut bytes[..])?; - - if header.recipient != fsm.identity.xid { - return Err(QlFsmError::InvalidXid); - } - match &payload { - QlPayload::PairRequest(_) => {} - QlPayload::Unpair(_) => { - let Some(peer) = fsm.peer.as_ref().map(|entry| entry.peer.xid) else { - return Ok(()); - }; - if header.sender != peer { - return Err(QlFsmError::InvalidXid); - } - } - _ => { - let Some(peer) = fsm.peer.as_ref().map(|entry| entry.peer.xid) else { - return Err(QlFsmError::NoPeerBound); - }; - if header.sender != peer { - return Err(QlFsmError::InvalidXid); + match wire::QlRecord::parse(&mut bytes[..])? { + wire::QlRecord::Handshake(record) => super::handle_handshake_record(fsm, crypto, &record), + wire::QlRecord::Session(record) => { + let (_, transport) = super::peer_transport(fsm).ok_or(QlFsmError::NoSession)?; + if record.header.connection_id != transport.rx_connection_id { + return Err(QlFsmError::InvalidPayload); } - } - } - match payload { - QlPayload::PairRequest(request) => { - super::handle_pair(fsm, crypto, &header, request)?; - } - QlPayload::Unpair(unpair) => { - super::handle_unpair(fsm, crypto, &header, &unpair)?; - } - QlPayload::Hello(hello) => { - super::handle_hello(fsm, crypto, &header, &hello)?; - } - QlPayload::HelloReply(reply) => { - super::handle_hello_reply(fsm, crypto, &header, &reply)?; - } - QlPayload::Confirm(confirm) => { - super::handle_confirm(fsm, crypto, &header, &confirm)?; - } - QlPayload::Ready(ready) => { - super::handle_ready(fsm, crypto, &header, ready)?; - } - QlPayload::Session(encrypted) => { - let Some((_, session_key)) = super::peer_session(fsm) else { - return Err(QlFsmError::NoSession); - }; - let plaintext = wire::decrypt_record(crypto, &header, encrypted, &session_key)?; - let (seq, frames) = wire::SessionRecord::parse(plaintext.as_ref())?; + let plaintext = + wire::decrypt_record(crypto, &record.header, record.payload, &transport.rx_key)?; + let frames = wire::SessionRecord::parse(plaintext.as_ref())?; let mut session_closed = false; - fsm.session.receive(fsm.state.now.instant, seq, frames, { - let session_events = &mut fsm.state.session_events; - |event| { - session_closed |= super::forward_session_event(session_events, event); - } - }); + fsm.session + .receive(fsm.state.now.instant, record.header.seq, frames, { + let session_events = &mut fsm.state.session_events; + |event| { + session_closed |= super::forward_session_event(session_events, event); + } + }); if session_closed { super::apply_session_closed(fsm); } + Ok(()) } } - - Ok(()) } pub fn on_timer(fsm: &mut QlFsm) { super::handle_timer(fsm); - if super::peer_session(fsm).is_some() { + if super::peer_transport(fsm).is_some() { let mut session_closed = false; fsm.session.on_timer(fsm.state.now.instant, { let session_events = &mut fsm.state.session_events; @@ -96,7 +55,7 @@ pub fn on_timer(fsm: &mut QlFsm) { pub fn next_deadline(fsm: &QlFsm) -> Option { [ super::next_handshake_deadline(fsm), - super::peer_session(fsm).and_then(|_| fsm.session.next_deadline()), + super::peer_transport(fsm).and_then(|_| fsm.session.next_deadline()), ] .into_iter() .flatten() @@ -104,9 +63,9 @@ pub fn next_deadline(fsm: &QlFsm) -> Option { } pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option { - if let Some(record) = fsm.state.outbound.pop_front() { + if let Some(record) = fsm.state.handshake.take() { return Some(OutboundWrite { - record, + record: wire::QlRecord::Handshake(record), session_write_id: None, }); } @@ -117,26 +76,26 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option Option { - fsm.state.events.pop_front() -} - -pub fn take_next_session_event(fsm: &mut QlFsm) -> Option { - fsm.state.session_events.pop_front() -} - pub fn open_stream(fsm: &mut QlFsm) -> Result { ensure_peer_bound(fsm)?; Ok(fsm.session.open_stream()?) @@ -228,10 +176,6 @@ pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), QlFsmError> { Ok(fsm.session.queue_ping()?) } -pub fn unpair(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option>> { - super::handle_unpair_local(fsm, crypto) -} - fn ensure_peer_bound(fsm: &QlFsm) -> Result<(), QlFsmError> { fsm.peer.as_ref().map(|_| ()).ok_or(QlFsmError::NoPeerBound) } @@ -241,7 +185,7 @@ fn ensure_session_open(fsm: &QlFsm) -> Result<(), QlFsmError> { if fsm .peer .as_ref() - .and_then(|entry| entry.session.session_key()) + .and_then(|entry| entry.session.transport()) .is_none() { return Err(QlFsmError::SessionClosed); diff --git a/ql-fsm/src/implementation/handshake.rs b/ql-fsm/src/implementation/handshake.rs deleted file mode 100644 index dac02825..00000000 --- a/ql-fsm/src/implementation/handshake.rs +++ /dev/null @@ -1,677 +0,0 @@ -use std::{cmp::Ordering, time::Instant}; - -use ql_wire::{ - self as wire, Confirm, Hello, HelloReply, MlDsaPublicKey, Nonce, QlCrypto, QlHeader, - QlPayload, Ready, SessionKey, XID, -}; - -use super::{ - emit_peer_status, enqueue_handshake, fail_pending_connect_session, is_replayed_control, - next_control_meta, -}; -use crate::{ - state::{ConnectionState, HandshakeInitiator, HandshakeResponder, RecentReady}, - Peer, QlFsm, QlFsmError, -}; - -#[derive(Debug)] -enum HelloAction { - StartResponder, - ResendReply { reply: HelloReply }, - Ignore, -} - -#[derive(Debug)] -enum HelloReplyAction { - Advance { - hello: Hello, - initiator_secret: SessionKey, - responder_signing_key: MlDsaPublicKey, - }, - ResendConfirm { - confirm: Confirm, - }, -} - -#[derive(Debug, Clone)] -enum RetryAction { - Hello { peer: XID, hello: Hello }, - HelloReply { peer: XID, reply: HelloReply }, - Confirm { peer: XID, confirm: Confirm }, -} - -pub fn handle_connect(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { - start_initiator_handshake(fsm, crypto) -} - -pub fn handle_hello( - fsm: &mut QlFsm, - crypto: &impl QlCrypto, - header: &QlHeader, - hello: &Hello, -) -> Result<(), QlFsmError> { - let action = { - let Some(entry) = fsm.peer.as_ref() else { - return Ok(()); - }; - match &entry.session { - ConnectionState::Initiator { - hello: local_hello, .. - } => { - if peer_hello_wins_ref(local_hello, fsm.identity.xid, hello, header.sender) { - HelloAction::StartResponder - } else { - HelloAction::Ignore - } - } - ConnectionState::Responder { - hello: stored, - reply, - stage: HandshakeResponder::WaitingConfirm { .. }, - .. - } => { - if same_hello_ref(stored, hello) { - HelloAction::ResendReply { - reply: reply.clone(), - } - } else { - HelloAction::StartResponder - } - } - ConnectionState::Disconnected | ConnectionState::Connected { .. } => { - HelloAction::StartResponder - } - } - }; - let peer = fsm.peer.as_ref().map(|entry| entry.peer.clone()).unwrap(); - wire::verify_hello( - crypto, - header.sender, - fsm.identity.xid, - &peer.signing_key, - hello, - fsm.state.now.unix_secs, - )?; - - match action { - HelloAction::Ignore => {} - HelloAction::ResendReply { reply } => { - enqueue_handshake(fsm, header.sender, QlPayload::HelloReply(reply)); - } - HelloAction::StartResponder => { - if is_replayed_control(fsm, header.sender, hello.meta) { - return Ok(()); - } - - let reply_meta = next_control_meta(fsm, fsm.config.handshake_timeout); - let (reply, secrets) = wire::respond_hello( - crypto, - &fsm.identity, - peer.xid, - &peer.signing_key, - &peer.encapsulation_key, - hello, - reply_meta, - fsm.state.now.unix_secs, - )?; - - let deadline = fsm.state.now.instant + fsm.config.handshake_timeout; - let retry_at = Some(fsm.state.now.instant + fsm.config.handshake_retry_interval); - if let Some(entry) = fsm.peer.as_mut() { - entry.session = ConnectionState::Responder { - hello: hello.clone(), - reply: reply.clone(), - deadline, - stage: HandshakeResponder::WaitingConfirm { - secrets, - retry_count: 0, - retry_at, - }, - }; - } - enqueue_handshake(fsm, header.sender, QlPayload::HelloReply(reply)); - emit_peer_status(fsm); - } - } - - Ok(()) -} - -pub fn handle_hello_reply( - fsm: &mut QlFsm, - crypto: &impl QlCrypto, - header: &QlHeader, - reply: &HelloReply, -) -> Result<(), QlFsmError> { - let action = { - let Some(entry) = fsm.peer.as_ref() else { - return Ok(()); - }; - match &entry.session { - ConnectionState::Initiator { - hello, - stage: - HandshakeInitiator::WaitingHelloReply { - initiator_secret, .. - }, - .. - } => HelloReplyAction::Advance { - hello: hello.clone(), - initiator_secret: *initiator_secret, - responder_signing_key: entry.peer.signing_key.clone(), - }, - ConnectionState::Initiator { - stage: - HandshakeInitiator::WaitingReady { - reply: stored, - confirm, - .. - }, - .. - } if same_reply_ref(stored, reply) => HelloReplyAction::ResendConfirm { - confirm: confirm.clone(), - }, - _ => return Ok(()), - } - }; - - match action { - HelloReplyAction::ResendConfirm { confirm } => { - enqueue_handshake(fsm, header.sender, QlPayload::Confirm(confirm)); - } - HelloReplyAction::Advance { - hello, - initiator_secret, - responder_signing_key, - } => { - let confirm_meta = next_control_meta(fsm, fsm.config.handshake_timeout); - let (confirm, session_key) = wire::build_confirm( - crypto, - &fsm.identity, - header.sender, - &responder_signing_key, - &hello, - reply, - &initiator_secret, - confirm_meta, - fsm.state.now.unix_secs, - )?; - - if is_replayed_control(fsm, header.sender, reply.meta) { - return Ok(()); - } - - let deadline = fsm.state.now.instant + fsm.config.handshake_timeout; - let retry_at = Some(fsm.state.now.instant + fsm.config.handshake_retry_interval); - if let Some(entry) = fsm.peer.as_mut() { - entry.session = ConnectionState::Initiator { - hello, - deadline, - stage: HandshakeInitiator::WaitingReady { - reply: reply.clone(), - confirm: confirm.clone(), - session_key, - retry_count: 0, - retry_at, - }, - }; - } - enqueue_handshake(fsm, header.sender, QlPayload::Confirm(confirm)); - } - } - - Ok(()) -} - -pub fn handle_confirm( - fsm: &mut QlFsm, - crypto: &impl QlCrypto, - header: &QlHeader, - confirm: &Confirm, -) -> Result<(), QlFsmError> { - if let Some(ready) = recent_ready_resend(fsm, crypto, header.sender, confirm) { - enqueue_handshake(fsm, header.sender, QlPayload::Ready(ready)); - return Ok(()); - } - - let outcome = { - let Some(entry) = fsm.peer.as_ref() else { - return Ok(()); - }; - let ConnectionState::Responder { - hello, - reply, - deadline, - stage: HandshakeResponder::WaitingConfirm { secrets, .. }, - } = &entry.session - else { - return Ok(()); - }; - - wire::finalize_confirm( - crypto, - header.sender, - fsm.identity.xid, - &entry.peer.signing_key, - hello, - reply, - confirm, - secrets, - fsm.state.now.unix_secs, - ) - .map(|session_key| (hello.clone(), reply.clone(), *deadline, session_key)) - }; - - let (hello, reply, deadline, session_key) = outcome?; - - if is_replayed_control(fsm, header.sender, confirm.meta) { - return Ok(()); - } - - let ready = wire::build_ready( - crypto, - QlHeader { - sender: fsm.identity.xid, - recipient: header.sender, - }, - &session_key, - next_control_meta(fsm, fsm.config.handshake_timeout), - next_encrypted_nonce(crypto), - ); - - if let Some(entry) = fsm.peer.as_mut() { - entry.session = ConnectionState::Connected { - session_key, - recent_ready: Some(RecentReady { - hello, - reply, - ready: ready.clone(), - expires_at: deadline, - }), - }; - } - - enqueue_handshake(fsm, header.sender, QlPayload::Ready(ready)); - emit_peer_status(fsm); - Ok(()) -} - -pub fn handle_ready( - fsm: &mut QlFsm, - crypto: &impl QlCrypto, - header: &QlHeader, - ready: Ready<&mut [u8]>, -) -> Result<(), QlFsmError> { - let session_key = { - let Some(entry) = fsm.peer.as_ref() else { - return Ok(()); - }; - match &entry.session { - ConnectionState::Initiator { - stage: HandshakeInitiator::WaitingReady { session_key, .. }, - .. - } => *session_key, - _ => return Ok(()), - } - }; - - let body = wire::decrypt_ready(crypto, header, ready, &session_key, fsm.state.now.unix_secs)?; - if is_replayed_control(fsm, header.sender, body.meta) { - return Ok(()); - } - - if let Some(entry) = fsm.peer.as_mut() { - entry.session = ConnectionState::Connected { - session_key, - recent_ready: None, - }; - } - emit_peer_status(fsm); - Ok(()) -} - -pub fn handle_timer(fsm: &mut QlFsm) { - let now = fsm.state.now.instant; - if let Some(ConnectionState::Connected { - recent_ready: Some(recent_ready), - .. - }) = fsm.peer.as_mut().map(|entry| &mut entry.session) - { - if recent_ready.expires_at <= now { - if let Some(entry) = fsm.peer.as_mut() { - if let ConnectionState::Connected { recent_ready, .. } = &mut entry.session { - *recent_ready = None; - } - } - } - } - - let mut retry_action = None; - let mut disconnected = false; - - if let Some(entry) = fsm.peer.as_mut() { - match &mut entry.session { - ConnectionState::Initiator { - hello, - deadline, - stage, - } => { - if *deadline <= now { - entry.session = ConnectionState::Disconnected; - disconnected = true; - } else { - retry_action = handle_initiator_retry( - &entry.peer, - hello, - stage, - now, - fsm.config.handshake_retry_interval, - fsm.config.max_handshake_retries, - ); - if retry_action.is_none() && initiator_retries_exhausted(stage) { - entry.session = ConnectionState::Disconnected; - disconnected = true; - } - } - } - ConnectionState::Responder { - reply, - deadline, - stage, - .. - } => { - if *deadline <= now { - entry.session = ConnectionState::Disconnected; - disconnected = true; - } else { - retry_action = handle_responder_retry( - &entry.peer, - reply, - stage, - now, - fsm.config.handshake_retry_interval, - fsm.config.max_handshake_retries, - ); - if retry_action.is_none() && responder_retries_exhausted(stage) { - entry.session = ConnectionState::Disconnected; - disconnected = true; - } - } - } - ConnectionState::Disconnected | ConnectionState::Connected { .. } => {} - } - } - - if disconnected { - fail_pending_connect_session(fsm, ql_wire::CloseCode::TIMEOUT); - emit_peer_status(fsm); - } - - if let Some(action) = retry_action { - match action { - RetryAction::Hello { peer, hello } => { - enqueue_handshake(fsm, peer, QlPayload::Hello(hello)); - } - RetryAction::HelloReply { peer, reply } => { - enqueue_handshake(fsm, peer, QlPayload::HelloReply(reply)); - } - RetryAction::Confirm { peer, confirm } => { - enqueue_handshake(fsm, peer, QlPayload::Confirm(confirm)); - } - } - } -} - -pub fn next_handshake_deadline(fsm: &QlFsm) -> Option { - let mut deadline = None; - if let Some(entry) = fsm.peer.as_ref() { - match &entry.session { - ConnectionState::Initiator { - deadline: session_deadline, - stage, - .. - } => { - deadline = Some(*session_deadline); - deadline = min_optional(deadline, initiator_retry_at(stage)); - } - ConnectionState::Responder { - deadline: session_deadline, - stage, - .. - } => { - deadline = Some(*session_deadline); - deadline = min_optional(deadline, responder_retry_at(stage)); - } - ConnectionState::Connected { - recent_ready: Some(recent_ready), - .. - } => { - deadline = Some(recent_ready.expires_at); - } - ConnectionState::Disconnected | ConnectionState::Connected { .. } => {} - } - } - deadline -} - -fn start_initiator_handshake(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { - let Some(entry) = fsm.peer.as_ref() else { - return Err(QlFsmError::NoPeerBound); - }; - if !matches!(entry.session, ConnectionState::Disconnected) { - return Ok(()); - } - - let peer = entry.peer.clone(); - let meta = next_control_meta(fsm, fsm.config.handshake_timeout); - let (hello, initiator_secret) = wire::build_hello( - crypto, - &fsm.identity, - peer.xid, - &peer.encapsulation_key, - meta, - ); - let deadline = fsm.state.now.instant + fsm.config.handshake_timeout; - let retry_at = Some(fsm.state.now.instant + fsm.config.handshake_retry_interval); - - if let Some(entry) = fsm.peer.as_mut() { - entry.session = ConnectionState::Initiator { - hello: hello.clone(), - deadline, - stage: HandshakeInitiator::WaitingHelloReply { - initiator_secret, - retry_count: 0, - retry_at, - }, - }; - } - - enqueue_handshake(fsm, peer.xid, QlPayload::Hello(hello)); - emit_peer_status(fsm); - Ok(()) -} - -fn recent_ready_resend( - fsm: &QlFsm, - crypto: &impl QlCrypto, - peer: XID, - confirm: &Confirm, -) -> Option>> { - let entry = fsm.peer.as_ref()?; - let ConnectionState::Connected { - recent_ready: Some(recent_ready), - .. - } = &entry.session - else { - return None; - }; - if recent_ready.expires_at <= fsm.state.now.instant { - return None; - } - wire::verify_confirm( - crypto, - peer, - fsm.identity.xid, - &entry.peer.signing_key, - &recent_ready.hello, - &recent_ready.reply, - confirm, - fsm.state.now.unix_secs, - ) - .ok()?; - Some(recent_ready.ready.clone()) -} - -fn handle_initiator_retry( - peer: &Peer, - hello: &Hello, - stage: &mut HandshakeInitiator, - now: Instant, - retry_interval: std::time::Duration, - max_retries: u8, -) -> Option { - match stage { - HandshakeInitiator::WaitingHelloReply { - retry_count, - retry_at, - .. - } => { - if retry_due(*retry_at, now) { - if *retry_count >= max_retries { - *retry_at = None; - None - } else { - *retry_count = retry_count.saturating_add(1); - *retry_at = Some(now + retry_interval); - Some(RetryAction::Hello { - peer: peer.xid, - hello: hello.clone(), - }) - } - } else { - None - } - } - HandshakeInitiator::WaitingReady { - confirm, - retry_count, - retry_at, - .. - } => { - if retry_due(*retry_at, now) { - if *retry_count >= max_retries { - *retry_at = None; - None - } else { - *retry_count = retry_count.saturating_add(1); - *retry_at = Some(now + retry_interval); - Some(RetryAction::Confirm { - peer: peer.xid, - confirm: confirm.clone(), - }) - } - } else { - None - } - } - } -} - -fn handle_responder_retry( - peer: &Peer, - reply: &HelloReply, - stage: &mut HandshakeResponder, - now: Instant, - retry_interval: std::time::Duration, - max_retries: u8, -) -> Option { - match stage { - HandshakeResponder::WaitingConfirm { - retry_count, - retry_at, - .. - } => { - if retry_due(*retry_at, now) { - if *retry_count >= max_retries { - *retry_at = None; - None - } else { - *retry_count = retry_count.saturating_add(1); - *retry_at = Some(now + retry_interval); - Some(RetryAction::HelloReply { - peer: peer.xid, - reply: reply.clone(), - }) - } - } else { - None - } - } - } -} - -fn initiator_retries_exhausted(stage: &HandshakeInitiator) -> bool { - match stage { - HandshakeInitiator::WaitingHelloReply { retry_at, .. } - | HandshakeInitiator::WaitingReady { retry_at, .. } => retry_at.is_none(), - } -} - -fn responder_retries_exhausted(stage: &HandshakeResponder) -> bool { - match stage { - HandshakeResponder::WaitingConfirm { retry_at, .. } => retry_at.is_none(), - } -} - -fn initiator_retry_at(stage: &HandshakeInitiator) -> Option { - match stage { - HandshakeInitiator::WaitingHelloReply { retry_at, .. } - | HandshakeInitiator::WaitingReady { retry_at, .. } => *retry_at, - } -} - -fn responder_retry_at(stage: &HandshakeResponder) -> Option { - match stage { - HandshakeResponder::WaitingConfirm { retry_at, .. } => *retry_at, - } -} - -fn same_hello_ref(stored: &Hello, incoming: &Hello) -> bool { - stored.meta.control_id == incoming.meta.control_id && stored.nonce == incoming.nonce -} - -fn same_reply_ref(stored: &HelloReply, incoming: &HelloReply) -> bool { - stored.meta.control_id == incoming.meta.control_id && stored.nonce == incoming.nonce -} - -fn peer_hello_wins_ref( - local_hello: &Hello, - local_sender: XID, - peer_hello: &Hello, - peer_sender: XID, -) -> bool { - match peer_hello.nonce.0.cmp(&local_hello.nonce.0) { - Ordering::Less => true, - Ordering::Greater => false, - Ordering::Equal => peer_sender.0.cmp(&local_sender.0) == Ordering::Less, - } -} - -fn next_encrypted_nonce(crypto: &impl QlCrypto) -> Nonce { - let mut bytes = [0u8; Nonce::SIZE]; - crypto.fill_random_bytes(&mut bytes); - Nonce(bytes) -} - -fn retry_due(retry_at: Option, now: Instant) -> bool { - retry_at.is_some_and(|deadline| deadline <= now) -} - -fn min_optional(current: Option, other: Option) -> Option { - match (current, other) { - (Some(left), Some(right)) => Some(left.min(right)), - (Some(left), None) => Some(left), - (None, Some(right)) => Some(right), - (None, None) => None, - } -} diff --git a/ql-fsm/src/implementation/handshake/kk.rs b/ql-fsm/src/implementation/handshake/kk.rs new file mode 100644 index 00000000..3b2162fe --- /dev/null +++ b/ql-fsm/src/implementation/handshake/kk.rs @@ -0,0 +1,130 @@ +use ql_wire::{ + self as wire, HandshakeHeader, HandshakePayload, Kk1, Kk2, KkMessage, PeerBundle, QlCrypto, + WireError, XID, +}; + +use super::{ + ensure_bound_peer, ensure_bound_peer_with_bundle, finish_handshake, + reset_connected_session_if_needed, should_ignore_inbound_handshake_start, +}; +use crate::{ + implementation::{emit_peer_status, enqueue_handshake, is_replayed_handshake_start}, + state::{ConnectionState, HandshakeMode, HandshakeState, SessionTransport}, + QlFsm, QlFsmError, +}; + +pub fn start_initiator( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + peer: XID, + bundle: PeerBundle, +) -> Result<(), QlFsmError> { + let header = HandshakeHeader { + sender: fsm.identity.xid, + recipient: peer, + }; + let meta = super::next_handshake_meta(fsm); + let mut handshake = wire::KkHandshake::new_initiator(crypto, fsm.identity.clone(), bundle); + let message = handshake.write_message(crypto, header, meta)?; + let payload = kk_payload(message); + let initial_ephemeral = match &payload { + HandshakePayload::Kk1(message) => Some(message.ephemeral.clone()), + _ => None, + }; + + if let Some(entry) = fsm.peer.as_mut() { + entry.session = ConnectionState::Handshaking(HandshakeState { + id: meta.handshake_id, + deadline: fsm.state.now.instant + fsm.config.handshake_timeout, + mode: HandshakeMode::KkInitiator(handshake), + initial_ephemeral, + }); + } + enqueue_handshake(fsm, peer, payload); + emit_peer_status(fsm); + Ok(()) +} + +pub fn handle_kk1( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + header: HandshakeHeader, + message: &Kk1, +) -> Result<(), QlFsmError> { + if should_ignore_inbound_handshake_start(fsm, header.sender, false, &message.ephemeral) { + return Ok(()); + } + + if is_replayed_handshake_start(fsm, header.sender, message.meta) { + return Ok(()); + } + ensure_bound_peer_with_bundle(fsm, header.sender)?; + reset_connected_session_if_needed(fsm); + + let bundle = fsm + .peer + .as_ref() + .and_then(|entry| entry.peer.bundle.clone()) + .ok_or(QlFsmError::NoPeerBound)?; + let mut handshake = wire::KkHandshake::new_responder(crypto, fsm.identity.clone(), bundle); + handshake.read_message( + crypto, + header, + fsm.state.now.unix_secs, + &KkMessage::Message1(message.clone()), + )?; + let outbound = handshake.write_message( + crypto, + HandshakeHeader { + sender: fsm.identity.xid, + recipient: header.sender, + }, + message.meta, + )?; + let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); + finish_handshake(fsm, transport, remote_bundle)?; + fsm.state.handshake = None; + enqueue_handshake(fsm, header.sender, kk_payload(outbound)); + Ok(()) +} + +pub fn handle_kk2( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + header: HandshakeHeader, + message: &Kk2, +) -> Result<(), QlFsmError> { + ensure_bound_peer(fsm, header.sender)?; + let session = match fsm.peer.as_ref() { + Some(entry) => entry.session.clone(), + None => return Ok(()), + }; + let ConnectionState::Handshaking(HandshakeState { + mode: HandshakeMode::KkInitiator(mut handshake), + .. + }) = session + else { + return Ok(()); + }; + + match handshake.read_message( + crypto, + header, + fsm.state.now.unix_secs, + &KkMessage::Message2(message.clone()), + ) { + Ok(()) => {} + Err(WireError::InvalidState) => return Ok(()), + Err(error) => return Err(error.into()), + } + + let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); + finish_handshake(fsm, transport, remote_bundle) +} + +fn kk_payload(message: KkMessage) -> HandshakePayload { + match message { + KkMessage::Message1(message) => HandshakePayload::Kk1(message), + KkMessage::Message2(message) => HandshakePayload::Kk2(message), + } +} diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs new file mode 100644 index 00000000..4c3cc0fe --- /dev/null +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -0,0 +1,207 @@ +mod kk; +mod xx; + +use std::cmp::Ordering; + +use ql_wire::{ + self as wire, EphemeralPublicKey, HandshakeHeader, HandshakeMeta, HandshakePayload, QlCrypto, + QlHandshakeRecord, XID, +}; + +use super::{emit_peer_status, fail_pending_connect_session, reset_session}; +use crate::{ + state::{ConnectionState, HandshakeMode, HandshakeState, SessionTransport}, + Peer, QlFsm, QlFsmError, +}; + +pub fn handle_connect(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { + let Some(peer) = fsm.peer.as_ref().map(|entry| entry.peer.clone()) else { + return Err(QlFsmError::NoPeerBound); + }; + if !matches!( + fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Disconnected) + ) { + return Ok(()); + } + + match &peer.bundle { + Some(bundle) => kk::start_initiator(fsm, crypto, peer.xid, bundle.clone()), + None => xx::start_initiator(fsm, crypto, peer.xid), + } +} + +pub fn next_handshake_meta(fsm: &mut QlFsm) -> wire::HandshakeMeta { + let handshake_id = wire::HandshakeId(fsm.state.next_control_id); + fsm.state.next_control_id = fsm.state.next_control_id.wrapping_add(1); + wire::HandshakeMeta { + handshake_id, + valid_until: super::deadline_after_secs( + fsm.state.now.unix_secs, + fsm.config.handshake_timeout, + ), + } +} + +pub fn enqueue_handshake(fsm: &mut QlFsm, peer: XID, payload: HandshakePayload) { + debug_assert!(fsm.state.handshake.is_none()); + fsm.state.handshake = Some(QlHandshakeRecord { + header: HandshakeHeader { + sender: fsm.identity.xid, + recipient: peer, + }, + payload, + }); +} + +pub fn is_replayed_handshake_start(fsm: &mut QlFsm, peer: XID, meta: HandshakeMeta) -> bool { + fsm.state + .replay_cache + .check_and_store_valid_until(peer, meta, fsm.state.now.unix_secs) +} + +pub fn handle_handshake_record( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + record: &QlHandshakeRecord, +) -> Result<(), QlFsmError> { + if record.header.recipient != fsm.identity.xid { + return Err(QlFsmError::InvalidXid); + } + + match &record.payload { + HandshakePayload::Xx1(message) => xx::handle_xx1(fsm, crypto, record.header, message), + HandshakePayload::Xx2(message) => xx::handle_xx2(fsm, crypto, record.header, message), + HandshakePayload::Xx3(message) => xx::handle_xx3(fsm, crypto, record.header, message), + HandshakePayload::Xx4(message) => xx::handle_xx4(fsm, crypto, record.header, message), + HandshakePayload::Kk1(message) => kk::handle_kk1(fsm, crypto, record.header, message), + HandshakePayload::Kk2(message) => kk::handle_kk2(fsm, crypto, record.header, message), + } +} + +pub fn handle_timer(fsm: &mut QlFsm) { + let timed_out = matches!( + fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Handshaking(HandshakeState { deadline, .. })) + if *deadline <= fsm.state.now.instant + ); + + if !timed_out { + return; + } + + if let Some(entry) = fsm.peer.as_mut() { + entry.session = ConnectionState::Disconnected; + } + fsm.state.handshake = None; + fail_pending_connect_session(fsm, ql_wire::CloseCode::TIMEOUT); + emit_peer_status(fsm); +} + +pub fn next_handshake_deadline(fsm: &QlFsm) -> Option { + match fsm.peer.as_ref().map(|entry| &entry.session) { + Some(ConnectionState::Handshaking(HandshakeState { deadline, .. })) => Some(*deadline), + _ => None, + } +} + +fn ensure_or_bind_peer( + fsm: &mut QlFsm, + xid: XID, + bundle: Option, +) -> Result<(), QlFsmError> { + match fsm.peer.as_ref() { + Some(entry) if entry.peer.xid == xid => Ok(()), + Some(_) => Err(QlFsmError::InvalidXid), + None => { + super::handle_bind_peer(fsm, Peer { xid, bundle }); + Ok(()) + } + } +} + +fn ensure_bound_peer(fsm: &QlFsm, xid: XID) -> Result<(), QlFsmError> { + match fsm.peer.as_ref() { + Some(entry) if entry.peer.xid == xid => Ok(()), + Some(_) => Err(QlFsmError::InvalidXid), + None => Ok(()), + } +} + +fn ensure_bound_peer_with_bundle(fsm: &QlFsm, xid: XID) -> Result<(), QlFsmError> { + match fsm.peer.as_ref() { + Some(entry) if entry.peer.xid == xid && entry.peer.bundle.is_some() => Ok(()), + Some(entry) if entry.peer.xid == xid => Err(QlFsmError::InvalidPayload), + Some(_) => Err(QlFsmError::InvalidXid), + None => Err(QlFsmError::NoPeerBound), + } +} + +fn finish_handshake( + fsm: &mut QlFsm, + transport: SessionTransport, + remote_bundle: wire::PeerBundle, +) -> Result<(), QlFsmError> { + let Some(entry) = fsm.peer.as_mut() else { + return Err(QlFsmError::NoPeerBound); + }; + + match &entry.peer.bundle { + Some(existing) if existing != &remote_bundle => return Err(QlFsmError::InvalidPayload), + Some(_) => {} + None => entry.peer.bundle = Some(remote_bundle), + } + + entry.session = ConnectionState::Connected(transport); + emit_peer_status(fsm); + Ok(()) +} + +fn reset_connected_session_if_needed(fsm: &mut QlFsm) { + if matches!( + fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Connected(_)) + ) { + reset_session(fsm); + } +} + +fn should_ignore_inbound_handshake_start( + fsm: &QlFsm, + sender: XID, + inbound_xx: bool, + inbound_ephemeral: &EphemeralPublicKey, +) -> bool { + let Some(entry) = fsm.peer.as_ref() else { + return false; + }; + if entry.peer.xid != sender { + return false; + } + + let ConnectionState::Handshaking(HandshakeState { + mode, + initial_ephemeral: Some(local_ephemeral), + .. + }) = &entry.session + else { + return false; + }; + + match (mode, inbound_xx) { + (HandshakeMode::KkInitiator(_), true) => false, + (HandshakeMode::XxInitiator(_), false) => true, + (HandshakeMode::XxInitiator(_), true) | (HandshakeMode::KkInitiator(_), false) => { + match inbound_ephemeral + .mlkem_public_key + .as_bytes() + .cmp(local_ephemeral.mlkem_public_key.as_bytes()) + { + Ordering::Less => false, + Ordering::Greater => true, + Ordering::Equal => sender.0.cmp(&fsm.identity.xid.0) != Ordering::Less, + } + } + _ => false, + } +} diff --git a/ql-fsm/src/implementation/handshake/xx.rs b/ql-fsm/src/implementation/handshake/xx.rs new file mode 100644 index 00000000..c5c66400 --- /dev/null +++ b/ql-fsm/src/implementation/handshake/xx.rs @@ -0,0 +1,230 @@ +use ql_wire::{ + self as wire, HandshakeHeader, HandshakePayload, QlCrypto, WireError, Xx1, Xx2, Xx3, Xx4, + XxMessage, XID, +}; + +use super::{ + ensure_bound_peer, ensure_or_bind_peer, finish_handshake, reset_connected_session_if_needed, + should_ignore_inbound_handshake_start, +}; +use crate::{ + implementation::{emit_peer_status, enqueue_handshake, is_replayed_handshake_start}, + state::{ConnectionState, HandshakeMode, HandshakeState, SessionTransport}, + QlFsm, QlFsmError, +}; + +pub fn start_initiator( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + peer: XID, +) -> Result<(), QlFsmError> { + let header = HandshakeHeader { + sender: fsm.identity.xid, + recipient: peer, + }; + let meta = super::next_handshake_meta(fsm); + let mut handshake = wire::XxHandshake::new_initiator(crypto, fsm.identity.clone()); + let message = handshake.write_message(crypto, header, meta)?; + let payload = xx_payload(message); + let initial_ephemeral = match &payload { + HandshakePayload::Xx1(message) => Some(message.ephemeral.clone()), + _ => None, + }; + + if let Some(entry) = fsm.peer.as_mut() { + entry.session = ConnectionState::Handshaking(HandshakeState { + id: meta.handshake_id, + deadline: fsm.state.now.instant + fsm.config.handshake_timeout, + mode: HandshakeMode::XxInitiator(handshake), + initial_ephemeral, + }); + } + enqueue_handshake(fsm, peer, payload); + emit_peer_status(fsm); + Ok(()) +} + +pub fn handle_xx1( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + header: HandshakeHeader, + message: &Xx1, +) -> Result<(), QlFsmError> { + if should_ignore_inbound_handshake_start(fsm, header.sender, true, &message.ephemeral) { + return Ok(()); + } + + if is_replayed_handshake_start(fsm, header.sender, message.meta) { + return Ok(()); + } + ensure_or_bind_peer(fsm, header.sender, None)?; + reset_connected_session_if_needed(fsm); + + let mut handshake = wire::XxHandshake::new_responder(crypto, fsm.identity.clone()); + handshake.read_message( + crypto, + header, + fsm.state.now.unix_secs, + &XxMessage::Message1(message.clone()), + )?; + let outbound = handshake.write_message( + crypto, + HandshakeHeader { + sender: fsm.identity.xid, + recipient: header.sender, + }, + message.meta, + )?; + + if let Some(entry) = fsm.peer.as_mut() { + entry.session = ConnectionState::Handshaking(HandshakeState { + id: message.meta.handshake_id, + deadline: fsm.state.now.instant + fsm.config.handshake_timeout, + mode: HandshakeMode::XxResponder(handshake), + initial_ephemeral: None, + }); + } + fsm.state.handshake = None; + enqueue_handshake(fsm, header.sender, xx_payload(outbound)); + emit_peer_status(fsm); + Ok(()) +} + +pub fn handle_xx2( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + header: HandshakeHeader, + message: &Xx2, +) -> Result<(), QlFsmError> { + ensure_bound_peer(fsm, header.sender)?; + let session = match fsm.peer.as_ref() { + Some(entry) => entry.session.clone(), + None => return Ok(()), + }; + let ConnectionState::Handshaking(HandshakeState { + id, + deadline, + mode: HandshakeMode::XxInitiator(mut handshake), + initial_ephemeral, + }) = session + else { + return Ok(()); + }; + + match handshake.read_message( + crypto, + header, + fsm.state.now.unix_secs, + &XxMessage::Message2(message.clone()), + ) { + Ok(()) => {} + Err(WireError::InvalidState) => return Ok(()), + Err(error) => return Err(error.into()), + } + + let outbound = handshake.write_message( + crypto, + HandshakeHeader { + sender: fsm.identity.xid, + recipient: header.sender, + }, + message.meta, + )?; + if let Some(entry) = fsm.peer.as_mut() { + entry.session = ConnectionState::Handshaking(HandshakeState { + id, + deadline, + mode: HandshakeMode::XxInitiator(handshake), + initial_ephemeral, + }); + } + enqueue_handshake(fsm, header.sender, xx_payload(outbound)); + Ok(()) +} + +pub fn handle_xx3( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + header: HandshakeHeader, + message: &Xx3, +) -> Result<(), QlFsmError> { + ensure_bound_peer(fsm, header.sender)?; + let session = match fsm.peer.as_ref() { + Some(entry) => entry.session.clone(), + None => return Ok(()), + }; + let ConnectionState::Handshaking(HandshakeState { + mode: HandshakeMode::XxResponder(mut handshake), + .. + }) = session + else { + return Ok(()); + }; + + match handshake.read_message( + crypto, + header, + fsm.state.now.unix_secs, + &XxMessage::Message3(message.clone()), + ) { + Ok(()) => {} + Err(WireError::InvalidState) => return Ok(()), + Err(error) => return Err(error.into()), + } + + let outbound = handshake.write_message( + crypto, + HandshakeHeader { + sender: fsm.identity.xid, + recipient: header.sender, + }, + message.meta, + )?; + let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); + finish_handshake(fsm, transport, remote_bundle)?; + enqueue_handshake(fsm, header.sender, xx_payload(outbound)); + Ok(()) +} + +pub fn handle_xx4( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + header: HandshakeHeader, + message: &Xx4, +) -> Result<(), QlFsmError> { + ensure_bound_peer(fsm, header.sender)?; + let session = match fsm.peer.as_ref() { + Some(entry) => entry.session.clone(), + None => return Ok(()), + }; + let ConnectionState::Handshaking(HandshakeState { + mode: HandshakeMode::XxInitiator(mut handshake), + .. + }) = session + else { + return Ok(()); + }; + + match handshake.read_message( + crypto, + header, + fsm.state.now.unix_secs, + &XxMessage::Message4(message.clone()), + ) { + Ok(()) => {} + Err(WireError::InvalidState) => return Ok(()), + Err(error) => return Err(error.into()), + } + + let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); + finish_handshake(fsm, transport, remote_bundle) +} + +fn xx_payload(message: XxMessage) -> HandshakePayload { + match message { + XxMessage::Message1(message) => HandshakePayload::Xx1(message), + XxMessage::Message2(message) => HandshakePayload::Xx2(message), + XxMessage::Message3(message) => HandshakePayload::Xx3(message), + XxMessage::Message4(message) => HandshakePayload::Xx4(message), + } +} diff --git a/ql-fsm/src/implementation/mod.rs b/ql-fsm/src/implementation/mod.rs index efae6dbe..9707ccec 100644 --- a/ql-fsm/src/implementation/mod.rs +++ b/ql-fsm/src/implementation/mod.rs @@ -1,17 +1,16 @@ mod fsm; mod handshake; -mod peer; use std::{collections::VecDeque, time::Duration}; pub use fsm::*; pub use handshake::*; -pub use peer::*; -use ql_wire::{ControlId, ControlMeta, QlHeader, QlPayload, QlRecord, SessionKey, XID}; +use ql_wire::XID; use crate::{ + state::PeerRecord, session::{state::StreamParity, SessionEvent, SessionFsmConfig}, - QlFsm, QlFsmEvent, QlSessionEvent, + Peer, QlFsm, QlFsmEvent, QlSessionEvent, }; fn emit_peer_status(fsm: &mut QlFsm) { @@ -23,35 +22,10 @@ fn emit_peer_status(fsm: &mut QlFsm) { } } -fn next_control_meta(fsm: &mut QlFsm, lifetime: Duration) -> ControlMeta { - let control_id = ControlId(fsm.state.next_control_id); - fsm.state.next_control_id = fsm.state.next_control_id.wrapping_add(1); - ControlMeta { - control_id, - valid_until: deadline_after_secs(fsm.state.now.unix_secs, lifetime), - } -} - -fn enqueue_handshake(fsm: &mut QlFsm, peer: XID, payload: QlPayload>) { - fsm.state.outbound.push_back(QlRecord { - header: QlHeader { - sender: fsm.identity.xid, - recipient: peer, - }, - payload, - }); -} - -fn is_replayed_control(fsm: &mut QlFsm, peer: XID, meta: ControlMeta) -> bool { - fsm.state - .replay_cache - .check_and_store_valid_until(peer, meta, fsm.state.now.unix_secs) -} - -fn peer_session(fsm: &QlFsm) -> Option<(XID, SessionKey)> { +fn peer_transport(fsm: &QlFsm) -> Option<(XID, crate::state::SessionTransport)> { let entry = fsm.peer.as_ref()?; - let session_key = *entry.session.session_key()?; - Some((entry.peer.xid, session_key)) + let transport = entry.session.transport()?.clone(); + Some((entry.peer.xid, transport)) } fn reset_session(fsm: &mut QlFsm) { @@ -75,14 +49,12 @@ fn reset_session(fsm: &mut QlFsm) { ); } -fn clear_bound_peer(fsm: &mut QlFsm) { - if fsm.peer.take().is_none() { - return; - } - fsm.state.outbound.clear(); +pub fn handle_bind_peer(fsm: &mut QlFsm, peer: Peer) { + fsm.state.handshake = None; + fsm.peer = Some(PeerRecord::new(peer.clone())); reset_session(fsm); - fsm.state.session_events.push_back(QlSessionEvent::Unpaired); - fsm.state.events.push_back(QlFsmEvent::ClearPeer); + fsm.state.events.push_back(QlFsmEvent::NewPeer(peer)); + emit_peer_status(fsm); } fn fail_pending_connect_session(fsm: &mut QlFsm, code: ql_wire::CloseCode) { @@ -92,7 +64,7 @@ fn fail_pending_connect_session(fsm: &mut QlFsm, code: ql_wire::CloseCode) { reset_session(fsm); fsm.state .session_events - .push_back(QlSessionEvent::SessionClosed(ql_wire::SessionCloseBody { + .push_back(QlSessionEvent::SessionClosed(ql_wire::SessionClose { code, })); } diff --git a/ql-fsm/src/implementation/peer.rs b/ql-fsm/src/implementation/peer.rs deleted file mode 100644 index 09fcfcbc..00000000 --- a/ql-fsm/src/implementation/peer.rs +++ /dev/null @@ -1,99 +0,0 @@ -use ql_wire::{self as wire, PairRequestRecord, QlCrypto, QlHeader, Unpair}; - -use super::{ - clear_bound_peer, emit_peer_status, handshake, is_replayed_control, next_control_meta, - reset_session, -}; -use crate::{state::PeerRecord, Peer, QlFsm, QlFsmError, QlFsmEvent}; - -pub fn handle_bind_peer(fsm: &mut QlFsm, peer: Peer) { - bind_peer_record(fsm, peer); -} - -pub fn handle_pair_local(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { - let meta = next_control_meta(fsm, fsm.config.control_expiration); - let peer = fsm.peer.as_ref().ok_or(QlFsmError::NoPeerBound)?; - let record = wire::build_pair_request( - crypto, - &fsm.identity, - peer.peer.xid, - &peer.peer.encapsulation_key, - meta, - ); - fsm.state.outbound.push_back(record); - Ok(()) -} - -pub fn handle_unpair_local( - fsm: &mut QlFsm, - crypto: &impl QlCrypto, -) -> Option>> { - let peer = fsm.peer.as_ref()?.peer.clone(); - let meta = next_control_meta(fsm, fsm.config.control_expiration); - let record = wire::build_unpair(crypto, &fsm.identity, peer.xid, meta); - clear_bound_peer(fsm); - Some(record) -} - -pub fn handle_pair( - fsm: &mut QlFsm, - crypto: &impl QlCrypto, - header: &QlHeader, - request: PairRequestRecord<&mut [u8]>, -) -> Result<(), QlFsmError> { - let payload = wire::decrypt_pair_request( - crypto, - &fsm.identity, - header, - request, - fsm.state.now.unix_secs, - )?; - let peer = Peer { - xid: payload.xid, - signing_key: payload.signing_pub_key, - encapsulation_key: payload.encapsulation_pub_key, - }; - if is_replayed_control(fsm, peer.xid, payload.meta) { - return Ok(()); - } - - match fsm.peer.as_ref() { - Some(existing) if existing.peer != peer => return Err(QlFsmError::InvalidXid), - Some(_) => {} - None => bind_peer_record(fsm, peer.clone()), - } - - handshake::handle_connect(fsm, crypto) -} - -pub fn handle_unpair( - fsm: &mut QlFsm, - crypto: &impl QlCrypto, - header: &QlHeader, - unpair: &Unpair, -) -> Result<(), QlFsmError> { - let Some(entry) = fsm.peer.as_ref() else { - return Ok(()); - }; - - wire::verify_unpair( - crypto, - header, - &entry.peer.signing_key, - unpair, - fsm.state.now.unix_secs, - )?; - if is_replayed_control(fsm, header.sender, unpair.meta) { - return Ok(()); - } - - clear_bound_peer(fsm); - Ok(()) -} - -fn bind_peer_record(fsm: &mut QlFsm, peer: Peer) { - fsm.peer = Some(PeerRecord::new(peer.clone())); - reset_session(fsm); - fsm.state.events.push_back(QlFsmEvent::NewPeer(peer)); - emit_peer_status(fsm); -} diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index b64bf5fe..389ad87b 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -3,14 +3,13 @@ //! a caller drives `QlFsm` inside its own event loop //! //! inputs to that loop usually include -//! - app actions like `bind_peer`, `pair`, `connect`, `unpair`, `open_stream`, or `write_stream` +//! - app actions like `bind_peer`, `connect`, `open_stream`, or `write_stream` //! - inbound transport bytes passed to `receive` //! - a deadline expiring, handled by calling `on_timer` //! - transport write results passed to `confirm_session_write` or `reject_session_write` //! //! outputs from `QlFsm` are //! - outbound session and handshake records from `take_next_write` -//! - a best-effort peer unpair record returned directly from `unpair` //! - peer events from `take_next_event` //! - session events from `take_next_session_event` //! @@ -30,8 +29,8 @@ use std::time::{Duration, Instant}; pub use error::QlFsmError; use ql_wire::{ - CloseCode, CloseTarget, MlDsaPublicKey, MlKemPublicKey, QlCrypto, QlIdentity, QlRecord, - SessionCloseBody, StreamClose, StreamId, XID, + CloseCode, CloseTarget, PeerBundle, QlCrypto, QlIdentity, QlRecord, SessionClose, StreamClose, + StreamId, XID, }; pub use session::stream_rx::StreamReadIter; @@ -55,10 +54,8 @@ pub struct FsmTime { pub struct Peer { /// peer xid pub xid: XID, - /// peer signing public key - pub signing_key: MlDsaPublicKey, - /// peer encapsulation public key - pub encapsulation_key: MlKemPublicKey, + /// peer static bundle when known + pub bundle: Option, } /// connection state for the bound peer @@ -108,7 +105,7 @@ pub enum QlSessionEvent { /// the peer requested unpairing Unpaired, /// the encrypted session was closed - SessionClosed(SessionCloseBody), + SessionClosed(SessionClose), } /// handle for a session write returned by `QlFsm::take_next_write` @@ -129,12 +126,6 @@ pub struct OutboundWrite { pub struct QlFsmConfig { /// overall time limit for one handshake attempt pub handshake_timeout: Duration, - /// delay before retrying the current handshake message - pub handshake_retry_interval: Duration, - /// maximum retries for each handshake step - pub max_handshake_retries: u8, - /// how far into the future control messages remain valid - pub control_expiration: Duration, /// delay before sending a pure record ack pub session_record_ack_delay: Duration, /// how long to wait before resending unacked session records @@ -155,9 +146,6 @@ impl Default for QlFsmConfig { fn default() -> Self { Self { handshake_timeout: Duration::from_secs(5), - handshake_retry_interval: Duration::from_millis(750), - max_handshake_retries: 3, - control_expiration: Duration::from_secs(30), session_record_ack_delay: Duration::from_millis(5), session_record_retransmit_timeout: Duration::from_millis(150), session_keepalive_interval: Duration::from_secs(10), @@ -203,7 +191,7 @@ impl QlFsm { state: QlFsmState { replay_cache: ReplayCache::default(), next_control_id: 1, - outbound: Default::default(), + handshake: None, events: Default::default(), session_events: Default::default(), now, @@ -216,12 +204,6 @@ impl QlFsm { implementation::handle_bind_peer(self, peer); } - /// queues a pair request for the bound peer - pub fn pair(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { - self.state.now = now; - implementation::handle_pair_local(self, crypto) - } - /// starts or resumes the encrypted session handshake pub fn connect(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { self.state.now = now; @@ -287,7 +269,7 @@ impl QlFsm { /// returns the next peer-level event pub fn take_next_event(&mut self) -> Option { - implementation::take_next_event(self) + self.state.events.pop_front() } /// opens a new outgoing stream @@ -339,14 +321,8 @@ impl QlFsm { implementation::queue_ping(self) } - /// clears the bound peer locally and returns a best-effort unpair record - pub fn unpair(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Option>> { - self.state.now = now; - implementation::unpair(self, crypto) - } - /// returns the next session or stream event pub fn take_next_session_event(&mut self) -> Option { - implementation::take_next_session_event(self) + self.state.session_events.pop_front() } } diff --git a/ql-fsm/src/replay_cache.rs b/ql-fsm/src/replay_cache.rs index 470d4100..335c75c8 100644 --- a/ql-fsm/src/replay_cache.rs +++ b/ql-fsm/src/replay_cache.rs @@ -1,11 +1,11 @@ use std::collections::{hash_map::Entry, HashMap}; -use ql_wire::{ControlId, ControlMeta, XID}; +use ql_wire::{HandshakeId, HandshakeMeta, XID}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] struct ReplayKey { peer: XID, - control_id: ControlId, + handshake_id: HandshakeId, } #[derive(Debug, Default)] @@ -17,7 +17,7 @@ impl ReplayCache { pub fn check_and_store_valid_until( &mut self, peer: XID, - meta: ControlMeta, + meta: HandshakeMeta, now_secs: u64, ) -> bool { self.valid_until_by_key @@ -25,7 +25,7 @@ impl ReplayCache { let key = ReplayKey { peer, - control_id: meta.control_id, + handshake_id: meta.handshake_id, }; match self.valid_until_by_key.entry(key) { diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 16571f10..589c901e 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -9,7 +9,7 @@ use std::time::{Duration, Instant}; use indexmap::map::Entry; use ql_wire::{ - CloseCode, CloseTarget, RecordAck, RecordSeq, SessionCloseBody, SessionFrame, + CloseCode, CloseTarget, RecordAck, RecordSeq, SessionClose, SessionFrame, SessionRecordBuilder, StreamClose, StreamData, StreamId, StreamWindow, WireError, }; @@ -60,7 +60,7 @@ pub enum SessionEvent { Finished(StreamId), Closed(StreamClose), WritableClosed(StreamId), - SessionClosed(SessionCloseBody), + SessionClosed(SessionClose), } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -263,7 +263,7 @@ impl SessionFsm { Ok(frame) => frame, Err(_) => { self.fail_session( - SessionCloseBody { + SessionClose { code: CloseCode::PROTOCOL, }, &mut emit, @@ -345,7 +345,7 @@ impl SessionFsm { && self.state.last_inbound_at + self.config.peer_timeout <= self.state.now { self.fail_session( - SessionCloseBody { + SessionClose { code: CloseCode::TIMEOUT, }, &mut emit, @@ -400,20 +400,21 @@ impl SessionFsm { }) } - pub fn take_next_write(&mut self, now: Instant) -> Option<(u64, SessionRecordBuilder)> { + pub fn take_next_write(&mut self, now: Instant) -> Option<(u64, RecordSeq, SessionRecordBuilder)> { self.state.now = now; self.collect_timeouts(); let (builder, outbound) = self.build_next_record()?; let write_id = self.state.next_write_id; self.state.next_write_id = self.state.next_write_id.wrapping_add(1); + let seq = outbound.seq; self.state.outbound_records.insert(write_id, outbound); - Some((write_id, builder)) + Some((write_id, seq, builder)) } fn build_next_record(&mut self) -> Option<(SessionRecordBuilder, OutboundRecord)> { let seq = self.state.next_record_seq; - let mut builder = SessionRecordBuilder::new(seq, self.config.record_size); + let mut builder = SessionRecordBuilder::new(self.config.record_size); let mut outbound = OutboundRecord { seq, reliable: Vec::new(), @@ -750,7 +751,7 @@ impl SessionFsm { Entry::Vacant(entry) => { if !self.config.local_parity.remote().matches(stream_id) { self.fail_session( - SessionCloseBody { + SessionClose { code: CloseCode::PROTOCOL, }, emit, @@ -773,7 +774,7 @@ impl SessionFsm { return Ok(()); } self.fail_session( - SessionCloseBody { + SessionClose { code: CloseCode::PROTOCOL, }, emit, @@ -804,7 +805,7 @@ impl SessionFsm { | Err(StreamRxError::TooManyMissingRanges) | Err(StreamRxError::OffsetOverflow) => { self.fail_session( - SessionCloseBody { + SessionClose { code: CloseCode::PROTOCOL, }, emit, @@ -822,7 +823,7 @@ impl SessionFsm { ) -> Result<(), ()> { let Some(stream) = self.state.streams.get_mut(&frame.stream_id) else { self.fail_session( - SessionCloseBody { + SessionClose { code: CloseCode::PROTOCOL, }, emit, @@ -850,7 +851,7 @@ impl SessionFsm { Entry::Vacant(entry) => { if !self.config.local_parity.remote().matches(frame.stream_id) { self.fail_session( - SessionCloseBody { + SessionClose { code: CloseCode::PROTOCOL, }, emit, @@ -894,7 +895,7 @@ impl SessionFsm { fn handle_session_close( &mut self, - close: SessionCloseBody, + close: SessionClose, emit: &mut impl FnMut(SessionEvent), ) { if self.state.session_state == SessionState::Closed { @@ -1008,7 +1009,7 @@ impl SessionFsm { } } - fn fail_session(&mut self, close: SessionCloseBody, emit: &mut impl FnMut(SessionEvent)) { + fn fail_session(&mut self, close: SessionClose, emit: &mut impl FnMut(SessionEvent)) { if self.state.session_state == SessionState::Closed { return; } diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index e85d23d1..2fbc4a09 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -2,7 +2,7 @@ use std::{collections::BTreeSet, time::Instant}; use indexmap::IndexMap; use ql_wire::{ - CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionCloseBody, StreamClose, StreamId, + CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionClose, StreamClose, StreamId, XID, }; @@ -150,7 +150,7 @@ impl StreamState { pub enum ReliableFrame { StreamData(StreamDataManifest), StreamClose(StreamClose), - Close(SessionCloseBody), + Close(SessionClose), } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -174,7 +174,7 @@ pub struct OutboundRecord { #[derive(Debug, Clone, Default)] pub struct PendingSessionControl { pub ping: bool, - pub close: Option, + pub close: Option, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index b4dca534..1109b749 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -23,15 +23,15 @@ fn read_stream_all(fsm: &mut SessionFsm, stream_id: StreamId) -> Vec { out } -fn next_outbound(fsm: &mut SessionFsm, now: Instant) -> Option { - let (write_id, builder) = fsm.take_next_write(now)?; +fn next_outbound(fsm: &mut SessionFsm, now: Instant) -> Option<(RecordSeq, SessionRecord)> { + let (write_id, seq, builder) = fsm.take_next_write(now)?; fsm.confirm_write(now, write_id); - Some(SessionRecord::decode(builder.bytes()).unwrap()) + Some((seq, SessionRecord::decode(builder.bytes()).unwrap())) } -fn receive_events(fsm: &mut SessionFsm, now: Instant, record: SessionRecord) -> Vec { +fn receive_events(fsm: &mut SessionFsm, now: Instant, seq: RecordSeq, record: SessionRecord) -> Vec { let bytes = record.encode(); - let (seq, frames) = SessionRecord::parse(&bytes).unwrap(); + let frames = SessionRecord::parse(&bytes).unwrap(); let mut events = Vec::new(); fsm.receive(now, seq, frames, |event| events.push(event)); events @@ -44,13 +44,13 @@ fn outbound_record_seq_increments_monotonically() { let stream_id = fsm.open_stream().unwrap(); assert_eq!(fsm.write_stream(stream_id, b"one").unwrap(), 3); - let first = next_outbound(&mut fsm, now).unwrap(); + let (first_seq, _) = next_outbound(&mut fsm, now).unwrap(); assert_eq!(fsm.write_stream(stream_id, b"two").unwrap(), 3); - let second = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); + let (second_seq, _) = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); - assert_eq!(first.seq, RecordSeq(0)); - assert_eq!(second.seq, RecordSeq(1)); + assert_eq!(first_seq, RecordSeq(0)); + assert_eq!(second_seq, RecordSeq(1)); } #[test] @@ -60,12 +60,12 @@ fn retransmit_uses_new_record_seq() { let stream_id = fsm.open_stream().unwrap(); assert_eq!(fsm.write_stream(stream_id, b"retry").unwrap(), 5); - let first = next_outbound(&mut fsm, now).unwrap(); + let (first_seq, first) = next_outbound(&mut fsm, now).unwrap(); fsm.on_timer(now + Duration::from_millis(200), |_| {}); - let retried = next_outbound(&mut fsm, now + Duration::from_millis(200)).unwrap(); + let (retried_seq, retried) = next_outbound(&mut fsm, now + Duration::from_millis(200)).unwrap(); - assert_ne!(first.seq, retried.seq); + assert_ne!(first_seq, retried_seq); assert_eq!(first.frames, retried.frames); } @@ -87,12 +87,16 @@ fn lost_record_on_one_stream_does_not_block_another_stream() { assert_eq!(fsm.write_stream(stream_id_a, &payload_a).unwrap(), 40); assert_eq!(fsm.write_stream(stream_id_b, &payload_b).unwrap(), 40); - let first = next_outbound(&mut fsm, now).unwrap(); - let second = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); - assert_ne!(first.seq, second.seq); + let (first_seq, first) = next_outbound(&mut fsm, now).unwrap(); + let (second_seq, _second) = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); + assert_ne!(first_seq, second_seq); + assert!(first + .frames + .iter() + .any(|frame| matches!(frame, SessionFrame::StreamData(frame) if frame.stream_id == stream_id_a))); assert_eq!(fsm.write_stream(stream_id_b, b"b-2").unwrap(), 3); - let third = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); + let (_third_seq, third) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); let stream_ids: Vec<_> = third .frames @@ -106,7 +110,7 @@ fn lost_record_on_one_stream_does_not_block_another_stream() { } #[test] -fn write_stream_is_partial_and_ack_emits_writable() { +fn ack_reopens_write_capacity() { let now = Instant::now(); let mut fsm = SessionFsm::new( SessionFsmConfig { @@ -118,20 +122,22 @@ fn write_stream_is_partial_and_ack_emits_writable() { let stream_id = fsm.open_stream().unwrap(); assert_eq!(fsm.write_stream(stream_id, b"abcd").unwrap(), 4); - assert_eq!(fsm.write_stream(stream_id, b"z").unwrap(), 0); + let (seq, _record) = next_outbound(&mut fsm, now).unwrap(); - let sent = next_outbound(&mut fsm, now).unwrap(); - let ack = SessionRecord { - seq: RecordSeq(99), - frames: vec![SessionFrame::Ack(RecordAck { + let mut events = Vec::new(); + fsm.receive( + now + Duration::from_millis(1), + RecordSeq(9), + std::iter::once(Ok(SessionFrame::Ack(RecordAck { ranges: vec![RecordAckRange { - start: sent.seq.0, - end: sent.seq.0 + 1, + start: seq.0, + end: seq.0 + 1, }], - })], - }; - let events = receive_events(&mut fsm, now + Duration::from_millis(1), ack); - assert_eq!(events, vec![SessionEvent::Writable(stream_id)]); + }))), + |event| events.push(event), + ); + + assert!(events.contains(&SessionEvent::Writable(stream_id))); assert_eq!(fsm.write_stream(stream_id, b"z").unwrap(), 1); } @@ -148,7 +154,6 @@ fn commit_stream_read_is_what_advances_stream_window() { ); let stream_id = StreamId(1); let data = SessionRecord { - seq: RecordSeq(7), frames: vec![SessionFrame::StreamData(StreamData { stream_id, offset: 0, @@ -156,7 +161,7 @@ fn commit_stream_read_is_what_advances_stream_window() { bytes: b"hi".to_vec(), })], }; - let events = receive_events(&mut fsm, now, data); + let events = receive_events(&mut fsm, now, RecordSeq(7), data); assert_eq!( events, vec![ @@ -165,7 +170,7 @@ fn commit_stream_read_is_what_advances_stream_window() { ] ); - let first = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); + let (_first_seq, first) = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); assert!(matches!(first.frames.as_slice(), [SessionFrame::Ack(_)])); let read = fsm @@ -178,27 +183,49 @@ fn commit_stream_read_is_what_advances_stream_window() { assert!(next_outbound(&mut fsm, now + Duration::from_millis(2)).is_none()); fsm.stream_read_commit(stream_id, 2).unwrap(); - let second = next_outbound(&mut fsm, now + Duration::from_millis(3)).unwrap(); + let (_second_seq, second) = next_outbound(&mut fsm, now + Duration::from_millis(3)).unwrap(); assert!(matches!( second.frames.as_slice(), [SessionFrame::StreamWindow(window)] if window.stream_id == stream_id )); } +#[test] +fn inbound_stream_data_emits_opened_and_readable() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let stream_id = ql_wire::StreamId(1); + let record = SessionRecord { + frames: vec![SessionFrame::StreamData(ql_wire::StreamData { + stream_id, + offset: 0, + fin: true, + bytes: b"hello".to_vec(), + })], + }; + + let events = receive_events(&mut fsm, now, RecordSeq(0), record); + assert_eq!( + events, + vec![ + SessionEvent::Opened(stream_id), + SessionEvent::Readable(stream_id), + SessionEvent::Finished(stream_id) + ] + ); + assert_eq!(read_stream_all(&mut fsm, stream_id), b"hello".to_vec()); +} + #[test] fn remote_stream_close_is_reliable_and_retried() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = fsm.open_stream().unwrap(); - fsm.close_stream( - stream_id, - CloseTarget::Both, - CloseCode::CANCELLED, - ) - .unwrap(); + fsm.close_stream(stream_id, CloseTarget::Both, CloseCode::CANCELLED) + .unwrap(); - let (write_id, builder) = fsm.take_next_write(now).unwrap(); + let (write_id, _seq, builder) = fsm.take_next_write(now).unwrap(); fsm.confirm_write(now, write_id); let first = SessionRecord::decode(builder.bytes()).unwrap(); assert!(matches!( @@ -207,8 +234,7 @@ fn remote_stream_close_is_reliable_and_retried() { )); fsm.on_timer(now + Duration::from_millis(200), |_| {}); - let retried = next_outbound(&mut fsm, now + Duration::from_millis(200)).unwrap(); - assert_ne!(first.seq, retried.seq); + let (_retried_seq, retried) = next_outbound(&mut fsm, now + Duration::from_millis(200)).unwrap(); assert_eq!(first.frames, retried.frames); } @@ -247,7 +273,6 @@ fn duplicate_stream_data_is_not_redelivered() { let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = StreamId(1); let record = SessionRecord { - seq: RecordSeq(1), frames: vec![SessionFrame::StreamData(StreamData { stream_id, offset: 0, @@ -255,15 +280,8 @@ fn duplicate_stream_data_is_not_redelivered() { bytes: b"hi".to_vec(), })], }; - let _ = receive_events(&mut fsm, now, record.clone()); - let _ = receive_events( - &mut fsm, - now + Duration::from_millis(1), - SessionRecord { - seq: RecordSeq(2), - ..record - }, - ); + let _ = receive_events(&mut fsm, now, RecordSeq(1), record.clone()); + let _ = receive_events(&mut fsm, now + Duration::from_millis(1), RecordSeq(2), record); assert_eq!(read_stream_all(&mut fsm, stream_id), b"hi".to_vec()); } diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index 2c577b47..8a075744 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -1,75 +1,73 @@ use std::{collections::VecDeque, time::Instant}; -use ql_wire::{Confirm, Hello, HelloReply, QlRecord, Ready, ResponderSecrets, SessionKey}; +use ql_wire::{ + ConnectionId, EphemeralPublicKey, HandshakeId, KkHandshake, PeerBundle, QlHandshakeRecord, + SessionKey, XxHandshake, +}; use crate::{replay_cache::ReplayCache, FsmTime, Peer, PeerStatus, QlFsmEvent, QlSessionEvent}; #[derive(Debug, Clone)] -pub enum HandshakeInitiator { - WaitingHelloReply { - initiator_secret: SessionKey, - retry_count: u8, - retry_at: Option, - }, - WaitingReady { - reply: HelloReply, - confirm: Confirm, - session_key: SessionKey, - retry_count: u8, - retry_at: Option, - }, +pub enum HandshakeMode { + XxInitiator(XxHandshake), + XxResponder(XxHandshake), + KkInitiator(KkHandshake), } #[derive(Debug, Clone)] -pub enum HandshakeResponder { - WaitingConfirm { - secrets: ResponderSecrets, - retry_count: u8, - retry_at: Option, - }, +pub struct HandshakeState { + pub id: HandshakeId, + pub deadline: Instant, + pub mode: HandshakeMode, + pub initial_ephemeral: Option, } -#[derive(Debug, Clone)] -pub struct RecentReady { - pub hello: Hello, - pub reply: HelloReply, - pub ready: Ready>, - pub expires_at: Instant, +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionTransport { + pub tx_key: SessionKey, + pub rx_key: SessionKey, + pub tx_connection_id: ConnectionId, + pub rx_connection_id: ConnectionId, +} + +impl SessionTransport { + pub fn from_finalized(finalized: ql_wire::FinalizedHandshake) -> (Self, PeerBundle) { + ( + Self { + tx_key: finalized.tx_key, + rx_key: finalized.rx_key, + tx_connection_id: finalized.tx_connection_id, + rx_connection_id: finalized.rx_connection_id, + }, + finalized.remote_bundle, + ) + } } #[derive(Debug, Clone)] pub enum ConnectionState { Disconnected, - Initiator { - hello: Hello, - deadline: Instant, - stage: HandshakeInitiator, - }, - Responder { - hello: Hello, - reply: HelloReply, - deadline: Instant, - stage: HandshakeResponder, - }, - Connected { - session_key: SessionKey, - recent_ready: Option, - }, + Handshaking(HandshakeState), + Connected(SessionTransport), } impl ConnectionState { pub fn status(&self) -> PeerStatus { match self { Self::Disconnected => PeerStatus::Disconnected, - Self::Initiator { .. } => PeerStatus::Initiator, - Self::Responder { .. } => PeerStatus::Responder, - Self::Connected { .. } => PeerStatus::Connected, + Self::Handshaking(HandshakeState { mode, .. }) => match mode { + HandshakeMode::XxInitiator(_) | HandshakeMode::KkInitiator(_) => { + PeerStatus::Initiator + } + HandshakeMode::XxResponder(_) => PeerStatus::Responder, + }, + Self::Connected(_) => PeerStatus::Connected, } } - pub fn session_key(&self) -> Option<&SessionKey> { + pub fn transport(&self) -> Option<&SessionTransport> { match self { - Self::Connected { session_key, .. } => Some(session_key), + Self::Connected(transport) => Some(transport), _ => None, } } @@ -93,7 +91,7 @@ impl PeerRecord { pub struct QlFsmState { pub replay_cache: ReplayCache, pub next_control_id: u32, - pub outbound: VecDeque>>, + pub handshake: Option, pub events: VecDeque, pub session_events: VecDeque, pub now: FsmTime, diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 44624b95..5c882d21 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -1,286 +1,158 @@ use std::time::Duration; -use ql_wire::{QlPayload, XID}; +use ql_wire::{HandshakePayload, QlRecord}; use super::*; -use crate::state::{ConnectionState, HandshakeInitiator, HandshakeResponder}; +use crate::state::ConnectionState; #[test] -fn handshake_deadline_is_derived_from_peer_state() { - let config = QlFsmConfig { - handshake_timeout: Duration::from_secs(5), - handshake_retry_interval: Duration::from_secs(10), - max_handshake_retries: 0, - session_keepalive_interval: Duration::from_millis(1), - session_peer_timeout: Duration::from_millis(2), - ..QlFsmConfig::default() - }; - let mut harness = Harness::paired(config); +fn kk_connect_round_trip_establishes_transport() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); harness .a .fsm .connect(harness.time(), &harness.a.crypto) .unwrap(); - assert_eq!( - harness.a.fsm.next_deadline(), - Some(harness.now + config.handshake_timeout) - ); + harness.pump(); - let _hello = harness.next_outbound_a().unwrap(); - harness.advance(Duration::from_secs(4)); - harness.a.fsm.on_timer(harness.time()); assert!(matches!( harness.a.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Initiator { .. }) + Some(ConnectionState::Connected(_)) )); - - harness.advance(Duration::from_secs(1)); - harness.a.fsm.on_timer(harness.time()); assert!(matches!( - harness.a.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Disconnected) + harness.b.fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Connected(_)) )); } #[test] -fn initiator_retries_hello_after_retry_interval() { - let config = QlFsmConfig { - handshake_retry_interval: Duration::from_millis(250), - max_handshake_retries: 2, - ..QlFsmConfig::default() - }; - let mut harness = Harness::paired(config); +fn xx_connect_round_trip_learns_peer_bundles() { + let mut harness = Harness::paired_unknown(QlFsmConfig::default()); harness .a .fsm .connect(harness.time(), &harness.a.crypto) .unwrap(); - let hello = harness.next_outbound_a().unwrap(); - - harness.advance(config.handshake_retry_interval); - harness.a.fsm.on_timer(harness.time()); + harness.pump(); - assert_eq!(harness.next_outbound_a(), Some(hello)); + assert_eq!( + harness.a.fsm.peer.as_ref().unwrap().peer.bundle, + Some(harness.b.fsm.identity.bundle()) + ); + assert_eq!( + harness.b.fsm.peer.as_ref().unwrap().peer.bundle, + Some(harness.a.fsm.identity.bundle()) + ); assert!(matches!( harness.a.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Initiator { - stage: HandshakeInitiator::WaitingHelloReply { retry_count: 1, .. }, - .. - }) + Some(ConnectionState::Connected(_)) + )); + assert!(matches!( + harness.b.fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Connected(_)) )); } #[test] -fn responder_retries_hello_reply_after_retry_interval() { - let config = QlFsmConfig { - handshake_retry_interval: Duration::from_millis(250), - max_handshake_retries: 2, - ..QlFsmConfig::default() - }; - let mut harness = Harness::paired(config); +fn inbound_xx1_auto_binds_unbound_responder() { + let mut harness = Harness::responder_unbound_unknown(QlFsmConfig::default()); harness .a .fsm .connect(harness.time(), &harness.a.crypto) .unwrap(); - let hello = harness.next_outbound_a().unwrap(); - harness.deliver_to_b(hello); - let reply = harness.next_outbound_b().unwrap(); - - harness.advance(config.handshake_retry_interval); - harness.b.fsm.on_timer(harness.time()); + harness.pump(); - assert_eq!(harness.next_outbound_b(), Some(reply)); - assert!(matches!( - harness.b.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Responder { - stage: HandshakeResponder::WaitingConfirm { retry_count: 1, .. }, - .. - }) - )); + assert_eq!( + harness.b.fsm.peer.as_ref().map(|entry| entry.peer.xid), + Some(harness.a.fsm.identity.xid) + ); + assert_eq!( + harness.b.fsm.peer.as_ref().unwrap().peer.bundle, + Some(harness.a.fsm.identity.bundle()) + ); } #[test] -fn initiator_retries_confirm_after_retry_interval() { +fn handshake_timeout_drops_single_attempt_without_resend() { let config = QlFsmConfig { - handshake_retry_interval: Duration::from_millis(250), - max_handshake_retries: 2, + handshake_timeout: Duration::from_millis(60), ..QlFsmConfig::default() }; - let mut harness = Harness::paired(config); + let mut harness = Harness::paired_unknown(config); harness .a .fsm .connect(harness.time(), &harness.a.crypto) .unwrap(); - let hello = harness.next_outbound_a().unwrap(); - harness.deliver_to_b(hello); - let reply = harness.next_outbound_b().unwrap(); - harness.deliver_to_a(reply); - let confirm = harness.next_outbound_a().unwrap(); - - harness.advance(config.handshake_retry_interval); - harness.a.fsm.on_timer(harness.time()); - - assert_eq!(harness.next_outbound_a(), Some(confirm)); + let first = harness.next_outbound_a().unwrap(); assert!(matches!( - harness.a.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Initiator { - stage: HandshakeInitiator::WaitingReady { retry_count: 1, .. }, + first, + QlRecord::Handshake(ql_wire::QlHandshakeRecord { + payload: HandshakePayload::Xx1(_), .. }) )); -} - -#[test] -fn duplicate_hello_resends_hello_reply() { - let mut harness = Harness::paired(QlFsmConfig::default()); - - harness - .a - .fsm - .connect(harness.time(), &harness.a.crypto) - .unwrap(); - let hello = harness.next_outbound_a().unwrap(); - - harness.deliver_to_b(hello.clone()); - let reply = harness.next_outbound_b().unwrap(); - - harness.deliver_to_b(hello); - assert_eq!(harness.next_outbound_b(), Some(reply)); -} - -#[test] -fn duplicate_hello_reply_resends_confirm() { - let mut harness = Harness::paired(QlFsmConfig::default()); - - harness - .a - .fsm - .connect(harness.time(), &harness.a.crypto) - .unwrap(); - let hello = harness.next_outbound_a().unwrap(); - harness.deliver_to_b(hello); - let reply = harness.next_outbound_b().unwrap(); + assert!(harness.next_outbound_a().is_none()); - harness.deliver_to_a(reply.clone()); - let confirm = harness.next_outbound_a().unwrap(); - - harness.deliver_to_a(reply); - assert_eq!(harness.next_outbound_a(), Some(confirm)); -} - -#[test] -fn responder_resends_ready_for_duplicate_confirm_after_connecting() { - let mut harness = Harness::paired(QlFsmConfig::default()); - - harness - .a - .fsm - .connect(harness.time(), &harness.a.crypto) - .unwrap(); - let hello = harness.next_outbound_a().unwrap(); - harness.deliver_to_b(hello); - let reply = harness.next_outbound_b().unwrap(); - harness.deliver_to_a(reply); - let confirm = harness.next_outbound_a().unwrap(); - - harness.deliver_to_b(confirm.clone()); - let ready = harness.next_outbound_b().unwrap(); + harness.advance(config.handshake_timeout); + harness.a.fsm.on_timer(harness.time()); assert!(matches!( - harness.b.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Connected { - recent_ready: Some(_), - .. - }) + harness.a.fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Disconnected) )); - - harness.deliver_to_b(confirm); - assert_eq!(harness.next_outbound_b(), Some(ready)); + assert!(harness.next_outbound_a().is_none()); } #[test] -fn initiator_waits_for_ready_before_connecting() { - let mut harness = Harness::paired(QlFsmConfig::default()); +fn handshake_timeout_clears_queued_handshake_output() { + let config = QlFsmConfig { + handshake_timeout: Duration::from_millis(60), + ..QlFsmConfig::default() + }; + let mut harness = Harness::paired_unknown(config); harness .a .fsm .connect(harness.time(), &harness.a.crypto) .unwrap(); - let hello = harness.next_outbound_a().unwrap(); - harness.deliver_to_b(hello); - let reply = harness.next_outbound_b().unwrap(); - harness.deliver_to_a(reply); - - assert!(matches!( - harness.a.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Initiator { - stage: HandshakeInitiator::WaitingReady { .. }, - .. - }) - )); - let stream_id = harness.a.fsm.open_stream().unwrap(); - harness.a.fsm.write_stream(stream_id, b"queued").unwrap(); - let confirm = harness.next_outbound_a().unwrap(); - assert!(matches!(confirm.payload, QlPayload::Confirm(_))); - harness.deliver_to_b(confirm); - let ready = harness.next_outbound_b().unwrap(); - - assert!(matches!( - harness.a.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Initiator { - stage: HandshakeInitiator::WaitingReady { .. }, - .. - }) - )); + harness.advance(config.handshake_timeout); + harness.a.fsm.on_timer(harness.time()); - harness.deliver_to_a(ready); assert!(matches!( harness.a.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Connected { .. }) + Some(ConnectionState::Disconnected) )); - let record = harness.next_outbound_a().unwrap(); - assert!(matches!(record.payload, QlPayload::Session(_))); + assert!(harness.next_outbound_a().is_none()); } #[test] -fn handshake_retry_limit_disconnects_initiator() { - let config = QlFsmConfig { - handshake_retry_interval: Duration::from_millis(250), - max_handshake_retries: 1, - ..QlFsmConfig::default() - }; - let mut harness = Harness::paired(config); +fn bind_peer_clears_queued_handshake_output() { + let mut harness = Harness::paired_unknown(QlFsmConfig::default()); harness .a .fsm .connect(harness.time(), &harness.a.crypto) .unwrap(); - let hello = harness.next_outbound_a().unwrap(); - - harness.advance(config.handshake_retry_interval); - harness.a.fsm.on_timer(harness.time()); - assert_eq!(harness.next_outbound_a(), Some(hello)); + harness.a.fsm.bind_peer(Peer { + xid: test_identity(99).xid, + bundle: None, + }); - harness.advance(config.handshake_retry_interval); - harness.a.fsm.on_timer(harness.time()); - assert!(matches!( - harness.a.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Disconnected) - )); + assert!(harness.next_outbound_a().is_none()); } #[test] -fn simultaneous_connect_converges_to_connected_peers() { - let mut harness = Harness::paired(QlFsmConfig::default()); +fn simultaneous_xx_connect_converges() { + let mut harness = Harness::paired_unknown(QlFsmConfig::default()); harness .a @@ -292,63 +164,44 @@ fn simultaneous_connect_converges_to_connected_peers() { .fsm .connect(harness.time(), &harness.b.crypto) .unwrap(); - - let hello_a = harness.next_outbound_a().unwrap(); - let hello_b = harness.next_outbound_b().unwrap(); - - harness.deliver_to_a(hello_b); - harness.deliver_to_b(hello_a); harness.pump(); assert!(matches!( harness.a.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Connected { .. }) + Some(ConnectionState::Connected(_)) )); assert!(matches!( harness.b.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Connected { .. }) + Some(ConnectionState::Connected(_)) )); } #[test] -fn receive_surfaces_invalid_xid_for_wrong_recipient() { - let mut harness = Harness::paired(QlFsmConfig::default()); +fn simultaneous_xx_and_kk_connect_prefers_xx() { + let mut harness = Harness::paired(QlFsmConfig::default(), false, true); harness .a .fsm .connect(harness.time(), &harness.a.crypto) .unwrap(); - let mut hello = harness.next_outbound_a().unwrap(); - hello.header.recipient = XID([0xAA; XID::SIZE]); - - assert_eq!( - harness - .b - .fsm - .receive(harness.time(), hello.encode(), &harness.b.crypto), - Err(crate::QlFsmError::InvalidXid) - ); -} - -#[test] -fn receive_surfaces_invalid_signature_for_tampered_hello() { - let mut harness = Harness::paired(QlFsmConfig::default()); - harness - .a + .b .fsm - .connect(harness.time(), &harness.a.crypto) + .connect(harness.time(), &harness.b.crypto) .unwrap(); - let hello = harness.next_outbound_a().unwrap(); - let mut bytes = hello.encode(); - *bytes.last_mut().unwrap() ^= 0x01; + harness.pump(); assert_eq!( - harness - .b - .fsm - .receive(harness.time(), bytes, &harness.b.crypto), - Err(crate::QlFsmError::InvalidSignature) + harness.a.fsm.peer.as_ref().unwrap().peer.bundle, + Some(harness.b.fsm.identity.bundle()) ); + assert!(matches!( + harness.a.fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Connected(_)) + )); + assert!(matches!( + harness.b.fsm.peer.as_ref().map(|entry| &entry.session), + Some(ConnectionState::Connected(_)) + )); } diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 9b475cc9..bee7c9ea 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -7,22 +7,24 @@ use std::{ }; use libcrux_aesgcm::AesGcm256Key; +use libcrux_ml_kem::mlkem1024; use ql_wire::{ - self, generate_ml_dsa_keypair, generate_ml_kem_keypair, QlCrypto, QlIdentity, QlPayload, - QlRecord, SessionKey, XID, ENCRYPTED_MESSAGE_AUTH_SIZE, + self, generate_identity, ConnectionId, ENCRYPTED_MESSAGE_AUTH_SIZE, MlKemCiphertext, + MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, QlAead, QlCrypto, QlHash, QlIdentity, + QlKem, QlRandom, QlRecord, SessionKey, XID, }; use sha2::{Digest, Sha256}; use crate::{ session::{state::StreamParity, SessionFsm, SessionFsmConfig}, - state::ConnectionState, + state::{ConnectionState, SessionTransport}, FsmTime, OutboundWrite, Peer, QlFsm, QlFsmConfig, SessionWriteId, }; #[derive(Clone)] struct TestCrypto { seed: u8, - counter: Cell, + counter: Cell, } impl TestCrypto { @@ -32,27 +34,37 @@ impl TestCrypto { counter: Cell::new(0), } } + + fn next_block(&self) -> [u8; 32] { + let counter = self.counter.get(); + self.counter.set(counter.wrapping_add(1)); + sha256_parts(&[b"ql-fsm:test-rng:v1", &[self.seed], &counter.to_le_bytes()]) + } + + fn random_array(&self) -> [u8; L] { + let mut out = [0u8; L]; + self.fill_random_bytes(&mut out); + out + } } -impl QlCrypto for TestCrypto { - fn fill_random_bytes(&self, data: &mut [u8]) { - let value = self.seed.wrapping_add(self.counter.get()); - self.counter.set(self.counter.get().wrapping_add(1)); - data.fill(value); +impl QlRandom for TestCrypto { + fn fill_random_bytes(&self, out: &mut [u8]) { + fill_expanded(self, &[b"ql-fsm:test-fill:v1"], out); } +} - fn hash(&self, parts: &[&[u8]]) -> [u8; 32] { - let mut hasher = Sha256::new(); - for part in parts { - hasher.update(part); - } - hasher.finalize().into() +impl QlHash for TestCrypto { + fn sha256(&self, parts: &[&[u8]]) -> [u8; 32] { + sha256_parts(parts) } +} - fn encrypt_with_aead( +impl QlAead for TestCrypto { + fn aes256_gcm_encrypt( &self, key: &SessionKey, - nonce: &ql_wire::Nonce, + nonce: &Nonce, aad: &[u8], buffer: &mut [u8], ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE] { @@ -70,10 +82,10 @@ impl QlCrypto for TestCrypto { auth } - fn decrypt_with_aead( + fn aes256_gcm_decrypt( &self, key: &SessionKey, - nonce: &ql_wire::Nonce, + nonce: &Nonce, aad: &[u8], buffer: &mut [u8], auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], @@ -85,6 +97,48 @@ impl QlCrypto for TestCrypto { } } +impl QlKem for TestCrypto { + fn mlkem_generate_keypair(&self) -> MlKemKeyPair { + let key_pair = mlkem1024::generate_key_pair(self.random_array()); + let mut public = [0u8; MlKemPublicKey::SIZE]; + public.copy_from_slice(key_pair.pk()); + let mut private = [0u8; MlKemPrivateKey::SIZE]; + private.copy_from_slice(key_pair.sk()); + + MlKemKeyPair { + private: MlKemPrivateKey::from_data(private), + public: MlKemPublicKey::from_data(public), + } + } + + fn mlkem_encapsulate(&self, public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { + let public_key = public_key.as_bytes().into(); + let (ciphertext_value, shared_value) = + mlkem1024::encapsulate(&public_key, self.random_array()); + let mut ciphertext = [0u8; MlKemCiphertext::SIZE]; + ciphertext.copy_from_slice(ciphertext_value.as_slice()); + let mut shared = [0u8; SessionKey::SIZE]; + shared.copy_from_slice(shared_value.as_slice()); + ( + MlKemCiphertext::from_data(ciphertext), + SessionKey::from_data(shared), + ) + } + + fn mlkem_decapsulate( + &self, + private_key: &MlKemPrivateKey, + ciphertext: &MlKemCiphertext, + ) -> SessionKey { + let private_key = private_key.as_bytes().into(); + let ciphertext = ciphertext.as_bytes().into(); + let shared = mlkem1024::decapsulate(&private_key, &ciphertext); + let mut out = [0u8; SessionKey::SIZE]; + out.copy_from_slice(shared.as_slice()); + SessionKey::from_data(out) + } +} + struct Node { fsm: QlFsm, crypto: TestCrypto, @@ -98,11 +152,17 @@ struct Harness { } impl Harness { - fn paired(config: QlFsmConfig) -> Self { + fn paired_known(config: QlFsmConfig) -> Self { + Self::paired(config, true, true) + } + + fn paired_unknown(config: QlFsmConfig) -> Self { + Self::paired(config, false, false) + } + + fn paired(config: QlFsmConfig, know_a: bool, know_b: bool) -> Self { let identity_a = test_identity(11); let identity_b = test_identity(73); - let peer_a = peer_from_identity(&identity_b); - let peer_b = peer_from_identity(&identity_a); let now = Instant::now(); let time = FsmTime { instant: now, @@ -113,67 +173,75 @@ impl Harness { now, unix_secs: time.unix_secs, a: Node { - fsm: QlFsm::new(config, identity_a, time), + fsm: QlFsm::new(config, identity_a.clone(), time), crypto: TestCrypto::new(1), }, b: Node { - fsm: QlFsm::new(config, identity_b, time), + fsm: QlFsm::new(config, identity_b.clone(), time), crypto: TestCrypto::new(2), }, }; - harness.a.fsm.bind_peer(peer_a); - harness.b.fsm.bind_peer(peer_b); + harness.a.fsm.bind_peer(peer_from_identity(&identity_b, know_a)); + harness.b.fsm.bind_peer(peer_from_identity(&identity_a, know_b)); while harness.a.fsm.take_next_event().is_some() {} while harness.b.fsm.take_next_event().is_some() {} harness } - fn connected(config: QlFsmConfig) -> Self { - let mut harness = Self::paired(config); - let session_key = SessionKey::from_data([7; SessionKey::SIZE]); - - harness.a.fsm.peer.as_mut().unwrap().session = ConnectionState::Connected { - session_key, - recent_ready: None, - }; - harness.b.fsm.peer.as_mut().unwrap().session = ConnectionState::Connected { - session_key, - recent_ready: None, + fn responder_unbound_unknown(config: QlFsmConfig) -> Self { + let identity_a = test_identity(11); + let identity_b = test_identity(73); + let now = Instant::now(); + let time = FsmTime { + instant: now, + unix_secs: 1_700_000_000, }; - harness.a.fsm.session = SessionFsm::new( - SessionFsmConfig { - local_parity: StreamParity::for_local( - harness.a.fsm.identity.xid, - harness.a.fsm.peer.as_ref().unwrap().peer.xid, - ), - record_size: config.session_record_size, - ack_delay: config.session_record_ack_delay, - retransmit_timeout: config.session_record_retransmit_timeout, - keepalive_interval: config.session_keepalive_interval, - peer_timeout: config.session_peer_timeout, - stream_send_buffer_size: config.session_stream_send_buffer_size, - stream_receive_buffer_size: config.session_stream_receive_buffer_size, + + let mut harness = Self { + now, + unix_secs: time.unix_secs, + a: Node { + fsm: QlFsm::new(config, identity_a, time), + crypto: TestCrypto::new(1), }, - harness.now, - ); - harness.b.fsm.session = SessionFsm::new( - SessionFsmConfig { - local_parity: StreamParity::for_local( - harness.b.fsm.identity.xid, - harness.b.fsm.peer.as_ref().unwrap().peer.xid, - ), - record_size: config.session_record_size, - ack_delay: config.session_record_ack_delay, - retransmit_timeout: config.session_record_retransmit_timeout, - keepalive_interval: config.session_keepalive_interval, - peer_timeout: config.session_peer_timeout, - stream_send_buffer_size: config.session_stream_send_buffer_size, - stream_receive_buffer_size: config.session_stream_receive_buffer_size, + b: Node { + fsm: QlFsm::new(config, identity_b.clone(), time), + crypto: TestCrypto::new(2), }, - harness.now, - ); + }; + + harness + .a + .fsm + .bind_peer(Peer { xid: identity_b.xid, bundle: None }); + while harness.a.fsm.take_next_event().is_some() {} + + harness + } + + fn connected(config: QlFsmConfig) -> Self { + let mut harness = Self::paired_known(config); + let a_to_b_key = SessionKey::from_data([7; SessionKey::SIZE]); + let b_to_a_key = SessionKey::from_data([9; SessionKey::SIZE]); + let a_to_b_conn = ConnectionId::from_data([0xA1; ConnectionId::SIZE]); + let b_to_a_conn = ConnectionId::from_data([0xB2; ConnectionId::SIZE]); + + harness.a.fsm.peer.as_mut().unwrap().session = ConnectionState::Connected(SessionTransport { + tx_key: a_to_b_key.clone(), + rx_key: b_to_a_key.clone(), + tx_connection_id: a_to_b_conn, + rx_connection_id: b_to_a_conn, + }); + harness.b.fsm.peer.as_mut().unwrap().session = ConnectionState::Connected(SessionTransport { + tx_key: b_to_a_key, + rx_key: a_to_b_key, + tx_connection_id: b_to_a_conn, + rx_connection_id: a_to_b_conn, + }); + harness.a.fsm.session = SessionFsm::new(session_config(&harness, true), harness.now); + harness.b.fsm.session = SessionFsm::new(session_config(&harness, false), harness.now); harness } @@ -256,22 +324,40 @@ impl Harness { fn test_identity(seed: u8) -> QlIdentity { let crypto = TestCrypto::new(seed); - let (signing_private, signing_public) = generate_ml_dsa_keypair(&crypto); - let (encapsulation_private, encapsulation_public) = generate_ml_kem_keypair(&crypto); - QlIdentity::new( - XID([seed; XID::SIZE]), - signing_private, - signing_public, - encapsulation_private, - encapsulation_public, - ) + generate_identity(&crypto, XID([seed; XID::SIZE])) } -fn peer_from_identity(identity: &QlIdentity) -> Peer { +fn peer_from_identity(identity: &QlIdentity, know_bundle: bool) -> Peer { Peer { xid: identity.xid, - signing_key: identity.signing_public_key.clone(), - encapsulation_key: identity.encapsulation_public_key.clone(), + bundle: know_bundle.then(|| identity.bundle()), + } +} + +fn session_config(harness: &Harness, a: bool) -> SessionFsmConfig { + let (local, peer, config) = if a { + ( + harness.a.fsm.identity.xid, + harness.a.fsm.peer.as_ref().unwrap().peer.xid, + harness.a.fsm.config, + ) + } else { + ( + harness.b.fsm.identity.xid, + harness.b.fsm.peer.as_ref().unwrap().peer.xid, + harness.b.fsm.config, + ) + }; + + SessionFsmConfig { + local_parity: StreamParity::for_local(local, peer), + record_size: config.session_record_size, + ack_delay: config.session_record_ack_delay, + retransmit_timeout: config.session_record_retransmit_timeout, + keepalive_interval: config.session_keepalive_interval, + peer_timeout: config.session_peer_timeout, + stream_send_buffer_size: config.session_stream_send_buffer_size, + stream_receive_buffer_size: config.session_stream_receive_buffer_size, } } @@ -279,11 +365,39 @@ fn decrypt_record( crypto: &impl QlCrypto, record: &QlRecord>, session_key: &SessionKey, -) -> ql_wire::SessionRecord { - let aad = record.header.aad(); - let QlPayload::Session(encrypted) = &record.payload else { - panic!("expected encrypted payload"); +) -> (ql_wire::SessionHeader, ql_wire::SessionRecord) { + let ql_wire::QlRecord::Session(record) = record else { + panic!("expected encrypted session record"); }; - let plaintext = encrypted.decrypt(crypto, session_key, &aad).unwrap(); - ql_wire::SessionRecord::decode(&plaintext).unwrap() + let plaintext = + ql_wire::decrypt_record(crypto, &record.header, record.payload.clone(), session_key) + .unwrap(); + (record.header, ql_wire::SessionRecord::decode(&plaintext).unwrap()) +} + +fn sha256_parts(parts: &[&[u8]]) -> [u8; 32] { + let mut hasher = Sha256::new(); + for part in parts { + hasher.update(part); + } + hasher.finalize().into() +} + +fn fill_expanded(crypto: &TestCrypto, parts: &[&[u8]], out: &mut [u8]) { + let mut written = 0usize; + let mut counter = 0u64; + while written < out.len() { + let random = crypto.next_block(); + let counter_bytes = counter.to_le_bytes(); + let mut inputs = Vec::with_capacity(parts.len() + 3); + inputs.push(b"ql-fsm:test-expand:v1".as_slice()); + inputs.push(&random); + inputs.push(&counter_bytes); + inputs.extend_from_slice(parts); + let block = sha256_parts(&inputs); + let take = (out.len() - written).min(block.len()); + out[written..written + take].copy_from_slice(&block[..take]); + written += take; + counter = counter.wrapping_add(1); + } } diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index e003cfa0..4009e24b 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -1,6 +1,6 @@ use std::time::Duration; -use ql_wire::{SessionCloseBody, SessionFrame, StreamId}; +use ql_wire::{SessionClose, StreamId}; use super::*; use crate::{session::state::StreamParity, QlFsmEvent, QlSessionEvent}; @@ -50,7 +50,7 @@ fn connected_fsms_deliver_stream_data() { } #[test] -fn lost_record_is_retried_with_new_record_seq() { +fn session_retransmit_uses_new_record_seq() { let config = QlFsmConfig::default(); let mut harness = Harness::connected(config); @@ -58,24 +58,26 @@ fn lost_record_is_retried_with_new_record_seq() { assert_eq!(harness.a.fsm.write_stream(stream_id, b"retry").unwrap(), 5); let first = harness.next_outbound_a().unwrap(); - let session_key = *harness + let first_transport = harness .b .fsm .peer .as_ref() .unwrap() .session - .session_key() - .unwrap(); - let first_record = decrypt_record(&harness.b.crypto, &first, &session_key); + .transport() + .unwrap() + .clone(); + let (first_header, first_record) = decrypt_record(&harness.b.crypto, &first, &first_transport.rx_key); harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); harness.a.fsm.on_timer(harness.time()); let retried = harness.next_outbound_a().unwrap(); - let retried_record = decrypt_record(&harness.b.crypto, &retried, &session_key); + let (retried_header, retried_record) = + decrypt_record(&harness.b.crypto, &retried, &first_transport.rx_key); - assert_ne!(retried_record.seq, first_record.seq); + assert_ne!(retried_header.seq, first_header.seq); assert_eq!(retried_record.frames, first_record.frames); harness.deliver_to_b(retried); @@ -92,10 +94,7 @@ fn lost_record_is_retried_with_new_record_seq() { harness.b.fsm.take_next_session_event(), Some(QlSessionEvent::Readable(stream_id)) ); - assert_eq!( - read_stream_all(&mut harness.b.fsm, stream_id), - b"retry".to_vec() - ); + assert_eq!(read_stream_all(&mut harness.b.fsm, stream_id), b"retry".to_vec()); harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); harness.a.fsm.on_timer(harness.time()); @@ -119,14 +118,8 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { .matches(stream_id_b) ); - assert_eq!( - harness.a.fsm.write_stream(stream_id_a, b"from-a").unwrap(), - 6 - ); - assert_eq!( - harness.b.fsm.write_stream(stream_id_b, b"from-b").unwrap(), - 6 - ); + assert_eq!(harness.a.fsm.write_stream(stream_id_a, b"from-a").unwrap(), 6); + assert_eq!(harness.b.fsm.write_stream(stream_id_b, b"from-b").unwrap(), 6); harness.pump(); @@ -138,10 +131,7 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { harness.a.fsm.take_next_session_event(), Some(QlSessionEvent::Readable(stream_id_b)) ); - assert_eq!( - read_stream_all(&mut harness.a.fsm, stream_id_b), - b"from-b".to_vec() - ); + assert_eq!(read_stream_all(&mut harness.a.fsm, stream_id_b), b"from-b".to_vec()); assert_eq!( harness.b.fsm.take_next_session_event(), Some(QlSessionEvent::Opened(stream_id_a)) @@ -150,15 +140,12 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { harness.b.fsm.take_next_session_event(), Some(QlSessionEvent::Readable(stream_id_a)) ); - assert_eq!( - read_stream_all(&mut harness.b.fsm, stream_id_a), - b"from-a".to_vec() - ); + assert_eq!(read_stream_all(&mut harness.b.fsm, stream_id_a), b"from-a".to_vec()); } #[test] fn queued_stream_work_auto_connects_and_drains_after_handshake() { - let mut harness = Harness::paired(QlFsmConfig::default()); + let mut harness = Harness::paired_known(QlFsmConfig::default()); let stream_id = harness.a.fsm.open_stream().unwrap(); assert_eq!(harness.a.fsm.write_stream(stream_id, b"queued").unwrap(), 6); @@ -166,14 +153,6 @@ fn queued_stream_work_auto_connects_and_drains_after_handshake() { harness.pump(); - assert!(matches!( - harness.a.fsm.peer.as_ref().map(|entry| &entry.session), - Some(crate::state::ConnectionState::Connected { .. }) - )); - assert!(matches!( - harness.b.fsm.peer.as_ref().map(|entry| &entry.session), - Some(crate::state::ConnectionState::Connected { .. }) - )); assert_eq!( harness.b.fsm.take_next_session_event(), Some(QlSessionEvent::Opened(stream_id)) @@ -195,27 +174,21 @@ fn queued_stream_work_auto_connects_and_drains_after_handshake() { #[test] fn queued_stream_work_is_failed_when_handshake_times_out() { let config = QlFsmConfig { - handshake_retry_interval: Duration::from_millis(50), - max_handshake_retries: 0, + handshake_timeout: Duration::from_millis(50), ..QlFsmConfig::default() }; - let mut harness = Harness::paired(config); + let mut harness = Harness::paired_unknown(config); let stream_id = harness.a.fsm.open_stream().unwrap(); assert_eq!(harness.a.fsm.write_stream(stream_id, b"queued").unwrap(), 6); - let _hello = harness.next_outbound_a().unwrap(); - - harness.advance(config.handshake_retry_interval); + let _first = harness.next_outbound_a().unwrap(); + harness.advance(config.handshake_timeout); harness.a.fsm.on_timer(harness.time()); - assert!(matches!( - harness.a.fsm.peer.as_ref().map(|entry| &entry.session), - Some(crate::state::ConnectionState::Disconnected) - )); assert_eq!( harness.a.fsm.take_next_session_event(), - Some(QlSessionEvent::SessionClosed(SessionCloseBody { + Some(QlSessionEvent::SessionClosed(SessionClose { code: ql_wire::CloseCode::TIMEOUT })) ); @@ -232,26 +205,28 @@ fn returned_session_write_is_reissued_with_new_record_seq() { let write = harness.next_write_a().unwrap(); let id = write.session_write_id.expect("expected session write"); let record = write.record; - let session_key = *harness + let session_key = harness .b .fsm .peer .as_ref() .unwrap() .session - .session_key() - .unwrap(); - let first = decrypt_record(&harness.b.crypto, &record, &session_key); + .transport() + .unwrap() + .rx_key + .clone(); + let (first_header, first) = decrypt_record(&harness.b.crypto, &record, &session_key); harness.return_write_a(id); let write = harness.next_write_a().unwrap(); let reissued_id = write.session_write_id.expect("expected reissued write"); let record = write.record; - let reissued = decrypt_record(&harness.b.crypto, &record, &session_key); + let (reissued_header, reissued) = decrypt_record(&harness.b.crypto, &record, &session_key); assert_ne!(reissued_id, id); - assert_ne!(reissued.seq, first.seq); + assert_ne!(reissued_header.seq, first_header.seq); assert_eq!(reissued.frames, first.frames); harness.confirm_write_a(reissued_id); @@ -266,10 +241,7 @@ fn returned_session_write_is_reissued_with_new_record_seq() { harness.b.fsm.take_next_session_event(), Some(QlSessionEvent::Readable(stream_id)) ); - assert_eq!( - read_stream_all(&mut harness.b.fsm, stream_id), - b"retry".to_vec() - ); + assert_eq!(read_stream_all(&mut harness.b.fsm, stream_id), b"retry".to_vec()); } #[test] @@ -283,16 +255,18 @@ fn unconfirmed_session_write_does_not_start_retransmit_timer() { let write = harness.next_write_a().unwrap(); let id = write.session_write_id.expect("expected session write"); let record = write.record; - let session_key = *harness + let session_key = harness .b .fsm .peer .as_ref() .unwrap() .session - .session_key() - .unwrap(); - let first = decrypt_record(&harness.b.crypto, &record, &session_key); + .transport() + .unwrap() + .rx_key + .clone(); + let (first_header, first) = decrypt_record(&harness.b.crypto, &record, &session_key); harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); harness.a.fsm.on_timer(harness.time()); @@ -304,9 +278,9 @@ fn unconfirmed_session_write_does_not_start_retransmit_timer() { let write = harness.next_write_a().unwrap(); let record = write.record; - let retried = decrypt_record(&harness.b.crypto, &record, &session_key); + let (retried_header, retried) = decrypt_record(&harness.b.crypto, &record, &session_key); - assert_ne!(retried.seq, first.seq); + assert_ne!(retried_header.seq, first_header.seq); assert_eq!(retried.frames, first.frames); } @@ -347,7 +321,7 @@ fn kill_session_disconnects_locally() { )); assert_eq!( harness.a.fsm.take_next_session_event(), - Some(QlSessionEvent::SessionClosed(SessionCloseBody { + Some(QlSessionEvent::SessionClosed(SessionClose { code: ql_wire::CloseCode::CANCELLED })) ); @@ -371,18 +345,17 @@ fn session_records_contain_ack_frames_after_delivery() { harness.b.fsm.on_timer(harness.time()); let ack = harness.next_outbound_b().unwrap(); - let session_key = *harness + let session_key = harness .a .fsm .peer .as_ref() .unwrap() .session - .session_key() - .unwrap(); - let ack_record = decrypt_record(&harness.a.crypto, &ack, &session_key); - assert!(matches!( - ack_record.frames.as_slice(), - [SessionFrame::Ack(_)] - )); + .transport() + .unwrap() + .rx_key + .clone(); + let (_ack_header, ack_record) = decrypt_record(&harness.a.crypto, &ack, &session_key); + assert!(matches!(ack_record.frames.as_slice(), [ql_wire::SessionFrame::Ack(_)])); } diff --git a/ql-wire/Cargo.toml b/ql-wire/Cargo.toml index 25bb291d..42db6996 100644 --- a/ql-wire/Cargo.toml +++ b/ql-wire/Cargo.toml @@ -7,4 +7,5 @@ license = "Proprietary" [dev-dependencies] libcrux-aesgcm = "0.0.7" +libcrux-ml-kem = "0.0.7" sha2 = "0.10" diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index 6182aad5..54e9a1f8 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -1,5 +1,5 @@ use super::{ - push_variable_len, RecordAck, SessionCloseBody, SessionFrame, SessionFrameKind, StreamClose, + push_variable_len, RecordAck, SessionClose, SessionFrame, SessionFrameKind, StreamClose, StreamData, StreamWindow, SIZE_LEN, }; use crate::{ @@ -97,8 +97,8 @@ impl SessionRecordBuilder { true } - pub fn push_close(&mut self, close: &SessionCloseBody) -> bool { - if !self.can_push_len(1 + SessionCloseBody::WIRE_SIZE) { + pub fn push_close(&mut self, close: &SessionClose) -> bool { + if !self.can_push_len(1 + SessionClose::WIRE_SIZE) { return false; } self.bytes.push(SessionFrameKind::Close as u8); diff --git a/ql-wire/src/encrypted/close.rs b/ql-wire/src/encrypted/close.rs index 4702566a..9ff0b072 100644 --- a/ql-wire/src/encrypted/close.rs +++ b/ql-wire/src/encrypted/close.rs @@ -8,11 +8,11 @@ use crate::{ /// closes the whole session immediately with a close code. #[derive(Debug, Clone, PartialEq, Eq)] -pub struct SessionCloseBody { +pub struct SessionClose { pub code: CloseCode, } -impl SessionCloseBody { +impl SessionClose { pub const WIRE_SIZE: usize = size_of::(); pub fn encode_into(&self, out: &mut Vec) { diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index d423434c..9a1b7692 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -36,7 +36,7 @@ pub enum SessionFrame { StreamData(StreamData), StreamWindow(StreamWindow), StreamClose(StreamClose), - Close(SessionCloseBody), + Close(SessionClose), } pub type SessionFrameVec = SessionFrame>; @@ -116,7 +116,7 @@ impl SessionFrame { Self::StreamData(frame) => SIZE_LEN + frame.encoded_len(), Self::StreamWindow(_) => StreamWindow::WIRE_SIZE, Self::StreamClose(frame) => SIZE_LEN + frame.encoded_len(), - Self::Close(_) => SessionCloseBody::WIRE_SIZE, + Self::Close(_) => SessionClose::WIRE_SIZE, } } @@ -237,11 +237,11 @@ fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrame<&[u8]>, &[u8]), WireEr Ok((SessionFrame::StreamClose(StreamClose::parse(frame)?), rest)) } SessionFrameKind::Close => { - if rest.len() < SessionCloseBody::WIRE_SIZE { + if rest.len() < SessionClose::WIRE_SIZE { return Err(WireError::InvalidPayload); } - let (frame, rest) = rest.split_at(SessionCloseBody::WIRE_SIZE); - Ok((SessionFrame::Close(SessionCloseBody::decode(frame)?), rest)) + let (frame, rest) = rest.split_at(SessionClose::WIRE_SIZE); + Ok((SessionFrame::Close(SessionClose::decode(frame)?), rest)) } } } diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index 97f962af..a4e8dc65 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -37,14 +37,7 @@ pub struct CloseCode(pub u16); impl CloseCode { pub const CANCELLED: Self = Self(0); pub const PROTOCOL: Self = Self(1); - pub const INVALID_DATA: Self = Self(2); - pub const TIMEOUT: Self = Self(3); - - pub const UNKNOWN: Self = Self(16); - pub const UNKNOWN_ROUTE: Self = Self(17); - pub const INVALID_HEAD: Self = Self(18); - pub const BUSY: Self = Self(19); - pub const UNHANDLED: Self = Self(20); + pub const TIMEOUT: Self = Self(2); } /// aborts one or both directions of a stream with a close code. diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index fddd00ad..f12020e6 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -1,6 +1,7 @@ use std::sync::atomic::{AtomicU64, Ordering}; use libcrux_aesgcm::AesGcm256Key; +use libcrux_ml_kem::mlkem1024; use sha2::{Digest, Sha256}; use super::*; @@ -20,6 +21,12 @@ impl TestCrypto { let value = self.counter.fetch_add(1, Ordering::Relaxed).to_le_bytes(); sha256_parts(&[b"ql-wire:test-rng:v1", &value]) } + + fn random_array(&self) -> [u8; L] { + let mut out = [0u8; L]; + self.fill_random_bytes(&mut out); + out + } } impl QlRandom for TestCrypto { @@ -73,20 +80,11 @@ impl QlAead for TestCrypto { impl QlKem for TestCrypto { fn mlkem_generate_keypair(&self) -> MlKemKeyPair { - let seed = self.next_block(); - let key_id = self.sha256(&[b"ql-wire:test-mlkem:key-id:v1", &seed]); - + let key_pair = mlkem1024::generate_key_pair(self.random_array()); let mut public = [0u8; MlKemPublicKey::SIZE]; - fill_expanded(self, &[b"ql-wire:test-mlkem:public:v1", &seed], &mut public); - public[..key_id.len()].copy_from_slice(&key_id); - + public.copy_from_slice(key_pair.pk()); let mut private = [0u8; MlKemPrivateKey::SIZE]; - fill_expanded( - self, - &[b"ql-wire:test-mlkem:private:v1", &seed], - &mut private, - ); - private[..key_id.len()].copy_from_slice(&key_id); + private.copy_from_slice(key_pair.sk()); MlKemKeyPair { private: MlKemPrivateKey::from_data(private), @@ -95,19 +93,13 @@ impl QlKem for TestCrypto { } fn mlkem_encapsulate(&self, public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { - let mut encaps_seed = [0u8; 32]; - self.fill_random_bytes(&mut encaps_seed); - let key_id = &public_key.as_bytes()[..32]; - + let public_key = public_key.as_bytes().into(); + let (ciphertext_value, shared_value) = + mlkem1024::encapsulate(&public_key, self.random_array()); let mut ciphertext = [0u8; MlKemCiphertext::SIZE]; - fill_expanded( - self, - &[b"ql-wire:test-mlkem:ciphertext:v1", &encaps_seed], - &mut ciphertext, - ); - ciphertext[..encaps_seed.len()].copy_from_slice(&encaps_seed); - - let shared = self.sha256(&[b"ql-wire:test-mlkem:shared:v1", key_id, &encaps_seed]); + ciphertext.copy_from_slice(ciphertext_value.as_slice()); + let mut shared = [0u8; SessionKey::SIZE]; + shared.copy_from_slice(shared_value.as_slice()); ( MlKemCiphertext::from_data(ciphertext), SessionKey::from_data(shared), @@ -119,9 +111,12 @@ impl QlKem for TestCrypto { private_key: &MlKemPrivateKey, ciphertext: &MlKemCiphertext, ) -> SessionKey { - let key_id = &private_key.as_bytes()[..32]; - let encaps_seed = &ciphertext.as_bytes()[..32]; - SessionKey::from_data(self.sha256(&[b"ql-wire:test-mlkem:shared:v1", key_id, encaps_seed])) + let private_key = private_key.as_bytes().into(); + let ciphertext = ciphertext.as_bytes().into(); + let shared = mlkem1024::decapsulate(&private_key, &ciphertext); + let mut out = [0u8; SessionKey::SIZE]; + out.copy_from_slice(shared.as_slice()); + SessionKey::from_data(out) } } @@ -455,7 +450,7 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { target: CloseTarget::Both, code: CloseCode::PROTOCOL, }), - SessionFrame::Close(SessionCloseBody { + SessionFrame::Close(SessionClose { code: CloseCode::TIMEOUT, }), ], @@ -605,7 +600,7 @@ fn protocol_record_size_breakdown() { }, &session.tx_key, &SessionRecord { - frames: vec![SessionFrame::Close(SessionCloseBody { + frames: vec![SessionFrame::Close(SessionClose { code: CloseCode::PROTOCOL, })], }, From 031d5cd129de3183d48707cf5414bb10c6b33421 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 30 Mar 2026 11:05:54 -0400 Subject: [PATCH 061/304] ql: separte streamclosecode and sessionclosecode --- ql-fsm/src/implementation/fsm.rs | 13 ++-- ql-fsm/src/implementation/handshake/mod.rs | 2 +- ql-fsm/src/implementation/mod.rs | 6 +- ql-fsm/src/lib.rs | 8 +-- ql-fsm/src/session/mod.rs | 21 +++--- ql-fsm/src/session/tests.rs | 6 +- ql-fsm/src/tests/session.rs | 9 ++- ql-wire/src/encrypted/close.rs | 17 +++-- ql-wire/src/encrypted/stream_close.rs | 75 ++++++++++------------ ql-wire/src/tests.rs | 6 +- 10 files changed, 85 insertions(+), 78 deletions(-) diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index 94a38793..5bcfcb5b 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -1,6 +1,9 @@ use std::time::Instant; -use ql_wire::{self as wire, CloseCode, CloseTarget, QlCrypto, SessionHeader, StreamId}; +use ql_wire::{ + self as wire, CloseTarget, QlCrypto, SessionClose, SessionCloseCode, SessionHeader, + StreamCloseCode, StreamId, +}; use crate::{OutboundWrite, QlFsm, QlFsmError, QlSessionEvent, SessionWriteId, StreamReadIter}; @@ -108,7 +111,7 @@ pub fn reject_session_write(fsm: &mut QlFsm, write_id: SessionWriteId) { fsm.session.reject_write(write_id.0); } -pub fn kill_session(fsm: &mut QlFsm, code: CloseCode) { +pub fn kill_session(fsm: &mut QlFsm, code: SessionCloseCode) { let Some(entry) = fsm.peer.as_mut() else { return; }; @@ -121,9 +124,7 @@ pub fn kill_session(fsm: &mut QlFsm, code: CloseCode) { super::reset_session(fsm); fsm.state .session_events - .push_back(QlSessionEvent::SessionClosed(ql_wire::SessionClose { - code, - })); + .push_back(QlSessionEvent::SessionClosed(SessionClose { code })); } pub fn open_stream(fsm: &mut QlFsm) -> Result { @@ -165,7 +166,7 @@ pub fn close_stream( fsm: &mut QlFsm, stream_id: StreamId, target: CloseTarget, - code: CloseCode, + code: StreamCloseCode, ) -> Result<(), QlFsmError> { ensure_peer_bound(fsm)?; Ok(fsm.session.close_stream(stream_id, target, code)?) diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs index 4c3cc0fe..a1e500c4 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -94,7 +94,7 @@ pub fn handle_timer(fsm: &mut QlFsm) { entry.session = ConnectionState::Disconnected; } fsm.state.handshake = None; - fail_pending_connect_session(fsm, ql_wire::CloseCode::TIMEOUT); + fail_pending_connect_session(fsm, ql_wire::SessionCloseCode::TIMEOUT); emit_peer_status(fsm); } diff --git a/ql-fsm/src/implementation/mod.rs b/ql-fsm/src/implementation/mod.rs index 9707ccec..d5131014 100644 --- a/ql-fsm/src/implementation/mod.rs +++ b/ql-fsm/src/implementation/mod.rs @@ -57,16 +57,14 @@ pub fn handle_bind_peer(fsm: &mut QlFsm, peer: Peer) { emit_peer_status(fsm); } -fn fail_pending_connect_session(fsm: &mut QlFsm, code: ql_wire::CloseCode) { +fn fail_pending_connect_session(fsm: &mut QlFsm, code: ql_wire::SessionCloseCode) { if !fsm.session.has_pending_stream_work() { return; } reset_session(fsm); fsm.state .session_events - .push_back(QlSessionEvent::SessionClosed(ql_wire::SessionClose { - code, - })); + .push_back(QlSessionEvent::SessionClosed(ql_wire::SessionClose { code })); } fn forward_session_event( diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 389ad87b..22ccd081 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -29,8 +29,8 @@ use std::time::{Duration, Instant}; pub use error::QlFsmError; use ql_wire::{ - CloseCode, CloseTarget, PeerBundle, QlCrypto, QlIdentity, QlRecord, SessionClose, StreamClose, - StreamId, XID, + CloseTarget, PeerBundle, QlCrypto, QlIdentity, QlRecord, SessionClose, SessionCloseCode, + StreamClose, StreamCloseCode, StreamId, XID, }; pub use session::stream_rx::StreamReadIter; @@ -263,7 +263,7 @@ impl QlFsm { } /// closes the current encrypted session locally - pub fn kill_session(&mut self, code: CloseCode) { + pub fn kill_session(&mut self, code: SessionCloseCode) { implementation::kill_session(self, code); } @@ -311,7 +311,7 @@ impl QlFsm { &mut self, stream_id: StreamId, target: CloseTarget, - code: CloseCode, + code: StreamCloseCode, ) -> Result<(), QlFsmError> { implementation::close_stream(self, stream_id, target, code) } diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 589c901e..e4e6b93f 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -9,8 +9,9 @@ use std::time::{Duration, Instant}; use indexmap::map::Entry; use ql_wire::{ - CloseCode, CloseTarget, RecordAck, RecordSeq, SessionClose, SessionFrame, - SessionRecordBuilder, StreamClose, StreamData, StreamId, StreamWindow, WireError, + CloseTarget, RecordAck, RecordSeq, SessionClose, SessionCloseCode, SessionFrame, + SessionRecordBuilder, StreamClose, StreamCloseCode, StreamData, StreamId, StreamWindow, + WireError, }; use self::{ @@ -169,7 +170,7 @@ impl SessionFsm { &mut self, stream_id: StreamId, target: CloseTarget, - code: CloseCode, + code: StreamCloseCode, ) -> Result<(), StreamError> { self.ensure_session_open()?; { @@ -264,7 +265,7 @@ impl SessionFsm { Err(_) => { self.fail_session( SessionClose { - code: CloseCode::PROTOCOL, + code: SessionCloseCode::PROTOCOL, }, &mut emit, ); @@ -346,7 +347,7 @@ impl SessionFsm { { self.fail_session( SessionClose { - code: CloseCode::TIMEOUT, + code: SessionCloseCode::TIMEOUT, }, &mut emit, ); @@ -752,7 +753,7 @@ impl SessionFsm { if !self.config.local_parity.remote().matches(stream_id) { self.fail_session( SessionClose { - code: CloseCode::PROTOCOL, + code: SessionCloseCode::PROTOCOL, }, emit, ); @@ -775,7 +776,7 @@ impl SessionFsm { } self.fail_session( SessionClose { - code: CloseCode::PROTOCOL, + code: SessionCloseCode::PROTOCOL, }, emit, ); @@ -806,7 +807,7 @@ impl SessionFsm { | Err(StreamRxError::OffsetOverflow) => { self.fail_session( SessionClose { - code: CloseCode::PROTOCOL, + code: SessionCloseCode::PROTOCOL, }, emit, ); @@ -824,7 +825,7 @@ impl SessionFsm { let Some(stream) = self.state.streams.get_mut(&frame.stream_id) else { self.fail_session( SessionClose { - code: CloseCode::PROTOCOL, + code: SessionCloseCode::PROTOCOL, }, emit, ); @@ -852,7 +853,7 @@ impl SessionFsm { if !self.config.local_parity.remote().matches(frame.stream_id) { self.fail_session( SessionClose { - code: CloseCode::PROTOCOL, + code: SessionCloseCode::PROTOCOL, }, emit, ); diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 1109b749..781f72b1 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -1,8 +1,8 @@ use std::time::{Duration, Instant}; use ql_wire::{ - CloseCode, CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionFrame, SessionRecord, - StreamClose, StreamData, StreamId, XID, + CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionFrame, SessionRecord, StreamClose, + StreamCloseCode, StreamData, StreamId, XID, }; use super::{state::StreamParity, SessionEvent, SessionFsm, SessionFsmConfig}; @@ -222,7 +222,7 @@ fn remote_stream_close_is_reliable_and_retried() { let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = fsm.open_stream().unwrap(); - fsm.close_stream(stream_id, CloseTarget::Both, CloseCode::CANCELLED) + fsm.close_stream(stream_id, CloseTarget::Both, StreamCloseCode(0)) .unwrap(); let (write_id, _seq, builder) = fsm.take_next_write(now).unwrap(); diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 4009e24b..c920beb7 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -189,7 +189,7 @@ fn queued_stream_work_is_failed_when_handshake_times_out() { assert_eq!( harness.a.fsm.take_next_session_event(), Some(QlSessionEvent::SessionClosed(SessionClose { - code: ql_wire::CloseCode::TIMEOUT + code: ql_wire::SessionCloseCode::TIMEOUT })) ); assert!(harness.next_outbound_a().is_none()); @@ -313,7 +313,10 @@ fn ack_frame_releases_stream_capacity_and_emits_writable() { fn kill_session_disconnects_locally() { let mut harness = Harness::connected(QlFsmConfig::default()); - harness.a.fsm.kill_session(ql_wire::CloseCode::CANCELLED); + harness + .a + .fsm + .kill_session(ql_wire::SessionCloseCode::CANCELLED); assert!(matches!( harness.a.fsm.peer.as_ref().map(|entry| &entry.session), @@ -322,7 +325,7 @@ fn kill_session_disconnects_locally() { assert_eq!( harness.a.fsm.take_next_session_event(), Some(QlSessionEvent::SessionClosed(SessionClose { - code: ql_wire::CloseCode::CANCELLED + code: ql_wire::SessionCloseCode::CANCELLED })) ); assert!(matches!( diff --git a/ql-wire/src/encrypted/close.rs b/ql-wire/src/encrypted/close.rs index 9ff0b072..31782a01 100644 --- a/ql-wire/src/encrypted/close.rs +++ b/ql-wire/src/encrypted/close.rs @@ -1,6 +1,5 @@ use std::mem::size_of; -use super::CloseCode; use crate::{ codec::{self, Reader}, WireError, @@ -9,11 +8,11 @@ use crate::{ /// closes the whole session immediately with a close code. #[derive(Debug, Clone, PartialEq, Eq)] pub struct SessionClose { - pub code: CloseCode, + pub code: SessionCloseCode, } impl SessionClose { - pub const WIRE_SIZE: usize = size_of::(); + pub const WIRE_SIZE: usize = size_of::(); pub fn encode_into(&self, out: &mut Vec) { codec::push_u16(out, self.code.0); @@ -23,7 +22,17 @@ impl SessionClose { let mut reader = Reader::new(bytes); let code = reader.take_u16()?; Ok(Self { - code: CloseCode(code), + code: SessionCloseCode(code), }) } } + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct SessionCloseCode(pub u16); + +impl SessionCloseCode { + pub const CANCELLED: Self = Self(0); + pub const PROTOCOL: Self = Self(1); + pub const TIMEOUT: Self = Self(2); +} diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index a4e8dc65..f492e259 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -3,60 +3,24 @@ use std::mem::size_of; use super::StreamId; use crate::{codec, ByteSlice, WireError}; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(u8)] -pub enum CloseTarget { - Request = 1, - Response = 2, - Both = 3, -} - -impl CloseTarget { - pub const fn to_wire(self) -> u8 { - self as u8 - } -} - -impl TryFrom for CloseTarget { - type Error = WireError; - - fn try_from(value: u8) -> Result { - match value { - 1 => Ok(Self::Request), - 2 => Ok(Self::Response), - 3 => Ok(Self::Both), - _ => Err(WireError::InvalidPayload), - } - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -#[repr(transparent)] -pub struct CloseCode(pub u16); - -impl CloseCode { - pub const CANCELLED: Self = Self(0); - pub const PROTOCOL: Self = Self(1); - pub const TIMEOUT: Self = Self(2); -} - /// aborts one or both directions of a stream with a close code. #[derive(Debug, Clone, PartialEq, Eq)] pub struct StreamClose { pub stream_id: StreamId, pub target: CloseTarget, - pub code: CloseCode, + pub code: StreamCloseCode, } impl StreamClose { - pub const WIRE_SIZE: usize = size_of::() + size_of::() + size_of::(); + pub const WIRE_SIZE: usize = + size_of::() + size_of::() + size_of::(); pub fn parse(bytes: B) -> Result { let mut reader = codec::Reader::new(bytes); let close = Self { stream_id: StreamId(reader.take_u32()?), target: CloseTarget::try_from(reader.take_u8()?)?, - code: CloseCode(reader.take_u16()?), + code: StreamCloseCode(reader.take_u16()?), }; reader.finish()?; Ok(close) @@ -72,3 +36,34 @@ impl StreamClose { codec::push_u16(out, self.code.0); } } + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum CloseTarget { + Request = 1, + Response = 2, + Both = 3, +} + +impl CloseTarget { + pub const fn to_wire(self) -> u8 { + self as u8 + } +} + +impl TryFrom for CloseTarget { + type Error = WireError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::Request), + 2 => Ok(Self::Response), + 3 => Ok(Self::Both), + _ => Err(WireError::InvalidPayload), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct StreamCloseCode(pub u16); diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index f12020e6..f8aac3c8 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -448,10 +448,10 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { SessionFrame::StreamClose(StreamClose { stream_id: StreamId(9), target: CloseTarget::Both, - code: CloseCode::PROTOCOL, + code: StreamCloseCode(0), }), SessionFrame::Close(SessionClose { - code: CloseCode::TIMEOUT, + code: SessionCloseCode::TIMEOUT, }), ], }; @@ -601,7 +601,7 @@ fn protocol_record_size_breakdown() { &session.tx_key, &SessionRecord { frames: vec![SessionFrame::Close(SessionClose { - code: CloseCode::PROTOCOL, + code: SessionCloseCode::PROTOCOL, })], }, ); From a629f4fed22caf61987c79dc7be682cd40f12514 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 30 Mar 2026 11:22:45 -0400 Subject: [PATCH 062/304] ql-wire: associated constants --- ql-wire/src/handshake/kk.rs | 4 ++-- ql-wire/src/handshake/mod.rs | 28 ++++++++++++++-------------- ql-wire/src/handshake/xx.rs | 12 +++++++----- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index 827ce8f0..4cffa2dc 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -2,7 +2,7 @@ use super::{ decrypt_mlkem_ciphertext, encrypt_mlkem_ciphertext, finalize_handshake, generate_ephemeral_keypair, init_kk_symmetric, initialize_handshake_meta, mix_hash_ephemeral, mix_hash_handshake, require_handshake_meta, EncryptedMlKemCiphertext, EphemeralKeyPair, - EphemeralPublicKey, FinalizedHandshake, Role, SymmetricState, ENCRYPTED_MLKEM_CIPHERTEXT_LEN, + EphemeralPublicKey, FinalizedHandshake, Role, SymmetricState, }; use crate::{ codec, HandshakeHeader, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, @@ -50,7 +50,7 @@ pub struct Kk2 { impl Kk2 { pub const ENCODED_LEN: usize = - HandshakeMeta::ENCODED_LEN + MlKemCiphertext::SIZE + ENCRYPTED_MLKEM_CIPHERTEXT_LEN; + HandshakeMeta::ENCODED_LEN + MlKemCiphertext::SIZE + EncryptedMlKemCiphertext::ENCODED_LEN; pub fn encode_into(&self, out: &mut Vec) { self.meta.encode_into(out); diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index ffaf2bac..05035407 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -18,10 +18,6 @@ const PROTOCOL_KK: &[u8] = b"ql-wire:pq-kk:v1"; const CONNECTION_ID_DOMAIN: &[u8] = b"ql-wire:conn-id:v1"; const HANDSHAKE_PREAMBLE_DOMAIN: &[u8] = b"ql-wire:handshake-preamble:v1"; -pub const ENCRYPTED_MLKEM_CIPHERTEXT_LEN: usize = - MlKemCiphertext::SIZE + ENCRYPTED_MESSAGE_AUTH_SIZE; -pub const ENCRYPTED_PEER_BUNDLE_LEN: usize = PeerBundle::ENCODED_LEN + ENCRYPTED_MESSAGE_AUTH_SIZE; - #[derive(Debug, Clone, PartialEq, Eq)] pub struct EphemeralPublicKey { pub mlkem_public_key: MlKemPublicKey, @@ -45,27 +41,31 @@ impl EphemeralPublicKey { } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct EncryptedMlKemCiphertext(Box<[u8; ENCRYPTED_MLKEM_CIPHERTEXT_LEN]>); +pub struct EncryptedMlKemCiphertext(Box<[u8; Self::ENCODED_LEN]>); impl EncryptedMlKemCiphertext { - pub fn from_data(data: [u8; ENCRYPTED_MLKEM_CIPHERTEXT_LEN]) -> Self { + pub const ENCODED_LEN: usize = MlKemCiphertext::SIZE + ENCRYPTED_MESSAGE_AUTH_SIZE; + + pub fn from_data(data: [u8; Self::ENCODED_LEN]) -> Self { Self(Box::new(data)) } - pub fn as_bytes(&self) -> &[u8; ENCRYPTED_MLKEM_CIPHERTEXT_LEN] { + pub fn as_bytes(&self) -> &[u8; Self::ENCODED_LEN] { self.0.as_ref() } } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct EncryptedPeerBundle(Box<[u8; ENCRYPTED_PEER_BUNDLE_LEN]>); +pub struct EncryptedPeerBundle(Box<[u8; Self::ENCODED_LEN]>); impl EncryptedPeerBundle { - pub fn from_data(data: [u8; ENCRYPTED_PEER_BUNDLE_LEN]) -> Self { + pub const ENCODED_LEN: usize = PeerBundle::ENCODED_LEN + ENCRYPTED_MESSAGE_AUTH_SIZE; + + pub fn from_data(data: [u8; Self::ENCODED_LEN]) -> Self { Self(Box::new(data)) } - pub fn as_bytes(&self) -> &[u8; ENCRYPTED_PEER_BUNDLE_LEN] { + pub fn as_bytes(&self) -> &[u8; Self::ENCODED_LEN] { self.0.as_ref() } } @@ -312,10 +312,10 @@ fn encrypt_peer_bundle( bundle: &PeerBundle, ) -> Result { let ciphertext = symmetric.encrypt_and_hash(crypto, &bundle.encode())?; - if ciphertext.len() != ENCRYPTED_PEER_BUNDLE_LEN { + if ciphertext.len() != EncryptedPeerBundle::ENCODED_LEN { return Err(WireError::InvalidState); } - let mut out = [0u8; ENCRYPTED_PEER_BUNDLE_LEN]; + let mut out = [0u8; EncryptedPeerBundle::ENCODED_LEN]; out.copy_from_slice(&ciphertext); Ok(EncryptedPeerBundle::from_data(out)) } @@ -335,10 +335,10 @@ fn encrypt_mlkem_ciphertext( ciphertext: &MlKemCiphertext, ) -> Result { let encrypted = symmetric.encrypt_and_hash(crypto, ciphertext.as_bytes())?; - if encrypted.len() != ENCRYPTED_MLKEM_CIPHERTEXT_LEN { + if encrypted.len() != EncryptedMlKemCiphertext::ENCODED_LEN { return Err(WireError::InvalidState); } - let mut out = [0u8; ENCRYPTED_MLKEM_CIPHERTEXT_LEN]; + let mut out = [0u8; EncryptedMlKemCiphertext::ENCODED_LEN]; out.copy_from_slice(&encrypted); Ok(EncryptedMlKemCiphertext::from_data(out)) } diff --git a/ql-wire/src/handshake/xx.rs b/ql-wire/src/handshake/xx.rs index 02ffddfa..37b1e3ed 100644 --- a/ql-wire/src/handshake/xx.rs +++ b/ql-wire/src/handshake/xx.rs @@ -2,8 +2,7 @@ use super::{ decrypt_mlkem_ciphertext, decrypt_peer_bundle, encrypt_mlkem_ciphertext, encrypt_peer_bundle, finalize_handshake, generate_ephemeral_keypair, initialize_handshake_meta, mix_hash_ephemeral, mix_hash_handshake, require_handshake_meta, EncryptedMlKemCiphertext, EncryptedPeerBundle, - EphemeralKeyPair, EphemeralPublicKey, FinalizedHandshake, Role, SymmetricState, - ENCRYPTED_MLKEM_CIPHERTEXT_LEN, ENCRYPTED_PEER_BUNDLE_LEN, PROTOCOL_XX, + EphemeralKeyPair, EphemeralPublicKey, FinalizedHandshake, Role, SymmetricState, PROTOCOL_XX, }; use crate::{ codec, HandshakeHeader, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, @@ -43,7 +42,7 @@ pub struct Xx2 { impl Xx2 { pub const ENCODED_LEN: usize = - HandshakeMeta::ENCODED_LEN + MlKemCiphertext::SIZE + ENCRYPTED_PEER_BUNDLE_LEN; + HandshakeMeta::ENCODED_LEN + MlKemCiphertext::SIZE + EncryptedPeerBundle::ENCODED_LEN; pub fn encode_into(&self, out: &mut Vec) { self.meta.encode_into(out); @@ -74,7 +73,9 @@ pub struct Xx3 { impl Xx3 { pub const ENCODED_LEN: usize = - HandshakeMeta::ENCODED_LEN + ENCRYPTED_MLKEM_CIPHERTEXT_LEN + ENCRYPTED_PEER_BUNDLE_LEN; + HandshakeMeta::ENCODED_LEN + + EncryptedMlKemCiphertext::ENCODED_LEN + + EncryptedPeerBundle::ENCODED_LEN; pub fn encode_into(&self, out: &mut Vec) { self.meta.encode_into(out); @@ -103,7 +104,8 @@ pub struct Xx4 { } impl Xx4 { - pub const ENCODED_LEN: usize = HandshakeMeta::ENCODED_LEN + ENCRYPTED_MLKEM_CIPHERTEXT_LEN; + pub const ENCODED_LEN: usize = + HandshakeMeta::ENCODED_LEN + EncryptedMlKemCiphertext::ENCODED_LEN; pub fn encode_into(&self, out: &mut Vec) { self.meta.encode_into(out); From 85f34e1d9c500c905d0d68fa83dca42e2f694dff Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 30 Mar 2026 11:34:10 -0400 Subject: [PATCH 063/304] ql-wire: get rid of capped bytes --- ql-wire/src/bytes.rs | 77 +------------------------------------------- 1 file changed, 1 insertion(+), 76 deletions(-) diff --git a/ql-wire/src/bytes.rs b/ql-wire/src/bytes.rs index 88ed03c7..ca839476 100644 --- a/ql-wire/src/bytes.rs +++ b/ql-wire/src/bytes.rs @@ -154,39 +154,12 @@ pub struct RangedByteChunks { pub len: usize, } -pub struct CappedByteChunksIter { - inner: I, - remaining: usize, -} - pub struct RangedByteChunksIter { inner: I, skip: usize, remaining: usize, } -impl<'a, I> Iterator for CappedByteChunksIter -where - I: Iterator, -{ - type Item = &'a [u8]; - - fn next(&mut self) -> Option { - while self.remaining > 0 { - let chunk = self.inner.next()?; - if chunk.is_empty() { - continue; - } - - let len = chunk.len().min(self.remaining); - self.remaining -= len; - return Some(&chunk[..len]); - } - - None - } -} - impl<'a, I> Iterator for RangedByteChunksIter where I: Iterator, @@ -216,24 +189,6 @@ where } } -impl ByteChunks for CappedByteChunks { - type Chunks<'a> - = CappedByteChunksIter> - where - Self: 'a; - - fn len(&self) -> usize { - self.inner.len().min(self.limit) - } - - fn chunks(&self) -> Self::Chunks<'_> { - CappedByteChunksIter { - inner: self.inner.chunks(), - remaining: self.len(), - } - } -} - impl ByteChunks for RangedByteChunks { type Chunks<'a> = RangedByteChunksIter> @@ -257,7 +212,7 @@ impl ByteChunks for RangedByteChunks { mod tests { use std::collections::VecDeque; - use super::{ByteChunks, ByteSlice, ByteSliceMut, CappedByteChunks, RangedByteChunks}; + use super::{ByteChunks, ByteSlice, ByteSliceMut, RangedByteChunks}; #[test] fn shared_slice_split_at() { @@ -310,36 +265,6 @@ mod tests { assert!(chunks.len() >= 1); } - #[test] - fn capped_byte_chunks_truncate_slice() { - let bytes: &[u8] = b"abcdef"; - let capped = CappedByteChunks { - inner: bytes, - limit: 4, - }; - - let chunks = capped.chunks().collect::>(); - assert_eq!(capped.len(), 4); - assert_eq!(chunks, vec![b"abcd".as_slice()]); - } - - #[test] - fn capped_byte_chunks_truncate_borrowed_vec_deque() { - let mut bytes = VecDeque::with_capacity(8); - bytes.extend(b"abcd".iter().copied()); - bytes.drain(..2); - bytes.extend(b"efgh".iter().copied()); - - let capped = CappedByteChunks { - inner: &bytes, - limit: 4, - }; - - let chunks = capped.chunks().collect::>(); - assert_eq!(capped.len(), 4); - assert_eq!(chunks.concat(), b"cdef"); - } - #[test] fn ranged_byte_chunks_slice_middle() { let bytes: &[u8] = b"abcdef"; From 626ae31931f8f841041ef1bbf7580b01a816799d Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 30 Mar 2026 12:48:45 -0400 Subject: [PATCH 064/304] ql-wire: remove header from xx.rs --- ql-wire/src/handshake/kk.rs | 99 ++++++++++++++--- ql-wire/src/handshake/mod.rs | 25 +++-- ql-wire/src/handshake/xx.rs | 59 +++------- ql-wire/src/header.rs | 33 +----- ql-wire/src/identity.rs | 10 +- ql-wire/src/record.rs | 25 ++--- ql-wire/src/tests.rs | 205 +++++++++++++++-------------------- 7 files changed, 218 insertions(+), 238 deletions(-) diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index 4cffa2dc..44a12ab7 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -1,26 +1,61 @@ use super::{ decrypt_mlkem_ciphertext, encrypt_mlkem_ciphertext, finalize_handshake, generate_ephemeral_keypair, init_kk_symmetric, initialize_handshake_meta, mix_hash_ephemeral, - mix_hash_handshake, require_handshake_meta, EncryptedMlKemCiphertext, EphemeralKeyPair, + mix_hash_kk_handshake, require_handshake_meta, EncryptedMlKemCiphertext, EphemeralKeyPair, EphemeralPublicKey, FinalizedHandshake, Role, SymmetricState, }; use crate::{ - codec, HandshakeHeader, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, - QlIdentity, WireError, + codec, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, QlIdentity, + WireError, XID, }; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct KkHandshakeHeader { + pub sender: XID, + pub recipient: XID, +} + +impl KkHandshakeHeader { + pub const ENCODED_LEN: usize = XID::SIZE * 2; + + pub fn encode_into(&self, out: &mut Vec) { + codec::push_bytes(out, &self.sender.0); + codec::push_bytes(out, &self.recipient.0); + } + + pub fn decode(bytes: &[u8]) -> Result { + let mut reader = codec::Reader::new(bytes); + let header = Self::decode_from(&mut reader)?; + reader.finish()?; + Ok(header) + } + + pub fn decode_from( + reader: &mut codec::Reader, + ) -> Result { + Ok(Self { + sender: XID(reader.take_array()?), + recipient: XID(reader.take_array()?), + }) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct Kk1 { + pub header: KkHandshakeHeader, pub meta: HandshakeMeta, pub skem_ciphertext: MlKemCiphertext, pub ephemeral: EphemeralPublicKey, } impl Kk1 { - pub const ENCODED_LEN: usize = - HandshakeMeta::ENCODED_LEN + MlKemCiphertext::SIZE + EphemeralPublicKey::ENCODED_LEN; + pub const ENCODED_LEN: usize = KkHandshakeHeader::ENCODED_LEN + + HandshakeMeta::ENCODED_LEN + + MlKemCiphertext::SIZE + + EphemeralPublicKey::ENCODED_LEN; pub fn encode_into(&self, out: &mut Vec) { + self.header.encode_into(out); self.meta.encode_into(out); codec::push_bytes(out, self.skem_ciphertext.as_bytes()); self.ephemeral.encode_into(out); @@ -28,12 +63,14 @@ impl Kk1 { pub fn decode(bytes: &[u8]) -> Result { let mut reader = codec::Reader::new(bytes); + let header = KkHandshakeHeader::decode_from(&mut reader)?; let meta = HandshakeMeta::decode_from(&mut reader)?; let skem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); let ephemeral = EphemeralPublicKey::decode(&reader.take_bytes(EphemeralPublicKey::ENCODED_LEN)?)?; reader.finish()?; Ok(Self { + header, meta, skem_ciphertext, ephemeral, @@ -43,16 +80,20 @@ impl Kk1 { #[derive(Debug, Clone, PartialEq, Eq)] pub struct Kk2 { + pub header: KkHandshakeHeader, pub meta: HandshakeMeta, pub ekem_ciphertext: MlKemCiphertext, pub skem_ciphertext: EncryptedMlKemCiphertext, } impl Kk2 { - pub const ENCODED_LEN: usize = - HandshakeMeta::ENCODED_LEN + MlKemCiphertext::SIZE + EncryptedMlKemCiphertext::ENCODED_LEN; + pub const ENCODED_LEN: usize = KkHandshakeHeader::ENCODED_LEN + + HandshakeMeta::ENCODED_LEN + + MlKemCiphertext::SIZE + + EncryptedMlKemCiphertext::ENCODED_LEN; pub fn encode_into(&self, out: &mut Vec) { + self.header.encode_into(out); self.meta.encode_into(out); codec::push_bytes(out, self.ekem_ciphertext.as_bytes()); codec::push_bytes(out, self.skem_ciphertext.as_bytes()); @@ -60,11 +101,13 @@ impl Kk2 { pub fn decode(bytes: &[u8]) -> Result { let mut reader = codec::Reader::new(bytes); + let header = KkHandshakeHeader::decode_from(&mut reader)?; let meta = HandshakeMeta::decode_from(&mut reader)?; let ekem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); let skem_ciphertext = EncryptedMlKemCiphertext::from_data(reader.take_array()?); reader.finish()?; Ok(Self { + header, meta, ekem_ciphertext, skem_ciphertext, @@ -140,16 +183,38 @@ impl KkHandshake { self.step == KkStep::Done } + fn outbound_header(&self) -> KkHandshakeHeader { + KkHandshakeHeader { + sender: self.local.xid, + recipient: self.remote_bundle.xid, + } + } + + fn inbound_header(&self) -> KkHandshakeHeader { + KkHandshakeHeader { + sender: self.remote_bundle.xid, + recipient: self.local.xid, + } + } + + fn ensure_inbound_header(&self, header: KkHandshakeHeader) -> Result<(), WireError> { + if header == self.inbound_header() { + Ok(()) + } else { + Err(WireError::InvalidPayload) + } + } + pub fn write_message( &mut self, crypto: &impl QlCrypto, - header: HandshakeHeader, meta: HandshakeMeta, ) -> Result { match self.step { KkStep::Send1 => { initialize_handshake_meta(&mut self.handshake_meta, meta)?; - mix_hash_handshake( + let header = self.outbound_header(); + mix_hash_kk_handshake( &mut self.symmetric, crypto, header, @@ -170,6 +235,7 @@ impl KkHandshake { self.local_ephemeral = Some(local_ephemeral); self.step = KkStep::Recv2; Ok(KkMessage::Message1(Kk1 { + header, meta, skem_ciphertext, ephemeral: public, @@ -177,7 +243,8 @@ impl KkHandshake { } KkStep::Send2 => { require_handshake_meta(&self.handshake_meta, meta)?; - mix_hash_handshake( + let header = self.outbound_header(); + mix_hash_kk_handshake( &mut self.symmetric, crypto, header, @@ -202,6 +269,7 @@ impl KkHandshake { self.step = KkStep::Done; Ok(KkMessage::Message2(Kk2 { + header, meta, ekem_ciphertext, skem_ciphertext, @@ -214,7 +282,6 @@ impl KkHandshake { pub fn read_message( &mut self, crypto: &impl QlCrypto, - header: HandshakeHeader, now_seconds: u64, message: &KkMessage, ) -> Result<(), WireError> { @@ -222,10 +289,11 @@ impl KkHandshake { (KkStep::Recv1, KkMessage::Message1(message)) => { message.meta.ensure_not_expired(now_seconds)?; initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; - mix_hash_handshake( + self.ensure_inbound_header(message.header)?; + mix_hash_kk_handshake( &mut self.symmetric, crypto, - header, + message.header, HandshakeKind::Kk1, &message.meta, ); @@ -244,10 +312,11 @@ impl KkHandshake { (KkStep::Recv2, KkMessage::Message2(message)) => { message.meta.ensure_not_expired(now_seconds)?; require_handshake_meta(&self.handshake_meta, message.meta)?; - mix_hash_handshake( + self.ensure_inbound_header(message.header)?; + mix_hash_kk_handshake( &mut self.symmetric, crypto, - header, + message.header, HandshakeKind::Kk2, &message.meta, ); diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 05035407..a1292d8d 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -1,14 +1,13 @@ use crate::{ - codec, ConnectionId, HandshakeHeader, HandshakeKind, MlKemCiphertext, MlKemKeyPair, - MlKemPublicKey, Nonce, PeerBundle, QlCrypto, SessionKey, WireError, - ENCRYPTED_MESSAGE_AUTH_SIZE, + codec, ConnectionId, HandshakeKind, MlKemCiphertext, MlKemKeyPair, MlKemPublicKey, Nonce, + PeerBundle, QlCrypto, SessionKey, WireError, ENCRYPTED_MESSAGE_AUTH_SIZE, }; mod kk; mod meta; mod xx; -pub use kk::{Kk1, Kk2, KkHandshake, KkMessage}; +pub use kk::{Kk1, Kk2, KkHandshake, KkHandshakeHeader, KkMessage}; pub use meta::{HandshakeId, HandshakeMeta}; pub use xx::{Xx1, Xx2, Xx3, Xx4, XxHandshake, XxMessage}; @@ -266,14 +265,26 @@ fn mix_hash_ephemeral( symmetric.mix_hash(crypto, public.mlkem_public_key.as_bytes()); } -fn mix_hash_handshake( +fn mix_hash_xx_handshake( symmetric: &mut SymmetricState, crypto: &impl QlCrypto, - header: HandshakeHeader, kind: HandshakeKind, meta: &HandshakeMeta, ) { - let mut encoded_header = Vec::with_capacity(HandshakeHeader::ENCODED_LEN); + let encoded = meta.encode(); + symmetric.mix_hash(crypto, HANDSHAKE_PREAMBLE_DOMAIN); + symmetric.mix_hash(crypto, &[kind as u8]); + symmetric.mix_hash(crypto, &encoded); +} + +fn mix_hash_kk_handshake( + symmetric: &mut SymmetricState, + crypto: &impl QlCrypto, + header: KkHandshakeHeader, + kind: HandshakeKind, + meta: &HandshakeMeta, +) { + let mut encoded_header = Vec::with_capacity(KkHandshakeHeader::ENCODED_LEN); header.encode_into(&mut encoded_header); let encoded = meta.encode(); symmetric.mix_hash(crypto, HANDSHAKE_PREAMBLE_DOMAIN); diff --git a/ql-wire/src/handshake/xx.rs b/ql-wire/src/handshake/xx.rs index 37b1e3ed..5b21197a 100644 --- a/ql-wire/src/handshake/xx.rs +++ b/ql-wire/src/handshake/xx.rs @@ -1,12 +1,12 @@ use super::{ decrypt_mlkem_ciphertext, decrypt_peer_bundle, encrypt_mlkem_ciphertext, encrypt_peer_bundle, finalize_handshake, generate_ephemeral_keypair, initialize_handshake_meta, mix_hash_ephemeral, - mix_hash_handshake, require_handshake_meta, EncryptedMlKemCiphertext, EncryptedPeerBundle, + mix_hash_xx_handshake, require_handshake_meta, EncryptedMlKemCiphertext, EncryptedPeerBundle, EphemeralKeyPair, EphemeralPublicKey, FinalizedHandshake, Role, SymmetricState, PROTOCOL_XX, }; use crate::{ - codec, HandshakeHeader, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, - QlIdentity, WireError, + codec, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, QlIdentity, + WireError, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -72,10 +72,9 @@ pub struct Xx3 { } impl Xx3 { - pub const ENCODED_LEN: usize = - HandshakeMeta::ENCODED_LEN - + EncryptedMlKemCiphertext::ENCODED_LEN - + EncryptedPeerBundle::ENCODED_LEN; + pub const ENCODED_LEN: usize = HandshakeMeta::ENCODED_LEN + + EncryptedMlKemCiphertext::ENCODED_LEN + + EncryptedPeerBundle::ENCODED_LEN; pub fn encode_into(&self, out: &mut Vec) { self.meta.encode_into(out); @@ -191,19 +190,12 @@ impl XxHandshake { pub fn write_message( &mut self, crypto: &impl QlCrypto, - header: HandshakeHeader, meta: HandshakeMeta, ) -> Result { match self.step { XxStep::Send1 => { initialize_handshake_meta(&mut self.handshake_meta, meta)?; - mix_hash_handshake( - &mut self.symmetric, - crypto, - header, - HandshakeKind::Xx1, - &meta, - ); + mix_hash_xx_handshake(&mut self.symmetric, crypto, HandshakeKind::Xx1, &meta); let local_ephemeral = generate_ephemeral_keypair(crypto); let public = local_ephemeral.public(); mix_hash_ephemeral(&mut self.symmetric, crypto, &public); @@ -216,13 +208,7 @@ impl XxHandshake { } XxStep::Send2 => { require_handshake_meta(&self.handshake_meta, meta)?; - mix_hash_handshake( - &mut self.symmetric, - crypto, - header, - HandshakeKind::Xx2, - &meta, - ); + mix_hash_xx_handshake(&mut self.symmetric, crypto, HandshakeKind::Xx2, &meta); let remote_ephemeral = self .remote_ephemeral .clone() @@ -244,13 +230,7 @@ impl XxHandshake { } XxStep::Send3 => { require_handshake_meta(&self.handshake_meta, meta)?; - mix_hash_handshake( - &mut self.symmetric, - crypto, - header, - HandshakeKind::Xx3, - &meta, - ); + mix_hash_xx_handshake(&mut self.symmetric, crypto, HandshakeKind::Xx3, &meta); let remote_bundle = self.remote_bundle.clone().ok_or(WireError::InvalidState)?; let (skem_ciphertext, skem_secret) = crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); @@ -271,13 +251,7 @@ impl XxHandshake { } XxStep::Send4 => { require_handshake_meta(&self.handshake_meta, meta)?; - mix_hash_handshake( - &mut self.symmetric, - crypto, - header, - HandshakeKind::Xx4, - &meta, - ); + mix_hash_xx_handshake(&mut self.symmetric, crypto, HandshakeKind::Xx4, &meta); let remote_bundle = self.remote_bundle.clone().ok_or(WireError::InvalidState)?; let (skem_ciphertext, skem_secret) = crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); @@ -299,7 +273,6 @@ impl XxHandshake { pub fn read_message( &mut self, crypto: &impl QlCrypto, - header: HandshakeHeader, now_seconds: u64, message: &XxMessage, ) -> Result<(), WireError> { @@ -307,10 +280,9 @@ impl XxHandshake { (XxStep::Recv1, XxMessage::Message1(message)) => { message.meta.ensure_not_expired(now_seconds)?; initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; - mix_hash_handshake( + mix_hash_xx_handshake( &mut self.symmetric, crypto, - header, HandshakeKind::Xx1, &message.meta, ); @@ -322,10 +294,9 @@ impl XxHandshake { (XxStep::Recv2, XxMessage::Message2(message)) => { message.meta.ensure_not_expired(now_seconds)?; require_handshake_meta(&self.handshake_meta, message.meta)?; - mix_hash_handshake( + mix_hash_xx_handshake( &mut self.symmetric, crypto, - header, HandshakeKind::Xx2, &message.meta, ); @@ -348,10 +319,9 @@ impl XxHandshake { (XxStep::Recv3, XxMessage::Message3(message)) => { message.meta.ensure_not_expired(now_seconds)?; require_handshake_meta(&self.handshake_meta, message.meta)?; - mix_hash_handshake( + mix_hash_xx_handshake( &mut self.symmetric, crypto, - header, HandshakeKind::Xx3, &message.meta, ); @@ -374,10 +344,9 @@ impl XxHandshake { (XxStep::Recv4, XxMessage::Message4(message)) => { message.meta.ensure_not_expired(now_seconds)?; require_handshake_meta(&self.handshake_meta, message.meta)?; - mix_hash_handshake( + mix_hash_xx_handshake( &mut self.symmetric, crypto, - header, HandshakeKind::Xx4, &message.meta, ); diff --git a/ql-wire/src/header.rs b/ql-wire/src/header.rs index 845697a3..7d194880 100644 --- a/ql-wire/src/header.rs +++ b/ql-wire/src/header.rs @@ -1,10 +1,4 @@ -use crate::{codec, QL_WIRE_VERSION, XID}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct HandshakeHeader { - pub sender: XID, - pub recipient: XID, -} +use crate::{codec, QL_WIRE_VERSION}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct SessionHeader { @@ -32,31 +26,6 @@ impl ConnectionId { } } -impl HandshakeHeader { - pub const ENCODED_LEN: usize = XID::SIZE * 2; - - pub fn encode_into(&self, out: &mut Vec) { - codec::push_bytes(out, &self.sender.0); - codec::push_bytes(out, &self.recipient.0); - } - - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - let header = Self::decode_from(&mut reader)?; - reader.finish()?; - Ok(header) - } - - pub fn decode_from( - reader: &mut codec::Reader, - ) -> Result { - Ok(Self { - sender: XID(reader.take_array()?), - recipient: XID(reader.take_array()?), - }) - } -} - impl SessionHeader { pub const ENCODED_LEN: usize = ConnectionId::SIZE + core::mem::size_of::(); diff --git a/ql-wire/src/identity.rs b/ql-wire/src/identity.rs index b4d72355..1f5e2510 100644 --- a/ql-wire/src/identity.rs +++ b/ql-wire/src/identity.rs @@ -3,17 +3,21 @@ use crate::{codec, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, QlCrypto, Wire #[derive(Debug, Clone, PartialEq, Eq)] pub struct PeerBundle { pub version: u16, + pub xid: XID, pub capabilities: u32, pub mlkem_public_key: MlKemPublicKey, } impl PeerBundle { pub const VERSION: u16 = 1; - pub const ENCODED_LEN: usize = - core::mem::size_of::() + core::mem::size_of::() + MlKemPublicKey::SIZE; + pub const ENCODED_LEN: usize = core::mem::size_of::() + + XID::SIZE + + core::mem::size_of::() + + MlKemPublicKey::SIZE; pub fn encode_into(&self, out: &mut Vec) { codec::push_u16(out, self.version); + codec::push_bytes(out, &self.xid.0); codec::push_u32(out, self.capabilities); codec::push_bytes(out, self.mlkem_public_key.as_bytes()); } @@ -28,6 +32,7 @@ impl PeerBundle { let mut reader = codec::Reader::new(bytes); let bundle = Self { version: reader.take_u16()?, + xid: XID(reader.take_array()?), capabilities: reader.take_u32()?, mlkem_public_key: MlKemPublicKey::from_data(reader.take_array()?), }; @@ -66,6 +71,7 @@ impl QlIdentity { pub fn bundle(&self) -> PeerBundle { PeerBundle { version: PeerBundle::VERSION, + xid: self.xid, capabilities: self.capabilities, mlkem_public_key: self.mlkem_public_key.clone(), } diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index 047ffda2..3f0b4069 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -2,15 +2,9 @@ use crate::{ codec, encrypted_message::EncryptedMessage, handshake::{Kk1, Kk2, Xx1, Xx2, Xx3, Xx4}, - ByteSlice, HandshakeHeader, SessionHeader, WireError, QL_WIRE_VERSION, + ByteSlice, SessionHeader, WireError, QL_WIRE_VERSION, }; -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct QlHandshakeRecord { - pub header: HandshakeHeader, - pub payload: HandshakePayload, -} - #[derive(Debug, Clone, PartialEq, Eq)] pub struct QlSessionRecord { pub header: SessionHeader, @@ -24,7 +18,7 @@ pub enum QlRecord { } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum HandshakePayload { +pub enum QlHandshakeRecord { Xx1(Xx1), Xx2(Xx2), Xx3(Xx3), @@ -79,7 +73,7 @@ impl TryFrom for HandshakeKind { } } -impl HandshakePayload { +impl QlHandshakeRecord { pub fn kind(&self) -> HandshakeKind { match self { Self::Xx1(_) => HandshakeKind::Xx1, @@ -102,7 +96,7 @@ impl HandshakePayload { } } - fn decode(kind: HandshakeKind, bytes: &[u8]) -> Result { + fn decode_payload(kind: HandshakeKind, bytes: &[u8]) -> Result { match kind { HandshakeKind::Xx1 => Ok(Self::Xx1(Xx1::decode(bytes)?)), HandshakeKind::Xx2 => Ok(Self::Xx2(Xx2::decode(bytes)?)), @@ -112,16 +106,13 @@ impl HandshakePayload { HandshakeKind::Kk2 => Ok(Self::Kk2(Kk2::decode(bytes)?)), } } -} -impl QlHandshakeRecord { pub fn encode(&self) -> Vec { let mut out = Vec::new(); codec::push_u8(&mut out, QL_WIRE_VERSION); codec::push_u8(&mut out, RecordType::Handshake as u8); - self.header.encode_into(&mut out); - codec::push_u8(&mut out, self.payload.kind() as u8); - self.payload.encode_into(&mut out); + codec::push_u8(&mut out, self.kind() as u8); + self.encode_into(&mut out); out } @@ -218,11 +209,9 @@ impl QlRecord { fn parse_handshake_record(bytes: B) -> Result { let mut reader = codec::Reader::new(bytes); - let header = HandshakeHeader::decode_from(&mut reader)?; let kind = HandshakeKind::try_from(reader.take_u8()?)?; let payload = reader.take_rest(); - let payload = HandshakePayload::decode(kind, &payload[..])?; - Ok(QlHandshakeRecord { header, payload }) + QlHandshakeRecord::decode_payload(kind, &payload[..]) } fn parse_session_record(bytes: B) -> Result, WireError> { diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index f8aac3c8..da607b04 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -162,29 +162,27 @@ fn make_identity(crypto: &impl QlCrypto, byte: u8) -> QlIdentity { generate_identity(crypto, xid(byte)) } -fn handshake_header(sender: u8, recipient: u8) -> HandshakeHeader { - HandshakeHeader { +fn kk_handshake_header(sender: u8, recipient: u8) -> KkHandshakeHeader { + KkHandshakeHeader { sender: xid(sender), recipient: xid(recipient), } } -fn xx_record(header: HandshakeHeader, message: XxMessage) -> QlHandshakeRecord { - let payload = match message { - XxMessage::Message1(message) => HandshakePayload::Xx1(message), - XxMessage::Message2(message) => HandshakePayload::Xx2(message), - XxMessage::Message3(message) => HandshakePayload::Xx3(message), - XxMessage::Message4(message) => HandshakePayload::Xx4(message), - }; - QlHandshakeRecord { header, payload } +fn xx_record(message: XxMessage) -> QlHandshakeRecord { + match message { + XxMessage::Message1(message) => QlHandshakeRecord::Xx1(message), + XxMessage::Message2(message) => QlHandshakeRecord::Xx2(message), + XxMessage::Message3(message) => QlHandshakeRecord::Xx3(message), + XxMessage::Message4(message) => QlHandshakeRecord::Xx4(message), + } } -fn kk_record(header: HandshakeHeader, message: KkMessage) -> QlHandshakeRecord { - let payload = match message { - KkMessage::Message1(message) => HandshakePayload::Kk1(message), - KkMessage::Message2(message) => HandshakePayload::Kk2(message), - }; - QlHandshakeRecord { header, payload } +fn kk_record(message: KkMessage) -> QlHandshakeRecord { + match message { + KkMessage::Message1(message) => QlHandshakeRecord::Kk1(message), + KkMessage::Message2(message) => QlHandshakeRecord::Kk2(message), + } } #[test] @@ -200,28 +198,34 @@ fn peer_bundle_round_trip() { } #[test] -fn handshake_record_round_trip_uses_handshake_header() { - let message = Xx1 { +fn handshake_record_round_trip_supports_xx_and_kk() { + let xx = QlHandshakeRecord::Xx1(Xx1 { meta: handshake_meta(1), ephemeral: EphemeralPublicKey { mlkem_public_key: MlKemPublicKey::from_data([9; MlKemPublicKey::SIZE]), }, - }; - let record = QlHandshakeRecord { - header: HandshakeHeader { - sender: xid(1), - recipient: xid(2), - }, - payload: HandshakePayload::Xx1(message), - }; - - let encoded = record.encode(); - let decoded = QlHandshakeRecord::decode(&encoded).unwrap(); - - assert_eq!(decoded, record); + }); + let xx_encoded = xx.encode(); + assert_eq!(QlHandshakeRecord::decode(&xx_encoded).unwrap(), xx); + assert_eq!( + QlRecord::decode(&xx_encoded).unwrap(), + QlRecord::Handshake(xx) + ); - let decoded = QlRecord::decode(&encoded).unwrap(); - assert_eq!(decoded, QlRecord::Handshake(record)); + let kk = QlHandshakeRecord::Kk1(Kk1 { + header: kk_handshake_header(1, 2), + meta: handshake_meta(2), + skem_ciphertext: MlKemCiphertext::from_data([7; MlKemCiphertext::SIZE]), + ephemeral: EphemeralPublicKey { + mlkem_public_key: MlKemPublicKey::from_data([11; MlKemPublicKey::SIZE]), + }, + }); + let kk_encoded = kk.encode(); + assert_eq!(QlHandshakeRecord::decode(&kk_encoded).unwrap(), kk); + assert_eq!( + QlRecord::decode(&kk_encoded).unwrap(), + QlRecord::Handshake(kk) + ); } #[test] @@ -232,18 +236,14 @@ fn xx_handshake_rejects_tampered_handshake_meta() { let mut initiator_state = XxHandshake::new_initiator(&crypto, initiator); let mut responder_state = XxHandshake::new_responder(&crypto, responder); - let initiator_header = handshake_header(1, 2); - let responder_header = handshake_header(2, 1); let m1 = initiator_state - .write_message(&crypto, initiator_header, handshake_meta(77)) - .unwrap(); - responder_state - .read_message(&crypto, initiator_header, 0, &m1) + .write_message(&crypto, handshake_meta(77)) .unwrap(); + responder_state.read_message(&crypto, 0, &m1).unwrap(); let mut m2 = responder_state - .write_message(&crypto, responder_header, handshake_meta(77)) + .write_message(&crypto, handshake_meta(77)) .unwrap(); let XxMessage::Message2(message) = &mut m2 else { panic!("expected xx2"); @@ -251,36 +251,37 @@ fn xx_handshake_rejects_tampered_handshake_meta() { message.meta.handshake_id = HandshakeId(78); assert_eq!( - initiator_state.read_message(&crypto, responder_header, 0, &m2), + initiator_state.read_message(&crypto, 0, &m2), Err(WireError::InvalidPayload) ); } #[test] -fn xx_handshake_rejects_tampered_handshake_header() { +fn kk_handshake_rejects_tampered_handshake_header() { let crypto = TestCrypto::new(10); let initiator = make_identity(&crypto, 1); let responder = make_identity(&crypto, 2); - let mut initiator_state = XxHandshake::new_initiator(&crypto, initiator); - let mut responder_state = XxHandshake::new_responder(&crypto, responder); - let initiator_header = handshake_header(1, 2); - let responder_header = handshake_header(2, 1); + let mut initiator_state = + KkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); + let mut responder_state = KkHandshake::new_responder(&crypto, responder, initiator.bundle()); let m1 = initiator_state - .write_message(&crypto, initiator_header, handshake_meta(88)) - .unwrap(); - responder_state - .read_message(&crypto, initiator_header, 0, &m1) + .write_message(&crypto, handshake_meta(88)) .unwrap(); + responder_state.read_message(&crypto, 0, &m1).unwrap(); - let m2 = responder_state - .write_message(&crypto, responder_header, handshake_meta(88)) + let mut m2 = responder_state + .write_message(&crypto, handshake_meta(88)) .unwrap(); + let KkMessage::Message2(message) = &mut m2 else { + panic!("expected kk2"); + }; + message.header = kk_handshake_header(9, 1); assert_eq!( - initiator_state.read_message(&crypto, handshake_header(9, 1), 0, &m2), - Err(WireError::DecryptFailed) + initiator_state.read_message(&crypto, 0, &m2), + Err(WireError::InvalidPayload) ); } @@ -292,12 +293,10 @@ fn xx_handshake_rejects_expired_message() { let mut initiator_state = XxHandshake::new_initiator(&crypto, initiator); let mut responder_state = XxHandshake::new_responder(&crypto, responder); - let initiator_header = handshake_header(1, 2); let m1 = initiator_state .write_message( &crypto, - initiator_header, HandshakeMeta { handshake_id: HandshakeId(90), valid_until: 5, @@ -306,7 +305,7 @@ fn xx_handshake_rejects_expired_message() { .unwrap(); assert_eq!( - responder_state.read_message(&crypto, initiator_header, 6, &m1), + responder_state.read_message(&crypto, 6, &m1), Err(WireError::Expired) ); } @@ -319,36 +318,26 @@ fn xx_handshake_round_trip_derives_matching_transport() { let mut initiator_state = XxHandshake::new_initiator(&crypto, initiator.clone()); let mut responder_state = XxHandshake::new_responder(&crypto, responder.clone()); - let initiator_header = handshake_header(1, 2); - let responder_header = handshake_header(2, 1); let m1 = initiator_state - .write_message(&crypto, initiator_header, handshake_meta(1)) - .unwrap(); - responder_state - .read_message(&crypto, initiator_header, 0, &m1) + .write_message(&crypto, handshake_meta(1)) .unwrap(); + responder_state.read_message(&crypto, 0, &m1).unwrap(); let m2 = responder_state - .write_message(&crypto, responder_header, handshake_meta(1)) - .unwrap(); - initiator_state - .read_message(&crypto, responder_header, 0, &m2) + .write_message(&crypto, handshake_meta(1)) .unwrap(); + initiator_state.read_message(&crypto, 0, &m2).unwrap(); let m3 = initiator_state - .write_message(&crypto, initiator_header, handshake_meta(1)) - .unwrap(); - responder_state - .read_message(&crypto, initiator_header, 0, &m3) + .write_message(&crypto, handshake_meta(1)) .unwrap(); + responder_state.read_message(&crypto, 0, &m3).unwrap(); let m4 = responder_state - .write_message(&crypto, responder_header, handshake_meta(1)) - .unwrap(); - initiator_state - .read_message(&crypto, responder_header, 0, &m4) + .write_message(&crypto, handshake_meta(1)) .unwrap(); + initiator_state.read_message(&crypto, 0, &m4).unwrap(); let initiator_final = initiator_state.finalize(&crypto).unwrap(); let responder_final = responder_state.finalize(&crypto).unwrap(); @@ -381,22 +370,16 @@ fn kk_handshake_round_trip_derives_matching_transport() { KkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); let mut responder_state = KkHandshake::new_responder(&crypto, responder.clone(), initiator.bundle()); - let initiator_header = handshake_header(3, 4); - let responder_header = handshake_header(4, 3); let m1 = initiator_state - .write_message(&crypto, initiator_header, handshake_meta(11)) - .unwrap(); - responder_state - .read_message(&crypto, initiator_header, 0, &m1) + .write_message(&crypto, handshake_meta(11)) .unwrap(); + responder_state.read_message(&crypto, 0, &m1).unwrap(); let m2 = responder_state - .write_message(&crypto, responder_header, handshake_meta(11)) - .unwrap(); - initiator_state - .read_message(&crypto, responder_header, 0, &m2) + .write_message(&crypto, handshake_meta(11)) .unwrap(); + initiator_state.read_message(&crypto, 0, &m2).unwrap(); let initiator_final = initiator_state.finalize(&crypto).unwrap(); let responder_final = responder_state.finalize(&crypto).unwrap(); @@ -504,65 +487,49 @@ fn protocol_record_size_breakdown() { let mut xx_initiator = XxHandshake::new_initiator(&crypto, initiator.clone()); let mut xx_responder = XxHandshake::new_responder(&crypto, responder.clone()); - let xx_initiator_header = handshake_header(1, 2); - let xx_responder_header = handshake_header(2, 1); let xx1 = xx_initiator - .write_message(&crypto, xx_initiator_header, handshake_meta(101)) - .unwrap(); - xx_responder - .read_message(&crypto, xx_initiator_header, 0, &xx1) + .write_message(&crypto, handshake_meta(101)) .unwrap(); + xx_responder.read_message(&crypto, 0, &xx1).unwrap(); let xx2 = xx_responder - .write_message(&crypto, xx_responder_header, handshake_meta(101)) - .unwrap(); - xx_initiator - .read_message(&crypto, xx_responder_header, 0, &xx2) + .write_message(&crypto, handshake_meta(101)) .unwrap(); + xx_initiator.read_message(&crypto, 0, &xx2).unwrap(); let xx3 = xx_initiator - .write_message(&crypto, xx_initiator_header, handshake_meta(101)) - .unwrap(); - xx_responder - .read_message(&crypto, xx_initiator_header, 0, &xx3) + .write_message(&crypto, handshake_meta(101)) .unwrap(); + xx_responder.read_message(&crypto, 0, &xx3).unwrap(); let xx4 = xx_responder - .write_message(&crypto, xx_responder_header, handshake_meta(101)) - .unwrap(); - xx_initiator - .read_message(&crypto, xx_responder_header, 0, &xx4) + .write_message(&crypto, handshake_meta(101)) .unwrap(); + xx_initiator.read_message(&crypto, 0, &xx4).unwrap(); - let xx1 = xx_record(handshake_header(1, 2), xx1); - let xx2 = xx_record(handshake_header(2, 1), xx2); - let xx3 = xx_record(handshake_header(1, 2), xx3); - let xx4 = xx_record(handshake_header(2, 1), xx4); + let xx1 = xx_record(xx1); + let xx2 = xx_record(xx2); + let xx3 = xx_record(xx3); + let xx4 = xx_record(xx4); let mut kk_initiator = KkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); let mut kk_responder = KkHandshake::new_responder(&crypto, responder.clone(), initiator.bundle()); - let kk_initiator_header = handshake_header(1, 2); - let kk_responder_header = handshake_header(2, 1); let kk1 = kk_initiator - .write_message(&crypto, kk_initiator_header, handshake_meta(201)) - .unwrap(); - kk_responder - .read_message(&crypto, kk_initiator_header, 0, &kk1) + .write_message(&crypto, handshake_meta(201)) .unwrap(); + kk_responder.read_message(&crypto, 0, &kk1).unwrap(); let kk2 = kk_responder - .write_message(&crypto, kk_responder_header, handshake_meta(201)) - .unwrap(); - kk_initiator - .read_message(&crypto, kk_responder_header, 0, &kk2) + .write_message(&crypto, handshake_meta(201)) .unwrap(); + kk_initiator.read_message(&crypto, 0, &kk2).unwrap(); - let kk1 = kk_record(handshake_header(1, 2), kk1); - let kk2 = kk_record(handshake_header(2, 1), kk2); + let kk1 = kk_record(kk1); + let kk2 = kk_record(kk2); let session = xx_initiator.finalize(&crypto).unwrap(); let session_ping = encrypted::encrypt_record( From 85ec6f7645a4f7e6554cd3a738c20f96c7890713 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 30 Mar 2026 14:45:11 -0400 Subject: [PATCH 065/304] ql-wire: different handshake headers --- ql-fsm/src/implementation/fsm.rs | 22 +-- ql-fsm/src/implementation/handshake/kk.rs | 114 ++++++------- ql-fsm/src/implementation/handshake/mod.rs | 183 +++++---------------- ql-fsm/src/implementation/handshake/xx.rs | 179 +++++++------------- ql-fsm/src/implementation/mod.rs | 38 ++--- ql-fsm/src/lib.rs | 20 +-- ql-fsm/src/replay_cache.rs | 26 +-- ql-fsm/src/session/mod.rs | 11 +- ql-fsm/src/session/state.rs | 3 +- ql-fsm/src/session/stream_tx.rs | 9 +- ql-fsm/src/session/tests.rs | 24 ++- ql-fsm/src/state.rs | 85 +++++----- ql-fsm/src/tests/handshake.rs | 88 ++-------- ql-fsm/src/tests/mod.rs | 52 +++--- ql-fsm/src/tests/session.rs | 94 +++++------ 15 files changed, 337 insertions(+), 611 deletions(-) diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index 5bcfcb5b..56877f1f 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -73,10 +73,9 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option Result<(), QlFsmError> { fn ensure_session_open(fsm: &QlFsm) -> Result<(), QlFsmError> { ensure_peer_bound(fsm)?; - if fsm - .peer - .as_ref() - .and_then(|entry| entry.session.transport()) - .is_none() - { + if fsm.state.link.transport().is_none() { return Err(QlFsmError::SessionClosed); } Ok(()) diff --git a/ql-fsm/src/implementation/handshake/kk.rs b/ql-fsm/src/implementation/handshake/kk.rs index 3b2162fe..d22a9b17 100644 --- a/ql-fsm/src/implementation/handshake/kk.rs +++ b/ql-fsm/src/implementation/handshake/kk.rs @@ -1,46 +1,35 @@ use ql_wire::{ - self as wire, HandshakeHeader, HandshakePayload, Kk1, Kk2, KkMessage, PeerBundle, QlCrypto, - WireError, XID, + self as wire, Kk1, Kk2, KkMessage, PeerBundle, QlCrypto, QlHandshakeRecord, WireError, }; use super::{ - ensure_bound_peer, ensure_bound_peer_with_bundle, finish_handshake, - reset_connected_session_if_needed, should_ignore_inbound_handshake_start, + enqueue_handshake, finish_handshake, is_replayed_handshake_start, + reset_connected_session_if_needed, }; use crate::{ - implementation::{emit_peer_status, enqueue_handshake, is_replayed_handshake_start}, - state::{ConnectionState, HandshakeMode, HandshakeState, SessionTransport}, + implementation::emit_peer_status, + state::{LinkState, SessionTransport}, QlFsm, QlFsmError, }; pub fn start_initiator( fsm: &mut QlFsm, crypto: &impl QlCrypto, - peer: XID, - bundle: PeerBundle, + peer: PeerBundle, ) -> Result<(), QlFsmError> { - let header = HandshakeHeader { - sender: fsm.identity.xid, - recipient: peer, - }; let meta = super::next_handshake_meta(fsm); - let mut handshake = wire::KkHandshake::new_initiator(crypto, fsm.identity.clone(), bundle); - let message = handshake.write_message(crypto, header, meta)?; - let payload = kk_payload(message); - let initial_ephemeral = match &payload { - HandshakePayload::Kk1(message) => Some(message.ephemeral.clone()), - _ => None, + let mut handshake = wire::KkHandshake::new_initiator(crypto, fsm.identity.clone(), peer); + let message = handshake.write_message(crypto, meta)?; + let KkMessage::Message1(message) = message else { + return Err(QlFsmError::InvalidPayload); }; - if let Some(entry) = fsm.peer.as_mut() { - entry.session = ConnectionState::Handshaking(HandshakeState { - id: meta.handshake_id, - deadline: fsm.state.now.instant + fsm.config.handshake_timeout, - mode: HandshakeMode::KkInitiator(handshake), - initial_ephemeral, - }); - } - enqueue_handshake(fsm, peer, payload); + fsm.state.link = LinkState::KkInitiator { + initial_ephemeral: message.ephemeral.clone(), + handshake, + deadline: fsm.state.now.instant + fsm.config.handshake_timeout, + }; + enqueue_handshake(fsm, QlHandshakeRecord::Kk1(message)); emit_peer_status(fsm); Ok(()) } @@ -48,68 +37,54 @@ pub fn start_initiator( pub fn handle_kk1( fsm: &mut QlFsm, crypto: &impl QlCrypto, - header: HandshakeHeader, message: &Kk1, ) -> Result<(), QlFsmError> { - if should_ignore_inbound_handshake_start(fsm, header.sender, false, &message.ephemeral) { + if should_ignore_inbound(fsm, message) { return Ok(()); } - - if is_replayed_handshake_start(fsm, header.sender, message.meta) { + if is_replayed_handshake_start(fsm, message.meta) { return Ok(()); } - ensure_bound_peer_with_bundle(fsm, header.sender)?; + + let Some(peer) = fsm.peer.clone() else { + return Err(QlFsmError::InvalidPayload); + }; + if message.header.recipient != fsm.identity.xid || message.header.sender != peer.xid { + return Err(QlFsmError::InvalidXid); + } + reset_connected_session_if_needed(fsm); - let bundle = fsm - .peer - .as_ref() - .and_then(|entry| entry.peer.bundle.clone()) - .ok_or(QlFsmError::NoPeerBound)?; - let mut handshake = wire::KkHandshake::new_responder(crypto, fsm.identity.clone(), bundle); + let mut handshake = wire::KkHandshake::new_responder(crypto, fsm.identity.clone(), peer); handshake.read_message( crypto, - header, fsm.state.now.unix_secs, &KkMessage::Message1(message.clone()), )?; - let outbound = handshake.write_message( - crypto, - HandshakeHeader { - sender: fsm.identity.xid, - recipient: header.sender, - }, - message.meta, - )?; + let outbound = handshake.write_message(crypto, message.meta)?; let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); finish_handshake(fsm, transport, remote_bundle)?; fsm.state.handshake = None; - enqueue_handshake(fsm, header.sender, kk_payload(outbound)); + enqueue_handshake(fsm, kk_record(outbound)); Ok(()) } pub fn handle_kk2( fsm: &mut QlFsm, crypto: &impl QlCrypto, - header: HandshakeHeader, message: &Kk2, ) -> Result<(), QlFsmError> { - ensure_bound_peer(fsm, header.sender)?; - let session = match fsm.peer.as_ref() { - Some(entry) => entry.session.clone(), - None => return Ok(()), - }; - let ConnectionState::Handshaking(HandshakeState { - mode: HandshakeMode::KkInitiator(mut handshake), - .. - }) = session + let LinkState::KkInitiator { + mut handshake, + deadline: _, + initial_ephemeral: _, + } = fsm.state.link.clone() else { return Ok(()); }; match handshake.read_message( crypto, - header, fsm.state.now.unix_secs, &KkMessage::Message2(message.clone()), ) { @@ -122,9 +97,24 @@ pub fn handle_kk2( finish_handshake(fsm, transport, remote_bundle) } -fn kk_payload(message: KkMessage) -> HandshakePayload { +fn kk_record(message: KkMessage) -> QlHandshakeRecord { match message { - KkMessage::Message1(message) => HandshakePayload::Kk1(message), - KkMessage::Message2(message) => HandshakePayload::Kk2(message), + KkMessage::Message1(message) => QlHandshakeRecord::Kk1(message), + KkMessage::Message2(message) => QlHandshakeRecord::Kk2(message), + } +} + +pub fn should_ignore_inbound(fsm: &QlFsm, message: &Kk1) -> bool { + match &fsm.state.link { + LinkState::Idle | LinkState::Connected(_) => false, + LinkState::XxInitiator { .. } | LinkState::XxResponder { .. } => true, + LinkState::KkInitiator { + initial_ephemeral, .. + } => { + if fsm.peer.as_ref().map(|peer| peer.xid) != Some(message.header.sender) { + return false; + } + super::local_start_wins(initial_ephemeral, &message.ephemeral) + } } } diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs index a1e500c4..fd694ce0 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -1,40 +1,29 @@ mod kk; mod xx; -use std::cmp::Ordering; - -use ql_wire::{ - self as wire, EphemeralPublicKey, HandshakeHeader, HandshakeMeta, HandshakePayload, QlCrypto, - QlHandshakeRecord, XID, -}; +use ql_wire::{self as wire, EphemeralPublicKey, HandshakeMeta, QlCrypto, QlHandshakeRecord}; use super::{emit_peer_status, fail_pending_connect_session, reset_session}; use crate::{ - state::{ConnectionState, HandshakeMode, HandshakeState, SessionTransport}, - Peer, QlFsm, QlFsmError, + state::{LinkState, SessionTransport}, + QlFsm, QlFsmError, QlFsmEvent, }; pub fn handle_connect(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { - let Some(peer) = fsm.peer.as_ref().map(|entry| entry.peer.clone()) else { - return Err(QlFsmError::NoPeerBound); - }; - if !matches!( - fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Disconnected) - ) { + if !matches!(fsm.state.link, LinkState::Idle) { return Ok(()); } - match &peer.bundle { - Some(bundle) => kk::start_initiator(fsm, crypto, peer.xid, bundle.clone()), - None => xx::start_initiator(fsm, crypto, peer.xid), + match fsm.peer.clone() { + Some(peer) => kk::start_initiator(fsm, crypto, peer), + None => xx::start_initiator(fsm, crypto), } } -pub fn next_handshake_meta(fsm: &mut QlFsm) -> wire::HandshakeMeta { +pub fn next_handshake_meta(fsm: &mut QlFsm) -> HandshakeMeta { let handshake_id = wire::HandshakeId(fsm.state.next_control_id); fsm.state.next_control_id = fsm.state.next_control_id.wrapping_add(1); - wire::HandshakeMeta { + HandshakeMeta { handshake_id, valid_until: super::deadline_after_secs( fsm.state.now.unix_secs, @@ -43,21 +32,15 @@ pub fn next_handshake_meta(fsm: &mut QlFsm) -> wire::HandshakeMeta { } } -pub fn enqueue_handshake(fsm: &mut QlFsm, peer: XID, payload: HandshakePayload) { +pub fn enqueue_handshake(fsm: &mut QlFsm, record: QlHandshakeRecord) { debug_assert!(fsm.state.handshake.is_none()); - fsm.state.handshake = Some(QlHandshakeRecord { - header: HandshakeHeader { - sender: fsm.identity.xid, - recipient: peer, - }, - payload, - }); + fsm.state.handshake = Some(record); } -pub fn is_replayed_handshake_start(fsm: &mut QlFsm, peer: XID, meta: HandshakeMeta) -> bool { +pub fn is_replayed_handshake_start(fsm: &mut QlFsm, meta: HandshakeMeta) -> bool { fsm.state .replay_cache - .check_and_store_valid_until(peer, meta, fsm.state.now.unix_secs) + .check_and_store_valid_until(meta, fsm.state.now.unix_secs) } pub fn handle_handshake_record( @@ -65,143 +48,63 @@ pub fn handle_handshake_record( crypto: &impl QlCrypto, record: &QlHandshakeRecord, ) -> Result<(), QlFsmError> { - if record.header.recipient != fsm.identity.xid { - return Err(QlFsmError::InvalidXid); - } - - match &record.payload { - HandshakePayload::Xx1(message) => xx::handle_xx1(fsm, crypto, record.header, message), - HandshakePayload::Xx2(message) => xx::handle_xx2(fsm, crypto, record.header, message), - HandshakePayload::Xx3(message) => xx::handle_xx3(fsm, crypto, record.header, message), - HandshakePayload::Xx4(message) => xx::handle_xx4(fsm, crypto, record.header, message), - HandshakePayload::Kk1(message) => kk::handle_kk1(fsm, crypto, record.header, message), - HandshakePayload::Kk2(message) => kk::handle_kk2(fsm, crypto, record.header, message), + match record { + QlHandshakeRecord::Xx1(message) => xx::handle_xx1(fsm, crypto, message), + QlHandshakeRecord::Xx2(message) => xx::handle_xx2(fsm, crypto, message), + QlHandshakeRecord::Xx3(message) => xx::handle_xx3(fsm, crypto, message), + QlHandshakeRecord::Xx4(message) => xx::handle_xx4(fsm, crypto, message), + QlHandshakeRecord::Kk1(message) => kk::handle_kk1(fsm, crypto, message), + QlHandshakeRecord::Kk2(message) => kk::handle_kk2(fsm, crypto, message), } } pub fn handle_timer(fsm: &mut QlFsm) { - let timed_out = matches!( - fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Handshaking(HandshakeState { deadline, .. })) - if *deadline <= fsm.state.now.instant - ); - - if !timed_out { + let Some(deadline) = fsm.state.link.handshake_deadline() else { + return; + }; + if deadline > fsm.state.now.instant { return; } - if let Some(entry) = fsm.peer.as_mut() { - entry.session = ConnectionState::Disconnected; - } + fsm.state.link = LinkState::Idle; fsm.state.handshake = None; fail_pending_connect_session(fsm, ql_wire::SessionCloseCode::TIMEOUT); emit_peer_status(fsm); } pub fn next_handshake_deadline(fsm: &QlFsm) -> Option { - match fsm.peer.as_ref().map(|entry| &entry.session) { - Some(ConnectionState::Handshaking(HandshakeState { deadline, .. })) => Some(*deadline), - _ => None, - } + fsm.state.link.handshake_deadline() } -fn ensure_or_bind_peer( - fsm: &mut QlFsm, - xid: XID, - bundle: Option, -) -> Result<(), QlFsmError> { - match fsm.peer.as_ref() { - Some(entry) if entry.peer.xid == xid => Ok(()), - Some(_) => Err(QlFsmError::InvalidXid), - None => { - super::handle_bind_peer(fsm, Peer { xid, bundle }); - Ok(()) - } - } -} - -fn ensure_bound_peer(fsm: &QlFsm, xid: XID) -> Result<(), QlFsmError> { - match fsm.peer.as_ref() { - Some(entry) if entry.peer.xid == xid => Ok(()), - Some(_) => Err(QlFsmError::InvalidXid), - None => Ok(()), - } -} - -fn ensure_bound_peer_with_bundle(fsm: &QlFsm, xid: XID) -> Result<(), QlFsmError> { - match fsm.peer.as_ref() { - Some(entry) if entry.peer.xid == xid && entry.peer.bundle.is_some() => Ok(()), - Some(entry) if entry.peer.xid == xid => Err(QlFsmError::InvalidPayload), - Some(_) => Err(QlFsmError::InvalidXid), - None => Err(QlFsmError::NoPeerBound), - } -} - -fn finish_handshake( +pub fn finish_handshake( fsm: &mut QlFsm, transport: SessionTransport, remote_bundle: wire::PeerBundle, ) -> Result<(), QlFsmError> { - let Some(entry) = fsm.peer.as_mut() else { - return Err(QlFsmError::NoPeerBound); - }; - - match &entry.peer.bundle { - Some(existing) if existing != &remote_bundle => return Err(QlFsmError::InvalidPayload), - Some(_) => {} - None => entry.peer.bundle = Some(remote_bundle), + if let Some(peer) = fsm.peer.as_ref() { + if peer != &remote_bundle { + return Err(QlFsmError::InvalidPayload); + } + } else { + fsm.peer = Some(remote_bundle.clone()); + reset_session(fsm); + fsm.state + .events + .push_back(QlFsmEvent::NewPeer(remote_bundle.clone())); } - entry.session = ConnectionState::Connected(transport); + fsm.state.link = LinkState::Connected(transport); emit_peer_status(fsm); Ok(()) } -fn reset_connected_session_if_needed(fsm: &mut QlFsm) { - if matches!( - fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Connected(_)) - ) { +pub fn reset_connected_session_if_needed(fsm: &mut QlFsm) { + if matches!(fsm.state.link, LinkState::Connected(_)) { + fsm.state.link = LinkState::Idle; reset_session(fsm); } } -fn should_ignore_inbound_handshake_start( - fsm: &QlFsm, - sender: XID, - inbound_xx: bool, - inbound_ephemeral: &EphemeralPublicKey, -) -> bool { - let Some(entry) = fsm.peer.as_ref() else { - return false; - }; - if entry.peer.xid != sender { - return false; - } - - let ConnectionState::Handshaking(HandshakeState { - mode, - initial_ephemeral: Some(local_ephemeral), - .. - }) = &entry.session - else { - return false; - }; - - match (mode, inbound_xx) { - (HandshakeMode::KkInitiator(_), true) => false, - (HandshakeMode::XxInitiator(_), false) => true, - (HandshakeMode::XxInitiator(_), true) | (HandshakeMode::KkInitiator(_), false) => { - match inbound_ephemeral - .mlkem_public_key - .as_bytes() - .cmp(local_ephemeral.mlkem_public_key.as_bytes()) - { - Ordering::Less => false, - Ordering::Greater => true, - Ordering::Equal => sender.0.cmp(&fsm.identity.xid.0) != Ordering::Less, - } - } - _ => false, - } +fn local_start_wins(local: &EphemeralPublicKey, inbound: &EphemeralPublicKey) -> bool { + local.mlkem_public_key.as_bytes() <= inbound.mlkem_public_key.as_bytes() } diff --git a/ql-fsm/src/implementation/handshake/xx.rs b/ql-fsm/src/implementation/handshake/xx.rs index c5c66400..659d6fda 100644 --- a/ql-fsm/src/implementation/handshake/xx.rs +++ b/ql-fsm/src/implementation/handshake/xx.rs @@ -1,45 +1,31 @@ use ql_wire::{ - self as wire, HandshakeHeader, HandshakePayload, QlCrypto, WireError, Xx1, Xx2, Xx3, Xx4, - XxMessage, XID, + self as wire, QlCrypto, QlHandshakeRecord, WireError, Xx1, Xx2, Xx3, Xx4, XxMessage, }; use super::{ - ensure_bound_peer, ensure_or_bind_peer, finish_handshake, reset_connected_session_if_needed, - should_ignore_inbound_handshake_start, + enqueue_handshake, finish_handshake, is_replayed_handshake_start, + reset_connected_session_if_needed, }; use crate::{ - implementation::{emit_peer_status, enqueue_handshake, is_replayed_handshake_start}, - state::{ConnectionState, HandshakeMode, HandshakeState, SessionTransport}, + implementation::emit_peer_status, + state::{LinkState, SessionTransport}, QlFsm, QlFsmError, }; -pub fn start_initiator( - fsm: &mut QlFsm, - crypto: &impl QlCrypto, - peer: XID, -) -> Result<(), QlFsmError> { - let header = HandshakeHeader { - sender: fsm.identity.xid, - recipient: peer, - }; +pub fn start_initiator(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { let meta = super::next_handshake_meta(fsm); let mut handshake = wire::XxHandshake::new_initiator(crypto, fsm.identity.clone()); - let message = handshake.write_message(crypto, header, meta)?; - let payload = xx_payload(message); - let initial_ephemeral = match &payload { - HandshakePayload::Xx1(message) => Some(message.ephemeral.clone()), - _ => None, + let message = handshake.write_message(crypto, meta)?; + let XxMessage::Message1(message) = message else { + return Err(QlFsmError::InvalidPayload); }; - if let Some(entry) = fsm.peer.as_mut() { - entry.session = ConnectionState::Handshaking(HandshakeState { - id: meta.handshake_id, - deadline: fsm.state.now.instant + fsm.config.handshake_timeout, - mode: HandshakeMode::XxInitiator(handshake), - initial_ephemeral, - }); - } - enqueue_handshake(fsm, peer, payload); + fsm.state.link = LinkState::XxInitiator { + initial_ephemeral: message.ephemeral.clone(), + handshake, + deadline: fsm.state.now.instant + fsm.config.handshake_timeout, + }; + enqueue_handshake(fsm, QlHandshakeRecord::Xx1(message)); emit_peer_status(fsm); Ok(()) } @@ -47,45 +33,31 @@ pub fn start_initiator( pub fn handle_xx1( fsm: &mut QlFsm, crypto: &impl QlCrypto, - header: HandshakeHeader, message: &Xx1, ) -> Result<(), QlFsmError> { - if should_ignore_inbound_handshake_start(fsm, header.sender, true, &message.ephemeral) { + if should_ignore_inbound(fsm, message) { return Ok(()); } - - if is_replayed_handshake_start(fsm, header.sender, message.meta) { + if is_replayed_handshake_start(fsm, message.meta) { return Ok(()); } - ensure_or_bind_peer(fsm, header.sender, None)?; + reset_connected_session_if_needed(fsm); let mut handshake = wire::XxHandshake::new_responder(crypto, fsm.identity.clone()); handshake.read_message( crypto, - header, fsm.state.now.unix_secs, &XxMessage::Message1(message.clone()), )?; - let outbound = handshake.write_message( - crypto, - HandshakeHeader { - sender: fsm.identity.xid, - recipient: header.sender, - }, - message.meta, - )?; + let outbound = handshake.write_message(crypto, message.meta)?; - if let Some(entry) = fsm.peer.as_mut() { - entry.session = ConnectionState::Handshaking(HandshakeState { - id: message.meta.handshake_id, - deadline: fsm.state.now.instant + fsm.config.handshake_timeout, - mode: HandshakeMode::XxResponder(handshake), - initial_ephemeral: None, - }); - } fsm.state.handshake = None; - enqueue_handshake(fsm, header.sender, xx_payload(outbound)); + fsm.state.link = LinkState::XxResponder { + handshake, + deadline: fsm.state.now.instant + fsm.config.handshake_timeout, + }; + enqueue_handshake(fsm, xx_record(outbound)); emit_peer_status(fsm); Ok(()) } @@ -93,27 +65,19 @@ pub fn handle_xx1( pub fn handle_xx2( fsm: &mut QlFsm, crypto: &impl QlCrypto, - header: HandshakeHeader, message: &Xx2, ) -> Result<(), QlFsmError> { - ensure_bound_peer(fsm, header.sender)?; - let session = match fsm.peer.as_ref() { - Some(entry) => entry.session.clone(), - None => return Ok(()), - }; - let ConnectionState::Handshaking(HandshakeState { - id, + let LinkState::XxInitiator { + mut handshake, deadline, - mode: HandshakeMode::XxInitiator(mut handshake), initial_ephemeral, - }) = session + } = fsm.state.link.clone() else { return Ok(()); }; match handshake.read_message( crypto, - header, fsm.state.now.unix_secs, &XxMessage::Message2(message.clone()), ) { @@ -122,48 +86,32 @@ pub fn handle_xx2( Err(error) => return Err(error.into()), } - let outbound = handshake.write_message( - crypto, - HandshakeHeader { - sender: fsm.identity.xid, - recipient: header.sender, - }, - message.meta, - )?; - if let Some(entry) = fsm.peer.as_mut() { - entry.session = ConnectionState::Handshaking(HandshakeState { - id, - deadline, - mode: HandshakeMode::XxInitiator(handshake), - initial_ephemeral, - }); - } - enqueue_handshake(fsm, header.sender, xx_payload(outbound)); + let outbound = handshake.write_message(crypto, message.meta)?; + fsm.state.handshake = None; + fsm.state.link = LinkState::XxInitiator { + handshake, + deadline, + initial_ephemeral, + }; + enqueue_handshake(fsm, xx_record(outbound)); Ok(()) } pub fn handle_xx3( fsm: &mut QlFsm, crypto: &impl QlCrypto, - header: HandshakeHeader, message: &Xx3, ) -> Result<(), QlFsmError> { - ensure_bound_peer(fsm, header.sender)?; - let session = match fsm.peer.as_ref() { - Some(entry) => entry.session.clone(), - None => return Ok(()), - }; - let ConnectionState::Handshaking(HandshakeState { - mode: HandshakeMode::XxResponder(mut handshake), - .. - }) = session + let LinkState::XxResponder { + mut handshake, + deadline: _, + } = fsm.state.link.clone() else { return Ok(()); }; match handshake.read_message( crypto, - header, fsm.state.now.unix_secs, &XxMessage::Message3(message.clone()), ) { @@ -172,42 +120,30 @@ pub fn handle_xx3( Err(error) => return Err(error.into()), } - let outbound = handshake.write_message( - crypto, - HandshakeHeader { - sender: fsm.identity.xid, - recipient: header.sender, - }, - message.meta, - )?; + let outbound = handshake.write_message(crypto, message.meta)?; let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); finish_handshake(fsm, transport, remote_bundle)?; - enqueue_handshake(fsm, header.sender, xx_payload(outbound)); + fsm.state.handshake = None; + enqueue_handshake(fsm, xx_record(outbound)); Ok(()) } pub fn handle_xx4( fsm: &mut QlFsm, crypto: &impl QlCrypto, - header: HandshakeHeader, message: &Xx4, ) -> Result<(), QlFsmError> { - ensure_bound_peer(fsm, header.sender)?; - let session = match fsm.peer.as_ref() { - Some(entry) => entry.session.clone(), - None => return Ok(()), - }; - let ConnectionState::Handshaking(HandshakeState { - mode: HandshakeMode::XxInitiator(mut handshake), - .. - }) = session + let LinkState::XxInitiator { + mut handshake, + deadline: _, + initial_ephemeral: _, + } = fsm.state.link.clone() else { return Ok(()); }; match handshake.read_message( crypto, - header, fsm.state.now.unix_secs, &XxMessage::Message4(message.clone()), ) { @@ -220,11 +156,22 @@ pub fn handle_xx4( finish_handshake(fsm, transport, remote_bundle) } -fn xx_payload(message: XxMessage) -> HandshakePayload { +fn xx_record(message: XxMessage) -> QlHandshakeRecord { match message { - XxMessage::Message1(message) => HandshakePayload::Xx1(message), - XxMessage::Message2(message) => HandshakePayload::Xx2(message), - XxMessage::Message3(message) => HandshakePayload::Xx3(message), - XxMessage::Message4(message) => HandshakePayload::Xx4(message), + XxMessage::Message1(message) => QlHandshakeRecord::Xx1(message), + XxMessage::Message2(message) => QlHandshakeRecord::Xx2(message), + XxMessage::Message3(message) => QlHandshakeRecord::Xx3(message), + XxMessage::Message4(message) => QlHandshakeRecord::Xx4(message), + } +} + +pub fn should_ignore_inbound(fsm: &QlFsm, message: &Xx1) -> bool { + match &fsm.state.link { + LinkState::Idle | LinkState::Connected(_) => false, + LinkState::XxResponder { .. } => true, + LinkState::KkInitiator { .. } => false, + LinkState::XxInitiator { + initial_ephemeral, .. + } => super::local_start_wins(initial_ephemeral, &message.ephemeral), } } diff --git a/ql-fsm/src/implementation/mod.rs b/ql-fsm/src/implementation/mod.rs index d5131014..16912086 100644 --- a/ql-fsm/src/implementation/mod.rs +++ b/ql-fsm/src/implementation/mod.rs @@ -8,31 +8,31 @@ pub use handshake::*; use ql_wire::XID; use crate::{ - state::PeerRecord, session::{state::StreamParity, SessionEvent, SessionFsmConfig}, - Peer, QlFsm, QlFsmEvent, QlSessionEvent, + state::LinkState, + QlFsm, QlFsmEvent, QlSessionEvent, }; fn emit_peer_status(fsm: &mut QlFsm) { - if let Some(entry) = fsm.peer.as_ref() { + if let Some(peer) = fsm.peer.as_ref() { fsm.state.events.push_back(QlFsmEvent::PeerStatusChanged { - peer: entry.peer.xid, - status: entry.session.status(), + peer: peer.xid, + status: fsm.state.link.status(), }); } } fn peer_transport(fsm: &QlFsm) -> Option<(XID, crate::state::SessionTransport)> { - let entry = fsm.peer.as_ref()?; - let transport = entry.session.transport()?.clone(); - Some((entry.peer.xid, transport)) + let peer = fsm.peer.as_ref()?; + let transport = fsm.state.link.transport()?.clone(); + Some((peer.xid, transport)) } fn reset_session(fsm: &mut QlFsm) { let local_parity = fsm .peer .as_ref() - .map(|peer| StreamParity::for_local(fsm.identity.xid, peer.peer.xid)) + .map(|peer| StreamParity::for_local(fsm.identity.xid, peer.xid)) .unwrap_or(StreamParity::Even); fsm.session = crate::session::SessionFsm::new( SessionFsmConfig { @@ -49,9 +49,10 @@ fn reset_session(fsm: &mut QlFsm) { ); } -pub fn handle_bind_peer(fsm: &mut QlFsm, peer: Peer) { +pub fn handle_bind_peer(fsm: &mut QlFsm, peer: ql_wire::PeerBundle) { fsm.state.handshake = None; - fsm.peer = Some(PeerRecord::new(peer.clone())); + fsm.state.link = LinkState::Idle; + fsm.peer = Some(peer.clone()); reset_session(fsm); fsm.state.events.push_back(QlFsmEvent::NewPeer(peer)); emit_peer_status(fsm); @@ -64,7 +65,9 @@ fn fail_pending_connect_session(fsm: &mut QlFsm, code: ql_wire::SessionCloseCode reset_session(fsm); fsm.state .session_events - .push_back(QlSessionEvent::SessionClosed(ql_wire::SessionClose { code })); + .push_back(QlSessionEvent::SessionClosed(ql_wire::SessionClose { + code, + })); } fn forward_session_event( @@ -104,14 +107,9 @@ fn forward_session_event( } fn apply_session_closed(fsm: &mut QlFsm) { - if let Some(entry) = fsm.peer.as_mut() { - if matches!( - entry.session, - crate::state::ConnectionState::Connected { .. } - ) { - entry.session = crate::state::ConnectionState::Disconnected; - emit_peer_status(fsm); - } + if matches!(fsm.state.link, crate::state::LinkState::Connected(_)) { + fsm.state.link = crate::state::LinkState::Idle; + emit_peer_status(fsm); } reset_session(fsm); } diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 22ccd081..5f8f547e 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -37,7 +37,7 @@ pub use session::stream_rx::StreamReadIter; use crate::{ replay_cache::ReplayCache, session::SessionFsm, - state::{PeerRecord, QlFsmState}, + state::{LinkState, QlFsmState}, }; /// time input for `QlFsm` @@ -49,15 +49,6 @@ pub struct FsmTime { pub unix_secs: u64, } -/// bound remote peer identity and public keys -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Peer { - /// peer xid - pub xid: XID, - /// peer static bundle when known - pub bundle: Option, -} - /// connection state for the bound peer #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PeerStatus { @@ -75,7 +66,7 @@ pub enum PeerStatus { #[derive(Debug, Clone)] pub enum QlFsmEvent { /// a peer was bound or replaced - NewPeer(Peer), + NewPeer(PeerBundle), /// the bound peer was cleared ClearPeer, /// the peer changed connection state @@ -157,13 +148,13 @@ impl Default for QlFsmConfig { } } -/// synchronous driver for pairing, handshake, and encrypted streams +/// synchronous driver for peer binding, handshake, and encrypted streams pub struct QlFsm { /// active configuration pub config: QlFsmConfig, /// local identity and private keys pub identity: QlIdentity, - pub(crate) peer: Option, + pub(crate) peer: Option, pub(crate) session: SessionFsm, pub(crate) state: QlFsmState, } @@ -192,6 +183,7 @@ impl QlFsm { replay_cache: ReplayCache::default(), next_control_id: 1, handshake: None, + link: LinkState::Idle, events: Default::default(), session_events: Default::default(), now, @@ -200,7 +192,7 @@ impl QlFsm { } /// binds or replaces the remote peer - pub fn bind_peer(&mut self, peer: Peer) { + pub fn bind_peer(&mut self, peer: PeerBundle) { implementation::handle_bind_peer(self, peer); } diff --git a/ql-fsm/src/replay_cache.rs b/ql-fsm/src/replay_cache.rs index 335c75c8..547c0507 100644 --- a/ql-fsm/src/replay_cache.rs +++ b/ql-fsm/src/replay_cache.rs @@ -1,34 +1,18 @@ use std::collections::{hash_map::Entry, HashMap}; -use ql_wire::{HandshakeId, HandshakeMeta, XID}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -struct ReplayKey { - peer: XID, - handshake_id: HandshakeId, -} +use ql_wire::{HandshakeId, HandshakeMeta}; #[derive(Debug, Default)] pub struct ReplayCache { - valid_until_by_key: HashMap, + valid_until_by_id: HashMap, } impl ReplayCache { - pub fn check_and_store_valid_until( - &mut self, - peer: XID, - meta: HandshakeMeta, - now_secs: u64, - ) -> bool { - self.valid_until_by_key + pub fn check_and_store_valid_until(&mut self, meta: HandshakeMeta, now_secs: u64) -> bool { + self.valid_until_by_id .retain(|_, stored_valid_until| *stored_valid_until > now_secs); - let key = ReplayKey { - peer, - handshake_id: meta.handshake_id, - }; - - match self.valid_until_by_key.entry(key) { + match self.valid_until_by_id.entry(meta.handshake_id) { Entry::Occupied(_) => true, Entry::Vacant(entry) => { entry.insert(meta.valid_until); diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index e4e6b93f..71d2571d 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -401,7 +401,10 @@ impl SessionFsm { }) } - pub fn take_next_write(&mut self, now: Instant) -> Option<(u64, RecordSeq, SessionRecordBuilder)> { + pub fn take_next_write( + &mut self, + now: Instant, + ) -> Option<(u64, RecordSeq, SessionRecordBuilder)> { self.state.now = now; self.collect_timeouts(); @@ -894,11 +897,7 @@ impl SessionFsm { Ok(()) } - fn handle_session_close( - &mut self, - close: SessionClose, - emit: &mut impl FnMut(SessionEvent), - ) { + fn handle_session_close(&mut self, close: SessionClose, emit: &mut impl FnMut(SessionEvent)) { if self.state.session_state == SessionState::Closed { return; } diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index 2fbc4a09..fe4af462 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -2,8 +2,7 @@ use std::{collections::BTreeSet, time::Instant}; use indexmap::IndexMap; use ql_wire::{ - CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionClose, StreamClose, StreamId, - XID, + CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionClose, StreamClose, StreamId, XID, }; use super::{ diff --git a/ql-fsm/src/session/stream_tx.rs b/ql-fsm/src/session/stream_tx.rs index 7c1a11d4..4a2e21a5 100644 --- a/ql-fsm/src/session/stream_tx.rs +++ b/ql-fsm/src/session/stream_tx.rs @@ -66,9 +66,9 @@ impl StreamTx { self.segments .iter() .any(|segment| matches!(segment.state, SendState::Unsent | SendState::Lost)) - || self - .final_offset - .is_some_and(|final_offset| matches!(final_offset.state, SendState::Unsent | SendState::Lost)) + || self.final_offset.is_some_and(|final_offset| { + matches!(final_offset.state, SendState::Unsent | SendState::Lost) + }) } pub fn is_empty(&self) -> bool { @@ -211,7 +211,8 @@ impl StreamTx { self.segments[index].state = state; } else { let segment = self.segments.remove(index).unwrap(); - self.segments.insert(index, SendSegment { offset, len, state }); + self.segments + .insert(index, SendSegment { offset, len, state }); self.segments.insert( index + 1, SendSegment { diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 781f72b1..d0077006 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -29,7 +29,12 @@ fn next_outbound(fsm: &mut SessionFsm, now: Instant) -> Option<(RecordSeq, Sessi Some((seq, SessionRecord::decode(builder.bytes()).unwrap())) } -fn receive_events(fsm: &mut SessionFsm, now: Instant, seq: RecordSeq, record: SessionRecord) -> Vec { +fn receive_events( + fsm: &mut SessionFsm, + now: Instant, + seq: RecordSeq, + record: SessionRecord, +) -> Vec { let bytes = record.encode(); let frames = SessionRecord::parse(&bytes).unwrap(); let mut events = Vec::new(); @@ -90,10 +95,9 @@ fn lost_record_on_one_stream_does_not_block_another_stream() { let (first_seq, first) = next_outbound(&mut fsm, now).unwrap(); let (second_seq, _second) = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); assert_ne!(first_seq, second_seq); - assert!(first - .frames - .iter() - .any(|frame| matches!(frame, SessionFrame::StreamData(frame) if frame.stream_id == stream_id_a))); + assert!(first.frames.iter().any( + |frame| matches!(frame, SessionFrame::StreamData(frame) if frame.stream_id == stream_id_a) + )); assert_eq!(fsm.write_stream(stream_id_b, b"b-2").unwrap(), 3); let (_third_seq, third) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); @@ -234,7 +238,8 @@ fn remote_stream_close_is_reliable_and_retried() { )); fsm.on_timer(now + Duration::from_millis(200), |_| {}); - let (_retried_seq, retried) = next_outbound(&mut fsm, now + Duration::from_millis(200)).unwrap(); + let (_retried_seq, retried) = + next_outbound(&mut fsm, now + Duration::from_millis(200)).unwrap(); assert_eq!(first.frames, retried.frames); } @@ -281,7 +286,12 @@ fn duplicate_stream_data_is_not_redelivered() { })], }; let _ = receive_events(&mut fsm, now, RecordSeq(1), record.clone()); - let _ = receive_events(&mut fsm, now + Duration::from_millis(1), RecordSeq(2), record); + let _ = receive_events( + &mut fsm, + now + Duration::from_millis(1), + RecordSeq(2), + record, + ); assert_eq!(read_stream_all(&mut fsm, stream_id), b"hi".to_vec()); } diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index 8a075744..d2163de0 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -1,25 +1,20 @@ use std::{collections::VecDeque, time::Instant}; use ql_wire::{ - ConnectionId, EphemeralPublicKey, HandshakeId, KkHandshake, PeerBundle, QlHandshakeRecord, - SessionKey, XxHandshake, + ConnectionId, EphemeralPublicKey, KkHandshake, PeerBundle, QlHandshakeRecord, SessionKey, + XxHandshake, }; -use crate::{replay_cache::ReplayCache, FsmTime, Peer, PeerStatus, QlFsmEvent, QlSessionEvent}; +use crate::{replay_cache::ReplayCache, FsmTime, PeerStatus, QlFsmEvent, QlSessionEvent}; -#[derive(Debug, Clone)] -pub enum HandshakeMode { - XxInitiator(XxHandshake), - XxResponder(XxHandshake), - KkInitiator(KkHandshake), -} - -#[derive(Debug, Clone)] -pub struct HandshakeState { - pub id: HandshakeId, - pub deadline: Instant, - pub mode: HandshakeMode, - pub initial_ephemeral: Option, +pub struct QlFsmState { + pub replay_cache: ReplayCache, + pub next_control_id: u32, + pub handshake: Option, + pub link: LinkState, + pub events: VecDeque, + pub session_events: VecDeque, + pub now: FsmTime, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -45,22 +40,31 @@ impl SessionTransport { } #[derive(Debug, Clone)] -pub enum ConnectionState { - Disconnected, - Handshaking(HandshakeState), +pub enum LinkState { + Idle, + XxInitiator { + handshake: XxHandshake, + deadline: Instant, + initial_ephemeral: EphemeralPublicKey, + }, + XxResponder { + handshake: XxHandshake, + deadline: Instant, + }, + KkInitiator { + handshake: KkHandshake, + deadline: Instant, + initial_ephemeral: EphemeralPublicKey, + }, Connected(SessionTransport), } -impl ConnectionState { +impl LinkState { pub fn status(&self) -> PeerStatus { match self { - Self::Disconnected => PeerStatus::Disconnected, - Self::Handshaking(HandshakeState { mode, .. }) => match mode { - HandshakeMode::XxInitiator(_) | HandshakeMode::KkInitiator(_) => { - PeerStatus::Initiator - } - HandshakeMode::XxResponder(_) => PeerStatus::Responder, - }, + Self::Idle => PeerStatus::Disconnected, + Self::XxInitiator { .. } | Self::KkInitiator { .. } => PeerStatus::Initiator, + Self::XxResponder { .. } => PeerStatus::Responder, Self::Connected(_) => PeerStatus::Connected, } } @@ -71,28 +75,13 @@ impl ConnectionState { _ => None, } } -} -#[derive(Debug, Clone)] -pub struct PeerRecord { - pub peer: Peer, - pub session: ConnectionState, -} - -impl PeerRecord { - pub fn new(peer: Peer) -> Self { - Self { - peer, - session: ConnectionState::Disconnected, + pub fn handshake_deadline(&self) -> Option { + match self { + Self::Idle | Self::Connected(_) => None, + Self::XxInitiator { deadline, .. } + | Self::XxResponder { deadline, .. } + | Self::KkInitiator { deadline, .. } => Some(*deadline), } } } - -pub struct QlFsmState { - pub replay_cache: ReplayCache, - pub next_control_id: u32, - pub handshake: Option, - pub events: VecDeque, - pub session_events: VecDeque, - pub now: FsmTime, -} diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 5c882d21..67d7ff01 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -1,9 +1,9 @@ use std::time::Duration; -use ql_wire::{HandshakePayload, QlRecord}; +use ql_wire::QlRecord; use super::*; -use crate::state::ConnectionState; +use crate::state::LinkState; #[test] fn kk_connect_round_trip_establishes_transport() { @@ -16,14 +16,8 @@ fn kk_connect_round_trip_establishes_transport() { .unwrap(); harness.pump(); - assert!(matches!( - harness.a.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Connected(_)) - )); - assert!(matches!( - harness.b.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Connected(_)) - )); + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); } #[test] @@ -37,22 +31,10 @@ fn xx_connect_round_trip_learns_peer_bundles() { .unwrap(); harness.pump(); - assert_eq!( - harness.a.fsm.peer.as_ref().unwrap().peer.bundle, - Some(harness.b.fsm.identity.bundle()) - ); - assert_eq!( - harness.b.fsm.peer.as_ref().unwrap().peer.bundle, - Some(harness.a.fsm.identity.bundle()) - ); - assert!(matches!( - harness.a.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Connected(_)) - )); - assert!(matches!( - harness.b.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Connected(_)) - )); + assert_eq!(harness.a.fsm.peer, Some(harness.b.fsm.identity.bundle())); + assert_eq!(harness.b.fsm.peer, Some(harness.a.fsm.identity.bundle())); + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); } #[test] @@ -66,14 +48,7 @@ fn inbound_xx1_auto_binds_unbound_responder() { .unwrap(); harness.pump(); - assert_eq!( - harness.b.fsm.peer.as_ref().map(|entry| entry.peer.xid), - Some(harness.a.fsm.identity.xid) - ); - assert_eq!( - harness.b.fsm.peer.as_ref().unwrap().peer.bundle, - Some(harness.a.fsm.identity.bundle()) - ); + assert_eq!(harness.b.fsm.peer, Some(harness.a.fsm.identity.bundle())); } #[test] @@ -92,20 +67,14 @@ fn handshake_timeout_drops_single_attempt_without_resend() { let first = harness.next_outbound_a().unwrap(); assert!(matches!( first, - QlRecord::Handshake(ql_wire::QlHandshakeRecord { - payload: HandshakePayload::Xx1(_), - .. - }) + QlRecord::Handshake(ql_wire::QlHandshakeRecord::Xx1(_)) )); assert!(harness.next_outbound_a().is_none()); harness.advance(config.handshake_timeout); harness.a.fsm.on_timer(harness.time()); - assert!(matches!( - harness.a.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Disconnected) - )); + assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); assert!(harness.next_outbound_a().is_none()); } @@ -126,10 +95,7 @@ fn handshake_timeout_clears_queued_handshake_output() { harness.advance(config.handshake_timeout); harness.a.fsm.on_timer(harness.time()); - assert!(matches!( - harness.a.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Disconnected) - )); + assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); assert!(harness.next_outbound_a().is_none()); } @@ -142,10 +108,7 @@ fn bind_peer_clears_queued_handshake_output() { .fsm .connect(harness.time(), &harness.a.crypto) .unwrap(); - harness.a.fsm.bind_peer(Peer { - xid: test_identity(99).xid, - bundle: None, - }); + harness.a.fsm.bind_peer(test_identity(99).bundle()); assert!(harness.next_outbound_a().is_none()); } @@ -166,14 +129,8 @@ fn simultaneous_xx_connect_converges() { .unwrap(); harness.pump(); - assert!(matches!( - harness.a.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Connected(_)) - )); - assert!(matches!( - harness.b.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Connected(_)) - )); + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); } #[test] @@ -192,16 +149,7 @@ fn simultaneous_xx_and_kk_connect_prefers_xx() { .unwrap(); harness.pump(); - assert_eq!( - harness.a.fsm.peer.as_ref().unwrap().peer.bundle, - Some(harness.b.fsm.identity.bundle()) - ); - assert!(matches!( - harness.a.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Connected(_)) - )); - assert!(matches!( - harness.b.fsm.peer.as_ref().map(|entry| &entry.session), - Some(ConnectionState::Connected(_)) - )); + assert_eq!(harness.a.fsm.peer, Some(harness.b.fsm.identity.bundle())); + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); } diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index bee7c9ea..5a54a455 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -9,16 +9,16 @@ use std::{ use libcrux_aesgcm::AesGcm256Key; use libcrux_ml_kem::mlkem1024; use ql_wire::{ - self, generate_identity, ConnectionId, ENCRYPTED_MESSAGE_AUTH_SIZE, MlKemCiphertext, - MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, QlAead, QlCrypto, QlHash, QlIdentity, - QlKem, QlRandom, QlRecord, SessionKey, XID, + self, generate_identity, ConnectionId, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, + MlKemPublicKey, Nonce, QlAead, QlCrypto, QlHash, QlIdentity, QlKem, QlRandom, QlRecord, + SessionKey, ENCRYPTED_MESSAGE_AUTH_SIZE, XID, }; use sha2::{Digest, Sha256}; use crate::{ session::{state::StreamParity, SessionFsm, SessionFsmConfig}, - state::{ConnectionState, SessionTransport}, - FsmTime, OutboundWrite, Peer, QlFsm, QlFsmConfig, SessionWriteId, + state::{LinkState, SessionTransport}, + FsmTime, OutboundWrite, QlFsm, QlFsmConfig, SessionWriteId, }; #[derive(Clone)] @@ -182,8 +182,12 @@ impl Harness { }, }; - harness.a.fsm.bind_peer(peer_from_identity(&identity_b, know_a)); - harness.b.fsm.bind_peer(peer_from_identity(&identity_a, know_b)); + if know_a { + harness.a.fsm.bind_peer(identity_b.bundle()); + } + if know_b { + harness.b.fsm.bind_peer(identity_a.bundle()); + } while harness.a.fsm.take_next_event().is_some() {} while harness.b.fsm.take_next_event().is_some() {} @@ -199,7 +203,7 @@ impl Harness { unix_secs: 1_700_000_000, }; - let mut harness = Self { + Self { now, unix_secs: time.unix_secs, a: Node { @@ -207,18 +211,10 @@ impl Harness { crypto: TestCrypto::new(1), }, b: Node { - fsm: QlFsm::new(config, identity_b.clone(), time), + fsm: QlFsm::new(config, identity_b, time), crypto: TestCrypto::new(2), }, - }; - - harness - .a - .fsm - .bind_peer(Peer { xid: identity_b.xid, bundle: None }); - while harness.a.fsm.take_next_event().is_some() {} - - harness + } } fn connected(config: QlFsmConfig) -> Self { @@ -228,13 +224,13 @@ impl Harness { let a_to_b_conn = ConnectionId::from_data([0xA1; ConnectionId::SIZE]); let b_to_a_conn = ConnectionId::from_data([0xB2; ConnectionId::SIZE]); - harness.a.fsm.peer.as_mut().unwrap().session = ConnectionState::Connected(SessionTransport { + harness.a.fsm.state.link = LinkState::Connected(SessionTransport { tx_key: a_to_b_key.clone(), rx_key: b_to_a_key.clone(), tx_connection_id: a_to_b_conn, rx_connection_id: b_to_a_conn, }); - harness.b.fsm.peer.as_mut().unwrap().session = ConnectionState::Connected(SessionTransport { + harness.b.fsm.state.link = LinkState::Connected(SessionTransport { tx_key: b_to_a_key, rx_key: a_to_b_key, tx_connection_id: b_to_a_conn, @@ -327,24 +323,17 @@ fn test_identity(seed: u8) -> QlIdentity { generate_identity(&crypto, XID([seed; XID::SIZE])) } -fn peer_from_identity(identity: &QlIdentity, know_bundle: bool) -> Peer { - Peer { - xid: identity.xid, - bundle: know_bundle.then(|| identity.bundle()), - } -} - fn session_config(harness: &Harness, a: bool) -> SessionFsmConfig { let (local, peer, config) = if a { ( harness.a.fsm.identity.xid, - harness.a.fsm.peer.as_ref().unwrap().peer.xid, + harness.a.fsm.peer.as_ref().unwrap().xid, harness.a.fsm.config, ) } else { ( harness.b.fsm.identity.xid, - harness.b.fsm.peer.as_ref().unwrap().peer.xid, + harness.b.fsm.peer.as_ref().unwrap().xid, harness.b.fsm.config, ) }; @@ -372,7 +361,10 @@ fn decrypt_record( let plaintext = ql_wire::decrypt_record(crypto, &record.header, record.payload.clone(), session_key) .unwrap(); - (record.header, ql_wire::SessionRecord::decode(&plaintext).unwrap()) + ( + record.header, + ql_wire::SessionRecord::decode(&plaintext).unwrap(), + ) } fn sha256_parts(parts: &[&[u8]]) -> [u8; 32] { diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index c920beb7..944964c3 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -3,7 +3,7 @@ use std::time::Duration; use ql_wire::{SessionClose, StreamId}; use super::*; -use crate::{session::state::StreamParity, QlFsmEvent, QlSessionEvent}; +use crate::{session::state::StreamParity, state::LinkState, QlFsmEvent, QlSessionEvent}; fn read_stream_all(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { let mut out = Vec::new(); @@ -58,17 +58,9 @@ fn session_retransmit_uses_new_record_seq() { assert_eq!(harness.a.fsm.write_stream(stream_id, b"retry").unwrap(), 5); let first = harness.next_outbound_a().unwrap(); - let first_transport = harness - .b - .fsm - .peer - .as_ref() - .unwrap() - .session - .transport() - .unwrap() - .clone(); - let (first_header, first_record) = decrypt_record(&harness.b.crypto, &first, &first_transport.rx_key); + let first_transport = harness.b.fsm.state.link.transport().unwrap().clone(); + let (first_header, first_record) = + decrypt_record(&harness.b.crypto, &first, &first_transport.rx_key); harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); harness.a.fsm.on_timer(harness.time()); @@ -94,7 +86,10 @@ fn session_retransmit_uses_new_record_seq() { harness.b.fsm.take_next_session_event(), Some(QlSessionEvent::Readable(stream_id)) ); - assert_eq!(read_stream_all(&mut harness.b.fsm, stream_id), b"retry".to_vec()); + assert_eq!( + read_stream_all(&mut harness.b.fsm, stream_id), + b"retry".to_vec() + ); harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); harness.a.fsm.on_timer(harness.time()); @@ -118,8 +113,14 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { .matches(stream_id_b) ); - assert_eq!(harness.a.fsm.write_stream(stream_id_a, b"from-a").unwrap(), 6); - assert_eq!(harness.b.fsm.write_stream(stream_id_b, b"from-b").unwrap(), 6); + assert_eq!( + harness.a.fsm.write_stream(stream_id_a, b"from-a").unwrap(), + 6 + ); + assert_eq!( + harness.b.fsm.write_stream(stream_id_b, b"from-b").unwrap(), + 6 + ); harness.pump(); @@ -131,7 +132,10 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { harness.a.fsm.take_next_session_event(), Some(QlSessionEvent::Readable(stream_id_b)) ); - assert_eq!(read_stream_all(&mut harness.a.fsm, stream_id_b), b"from-b".to_vec()); + assert_eq!( + read_stream_all(&mut harness.a.fsm, stream_id_b), + b"from-b".to_vec() + ); assert_eq!( harness.b.fsm.take_next_session_event(), Some(QlSessionEvent::Opened(stream_id_a)) @@ -140,7 +144,10 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { harness.b.fsm.take_next_session_event(), Some(QlSessionEvent::Readable(stream_id_a)) ); - assert_eq!(read_stream_all(&mut harness.b.fsm, stream_id_a), b"from-a".to_vec()); + assert_eq!( + read_stream_all(&mut harness.b.fsm, stream_id_a), + b"from-a".to_vec() + ); } #[test] @@ -177,7 +184,7 @@ fn queued_stream_work_is_failed_when_handshake_times_out() { handshake_timeout: Duration::from_millis(50), ..QlFsmConfig::default() }; - let mut harness = Harness::paired_unknown(config); + let mut harness = Harness::paired_known(config); let stream_id = harness.a.fsm.open_stream().unwrap(); assert_eq!(harness.a.fsm.write_stream(stream_id, b"queued").unwrap(), 6); @@ -205,17 +212,7 @@ fn returned_session_write_is_reissued_with_new_record_seq() { let write = harness.next_write_a().unwrap(); let id = write.session_write_id.expect("expected session write"); let record = write.record; - let session_key = harness - .b - .fsm - .peer - .as_ref() - .unwrap() - .session - .transport() - .unwrap() - .rx_key - .clone(); + let session_key = harness.b.fsm.state.link.transport().unwrap().rx_key.clone(); let (first_header, first) = decrypt_record(&harness.b.crypto, &record, &session_key); harness.return_write_a(id); @@ -241,7 +238,10 @@ fn returned_session_write_is_reissued_with_new_record_seq() { harness.b.fsm.take_next_session_event(), Some(QlSessionEvent::Readable(stream_id)) ); - assert_eq!(read_stream_all(&mut harness.b.fsm, stream_id), b"retry".to_vec()); + assert_eq!( + read_stream_all(&mut harness.b.fsm, stream_id), + b"retry".to_vec() + ); } #[test] @@ -255,17 +255,7 @@ fn unconfirmed_session_write_does_not_start_retransmit_timer() { let write = harness.next_write_a().unwrap(); let id = write.session_write_id.expect("expected session write"); let record = write.record; - let session_key = harness - .b - .fsm - .peer - .as_ref() - .unwrap() - .session - .transport() - .unwrap() - .rx_key - .clone(); + let session_key = harness.b.fsm.state.link.transport().unwrap().rx_key.clone(); let (first_header, first) = decrypt_record(&harness.b.crypto, &record, &session_key); harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); @@ -318,10 +308,7 @@ fn kill_session_disconnects_locally() { .fsm .kill_session(ql_wire::SessionCloseCode::CANCELLED); - assert!(matches!( - harness.a.fsm.peer.as_ref().map(|entry| &entry.session), - Some(crate::state::ConnectionState::Disconnected) - )); + assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); assert_eq!( harness.a.fsm.take_next_session_event(), Some(QlSessionEvent::SessionClosed(SessionClose { @@ -348,17 +335,10 @@ fn session_records_contain_ack_frames_after_delivery() { harness.b.fsm.on_timer(harness.time()); let ack = harness.next_outbound_b().unwrap(); - let session_key = harness - .a - .fsm - .peer - .as_ref() - .unwrap() - .session - .transport() - .unwrap() - .rx_key - .clone(); + let session_key = harness.a.fsm.state.link.transport().unwrap().rx_key.clone(); let (_ack_header, ack_record) = decrypt_record(&harness.a.crypto, &ack, &session_key); - assert!(matches!(ack_record.frames.as_slice(), [ql_wire::SessionFrame::Ack(_)])); + assert!(matches!( + ack_record.frames.as_slice(), + [ql_wire::SessionFrame::Ack(_)] + )); } From 8b52d6f382cb6cfed2850afc09de66fd969e829e Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 30 Mar 2026 15:03:21 -0400 Subject: [PATCH 066/304] ql-fsm: move peer into state --- ql-fsm/src/implementation/fsm.rs | 29 +++++++++++----------- ql-fsm/src/implementation/handshake/kk.rs | 4 +-- ql-fsm/src/implementation/handshake/mod.rs | 6 ++--- ql-fsm/src/implementation/mod.rs | 12 +++------ ql-fsm/src/lib.rs | 3 +-- ql-fsm/src/state.rs | 10 ++++++++ ql-fsm/src/tests/handshake.rs | 20 ++++++++++++--- ql-fsm/src/tests/mod.rs | 4 +-- 8 files changed, 51 insertions(+), 37 deletions(-) diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index 56877f1f..2f5d59af 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -15,7 +15,7 @@ pub fn receive( match wire::QlRecord::parse(&mut bytes[..])? { wire::QlRecord::Handshake(record) => super::handle_handshake_record(fsm, crypto, &record), wire::QlRecord::Session(record) => { - let (_, transport) = super::peer_transport(fsm).ok_or(QlFsmError::NoSession)?; + let transport = fsm.state.link.transport().ok_or(QlFsmError::NoSession)?; if record.header.connection_id != transport.rx_connection_id { return Err(QlFsmError::InvalidPayload); } @@ -41,7 +41,7 @@ pub fn receive( pub fn on_timer(fsm: &mut QlFsm) { super::handle_timer(fsm); - if super::peer_transport(fsm).is_some() { + if fsm.state.link.transport().is_some() { let mut session_closed = false; fsm.session.on_timer(fsm.state.now.instant, { let session_events = &mut fsm.state.session_events; @@ -58,7 +58,10 @@ pub fn on_timer(fsm: &mut QlFsm) { pub fn next_deadline(fsm: &QlFsm) -> Option { [ super::next_handshake_deadline(fsm), - super::peer_transport(fsm).and_then(|_| fsm.session.next_deadline()), + fsm.state + .link + .transport() + .and_then(|_| fsm.session.next_deadline()), ] .into_iter() .flatten() @@ -74,7 +77,7 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option Option Result { - ensure_peer_bound(fsm)?; + fsm.state.ensure_peer_bound()?; Ok(fsm.session.open_stream()?) } @@ -136,7 +139,7 @@ pub fn write_stream( stream_id: StreamId, bytes: &[u8], ) -> Result { - ensure_peer_bound(fsm)?; + fsm.state.ensure_peer_bound()?; Ok(fsm.session.write_stream(stream_id, bytes)?) } @@ -157,7 +160,7 @@ pub fn stream_available_bytes(fsm: &QlFsm, stream_id: StreamId) -> Result Result<(), QlFsmError> { - ensure_peer_bound(fsm)?; + fsm.state.ensure_peer_bound()?; Ok(fsm.session.finish_stream(stream_id)?) } @@ -167,7 +170,7 @@ pub fn close_stream( target: CloseTarget, code: StreamCloseCode, ) -> Result<(), QlFsmError> { - ensure_peer_bound(fsm)?; + fsm.state.ensure_peer_bound()?; Ok(fsm.session.close_stream(stream_id, target, code)?) } @@ -176,12 +179,8 @@ pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), QlFsmError> { Ok(fsm.session.queue_ping()?) } -fn ensure_peer_bound(fsm: &QlFsm) -> Result<(), QlFsmError> { - fsm.peer.as_ref().map(|_| ()).ok_or(QlFsmError::NoPeerBound) -} - fn ensure_session_open(fsm: &QlFsm) -> Result<(), QlFsmError> { - ensure_peer_bound(fsm)?; + fsm.state.ensure_peer_bound()?; if fsm.state.link.transport().is_none() { return Err(QlFsmError::SessionClosed); } diff --git a/ql-fsm/src/implementation/handshake/kk.rs b/ql-fsm/src/implementation/handshake/kk.rs index d22a9b17..3215f48c 100644 --- a/ql-fsm/src/implementation/handshake/kk.rs +++ b/ql-fsm/src/implementation/handshake/kk.rs @@ -46,7 +46,7 @@ pub fn handle_kk1( return Ok(()); } - let Some(peer) = fsm.peer.clone() else { + let Some(peer) = fsm.state.peer.clone() else { return Err(QlFsmError::InvalidPayload); }; if message.header.recipient != fsm.identity.xid || message.header.sender != peer.xid { @@ -111,7 +111,7 @@ pub fn should_ignore_inbound(fsm: &QlFsm, message: &Kk1) -> bool { LinkState::KkInitiator { initial_ephemeral, .. } => { - if fsm.peer.as_ref().map(|peer| peer.xid) != Some(message.header.sender) { + if fsm.state.peer.as_ref().map(|peer| peer.xid) != Some(message.header.sender) { return false; } super::local_start_wins(initial_ephemeral, &message.ephemeral) diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs index fd694ce0..ea0213b4 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -14,7 +14,7 @@ pub fn handle_connect(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlF return Ok(()); } - match fsm.peer.clone() { + match fsm.state.peer.clone() { Some(peer) => kk::start_initiator(fsm, crypto, peer), None => xx::start_initiator(fsm, crypto), } @@ -81,12 +81,12 @@ pub fn finish_handshake( transport: SessionTransport, remote_bundle: wire::PeerBundle, ) -> Result<(), QlFsmError> { - if let Some(peer) = fsm.peer.as_ref() { + if let Some(peer) = fsm.state.peer.as_ref() { if peer != &remote_bundle { return Err(QlFsmError::InvalidPayload); } } else { - fsm.peer = Some(remote_bundle.clone()); + fsm.state.peer = Some(remote_bundle.clone()); reset_session(fsm); fsm.state .events diff --git a/ql-fsm/src/implementation/mod.rs b/ql-fsm/src/implementation/mod.rs index 16912086..79d2f3f2 100644 --- a/ql-fsm/src/implementation/mod.rs +++ b/ql-fsm/src/implementation/mod.rs @@ -5,7 +5,6 @@ use std::{collections::VecDeque, time::Duration}; pub use fsm::*; pub use handshake::*; -use ql_wire::XID; use crate::{ session::{state::StreamParity, SessionEvent, SessionFsmConfig}, @@ -14,7 +13,7 @@ use crate::{ }; fn emit_peer_status(fsm: &mut QlFsm) { - if let Some(peer) = fsm.peer.as_ref() { + if let Some(peer) = fsm.state.peer.as_ref() { fsm.state.events.push_back(QlFsmEvent::PeerStatusChanged { peer: peer.xid, status: fsm.state.link.status(), @@ -22,14 +21,9 @@ fn emit_peer_status(fsm: &mut QlFsm) { } } -fn peer_transport(fsm: &QlFsm) -> Option<(XID, crate::state::SessionTransport)> { - let peer = fsm.peer.as_ref()?; - let transport = fsm.state.link.transport()?.clone(); - Some((peer.xid, transport)) -} - fn reset_session(fsm: &mut QlFsm) { let local_parity = fsm + .state .peer .as_ref() .map(|peer| StreamParity::for_local(fsm.identity.xid, peer.xid)) @@ -52,7 +46,7 @@ fn reset_session(fsm: &mut QlFsm) { pub fn handle_bind_peer(fsm: &mut QlFsm, peer: ql_wire::PeerBundle) { fsm.state.handshake = None; fsm.state.link = LinkState::Idle; - fsm.peer = Some(peer.clone()); + fsm.state.peer = Some(peer.clone()); reset_session(fsm); fsm.state.events.push_back(QlFsmEvent::NewPeer(peer)); emit_peer_status(fsm); diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 5f8f547e..dd914162 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -154,7 +154,6 @@ pub struct QlFsm { pub config: QlFsmConfig, /// local identity and private keys pub identity: QlIdentity, - pub(crate) peer: Option, pub(crate) session: SessionFsm, pub(crate) state: QlFsmState, } @@ -165,7 +164,6 @@ impl QlFsm { Self { config, identity, - peer: None, session: session::SessionFsm::new( session::SessionFsmConfig { local_parity: session::state::StreamParity::Even, @@ -182,6 +180,7 @@ impl QlFsm { state: QlFsmState { replay_cache: ReplayCache::default(), next_control_id: 1, + peer: None, handshake: None, link: LinkState::Idle, events: Default::default(), diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index d2163de0..cb78022f 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -10,6 +10,7 @@ use crate::{replay_cache::ReplayCache, FsmTime, PeerStatus, QlFsmEvent, QlSessio pub struct QlFsmState { pub replay_cache: ReplayCache, pub next_control_id: u32, + pub peer: Option, pub handshake: Option, pub link: LinkState, pub events: VecDeque, @@ -85,3 +86,12 @@ impl LinkState { } } } + +impl QlFsmState { + pub fn ensure_peer_bound(&self) -> Result<(), crate::QlFsmError> { + self.peer + .as_ref() + .map(|_| ()) + .ok_or(crate::QlFsmError::NoPeerBound) + } +} diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 67d7ff01..e4633ad7 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -31,8 +31,14 @@ fn xx_connect_round_trip_learns_peer_bundles() { .unwrap(); harness.pump(); - assert_eq!(harness.a.fsm.peer, Some(harness.b.fsm.identity.bundle())); - assert_eq!(harness.b.fsm.peer, Some(harness.a.fsm.identity.bundle())); + assert_eq!( + harness.a.fsm.state.peer, + Some(harness.b.fsm.identity.bundle()) + ); + assert_eq!( + harness.b.fsm.state.peer, + Some(harness.a.fsm.identity.bundle()) + ); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); } @@ -48,7 +54,10 @@ fn inbound_xx1_auto_binds_unbound_responder() { .unwrap(); harness.pump(); - assert_eq!(harness.b.fsm.peer, Some(harness.a.fsm.identity.bundle())); + assert_eq!( + harness.b.fsm.state.peer, + Some(harness.a.fsm.identity.bundle()) + ); } #[test] @@ -149,7 +158,10 @@ fn simultaneous_xx_and_kk_connect_prefers_xx() { .unwrap(); harness.pump(); - assert_eq!(harness.a.fsm.peer, Some(harness.b.fsm.identity.bundle())); + assert_eq!( + harness.a.fsm.state.peer, + Some(harness.b.fsm.identity.bundle()) + ); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); } diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 5a54a455..92d0c1cd 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -327,13 +327,13 @@ fn session_config(harness: &Harness, a: bool) -> SessionFsmConfig { let (local, peer, config) = if a { ( harness.a.fsm.identity.xid, - harness.a.fsm.peer.as_ref().unwrap().xid, + harness.a.fsm.state.peer.as_ref().unwrap().xid, harness.a.fsm.config, ) } else { ( harness.b.fsm.identity.xid, - harness.b.fsm.peer.as_ref().unwrap().xid, + harness.b.fsm.state.peer.as_ref().unwrap().xid, harness.b.fsm.config, ) }; From 0e8dcc94a2686be76c0fff0ef69e397d3c9e337d Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 30 Mar 2026 15:51:42 -0400 Subject: [PATCH 067/304] ql-fsm: remove intermediate allocations --- ql-fsm/src/session/mod.rs | 286 ++++++++++++++++++++------------------ 1 file changed, 154 insertions(+), 132 deletions(-) diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 71d2571d..dbb51db5 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -7,7 +7,7 @@ mod tests; use std::time::{Duration, Instant}; -use indexmap::map::Entry; +use indexmap::{map::Entry, IndexMap}; use ql_wire::{ CloseTarget, RecordAck, RecordSeq, SessionClose, SessionCloseCode, SessionFrame, SessionRecordBuilder, StreamClose, StreamCloseCode, StreamData, StreamId, StreamWindow, @@ -331,7 +331,14 @@ impl SessionFsm { let Some(record) = self.state.outbound_records.shift_remove(&write_id) else { return; }; - self.restore_outbound_record(record); + restore_outbound_record( + self.state.now, + self.config.ack_delay, + &mut self.state.ack_state, + &mut self.state.pending_control, + &mut self.state.streams, + record, + ); } pub fn on_timer(&mut self, now: Instant, mut emit: impl FnMut(SessionEvent)) { @@ -616,43 +623,32 @@ impl SessionFsm { } fn process_record_ack(&mut self, ack: RecordAck, emit: &mut impl FnMut(SessionEvent)) { - let acked: Vec = self - .state - .outbound_records - .iter() - .filter_map(|(write_id, record)| { - record - .sent_at - .filter(|_| Self::ack_covers(&ack, record.seq)) - .map(|_| *write_id) - }) - .collect(); - - for write_id in acked { - let Some(record) = self.state.outbound_records.shift_remove(&write_id) else { - continue; - }; - for frame in record.reliable { - self.acknowledge_reliable_frame(frame, emit); + let stream_send_buffer_size = self.config.stream_send_buffer_size; + { + let outbound_records = &mut self.state.outbound_records; + let streams = &mut self.state.streams; + for (_, record) in outbound_records.extract_if(.., |_, record| { + record.sent_at.is_some() + && ack + .ranges + .iter() + .any(|range| range.start <= record.seq.0 && record.seq.0 < range.end) + }) { + for frame in &record.reliable { + acknowledge_reliable_frame(streams, stream_send_buffer_size, frame, emit); + } } } - } - - fn ack_covers(ack: &RecordAck, seq: RecordSeq) -> bool { - ack.ranges - .iter() - .any(|range| range.start <= seq.0 && seq.0 < range.end) + self.reap_reapable_streams(); } fn schedule_ack(&mut self, immediate: bool) { - self.state.ack_state = match self.state.ack_state { - AckState::Immediate => AckState::Immediate, - _ if immediate || self.config.ack_delay.is_zero() => AckState::Immediate, - AckState::Delayed { due_at } => AckState::Delayed { due_at }, - AckState::Idle => AckState::Delayed { - due_at: self.state.now + self.config.ack_delay, - }, - }; + schedule_ack( + &mut self.state.ack_state, + self.state.now, + self.config.ack_delay, + immediate, + ); } fn should_send_ack(&self) -> bool { @@ -667,80 +663,20 @@ impl SessionFsm { } fn collect_timeouts(&mut self) { - let expired: Vec = self - .state - .outbound_records - .iter() - .filter_map(|(write_id, record)| { - record - .sent_at - .filter(|sent_at| *sent_at + self.config.retransmit_timeout <= self.state.now) - .map(|_| *write_id) - }) - .collect(); - - for write_id in expired { - let Some(record) = self.state.outbound_records.shift_remove(&write_id) else { - continue; - }; - self.restore_outbound_record(record); - } - } - - fn restore_outbound_record(&mut self, record: OutboundRecord) { - if record.ack_included { - self.schedule_ack(true); - } - if record.ping_included { - self.state.pending_control.ping = true; - } - for (stream_id, maximum_offset) in record.window_updates { - if let Some(stream) = self.state.streams.get_mut(&stream_id) { - if stream.recv_limit() >= maximum_offset { - stream.pending_window = true; - } - } - } - for frame in record.reliable { - self.requeue_reliable_frame(frame); - } - } - - fn requeue_reliable_frame(&mut self, frame: ReliableFrame) { - match frame { - ReliableFrame::Close(close) => { - self.state.pending_control.close = Some(close); - } - ReliableFrame::StreamClose(close) => self.restore_stream_close(close), - ReliableFrame::StreamData(frame) => self.restore_stream_data(frame), - } - } - - fn acknowledge_reliable_frame( - &mut self, - frame: ReliableFrame, - emit: &mut impl FnMut(SessionEvent), - ) { - match frame { - ReliableFrame::Close(_) => {} - ReliableFrame::StreamClose(frame) => { - self.try_reap_stream(frame.stream_id); - } - ReliableFrame::StreamData(frame) => { - let stream_id = frame.stream_id; - if let Some(stream) = self.state.streams.get_mut(&stream_id) { - let was_full = stream.send_capacity(self.config.stream_send_buffer_size) == 0; - stream.tx.mark_acked(StreamTxRange { - offset: frame.offset, - len: frame.len, - fin: frame.fin, - }); - if was_full && stream.send_capacity(self.config.stream_send_buffer_size) > 0 { - emit(SessionEvent::Writable(stream_id)); - } - } - self.try_reap_stream(stream_id); - } + let retransmit_timeout = self.config.retransmit_timeout; + for (_, record) in self.state.outbound_records.extract_if(.., |_, record| { + record + .sent_at + .is_some_and(|sent_at| sent_at + retransmit_timeout <= self.state.now) + }) { + restore_outbound_record( + self.state.now, + self.config.ack_delay, + &mut self.state.ack_state, + &mut self.state.pending_control, + &mut self.state.streams, + record, + ); } } @@ -928,30 +864,6 @@ impl SessionFsm { matches!(target, CloseTarget::Both) || role.outbound_target() == target } - fn restore_stream_close(&mut self, close: StreamClose) { - if let Some(stream) = self.state.streams.get_mut(&close.stream_id) { - stream.pending_close = Some(close); - } - } - - fn restore_stream_data(&mut self, frame: StreamDataManifest) { - if let Some(stream) = self.state.streams.get_mut(&frame.stream_id) { - if matches!(stream.outbound_state, OutboundState::Closed) { - return; - } - stream.tx.mark_lost(StreamTxRange { - offset: frame.offset, - len: frame.len, - fin: frame.fin, - }); - if frame.fin { - if matches!(stream.outbound_state, OutboundState::Finished) { - stream.outbound_state = OutboundState::FinQueued; - } - } - } - } - fn stream_is_reapable(&self, stream_id: StreamId, stream: &StreamState) -> bool { let outbound_refs_stream = self.state.outbound_records.values().any(|record| { record.window_updates.iter().any(|(id, _)| *id == stream_id) @@ -982,6 +894,18 @@ impl SessionFsm { ) } + fn reap_reapable_streams(&mut self) { + let mut index = 0usize; + while index < self.state.streams.len() { + let stream_id = *self.state.streams.get_index(index).unwrap().0; + let len_before = self.state.streams.len(); + self.try_reap_stream(stream_id); + if self.state.streams.len() == len_before { + index += 1; + } + } + } + fn try_reap_stream(&mut self, stream_id: StreamId) { let should_reap = self .state @@ -1027,3 +951,101 @@ impl SessionFsm { self.state.streams.clear(); } } + +fn schedule_ack(ack_state: &mut AckState, now: Instant, ack_delay: Duration, immediate: bool) { + *ack_state = match *ack_state { + AckState::Immediate => AckState::Immediate, + _ if immediate || ack_delay.is_zero() => AckState::Immediate, + AckState::Delayed { due_at } => AckState::Delayed { due_at }, + AckState::Idle => AckState::Delayed { + due_at: now + ack_delay, + }, + }; +} + +fn restore_outbound_record( + now: Instant, + ack_delay: Duration, + ack_state: &mut AckState, + pending_control: &mut state::PendingSessionControl, + streams: &mut IndexMap, + record: OutboundRecord, +) { + if record.ack_included { + schedule_ack(ack_state, now, ack_delay, true); + } + if record.ping_included { + pending_control.ping = true; + } + for (stream_id, maximum_offset) in record.window_updates { + if let Some(stream) = streams.get_mut(&stream_id) { + if stream.recv_limit() >= maximum_offset { + stream.pending_window = true; + } + } + } + for frame in record.reliable { + requeue_reliable_frame(pending_control, streams, frame); + } +} + +fn requeue_reliable_frame( + pending_control: &mut state::PendingSessionControl, + streams: &mut IndexMap, + frame: ReliableFrame, +) { + match frame { + ReliableFrame::Close(close) => { + pending_control.close = Some(close); + } + ReliableFrame::StreamClose(close) => restore_stream_close(streams, close), + ReliableFrame::StreamData(frame) => restore_stream_data(streams, frame), + } +} + +fn restore_stream_close(streams: &mut IndexMap, close: StreamClose) { + if let Some(stream) = streams.get_mut(&close.stream_id) { + stream.pending_close = Some(close); + } +} + +fn restore_stream_data(streams: &mut IndexMap, frame: StreamDataManifest) { + if let Some(stream) = streams.get_mut(&frame.stream_id) { + if matches!(stream.outbound_state, OutboundState::Closed) { + return; + } + stream.tx.mark_lost(StreamTxRange { + offset: frame.offset, + len: frame.len, + fin: frame.fin, + }); + if frame.fin && matches!(stream.outbound_state, OutboundState::Finished) { + stream.outbound_state = OutboundState::FinQueued; + } + } +} + +fn acknowledge_reliable_frame( + streams: &mut IndexMap, + stream_send_buffer_size: usize, + frame: &ReliableFrame, + emit: &mut impl FnMut(SessionEvent), +) { + match frame { + ReliableFrame::Close(_) | ReliableFrame::StreamClose(_) => {} + ReliableFrame::StreamData(frame) => { + let stream_id = frame.stream_id; + if let Some(stream) = streams.get_mut(&stream_id) { + let was_full = stream.send_capacity(stream_send_buffer_size) == 0; + stream.tx.mark_acked(StreamTxRange { + offset: frame.offset, + len: frame.len, + fin: frame.fin, + }); + if was_full && stream.send_capacity(stream_send_buffer_size) > 0 { + emit(SessionEvent::Writable(stream_id)); + } + } + } + } +} From d3ce08c7e6b15f51b3c595acf316ad999ac7d56e Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 2 Apr 2026 10:03:34 -0400 Subject: [PATCH 068/304] ql-wire: remove xx handshake for ik handshake --- ql-wire/src/handshake/{xx.rs => ik.rs} | 343 ++++++++++++------------- ql-wire/src/handshake/kk.rs | 67 ++--- ql-wire/src/handshake/mod.rs | 65 +++-- ql-wire/src/lib.rs | 2 +- ql-wire/src/record.rs | 46 ++-- ql-wire/src/tests.rs | 204 ++++++++++----- 6 files changed, 376 insertions(+), 351 deletions(-) rename ql-wire/src/handshake/{xx.rs => ik.rs} (56%) diff --git a/ql-wire/src/handshake/xx.rs b/ql-wire/src/handshake/ik.rs similarity index 56% rename from ql-wire/src/handshake/xx.rs rename to ql-wire/src/handshake/ik.rs index 5b21197a..d8a4be0d 100644 --- a/ql-wire/src/handshake/xx.rs +++ b/ql-wire/src/handshake/ik.rs @@ -1,8 +1,9 @@ use super::{ decrypt_mlkem_ciphertext, decrypt_peer_bundle, encrypt_mlkem_ciphertext, encrypt_peer_bundle, - finalize_handshake, generate_ephemeral_keypair, initialize_handshake_meta, mix_hash_ephemeral, - mix_hash_xx_handshake, require_handshake_meta, EncryptedMlKemCiphertext, EncryptedPeerBundle, - EphemeralKeyPair, EphemeralPublicKey, FinalizedHandshake, Role, SymmetricState, PROTOCOL_XX, + finalize_handshake, generate_ephemeral_keypair, init_ik_symmetric, initialize_handshake_meta, + mix_hash_ephemeral, mix_hash_routed_handshake, require_handshake_meta, + EncryptedMlKemCiphertext, EncryptedPeerBundle, EphemeralKeyPair, EphemeralPublicKey, + FinalizedHandshake, HandshakeHeader, Role, SymmetricState, }; use crate::{ codec, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, QlIdentity, @@ -10,205 +11,228 @@ use crate::{ }; #[derive(Debug, Clone, PartialEq, Eq)] -pub struct Xx1 { +pub struct Ik1 { + pub header: HandshakeHeader, pub meta: HandshakeMeta, + pub skem_ciphertext: MlKemCiphertext, pub ephemeral: EphemeralPublicKey, -} - -impl Xx1 { - pub const ENCODED_LEN: usize = HandshakeMeta::ENCODED_LEN + EphemeralPublicKey::ENCODED_LEN; - - pub fn encode_into(&self, out: &mut Vec) { - self.meta.encode_into(out); - self.ephemeral.encode_into(out); - } - - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - let meta = HandshakeMeta::decode_from(&mut reader)?; - let ephemeral = - EphemeralPublicKey::decode(&reader.take_bytes(EphemeralPublicKey::ENCODED_LEN)?)?; - reader.finish()?; - Ok(Self { meta, ephemeral }) - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Xx2 { - pub meta: HandshakeMeta, - pub ekem_ciphertext: MlKemCiphertext, - pub static_bundle: EncryptedPeerBundle, -} - -impl Xx2 { - pub const ENCODED_LEN: usize = - HandshakeMeta::ENCODED_LEN + MlKemCiphertext::SIZE + EncryptedPeerBundle::ENCODED_LEN; - - pub fn encode_into(&self, out: &mut Vec) { - self.meta.encode_into(out); - codec::push_bytes(out, self.ekem_ciphertext.as_bytes()); - codec::push_bytes(out, self.static_bundle.as_bytes()); - } - - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - let meta = HandshakeMeta::decode_from(&mut reader)?; - let ekem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); - let static_bundle = EncryptedPeerBundle::from_data(reader.take_array()?); - reader.finish()?; - Ok(Self { - meta, - ekem_ciphertext, - static_bundle, - }) - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Xx3 { - pub meta: HandshakeMeta, - pub skem_ciphertext: EncryptedMlKemCiphertext, pub static_bundle: EncryptedPeerBundle, } -impl Xx3 { - pub const ENCODED_LEN: usize = HandshakeMeta::ENCODED_LEN - + EncryptedMlKemCiphertext::ENCODED_LEN +impl Ik1 { + pub const ENCODED_LEN: usize = HandshakeHeader::ENCODED_LEN + + HandshakeMeta::ENCODED_LEN + + MlKemCiphertext::SIZE + + EphemeralPublicKey::ENCODED_LEN + EncryptedPeerBundle::ENCODED_LEN; pub fn encode_into(&self, out: &mut Vec) { + self.header.encode_into(out); self.meta.encode_into(out); codec::push_bytes(out, self.skem_ciphertext.as_bytes()); + self.ephemeral.encode_into(out); codec::push_bytes(out, self.static_bundle.as_bytes()); } pub fn decode(bytes: &[u8]) -> Result { let mut reader = codec::Reader::new(bytes); + let header = HandshakeHeader::decode_from(&mut reader)?; let meta = HandshakeMeta::decode_from(&mut reader)?; - let skem_ciphertext = EncryptedMlKemCiphertext::from_data(reader.take_array()?); + let skem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); + let ephemeral = + EphemeralPublicKey::decode(&reader.take_bytes(EphemeralPublicKey::ENCODED_LEN)?)?; let static_bundle = EncryptedPeerBundle::from_data(reader.take_array()?); reader.finish()?; Ok(Self { + header, meta, skem_ciphertext, + ephemeral, static_bundle, }) } } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct Xx4 { +pub struct Ik2 { + pub header: HandshakeHeader, pub meta: HandshakeMeta, + pub ekem_ciphertext: MlKemCiphertext, pub skem_ciphertext: EncryptedMlKemCiphertext, } -impl Xx4 { - pub const ENCODED_LEN: usize = - HandshakeMeta::ENCODED_LEN + EncryptedMlKemCiphertext::ENCODED_LEN; +impl Ik2 { + pub const ENCODED_LEN: usize = HandshakeHeader::ENCODED_LEN + + HandshakeMeta::ENCODED_LEN + + MlKemCiphertext::SIZE + + EncryptedMlKemCiphertext::ENCODED_LEN; pub fn encode_into(&self, out: &mut Vec) { + self.header.encode_into(out); self.meta.encode_into(out); + codec::push_bytes(out, self.ekem_ciphertext.as_bytes()); codec::push_bytes(out, self.skem_ciphertext.as_bytes()); } pub fn decode(bytes: &[u8]) -> Result { let mut reader = codec::Reader::new(bytes); + let header = HandshakeHeader::decode_from(&mut reader)?; let meta = HandshakeMeta::decode_from(&mut reader)?; + let ekem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); let skem_ciphertext = EncryptedMlKemCiphertext::from_data(reader.take_array()?); reader.finish()?; Ok(Self { + header, meta, + ekem_ciphertext, skem_ciphertext, }) } } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum XxMessage { - Message1(Xx1), - Message2(Xx2), - Message3(Xx3), - Message4(Xx4), +pub enum IkMessage { + Message1(Ik1), + Message2(Ik2), } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum XxStep { +enum IkStep { Send1, Recv1, Send2, Recv2, - Send3, - Recv3, - Send4, - Recv4, Done, } #[derive(Debug, Clone)] -pub struct XxHandshake { +pub struct IkHandshake { role: Role, - step: XxStep, + step: IkStep, symmetric: SymmetricState, local: QlIdentity, + remote_bundle: Option, local_ephemeral: Option, remote_ephemeral: Option, - remote_bundle: Option, handshake_meta: Option, } -impl XxHandshake { - pub fn new_initiator(crypto: &impl QlCrypto, local: QlIdentity) -> Self { +impl IkHandshake { + pub fn new_initiator( + crypto: &impl QlCrypto, + local: QlIdentity, + remote_bundle: PeerBundle, + ) -> Self { + let symmetric = init_ik_symmetric(crypto, &remote_bundle); Self { role: Role::Initiator, - step: XxStep::Send1, - symmetric: SymmetricState::new(crypto, PROTOCOL_XX), + step: IkStep::Send1, + symmetric, local, + remote_bundle: Some(remote_bundle), local_ephemeral: None, remote_ephemeral: None, - remote_bundle: None, handshake_meta: None, } } - pub fn new_responder(crypto: &impl QlCrypto, local: QlIdentity) -> Self { + pub fn new_responder( + crypto: &impl QlCrypto, + local: QlIdentity, + expected_remote: Option, + ) -> Self { + let symmetric = init_ik_symmetric(crypto, &local.bundle()); Self { role: Role::Responder, - step: XxStep::Recv1, - symmetric: SymmetricState::new(crypto, PROTOCOL_XX), + step: IkStep::Recv1, + symmetric, local, + remote_bundle: expected_remote, local_ephemeral: None, remote_ephemeral: None, - remote_bundle: None, handshake_meta: None, } } pub fn is_finished(&self) -> bool { - self.step == XxStep::Done + self.step == IkStep::Done + } + + fn outbound_header(&self) -> Result { + let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; + Ok(HandshakeHeader { + sender: self.local.xid, + recipient: remote_bundle.xid, + }) + } + + fn ensure_inbound_recipient(&self, header: HandshakeHeader) -> Result<(), WireError> { + if header.recipient == self.local.xid { + Ok(()) + } else { + Err(WireError::InvalidPayload) + } + } + + fn ensure_known_remote_sender(&self, header: HandshakeHeader) -> Result<(), WireError> { + if let Some(remote_bundle) = self.remote_bundle.as_ref() { + if header.sender != remote_bundle.xid { + return Err(WireError::InvalidPayload); + } + } + Ok(()) } pub fn write_message( &mut self, crypto: &impl QlCrypto, meta: HandshakeMeta, - ) -> Result { + ) -> Result { match self.step { - XxStep::Send1 => { + IkStep::Send1 => { initialize_handshake_meta(&mut self.handshake_meta, meta)?; - mix_hash_xx_handshake(&mut self.symmetric, crypto, HandshakeKind::Xx1, &meta); + let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; + let header = self.outbound_header()?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Ik1, + &meta, + ); + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); + self.symmetric.mix_hash(crypto, skem_ciphertext.as_bytes()); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + let local_ephemeral = generate_ephemeral_keypair(crypto); let public = local_ephemeral.public(); mix_hash_ephemeral(&mut self.symmetric, crypto, &public); + + let static_bundle = + encrypt_peer_bundle(crypto, &mut self.symmetric, &self.local.bundle())?; + self.local_ephemeral = Some(local_ephemeral); - self.step = XxStep::Recv2; - Ok(XxMessage::Message1(Xx1 { + self.step = IkStep::Recv2; + Ok(IkMessage::Message1(Ik1 { + header, meta, + skem_ciphertext, ephemeral: public, + static_bundle, })) } - XxStep::Send2 => { + IkStep::Send2 => { require_handshake_meta(&self.handshake_meta, meta)?; - mix_hash_xx_handshake(&mut self.symmetric, crypto, HandshakeKind::Xx2, &meta); + let header = self.outbound_header()?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Ik2, + &meta, + ); let remote_ephemeral = self .remote_ephemeral .clone() @@ -218,51 +242,19 @@ impl XxHandshake { self.symmetric.mix_hash(crypto, ekem_ciphertext.as_bytes()); self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); - let static_bundle = - encrypt_peer_bundle(crypto, &mut self.symmetric, &self.local.bundle())?; - - self.step = XxStep::Recv3; - Ok(XxMessage::Message2(Xx2 { - meta, - ekem_ciphertext, - static_bundle, - })) - } - XxStep::Send3 => { - require_handshake_meta(&self.handshake_meta, meta)?; - mix_hash_xx_handshake(&mut self.symmetric, crypto, HandshakeKind::Xx3, &meta); - let remote_bundle = self.remote_bundle.clone().ok_or(WireError::InvalidState)?; - let (skem_ciphertext, skem_secret) = - crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); - let skem_ciphertext = - encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; - self.symmetric - .mix_key_and_hash(crypto, skem_secret.as_bytes()); - - let static_bundle = - encrypt_peer_bundle(crypto, &mut self.symmetric, &self.local.bundle())?; - - self.step = XxStep::Recv4; - Ok(XxMessage::Message3(Xx3 { - meta, - skem_ciphertext, - static_bundle, - })) - } - XxStep::Send4 => { - require_handshake_meta(&self.handshake_meta, meta)?; - mix_hash_xx_handshake(&mut self.symmetric, crypto, HandshakeKind::Xx4, &meta); - let remote_bundle = self.remote_bundle.clone().ok_or(WireError::InvalidState)?; + let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; let (skem_ciphertext, skem_secret) = crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); let skem_ciphertext = encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; self.symmetric .mix_key_and_hash(crypto, skem_secret.as_bytes()); - self.step = XxStep::Done; - Ok(XxMessage::Message4(Xx4 { + self.step = IkStep::Done; + Ok(IkMessage::Message2(Ik2 { + header, meta, + ekem_ciphertext, skem_ciphertext, })) } @@ -274,30 +266,56 @@ impl XxHandshake { &mut self, crypto: &impl QlCrypto, now_seconds: u64, - message: &XxMessage, + message: &IkMessage, ) -> Result<(), WireError> { match (&self.step, message) { - (XxStep::Recv1, XxMessage::Message1(message)) => { + (IkStep::Recv1, IkMessage::Message1(message)) => { message.meta.ensure_not_expired(now_seconds)?; initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; - mix_hash_xx_handshake( + self.ensure_inbound_recipient(message.header)?; + self.ensure_known_remote_sender(message.header)?; + mix_hash_routed_handshake( &mut self.symmetric, crypto, - HandshakeKind::Xx1, + message.header, + HandshakeKind::Ik1, &message.meta, ); + self.symmetric + .mix_hash(crypto, message.skem_ciphertext.as_bytes()); + let skem_secret = crypto + .mlkem_decapsulate(&self.local.mlkem_private_key, &message.skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); self.remote_ephemeral = Some(message.ephemeral.clone()); - self.step = XxStep::Send2; + + let remote_bundle = + decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; + if remote_bundle.xid != message.header.sender { + return Err(WireError::InvalidPayload); + } + match self.remote_bundle.as_ref() { + Some(expected) if expected != &remote_bundle => { + return Err(WireError::InvalidPayload); + } + Some(_) => {} + None => self.remote_bundle = Some(remote_bundle), + } + self.step = IkStep::Send2; Ok(()) } - (XxStep::Recv2, XxMessage::Message2(message)) => { + (IkStep::Recv2, IkMessage::Message2(message)) => { message.meta.ensure_not_expired(now_seconds)?; require_handshake_meta(&self.handshake_meta, message.meta)?; - mix_hash_xx_handshake( + self.ensure_inbound_recipient(message.header)?; + self.ensure_known_remote_sender(message.header)?; + mix_hash_routed_handshake( &mut self.symmetric, crypto, - HandshakeKind::Xx2, + message.header, + HandshakeKind::Ik2, &message.meta, ); let local_ephemeral = self @@ -310,21 +328,6 @@ impl XxHandshake { .mlkem_decapsulate(&local_ephemeral.mlkem.private, &message.ekem_ciphertext); self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); - let remote_bundle = - decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; - self.remote_bundle = Some(remote_bundle); - self.step = XxStep::Send3; - Ok(()) - } - (XxStep::Recv3, XxMessage::Message3(message)) => { - message.meta.ensure_not_expired(now_seconds)?; - require_handshake_meta(&self.handshake_meta, message.meta)?; - mix_hash_xx_handshake( - &mut self.symmetric, - crypto, - HandshakeKind::Xx3, - &message.meta, - ); let skem_ciphertext = decrypt_mlkem_ciphertext( crypto, &mut self.symmetric, @@ -335,31 +338,7 @@ impl XxHandshake { self.symmetric .mix_key_and_hash(crypto, skem_secret.as_bytes()); - let remote_bundle = - decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; - self.remote_bundle = Some(remote_bundle); - self.step = XxStep::Send4; - Ok(()) - } - (XxStep::Recv4, XxMessage::Message4(message)) => { - message.meta.ensure_not_expired(now_seconds)?; - require_handshake_meta(&self.handshake_meta, message.meta)?; - mix_hash_xx_handshake( - &mut self.symmetric, - crypto, - HandshakeKind::Xx4, - &message.meta, - ); - let skem_ciphertext = decrypt_mlkem_ciphertext( - crypto, - &mut self.symmetric, - &message.skem_ciphertext, - )?; - let skem_secret = - crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &skem_ciphertext); - self.symmetric - .mix_key_and_hash(crypto, skem_secret.as_bytes()); - self.step = XxStep::Done; + self.step = IkStep::Done; Ok(()) } _ => Err(WireError::InvalidState), diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index 44a12ab7..a7b3c3fc 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -1,55 +1,24 @@ use super::{ decrypt_mlkem_ciphertext, encrypt_mlkem_ciphertext, finalize_handshake, generate_ephemeral_keypair, init_kk_symmetric, initialize_handshake_meta, mix_hash_ephemeral, - mix_hash_kk_handshake, require_handshake_meta, EncryptedMlKemCiphertext, EphemeralKeyPair, - EphemeralPublicKey, FinalizedHandshake, Role, SymmetricState, + mix_hash_routed_handshake, require_handshake_meta, EncryptedMlKemCiphertext, EphemeralKeyPair, + EphemeralPublicKey, FinalizedHandshake, HandshakeHeader, Role, SymmetricState, }; use crate::{ codec, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, QlIdentity, - WireError, XID, + WireError, }; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct KkHandshakeHeader { - pub sender: XID, - pub recipient: XID, -} - -impl KkHandshakeHeader { - pub const ENCODED_LEN: usize = XID::SIZE * 2; - - pub fn encode_into(&self, out: &mut Vec) { - codec::push_bytes(out, &self.sender.0); - codec::push_bytes(out, &self.recipient.0); - } - - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - let header = Self::decode_from(&mut reader)?; - reader.finish()?; - Ok(header) - } - - pub fn decode_from( - reader: &mut codec::Reader, - ) -> Result { - Ok(Self { - sender: XID(reader.take_array()?), - recipient: XID(reader.take_array()?), - }) - } -} - #[derive(Debug, Clone, PartialEq, Eq)] pub struct Kk1 { - pub header: KkHandshakeHeader, + pub header: HandshakeHeader, pub meta: HandshakeMeta, pub skem_ciphertext: MlKemCiphertext, pub ephemeral: EphemeralPublicKey, } impl Kk1 { - pub const ENCODED_LEN: usize = KkHandshakeHeader::ENCODED_LEN + pub const ENCODED_LEN: usize = HandshakeHeader::ENCODED_LEN + HandshakeMeta::ENCODED_LEN + MlKemCiphertext::SIZE + EphemeralPublicKey::ENCODED_LEN; @@ -63,7 +32,7 @@ impl Kk1 { pub fn decode(bytes: &[u8]) -> Result { let mut reader = codec::Reader::new(bytes); - let header = KkHandshakeHeader::decode_from(&mut reader)?; + let header = HandshakeHeader::decode_from(&mut reader)?; let meta = HandshakeMeta::decode_from(&mut reader)?; let skem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); let ephemeral = @@ -80,14 +49,14 @@ impl Kk1 { #[derive(Debug, Clone, PartialEq, Eq)] pub struct Kk2 { - pub header: KkHandshakeHeader, + pub header: HandshakeHeader, pub meta: HandshakeMeta, pub ekem_ciphertext: MlKemCiphertext, pub skem_ciphertext: EncryptedMlKemCiphertext, } impl Kk2 { - pub const ENCODED_LEN: usize = KkHandshakeHeader::ENCODED_LEN + pub const ENCODED_LEN: usize = HandshakeHeader::ENCODED_LEN + HandshakeMeta::ENCODED_LEN + MlKemCiphertext::SIZE + EncryptedMlKemCiphertext::ENCODED_LEN; @@ -101,7 +70,7 @@ impl Kk2 { pub fn decode(bytes: &[u8]) -> Result { let mut reader = codec::Reader::new(bytes); - let header = KkHandshakeHeader::decode_from(&mut reader)?; + let header = HandshakeHeader::decode_from(&mut reader)?; let meta = HandshakeMeta::decode_from(&mut reader)?; let ekem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); let skem_ciphertext = EncryptedMlKemCiphertext::from_data(reader.take_array()?); @@ -183,21 +152,21 @@ impl KkHandshake { self.step == KkStep::Done } - fn outbound_header(&self) -> KkHandshakeHeader { - KkHandshakeHeader { + fn outbound_header(&self) -> HandshakeHeader { + HandshakeHeader { sender: self.local.xid, recipient: self.remote_bundle.xid, } } - fn inbound_header(&self) -> KkHandshakeHeader { - KkHandshakeHeader { + fn inbound_header(&self) -> HandshakeHeader { + HandshakeHeader { sender: self.remote_bundle.xid, recipient: self.local.xid, } } - fn ensure_inbound_header(&self, header: KkHandshakeHeader) -> Result<(), WireError> { + fn ensure_inbound_header(&self, header: HandshakeHeader) -> Result<(), WireError> { if header == self.inbound_header() { Ok(()) } else { @@ -214,7 +183,7 @@ impl KkHandshake { KkStep::Send1 => { initialize_handshake_meta(&mut self.handshake_meta, meta)?; let header = self.outbound_header(); - mix_hash_kk_handshake( + mix_hash_routed_handshake( &mut self.symmetric, crypto, header, @@ -244,7 +213,7 @@ impl KkHandshake { KkStep::Send2 => { require_handshake_meta(&self.handshake_meta, meta)?; let header = self.outbound_header(); - mix_hash_kk_handshake( + mix_hash_routed_handshake( &mut self.symmetric, crypto, header, @@ -290,7 +259,7 @@ impl KkHandshake { message.meta.ensure_not_expired(now_seconds)?; initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; self.ensure_inbound_header(message.header)?; - mix_hash_kk_handshake( + mix_hash_routed_handshake( &mut self.symmetric, crypto, message.header, @@ -313,7 +282,7 @@ impl KkHandshake { message.meta.ensure_not_expired(now_seconds)?; require_handshake_meta(&self.handshake_meta, message.meta)?; self.ensure_inbound_header(message.header)?; - mix_hash_kk_handshake( + mix_hash_routed_handshake( &mut self.symmetric, crypto, message.header, diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index a1292d8d..eca64321 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -1,22 +1,53 @@ use crate::{ codec, ConnectionId, HandshakeKind, MlKemCiphertext, MlKemKeyPair, MlKemPublicKey, Nonce, - PeerBundle, QlCrypto, SessionKey, WireError, ENCRYPTED_MESSAGE_AUTH_SIZE, + PeerBundle, QlCrypto, SessionKey, WireError, ENCRYPTED_MESSAGE_AUTH_SIZE, XID, }; +mod ik; mod kk; mod meta; -mod xx; -pub use kk::{Kk1, Kk2, KkHandshake, KkHandshakeHeader, KkMessage}; +pub use ik::{Ik1, Ik2, IkHandshake, IkMessage}; +pub use kk::{Kk1, Kk2, KkHandshake, KkMessage}; pub use meta::{HandshakeId, HandshakeMeta}; -pub use xx::{Xx1, Xx2, Xx3, Xx4, XxHandshake, XxMessage}; const SHA256_BLOCK_LEN: usize = 64; -const PROTOCOL_XX: &[u8] = b"ql-wire:pq-xx:v1"; +const PROTOCOL_IK: &[u8] = b"ql-wire:pq-ik:v1"; const PROTOCOL_KK: &[u8] = b"ql-wire:pq-kk:v1"; const CONNECTION_ID_DOMAIN: &[u8] = b"ql-wire:conn-id:v1"; const HANDSHAKE_PREAMBLE_DOMAIN: &[u8] = b"ql-wire:handshake-preamble:v1"; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct HandshakeHeader { + pub sender: XID, + pub recipient: XID, +} + +impl HandshakeHeader { + pub const ENCODED_LEN: usize = XID::SIZE * 2; + + pub fn encode_into(&self, out: &mut Vec) { + codec::push_bytes(out, &self.sender.0); + codec::push_bytes(out, &self.recipient.0); + } + + pub fn decode(bytes: &[u8]) -> Result { + let mut reader = codec::Reader::new(bytes); + let header = Self::decode_from(&mut reader)?; + reader.finish()?; + Ok(header) + } + + pub fn decode_from( + reader: &mut codec::Reader, + ) -> Result { + Ok(Self { + sender: XID(reader.take_array()?), + recipient: XID(reader.take_array()?), + }) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct EphemeralPublicKey { pub mlkem_public_key: MlKemPublicKey, @@ -251,6 +282,12 @@ fn init_kk_symmetric( symmetric } +fn init_ik_symmetric(crypto: &impl QlCrypto, responder_bundle: &PeerBundle) -> SymmetricState { + let mut symmetric = SymmetricState::new(crypto, PROTOCOL_IK); + symmetric.mix_hash(crypto, &responder_bundle.encode()); + symmetric +} + fn generate_ephemeral_keypair(crypto: &impl QlCrypto) -> EphemeralKeyPair { EphemeralKeyPair { mlkem: crypto.mlkem_generate_keypair(), @@ -265,26 +302,14 @@ fn mix_hash_ephemeral( symmetric.mix_hash(crypto, public.mlkem_public_key.as_bytes()); } -fn mix_hash_xx_handshake( - symmetric: &mut SymmetricState, - crypto: &impl QlCrypto, - kind: HandshakeKind, - meta: &HandshakeMeta, -) { - let encoded = meta.encode(); - symmetric.mix_hash(crypto, HANDSHAKE_PREAMBLE_DOMAIN); - symmetric.mix_hash(crypto, &[kind as u8]); - symmetric.mix_hash(crypto, &encoded); -} - -fn mix_hash_kk_handshake( +fn mix_hash_routed_handshake( symmetric: &mut SymmetricState, crypto: &impl QlCrypto, - header: KkHandshakeHeader, + header: HandshakeHeader, kind: HandshakeKind, meta: &HandshakeMeta, ) { - let mut encoded_header = Vec::with_capacity(KkHandshakeHeader::ENCODED_LEN); + let mut encoded_header = Vec::with_capacity(HandshakeHeader::ENCODED_LEN); header.encode_into(&mut encoded_header); let encoded = meta.encode(); symmetric.mix_hash(crypto, HANDSHAKE_PREAMBLE_DOMAIN); diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs index ba77f53c..c82f2fa4 100644 --- a/ql-wire/src/lib.rs +++ b/ql-wire/src/lib.rs @@ -31,7 +31,7 @@ pub use pq::*; pub use record::*; pub use xid::*; -pub const QL_WIRE_VERSION: u8 = 1; +pub const QL_WIRE_VERSION: u8 = 2; pub const ENCRYPTED_MESSAGE_AUTH_SIZE: usize = 16; #[cfg(test)] diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index 3f0b4069..77c4ceb5 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -1,7 +1,7 @@ use crate::{ codec, encrypted_message::EncryptedMessage, - handshake::{Kk1, Kk2, Xx1, Xx2, Xx3, Xx4}, + handshake::{Ik1, Ik2, Kk1, Kk2}, ByteSlice, SessionHeader, WireError, QL_WIRE_VERSION, }; @@ -19,10 +19,8 @@ pub enum QlRecord { #[derive(Debug, Clone, PartialEq, Eq)] pub enum QlHandshakeRecord { - Xx1(Xx1), - Xx2(Xx2), - Xx3(Xx3), - Xx4(Xx4), + Ik1(Ik1), + Ik2(Ik2), Kk1(Kk1), Kk2(Kk2), } @@ -49,12 +47,10 @@ impl TryFrom for RecordType { #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub enum HandshakeKind { - Xx1 = 1, - Xx2 = 2, - Xx3 = 3, - Xx4 = 4, - Kk1 = 5, - Kk2 = 6, + Ik1 = 1, + Ik2 = 2, + Kk1 = 3, + Kk2 = 4, } impl TryFrom for HandshakeKind { @@ -62,12 +58,10 @@ impl TryFrom for HandshakeKind { fn try_from(value: u8) -> Result { match value { - 1 => Ok(Self::Xx1), - 2 => Ok(Self::Xx2), - 3 => Ok(Self::Xx3), - 4 => Ok(Self::Xx4), - 5 => Ok(Self::Kk1), - 6 => Ok(Self::Kk2), + 1 => Ok(Self::Ik1), + 2 => Ok(Self::Ik2), + 3 => Ok(Self::Kk1), + 4 => Ok(Self::Kk2), _ => Err(WireError::InvalidPayload), } } @@ -76,10 +70,8 @@ impl TryFrom for HandshakeKind { impl QlHandshakeRecord { pub fn kind(&self) -> HandshakeKind { match self { - Self::Xx1(_) => HandshakeKind::Xx1, - Self::Xx2(_) => HandshakeKind::Xx2, - Self::Xx3(_) => HandshakeKind::Xx3, - Self::Xx4(_) => HandshakeKind::Xx4, + Self::Ik1(_) => HandshakeKind::Ik1, + Self::Ik2(_) => HandshakeKind::Ik2, Self::Kk1(_) => HandshakeKind::Kk1, Self::Kk2(_) => HandshakeKind::Kk2, } @@ -87,10 +79,8 @@ impl QlHandshakeRecord { fn encode_into(&self, out: &mut Vec) { match self { - Self::Xx1(message) => message.encode_into(out), - Self::Xx2(message) => message.encode_into(out), - Self::Xx3(message) => message.encode_into(out), - Self::Xx4(message) => message.encode_into(out), + Self::Ik1(message) => message.encode_into(out), + Self::Ik2(message) => message.encode_into(out), Self::Kk1(message) => message.encode_into(out), Self::Kk2(message) => message.encode_into(out), } @@ -98,10 +88,8 @@ impl QlHandshakeRecord { fn decode_payload(kind: HandshakeKind, bytes: &[u8]) -> Result { match kind { - HandshakeKind::Xx1 => Ok(Self::Xx1(Xx1::decode(bytes)?)), - HandshakeKind::Xx2 => Ok(Self::Xx2(Xx2::decode(bytes)?)), - HandshakeKind::Xx3 => Ok(Self::Xx3(Xx3::decode(bytes)?)), - HandshakeKind::Xx4 => Ok(Self::Xx4(Xx4::decode(bytes)?)), + HandshakeKind::Ik1 => Ok(Self::Ik1(Ik1::decode(bytes)?)), + HandshakeKind::Ik2 => Ok(Self::Ik2(Ik2::decode(bytes)?)), HandshakeKind::Kk1 => Ok(Self::Kk1(Kk1::decode(bytes)?)), HandshakeKind::Kk2 => Ok(Self::Kk2(Kk2::decode(bytes)?)), } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index da607b04..4ddb27f9 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -162,19 +162,17 @@ fn make_identity(crypto: &impl QlCrypto, byte: u8) -> QlIdentity { generate_identity(crypto, xid(byte)) } -fn kk_handshake_header(sender: u8, recipient: u8) -> KkHandshakeHeader { - KkHandshakeHeader { +fn handshake_header(sender: u8, recipient: u8) -> HandshakeHeader { + HandshakeHeader { sender: xid(sender), recipient: xid(recipient), } } -fn xx_record(message: XxMessage) -> QlHandshakeRecord { +fn ik_record(message: IkMessage) -> QlHandshakeRecord { match message { - XxMessage::Message1(message) => QlHandshakeRecord::Xx1(message), - XxMessage::Message2(message) => QlHandshakeRecord::Xx2(message), - XxMessage::Message3(message) => QlHandshakeRecord::Xx3(message), - XxMessage::Message4(message) => QlHandshakeRecord::Xx4(message), + IkMessage::Message1(message) => QlHandshakeRecord::Ik1(message), + IkMessage::Message2(message) => QlHandshakeRecord::Ik2(message), } } @@ -198,26 +196,29 @@ fn peer_bundle_round_trip() { } #[test] -fn handshake_record_round_trip_supports_xx_and_kk() { - let xx = QlHandshakeRecord::Xx1(Xx1 { +fn handshake_record_round_trip_supports_ik_and_kk() { + let ik = QlHandshakeRecord::Ik1(Ik1 { + header: handshake_header(1, 2), meta: handshake_meta(1), + skem_ciphertext: MlKemCiphertext::from_data([7; MlKemCiphertext::SIZE]), ephemeral: EphemeralPublicKey { mlkem_public_key: MlKemPublicKey::from_data([9; MlKemPublicKey::SIZE]), }, + static_bundle: EncryptedPeerBundle::from_data([13; EncryptedPeerBundle::ENCODED_LEN]), }); - let xx_encoded = xx.encode(); - assert_eq!(QlHandshakeRecord::decode(&xx_encoded).unwrap(), xx); + let ik_encoded = ik.encode(); + assert_eq!(QlHandshakeRecord::decode(&ik_encoded).unwrap(), ik); assert_eq!( - QlRecord::decode(&xx_encoded).unwrap(), - QlRecord::Handshake(xx) + QlRecord::decode(&ik_encoded).unwrap(), + QlRecord::Handshake(ik) ); let kk = QlHandshakeRecord::Kk1(Kk1 { - header: kk_handshake_header(1, 2), + header: handshake_header(1, 2), meta: handshake_meta(2), - skem_ciphertext: MlKemCiphertext::from_data([7; MlKemCiphertext::SIZE]), + skem_ciphertext: MlKemCiphertext::from_data([11; MlKemCiphertext::SIZE]), ephemeral: EphemeralPublicKey { - mlkem_public_key: MlKemPublicKey::from_data([11; MlKemPublicKey::SIZE]), + mlkem_public_key: MlKemPublicKey::from_data([15; MlKemPublicKey::SIZE]), }, }); let kk_encoded = kk.encode(); @@ -229,13 +230,13 @@ fn handshake_record_round_trip_supports_xx_and_kk() { } #[test] -fn xx_handshake_rejects_tampered_handshake_meta() { +fn ik_handshake_rejects_tampered_handshake_meta() { let crypto = TestCrypto::new(9); let initiator = make_identity(&crypto, 1); let responder = make_identity(&crypto, 2); - let mut initiator_state = XxHandshake::new_initiator(&crypto, initiator); - let mut responder_state = XxHandshake::new_responder(&crypto, responder); + let mut initiator_state = IkHandshake::new_initiator(&crypto, initiator, responder.bundle()); + let mut responder_state = IkHandshake::new_responder(&crypto, responder, None); let m1 = initiator_state .write_message(&crypto, handshake_meta(77)) @@ -245,8 +246,8 @@ fn xx_handshake_rejects_tampered_handshake_meta() { let mut m2 = responder_state .write_message(&crypto, handshake_meta(77)) .unwrap(); - let XxMessage::Message2(message) = &mut m2 else { - panic!("expected xx2"); + let IkMessage::Message2(message) = &mut m2 else { + panic!("expected ik2"); }; message.meta.handshake_id = HandshakeId(78); @@ -277,7 +278,7 @@ fn kk_handshake_rejects_tampered_handshake_header() { let KkMessage::Message2(message) = &mut m2 else { panic!("expected kk2"); }; - message.header = kk_handshake_header(9, 1); + message.header = handshake_header(9, 1); assert_eq!( initiator_state.read_message(&crypto, 0, &m2), @@ -286,19 +287,62 @@ fn kk_handshake_rejects_tampered_handshake_header() { } #[test] -fn xx_handshake_rejects_expired_message() { +fn ik_handshake_rejects_tampered_handshake_header() { let crypto = TestCrypto::new(11); let initiator = make_identity(&crypto, 1); let responder = make_identity(&crypto, 2); - let mut initiator_state = XxHandshake::new_initiator(&crypto, initiator); - let mut responder_state = XxHandshake::new_responder(&crypto, responder); + let mut initiator_state = IkHandshake::new_initiator(&crypto, initiator, responder.bundle()); + let mut responder_state = IkHandshake::new_responder(&crypto, responder, None); + + let mut m1 = initiator_state + .write_message(&crypto, handshake_meta(90)) + .unwrap(); + let IkMessage::Message1(message) = &mut m1 else { + panic!("expected ik1"); + }; + message.header.sender = xid(9); + + assert_eq!( + responder_state.read_message(&crypto, 0, &m1), + Err(WireError::DecryptFailed) + ); +} + +#[test] +fn ik_handshake_rejects_bound_remote_bundle_mismatch() { + let crypto = TestCrypto::new(12); + let initiator = make_identity(&crypto, 1); + let bogus = make_identity(&crypto, 1); + let responder = make_identity(&crypto, 2); + + let mut initiator_state = IkHandshake::new_initiator(&crypto, initiator, responder.bundle()); + let mut responder_state = IkHandshake::new_responder(&crypto, responder, Some(bogus.bundle())); + + let m1 = initiator_state + .write_message(&crypto, handshake_meta(91)) + .unwrap(); + + assert_eq!( + responder_state.read_message(&crypto, 0, &m1), + Err(WireError::InvalidPayload) + ); +} + +#[test] +fn ik_handshake_rejects_expired_message() { + let crypto = TestCrypto::new(13); + let initiator = make_identity(&crypto, 1); + let responder = make_identity(&crypto, 2); + + let mut initiator_state = IkHandshake::new_initiator(&crypto, initiator, responder.bundle()); + let mut responder_state = IkHandshake::new_responder(&crypto, responder, None); let m1 = initiator_state .write_message( &crypto, HandshakeMeta { - handshake_id: HandshakeId(90), + handshake_id: HandshakeId(92), valid_until: 5, }, ) @@ -311,33 +355,66 @@ fn xx_handshake_rejects_expired_message() { } #[test] -fn xx_handshake_round_trip_derives_matching_transport() { - let crypto = TestCrypto::new(10); - let initiator = make_identity(&crypto, 1); - let responder = make_identity(&crypto, 2); +fn ik_handshake_round_trip_derives_matching_transport_and_learns_remote() { + let crypto = TestCrypto::new(20); + let initiator = make_identity(&crypto, 3); + let responder = make_identity(&crypto, 4); - let mut initiator_state = XxHandshake::new_initiator(&crypto, initiator.clone()); - let mut responder_state = XxHandshake::new_responder(&crypto, responder.clone()); + let mut initiator_state = + IkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); + let mut responder_state = IkHandshake::new_responder(&crypto, responder.clone(), None); let m1 = initiator_state - .write_message(&crypto, handshake_meta(1)) + .write_message(&crypto, handshake_meta(11)) .unwrap(); responder_state.read_message(&crypto, 0, &m1).unwrap(); let m2 = responder_state - .write_message(&crypto, handshake_meta(1)) + .write_message(&crypto, handshake_meta(11)) .unwrap(); initiator_state.read_message(&crypto, 0, &m2).unwrap(); - let m3 = initiator_state - .write_message(&crypto, handshake_meta(1)) + let initiator_final = initiator_state.finalize(&crypto).unwrap(); + let responder_final = responder_state.finalize(&crypto).unwrap(); + + assert_eq!( + initiator_final.handshake_hash, + responder_final.handshake_hash + ); + assert_eq!(initiator_final.tx_key, responder_final.rx_key); + assert_eq!(initiator_final.rx_key, responder_final.tx_key); + assert_eq!( + initiator_final.tx_connection_id, + responder_final.rx_connection_id + ); + assert_eq!( + initiator_final.rx_connection_id, + responder_final.tx_connection_id + ); + assert_eq!(initiator_final.remote_bundle, responder.bundle()); + assert_eq!(responder_final.remote_bundle, initiator.bundle()); +} + +#[test] +fn ik_handshake_round_trip_derives_matching_transport_with_bound_responder() { + let crypto = TestCrypto::new(21); + let initiator = make_identity(&crypto, 3); + let responder = make_identity(&crypto, 4); + + let mut initiator_state = + IkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); + let mut responder_state = + IkHandshake::new_responder(&crypto, responder.clone(), Some(initiator.bundle())); + + let m1 = initiator_state + .write_message(&crypto, handshake_meta(12)) .unwrap(); - responder_state.read_message(&crypto, 0, &m3).unwrap(); + responder_state.read_message(&crypto, 0, &m1).unwrap(); - let m4 = responder_state - .write_message(&crypto, handshake_meta(1)) + let m2 = responder_state + .write_message(&crypto, handshake_meta(12)) .unwrap(); - initiator_state.read_message(&crypto, 0, &m4).unwrap(); + initiator_state.read_message(&crypto, 0, &m2).unwrap(); let initiator_final = initiator_state.finalize(&crypto).unwrap(); let responder_final = responder_state.finalize(&crypto).unwrap(); @@ -362,7 +439,7 @@ fn xx_handshake_round_trip_derives_matching_transport() { #[test] fn kk_handshake_round_trip_derives_matching_transport() { - let crypto = TestCrypto::new(20); + let crypto = TestCrypto::new(30); let initiator = make_identity(&crypto, 3); let responder = make_identity(&crypto, 4); @@ -372,12 +449,12 @@ fn kk_handshake_round_trip_derives_matching_transport() { KkHandshake::new_responder(&crypto, responder.clone(), initiator.bundle()); let m1 = initiator_state - .write_message(&crypto, handshake_meta(11)) + .write_message(&crypto, handshake_meta(21)) .unwrap(); responder_state.read_message(&crypto, 0, &m1).unwrap(); let m2 = responder_state - .write_message(&crypto, handshake_meta(11)) + .write_message(&crypto, handshake_meta(21)) .unwrap(); initiator_state.read_message(&crypto, 0, &m2).unwrap(); @@ -404,7 +481,7 @@ fn kk_handshake_round_trip_derives_matching_transport() { #[test] fn encrypted_session_record_round_trip_uses_connection_id_header() { - let crypto = TestCrypto::new(30); + let crypto = TestCrypto::new(40); let header = SessionHeader { connection_id: ConnectionId::from_data([0x44; ConnectionId::SIZE]), seq: RecordSeq(11), @@ -481,37 +558,26 @@ fn protocol_record_size_breakdown() { println!("{label:<32}: {size} bytes"); } - let crypto = TestCrypto::new(40); + let crypto = TestCrypto::new(50); let initiator = make_identity(&crypto, 1); let responder = make_identity(&crypto, 2); - let mut xx_initiator = XxHandshake::new_initiator(&crypto, initiator.clone()); - let mut xx_responder = XxHandshake::new_responder(&crypto, responder.clone()); - - let xx1 = xx_initiator - .write_message(&crypto, handshake_meta(101)) - .unwrap(); - xx_responder.read_message(&crypto, 0, &xx1).unwrap(); - - let xx2 = xx_responder - .write_message(&crypto, handshake_meta(101)) - .unwrap(); - xx_initiator.read_message(&crypto, 0, &xx2).unwrap(); + let mut ik_initiator = + IkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); + let mut ik_responder = IkHandshake::new_responder(&crypto, responder.clone(), None); - let xx3 = xx_initiator + let ik1 = ik_initiator .write_message(&crypto, handshake_meta(101)) .unwrap(); - xx_responder.read_message(&crypto, 0, &xx3).unwrap(); + ik_responder.read_message(&crypto, 0, &ik1).unwrap(); - let xx4 = xx_responder + let ik2 = ik_responder .write_message(&crypto, handshake_meta(101)) .unwrap(); - xx_initiator.read_message(&crypto, 0, &xx4).unwrap(); + ik_initiator.read_message(&crypto, 0, &ik2).unwrap(); - let xx1 = xx_record(xx1); - let xx2 = xx_record(xx2); - let xx3 = xx_record(xx3); - let xx4 = xx_record(xx4); + let ik1 = ik_record(ik1); + let ik2 = ik_record(ik2); let mut kk_initiator = KkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); @@ -531,7 +597,7 @@ fn protocol_record_size_breakdown() { let kk1 = kk_record(kk1); let kk2 = kk_record(kk2); - let session = xx_initiator.finalize(&crypto).unwrap(); + let session = ik_initiator.finalize(&crypto).unwrap(); let session_ping = encrypted::encrypt_record( &crypto, SessionHeader { @@ -576,10 +642,8 @@ fn protocol_record_size_breakdown() { print_size("ql-wire peer bundle", initiator.bundle().encode().len()); print_size("ql-wire mlkem public key", MlKemPublicKey::SIZE); print_size("ql-wire mlkem ciphertext", MlKemCiphertext::SIZE); - print_size("ql-wire pq xx1", xx1.encode().len()); - print_size("ql-wire pq xx2", xx2.encode().len()); - print_size("ql-wire pq xx3", xx3.encode().len()); - print_size("ql-wire pq xx4", xx4.encode().len()); + print_size("ql-wire pq ik1", ik1.encode().len()); + print_size("ql-wire pq ik2", ik2.encode().len()); print_size("ql-wire pq kk1", kk1.encode().len()); print_size("ql-wire pq kk2", kk2.encode().len()); print_size("ql-wire session ping", session_ping.encode().len()); From 460757a9341dcdc4edf31e5a177d8e6f8b4e423c Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 2 Apr 2026 10:26:50 -0400 Subject: [PATCH 069/304] ql-fsm: remove xx handshake --- ql-fsm/src/error.rs | 2 + ql-fsm/src/implementation/fsm.rs | 15 +- ql-fsm/src/implementation/handshake/ik.rs | 122 ++++++++++++++ ql-fsm/src/implementation/handshake/kk.rs | 2 +- ql-fsm/src/implementation/handshake/mod.rs | 23 +-- ql-fsm/src/implementation/handshake/xx.rs | 177 --------------------- ql-fsm/src/lib.rs | 16 +- ql-fsm/src/state.rs | 21 +-- ql-fsm/src/tests/handshake.rs | 93 ++++++----- ql-fsm/src/tests/mod.rs | 27 ---- ql-fsm/src/tests/session.rs | 14 +- 11 files changed, 228 insertions(+), 284 deletions(-) create mode 100644 ql-fsm/src/implementation/handshake/ik.rs delete mode 100644 ql-fsm/src/implementation/handshake/xx.rs diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs index e933a50e..2c033e4d 100644 --- a/ql-fsm/src/error.rs +++ b/ql-fsm/src/error.rs @@ -23,6 +23,8 @@ pub enum QlFsmError { SessionClosed, #[error("no peer bound")] NoPeerBound, + #[error("fsm is busy")] + Busy, #[error("no active session")] NoSession, } diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/fsm.rs index 2f5d59af..d5ede6b6 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/fsm.rs @@ -75,19 +75,8 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option Result<(), QlFsmError> { + let meta = super::next_handshake_meta(fsm); + let mut handshake = wire::IkHandshake::new_initiator(crypto, fsm.identity.clone(), peer); + let message = handshake.write_message(crypto, meta)?; + let IkMessage::Message1(message) = message else { + return Err(QlFsmError::InvalidPayload); + }; + + fsm.state.link = LinkState::IkInitiator { + initial_ephemeral: message.ephemeral.clone(), + handshake, + deadline: fsm.state.now.instant + fsm.config.handshake_timeout, + }; + enqueue_handshake(fsm, QlHandshakeRecord::Ik1(message)); + emit_peer_status(fsm); + Ok(()) +} + +pub fn handle_ik1( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Ik1, +) -> Result<(), QlFsmError> { + if should_ignore_inbound(fsm, message) { + return Ok(()); + } + if is_replayed_handshake_start(fsm, message.meta) { + return Ok(()); + } + if message.header.recipient != fsm.identity.xid { + return Err(QlFsmError::InvalidXid); + } + if let Some(peer) = fsm.state.peer.as_ref() { + if message.header.sender != peer.xid { + return Err(QlFsmError::InvalidXid); + } + } + + reset_connected_session_if_needed(fsm); + + let mut handshake = + wire::IkHandshake::new_responder(crypto, fsm.identity.clone(), fsm.state.peer.clone()); + handshake.read_message( + crypto, + fsm.state.now.unix_secs, + &IkMessage::Message1(message.clone()), + )?; + let outbound = handshake.write_message(crypto, message.meta)?; + let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); + finish_handshake(fsm, transport, remote_bundle)?; + fsm.state.handshake = None; + enqueue_handshake(fsm, ik_record(outbound)); + Ok(()) +} + +pub fn handle_ik2( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Ik2, +) -> Result<(), QlFsmError> { + let LinkState::IkInitiator { + mut handshake, + deadline: _, + initial_ephemeral: _, + } = fsm.state.link.clone() + else { + return Ok(()); + }; + + match handshake.read_message( + crypto, + fsm.state.now.unix_secs, + &IkMessage::Message2(message.clone()), + ) { + Ok(()) => {} + Err(WireError::InvalidState) => return Ok(()), + Err(error) => return Err(error.into()), + } + + let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); + finish_handshake(fsm, transport, remote_bundle) +} + +fn ik_record(message: IkMessage) -> QlHandshakeRecord { + match message { + IkMessage::Message1(message) => QlHandshakeRecord::Ik1(message), + IkMessage::Message2(message) => QlHandshakeRecord::Ik2(message), + } +} + +pub fn should_ignore_inbound(fsm: &QlFsm, message: &Ik1) -> bool { + match &fsm.state.link { + LinkState::Idle | LinkState::Connected(_) => false, + LinkState::IkInitiator { + initial_ephemeral, .. + } => { + if fsm.state.peer.as_ref().map(|peer| peer.xid) != Some(message.header.sender) { + return false; + } + super::local_start_wins(initial_ephemeral, &message.ephemeral) + } + LinkState::KkInitiator { .. } => false, + } +} diff --git a/ql-fsm/src/implementation/handshake/kk.rs b/ql-fsm/src/implementation/handshake/kk.rs index 3215f48c..8c642ee4 100644 --- a/ql-fsm/src/implementation/handshake/kk.rs +++ b/ql-fsm/src/implementation/handshake/kk.rs @@ -107,7 +107,7 @@ fn kk_record(message: KkMessage) -> QlHandshakeRecord { pub fn should_ignore_inbound(fsm: &QlFsm, message: &Kk1) -> bool { match &fsm.state.link { LinkState::Idle | LinkState::Connected(_) => false, - LinkState::XxInitiator { .. } | LinkState::XxResponder { .. } => true, + LinkState::IkInitiator { .. } => true, LinkState::KkInitiator { initial_ephemeral, .. } => { diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs index ea0213b4..d641d65e 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -1,5 +1,5 @@ +mod ik; mod kk; -mod xx; use ql_wire::{self as wire, EphemeralPublicKey, HandshakeMeta, QlCrypto, QlHandshakeRecord}; @@ -9,15 +9,20 @@ use crate::{ QlFsm, QlFsmError, QlFsmEvent, }; -pub fn handle_connect(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { +pub fn handle_connect_ik(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { if !matches!(fsm.state.link, LinkState::Idle) { - return Ok(()); + return Err(QlFsmError::Busy); } + let peer = fsm.state.peer.clone().ok_or(QlFsmError::NoPeerBound)?; + ik::start_initiator(fsm, crypto, peer) +} - match fsm.state.peer.clone() { - Some(peer) => kk::start_initiator(fsm, crypto, peer), - None => xx::start_initiator(fsm, crypto), +pub fn handle_connect_kk(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { + if !matches!(fsm.state.link, LinkState::Idle) { + return Err(QlFsmError::Busy); } + let peer = fsm.state.peer.clone().ok_or(QlFsmError::NoPeerBound)?; + kk::start_initiator(fsm, crypto, peer) } pub fn next_handshake_meta(fsm: &mut QlFsm) -> HandshakeMeta { @@ -49,10 +54,8 @@ pub fn handle_handshake_record( record: &QlHandshakeRecord, ) -> Result<(), QlFsmError> { match record { - QlHandshakeRecord::Xx1(message) => xx::handle_xx1(fsm, crypto, message), - QlHandshakeRecord::Xx2(message) => xx::handle_xx2(fsm, crypto, message), - QlHandshakeRecord::Xx3(message) => xx::handle_xx3(fsm, crypto, message), - QlHandshakeRecord::Xx4(message) => xx::handle_xx4(fsm, crypto, message), + QlHandshakeRecord::Ik1(message) => ik::handle_ik1(fsm, crypto, message), + QlHandshakeRecord::Ik2(message) => ik::handle_ik2(fsm, crypto, message), QlHandshakeRecord::Kk1(message) => kk::handle_kk1(fsm, crypto, message), QlHandshakeRecord::Kk2(message) => kk::handle_kk2(fsm, crypto, message), } diff --git a/ql-fsm/src/implementation/handshake/xx.rs b/ql-fsm/src/implementation/handshake/xx.rs deleted file mode 100644 index 659d6fda..00000000 --- a/ql-fsm/src/implementation/handshake/xx.rs +++ /dev/null @@ -1,177 +0,0 @@ -use ql_wire::{ - self as wire, QlCrypto, QlHandshakeRecord, WireError, Xx1, Xx2, Xx3, Xx4, XxMessage, -}; - -use super::{ - enqueue_handshake, finish_handshake, is_replayed_handshake_start, - reset_connected_session_if_needed, -}; -use crate::{ - implementation::emit_peer_status, - state::{LinkState, SessionTransport}, - QlFsm, QlFsmError, -}; - -pub fn start_initiator(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { - let meta = super::next_handshake_meta(fsm); - let mut handshake = wire::XxHandshake::new_initiator(crypto, fsm.identity.clone()); - let message = handshake.write_message(crypto, meta)?; - let XxMessage::Message1(message) = message else { - return Err(QlFsmError::InvalidPayload); - }; - - fsm.state.link = LinkState::XxInitiator { - initial_ephemeral: message.ephemeral.clone(), - handshake, - deadline: fsm.state.now.instant + fsm.config.handshake_timeout, - }; - enqueue_handshake(fsm, QlHandshakeRecord::Xx1(message)); - emit_peer_status(fsm); - Ok(()) -} - -pub fn handle_xx1( - fsm: &mut QlFsm, - crypto: &impl QlCrypto, - message: &Xx1, -) -> Result<(), QlFsmError> { - if should_ignore_inbound(fsm, message) { - return Ok(()); - } - if is_replayed_handshake_start(fsm, message.meta) { - return Ok(()); - } - - reset_connected_session_if_needed(fsm); - - let mut handshake = wire::XxHandshake::new_responder(crypto, fsm.identity.clone()); - handshake.read_message( - crypto, - fsm.state.now.unix_secs, - &XxMessage::Message1(message.clone()), - )?; - let outbound = handshake.write_message(crypto, message.meta)?; - - fsm.state.handshake = None; - fsm.state.link = LinkState::XxResponder { - handshake, - deadline: fsm.state.now.instant + fsm.config.handshake_timeout, - }; - enqueue_handshake(fsm, xx_record(outbound)); - emit_peer_status(fsm); - Ok(()) -} - -pub fn handle_xx2( - fsm: &mut QlFsm, - crypto: &impl QlCrypto, - message: &Xx2, -) -> Result<(), QlFsmError> { - let LinkState::XxInitiator { - mut handshake, - deadline, - initial_ephemeral, - } = fsm.state.link.clone() - else { - return Ok(()); - }; - - match handshake.read_message( - crypto, - fsm.state.now.unix_secs, - &XxMessage::Message2(message.clone()), - ) { - Ok(()) => {} - Err(WireError::InvalidState) => return Ok(()), - Err(error) => return Err(error.into()), - } - - let outbound = handshake.write_message(crypto, message.meta)?; - fsm.state.handshake = None; - fsm.state.link = LinkState::XxInitiator { - handshake, - deadline, - initial_ephemeral, - }; - enqueue_handshake(fsm, xx_record(outbound)); - Ok(()) -} - -pub fn handle_xx3( - fsm: &mut QlFsm, - crypto: &impl QlCrypto, - message: &Xx3, -) -> Result<(), QlFsmError> { - let LinkState::XxResponder { - mut handshake, - deadline: _, - } = fsm.state.link.clone() - else { - return Ok(()); - }; - - match handshake.read_message( - crypto, - fsm.state.now.unix_secs, - &XxMessage::Message3(message.clone()), - ) { - Ok(()) => {} - Err(WireError::InvalidState) => return Ok(()), - Err(error) => return Err(error.into()), - } - - let outbound = handshake.write_message(crypto, message.meta)?; - let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); - finish_handshake(fsm, transport, remote_bundle)?; - fsm.state.handshake = None; - enqueue_handshake(fsm, xx_record(outbound)); - Ok(()) -} - -pub fn handle_xx4( - fsm: &mut QlFsm, - crypto: &impl QlCrypto, - message: &Xx4, -) -> Result<(), QlFsmError> { - let LinkState::XxInitiator { - mut handshake, - deadline: _, - initial_ephemeral: _, - } = fsm.state.link.clone() - else { - return Ok(()); - }; - - match handshake.read_message( - crypto, - fsm.state.now.unix_secs, - &XxMessage::Message4(message.clone()), - ) { - Ok(()) => {} - Err(WireError::InvalidState) => return Ok(()), - Err(error) => return Err(error.into()), - } - - let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); - finish_handshake(fsm, transport, remote_bundle) -} - -fn xx_record(message: XxMessage) -> QlHandshakeRecord { - match message { - XxMessage::Message1(message) => QlHandshakeRecord::Xx1(message), - XxMessage::Message2(message) => QlHandshakeRecord::Xx2(message), - XxMessage::Message3(message) => QlHandshakeRecord::Xx3(message), - XxMessage::Message4(message) => QlHandshakeRecord::Xx4(message), - } -} - -pub fn should_ignore_inbound(fsm: &QlFsm, message: &Xx1) -> bool { - match &fsm.state.link { - LinkState::Idle | LinkState::Connected(_) => false, - LinkState::XxResponder { .. } => true, - LinkState::KkInitiator { .. } => false, - LinkState::XxInitiator { - initial_ephemeral, .. - } => super::local_start_wins(initial_ephemeral, &message.ephemeral), - } -} diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index dd914162..0ed04024 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -3,7 +3,7 @@ //! a caller drives `QlFsm` inside its own event loop //! //! inputs to that loop usually include -//! - app actions like `bind_peer`, `connect`, `open_stream`, or `write_stream` +//! - app actions like `bind_peer`, `connect_ik`, `connect_kk`, `open_stream`, or `write_stream` //! - inbound transport bytes passed to `receive` //! - a deadline expiring, handled by calling `on_timer` //! - transport write results passed to `confirm_session_write` or `reject_session_write` @@ -56,8 +56,6 @@ pub enum PeerStatus { Disconnected, /// we are driving the handshake Initiator, - /// the peer is driving the handshake - Responder, /// the encrypted session is up Connected, } @@ -195,10 +193,16 @@ impl QlFsm { implementation::handle_bind_peer(self, peer); } - /// starts or resumes the encrypted session handshake - pub fn connect(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { + /// starts an IK handshake with the currently bound peer + pub fn connect_ik(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { self.state.now = now; - implementation::handle_connect(self, crypto) + implementation::handle_connect_ik(self, crypto) + } + + /// starts a KK handshake with the currently bound peer + pub fn connect_kk(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { + self.state.now = now; + implementation::handle_connect_kk(self, crypto) } /// handles one inbound wire message diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index cb78022f..73942284 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -1,8 +1,8 @@ use std::{collections::VecDeque, time::Instant}; use ql_wire::{ - ConnectionId, EphemeralPublicKey, KkHandshake, PeerBundle, QlHandshakeRecord, SessionKey, - XxHandshake, + ConnectionId, EphemeralPublicKey, IkHandshake, KkHandshake, PeerBundle, QlHandshakeRecord, + SessionKey, }; use crate::{replay_cache::ReplayCache, FsmTime, PeerStatus, QlFsmEvent, QlSessionEvent}; @@ -43,15 +43,11 @@ impl SessionTransport { #[derive(Debug, Clone)] pub enum LinkState { Idle, - XxInitiator { - handshake: XxHandshake, + IkInitiator { + handshake: IkHandshake, deadline: Instant, initial_ephemeral: EphemeralPublicKey, }, - XxResponder { - handshake: XxHandshake, - deadline: Instant, - }, KkInitiator { handshake: KkHandshake, deadline: Instant, @@ -64,8 +60,7 @@ impl LinkState { pub fn status(&self) -> PeerStatus { match self { Self::Idle => PeerStatus::Disconnected, - Self::XxInitiator { .. } | Self::KkInitiator { .. } => PeerStatus::Initiator, - Self::XxResponder { .. } => PeerStatus::Responder, + Self::IkInitiator { .. } | Self::KkInitiator { .. } => PeerStatus::Initiator, Self::Connected(_) => PeerStatus::Connected, } } @@ -80,9 +75,9 @@ impl LinkState { pub fn handshake_deadline(&self) -> Option { match self { Self::Idle | Self::Connected(_) => None, - Self::XxInitiator { deadline, .. } - | Self::XxResponder { deadline, .. } - | Self::KkInitiator { deadline, .. } => Some(*deadline), + Self::IkInitiator { deadline, .. } | Self::KkInitiator { deadline, .. } => { + Some(*deadline) + } } } } diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index e4633ad7..ef368609 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -3,16 +3,16 @@ use std::time::Duration; use ql_wire::QlRecord; use super::*; -use crate::state::LinkState; +use crate::{state::LinkState, QlFsmError}; #[test] -fn kk_connect_round_trip_establishes_transport() { +fn ik_connect_round_trip_establishes_transport() { let mut harness = Harness::paired_known(QlFsmConfig::default()); harness .a .fsm - .connect(harness.time(), &harness.a.crypto) + .connect_ik(harness.time(), &harness.a.crypto) .unwrap(); harness.pump(); @@ -21,36 +21,59 @@ fn kk_connect_round_trip_establishes_transport() { } #[test] -fn xx_connect_round_trip_learns_peer_bundles() { - let mut harness = Harness::paired_unknown(QlFsmConfig::default()); +fn kk_connect_round_trip_establishes_transport() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); harness .a .fsm - .connect(harness.time(), &harness.a.crypto) + .connect_kk(harness.time(), &harness.a.crypto) .unwrap(); harness.pump(); + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +#[test] +fn connect_methods_require_bound_peer() { + let time = Harness::paired_known(QlFsmConfig::default()).time(); + let identity = test_identity(55); + let mut fsm = QlFsm::new(QlFsmConfig::default(), identity, time); + let crypto = TestCrypto::new(9); + + assert_eq!(fsm.connect_ik(time, &crypto), Err(QlFsmError::NoPeerBound)); + assert_eq!(fsm.connect_kk(time, &crypto), Err(QlFsmError::NoPeerBound)); +} + +#[test] +fn connect_methods_return_busy_when_link_is_not_idle() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness + .a + .fsm + .connect_ik(harness.time(), &harness.a.crypto) + .unwrap(); + assert_eq!( - harness.a.fsm.state.peer, - Some(harness.b.fsm.identity.bundle()) + harness.a.fsm.connect_ik(harness.time(), &harness.a.crypto), + Err(QlFsmError::Busy) ); assert_eq!( - harness.b.fsm.state.peer, - Some(harness.a.fsm.identity.bundle()) + harness.a.fsm.connect_kk(harness.time(), &harness.a.crypto), + Err(QlFsmError::Busy) ); - assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); - assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); } #[test] -fn inbound_xx1_auto_binds_unbound_responder() { - let mut harness = Harness::responder_unbound_unknown(QlFsmConfig::default()); +fn inbound_ik1_auto_binds_unbound_responder() { + let mut harness = Harness::paired(QlFsmConfig::default(), true, false); harness .a .fsm - .connect(harness.time(), &harness.a.crypto) + .connect_ik(harness.time(), &harness.a.crypto) .unwrap(); harness.pump(); @@ -58,25 +81,27 @@ fn inbound_xx1_auto_binds_unbound_responder() { harness.b.fsm.state.peer, Some(harness.a.fsm.identity.bundle()) ); + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); } #[test] -fn handshake_timeout_drops_single_attempt_without_resend() { +fn handshake_timeout_drops_single_ik_attempt_without_resend() { let config = QlFsmConfig { handshake_timeout: Duration::from_millis(60), ..QlFsmConfig::default() }; - let mut harness = Harness::paired_unknown(config); + let mut harness = Harness::paired_known(config); harness .a .fsm - .connect(harness.time(), &harness.a.crypto) + .connect_ik(harness.time(), &harness.a.crypto) .unwrap(); let first = harness.next_outbound_a().unwrap(); assert!(matches!( first, - QlRecord::Handshake(ql_wire::QlHandshakeRecord::Xx1(_)) + QlRecord::Handshake(ql_wire::QlHandshakeRecord::Ik1(_)) )); assert!(harness.next_outbound_a().is_none()); @@ -88,17 +113,17 @@ fn handshake_timeout_drops_single_attempt_without_resend() { } #[test] -fn handshake_timeout_clears_queued_handshake_output() { +fn handshake_timeout_clears_queued_kk_output() { let config = QlFsmConfig { handshake_timeout: Duration::from_millis(60), ..QlFsmConfig::default() }; - let mut harness = Harness::paired_unknown(config); + let mut harness = Harness::paired_known(config); harness .a .fsm - .connect(harness.time(), &harness.a.crypto) + .connect_kk(harness.time(), &harness.a.crypto) .unwrap(); harness.advance(config.handshake_timeout); @@ -110,12 +135,12 @@ fn handshake_timeout_clears_queued_handshake_output() { #[test] fn bind_peer_clears_queued_handshake_output() { - let mut harness = Harness::paired_unknown(QlFsmConfig::default()); + let mut harness = Harness::paired_known(QlFsmConfig::default()); harness .a .fsm - .connect(harness.time(), &harness.a.crypto) + .connect_ik(harness.time(), &harness.a.crypto) .unwrap(); harness.a.fsm.bind_peer(test_identity(99).bundle()); @@ -123,18 +148,18 @@ fn bind_peer_clears_queued_handshake_output() { } #[test] -fn simultaneous_xx_connect_converges() { - let mut harness = Harness::paired_unknown(QlFsmConfig::default()); +fn simultaneous_ik_connect_converges() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); harness .a .fsm - .connect(harness.time(), &harness.a.crypto) + .connect_ik(harness.time(), &harness.a.crypto) .unwrap(); harness .b .fsm - .connect(harness.time(), &harness.b.crypto) + .connect_ik(harness.time(), &harness.b.crypto) .unwrap(); harness.pump(); @@ -143,25 +168,21 @@ fn simultaneous_xx_connect_converges() { } #[test] -fn simultaneous_xx_and_kk_connect_prefers_xx() { - let mut harness = Harness::paired(QlFsmConfig::default(), false, true); +fn simultaneous_ik_and_kk_connect_prefers_ik() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); harness .a .fsm - .connect(harness.time(), &harness.a.crypto) + .connect_ik(harness.time(), &harness.a.crypto) .unwrap(); harness .b .fsm - .connect(harness.time(), &harness.b.crypto) + .connect_kk(harness.time(), &harness.b.crypto) .unwrap(); harness.pump(); - assert_eq!( - harness.a.fsm.state.peer, - Some(harness.b.fsm.identity.bundle()) - ); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); } diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 92d0c1cd..f31dbc42 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -156,10 +156,6 @@ impl Harness { Self::paired(config, true, true) } - fn paired_unknown(config: QlFsmConfig) -> Self { - Self::paired(config, false, false) - } - fn paired(config: QlFsmConfig, know_a: bool, know_b: bool) -> Self { let identity_a = test_identity(11); let identity_b = test_identity(73); @@ -194,29 +190,6 @@ impl Harness { harness } - fn responder_unbound_unknown(config: QlFsmConfig) -> Self { - let identity_a = test_identity(11); - let identity_b = test_identity(73); - let now = Instant::now(); - let time = FsmTime { - instant: now, - unix_secs: 1_700_000_000, - }; - - Self { - now, - unix_secs: time.unix_secs, - a: Node { - fsm: QlFsm::new(config, identity_a, time), - crypto: TestCrypto::new(1), - }, - b: Node { - fsm: QlFsm::new(config, identity_b, time), - crypto: TestCrypto::new(2), - }, - } - } - fn connected(config: QlFsmConfig) -> Self { let mut harness = Self::paired_known(config); let a_to_b_key = SessionKey::from_data([7; SessionKey::SIZE]); diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 944964c3..2ec9573b 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -151,13 +151,20 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { } #[test] -fn queued_stream_work_auto_connects_and_drains_after_handshake() { +fn queued_stream_work_waits_for_explicit_connect_and_then_drains() { let mut harness = Harness::paired_known(QlFsmConfig::default()); let stream_id = harness.a.fsm.open_stream().unwrap(); assert_eq!(harness.a.fsm.write_stream(stream_id, b"queued").unwrap(), 6); harness.a.fsm.finish_stream(stream_id).unwrap(); + assert!(harness.next_outbound_a().is_none()); + + harness + .a + .fsm + .connect_ik(harness.time(), &harness.a.crypto) + .unwrap(); harness.pump(); assert_eq!( @@ -189,6 +196,11 @@ fn queued_stream_work_is_failed_when_handshake_times_out() { let stream_id = harness.a.fsm.open_stream().unwrap(); assert_eq!(harness.a.fsm.write_stream(stream_id, b"queued").unwrap(), 6); + harness + .a + .fsm + .connect_ik(harness.time(), &harness.a.crypto) + .unwrap(); let _first = harness.next_outbound_a().unwrap(); harness.advance(config.handshake_timeout); harness.a.fsm.on_timer(harness.time()); From e84ea3042abf8595a2510c682b432bff11687824 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 2 Apr 2026 11:30:32 -0400 Subject: [PATCH 070/304] ql-fsm: clean up clones of handshake methods --- ql-fsm/src/error.rs | 6 +- ql-fsm/src/implementation/handshake/ik.rs | 54 +++++++------- ql-fsm/src/implementation/handshake/kk.rs | 54 +++++++------- ql-fsm/src/implementation/handshake/mod.rs | 13 ++-- ql-fsm/src/lib.rs | 4 +- ql-fsm/src/state.rs | 43 +++++++---- ql-fsm/src/tests/handshake.rs | 86 +++++++++++++++++++--- 7 files changed, 168 insertions(+), 92 deletions(-) diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs index 2c033e4d..c9114e02 100644 --- a/ql-fsm/src/error.rs +++ b/ql-fsm/src/error.rs @@ -7,6 +7,8 @@ use crate::session::StreamError; pub enum QlFsmError { #[error("invalid payload")] InvalidPayload, + #[error("invalid state")] + InvalidState, #[error("expired")] Expired, #[error("decryption failed")] @@ -23,8 +25,6 @@ pub enum QlFsmError { SessionClosed, #[error("no peer bound")] NoPeerBound, - #[error("fsm is busy")] - Busy, #[error("no active session")] NoSession, } @@ -33,9 +33,9 @@ impl From for QlFsmError { fn from(value: WireError) -> Self { match value { WireError::InvalidPayload => Self::InvalidPayload, + WireError::InvalidState => Self::InvalidState, WireError::Expired => Self::Expired, WireError::DecryptFailed => Self::DecryptFailed, - WireError::InvalidState => Self::InvalidPayload, } } } diff --git a/ql-fsm/src/implementation/handshake/ik.rs b/ql-fsm/src/implementation/handshake/ik.rs index 0554048d..487934eb 100644 --- a/ql-fsm/src/implementation/handshake/ik.rs +++ b/ql-fsm/src/implementation/handshake/ik.rs @@ -1,6 +1,4 @@ -use ql_wire::{ - self as wire, Ik1, Ik2, IkMessage, PeerBundle, QlCrypto, QlHandshakeRecord, WireError, -}; +use ql_wire::{self as wire, Ik1, Ik2, IkMessage, PeerBundle, QlCrypto, QlHandshakeRecord}; use super::{ enqueue_handshake, finish_handshake, is_replayed_handshake_start, @@ -8,7 +6,7 @@ use super::{ }; use crate::{ implementation::emit_peer_status, - state::{LinkState, SessionTransport}, + state::{IkInitiatorState, LinkState, SessionTransport}, QlFsm, QlFsmError, }; @@ -24,11 +22,12 @@ pub fn start_initiator( return Err(QlFsmError::InvalidPayload); }; - fsm.state.link = LinkState::IkInitiator { + fsm.state.link = LinkState::IkInitiator(IkInitiatorState { + handshake_id: meta.handshake_id, initial_ephemeral: message.ephemeral.clone(), handshake, deadline: fsm.state.now.instant + fsm.config.handshake_timeout, - }; + }); enqueue_handshake(fsm, QlHandshakeRecord::Ik1(message)); emit_peer_status(fsm); Ok(()) @@ -76,26 +75,27 @@ pub fn handle_ik2( crypto: &impl QlCrypto, message: &Ik2, ) -> Result<(), QlFsmError> { - let LinkState::IkInitiator { - mut handshake, - deadline: _, - initial_ephemeral: _, - } = fsm.state.link.clone() - else { - return Ok(()); - }; + { + let LinkState::IkInitiator(state) = &mut fsm.state.link else { + return Ok(()); + }; - match handshake.read_message( - crypto, - fsm.state.now.unix_secs, - &IkMessage::Message2(message.clone()), - ) { - Ok(()) => {} - Err(WireError::InvalidState) => return Ok(()), - Err(error) => return Err(error.into()), + if message.meta.handshake_id != state.handshake_id { + return Ok(()); + } + + state.handshake.read_message( + crypto, + fsm.state.now.unix_secs, + &IkMessage::Message2(message.clone()), + )?; } - let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); + let LinkState::IkInitiator(state) = fsm.state.link.take() else { + unreachable!("active IK initiator was checked above"); + }; + let (transport, remote_bundle) = + SessionTransport::from_finalized(state.handshake.finalize(crypto)?); finish_handshake(fsm, transport, remote_bundle) } @@ -109,14 +109,12 @@ fn ik_record(message: IkMessage) -> QlHandshakeRecord { pub fn should_ignore_inbound(fsm: &QlFsm, message: &Ik1) -> bool { match &fsm.state.link { LinkState::Idle | LinkState::Connected(_) => false, - LinkState::IkInitiator { - initial_ephemeral, .. - } => { + LinkState::IkInitiator(state) => { if fsm.state.peer.as_ref().map(|peer| peer.xid) != Some(message.header.sender) { return false; } - super::local_start_wins(initial_ephemeral, &message.ephemeral) + super::local_start_wins(&state.initial_ephemeral, &message.ephemeral) } - LinkState::KkInitiator { .. } => false, + LinkState::KkInitiator(_) => false, } } diff --git a/ql-fsm/src/implementation/handshake/kk.rs b/ql-fsm/src/implementation/handshake/kk.rs index 8c642ee4..8fe66ecc 100644 --- a/ql-fsm/src/implementation/handshake/kk.rs +++ b/ql-fsm/src/implementation/handshake/kk.rs @@ -1,6 +1,4 @@ -use ql_wire::{ - self as wire, Kk1, Kk2, KkMessage, PeerBundle, QlCrypto, QlHandshakeRecord, WireError, -}; +use ql_wire::{self as wire, Kk1, Kk2, KkMessage, PeerBundle, QlCrypto, QlHandshakeRecord}; use super::{ enqueue_handshake, finish_handshake, is_replayed_handshake_start, @@ -8,7 +6,7 @@ use super::{ }; use crate::{ implementation::emit_peer_status, - state::{LinkState, SessionTransport}, + state::{KkInitiatorState, LinkState, SessionTransport}, QlFsm, QlFsmError, }; @@ -24,11 +22,12 @@ pub fn start_initiator( return Err(QlFsmError::InvalidPayload); }; - fsm.state.link = LinkState::KkInitiator { + fsm.state.link = LinkState::KkInitiator(KkInitiatorState { + handshake_id: meta.handshake_id, initial_ephemeral: message.ephemeral.clone(), handshake, deadline: fsm.state.now.instant + fsm.config.handshake_timeout, - }; + }); enqueue_handshake(fsm, QlHandshakeRecord::Kk1(message)); emit_peer_status(fsm); Ok(()) @@ -74,26 +73,27 @@ pub fn handle_kk2( crypto: &impl QlCrypto, message: &Kk2, ) -> Result<(), QlFsmError> { - let LinkState::KkInitiator { - mut handshake, - deadline: _, - initial_ephemeral: _, - } = fsm.state.link.clone() - else { - return Ok(()); - }; + { + let LinkState::KkInitiator(state) = &mut fsm.state.link else { + return Ok(()); + }; - match handshake.read_message( - crypto, - fsm.state.now.unix_secs, - &KkMessage::Message2(message.clone()), - ) { - Ok(()) => {} - Err(WireError::InvalidState) => return Ok(()), - Err(error) => return Err(error.into()), + if message.meta.handshake_id != state.handshake_id { + return Ok(()); + } + + state.handshake.read_message( + crypto, + fsm.state.now.unix_secs, + &KkMessage::Message2(message.clone()), + )?; } - let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); + let LinkState::KkInitiator(state) = fsm.state.link.take() else { + unreachable!("active KK initiator was checked above"); + }; + let (transport, remote_bundle) = + SessionTransport::from_finalized(state.handshake.finalize(crypto)?); finish_handshake(fsm, transport, remote_bundle) } @@ -107,14 +107,12 @@ fn kk_record(message: KkMessage) -> QlHandshakeRecord { pub fn should_ignore_inbound(fsm: &QlFsm, message: &Kk1) -> bool { match &fsm.state.link { LinkState::Idle | LinkState::Connected(_) => false, - LinkState::IkInitiator { .. } => true, - LinkState::KkInitiator { - initial_ephemeral, .. - } => { + LinkState::IkInitiator(_) => true, + LinkState::KkInitiator(state) => { if fsm.state.peer.as_ref().map(|peer| peer.xid) != Some(message.header.sender) { return false; } - super::local_start_wins(initial_ephemeral, &message.ephemeral) + super::local_start_wins(&state.initial_ephemeral, &message.ephemeral) } } } diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs index d641d65e..8942aa56 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -10,18 +10,14 @@ use crate::{ }; pub fn handle_connect_ik(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { - if !matches!(fsm.state.link, LinkState::Idle) { - return Err(QlFsmError::Busy); - } let peer = fsm.state.peer.clone().ok_or(QlFsmError::NoPeerBound)?; + prepare_for_outbound_connect(fsm); ik::start_initiator(fsm, crypto, peer) } pub fn handle_connect_kk(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { - if !matches!(fsm.state.link, LinkState::Idle) { - return Err(QlFsmError::Busy); - } let peer = fsm.state.peer.clone().ok_or(QlFsmError::NoPeerBound)?; + prepare_for_outbound_connect(fsm); kk::start_initiator(fsm, crypto, peer) } @@ -42,6 +38,11 @@ pub fn enqueue_handshake(fsm: &mut QlFsm, record: QlHandshakeRecord) { fsm.state.handshake = Some(record); } +pub fn prepare_for_outbound_connect(fsm: &mut QlFsm) { + fsm.state.handshake = None; + reset_connected_session_if_needed(fsm); +} + pub fn is_replayed_handshake_start(fsm: &mut QlFsm, meta: HandshakeMeta) -> bool { fsm.state .replay_cache diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 0ed04024..e9d130f3 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -193,13 +193,13 @@ impl QlFsm { implementation::handle_bind_peer(self, peer); } - /// starts an IK handshake with the currently bound peer + /// starts or replaces an IK handshake with the currently bound peer pub fn connect_ik(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { self.state.now = now; implementation::handle_connect_ik(self, crypto) } - /// starts a KK handshake with the currently bound peer + /// starts or replaces a KK handshake with the currently bound peer pub fn connect_kk(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { self.state.now = now; implementation::handle_connect_kk(self, crypto) diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index 73942284..80b2c826 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -1,8 +1,8 @@ use std::{collections::VecDeque, time::Instant}; use ql_wire::{ - ConnectionId, EphemeralPublicKey, IkHandshake, KkHandshake, PeerBundle, QlHandshakeRecord, - SessionKey, + ConnectionId, EphemeralPublicKey, HandshakeId, IkHandshake, KkHandshake, PeerBundle, + QlHandshakeRecord, SessionKey, }; use crate::{replay_cache::ReplayCache, FsmTime, PeerStatus, QlFsmEvent, QlSessionEvent}; @@ -43,24 +43,36 @@ impl SessionTransport { #[derive(Debug, Clone)] pub enum LinkState { Idle, - IkInitiator { - handshake: IkHandshake, - deadline: Instant, - initial_ephemeral: EphemeralPublicKey, - }, - KkInitiator { - handshake: KkHandshake, - deadline: Instant, - initial_ephemeral: EphemeralPublicKey, - }, + IkInitiator(IkInitiatorState), + KkInitiator(KkInitiatorState), Connected(SessionTransport), } +#[derive(Debug, Clone)] +pub struct IkInitiatorState { + pub handshake: IkHandshake, + pub handshake_id: HandshakeId, + pub deadline: Instant, + pub initial_ephemeral: EphemeralPublicKey, +} + +#[derive(Debug, Clone)] +pub struct KkInitiatorState { + pub handshake: KkHandshake, + pub handshake_id: HandshakeId, + pub deadline: Instant, + pub initial_ephemeral: EphemeralPublicKey, +} + impl LinkState { + pub fn take(&mut self) -> Self { + std::mem::replace(self, Self::Idle) + } + pub fn status(&self) -> PeerStatus { match self { Self::Idle => PeerStatus::Disconnected, - Self::IkInitiator { .. } | Self::KkInitiator { .. } => PeerStatus::Initiator, + Self::IkInitiator(_) | Self::KkInitiator(_) => PeerStatus::Initiator, Self::Connected(_) => PeerStatus::Connected, } } @@ -75,9 +87,8 @@ impl LinkState { pub fn handshake_deadline(&self) -> Option { match self { Self::Idle | Self::Connected(_) => None, - Self::IkInitiator { deadline, .. } | Self::KkInitiator { deadline, .. } => { - Some(*deadline) - } + Self::IkInitiator(state) => Some(state.deadline), + Self::KkInitiator(state) => Some(state.deadline), } } } diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index ef368609..11857149 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -47,7 +47,7 @@ fn connect_methods_require_bound_peer() { } #[test] -fn connect_methods_return_busy_when_link_is_not_idle() { +fn connect_ik_replaces_in_flight_attempt_and_ignores_stale_reply() { let mut harness = Harness::paired_known(QlFsmConfig::default()); harness @@ -55,15 +55,73 @@ fn connect_methods_return_busy_when_link_is_not_idle() { .fsm .connect_ik(harness.time(), &harness.a.crypto) .unwrap(); + let first = harness.next_outbound_a().unwrap(); + let first_id = handshake_id(&first); - assert_eq!( - harness.a.fsm.connect_ik(harness.time(), &harness.a.crypto), - Err(QlFsmError::Busy) - ); - assert_eq!( - harness.a.fsm.connect_kk(harness.time(), &harness.a.crypto), - Err(QlFsmError::Busy) - ); + harness + .a + .fsm + .connect_ik(harness.time(), &harness.a.crypto) + .unwrap(); + let second = harness.next_outbound_a().unwrap(); + let second_id = handshake_id(&second); + + assert_ne!(first_id, second_id); + + harness.deliver_to_b(first); + let stale_reply = harness.next_outbound_b().unwrap(); + assert_eq!(handshake_id(&stale_reply), first_id); + + harness.deliver_to_a(stale_reply); + assert!(matches!( + harness.a.fsm.state.link, + LinkState::IkInitiator(_) + )); + + harness.deliver_to_b(second); + harness.pump(); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +#[test] +fn connect_kk_replaces_in_flight_attempt_and_ignores_stale_reply() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness + .a + .fsm + .connect_kk(harness.time(), &harness.a.crypto) + .unwrap(); + let first = harness.next_outbound_a().unwrap(); + let first_id = handshake_id(&first); + + harness + .a + .fsm + .connect_kk(harness.time(), &harness.a.crypto) + .unwrap(); + let second = harness.next_outbound_a().unwrap(); + let second_id = handshake_id(&second); + + assert_ne!(first_id, second_id); + + harness.deliver_to_b(first); + let stale_reply = harness.next_outbound_b().unwrap(); + assert_eq!(handshake_id(&stale_reply), first_id); + + harness.deliver_to_a(stale_reply); + assert!(matches!( + harness.a.fsm.state.link, + LinkState::KkInitiator(_) + )); + + harness.deliver_to_b(second); + harness.pump(); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); } #[test] @@ -186,3 +244,13 @@ fn simultaneous_ik_and_kk_connect_prefers_ik() { assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); } + +fn handshake_id(record: &QlRecord>) -> ql_wire::HandshakeId { + match record { + QlRecord::Handshake(ql_wire::QlHandshakeRecord::Ik1(message)) => message.meta.handshake_id, + QlRecord::Handshake(ql_wire::QlHandshakeRecord::Ik2(message)) => message.meta.handshake_id, + QlRecord::Handshake(ql_wire::QlHandshakeRecord::Kk1(message)) => message.meta.handshake_id, + QlRecord::Handshake(ql_wire::QlHandshakeRecord::Kk2(message)) => message.meta.handshake_id, + QlRecord::Session(_) => panic!("expected handshake record"), + } +} From 153effc91e8524e65b2b2415a38b882ceacc19eb Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 2 Apr 2026 12:47:58 -0400 Subject: [PATCH 071/304] ql: remove ikmessage and kkmessage wrapper enums --- ql-fsm/src/implementation/handshake/ik.rs | 32 +-- ql-fsm/src/implementation/handshake/kk.rs | 32 +-- ql-wire/src/handshake/ik.rs | 296 +++++++++++----------- ql-wire/src/handshake/kk.rs | 263 +++++++++---------- ql-wire/src/handshake/mod.rs | 4 +- ql-wire/src/tests.rs | 113 +++------ 6 files changed, 341 insertions(+), 399 deletions(-) diff --git a/ql-fsm/src/implementation/handshake/ik.rs b/ql-fsm/src/implementation/handshake/ik.rs index 487934eb..ffcdd569 100644 --- a/ql-fsm/src/implementation/handshake/ik.rs +++ b/ql-fsm/src/implementation/handshake/ik.rs @@ -1,4 +1,4 @@ -use ql_wire::{self as wire, Ik1, Ik2, IkMessage, PeerBundle, QlCrypto, QlHandshakeRecord}; +use ql_wire::{self as wire, Ik1, Ik2, PeerBundle, QlCrypto, QlHandshakeRecord}; use super::{ enqueue_handshake, finish_handshake, is_replayed_handshake_start, @@ -17,10 +17,7 @@ pub fn start_initiator( ) -> Result<(), QlFsmError> { let meta = super::next_handshake_meta(fsm); let mut handshake = wire::IkHandshake::new_initiator(crypto, fsm.identity.clone(), peer); - let message = handshake.write_message(crypto, meta)?; - let IkMessage::Message1(message) = message else { - return Err(QlFsmError::InvalidPayload); - }; + let message = handshake.write_1(crypto, meta)?; fsm.state.link = LinkState::IkInitiator(IkInitiatorState { handshake_id: meta.handshake_id, @@ -57,16 +54,12 @@ pub fn handle_ik1( let mut handshake = wire::IkHandshake::new_responder(crypto, fsm.identity.clone(), fsm.state.peer.clone()); - handshake.read_message( - crypto, - fsm.state.now.unix_secs, - &IkMessage::Message1(message.clone()), - )?; - let outbound = handshake.write_message(crypto, message.meta)?; + handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; + let outbound = handshake.write_2(crypto, message.meta)?; let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); finish_handshake(fsm, transport, remote_bundle)?; fsm.state.handshake = None; - enqueue_handshake(fsm, ik_record(outbound)); + enqueue_handshake(fsm, QlHandshakeRecord::Ik2(outbound)); Ok(()) } @@ -84,11 +77,9 @@ pub fn handle_ik2( return Ok(()); } - state.handshake.read_message( - crypto, - fsm.state.now.unix_secs, - &IkMessage::Message2(message.clone()), - )?; + state + .handshake + .read_2(crypto, fsm.state.now.unix_secs, message)?; } let LinkState::IkInitiator(state) = fsm.state.link.take() else { @@ -99,13 +90,6 @@ pub fn handle_ik2( finish_handshake(fsm, transport, remote_bundle) } -fn ik_record(message: IkMessage) -> QlHandshakeRecord { - match message { - IkMessage::Message1(message) => QlHandshakeRecord::Ik1(message), - IkMessage::Message2(message) => QlHandshakeRecord::Ik2(message), - } -} - pub fn should_ignore_inbound(fsm: &QlFsm, message: &Ik1) -> bool { match &fsm.state.link { LinkState::Idle | LinkState::Connected(_) => false, diff --git a/ql-fsm/src/implementation/handshake/kk.rs b/ql-fsm/src/implementation/handshake/kk.rs index 8fe66ecc..63b03b6b 100644 --- a/ql-fsm/src/implementation/handshake/kk.rs +++ b/ql-fsm/src/implementation/handshake/kk.rs @@ -1,4 +1,4 @@ -use ql_wire::{self as wire, Kk1, Kk2, KkMessage, PeerBundle, QlCrypto, QlHandshakeRecord}; +use ql_wire::{self as wire, Kk1, Kk2, PeerBundle, QlCrypto, QlHandshakeRecord}; use super::{ enqueue_handshake, finish_handshake, is_replayed_handshake_start, @@ -17,10 +17,7 @@ pub fn start_initiator( ) -> Result<(), QlFsmError> { let meta = super::next_handshake_meta(fsm); let mut handshake = wire::KkHandshake::new_initiator(crypto, fsm.identity.clone(), peer); - let message = handshake.write_message(crypto, meta)?; - let KkMessage::Message1(message) = message else { - return Err(QlFsmError::InvalidPayload); - }; + let message = handshake.write_1(crypto, meta)?; fsm.state.link = LinkState::KkInitiator(KkInitiatorState { handshake_id: meta.handshake_id, @@ -55,16 +52,12 @@ pub fn handle_kk1( reset_connected_session_if_needed(fsm); let mut handshake = wire::KkHandshake::new_responder(crypto, fsm.identity.clone(), peer); - handshake.read_message( - crypto, - fsm.state.now.unix_secs, - &KkMessage::Message1(message.clone()), - )?; - let outbound = handshake.write_message(crypto, message.meta)?; + handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; + let outbound = handshake.write_2(crypto, message.meta)?; let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); finish_handshake(fsm, transport, remote_bundle)?; fsm.state.handshake = None; - enqueue_handshake(fsm, kk_record(outbound)); + enqueue_handshake(fsm, QlHandshakeRecord::Kk2(outbound)); Ok(()) } @@ -82,11 +75,9 @@ pub fn handle_kk2( return Ok(()); } - state.handshake.read_message( - crypto, - fsm.state.now.unix_secs, - &KkMessage::Message2(message.clone()), - )?; + state + .handshake + .read_2(crypto, fsm.state.now.unix_secs, message)?; } let LinkState::KkInitiator(state) = fsm.state.link.take() else { @@ -97,13 +88,6 @@ pub fn handle_kk2( finish_handshake(fsm, transport, remote_bundle) } -fn kk_record(message: KkMessage) -> QlHandshakeRecord { - match message { - KkMessage::Message1(message) => QlHandshakeRecord::Kk1(message), - KkMessage::Message2(message) => QlHandshakeRecord::Kk2(message), - } -} - pub fn should_ignore_inbound(fsm: &QlFsm, message: &Kk1) -> bool { match &fsm.state.link { LinkState::Idle | LinkState::Connected(_) => false, diff --git a/ql-wire/src/handshake/ik.rs b/ql-wire/src/handshake/ik.rs index d8a4be0d..7be407a2 100644 --- a/ql-wire/src/handshake/ik.rs +++ b/ql-wire/src/handshake/ik.rs @@ -90,12 +90,6 @@ impl Ik2 { } } -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum IkMessage { - Message1(Ik1), - Message2(Ik2), -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum IkStep { Send1, @@ -183,166 +177,174 @@ impl IkHandshake { Ok(()) } - pub fn write_message( + pub fn write_1( &mut self, crypto: &impl QlCrypto, meta: HandshakeMeta, - ) -> Result { - match self.step { - IkStep::Send1 => { - initialize_handshake_meta(&mut self.handshake_meta, meta)?; - let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; - let header = self.outbound_header()?; - mix_hash_routed_handshake( - &mut self.symmetric, - crypto, - header, - HandshakeKind::Ik1, - &meta, - ); - let (skem_ciphertext, skem_secret) = - crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); - self.symmetric.mix_hash(crypto, skem_ciphertext.as_bytes()); - self.symmetric - .mix_key_and_hash(crypto, skem_secret.as_bytes()); - - let local_ephemeral = generate_ephemeral_keypair(crypto); - let public = local_ephemeral.public(); - mix_hash_ephemeral(&mut self.symmetric, crypto, &public); + ) -> Result { + if self.step != IkStep::Send1 { + return Err(WireError::InvalidState); + } + initialize_handshake_meta(&mut self.handshake_meta, meta)?; + let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; + let header = self.outbound_header()?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Ik1, + &meta, + ); + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); + self.symmetric.mix_hash(crypto, skem_ciphertext.as_bytes()); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); - let static_bundle = - encrypt_peer_bundle(crypto, &mut self.symmetric, &self.local.bundle())?; + let local_ephemeral = generate_ephemeral_keypair(crypto); + let public = local_ephemeral.public(); + mix_hash_ephemeral(&mut self.symmetric, crypto, &public); - self.local_ephemeral = Some(local_ephemeral); - self.step = IkStep::Recv2; - Ok(IkMessage::Message1(Ik1 { - header, - meta, - skem_ciphertext, - ephemeral: public, - static_bundle, - })) - } - IkStep::Send2 => { - require_handshake_meta(&self.handshake_meta, meta)?; - let header = self.outbound_header()?; - mix_hash_routed_handshake( - &mut self.symmetric, - crypto, - header, - HandshakeKind::Ik2, - &meta, - ); - let remote_ephemeral = self - .remote_ephemeral - .clone() - .ok_or(WireError::InvalidState)?; - let (ekem_ciphertext, ekem_secret) = - crypto.mlkem_encapsulate(&remote_ephemeral.mlkem_public_key); - self.symmetric.mix_hash(crypto, ekem_ciphertext.as_bytes()); - self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + let static_bundle = encrypt_peer_bundle(crypto, &mut self.symmetric, &self.local.bundle())?; - let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; - let (skem_ciphertext, skem_secret) = - crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); - let skem_ciphertext = - encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; - self.symmetric - .mix_key_and_hash(crypto, skem_secret.as_bytes()); + self.local_ephemeral = Some(local_ephemeral); + self.step = IkStep::Recv2; + Ok(Ik1 { + header, + meta, + skem_ciphertext, + ephemeral: public, + static_bundle, + }) + } - self.step = IkStep::Done; - Ok(IkMessage::Message2(Ik2 { - header, - meta, - ekem_ciphertext, - skem_ciphertext, - })) - } - _ => Err(WireError::InvalidState), + pub fn write_2( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != IkStep::Send2 { + return Err(WireError::InvalidState); } + require_handshake_meta(&self.handshake_meta, meta)?; + let header = self.outbound_header()?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Ik2, + &meta, + ); + let remote_ephemeral = self + .remote_ephemeral + .clone() + .ok_or(WireError::InvalidState)?; + let (ekem_ciphertext, ekem_secret) = + crypto.mlkem_encapsulate(&remote_ephemeral.mlkem_public_key); + self.symmetric.mix_hash(crypto, ekem_ciphertext.as_bytes()); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); + let skem_ciphertext = + encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + self.step = IkStep::Done; + Ok(Ik2 { + header, + meta, + ekem_ciphertext, + skem_ciphertext, + }) } - pub fn read_message( + pub fn read_1( &mut self, crypto: &impl QlCrypto, now_seconds: u64, - message: &IkMessage, + message: &Ik1, ) -> Result<(), WireError> { - match (&self.step, message) { - (IkStep::Recv1, IkMessage::Message1(message)) => { - message.meta.ensure_not_expired(now_seconds)?; - initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; - self.ensure_inbound_recipient(message.header)?; - self.ensure_known_remote_sender(message.header)?; - mix_hash_routed_handshake( - &mut self.symmetric, - crypto, - message.header, - HandshakeKind::Ik1, - &message.meta, - ); - self.symmetric - .mix_hash(crypto, message.skem_ciphertext.as_bytes()); - let skem_secret = crypto - .mlkem_decapsulate(&self.local.mlkem_private_key, &message.skem_ciphertext); - self.symmetric - .mix_key_and_hash(crypto, skem_secret.as_bytes()); + if self.step != IkStep::Recv1 { + return Err(WireError::InvalidState); + } + message.meta.ensure_not_expired(now_seconds)?; + initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; + self.ensure_inbound_recipient(message.header)?; + self.ensure_known_remote_sender(message.header)?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Ik1, + &message.meta, + ); + self.symmetric + .mix_hash(crypto, message.skem_ciphertext.as_bytes()); + let skem_secret = + crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &message.skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); - mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); - self.remote_ephemeral = Some(message.ephemeral.clone()); + mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); + self.remote_ephemeral = Some(message.ephemeral.clone()); - let remote_bundle = - decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; - if remote_bundle.xid != message.header.sender { - return Err(WireError::InvalidPayload); - } - match self.remote_bundle.as_ref() { - Some(expected) if expected != &remote_bundle => { - return Err(WireError::InvalidPayload); - } - Some(_) => {} - None => self.remote_bundle = Some(remote_bundle), - } - self.step = IkStep::Send2; - Ok(()) + let remote_bundle = + decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; + if remote_bundle.xid != message.header.sender { + return Err(WireError::InvalidPayload); + } + match self.remote_bundle.as_ref() { + Some(expected) if expected != &remote_bundle => { + return Err(WireError::InvalidPayload); } - (IkStep::Recv2, IkMessage::Message2(message)) => { - message.meta.ensure_not_expired(now_seconds)?; - require_handshake_meta(&self.handshake_meta, message.meta)?; - self.ensure_inbound_recipient(message.header)?; - self.ensure_known_remote_sender(message.header)?; - mix_hash_routed_handshake( - &mut self.symmetric, - crypto, - message.header, - HandshakeKind::Ik2, - &message.meta, - ); - let local_ephemeral = self - .local_ephemeral - .as_ref() - .ok_or(WireError::InvalidState)?; - self.symmetric - .mix_hash(crypto, message.ekem_ciphertext.as_bytes()); - let ekem_secret = crypto - .mlkem_decapsulate(&local_ephemeral.mlkem.private, &message.ekem_ciphertext); - self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); - - let skem_ciphertext = decrypt_mlkem_ciphertext( - crypto, - &mut self.symmetric, - &message.skem_ciphertext, - )?; - let skem_secret = - crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &skem_ciphertext); - self.symmetric - .mix_key_and_hash(crypto, skem_secret.as_bytes()); + Some(_) => {} + None => self.remote_bundle = Some(remote_bundle), + } + self.step = IkStep::Send2; + Ok(()) + } - self.step = IkStep::Done; - Ok(()) - } - _ => Err(WireError::InvalidState), + pub fn read_2( + &mut self, + crypto: &impl QlCrypto, + now_seconds: u64, + message: &Ik2, + ) -> Result<(), WireError> { + if self.step != IkStep::Recv2 { + return Err(WireError::InvalidState); } + message.meta.ensure_not_expired(now_seconds)?; + require_handshake_meta(&self.handshake_meta, message.meta)?; + self.ensure_inbound_recipient(message.header)?; + self.ensure_known_remote_sender(message.header)?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Ik2, + &message.meta, + ); + let local_ephemeral = self + .local_ephemeral + .as_ref() + .ok_or(WireError::InvalidState)?; + self.symmetric + .mix_hash(crypto, message.ekem_ciphertext.as_bytes()); + let ekem_secret = + crypto.mlkem_decapsulate(&local_ephemeral.mlkem.private, &message.ekem_ciphertext); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let skem_ciphertext = + decrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &message.skem_ciphertext)?; + let skem_secret = crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + self.step = IkStep::Done; + Ok(()) } pub fn finalize(self, crypto: &impl QlCrypto) -> Result { diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index a7b3c3fc..5f4ff45e 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -84,12 +84,6 @@ impl Kk2 { } } -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum KkMessage { - Message1(Kk1), - Message2(Kk2), -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum KkStep { Send1, @@ -174,146 +168,155 @@ impl KkHandshake { } } - pub fn write_message( + pub fn write_1( &mut self, crypto: &impl QlCrypto, meta: HandshakeMeta, - ) -> Result { - match self.step { - KkStep::Send1 => { - initialize_handshake_meta(&mut self.handshake_meta, meta)?; - let header = self.outbound_header(); - mix_hash_routed_handshake( - &mut self.symmetric, - crypto, - header, - HandshakeKind::Kk1, - &meta, - ); - let (skem_ciphertext, skem_secret) = - crypto.mlkem_encapsulate(&self.remote_bundle.mlkem_public_key); - self.symmetric - .encrypt_and_hash(crypto, skem_ciphertext.as_bytes())?; - self.symmetric - .mix_key_and_hash(crypto, skem_secret.as_bytes()); - - let local_ephemeral = generate_ephemeral_keypair(crypto); - let public = local_ephemeral.public(); - mix_hash_ephemeral(&mut self.symmetric, crypto, &public); + ) -> Result { + if self.step != KkStep::Send1 { + return Err(WireError::InvalidState); + } + initialize_handshake_meta(&mut self.handshake_meta, meta)?; + let header = self.outbound_header(); + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Kk1, + &meta, + ); + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&self.remote_bundle.mlkem_public_key); + self.symmetric + .encrypt_and_hash(crypto, skem_ciphertext.as_bytes())?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); - self.local_ephemeral = Some(local_ephemeral); - self.step = KkStep::Recv2; - Ok(KkMessage::Message1(Kk1 { - header, - meta, - skem_ciphertext, - ephemeral: public, - })) - } - KkStep::Send2 => { - require_handshake_meta(&self.handshake_meta, meta)?; - let header = self.outbound_header(); - mix_hash_routed_handshake( - &mut self.symmetric, - crypto, - header, - HandshakeKind::Kk2, - &meta, - ); - let remote_ephemeral = self - .remote_ephemeral - .clone() - .ok_or(WireError::InvalidState)?; - let (ekem_ciphertext, ekem_secret) = - crypto.mlkem_encapsulate(&remote_ephemeral.mlkem_public_key); - self.symmetric.mix_hash(crypto, ekem_ciphertext.as_bytes()); - self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + let local_ephemeral = generate_ephemeral_keypair(crypto); + let public = local_ephemeral.public(); + mix_hash_ephemeral(&mut self.symmetric, crypto, &public); - let (skem_ciphertext, skem_secret) = - crypto.mlkem_encapsulate(&self.remote_bundle.mlkem_public_key); - let skem_ciphertext = - encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; - self.symmetric - .mix_key_and_hash(crypto, skem_secret.as_bytes()); + self.local_ephemeral = Some(local_ephemeral); + self.step = KkStep::Recv2; + Ok(Kk1 { + header, + meta, + skem_ciphertext, + ephemeral: public, + }) + } - self.step = KkStep::Done; - Ok(KkMessage::Message2(Kk2 { - header, - meta, - ekem_ciphertext, - skem_ciphertext, - })) - } - _ => Err(WireError::InvalidState), + pub fn write_2( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != KkStep::Send2 { + return Err(WireError::InvalidState); } + require_handshake_meta(&self.handshake_meta, meta)?; + let header = self.outbound_header(); + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Kk2, + &meta, + ); + let remote_ephemeral = self + .remote_ephemeral + .clone() + .ok_or(WireError::InvalidState)?; + let (ekem_ciphertext, ekem_secret) = + crypto.mlkem_encapsulate(&remote_ephemeral.mlkem_public_key); + self.symmetric.mix_hash(crypto, ekem_ciphertext.as_bytes()); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&self.remote_bundle.mlkem_public_key); + let skem_ciphertext = + encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + self.step = KkStep::Done; + Ok(Kk2 { + header, + meta, + ekem_ciphertext, + skem_ciphertext, + }) } - pub fn read_message( + pub fn read_1( &mut self, crypto: &impl QlCrypto, now_seconds: u64, - message: &KkMessage, + message: &Kk1, ) -> Result<(), WireError> { - match (&self.step, message) { - (KkStep::Recv1, KkMessage::Message1(message)) => { - message.meta.ensure_not_expired(now_seconds)?; - initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; - self.ensure_inbound_header(message.header)?; - mix_hash_routed_handshake( - &mut self.symmetric, - crypto, - message.header, - HandshakeKind::Kk1, - &message.meta, - ); - self.symmetric - .decrypt_and_hash(crypto, message.skem_ciphertext.as_bytes())?; - let skem_secret = crypto - .mlkem_decapsulate(&self.local.mlkem_private_key, &message.skem_ciphertext); - self.symmetric - .mix_key_and_hash(crypto, skem_secret.as_bytes()); - - mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); - self.remote_ephemeral = Some(message.ephemeral.clone()); - self.step = KkStep::Send2; - Ok(()) - } - (KkStep::Recv2, KkMessage::Message2(message)) => { - message.meta.ensure_not_expired(now_seconds)?; - require_handshake_meta(&self.handshake_meta, message.meta)?; - self.ensure_inbound_header(message.header)?; - mix_hash_routed_handshake( - &mut self.symmetric, - crypto, - message.header, - HandshakeKind::Kk2, - &message.meta, - ); - let local_ephemeral = self - .local_ephemeral - .as_ref() - .ok_or(WireError::InvalidState)?; - self.symmetric - .mix_hash(crypto, message.ekem_ciphertext.as_bytes()); - let ekem_secret = crypto - .mlkem_decapsulate(&local_ephemeral.mlkem.private, &message.ekem_ciphertext); - self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + if self.step != KkStep::Recv1 { + return Err(WireError::InvalidState); + } + message.meta.ensure_not_expired(now_seconds)?; + initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; + self.ensure_inbound_header(message.header)?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Kk1, + &message.meta, + ); + self.symmetric + .decrypt_and_hash(crypto, message.skem_ciphertext.as_bytes())?; + let skem_secret = + crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &message.skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); - let skem_ciphertext = decrypt_mlkem_ciphertext( - crypto, - &mut self.symmetric, - &message.skem_ciphertext, - )?; - let skem_secret = - crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &skem_ciphertext); - self.symmetric - .mix_key_and_hash(crypto, skem_secret.as_bytes()); + mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); + self.remote_ephemeral = Some(message.ephemeral.clone()); + self.step = KkStep::Send2; + Ok(()) + } - self.step = KkStep::Done; - Ok(()) - } - _ => Err(WireError::InvalidState), + pub fn read_2( + &mut self, + crypto: &impl QlCrypto, + now_seconds: u64, + message: &Kk2, + ) -> Result<(), WireError> { + if self.step != KkStep::Recv2 { + return Err(WireError::InvalidState); } + message.meta.ensure_not_expired(now_seconds)?; + require_handshake_meta(&self.handshake_meta, message.meta)?; + self.ensure_inbound_header(message.header)?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Kk2, + &message.meta, + ); + let local_ephemeral = self + .local_ephemeral + .as_ref() + .ok_or(WireError::InvalidState)?; + self.symmetric + .mix_hash(crypto, message.ekem_ciphertext.as_bytes()); + let ekem_secret = + crypto.mlkem_decapsulate(&local_ephemeral.mlkem.private, &message.ekem_ciphertext); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let skem_ciphertext = + decrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &message.skem_ciphertext)?; + let skem_secret = crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + self.step = KkStep::Done; + Ok(()) } pub fn finalize(self, crypto: &impl QlCrypto) -> Result { diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index eca64321..edbbc3a7 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -7,8 +7,8 @@ mod ik; mod kk; mod meta; -pub use ik::{Ik1, Ik2, IkHandshake, IkMessage}; -pub use kk::{Kk1, Kk2, KkHandshake, KkMessage}; +pub use ik::{Ik1, Ik2, IkHandshake}; +pub use kk::{Kk1, Kk2, KkHandshake}; pub use meta::{HandshakeId, HandshakeMeta}; const SHA256_BLOCK_LEN: usize = 64; diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 4ddb27f9..c9ed6245 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -169,20 +169,6 @@ fn handshake_header(sender: u8, recipient: u8) -> HandshakeHeader { } } -fn ik_record(message: IkMessage) -> QlHandshakeRecord { - match message { - IkMessage::Message1(message) => QlHandshakeRecord::Ik1(message), - IkMessage::Message2(message) => QlHandshakeRecord::Ik2(message), - } -} - -fn kk_record(message: KkMessage) -> QlHandshakeRecord { - match message { - KkMessage::Message1(message) => QlHandshakeRecord::Kk1(message), - KkMessage::Message2(message) => QlHandshakeRecord::Kk2(message), - } -} - #[test] fn peer_bundle_round_trip() { let crypto = TestCrypto::new(1); @@ -239,20 +225,17 @@ fn ik_handshake_rejects_tampered_handshake_meta() { let mut responder_state = IkHandshake::new_responder(&crypto, responder, None); let m1 = initiator_state - .write_message(&crypto, handshake_meta(77)) + .write_1(&crypto, handshake_meta(77)) .unwrap(); - responder_state.read_message(&crypto, 0, &m1).unwrap(); + responder_state.read_1(&crypto, 0, &m1).unwrap(); let mut m2 = responder_state - .write_message(&crypto, handshake_meta(77)) + .write_2(&crypto, handshake_meta(77)) .unwrap(); - let IkMessage::Message2(message) = &mut m2 else { - panic!("expected ik2"); - }; - message.meta.handshake_id = HandshakeId(78); + m2.meta.handshake_id = HandshakeId(78); assert_eq!( - initiator_state.read_message(&crypto, 0, &m2), + initiator_state.read_2(&crypto, 0, &m2), Err(WireError::InvalidPayload) ); } @@ -268,20 +251,17 @@ fn kk_handshake_rejects_tampered_handshake_header() { let mut responder_state = KkHandshake::new_responder(&crypto, responder, initiator.bundle()); let m1 = initiator_state - .write_message(&crypto, handshake_meta(88)) + .write_1(&crypto, handshake_meta(88)) .unwrap(); - responder_state.read_message(&crypto, 0, &m1).unwrap(); + responder_state.read_1(&crypto, 0, &m1).unwrap(); let mut m2 = responder_state - .write_message(&crypto, handshake_meta(88)) + .write_2(&crypto, handshake_meta(88)) .unwrap(); - let KkMessage::Message2(message) = &mut m2 else { - panic!("expected kk2"); - }; - message.header = handshake_header(9, 1); + m2.header = handshake_header(9, 1); assert_eq!( - initiator_state.read_message(&crypto, 0, &m2), + initiator_state.read_2(&crypto, 0, &m2), Err(WireError::InvalidPayload) ); } @@ -296,15 +276,12 @@ fn ik_handshake_rejects_tampered_handshake_header() { let mut responder_state = IkHandshake::new_responder(&crypto, responder, None); let mut m1 = initiator_state - .write_message(&crypto, handshake_meta(90)) + .write_1(&crypto, handshake_meta(90)) .unwrap(); - let IkMessage::Message1(message) = &mut m1 else { - panic!("expected ik1"); - }; - message.header.sender = xid(9); + m1.header.sender = xid(9); assert_eq!( - responder_state.read_message(&crypto, 0, &m1), + responder_state.read_1(&crypto, 0, &m1), Err(WireError::DecryptFailed) ); } @@ -320,11 +297,11 @@ fn ik_handshake_rejects_bound_remote_bundle_mismatch() { let mut responder_state = IkHandshake::new_responder(&crypto, responder, Some(bogus.bundle())); let m1 = initiator_state - .write_message(&crypto, handshake_meta(91)) + .write_1(&crypto, handshake_meta(91)) .unwrap(); assert_eq!( - responder_state.read_message(&crypto, 0, &m1), + responder_state.read_1(&crypto, 0, &m1), Err(WireError::InvalidPayload) ); } @@ -339,7 +316,7 @@ fn ik_handshake_rejects_expired_message() { let mut responder_state = IkHandshake::new_responder(&crypto, responder, None); let m1 = initiator_state - .write_message( + .write_1( &crypto, HandshakeMeta { handshake_id: HandshakeId(92), @@ -349,7 +326,7 @@ fn ik_handshake_rejects_expired_message() { .unwrap(); assert_eq!( - responder_state.read_message(&crypto, 6, &m1), + responder_state.read_1(&crypto, 6, &m1), Err(WireError::Expired) ); } @@ -365,14 +342,14 @@ fn ik_handshake_round_trip_derives_matching_transport_and_learns_remote() { let mut responder_state = IkHandshake::new_responder(&crypto, responder.clone(), None); let m1 = initiator_state - .write_message(&crypto, handshake_meta(11)) + .write_1(&crypto, handshake_meta(11)) .unwrap(); - responder_state.read_message(&crypto, 0, &m1).unwrap(); + responder_state.read_1(&crypto, 0, &m1).unwrap(); let m2 = responder_state - .write_message(&crypto, handshake_meta(11)) + .write_2(&crypto, handshake_meta(11)) .unwrap(); - initiator_state.read_message(&crypto, 0, &m2).unwrap(); + initiator_state.read_2(&crypto, 0, &m2).unwrap(); let initiator_final = initiator_state.finalize(&crypto).unwrap(); let responder_final = responder_state.finalize(&crypto).unwrap(); @@ -407,14 +384,14 @@ fn ik_handshake_round_trip_derives_matching_transport_with_bound_responder() { IkHandshake::new_responder(&crypto, responder.clone(), Some(initiator.bundle())); let m1 = initiator_state - .write_message(&crypto, handshake_meta(12)) + .write_1(&crypto, handshake_meta(12)) .unwrap(); - responder_state.read_message(&crypto, 0, &m1).unwrap(); + responder_state.read_1(&crypto, 0, &m1).unwrap(); let m2 = responder_state - .write_message(&crypto, handshake_meta(12)) + .write_2(&crypto, handshake_meta(12)) .unwrap(); - initiator_state.read_message(&crypto, 0, &m2).unwrap(); + initiator_state.read_2(&crypto, 0, &m2).unwrap(); let initiator_final = initiator_state.finalize(&crypto).unwrap(); let responder_final = responder_state.finalize(&crypto).unwrap(); @@ -449,14 +426,14 @@ fn kk_handshake_round_trip_derives_matching_transport() { KkHandshake::new_responder(&crypto, responder.clone(), initiator.bundle()); let m1 = initiator_state - .write_message(&crypto, handshake_meta(21)) + .write_1(&crypto, handshake_meta(21)) .unwrap(); - responder_state.read_message(&crypto, 0, &m1).unwrap(); + responder_state.read_1(&crypto, 0, &m1).unwrap(); let m2 = responder_state - .write_message(&crypto, handshake_meta(21)) + .write_2(&crypto, handshake_meta(21)) .unwrap(); - initiator_state.read_message(&crypto, 0, &m2).unwrap(); + initiator_state.read_2(&crypto, 0, &m2).unwrap(); let initiator_final = initiator_state.finalize(&crypto).unwrap(); let responder_final = responder_state.finalize(&crypto).unwrap(); @@ -566,36 +543,28 @@ fn protocol_record_size_breakdown() { IkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); let mut ik_responder = IkHandshake::new_responder(&crypto, responder.clone(), None); - let ik1 = ik_initiator - .write_message(&crypto, handshake_meta(101)) - .unwrap(); - ik_responder.read_message(&crypto, 0, &ik1).unwrap(); + let ik1 = ik_initiator.write_1(&crypto, handshake_meta(101)).unwrap(); + ik_responder.read_1(&crypto, 0, &ik1).unwrap(); - let ik2 = ik_responder - .write_message(&crypto, handshake_meta(101)) - .unwrap(); - ik_initiator.read_message(&crypto, 0, &ik2).unwrap(); + let ik2 = ik_responder.write_2(&crypto, handshake_meta(101)).unwrap(); + ik_initiator.read_2(&crypto, 0, &ik2).unwrap(); - let ik1 = ik_record(ik1); - let ik2 = ik_record(ik2); + let ik1 = QlHandshakeRecord::Ik1(ik1); + let ik2 = QlHandshakeRecord::Ik2(ik2); let mut kk_initiator = KkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); let mut kk_responder = KkHandshake::new_responder(&crypto, responder.clone(), initiator.bundle()); - let kk1 = kk_initiator - .write_message(&crypto, handshake_meta(201)) - .unwrap(); - kk_responder.read_message(&crypto, 0, &kk1).unwrap(); + let kk1 = kk_initiator.write_1(&crypto, handshake_meta(201)).unwrap(); + kk_responder.read_1(&crypto, 0, &kk1).unwrap(); - let kk2 = kk_responder - .write_message(&crypto, handshake_meta(201)) - .unwrap(); - kk_initiator.read_message(&crypto, 0, &kk2).unwrap(); + let kk2 = kk_responder.write_2(&crypto, handshake_meta(201)).unwrap(); + kk_initiator.read_2(&crypto, 0, &kk2).unwrap(); - let kk1 = kk_record(kk1); - let kk2 = kk_record(kk2); + let kk1 = QlHandshakeRecord::Kk1(kk1); + let kk2 = QlHandshakeRecord::Kk2(kk2); let session = ik_initiator.finalize(&crypto).unwrap(); let session_ping = encrypted::encrypt_record( From 3713ee1baef6bf10338b0a9e36264356716d75b7 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 2 Apr 2026 17:33:43 -0400 Subject: [PATCH 072/304] ql-fsm: bytes --- ql-fsm/src/session/mod.rs | 9 +--- ql-fsm/src/session/stream_rx.rs | 90 ++++++++++++++++++--------------- ql-fsm/src/session/stream_tx.rs | 18 +++---- 3 files changed, 59 insertions(+), 58 deletions(-) diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index dbb51db5..979aa363 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -212,10 +212,7 @@ impl SessionFsm { if len > stream.readable_bytes() { return Err(StreamError::InvalidRead); } - stream - .rx - .consume(len) - .map_err(|_| StreamError::InvalidRead)?; + stream.rx.consume(len); if stream.recv_limit() > stream.advertised_max_offset { stream.pending_window = true; } @@ -737,8 +734,7 @@ impl SessionFsm { self.try_reap_stream(stream_id); Ok(()) } - Err(StreamRxError::ConflictingOverlap) - | Err(StreamRxError::OutOfWindow) + Err(StreamRxError::OutOfWindow) | Err(StreamRxError::InconsistentFinalOffset) | Err(StreamRxError::FinalOffsetBeforeBufferedData) | Err(StreamRxError::BeyondFinalOffset) @@ -752,7 +748,6 @@ impl SessionFsm { ); Err(()) } - Err(StreamRxError::ConsumeBeyondReadable) => unreachable!(), } } diff --git a/ql-fsm/src/session/stream_rx.rs b/ql-fsm/src/session/stream_rx.rs index 34c3c797..31b18ec7 100644 --- a/ql-fsm/src/session/stream_rx.rs +++ b/ql-fsm/src/session/stream_rx.rs @@ -29,8 +29,6 @@ pub enum StreamRxError { InconsistentFinalOffset, FinalOffsetBeforeBufferedData, BeyondFinalOffset, - ConflictingOverlap, - ConsumeBeyondReadable, TooManyMissingRanges, } @@ -103,16 +101,6 @@ impl StreamRx { } } - #[cfg(test)] - pub fn copy_readable(&self) -> Vec { - let readable = self.readable_len(); - let mut out = Vec::with_capacity(readable); - for chunk in self.bytes() { - out.extend_from_slice(chunk); - } - out - } - pub fn is_complete(&self) -> bool { matches!(self.final_offset, Some(final_offset) if final_offset == self.buffered_end_offset()) && self.missing.is_empty() @@ -154,22 +142,23 @@ impl StreamRx { self.ensure_within_window(end)?; self.ensure_buffered(end)?; - self.validate_overlap(effective_offset, effective_bytes)?; + #[cfg(test)] + self.assert_valid_overlap(effective_offset, effective_bytes); self.write_bytes(effective_offset, effective_bytes); self.subtract_missing_range(effective_offset, end)?; Ok(self.insert_outcome(was_complete, old_readable)) } - pub fn consume(&mut self, len: usize) -> Result<(), StreamRxError> { + pub fn consume(&mut self, len: usize) { let readable = self.readable_len(); + debug_assert!(len <= readable, "consume beyond readable bytes"); if len > readable { - return Err(StreamRxError::ConsumeBeyondReadable); + return; } self.bytes.drain(..len); self.start_offset = self.start_offset.saturating_add(len as u64); - Ok(()) } fn insert_outcome(&self, was_complete: bool, old_readable: usize) -> InsertOutcome { @@ -234,7 +223,8 @@ impl StreamRx { self.missing.push(range) } - fn validate_overlap(&self, offset: u64, bytes: &[u8]) -> Result<(), StreamRxError> { + #[cfg(test)] + fn assert_valid_overlap(&self, offset: u64, bytes: &[u8]) { let mut gap_index = self.first_gap_index_after(offset); for (index, byte) in bytes.iter().copied().enumerate() { @@ -251,19 +241,31 @@ impl StreamRx { continue; } - if self.byte_at(absolute) != byte { - return Err(StreamRxError::ConflictingOverlap); - } - } + let index = + usize::try_from(absolute - self.start_offset).expect("read index exceeds usize"); - Ok(()) + assert_eq!( + self.bytes[index], byte, + "conflicting overlap at stream offset {absolute}" + ); + } } fn write_bytes(&mut self, offset: u64, bytes: &[u8]) { - let start_index = - usize::try_from(offset - self.start_offset).expect("write index exceeds usize"); - for (index, byte) in bytes.iter().copied().enumerate() { - self.bytes[start_index + index] = byte; + let start = usize::try_from(offset - self.start_offset).expect("write index exceeds usize"); + let (front, back) = self.bytes.as_mut_slices(); + + if start >= front.len() { + let start = start - front.len(); + back[start..start + bytes.len()].copy_from_slice(bytes); + return; + } + + let front_len = (front.len() - start).min(bytes.len()); + front[start..start + front_len].copy_from_slice(&bytes[..front_len]); + + if front_len < bytes.len() { + back[..bytes.len() - front_len].copy_from_slice(&bytes[front_len..]); } } @@ -335,11 +337,6 @@ impl StreamRx { .as_slice() .partition_point(|range| range.end <= offset) } - - fn byte_at(&self, offset: u64) -> u8 { - let index = usize::try_from(offset - self.start_offset).expect("read index exceeds usize"); - self.bytes[index] - } } impl<'a> Iterator for StreamReadIter<'a> { @@ -465,6 +462,15 @@ impl std::ops::IndexMut for MissingRanges { mod tests { use super::{InsertOutcome, MissingRange, StreamRx, StreamRxError}; + pub fn copy_readable(rx: &StreamRx) -> Vec { + let readable = rx.readable_len(); + let mut out = Vec::with_capacity(readable); + for chunk in rx.bytes() { + out.extend_from_slice(chunk); + } + out + } + #[test] fn contiguous_insert_becomes_readable_and_complete() { let mut rx = StreamRx::<8>::new(64); @@ -479,7 +485,7 @@ mod tests { } ); assert_eq!(rx.readable_len(), 5); - assert_eq!(rx.copy_readable(), b"hello"); + assert_eq!(copy_readable(&rx), b"hello"); assert_eq!(rx.final_offset, Some(5)); assert!(rx.is_complete()); assert!(rx.missing.is_empty()); @@ -508,7 +514,7 @@ mod tests { became_complete: true, } ); - assert_eq!(rx.copy_readable(), b"hello world"); + assert_eq!(copy_readable(&rx), b"hello world"); assert!(rx.missing.is_empty()); assert!(rx.is_complete()); } @@ -527,17 +533,16 @@ mod tests { became_complete: false, } ); - assert_eq!(rx.copy_readable(), b"hello"); + assert_eq!(copy_readable(&rx), b"hello"); } #[test] - fn conflicting_overlap_is_rejected() { + #[should_panic(expected = "conflicting overlap at stream offset 3")] + fn conflicting_overlap_panics_in_test_builds() { let mut rx = StreamRx::<8>::new(64); rx.insert(0, false, b"abcdef").unwrap(); - let error = rx.insert(3, false, b"xyz").unwrap_err(); - - assert_eq!(error, StreamRxError::ConflictingOverlap); + rx.insert(3, false, b"xyz").unwrap(); } #[test] @@ -545,9 +550,9 @@ mod tests { let mut rx = StreamRx::<8>::new(64); rx.insert(0, false, b"abcd").unwrap(); - rx.consume(2).unwrap(); + rx.consume(2); assert_eq!(rx.start_offset(), 2); - assert_eq!(rx.copy_readable(), b"cd"); + assert_eq!(copy_readable(&rx), b"cd"); let outcome = rx.insert(1, true, b"bcde").unwrap(); assert_eq!( @@ -557,7 +562,7 @@ mod tests { became_complete: true, } ); - assert_eq!(rx.copy_readable(), b"cde"); + assert_eq!(copy_readable(&rx), b"cde"); assert_eq!(rx.final_offset, Some(5)); assert!(rx.is_complete()); } @@ -599,7 +604,8 @@ mod tests { } ); assert!(rx.missing.is_empty()); - assert_eq!(rx.copy_readable(), b"abcdefghij"); + + assert_eq!(copy_readable(&rx), b"abcdefghij"); assert!(rx.is_complete()); } } diff --git a/ql-fsm/src/session/stream_tx.rs b/ql-fsm/src/session/stream_tx.rs index 4a2e21a5..5146522c 100644 --- a/ql-fsm/src/session/stream_tx.rs +++ b/ql-fsm/src/session/stream_tx.rs @@ -81,7 +81,7 @@ impl StreamTx { } let start = self.end_offset(); - self.bytes.extend(bytes.iter().copied()); + self.bytes.extend(bytes); if let Some(last) = self.segments.back_mut() { if last.state == SendState::Unsent && last.end_offset() == start { last.len += bytes.len(); @@ -130,20 +130,20 @@ impl StreamTx { if segment.state == SendState::Lost { return Some(range); } - unsent = Some(range); + if unsent.is_none() { + unsent = Some(range); + } } if let Some(range) = unsent { return Some(range); } - let final_offset = self.final_offset?; - if !matches!(final_offset.state, SendState::Lost | SendState::Unsent) { - return None; - } - if final_offset.offset > peer_max_offset { - return None; - } + let final_offset = self.final_offset.filter(|final_offset| { + matches!(final_offset.state, SendState::Lost | SendState::Unsent) + && final_offset.offset <= peer_max_offset + })?; + Some(StreamTxRange { offset: final_offset.offset, len: 0, From 806c5572e0eabadb3faeced48cc145361e5d1051 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 3 Apr 2026 09:17:50 -0400 Subject: [PATCH 073/304] ql-fsm: break out receivedrecords to module --- ql-fsm/src/session/mod.rs | 9 ++- ql-fsm/src/session/received_records.rs | 74 +++++++++++++++++++++++++ ql-fsm/src/session/state.rs | 77 +------------------------- 3 files changed, 81 insertions(+), 79 deletions(-) create mode 100644 ql-fsm/src/session/received_records.rs diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 979aa363..7125706c 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod received_records; pub(crate) mod state; pub(crate) mod stream_rx; pub(crate) mod stream_tx; @@ -15,17 +16,15 @@ use ql_wire::{ }; use self::{ + received_records::{ReceiveInsertOutcome, ReceivedRecords}, state::{ - AckState, InboundState, OutboundRecord, OutboundState, ReceiveInsertOutcome, - ReceivedRecords, ReliableFrame, SessionFsmState, StreamDataManifest, StreamParity, - StreamRole, StreamState, + AckState, InboundState, OutboundRecord, OutboundState, ReliableFrame, SessionFsmState, + StreamDataManifest, StreamParity, StreamRole, StreamState, }, stream_rx::{StreamReadIter, StreamRxError}, stream_tx::StreamTxRange, }; -pub(crate) const SESSION_RECORD_TRACKED_WINDOW: u64 = 256; - #[derive(Debug, Clone, Copy)] pub struct SessionFsmConfig { pub local_parity: StreamParity, diff --git a/ql-fsm/src/session/received_records.rs b/ql-fsm/src/session/received_records.rs new file mode 100644 index 00000000..82341847 --- /dev/null +++ b/ql-fsm/src/session/received_records.rs @@ -0,0 +1,74 @@ +use std::collections::BTreeSet; + +use ql_wire::{RecordAck, RecordAckRange, RecordSeq}; + +#[derive(Debug, Default)] +pub struct ReceivedRecords { + seen: BTreeSet, + largest: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReceiveInsertOutcome { + New { out_of_order: bool }, + Duplicate, +} + +impl ReceivedRecords { + const TRACKED_WINDOW: u64 = 256; + + pub fn insert(&mut self, seq: RecordSeq) -> ReceiveInsertOutcome { + if self.seen.contains(&seq.0) { + return ReceiveInsertOutcome::Duplicate; + } + + if self + .largest + .is_some_and(|largest| largest.saturating_sub(seq.0) > Self::TRACKED_WINDOW) + { + return ReceiveInsertOutcome::Duplicate; + } + + let out_of_order = self + .largest + .is_some_and(|largest| seq.0 != largest.saturating_add(1)); + self.seen.insert(seq.0); + self.largest = Some(self.largest.map_or(seq.0, |largest| largest.max(seq.0))); + self.prune(); + ReceiveInsertOutcome::New { out_of_order } + } + + pub fn ack(&self) -> Option { + if self.seen.is_empty() { + return None; + } + + let mut ranges = Vec::new(); + let mut iter = self.seen.iter().copied(); + let first = iter.next()?; + let mut start = first; + let mut end = first.saturating_add(1); + + for seq in iter { + if seq == end { + end = end.saturating_add(1); + continue; + } + + ranges.push(RecordAckRange { start, end }); + start = seq; + end = seq.saturating_add(1); + } + + ranges.push(RecordAckRange { start, end }); + Some(RecordAck { ranges }) + } + + fn prune(&mut self) { + let Some(largest) = self.largest else { + return; + }; + let keep_from = largest.saturating_sub(Self::TRACKED_WINDOW); + self.seen.retain(|seq| *seq >= keep_from); + } +} diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index fe4af462..fccbe046 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -1,12 +1,10 @@ -use std::{collections::BTreeSet, time::Instant}; +use std::time::Instant; use indexmap::IndexMap; -use ql_wire::{ - CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionClose, StreamClose, StreamId, XID, -}; +use ql_wire::{CloseTarget, RecordSeq, SessionClose, StreamClose, StreamId, XID}; use super::{ - stream_rx::StreamRx, stream_tx::StreamTx, SessionState, SESSION_RECORD_TRACKED_WINDOW, + received_records::ReceivedRecords, stream_rx::StreamRx, stream_tx::StreamTx, SessionState, }; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -183,75 +181,6 @@ pub enum AckState { Immediate, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ReceiveInsertOutcome { - New { out_of_order: bool }, - Duplicate, -} - -#[derive(Debug, Default)] -pub struct ReceivedRecords { - seen: BTreeSet, - largest: Option, -} - -impl ReceivedRecords { - pub fn insert(&mut self, seq: RecordSeq) -> ReceiveInsertOutcome { - if self.seen.contains(&seq.0) { - return ReceiveInsertOutcome::Duplicate; - } - - if self - .largest - .is_some_and(|largest| largest.saturating_sub(seq.0) > SESSION_RECORD_TRACKED_WINDOW) - { - return ReceiveInsertOutcome::Duplicate; - } - - let out_of_order = self - .largest - .is_some_and(|largest| seq.0 != largest.saturating_add(1)); - self.seen.insert(seq.0); - self.largest = Some(self.largest.map_or(seq.0, |largest| largest.max(seq.0))); - self.prune(); - ReceiveInsertOutcome::New { out_of_order } - } - - pub fn ack(&self) -> Option { - if self.seen.is_empty() { - return None; - } - - let mut ranges = Vec::new(); - let mut iter = self.seen.iter().copied(); - let first = iter.next()?; - let mut start = first; - let mut end = first.saturating_add(1); - - for seq in iter { - if seq == end { - end = end.saturating_add(1); - continue; - } - - ranges.push(RecordAckRange { start, end }); - start = seq; - end = seq.saturating_add(1); - } - - ranges.push(RecordAckRange { start, end }); - Some(RecordAck { ranges }) - } - - fn prune(&mut self) { - let Some(largest) = self.largest else { - return; - }; - let keep_from = largest.saturating_sub(SESSION_RECORD_TRACKED_WINDOW); - self.seen.retain(|seq| *seq >= keep_from); - } -} - pub struct SessionFsmState { pub now: Instant, pub last_activity_at: Instant, From 7573aae7ae5f407e84a174a912c38ec510e107f0 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 3 Apr 2026 09:27:11 -0400 Subject: [PATCH 074/304] ql-fsm: break out stream_parity --- ql-fsm/src/implementation/mod.rs | 2 +- ql-fsm/src/lib.rs | 2 +- ql-fsm/src/session/mod.rs | 4 +- ql-fsm/src/session/state.rs | 169 +++++++++++----------------- ql-fsm/src/session/stream_parity.rs | 44 ++++++++ ql-fsm/src/session/tests.rs | 3 +- ql-fsm/src/tests/mod.rs | 2 +- ql-fsm/src/tests/session.rs | 2 +- 8 files changed, 116 insertions(+), 112 deletions(-) create mode 100644 ql-fsm/src/session/stream_parity.rs diff --git a/ql-fsm/src/implementation/mod.rs b/ql-fsm/src/implementation/mod.rs index 79d2f3f2..d24f21c3 100644 --- a/ql-fsm/src/implementation/mod.rs +++ b/ql-fsm/src/implementation/mod.rs @@ -7,7 +7,7 @@ pub use fsm::*; pub use handshake::*; use crate::{ - session::{state::StreamParity, SessionEvent, SessionFsmConfig}, + session::{stream_parity::StreamParity, SessionEvent, SessionFsmConfig}, state::LinkState, QlFsm, QlFsmEvent, QlSessionEvent, }; diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index e9d130f3..ea842817 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -164,7 +164,7 @@ impl QlFsm { identity, session: session::SessionFsm::new( session::SessionFsmConfig { - local_parity: session::state::StreamParity::Even, + local_parity: session::stream_parity::StreamParity::Even, record_size: config.session_record_size, ack_delay: config.session_record_ack_delay, retransmit_timeout: config.session_record_retransmit_timeout, diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 7125706c..9d0fec67 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -1,5 +1,6 @@ pub(crate) mod received_records; pub(crate) mod state; +pub(crate) mod stream_parity; pub(crate) mod stream_rx; pub(crate) mod stream_tx; @@ -19,8 +20,9 @@ use self::{ received_records::{ReceiveInsertOutcome, ReceivedRecords}, state::{ AckState, InboundState, OutboundRecord, OutboundState, ReliableFrame, SessionFsmState, - StreamDataManifest, StreamParity, StreamRole, StreamState, + StreamDataManifest, StreamRole, StreamState, }, + stream_parity::StreamParity, stream_rx::{StreamReadIter, StreamRxError}, stream_tx::StreamTxRange, }; diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index fccbe046..868c706d 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -1,91 +1,26 @@ use std::time::Instant; use indexmap::IndexMap; -use ql_wire::{CloseTarget, RecordSeq, SessionClose, StreamClose, StreamId, XID}; +use ql_wire::{CloseTarget, RecordSeq, SessionClose, StreamClose, StreamId}; use super::{ received_records::ReceivedRecords, stream_rx::StreamRx, stream_tx::StreamTx, SessionState, }; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum StreamParity { - Even, - Odd, -} - -impl StreamParity { - pub fn for_local(local: XID, peer: XID) -> Self { - match local.0.cmp(&peer.0) { - std::cmp::Ordering::Less | std::cmp::Ordering::Equal => Self::Even, - std::cmp::Ordering::Greater => Self::Odd, - } - } - - pub const fn first_stream_id(self) -> u32 { - match self { - Self::Even => 0, - Self::Odd => 1, - } - } - - pub const fn matches(self, stream_id: StreamId) -> bool { - match self { - Self::Even => stream_id.0 % 2 == 0, - Self::Odd => stream_id.0 % 2 == 1, - } - } - - pub const fn remote(self) -> Self { - match self { - Self::Even => Self::Odd, - Self::Odd => Self::Even, - } - } - - pub fn make_stream_id(self, ordinal: u32) -> StreamId { - StreamId( - self.first_stream_id() - .saturating_add(ordinal.saturating_mul(2)), - ) - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum StreamRole { - Initiator, - Responder, -} - -impl StreamRole { - pub fn outbound_target(self) -> CloseTarget { - match self { - Self::Initiator => CloseTarget::Request, - Self::Responder => CloseTarget::Response, - } - } - - pub fn inbound_target(self) -> CloseTarget { - match self { - Self::Initiator => CloseTarget::Response, - Self::Responder => CloseTarget::Request, - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum OutboundState { - Open, - FinQueued, - Finished, - Closed, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum InboundState { - Open, - Finished, - Closed(StreamClose), - Discarding, +pub struct SessionFsmState { + pub now: Instant, + pub last_activity_at: Instant, + pub last_inbound_at: Instant, + pub session_state: SessionState, + pub next_stream_ordinal: u32, + pub next_record_seq: RecordSeq, + pub next_write_id: u64, + pub outbound_records: IndexMap, + pub received_records: ReceivedRecords, + pub ack_state: AckState, + pub pending_control: PendingSessionControl, + pub streams: IndexMap, + pub next_stream_index: usize, } #[derive(Debug)] @@ -143,6 +78,54 @@ impl StreamState { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamRole { + Initiator, + Responder, +} + +impl StreamRole { + pub fn outbound_target(self) -> CloseTarget { + match self { + Self::Initiator => CloseTarget::Request, + Self::Responder => CloseTarget::Response, + } + } + + pub fn inbound_target(self) -> CloseTarget { + match self { + Self::Initiator => CloseTarget::Response, + Self::Responder => CloseTarget::Request, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum OutboundState { + Open, + FinQueued, + Finished, + Closed, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum InboundState { + Open, + Finished, + Closed(StreamClose), + Discarding, +} + +#[derive(Debug, Clone)] +pub struct OutboundRecord { + pub seq: RecordSeq, + pub reliable: Vec, + pub ack_included: bool, + pub ping_included: bool, + pub window_updates: Vec<(StreamId, u64)>, + pub sent_at: Option, +} + #[derive(Debug, Clone)] pub enum ReliableFrame { StreamData(StreamDataManifest), @@ -158,16 +141,6 @@ pub struct StreamDataManifest { pub fin: bool, } -#[derive(Debug, Clone)] -pub struct OutboundRecord { - pub seq: RecordSeq, - pub reliable: Vec, - pub ack_included: bool, - pub ping_included: bool, - pub window_updates: Vec<(StreamId, u64)>, - pub sent_at: Option, -} - #[derive(Debug, Clone, Default)] pub struct PendingSessionControl { pub ping: bool, @@ -180,19 +153,3 @@ pub enum AckState { Delayed { due_at: Instant }, Immediate, } - -pub struct SessionFsmState { - pub now: Instant, - pub last_activity_at: Instant, - pub last_inbound_at: Instant, - pub session_state: SessionState, - pub next_stream_ordinal: u32, - pub next_record_seq: RecordSeq, - pub next_write_id: u64, - pub outbound_records: IndexMap, - pub received_records: ReceivedRecords, - pub ack_state: AckState, - pub pending_control: PendingSessionControl, - pub streams: IndexMap, - pub next_stream_index: usize, -} diff --git a/ql-fsm/src/session/stream_parity.rs b/ql-fsm/src/session/stream_parity.rs new file mode 100644 index 00000000..8b95ad51 --- /dev/null +++ b/ql-fsm/src/session/stream_parity.rs @@ -0,0 +1,44 @@ +use ql_wire::{StreamId, XID}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamParity { + Even, + Odd, +} + +impl StreamParity { + pub fn for_local(local: XID, peer: XID) -> Self { + match local.0.cmp(&peer.0) { + std::cmp::Ordering::Less | std::cmp::Ordering::Equal => Self::Even, + std::cmp::Ordering::Greater => Self::Odd, + } + } + + pub const fn first_stream_id(self) -> u32 { + match self { + Self::Even => 0, + Self::Odd => 1, + } + } + + pub const fn matches(self, stream_id: StreamId) -> bool { + match self { + Self::Even => stream_id.0 % 2 == 0, + Self::Odd => stream_id.0 % 2 == 1, + } + } + + pub const fn remote(self) -> Self { + match self { + Self::Even => Self::Odd, + Self::Odd => Self::Even, + } + } + + pub fn make_stream_id(self, ordinal: u32) -> StreamId { + StreamId( + self.first_stream_id() + .saturating_add(ordinal.saturating_mul(2)), + ) + } +} diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index d0077006..7abc4b9f 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -5,7 +5,8 @@ use ql_wire::{ StreamCloseCode, StreamData, StreamId, XID, }; -use super::{state::StreamParity, SessionEvent, SessionFsm, SessionFsmConfig}; +use super::{SessionEvent, SessionFsm, SessionFsmConfig}; +use crate::session::stream_parity::StreamParity; fn read_stream_all(fsm: &mut SessionFsm, stream_id: StreamId) -> Vec { let mut out = Vec::new(); diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index f31dbc42..9125984f 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -16,7 +16,7 @@ use ql_wire::{ use sha2::{Digest, Sha256}; use crate::{ - session::{state::StreamParity, SessionFsm, SessionFsmConfig}, + session::{stream_parity::StreamParity, SessionFsm, SessionFsmConfig}, state::{LinkState, SessionTransport}, FsmTime, OutboundWrite, QlFsm, QlFsmConfig, SessionWriteId, }; diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 2ec9573b..998462ad 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -3,7 +3,7 @@ use std::time::Duration; use ql_wire::{SessionClose, StreamId}; use super::*; -use crate::{session::state::StreamParity, state::LinkState, QlFsmEvent, QlSessionEvent}; +use crate::{state::LinkState, QlFsmEvent, QlSessionEvent}; fn read_stream_all(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { let mut out = Vec::new(); From 457405e21003dfc3375e6bd8911e9d0bec2fa76c Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 3 Apr 2026 09:38:24 -0400 Subject: [PATCH 075/304] ql-fsm: break out tracked module --- ql-fsm/src/session/mod.rs | 97 +++++++++++++++++------------------ ql-fsm/src/session/state.rs | 30 ++--------- ql-fsm/src/session/tracked.rs | 30 +++++++++++ 3 files changed, 81 insertions(+), 76 deletions(-) create mode 100644 ql-fsm/src/session/tracked.rs diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 9d0fec67..b281dd19 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -3,6 +3,7 @@ pub(crate) mod state; pub(crate) mod stream_parity; pub(crate) mod stream_rx; pub(crate) mod stream_tx; +pub(crate) mod tracked; #[cfg(test)] mod tests; @@ -18,13 +19,11 @@ use ql_wire::{ use self::{ received_records::{ReceiveInsertOutcome, ReceivedRecords}, - state::{ - AckState, InboundState, OutboundRecord, OutboundState, ReliableFrame, SessionFsmState, - StreamDataManifest, StreamRole, StreamState, - }, + state::{AckState, InboundState, OutboundState, SessionFsmState, StreamRole, StreamState}, stream_parity::StreamParity, stream_rx::{StreamReadIter, StreamRxError}, stream_tx::StreamTxRange, + tracked::{TrackedFrame, TrackedRecord, TrackedStreamData}, }; #[derive(Debug, Clone, Copy)] @@ -103,7 +102,7 @@ impl SessionFsm { next_stream_ordinal: 0, next_record_seq: RecordSeq(0), next_write_id: 0, - outbound_records: Default::default(), + tracked_records: Default::default(), received_records: ReceivedRecords::default(), ack_state: AckState::Idle, pending_control: Default::default(), @@ -307,7 +306,7 @@ impl SessionFsm { pub fn confirm_write(&mut self, now: Instant, write_id: u64) { self.state.now = now; - let Some(record) = self.state.outbound_records.get_mut(&write_id) else { + let Some(record) = self.state.tracked_records.get_mut(&write_id) else { return; }; if record.sent_at.is_some() { @@ -320,16 +319,16 @@ impl SessionFsm { pub fn reject_write(&mut self, write_id: u64) { if self .state - .outbound_records + .tracked_records .get(&write_id) .is_some_and(|record| record.sent_at.is_some()) { return; } - let Some(record) = self.state.outbound_records.shift_remove(&write_id) else { + let Some(record) = self.state.tracked_records.shift_remove(&write_id) else { return; }; - restore_outbound_record( + restore_tracked_record( self.state.now, self.config.ack_delay, &mut self.state.ack_state, @@ -374,7 +373,7 @@ impl SessionFsm { }; let retransmit_deadline = self .state - .outbound_records + .tracked_records .values() .filter_map(|record| { record @@ -417,16 +416,16 @@ impl SessionFsm { let write_id = self.state.next_write_id; self.state.next_write_id = self.state.next_write_id.wrapping_add(1); let seq = outbound.seq; - self.state.outbound_records.insert(write_id, outbound); + self.state.tracked_records.insert(write_id, outbound); Some((write_id, seq, builder)) } - fn build_next_record(&mut self) -> Option<(SessionRecordBuilder, OutboundRecord)> { + fn build_next_record(&mut self) -> Option<(SessionRecordBuilder, TrackedRecord)> { let seq = self.state.next_record_seq; let mut builder = SessionRecordBuilder::new(self.config.record_size); - let mut outbound = OutboundRecord { + let mut outbound = TrackedRecord { seq, - reliable: Vec::new(), + frames: Vec::new(), ack_included: false, ping_included: false, window_updates: Vec::new(), @@ -445,7 +444,7 @@ impl SessionFsm { if let Some(close) = self.state.pending_control.close.clone() { if builder.push_close(&close) { self.state.pending_control.close = None; - outbound.reliable.push(ReliableFrame::Close(close)); + outbound.frames.push(TrackedFrame::Close(close)); } } @@ -471,7 +470,7 @@ impl SessionFsm { fn push_next_pending_stream_close( &mut self, builder: &mut SessionRecordBuilder, - outbound: &mut OutboundRecord, + outbound: &mut TrackedRecord, ) -> bool { let len = self.state.streams.len(); if len == 0 { @@ -493,7 +492,7 @@ impl SessionFsm { let stream = self.state.streams.get_index_mut(index).unwrap().1; self.state.next_stream_index = (index + 1) % len; - outbound.reliable.push(ReliableFrame::StreamClose( + outbound.frames.push(TrackedFrame::StreamClose( stream.pending_close.take().unwrap(), )); return true; @@ -505,7 +504,7 @@ impl SessionFsm { fn push_next_pending_stream_window( &mut self, builder: &mut SessionRecordBuilder, - outbound: &mut OutboundRecord, + outbound: &mut TrackedRecord, ) -> bool { let len = self.state.streams.len(); if len == 0 { @@ -545,7 +544,7 @@ impl SessionFsm { fn push_next_stream_data( &mut self, builder: &mut SessionRecordBuilder, - outbound: &mut OutboundRecord, + outbound: &mut TrackedRecord, ) -> bool { let Some(max_payload) = self.max_stream_data_payload(builder) else { return false; @@ -587,8 +586,8 @@ impl SessionFsm { } self.state.next_stream_index = (index + 1) % len; outbound - .reliable - .push(ReliableFrame::StreamData(StreamDataManifest { + .frames + .push(TrackedFrame::StreamData(TrackedStreamData { stream_id, offset: candidate.offset, len: candidate.len, @@ -623,17 +622,17 @@ impl SessionFsm { fn process_record_ack(&mut self, ack: RecordAck, emit: &mut impl FnMut(SessionEvent)) { let stream_send_buffer_size = self.config.stream_send_buffer_size; { - let outbound_records = &mut self.state.outbound_records; + let tracked_records = &mut self.state.tracked_records; let streams = &mut self.state.streams; - for (_, record) in outbound_records.extract_if(.., |_, record| { + for (_, record) in tracked_records.extract_if(.., |_, record| { record.sent_at.is_some() && ack .ranges .iter() .any(|range| range.start <= record.seq.0 && record.seq.0 < range.end) }) { - for frame in &record.reliable { - acknowledge_reliable_frame(streams, stream_send_buffer_size, frame, emit); + for frame in &record.frames { + acknowledge_tracked_frame(streams, stream_send_buffer_size, frame, emit); } } } @@ -662,12 +661,12 @@ impl SessionFsm { fn collect_timeouts(&mut self) { let retransmit_timeout = self.config.retransmit_timeout; - for (_, record) in self.state.outbound_records.extract_if(.., |_, record| { + for (_, record) in self.state.tracked_records.extract_if(.., |_, record| { record .sent_at .is_some_and(|sent_at| sent_at + retransmit_timeout <= self.state.now) }) { - restore_outbound_record( + restore_tracked_record( self.state.now, self.config.ack_delay, &mut self.state.ack_state, @@ -835,7 +834,7 @@ impl SessionFsm { } self.state.session_state = SessionState::Closed; - self.state.outbound_records.clear(); + self.state.tracked_records.clear(); self.clear_streams(); self.state.pending_control = Default::default(); emit(SessionEvent::SessionClosed(close)); @@ -861,15 +860,15 @@ impl SessionFsm { } fn stream_is_reapable(&self, stream_id: StreamId, stream: &StreamState) -> bool { - let outbound_refs_stream = self.state.outbound_records.values().any(|record| { + let tracked_refs_stream = self.state.tracked_records.values().any(|record| { record.window_updates.iter().any(|(id, _)| *id == stream_id) - || record.reliable.iter().any(|frame| match frame { - ReliableFrame::StreamData(frame) => frame.stream_id == stream_id, - ReliableFrame::StreamClose(frame) => frame.stream_id == stream_id, - ReliableFrame::Close(_) => false, + || record.frames.iter().any(|frame| match frame { + TrackedFrame::StreamData(frame) => frame.stream_id == stream_id, + TrackedFrame::StreamClose(frame) => frame.stream_id == stream_id, + TrackedFrame::Close(_) => false, }) }); - if outbound_refs_stream { + if tracked_refs_stream { return false; } @@ -935,7 +934,7 @@ impl SessionFsm { } self.state.session_state = SessionState::Closed; - self.state.outbound_records.clear(); + self.state.tracked_records.clear(); self.state.pending_control = Default::default(); self.state.pending_control.close = Some(close.clone()); self.clear_streams(); @@ -959,13 +958,13 @@ fn schedule_ack(ack_state: &mut AckState, now: Instant, ack_delay: Duration, imm }; } -fn restore_outbound_record( +fn restore_tracked_record( now: Instant, ack_delay: Duration, ack_state: &mut AckState, pending_control: &mut state::PendingSessionControl, streams: &mut IndexMap, - record: OutboundRecord, + record: TrackedRecord, ) { if record.ack_included { schedule_ack(ack_state, now, ack_delay, true); @@ -980,22 +979,22 @@ fn restore_outbound_record( } } } - for frame in record.reliable { - requeue_reliable_frame(pending_control, streams, frame); + for frame in record.frames { + requeue_tracked_frame(pending_control, streams, frame); } } -fn requeue_reliable_frame( +fn requeue_tracked_frame( pending_control: &mut state::PendingSessionControl, streams: &mut IndexMap, - frame: ReliableFrame, + frame: TrackedFrame, ) { match frame { - ReliableFrame::Close(close) => { + TrackedFrame::Close(close) => { pending_control.close = Some(close); } - ReliableFrame::StreamClose(close) => restore_stream_close(streams, close), - ReliableFrame::StreamData(frame) => restore_stream_data(streams, frame), + TrackedFrame::StreamClose(close) => restore_stream_close(streams, close), + TrackedFrame::StreamData(frame) => restore_stream_data(streams, frame), } } @@ -1005,7 +1004,7 @@ fn restore_stream_close(streams: &mut IndexMap, close: St } } -fn restore_stream_data(streams: &mut IndexMap, frame: StreamDataManifest) { +fn restore_stream_data(streams: &mut IndexMap, frame: TrackedStreamData) { if let Some(stream) = streams.get_mut(&frame.stream_id) { if matches!(stream.outbound_state, OutboundState::Closed) { return; @@ -1021,15 +1020,15 @@ fn restore_stream_data(streams: &mut IndexMap, frame: Str } } -fn acknowledge_reliable_frame( +fn acknowledge_tracked_frame( streams: &mut IndexMap, stream_send_buffer_size: usize, - frame: &ReliableFrame, + frame: &TrackedFrame, emit: &mut impl FnMut(SessionEvent), ) { match frame { - ReliableFrame::Close(_) | ReliableFrame::StreamClose(_) => {} - ReliableFrame::StreamData(frame) => { + TrackedFrame::Close(_) | TrackedFrame::StreamClose(_) => {} + TrackedFrame::StreamData(frame) => { let stream_id = frame.stream_id; if let Some(stream) = streams.get_mut(&stream_id) { let was_full = stream.send_capacity(stream_send_buffer_size) == 0; diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index 868c706d..0a373782 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -4,7 +4,8 @@ use indexmap::IndexMap; use ql_wire::{CloseTarget, RecordSeq, SessionClose, StreamClose, StreamId}; use super::{ - received_records::ReceivedRecords, stream_rx::StreamRx, stream_tx::StreamTx, SessionState, + received_records::ReceivedRecords, stream_rx::StreamRx, stream_tx::StreamTx, + tracked::TrackedRecord, SessionState, }; pub struct SessionFsmState { @@ -15,7 +16,7 @@ pub struct SessionFsmState { pub next_stream_ordinal: u32, pub next_record_seq: RecordSeq, pub next_write_id: u64, - pub outbound_records: IndexMap, + pub tracked_records: IndexMap, pub received_records: ReceivedRecords, pub ack_state: AckState, pub pending_control: PendingSessionControl, @@ -116,31 +117,6 @@ pub enum InboundState { Discarding, } -#[derive(Debug, Clone)] -pub struct OutboundRecord { - pub seq: RecordSeq, - pub reliable: Vec, - pub ack_included: bool, - pub ping_included: bool, - pub window_updates: Vec<(StreamId, u64)>, - pub sent_at: Option, -} - -#[derive(Debug, Clone)] -pub enum ReliableFrame { - StreamData(StreamDataManifest), - StreamClose(StreamClose), - Close(SessionClose), -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct StreamDataManifest { - pub stream_id: StreamId, - pub offset: u64, - pub len: usize, - pub fin: bool, -} - #[derive(Debug, Clone, Default)] pub struct PendingSessionControl { pub ping: bool, diff --git a/ql-fsm/src/session/tracked.rs b/ql-fsm/src/session/tracked.rs new file mode 100644 index 00000000..1c7bd798 --- /dev/null +++ b/ql-fsm/src/session/tracked.rs @@ -0,0 +1,30 @@ +//! outbound record tracking state for ack and retransmit handling + +use std::time::Instant; + +use ql_wire::{RecordSeq, SessionClose, StreamClose, StreamId}; + +#[derive(Debug, Clone)] +pub struct TrackedRecord { + pub seq: RecordSeq, + pub frames: Vec, + pub ack_included: bool, + pub ping_included: bool, + pub window_updates: Vec<(StreamId, u64)>, + pub sent_at: Option, +} + +#[derive(Debug, Clone)] +pub enum TrackedFrame { + StreamData(TrackedStreamData), + StreamClose(StreamClose), + Close(SessionClose), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TrackedStreamData { + pub stream_id: StreamId, + pub offset: u64, + pub len: usize, + pub fin: bool, +} From 50594d31dd00fb0736cd4d417ff47b1d4e7b0311 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 3 Apr 2026 09:58:10 -0400 Subject: [PATCH 076/304] ql-fsm: refactor implementation --- ql-fsm/src/implementation/{fsm.rs => core.rs} | 117 ++++++++++++++++-- ql-fsm/src/implementation/handshake/ik.rs | 3 +- ql-fsm/src/implementation/handshake/kk.rs | 3 +- ql-fsm/src/implementation/handshake/mod.rs | 16 ++- ql-fsm/src/implementation/mod.rs | 117 +----------------- 5 files changed, 127 insertions(+), 129 deletions(-) rename ql-fsm/src/implementation/{fsm.rs => core.rs} (57%) diff --git a/ql-fsm/src/implementation/fsm.rs b/ql-fsm/src/implementation/core.rs similarity index 57% rename from ql-fsm/src/implementation/fsm.rs rename to ql-fsm/src/implementation/core.rs index d5ede6b6..07290870 100644 --- a/ql-fsm/src/implementation/fsm.rs +++ b/ql-fsm/src/implementation/core.rs @@ -1,11 +1,27 @@ -use std::time::Instant; +use std::{ + collections::VecDeque, + time::{Duration, Instant}, +}; use ql_wire::{ self as wire, CloseTarget, QlCrypto, SessionClose, SessionCloseCode, SessionHeader, StreamCloseCode, StreamId, }; -use crate::{OutboundWrite, QlFsm, QlFsmError, QlSessionEvent, SessionWriteId, StreamReadIter}; +use crate::{ + session::{stream_parity::StreamParity, SessionEvent, SessionFsmConfig}, + state::LinkState, + OutboundWrite, QlFsm, QlFsmError, QlFsmEvent, QlSessionEvent, SessionWriteId, StreamReadIter, +}; + +pub fn handle_bind_peer(fsm: &mut QlFsm, peer: ql_wire::PeerBundle) { + fsm.state.handshake = None; + fsm.state.link = LinkState::Idle; + fsm.state.peer = Some(peer.clone()); + reset_session(fsm); + fsm.state.events.push_back(QlFsmEvent::NewPeer(peer)); + emit_peer_status(fsm); +} pub fn receive( fsm: &mut QlFsm, @@ -28,11 +44,11 @@ pub fn receive( .receive(fsm.state.now.instant, record.header.seq, frames, { let session_events = &mut fsm.state.session_events; |event| { - session_closed |= super::forward_session_event(session_events, event); + session_closed |= forward_session_event(session_events, event); } }); if session_closed { - super::apply_session_closed(fsm); + apply_session_closed(fsm); } Ok(()) } @@ -46,11 +62,11 @@ pub fn on_timer(fsm: &mut QlFsm) { fsm.session.on_timer(fsm.state.now.instant, { let session_events = &mut fsm.state.session_events; |event| { - session_closed |= super::forward_session_event(session_events, event); + session_closed |= forward_session_event(session_events, event); } }); if session_closed { - super::apply_session_closed(fsm); + apply_session_closed(fsm); } } } @@ -111,8 +127,8 @@ pub fn kill_session(fsm: &mut QlFsm, code: SessionCloseCode) { } fsm.state.link = crate::state::LinkState::Idle; - super::emit_peer_status(fsm); - super::reset_session(fsm); + emit_peer_status(fsm); + reset_session(fsm); fsm.state .session_events .push_back(QlSessionEvent::SessionClosed(SessionClose { code })); @@ -168,6 +184,81 @@ pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), QlFsmError> { Ok(fsm.session.queue_ping()?) } +pub fn emit_peer_status(fsm: &mut QlFsm) { + if let Some(peer) = fsm.state.peer.as_ref() { + fsm.state.events.push_back(QlFsmEvent::PeerStatusChanged { + peer: peer.xid, + status: fsm.state.link.status(), + }); + } +} + +pub fn reset_session(fsm: &mut QlFsm) { + let local_parity = fsm + .state + .peer + .as_ref() + .map(|peer| StreamParity::for_local(fsm.identity.xid, peer.xid)) + .unwrap_or(StreamParity::Even); + fsm.session = crate::session::SessionFsm::new( + SessionFsmConfig { + local_parity, + record_size: fsm.config.session_record_size, + ack_delay: fsm.config.session_record_ack_delay, + retransmit_timeout: fsm.config.session_record_retransmit_timeout, + keepalive_interval: fsm.config.session_keepalive_interval, + peer_timeout: fsm.config.session_peer_timeout, + stream_send_buffer_size: fsm.config.session_stream_send_buffer_size, + stream_receive_buffer_size: fsm.config.session_stream_receive_buffer_size, + }, + fsm.state.now.instant, + ); +} + +fn forward_session_event( + session_events: &mut VecDeque, + event: SessionEvent, +) -> bool { + match event { + SessionEvent::Opened(stream_id) => { + session_events.push_back(QlSessionEvent::Opened(stream_id)); + false + } + SessionEvent::Readable(stream_id) => { + session_events.push_back(QlSessionEvent::Readable(stream_id)); + false + } + SessionEvent::Writable(stream_id) => { + session_events.push_back(QlSessionEvent::Writable(stream_id)); + false + } + SessionEvent::Finished(stream_id) => { + session_events.push_back(QlSessionEvent::Finished(stream_id)); + false + } + SessionEvent::Closed(frame) => { + session_events.push_back(QlSessionEvent::Closed(frame)); + false + } + SessionEvent::WritableClosed(stream_id) => { + session_events.push_back(QlSessionEvent::WritableClosed(stream_id)); + false + } + SessionEvent::SessionClosed(close) => { + session_events.push_back(QlSessionEvent::SessionClosed(close)); + true + } + } +} + +fn apply_session_closed(fsm: &mut QlFsm) { + if matches!(fsm.state.link, crate::state::LinkState::Connected(_)) { + fsm.state.link = crate::state::LinkState::Idle; + emit_peer_status(fsm); + } + reset_session(fsm); +} + fn ensure_session_open(fsm: &QlFsm) -> Result<(), QlFsmError> { fsm.state.ensure_peer_bound()?; if fsm.state.link.transport().is_none() { @@ -175,3 +266,13 @@ fn ensure_session_open(fsm: &QlFsm) -> Result<(), QlFsmError> { } Ok(()) } + +pub(super) fn deadline_after_secs(now_secs: u64, duration: Duration) -> u64 { + now_secs.saturating_add(duration_to_secs(duration)) +} + +fn duration_to_secs(duration: Duration) -> u64 { + duration + .as_secs() + .saturating_add(u64::from(duration.subsec_nanos() > 0)) +} diff --git a/ql-fsm/src/implementation/handshake/ik.rs b/ql-fsm/src/implementation/handshake/ik.rs index ffcdd569..e30088c9 100644 --- a/ql-fsm/src/implementation/handshake/ik.rs +++ b/ql-fsm/src/implementation/handshake/ik.rs @@ -1,11 +1,10 @@ use ql_wire::{self as wire, Ik1, Ik2, PeerBundle, QlCrypto, QlHandshakeRecord}; use super::{ - enqueue_handshake, finish_handshake, is_replayed_handshake_start, + emit_peer_status, enqueue_handshake, finish_handshake, is_replayed_handshake_start, reset_connected_session_if_needed, }; use crate::{ - implementation::emit_peer_status, state::{IkInitiatorState, LinkState, SessionTransport}, QlFsm, QlFsmError, }; diff --git a/ql-fsm/src/implementation/handshake/kk.rs b/ql-fsm/src/implementation/handshake/kk.rs index 63b03b6b..7c29943a 100644 --- a/ql-fsm/src/implementation/handshake/kk.rs +++ b/ql-fsm/src/implementation/handshake/kk.rs @@ -1,11 +1,10 @@ use ql_wire::{self as wire, Kk1, Kk2, PeerBundle, QlCrypto, QlHandshakeRecord}; use super::{ - enqueue_handshake, finish_handshake, is_replayed_handshake_start, + emit_peer_status, enqueue_handshake, finish_handshake, is_replayed_handshake_start, reset_connected_session_if_needed, }; use crate::{ - implementation::emit_peer_status, state::{KkInitiatorState, LinkState, SessionTransport}, QlFsm, QlFsmError, }; diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs index 8942aa56..0b2d47c9 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -3,10 +3,10 @@ mod kk; use ql_wire::{self as wire, EphemeralPublicKey, HandshakeMeta, QlCrypto, QlHandshakeRecord}; -use super::{emit_peer_status, fail_pending_connect_session, reset_session}; +use super::{emit_peer_status, reset_session}; use crate::{ state::{LinkState, SessionTransport}, - QlFsm, QlFsmError, QlFsmEvent, + QlFsm, QlFsmError, QlFsmEvent, QlSessionEvent, }; pub fn handle_connect_ik(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { @@ -109,6 +109,18 @@ pub fn reset_connected_session_if_needed(fsm: &mut QlFsm) { } } +fn fail_pending_connect_session(fsm: &mut QlFsm, code: ql_wire::SessionCloseCode) { + if !fsm.session.has_pending_stream_work() { + return; + } + reset_session(fsm); + fsm.state + .session_events + .push_back(QlSessionEvent::SessionClosed(ql_wire::SessionClose { + code, + })); +} + fn local_start_wins(local: &EphemeralPublicKey, inbound: &EphemeralPublicKey) -> bool { local.mlkem_public_key.as_bytes() <= inbound.mlkem_public_key.as_bytes() } diff --git a/ql-fsm/src/implementation/mod.rs b/ql-fsm/src/implementation/mod.rs index d24f21c3..64b0b3d3 100644 --- a/ql-fsm/src/implementation/mod.rs +++ b/ql-fsm/src/implementation/mod.rs @@ -1,119 +1,6 @@ -mod fsm; +mod core; mod handshake; -use std::{collections::VecDeque, time::Duration}; +pub use core::*; -pub use fsm::*; pub use handshake::*; - -use crate::{ - session::{stream_parity::StreamParity, SessionEvent, SessionFsmConfig}, - state::LinkState, - QlFsm, QlFsmEvent, QlSessionEvent, -}; - -fn emit_peer_status(fsm: &mut QlFsm) { - if let Some(peer) = fsm.state.peer.as_ref() { - fsm.state.events.push_back(QlFsmEvent::PeerStatusChanged { - peer: peer.xid, - status: fsm.state.link.status(), - }); - } -} - -fn reset_session(fsm: &mut QlFsm) { - let local_parity = fsm - .state - .peer - .as_ref() - .map(|peer| StreamParity::for_local(fsm.identity.xid, peer.xid)) - .unwrap_or(StreamParity::Even); - fsm.session = crate::session::SessionFsm::new( - SessionFsmConfig { - local_parity, - record_size: fsm.config.session_record_size, - ack_delay: fsm.config.session_record_ack_delay, - retransmit_timeout: fsm.config.session_record_retransmit_timeout, - keepalive_interval: fsm.config.session_keepalive_interval, - peer_timeout: fsm.config.session_peer_timeout, - stream_send_buffer_size: fsm.config.session_stream_send_buffer_size, - stream_receive_buffer_size: fsm.config.session_stream_receive_buffer_size, - }, - fsm.state.now.instant, - ); -} - -pub fn handle_bind_peer(fsm: &mut QlFsm, peer: ql_wire::PeerBundle) { - fsm.state.handshake = None; - fsm.state.link = LinkState::Idle; - fsm.state.peer = Some(peer.clone()); - reset_session(fsm); - fsm.state.events.push_back(QlFsmEvent::NewPeer(peer)); - emit_peer_status(fsm); -} - -fn fail_pending_connect_session(fsm: &mut QlFsm, code: ql_wire::SessionCloseCode) { - if !fsm.session.has_pending_stream_work() { - return; - } - reset_session(fsm); - fsm.state - .session_events - .push_back(QlSessionEvent::SessionClosed(ql_wire::SessionClose { - code, - })); -} - -fn forward_session_event( - session_events: &mut VecDeque, - event: SessionEvent, -) -> bool { - match event { - SessionEvent::Opened(stream_id) => { - session_events.push_back(QlSessionEvent::Opened(stream_id)); - false - } - SessionEvent::Readable(stream_id) => { - session_events.push_back(QlSessionEvent::Readable(stream_id)); - false - } - SessionEvent::Writable(stream_id) => { - session_events.push_back(QlSessionEvent::Writable(stream_id)); - false - } - SessionEvent::Finished(stream_id) => { - session_events.push_back(QlSessionEvent::Finished(stream_id)); - false - } - SessionEvent::Closed(frame) => { - session_events.push_back(QlSessionEvent::Closed(frame)); - false - } - SessionEvent::WritableClosed(stream_id) => { - session_events.push_back(QlSessionEvent::WritableClosed(stream_id)); - false - } - SessionEvent::SessionClosed(close) => { - session_events.push_back(QlSessionEvent::SessionClosed(close)); - true - } - } -} - -fn apply_session_closed(fsm: &mut QlFsm) { - if matches!(fsm.state.link, crate::state::LinkState::Connected(_)) { - fsm.state.link = crate::state::LinkState::Idle; - emit_peer_status(fsm); - } - reset_session(fsm); -} - -fn deadline_after_secs(now_secs: u64, duration: Duration) -> u64 { - now_secs.saturating_add(duration_to_secs(duration)) -} - -fn duration_to_secs(duration: Duration) -> u64 { - duration - .as_secs() - .saturating_add(u64::from(duration.subsec_nanos() > 0)) -} From 9a808d9ad0801b86af4dfdf0f3ce68261303a03b Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 3 Apr 2026 10:02:28 -0400 Subject: [PATCH 077/304] ql-fsm: use option where applicable --- ql-fsm/src/implementation/core.rs | 8 ++++---- ql-fsm/src/lib.rs | 4 ++-- ql-fsm/src/session/mod.rs | 20 ++++++-------------- 3 files changed, 12 insertions(+), 20 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index 07290870..39a48fd3 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -148,8 +148,8 @@ pub fn write_stream( Ok(fsm.session.write_stream(stream_id, bytes)?) } -pub fn stream_read(fsm: &QlFsm, stream_id: StreamId) -> Result, QlFsmError> { - Ok(fsm.session.stream_read(stream_id)?) +pub fn stream_read(fsm: &QlFsm, stream_id: StreamId) -> Option> { + fsm.session.stream_read(stream_id) } pub fn stream_read_commit( @@ -160,8 +160,8 @@ pub fn stream_read_commit( Ok(fsm.session.stream_read_commit(stream_id, len)?) } -pub fn stream_available_bytes(fsm: &QlFsm, stream_id: StreamId) -> Result { - Ok(fsm.session.stream_available_bytes(stream_id)?) +pub fn stream_available_bytes(fsm: &QlFsm, stream_id: StreamId) -> Option { + fsm.session.stream_available_bytes(stream_id) } pub fn finish_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result<(), QlFsmError> { diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index ea842817..62a80445 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -278,7 +278,7 @@ impl QlFsm { } /// returns the readable stream bytes as borrowed chunks without consuming them - pub fn stream_read(&self, stream_id: StreamId) -> Result, QlFsmError> { + pub fn stream_read(&self, stream_id: StreamId) -> Option> { implementation::stream_read(self, stream_id) } @@ -292,7 +292,7 @@ impl QlFsm { } /// returns how many bytes can be read from a stream - pub fn stream_available_bytes(&self, stream_id: StreamId) -> Result { + pub fn stream_available_bytes(&self, stream_id: StreamId) -> Option { implementation::stream_available_bytes(self, stream_id) } diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index b281dd19..abe2874b 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -190,13 +190,9 @@ impl SessionFsm { Ok(()) } - pub fn stream_read(&self, stream_id: StreamId) -> Result, StreamError> { - let stream = self - .state - .streams - .get(&stream_id) - .ok_or(StreamError::MissingStream)?; - Ok(stream.rx.bytes()) + pub fn stream_read(&self, stream_id: StreamId) -> Option> { + let stream = self.state.streams.get(&stream_id)?; + Some(stream.rx.bytes()) } pub fn stream_read_commit( @@ -220,13 +216,9 @@ impl SessionFsm { Ok(()) } - pub fn stream_available_bytes(&self, stream_id: StreamId) -> Result { - let stream = self - .state - .streams - .get(&stream_id) - .ok_or(StreamError::MissingStream)?; - Ok(stream.readable_bytes()) + pub fn stream_available_bytes(&self, stream_id: StreamId) -> Option { + let stream = self.state.streams.get(&stream_id)?; + Some(stream.readable_bytes()) } pub fn queue_ping(&mut self) -> Result<(), StreamError> { From 0816b8ce4e718cca79dcc2339047ea14f3abee67 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 3 Apr 2026 10:06:40 -0400 Subject: [PATCH 078/304] ql-fsm: remove un-needed check for already open streams --- ql-fsm/src/implementation/core.rs | 3 --- ql-fsm/src/state.rs | 6 ++---- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index 39a48fd3..a82e811e 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -144,7 +144,6 @@ pub fn write_stream( stream_id: StreamId, bytes: &[u8], ) -> Result { - fsm.state.ensure_peer_bound()?; Ok(fsm.session.write_stream(stream_id, bytes)?) } @@ -165,7 +164,6 @@ pub fn stream_available_bytes(fsm: &QlFsm, stream_id: StreamId) -> Option } pub fn finish_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result<(), QlFsmError> { - fsm.state.ensure_peer_bound()?; Ok(fsm.session.finish_stream(stream_id)?) } @@ -175,7 +173,6 @@ pub fn close_stream( target: CloseTarget, code: StreamCloseCode, ) -> Result<(), QlFsmError> { - fsm.state.ensure_peer_bound()?; Ok(fsm.session.close_stream(stream_id, target, code)?) } diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index 80b2c826..57655aae 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -95,9 +95,7 @@ impl LinkState { impl QlFsmState { pub fn ensure_peer_bound(&self) -> Result<(), crate::QlFsmError> { - self.peer - .as_ref() - .map(|_| ()) - .ok_or(crate::QlFsmError::NoPeerBound) + self.peer.as_ref().ok_or(crate::QlFsmError::NoPeerBound)?; + Ok(()) } } From d57858e77a8e15f64ae7acb3245db1a154574f48 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 3 Apr 2026 11:28:47 -0400 Subject: [PATCH 079/304] ql-wire: better encoding --- ql-wire/src/bytes.rs | 6 -- ql-wire/src/codec.rs | 44 ++++---- ql-wire/src/encrypted/ack.rs | 12 ++- ql-wire/src/encrypted/builder.rs | 134 ++++++++++++++----------- ql-wire/src/encrypted/close.rs | 11 +- ql-wire/src/encrypted/mod.rs | 50 ++------- ql-wire/src/encrypted/stream_close.rs | 10 +- ql-wire/src/encrypted/stream_data.rs | 15 ++- ql-wire/src/encrypted/stream_window.rs | 10 +- ql-wire/src/encrypted_message.rs | 10 +- ql-wire/src/handshake/ik.rs | 22 ++-- ql-wire/src/handshake/kk.rs | 20 ++-- ql-wire/src/handshake/meta.rs | 14 +-- ql-wire/src/handshake/mod.rs | 19 ++-- ql-wire/src/header.rs | 34 +++++-- ql-wire/src/identity.rs | 20 ++-- ql-wire/src/record.rs | 32 ++++-- 17 files changed, 229 insertions(+), 234 deletions(-) diff --git a/ql-wire/src/bytes.rs b/ql-wire/src/bytes.rs index ca839476..38cf1266 100644 --- a/ql-wire/src/bytes.rs +++ b/ql-wire/src/bytes.rs @@ -141,12 +141,6 @@ impl ByteSlice for &mut [u8] { } } -#[derive(Debug, Clone, Copy)] -pub struct CappedByteChunks { - pub inner: T, - pub limit: usize, -} - #[derive(Debug, Clone, Copy)] pub struct RangedByteChunks { pub inner: T, diff --git a/ql-wire/src/codec.rs b/ql-wire/src/codec.rs index 45fb9a41..143f93b6 100644 --- a/ql-wire/src/codec.rs +++ b/ql-wire/src/codec.rs @@ -1,23 +1,37 @@ use crate::{ByteSlice, WireError}; -pub fn push_u8(out: &mut Vec, value: u8) { - out.push(value); +pub fn write_u8(out: &mut [u8], value: u8) -> &mut [u8] { + let (head, rest) = out.split_at_mut(1); + head[0] = value; + rest } -pub fn push_u16(out: &mut Vec, value: u16) { - out.extend_from_slice(&value.to_le_bytes()); +pub fn write_u16(out: &mut [u8], value: u16) -> &mut [u8] { + let (head, rest) = out.split_at_mut(size_of::()); + head.copy_from_slice(&value.to_le_bytes()); + rest } -pub fn push_u32(out: &mut Vec, value: u32) { - out.extend_from_slice(&value.to_le_bytes()); +pub fn write_u32(out: &mut [u8], value: u32) -> &mut [u8] { + let (head, rest) = out.split_at_mut(size_of::()); + head.copy_from_slice(&value.to_le_bytes()); + rest } -pub fn push_u64(out: &mut Vec, value: u64) { - out.extend_from_slice(&value.to_le_bytes()); +pub fn write_u64(out: &mut [u8], value: u64) -> &mut [u8] { + let (head, rest) = out.split_at_mut(size_of::()); + head.copy_from_slice(&value.to_le_bytes()); + rest } -pub fn push_bytes(out: &mut Vec, bytes: &[u8]) { - out.extend_from_slice(bytes); +pub fn write_bool(out: &mut [u8], value: bool) -> &mut [u8] { + write_u8(out, u8::from(value)) +} + +pub fn write_bytes<'a>(out: &'a mut [u8], bytes: &[u8]) -> &'a mut [u8] { + let (head, rest) = out.split_at_mut(bytes.len()); + head.copy_from_slice(bytes); + rest } pub struct Reader { @@ -96,13 +110,3 @@ impl Reader { } } } - -pub fn append_field(out: &mut Vec, label: &[u8], value: &[u8]) { - append_framed_bytes(out, label); - append_framed_bytes(out, value); -} - -pub fn append_framed_bytes(out: &mut Vec, value: &[u8]) { - out.extend_from_slice(&u64::try_from(value.len()).unwrap().to_le_bytes()); - out.extend_from_slice(value); -} diff --git a/ql-wire/src/encrypted/ack.rs b/ql-wire/src/encrypted/ack.rs index 46d068fd..b9649128 100644 --- a/ql-wire/src/encrypted/ack.rs +++ b/ql-wire/src/encrypted/ack.rs @@ -1,5 +1,3 @@ -use std::mem::size_of; - use crate::{codec, WireError}; #[derive(Debug, Clone, PartialEq, Eq)] @@ -49,10 +47,14 @@ impl RecordAck { self.ranges.len() * Self::RANGE_ENCODED_LEN } - pub fn encode_into(&self, out: &mut Vec) { + pub fn encode_into(&self, out: &mut [u8]) { + assert_eq!(out.len(), self.encoded_len()); + let mut out = out; for range in &self.ranges { - codec::push_u64(out, range.start); - codec::push_u64(out, range.end); + let (encoded, rest) = out.split_at_mut(Self::RANGE_ENCODED_LEN); + let encoded = codec::write_u64(encoded, range.start); + let _ = codec::write_u64(encoded, range.end); + out = rest; } } } diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index 54e9a1f8..08ea8c07 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -1,23 +1,27 @@ -use super::{ - push_variable_len, RecordAck, SessionClose, SessionFrame, SessionFrameKind, StreamClose, - StreamData, StreamWindow, SIZE_LEN, -}; -use crate::{ - encrypted_message::EncryptedMessage, ByteChunks, Nonce, QlCrypto, QlSessionRecord, - SessionHeader, SessionKey, -}; +use super::{RecordAck, SessionClose, SessionFrame, StreamClose, StreamData, StreamWindow}; +use crate::{ByteChunks, Nonce, QlCrypto, RecordType, SessionHeader, SessionKey, QL_WIRE_VERSION}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct SessionRecordBuilder { max_capacity: usize, + body_start: usize, bytes: Vec, } impl SessionRecordBuilder { - pub fn new(max_capacity: usize) -> Self { - let bytes = Vec::with_capacity(max_capacity); + pub const WIRE_PREFIX_LEN: usize = + 1 + 1 + SessionHeader::ENCODED_LEN + crate::ENCRYPTED_MESSAGE_AUTH_SIZE; + + pub fn new(max_capacity: usize, initial_capacity: usize) -> Self { + assert!(initial_capacity <= max_capacity); + assert!(max_capacity >= Self::WIRE_PREFIX_LEN); + + let body_start = Self::WIRE_PREFIX_LEN; + let mut bytes = Vec::with_capacity(initial_capacity); + bytes.resize(body_start, 0); Self { max_capacity, + body_start, bytes, } } @@ -27,23 +31,26 @@ impl SessionRecordBuilder { } pub fn len(&self) -> usize { - self.bytes.len() + self.bytes.len().saturating_sub(self.body_start) } pub fn is_empty(&self) -> bool { - self.bytes.is_empty() + self.len() == 0 } pub fn remaining_capacity(&self) -> usize { - self.max_capacity.saturating_sub(self.bytes.len()) + self.max_capacity + .saturating_sub(self.body_start) + .saturating_sub(self.len()) } pub fn bytes(&self) -> &[u8] { - &self.bytes + &self.bytes[self.body_start..] } pub fn into_plaintext(self) -> Vec { - self.bytes + let mut bytes = self.bytes; + bytes.split_off(self.body_start) } pub fn can_push_len(&self, len: usize) -> bool { @@ -51,59 +58,50 @@ impl SessionRecordBuilder { } pub fn push_ping(&mut self) -> bool { - if !self.can_push_len(1) { - return false; - } - self.bytes.push(SessionFrameKind::Ping as u8); - true + self.push_encoded_len(1, |out| out[0] = super::SessionFrameKind::Ping as u8) } pub fn push_ack(&mut self, ack: &RecordAck) -> bool { - if !self.can_push_len(1 + SIZE_LEN + ack.encoded_len()) { - return false; - } - self.bytes.push(SessionFrameKind::Ack as u8); - push_variable_len(&mut self.bytes, ack.encoded_len()); - ack.encode_into(&mut self.bytes); - true + let len = 1 + super::SIZE_LEN + ack.encoded_len(); + self.push_encoded_len(len, |out| { + out[0] = super::SessionFrameKind::Ack as u8; + super::push_variable_len(&mut out[1..1 + super::SIZE_LEN], ack.encoded_len()); + ack.encode_into(&mut out[1 + super::SIZE_LEN..]); + }) } pub fn push_stream_data(&mut self, frame: &StreamData) -> bool { - if !self.can_push_len(1 + SIZE_LEN + frame.encoded_len()) { - return false; - } - self.bytes.push(SessionFrameKind::StreamData as u8); - push_variable_len(&mut self.bytes, frame.encoded_len()); - frame.encode_into(&mut self.bytes); - true + let len = 1 + super::SIZE_LEN + frame.encoded_len(); + self.push_encoded_len(len, |out| { + out[0] = super::SessionFrameKind::StreamData as u8; + super::push_variable_len(&mut out[1..1 + super::SIZE_LEN], frame.encoded_len()); + frame.encode_into(&mut out[1 + super::SIZE_LEN..]); + }) } pub fn push_stream_window(&mut self, frame: &StreamWindow) -> bool { - if !self.can_push_len(1 + StreamWindow::WIRE_SIZE) { - return false; - } - self.bytes.push(SessionFrameKind::StreamWindow as u8); - frame.encode_into(&mut self.bytes); - true + let len = 1 + StreamWindow::WIRE_SIZE; + self.push_encoded_len(len, |out| { + out[0] = super::SessionFrameKind::StreamWindow as u8; + frame.encode_into(&mut out[1..]); + }) } pub fn push_stream_close(&mut self, frame: &StreamClose) -> bool { - if !self.can_push_len(1 + SIZE_LEN + frame.encoded_len()) { - return false; - } - self.bytes.push(SessionFrameKind::StreamClose as u8); - push_variable_len(&mut self.bytes, frame.encoded_len()); - frame.encode_into(&mut self.bytes); - true + let len = 1 + super::SIZE_LEN + frame.encoded_len(); + self.push_encoded_len(len, |out| { + out[0] = super::SessionFrameKind::StreamClose as u8; + super::push_variable_len(&mut out[1..1 + super::SIZE_LEN], frame.encoded_len()); + frame.encode_into(&mut out[1 + super::SIZE_LEN..]); + }) } pub fn push_close(&mut self, close: &SessionClose) -> bool { - if !self.can_push_len(1 + SessionClose::WIRE_SIZE) { - return false; - } - self.bytes.push(SessionFrameKind::Close as u8); - close.encode_into(&mut self.bytes); - true + let len = 1 + SessionClose::WIRE_SIZE; + self.push_encoded_len(len, |out| { + out[0] = super::SessionFrameKind::Close as u8; + close.encode_into(&mut out[1..]); + }) } pub fn push_frame(&mut self, frame: &SessionFrame) -> bool { @@ -118,17 +116,35 @@ impl SessionRecordBuilder { } pub fn encrypt( - self, + mut self, crypto: &impl QlCrypto, header: SessionHeader, session_key: &SessionKey, - ) -> QlSessionRecord> { + ) -> Vec { let aad = header.aad(); let nonce = Nonce::from_counter(header.seq.0); - let encrypted = EncryptedMessage::encrypt(crypto, session_key, self.bytes, &nonce, &aad); - QlSessionRecord { - header, - payload: encrypted, + let auth = crypto.aes256_gcm_encrypt( + session_key, + &nonce, + &aad, + &mut self.bytes[self.body_start..], + ); + + let prefix = &mut self.bytes[..self.body_start]; + prefix[0] = QL_WIRE_VERSION; + prefix[1] = RecordType::Session as u8; + header.encode_into(&mut prefix[2..2 + SessionHeader::ENCODED_LEN]); + prefix[2 + SessionHeader::ENCODED_LEN..].copy_from_slice(&auth); + self.bytes + } + + fn push_encoded_len(&mut self, len: usize, encode: impl FnOnce(&mut [u8])) -> bool { + if !self.can_push_len(len) { + return false; } + let start = self.bytes.len(); + self.bytes.resize(start + len, 0); + encode(&mut self.bytes[start..]); + true } } diff --git a/ql-wire/src/encrypted/close.rs b/ql-wire/src/encrypted/close.rs index 31782a01..af9e234d 100644 --- a/ql-wire/src/encrypted/close.rs +++ b/ql-wire/src/encrypted/close.rs @@ -1,9 +1,4 @@ -use std::mem::size_of; - -use crate::{ - codec::{self, Reader}, - WireError, -}; +use crate::{codec, codec::Reader, WireError}; /// closes the whole session immediately with a close code. #[derive(Debug, Clone, PartialEq, Eq)] @@ -14,8 +9,8 @@ pub struct SessionClose { impl SessionClose { pub const WIRE_SIZE: usize = size_of::(); - pub fn encode_into(&self, out: &mut Vec) { - codec::push_u16(out, self.code.0); + pub fn encode_into(&self, out: &mut [u8]) { + let _ = codec::write_u16(out, self.code.0); } pub fn decode(bytes: &[u8]) -> Result { diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 9a1b7692..11a520ef 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -1,5 +1,3 @@ -use std::mem::size_of; - use crate::{ codec, encrypted_message::EncryptedMessage, ByteChunks, ByteSlice, Nonce, QlCrypto, QlSessionRecord, SessionHeader, SessionKey, WireError, @@ -97,15 +95,6 @@ impl SessionRecord { .map(SessionFrame::encoded_len) .sum::() } - - pub fn encode(&self) -> Vec { - let mut out = SessionRecordBuilder::new(self.encoded_len()); - for frame in &self.frames { - let pushed = out.push_frame(frame); - debug_assert!(pushed); - } - out.into_plaintext() - } } impl SessionFrame { @@ -119,35 +108,6 @@ impl SessionFrame { Self::Close(_) => SessionClose::WIRE_SIZE, } } - - pub fn encode_into(&self, out: &mut Vec) { - match self { - Self::Ping => out.push(SessionFrameKind::Ping as u8), - Self::Ack(frame) => { - out.push(SessionFrameKind::Ack as u8); - push_variable_len(out, frame.encoded_len()); - frame.encode_into(out); - } - Self::StreamData(frame) => { - out.push(SessionFrameKind::StreamData as u8); - push_variable_len(out, frame.encoded_len()); - frame.encode_into(out); - } - Self::StreamWindow(frame) => { - out.push(SessionFrameKind::StreamWindow as u8); - frame.encode_into(out); - } - Self::StreamClose(frame) => { - out.push(SessionFrameKind::StreamClose as u8); - push_variable_len(out, frame.encoded_len()); - frame.encode_into(out); - } - Self::Close(body) => { - out.push(SessionFrameKind::Close as u8); - body.encode_into(out); - } - } - } } impl SessionFrame { @@ -191,12 +151,14 @@ pub fn encrypt_record( session_key: &SessionKey, body: &SessionRecord, ) -> QlSessionRecord> { - let mut builder = SessionRecordBuilder::new(body.encoded_len()); + let encoded_len = body.encoded_len() + SessionRecordBuilder::WIRE_PREFIX_LEN; + let mut builder = SessionRecordBuilder::new(encoded_len, encoded_len); for frame in &body.frames { let pushed = builder.push_frame(frame); debug_assert!(pushed); } - builder.encrypt(crypto, header, session_key) + QlSessionRecord::decode(&builder.encrypt(crypto, header, session_key)) + .expect("builder emitted an invalid session record") } pub fn decrypt_record>( @@ -246,9 +208,9 @@ fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrame<&[u8]>, &[u8]), WireEr } } -fn push_variable_len(out: &mut Vec, len: usize) { +fn push_variable_len(out: &mut [u8], len: usize) { let len = u16::try_from(len).expect("session frame exceeds u16"); - codec::push_u16(out, len); + let _ = codec::write_u16(out, len); } fn split_variable_frame(bytes: &[u8]) -> Result<(&[u8], &[u8]), WireError> { diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index f492e259..1589a14d 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -1,5 +1,3 @@ -use std::mem::size_of; - use super::StreamId; use crate::{codec, ByteSlice, WireError}; @@ -30,10 +28,10 @@ impl StreamClose { Self::WIRE_SIZE } - pub fn encode_into(&self, out: &mut Vec) { - codec::push_u32(out, self.stream_id.0); - codec::push_u8(out, self.target.to_wire()); - codec::push_u16(out, self.code.0); + pub fn encode_into(&self, out: &mut [u8]) { + let out = codec::write_u32(out, self.stream_id.0); + let out = codec::write_u8(out, self.target.to_wire()); + let _ = codec::write_u16(out, self.code.0); } } diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs index e7253cc3..2e7deb52 100644 --- a/ql-wire/src/encrypted/stream_data.rs +++ b/ql-wire/src/encrypted/stream_data.rs @@ -1,5 +1,3 @@ -use std::mem::size_of; - use super::StreamId; use crate::{codec, ByteChunks, ByteSlice, WireError}; @@ -13,7 +11,7 @@ pub struct StreamData { } impl StreamData { - pub const MIN_WIRE_SIZE: usize = size_of::() + size_of::() + size_of::(); + pub const MIN_WIRE_SIZE: usize = size_of::() + size_of::() + size_of::(); } impl StreamData { @@ -47,12 +45,13 @@ impl StreamData { Self::MIN_WIRE_SIZE + self.bytes.len() } - pub fn encode_into(&self, out: &mut Vec) { - codec::push_u32(out, self.stream_id.0); - codec::push_u64(out, self.offset); - codec::push_u8(out, u8::from(self.fin)); + pub fn encode_into(&self, out: &mut [u8]) { + assert_eq!(out.len(), self.encoded_len()); + let out = codec::write_u32(out, self.stream_id.0); + let out = codec::write_u64(out, self.offset); + let mut out = codec::write_bool(out, self.fin); for chunk in self.bytes.chunks() { - codec::push_bytes(out, chunk); + out = codec::write_bytes(out, chunk); } } } diff --git a/ql-wire/src/encrypted/stream_window.rs b/ql-wire/src/encrypted/stream_window.rs index d03f0d02..1f3388c0 100644 --- a/ql-wire/src/encrypted/stream_window.rs +++ b/ql-wire/src/encrypted/stream_window.rs @@ -1,5 +1,3 @@ -use std::mem::size_of; - use super::StreamId; use crate::{codec, WireError}; @@ -11,11 +9,11 @@ pub struct StreamWindow { } impl StreamWindow { - pub const WIRE_SIZE: usize = size_of::() + size_of::(); + pub const WIRE_SIZE: usize = size_of::() + size_of::(); - pub fn encode_into(&self, out: &mut Vec) { - codec::push_u32(out, self.stream_id.0); - codec::push_u64(out, self.maximum_offset); + pub fn encode_into(&self, out: &mut [u8]) { + let out = codec::write_u32(out, self.stream_id.0); + let _ = codec::write_u64(out, self.maximum_offset); } pub fn decode(bytes: &[u8]) -> Result { diff --git a/ql-wire/src/encrypted_message.rs b/ql-wire/src/encrypted_message.rs index 062c4ecd..886b50b3 100644 --- a/ql-wire/src/encrypted_message.rs +++ b/ql-wire/src/encrypted_message.rs @@ -34,14 +34,14 @@ impl EncryptedMessage { } impl> EncryptedMessage { - pub fn encode_into(&self, out: &mut Vec) { - codec::push_bytes(out, &self.auth); - codec::push_bytes(out, self.ciphertext.as_ref()); + pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { + let out = codec::write_bytes(out, &self.auth); + codec::write_bytes(out, self.ciphertext.as_ref()) } pub fn encode(&self) -> Vec { - let mut out = Vec::with_capacity(Self::HEADER_LEN + self.ciphertext.as_ref().len()); - self.encode_into(&mut out); + let mut out = vec![0; Self::HEADER_LEN + self.ciphertext.as_ref().len()]; + let _ = self.encode_into(&mut out); out } diff --git a/ql-wire/src/handshake/ik.rs b/ql-wire/src/handshake/ik.rs index 7be407a2..d5b209b2 100644 --- a/ql-wire/src/handshake/ik.rs +++ b/ql-wire/src/handshake/ik.rs @@ -26,12 +26,12 @@ impl Ik1 { + EphemeralPublicKey::ENCODED_LEN + EncryptedPeerBundle::ENCODED_LEN; - pub fn encode_into(&self, out: &mut Vec) { - self.header.encode_into(out); - self.meta.encode_into(out); - codec::push_bytes(out, self.skem_ciphertext.as_bytes()); - self.ephemeral.encode_into(out); - codec::push_bytes(out, self.static_bundle.as_bytes()); + pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { + let out = self.header.encode_into(out); + let out = self.meta.encode_into(out); + let out = codec::write_bytes(out, self.skem_ciphertext.as_bytes()); + let out = self.ephemeral.encode_into(out); + codec::write_bytes(out, self.static_bundle.as_bytes()) } pub fn decode(bytes: &[u8]) -> Result { @@ -67,11 +67,11 @@ impl Ik2 { + MlKemCiphertext::SIZE + EncryptedMlKemCiphertext::ENCODED_LEN; - pub fn encode_into(&self, out: &mut Vec) { - self.header.encode_into(out); - self.meta.encode_into(out); - codec::push_bytes(out, self.ekem_ciphertext.as_bytes()); - codec::push_bytes(out, self.skem_ciphertext.as_bytes()); + pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { + let out = self.header.encode_into(out); + let out = self.meta.encode_into(out); + let out = codec::write_bytes(out, self.ekem_ciphertext.as_bytes()); + codec::write_bytes(out, self.skem_ciphertext.as_bytes()) } pub fn decode(bytes: &[u8]) -> Result { diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index 5f4ff45e..244bde96 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -23,11 +23,11 @@ impl Kk1 { + MlKemCiphertext::SIZE + EphemeralPublicKey::ENCODED_LEN; - pub fn encode_into(&self, out: &mut Vec) { - self.header.encode_into(out); - self.meta.encode_into(out); - codec::push_bytes(out, self.skem_ciphertext.as_bytes()); - self.ephemeral.encode_into(out); + pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { + let out = self.header.encode_into(out); + let out = self.meta.encode_into(out); + let out = codec::write_bytes(out, self.skem_ciphertext.as_bytes()); + self.ephemeral.encode_into(out) } pub fn decode(bytes: &[u8]) -> Result { @@ -61,11 +61,11 @@ impl Kk2 { + MlKemCiphertext::SIZE + EncryptedMlKemCiphertext::ENCODED_LEN; - pub fn encode_into(&self, out: &mut Vec) { - self.header.encode_into(out); - self.meta.encode_into(out); - codec::push_bytes(out, self.ekem_ciphertext.as_bytes()); - codec::push_bytes(out, self.skem_ciphertext.as_bytes()); + pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { + let out = self.header.encode_into(out); + let out = self.meta.encode_into(out); + let out = codec::write_bytes(out, self.ekem_ciphertext.as_bytes()); + codec::write_bytes(out, self.skem_ciphertext.as_bytes()) } pub fn decode(bytes: &[u8]) -> Result { diff --git a/ql-wire/src/handshake/meta.rs b/ql-wire/src/handshake/meta.rs index 747d7e18..bf780750 100644 --- a/ql-wire/src/handshake/meta.rs +++ b/ql-wire/src/handshake/meta.rs @@ -11,7 +11,7 @@ pub struct HandshakeMeta { } impl HandshakeMeta { - pub const ENCODED_LEN: usize = core::mem::size_of::() + core::mem::size_of::(); + pub const ENCODED_LEN: usize = size_of::() + size_of::(); pub fn ensure_not_expired(&self, now_seconds: u64) -> Result<(), WireError> { if now_seconds > self.valid_until { @@ -21,14 +21,14 @@ impl HandshakeMeta { } } - pub fn encode_into(&self, out: &mut Vec) { - codec::push_u32(out, self.handshake_id.0); - codec::push_u64(out, self.valid_until); + pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { + let out = codec::write_u32(out, self.handshake_id.0); + codec::write_u64(out, self.valid_until) } - pub fn encode(&self) -> Vec { - let mut out = Vec::with_capacity(Self::ENCODED_LEN); - self.encode_into(&mut out); + pub fn encode(&self) -> [u8; Self::ENCODED_LEN] { + let mut out = [0; Self::ENCODED_LEN]; + let _ = self.encode_into(&mut out); out } diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index edbbc3a7..fa050fdf 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -26,9 +26,15 @@ pub struct HandshakeHeader { impl HandshakeHeader { pub const ENCODED_LEN: usize = XID::SIZE * 2; - pub fn encode_into(&self, out: &mut Vec) { - codec::push_bytes(out, &self.sender.0); - codec::push_bytes(out, &self.recipient.0); + pub fn encode(&self) -> [u8; Self::ENCODED_LEN] { + let mut out = [0; Self::ENCODED_LEN]; + let _ = self.encode_into(&mut out); + out + } + + pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { + let out = codec::write_bytes(out, &self.sender.0); + codec::write_bytes(out, &self.recipient.0) } pub fn decode(bytes: &[u8]) -> Result { @@ -56,8 +62,8 @@ pub struct EphemeralPublicKey { impl EphemeralPublicKey { pub const ENCODED_LEN: usize = MlKemPublicKey::SIZE; - pub fn encode_into(&self, out: &mut Vec) { - codec::push_bytes(out, self.mlkem_public_key.as_bytes()); + pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { + codec::write_bytes(out, self.mlkem_public_key.as_bytes()) } pub fn decode(bytes: &[u8]) -> Result { @@ -309,8 +315,7 @@ fn mix_hash_routed_handshake( kind: HandshakeKind, meta: &HandshakeMeta, ) { - let mut encoded_header = Vec::with_capacity(HandshakeHeader::ENCODED_LEN); - header.encode_into(&mut encoded_header); + let encoded_header = header.encode(); let encoded = meta.encode(); symmetric.mix_hash(crypto, HANDSHAKE_PREAMBLE_DOMAIN); symmetric.mix_hash(crypto, &encoded_header); diff --git a/ql-wire/src/header.rs b/ql-wire/src/header.rs index 7d194880..ea289eea 100644 --- a/ql-wire/src/header.rs +++ b/ql-wire/src/header.rs @@ -27,11 +27,20 @@ impl ConnectionId { } impl SessionHeader { - pub const ENCODED_LEN: usize = ConnectionId::SIZE + core::mem::size_of::(); + pub const ENCODED_LEN: usize = ConnectionId::SIZE + size_of::(); + const AAD_DOMAIN: &[u8] = b"ql-wire:session-aad:v1"; + const AAD_RECORD_KIND_SESSION: u8 = 1; - pub fn encode_into(&self, out: &mut Vec) { - codec::push_bytes(out, self.connection_id.as_bytes()); - codec::push_u64(out, self.seq.0); + pub fn encode(&self) -> [u8; Self::ENCODED_LEN] { + let mut out = [0; Self::ENCODED_LEN]; + self.encode_into(&mut out); + out + } + + pub fn encode_into(&self, out: &mut [u8]) { + assert_eq!(out.len(), Self::ENCODED_LEN); + let out = codec::write_bytes(out, self.connection_id.as_bytes()); + let _ = codec::write_u64(out, self.seq.0); } pub fn decode(bytes: &[u8]) -> Result { @@ -51,12 +60,17 @@ impl SessionHeader { } pub fn aad(&self) -> Vec { - let mut aad = Vec::new(); - codec::append_field(&mut aad, b"domain", b"ql-wire:session-aad:v1"); - codec::append_field(&mut aad, b"wire-version", &[QL_WIRE_VERSION]); - codec::append_field(&mut aad, b"record-kind", b"session"); - codec::append_field(&mut aad, b"connection-id", self.connection_id.as_bytes()); - codec::append_field(&mut aad, b"record-seq", &self.seq.0.to_le_bytes()); + let aad_len = Self::AAD_DOMAIN.len() + + size_of::() + + size_of::() + + ConnectionId::SIZE + + size_of::(); + let mut aad = vec![0; aad_len]; + let out = codec::write_bytes(&mut aad, Self::AAD_DOMAIN); + let out = codec::write_u8(out, QL_WIRE_VERSION); + let out = codec::write_u8(out, Self::AAD_RECORD_KIND_SESSION); + let out = codec::write_bytes(out, self.connection_id.as_bytes()); + let _ = codec::write_u64(out, self.seq.0); aad } } diff --git a/ql-wire/src/identity.rs b/ql-wire/src/identity.rs index 1f5e2510..2dfda418 100644 --- a/ql-wire/src/identity.rs +++ b/ql-wire/src/identity.rs @@ -10,21 +10,19 @@ pub struct PeerBundle { impl PeerBundle { pub const VERSION: u16 = 1; - pub const ENCODED_LEN: usize = core::mem::size_of::() - + XID::SIZE - + core::mem::size_of::() - + MlKemPublicKey::SIZE; + pub const ENCODED_LEN: usize = + size_of::() + XID::SIZE + size_of::() + MlKemPublicKey::SIZE; - pub fn encode_into(&self, out: &mut Vec) { - codec::push_u16(out, self.version); - codec::push_bytes(out, &self.xid.0); - codec::push_u32(out, self.capabilities); - codec::push_bytes(out, self.mlkem_public_key.as_bytes()); + pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { + let out = codec::write_u16(out, self.version); + let out = codec::write_bytes(out, &self.xid.0); + let out = codec::write_u32(out, self.capabilities); + codec::write_bytes(out, self.mlkem_public_key.as_bytes()) } pub fn encode(&self) -> Vec { - let mut out = Vec::with_capacity(Self::ENCODED_LEN); - self.encode_into(&mut out); + let mut out = vec![0; Self::ENCODED_LEN]; + let _ = self.encode_into(&mut out); out } diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index 77c4ceb5..a01ffd72 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -77,7 +77,16 @@ impl QlHandshakeRecord { } } - fn encode_into(&self, out: &mut Vec) { + fn encoded_len(&self) -> usize { + match self { + Self::Ik1(_) => Ik1::ENCODED_LEN, + Self::Ik2(_) => Ik2::ENCODED_LEN, + Self::Kk1(_) => Kk1::ENCODED_LEN, + Self::Kk2(_) => Kk2::ENCODED_LEN, + } + } + + fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { match self { Self::Ik1(message) => message.encode_into(out), Self::Ik2(message) => message.encode_into(out), @@ -96,11 +105,11 @@ impl QlHandshakeRecord { } pub fn encode(&self) -> Vec { - let mut out = Vec::new(); - codec::push_u8(&mut out, QL_WIRE_VERSION); - codec::push_u8(&mut out, RecordType::Handshake as u8); - codec::push_u8(&mut out, self.kind() as u8); - self.encode_into(&mut out); + let mut out = vec![0; 3 + self.encoded_len()]; + let rest = codec::write_u8(&mut out, QL_WIRE_VERSION); + let rest = codec::write_u8(rest, RecordType::Handshake as u8); + let rest = codec::write_u8(rest, self.kind() as u8); + let _ = self.encode_into(rest); out } @@ -122,11 +131,12 @@ impl QlHandshakeRecord { impl> QlSessionRecord { pub fn encode(&self) -> Vec { - let mut out = Vec::new(); - codec::push_u8(&mut out, QL_WIRE_VERSION); - codec::push_u8(&mut out, RecordType::Session as u8); - self.header.encode_into(&mut out); - self.payload.encode_into(&mut out); + let mut out = + vec![0; 2 + SessionHeader::ENCODED_LEN + EncryptedMessage::<&[u8]>::HEADER_LEN + self.payload.ciphertext.as_ref().len()]; + let rest = codec::write_u8(&mut out, QL_WIRE_VERSION); + let rest = codec::write_u8(rest, RecordType::Session as u8); + let rest = codec::write_bytes(rest, &self.header.encode()); + let _ = self.payload.encode_into(rest); out } } From b2cae3bbc859247f7f5dafe7b4e5b54e4b1e6d10 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 3 Apr 2026 12:07:30 -0400 Subject: [PATCH 080/304] ql-fsm: use new wire encoding functions --- ql-fsm/src/implementation/core.rs | 7 ++++--- ql-fsm/src/lib.rs | 33 ++++++++++++++++++------------- ql-fsm/src/session/mod.rs | 23 ++++++++++++++++----- ql-fsm/src/session/tests.rs | 16 +++++++++++---- ql-fsm/src/tests/handshake.rs | 4 +++- ql-fsm/src/tests/mod.rs | 28 +++++++++++--------------- 6 files changed, 68 insertions(+), 43 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index a82e811e..21285654 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -87,7 +87,7 @@ pub fn next_deadline(fsm: &QlFsm) -> Option { pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option { if let Some(record) = fsm.state.handshake.take() { return Some(OutboundWrite { - record: wire::QlRecord::Handshake(record), + record: record.encode(), session_write_id: None, }); } @@ -105,7 +105,7 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option>, + /// wire bytes to hand to the transport + pub record: Vec, /// write handle that must be confirmed or rejected pub session_write_id: Option, } @@ -123,8 +123,10 @@ pub struct QlFsmConfig { pub session_keepalive_interval: Duration, /// how long to wait before declaring the peer dead pub session_peer_timeout: Duration, - /// target plaintext size for one session record - pub session_record_size: usize, + /// target total wire size for one session record, including header and auth tag + pub session_record_target_size: usize, + /// maximum total wire size for one session record, including header and auth tag + pub session_record_max_size: usize, /// maximum bytes buffered locally for one stream send side pub session_stream_send_buffer_size: usize, /// maximum bytes buffered locally for one stream receive side @@ -133,15 +135,17 @@ pub struct QlFsmConfig { impl Default for QlFsmConfig { fn default() -> Self { + let s = session::SessionFsmConfig::default(); Self { handshake_timeout: Duration::from_secs(5), - session_record_ack_delay: Duration::from_millis(5), - session_record_retransmit_timeout: Duration::from_millis(150), - session_keepalive_interval: Duration::from_secs(10), - session_peer_timeout: Duration::from_secs(30), - session_record_size: 16 * 1024, - session_stream_send_buffer_size: 64 * 1024, - session_stream_receive_buffer_size: 64 * 1024, + session_record_ack_delay: s.ack_delay, + session_record_retransmit_timeout: s.retransmit_timeout, + session_keepalive_interval: s.keepalive_interval, + session_peer_timeout: s.peer_timeout, + session_record_target_size: s.record_target_size, + session_record_max_size: s.record_max_size, + session_stream_send_buffer_size: s.stream_send_buffer_size, + session_stream_receive_buffer_size: s.stream_receive_buffer_size, } } } @@ -165,7 +169,8 @@ impl QlFsm { session: session::SessionFsm::new( session::SessionFsmConfig { local_parity: session::stream_parity::StreamParity::Even, - record_size: config.session_record_size, + record_target_size: config.session_record_target_size, + record_max_size: config.session_record_max_size, ack_delay: config.session_record_ack_delay, retransmit_timeout: config.session_record_retransmit_timeout, keepalive_interval: config.session_keepalive_interval, diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index abe2874b..618a668a 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -29,7 +29,8 @@ use self::{ #[derive(Debug, Clone, Copy)] pub struct SessionFsmConfig { pub local_parity: StreamParity, - pub record_size: usize, + pub record_target_size: usize, + pub record_max_size: usize, pub ack_delay: Duration, pub retransmit_timeout: Duration, pub keepalive_interval: Duration, @@ -42,7 +43,8 @@ impl Default for SessionFsmConfig { fn default() -> Self { Self { local_parity: StreamParity::Even, - record_size: 16 * 1024, + record_target_size: 4 * 1024, + record_max_size: 16 * 1024, ack_delay: Duration::from_millis(5), retransmit_timeout: Duration::from_millis(150), keepalive_interval: Duration::from_secs(10), @@ -89,7 +91,13 @@ pub struct SessionFsm { impl SessionFsm { pub fn new(mut config: SessionFsmConfig, now: Instant) -> Self { - config.record_size = config.record_size.max(64); + config.record_target_size = config + .record_target_size + .max(SessionRecordBuilder::WIRE_PREFIX_LEN); + config.record_max_size = config + .record_max_size + .max(SessionRecordBuilder::WIRE_PREFIX_LEN); + config.record_target_size = config.record_target_size.min(config.record_max_size); config.stream_send_buffer_size = config.stream_send_buffer_size.max(1); config.stream_receive_buffer_size = config.stream_receive_buffer_size.max(1); Self { @@ -414,7 +422,8 @@ impl SessionFsm { fn build_next_record(&mut self) -> Option<(SessionRecordBuilder, TrackedRecord)> { let seq = self.state.next_record_seq; - let mut builder = SessionRecordBuilder::new(self.config.record_size); + let mut builder = + SessionRecordBuilder::new(self.config.record_max_size, self.config.record_target_size); let mut outbound = TrackedRecord { seq, frames: Vec::new(), @@ -597,7 +606,11 @@ impl SessionFsm { if remaining > overhead { Some(remaining - overhead) } else if builder.is_empty() { - Some(self.config.record_size) + Some( + self.config + .record_max_size + .saturating_sub(SessionRecordBuilder::WIRE_PREFIX_LEN), + ) } else { None } diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 7abc4b9f..48c7e19e 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -1,8 +1,8 @@ use std::time::{Duration, Instant}; use ql_wire::{ - CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionFrame, SessionRecord, StreamClose, - StreamCloseCode, StreamData, StreamId, XID, + CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionFrame, SessionRecord, + SessionRecordBuilder, StreamClose, StreamCloseCode, StreamData, StreamId, XID, }; use super::{SessionEvent, SessionFsm, SessionFsmConfig}; @@ -36,7 +36,14 @@ fn receive_events( seq: RecordSeq, record: SessionRecord, ) -> Vec { - let bytes = record.encode(); + let mut builder = SessionRecordBuilder::new( + SessionRecordBuilder::WIRE_PREFIX_LEN + record.encoded_len(), + SessionRecordBuilder::WIRE_PREFIX_LEN + record.encoded_len(), + ); + for frame in &record.frames { + assert!(builder.push_frame(frame)); + } + let bytes = builder.into_plaintext(); let frames = SessionRecord::parse(&bytes).unwrap(); let mut events = Vec::new(); fsm.receive(now, seq, frames, |event| events.push(event)); @@ -80,7 +87,8 @@ fn lost_record_on_one_stream_does_not_block_another_stream() { let now = Instant::now(); let mut fsm = SessionFsm::new( SessionFsmConfig { - record_size: 80, + record_target_size: 80 + SessionRecordBuilder::WIRE_PREFIX_LEN, + record_max_size: 80 + SessionRecordBuilder::WIRE_PREFIX_LEN, ..SessionFsmConfig::default() }, now, diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 11857149..6653ca1d 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -157,6 +157,7 @@ fn handshake_timeout_drops_single_ik_attempt_without_resend() { .connect_ik(harness.time(), &harness.a.crypto) .unwrap(); let first = harness.next_outbound_a().unwrap(); + let first = QlRecord::decode(&first).unwrap(); assert!(matches!( first, QlRecord::Handshake(ql_wire::QlHandshakeRecord::Ik1(_)) @@ -245,7 +246,8 @@ fn simultaneous_ik_and_kk_connect_prefers_ik() { assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); } -fn handshake_id(record: &QlRecord>) -> ql_wire::HandshakeId { +fn handshake_id(record: &[u8]) -> ql_wire::HandshakeId { + let record = QlRecord::decode(record).unwrap(); match record { QlRecord::Handshake(ql_wire::QlHandshakeRecord::Ik1(message)) => message.meta.handshake_id, QlRecord::Handshake(ql_wire::QlHandshakeRecord::Ik2(message)) => message.meta.handshake_id, diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 9125984f..12abb202 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -10,8 +10,8 @@ use libcrux_aesgcm::AesGcm256Key; use libcrux_ml_kem::mlkem1024; use ql_wire::{ self, generate_identity, ConnectionId, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, - MlKemPublicKey, Nonce, QlAead, QlCrypto, QlHash, QlIdentity, QlKem, QlRandom, QlRecord, - SessionKey, ENCRYPTED_MESSAGE_AUTH_SIZE, XID, + MlKemPublicKey, Nonce, QlAead, QlCrypto, QlHash, QlIdentity, QlKem, QlRandom, SessionKey, + ENCRYPTED_MESSAGE_AUTH_SIZE, XID, }; use sha2::{Digest, Sha256}; @@ -226,7 +226,7 @@ impl Harness { self.unix_secs = self.unix_secs.saturating_add(duration.as_secs()); } - fn next_outbound_a(&mut self) -> Option>> { + fn next_outbound_a(&mut self) -> Option> { let write = self.a.fsm.take_next_write(self.time(), &self.a.crypto)?; if let Some(id) = write.session_write_id { self.a.fsm.confirm_session_write(self.time(), id); @@ -234,7 +234,7 @@ impl Harness { Some(write.record) } - fn next_outbound_b(&mut self) -> Option>> { + fn next_outbound_b(&mut self) -> Option> { let write = self.b.fsm.take_next_write(self.time(), &self.b.crypto)?; if let Some(id) = write.session_write_id { self.b.fsm.confirm_session_write(self.time(), id); @@ -246,18 +246,12 @@ impl Harness { self.a.fsm.take_next_write(self.time(), &self.a.crypto) } - fn deliver_to_a(&mut self, record: QlRecord>) { - self.a - .fsm - .receive(self.time(), record.encode(), &self.a.crypto) - .unwrap(); + fn deliver_to_a(&mut self, record: Vec) { + self.a.fsm.receive(self.time(), record, &self.a.crypto).unwrap(); } - fn deliver_to_b(&mut self, record: QlRecord>) { - self.b - .fsm - .receive(self.time(), record.encode(), &self.b.crypto) - .unwrap(); + fn deliver_to_b(&mut self, record: Vec) { + self.b.fsm.receive(self.time(), record, &self.b.crypto).unwrap(); } fn confirm_write_a(&mut self, write_id: SessionWriteId) { @@ -313,7 +307,8 @@ fn session_config(harness: &Harness, a: bool) -> SessionFsmConfig { SessionFsmConfig { local_parity: StreamParity::for_local(local, peer), - record_size: config.session_record_size, + record_target_size: config.session_record_target_size, + record_max_size: config.session_record_max_size, ack_delay: config.session_record_ack_delay, retransmit_timeout: config.session_record_retransmit_timeout, keepalive_interval: config.session_keepalive_interval, @@ -325,9 +320,10 @@ fn session_config(harness: &Harness, a: bool) -> SessionFsmConfig { fn decrypt_record( crypto: &impl QlCrypto, - record: &QlRecord>, + record: &[u8], session_key: &SessionKey, ) -> (ql_wire::SessionHeader, ql_wire::SessionRecord) { + let record = ql_wire::QlRecord::decode(record).unwrap(); let ql_wire::QlRecord::Session(record) = record else { panic!("expected encrypted session record"); }; From 350e3124ee9ab4ef9a7aea98d6c2fc829dbdf8fb Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 3 Apr 2026 17:41:20 -0400 Subject: [PATCH 081/304] ql-wire & ql-fsm: fix clippy --- ql-fsm/src/implementation/core.rs | 11 ++--- ql-fsm/src/implementation/handshake/ik.rs | 7 ++-- ql-fsm/src/implementation/handshake/kk.rs | 4 +- ql-fsm/src/implementation/handshake/mod.rs | 4 +- ql-fsm/src/lib.rs | 2 +- ql-fsm/src/session/mod.rs | 49 +++++++++++----------- ql-fsm/src/session/stream_rx.rs | 2 +- ql-fsm/src/session/stream_tx.rs | 2 +- ql-fsm/src/session/tests.rs | 12 +++--- ql-fsm/src/tests/mod.rs | 10 ++++- ql-wire/src/bytes.rs | 10 +++-- ql-wire/src/encrypted/builder.rs | 6 +-- ql-wire/src/handshake/ik.rs | 8 ++-- ql-wire/src/handshake/kk.rs | 8 ++-- ql-wire/src/handshake/mod.rs | 4 +- ql-wire/src/identity.rs | 1 + ql-wire/src/record.rs | 14 ++++--- ql-wire/src/tests.rs | 3 +- 18 files changed, 83 insertions(+), 74 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index 21285654..b3ffbda9 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -38,7 +38,7 @@ pub fn receive( let plaintext = wire::decrypt_record(crypto, &record.header, record.payload, &transport.rx_key)?; - let frames = wire::SessionRecord::parse(plaintext.as_ref())?; + let frames = wire::SessionRecord::parse(plaintext)?; let mut session_closed = false; fsm.session .receive(fsm.state.now.instant, record.header.seq, frames, { @@ -191,12 +191,9 @@ pub fn emit_peer_status(fsm: &mut QlFsm) { } pub fn reset_session(fsm: &mut QlFsm) { - let local_parity = fsm - .state - .peer - .as_ref() - .map(|peer| StreamParity::for_local(fsm.identity.xid, peer.xid)) - .unwrap_or(StreamParity::Even); + let local_parity = fsm.state.peer.as_ref().map_or(StreamParity::Even, |peer| { + StreamParity::for_local(fsm.identity.xid, peer.xid) + }); fsm.session = crate::session::SessionFsm::new( SessionFsmConfig { local_parity, diff --git a/ql-fsm/src/implementation/handshake/ik.rs b/ql-fsm/src/implementation/handshake/ik.rs index e30088c9..d3e3ce35 100644 --- a/ql-fsm/src/implementation/handshake/ik.rs +++ b/ql-fsm/src/implementation/handshake/ik.rs @@ -56,7 +56,7 @@ pub fn handle_ik1( handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; let outbound = handshake.write_2(crypto, message.meta)?; let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); - finish_handshake(fsm, transport, remote_bundle)?; + finish_handshake(fsm, transport, &remote_bundle)?; fsm.state.handshake = None; enqueue_handshake(fsm, QlHandshakeRecord::Ik2(outbound)); Ok(()) @@ -86,18 +86,17 @@ pub fn handle_ik2( }; let (transport, remote_bundle) = SessionTransport::from_finalized(state.handshake.finalize(crypto)?); - finish_handshake(fsm, transport, remote_bundle) + finish_handshake(fsm, transport, &remote_bundle) } pub fn should_ignore_inbound(fsm: &QlFsm, message: &Ik1) -> bool { match &fsm.state.link { - LinkState::Idle | LinkState::Connected(_) => false, + LinkState::Idle | LinkState::Connected(_) | LinkState::KkInitiator(_) => false, LinkState::IkInitiator(state) => { if fsm.state.peer.as_ref().map(|peer| peer.xid) != Some(message.header.sender) { return false; } super::local_start_wins(&state.initial_ephemeral, &message.ephemeral) } - LinkState::KkInitiator(_) => false, } } diff --git a/ql-fsm/src/implementation/handshake/kk.rs b/ql-fsm/src/implementation/handshake/kk.rs index 7c29943a..454f43e6 100644 --- a/ql-fsm/src/implementation/handshake/kk.rs +++ b/ql-fsm/src/implementation/handshake/kk.rs @@ -54,7 +54,7 @@ pub fn handle_kk1( handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; let outbound = handshake.write_2(crypto, message.meta)?; let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); - finish_handshake(fsm, transport, remote_bundle)?; + finish_handshake(fsm, transport, &remote_bundle)?; fsm.state.handshake = None; enqueue_handshake(fsm, QlHandshakeRecord::Kk2(outbound)); Ok(()) @@ -84,7 +84,7 @@ pub fn handle_kk2( }; let (transport, remote_bundle) = SessionTransport::from_finalized(state.handshake.finalize(crypto)?); - finish_handshake(fsm, transport, remote_bundle) + finish_handshake(fsm, transport, &remote_bundle) } pub fn should_ignore_inbound(fsm: &QlFsm, message: &Kk1) -> bool { diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs index 0b2d47c9..49202654 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -83,10 +83,10 @@ pub fn next_handshake_deadline(fsm: &QlFsm) -> Option { pub fn finish_handshake( fsm: &mut QlFsm, transport: SessionTransport, - remote_bundle: wire::PeerBundle, + remote_bundle: &wire::PeerBundle, ) -> Result<(), QlFsmError> { if let Some(peer) = fsm.state.peer.as_ref() { - if peer != &remote_bundle { + if peer != remote_bundle { return Err(QlFsmError::InvalidPayload); } } else { diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 49edf0bd..19ed58f3 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -102,7 +102,7 @@ pub enum QlSessionEvent { pub struct SessionWriteId(pub(crate) u64); /// outbound record produced by `QlFsm` -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct OutboundWrite { /// wire bytes to hand to the transport pub record: Vec, diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 618a668a..783d386a 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -257,17 +257,14 @@ impl SessionFsm { let closed = self.state.session_state == SessionState::Closed; let mut ack_eliciting = false; for frame in frames { - let frame = match frame { - Ok(frame) => frame, - Err(_) => { - self.fail_session( - SessionClose { - code: SessionCloseCode::PROTOCOL, - }, - &mut emit, - ); - return; - } + let Ok(frame) = frame else { + self.fail_session( + SessionClose { + code: SessionCloseCode::PROTOCOL, + }, + &mut emit, + ); + return; }; ack_eliciting |= !matches!(frame, SessionFrame::Ack(_)); if duplicate || closed { @@ -276,19 +273,19 @@ impl SessionFsm { match frame { SessionFrame::Ping => {} - SessionFrame::Ack(ack) => self.process_record_ack(ack, &mut emit), + SessionFrame::Ack(ack) => self.process_record_ack(&ack, &mut emit), SessionFrame::StreamData(frame) => { - if self.handle_stream_data(frame, &mut emit).is_err() { + if self.handle_stream_data(&frame, &mut emit).is_err() { return; } } SessionFrame::StreamWindow(frame) => { - if self.handle_stream_window(frame, &mut emit).is_err() { + if self.handle_stream_window(&frame, &mut emit).is_err() { return; } } SessionFrame::StreamClose(frame) => { - if self.handle_stream_close(frame, &mut emit).is_err() { + if self.handle_stream_close(&frame, &mut emit).is_err() { return; } } @@ -624,7 +621,7 @@ impl SessionFsm { } } - fn process_record_ack(&mut self, ack: RecordAck, emit: &mut impl FnMut(SessionEvent)) { + fn process_record_ack(&mut self, ack: &RecordAck, emit: &mut impl FnMut(SessionEvent)) { let stream_send_buffer_size = self.config.stream_send_buffer_size; { let tracked_records = &mut self.state.tracked_records; @@ -684,7 +681,7 @@ impl SessionFsm { fn handle_stream_data( &mut self, - frame: StreamData<&[u8]>, + frame: &StreamData<&[u8]>, emit: &mut impl FnMut(SessionEvent), ) -> Result<(), ()> { let stream_id = frame.stream_id; @@ -739,12 +736,14 @@ impl SessionFsm { self.try_reap_stream(stream_id); Ok(()) } - Err(StreamRxError::OutOfWindow) - | Err(StreamRxError::InconsistentFinalOffset) - | Err(StreamRxError::FinalOffsetBeforeBufferedData) - | Err(StreamRxError::BeyondFinalOffset) - | Err(StreamRxError::TooManyMissingRanges) - | Err(StreamRxError::OffsetOverflow) => { + Err( + StreamRxError::OutOfWindow + | StreamRxError::InconsistentFinalOffset + | StreamRxError::FinalOffsetBeforeBufferedData + | StreamRxError::BeyondFinalOffset + | StreamRxError::TooManyMissingRanges + | StreamRxError::OffsetOverflow, + ) => { self.fail_session( SessionClose { code: SessionCloseCode::PROTOCOL, @@ -758,7 +757,7 @@ impl SessionFsm { fn handle_stream_window( &mut self, - frame: StreamWindow, + frame: &StreamWindow, emit: &mut impl FnMut(SessionEvent), ) -> Result<(), ()> { let Some(stream) = self.state.streams.get_mut(&frame.stream_id) else { @@ -783,7 +782,7 @@ impl SessionFsm { fn handle_stream_close( &mut self, - frame: StreamClose, + frame: &StreamClose, emit: &mut impl FnMut(SessionEvent), ) -> Result<(), ()> { let created = match self.state.streams.entry(frame.stream_id) { diff --git a/ql-fsm/src/session/stream_rx.rs b/ql-fsm/src/session/stream_rx.rs index 31b18ec7..112e08a1 100644 --- a/ql-fsm/src/session/stream_rx.rs +++ b/ql-fsm/src/session/stream_rx.rs @@ -32,7 +32,7 @@ pub enum StreamRxError { TooManyMissingRanges, } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] pub struct StreamReadIter<'a> { front: Option<&'a [u8]>, back: Option<&'a [u8]>, diff --git a/ql-fsm/src/session/stream_tx.rs b/ql-fsm/src/session/stream_tx.rs index 5146522c..a5a5e417 100644 --- a/ql-fsm/src/session/stream_tx.rs +++ b/ql-fsm/src/session/stream_tx.rs @@ -111,7 +111,7 @@ impl StreamTx { } let credit_remaining = peer_max_offset.saturating_sub(segment.offset); - let credit_remaining = credit_remaining.min(usize::MAX as u64) as usize; + let credit_remaining = usize::try_from(credit_remaining).unwrap_or(usize::MAX); let len = segment.len.min(max_payload).min(credit_remaining); if len == 0 { continue; diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 48c7e19e..3d86a08f 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -34,7 +34,7 @@ fn receive_events( fsm: &mut SessionFsm, now: Instant, seq: RecordSeq, - record: SessionRecord, + record: &SessionRecord, ) -> Vec { let mut builder = SessionRecordBuilder::new( SessionRecordBuilder::WIRE_PREFIX_LEN + record.encoded_len(), @@ -174,7 +174,7 @@ fn commit_stream_read_is_what_advances_stream_window() { bytes: b"hi".to_vec(), })], }; - let events = receive_events(&mut fsm, now, RecordSeq(7), data); + let events = receive_events(&mut fsm, now, RecordSeq(7), &data); assert_eq!( events, vec![ @@ -189,7 +189,7 @@ fn commit_stream_read_is_what_advances_stream_window() { let read = fsm .stream_read(stream_id) .unwrap() - .map(|chunk| chunk.len()) + .map(<[u8]>::len) .sum::(); assert_eq!(read, 2); @@ -217,7 +217,7 @@ fn inbound_stream_data_emits_opened_and_readable() { })], }; - let events = receive_events(&mut fsm, now, RecordSeq(0), record); + let events = receive_events(&mut fsm, now, RecordSeq(0), &record); assert_eq!( events, vec![ @@ -294,12 +294,12 @@ fn duplicate_stream_data_is_not_redelivered() { bytes: b"hi".to_vec(), })], }; - let _ = receive_events(&mut fsm, now, RecordSeq(1), record.clone()); + let _ = receive_events(&mut fsm, now, RecordSeq(1), &record); let _ = receive_events( &mut fsm, now + Duration::from_millis(1), RecordSeq(2), - record, + &record, ); assert_eq!(read_stream_all(&mut fsm, stream_id), b"hi".to_vec()); diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 12abb202..5b68c3a0 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -247,11 +247,17 @@ impl Harness { } fn deliver_to_a(&mut self, record: Vec) { - self.a.fsm.receive(self.time(), record, &self.a.crypto).unwrap(); + self.a + .fsm + .receive(self.time(), record, &self.a.crypto) + .unwrap(); } fn deliver_to_b(&mut self, record: Vec) { - self.b.fsm.receive(self.time(), record, &self.b.crypto).unwrap(); + self.b + .fsm + .receive(self.time(), record, &self.b.crypto) + .unwrap(); } fn confirm_write_a(&mut self, write_id: SessionWriteId) { diff --git a/ql-wire/src/bytes.rs b/ql-wire/src/bytes.rs index 38cf1266..1a1294f0 100644 --- a/ql-wire/src/bytes.rs +++ b/ql-wire/src/bytes.rs @@ -24,6 +24,10 @@ pub trait ByteChunks { fn len(&self) -> usize; fn chunks(&self) -> Self::Chunks<'_>; + + fn is_empty(&self) -> bool { + self.len() == 0 + } } impl ByteSliceMut for B where B: ByteSlice + DerefMut {} @@ -95,7 +99,7 @@ impl ByteChunks for Vec { Self: 'a; fn len(&self) -> usize { - Vec::len(self) + Self::len(self) } fn chunks(&self) -> Self::Chunks<'_> { @@ -110,7 +114,7 @@ impl ByteChunks for VecDeque { Self: 'a; fn len(&self) -> usize { - VecDeque::len(self) + Self::len(self) } fn chunks(&self) -> Self::Chunks<'_> { @@ -256,7 +260,7 @@ mod tests { let chunks = ByteChunks::chunks(&bytes).collect::>(); assert_eq!(bytes.len(), 6); assert_eq!(chunks.concat(), b"cdefgh"); - assert!(chunks.len() >= 1); + assert!(!chunks.is_empty()); } #[test] diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index 08ea8c07..8dad710d 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -65,7 +65,7 @@ impl SessionRecordBuilder { let len = 1 + super::SIZE_LEN + ack.encoded_len(); self.push_encoded_len(len, |out| { out[0] = super::SessionFrameKind::Ack as u8; - super::push_variable_len(&mut out[1..1 + super::SIZE_LEN], ack.encoded_len()); + super::push_variable_len(&mut out[1..=super::SIZE_LEN], ack.encoded_len()); ack.encode_into(&mut out[1 + super::SIZE_LEN..]); }) } @@ -74,7 +74,7 @@ impl SessionRecordBuilder { let len = 1 + super::SIZE_LEN + frame.encoded_len(); self.push_encoded_len(len, |out| { out[0] = super::SessionFrameKind::StreamData as u8; - super::push_variable_len(&mut out[1..1 + super::SIZE_LEN], frame.encoded_len()); + super::push_variable_len(&mut out[1..=super::SIZE_LEN], frame.encoded_len()); frame.encode_into(&mut out[1 + super::SIZE_LEN..]); }) } @@ -91,7 +91,7 @@ impl SessionRecordBuilder { let len = 1 + super::SIZE_LEN + frame.encoded_len(); self.push_encoded_len(len, |out| { out[0] = super::SessionFrameKind::StreamClose as u8; - super::push_variable_len(&mut out[1..1 + super::SIZE_LEN], frame.encoded_len()); + super::push_variable_len(&mut out[1..=super::SIZE_LEN], frame.encoded_len()); frame.encode_into(&mut out[1 + super::SIZE_LEN..]); }) } diff --git a/ql-wire/src/handshake/ik.rs b/ql-wire/src/handshake/ik.rs index d5b209b2..63e86c1b 100644 --- a/ql-wire/src/handshake/ik.rs +++ b/ql-wire/src/handshake/ik.rs @@ -40,7 +40,7 @@ impl Ik1 { let meta = HandshakeMeta::decode_from(&mut reader)?; let skem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); let ephemeral = - EphemeralPublicKey::decode(&reader.take_bytes(EphemeralPublicKey::ENCODED_LEN)?)?; + EphemeralPublicKey::decode(reader.take_bytes(EphemeralPublicKey::ENCODED_LEN)?)?; let static_bundle = EncryptedPeerBundle::from_data(reader.take_array()?); reader.finish()?; Ok(Self { @@ -226,7 +226,7 @@ impl IkHandshake { if self.step != IkStep::Send2 { return Err(WireError::InvalidState); } - require_handshake_meta(&self.handshake_meta, meta)?; + require_handshake_meta(self.handshake_meta.as_ref(), meta)?; let header = self.outbound_header()?; mix_hash_routed_handshake( &mut self.symmetric, @@ -317,7 +317,7 @@ impl IkHandshake { return Err(WireError::InvalidState); } message.meta.ensure_not_expired(now_seconds)?; - require_handshake_meta(&self.handshake_meta, message.meta)?; + require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; self.ensure_inbound_recipient(message.header)?; self.ensure_known_remote_sender(message.header)?; mix_hash_routed_handshake( @@ -354,7 +354,7 @@ impl IkHandshake { let remote_bundle = self.remote_bundle.ok_or(WireError::InvalidState)?; Ok(finalize_handshake( crypto, - self.symmetric, + &self.symmetric, self.role, remote_bundle, )) diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index 244bde96..9df80a57 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -36,7 +36,7 @@ impl Kk1 { let meta = HandshakeMeta::decode_from(&mut reader)?; let skem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); let ephemeral = - EphemeralPublicKey::decode(&reader.take_bytes(EphemeralPublicKey::ENCODED_LEN)?)?; + EphemeralPublicKey::decode(reader.take_bytes(EphemeralPublicKey::ENCODED_LEN)?)?; reader.finish()?; Ok(Self { header, @@ -214,7 +214,7 @@ impl KkHandshake { if self.step != KkStep::Send2 { return Err(WireError::InvalidState); } - require_handshake_meta(&self.handshake_meta, meta)?; + require_handshake_meta(self.handshake_meta.as_ref(), meta)?; let header = self.outbound_header(); mix_hash_routed_handshake( &mut self.symmetric, @@ -290,7 +290,7 @@ impl KkHandshake { return Err(WireError::InvalidState); } message.meta.ensure_not_expired(now_seconds)?; - require_handshake_meta(&self.handshake_meta, message.meta)?; + require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; self.ensure_inbound_header(message.header)?; mix_hash_routed_handshake( &mut self.symmetric, @@ -325,7 +325,7 @@ impl KkHandshake { } Ok(finalize_handshake( crypto, - self.symmetric, + &self.symmetric, self.role, self.remote_bundle, )) diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index fa050fdf..df689048 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -338,7 +338,7 @@ fn initialize_handshake_meta( } fn require_handshake_meta( - expected: &Option, + expected: Option<&HandshakeMeta>, meta: HandshakeMeta, ) -> Result<(), WireError> { match expected { @@ -400,7 +400,7 @@ fn decrypt_mlkem_ciphertext( fn finalize_handshake( crypto: &impl QlCrypto, - symmetric: SymmetricState, + symmetric: &SymmetricState, role: Role, remote_bundle: PeerBundle, ) -> FinalizedHandshake { diff --git a/ql-wire/src/identity.rs b/ql-wire/src/identity.rs index 2dfda418..1d640893 100644 --- a/ql-wire/src/identity.rs +++ b/ql-wire/src/identity.rs @@ -61,6 +61,7 @@ impl QlIdentity { } } + #[must_use] pub fn with_capabilities(mut self, capabilities: u32) -> Self { self.capabilities = capabilities; self diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index a01ffd72..fde3d19a 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -114,7 +114,7 @@ impl QlHandshakeRecord { } pub fn decode(bytes: &[u8]) -> Result { - Ok(Self::parse(bytes)?) + Self::parse(bytes) } pub fn parse(bytes: B) -> Result { @@ -131,8 +131,12 @@ impl QlHandshakeRecord { impl> QlSessionRecord { pub fn encode(&self) -> Vec { - let mut out = - vec![0; 2 + SessionHeader::ENCODED_LEN + EncryptedMessage::<&[u8]>::HEADER_LEN + self.payload.ciphertext.as_ref().len()]; + let mut out = vec![ + 0; + 2 + SessionHeader::ENCODED_LEN + + EncryptedMessage::<&[u8]>::HEADER_LEN + + self.payload.ciphertext.as_ref().len() + ]; let rest = codec::write_u8(&mut out, QL_WIRE_VERSION); let rest = codec::write_u8(rest, RecordType::Session as u8); let rest = codec::write_bytes(rest, &self.header.encode()); @@ -143,7 +147,7 @@ impl> QlSessionRecord { impl QlSessionRecord> { pub fn decode(bytes: &[u8]) -> Result { - Ok(QlSessionRecord::parse(bytes)?.into_owned()) + QlSessionRecord::parse(bytes).map(QlSessionRecord::into_owned) } } @@ -178,7 +182,7 @@ impl> QlRecord { impl QlRecord> { pub fn decode(bytes: &[u8]) -> Result { - Ok(QlRecord::parse(bytes)?.into_owned()) + QlRecord::parse(bytes).map(QlRecord::into_owned) } } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index c9ed6245..c6c8aa27 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -554,8 +554,7 @@ fn protocol_record_size_breakdown() { let mut kk_initiator = KkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); - let mut kk_responder = - KkHandshake::new_responder(&crypto, responder.clone(), initiator.bundle()); + let mut kk_responder = KkHandshake::new_responder(&crypto, responder, initiator.bundle()); let kk1 = kk_initiator.write_1(&crypto, handshake_meta(201)).unwrap(); kk_responder.read_1(&crypto, 0, &kk1).unwrap(); From 65b78ad8eee96080cdd93bcd98d5c214cbc912a7 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 3 Apr 2026 19:26:47 -0400 Subject: [PATCH 082/304] ql: use bitmap for record ack --- ql-fsm/src/session/mod.rs | 12 +- ql-fsm/src/session/received_records.rs | 197 +++++++++++++++++++------ ql-fsm/src/session/tests.rs | 10 +- ql-wire/src/encrypted/ack.rs | 98 ++++++------ ql-wire/src/encrypted/builder.rs | 5 +- ql-wire/src/encrypted/mod.rs | 13 +- ql-wire/src/tests.rs | 11 +- 7 files changed, 233 insertions(+), 113 deletions(-) diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 783d386a..98bbf148 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -18,7 +18,7 @@ use ql_wire::{ }; use self::{ - received_records::{ReceiveInsertOutcome, ReceivedRecords}, + received_records::{ReceiveOutcome, ReceivedRecords}, state::{AckState, InboundState, OutboundState, SessionFsmState, StreamRole, StreamState}, stream_parity::StreamParity, stream_rx::{StreamReadIter, StreamRxError}, @@ -250,8 +250,8 @@ impl SessionFsm { self.state.last_inbound_at = self.state.now; let (duplicate, out_of_order) = match self.state.received_records.insert(seq) { - ReceiveInsertOutcome::Duplicate => (true, false), - ReceiveInsertOutcome::New { out_of_order } => (false, out_of_order), + ReceiveOutcome::Duplicate => (true, false), + ReceiveOutcome::New { out_of_order } => (false, out_of_order), }; let closed = self.state.session_state == SessionState::Closed; @@ -627,11 +627,7 @@ impl SessionFsm { let tracked_records = &mut self.state.tracked_records; let streams = &mut self.state.streams; for (_, record) in tracked_records.extract_if(.., |_, record| { - record.sent_at.is_some() - && ack - .ranges - .iter() - .any(|range| range.start <= record.seq.0 && record.seq.0 < range.end) + record.sent_at.is_some() && ack.contains(record.seq.0) }) { for frame in &record.frames { acknowledge_tracked_frame(streams, stream_send_buffer_size, frame, emit); diff --git a/ql-fsm/src/session/received_records.rs b/ql-fsm/src/session/received_records.rs index 82341847..03bd08fa 100644 --- a/ql-fsm/src/session/received_records.rs +++ b/ql-fsm/src/session/received_records.rs @@ -1,74 +1,185 @@ -use std::collections::BTreeSet; - -use ql_wire::{RecordAck, RecordAckRange, RecordSeq}; +use ql_wire::{RecordAck, RecordSeq}; #[derive(Debug, Default)] pub struct ReceivedRecords { - seen: BTreeSet, + seen: u64, + base: u64, largest: Option, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ReceiveInsertOutcome { +pub enum ReceiveOutcome { New { out_of_order: bool }, Duplicate, } impl ReceivedRecords { - const TRACKED_WINDOW: u64 = 256; + const TRACKED_LEN: u64 = RecordAck::BITMAP_BITS as u64; + const TRACKED_WINDOW: u64 = Self::TRACKED_LEN - 1; + + pub fn insert(&mut self, seq: RecordSeq) -> ReceiveOutcome { + let seq = seq.0; + let Some(largest) = self.largest else { + self.base = seq; + self.seen = 1; + self.largest = Some(seq); + return ReceiveOutcome::New { + out_of_order: false, + }; + }; + + if largest.saturating_sub(seq) > Self::TRACKED_WINDOW { + return ReceiveOutcome::Duplicate; + } - pub fn insert(&mut self, seq: RecordSeq) -> ReceiveInsertOutcome { - if self.seen.contains(&seq.0) { - return ReceiveInsertOutcome::Duplicate; + let out_of_order = seq != largest.saturating_add(1); + if seq > largest { + self.advance_base(seq.saturating_sub(Self::TRACKED_WINDOW)); + self.largest = Some(seq); } - if self - .largest - .is_some_and(|largest| largest.saturating_sub(seq.0) > Self::TRACKED_WINDOW) - { - return ReceiveInsertOutcome::Duplicate; + let Some(bit) = self.bit_for(seq) else { + return ReceiveOutcome::Duplicate; + }; + if self.seen & bit != 0 { + return ReceiveOutcome::Duplicate; } - let out_of_order = self - .largest - .is_some_and(|largest| seq.0 != largest.saturating_add(1)); - self.seen.insert(seq.0); - self.largest = Some(self.largest.map_or(seq.0, |largest| largest.max(seq.0))); - self.prune(); - ReceiveInsertOutcome::New { out_of_order } + self.seen |= bit; + ReceiveOutcome::New { out_of_order } } pub fn ack(&self) -> Option { - if self.seen.is_empty() { + (self.seen != 0).then_some(RecordAck { + base_seq: RecordSeq(self.base), + bits: self.seen, + }) + } + + fn bit_for(&self, seq: u64) -> Option { + if seq < self.base { return None; } - let mut ranges = Vec::new(); - let mut iter = self.seen.iter().copied(); - let first = iter.next()?; - let mut start = first; - let mut end = first.saturating_add(1); + let offset = seq - self.base; + (offset < Self::TRACKED_LEN).then_some(1u64 << offset) + } - for seq in iter { - if seq == end { - end = end.saturating_add(1); - continue; - } + fn advance_base(&mut self, new_base: u64) { + if new_base <= self.base { + return; + } - ranges.push(RecordAckRange { start, end }); - start = seq; - end = seq.saturating_add(1); + let shift = new_base - self.base; + if shift >= Self::TRACKED_LEN { + self.seen = 0; + } else { + self.seen >>= shift; } + self.base = new_base; + } +} + +#[cfg(test)] +mod tests { + use ql_wire::{RecordAck, RecordSeq}; + + use super::{ReceiveOutcome, ReceivedRecords}; + + #[test] + fn inserts_pack_contiguous_bits() { + let mut received = ReceivedRecords::default(); + + assert_eq!( + received.insert(RecordSeq(10)), + ReceiveOutcome::New { + out_of_order: false + } + ); + assert_eq!( + received.insert(RecordSeq(12)), + ReceiveOutcome::New { out_of_order: true } + ); + assert_eq!( + received.insert(RecordSeq(11)), + ReceiveOutcome::New { out_of_order: true } + ); - ranges.push(RecordAckRange { start, end }); - Some(RecordAck { ranges }) + let ack = received.ack().unwrap(); + assert_eq!( + ack, + RecordAck { + base_seq: RecordSeq(10), + bits: 0b111, + } + ); } - fn prune(&mut self) { - let Some(largest) = self.largest else { - return; - }; - let keep_from = largest.saturating_sub(Self::TRACKED_WINDOW); - self.seen.retain(|seq| *seq >= keep_from); + #[test] + fn old_records_fall_out_of_fixed_window() { + let mut received = ReceivedRecords::default(); + + assert_eq!( + received.insert(RecordSeq(0)), + ReceiveOutcome::New { + out_of_order: false + } + ); + assert_eq!( + received.insert(RecordSeq(300)), + ReceiveOutcome::New { out_of_order: true } + ); + assert_eq!(received.insert(RecordSeq(0)), ReceiveOutcome::Duplicate); + + let ack = received.ack().unwrap(); + assert_eq!( + ack, + RecordAck { + base_seq: RecordSeq(237), + bits: 1u64 << 63, + } + ); + } + + #[test] + fn duplicate_in_window_is_rejected() { + let mut received = ReceivedRecords::default(); + + assert_eq!( + received.insert(RecordSeq(7)), + ReceiveOutcome::New { + out_of_order: false + } + ); + assert_eq!(received.insert(RecordSeq(7)), ReceiveOutcome::Duplicate); + } + + #[test] + fn sliding_window_preserves_relative_bits() { + let mut received = ReceivedRecords::default(); + + assert_eq!( + received.insert(RecordSeq(10)), + ReceiveOutcome::New { + out_of_order: false + } + ); + assert_eq!( + received.insert(RecordSeq(12)), + ReceiveOutcome::New { out_of_order: true } + ); + assert_eq!( + received.insert(RecordSeq(70)), + ReceiveOutcome::New { out_of_order: true } + ); + + let ack = received.ack().unwrap(); + assert_eq!( + ack, + RecordAck { + base_seq: RecordSeq(10), + bits: (1u64 << 0) | (1u64 << 2) | (1u64 << 60), + } + ); } } diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 3d86a08f..0173176b 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -1,8 +1,8 @@ use std::time::{Duration, Instant}; use ql_wire::{ - CloseTarget, RecordAck, RecordAckRange, RecordSeq, SessionFrame, SessionRecord, - SessionRecordBuilder, StreamClose, StreamCloseCode, StreamData, StreamId, XID, + CloseTarget, RecordAck, RecordSeq, SessionFrame, SessionRecord, SessionRecordBuilder, + StreamClose, StreamCloseCode, StreamData, StreamId, XID, }; use super::{SessionEvent, SessionFsm, SessionFsmConfig}; @@ -142,10 +142,8 @@ fn ack_reopens_write_capacity() { now + Duration::from_millis(1), RecordSeq(9), std::iter::once(Ok(SessionFrame::Ack(RecordAck { - ranges: vec![RecordAckRange { - start: seq.0, - end: seq.0 + 1, - }], + base_seq: seq, + bits: 1u64, }))), |event| events.push(event), ); diff --git a/ql-wire/src/encrypted/ack.rs b/ql-wire/src/encrypted/ack.rs index b9649128..391a04b8 100644 --- a/ql-wire/src/encrypted/ack.rs +++ b/ql-wire/src/encrypted/ack.rs @@ -1,60 +1,72 @@ -use crate::{codec, WireError}; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct RecordAck { - pub ranges: Vec, -} +use crate::{codec, RecordSeq, WireError}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct RecordAckRange { - pub start: u64, - pub end: u64, +pub struct RecordAck { + pub base_seq: RecordSeq, + pub bits: u64, } impl RecordAck { - pub const RANGE_ENCODED_LEN: usize = size_of::() + size_of::(); + pub const BITMAP_BITS: usize = u64::BITS as usize; + pub const ENCODED_LEN: usize = size_of::() + size_of::(); pub fn decode(bytes: &[u8]) -> Result { - if bytes.is_empty() || bytes.len() % Self::RANGE_ENCODED_LEN != 0 { - return Err(WireError::InvalidPayload); + let mut reader = codec::Reader::new(bytes); + Ok(Self { + base_seq: RecordSeq(reader.take_u64()?), + bits: reader.take_u64()?, + }) + } + + pub fn contains(&self, seq: u64) -> bool { + if seq < self.base_seq.0 { + return false; } - let mut reader = codec::Reader::new(bytes); - let mut ranges = Vec::with_capacity(bytes.len() / Self::RANGE_ENCODED_LEN); - let mut previous_end = 0; - - while !reader.is_empty() { - let range = RecordAckRange { - start: reader.take_u64()?, - end: reader.take_u64()?, - }; - - if range.start >= range.end { - return Err(WireError::InvalidPayload); - } - if !ranges.is_empty() && range.start < previous_end { - return Err(WireError::InvalidPayload); - } - - previous_end = range.end; - ranges.push(range); + let offset = seq - self.base_seq.0; + if offset >= Self::BITMAP_BITS as u64 { + return false; } - Ok(Self { ranges }) + (self.bits & (1u64 << offset)) != 0 } - pub fn encoded_len(&self) -> usize { - self.ranges.len() * Self::RANGE_ENCODED_LEN + pub fn encode_into(&self, out: &mut [u8]) { + assert_eq!(out.len(), Self::ENCODED_LEN); + let out = codec::write_u64(out, self.base_seq.0); + let _ = codec::write_u64(out, self.bits); } +} - pub fn encode_into(&self, out: &mut [u8]) { - assert_eq!(out.len(), self.encoded_len()); - let mut out = out; - for range in &self.ranges { - let (encoded, rest) = out.split_at_mut(Self::RANGE_ENCODED_LEN); - let encoded = codec::write_u64(encoded, range.start); - let _ = codec::write_u64(encoded, range.end); - out = rest; - } +#[cfg(test)] +mod tests { + use super::RecordAck; + use crate::RecordSeq; + + #[test] + fn encode_decode_round_trip() { + let ack = RecordAck { + base_seq: RecordSeq(42), + bits: (1u64 << 0) | (1u64 << 17) | (1u64 << 63), + }; + let mut encoded = [0; RecordAck::ENCODED_LEN]; + ack.encode_into(&mut encoded); + + assert_eq!(RecordAck::decode(&encoded).unwrap(), ack); + } + + #[test] + fn contains_matches_bit_membership() { + let ack = RecordAck { + base_seq: RecordSeq(100), + bits: (1u64 << 0) | (1u64 << 5) | (1u64 << 63), + }; + + assert!(ack.contains(100)); + assert!(ack.contains(105)); + assert!(ack.contains(163)); + assert!(!ack.contains(99)); + assert!(!ack.contains(101)); + assert!(!ack.contains(164)); } } diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index 8dad710d..b728a204 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -62,11 +62,10 @@ impl SessionRecordBuilder { } pub fn push_ack(&mut self, ack: &RecordAck) -> bool { - let len = 1 + super::SIZE_LEN + ack.encoded_len(); + let len = 1 + RecordAck::ENCODED_LEN; self.push_encoded_len(len, |out| { out[0] = super::SessionFrameKind::Ack as u8; - super::push_variable_len(&mut out[1..=super::SIZE_LEN], ack.encoded_len()); - ack.encode_into(&mut out[1 + super::SIZE_LEN..]); + ack.encode_into(&mut out[1..]); }) } diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 11a520ef..4924c9e2 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -101,7 +101,7 @@ impl SessionFrame { pub fn encoded_len(&self) -> usize { 1 + match self { Self::Ping => 0, - Self::Ack(frame) => SIZE_LEN + frame.encoded_len(), + Self::Ack(_) => RecordAck::ENCODED_LEN, Self::StreamData(frame) => SIZE_LEN + frame.encoded_len(), Self::StreamWindow(_) => StreamWindow::WIRE_SIZE, Self::StreamClose(frame) => SIZE_LEN + frame.encoded_len(), @@ -177,7 +177,9 @@ fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrame<&[u8]>, &[u8]), WireEr match SessionFrameKind::try_from(kind)? { SessionFrameKind::Ping => Ok((SessionFrame::Ping, rest)), SessionFrameKind::Ack => { - let (frame, rest) = split_variable_frame(rest)?; + let (frame, rest) = rest + .split_at_checked(RecordAck::ENCODED_LEN) + .ok_or(WireError::InvalidPayload)?; Ok((SessionFrame::Ack(RecordAck::decode(frame)?), rest)) } SessionFrameKind::StreamData => { @@ -185,10 +187,9 @@ fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrame<&[u8]>, &[u8]), WireEr Ok((SessionFrame::StreamData(StreamData::parse(frame)?), rest)) } SessionFrameKind::StreamWindow => { - if rest.len() < StreamWindow::WIRE_SIZE { - return Err(WireError::InvalidPayload); - } - let (frame, rest) = rest.split_at(StreamWindow::WIRE_SIZE); + let (frame, rest) = rest + .split_at_checked(StreamWindow::WIRE_SIZE) + .ok_or(WireError::InvalidPayload)?; Ok(( SessionFrame::StreamWindow(StreamWindow::decode(frame)?), rest, diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index c6c8aa27..82f6424b 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -467,10 +467,13 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { frames: vec![ SessionFrame::Ping, SessionFrame::Ack(RecordAck { - ranges: vec![ - RecordAckRange { start: 12, end: 14 }, - RecordAckRange { start: 20, end: 24 }, - ], + base_seq: RecordSeq(12), + bits: (1u64 << 0) + | (1u64 << 1) + | (1u64 << 8) + | (1u64 << 9) + | (1u64 << 10) + | (1u64 << 11), }), SessionFrame::StreamWindow(StreamWindow { stream_id: StreamId(9), From ac2fbdd970deb0e02c076c582fd1c46241838593 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 4 Apr 2026 08:57:21 -0400 Subject: [PATCH 083/304] ql: wire_size over encoded_len --- ql-fsm/src/session/tests.rs | 4 ++-- ql-wire/src/encrypted/ack.rs | 25 ++++++++++++++++---- ql-wire/src/encrypted/builder.rs | 34 +++++++++++++-------------- ql-wire/src/encrypted/mod.rs | 25 ++++++++++---------- ql-wire/src/encrypted/stream_close.rs | 4 ---- ql-wire/src/encrypted/stream_data.rs | 4 ++-- ql-wire/src/handshake/ik.rs | 16 ++++++------- ql-wire/src/handshake/kk.rs | 14 +++++------ ql-wire/src/handshake/meta.rs | 6 ++--- ql-wire/src/handshake/mod.rs | 32 ++++++++++++------------- ql-wire/src/header.rs | 8 +++---- ql-wire/src/identity.rs | 4 ++-- ql-wire/src/record.rs | 14 +++++------ ql-wire/src/tests.rs | 2 +- 14 files changed, 102 insertions(+), 90 deletions(-) diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 0173176b..b44174bf 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -37,8 +37,8 @@ fn receive_events( record: &SessionRecord, ) -> Vec { let mut builder = SessionRecordBuilder::new( - SessionRecordBuilder::WIRE_PREFIX_LEN + record.encoded_len(), - SessionRecordBuilder::WIRE_PREFIX_LEN + record.encoded_len(), + SessionRecordBuilder::WIRE_PREFIX_LEN + record.wire_size(), + SessionRecordBuilder::WIRE_PREFIX_LEN + record.wire_size(), ); for frame in &record.frames { assert!(builder.push_frame(frame)); diff --git a/ql-wire/src/encrypted/ack.rs b/ql-wire/src/encrypted/ack.rs index 391a04b8..a5ba6d1c 100644 --- a/ql-wire/src/encrypted/ack.rs +++ b/ql-wire/src/encrypted/ack.rs @@ -8,9 +8,13 @@ pub struct RecordAck { impl RecordAck { pub const BITMAP_BITS: usize = u64::BITS as usize; - pub const ENCODED_LEN: usize = size_of::() + size_of::(); + pub const WIRE_SIZE: usize = size_of::() + size_of::(); pub fn decode(bytes: &[u8]) -> Result { + if bytes.len() != Self::WIRE_SIZE { + return Err(WireError::InvalidPayload); + } + let mut reader = codec::Reader::new(bytes); Ok(Self { base_seq: RecordSeq(reader.take_u64()?), @@ -32,7 +36,7 @@ impl RecordAck { } pub fn encode_into(&self, out: &mut [u8]) { - assert_eq!(out.len(), Self::ENCODED_LEN); + assert_eq!(out.len(), Self::WIRE_SIZE); let out = codec::write_u64(out, self.base_seq.0); let _ = codec::write_u64(out, self.bits); } @@ -41,7 +45,7 @@ impl RecordAck { #[cfg(test)] mod tests { use super::RecordAck; - use crate::RecordSeq; + use crate::{RecordSeq, WireError}; #[test] fn encode_decode_round_trip() { @@ -49,7 +53,7 @@ mod tests { base_seq: RecordSeq(42), bits: (1u64 << 0) | (1u64 << 17) | (1u64 << 63), }; - let mut encoded = [0; RecordAck::ENCODED_LEN]; + let mut encoded = [0; RecordAck::WIRE_SIZE]; ack.encode_into(&mut encoded); assert_eq!(RecordAck::decode(&encoded).unwrap(), ack); @@ -69,4 +73,17 @@ mod tests { assert!(!ack.contains(101)); assert!(!ack.contains(164)); } + + #[test] + fn decode_rejects_invalid_length() { + assert_eq!(RecordAck::decode(&[]), Err(WireError::InvalidPayload)); + assert_eq!( + RecordAck::decode(&[0; RecordAck::WIRE_SIZE - 1]), + Err(WireError::InvalidPayload) + ); + assert_eq!( + RecordAck::decode(&[0; RecordAck::WIRE_SIZE + 1]), + Err(WireError::InvalidPayload) + ); + } } diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index b728a204..d4092123 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -10,7 +10,7 @@ pub struct SessionRecordBuilder { impl SessionRecordBuilder { pub const WIRE_PREFIX_LEN: usize = - 1 + 1 + SessionHeader::ENCODED_LEN + crate::ENCRYPTED_MESSAGE_AUTH_SIZE; + 1 + 1 + SessionHeader::WIRE_SIZE + crate::ENCRYPTED_MESSAGE_AUTH_SIZE; pub fn new(max_capacity: usize, initial_capacity: usize) -> Self { assert!(initial_capacity <= max_capacity); @@ -58,46 +58,46 @@ impl SessionRecordBuilder { } pub fn push_ping(&mut self) -> bool { - self.push_encoded_len(1, |out| out[0] = super::SessionFrameKind::Ping as u8) + self.push_wire_size(1, |out| out[0] = super::SessionFrameKind::Ping as u8) } pub fn push_ack(&mut self, ack: &RecordAck) -> bool { - let len = 1 + RecordAck::ENCODED_LEN; - self.push_encoded_len(len, |out| { + let len = 1 + RecordAck::WIRE_SIZE; + self.push_wire_size(len, |out| { out[0] = super::SessionFrameKind::Ack as u8; ack.encode_into(&mut out[1..]); }) } pub fn push_stream_data(&mut self, frame: &StreamData) -> bool { - let len = 1 + super::SIZE_LEN + frame.encoded_len(); - self.push_encoded_len(len, |out| { + let len = 1 + super::SIZE_LEN + frame.wire_size(); + self.push_wire_size(len, |out| { out[0] = super::SessionFrameKind::StreamData as u8; - super::push_variable_len(&mut out[1..=super::SIZE_LEN], frame.encoded_len()); + super::push_variable_len(&mut out[1..=super::SIZE_LEN], frame.wire_size()); frame.encode_into(&mut out[1 + super::SIZE_LEN..]); }) } pub fn push_stream_window(&mut self, frame: &StreamWindow) -> bool { let len = 1 + StreamWindow::WIRE_SIZE; - self.push_encoded_len(len, |out| { + self.push_wire_size(len, |out| { out[0] = super::SessionFrameKind::StreamWindow as u8; frame.encode_into(&mut out[1..]); }) } pub fn push_stream_close(&mut self, frame: &StreamClose) -> bool { - let len = 1 + super::SIZE_LEN + frame.encoded_len(); - self.push_encoded_len(len, |out| { + let len = 1 + super::SIZE_LEN + StreamClose::WIRE_SIZE; + self.push_wire_size(len, |out| { out[0] = super::SessionFrameKind::StreamClose as u8; - super::push_variable_len(&mut out[1..=super::SIZE_LEN], frame.encoded_len()); + super::push_variable_len(&mut out[1..=super::SIZE_LEN], StreamClose::WIRE_SIZE); frame.encode_into(&mut out[1 + super::SIZE_LEN..]); }) } pub fn push_close(&mut self, close: &SessionClose) -> bool { let len = 1 + SessionClose::WIRE_SIZE; - self.push_encoded_len(len, |out| { + self.push_wire_size(len, |out| { out[0] = super::SessionFrameKind::Close as u8; close.encode_into(&mut out[1..]); }) @@ -132,17 +132,17 @@ impl SessionRecordBuilder { let prefix = &mut self.bytes[..self.body_start]; prefix[0] = QL_WIRE_VERSION; prefix[1] = RecordType::Session as u8; - header.encode_into(&mut prefix[2..2 + SessionHeader::ENCODED_LEN]); - prefix[2 + SessionHeader::ENCODED_LEN..].copy_from_slice(&auth); + header.encode_into(&mut prefix[2..2 + SessionHeader::WIRE_SIZE]); + prefix[2 + SessionHeader::WIRE_SIZE..].copy_from_slice(&auth); self.bytes } - fn push_encoded_len(&mut self, len: usize, encode: impl FnOnce(&mut [u8])) -> bool { - if !self.can_push_len(len) { + fn push_wire_size(&mut self, wire_size: usize, encode: impl FnOnce(&mut [u8])) -> bool { + if !self.can_push_len(wire_size) { return false; } let start = self.bytes.len(); - self.bytes.resize(start + len, 0); + self.bytes.resize(start + wire_size, 0); encode(&mut self.bytes[start..]); true } diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 4924c9e2..7f098b38 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -89,22 +89,22 @@ impl SessionRecord { Ok(Self { frames }) } - pub fn encoded_len(&self) -> usize { + pub fn wire_size(&self) -> usize { self.frames .iter() - .map(SessionFrame::encoded_len) + .map(SessionFrame::wire_size) .sum::() } } impl SessionFrame { - pub fn encoded_len(&self) -> usize { + pub fn wire_size(&self) -> usize { 1 + match self { Self::Ping => 0, - Self::Ack(_) => RecordAck::ENCODED_LEN, - Self::StreamData(frame) => SIZE_LEN + frame.encoded_len(), + Self::Ack(_) => RecordAck::WIRE_SIZE, + Self::StreamData(frame) => SIZE_LEN + frame.wire_size(), Self::StreamWindow(_) => StreamWindow::WIRE_SIZE, - Self::StreamClose(frame) => SIZE_LEN + frame.encoded_len(), + Self::StreamClose(_) => SIZE_LEN + StreamClose::WIRE_SIZE, Self::Close(_) => SessionClose::WIRE_SIZE, } } @@ -151,8 +151,8 @@ pub fn encrypt_record( session_key: &SessionKey, body: &SessionRecord, ) -> QlSessionRecord> { - let encoded_len = body.encoded_len() + SessionRecordBuilder::WIRE_PREFIX_LEN; - let mut builder = SessionRecordBuilder::new(encoded_len, encoded_len); + let wire_size = body.wire_size() + SessionRecordBuilder::WIRE_PREFIX_LEN; + let mut builder = SessionRecordBuilder::new(wire_size, wire_size); for frame in &body.frames { let pushed = builder.push_frame(frame); debug_assert!(pushed); @@ -178,7 +178,7 @@ fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrame<&[u8]>, &[u8]), WireEr SessionFrameKind::Ping => Ok((SessionFrame::Ping, rest)), SessionFrameKind::Ack => { let (frame, rest) = rest - .split_at_checked(RecordAck::ENCODED_LEN) + .split_at_checked(RecordAck::WIRE_SIZE) .ok_or(WireError::InvalidPayload)?; Ok((SessionFrame::Ack(RecordAck::decode(frame)?), rest)) } @@ -200,10 +200,9 @@ fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrame<&[u8]>, &[u8]), WireEr Ok((SessionFrame::StreamClose(StreamClose::parse(frame)?), rest)) } SessionFrameKind::Close => { - if rest.len() < SessionClose::WIRE_SIZE { - return Err(WireError::InvalidPayload); - } - let (frame, rest) = rest.split_at(SessionClose::WIRE_SIZE); + let (frame, rest) = rest + .split_at_checked(SessionClose::WIRE_SIZE) + .ok_or(WireError::InvalidPayload)?; Ok((SessionFrame::Close(SessionClose::decode(frame)?), rest)) } } diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index 1589a14d..a83cc79a 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -24,10 +24,6 @@ impl StreamClose { Ok(close) } - pub fn encoded_len(&self) -> usize { - Self::WIRE_SIZE - } - pub fn encode_into(&self, out: &mut [u8]) { let out = codec::write_u32(out, self.stream_id.0); let out = codec::write_u8(out, self.target.to_wire()); diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs index 2e7deb52..a4630ae5 100644 --- a/ql-wire/src/encrypted/stream_data.rs +++ b/ql-wire/src/encrypted/stream_data.rs @@ -41,12 +41,12 @@ impl StreamData { } impl StreamData { - pub fn encoded_len(&self) -> usize { + pub fn wire_size(&self) -> usize { Self::MIN_WIRE_SIZE + self.bytes.len() } pub fn encode_into(&self, out: &mut [u8]) { - assert_eq!(out.len(), self.encoded_len()); + assert_eq!(out.len(), self.wire_size()); let out = codec::write_u32(out, self.stream_id.0); let out = codec::write_u64(out, self.offset); let mut out = codec::write_bool(out, self.fin); diff --git a/ql-wire/src/handshake/ik.rs b/ql-wire/src/handshake/ik.rs index 63e86c1b..27250d58 100644 --- a/ql-wire/src/handshake/ik.rs +++ b/ql-wire/src/handshake/ik.rs @@ -20,11 +20,11 @@ pub struct Ik1 { } impl Ik1 { - pub const ENCODED_LEN: usize = HandshakeHeader::ENCODED_LEN - + HandshakeMeta::ENCODED_LEN + pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + MlKemCiphertext::SIZE - + EphemeralPublicKey::ENCODED_LEN - + EncryptedPeerBundle::ENCODED_LEN; + + EphemeralPublicKey::WIRE_SIZE + + EncryptedPeerBundle::WIRE_SIZE; pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { let out = self.header.encode_into(out); @@ -40,7 +40,7 @@ impl Ik1 { let meta = HandshakeMeta::decode_from(&mut reader)?; let skem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); let ephemeral = - EphemeralPublicKey::decode(reader.take_bytes(EphemeralPublicKey::ENCODED_LEN)?)?; + EphemeralPublicKey::decode(reader.take_bytes(EphemeralPublicKey::WIRE_SIZE)?)?; let static_bundle = EncryptedPeerBundle::from_data(reader.take_array()?); reader.finish()?; Ok(Self { @@ -62,10 +62,10 @@ pub struct Ik2 { } impl Ik2 { - pub const ENCODED_LEN: usize = HandshakeHeader::ENCODED_LEN - + HandshakeMeta::ENCODED_LEN + pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + MlKemCiphertext::SIZE - + EncryptedMlKemCiphertext::ENCODED_LEN; + + EncryptedMlKemCiphertext::WIRE_SIZE; pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { let out = self.header.encode_into(out); diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index 9df80a57..8daa3593 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -18,10 +18,10 @@ pub struct Kk1 { } impl Kk1 { - pub const ENCODED_LEN: usize = HandshakeHeader::ENCODED_LEN - + HandshakeMeta::ENCODED_LEN + pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + MlKemCiphertext::SIZE - + EphemeralPublicKey::ENCODED_LEN; + + EphemeralPublicKey::WIRE_SIZE; pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { let out = self.header.encode_into(out); @@ -36,7 +36,7 @@ impl Kk1 { let meta = HandshakeMeta::decode_from(&mut reader)?; let skem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); let ephemeral = - EphemeralPublicKey::decode(reader.take_bytes(EphemeralPublicKey::ENCODED_LEN)?)?; + EphemeralPublicKey::decode(reader.take_bytes(EphemeralPublicKey::WIRE_SIZE)?)?; reader.finish()?; Ok(Self { header, @@ -56,10 +56,10 @@ pub struct Kk2 { } impl Kk2 { - pub const ENCODED_LEN: usize = HandshakeHeader::ENCODED_LEN - + HandshakeMeta::ENCODED_LEN + pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + MlKemCiphertext::SIZE - + EncryptedMlKemCiphertext::ENCODED_LEN; + + EncryptedMlKemCiphertext::WIRE_SIZE; pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { let out = self.header.encode_into(out); diff --git a/ql-wire/src/handshake/meta.rs b/ql-wire/src/handshake/meta.rs index bf780750..50369c15 100644 --- a/ql-wire/src/handshake/meta.rs +++ b/ql-wire/src/handshake/meta.rs @@ -11,7 +11,7 @@ pub struct HandshakeMeta { } impl HandshakeMeta { - pub const ENCODED_LEN: usize = size_of::() + size_of::(); + pub const WIRE_SIZE: usize = size_of::() + size_of::(); pub fn ensure_not_expired(&self, now_seconds: u64) -> Result<(), WireError> { if now_seconds > self.valid_until { @@ -26,8 +26,8 @@ impl HandshakeMeta { codec::write_u64(out, self.valid_until) } - pub fn encode(&self) -> [u8; Self::ENCODED_LEN] { - let mut out = [0; Self::ENCODED_LEN]; + pub fn encode(&self) -> [u8; Self::WIRE_SIZE] { + let mut out = [0; Self::WIRE_SIZE]; let _ = self.encode_into(&mut out); out } diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index df689048..9bc47a5a 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -24,10 +24,10 @@ pub struct HandshakeHeader { } impl HandshakeHeader { - pub const ENCODED_LEN: usize = XID::SIZE * 2; + pub const WIRE_SIZE: usize = XID::SIZE * 2; - pub fn encode(&self) -> [u8; Self::ENCODED_LEN] { - let mut out = [0; Self::ENCODED_LEN]; + pub fn encode(&self) -> [u8; Self::WIRE_SIZE] { + let mut out = [0; Self::WIRE_SIZE]; let _ = self.encode_into(&mut out); out } @@ -60,7 +60,7 @@ pub struct EphemeralPublicKey { } impl EphemeralPublicKey { - pub const ENCODED_LEN: usize = MlKemPublicKey::SIZE; + pub const WIRE_SIZE: usize = MlKemPublicKey::SIZE; pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { codec::write_bytes(out, self.mlkem_public_key.as_bytes()) @@ -77,31 +77,31 @@ impl EphemeralPublicKey { } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct EncryptedMlKemCiphertext(Box<[u8; Self::ENCODED_LEN]>); +pub struct EncryptedMlKemCiphertext(Box<[u8; Self::WIRE_SIZE]>); impl EncryptedMlKemCiphertext { - pub const ENCODED_LEN: usize = MlKemCiphertext::SIZE + ENCRYPTED_MESSAGE_AUTH_SIZE; + pub const WIRE_SIZE: usize = MlKemCiphertext::SIZE + ENCRYPTED_MESSAGE_AUTH_SIZE; - pub fn from_data(data: [u8; Self::ENCODED_LEN]) -> Self { + pub fn from_data(data: [u8; Self::WIRE_SIZE]) -> Self { Self(Box::new(data)) } - pub fn as_bytes(&self) -> &[u8; Self::ENCODED_LEN] { + pub fn as_bytes(&self) -> &[u8; Self::WIRE_SIZE] { self.0.as_ref() } } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct EncryptedPeerBundle(Box<[u8; Self::ENCODED_LEN]>); +pub struct EncryptedPeerBundle(Box<[u8; Self::WIRE_SIZE]>); impl EncryptedPeerBundle { - pub const ENCODED_LEN: usize = PeerBundle::ENCODED_LEN + ENCRYPTED_MESSAGE_AUTH_SIZE; + pub const WIRE_SIZE: usize = PeerBundle::WIRE_SIZE + ENCRYPTED_MESSAGE_AUTH_SIZE; - pub fn from_data(data: [u8; Self::ENCODED_LEN]) -> Self { + pub fn from_data(data: [u8; Self::WIRE_SIZE]) -> Self { Self(Box::new(data)) } - pub fn as_bytes(&self) -> &[u8; Self::ENCODED_LEN] { + pub fn as_bytes(&self) -> &[u8; Self::WIRE_SIZE] { self.0.as_ref() } } @@ -353,10 +353,10 @@ fn encrypt_peer_bundle( bundle: &PeerBundle, ) -> Result { let ciphertext = symmetric.encrypt_and_hash(crypto, &bundle.encode())?; - if ciphertext.len() != EncryptedPeerBundle::ENCODED_LEN { + if ciphertext.len() != EncryptedPeerBundle::WIRE_SIZE { return Err(WireError::InvalidState); } - let mut out = [0u8; EncryptedPeerBundle::ENCODED_LEN]; + let mut out = [0u8; EncryptedPeerBundle::WIRE_SIZE]; out.copy_from_slice(&ciphertext); Ok(EncryptedPeerBundle::from_data(out)) } @@ -376,10 +376,10 @@ fn encrypt_mlkem_ciphertext( ciphertext: &MlKemCiphertext, ) -> Result { let encrypted = symmetric.encrypt_and_hash(crypto, ciphertext.as_bytes())?; - if encrypted.len() != EncryptedMlKemCiphertext::ENCODED_LEN { + if encrypted.len() != EncryptedMlKemCiphertext::WIRE_SIZE { return Err(WireError::InvalidState); } - let mut out = [0u8; EncryptedMlKemCiphertext::ENCODED_LEN]; + let mut out = [0u8; EncryptedMlKemCiphertext::WIRE_SIZE]; out.copy_from_slice(&encrypted); Ok(EncryptedMlKemCiphertext::from_data(out)) } diff --git a/ql-wire/src/header.rs b/ql-wire/src/header.rs index ea289eea..a60964f3 100644 --- a/ql-wire/src/header.rs +++ b/ql-wire/src/header.rs @@ -27,18 +27,18 @@ impl ConnectionId { } impl SessionHeader { - pub const ENCODED_LEN: usize = ConnectionId::SIZE + size_of::(); + pub const WIRE_SIZE: usize = ConnectionId::SIZE + size_of::(); const AAD_DOMAIN: &[u8] = b"ql-wire:session-aad:v1"; const AAD_RECORD_KIND_SESSION: u8 = 1; - pub fn encode(&self) -> [u8; Self::ENCODED_LEN] { - let mut out = [0; Self::ENCODED_LEN]; + pub fn encode(&self) -> [u8; Self::WIRE_SIZE] { + let mut out = [0; Self::WIRE_SIZE]; self.encode_into(&mut out); out } pub fn encode_into(&self, out: &mut [u8]) { - assert_eq!(out.len(), Self::ENCODED_LEN); + assert_eq!(out.len(), Self::WIRE_SIZE); let out = codec::write_bytes(out, self.connection_id.as_bytes()); let _ = codec::write_u64(out, self.seq.0); } diff --git a/ql-wire/src/identity.rs b/ql-wire/src/identity.rs index 1d640893..584016d9 100644 --- a/ql-wire/src/identity.rs +++ b/ql-wire/src/identity.rs @@ -10,7 +10,7 @@ pub struct PeerBundle { impl PeerBundle { pub const VERSION: u16 = 1; - pub const ENCODED_LEN: usize = + pub const WIRE_SIZE: usize = size_of::() + XID::SIZE + size_of::() + MlKemPublicKey::SIZE; pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { @@ -21,7 +21,7 @@ impl PeerBundle { } pub fn encode(&self) -> Vec { - let mut out = vec![0; Self::ENCODED_LEN]; + let mut out = vec![0; Self::WIRE_SIZE]; let _ = self.encode_into(&mut out); out } diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index fde3d19a..8c251578 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -77,12 +77,12 @@ impl QlHandshakeRecord { } } - fn encoded_len(&self) -> usize { + fn wire_size(&self) -> usize { match self { - Self::Ik1(_) => Ik1::ENCODED_LEN, - Self::Ik2(_) => Ik2::ENCODED_LEN, - Self::Kk1(_) => Kk1::ENCODED_LEN, - Self::Kk2(_) => Kk2::ENCODED_LEN, + Self::Ik1(_) => Ik1::WIRE_SIZE, + Self::Ik2(_) => Ik2::WIRE_SIZE, + Self::Kk1(_) => Kk1::WIRE_SIZE, + Self::Kk2(_) => Kk2::WIRE_SIZE, } } @@ -105,7 +105,7 @@ impl QlHandshakeRecord { } pub fn encode(&self) -> Vec { - let mut out = vec![0; 3 + self.encoded_len()]; + let mut out = vec![0; 3 + self.wire_size()]; let rest = codec::write_u8(&mut out, QL_WIRE_VERSION); let rest = codec::write_u8(rest, RecordType::Handshake as u8); let rest = codec::write_u8(rest, self.kind() as u8); @@ -133,7 +133,7 @@ impl> QlSessionRecord { pub fn encode(&self) -> Vec { let mut out = vec![ 0; - 2 + SessionHeader::ENCODED_LEN + 2 + SessionHeader::WIRE_SIZE + EncryptedMessage::<&[u8]>::HEADER_LEN + self.payload.ciphertext.as_ref().len() ]; diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 82f6424b..df57fa54 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -190,7 +190,7 @@ fn handshake_record_round_trip_supports_ik_and_kk() { ephemeral: EphemeralPublicKey { mlkem_public_key: MlKemPublicKey::from_data([9; MlKemPublicKey::SIZE]), }, - static_bundle: EncryptedPeerBundle::from_data([13; EncryptedPeerBundle::ENCODED_LEN]), + static_bundle: EncryptedPeerBundle::from_data([13; EncryptedPeerBundle::WIRE_SIZE]), }); let ik_encoded = ik.encode(); assert_eq!(QlHandshakeRecord::decode(&ik_encoded).unwrap(), ik); From 25b7d093c6b4cfbddb559374b25b4db93c1dbe5f Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 4 Apr 2026 09:22:19 -0400 Subject: [PATCH 084/304] ql-wire: cleanup --- ql-wire/src/encrypted/ack.rs | 4 ---- ql-wire/src/encrypted/mod.rs | 23 ++--------------------- ql-wire/src/tests.rs | 23 +++++++++++++++++++---- 3 files changed, 21 insertions(+), 29 deletions(-) diff --git a/ql-wire/src/encrypted/ack.rs b/ql-wire/src/encrypted/ack.rs index a5ba6d1c..7eaeb0f6 100644 --- a/ql-wire/src/encrypted/ack.rs +++ b/ql-wire/src/encrypted/ack.rs @@ -11,10 +11,6 @@ impl RecordAck { pub const WIRE_SIZE: usize = size_of::() + size_of::(); pub fn decode(bytes: &[u8]) -> Result { - if bytes.len() != Self::WIRE_SIZE { - return Err(WireError::InvalidPayload); - } - let mut reader = codec::Reader::new(bytes); Ok(Self { base_seq: RecordSeq(reader.take_u64()?), diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 7f098b38..ac24ee84 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -1,6 +1,6 @@ use crate::{ codec, encrypted_message::EncryptedMessage, ByteChunks, ByteSlice, Nonce, QlCrypto, - QlSessionRecord, SessionHeader, SessionKey, WireError, + SessionHeader, SessionKey, WireError, }; mod ack; @@ -75,10 +75,7 @@ impl TryFrom for SessionFrameKind { impl SessionRecord { pub fn parse(bytes: &[u8]) -> Result, WireError> { - let reader = codec::Reader::new(bytes); - Ok(SessionFrameIter { - remaining: reader.take_rest(), - }) + Ok(SessionFrameIter { remaining: bytes }) } pub fn decode(bytes: &[u8]) -> Result { @@ -145,22 +142,6 @@ impl<'a> Iterator for SessionFrameIter<'a> { } } -pub fn encrypt_record( - crypto: &impl QlCrypto, - header: SessionHeader, - session_key: &SessionKey, - body: &SessionRecord, -) -> QlSessionRecord> { - let wire_size = body.wire_size() + SessionRecordBuilder::WIRE_PREFIX_LEN; - let mut builder = SessionRecordBuilder::new(wire_size, wire_size); - for frame in &body.frames { - let pushed = builder.push_frame(frame); - debug_assert!(pushed); - } - QlSessionRecord::decode(&builder.encrypt(crypto, header, session_key)) - .expect("builder emitted an invalid session record") -} - pub fn decrypt_record>( crypto: &impl QlCrypto, header: &SessionHeader, diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index df57fa54..ae9b999a 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -169,6 +169,21 @@ fn handshake_header(sender: u8, recipient: u8) -> HandshakeHeader { } } +fn encrypt_record( + crypto: &impl QlCrypto, + header: SessionHeader, + session_key: &SessionKey, + body: &SessionRecord, +) -> QlSessionRecord> { + let wire_size = body.wire_size() + SessionRecordBuilder::WIRE_PREFIX_LEN; + let mut builder = SessionRecordBuilder::new(wire_size, wire_size); + for frame in &body.frames { + let _pushed = builder.push_frame(frame); + debug_assert!(_pushed); + } + QlSessionRecord::decode(&builder.encrypt(crypto, header, session_key)).unwrap() +} + #[test] fn peer_bundle_round_trip() { let crypto = TestCrypto::new(1); @@ -496,7 +511,7 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { ], }; let session_key = SessionKey::from_data([7; SessionKey::SIZE]); - let record = encrypted::encrypt_record(&crypto, header, &session_key, &body); + let record = encrypt_record(&crypto, header, &session_key, &body); let bytes = record.encode(); let decoded = QlRecord::decode(&bytes).unwrap(); @@ -569,7 +584,7 @@ fn protocol_record_size_breakdown() { let kk2 = QlHandshakeRecord::Kk2(kk2); let session = ik_initiator.finalize(&crypto).unwrap(); - let session_ping = encrypted::encrypt_record( + let session_ping = encrypt_record( &crypto, SessionHeader { connection_id: session.tx_connection_id, @@ -580,7 +595,7 @@ fn protocol_record_size_breakdown() { frames: vec![SessionFrame::Ping], }, ); - let session_stream_empty = encrypted::encrypt_record( + let session_stream_empty = encrypt_record( &crypto, SessionHeader { connection_id: session.tx_connection_id, @@ -596,7 +611,7 @@ fn protocol_record_size_breakdown() { })], }, ); - let session_close = encrypted::encrypt_record( + let session_close = encrypt_record( &crypto, SessionHeader { connection_id: session.tx_connection_id, From 72763fbb2e3ddc49518a513b76ab975a2dc6ec38 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 4 Apr 2026 09:44:40 -0400 Subject: [PATCH 085/304] ql-wire: introduce WireParse trait --- ql-wire/src/codec.rs | 29 +++++++++----- ql-wire/src/encrypted/ack.rs | 30 +++++++-------- ql-wire/src/encrypted/close.rs | 10 ++--- ql-wire/src/encrypted/mod.rs | 13 ++++--- ql-wire/src/encrypted/stream_close.rs | 21 +++++----- ql-wire/src/encrypted/stream_window.rs | 13 +++---- ql-wire/src/handshake/ik.rs | 44 +++++++++------------ ql-wire/src/handshake/kk.rs | 41 ++++++++------------ ql-wire/src/handshake/meta.rs | 18 ++------- ql-wire/src/handshake/mod.rs | 53 +++++++++----------------- ql-wire/src/header.rs | 27 +++++-------- ql-wire/src/identity.rs | 15 ++++---- ql-wire/src/lib.rs | 1 + ql-wire/src/record.rs | 16 +++----- ql-wire/src/tests.rs | 6 +-- 15 files changed, 147 insertions(+), 190 deletions(-) diff --git a/ql-wire/src/codec.rs b/ql-wire/src/codec.rs index 143f93b6..7b06c1a0 100644 --- a/ql-wire/src/codec.rs +++ b/ql-wire/src/codec.rs @@ -34,6 +34,20 @@ pub fn write_bytes<'a>(out: &'a mut [u8], bytes: &[u8]) -> &'a mut [u8] { rest } +pub trait WireParse: Sized { + fn parse(reader: &mut Reader) -> Result; + + fn parse_bytes(bytes: B) -> Result { + let mut reader = Reader::new(bytes); + let value = Self::parse(&mut reader)?; + if reader.is_empty() { + Ok(value) + } else { + Err(WireError::InvalidPayload) + } + } +} + pub struct Reader { remaining: Option, } @@ -49,10 +63,6 @@ impl Reader { self.remaining.as_ref().unwrap().is_empty() } - pub fn remaining(&self) -> usize { - self.remaining.as_ref().unwrap().len() - } - pub fn take_bytes(&mut self, len: usize) -> Result { let remaining = self.remaining.take().unwrap(); match remaining.split_at(len) { @@ -102,11 +112,10 @@ impl Reader { } } - pub fn finish(self) -> Result<(), WireError> { - if self.is_empty() { - Ok(()) - } else { - Err(WireError::InvalidPayload) - } + pub fn parse(&mut self) -> Result + where + T: WireParse, + { + T::parse(self) } } diff --git a/ql-wire/src/encrypted/ack.rs b/ql-wire/src/encrypted/ack.rs index 7eaeb0f6..c7794332 100644 --- a/ql-wire/src/encrypted/ack.rs +++ b/ql-wire/src/encrypted/ack.rs @@ -1,4 +1,4 @@ -use crate::{codec, RecordSeq, WireError}; +use crate::{codec, ByteSlice, RecordSeq, WireError}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct RecordAck { @@ -10,14 +10,6 @@ impl RecordAck { pub const BITMAP_BITS: usize = u64::BITS as usize; pub const WIRE_SIZE: usize = size_of::() + size_of::(); - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - Ok(Self { - base_seq: RecordSeq(reader.take_u64()?), - bits: reader.take_u64()?, - }) - } - pub fn contains(&self, seq: u64) -> bool { if seq < self.base_seq.0 { return false; @@ -38,10 +30,19 @@ impl RecordAck { } } +impl codec::WireParse for RecordAck { + fn parse(reader: &mut codec::Reader) -> Result { + Ok(Self { + base_seq: RecordSeq(reader.take_u64()?), + bits: reader.take_u64()?, + }) + } +} + #[cfg(test)] mod tests { use super::RecordAck; - use crate::{RecordSeq, WireError}; + use crate::{RecordSeq, WireError, WireParse}; #[test] fn encode_decode_round_trip() { @@ -52,7 +53,7 @@ mod tests { let mut encoded = [0; RecordAck::WIRE_SIZE]; ack.encode_into(&mut encoded); - assert_eq!(RecordAck::decode(&encoded).unwrap(), ack); + assert_eq!(RecordAck::parse_bytes(&encoded[..]).unwrap(), ack); } #[test] @@ -71,14 +72,13 @@ mod tests { } #[test] - fn decode_rejects_invalid_length() { - assert_eq!(RecordAck::decode(&[]), Err(WireError::InvalidPayload)); + fn decode_rejects_truncated_payload() { assert_eq!( - RecordAck::decode(&[0; RecordAck::WIRE_SIZE - 1]), + RecordAck::parse_bytes(&[][..]), Err(WireError::InvalidPayload) ); assert_eq!( - RecordAck::decode(&[0; RecordAck::WIRE_SIZE + 1]), + RecordAck::parse_bytes(&[0; RecordAck::WIRE_SIZE - 1][..]), Err(WireError::InvalidPayload) ); } diff --git a/ql-wire/src/encrypted/close.rs b/ql-wire/src/encrypted/close.rs index af9e234d..51653643 100644 --- a/ql-wire/src/encrypted/close.rs +++ b/ql-wire/src/encrypted/close.rs @@ -1,4 +1,4 @@ -use crate::{codec, codec::Reader, WireError}; +use crate::{codec, codec::Reader, ByteSlice, WireError}; /// closes the whole session immediately with a close code. #[derive(Debug, Clone, PartialEq, Eq)] @@ -12,12 +12,12 @@ impl SessionClose { pub fn encode_into(&self, out: &mut [u8]) { let _ = codec::write_u16(out, self.code.0); } +} - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = Reader::new(bytes); - let code = reader.take_u16()?; +impl codec::WireParse for SessionClose { + fn parse(reader: &mut Reader) -> Result { Ok(Self { - code: SessionCloseCode(code), + code: SessionCloseCode(reader.take_u16()?), }) } } diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index ac24ee84..9f9bc892 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -1,6 +1,6 @@ use crate::{ codec, encrypted_message::EncryptedMessage, ByteChunks, ByteSlice, Nonce, QlCrypto, - SessionHeader, SessionKey, WireError, + SessionHeader, SessionKey, WireError, WireParse, }; mod ack; @@ -161,7 +161,7 @@ fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrame<&[u8]>, &[u8]), WireEr let (frame, rest) = rest .split_at_checked(RecordAck::WIRE_SIZE) .ok_or(WireError::InvalidPayload)?; - Ok((SessionFrame::Ack(RecordAck::decode(frame)?), rest)) + Ok((SessionFrame::Ack(RecordAck::parse_bytes(frame)?), rest)) } SessionFrameKind::StreamData => { let (frame, rest) = split_variable_frame(rest)?; @@ -172,19 +172,22 @@ fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrame<&[u8]>, &[u8]), WireEr .split_at_checked(StreamWindow::WIRE_SIZE) .ok_or(WireError::InvalidPayload)?; Ok(( - SessionFrame::StreamWindow(StreamWindow::decode(frame)?), + SessionFrame::StreamWindow(StreamWindow::parse_bytes(frame)?), rest, )) } SessionFrameKind::StreamClose => { let (frame, rest) = split_variable_frame(rest)?; - Ok((SessionFrame::StreamClose(StreamClose::parse(frame)?), rest)) + Ok(( + SessionFrame::StreamClose(StreamClose::parse_bytes(frame)?), + rest, + )) } SessionFrameKind::Close => { let (frame, rest) = rest .split_at_checked(SessionClose::WIRE_SIZE) .ok_or(WireError::InvalidPayload)?; - Ok((SessionFrame::Close(SessionClose::decode(frame)?), rest)) + Ok((SessionFrame::Close(SessionClose::parse_bytes(frame)?), rest)) } } } diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index a83cc79a..81f9a2fe 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -13,17 +13,6 @@ impl StreamClose { pub const WIRE_SIZE: usize = size_of::() + size_of::() + size_of::(); - pub fn parse(bytes: B) -> Result { - let mut reader = codec::Reader::new(bytes); - let close = Self { - stream_id: StreamId(reader.take_u32()?), - target: CloseTarget::try_from(reader.take_u8()?)?, - code: StreamCloseCode(reader.take_u16()?), - }; - reader.finish()?; - Ok(close) - } - pub fn encode_into(&self, out: &mut [u8]) { let out = codec::write_u32(out, self.stream_id.0); let out = codec::write_u8(out, self.target.to_wire()); @@ -31,6 +20,16 @@ impl StreamClose { } } +impl codec::WireParse for StreamClose { + fn parse(reader: &mut codec::Reader) -> Result { + Ok(Self { + stream_id: StreamId(reader.take_u32()?), + target: CloseTarget::try_from(reader.take_u8()?)?, + code: StreamCloseCode(reader.take_u16()?), + }) + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub enum CloseTarget { diff --git a/ql-wire/src/encrypted/stream_window.rs b/ql-wire/src/encrypted/stream_window.rs index 1f3388c0..224e9e43 100644 --- a/ql-wire/src/encrypted/stream_window.rs +++ b/ql-wire/src/encrypted/stream_window.rs @@ -1,5 +1,5 @@ use super::StreamId; -use crate::{codec, WireError}; +use crate::{codec, ByteSlice, WireError}; /// advertises the highest byte offset the peer may send on a stream. #[derive(Debug, Clone, PartialEq, Eq)] @@ -15,14 +15,13 @@ impl StreamWindow { let out = codec::write_u32(out, self.stream_id.0); let _ = codec::write_u64(out, self.maximum_offset); } +} - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - let window = Self { +impl codec::WireParse for StreamWindow { + fn parse(reader: &mut codec::Reader) -> Result { + Ok(Self { stream_id: StreamId(reader.take_u32()?), maximum_offset: reader.take_u64()?, - }; - reader.finish()?; - Ok(window) + }) } } diff --git a/ql-wire/src/handshake/ik.rs b/ql-wire/src/handshake/ik.rs index 27250d58..e4bac713 100644 --- a/ql-wire/src/handshake/ik.rs +++ b/ql-wire/src/handshake/ik.rs @@ -6,8 +6,8 @@ use super::{ FinalizedHandshake, HandshakeHeader, Role, SymmetricState, }; use crate::{ - codec, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, QlIdentity, - WireError, + codec, ByteSlice, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, + QlIdentity, WireError, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -33,22 +33,16 @@ impl Ik1 { let out = self.ephemeral.encode_into(out); codec::write_bytes(out, self.static_bundle.as_bytes()) } +} - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - let header = HandshakeHeader::decode_from(&mut reader)?; - let meta = HandshakeMeta::decode_from(&mut reader)?; - let skem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); - let ephemeral = - EphemeralPublicKey::decode(reader.take_bytes(EphemeralPublicKey::WIRE_SIZE)?)?; - let static_bundle = EncryptedPeerBundle::from_data(reader.take_array()?); - reader.finish()?; +impl codec::WireParse for Ik1 { + fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { - header, - meta, - skem_ciphertext, - ephemeral, - static_bundle, + header: reader.parse()?, + meta: reader.parse()?, + skem_ciphertext: MlKemCiphertext::from_data(reader.take_array()?), + ephemeral: reader.parse()?, + static_bundle: EncryptedPeerBundle::from_data(reader.take_array()?), }) } } @@ -73,19 +67,15 @@ impl Ik2 { let out = codec::write_bytes(out, self.ekem_ciphertext.as_bytes()); codec::write_bytes(out, self.skem_ciphertext.as_bytes()) } +} - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - let header = HandshakeHeader::decode_from(&mut reader)?; - let meta = HandshakeMeta::decode_from(&mut reader)?; - let ekem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); - let skem_ciphertext = EncryptedMlKemCiphertext::from_data(reader.take_array()?); - reader.finish()?; +impl codec::WireParse for Ik2 { + fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { - header, - meta, - ekem_ciphertext, - skem_ciphertext, + header: reader.parse()?, + meta: reader.parse()?, + ekem_ciphertext: MlKemCiphertext::from_data(reader.take_array()?), + skem_ciphertext: EncryptedMlKemCiphertext::from_data(reader.take_array()?), }) } } diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index 8daa3593..43cbba3e 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -5,8 +5,8 @@ use super::{ EphemeralPublicKey, FinalizedHandshake, HandshakeHeader, Role, SymmetricState, }; use crate::{ - codec, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, QlIdentity, - WireError, + codec, ByteSlice, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, + QlIdentity, WireError, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -29,20 +29,15 @@ impl Kk1 { let out = codec::write_bytes(out, self.skem_ciphertext.as_bytes()); self.ephemeral.encode_into(out) } +} - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - let header = HandshakeHeader::decode_from(&mut reader)?; - let meta = HandshakeMeta::decode_from(&mut reader)?; - let skem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); - let ephemeral = - EphemeralPublicKey::decode(reader.take_bytes(EphemeralPublicKey::WIRE_SIZE)?)?; - reader.finish()?; +impl codec::WireParse for Kk1 { + fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { - header, - meta, - skem_ciphertext, - ephemeral, + header: reader.parse()?, + meta: reader.parse()?, + skem_ciphertext: MlKemCiphertext::from_data(reader.take_array()?), + ephemeral: reader.parse()?, }) } } @@ -67,19 +62,15 @@ impl Kk2 { let out = codec::write_bytes(out, self.ekem_ciphertext.as_bytes()); codec::write_bytes(out, self.skem_ciphertext.as_bytes()) } +} - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - let header = HandshakeHeader::decode_from(&mut reader)?; - let meta = HandshakeMeta::decode_from(&mut reader)?; - let ekem_ciphertext = MlKemCiphertext::from_data(reader.take_array()?); - let skem_ciphertext = EncryptedMlKemCiphertext::from_data(reader.take_array()?); - reader.finish()?; +impl codec::WireParse for Kk2 { + fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { - header, - meta, - ekem_ciphertext, - skem_ciphertext, + header: reader.parse()?, + meta: reader.parse()?, + ekem_ciphertext: MlKemCiphertext::from_data(reader.take_array()?), + skem_ciphertext: EncryptedMlKemCiphertext::from_data(reader.take_array()?), }) } } diff --git a/ql-wire/src/handshake/meta.rs b/ql-wire/src/handshake/meta.rs index 50369c15..f74697ad 100644 --- a/ql-wire/src/handshake/meta.rs +++ b/ql-wire/src/handshake/meta.rs @@ -1,4 +1,4 @@ -use crate::{codec, WireError}; +use crate::{codec, ByteSlice, WireError}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] @@ -31,20 +31,10 @@ impl HandshakeMeta { let _ = self.encode_into(&mut out); out } +} - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - let meta = Self { - handshake_id: HandshakeId(reader.take_u32()?), - valid_until: reader.take_u64()?, - }; - reader.finish()?; - Ok(meta) - } - - pub fn decode_from( - reader: &mut codec::Reader, - ) -> Result { +impl codec::WireParse for HandshakeMeta { + fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { handshake_id: HandshakeId(reader.take_u32()?), valid_until: reader.take_u64()?, diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 9bc47a5a..62a507f7 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -1,6 +1,7 @@ use crate::{ - codec, ConnectionId, HandshakeKind, MlKemCiphertext, MlKemKeyPair, MlKemPublicKey, Nonce, - PeerBundle, QlCrypto, SessionKey, WireError, ENCRYPTED_MESSAGE_AUTH_SIZE, XID, + codec, ByteSlice, ConnectionId, HandshakeKind, MlKemCiphertext, MlKemKeyPair, MlKemPublicKey, + Nonce, PeerBundle, QlCrypto, SessionKey, WireError, WireParse, ENCRYPTED_MESSAGE_AUTH_SIZE, + XID, }; mod ik; @@ -36,17 +37,10 @@ impl HandshakeHeader { let out = codec::write_bytes(out, &self.sender.0); codec::write_bytes(out, &self.recipient.0) } +} - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - let header = Self::decode_from(&mut reader)?; - reader.finish()?; - Ok(header) - } - - pub fn decode_from( - reader: &mut codec::Reader, - ) -> Result { +impl codec::WireParse for HandshakeHeader { + fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { sender: XID(reader.take_array()?), recipient: XID(reader.take_array()?), @@ -65,14 +59,13 @@ impl EphemeralPublicKey { pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { codec::write_bytes(out, self.mlkem_public_key.as_bytes()) } +} - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - let value = Self { +impl codec::WireParse for EphemeralPublicKey { + fn parse(reader: &mut codec::Reader) -> Result { + Ok(Self { mlkem_public_key: MlKemPublicKey::from_data(reader.take_array()?), - }; - reader.finish()?; - Ok(value) + }) } } @@ -353,11 +346,8 @@ fn encrypt_peer_bundle( bundle: &PeerBundle, ) -> Result { let ciphertext = symmetric.encrypt_and_hash(crypto, &bundle.encode())?; - if ciphertext.len() != EncryptedPeerBundle::WIRE_SIZE { - return Err(WireError::InvalidState); - } - let mut out = [0u8; EncryptedPeerBundle::WIRE_SIZE]; - out.copy_from_slice(&ciphertext); + let out: [u8; EncryptedPeerBundle::WIRE_SIZE] = + ciphertext.try_into().map_err(|_| WireError::InvalidState)?; Ok(EncryptedPeerBundle::from_data(out)) } @@ -367,7 +357,7 @@ fn decrypt_peer_bundle( bundle: &EncryptedPeerBundle, ) -> Result { let plaintext = symmetric.decrypt_and_hash(crypto, bundle.as_bytes())?; - PeerBundle::decode(&plaintext) + PeerBundle::parse_bytes(plaintext.as_slice()) } fn encrypt_mlkem_ciphertext( @@ -376,11 +366,8 @@ fn encrypt_mlkem_ciphertext( ciphertext: &MlKemCiphertext, ) -> Result { let encrypted = symmetric.encrypt_and_hash(crypto, ciphertext.as_bytes())?; - if encrypted.len() != EncryptedMlKemCiphertext::WIRE_SIZE { - return Err(WireError::InvalidState); - } - let mut out = [0u8; EncryptedMlKemCiphertext::WIRE_SIZE]; - out.copy_from_slice(&encrypted); + let out: [u8; EncryptedMlKemCiphertext::WIRE_SIZE] = + encrypted.try_into().map_err(|_| WireError::InvalidState)?; Ok(EncryptedMlKemCiphertext::from_data(out)) } @@ -390,11 +377,9 @@ fn decrypt_mlkem_ciphertext( ciphertext: &EncryptedMlKemCiphertext, ) -> Result { let plaintext = symmetric.decrypt_and_hash(crypto, ciphertext.as_bytes())?; - if plaintext.len() != MlKemCiphertext::SIZE { - return Err(WireError::InvalidPayload); - } - let mut out = [0u8; MlKemCiphertext::SIZE]; - out.copy_from_slice(&plaintext); + let out: [u8; MlKemCiphertext::SIZE] = plaintext + .try_into() + .map_err(|_| WireError::InvalidPayload)?; Ok(MlKemCiphertext::from_data(out)) } diff --git a/ql-wire/src/header.rs b/ql-wire/src/header.rs index a60964f3..253f2e21 100644 --- a/ql-wire/src/header.rs +++ b/ql-wire/src/header.rs @@ -1,4 +1,4 @@ -use crate::{codec, QL_WIRE_VERSION}; +use crate::{codec, ByteSlice, QL_WIRE_VERSION}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct SessionHeader { @@ -43,22 +43,6 @@ impl SessionHeader { let _ = codec::write_u64(out, self.seq.0); } - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - let header = Self::decode_from(&mut reader)?; - reader.finish()?; - Ok(header) - } - - pub fn decode_from( - reader: &mut codec::Reader, - ) -> Result { - Ok(Self { - connection_id: ConnectionId::from_data(reader.take_array()?), - seq: RecordSeq(reader.take_u64()?), - }) - } - pub fn aad(&self) -> Vec { let aad_len = Self::AAD_DOMAIN.len() + size_of::() @@ -74,3 +58,12 @@ impl SessionHeader { aad } } + +impl codec::WireParse for SessionHeader { + fn parse(reader: &mut codec::Reader) -> Result { + Ok(Self { + connection_id: ConnectionId::from_data(reader.take_array()?), + seq: RecordSeq(reader.take_u64()?), + }) + } +} diff --git a/ql-wire/src/identity.rs b/ql-wire/src/identity.rs index 584016d9..d162eac2 100644 --- a/ql-wire/src/identity.rs +++ b/ql-wire/src/identity.rs @@ -1,4 +1,6 @@ -use crate::{codec, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, QlCrypto, WireError, XID}; +use crate::{ + codec, ByteSlice, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, QlCrypto, WireError, XID, +}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct PeerBundle { @@ -25,17 +27,16 @@ impl PeerBundle { let _ = self.encode_into(&mut out); out } +} - pub fn decode(bytes: &[u8]) -> Result { - let mut reader = codec::Reader::new(bytes); - let bundle = Self { +impl codec::WireParse for PeerBundle { + fn parse(reader: &mut codec::Reader) -> Result { + Ok(Self { version: reader.take_u16()?, xid: XID(reader.take_array()?), capabilities: reader.take_u32()?, mlkem_public_key: MlKemPublicKey::from_data(reader.take_array()?), - }; - reader.finish()?; - Ok(bundle) + }) } } diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs index c82f2fa4..ecebf1f1 100644 --- a/ql-wire/src/lib.rs +++ b/ql-wire/src/lib.rs @@ -19,6 +19,7 @@ mod record; mod xid; pub use bytes::*; +pub use codec::*; pub use crypto::*; pub use encrypted::*; pub use encrypted_message::*; diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index 8c251578..e75c5529 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -2,7 +2,7 @@ use crate::{ codec, encrypted_message::EncryptedMessage, handshake::{Ik1, Ik2, Kk1, Kk2}, - ByteSlice, SessionHeader, WireError, QL_WIRE_VERSION, + ByteSlice, SessionHeader, WireError, WireParse, QL_WIRE_VERSION, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -97,10 +97,10 @@ impl QlHandshakeRecord { fn decode_payload(kind: HandshakeKind, bytes: &[u8]) -> Result { match kind { - HandshakeKind::Ik1 => Ok(Self::Ik1(Ik1::decode(bytes)?)), - HandshakeKind::Ik2 => Ok(Self::Ik2(Ik2::decode(bytes)?)), - HandshakeKind::Kk1 => Ok(Self::Kk1(Kk1::decode(bytes)?)), - HandshakeKind::Kk2 => Ok(Self::Kk2(Kk2::decode(bytes)?)), + HandshakeKind::Ik1 => Ok(Self::Ik1(Ik1::parse_bytes(bytes)?)), + HandshakeKind::Ik2 => Ok(Self::Ik2(Ik2::parse_bytes(bytes)?)), + HandshakeKind::Kk1 => Ok(Self::Kk1(Kk1::parse_bytes(bytes)?)), + HandshakeKind::Kk2 => Ok(Self::Kk2(Kk2::parse_bytes(bytes)?)), } } @@ -113,10 +113,6 @@ impl QlHandshakeRecord { out } - pub fn decode(bytes: &[u8]) -> Result { - Self::parse(bytes) - } - pub fn parse(bytes: B) -> Result { let mut reader = codec::Reader::new(bytes); if reader.take_u8()? != QL_WIRE_VERSION { @@ -218,7 +214,7 @@ fn parse_handshake_record(bytes: B) -> Result(bytes: B) -> Result, WireError> { let mut reader = codec::Reader::new(bytes); - let header = SessionHeader::decode_from(&mut reader)?; + let header = reader.parse::()?; let payload = EncryptedMessage::parse(reader.take_rest())?; Ok(QlSessionRecord { header, payload }) } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index ae9b999a..c91c6a21 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -191,7 +191,7 @@ fn peer_bundle_round_trip() { let bundle = identity.bundle(); let encoded = bundle.encode(); - let decoded = PeerBundle::decode(&encoded).unwrap(); + let decoded = PeerBundle::parse_bytes(encoded.as_slice()).unwrap(); assert_eq!(decoded, bundle); } @@ -208,7 +208,7 @@ fn handshake_record_round_trip_supports_ik_and_kk() { static_bundle: EncryptedPeerBundle::from_data([13; EncryptedPeerBundle::WIRE_SIZE]), }); let ik_encoded = ik.encode(); - assert_eq!(QlHandshakeRecord::decode(&ik_encoded).unwrap(), ik); + assert_eq!(QlHandshakeRecord::parse(ik_encoded.as_slice()).unwrap(), ik); assert_eq!( QlRecord::decode(&ik_encoded).unwrap(), QlRecord::Handshake(ik) @@ -223,7 +223,7 @@ fn handshake_record_round_trip_supports_ik_and_kk() { }, }); let kk_encoded = kk.encode(); - assert_eq!(QlHandshakeRecord::decode(&kk_encoded).unwrap(), kk); + assert_eq!(QlHandshakeRecord::parse(kk_encoded.as_slice()).unwrap(), kk); assert_eq!( QlRecord::decode(&kk_encoded).unwrap(), QlRecord::Handshake(kk) From f3d1612f9bf2f3d8a931f0de8aa3619541d55c75 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 4 Apr 2026 10:45:58 -0400 Subject: [PATCH 086/304] ql-wire: take_boxed_slice --- ql-fsm/src/tests/mod.rs | 6 +++--- ql-wire/src/codec.rs | 12 ++++++++++++ ql-wire/src/handshake/ik.rs | 8 ++++---- ql-wire/src/handshake/kk.rs | 6 +++--- ql-wire/src/handshake/mod.rs | 25 +++++++++++++------------ ql-wire/src/identity.rs | 2 +- ql-wire/src/pq.rs | 12 ++++++------ ql-wire/src/tests.rs | 16 ++++++++-------- 8 files changed, 50 insertions(+), 37 deletions(-) diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 5b68c3a0..8458d4fe 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -106,8 +106,8 @@ impl QlKem for TestCrypto { private.copy_from_slice(key_pair.sk()); MlKemKeyPair { - private: MlKemPrivateKey::from_data(private), - public: MlKemPublicKey::from_data(public), + private: MlKemPrivateKey::new(Box::new(private)), + public: MlKemPublicKey::new(Box::new(public)), } } @@ -120,7 +120,7 @@ impl QlKem for TestCrypto { let mut shared = [0u8; SessionKey::SIZE]; shared.copy_from_slice(shared_value.as_slice()); ( - MlKemCiphertext::from_data(ciphertext), + MlKemCiphertext::new(Box::new(ciphertext)), SessionKey::from_data(shared), ) } diff --git a/ql-wire/src/codec.rs b/ql-wire/src/codec.rs index 7b06c1a0..e970d1d5 100644 --- a/ql-wire/src/codec.rs +++ b/ql-wire/src/codec.rs @@ -88,6 +88,18 @@ impl Reader { Ok(out) } + pub fn take_boxed_array(&mut self) -> Result, WireError> { + let bytes = self.take_bytes(N)?; + let mut out = Box::<[u8; N]>::new_uninit(); + let src = bytes.as_ptr(); + let dst = out.as_mut_ptr().cast::(); + // SAFETY: `take_bytes(N)` guarantees the source has exactly `N` bytes + unsafe { + std::ptr::copy_nonoverlapping(src, dst, N); + Ok(out.assume_init()) + } + } + pub fn take_u8(&mut self) -> Result { Ok(self.take_bytes(1)?[0]) } diff --git a/ql-wire/src/handshake/ik.rs b/ql-wire/src/handshake/ik.rs index e4bac713..61bfdce7 100644 --- a/ql-wire/src/handshake/ik.rs +++ b/ql-wire/src/handshake/ik.rs @@ -40,9 +40,9 @@ impl codec::WireParse for Ik1 { Ok(Self { header: reader.parse()?, meta: reader.parse()?, - skem_ciphertext: MlKemCiphertext::from_data(reader.take_array()?), + skem_ciphertext: MlKemCiphertext::new(reader.take_boxed_array()?), ephemeral: reader.parse()?, - static_bundle: EncryptedPeerBundle::from_data(reader.take_array()?), + static_bundle: EncryptedPeerBundle::new(reader.take_boxed_array()?), }) } } @@ -74,8 +74,8 @@ impl codec::WireParse for Ik2 { Ok(Self { header: reader.parse()?, meta: reader.parse()?, - ekem_ciphertext: MlKemCiphertext::from_data(reader.take_array()?), - skem_ciphertext: EncryptedMlKemCiphertext::from_data(reader.take_array()?), + ekem_ciphertext: MlKemCiphertext::new(reader.take_boxed_array()?), + skem_ciphertext: EncryptedMlKemCiphertext::new(reader.take_boxed_array()?), }) } } diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index 43cbba3e..2506cc7a 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -36,7 +36,7 @@ impl codec::WireParse for Kk1 { Ok(Self { header: reader.parse()?, meta: reader.parse()?, - skem_ciphertext: MlKemCiphertext::from_data(reader.take_array()?), + skem_ciphertext: MlKemCiphertext::new(reader.take_boxed_array()?), ephemeral: reader.parse()?, }) } @@ -69,8 +69,8 @@ impl codec::WireParse for Kk2 { Ok(Self { header: reader.parse()?, meta: reader.parse()?, - ekem_ciphertext: MlKemCiphertext::from_data(reader.take_array()?), - skem_ciphertext: EncryptedMlKemCiphertext::from_data(reader.take_array()?), + ekem_ciphertext: MlKemCiphertext::new(reader.take_boxed_array()?), + skem_ciphertext: EncryptedMlKemCiphertext::new(reader.take_boxed_array()?), }) } } diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 62a507f7..f0bb7753 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -64,7 +64,7 @@ impl EphemeralPublicKey { impl codec::WireParse for EphemeralPublicKey { fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { - mlkem_public_key: MlKemPublicKey::from_data(reader.take_array()?), + mlkem_public_key: MlKemPublicKey::new(reader.take_boxed_array()?), }) } } @@ -75,8 +75,8 @@ pub struct EncryptedMlKemCiphertext(Box<[u8; Self::WIRE_SIZE]>); impl EncryptedMlKemCiphertext { pub const WIRE_SIZE: usize = MlKemCiphertext::SIZE + ENCRYPTED_MESSAGE_AUTH_SIZE; - pub fn from_data(data: [u8; Self::WIRE_SIZE]) -> Self { - Self(Box::new(data)) + pub fn new(data: Box<[u8; Self::WIRE_SIZE]>) -> Self { + Self(data) } pub fn as_bytes(&self) -> &[u8; Self::WIRE_SIZE] { @@ -90,8 +90,8 @@ pub struct EncryptedPeerBundle(Box<[u8; Self::WIRE_SIZE]>); impl EncryptedPeerBundle { pub const WIRE_SIZE: usize = PeerBundle::WIRE_SIZE + ENCRYPTED_MESSAGE_AUTH_SIZE; - pub fn from_data(data: [u8; Self::WIRE_SIZE]) -> Self { - Self(Box::new(data)) + pub fn new(data: Box<[u8; Self::WIRE_SIZE]>) -> Self { + Self(data) } pub fn as_bytes(&self) -> &[u8; Self::WIRE_SIZE] { @@ -159,7 +159,8 @@ impl CipherState { ) -> Result, WireError> { let key = self.key.as_ref().ok_or(WireError::InvalidState)?; let nonce = Nonce::from_counter(self.nonce); - let mut ciphertext = plaintext.to_vec(); + let mut ciphertext = Vec::with_capacity(plaintext.len() + ENCRYPTED_MESSAGE_AUTH_SIZE); + ciphertext.extend_from_slice(plaintext); let auth = crypto.aes256_gcm_encrypt(key, &nonce, aad, &mut ciphertext); self.nonce = self.nonce.wrapping_add(1); ciphertext.extend_from_slice(&auth); @@ -346,9 +347,9 @@ fn encrypt_peer_bundle( bundle: &PeerBundle, ) -> Result { let ciphertext = symmetric.encrypt_and_hash(crypto, &bundle.encode())?; - let out: [u8; EncryptedPeerBundle::WIRE_SIZE] = + let out: Box<[u8; EncryptedPeerBundle::WIRE_SIZE]> = ciphertext.try_into().map_err(|_| WireError::InvalidState)?; - Ok(EncryptedPeerBundle::from_data(out)) + Ok(EncryptedPeerBundle::new(out)) } fn decrypt_peer_bundle( @@ -366,9 +367,9 @@ fn encrypt_mlkem_ciphertext( ciphertext: &MlKemCiphertext, ) -> Result { let encrypted = symmetric.encrypt_and_hash(crypto, ciphertext.as_bytes())?; - let out: [u8; EncryptedMlKemCiphertext::WIRE_SIZE] = + let out: Box<[u8; EncryptedMlKemCiphertext::WIRE_SIZE]> = encrypted.try_into().map_err(|_| WireError::InvalidState)?; - Ok(EncryptedMlKemCiphertext::from_data(out)) + Ok(EncryptedMlKemCiphertext::new(out)) } fn decrypt_mlkem_ciphertext( @@ -377,10 +378,10 @@ fn decrypt_mlkem_ciphertext( ciphertext: &EncryptedMlKemCiphertext, ) -> Result { let plaintext = symmetric.decrypt_and_hash(crypto, ciphertext.as_bytes())?; - let out: [u8; MlKemCiphertext::SIZE] = plaintext + let out: Box<[u8; MlKemCiphertext::SIZE]> = plaintext .try_into() .map_err(|_| WireError::InvalidPayload)?; - Ok(MlKemCiphertext::from_data(out)) + Ok(MlKemCiphertext::new(out)) } fn finalize_handshake( diff --git a/ql-wire/src/identity.rs b/ql-wire/src/identity.rs index d162eac2..8533a054 100644 --- a/ql-wire/src/identity.rs +++ b/ql-wire/src/identity.rs @@ -35,7 +35,7 @@ impl codec::WireParse for PeerBundle { version: reader.take_u16()?, xid: XID(reader.take_array()?), capabilities: reader.take_u32()?, - mlkem_public_key: MlKemPublicKey::from_data(reader.take_array()?), + mlkem_public_key: MlKemPublicKey::new(reader.take_boxed_array()?), }) } } diff --git a/ql-wire/src/pq.rs b/ql-wire/src/pq.rs index ba8753d0..7000e406 100644 --- a/ql-wire/src/pq.rs +++ b/ql-wire/src/pq.rs @@ -45,8 +45,8 @@ pub struct MlKemPublicKey(Box<[u8; MlKemPublicKey::SIZE]>); impl MlKemPublicKey { pub const SIZE: usize = ML_KEM_1024_PUBLIC_KEY_SIZE; - pub fn from_data(data: [u8; Self::SIZE]) -> Self { - Self(Box::new(data)) + pub fn new(data: Box<[u8; Self::SIZE]>) -> Self { + Self(data) } pub fn as_bytes(&self) -> &[u8; Self::SIZE] { @@ -66,8 +66,8 @@ pub struct MlKemPrivateKey(Box<[u8; MlKemPrivateKey::SIZE]>); impl MlKemPrivateKey { pub const SIZE: usize = ML_KEM_1024_PRIVATE_KEY_SIZE; - pub fn from_data(data: [u8; Self::SIZE]) -> Self { - Self(Box::new(data)) + pub fn new(data: Box<[u8; Self::SIZE]>) -> Self { + Self(data) } pub fn as_bytes(&self) -> &[u8; Self::SIZE] { @@ -87,8 +87,8 @@ pub struct MlKemCiphertext(Box<[u8; MlKemCiphertext::SIZE]>); impl MlKemCiphertext { pub const SIZE: usize = ML_KEM_1024_CIPHERTEXT_SIZE; - pub fn from_data(data: [u8; Self::SIZE]) -> Self { - Self(Box::new(data)) + pub fn new(data: Box<[u8; Self::SIZE]>) -> Self { + Self(data) } pub fn as_bytes(&self) -> &[u8; Self::SIZE] { diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index c91c6a21..ef9c149b 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -87,8 +87,8 @@ impl QlKem for TestCrypto { private.copy_from_slice(key_pair.sk()); MlKemKeyPair { - private: MlKemPrivateKey::from_data(private), - public: MlKemPublicKey::from_data(public), + private: MlKemPrivateKey::new(Box::new(private)), + public: MlKemPublicKey::new(Box::new(public)), } } @@ -101,7 +101,7 @@ impl QlKem for TestCrypto { let mut shared = [0u8; SessionKey::SIZE]; shared.copy_from_slice(shared_value.as_slice()); ( - MlKemCiphertext::from_data(ciphertext), + MlKemCiphertext::new(Box::new(ciphertext)), SessionKey::from_data(shared), ) } @@ -201,11 +201,11 @@ fn handshake_record_round_trip_supports_ik_and_kk() { let ik = QlHandshakeRecord::Ik1(Ik1 { header: handshake_header(1, 2), meta: handshake_meta(1), - skem_ciphertext: MlKemCiphertext::from_data([7; MlKemCiphertext::SIZE]), + skem_ciphertext: MlKemCiphertext::new(Box::new([7; MlKemCiphertext::SIZE])), ephemeral: EphemeralPublicKey { - mlkem_public_key: MlKemPublicKey::from_data([9; MlKemPublicKey::SIZE]), + mlkem_public_key: MlKemPublicKey::new(Box::new([9; MlKemPublicKey::SIZE])), }, - static_bundle: EncryptedPeerBundle::from_data([13; EncryptedPeerBundle::WIRE_SIZE]), + static_bundle: EncryptedPeerBundle::new(Box::new([13; EncryptedPeerBundle::WIRE_SIZE])), }); let ik_encoded = ik.encode(); assert_eq!(QlHandshakeRecord::parse(ik_encoded.as_slice()).unwrap(), ik); @@ -217,9 +217,9 @@ fn handshake_record_round_trip_supports_ik_and_kk() { let kk = QlHandshakeRecord::Kk1(Kk1 { header: handshake_header(1, 2), meta: handshake_meta(2), - skem_ciphertext: MlKemCiphertext::from_data([11; MlKemCiphertext::SIZE]), + skem_ciphertext: MlKemCiphertext::new(Box::new([11; MlKemCiphertext::SIZE])), ephemeral: EphemeralPublicKey { - mlkem_public_key: MlKemPublicKey::from_data([15; MlKemPublicKey::SIZE]), + mlkem_public_key: MlKemPublicKey::new(Box::new([15; MlKemPublicKey::SIZE])), }, }); let kk_encoded = kk.encode(); From 8c643624dc041cd71c0281ba2cc0b5959b6f0f59 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 4 Apr 2026 11:09:09 -0400 Subject: [PATCH 087/304] ql: remove QlRecord enum --- ql-fsm/src/implementation/core.rs | 13 +++- ql-fsm/src/tests/handshake.rs | 20 ++--- ql-fsm/src/tests/mod.rs | 9 +-- ql-wire/src/codec.rs | 10 +++ ql-wire/src/record.rs | 122 ++++++++++-------------------- ql-wire/src/tests.rs | 45 +++++++---- 6 files changed, 101 insertions(+), 118 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index b3ffbda9..8d99149d 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -5,7 +5,7 @@ use std::{ use ql_wire::{ self as wire, CloseTarget, QlCrypto, SessionClose, SessionCloseCode, SessionHeader, - StreamCloseCode, StreamId, + StreamCloseCode, StreamId, WireParse, }; use crate::{ @@ -28,9 +28,14 @@ pub fn receive( mut bytes: Vec, crypto: &impl QlCrypto, ) -> Result<(), QlFsmError> { - match wire::QlRecord::parse(&mut bytes[..])? { - wire::QlRecord::Handshake(record) => super::handle_handshake_record(fsm, crypto, &record), - wire::QlRecord::Session(record) => { + let header = wire::RecordHeader::parse_prefix(bytes.as_slice())?; + match header.record_type { + wire::RecordType::Handshake => { + let record = wire::QlHandshakeRecord::parse_bytes(bytes.as_slice())?; + super::handle_handshake_record(fsm, crypto, &record) + } + wire::RecordType::Session => { + let record = wire::QlSessionRecord::parse_bytes(&mut bytes[..])?; let transport = fsm.state.link.transport().ok_or(QlFsmError::NoSession)?; if record.header.connection_id != transport.rx_connection_id { return Err(QlFsmError::InvalidPayload); diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 6653ca1d..b6f089f1 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -1,6 +1,6 @@ use std::time::Duration; -use ql_wire::QlRecord; +use ql_wire::{QlHandshakeRecord, WireParse}; use super::*; use crate::{state::LinkState, QlFsmError}; @@ -157,11 +157,8 @@ fn handshake_timeout_drops_single_ik_attempt_without_resend() { .connect_ik(harness.time(), &harness.a.crypto) .unwrap(); let first = harness.next_outbound_a().unwrap(); - let first = QlRecord::decode(&first).unwrap(); - assert!(matches!( - first, - QlRecord::Handshake(ql_wire::QlHandshakeRecord::Ik1(_)) - )); + let first = QlHandshakeRecord::parse_bytes(first.as_slice()).unwrap(); + assert!(matches!(first, ql_wire::QlHandshakeRecord::Ik1(_))); assert!(harness.next_outbound_a().is_none()); harness.advance(config.handshake_timeout); @@ -247,12 +244,11 @@ fn simultaneous_ik_and_kk_connect_prefers_ik() { } fn handshake_id(record: &[u8]) -> ql_wire::HandshakeId { - let record = QlRecord::decode(record).unwrap(); + let record = QlHandshakeRecord::parse_bytes(record).unwrap(); match record { - QlRecord::Handshake(ql_wire::QlHandshakeRecord::Ik1(message)) => message.meta.handshake_id, - QlRecord::Handshake(ql_wire::QlHandshakeRecord::Ik2(message)) => message.meta.handshake_id, - QlRecord::Handshake(ql_wire::QlHandshakeRecord::Kk1(message)) => message.meta.handshake_id, - QlRecord::Handshake(ql_wire::QlHandshakeRecord::Kk2(message)) => message.meta.handshake_id, - QlRecord::Session(_) => panic!("expected handshake record"), + ql_wire::QlHandshakeRecord::Ik1(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Ik2(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Kk1(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Kk2(message) => message.meta.handshake_id, } } diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 8458d4fe..8dc777b3 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -11,7 +11,7 @@ use libcrux_ml_kem::mlkem1024; use ql_wire::{ self, generate_identity, ConnectionId, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, QlAead, QlCrypto, QlHash, QlIdentity, QlKem, QlRandom, SessionKey, - ENCRYPTED_MESSAGE_AUTH_SIZE, XID, + WireParse, ENCRYPTED_MESSAGE_AUTH_SIZE, XID, }; use sha2::{Digest, Sha256}; @@ -329,10 +329,9 @@ fn decrypt_record( record: &[u8], session_key: &SessionKey, ) -> (ql_wire::SessionHeader, ql_wire::SessionRecord) { - let record = ql_wire::QlRecord::decode(record).unwrap(); - let ql_wire::QlRecord::Session(record) = record else { - panic!("expected encrypted session record"); - }; + let record = ql_wire::QlSessionRecord::parse_bytes(record) + .unwrap() + .into_owned(); let plaintext = ql_wire::decrypt_record(crypto, &record.header, record.payload.clone(), session_key) .unwrap(); diff --git a/ql-wire/src/codec.rs b/ql-wire/src/codec.rs index e970d1d5..cb9af122 100644 --- a/ql-wire/src/codec.rs +++ b/ql-wire/src/codec.rs @@ -37,6 +37,11 @@ pub fn write_bytes<'a>(out: &'a mut [u8], bytes: &[u8]) -> &'a mut [u8] { pub trait WireParse: Sized { fn parse(reader: &mut Reader) -> Result; + fn parse_prefix(bytes: B) -> Result { + let mut reader = Reader::new(bytes); + Self::parse(&mut reader) + } + fn parse_bytes(bytes: B) -> Result { let mut reader = Reader::new(bytes); let value = Self::parse(&mut reader)?; @@ -63,6 +68,10 @@ impl Reader { self.remaining.as_ref().unwrap().is_empty() } + pub fn remaining_len(&self) -> usize { + self.remaining.as_ref().unwrap().len() + } + pub fn take_bytes(&mut self, len: usize) -> Result { let remaining = self.remaining.take().unwrap(); match remaining.split_at(len) { @@ -124,6 +133,7 @@ impl Reader { } } + #[inline] pub fn parse(&mut self) -> Result where T: WireParse, diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index e75c5529..509d0f2a 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -11,12 +11,6 @@ pub struct QlSessionRecord { pub payload: EncryptedMessage, } -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum QlRecord { - Handshake(QlHandshakeRecord), - Session(QlSessionRecord), -} - #[derive(Debug, Clone, PartialEq, Eq)] pub enum QlHandshakeRecord { Ik1(Ik1), @@ -32,6 +26,12 @@ pub enum RecordType { Session = 2, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RecordHeader { + pub version: u8, + pub record_type: RecordType, +} + impl TryFrom for RecordType { type Error = WireError; @@ -44,6 +44,17 @@ impl TryFrom for RecordType { } } +impl WireParse for RecordHeader { + fn parse(reader: &mut codec::Reader) -> Result { + let version = reader.take_u8()?; + let record_type = RecordType::try_from(reader.take_u8()?)?; + Ok(Self { + version, + record_type, + }) + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub enum HandshakeKind { @@ -95,15 +106,6 @@ impl QlHandshakeRecord { } } - fn decode_payload(kind: HandshakeKind, bytes: &[u8]) -> Result { - match kind { - HandshakeKind::Ik1 => Ok(Self::Ik1(Ik1::parse_bytes(bytes)?)), - HandshakeKind::Ik2 => Ok(Self::Ik2(Ik2::parse_bytes(bytes)?)), - HandshakeKind::Kk1 => Ok(Self::Kk1(Kk1::parse_bytes(bytes)?)), - HandshakeKind::Kk2 => Ok(Self::Kk2(Kk2::parse_bytes(bytes)?)), - } - } - pub fn encode(&self) -> Vec { let mut out = vec![0; 3 + self.wire_size()]; let rest = codec::write_u8(&mut out, QL_WIRE_VERSION); @@ -112,16 +114,24 @@ impl QlHandshakeRecord { let _ = self.encode_into(rest); out } +} - pub fn parse(bytes: B) -> Result { - let mut reader = codec::Reader::new(bytes); - if reader.take_u8()? != QL_WIRE_VERSION { +impl WireParse for QlHandshakeRecord { + fn parse(reader: &mut codec::Reader) -> Result { + let header = reader.parse::()?; + if header.version != QL_WIRE_VERSION { return Err(WireError::InvalidPayload); } - if RecordType::try_from(reader.take_u8()?)? != RecordType::Handshake { + if header.record_type != RecordType::Handshake { return Err(WireError::InvalidPayload); } - parse_handshake_record(reader.take_rest()) + let kind = HandshakeKind::try_from(reader.take_u8()?)?; + match kind { + HandshakeKind::Ik1 => Ok(Self::Ik1(reader.parse()?)), + HandshakeKind::Ik2 => Ok(Self::Ik2(reader.parse()?)), + HandshakeKind::Kk1 => Ok(Self::Kk1(reader.parse()?)), + HandshakeKind::Kk2 => Ok(Self::Kk2(reader.parse()?)), + } } } @@ -141,24 +151,7 @@ impl> QlSessionRecord { } } -impl QlSessionRecord> { - pub fn decode(bytes: &[u8]) -> Result { - QlSessionRecord::parse(bytes).map(QlSessionRecord::into_owned) - } -} - impl QlSessionRecord { - pub fn parse(bytes: B) -> Result { - let mut reader = codec::Reader::new(bytes); - if reader.take_u8()? != QL_WIRE_VERSION { - return Err(WireError::InvalidPayload); - } - if RecordType::try_from(reader.take_u8()?)? != RecordType::Session { - return Err(WireError::InvalidPayload); - } - parse_session_record(reader.take_rest()) - } - pub fn into_owned(self) -> QlSessionRecord> { QlSessionRecord { header: self.header, @@ -167,54 +160,17 @@ impl QlSessionRecord { } } -impl> QlRecord { - pub fn encode(&self) -> Vec { - match self { - Self::Handshake(record) => record.encode(), - Self::Session(record) => record.encode(), - } - } -} - -impl QlRecord> { - pub fn decode(bytes: &[u8]) -> Result { - QlRecord::parse(bytes).map(QlRecord::into_owned) - } -} - -impl QlRecord { - pub fn parse(bytes: B) -> Result { - let mut reader = codec::Reader::new(bytes); - if reader.take_u8()? != QL_WIRE_VERSION { +impl WireParse for QlSessionRecord { + fn parse(reader: &mut codec::Reader) -> Result { + let header = reader.parse::()?; + if header.version != QL_WIRE_VERSION { return Err(WireError::InvalidPayload); } - - let record_type = RecordType::try_from(reader.take_u8()?)?; - let remaining = reader.take_rest(); - match record_type { - RecordType::Handshake => Ok(Self::Handshake(parse_handshake_record(remaining)?)), - RecordType::Session => Ok(Self::Session(parse_session_record(remaining)?)), - } - } - - pub fn into_owned(self) -> QlRecord> { - match self { - Self::Handshake(record) => QlRecord::Handshake(record), - Self::Session(record) => QlRecord::Session(record.into_owned()), + if header.record_type != RecordType::Session { + return Err(WireError::InvalidPayload); } + let header = reader.parse::()?; + let payload = EncryptedMessage::parse(reader.take_bytes(reader.remaining_len())?)?; + Ok(QlSessionRecord { header, payload }) } } - -fn parse_handshake_record(bytes: B) -> Result { - let mut reader = codec::Reader::new(bytes); - let kind = HandshakeKind::try_from(reader.take_u8()?)?; - let payload = reader.take_rest(); - QlHandshakeRecord::decode_payload(kind, &payload[..]) -} - -fn parse_session_record(bytes: B) -> Result, WireError> { - let mut reader = codec::Reader::new(bytes); - let header = reader.parse::()?; - let payload = EncryptedMessage::parse(reader.take_rest())?; - Ok(QlSessionRecord { header, payload }) -} diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index ef9c149b..b879c87f 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -181,7 +181,9 @@ fn encrypt_record( let _pushed = builder.push_frame(frame); debug_assert!(_pushed); } - QlSessionRecord::decode(&builder.encrypt(crypto, header, session_key)).unwrap() + QlSessionRecord::parse_bytes(builder.encrypt(crypto, header, session_key).as_slice()) + .unwrap() + .into_owned() } #[test] @@ -208,10 +210,16 @@ fn handshake_record_round_trip_supports_ik_and_kk() { static_bundle: EncryptedPeerBundle::new(Box::new([13; EncryptedPeerBundle::WIRE_SIZE])), }); let ik_encoded = ik.encode(); - assert_eq!(QlHandshakeRecord::parse(ik_encoded.as_slice()).unwrap(), ik); assert_eq!( - QlRecord::decode(&ik_encoded).unwrap(), - QlRecord::Handshake(ik) + RecordHeader::parse_prefix(ik_encoded.as_slice()).unwrap(), + RecordHeader { + version: QL_WIRE_VERSION, + record_type: RecordType::Handshake, + } + ); + assert_eq!( + QlHandshakeRecord::parse_bytes(ik_encoded.as_slice()).unwrap(), + ik ); let kk = QlHandshakeRecord::Kk1(Kk1 { @@ -223,10 +231,16 @@ fn handshake_record_round_trip_supports_ik_and_kk() { }, }); let kk_encoded = kk.encode(); - assert_eq!(QlHandshakeRecord::parse(kk_encoded.as_slice()).unwrap(), kk); assert_eq!( - QlRecord::decode(&kk_encoded).unwrap(), - QlRecord::Handshake(kk) + RecordHeader::parse_prefix(kk_encoded.as_slice()).unwrap(), + RecordHeader { + version: QL_WIRE_VERSION, + record_type: RecordType::Handshake, + } + ); + assert_eq!( + QlHandshakeRecord::parse_bytes(kk_encoded.as_slice()).unwrap(), + kk ); } @@ -514,10 +528,16 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { let record = encrypt_record(&crypto, header, &session_key, &body); let bytes = record.encode(); - let decoded = QlRecord::decode(&bytes).unwrap(); - let QlRecord::Session(decoded) = decoded else { - panic!("expected session payload"); - }; + assert_eq!( + RecordHeader::parse_prefix(bytes.as_slice()).unwrap(), + RecordHeader { + version: QL_WIRE_VERSION, + record_type: RecordType::Session, + } + ); + let decoded = QlSessionRecord::parse_bytes(bytes.as_slice()) + .unwrap() + .into_owned(); assert_eq!(decoded.header, header); let encrypted = decoded.payload; @@ -525,9 +545,6 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { encrypted::decrypt_record(&crypto, &header, encrypted.clone(), &session_key).unwrap(); assert_eq!(SessionRecord::decode(&decrypted).unwrap(), body); - let decoded = QlSessionRecord::decode(&bytes).unwrap(); - assert_eq!(decoded.header, header); - let wrong_header = SessionHeader { connection_id: ConnectionId::from_data([0x99; ConnectionId::SIZE]), seq: header.seq, From d6a0bc63a9f4dcda70c7a351097a4521b7847a5d Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 4 Apr 2026 13:09:32 -0400 Subject: [PATCH 088/304] ql: add todos --- ql-fsm/src/session/mod.rs | 3 +++ ql-wire/src/encrypted/builder.rs | 1 + 2 files changed, 4 insertions(+) diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 98bbf148..327d890f 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -249,6 +249,9 @@ impl SessionFsm { self.state.last_activity_at = self.state.now; self.state.last_inbound_at = self.state.now; + // TODO: We record the session seq before validating its frames. If later frame + // handling fails with PROTOCOL, a subsequent outbound close can still carry an ack for + // this seq even though its stream data was rejected. let (duplicate, out_of_order) = match self.state.received_records.insert(seq) { ReceiveOutcome::Duplicate => (true, false), ReceiveOutcome::New { out_of_order } => (false, out_of_order), diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index d4092123..4b45b601 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -4,6 +4,7 @@ use crate::{ByteChunks, Nonce, QlCrypto, RecordType, SessionHeader, SessionKey, #[derive(Debug, Clone, PartialEq, Eq)] pub struct SessionRecordBuilder { max_capacity: usize, + // todo: remove body_start: usize, bytes: Vec, } From fe30845eb12e59ff4f17bea3d2215c634ad20bbc Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 4 Apr 2026 13:18:37 -0400 Subject: [PATCH 089/304] ql-wire: add transport params --- ql-wire/src/handshake/ik.rs | 28 ++- ql-wire/src/handshake/kk.rs | 27 +++ ql-wire/src/handshake/mod.rs | 13 +- ql-wire/src/handshake/transport_params.rs | 38 ++++ ql-wire/src/lib.rs | 2 +- ql-wire/src/tests.rs | 235 +++++++++++++++++++--- 6 files changed, 314 insertions(+), 29 deletions(-) create mode 100644 ql-wire/src/handshake/transport_params.rs diff --git a/ql-wire/src/handshake/ik.rs b/ql-wire/src/handshake/ik.rs index 61bfdce7..8f166ab9 100644 --- a/ql-wire/src/handshake/ik.rs +++ b/ql-wire/src/handshake/ik.rs @@ -3,7 +3,7 @@ use super::{ finalize_handshake, generate_ephemeral_keypair, init_ik_symmetric, initialize_handshake_meta, mix_hash_ephemeral, mix_hash_routed_handshake, require_handshake_meta, EncryptedMlKemCiphertext, EncryptedPeerBundle, EphemeralKeyPair, EphemeralPublicKey, - FinalizedHandshake, HandshakeHeader, Role, SymmetricState, + FinalizedHandshake, HandshakeHeader, Role, SymmetricState, TransportParams, }; use crate::{ codec, ByteSlice, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, @@ -14,6 +14,7 @@ use crate::{ pub struct Ik1 { pub header: HandshakeHeader, pub meta: HandshakeMeta, + pub transport_params: TransportParams, pub skem_ciphertext: MlKemCiphertext, pub ephemeral: EphemeralPublicKey, pub static_bundle: EncryptedPeerBundle, @@ -22,6 +23,7 @@ pub struct Ik1 { impl Ik1 { pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + HandshakeMeta::WIRE_SIZE + + TransportParams::WIRE_SIZE + MlKemCiphertext::SIZE + EphemeralPublicKey::WIRE_SIZE + EncryptedPeerBundle::WIRE_SIZE; @@ -29,6 +31,7 @@ impl Ik1 { pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { let out = self.header.encode_into(out); let out = self.meta.encode_into(out); + let out = self.transport_params.encode_into(out); let out = codec::write_bytes(out, self.skem_ciphertext.as_bytes()); let out = self.ephemeral.encode_into(out); codec::write_bytes(out, self.static_bundle.as_bytes()) @@ -40,6 +43,7 @@ impl codec::WireParse for Ik1 { Ok(Self { header: reader.parse()?, meta: reader.parse()?, + transport_params: reader.parse()?, skem_ciphertext: MlKemCiphertext::new(reader.take_boxed_array()?), ephemeral: reader.parse()?, static_bundle: EncryptedPeerBundle::new(reader.take_boxed_array()?), @@ -51,6 +55,7 @@ impl codec::WireParse for Ik1 { pub struct Ik2 { pub header: HandshakeHeader, pub meta: HandshakeMeta, + pub transport_params: TransportParams, pub ekem_ciphertext: MlKemCiphertext, pub skem_ciphertext: EncryptedMlKemCiphertext, } @@ -58,12 +63,14 @@ pub struct Ik2 { impl Ik2 { pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + HandshakeMeta::WIRE_SIZE + + TransportParams::WIRE_SIZE + MlKemCiphertext::SIZE + EncryptedMlKemCiphertext::WIRE_SIZE; pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { let out = self.header.encode_into(out); let out = self.meta.encode_into(out); + let out = self.transport_params.encode_into(out); let out = codec::write_bytes(out, self.ekem_ciphertext.as_bytes()); codec::write_bytes(out, self.skem_ciphertext.as_bytes()) } @@ -74,6 +81,7 @@ impl codec::WireParse for Ik2 { Ok(Self { header: reader.parse()?, meta: reader.parse()?, + transport_params: reader.parse()?, ekem_ciphertext: MlKemCiphertext::new(reader.take_boxed_array()?), skem_ciphertext: EncryptedMlKemCiphertext::new(reader.take_boxed_array()?), }) @@ -99,6 +107,8 @@ pub struct IkHandshake { local_ephemeral: Option, remote_ephemeral: Option, handshake_meta: Option, + local_transport_params: TransportParams, + remote_transport_params: Option, } impl IkHandshake { @@ -106,6 +116,7 @@ impl IkHandshake { crypto: &impl QlCrypto, local: QlIdentity, remote_bundle: PeerBundle, + local_transport_params: TransportParams, ) -> Self { let symmetric = init_ik_symmetric(crypto, &remote_bundle); Self { @@ -117,6 +128,8 @@ impl IkHandshake { local_ephemeral: None, remote_ephemeral: None, handshake_meta: None, + local_transport_params, + remote_transport_params: None, } } @@ -124,6 +137,7 @@ impl IkHandshake { crypto: &impl QlCrypto, local: QlIdentity, expected_remote: Option, + local_transport_params: TransportParams, ) -> Self { let symmetric = init_ik_symmetric(crypto, &local.bundle()); Self { @@ -135,6 +149,8 @@ impl IkHandshake { local_ephemeral: None, remote_ephemeral: None, handshake_meta: None, + local_transport_params, + remote_transport_params: None, } } @@ -184,6 +200,7 @@ impl IkHandshake { header, HandshakeKind::Ik1, &meta, + &self.local_transport_params, ); let (skem_ciphertext, skem_secret) = crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); @@ -202,6 +219,7 @@ impl IkHandshake { Ok(Ik1 { header, meta, + transport_params: self.local_transport_params, skem_ciphertext, ephemeral: public, static_bundle, @@ -224,6 +242,7 @@ impl IkHandshake { header, HandshakeKind::Ik2, &meta, + &self.local_transport_params, ); let remote_ephemeral = self .remote_ephemeral @@ -246,6 +265,7 @@ impl IkHandshake { Ok(Ik2 { header, meta, + transport_params: self.local_transport_params, ekem_ciphertext, skem_ciphertext, }) @@ -270,6 +290,7 @@ impl IkHandshake { message.header, HandshakeKind::Ik1, &message.meta, + &message.transport_params, ); self.symmetric .mix_hash(crypto, message.skem_ciphertext.as_bytes()); @@ -293,6 +314,7 @@ impl IkHandshake { Some(_) => {} None => self.remote_bundle = Some(remote_bundle), } + self.remote_transport_params = Some(message.transport_params); self.step = IkStep::Send2; Ok(()) } @@ -316,6 +338,7 @@ impl IkHandshake { message.header, HandshakeKind::Ik2, &message.meta, + &message.transport_params, ); let local_ephemeral = self .local_ephemeral @@ -333,6 +356,7 @@ impl IkHandshake { self.symmetric .mix_key_and_hash(crypto, skem_secret.as_bytes()); + self.remote_transport_params = Some(message.transport_params); self.step = IkStep::Done; Ok(()) } @@ -342,11 +366,13 @@ impl IkHandshake { return Err(WireError::InvalidState); } let remote_bundle = self.remote_bundle.ok_or(WireError::InvalidState)?; + let remote_transport_params = self.remote_transport_params.ok_or(WireError::InvalidState)?; Ok(finalize_handshake( crypto, &self.symmetric, self.role, remote_bundle, + remote_transport_params, )) } } diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index 2506cc7a..84afb3f5 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -3,6 +3,7 @@ use super::{ generate_ephemeral_keypair, init_kk_symmetric, initialize_handshake_meta, mix_hash_ephemeral, mix_hash_routed_handshake, require_handshake_meta, EncryptedMlKemCiphertext, EphemeralKeyPair, EphemeralPublicKey, FinalizedHandshake, HandshakeHeader, Role, SymmetricState, + TransportParams, }; use crate::{ codec, ByteSlice, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, @@ -13,6 +14,7 @@ use crate::{ pub struct Kk1 { pub header: HandshakeHeader, pub meta: HandshakeMeta, + pub transport_params: TransportParams, pub skem_ciphertext: MlKemCiphertext, pub ephemeral: EphemeralPublicKey, } @@ -20,12 +22,14 @@ pub struct Kk1 { impl Kk1 { pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + HandshakeMeta::WIRE_SIZE + + TransportParams::WIRE_SIZE + MlKemCiphertext::SIZE + EphemeralPublicKey::WIRE_SIZE; pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { let out = self.header.encode_into(out); let out = self.meta.encode_into(out); + let out = self.transport_params.encode_into(out); let out = codec::write_bytes(out, self.skem_ciphertext.as_bytes()); self.ephemeral.encode_into(out) } @@ -36,6 +40,7 @@ impl codec::WireParse for Kk1 { Ok(Self { header: reader.parse()?, meta: reader.parse()?, + transport_params: reader.parse()?, skem_ciphertext: MlKemCiphertext::new(reader.take_boxed_array()?), ephemeral: reader.parse()?, }) @@ -46,6 +51,7 @@ impl codec::WireParse for Kk1 { pub struct Kk2 { pub header: HandshakeHeader, pub meta: HandshakeMeta, + pub transport_params: TransportParams, pub ekem_ciphertext: MlKemCiphertext, pub skem_ciphertext: EncryptedMlKemCiphertext, } @@ -53,12 +59,14 @@ pub struct Kk2 { impl Kk2 { pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + HandshakeMeta::WIRE_SIZE + + TransportParams::WIRE_SIZE + MlKemCiphertext::SIZE + EncryptedMlKemCiphertext::WIRE_SIZE; pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { let out = self.header.encode_into(out); let out = self.meta.encode_into(out); + let out = self.transport_params.encode_into(out); let out = codec::write_bytes(out, self.ekem_ciphertext.as_bytes()); codec::write_bytes(out, self.skem_ciphertext.as_bytes()) } @@ -69,6 +77,7 @@ impl codec::WireParse for Kk2 { Ok(Self { header: reader.parse()?, meta: reader.parse()?, + transport_params: reader.parse()?, ekem_ciphertext: MlKemCiphertext::new(reader.take_boxed_array()?), skem_ciphertext: EncryptedMlKemCiphertext::new(reader.take_boxed_array()?), }) @@ -94,6 +103,8 @@ pub struct KkHandshake { local_ephemeral: Option, remote_ephemeral: Option, handshake_meta: Option, + local_transport_params: TransportParams, + remote_transport_params: Option, } impl KkHandshake { @@ -101,6 +112,7 @@ impl KkHandshake { crypto: &impl QlCrypto, local: QlIdentity, remote_bundle: PeerBundle, + local_transport_params: TransportParams, ) -> Self { let symmetric = init_kk_symmetric(crypto, &local.bundle(), &remote_bundle); Self { @@ -112,6 +124,8 @@ impl KkHandshake { local_ephemeral: None, remote_ephemeral: None, handshake_meta: None, + local_transport_params, + remote_transport_params: None, } } @@ -119,6 +133,7 @@ impl KkHandshake { crypto: &impl QlCrypto, local: QlIdentity, remote_bundle: PeerBundle, + local_transport_params: TransportParams, ) -> Self { let symmetric = init_kk_symmetric(crypto, &remote_bundle, &local.bundle()); Self { @@ -130,6 +145,8 @@ impl KkHandshake { local_ephemeral: None, remote_ephemeral: None, handshake_meta: None, + local_transport_params, + remote_transport_params: None, } } @@ -175,6 +192,7 @@ impl KkHandshake { header, HandshakeKind::Kk1, &meta, + &self.local_transport_params, ); let (skem_ciphertext, skem_secret) = crypto.mlkem_encapsulate(&self.remote_bundle.mlkem_public_key); @@ -192,6 +210,7 @@ impl KkHandshake { Ok(Kk1 { header, meta, + transport_params: self.local_transport_params, skem_ciphertext, ephemeral: public, }) @@ -213,6 +232,7 @@ impl KkHandshake { header, HandshakeKind::Kk2, &meta, + &self.local_transport_params, ); let remote_ephemeral = self .remote_ephemeral @@ -234,6 +254,7 @@ impl KkHandshake { Ok(Kk2 { header, meta, + transport_params: self.local_transport_params, ekem_ciphertext, skem_ciphertext, }) @@ -257,6 +278,7 @@ impl KkHandshake { message.header, HandshakeKind::Kk1, &message.meta, + &message.transport_params, ); self.symmetric .decrypt_and_hash(crypto, message.skem_ciphertext.as_bytes())?; @@ -267,6 +289,7 @@ impl KkHandshake { mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); self.remote_ephemeral = Some(message.ephemeral.clone()); + self.remote_transport_params = Some(message.transport_params); self.step = KkStep::Send2; Ok(()) } @@ -289,6 +312,7 @@ impl KkHandshake { message.header, HandshakeKind::Kk2, &message.meta, + &message.transport_params, ); let local_ephemeral = self .local_ephemeral @@ -306,6 +330,7 @@ impl KkHandshake { self.symmetric .mix_key_and_hash(crypto, skem_secret.as_bytes()); + self.remote_transport_params = Some(message.transport_params); self.step = KkStep::Done; Ok(()) } @@ -314,11 +339,13 @@ impl KkHandshake { if !self.is_finished() { return Err(WireError::InvalidState); } + let remote_transport_params = self.remote_transport_params.ok_or(WireError::InvalidState)?; Ok(finalize_handshake( crypto, &self.symmetric, self.role, self.remote_bundle, + remote_transport_params, )) } } diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index f0bb7753..241ffca1 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -7,10 +7,12 @@ use crate::{ mod ik; mod kk; mod meta; +mod transport_params; pub use ik::{Ik1, Ik2, IkHandshake}; pub use kk::{Kk1, Kk2, KkHandshake}; pub use meta::{HandshakeId, HandshakeMeta}; +pub use transport_params::TransportParams; const SHA256_BLOCK_LEN: usize = 64; const PROTOCOL_IK: &[u8] = b"ql-wire:pq-ik:v1"; @@ -107,6 +109,8 @@ pub struct FinalizedHandshake { pub rx_connection_id: ConnectionId, pub handshake_hash: [u8; 32], pub remote_bundle: PeerBundle, + /// Transport parameters advertised by the remote peer + pub remote_transport_params: TransportParams, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -308,13 +312,16 @@ fn mix_hash_routed_handshake( header: HandshakeHeader, kind: HandshakeKind, meta: &HandshakeMeta, + transport_params: &TransportParams, ) { let encoded_header = header.encode(); - let encoded = meta.encode(); + let encoded_meta = meta.encode(); + let encoded_transport_params = transport_params.encode(); symmetric.mix_hash(crypto, HANDSHAKE_PREAMBLE_DOMAIN); symmetric.mix_hash(crypto, &encoded_header); symmetric.mix_hash(crypto, &[kind as u8]); - symmetric.mix_hash(crypto, &encoded); + symmetric.mix_hash(crypto, &encoded_meta); + symmetric.mix_hash(crypto, &encoded_transport_params); } fn initialize_handshake_meta( @@ -389,6 +396,7 @@ fn finalize_handshake( symmetric: &SymmetricState, role: Role, remote_bundle: PeerBundle, + remote_transport_params: TransportParams, ) -> FinalizedHandshake { let handshake_hash = symmetric.handshake_hash; let (tx_key, rx_key) = symmetric.split_for_role(crypto, role); @@ -404,6 +412,7 @@ fn finalize_handshake( rx_connection_id, handshake_hash, remote_bundle, + remote_transport_params, } } diff --git a/ql-wire/src/handshake/transport_params.rs b/ql-wire/src/handshake/transport_params.rs new file mode 100644 index 00000000..2acf9608 --- /dev/null +++ b/ql-wire/src/handshake/transport_params.rs @@ -0,0 +1,38 @@ +use crate::{codec, ByteSlice, WireError}; + +/// Session parameters advertised in the handshake +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TransportParams { + /// Initial per-stream receive credit granted to the remote peer + pub initial_stream_receive_window: u32, +} + +impl TransportParams { + pub const WIRE_SIZE: usize = size_of::(); + + pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { + codec::write_u32(out, self.initial_stream_receive_window) + } + + pub fn encode(&self) -> [u8; Self::WIRE_SIZE] { + let mut out = [0; Self::WIRE_SIZE]; + let _ = self.encode_into(&mut out); + out + } +} + +impl Default for TransportParams { + fn default() -> Self { + Self { + initial_stream_receive_window: 16 * 1024, + } + } +} + +impl codec::WireParse for TransportParams { + fn parse(reader: &mut codec::Reader) -> Result { + Ok(Self { + initial_stream_receive_window: reader.take_u32()?, + }) + } +} diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs index ecebf1f1..36568e24 100644 --- a/ql-wire/src/lib.rs +++ b/ql-wire/src/lib.rs @@ -32,7 +32,7 @@ pub use pq::*; pub use record::*; pub use xid::*; -pub const QL_WIRE_VERSION: u8 = 2; +pub const QL_WIRE_VERSION: u8 = 1; pub const ENCRYPTED_MESSAGE_AUTH_SIZE: usize = 16; #[cfg(test)] diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index b879c87f..768011c4 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -158,6 +158,12 @@ fn handshake_meta(id: u32) -> HandshakeMeta { } } +fn handshake_transport_params(window: u32) -> TransportParams { + TransportParams { + initial_stream_receive_window: window, + } +} + fn make_identity(crypto: &impl QlCrypto, byte: u8) -> QlIdentity { generate_identity(crypto, xid(byte)) } @@ -203,6 +209,7 @@ fn handshake_record_round_trip_supports_ik_and_kk() { let ik = QlHandshakeRecord::Ik1(Ik1 { header: handshake_header(1, 2), meta: handshake_meta(1), + transport_params: handshake_transport_params(65_536), skem_ciphertext: MlKemCiphertext::new(Box::new([7; MlKemCiphertext::SIZE])), ephemeral: EphemeralPublicKey { mlkem_public_key: MlKemPublicKey::new(Box::new([9; MlKemPublicKey::SIZE])), @@ -225,6 +232,7 @@ fn handshake_record_round_trip_supports_ik_and_kk() { let kk = QlHandshakeRecord::Kk1(Kk1 { header: handshake_header(1, 2), meta: handshake_meta(2), + transport_params: handshake_transport_params(131_072), skem_ciphertext: MlKemCiphertext::new(Box::new([11; MlKemCiphertext::SIZE])), ephemeral: EphemeralPublicKey { mlkem_public_key: MlKemPublicKey::new(Box::new([15; MlKemPublicKey::SIZE])), @@ -250,8 +258,18 @@ fn ik_handshake_rejects_tampered_handshake_meta() { let initiator = make_identity(&crypto, 1); let responder = make_identity(&crypto, 2); - let mut initiator_state = IkHandshake::new_initiator(&crypto, initiator, responder.bundle()); - let mut responder_state = IkHandshake::new_responder(&crypto, responder, None); + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator, + responder.bundle(), + TransportParams::default(), + ); + let mut responder_state = IkHandshake::new_responder( + &crypto, + responder, + None, + TransportParams::default(), + ); let m1 = initiator_state .write_1(&crypto, handshake_meta(77)) @@ -276,8 +294,18 @@ fn kk_handshake_rejects_tampered_handshake_header() { let responder = make_identity(&crypto, 2); let mut initiator_state = - KkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); - let mut responder_state = KkHandshake::new_responder(&crypto, responder, initiator.bundle()); + KkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + TransportParams::default(), + ); + let mut responder_state = KkHandshake::new_responder( + &crypto, + responder, + initiator.bundle(), + TransportParams::default(), + ); let m1 = initiator_state .write_1(&crypto, handshake_meta(88)) @@ -295,14 +323,59 @@ fn kk_handshake_rejects_tampered_handshake_header() { ); } +#[test] +fn ik_handshake_rejects_tampered_transport_params() { + let crypto = TestCrypto::new(10_1); + let initiator = make_identity(&crypto, 1); + let responder = make_identity(&crypto, 2); + + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator, + responder.bundle(), + handshake_transport_params(4096), + ); + let mut responder_state = IkHandshake::new_responder( + &crypto, + responder, + None, + handshake_transport_params(8192), + ); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(89)) + .unwrap(); + responder_state.read_1(&crypto, 0, &m1).unwrap(); + + let mut m2 = responder_state + .write_2(&crypto, handshake_meta(89)) + .unwrap(); + m2.transport_params.initial_stream_receive_window += 1; + + assert_eq!( + initiator_state.read_2(&crypto, 0, &m2), + Err(WireError::DecryptFailed) + ); +} + #[test] fn ik_handshake_rejects_tampered_handshake_header() { let crypto = TestCrypto::new(11); let initiator = make_identity(&crypto, 1); let responder = make_identity(&crypto, 2); - let mut initiator_state = IkHandshake::new_initiator(&crypto, initiator, responder.bundle()); - let mut responder_state = IkHandshake::new_responder(&crypto, responder, None); + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator, + responder.bundle(), + TransportParams::default(), + ); + let mut responder_state = IkHandshake::new_responder( + &crypto, + responder, + None, + TransportParams::default(), + ); let mut m1 = initiator_state .write_1(&crypto, handshake_meta(90)) @@ -322,8 +395,18 @@ fn ik_handshake_rejects_bound_remote_bundle_mismatch() { let bogus = make_identity(&crypto, 1); let responder = make_identity(&crypto, 2); - let mut initiator_state = IkHandshake::new_initiator(&crypto, initiator, responder.bundle()); - let mut responder_state = IkHandshake::new_responder(&crypto, responder, Some(bogus.bundle())); + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator, + responder.bundle(), + TransportParams::default(), + ); + let mut responder_state = IkHandshake::new_responder( + &crypto, + responder, + Some(bogus.bundle()), + TransportParams::default(), + ); let m1 = initiator_state .write_1(&crypto, handshake_meta(91)) @@ -341,8 +424,18 @@ fn ik_handshake_rejects_expired_message() { let initiator = make_identity(&crypto, 1); let responder = make_identity(&crypto, 2); - let mut initiator_state = IkHandshake::new_initiator(&crypto, initiator, responder.bundle()); - let mut responder_state = IkHandshake::new_responder(&crypto, responder, None); + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator, + responder.bundle(), + TransportParams::default(), + ); + let mut responder_state = IkHandshake::new_responder( + &crypto, + responder, + None, + TransportParams::default(), + ); let m1 = initiator_state .write_1( @@ -366,9 +459,20 @@ fn ik_handshake_round_trip_derives_matching_transport_and_learns_remote() { let initiator = make_identity(&crypto, 3); let responder = make_identity(&crypto, 4); - let mut initiator_state = - IkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); - let mut responder_state = IkHandshake::new_responder(&crypto, responder.clone(), None); + let initiator_params = handshake_transport_params(4096); + let responder_params = handshake_transport_params(8192); + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + initiator_params, + ); + let mut responder_state = IkHandshake::new_responder( + &crypto, + responder.clone(), + None, + responder_params, + ); let m1 = initiator_state .write_1(&crypto, handshake_meta(11)) @@ -399,6 +503,8 @@ fn ik_handshake_round_trip_derives_matching_transport_and_learns_remote() { ); assert_eq!(initiator_final.remote_bundle, responder.bundle()); assert_eq!(responder_final.remote_bundle, initiator.bundle()); + assert_eq!(initiator_final.remote_transport_params, responder_params); + assert_eq!(responder_final.remote_transport_params, initiator_params); } #[test] @@ -407,10 +513,20 @@ fn ik_handshake_round_trip_derives_matching_transport_with_bound_responder() { let initiator = make_identity(&crypto, 3); let responder = make_identity(&crypto, 4); - let mut initiator_state = - IkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); - let mut responder_state = - IkHandshake::new_responder(&crypto, responder.clone(), Some(initiator.bundle())); + let initiator_params = handshake_transport_params(16_384); + let responder_params = handshake_transport_params(32_768); + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + initiator_params, + ); + let mut responder_state = IkHandshake::new_responder( + &crypto, + responder.clone(), + Some(initiator.bundle()), + responder_params, + ); let m1 = initiator_state .write_1(&crypto, handshake_meta(12)) @@ -441,6 +557,8 @@ fn ik_handshake_round_trip_derives_matching_transport_with_bound_responder() { ); assert_eq!(initiator_final.remote_bundle, responder.bundle()); assert_eq!(responder_final.remote_bundle, initiator.bundle()); + assert_eq!(initiator_final.remote_transport_params, responder_params); + assert_eq!(responder_final.remote_transport_params, initiator_params); } #[test] @@ -449,10 +567,20 @@ fn kk_handshake_round_trip_derives_matching_transport() { let initiator = make_identity(&crypto, 3); let responder = make_identity(&crypto, 4); - let mut initiator_state = - KkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); - let mut responder_state = - KkHandshake::new_responder(&crypto, responder.clone(), initiator.bundle()); + let initiator_params = handshake_transport_params(24_576); + let responder_params = handshake_transport_params(49_152); + let mut initiator_state = KkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + initiator_params, + ); + let mut responder_state = KkHandshake::new_responder( + &crypto, + responder.clone(), + initiator.bundle(), + responder_params, + ); let m1 = initiator_state .write_1(&crypto, handshake_meta(21)) @@ -483,6 +611,43 @@ fn kk_handshake_round_trip_derives_matching_transport() { ); assert_eq!(initiator_final.remote_bundle, responder.bundle()); assert_eq!(responder_final.remote_bundle, initiator.bundle()); + assert_eq!(initiator_final.remote_transport_params, responder_params); + assert_eq!(responder_final.remote_transport_params, initiator_params); +} + +#[test] +fn kk_handshake_rejects_tampered_transport_params() { + let crypto = TestCrypto::new(31); + let initiator = make_identity(&crypto, 3); + let responder = make_identity(&crypto, 4); + + let mut initiator_state = KkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + handshake_transport_params(12288), + ); + let mut responder_state = KkHandshake::new_responder( + &crypto, + responder, + initiator.bundle(), + handshake_transport_params(24576), + ); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(22)) + .unwrap(); + responder_state.read_1(&crypto, 0, &m1).unwrap(); + + let mut m2 = responder_state + .write_2(&crypto, handshake_meta(22)) + .unwrap(); + m2.transport_params.initial_stream_receive_window += 1; + + assert_eq!( + initiator_state.read_2(&crypto, 0, &m2), + Err(WireError::DecryptFailed) + ); } #[test] @@ -575,8 +740,18 @@ fn protocol_record_size_breakdown() { let responder = make_identity(&crypto, 2); let mut ik_initiator = - IkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); - let mut ik_responder = IkHandshake::new_responder(&crypto, responder.clone(), None); + IkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + TransportParams::default(), + ); + let mut ik_responder = IkHandshake::new_responder( + &crypto, + responder.clone(), + None, + TransportParams::default(), + ); let ik1 = ik_initiator.write_1(&crypto, handshake_meta(101)).unwrap(); ik_responder.read_1(&crypto, 0, &ik1).unwrap(); @@ -588,8 +763,18 @@ fn protocol_record_size_breakdown() { let ik2 = QlHandshakeRecord::Ik2(ik2); let mut kk_initiator = - KkHandshake::new_initiator(&crypto, initiator.clone(), responder.bundle()); - let mut kk_responder = KkHandshake::new_responder(&crypto, responder, initiator.bundle()); + KkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + TransportParams::default(), + ); + let mut kk_responder = KkHandshake::new_responder( + &crypto, + responder, + initiator.bundle(), + TransportParams::default(), + ); let kk1 = kk_initiator.write_1(&crypto, handshake_meta(201)).unwrap(); kk_responder.read_1(&crypto, 0, &kk1).unwrap(); From ae33dc1dd8d493718818cabbc139cc0339552a45 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 4 Apr 2026 13:47:11 -0400 Subject: [PATCH 090/304] ql-fsm: transport params support --- ql-fsm/src/implementation/core.rs | 10 +++++ ql-fsm/src/implementation/handshake/ik.rs | 21 ++++++++-- ql-fsm/src/implementation/handshake/kk.rs | 20 ++++++++-- ql-fsm/src/implementation/handshake/mod.rs | 17 ++++++-- ql-fsm/src/lib.rs | 2 + ql-fsm/src/session/mod.rs | 12 ++++++ ql-fsm/src/session/state.rs | 8 +++- ql-fsm/src/session/tests.rs | 39 ++++++++++++++++++ ql-fsm/src/state.rs | 4 +- ql-fsm/src/tests/handshake.rs | 46 ++++++++++++++++++++++ ql-fsm/src/tests/mod.rs | 40 +++++++++++++++++-- ql-fsm/src/tests/session.rs | 36 +++++++++++++++++ 12 files changed, 238 insertions(+), 17 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index 8d99149d..3776e9e5 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -210,6 +210,16 @@ pub fn reset_session(fsm: &mut QlFsm) { peer_timeout: fsm.config.session_peer_timeout, stream_send_buffer_size: fsm.config.session_stream_send_buffer_size, stream_receive_buffer_size: fsm.config.session_stream_receive_buffer_size, + initial_peer_stream_receive_window: fsm + .state + .link + .transport() + .map(|transport| { + transport + .remote_transport_params + .initial_stream_receive_window + }) + .unwrap_or(fsm.config.session_stream_receive_buffer_size as u32), }, fsm.state.now.instant, ); diff --git a/ql-fsm/src/implementation/handshake/ik.rs b/ql-fsm/src/implementation/handshake/ik.rs index d3e3ce35..493f9ecb 100644 --- a/ql-fsm/src/implementation/handshake/ik.rs +++ b/ql-fsm/src/implementation/handshake/ik.rs @@ -1,4 +1,4 @@ -use ql_wire::{self as wire, Ik1, Ik2, PeerBundle, QlCrypto, QlHandshakeRecord}; +use ql_wire::{self as wire, Ik1, Ik2, PeerBundle, QlCrypto, QlHandshakeRecord, TransportParams}; use super::{ emit_peer_status, enqueue_handshake, finish_handshake, is_replayed_handshake_start, @@ -15,7 +15,14 @@ pub fn start_initiator( peer: PeerBundle, ) -> Result<(), QlFsmError> { let meta = super::next_handshake_meta(fsm); - let mut handshake = wire::IkHandshake::new_initiator(crypto, fsm.identity.clone(), peer); + let mut handshake = wire::IkHandshake::new_initiator( + crypto, + fsm.identity.clone(), + peer, + TransportParams { + initial_stream_receive_window: fsm.config.session_stream_receive_buffer_size as u32, + }, + ); let message = handshake.write_1(crypto, meta)?; fsm.state.link = LinkState::IkInitiator(IkInitiatorState { @@ -51,8 +58,14 @@ pub fn handle_ik1( reset_connected_session_if_needed(fsm); - let mut handshake = - wire::IkHandshake::new_responder(crypto, fsm.identity.clone(), fsm.state.peer.clone()); + let mut handshake = wire::IkHandshake::new_responder( + crypto, + fsm.identity.clone(), + fsm.state.peer.clone(), + TransportParams { + initial_stream_receive_window: fsm.config.session_stream_receive_buffer_size as u32, + }, + ); handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; let outbound = handshake.write_2(crypto, message.meta)?; let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); diff --git a/ql-fsm/src/implementation/handshake/kk.rs b/ql-fsm/src/implementation/handshake/kk.rs index 454f43e6..3af501b0 100644 --- a/ql-fsm/src/implementation/handshake/kk.rs +++ b/ql-fsm/src/implementation/handshake/kk.rs @@ -1,4 +1,4 @@ -use ql_wire::{self as wire, Kk1, Kk2, PeerBundle, QlCrypto, QlHandshakeRecord}; +use ql_wire::{self as wire, Kk1, Kk2, PeerBundle, QlCrypto, QlHandshakeRecord, TransportParams}; use super::{ emit_peer_status, enqueue_handshake, finish_handshake, is_replayed_handshake_start, @@ -15,7 +15,14 @@ pub fn start_initiator( peer: PeerBundle, ) -> Result<(), QlFsmError> { let meta = super::next_handshake_meta(fsm); - let mut handshake = wire::KkHandshake::new_initiator(crypto, fsm.identity.clone(), peer); + let mut handshake = wire::KkHandshake::new_initiator( + crypto, + fsm.identity.clone(), + peer, + TransportParams { + initial_stream_receive_window: fsm.config.session_stream_receive_buffer_size as u32, + }, + ); let message = handshake.write_1(crypto, meta)?; fsm.state.link = LinkState::KkInitiator(KkInitiatorState { @@ -50,7 +57,14 @@ pub fn handle_kk1( reset_connected_session_if_needed(fsm); - let mut handshake = wire::KkHandshake::new_responder(crypto, fsm.identity.clone(), peer); + let mut handshake = wire::KkHandshake::new_responder( + crypto, + fsm.identity.clone(), + peer, + TransportParams { + initial_stream_receive_window: fsm.config.session_stream_receive_buffer_size as u32, + }, + ); handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; let outbound = handshake.write_2(crypto, message.meta)?; let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs index 49202654..b3af9682 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -85,19 +85,30 @@ pub fn finish_handshake( transport: SessionTransport, remote_bundle: &wire::PeerBundle, ) -> Result<(), QlFsmError> { - if let Some(peer) = fsm.state.peer.as_ref() { + let initial_peer_stream_receive_window = transport + .remote_transport_params + .initial_stream_receive_window; + let new_peer = if let Some(peer) = fsm.state.peer.as_ref() { if peer != remote_bundle { return Err(QlFsmError::InvalidPayload); } + false } else { fsm.state.peer = Some(remote_bundle.clone()); - reset_session(fsm); fsm.state .events .push_back(QlFsmEvent::NewPeer(remote_bundle.clone())); - } + true + }; fsm.state.link = LinkState::Connected(transport); + + if new_peer { + reset_session(fsm); + } else { + fsm.session + .set_initial_peer_stream_receive_window(initial_peer_stream_receive_window); + } emit_peer_status(fsm); Ok(()) } diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 19ed58f3..870ca14f 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -177,6 +177,8 @@ impl QlFsm { peer_timeout: config.session_peer_timeout, stream_send_buffer_size: config.session_stream_send_buffer_size, stream_receive_buffer_size: config.session_stream_receive_buffer_size, + initial_peer_stream_receive_window: config + .session_stream_receive_buffer_size as u32, }, now.instant, ), diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 327d890f..157df391 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -37,6 +37,7 @@ pub struct SessionFsmConfig { pub peer_timeout: Duration, pub stream_send_buffer_size: usize, pub stream_receive_buffer_size: usize, + pub initial_peer_stream_receive_window: u32, } impl Default for SessionFsmConfig { @@ -51,6 +52,7 @@ impl Default for SessionFsmConfig { peer_timeout: Duration::from_secs(30), stream_send_buffer_size: 64 * 1024, stream_receive_buffer_size: 64 * 1024, + initial_peer_stream_receive_window: 16 * 1024, } } } @@ -132,6 +134,7 @@ impl SessionFsm { StreamState::new( StreamRole::Initiator, self.config.stream_receive_buffer_size, + self.config.initial_peer_stream_receive_window, ), ); Ok(stream_id) @@ -235,6 +238,13 @@ impl SessionFsm { Ok(()) } + pub fn set_initial_peer_stream_receive_window(&mut self, window: u32) { + self.config.initial_peer_stream_receive_window = window; + for stream in self.state.streams.values_mut() { + stream.peer_max_offset = window as u64; + } + } + pub fn receive<'a, I>( &mut self, now: Instant, @@ -700,6 +710,7 @@ impl SessionFsm { entry.insert(StreamState::new( StreamRole::Responder, self.config.stream_receive_buffer_size, + self.config.initial_peer_stream_receive_window, )) } }; @@ -799,6 +810,7 @@ impl SessionFsm { entry.insert(StreamState::new( StreamRole::Responder, self.config.stream_receive_buffer_size, + self.config.initial_peer_stream_receive_window, )); true } diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index 0a373782..060a75fb 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -38,12 +38,16 @@ pub struct StreamState { } impl StreamState { - pub fn new(role: StreamRole, receive_buffer_size: usize) -> Self { + pub fn new( + role: StreamRole, + receive_buffer_size: usize, + initial_peer_stream_receive_window: u32, + ) -> Self { Self { role, tx: StreamTx::new(), pending_close: None, - peer_max_offset: receive_buffer_size as u64, + peer_max_offset: initial_peer_stream_receive_window as u64, outbound_state: OutboundState::Open, inbound_state: InboundState::Open, rx: StreamRx::new(receive_buffer_size), diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index b44174bf..c28a8094 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -302,3 +302,42 @@ fn duplicate_stream_data_is_not_redelivered() { assert_eq!(read_stream_all(&mut fsm, stream_id), b"hi".to_vec()); } + +#[test] +fn initial_peer_stream_receive_window_limits_first_send() { + let now = Instant::now(); + let mut fsm = SessionFsm::new( + SessionFsmConfig { + initial_peer_stream_receive_window: 3, + ..SessionFsmConfig::default() + }, + now, + ); + let stream_id = fsm.open_stream().unwrap(); + + assert_eq!(fsm.write_stream(stream_id, b"hello").unwrap(), 5); + let (_first_seq, first) = next_outbound(&mut fsm, now).unwrap(); + assert!(matches!( + first.frames.as_slice(), + [SessionFrame::StreamData(frame)] if frame.stream_id == stream_id && frame.bytes.as_slice() == b"hel" + )); + + let events = receive_events( + &mut fsm, + now + Duration::from_millis(1), + RecordSeq(9), + &SessionRecord { + frames: vec![SessionFrame::StreamWindow(ql_wire::StreamWindow { + stream_id, + maximum_offset: 5, + })], + }, + ); + assert!(events.is_empty()); + + let (_second_seq, second) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); + assert!(matches!( + second.frames.as_slice(), + [SessionFrame::StreamData(frame)] if frame.stream_id == stream_id && frame.offset == 3 && frame.bytes.as_slice() == b"lo" + )); +} diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index 57655aae..0ca6a1ff 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -2,7 +2,7 @@ use std::{collections::VecDeque, time::Instant}; use ql_wire::{ ConnectionId, EphemeralPublicKey, HandshakeId, IkHandshake, KkHandshake, PeerBundle, - QlHandshakeRecord, SessionKey, + QlHandshakeRecord, SessionKey, TransportParams, }; use crate::{replay_cache::ReplayCache, FsmTime, PeerStatus, QlFsmEvent, QlSessionEvent}; @@ -24,6 +24,7 @@ pub struct SessionTransport { pub rx_key: SessionKey, pub tx_connection_id: ConnectionId, pub rx_connection_id: ConnectionId, + pub remote_transport_params: TransportParams, } impl SessionTransport { @@ -34,6 +35,7 @@ impl SessionTransport { rx_key: finalized.rx_key, tx_connection_id: finalized.tx_connection_id, rx_connection_id: finalized.rx_connection_id, + remote_transport_params: finalized.remote_transport_params, }, finalized.remote_bundle, ) diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index b6f089f1..21344418 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -35,6 +35,52 @@ fn kk_connect_round_trip_establishes_transport() { assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); } +#[test] +fn ik_connect_learns_remote_initial_stream_receive_window() { + let mut harness = Harness::paired_known_with_configs( + QlFsmConfig { + session_stream_receive_buffer_size: 9, + ..QlFsmConfig::default() + }, + QlFsmConfig { + session_stream_receive_buffer_size: 3, + ..QlFsmConfig::default() + }, + ); + + harness + .a + .fsm + .connect_ik(harness.time(), &harness.a.crypto) + .unwrap(); + harness.pump(); + + assert_eq!( + harness + .a + .fsm + .state + .link + .transport() + .unwrap() + .remote_transport_params + .initial_stream_receive_window, + 3 + ); + assert_eq!( + harness + .b + .fsm + .state + .link + .transport() + .unwrap() + .remote_transport_params + .initial_stream_receive_window, + 9 + ); +} + #[test] fn connect_methods_require_bound_peer() { let time = Harness::paired_known(QlFsmConfig::default()).time(); diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 8dc777b3..6fd940f6 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -11,7 +11,7 @@ use libcrux_ml_kem::mlkem1024; use ql_wire::{ self, generate_identity, ConnectionId, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, QlAead, QlCrypto, QlHash, QlIdentity, QlKem, QlRandom, SessionKey, - WireParse, ENCRYPTED_MESSAGE_AUTH_SIZE, XID, + TransportParams, WireParse, ENCRYPTED_MESSAGE_AUTH_SIZE, XID, }; use sha2::{Digest, Sha256}; @@ -153,10 +153,23 @@ struct Harness { impl Harness { fn paired_known(config: QlFsmConfig) -> Self { - Self::paired(config, true, true) + Self::paired_with_configs(config, config, true, true) } fn paired(config: QlFsmConfig, know_a: bool, know_b: bool) -> Self { + Self::paired_with_configs(config, config, know_a, know_b) + } + + fn paired_known_with_configs(config_a: QlFsmConfig, config_b: QlFsmConfig) -> Self { + Self::paired_with_configs(config_a, config_b, true, true) + } + + fn paired_with_configs( + config_a: QlFsmConfig, + config_b: QlFsmConfig, + know_a: bool, + know_b: bool, + ) -> Self { let identity_a = test_identity(11); let identity_b = test_identity(73); let now = Instant::now(); @@ -169,11 +182,11 @@ impl Harness { now, unix_secs: time.unix_secs, a: Node { - fsm: QlFsm::new(config, identity_a.clone(), time), + fsm: QlFsm::new(config_a, identity_a.clone(), time), crypto: TestCrypto::new(1), }, b: Node { - fsm: QlFsm::new(config, identity_b.clone(), time), + fsm: QlFsm::new(config_b, identity_b.clone(), time), crypto: TestCrypto::new(2), }, }; @@ -202,12 +215,26 @@ impl Harness { rx_key: b_to_a_key.clone(), tx_connection_id: a_to_b_conn, rx_connection_id: b_to_a_conn, + remote_transport_params: TransportParams { + initial_stream_receive_window: harness + .b + .fsm + .config + .session_stream_receive_buffer_size as u32, + }, }); harness.b.fsm.state.link = LinkState::Connected(SessionTransport { tx_key: b_to_a_key, rx_key: a_to_b_key, tx_connection_id: b_to_a_conn, rx_connection_id: a_to_b_conn, + remote_transport_params: TransportParams { + initial_stream_receive_window: harness + .a + .fsm + .config + .session_stream_receive_buffer_size as u32, + }, }); harness.a.fsm.session = SessionFsm::new(session_config(&harness, true), harness.now); harness.b.fsm.session = SessionFsm::new(session_config(&harness, false), harness.now); @@ -321,6 +348,11 @@ fn session_config(harness: &Harness, a: bool) -> SessionFsmConfig { peer_timeout: config.session_peer_timeout, stream_send_buffer_size: config.session_stream_send_buffer_size, stream_receive_buffer_size: config.session_stream_receive_buffer_size, + initial_peer_stream_receive_window: if a { + harness.b.fsm.config.session_stream_receive_buffer_size as u32 + } else { + harness.a.fsm.config.session_stream_receive_buffer_size as u32 + }, } } diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 998462ad..ec50e6f1 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -354,3 +354,39 @@ fn session_records_contain_ack_frames_after_delivery() { [ql_wire::SessionFrame::Ack(_)] )); } + +#[test] +fn queued_stream_work_uses_negotiated_initial_peer_credit_after_connect() { + let mut harness = Harness::paired_known_with_configs( + QlFsmConfig { + session_stream_receive_buffer_size: 8, + ..QlFsmConfig::default() + }, + QlFsmConfig { + session_stream_receive_buffer_size: 3, + ..QlFsmConfig::default() + }, + ); + + let stream_id = harness.a.fsm.open_stream().unwrap(); + assert_eq!(harness.a.fsm.write_stream(stream_id, b"hello").unwrap(), 5); + + harness + .a + .fsm + .connect_ik(harness.time(), &harness.a.crypto) + .unwrap(); + let ik1 = harness.next_outbound_a().unwrap(); + harness.deliver_to_b(ik1); + let ik2 = harness.next_outbound_b().unwrap(); + harness.deliver_to_a(ik2); + + let data = harness.next_outbound_a().unwrap(); + let session_key = harness.b.fsm.state.link.transport().unwrap().rx_key.clone(); + let (_header, record) = decrypt_record(&harness.b.crypto, &data, &session_key); + + assert!(matches!( + record.frames.as_slice(), + [ql_wire::SessionFrame::StreamData(frame)] if frame.stream_id == stream_id && frame.bytes.as_slice() == b"hel" + )); +} From 8958e20ebb933a42d20f075044ba3d0c2482b17f Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 4 Apr 2026 14:16:02 -0400 Subject: [PATCH 091/304] ql-fsm: get rid of queuing stream content before session is connected --- ql-fsm/src/implementation/core.rs | 141 +++++++++------------ ql-fsm/src/implementation/handshake/mod.rs | 57 ++++----- ql-fsm/src/lib.rs | 18 --- ql-fsm/src/session/mod.rs | 13 -- ql-fsm/src/session/stream_tx.rs | 9 -- ql-fsm/src/state.rs | 39 ++++-- ql-fsm/src/tests/mod.rs | 56 ++++---- ql-fsm/src/tests/session.rs | 78 ++++-------- 8 files changed, 172 insertions(+), 239 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index 3776e9e5..a219e0ff 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -9,16 +9,14 @@ use ql_wire::{ }; use crate::{ - session::{stream_parity::StreamParity, SessionEvent, SessionFsmConfig}, - state::LinkState, - OutboundWrite, QlFsm, QlFsmError, QlFsmEvent, QlSessionEvent, SessionWriteId, StreamReadIter, + session::SessionEvent, state::LinkState, OutboundWrite, QlFsm, QlFsmError, QlFsmEvent, + QlSessionEvent, SessionWriteId, StreamReadIter, }; pub fn handle_bind_peer(fsm: &mut QlFsm, peer: ql_wire::PeerBundle) { fsm.state.handshake = None; fsm.state.link = LinkState::Idle; fsm.state.peer = Some(peer.clone()); - reset_session(fsm); fsm.state.events.push_back(QlFsmEvent::NewPeer(peer)); emit_peer_status(fsm); } @@ -36,22 +34,28 @@ pub fn receive( } wire::RecordType::Session => { let record = wire::QlSessionRecord::parse_bytes(&mut bytes[..])?; - let transport = fsm.state.link.transport().ok_or(QlFsmError::NoSession)?; - if record.header.connection_id != transport.rx_connection_id { + let state = fsm.state.link.connected_mut_or_err()?; + if record.header.connection_id != state.transport.rx_connection_id { return Err(QlFsmError::InvalidPayload); } - - let plaintext = - wire::decrypt_record(crypto, &record.header, record.payload, &transport.rx_key)?; + let plaintext = wire::decrypt_record( + crypto, + &record.header, + record.payload, + &state.transport.rx_key, + )?; let frames = wire::SessionRecord::parse(plaintext)?; + let mut session_closed = false; - fsm.session + state + .session .receive(fsm.state.now.instant, record.header.seq, frames, { let session_events = &mut fsm.state.session_events; |event| { session_closed |= forward_session_event(session_events, event); } }); + if session_closed { apply_session_closed(fsm); } @@ -62,17 +66,20 @@ pub fn receive( pub fn on_timer(fsm: &mut QlFsm) { super::handle_timer(fsm); - if fsm.state.link.transport().is_some() { - let mut session_closed = false; - fsm.session.on_timer(fsm.state.now.instant, { - let session_events = &mut fsm.state.session_events; - |event| { - session_closed |= forward_session_event(session_events, event); - } - }); - if session_closed { - apply_session_closed(fsm); + let Some(state) = fsm.state.link.connected_mut() else { + return; + }; + + let mut session_closed = false; + state.session.on_timer(fsm.state.now.instant, { + let session_events = &mut fsm.state.session_events; + |event| { + session_closed |= forward_session_event(session_events, event); } + }); + + if session_closed { + apply_session_closed(fsm); } } @@ -81,8 +88,8 @@ pub fn next_deadline(fsm: &QlFsm) -> Option { super::next_handshake_deadline(fsm), fsm.state .link - .transport() - .and_then(|_| fsm.session.next_deadline()), + .connected() + .and_then(|state| state.session.next_deadline()), ] .into_iter() .flatten() @@ -96,18 +103,16 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option Option Result { - fsm.state.ensure_peer_bound()?; - Ok(fsm.session.open_stream()?) + let state = fsm.state.link.connected_mut_or_err()?; + Ok(state.session.open_stream()?) } pub fn write_stream( @@ -149,11 +159,15 @@ pub fn write_stream( stream_id: StreamId, bytes: &[u8], ) -> Result { - Ok(fsm.session.write_stream(stream_id, bytes)?) + let state = fsm.state.link.connected_mut_or_err()?; + Ok(state.session.write_stream(stream_id, bytes)?) } pub fn stream_read(fsm: &QlFsm, stream_id: StreamId) -> Option> { - fsm.session.stream_read(stream_id) + fsm.state + .link + .connected() + .and_then(|state| state.session.stream_read(stream_id)) } pub fn stream_read_commit( @@ -161,15 +175,20 @@ pub fn stream_read_commit( stream_id: StreamId, len: usize, ) -> Result<(), QlFsmError> { - Ok(fsm.session.stream_read_commit(stream_id, len)?) + let state = fsm.state.link.connected_mut_or_err()?; + Ok(state.session.stream_read_commit(stream_id, len)?) } pub fn stream_available_bytes(fsm: &QlFsm, stream_id: StreamId) -> Option { - fsm.session.stream_available_bytes(stream_id) + fsm.state + .link + .connected() + .and_then(|state| state.session.stream_available_bytes(stream_id)) } pub fn finish_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result<(), QlFsmError> { - Ok(fsm.session.finish_stream(stream_id)?) + let state = fsm.state.link.connected_mut_or_err()?; + Ok(state.session.finish_stream(stream_id)?) } pub fn close_stream( @@ -178,12 +197,13 @@ pub fn close_stream( target: CloseTarget, code: StreamCloseCode, ) -> Result<(), QlFsmError> { - Ok(fsm.session.close_stream(stream_id, target, code)?) + let state = fsm.state.link.connected_mut_or_err()?; + Ok(state.session.close_stream(stream_id, target, code)?) } pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), QlFsmError> { - ensure_session_open(fsm)?; - Ok(fsm.session.queue_ping()?) + let state = fsm.state.link.connected_mut_or_err()?; + Ok(state.session.queue_ping()?) } pub fn emit_peer_status(fsm: &mut QlFsm) { @@ -195,36 +215,6 @@ pub fn emit_peer_status(fsm: &mut QlFsm) { } } -pub fn reset_session(fsm: &mut QlFsm) { - let local_parity = fsm.state.peer.as_ref().map_or(StreamParity::Even, |peer| { - StreamParity::for_local(fsm.identity.xid, peer.xid) - }); - fsm.session = crate::session::SessionFsm::new( - SessionFsmConfig { - local_parity, - record_target_size: fsm.config.session_record_target_size, - record_max_size: fsm.config.session_record_max_size, - ack_delay: fsm.config.session_record_ack_delay, - retransmit_timeout: fsm.config.session_record_retransmit_timeout, - keepalive_interval: fsm.config.session_keepalive_interval, - peer_timeout: fsm.config.session_peer_timeout, - stream_send_buffer_size: fsm.config.session_stream_send_buffer_size, - stream_receive_buffer_size: fsm.config.session_stream_receive_buffer_size, - initial_peer_stream_receive_window: fsm - .state - .link - .transport() - .map(|transport| { - transport - .remote_transport_params - .initial_stream_receive_window - }) - .unwrap_or(fsm.config.session_stream_receive_buffer_size as u32), - }, - fsm.state.now.instant, - ); -} - fn forward_session_event( session_events: &mut VecDeque, event: SessionEvent, @@ -266,15 +256,6 @@ fn apply_session_closed(fsm: &mut QlFsm) { fsm.state.link = crate::state::LinkState::Idle; emit_peer_status(fsm); } - reset_session(fsm); -} - -fn ensure_session_open(fsm: &QlFsm) -> Result<(), QlFsmError> { - fsm.state.ensure_peer_bound()?; - if fsm.state.link.transport().is_none() { - return Err(QlFsmError::SessionClosed); - } - Ok(()) } pub(super) fn deadline_after_secs(now_secs: u64, duration: Duration) -> u64 { diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs index b3af9682..4c42ef9b 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -3,10 +3,11 @@ mod kk; use ql_wire::{self as wire, EphemeralPublicKey, HandshakeMeta, QlCrypto, QlHandshakeRecord}; -use super::{emit_peer_status, reset_session}; +use super::emit_peer_status; use crate::{ - state::{LinkState, SessionTransport}, - QlFsm, QlFsmError, QlFsmEvent, QlSessionEvent, + session::{stream_parity::StreamParity, SessionFsm, SessionFsmConfig}, + state::{ConnectedState, LinkState, SessionTransport}, + QlFsm, QlFsmError, QlFsmEvent, }; pub fn handle_connect_ik(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { @@ -72,7 +73,6 @@ pub fn handle_timer(fsm: &mut QlFsm) { fsm.state.link = LinkState::Idle; fsm.state.handshake = None; - fail_pending_connect_session(fsm, ql_wire::SessionCloseCode::TIMEOUT); emit_peer_status(fsm); } @@ -85,30 +85,36 @@ pub fn finish_handshake( transport: SessionTransport, remote_bundle: &wire::PeerBundle, ) -> Result<(), QlFsmError> { - let initial_peer_stream_receive_window = transport - .remote_transport_params - .initial_stream_receive_window; - let new_peer = if let Some(peer) = fsm.state.peer.as_ref() { + if let Some(peer) = fsm.state.peer.as_ref() { if peer != remote_bundle { return Err(QlFsmError::InvalidPayload); } - false } else { fsm.state.peer = Some(remote_bundle.clone()); fsm.state .events .push_back(QlFsmEvent::NewPeer(remote_bundle.clone())); - true - }; - - fsm.state.link = LinkState::Connected(transport); - - if new_peer { - reset_session(fsm); - } else { - fsm.session - .set_initial_peer_stream_receive_window(initial_peer_stream_receive_window); } + + let config = &fsm.config; + let session = SessionFsm::new( + SessionFsmConfig { + local_parity: StreamParity::for_local(fsm.identity.xid, remote_bundle.xid), + record_target_size: config.session_record_target_size, + record_max_size: config.session_record_max_size, + ack_delay: config.session_record_ack_delay, + retransmit_timeout: config.session_record_retransmit_timeout, + keepalive_interval: config.session_keepalive_interval, + peer_timeout: config.session_peer_timeout, + stream_send_buffer_size: config.session_stream_send_buffer_size, + stream_receive_buffer_size: config.session_stream_receive_buffer_size, + initial_peer_stream_receive_window: transport + .remote_transport_params + .initial_stream_receive_window, + }, + fsm.state.now.instant, + ); + fsm.state.link = LinkState::Connected(ConnectedState { transport, session }); emit_peer_status(fsm); Ok(()) } @@ -116,22 +122,9 @@ pub fn finish_handshake( pub fn reset_connected_session_if_needed(fsm: &mut QlFsm) { if matches!(fsm.state.link, LinkState::Connected(_)) { fsm.state.link = LinkState::Idle; - reset_session(fsm); } } -fn fail_pending_connect_session(fsm: &mut QlFsm, code: ql_wire::SessionCloseCode) { - if !fsm.session.has_pending_stream_work() { - return; - } - reset_session(fsm); - fsm.state - .session_events - .push_back(QlSessionEvent::SessionClosed(ql_wire::SessionClose { - code, - })); -} - fn local_start_wins(local: &EphemeralPublicKey, inbound: &EphemeralPublicKey) -> bool { local.mlkem_public_key.as_bytes() <= inbound.mlkem_public_key.as_bytes() } diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 870ca14f..eedfd005 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -36,7 +36,6 @@ pub use session::stream_rx::StreamReadIter; use crate::{ replay_cache::ReplayCache, - session::SessionFsm, state::{LinkState, QlFsmState}, }; @@ -156,7 +155,6 @@ pub struct QlFsm { pub config: QlFsmConfig, /// local identity and private keys pub identity: QlIdentity, - pub(crate) session: SessionFsm, pub(crate) state: QlFsmState, } @@ -166,22 +164,6 @@ impl QlFsm { Self { config, identity, - session: session::SessionFsm::new( - session::SessionFsmConfig { - local_parity: session::stream_parity::StreamParity::Even, - record_target_size: config.session_record_target_size, - record_max_size: config.session_record_max_size, - ack_delay: config.session_record_ack_delay, - retransmit_timeout: config.session_record_retransmit_timeout, - keepalive_interval: config.session_keepalive_interval, - peer_timeout: config.session_peer_timeout, - stream_send_buffer_size: config.session_stream_send_buffer_size, - stream_receive_buffer_size: config.session_stream_receive_buffer_size, - initial_peer_stream_receive_window: config - .session_stream_receive_buffer_size as u32, - }, - now.instant, - ), state: QlFsmState { replay_cache: ReplayCache::default(), next_control_id: 1, diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 157df391..f2eb84ac 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -238,13 +238,6 @@ impl SessionFsm { Ok(()) } - pub fn set_initial_peer_stream_receive_window(&mut self, window: u32) { - self.config.initial_peer_stream_receive_window = window; - for stream in self.state.streams.values_mut() { - stream.peer_max_offset = window as u64; - } - } - pub fn receive<'a, I>( &mut self, now: Instant, @@ -409,12 +402,6 @@ impl SessionFsm { .min() } - pub fn has_pending_stream_work(&self) -> bool { - self.state.streams.values().any(|stream| { - stream.pending_close.is_some() || stream.pending_window || stream.tx.has_pending() - }) - } - pub fn take_next_write( &mut self, now: Instant, diff --git a/ql-fsm/src/session/stream_tx.rs b/ql-fsm/src/session/stream_tx.rs index a5a5e417..8e239494 100644 --- a/ql-fsm/src/session/stream_tx.rs +++ b/ql-fsm/src/session/stream_tx.rs @@ -62,15 +62,6 @@ impl StreamTx { self.base_offset + self.bytes.len() as u64 } - pub fn has_pending(&self) -> bool { - self.segments - .iter() - .any(|segment| matches!(segment.state, SendState::Unsent | SendState::Lost)) - || self.final_offset.is_some_and(|final_offset| { - matches!(final_offset.state, SendState::Unsent | SendState::Lost) - }) - } - pub fn is_empty(&self) -> bool { self.bytes.is_empty() && self.segments.is_empty() && self.final_offset.is_none() } diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index 0ca6a1ff..115e42d5 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -5,7 +5,10 @@ use ql_wire::{ QlHandshakeRecord, SessionKey, TransportParams, }; -use crate::{replay_cache::ReplayCache, FsmTime, PeerStatus, QlFsmEvent, QlSessionEvent}; +use crate::{ + replay_cache::ReplayCache, session::SessionFsm, FsmTime, PeerStatus, QlFsmError, QlFsmEvent, + QlSessionEvent, +}; pub struct QlFsmState { pub replay_cache: ReplayCache, @@ -42,12 +45,16 @@ impl SessionTransport { } } -#[derive(Debug, Clone)] pub enum LinkState { Idle, IkInitiator(IkInitiatorState), KkInitiator(KkInitiatorState), - Connected(SessionTransport), + Connected(ConnectedState), +} + +pub struct ConnectedState { + pub transport: SessionTransport, + pub session: SessionFsm, } #[derive(Debug, Clone)] @@ -79,13 +86,27 @@ impl LinkState { } } - pub fn transport(&self) -> Option<&SessionTransport> { + #[inline] + pub fn connected(&self) -> Option<&ConnectedState> { + match self { + Self::Connected(state) => Some(state), + _ => None, + } + } + + #[inline] + pub fn connected_mut(&mut self) -> Option<&mut ConnectedState> { match self { - Self::Connected(transport) => Some(transport), + Self::Connected(state) => Some(state), _ => None, } } + #[inline] + pub fn connected_mut_or_err(&mut self) -> Result<&mut ConnectedState, QlFsmError> { + self.connected_mut().ok_or(QlFsmError::NoSession) + } + pub fn handshake_deadline(&self) -> Option { match self { Self::Idle | Self::Connected(_) => None, @@ -93,11 +114,9 @@ impl LinkState { Self::KkInitiator(state) => Some(state.deadline), } } -} -impl QlFsmState { - pub fn ensure_peer_bound(&self) -> Result<(), crate::QlFsmError> { - self.peer.as_ref().ok_or(crate::QlFsmError::NoPeerBound)?; - Ok(()) + #[cfg(test)] + pub fn transport(&self) -> Option<&SessionTransport> { + self.connected().map(|state| &state.transport) } } diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 6fd940f6..a7ef7660 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -17,7 +17,7 @@ use sha2::{Digest, Sha256}; use crate::{ session::{stream_parity::StreamParity, SessionFsm, SessionFsmConfig}, - state::{LinkState, SessionTransport}, + state::{ConnectedState, LinkState, SessionTransport}, FsmTime, OutboundWrite, QlFsm, QlFsmConfig, SessionWriteId, }; @@ -210,34 +210,40 @@ impl Harness { let a_to_b_conn = ConnectionId::from_data([0xA1; ConnectionId::SIZE]); let b_to_a_conn = ConnectionId::from_data([0xB2; ConnectionId::SIZE]); - harness.a.fsm.state.link = LinkState::Connected(SessionTransport { - tx_key: a_to_b_key.clone(), - rx_key: b_to_a_key.clone(), - tx_connection_id: a_to_b_conn, - rx_connection_id: b_to_a_conn, - remote_transport_params: TransportParams { - initial_stream_receive_window: harness - .b - .fsm - .config - .session_stream_receive_buffer_size as u32, + harness.a.fsm.state.link = LinkState::Connected(ConnectedState { + transport: SessionTransport { + tx_key: a_to_b_key.clone(), + rx_key: b_to_a_key.clone(), + tx_connection_id: a_to_b_conn, + rx_connection_id: b_to_a_conn, + remote_transport_params: TransportParams { + initial_stream_receive_window: harness + .b + .fsm + .config + .session_stream_receive_buffer_size + as u32, + }, }, + session: SessionFsm::new(session_config(&harness, true), harness.now), }); - harness.b.fsm.state.link = LinkState::Connected(SessionTransport { - tx_key: b_to_a_key, - rx_key: a_to_b_key, - tx_connection_id: b_to_a_conn, - rx_connection_id: a_to_b_conn, - remote_transport_params: TransportParams { - initial_stream_receive_window: harness - .a - .fsm - .config - .session_stream_receive_buffer_size as u32, + harness.b.fsm.state.link = LinkState::Connected(ConnectedState { + transport: SessionTransport { + tx_key: b_to_a_key, + rx_key: a_to_b_key, + tx_connection_id: b_to_a_conn, + rx_connection_id: a_to_b_conn, + remote_transport_params: TransportParams { + initial_stream_receive_window: harness + .a + .fsm + .config + .session_stream_receive_buffer_size + as u32, + }, }, + session: SessionFsm::new(session_config(&harness, false), harness.now), }); - harness.a.fsm.session = SessionFsm::new(session_config(&harness, true), harness.now); - harness.b.fsm.session = SessionFsm::new(session_config(&harness, false), harness.now); harness } diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index ec50e6f1..7695db03 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -3,7 +3,7 @@ use std::time::Duration; use ql_wire::{SessionClose, StreamId}; use super::*; -use crate::{state::LinkState, QlFsmEvent, QlSessionEvent}; +use crate::{state::LinkState, QlFsmError, QlFsmEvent, QlSessionEvent}; fn read_stream_all(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { let mut out = Vec::new(); @@ -151,67 +151,41 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { } #[test] -fn queued_stream_work_waits_for_explicit_connect_and_then_drains() { +fn disconnected_stream_operations_fail_with_no_session() { let mut harness = Harness::paired_known(QlFsmConfig::default()); + let missing = StreamId(0); - let stream_id = harness.a.fsm.open_stream().unwrap(); - assert_eq!(harness.a.fsm.write_stream(stream_id, b"queued").unwrap(), 6); - harness.a.fsm.finish_stream(stream_id).unwrap(); - - assert!(harness.next_outbound_a().is_none()); - - harness - .a - .fsm - .connect_ik(harness.time(), &harness.a.crypto) - .unwrap(); - harness.pump(); - + assert_eq!(harness.a.fsm.open_stream(), Err(QlFsmError::NoSession)); assert_eq!( - harness.b.fsm.take_next_session_event(), - Some(QlSessionEvent::Opened(stream_id)) + harness.a.fsm.write_stream(missing, b"queued"), + Err(QlFsmError::NoSession) ); assert_eq!( - harness.b.fsm.take_next_session_event(), - Some(QlSessionEvent::Readable(stream_id)) + harness.a.fsm.finish_stream(missing), + Err(QlFsmError::NoSession) ); assert_eq!( - read_stream_all(&mut harness.b.fsm, stream_id), - b"queued".to_vec() + harness.a.fsm.close_stream( + missing, + ql_wire::CloseTarget::Both, + ql_wire::StreamCloseCode(0) + ), + Err(QlFsmError::NoSession) ); + assert_eq!(harness.a.fsm.queue_ping(), Err(QlFsmError::NoSession)); assert_eq!( - harness.b.fsm.take_next_session_event(), - Some(QlSessionEvent::Finished(stream_id)) + harness.a.fsm.stream_read_commit(missing, 1), + Err(QlFsmError::NoSession) ); } #[test] -fn queued_stream_work_is_failed_when_handshake_times_out() { - let config = QlFsmConfig { - handshake_timeout: Duration::from_millis(50), - ..QlFsmConfig::default() - }; - let mut harness = Harness::paired_known(config); +fn disconnected_stream_read_accessors_return_none() { + let harness = Harness::paired_known(QlFsmConfig::default()); + let missing = StreamId(0); - let stream_id = harness.a.fsm.open_stream().unwrap(); - assert_eq!(harness.a.fsm.write_stream(stream_id, b"queued").unwrap(), 6); - - harness - .a - .fsm - .connect_ik(harness.time(), &harness.a.crypto) - .unwrap(); - let _first = harness.next_outbound_a().unwrap(); - harness.advance(config.handshake_timeout); - harness.a.fsm.on_timer(harness.time()); - - assert_eq!( - harness.a.fsm.take_next_session_event(), - Some(QlSessionEvent::SessionClosed(SessionClose { - code: ql_wire::SessionCloseCode::TIMEOUT - })) - ); - assert!(harness.next_outbound_a().is_none()); + assert!(harness.a.fsm.stream_read(missing).is_none()); + assert!(harness.a.fsm.stream_available_bytes(missing).is_none()); } #[test] @@ -356,7 +330,7 @@ fn session_records_contain_ack_frames_after_delivery() { } #[test] -fn queued_stream_work_uses_negotiated_initial_peer_credit_after_connect() { +fn first_stream_data_uses_negotiated_initial_peer_credit() { let mut harness = Harness::paired_known_with_configs( QlFsmConfig { session_stream_receive_buffer_size: 8, @@ -368,9 +342,6 @@ fn queued_stream_work_uses_negotiated_initial_peer_credit_after_connect() { }, ); - let stream_id = harness.a.fsm.open_stream().unwrap(); - assert_eq!(harness.a.fsm.write_stream(stream_id, b"hello").unwrap(), 5); - harness .a .fsm @@ -381,6 +352,9 @@ fn queued_stream_work_uses_negotiated_initial_peer_credit_after_connect() { let ik2 = harness.next_outbound_b().unwrap(); harness.deliver_to_a(ik2); + let stream_id = harness.a.fsm.open_stream().unwrap(); + assert_eq!(harness.a.fsm.write_stream(stream_id, b"hello").unwrap(), 5); + let data = harness.next_outbound_a().unwrap(); let session_key = harness.b.fsm.state.link.transport().unwrap().rx_key.clone(); let (_header, record) = decrypt_record(&harness.b.crypto, &data, &session_key); From c1682915ddd56c109e7734c1c0ff0de605d41dd6 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 4 Apr 2026 15:34:43 -0400 Subject: [PATCH 092/304] ql-fsm: better rx bitmap --- ql-fsm/src/session/mod.rs | 37 +++++++++++------- ql-fsm/src/session/received_records.rs | 53 ++++++++++---------------- ql-fsm/src/session/tests.rs | 42 ++++++++++++++++++++ 3 files changed, 87 insertions(+), 45 deletions(-) diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index f2eb84ac..9840b887 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -252,16 +252,23 @@ impl SessionFsm { self.state.last_activity_at = self.state.now; self.state.last_inbound_at = self.state.now; - // TODO: We record the session seq before validating its frames. If later frame - // handling fails with PROTOCOL, a subsequent outbound close can still carry an ack for - // this seq even though its stream data was rejected. - let (duplicate, out_of_order) = match self.state.received_records.insert(seq) { - ReceiveOutcome::Duplicate => (true, false), - ReceiveOutcome::New { out_of_order } => (false, out_of_order), + if self.state.session_state == SessionState::Closed { + return; + } + + let mut received_records = self.state.received_records.clone(); + let out_of_order = match received_records.insert(seq) { + ReceiveOutcome::TooOld => return, + ReceiveOutcome::Duplicate => { + self.schedule_ack(true); + return; + } + ReceiveOutcome::New { out_of_order } => out_of_order, }; - let closed = self.state.session_state == SessionState::Closed; let mut ack_eliciting = false; + let mut handled_close = false; + for frame in frames { let Ok(frame) = frame else { self.fail_session( @@ -273,10 +280,6 @@ impl SessionFsm { return; }; ack_eliciting |= !matches!(frame, SessionFrame::Ack(_)); - if duplicate || closed { - continue; - } - match frame { SessionFrame::Ping => {} SessionFrame::Ack(ack) => self.process_record_ack(&ack, &mut emit), @@ -297,13 +300,21 @@ impl SessionFsm { } SessionFrame::Close(close) => { self.handle_session_close(close, &mut emit); - return; + handled_close = true; + break; } } } + // commit after processing + self.state.received_records = received_records; + + if handled_close { + return; + } + if ack_eliciting { - self.schedule_ack(duplicate || closed || out_of_order); + self.schedule_ack(out_of_order); } } diff --git a/ql-fsm/src/session/received_records.rs b/ql-fsm/src/session/received_records.rs index 03bd08fa..1a8eeb9c 100644 --- a/ql-fsm/src/session/received_records.rs +++ b/ql-fsm/src/session/received_records.rs @@ -1,16 +1,16 @@ use ql_wire::{RecordAck, RecordSeq}; -#[derive(Debug, Default)] +#[derive(Debug, Clone, Default)] pub struct ReceivedRecords { seen: u64, base: u64, - largest: Option, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ReceiveOutcome { New { out_of_order: bool }, Duplicate, + TooOld, } impl ReceivedRecords { @@ -19,33 +19,32 @@ impl ReceivedRecords { pub fn insert(&mut self, seq: RecordSeq) -> ReceiveOutcome { let seq = seq.0; - let Some(largest) = self.largest else { + if self.seen == 0 { self.base = seq; self.seen = 1; - self.largest = Some(seq); return ReceiveOutcome::New { out_of_order: false, }; - }; - - if largest.saturating_sub(seq) > Self::TRACKED_WINDOW { - return ReceiveOutcome::Duplicate; } - let out_of_order = seq != largest.saturating_add(1); - if seq > largest { - self.advance_base(seq.saturating_sub(Self::TRACKED_WINDOW)); - self.largest = Some(seq); + if seq < self.base { + return ReceiveOutcome::TooOld; } - let Some(bit) = self.bit_for(seq) else { - return ReceiveOutcome::Duplicate; - }; - if self.seen & bit != 0 { + let base = self.base.max(seq.saturating_sub(Self::TRACKED_WINDOW)); + let seen = self.rebased_seen(base); + let next_seen = seen | (1u64 << (seq - base)); + if next_seen == seen { return ReceiveOutcome::Duplicate; } - self.seen |= bit; + let out_of_order = seq + != self + .base + .saturating_add((u64::BITS - 1 - self.seen.leading_zeros()) as u64) + .saturating_add(1); + self.base = base; + self.seen = next_seen; ReceiveOutcome::New { out_of_order } } @@ -56,27 +55,17 @@ impl ReceivedRecords { }) } - fn bit_for(&self, seq: u64) -> Option { - if seq < self.base { - return None; - } - - let offset = seq - self.base; - (offset < Self::TRACKED_LEN).then_some(1u64 << offset) - } - - fn advance_base(&mut self, new_base: u64) { + fn rebased_seen(&self, new_base: u64) -> u64 { if new_base <= self.base { - return; + return self.seen; } let shift = new_base - self.base; if shift >= Self::TRACKED_LEN { - self.seen = 0; + 0 } else { - self.seen >>= shift; + self.seen >> shift } - self.base = new_base; } } @@ -129,7 +118,7 @@ mod tests { received.insert(RecordSeq(300)), ReceiveOutcome::New { out_of_order: true } ); - assert_eq!(received.insert(RecordSeq(0)), ReceiveOutcome::Duplicate); + assert_eq!(received.insert(RecordSeq(0)), ReceiveOutcome::TooOld); let ack = received.ack().unwrap(); assert_eq!( diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index c28a8094..921e1284 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -303,6 +303,48 @@ fn duplicate_stream_data_is_not_redelivered() { assert_eq!(read_stream_all(&mut fsm, stream_id), b"hi".to_vec()); } +#[test] +fn close_does_not_ack_rejected_record_seq() { + let now = Instant::now(); + let mut fsm = SessionFsm::new( + SessionFsmConfig { + ack_delay: Duration::ZERO, + ..SessionFsmConfig::default() + }, + now, + ); + + let invalid = SessionRecord { + frames: vec![SessionFrame::StreamData(StreamData { + stream_id: StreamId(0), + offset: 0, + fin: false, + bytes: b"bad".to_vec(), + })], + }; + let events = receive_events(&mut fsm, now, RecordSeq(7), &invalid); + assert_eq!( + events, + vec![SessionEvent::SessionClosed(ql_wire::SessionClose { + code: ql_wire::SessionCloseCode::PROTOCOL, + })] + ); + + let valid_after_close = SessionRecord { + frames: vec![SessionFrame::Ping], + }; + let events = receive_events( + &mut fsm, + now + Duration::from_millis(1), + RecordSeq(8), + &valid_after_close, + ); + assert!(events.is_empty()); + + let (_seq, outbound) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); + assert!(matches!(outbound.frames.as_slice(), [SessionFrame::Close(_)])); +} + #[test] fn initial_peer_stream_receive_window_limits_first_send() { let now = Instant::now(); From 98db734ba08822002f4f73266a937cce662c8faf Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 4 Apr 2026 16:16:55 -0400 Subject: [PATCH 093/304] ql-fsm: better lazy ack --- ql-fsm/src/session/mod.rs | 59 +++++++++++++++---------------------- ql-fsm/src/session/state.rs | 5 ++-- ql-fsm/src/session/tests.rs | 13 +++++--- 3 files changed, 35 insertions(+), 42 deletions(-) diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 9840b887..e31725c9 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -344,7 +344,6 @@ impl SessionFsm { }; restore_tracked_record( self.state.now, - self.config.ack_delay, &mut self.state.ack_state, &mut self.state.pending_control, &mut self.state.streams, @@ -355,11 +354,6 @@ impl SessionFsm { pub fn on_timer(&mut self, now: Instant, mut emit: impl FnMut(SessionEvent)) { self.state.now = now; self.collect_timeouts(); - if let AckState::Delayed { due_at } = self.state.ack_state { - if due_at <= self.state.now { - self.state.ack_state = AckState::Immediate; - } - } if !self.config.peer_timeout.is_zero() && self.state.last_inbound_at + self.config.peer_timeout <= self.state.now { @@ -382,8 +376,7 @@ impl SessionFsm { pub fn next_deadline(&self) -> Option { let ack_deadline = match self.state.ack_state { AckState::Idle => None, - AckState::Immediate => Some(self.state.now), - AckState::Delayed { due_at } => Some(due_at), + AckState::Dirty { due_at } => Some(due_at), }; let retransmit_deadline = self .state @@ -441,15 +434,6 @@ impl SessionFsm { sent_at: None, }; - if self.should_send_ack() { - if let Some(ack) = self.state.received_records.ack() { - if builder.push_ack(&ack) { - outbound.ack_included = true; - self.state.ack_state = AckState::Idle; - } - } - } - if let Some(close) = self.state.pending_control.close.clone() { if builder.push_close(&close) { self.state.pending_control.close = None; @@ -468,6 +452,13 @@ impl SessionFsm { while self.push_next_stream_data(&mut builder, &mut outbound) {} + if let Some((ack, due_at)) = self.pending_ack() { + if (!builder.is_empty() || due_at <= self.state.now) && builder.push_ack(&ack) { + outbound.ack_included = true; + self.state.ack_state = AckState::Idle; + } + } + if builder.is_empty() { return None; } @@ -651,20 +642,20 @@ impl SessionFsm { fn schedule_ack(&mut self, immediate: bool) { schedule_ack( &mut self.state.ack_state, - self.state.now, - self.config.ack_delay, - immediate, + if immediate { + self.state.now + } else { + self.state.now + self.config.ack_delay + }, ); } - fn should_send_ack(&self) -> bool { - if self.state.received_records.ack().is_none() { - return false; - } + fn pending_ack(&self) -> Option<(RecordAck, Instant)> { match self.state.ack_state { - AckState::Immediate => true, - AckState::Delayed { due_at } => due_at <= self.state.now, - AckState::Idle => false, + AckState::Idle => None, + AckState::Dirty { due_at } => { + self.state.received_records.ack().map(|ack| (ack, due_at)) + } } } @@ -677,7 +668,6 @@ impl SessionFsm { }) { restore_tracked_record( self.state.now, - self.config.ack_delay, &mut self.state.ack_state, &mut self.state.pending_control, &mut self.state.streams, @@ -960,27 +950,24 @@ impl SessionFsm { } } -fn schedule_ack(ack_state: &mut AckState, now: Instant, ack_delay: Duration, immediate: bool) { +fn schedule_ack(ack_state: &mut AckState, due_at: Instant) { *ack_state = match *ack_state { - AckState::Immediate => AckState::Immediate, - _ if immediate || ack_delay.is_zero() => AckState::Immediate, - AckState::Delayed { due_at } => AckState::Delayed { due_at }, - AckState::Idle => AckState::Delayed { - due_at: now + ack_delay, + AckState::Dirty { due_at: old } => AckState::Dirty { + due_at: due_at.min(old), }, + AckState::Idle => AckState::Dirty { due_at: due_at }, }; } fn restore_tracked_record( now: Instant, - ack_delay: Duration, ack_state: &mut AckState, pending_control: &mut state::PendingSessionControl, streams: &mut IndexMap, record: TrackedRecord, ) { if record.ack_included { - schedule_ack(ack_state, now, ack_delay, true); + schedule_ack(ack_state, now); } if record.ping_included { pending_control.ping = true; diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index 060a75fb..70e9c0a1 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -129,7 +129,8 @@ pub struct PendingSessionControl { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum AckState { + // ack state is not dirty Idle, - Delayed { due_at: Instant }, - Immediate, + // ack is dirty. we can wait to piggy back on an outgoing record until this time + Dirty { due_at: Instant }, } diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 921e1284..34b7eb63 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -378,8 +378,13 @@ fn initial_peer_stream_receive_window_limits_first_send() { assert!(events.is_empty()); let (_second_seq, second) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); - assert!(matches!( - second.frames.as_slice(), - [SessionFrame::StreamData(frame)] if frame.stream_id == stream_id && frame.offset == 3 && frame.bytes.as_slice() == b"lo" - )); + assert!(second.frames.iter().any(|frame| { + matches!( + frame, + SessionFrame::StreamData(frame) + if frame.stream_id == stream_id + && frame.offset == 3 + && frame.bytes.as_slice() == b"lo" + ) + })); } From 14fcf78f20b8b3418d226c7f919a49f337377018 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 4 Apr 2026 16:29:53 -0400 Subject: [PATCH 094/304] ql: more lazy session record builder --- ql-fsm/src/implementation/handshake/mod.rs | 1 - ql-fsm/src/lib.rs | 3 -- ql-fsm/src/session/mod.rs | 17 +------- ql-fsm/src/session/tests.rs | 42 ++++++++++++++++---- ql-fsm/src/tests/mod.rs | 1 - ql-wire/src/encrypted/builder.rs | 46 ++++++++++------------ ql-wire/src/tests.rs | 2 +- 7 files changed, 59 insertions(+), 53 deletions(-) diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs index 4c42ef9b..1faec712 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -100,7 +100,6 @@ pub fn finish_handshake( let session = SessionFsm::new( SessionFsmConfig { local_parity: StreamParity::for_local(fsm.identity.xid, remote_bundle.xid), - record_target_size: config.session_record_target_size, record_max_size: config.session_record_max_size, ack_delay: config.session_record_ack_delay, retransmit_timeout: config.session_record_retransmit_timeout, diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index eedfd005..c6593407 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -122,8 +122,6 @@ pub struct QlFsmConfig { pub session_keepalive_interval: Duration, /// how long to wait before declaring the peer dead pub session_peer_timeout: Duration, - /// target total wire size for one session record, including header and auth tag - pub session_record_target_size: usize, /// maximum total wire size for one session record, including header and auth tag pub session_record_max_size: usize, /// maximum bytes buffered locally for one stream send side @@ -141,7 +139,6 @@ impl Default for QlFsmConfig { session_record_retransmit_timeout: s.retransmit_timeout, session_keepalive_interval: s.keepalive_interval, session_peer_timeout: s.peer_timeout, - session_record_target_size: s.record_target_size, session_record_max_size: s.record_max_size, session_stream_send_buffer_size: s.stream_send_buffer_size, session_stream_receive_buffer_size: s.stream_receive_buffer_size, diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index e31725c9..e320ee4c 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -29,7 +29,6 @@ use self::{ #[derive(Debug, Clone, Copy)] pub struct SessionFsmConfig { pub local_parity: StreamParity, - pub record_target_size: usize, pub record_max_size: usize, pub ack_delay: Duration, pub retransmit_timeout: Duration, @@ -44,7 +43,6 @@ impl Default for SessionFsmConfig { fn default() -> Self { Self { local_parity: StreamParity::Even, - record_target_size: 4 * 1024, record_max_size: 16 * 1024, ack_delay: Duration::from_millis(5), retransmit_timeout: Duration::from_millis(150), @@ -93,13 +91,9 @@ pub struct SessionFsm { impl SessionFsm { pub fn new(mut config: SessionFsmConfig, now: Instant) -> Self { - config.record_target_size = config - .record_target_size - .max(SessionRecordBuilder::WIRE_PREFIX_LEN); config.record_max_size = config .record_max_size .max(SessionRecordBuilder::WIRE_PREFIX_LEN); - config.record_target_size = config.record_target_size.min(config.record_max_size); config.stream_send_buffer_size = config.stream_send_buffer_size.max(1); config.stream_receive_buffer_size = config.stream_receive_buffer_size.max(1); Self { @@ -423,8 +417,7 @@ impl SessionFsm { fn build_next_record(&mut self) -> Option<(SessionRecordBuilder, TrackedRecord)> { let seq = self.state.next_record_seq; - let mut builder = - SessionRecordBuilder::new(self.config.record_max_size, self.config.record_target_size); + let mut builder = SessionRecordBuilder::new(self.config.record_max_size); let mut outbound = TrackedRecord { seq, frames: Vec::new(), @@ -602,14 +595,8 @@ impl SessionFsm { fn max_stream_data_payload(&self, builder: &SessionRecordBuilder) -> Option { let overhead = 1 + std::mem::size_of::() + StreamData::>::MIN_WIRE_SIZE; let remaining = builder.remaining_capacity(); - if remaining > overhead { + if remaining >= overhead { Some(remaining - overhead) - } else if builder.is_empty() { - Some( - self.config - .record_max_size - .saturating_sub(SessionRecordBuilder::WIRE_PREFIX_LEN), - ) } else { None } diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 34b7eb63..0d2856f0 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -36,14 +36,12 @@ fn receive_events( seq: RecordSeq, record: &SessionRecord, ) -> Vec { - let mut builder = SessionRecordBuilder::new( - SessionRecordBuilder::WIRE_PREFIX_LEN + record.wire_size(), - SessionRecordBuilder::WIRE_PREFIX_LEN + record.wire_size(), - ); + let mut builder = + SessionRecordBuilder::new(SessionRecordBuilder::WIRE_PREFIX_LEN + record.wire_size()); for frame in &record.frames { assert!(builder.push_frame(frame)); } - let bytes = builder.into_plaintext(); + let bytes = builder.bytes().to_vec(); let frames = SessionRecord::parse(&bytes).unwrap(); let mut events = Vec::new(); fsm.receive(now, seq, frames, |event| events.push(event)); @@ -87,7 +85,6 @@ fn lost_record_on_one_stream_does_not_block_another_stream() { let now = Instant::now(); let mut fsm = SessionFsm::new( SessionFsmConfig { - record_target_size: 80 + SessionRecordBuilder::WIRE_PREFIX_LEN, record_max_size: 80 + SessionRecordBuilder::WIRE_PREFIX_LEN, ..SessionFsmConfig::default() }, @@ -122,6 +119,34 @@ fn lost_record_on_one_stream_does_not_block_another_stream() { assert_eq!(stream_ids, vec![stream_id_b]); } +#[test] +fn fin_only_stream_data_fits_exact_record_limit() { + let now = Instant::now(); + let stream_data_overhead = + 1 + std::mem::size_of::() + StreamData::>::MIN_WIRE_SIZE; + let mut fsm = SessionFsm::new( + SessionFsmConfig { + record_max_size: SessionRecordBuilder::WIRE_PREFIX_LEN + stream_data_overhead, + ..SessionFsmConfig::default() + }, + now, + ); + let stream_id = fsm.open_stream().unwrap(); + + fsm.finish_stream(stream_id).unwrap(); + + let (_seq, record) = next_outbound(&mut fsm, now).unwrap(); + assert_eq!(record.frames.len(), 1); + match &record.frames[0] { + SessionFrame::StreamData(frame) => { + assert_eq!(frame.stream_id, stream_id); + assert!(frame.fin); + assert!(frame.bytes.is_empty()); + } + frame => panic!("expected stream data frame, got {frame:?}"), + } +} + #[test] fn ack_reopens_write_capacity() { let now = Instant::now(); @@ -342,7 +367,10 @@ fn close_does_not_ack_rejected_record_seq() { assert!(events.is_empty()); let (_seq, outbound) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); - assert!(matches!(outbound.frames.as_slice(), [SessionFrame::Close(_)])); + assert!(matches!( + outbound.frames.as_slice(), + [SessionFrame::Close(_)] + )); } #[test] diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index a7ef7660..d5c317c4 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -346,7 +346,6 @@ fn session_config(harness: &Harness, a: bool) -> SessionFsmConfig { SessionFsmConfig { local_parity: StreamParity::for_local(local, peer), - record_target_size: config.session_record_target_size, record_max_size: config.session_record_max_size, ack_delay: config.session_record_ack_delay, retransmit_timeout: config.session_record_retransmit_timeout, diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index 4b45b601..090fa6f7 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -4,8 +4,6 @@ use crate::{ByteChunks, Nonce, QlCrypto, RecordType, SessionHeader, SessionKey, #[derive(Debug, Clone, PartialEq, Eq)] pub struct SessionRecordBuilder { max_capacity: usize, - // todo: remove - body_start: usize, bytes: Vec, } @@ -13,17 +11,11 @@ impl SessionRecordBuilder { pub const WIRE_PREFIX_LEN: usize = 1 + 1 + SessionHeader::WIRE_SIZE + crate::ENCRYPTED_MESSAGE_AUTH_SIZE; - pub fn new(max_capacity: usize, initial_capacity: usize) -> Self { - assert!(initial_capacity <= max_capacity); + pub fn new(max_capacity: usize) -> Self { assert!(max_capacity >= Self::WIRE_PREFIX_LEN); - - let body_start = Self::WIRE_PREFIX_LEN; - let mut bytes = Vec::with_capacity(initial_capacity); - bytes.resize(body_start, 0); Self { max_capacity, - body_start, - bytes, + bytes: Vec::new(), } } @@ -32,7 +24,7 @@ impl SessionRecordBuilder { } pub fn len(&self) -> usize { - self.bytes.len().saturating_sub(self.body_start) + self.bytes.len().saturating_sub(Self::WIRE_PREFIX_LEN) } pub fn is_empty(&self) -> bool { @@ -41,21 +33,11 @@ impl SessionRecordBuilder { pub fn remaining_capacity(&self) -> usize { self.max_capacity - .saturating_sub(self.body_start) - .saturating_sub(self.len()) + .saturating_sub(self.bytes.len().max(Self::WIRE_PREFIX_LEN)) } pub fn bytes(&self) -> &[u8] { - &self.bytes[self.body_start..] - } - - pub fn into_plaintext(self) -> Vec { - let mut bytes = self.bytes; - bytes.split_off(self.body_start) - } - - pub fn can_push_len(&self, len: usize) -> bool { - len <= self.remaining_capacity() || self.is_empty() + self.bytes.get(Self::WIRE_PREFIX_LEN..).unwrap_or_default() } pub fn push_ping(&mut self) -> bool { @@ -121,16 +103,17 @@ impl SessionRecordBuilder { header: SessionHeader, session_key: &SessionKey, ) -> Vec { + self.ensure_prefix_capacity(0); let aad = header.aad(); let nonce = Nonce::from_counter(header.seq.0); let auth = crypto.aes256_gcm_encrypt( session_key, &nonce, &aad, - &mut self.bytes[self.body_start..], + &mut self.bytes[Self::WIRE_PREFIX_LEN..], ); - let prefix = &mut self.bytes[..self.body_start]; + let prefix = &mut self.bytes[..Self::WIRE_PREFIX_LEN]; prefix[0] = QL_WIRE_VERSION; prefix[1] = RecordType::Session as u8; header.encode_into(&mut prefix[2..2 + SessionHeader::WIRE_SIZE]); @@ -142,9 +125,22 @@ impl SessionRecordBuilder { if !self.can_push_len(wire_size) { return false; } + self.ensure_prefix_capacity(wire_size); let start = self.bytes.len(); self.bytes.resize(start + wire_size, 0); encode(&mut self.bytes[start..]); true } + + fn can_push_len(&self, len: usize) -> bool { + len <= self.remaining_capacity() + } + + fn ensure_prefix_capacity(&mut self, additional_body_len: usize) { + if self.bytes.is_empty() { + self.bytes + .reserve(Self::WIRE_PREFIX_LEN + additional_body_len); + self.bytes.resize(Self::WIRE_PREFIX_LEN, 0); + } + } } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 768011c4..a66869eb 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -182,7 +182,7 @@ fn encrypt_record( body: &SessionRecord, ) -> QlSessionRecord> { let wire_size = body.wire_size() + SessionRecordBuilder::WIRE_PREFIX_LEN; - let mut builder = SessionRecordBuilder::new(wire_size, wire_size); + let mut builder = SessionRecordBuilder::new(wire_size); for frame in &body.frames { let _pushed = builder.push_frame(frame); debug_assert!(_pushed); From cb499f07bbb6e7f0dbc3ee27db9c8622bbd46817 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 4 Apr 2026 17:39:09 -0400 Subject: [PATCH 095/304] ql-fsm: better build session record --- ql-fsm/src/session/mod.rs | 99 +++++++++++++++------------------------ 1 file changed, 37 insertions(+), 62 deletions(-) diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index e320ee4c..6ef6256d 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -427,23 +427,24 @@ impl SessionFsm { sent_at: None, }; - if let Some(close) = self.state.pending_control.close.clone() { + if let Some(close) = self.state.pending_control.close.take() { if builder.push_close(&close) { - self.state.pending_control.close = None; - outbound.frames.push(TrackedFrame::Close(close)); + outbound.frames.push(TrackedFrame::Close(close.clone())); + } else { + self.state.pending_control.close = Some(close); } } - while self.push_next_pending_stream_close(&mut builder, &mut outbound) {} + self.push_next_pending_stream_close(&mut builder, &mut outbound); if self.state.pending_control.ping && builder.push_ping() { self.state.pending_control.ping = false; outbound.ping_included = true; } - while self.push_next_pending_stream_window(&mut builder, &mut outbound) {} + self.push_next_pending_stream_window(&mut builder, &mut outbound); - while self.push_next_stream_data(&mut builder, &mut outbound) {} + self.push_next_stream_data(&mut builder, &mut outbound); if let Some((ack, due_at)) = self.pending_ack() { if (!builder.is_empty() || due_at <= self.state.now) && builder.push_ack(&ack) { @@ -464,52 +465,43 @@ impl SessionFsm { &mut self, builder: &mut SessionRecordBuilder, outbound: &mut TrackedRecord, - ) -> bool { + ) { let len = self.state.streams.len(); if len == 0 { - return false; + return; } let start = self.state.next_stream_index % len; for offset in 0..len { let index = (start + offset) % len; - let Some((_, stream)) = self.state.streams.get_index(index) else { - continue; - }; + let stream = self.state.streams.get_index_mut(index).unwrap().1; let Some(close) = stream.pending_close.as_ref() else { continue; }; if !builder.push_stream_close(close) { - continue; + break; } - let stream = self.state.streams.get_index_mut(index).unwrap().1; - self.state.next_stream_index = (index + 1) % len; outbound.frames.push(TrackedFrame::StreamClose( stream.pending_close.take().unwrap(), )); - return true; } - - false } fn push_next_pending_stream_window( &mut self, builder: &mut SessionRecordBuilder, outbound: &mut TrackedRecord, - ) -> bool { + ) { let len = self.state.streams.len(); if len == 0 { - return false; + return; } let start = self.state.next_stream_index % len; for offset in 0..len { let index = (start + offset) % len; - let Some((&stream_id, stream)) = self.state.streams.get_index(index) else { - continue; - }; + let (&stream_id, stream) = self.state.streams.get_index_mut(index).unwrap(); if !stream.pending_window { continue; } @@ -518,66 +510,59 @@ impl SessionFsm { maximum_offset: stream.recv_limit(), }; if !builder.push_stream_window(&frame) { - continue; + break; } - let (_, stream) = self.state.streams.get_index_mut(index).unwrap(); stream.pending_window = false; stream.advertised_max_offset = frame.maximum_offset; - self.state.next_stream_index = (index + 1) % len; outbound .window_updates .push((stream_id, frame.maximum_offset)); - return true; } - - false } fn push_next_stream_data( &mut self, builder: &mut SessionRecordBuilder, outbound: &mut TrackedRecord, - ) -> bool { - let Some(max_payload) = self.max_stream_data_payload(builder) else { - return false; - }; + ) { + const OVERHEAD: usize = + 1 + std::mem::size_of::() + StreamData::>::MIN_WIRE_SIZE; + let len = self.state.streams.len(); if len == 0 { - return false; + return; } let start = self.state.next_stream_index % len; + let mut next_index = start; + for offset in 0..len { - let index = (start + offset) % len; - let Some((&stream_id, stream)) = self.state.streams.get_index(index) else { - continue; + let Some(max_payload) = builder.remaining_capacity().checked_sub(OVERHEAD) else { + break; }; + + let index = (start + offset) % len; + let (&stream_id, stream) = self.state.streams.get_index_mut(index).unwrap(); if matches!(stream.outbound_state, OutboundState::Closed) { continue; } - let Some(candidate) = stream.tx.next_range(max_payload, stream.peer_max_offset) else { continue; }; - { - let frame = StreamData { - stream_id, - offset: candidate.offset, - fin: candidate.fin, - bytes: stream.tx.ranged_bytes(candidate), - }; - if !builder.push_stream_data(&frame) { - continue; - } - } + let frame = StreamData { + stream_id, + offset: candidate.offset, + fin: candidate.fin, + bytes: stream.tx.ranged_bytes(candidate), + }; + let res = builder.push_stream_data(&frame); + assert!(res, "builder has capacity"); - let (_, stream) = self.state.streams.get_index_mut(index).unwrap(); stream.tx.mark_in_flight(candidate); if candidate.fin { stream.outbound_state = OutboundState::Finished; } - self.state.next_stream_index = (index + 1) % len; outbound .frames .push(TrackedFrame::StreamData(TrackedStreamData { @@ -586,20 +571,10 @@ impl SessionFsm { len: candidate.len, fin: candidate.fin, })); - return true; + next_index = (index + 1) % len; } - false - } - - fn max_stream_data_payload(&self, builder: &SessionRecordBuilder) -> Option { - let overhead = 1 + std::mem::size_of::() + StreamData::>::MIN_WIRE_SIZE; - let remaining = builder.remaining_capacity(); - if remaining >= overhead { - Some(remaining - overhead) - } else { - None - } + self.state.next_stream_index = next_index; } fn ensure_session_open(&self) -> Result<(), StreamError> { From 0ade1e0242180d1e5c908a9380855a85279f5e57 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 4 Apr 2026 20:18:38 -0400 Subject: [PATCH 096/304] ql: design doc --- QL_V2.md | 397 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 219 insertions(+), 178 deletions(-) diff --git a/QL_V2.md b/QL_V2.md index f3e2f7a5..c8dd3ffd 100644 --- a/QL_V2.md +++ b/QL_V2.md @@ -1,205 +1,246 @@ # QuantumLink V2 Design Document -QuantumLink V2 is a peer-to-peer protocol for authenticated, encrypted sessions carrying multiplexed byte streams. +QuantumLink V2 is a peer-to-peer protocol for authenticated encrypted sessions carrying multiplexed duplex byte streams. -It replaces QLv1's one-message-at-a-time model with explicit pairing, handshake, session, and stream state. +It operates on whole QL records. Packetization, fragmentation, batching, and reassembly belong to the transport adapter, not to QLv2 itself. -QLv2 operates on complete QL records and leaves transport-specific framing, fragmentation, reassembly, and delivery behavior to platform adapters. - -## Table of contents -- [Design goals](#design-goals) -- [Non-design goals](#non-design-goals) -- [Protocol model](#protocol-model) -- [Session handshake](#handshake) -- [Session sequencing and reliability](#session-sequencing-and-reliability) -- [Keepalive and liveness](#keepalive-and-liveness) -- [Stream model](#stream-model) +The handshake is the setup phase. It authenticates the remote peer, establishes a fresh session, and derives the keys used for steady-state traffic. ## Design goals -1. [use ephemeral peer sessions for record encryption](#1-explicit-peer-sessions) -2. [include a minimal unencrypted but authenticated header](#2-minimal-authenticated-header) -3. [keep the record layer transport-agnostic](#3-transport-agnostic-record-layer) -4. [add QL-level reliability above the transport](#4-ql-level-reliability) -5. [use duplex byte streams as the application primitive](#5-duplex-byte-streams) -6. [efficient protocol wire format](#6-efficient-wire-format) -7. [provide a single shared protocol state machine across platforms](#7-shared-core-state-machine) -8. [support hardware-backed cryptography](#8-hardware-backed-cryptography) - -### 1. Explicit peer sessions -QLv2 replaces per-exchange sealing with explicit pairing, handshake, session, and stream state. This keeps peer state durable across many records, amortizes large post-quantum signatures and expensive key exchange, and keeps steady-state traffic smaller and cheaper. - -### 2. Minimal authenticated header -QLv2 keeps a small header visible on the wire while still authenticating it. This lets a host route a record to the correct local or third-party application before decryption without exposing more metadata than necessary. - -The visible record header currently includes: - -- protocol version -- record kind -- sender XID -- recipient XID - -This header is intentionally narrow, and can be extended in the future if needed. - -### 3. Transport-agnostic record layer -The core protocol only consumes and produces complete QL records. Framing, batching, fragmentation, and reassembly stay in the transport adapter so the same protocol can run over transports such as TCP, BLE, or L2CAP without rewriting core logic. - -### 4. QL-level reliability -QLv2 includes QL-level sequence numbers and acknowledgments above the transport. A transport can usually only tell us that bytes were accepted for transmission. A QL acknowledgment tells us something stronger: the peer received and decrypted the message with the session key. - -This is deliberate redundancy, not a replacement for transport reliability. It is not sufficient for a fully unreliable transport like raw UDP, but it does make QLv2 more robust on transports that should be reliable in theory yet have shown implementation-level flakiness in practice, such as Passport Prime's embedded BLE. - -### 5. Duplex byte streams -QLv2 treats duplex byte streams as the application primitive rather than building in a separate model for each interaction style. Request/response, subscriptions, progress updates, and bulk transfer can all be adapted to the same abstraction, which also gives useful behavior such as finish semantics, cancellation, and backpressure without separate protocol features. - -### 6. Efficient wire format -The wire format should stay compact, cheap to process, and independent of any one implementation language. QLv2 uses an efficient binary encoding with explicit endianness and fixed layouts, so records can be parsed consistently across platforms. - -The record sizes shows the protocol's intended split between setup and steady-state traffic. Setup records are relatively large because they carry post-quantum cryptography material, while steady-state session records are much smaller. - -| Record type | Encoded size | -| --- | ---: | -| `hello` | 6253 bytes | -| `hello_reply` | 6253 bytes | -| `confirm` | 4673 bytes | -| `pair_request empty` | 1630 bytes | -| `unpair` | 4673 bytes | -| `ready empty` | 62 bytes | -| `session ack` | 87 bytes | -| `session ping` | 87 bytes | -| `session stream empty` | 100 bytes | -| `session stream fin` | 100 bytes | -| `session stream close` | 94 bytes | -| `session close` | 89 bytes | - -Any encrypted record has the same outer wire shape: - -| Component | Size | -| --- | ---: | -| protocol version | 1 byte | -| record kind | 1 byte | -| sender XID | 16 bytes | -| recipient XID | 16 bytes | -| AEAD nonce | 12 bytes | -| AEAD auth tag | 16 bytes | -| ciphertext | N bytes | - -That gives a 62-byte minimum for any encrypted record before counting the encrypted plaintext. The AEAD keeps the ciphertext the same length as the plaintext, so after that fixed 62-byte overhead, each additional plaintext byte becomes one additional ciphertext byte. - -For session records, the encrypted plaintext always starts with a 25-byte session envelope: - -| Session envelope field | Size | -| --- | ---: | -| `seq` | 8 bytes | -| `ack.base` | 8 bytes | -| `ack.bitmap` | 8 bytes | -| session body kind discriminator | 1 byte | - -### 7. Shared core state machine -QLv2 should have one core implementation of pairing, handshake, session, retransmission, and stream behavior. Platforms should integrate that shared state machine instead of rebuilding subtle protocol logic independently. - -### 8. Hardware-backed cryptography -QLv2 separates parts of its cryptographic implementation through the `QlCrypto` trait. Each platform can provide its own source of randomness, hashing, and AEAD encryption and decryption, choosing software or hardware-backed implementations as appropriate. - -```rust -pub trait QlCrypto { - fn fill_random_bytes(&self, data: &mut [u8]); - fn hash(&self, parts: &[&[u8]]) -> [u8; 32]; - fn encrypt_with_aead(&self, /*...*/) -> [u8; EncryptedMessage::AUTH_SIZE]; - fn decrypt_with_aead(&self, /*...*/) -> bool; -} -``` - -## Non-design goals -- not a replacement for TCP, QUIC, BLE, or any other transport -- not a universal reliability layer for arbitrary raw packets -- not responsible for framing, batching, fragmentation, or reassembly on a given platform -- not responsible for how QL records map onto TCP reads/writes, BLE packets, or similar transport units -- not a general-purpose message bus above the stream layer -- not an attempt to preserve QLv1's sealed-message model in the core protocol - -## Protocol model -QLv2 has four layers of state: - -- `Pairing` establish a durable peer relationship -- `Handshake` establish a fresh encrypted session between paired peers -- `Session` carries authenticated encrypted traffic with QL-level acknowledgment and retransmission -- `Stream` multiplex many concurrent duplex byte streams inside one session +1. [Ephemeral peer sessions](#handshake): short-lived keys for encryption +2. [Forward secrecy](#security-properties): losing a long-term private key does not reveal old session data +3. [Minimal authenticated header](#record-and-frame-wire-format): keep routing visible, but authenticated +4. [QL-level reliability](#acknowledgment-and-retransmission): `ack` means received, decrypted, and accepted +5. [Duplex byte streams](#streams): avoid cross-stream head-of-line blocking and keep backpressure local +6. [Efficient wire format](#record-and-frame-wire-format): keep steady-state traffic compact +7. [Hardware-backed cryptography](#security-properties): allow platform-specific crypto implementations +8. Shared core state machine: keep implementation consistent across platforms + +## Non-goals + +QLv2 is not: + +- a packet framing format +- a generic reliability layer for arbitrary raw datagrams +- a globally ordered message bus + +## Core terms + +- `peer`: one QLv2 endpoint +- `XID`: a stable 16-byte peer identifier +- `peer bundle`: public peer information: `version`, `xid`, `capabilities`, and ML-KEM public key +- `session`: one live encrypted channel with directional keys and directional connection IDs +- `record`: one complete QLv2 wire unit +- `frame`: one logical item inside a session record +- `stream`: one duplex byte stream inside a session +- `stream origin`: the peer that opened the stream +- `origin lane`: bytes sent by the stream origin +- `return lane`: bytes sent back toward the stream origin + +## Record And Frame Wire Format + +QLv2 has two record types: + +- `handshake record`: used only during setup +- `session record`: used after the handshake completes + +Handshake records are large because they carry ML-KEM material. Session records are small and can carry multiple frames, including frames for different streams. + +Handshake records are routed by peer identity. Session records are routed by `connection_id`. + +### Handshake records + +| Record | Size | Used when | Purpose | +| --- | ---: | --- | --- | +| `IK1` | 4793 bytes | initiator already knows the responder bundle | start a handshake toward a known responder | +| `IK2` | 3203 bytes | second message of `IK` | finish the responder side of the handshake and establish the session | +| `KK1` | 3187 bytes | both peers already know each other | start a handshake between already-known peers | +| `KK2` | 3203 bytes | second message of `KK` | finish the responder side of the handshake and establish the session | + +### Session records + +`session record size = 42 + sum(frame sizes)` -`Unpair` is a peer-level signed control record outside the session. It tears down the pairing relationship on a best-effort basis and does not depend on session ordering or session establishment. - -This structure gives QLv2 a few important properties: - -- one peer relationship can span many sessions over time -- one session can carry many streams at once -- stream data from different streams can be interwoven on the same session -- ordering is preserved within a stream, not across all streams -- one blocked stream does not block unrelated streams +There is no explicit AEAD nonce on the wire. The record `seq` is used to derive the nonce. + +| Fixed part | Size | Purpose | +| --- | ---: | --- | +| version | 1 byte | protocol version | +| record type | 1 byte | identifies a session record | +| `connection_id` | 16 bytes | route the record to the current session | +| `seq` | 8 bytes | record identity for ack and retransmit | +| AEAD auth tag | 16 bytes | authenticate the encrypted body | +| fixed overhead total | 42 bytes | overhead before any frames | + +The visible session header is authenticated as AEAD AAD but is not encrypted. + +### Session frames + +| Frame | Size | Purpose | +| --- | ---: | --- | +| `Ping` | 1 byte | keep the session alive when idle | +| `Ack` | 17 bytes | acknowledge received session records | +| `StreamWindow` | 13 bytes | extend per-stream send credit | +| `StreamClose` | 10 bytes | abort one stream lane or both lanes | +| `Close` | 3 bytes | close the whole session | +| `StreamData` | `16 + payload_len` bytes | carry stream bytes and optional `fin` | + +`StreamData` is the main steady-state frame: + +`1 kind + 2 variable-length prefix + 4 stream_id + 8 offset + 1 fin + payload_len` + +Some useful minimum record sizes: + +| Record | Size | Meaning | +| --- | ---: | --- | +| `Ping` only | 43 bytes | idle keepalive | +| `Close` only | 45 bytes | session shutdown | +| empty or fin-only `StreamData` | 58 bytes | open or finish a stream lane without payload bytes | ## Handshake -The handshake authenticates both peers, derives a fresh session key, and confirms that both sides can use it. -| Message | Sender | Est. size | Purpose | -| --- | --- | ---: | --- | -| `hello` | initiator | ~6253 bytes | start the handshake, contribute fresh key material, prove initiator identity | -| `hello_reply` | responder | ~6253 bytes | contribute fresh key material, prove responder identity, bind to `hello` | -| `confirm` | initiator | ~4673 bytes | prove the initiator saw `hello_reply` and derived the same session | -| `ready` | responder | ~62 bytes | prove the responder derived the session key by encrypting under it | +QLv2 currently supports two 2-message Noise-style handshake patterns: + +- `IK`: the initiator already knows the responder bundle +- `KK`: both peers already know each other + +The handshake covers peer authentication and session establishment. There is no separate peer-level pairing record. + +The handshake does five things: + +1. authenticate which peer we are talking to +2. derive a fresh transmit key and receive key +3. derive a directional transmit `connection_id` and receive `connection_id` +4. bind transport parameters into the transcript +5. produce a `handshake_hash` for the completed exchange + +Today, first-contact identity exchange is still partly out of band. `IK` removes the need for the responder to know the initiator in advance, but the initiator still needs the responder bundle before it can start. A future pattern such as `XX` could remove that requirement. + +Each handshake carries: + +- `handshake_id`: identifies one handshake attempt +- `valid_until`: expiration time for that attempt +- transport parameters: today this is initial per-stream receive credit + +Important behavior: + +- handshake start messages are replay-checked by `handshake_id` +- expired handshake messages are rejected +- simultaneous starts are resolved deterministically +- handshake attempts time out and are dropped rather than being retransmitted in place + +Session establishment is slightly asymmetric: + +- the responder enters the connected state when it processes message 1 and constructs message 2 +- the initiator enters the connected state when it receives message 2 + +## Session Model + +After the handshake, peers exchange encrypted session records. + +Each session record has: + +- one visible `connection_id` +- one visible `seq` +- one encrypted body containing one or more frames + +One session record may carry: + +- only control frames +- only stream data +- a mixture of frames for multiple streams + +This is the core steady-state model: records are the encrypted transport unit, frames are the logical items inside them. + +## Acknowledgment And Retransmission + +`Ack` is record-level, not stream-level. + +An `Ack` means the peer: + +- received that session record +- decrypted it with the current session key +- accepted its `seq` + +Retransmission works at the frame level: + +- every emitted session record gets a fresh `seq` +- retransmit timers start only after the local transport confirms that it accepted the write +- if a record is considered lost, the FSM restores its frames +- those frames are packed into a new record with a new `seq` + +QLv2 does not resend the same logical record identity. + +Receivers track a recent record window so they can: +- reject duplicates +- send selective acks with `base_seq + bitmap` + +## Streams + +Streams are the application primitive. + +A stream has two independent lanes: + +- origin lane +- return lane + +Important properties: + +- either peer can open a stream +- stream IDs are split by parity so both peers can open streams without collision +- ordering is preserved within a stream lane +- different streams can make progress independently +- record loss on one stream does not block unrelated streams + +`StreamData` carries: + +- `stream_id` +- `offset` +- `fin` +- bytes + +`fin` is graceful completion of one lane. It says "no more bytes on this lane" without aborting the other lane. + +## Flow Control + +Flow control is per stream. -Both peers contribute fresh key material during the handshake. The signatures bind the exchange to the two peers and to the full handshake transcript rather than to isolated messages. The session key is derived from the combined exchange. `ready` is the final key confirmation step because it is encrypted under that new session key. +During the handshake, each peer advertises an initial per-stream receive window. That becomes the initial send credit the remote peer can use on each stream. -The handshake also follows a few simple rules: +`StreamWindow` extends that credit by advertising a larger maximum offset. -- each handshake message has a bounded lifetime -- duplicate handshake messages can trigger resend of the matching response -- simultaneous `hello` messages are resolved deterministically so only one side continues as the initiator +Important detail: reading bytes is not what returns credit. Committing those reads is what returns credit and causes window updates to be sent. -## Session sequencing and reliability -This layer gives the session record-level acknowledgment and retransmission, independent of any one stream. +In practice, a stream is writable only when both are true: -| Term | Meaning | -| --- | --- | -| `seq` | session-wide sequence number for one encrypted record | -| `ack.base` | all sequence numbers up to this point are acknowledged | -| `ack.bitmap` | selective acknowledgment for the next 64 sequence numbers after `ack.base` | +- local send buffering has room +- peer-advertised stream credit allows more bytes -- every encrypted session record gets a `seq` -- the sequence space is shared by all streams on the session -- receivers can acknowledge out-of-order records within the session receive window -- retransmission resends the same logical session record with the same `seq` -- a QL acknowledgment tells us that the peer received the record, decrypted it successfully under the current session key, verified it, and accepted its session sequence number +## Close And Liveness -### Keepalive and liveness -- when a session is idle, a peer may send a `ping` to show that the session is still alive -- the peer does not answer with another `ping`; it simply acknowledges the record at the normal session layer -- if inbound traffic stays silent for too long, the session is treated as dead and closed +`StreamClose` aborts a stream early. Semantically it can target: -Multiple streams can be interwoven in the same session. A missing session record can stall byte delivery on its own stream, but it does not block unrelated streams. +- the origin lane +- the return lane +- both lanes -## Stream model -QLv2 uses duplex byte streams as the application primitive. +`Close` aborts the whole session. -- each stream has independent inbound and outbound directions -- either peer can open a stream at any time -- many streams can be active on the same session -- bytes are delivered in order within a stream -- each stream chunk may carry bytes and may also mark that direction as complete -- this supports both bounded exchanges and long-lived streams +Idle sessions may send `Ping`. The peer does not answer with another ping; normal record acknowledgment is enough. -Normal completion means one side is done sending bytes on that direction while the other direction may continue. Explicit close is different. It terminates one side or both sides of the stream early and carries a close code. +Sessions also have local timers for: -By convention, higher-level protocols can treat one direction as a request and the other as a response. +- handshake timeout +- delayed ack emission +- session record retransmit timeout +- keepalive ping interval +- peer silence timeout -### Example: RPC over streams +## Security Properties -#### Unary request/response +The current handshake is ML-KEM-based and post-quantum focused. -- the caller opens a stream -- the caller writes the request bytes and marks the request direction complete -- the responder reads the request, writes the response bytes, and marks the response direction complete +Session payloads are encrypted and authenticated. The session header stays visible so the receiver can route the record, but it is still authenticated as AEAD AAD. -#### Subscription +QLv2 also provides forward secrecy in the following sense: even if an attacker later obtains a peer's long-term ML-KEM private key, they still cannot decrypt messages from earlier completed sessions. -- the caller opens a stream and writes a request body (any subscription parameters) -- the caller marks the request direction complete once the request is sent -- the responder keeps writing response updates on the response direction until the subscription ends or the job completes -- either side can explicitly close the stream early to cancel From 1062be0c2796b0d2a24cc01cfb8a259178892535 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 06:39:54 -0400 Subject: [PATCH 097/304] ql: use Origin/Return stream lanes --- ql-fsm/src/lib.rs | 2 +- ql-fsm/src/session/state.rs | 8 ++++---- ql-runtime/src/driver.rs | 14 +++++++------- ql-runtime/src/handle.rs | 4 ++-- ql-runtime/src/tests/stream.rs | 2 +- ql-wire/src/encrypted/stream_close.rs | 18 +++++++++++++----- 6 files changed, 28 insertions(+), 20 deletions(-) diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index c6593407..80e407aa 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -287,7 +287,7 @@ impl QlFsm { implementation::finish_stream(self, stream_id) } - /// closes part or all of a stream + /// closes the origin lane, return lane, or both lanes of a stream pub fn close_stream( &mut self, stream_id: StreamId, diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index 70e9c0a1..d2aff325 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -92,15 +92,15 @@ pub enum StreamRole { impl StreamRole { pub fn outbound_target(self) -> CloseTarget { match self { - Self::Initiator => CloseTarget::Request, - Self::Responder => CloseTarget::Response, + Self::Initiator => CloseTarget::Origin, + Self::Responder => CloseTarget::Return, } } pub fn inbound_target(self) -> CloseTarget { match self { - Self::Initiator => CloseTarget::Response, - Self::Responder => CloseTarget::Request, + Self::Initiator => CloseTarget::Return, + Self::Responder => CloseTarget::Origin, } } } diff --git a/ql-runtime/src/driver.rs b/ql-runtime/src/driver.rs index 137ec69e..95b27a30 100644 --- a/ql-runtime/src/driver.rs +++ b/ql-runtime/src/driver.rs @@ -156,15 +156,15 @@ impl DriverStreamIo { fn inbound_target(&self) -> CloseTarget { match self { - Self::Initiator { .. } => CloseTarget::Response, - Self::Responder { .. } => CloseTarget::Request, + Self::Initiator { .. } => CloseTarget::Return, + Self::Responder { .. } => CloseTarget::Origin, } } fn outbound_target(&self) -> CloseTarget { match self { - Self::Initiator { .. } => CloseTarget::Request, - Self::Responder { .. } => CloseTarget::Response, + Self::Initiator { .. } => CloseTarget::Origin, + Self::Responder { .. } => CloseTarget::Return, } } @@ -352,13 +352,13 @@ impl DriverState { stream_id, request: ByteReader::new( stream_id, - CloseTarget::Request, + CloseTarget::Origin, request_rx, self.runtime_tx.clone(), ), response: ByteWriter::new( stream_id, - CloseTarget::Response, + CloseTarget::Return, response_writer, self.runtime_tx.clone(), ), @@ -762,7 +762,7 @@ mod tests { state.drive_command( RuntimeCommand::CloseStream { stream_id, - target: CloseTarget::Request, + target: CloseTarget::Origin, code: CloseCode::CANCELLED, payload: Vec::new(), }, diff --git a/ql-runtime/src/handle.rs b/ql-runtime/src/handle.rs index f013b78c..40707ef7 100644 --- a/ql-runtime/src/handle.rs +++ b/ql-runtime/src/handle.rs @@ -279,11 +279,11 @@ impl RuntimeHandle { stream_id, request: ByteWriter::new( stream_id, - CloseTarget::Request, + CloseTarget::Origin, request_writer, self.tx.clone(), ), - response: ByteReader::new(stream_id, CloseTarget::Response, response, self.tx.clone()), + response: ByteReader::new(stream_id, CloseTarget::Return, response, self.tx.clone()), }) } diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index f66096b1..b750bf4a 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -152,7 +152,7 @@ async fn dropping_responder_closes_initiator_response() { assert!(matches!( err, QlError::StreamClosed { - target: CloseTarget::Response, + target: CloseTarget::Return, code: CloseCode::CANCELLED, payload, } if payload.is_empty() diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index 81f9a2fe..742d9992 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -1,7 +1,11 @@ use super::StreamId; use crate::{codec, ByteSlice, WireError}; -/// aborts one or both directions of a stream with a close code. +/// aborts one or both lanes of a stream with a close code +/// +/// stream origin is the peer that opened the stream +/// origin lane carries bytes sent by the stream origin +/// return lane carries bytes sent back toward the stream origin #[derive(Debug, Clone, PartialEq, Eq)] pub struct StreamClose { pub stream_id: StreamId, @@ -30,11 +34,15 @@ impl codec::WireParse for StreamClose { } } +/// selects which stream lane a [`StreamClose`] applies to #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub enum CloseTarget { - Request = 1, - Response = 2, + /// close the lane sent by the stream origin + Origin = 1, + /// close the lane sent back toward the stream origin + Return = 2, + /// close both stream lanes Both = 3, } @@ -49,8 +57,8 @@ impl TryFrom for CloseTarget { fn try_from(value: u8) -> Result { match value { - 1 => Ok(Self::Request), - 2 => Ok(Self::Response), + 1 => Ok(Self::Origin), + 2 => Ok(Self::Return), 3 => Ok(Self::Both), _ => Err(WireError::InvalidPayload), } From 75c32454703f3712e3611a4bd3112b3c67573b43 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 06:40:16 -0400 Subject: [PATCH 098/304] ql: cleanup --- ql-fsm/src/session/stream_tx.rs | 32 ++++++++++++++-------------- ql-wire/src/encrypted/stream_data.rs | 1 - 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/ql-fsm/src/session/stream_tx.rs b/ql-fsm/src/session/stream_tx.rs index 8e239494..364f9457 100644 --- a/ql-fsm/src/session/stream_tx.rs +++ b/ql-fsm/src/session/stream_tx.rs @@ -2,12 +2,18 @@ use std::collections::VecDeque; use ql_wire::RangedByteChunks; +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamTx { + bytes: VecDeque, + base_offset: u64, + segments: VecDeque, + final_offset: Option, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum SendState { - Unsent, - InFlight, - Lost, - Acked, +struct TrackedFinalOffset { + offset: u64, + state: SendState, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -24,9 +30,11 @@ impl SendSegment { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -struct TrackedFinalOffset { - offset: u64, - state: SendState, +enum SendState { + Unsent, + InFlight, + Lost, + Acked, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -36,14 +44,6 @@ pub struct StreamTxRange { pub fin: bool, } -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct StreamTx { - bytes: VecDeque, - base_offset: u64, - segments: VecDeque, - final_offset: Option, -} - impl StreamTx { pub fn new() -> Self { Self { diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs index a4630ae5..33fbb7b8 100644 --- a/ql-wire/src/encrypted/stream_data.rs +++ b/ql-wire/src/encrypted/stream_data.rs @@ -46,7 +46,6 @@ impl StreamData { } pub fn encode_into(&self, out: &mut [u8]) { - assert_eq!(out.len(), self.wire_size()); let out = codec::write_u32(out, self.stream_id.0); let out = codec::write_u64(out, self.offset); let mut out = codec::write_bool(out, self.fin); From f02744206ae01fe70b94d441043305c510d4bcc4 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 06:51:24 -0400 Subject: [PATCH 099/304] ql-wire: stream close fixed len --- ql-wire/src/encrypted/builder.rs | 93 ++++++++++++++++++++++---------- ql-wire/src/encrypted/mod.rs | 6 ++- 2 files changed, 69 insertions(+), 30 deletions(-) diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index 090fa6f7..615f83e4 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -41,49 +41,57 @@ impl SessionRecordBuilder { } pub fn push_ping(&mut self) -> bool { - self.push_wire_size(1, |out| out[0] = super::SessionFrameKind::Ping as u8) + self.push_empty_frame(super::SessionFrameKind::Ping) } pub fn push_ack(&mut self, ack: &RecordAck) -> bool { - let len = 1 + RecordAck::WIRE_SIZE; - self.push_wire_size(len, |out| { - out[0] = super::SessionFrameKind::Ack as u8; - ack.encode_into(&mut out[1..]); - }) + self.push_frame_payload( + super::SessionFrameKind::Ack, + RecordAck::WIRE_SIZE, + |payload| { + ack.encode_into(payload); + }, + ) } pub fn push_stream_data(&mut self, frame: &StreamData) -> bool { - let len = 1 + super::SIZE_LEN + frame.wire_size(); - self.push_wire_size(len, |out| { - out[0] = super::SessionFrameKind::StreamData as u8; - super::push_variable_len(&mut out[1..=super::SIZE_LEN], frame.wire_size()); - frame.encode_into(&mut out[1 + super::SIZE_LEN..]); - }) + self.push_len_prefixed_frame( + super::SessionFrameKind::StreamData, + frame.wire_size(), + |payload| { + frame.encode_into(payload); + }, + ) } pub fn push_stream_window(&mut self, frame: &StreamWindow) -> bool { - let len = 1 + StreamWindow::WIRE_SIZE; - self.push_wire_size(len, |out| { - out[0] = super::SessionFrameKind::StreamWindow as u8; - frame.encode_into(&mut out[1..]); - }) + self.push_frame_payload( + super::SessionFrameKind::StreamWindow, + StreamWindow::WIRE_SIZE, + |payload| { + frame.encode_into(payload); + }, + ) } pub fn push_stream_close(&mut self, frame: &StreamClose) -> bool { - let len = 1 + super::SIZE_LEN + StreamClose::WIRE_SIZE; - self.push_wire_size(len, |out| { - out[0] = super::SessionFrameKind::StreamClose as u8; - super::push_variable_len(&mut out[1..=super::SIZE_LEN], StreamClose::WIRE_SIZE); - frame.encode_into(&mut out[1 + super::SIZE_LEN..]); - }) + self.push_frame_payload( + super::SessionFrameKind::StreamClose, + StreamClose::WIRE_SIZE, + |payload| { + frame.encode_into(payload); + }, + ) } pub fn push_close(&mut self, close: &SessionClose) -> bool { - let len = 1 + SessionClose::WIRE_SIZE; - self.push_wire_size(len, |out| { - out[0] = super::SessionFrameKind::Close as u8; - close.encode_into(&mut out[1..]); - }) + self.push_frame_payload( + super::SessionFrameKind::Close, + SessionClose::WIRE_SIZE, + |payload| { + close.encode_into(payload); + }, + ) } pub fn push_frame(&mut self, frame: &SessionFrame) -> bool { @@ -132,6 +140,35 @@ impl SessionRecordBuilder { true } + fn push_empty_frame(&mut self, kind: super::SessionFrameKind) -> bool { + self.push_wire_size(1, |out| out[0] = kind as u8) + } + + fn push_frame_payload( + &mut self, + kind: super::SessionFrameKind, + payload_wire_size: usize, + encode_payload: impl FnOnce(&mut [u8]), + ) -> bool { + self.push_wire_size(1 + payload_wire_size, |out| { + out[0] = kind as u8; + encode_payload(&mut out[1..]); + }) + } + + fn push_len_prefixed_frame( + &mut self, + kind: super::SessionFrameKind, + payload_wire_size: usize, + encode_payload: impl FnOnce(&mut [u8]), + ) -> bool { + self.push_wire_size(1 + super::SIZE_LEN + payload_wire_size, |out| { + out[0] = kind as u8; + super::push_variable_len(&mut out[1..=super::SIZE_LEN], payload_wire_size); + encode_payload(&mut out[1 + super::SIZE_LEN..]); + }) + } + fn can_push_len(&self, len: usize) -> bool { len <= self.remaining_capacity() } diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 9f9bc892..62c8dfec 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -101,7 +101,7 @@ impl SessionFrame { Self::Ack(_) => RecordAck::WIRE_SIZE, Self::StreamData(frame) => SIZE_LEN + frame.wire_size(), Self::StreamWindow(_) => StreamWindow::WIRE_SIZE, - Self::StreamClose(_) => SIZE_LEN + StreamClose::WIRE_SIZE, + Self::StreamClose(_) => StreamClose::WIRE_SIZE, Self::Close(_) => SessionClose::WIRE_SIZE, } } @@ -177,7 +177,9 @@ fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrame<&[u8]>, &[u8]), WireEr )) } SessionFrameKind::StreamClose => { - let (frame, rest) = split_variable_frame(rest)?; + let (frame, rest) = rest + .split_at_checked(StreamClose::WIRE_SIZE) + .ok_or(WireError::InvalidPayload)?; Ok(( SessionFrame::StreamClose(StreamClose::parse_bytes(frame)?), rest, From 4a2d85c94358895c180deb05e6c0455bb0c976e2 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 06:59:43 -0400 Subject: [PATCH 100/304] ql: fix clippy --- ql-fsm/src/implementation/handshake/ik.rs | 10 +-- ql-fsm/src/implementation/handshake/kk.rs | 10 +-- ql-fsm/src/implementation/handshake/mod.rs | 6 ++ ql-fsm/src/lib.rs | 2 +- ql-fsm/src/session/mod.rs | 4 +- ql-fsm/src/session/received_records.rs | 2 +- ql-fsm/src/session/state.rs | 5 +- ql-fsm/src/tests/mod.rs | 10 +-- ql-wire/src/handshake/ik.rs | 12 +-- ql-wire/src/handshake/kk.rs | 15 ++-- ql-wire/src/handshake/mod.rs | 2 +- ql-wire/src/record.rs | 2 +- ql-wire/src/tests.rs | 87 ++++++++-------------- 13 files changed, 70 insertions(+), 97 deletions(-) diff --git a/ql-fsm/src/implementation/handshake/ik.rs b/ql-fsm/src/implementation/handshake/ik.rs index 493f9ecb..7b372e69 100644 --- a/ql-fsm/src/implementation/handshake/ik.rs +++ b/ql-fsm/src/implementation/handshake/ik.rs @@ -1,4 +1,4 @@ -use ql_wire::{self as wire, Ik1, Ik2, PeerBundle, QlCrypto, QlHandshakeRecord, TransportParams}; +use ql_wire::{self as wire, Ik1, Ik2, PeerBundle, QlCrypto, QlHandshakeRecord}; use super::{ emit_peer_status, enqueue_handshake, finish_handshake, is_replayed_handshake_start, @@ -19,9 +19,7 @@ pub fn start_initiator( crypto, fsm.identity.clone(), peer, - TransportParams { - initial_stream_receive_window: fsm.config.session_stream_receive_buffer_size as u32, - }, + super::local_transport_params(fsm), ); let message = handshake.write_1(crypto, meta)?; @@ -62,9 +60,7 @@ pub fn handle_ik1( crypto, fsm.identity.clone(), fsm.state.peer.clone(), - TransportParams { - initial_stream_receive_window: fsm.config.session_stream_receive_buffer_size as u32, - }, + super::local_transport_params(fsm), ); handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; let outbound = handshake.write_2(crypto, message.meta)?; diff --git a/ql-fsm/src/implementation/handshake/kk.rs b/ql-fsm/src/implementation/handshake/kk.rs index 3af501b0..6f1e8371 100644 --- a/ql-fsm/src/implementation/handshake/kk.rs +++ b/ql-fsm/src/implementation/handshake/kk.rs @@ -1,4 +1,4 @@ -use ql_wire::{self as wire, Kk1, Kk2, PeerBundle, QlCrypto, QlHandshakeRecord, TransportParams}; +use ql_wire::{self as wire, Kk1, Kk2, PeerBundle, QlCrypto, QlHandshakeRecord}; use super::{ emit_peer_status, enqueue_handshake, finish_handshake, is_replayed_handshake_start, @@ -19,9 +19,7 @@ pub fn start_initiator( crypto, fsm.identity.clone(), peer, - TransportParams { - initial_stream_receive_window: fsm.config.session_stream_receive_buffer_size as u32, - }, + super::local_transport_params(fsm), ); let message = handshake.write_1(crypto, meta)?; @@ -61,9 +59,7 @@ pub fn handle_kk1( crypto, fsm.identity.clone(), peer, - TransportParams { - initial_stream_receive_window: fsm.config.session_stream_receive_buffer_size as u32, - }, + super::local_transport_params(fsm), ); handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; let outbound = handshake.write_2(crypto, message.meta)?; diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs index 1faec712..c35d7a42 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -39,6 +39,12 @@ pub fn enqueue_handshake(fsm: &mut QlFsm, record: QlHandshakeRecord) { fsm.state.handshake = Some(record); } +fn local_transport_params(fsm: &QlFsm) -> wire::TransportParams { + wire::TransportParams { + initial_stream_receive_window: fsm.config.session_stream_receive_buffer_size, + } +} + pub fn prepare_for_outbound_connect(fsm: &mut QlFsm) { fsm.state.handshake = None; reset_connected_session_if_needed(fsm); diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 80e407aa..294747bf 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -127,7 +127,7 @@ pub struct QlFsmConfig { /// maximum bytes buffered locally for one stream send side pub session_stream_send_buffer_size: usize, /// maximum bytes buffered locally for one stream receive side - pub session_stream_receive_buffer_size: usize, + pub session_stream_receive_buffer_size: u32, } impl Default for QlFsmConfig { diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 6ef6256d..9c091dbe 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -35,7 +35,7 @@ pub struct SessionFsmConfig { pub keepalive_interval: Duration, pub peer_timeout: Duration, pub stream_send_buffer_size: usize, - pub stream_receive_buffer_size: usize, + pub stream_receive_buffer_size: u32, pub initial_peer_stream_receive_window: u32, } @@ -917,7 +917,7 @@ fn schedule_ack(ack_state: &mut AckState, due_at: Instant) { AckState::Dirty { due_at: old } => AckState::Dirty { due_at: due_at.min(old), }, - AckState::Idle => AckState::Dirty { due_at: due_at }, + AckState::Idle => AckState::Dirty { due_at }, }; } diff --git a/ql-fsm/src/session/received_records.rs b/ql-fsm/src/session/received_records.rs index 1a8eeb9c..69cf0cad 100644 --- a/ql-fsm/src/session/received_records.rs +++ b/ql-fsm/src/session/received_records.rs @@ -41,7 +41,7 @@ impl ReceivedRecords { let out_of_order = seq != self .base - .saturating_add((u64::BITS - 1 - self.seen.leading_zeros()) as u64) + .saturating_add(u64::from(u64::BITS - 1 - self.seen.leading_zeros())) .saturating_add(1); self.base = base; self.seen = next_seen; diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index d2aff325..24e07c62 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -40,14 +40,15 @@ pub struct StreamState { impl StreamState { pub fn new( role: StreamRole, - receive_buffer_size: usize, + receive_buffer_size: u32, initial_peer_stream_receive_window: u32, ) -> Self { + let receive_buffer_size = receive_buffer_size as usize; Self { role, tx: StreamTx::new(), pending_close: None, - peer_max_offset: initial_peer_stream_receive_window as u64, + peer_max_offset: u64::from(initial_peer_stream_receive_window), outbound_state: OutboundState::Open, inbound_state: InboundState::Open, rx: StreamRx::new(receive_buffer_size), diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index d5c317c4..c15da72c 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -221,8 +221,7 @@ impl Harness { .b .fsm .config - .session_stream_receive_buffer_size - as u32, + .session_stream_receive_buffer_size, }, }, session: SessionFsm::new(session_config(&harness, true), harness.now), @@ -238,8 +237,7 @@ impl Harness { .a .fsm .config - .session_stream_receive_buffer_size - as u32, + .session_stream_receive_buffer_size, }, }, session: SessionFsm::new(session_config(&harness, false), harness.now), @@ -354,9 +352,9 @@ fn session_config(harness: &Harness, a: bool) -> SessionFsmConfig { stream_send_buffer_size: config.session_stream_send_buffer_size, stream_receive_buffer_size: config.session_stream_receive_buffer_size, initial_peer_stream_receive_window: if a { - harness.b.fsm.config.session_stream_receive_buffer_size as u32 + harness.b.fsm.config.session_stream_receive_buffer_size } else { - harness.a.fsm.config.session_stream_receive_buffer_size as u32 + harness.a.fsm.config.session_stream_receive_buffer_size }, } } diff --git a/ql-wire/src/handshake/ik.rs b/ql-wire/src/handshake/ik.rs index 8f166ab9..1d19069f 100644 --- a/ql-wire/src/handshake/ik.rs +++ b/ql-wire/src/handshake/ik.rs @@ -200,7 +200,7 @@ impl IkHandshake { header, HandshakeKind::Ik1, &meta, - &self.local_transport_params, + self.local_transport_params, ); let (skem_ciphertext, skem_secret) = crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); @@ -242,7 +242,7 @@ impl IkHandshake { header, HandshakeKind::Ik2, &meta, - &self.local_transport_params, + self.local_transport_params, ); let remote_ephemeral = self .remote_ephemeral @@ -290,7 +290,7 @@ impl IkHandshake { message.header, HandshakeKind::Ik1, &message.meta, - &message.transport_params, + message.transport_params, ); self.symmetric .mix_hash(crypto, message.skem_ciphertext.as_bytes()); @@ -338,7 +338,7 @@ impl IkHandshake { message.header, HandshakeKind::Ik2, &message.meta, - &message.transport_params, + message.transport_params, ); let local_ephemeral = self .local_ephemeral @@ -366,7 +366,9 @@ impl IkHandshake { return Err(WireError::InvalidState); } let remote_bundle = self.remote_bundle.ok_or(WireError::InvalidState)?; - let remote_transport_params = self.remote_transport_params.ok_or(WireError::InvalidState)?; + let remote_transport_params = self + .remote_transport_params + .ok_or(WireError::InvalidState)?; Ok(finalize_handshake( crypto, &self.symmetric, diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index 84afb3f5..31ed5cda 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -2,8 +2,7 @@ use super::{ decrypt_mlkem_ciphertext, encrypt_mlkem_ciphertext, finalize_handshake, generate_ephemeral_keypair, init_kk_symmetric, initialize_handshake_meta, mix_hash_ephemeral, mix_hash_routed_handshake, require_handshake_meta, EncryptedMlKemCiphertext, EphemeralKeyPair, - EphemeralPublicKey, FinalizedHandshake, HandshakeHeader, Role, SymmetricState, - TransportParams, + EphemeralPublicKey, FinalizedHandshake, HandshakeHeader, Role, SymmetricState, TransportParams, }; use crate::{ codec, ByteSlice, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, @@ -192,7 +191,7 @@ impl KkHandshake { header, HandshakeKind::Kk1, &meta, - &self.local_transport_params, + self.local_transport_params, ); let (skem_ciphertext, skem_secret) = crypto.mlkem_encapsulate(&self.remote_bundle.mlkem_public_key); @@ -232,7 +231,7 @@ impl KkHandshake { header, HandshakeKind::Kk2, &meta, - &self.local_transport_params, + self.local_transport_params, ); let remote_ephemeral = self .remote_ephemeral @@ -278,7 +277,7 @@ impl KkHandshake { message.header, HandshakeKind::Kk1, &message.meta, - &message.transport_params, + message.transport_params, ); self.symmetric .decrypt_and_hash(crypto, message.skem_ciphertext.as_bytes())?; @@ -312,7 +311,7 @@ impl KkHandshake { message.header, HandshakeKind::Kk2, &message.meta, - &message.transport_params, + message.transport_params, ); let local_ephemeral = self .local_ephemeral @@ -339,7 +338,9 @@ impl KkHandshake { if !self.is_finished() { return Err(WireError::InvalidState); } - let remote_transport_params = self.remote_transport_params.ok_or(WireError::InvalidState)?; + let remote_transport_params = self + .remote_transport_params + .ok_or(WireError::InvalidState)?; Ok(finalize_handshake( crypto, &self.symmetric, diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 241ffca1..56408933 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -312,7 +312,7 @@ fn mix_hash_routed_handshake( header: HandshakeHeader, kind: HandshakeKind, meta: &HandshakeMeta, - transport_params: &TransportParams, + transport_params: TransportParams, ) { let encoded_header = header.encode(); let encoded_meta = meta.encode(); diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index 509d0f2a..a8d4124a 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -171,6 +171,6 @@ impl WireParse for QlSessionRecord { } let header = reader.parse::()?; let payload = EncryptedMessage::parse(reader.take_bytes(reader.remaining_len())?)?; - Ok(QlSessionRecord { header, payload }) + Ok(Self { header, payload }) } } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index a66869eb..6a87939b 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -184,8 +184,8 @@ fn encrypt_record( let wire_size = body.wire_size() + SessionRecordBuilder::WIRE_PREFIX_LEN; let mut builder = SessionRecordBuilder::new(wire_size); for frame in &body.frames { - let _pushed = builder.push_frame(frame); - debug_assert!(_pushed); + let pushed = builder.push_frame(frame); + debug_assert!(pushed); } QlSessionRecord::parse_bytes(builder.encrypt(crypto, header, session_key).as_slice()) .unwrap() @@ -264,12 +264,8 @@ fn ik_handshake_rejects_tampered_handshake_meta() { responder.bundle(), TransportParams::default(), ); - let mut responder_state = IkHandshake::new_responder( - &crypto, - responder, - None, - TransportParams::default(), - ); + let mut responder_state = + IkHandshake::new_responder(&crypto, responder, None, TransportParams::default()); let m1 = initiator_state .write_1(&crypto, handshake_meta(77)) @@ -293,13 +289,12 @@ fn kk_handshake_rejects_tampered_handshake_header() { let initiator = make_identity(&crypto, 1); let responder = make_identity(&crypto, 2); - let mut initiator_state = - KkHandshake::new_initiator( - &crypto, - initiator.clone(), - responder.bundle(), - TransportParams::default(), - ); + let mut initiator_state = KkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + TransportParams::default(), + ); let mut responder_state = KkHandshake::new_responder( &crypto, responder, @@ -325,7 +320,7 @@ fn kk_handshake_rejects_tampered_handshake_header() { #[test] fn ik_handshake_rejects_tampered_transport_params() { - let crypto = TestCrypto::new(10_1); + let crypto = TestCrypto::new(101); let initiator = make_identity(&crypto, 1); let responder = make_identity(&crypto, 2); @@ -335,12 +330,8 @@ fn ik_handshake_rejects_tampered_transport_params() { responder.bundle(), handshake_transport_params(4096), ); - let mut responder_state = IkHandshake::new_responder( - &crypto, - responder, - None, - handshake_transport_params(8192), - ); + let mut responder_state = + IkHandshake::new_responder(&crypto, responder, None, handshake_transport_params(8192)); let m1 = initiator_state .write_1(&crypto, handshake_meta(89)) @@ -370,12 +361,8 @@ fn ik_handshake_rejects_tampered_handshake_header() { responder.bundle(), TransportParams::default(), ); - let mut responder_state = IkHandshake::new_responder( - &crypto, - responder, - None, - TransportParams::default(), - ); + let mut responder_state = + IkHandshake::new_responder(&crypto, responder, None, TransportParams::default()); let mut m1 = initiator_state .write_1(&crypto, handshake_meta(90)) @@ -430,12 +417,8 @@ fn ik_handshake_rejects_expired_message() { responder.bundle(), TransportParams::default(), ); - let mut responder_state = IkHandshake::new_responder( - &crypto, - responder, - None, - TransportParams::default(), - ); + let mut responder_state = + IkHandshake::new_responder(&crypto, responder, None, TransportParams::default()); let m1 = initiator_state .write_1( @@ -467,12 +450,8 @@ fn ik_handshake_round_trip_derives_matching_transport_and_learns_remote() { responder.bundle(), initiator_params, ); - let mut responder_state = IkHandshake::new_responder( - &crypto, - responder.clone(), - None, - responder_params, - ); + let mut responder_state = + IkHandshake::new_responder(&crypto, responder.clone(), None, responder_params); let m1 = initiator_state .write_1(&crypto, handshake_meta(11)) @@ -739,19 +718,14 @@ fn protocol_record_size_breakdown() { let initiator = make_identity(&crypto, 1); let responder = make_identity(&crypto, 2); - let mut ik_initiator = - IkHandshake::new_initiator( - &crypto, - initiator.clone(), - responder.bundle(), - TransportParams::default(), - ); - let mut ik_responder = IkHandshake::new_responder( + let mut ik_initiator = IkHandshake::new_initiator( &crypto, - responder.clone(), - None, + initiator.clone(), + responder.bundle(), TransportParams::default(), ); + let mut ik_responder = + IkHandshake::new_responder(&crypto, responder.clone(), None, TransportParams::default()); let ik1 = ik_initiator.write_1(&crypto, handshake_meta(101)).unwrap(); ik_responder.read_1(&crypto, 0, &ik1).unwrap(); @@ -762,13 +736,12 @@ fn protocol_record_size_breakdown() { let ik1 = QlHandshakeRecord::Ik1(ik1); let ik2 = QlHandshakeRecord::Ik2(ik2); - let mut kk_initiator = - KkHandshake::new_initiator( - &crypto, - initiator.clone(), - responder.bundle(), - TransportParams::default(), - ); + let mut kk_initiator = KkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + TransportParams::default(), + ); let mut kk_responder = KkHandshake::new_responder( &crypto, responder, From 8953465eccd41b27278f14985706a8eb8de9b000 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 07:30:36 -0400 Subject: [PATCH 101/304] ql-fsm: event callback sink --- ql-fsm/src/implementation/core.rs | 78 +++++------- ql-fsm/src/implementation/handshake/ik.rs | 11 +- ql-fsm/src/implementation/handshake/kk.rs | 11 +- ql-fsm/src/implementation/handshake/mod.rs | 36 +++--- ql-fsm/src/lib.rs | 74 +++++------ ql-fsm/src/state.rs | 9 +- ql-fsm/src/tests/handshake.rs | 135 ++++++++------------- ql-fsm/src/tests/mod.rs | 96 +++++++++++++-- ql-fsm/src/tests/session.rs | 110 ++++++++--------- 9 files changed, 294 insertions(+), 266 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index a219e0ff..39e2f7b4 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -1,36 +1,32 @@ -use std::{ - collections::VecDeque, - time::{Duration, Instant}, -}; +use std::time::{Duration, Instant}; use ql_wire::{ - self as wire, CloseTarget, QlCrypto, SessionClose, SessionCloseCode, SessionHeader, - StreamCloseCode, StreamId, WireParse, + self as wire, CloseTarget, QlCrypto, SessionCloseCode, SessionHeader, StreamCloseCode, + StreamId, WireParse, }; use crate::{ session::SessionEvent, state::LinkState, OutboundWrite, QlFsm, QlFsmError, QlFsmEvent, - QlSessionEvent, SessionWriteId, StreamReadIter, + SessionWriteId, StreamReadIter, }; pub fn handle_bind_peer(fsm: &mut QlFsm, peer: ql_wire::PeerBundle) { fsm.state.handshake = None; fsm.state.link = LinkState::Idle; - fsm.state.peer = Some(peer.clone()); - fsm.state.events.push_back(QlFsmEvent::NewPeer(peer)); - emit_peer_status(fsm); + fsm.state.peer = Some(peer); } pub fn receive( fsm: &mut QlFsm, mut bytes: Vec, crypto: &impl QlCrypto, + mut emit: impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { let header = wire::RecordHeader::parse_prefix(bytes.as_slice())?; match header.record_type { wire::RecordType::Handshake => { let record = wire::QlHandshakeRecord::parse_bytes(bytes.as_slice())?; - super::handle_handshake_record(fsm, crypto, &record) + super::handle_handshake_record(fsm, crypto, &record, &mut emit) } wire::RecordType::Session => { let record = wire::QlSessionRecord::parse_bytes(&mut bytes[..])?; @@ -49,37 +45,31 @@ pub fn receive( let mut session_closed = false; state .session - .receive(fsm.state.now.instant, record.header.seq, frames, { - let session_events = &mut fsm.state.session_events; - |event| { - session_closed |= forward_session_event(session_events, event); - } + .receive(fsm.state.now.instant, record.header.seq, frames, |event| { + session_closed |= forward_session_event(event, &mut emit); }); if session_closed { - apply_session_closed(fsm); + apply_session_closed(fsm, &mut emit); } Ok(()) } } } -pub fn on_timer(fsm: &mut QlFsm) { - super::handle_timer(fsm); +pub fn on_timer(fsm: &mut QlFsm, mut emit: impl FnMut(QlFsmEvent)) { + super::handle_timer(fsm, &mut emit); let Some(state) = fsm.state.link.connected_mut() else { return; }; let mut session_closed = false; - state.session.on_timer(fsm.state.now.instant, { - let session_events = &mut fsm.state.session_events; - |event| { - session_closed |= forward_session_event(session_events, event); - } + state.session.on_timer(fsm.state.now.instant, |event| { + session_closed |= forward_session_event(event, &mut emit); }); if session_closed { - apply_session_closed(fsm); + apply_session_closed(fsm, &mut emit); } } @@ -134,7 +124,7 @@ pub fn reject_session_write(fsm: &mut QlFsm, write_id: SessionWriteId) { } } -pub fn kill_session(fsm: &mut QlFsm, code: SessionCloseCode) { +pub fn kill_session(fsm: &mut QlFsm, _code: SessionCloseCode) { if fsm.state.peer.is_none() { return; } @@ -143,10 +133,6 @@ pub fn kill_session(fsm: &mut QlFsm, code: SessionCloseCode) { } fsm.state.link = crate::state::LinkState::Idle; - emit_peer_status(fsm); - fsm.state - .session_events - .push_back(QlSessionEvent::SessionClosed(SessionClose { code })); } pub fn open_stream(fsm: &mut QlFsm) -> Result { @@ -206,55 +192,49 @@ pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), QlFsmError> { Ok(state.session.queue_ping()?) } -pub fn emit_peer_status(fsm: &mut QlFsm) { - if let Some(peer) = fsm.state.peer.as_ref() { - fsm.state.events.push_back(QlFsmEvent::PeerStatusChanged { - peer: peer.xid, - status: fsm.state.link.status(), - }); +pub fn emit_peer_status(fsm: &QlFsm, emit: &mut impl FnMut(QlFsmEvent)) { + if fsm.state.peer.is_some() { + emit(QlFsmEvent::PeerStatusChanged(fsm.state.link.status())); } } -fn forward_session_event( - session_events: &mut VecDeque, - event: SessionEvent, -) -> bool { +fn forward_session_event(event: SessionEvent, emit: &mut impl FnMut(QlFsmEvent)) -> bool { match event { SessionEvent::Opened(stream_id) => { - session_events.push_back(QlSessionEvent::Opened(stream_id)); + emit(QlFsmEvent::Opened(stream_id)); false } SessionEvent::Readable(stream_id) => { - session_events.push_back(QlSessionEvent::Readable(stream_id)); + emit(QlFsmEvent::Readable(stream_id)); false } SessionEvent::Writable(stream_id) => { - session_events.push_back(QlSessionEvent::Writable(stream_id)); + emit(QlFsmEvent::Writable(stream_id)); false } SessionEvent::Finished(stream_id) => { - session_events.push_back(QlSessionEvent::Finished(stream_id)); + emit(QlFsmEvent::Finished(stream_id)); false } SessionEvent::Closed(frame) => { - session_events.push_back(QlSessionEvent::Closed(frame)); + emit(QlFsmEvent::Closed(frame)); false } SessionEvent::WritableClosed(stream_id) => { - session_events.push_back(QlSessionEvent::WritableClosed(stream_id)); + emit(QlFsmEvent::WritableClosed(stream_id)); false } SessionEvent::SessionClosed(close) => { - session_events.push_back(QlSessionEvent::SessionClosed(close)); + emit(QlFsmEvent::SessionClosed(close)); true } } } -fn apply_session_closed(fsm: &mut QlFsm) { +fn apply_session_closed(fsm: &mut QlFsm, emit: &mut impl FnMut(QlFsmEvent)) { if matches!(fsm.state.link, crate::state::LinkState::Connected(_)) { fsm.state.link = crate::state::LinkState::Idle; - emit_peer_status(fsm); + emit_peer_status(fsm, emit); } } diff --git a/ql-fsm/src/implementation/handshake/ik.rs b/ql-fsm/src/implementation/handshake/ik.rs index 7b372e69..02eea0d4 100644 --- a/ql-fsm/src/implementation/handshake/ik.rs +++ b/ql-fsm/src/implementation/handshake/ik.rs @@ -6,13 +6,14 @@ use super::{ }; use crate::{ state::{IkInitiatorState, LinkState, SessionTransport}, - QlFsm, QlFsmError, + QlFsm, QlFsmError, QlFsmEvent, }; pub fn start_initiator( fsm: &mut QlFsm, crypto: &impl QlCrypto, peer: PeerBundle, + emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { let meta = super::next_handshake_meta(fsm); let mut handshake = wire::IkHandshake::new_initiator( @@ -30,7 +31,7 @@ pub fn start_initiator( deadline: fsm.state.now.instant + fsm.config.handshake_timeout, }); enqueue_handshake(fsm, QlHandshakeRecord::Ik1(message)); - emit_peer_status(fsm); + emit_peer_status(fsm, emit); Ok(()) } @@ -38,6 +39,7 @@ pub fn handle_ik1( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Ik1, + emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { if should_ignore_inbound(fsm, message) { return Ok(()); @@ -65,7 +67,7 @@ pub fn handle_ik1( handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; let outbound = handshake.write_2(crypto, message.meta)?; let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); - finish_handshake(fsm, transport, &remote_bundle)?; + finish_handshake(fsm, transport, &remote_bundle, emit)?; fsm.state.handshake = None; enqueue_handshake(fsm, QlHandshakeRecord::Ik2(outbound)); Ok(()) @@ -75,6 +77,7 @@ pub fn handle_ik2( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Ik2, + emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { { let LinkState::IkInitiator(state) = &mut fsm.state.link else { @@ -95,7 +98,7 @@ pub fn handle_ik2( }; let (transport, remote_bundle) = SessionTransport::from_finalized(state.handshake.finalize(crypto)?); - finish_handshake(fsm, transport, &remote_bundle) + finish_handshake(fsm, transport, &remote_bundle, emit) } pub fn should_ignore_inbound(fsm: &QlFsm, message: &Ik1) -> bool { diff --git a/ql-fsm/src/implementation/handshake/kk.rs b/ql-fsm/src/implementation/handshake/kk.rs index 6f1e8371..b5d877c7 100644 --- a/ql-fsm/src/implementation/handshake/kk.rs +++ b/ql-fsm/src/implementation/handshake/kk.rs @@ -6,13 +6,14 @@ use super::{ }; use crate::{ state::{KkInitiatorState, LinkState, SessionTransport}, - QlFsm, QlFsmError, + QlFsm, QlFsmError, QlFsmEvent, }; pub fn start_initiator( fsm: &mut QlFsm, crypto: &impl QlCrypto, peer: PeerBundle, + emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { let meta = super::next_handshake_meta(fsm); let mut handshake = wire::KkHandshake::new_initiator( @@ -30,7 +31,7 @@ pub fn start_initiator( deadline: fsm.state.now.instant + fsm.config.handshake_timeout, }); enqueue_handshake(fsm, QlHandshakeRecord::Kk1(message)); - emit_peer_status(fsm); + emit_peer_status(fsm, emit); Ok(()) } @@ -38,6 +39,7 @@ pub fn handle_kk1( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Kk1, + emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { if should_ignore_inbound(fsm, message) { return Ok(()); @@ -64,7 +66,7 @@ pub fn handle_kk1( handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; let outbound = handshake.write_2(crypto, message.meta)?; let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); - finish_handshake(fsm, transport, &remote_bundle)?; + finish_handshake(fsm, transport, &remote_bundle, emit)?; fsm.state.handshake = None; enqueue_handshake(fsm, QlHandshakeRecord::Kk2(outbound)); Ok(()) @@ -74,6 +76,7 @@ pub fn handle_kk2( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Kk2, + emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { { let LinkState::KkInitiator(state) = &mut fsm.state.link else { @@ -94,7 +97,7 @@ pub fn handle_kk2( }; let (transport, remote_bundle) = SessionTransport::from_finalized(state.handshake.finalize(crypto)?); - finish_handshake(fsm, transport, &remote_bundle) + finish_handshake(fsm, transport, &remote_bundle, emit) } pub fn should_ignore_inbound(fsm: &QlFsm, message: &Kk1) -> bool { diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs index c35d7a42..4e9d9a0d 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -10,16 +10,24 @@ use crate::{ QlFsm, QlFsmError, QlFsmEvent, }; -pub fn handle_connect_ik(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { +pub fn handle_connect_ik( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + mut emit: impl FnMut(QlFsmEvent), +) -> Result<(), QlFsmError> { let peer = fsm.state.peer.clone().ok_or(QlFsmError::NoPeerBound)?; prepare_for_outbound_connect(fsm); - ik::start_initiator(fsm, crypto, peer) + ik::start_initiator(fsm, crypto, peer, &mut emit) } -pub fn handle_connect_kk(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { +pub fn handle_connect_kk( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + mut emit: impl FnMut(QlFsmEvent), +) -> Result<(), QlFsmError> { let peer = fsm.state.peer.clone().ok_or(QlFsmError::NoPeerBound)?; prepare_for_outbound_connect(fsm); - kk::start_initiator(fsm, crypto, peer) + kk::start_initiator(fsm, crypto, peer, &mut emit) } pub fn next_handshake_meta(fsm: &mut QlFsm) -> HandshakeMeta { @@ -60,16 +68,17 @@ pub fn handle_handshake_record( fsm: &mut QlFsm, crypto: &impl QlCrypto, record: &QlHandshakeRecord, + emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { match record { - QlHandshakeRecord::Ik1(message) => ik::handle_ik1(fsm, crypto, message), - QlHandshakeRecord::Ik2(message) => ik::handle_ik2(fsm, crypto, message), - QlHandshakeRecord::Kk1(message) => kk::handle_kk1(fsm, crypto, message), - QlHandshakeRecord::Kk2(message) => kk::handle_kk2(fsm, crypto, message), + QlHandshakeRecord::Ik1(message) => ik::handle_ik1(fsm, crypto, message, emit), + QlHandshakeRecord::Ik2(message) => ik::handle_ik2(fsm, crypto, message, emit), + QlHandshakeRecord::Kk1(message) => kk::handle_kk1(fsm, crypto, message, emit), + QlHandshakeRecord::Kk2(message) => kk::handle_kk2(fsm, crypto, message, emit), } } -pub fn handle_timer(fsm: &mut QlFsm) { +pub fn handle_timer(fsm: &mut QlFsm, emit: &mut impl FnMut(QlFsmEvent)) { let Some(deadline) = fsm.state.link.handshake_deadline() else { return; }; @@ -79,7 +88,7 @@ pub fn handle_timer(fsm: &mut QlFsm) { fsm.state.link = LinkState::Idle; fsm.state.handshake = None; - emit_peer_status(fsm); + emit_peer_status(fsm, emit); } pub fn next_handshake_deadline(fsm: &QlFsm) -> Option { @@ -90,6 +99,7 @@ pub fn finish_handshake( fsm: &mut QlFsm, transport: SessionTransport, remote_bundle: &wire::PeerBundle, + emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { if let Some(peer) = fsm.state.peer.as_ref() { if peer != remote_bundle { @@ -97,9 +107,7 @@ pub fn finish_handshake( } } else { fsm.state.peer = Some(remote_bundle.clone()); - fsm.state - .events - .push_back(QlFsmEvent::NewPeer(remote_bundle.clone())); + emit(QlFsmEvent::NewPeer); } let config = &fsm.config; @@ -120,7 +128,7 @@ pub fn finish_handshake( fsm.state.now.instant, ); fsm.state.link = LinkState::Connected(ConnectedState { transport, session }); - emit_peer_status(fsm); + emit_peer_status(fsm, emit); Ok(()) } diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 294747bf..29bbe1b5 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -10,10 +10,10 @@ //! //! outputs from `QlFsm` are //! - outbound session and handshake records from `take_next_write` -//! - peer events from `take_next_event` -//! - session events from `take_next_session_event` +//! - callback-driven `QlFsmEvent`s emitted during `connect_ik`, `connect_kk`, `receive`, and +//! `on_timer` //! -//! call `next_deadline` after handling current inputs and draining current outputs +//! call `next_deadline` after handling current inputs and any emitted outputs //! use it to decide how long the outer loop can wait before `on_timer` must run //! another input may arrive before that deadline, which is fine @@ -30,7 +30,7 @@ use std::time::{Duration, Instant}; pub use error::QlFsmError; use ql_wire::{ CloseTarget, PeerBundle, QlCrypto, QlIdentity, SessionClose, SessionCloseCode, StreamClose, - StreamCloseCode, StreamId, XID, + StreamCloseCode, StreamId, }; pub use session::stream_rx::StreamReadIter; @@ -59,25 +59,13 @@ pub enum PeerStatus { Connected, } -/// peer-level events emitted by `QlFsm` -#[derive(Debug, Clone)] +/// events emitted by `QlFsm` +#[derive(Debug, Clone, PartialEq, Eq)] pub enum QlFsmEvent { - /// a peer was bound or replaced - NewPeer(PeerBundle), - /// the bound peer was cleared - ClearPeer, + /// a peer was learned during handshake completion + NewPeer, /// the peer changed connection state - PeerStatusChanged { - /// peer that changed state - peer: XID, - /// new connection state - status: PeerStatus, - }, -} - -/// session and stream events emitted by `QlFsm` -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum QlSessionEvent { + PeerStatusChanged(PeerStatus), /// a stream was opened Opened(StreamId), /// a stream has bytes ready to read @@ -90,8 +78,6 @@ pub enum QlSessionEvent { Closed(StreamClose), /// local writes on this stream are closed WritableClosed(StreamId), - /// the peer requested unpairing - Unpaired, /// the encrypted session was closed SessionClosed(SessionClose), } @@ -167,8 +153,6 @@ impl QlFsm { peer: None, handshake: None, link: LinkState::Idle, - events: Default::default(), - session_events: Default::default(), now, }, } @@ -179,16 +163,31 @@ impl QlFsm { implementation::handle_bind_peer(self, peer); } + /// returns the currently bound peer, if any + pub fn peer(&self) -> Option<&PeerBundle> { + self.state.peer.as_ref() + } + /// starts or replaces an IK handshake with the currently bound peer - pub fn connect_ik(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { + pub fn connect_ik( + &mut self, + now: FsmTime, + crypto: &impl QlCrypto, + emit: impl FnMut(QlFsmEvent), + ) -> Result<(), QlFsmError> { self.state.now = now; - implementation::handle_connect_ik(self, crypto) + implementation::handle_connect_ik(self, crypto, emit) } /// starts or replaces a KK handshake with the currently bound peer - pub fn connect_kk(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { + pub fn connect_kk( + &mut self, + now: FsmTime, + crypto: &impl QlCrypto, + emit: impl FnMut(QlFsmEvent), + ) -> Result<(), QlFsmError> { self.state.now = now; - implementation::handle_connect_kk(self, crypto) + implementation::handle_connect_kk(self, crypto, emit) } /// handles one inbound wire message @@ -197,15 +196,16 @@ impl QlFsm { now: FsmTime, bytes: Vec, crypto: &impl QlCrypto, + emit: impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { self.state.now = now; - implementation::receive(self, bytes, crypto) + implementation::receive(self, bytes, crypto, emit) } /// advances time-based state - pub fn on_timer(&mut self, now: FsmTime) { + pub fn on_timer(&mut self, now: FsmTime, emit: impl FnMut(QlFsmEvent)) { self.state.now = now; - implementation::on_timer(self); + implementation::on_timer(self, emit); } /// returns the next timer deadline, if any @@ -248,11 +248,6 @@ impl QlFsm { implementation::kill_session(self, code); } - /// returns the next peer-level event - pub fn take_next_event(&mut self) -> Option { - self.state.events.pop_front() - } - /// opens a new outgoing stream pub fn open_stream(&mut self) -> Result { implementation::open_stream(self) @@ -301,9 +296,4 @@ impl QlFsm { pub fn queue_ping(&mut self) -> Result<(), QlFsmError> { implementation::queue_ping(self) } - - /// returns the next session or stream event - pub fn take_next_session_event(&mut self) -> Option { - self.state.session_events.pop_front() - } } diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index 115e42d5..03ae1709 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -1,14 +1,11 @@ -use std::{collections::VecDeque, time::Instant}; +use std::time::Instant; use ql_wire::{ ConnectionId, EphemeralPublicKey, HandshakeId, IkHandshake, KkHandshake, PeerBundle, QlHandshakeRecord, SessionKey, TransportParams, }; -use crate::{ - replay_cache::ReplayCache, session::SessionFsm, FsmTime, PeerStatus, QlFsmError, QlFsmEvent, - QlSessionEvent, -}; +use crate::{replay_cache::ReplayCache, session::SessionFsm, FsmTime, PeerStatus, QlFsmError}; pub struct QlFsmState { pub replay_cache: ReplayCache, @@ -16,8 +13,6 @@ pub struct QlFsmState { pub peer: Option, pub handshake: Option, pub link: LinkState, - pub events: VecDeque, - pub session_events: VecDeque, pub now: FsmTime, } diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 21344418..766c5cf1 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -3,17 +3,13 @@ use std::time::Duration; use ql_wire::{QlHandshakeRecord, WireParse}; use super::*; -use crate::{state::LinkState, QlFsmError}; +use crate::{state::LinkState, PeerStatus, QlFsmError, QlFsmEvent}; #[test] fn ik_connect_round_trip_establishes_transport() { let mut harness = Harness::paired_known(QlFsmConfig::default()); - harness - .a - .fsm - .connect_ik(harness.time(), &harness.a.crypto) - .unwrap(); + harness.connect_ik_a().unwrap(); harness.pump(); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); @@ -24,11 +20,7 @@ fn ik_connect_round_trip_establishes_transport() { fn kk_connect_round_trip_establishes_transport() { let mut harness = Harness::paired_known(QlFsmConfig::default()); - harness - .a - .fsm - .connect_kk(harness.time(), &harness.a.crypto) - .unwrap(); + harness.connect_kk_a().unwrap(); harness.pump(); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); @@ -48,11 +40,7 @@ fn ik_connect_learns_remote_initial_stream_receive_window() { }, ); - harness - .a - .fsm - .connect_ik(harness.time(), &harness.a.crypto) - .unwrap(); + harness.connect_ik_a().unwrap(); harness.pump(); assert_eq!( @@ -88,27 +76,38 @@ fn connect_methods_require_bound_peer() { let mut fsm = QlFsm::new(QlFsmConfig::default(), identity, time); let crypto = TestCrypto::new(9); - assert_eq!(fsm.connect_ik(time, &crypto), Err(QlFsmError::NoPeerBound)); - assert_eq!(fsm.connect_kk(time, &crypto), Err(QlFsmError::NoPeerBound)); + assert_eq!( + fsm.connect_ik(time, &crypto, |_| {}), + Err(QlFsmError::NoPeerBound) + ); + assert_eq!( + fsm.connect_kk(time, &crypto, |_| {}), + Err(QlFsmError::NoPeerBound) + ); +} + +#[test] +fn connect_ik_emits_initiator_status() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.connect_ik_a().unwrap(); + + assert_eq!( + harness.drain_events_a(), + vec![QlFsmEvent::PeerStatusChanged(PeerStatus::Initiator)] + ); } #[test] fn connect_ik_replaces_in_flight_attempt_and_ignores_stale_reply() { let mut harness = Harness::paired_known(QlFsmConfig::default()); - harness - .a - .fsm - .connect_ik(harness.time(), &harness.a.crypto) - .unwrap(); + harness.connect_ik_a().unwrap(); + harness.drain_events_a(); let first = harness.next_outbound_a().unwrap(); let first_id = handshake_id(&first); - harness - .a - .fsm - .connect_ik(harness.time(), &harness.a.crypto) - .unwrap(); + harness.connect_ik_a().unwrap(); let second = harness.next_outbound_a().unwrap(); let second_id = handshake_id(&second); @@ -135,19 +134,11 @@ fn connect_ik_replaces_in_flight_attempt_and_ignores_stale_reply() { fn connect_kk_replaces_in_flight_attempt_and_ignores_stale_reply() { let mut harness = Harness::paired_known(QlFsmConfig::default()); - harness - .a - .fsm - .connect_kk(harness.time(), &harness.a.crypto) - .unwrap(); + harness.connect_kk_a().unwrap(); let first = harness.next_outbound_a().unwrap(); let first_id = handshake_id(&first); - harness - .a - .fsm - .connect_kk(harness.time(), &harness.a.crypto) - .unwrap(); + harness.connect_kk_a().unwrap(); let second = harness.next_outbound_a().unwrap(); let second_id = handshake_id(&second); @@ -174,16 +165,17 @@ fn connect_kk_replaces_in_flight_attempt_and_ignores_stale_reply() { fn inbound_ik1_auto_binds_unbound_responder() { let mut harness = Harness::paired(QlFsmConfig::default(), true, false); - harness - .a - .fsm - .connect_ik(harness.time(), &harness.a.crypto) - .unwrap(); + harness.connect_ik_a().unwrap(); harness.pump(); + let expected_peer = harness.a.fsm.identity.bundle(); + assert_eq!(harness.b.fsm.peer(), Some(&expected_peer)); assert_eq!( - harness.b.fsm.state.peer, - Some(harness.a.fsm.identity.bundle()) + harness.drain_events_b(), + vec![ + QlFsmEvent::NewPeer, + QlFsmEvent::PeerStatusChanged(PeerStatus::Connected), + ] ); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); @@ -197,20 +189,21 @@ fn handshake_timeout_drops_single_ik_attempt_without_resend() { }; let mut harness = Harness::paired_known(config); - harness - .a - .fsm - .connect_ik(harness.time(), &harness.a.crypto) - .unwrap(); + harness.connect_ik_a().unwrap(); + harness.drain_events_a(); let first = harness.next_outbound_a().unwrap(); let first = QlHandshakeRecord::parse_bytes(first.as_slice()).unwrap(); assert!(matches!(first, ql_wire::QlHandshakeRecord::Ik1(_))); assert!(harness.next_outbound_a().is_none()); harness.advance(config.handshake_timeout); - harness.a.fsm.on_timer(harness.time()); + harness.on_timer_a(); assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); + assert_eq!( + harness.take_event_a(), + Some(QlFsmEvent::PeerStatusChanged(PeerStatus::Disconnected)) + ); assert!(harness.next_outbound_a().is_none()); } @@ -222,14 +215,10 @@ fn handshake_timeout_clears_queued_kk_output() { }; let mut harness = Harness::paired_known(config); - harness - .a - .fsm - .connect_kk(harness.time(), &harness.a.crypto) - .unwrap(); + harness.connect_kk_a().unwrap(); harness.advance(config.handshake_timeout); - harness.a.fsm.on_timer(harness.time()); + harness.on_timer_a(); assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); assert!(harness.next_outbound_a().is_none()); @@ -239,13 +228,11 @@ fn handshake_timeout_clears_queued_kk_output() { fn bind_peer_clears_queued_handshake_output() { let mut harness = Harness::paired_known(QlFsmConfig::default()); - harness - .a - .fsm - .connect_ik(harness.time(), &harness.a.crypto) - .unwrap(); + harness.connect_ik_a().unwrap(); + harness.drain_events_a(); harness.a.fsm.bind_peer(test_identity(99).bundle()); + assert!(harness.drain_events_a().is_empty()); assert!(harness.next_outbound_a().is_none()); } @@ -253,16 +240,8 @@ fn bind_peer_clears_queued_handshake_output() { fn simultaneous_ik_connect_converges() { let mut harness = Harness::paired_known(QlFsmConfig::default()); - harness - .a - .fsm - .connect_ik(harness.time(), &harness.a.crypto) - .unwrap(); - harness - .b - .fsm - .connect_ik(harness.time(), &harness.b.crypto) - .unwrap(); + harness.connect_ik_a().unwrap(); + harness.connect_ik_b().unwrap(); harness.pump(); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); @@ -273,16 +252,8 @@ fn simultaneous_ik_connect_converges() { fn simultaneous_ik_and_kk_connect_prefers_ik() { let mut harness = Harness::paired_known(QlFsmConfig::default()); - harness - .a - .fsm - .connect_ik(harness.time(), &harness.a.crypto) - .unwrap(); - harness - .b - .fsm - .connect_kk(harness.time(), &harness.b.crypto) - .unwrap(); + harness.connect_ik_a().unwrap(); + harness.connect_kk_b().unwrap(); harness.pump(); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index c15da72c..fc108c1c 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -3,6 +3,7 @@ mod session; use std::{ cell::Cell, + collections::VecDeque, time::{Duration, Instant}, }; @@ -18,7 +19,7 @@ use sha2::{Digest, Sha256}; use crate::{ session::{stream_parity::StreamParity, SessionFsm, SessionFsmConfig}, state::{ConnectedState, LinkState, SessionTransport}, - FsmTime, OutboundWrite, QlFsm, QlFsmConfig, SessionWriteId, + FsmTime, OutboundWrite, QlFsm, QlFsmConfig, QlFsmError, QlFsmEvent, SessionWriteId, }; #[derive(Clone)] @@ -142,6 +143,7 @@ impl QlKem for TestCrypto { struct Node { fsm: QlFsm, crypto: TestCrypto, + events: VecDeque, } struct Harness { @@ -184,10 +186,12 @@ impl Harness { a: Node { fsm: QlFsm::new(config_a, identity_a.clone(), time), crypto: TestCrypto::new(1), + events: Default::default(), }, b: Node { fsm: QlFsm::new(config_b, identity_b.clone(), time), crypto: TestCrypto::new(2), + events: Default::default(), }, }; @@ -197,8 +201,6 @@ impl Harness { if know_b { harness.b.fsm.bind_peer(identity_a.bundle()); } - while harness.a.fsm.take_next_event().is_some() {} - while harness.b.fsm.take_next_event().is_some() {} harness } @@ -277,17 +279,65 @@ impl Harness { self.a.fsm.take_next_write(self.time(), &self.a.crypto) } + fn connect_ik_a(&mut self) -> Result<(), QlFsmError> { + let time = self.time(); + let Node { + fsm, + crypto, + events, + } = &mut self.a; + fsm.connect_ik(time, crypto, |event| events.push_back(event)) + } + + fn connect_ik_b(&mut self) -> Result<(), QlFsmError> { + let time = self.time(); + let Node { + fsm, + crypto, + events, + } = &mut self.b; + fsm.connect_ik(time, crypto, |event| events.push_back(event)) + } + + fn connect_kk_a(&mut self) -> Result<(), QlFsmError> { + let time = self.time(); + let Node { + fsm, + crypto, + events, + } = &mut self.a; + fsm.connect_kk(time, crypto, |event| events.push_back(event)) + } + + fn connect_kk_b(&mut self) -> Result<(), QlFsmError> { + let time = self.time(); + let Node { + fsm, + crypto, + events, + } = &mut self.b; + fsm.connect_kk(time, crypto, |event| events.push_back(event)) + } + fn deliver_to_a(&mut self, record: Vec) { - self.a - .fsm - .receive(self.time(), record, &self.a.crypto) + let time = self.time(); + let Node { + fsm, + crypto, + events, + } = &mut self.a; + fsm.receive(time, record, crypto, |event| events.push_back(event)) .unwrap(); } fn deliver_to_b(&mut self, record: Vec) { - self.b - .fsm - .receive(self.time(), record, &self.b.crypto) + let time = self.time(); + let Node { + fsm, + crypto, + events, + } = &mut self.b; + fsm.receive(time, record, crypto, |event| events.push_back(event)) .unwrap(); } @@ -299,6 +349,34 @@ impl Harness { self.a.fsm.reject_session_write(write_id); } + fn on_timer_a(&mut self) { + let time = self.time(); + let Node { fsm, events, .. } = &mut self.a; + fsm.on_timer(time, |event| events.push_back(event)); + } + + fn on_timer_b(&mut self) { + let time = self.time(); + let Node { fsm, events, .. } = &mut self.b; + fsm.on_timer(time, |event| events.push_back(event)); + } + + fn take_event_a(&mut self) -> Option { + self.a.events.pop_front() + } + + fn take_event_b(&mut self) -> Option { + self.b.events.pop_front() + } + + fn drain_events_a(&mut self) -> Vec { + self.a.events.drain(..).collect() + } + + fn drain_events_b(&mut self) -> Vec { + self.b.events.drain(..).collect() + } + fn pump(&mut self) { for _ in 0..128 { let mut progressed = false; diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 7695db03..6af5524e 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -3,7 +3,7 @@ use std::time::Duration; use ql_wire::{SessionClose, StreamId}; use super::*; -use crate::{state::LinkState, QlFsmError, QlFsmEvent, QlSessionEvent}; +use crate::{state::LinkState, PeerStatus, QlFsmError, QlFsmEvent}; fn read_stream_all(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { let mut out = Vec::new(); @@ -31,21 +31,18 @@ fn connected_fsms_deliver_stream_data() { harness.pump(); + assert_eq!(harness.take_event_b(), Some(QlFsmEvent::Opened(stream_id))); assert_eq!( - harness.b.fsm.take_next_session_event(), - Some(QlSessionEvent::Opened(stream_id)) - ); - assert_eq!( - harness.b.fsm.take_next_session_event(), - Some(QlSessionEvent::Readable(stream_id)) + harness.take_event_b(), + Some(QlFsmEvent::Readable(stream_id)) ); assert_eq!( read_stream_all(&mut harness.b.fsm, stream_id), b"hello".to_vec() ); assert_eq!( - harness.b.fsm.take_next_session_event(), - Some(QlSessionEvent::Finished(stream_id)) + harness.take_event_b(), + Some(QlFsmEvent::Finished(stream_id)) ); } @@ -63,7 +60,7 @@ fn session_retransmit_uses_new_record_seq() { decrypt_record(&harness.b.crypto, &first, &first_transport.rx_key); harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); - harness.a.fsm.on_timer(harness.time()); + harness.on_timer_a(); let retried = harness.next_outbound_a().unwrap(); let (retried_header, retried_record) = @@ -74,17 +71,14 @@ fn session_retransmit_uses_new_record_seq() { harness.deliver_to_b(retried); harness.advance(config.session_record_ack_delay); - harness.a.fsm.on_timer(harness.time()); - harness.b.fsm.on_timer(harness.time()); + harness.on_timer_a(); + harness.on_timer_b(); harness.pump(); + assert_eq!(harness.take_event_b(), Some(QlFsmEvent::Opened(stream_id))); assert_eq!( - harness.b.fsm.take_next_session_event(), - Some(QlSessionEvent::Opened(stream_id)) - ); - assert_eq!( - harness.b.fsm.take_next_session_event(), - Some(QlSessionEvent::Readable(stream_id)) + harness.take_event_b(), + Some(QlFsmEvent::Readable(stream_id)) ); assert_eq!( read_stream_all(&mut harness.b.fsm, stream_id), @@ -92,7 +86,7 @@ fn session_retransmit_uses_new_record_seq() { ); harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); - harness.a.fsm.on_timer(harness.time()); + harness.on_timer_a(); assert!(harness.next_outbound_a().is_none()); } @@ -125,24 +119,24 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { harness.pump(); assert_eq!( - harness.a.fsm.take_next_session_event(), - Some(QlSessionEvent::Opened(stream_id_b)) + harness.take_event_a(), + Some(QlFsmEvent::Opened(stream_id_b)) ); assert_eq!( - harness.a.fsm.take_next_session_event(), - Some(QlSessionEvent::Readable(stream_id_b)) + harness.take_event_a(), + Some(QlFsmEvent::Readable(stream_id_b)) ); assert_eq!( read_stream_all(&mut harness.a.fsm, stream_id_b), b"from-b".to_vec() ); assert_eq!( - harness.b.fsm.take_next_session_event(), - Some(QlSessionEvent::Opened(stream_id_a)) + harness.take_event_b(), + Some(QlFsmEvent::Opened(stream_id_a)) ); assert_eq!( - harness.b.fsm.take_next_session_event(), - Some(QlSessionEvent::Readable(stream_id_a)) + harness.take_event_b(), + Some(QlFsmEvent::Readable(stream_id_a)) ); assert_eq!( read_stream_all(&mut harness.b.fsm, stream_id_a), @@ -216,13 +210,10 @@ fn returned_session_write_is_reissued_with_new_record_seq() { harness.deliver_to_b(record); harness.pump(); + assert_eq!(harness.take_event_b(), Some(QlFsmEvent::Opened(stream_id))); assert_eq!( - harness.b.fsm.take_next_session_event(), - Some(QlSessionEvent::Opened(stream_id)) - ); - assert_eq!( - harness.b.fsm.take_next_session_event(), - Some(QlSessionEvent::Readable(stream_id)) + harness.take_event_b(), + Some(QlFsmEvent::Readable(stream_id)) ); assert_eq!( read_stream_all(&mut harness.b.fsm, stream_id), @@ -245,12 +236,12 @@ fn unconfirmed_session_write_does_not_start_retransmit_timer() { let (first_header, first) = decrypt_record(&harness.b.crypto, &record, &session_key); harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); - harness.a.fsm.on_timer(harness.time()); + harness.on_timer_a(); assert!(harness.next_write_a().is_none()); harness.confirm_write_a(id); harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); - harness.a.fsm.on_timer(harness.time()); + harness.on_timer_a(); let write = harness.next_write_a().unwrap(); let record = write.record; @@ -275,13 +266,13 @@ fn ack_frame_releases_stream_capacity_and_emits_writable() { let record = harness.next_outbound_a().unwrap(); harness.deliver_to_b(record); harness.advance(config.session_record_ack_delay); - harness.a.fsm.on_timer(harness.time()); - harness.b.fsm.on_timer(harness.time()); + harness.on_timer_a(); + harness.on_timer_b(); harness.pump(); assert_eq!( - harness.a.fsm.take_next_session_event(), - Some(QlSessionEvent::Writable(stream_id)) + harness.take_event_a(), + Some(QlFsmEvent::Writable(stream_id)) ); } @@ -295,16 +286,7 @@ fn kill_session_disconnects_locally() { .kill_session(ql_wire::SessionCloseCode::CANCELLED); assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); - assert_eq!( - harness.a.fsm.take_next_session_event(), - Some(QlSessionEvent::SessionClosed(SessionClose { - code: ql_wire::SessionCloseCode::CANCELLED - })) - ); - assert!(matches!( - harness.a.fsm.take_next_event(), - Some(QlFsmEvent::PeerStatusChanged { .. }) - )); + assert!(harness.drain_events_a().is_empty()); } #[test] @@ -318,7 +300,7 @@ fn session_records_contain_ack_frames_after_delivery() { let data = harness.next_outbound_a().unwrap(); harness.deliver_to_b(data); harness.advance(config.session_record_ack_delay); - harness.b.fsm.on_timer(harness.time()); + harness.on_timer_b(); let ack = harness.next_outbound_b().unwrap(); let session_key = harness.a.fsm.state.link.transport().unwrap().rx_key.clone(); @@ -342,11 +324,7 @@ fn first_stream_data_uses_negotiated_initial_peer_credit() { }, ); - harness - .a - .fsm - .connect_ik(harness.time(), &harness.a.crypto) - .unwrap(); + harness.connect_ik_a().unwrap(); let ik1 = harness.next_outbound_a().unwrap(); harness.deliver_to_b(ik1); let ik2 = harness.next_outbound_b().unwrap(); @@ -364,3 +342,25 @@ fn first_stream_data_uses_negotiated_initial_peer_credit() { [ql_wire::SessionFrame::StreamData(frame)] if frame.stream_id == stream_id && frame.bytes.as_slice() == b"hel" )); } + +#[test] +fn session_timeout_emits_close_before_disconnect() { + let config = QlFsmConfig { + session_peer_timeout: Duration::from_millis(30), + ..QlFsmConfig::default() + }; + let mut harness = Harness::connected(config); + + harness.advance(config.session_peer_timeout); + harness.on_timer_a(); + + assert_eq!( + harness.drain_events_a(), + vec![ + QlFsmEvent::SessionClosed(SessionClose { + code: ql_wire::SessionCloseCode::TIMEOUT, + }), + QlFsmEvent::PeerStatusChanged(PeerStatus::Disconnected), + ] + ); +} From 23d8c4f719abc568ae980f435c676d7cc286e9db Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 07:34:47 -0400 Subject: [PATCH 102/304] ql-fsm: clean up useless bundle clone --- ql-fsm/src/implementation/handshake/ik.rs | 4 ++-- ql-fsm/src/implementation/handshake/kk.rs | 4 ++-- ql-fsm/src/implementation/handshake/mod.rs | 9 +++++---- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/ql-fsm/src/implementation/handshake/ik.rs b/ql-fsm/src/implementation/handshake/ik.rs index 02eea0d4..b785da24 100644 --- a/ql-fsm/src/implementation/handshake/ik.rs +++ b/ql-fsm/src/implementation/handshake/ik.rs @@ -67,7 +67,7 @@ pub fn handle_ik1( handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; let outbound = handshake.write_2(crypto, message.meta)?; let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); - finish_handshake(fsm, transport, &remote_bundle, emit)?; + finish_handshake(fsm, transport, remote_bundle, emit)?; fsm.state.handshake = None; enqueue_handshake(fsm, QlHandshakeRecord::Ik2(outbound)); Ok(()) @@ -98,7 +98,7 @@ pub fn handle_ik2( }; let (transport, remote_bundle) = SessionTransport::from_finalized(state.handshake.finalize(crypto)?); - finish_handshake(fsm, transport, &remote_bundle, emit) + finish_handshake(fsm, transport, remote_bundle, emit) } pub fn should_ignore_inbound(fsm: &QlFsm, message: &Ik1) -> bool { diff --git a/ql-fsm/src/implementation/handshake/kk.rs b/ql-fsm/src/implementation/handshake/kk.rs index b5d877c7..bf19dcb9 100644 --- a/ql-fsm/src/implementation/handshake/kk.rs +++ b/ql-fsm/src/implementation/handshake/kk.rs @@ -66,7 +66,7 @@ pub fn handle_kk1( handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; let outbound = handshake.write_2(crypto, message.meta)?; let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); - finish_handshake(fsm, transport, &remote_bundle, emit)?; + finish_handshake(fsm, transport, remote_bundle, emit)?; fsm.state.handshake = None; enqueue_handshake(fsm, QlHandshakeRecord::Kk2(outbound)); Ok(()) @@ -97,7 +97,7 @@ pub fn handle_kk2( }; let (transport, remote_bundle) = SessionTransport::from_finalized(state.handshake.finalize(crypto)?); - finish_handshake(fsm, transport, &remote_bundle, emit) + finish_handshake(fsm, transport, remote_bundle, emit) } pub fn should_ignore_inbound(fsm: &QlFsm, message: &Kk1) -> bool { diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs index 4e9d9a0d..233ce82a 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -98,22 +98,23 @@ pub fn next_handshake_deadline(fsm: &QlFsm) -> Option { pub fn finish_handshake( fsm: &mut QlFsm, transport: SessionTransport, - remote_bundle: &wire::PeerBundle, + remote_bundle: wire::PeerBundle, emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { + let xid = remote_bundle.xid; if let Some(peer) = fsm.state.peer.as_ref() { - if peer != remote_bundle { + if peer != &remote_bundle { return Err(QlFsmError::InvalidPayload); } } else { - fsm.state.peer = Some(remote_bundle.clone()); + fsm.state.peer = Some(remote_bundle); emit(QlFsmEvent::NewPeer); } let config = &fsm.config; let session = SessionFsm::new( SessionFsmConfig { - local_parity: StreamParity::for_local(fsm.identity.xid, remote_bundle.xid), + local_parity: StreamParity::for_local(fsm.identity.xid, xid), record_max_size: config.session_record_max_size, ack_delay: config.session_record_ack_delay, retransmit_timeout: config.session_record_retransmit_timeout, From 6911842c95e138871b341932b081b38fc24bde58 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 10:02:39 -0400 Subject: [PATCH 103/304] ql-fsm: fixes --- ql-fsm/src/implementation/core.rs | 2 +- ql-fsm/src/session/mod.rs | 56 +++++++++------ ql-fsm/src/session/stream_tx.rs | 115 ++++++++++++++++++++++++++---- ql-fsm/src/session/tests.rs | 43 ++++++++++- 4 files changed, 178 insertions(+), 38 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index 39e2f7b4..40b70640 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -106,7 +106,7 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option { - if self.handle_stream_window(&frame, &mut emit).is_err() { - return; - } - } + SessionFrame::StreamWindow(frame) => self.handle_stream_window(&frame, &mut emit), SessionFrame::StreamClose(frame) => { if self.handle_stream_close(&frame, &mut emit).is_err() { return; @@ -403,15 +399,24 @@ impl SessionFsm { pub fn take_next_write( &mut self, now: Instant, - ) -> Option<(u64, RecordSeq, SessionRecordBuilder)> { + ) -> Option<(Option, RecordSeq, SessionRecordBuilder)> { self.state.now = now; self.collect_timeouts(); let (builder, outbound) = self.build_next_record()?; - let write_id = self.state.next_write_id; - self.state.next_write_id = self.state.next_write_id.wrapping_add(1); let seq = outbound.seq; - self.state.tracked_records.insert(write_id, outbound); + + let should_track = outbound.ping_included + || !outbound.window_updates.is_empty() + || !outbound.frames.is_empty(); + + let write_id = should_track.then(|| { + let write_id = self.state.next_write_id; + self.state.next_write_id = self.state.next_write_id.wrapping_add(1); + self.state.tracked_records.insert(write_id, outbound); + write_id + }); + Some((write_id, seq, builder)) } @@ -648,6 +653,9 @@ impl SessionFsm { Entry::Occupied(entry) => entry.into_mut(), Entry::Vacant(entry) => { if !self.config.local_parity.remote().matches(stream_id) { + if self.local_stream_was_opened(stream_id) { + return Ok(()); + } self.fail_session( SessionClose { code: SessionCloseCode::PROTOCOL, @@ -715,19 +723,9 @@ impl SessionFsm { } } - fn handle_stream_window( - &mut self, - frame: &StreamWindow, - emit: &mut impl FnMut(SessionEvent), - ) -> Result<(), ()> { + fn handle_stream_window(&mut self, frame: &StreamWindow, emit: &mut impl FnMut(SessionEvent)) { let Some(stream) = self.state.streams.get_mut(&frame.stream_id) else { - self.fail_session( - SessionClose { - code: SessionCloseCode::PROTOCOL, - }, - emit, - ); - return Err(()); + return; }; let was_full = stream.send_capacity(self.config.stream_send_buffer_size) == 0; @@ -737,7 +735,6 @@ impl SessionFsm { if was_full && stream.send_capacity(self.config.stream_send_buffer_size) > 0 { emit(SessionEvent::Writable(frame.stream_id)); } - Ok(()) } fn handle_stream_close( @@ -749,6 +746,9 @@ impl SessionFsm { Entry::Occupied(_) => false, Entry::Vacant(entry) => { if !self.config.local_parity.remote().matches(frame.stream_id) { + if self.local_stream_was_opened(frame.stream_id) { + return Ok(()); + } self.fail_session( SessionClose { code: SessionCloseCode::PROTOCOL, @@ -824,6 +824,17 @@ impl SessionFsm { matches!(target, CloseTarget::Both) || role.outbound_target() == target } + /// Returns true if this locally-opened stream id was already reaped, so stale peer frames for it can be ignored. + fn local_stream_was_opened(&self, stream_id: StreamId) -> bool { + self.config.local_parity.matches(stream_id) + && stream_id.0 + < self + .config + .local_parity + .make_stream_id(self.state.next_stream_ordinal) + .0 + } + fn stream_is_reapable(&self, stream_id: StreamId, stream: &StreamState) -> bool { let tracked_refs_stream = self.state.tracked_records.values().any(|record| { record.window_updates.iter().any(|(id, _)| *id == stream_id) @@ -839,6 +850,7 @@ impl SessionFsm { if !stream.tx.is_empty() || stream.pending_close.is_some() + || stream.pending_window || stream.readable_bytes() > 0 || stream.rx.buffered_end_offset() > stream.rx.start_offset() { diff --git a/ql-fsm/src/session/stream_tx.rs b/ql-fsm/src/session/stream_tx.rs index 364f9457..8b514d0b 100644 --- a/ql-fsm/src/session/stream_tx.rs +++ b/ql-fsm/src/session/stream_tx.rs @@ -155,7 +155,9 @@ impl StreamTx { self.set_segment_state(range.offset, range.len, SendState::InFlight); if range.fin { if let Some(final_offset) = self.final_offset.as_mut() { - final_offset.state = SendState::InFlight; + if final_offset.state != SendState::Acked { + final_offset.state = SendState::InFlight; + } } } } @@ -164,7 +166,9 @@ impl StreamTx { self.set_segment_state(range.offset, range.len, SendState::Lost); if range.fin { if let Some(final_offset) = self.final_offset.as_mut() { - final_offset.state = SendState::Lost; + if final_offset.state != SendState::Acked { + final_offset.state = SendState::Lost; + } } } } @@ -189,26 +193,45 @@ impl StreamTx { if len == 0 { return; } + let end = offset + len as u64; let Some(index) = self .segments .iter() - .position(|segment| segment.offset == offset && segment.len >= len) + .position(|segment| segment.offset <= offset && end <= segment.end_offset()) else { return; }; - if self.segments[index].len == len { - self.segments[index].state = state; - } else { - let segment = self.segments.remove(index).unwrap(); - self.segments - .insert(index, SendSegment { offset, len, state }); + if self.segments[index].state == SendState::Acked && state != SendState::Acked { + return; + } + + let segment = self.segments.remove(index).unwrap(); + let mut insert_index = index; + + if segment.offset < offset { + self.segments.insert( + insert_index, + SendSegment { + offset: segment.offset, + len: usize::try_from(offset - segment.offset).unwrap(), + state: segment.state, + }, + ); + insert_index += 1; + } + + self.segments + .insert(insert_index, SendSegment { offset, len, state }); + insert_index += 1; + + if end < segment.end_offset() { self.segments.insert( - index + 1, + insert_index, SendSegment { - offset: offset + len as u64, - len: segment.len - len, + offset: end, + len: usize::try_from(segment.end_offset() - end).unwrap(), state: segment.state, }, ); @@ -326,4 +349,72 @@ mod tests { tx.mark_acked(range); assert!(tx.is_empty()); } + + #[test] + fn subrange_updates_split_merged_in_flight_segments() { + let mut tx = StreamTx::new(); + tx.append(b"abcdefghijkl"); + + let first = tx.next_range(4, u64::MAX).unwrap(); + tx.mark_in_flight(first); + let second = tx.next_range(4, u64::MAX).unwrap(); + tx.mark_in_flight(second); + let third = tx.next_range(4, u64::MAX).unwrap(); + tx.mark_in_flight(third); + + tx.mark_lost(StreamTxRange { + offset: 4, + len: 4, + fin: false, + }); + + assert_eq!( + tx.next_range(4, u64::MAX), + Some(StreamTxRange { + offset: 4, + len: 4, + fin: false, + }) + ); + } + + #[test] + fn acked_subrange_is_not_reopened_by_stale_timeout() { + let mut tx = StreamTx::new(); + tx.append(b"abcdefghijklmnop"); + + let first = tx.next_range(4, u64::MAX).unwrap(); + tx.mark_in_flight(first); + let second = tx.next_range(4, u64::MAX).unwrap(); + tx.mark_in_flight(second); + let third = tx.next_range(4, u64::MAX).unwrap(); + tx.mark_in_flight(third); + let fourth = tx.next_range(4, u64::MAX).unwrap(); + tx.mark_in_flight(fourth); + + tx.mark_acked(StreamTxRange { + offset: 4, + len: 4, + fin: false, + }); + tx.mark_lost(StreamTxRange { + offset: 4, + len: 4, + fin: false, + }); + tx.mark_lost(StreamTxRange { + offset: 8, + len: 4, + fin: false, + }); + + assert_eq!( + tx.next_range(4, u64::MAX), + Some(StreamTxRange { + offset: 8, + len: 4, + fin: false, + }) + ); + } } diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 0d2856f0..703d4d17 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -26,7 +26,9 @@ fn read_stream_all(fsm: &mut SessionFsm, stream_id: StreamId) -> Vec { fn next_outbound(fsm: &mut SessionFsm, now: Instant) -> Option<(RecordSeq, SessionRecord)> { let (write_id, seq, builder) = fsm.take_next_write(now)?; - fsm.confirm_write(now, write_id); + if let Some(write_id) = write_id { + fsm.confirm_write(now, write_id); + } Some((seq, SessionRecord::decode(builder.bytes()).unwrap())) } @@ -206,7 +208,10 @@ fn commit_stream_read_is_what_advances_stream_window() { ] ); - let (_first_seq, first) = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); + let (write_id, _first_seq, builder) = + fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); + let first = SessionRecord::decode(builder.bytes()).unwrap(); + assert!(write_id.is_none()); assert!(matches!(first.frames.as_slice(), [SessionFrame::Ack(_)])); let read = fsm @@ -226,6 +231,38 @@ fn commit_stream_read_is_what_advances_stream_window() { )); } +#[test] +fn pure_ack_only_records_are_fire_and_forget() { + let now = Instant::now(); + let config = SessionFsmConfig { + ack_delay: Duration::ZERO, + ..SessionFsmConfig::default() + }; + let retransmit_timeout = config.retransmit_timeout; + let mut fsm = SessionFsm::new(config, now); + let stream_id = StreamId(1); + let record = SessionRecord { + frames: vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: 0, + fin: false, + bytes: b"hi".to_vec(), + })], + }; + + let _ = receive_events(&mut fsm, now, RecordSeq(7), &record); + + let (write_id, _seq, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); + let ack = SessionRecord::decode(builder.bytes()).unwrap(); + assert!(write_id.is_none()); + assert!(matches!(ack.frames.as_slice(), [SessionFrame::Ack(_)])); + + fsm.on_timer(now + retransmit_timeout + Duration::from_millis(1), |_| {}); + assert!(fsm + .take_next_write(now + retransmit_timeout + Duration::from_millis(1)) + .is_none()); +} + #[test] fn inbound_stream_data_emits_opened_and_readable() { let now = Instant::now(); @@ -262,7 +299,7 @@ fn remote_stream_close_is_reliable_and_retried() { .unwrap(); let (write_id, _seq, builder) = fsm.take_next_write(now).unwrap(); - fsm.confirm_write(now, write_id); + fsm.confirm_write(now, write_id.expect("stream close should be tracked")); let first = SessionRecord::decode(builder.bytes()).unwrap(); assert!(matches!( first.frames.as_slice(), From 7a46ceb559868545cf68691198ba6ec8923fef43 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 09:05:37 -0400 Subject: [PATCH 104/304] ql-runtime: update --- ql-runtime/src/command.rs | 11 +- ql-runtime/src/driver.rs | 458 ++++++++++++++++++------------ ql-runtime/src/handle.rs | 30 +- ql-runtime/src/lib.rs | 19 +- ql-runtime/src/platform.rs | 7 +- ql-runtime/src/tests/handshake.rs | 36 +-- ql-runtime/src/tests/mod.rs | 138 ++++++--- ql-runtime/src/tests/stream.rs | 9 +- 8 files changed, 407 insertions(+), 301 deletions(-) diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index 9686e85a..8a2eff0d 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -1,15 +1,13 @@ use crate::{ - wire::{CloseCode, CloseTarget}, - OpenedStreamDelivery, Peer, QlError, StreamId, + wire::{CloseTarget, StreamCloseCode}, + OpenedStreamDelivery, PeerBundle, QlError, StreamId, }; pub(crate) enum RuntimeCommand { BindPeer { - peer: Peer, + peer: PeerBundle, }, - Pair, Connect, - Unpair, OpenStream { request_reader: piper::Reader, start: oneshot::Sender>, @@ -20,8 +18,7 @@ pub(crate) enum RuntimeCommand { CloseStream { stream_id: StreamId, target: CloseTarget, - code: CloseCode, - payload: Vec, + code: StreamCloseCode, }, Incoming(Vec), } diff --git a/ql-runtime/src/driver.rs b/ql-runtime/src/driver.rs index 95b27a30..56e6050a 100644 --- a/ql-runtime/src/driver.rs +++ b/ql-runtime/src/driver.rs @@ -1,18 +1,19 @@ use std::{ - collections::HashMap, + collections::{HashMap, VecDeque}, future::Future, - task::Poll, + task::{Context, Poll, Waker}, time::{Duration, Instant, SystemTime, UNIX_EPOCH}, }; use futures_lite::future::poll_fn; -use ql_fsm::{FsmTime, QlFsm, QlFsmEvent, QlSessionEvent, SessionWriteId}; +use ql_fsm::{FsmTime, QlFsm, QlFsmEvent, SessionWriteId}; +use ql_wire::{CloseTarget, StreamCloseCode, StreamId, XID}; use crate::{ command::RuntimeCommand, handle::{ByteReader, ByteWriter, InboundStream}, platform::{PlatformFuture, QlPlatform}, - CloseCode, CloseTarget, InboundEvent, OpenedStreamDelivery, QlError, Runtime, StreamId, + InboundEvent, OpenedStreamDelivery, QlError, Runtime, }; struct InFlightWrite<'a> { @@ -50,33 +51,14 @@ impl OutboundIo { *self = Self::Closed; } - fn take_pending(&mut self) -> (Option>, bool) { - let Self::Open { - reader, - finish_queued, - } = self - else { - return (None, false); - }; - - let mut drained = None; - let available = reader.len(); - if available > 0 { - let mut bytes = vec![0; available]; - let read = reader.try_drain(&mut bytes); - if read > 0 { - bytes.truncate(read); - drained = Some(bytes); - } - } - - let mut finished = false; - if reader.is_closed() && !*finish_queued { - *finish_queued = true; - finished = true; + fn open_mut(&mut self) -> Option<(&mut piper::Reader, &mut bool)> { + match self { + Self::Open { + reader, + finish_queued, + } => Some((reader, finish_queued)), + Self::Closed => None, } - - (drained, finished) } } @@ -140,6 +122,26 @@ enum DriverStreamIo { } impl DriverStreamIo { + fn new_initiator( + request: piper::Reader, + response: async_channel::Sender, + ) -> Self { + Self::Initiator { + request: OutboundIo::new(request), + response: InboundIo::new(response), + } + } + + fn new_responder( + request: async_channel::Sender, + response: piper::Reader, + ) -> Self { + Self::Responder { + request: InboundIo::new(request), + response: OutboundIo::new(response), + } + } + fn outbound_mut(&mut self) -> &mut OutboundIo { match self { Self::Initiator { request, .. } => request, @@ -170,11 +172,15 @@ impl DriverStreamIo { fn fail_all(&mut self, error: QlError) { match self { - Self::Initiator { request, response } => { + Self::Initiator { + request, response, .. + } => { request.close(); response.fail(error); } - Self::Responder { request, response } => { + Self::Responder { + request, response, .. + } => { request.fail(error); response.close(); } @@ -183,80 +189,69 @@ impl DriverStreamIo { } struct DriverState { - fsm: QlFsm, streams: HashMap, runtime_tx: async_channel::Sender, stream_send_buffer_bytes: usize, max_concurrent_message_writes: usize, + peer_xid: Option, + pending_fsm_events: VecDeque, } impl DriverState { fn drive_command<'a, P: QlPlatform>( &mut self, + fsm: &mut QlFsm, command: RuntimeCommand, platform: &'a P, in_flight: &mut Vec>, ) { match command { RuntimeCommand::BindPeer { peer } => { - self.fsm.bind_peer(peer); - self.finish_step(platform, in_flight); - } - RuntimeCommand::Pair => { - let _ = self.fsm.pair(now(), platform); - self.finish_step(platform, in_flight); + self.peer_xid = Some(peer.xid); + fsm.bind_peer(peer); + self.finish_step(fsm, platform, in_flight); } RuntimeCommand::Connect => { - let _ = self.fsm.connect(now(), platform); - self.finish_step(platform, in_flight); - } - RuntimeCommand::Unpair => { - if let Some(record) = self.fsm.unpair(now(), platform) { - in_flight.push(InFlightWrite { - session_write_id: None, - future: platform.write_message(record.encode()), - }); - } - self.finish_step(platform, in_flight); + let _ = self.with_fsm_events(fsm, platform, |fsm, emit| { + fsm.connect_ik(now(), platform, emit) + }); + self.finish_step(fsm, platform, in_flight); } RuntimeCommand::Incoming(bytes) => { - // TODO: surface these errors somehow? - let _ = self.fsm.receive(now(), bytes, platform); - self.finish_step(platform, in_flight); + let _ = self.with_fsm_events(fsm, platform, |fsm, emit| { + fsm.receive(now(), bytes, platform, emit) + }); + self.finish_step(fsm, platform, in_flight); } RuntimeCommand::OpenStream { request_reader, start, - } => match self.fsm.open_stream().map_err(QlError::from) { + } => match fsm.open_stream().map_err(QlError::from) { Ok(stream_id) => { let (response_tx, response_rx) = async_channel::unbounded(); self.streams.insert( stream_id, - DriverStreamIo::Initiator { - request: OutboundIo::new(request_reader), - response: InboundIo::new(response_tx), - }, + DriverStreamIo::new_initiator(request_reader, response_tx), ); let _ = start.send(Ok(OpenedStreamDelivery { stream_id, response: response_rx, })); - self.poll_stream(stream_id); - self.finish_step(platform, in_flight); + self.poll_stream(fsm, stream_id); + self.finish_step(fsm, platform, in_flight); } Err(error) => { let _ = start.send(Err(error)); } }, RuntimeCommand::PollStream { stream_id } => { - self.poll_stream(stream_id); - self.finish_step(platform, in_flight); + self.poll_stream(fsm, stream_id); + self.finish_step(fsm, platform, in_flight); } RuntimeCommand::CloseStream { stream_id, target, code, - payload, } => { if let Some(stream) = self.streams.get_mut(&stream_id) { if target == CloseTarget::Both || target == stream.inbound_target() { @@ -266,15 +261,16 @@ impl DriverState { stream.outbound_mut().close(); } } - let _ = self.fsm.close_stream(stream_id, target, code, payload); + let _ = fsm.close_stream(stream_id, target, code); self.try_reap_stream(stream_id); - self.finish_step(platform, in_flight); + self.finish_step(fsm, platform, in_flight); } } } fn drive_write_completed<'a, P: QlPlatform>( &mut self, + fsm: &mut QlFsm, session_write_id: Option, result: Result<(), QlError>, platform: &'a P, @@ -282,58 +278,84 @@ impl DriverState { ) { if let Some(write_id) = session_write_id { match result { - Ok(()) => self.fsm.confirm_session_write(now(), write_id), - Err(_) => self.fsm.reject_session_write(write_id), + Ok(()) => fsm.confirm_session_write(now(), write_id), + Err(_) => fsm.reject_session_write(write_id), } } - self.finish_step(platform, in_flight); + self.finish_step(fsm, platform, in_flight); } fn finish_step<'a, P: QlPlatform>( &mut self, + fsm: &mut QlFsm, platform: &'a P, in_flight: &mut Vec>, ) { - loop { - let mut progressed = false; + while self.fill_write_slots(fsm, platform, in_flight) {} + } - progressed |= self.drain_fsm(platform); - progressed |= self.fill_write_slots(platform, in_flight); + fn with_fsm_events( + &mut self, + fsm: &mut QlFsm, + platform: &P, + run: impl FnOnce(&mut QlFsm, &mut dyn FnMut(QlFsmEvent)) -> T, + ) -> T { + let output = { + let pending = &mut self.pending_fsm_events; + let mut emit = |event| pending.push_back(event); + run(fsm, &mut emit) + }; + self.process_pending_fsm_events(fsm, platform); + output + } - if !progressed { - break; - } + fn process_pending_fsm_events(&mut self, fsm: &mut QlFsm, platform: &P) { + while let Some(event) = self.pending_fsm_events.pop_front() { + self.process_fsm_event(fsm, platform, event); } } - fn drain_fsm(&mut self, platform: &P) -> bool { - let mut progressed = false; - - while let Some(event) = self.fsm.take_next_event() { - progressed = true; - match event { - QlFsmEvent::NewPeer(peer) => platform.persist_peer(peer), - QlFsmEvent::ClearPeer => platform.clear_peer(), - QlFsmEvent::PeerStatusChanged { peer, status } => { - platform.handle_peer_status(peer, status) + fn process_fsm_event( + &mut self, + fsm: &mut QlFsm, + platform: &P, + event: QlFsmEvent, + ) { + match event { + QlFsmEvent::NewPeer => { + if let Some(peer) = fsm.peer().cloned() { + self.peer_xid = Some(peer.xid); + platform.persist_peer(peer); } } - } - - while let Some(event) = self.fsm.take_next_session_event() { - progressed = true; - match event { - QlSessionEvent::Opened(stream_id) => self.handle_opened_stream(platform, stream_id), - QlSessionEvent::Readable(stream_id) => self.handle_inbound_readable(stream_id), - QlSessionEvent::Finished(stream_id) => self.handle_inbound_finished(stream_id), - QlSessionEvent::Closed(frame) => self.handle_closed_stream(frame), - QlSessionEvent::WritableClosed(stream_id) => self.handle_writable_closed(stream_id), - QlSessionEvent::Unpaired => self.fail_all_streams(QlError::Cancelled), - QlSessionEvent::SessionClosed(_) => self.fail_all_streams(QlError::SessionClosed), + QlFsmEvent::PeerStatusChanged(status) => { + if self.peer_xid.is_none() { + self.peer_xid = fsm.peer().map(|peer| peer.xid); + } + if let Some(peer) = self.peer_xid { + platform.handle_peer_status(peer, status); + } + } + QlFsmEvent::Opened(stream_id) => { + self.handle_opened_stream(platform, stream_id); + } + QlFsmEvent::Readable(stream_id) => { + self.handle_inbound_readable(fsm, stream_id); } + QlFsmEvent::Writable(stream_id) => { + self.poll_stream(fsm, stream_id); + } + QlFsmEvent::Finished(stream_id) => { + self.handle_inbound_finished(stream_id); + } + QlFsmEvent::Closed(frame) => { + self.handle_closed_stream(frame); + } + QlFsmEvent::WritableClosed(stream_id) => { + self.handle_writable_closed(stream_id); + } + QlFsmEvent::SessionClosed(_) => self.fail_all_streams(QlError::SessionClosed), } - - progressed } fn handle_opened_stream(&mut self, platform: &P, stream_id: StreamId) { @@ -342,10 +364,7 @@ impl DriverState { self.streams.insert( stream_id, - DriverStreamIo::Responder { - request: InboundIo::new(request_tx), - response: OutboundIo::new(response_reader), - }, + DriverStreamIo::new_responder(request_tx, response_reader), ); platform.handle_inbound(InboundStream { @@ -365,36 +384,40 @@ impl DriverState { }); } - fn handle_inbound_readable(&mut self, stream_id: StreamId) { + fn handle_inbound_readable(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { loop { - let max_len = self.fsm.config.session_stream_chunk_size.max(1); - let available = match self.fsm.stream_available_bytes(stream_id) { - Ok(available) => available, - Err(_) => return, + let Some(available) = fsm.stream_available_bytes(stream_id) else { + return; }; if available == 0 { break; } - let mut bytes = vec![0; available.min(max_len)]; - let read = match self.fsm.read_stream(stream_id, &mut bytes) { - Ok(read) => read, - Err(_) => return, + let bytes = { + let Some(chunks) = fsm.stream_read(stream_id) else { + return; + }; + let mut bytes = Vec::with_capacity(available); + for chunk in chunks { + bytes.extend_from_slice(chunk); + } + bytes }; - bytes.truncate(read); + + if bytes.is_empty() { + break; + } let Some(stream) = self.streams.get_mut(&stream_id) else { return; }; let target = stream.inbound_target(); - let should_close = stream.inbound_mut().write_or_close(bytes); - if should_close { - let _ = self - .fsm - .close_stream(stream_id, target, CloseCode::CANCELLED, Vec::new()); + if stream.inbound_mut().write_or_close(bytes.clone()) { + let _ = fsm.close_stream(stream_id, target, StreamCloseCode(0)); self.try_reap_stream(stream_id); break; } + fsm.stream_read_commit(stream_id, bytes.len()).unwrap(); } } @@ -414,11 +437,10 @@ impl DriverState { let error = QlError::StreamClosed { target: frame.target, code: frame.code, - payload: frame.payload.clone(), }; if frame.target == CloseTarget::Both || frame.target == stream.inbound_target() { - stream.inbound_mut().fail(error); + stream.inbound_mut().fail(error.clone()); } if frame.target == CloseTarget::Both || frame.target == stream.outbound_target() { stream.outbound_mut().close(); @@ -443,39 +465,74 @@ impl DriverState { fn fill_write_slots<'a, P: QlPlatform>( &mut self, + fsm: &mut QlFsm, platform: &'a P, in_flight: &mut Vec>, ) -> bool { let mut progressed = false; while in_flight.len() < self.max_concurrent_message_writes { - let Some(write) = self.fsm.take_next_write(now(), platform) else { + let Some(write) = fsm.take_next_write(now(), platform) else { break; }; progressed = true; in_flight.push(InFlightWrite { session_write_id: write.session_write_id, - future: platform.write_message(write.record.encode()), + future: platform.write_message(write.record), }); } progressed } - fn poll_stream(&mut self, stream_id: StreamId) { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - let (bytes, finished) = stream.outbound_mut().take_pending(); - if let Some(bytes) = bytes { - let _ = self.fsm.write_stream(stream_id, bytes); - } - if finished { - let _ = self.fsm.finish_stream(stream_id); - if let Some(stream) = self.streams.get_mut(&stream_id) { - stream.outbound_mut().close(); + fn poll_stream(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { + loop { + let mut should_finish = false; + let progressed = { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + let Some((reader, finish_queued)) = stream.outbound_mut().open_mut() else { + return; + }; + + let ready = with_noop_context(|cx| reader.poll(cx)); + if matches!(ready, Poll::Pending) { + false + } else { + let bytes = reader.peek_buf(); + if bytes.is_empty() { + if reader.is_closed() && reader.len() == 0 && !*finish_queued { + *finish_queued = true; + should_finish = true; + } + false + } else { + let len = bytes.len(); + let accepted = match fsm.write_stream(stream_id, bytes) { + Ok(accepted) => accepted, + Err(_) => 0, + }; + if accepted > 0 { + reader.consume(accepted); + } + accepted > 0 && accepted == len + } + } + }; + + if should_finish { + let _ = fsm.finish_stream(stream_id); + if let Some(stream) = self.streams.get_mut(&stream_id) { + stream.outbound_mut().close(); + } + self.try_reap_stream(stream_id); + break; + } + + if !progressed { + break; } - self.try_reap_stream(stream_id); } } @@ -484,12 +541,12 @@ impl DriverState { .streams .get(&stream_id) .is_some_and(|stream| match stream { - DriverStreamIo::Initiator { request, response } => { - matches!(request, OutboundIo::Closed) && matches!(response, InboundIo::Closed) - } - DriverStreamIo::Responder { request, response } => { - matches!(request, InboundIo::Closed) && matches!(response, OutboundIo::Closed) - } + DriverStreamIo::Initiator { + request, response, .. + } => matches!(request, OutboundIo::Closed) && matches!(response, InboundIo::Closed), + DriverStreamIo::Responder { + request, response, .. + } => matches!(request, InboundIo::Closed) && matches!(response, OutboundIo::Closed), }); if should_reap { self.streams.remove(&stream_id); @@ -548,34 +605,37 @@ impl Runtime

{ let runtime_tx = tx.upgrade().expect("runtime tx"); let mut fsm = QlFsm::new(config.fsm, identity, now()); + let mut peer_xid = None; if let Some(peer) = platform.load_peer().await { + peer_xid = Some(peer.xid); fsm.bind_peer(peer); } let mut state = DriverState { - fsm, streams: HashMap::new(), runtime_tx, stream_send_buffer_bytes: config.stream_send_buffer_bytes, max_concurrent_message_writes: config.max_concurrent_message_writes, + peer_xid, + pending_fsm_events: VecDeque::new(), }; let mut in_flight = Vec::new(); loop { - state.finish_step(&platform, &mut in_flight); + state.finish_step(&mut fsm, &platform, &mut in_flight); if rx.is_closed() && in_flight.is_empty() { break; } - match next_driver_event(&rx, &platform, state.fsm.next_deadline(), &mut in_flight).await - { + match next_driver_event(&rx, &platform, fsm.next_deadline(), &mut in_flight).await { DriverEvent::Command(command) => { - state.drive_command(command, &platform, &mut in_flight) + state.drive_command(&mut fsm, command, &platform, &mut in_flight) } DriverEvent::WriteCompleted { index, result } => { let write = in_flight.swap_remove(index); state.drive_write_completed( + &mut fsm, write.session_write_id, result, &platform, @@ -583,8 +643,10 @@ impl Runtime

{ ); } DriverEvent::TimerExpired => { - state.fsm.on_timer(now()); - state.finish_step(&platform, &mut in_flight); + state.with_fsm_events(&mut fsm, &platform, |fsm, emit| { + fsm.on_timer(now(), emit) + }); + state.finish_step(&mut fsm, &platform, &mut in_flight); } DriverEvent::CommandsClosed => {} } @@ -606,47 +668,82 @@ fn unix_now_secs() -> u64 { .as_secs() } +fn with_noop_context(f: impl FnOnce(&mut Context<'_>) -> T) -> T { + let mut cx = Context::from_waker(Waker::noop()); + f(&mut cx) +} + #[cfg(test)] mod tests { - use ql_fsm::Peer; - use ql_wire::{CloseCode, StreamClose, XID}; + use ql_wire::{ + MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, PeerBundle, QlAead, QlHash, + QlKem, QlRandom, SessionKey, StreamClose, XID, + }; use super::*; use crate::tests::new_identity; struct NoopPlatform; - impl ql_wire::QlCrypto for NoopPlatform { + impl QlRandom for NoopPlatform { fn fill_random_bytes(&self, data: &mut [u8]) { data.fill(0); } + } - fn hash(&self, _parts: &[&[u8]]) -> [u8; 32] { + impl QlHash for NoopPlatform { + fn sha256(&self, _parts: &[&[u8]]) -> [u8; 32] { [0; 32] } + } - fn encrypt_with_aead( + impl QlAead for NoopPlatform { + fn aes256_gcm_encrypt( &self, - _key: &ql_wire::SessionKey, + _key: &SessionKey, _nonce: &ql_wire::Nonce, _aad: &[u8], _buffer: &mut [u8], - ) -> [u8; ql_wire::EncryptedMessage::AUTH_SIZE] { - [0; ql_wire::EncryptedMessage::AUTH_SIZE] + ) -> [u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE] { + [0; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE] } - fn decrypt_with_aead( + fn aes256_gcm_decrypt( &self, - _key: &ql_wire::SessionKey, + _key: &SessionKey, _nonce: &ql_wire::Nonce, _aad: &[u8], _buffer: &mut [u8], - _auth_tag: &[u8; ql_wire::EncryptedMessage::AUTH_SIZE], + _auth_tag: &[u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE], ) -> bool { false } } + impl QlKem for NoopPlatform { + fn mlkem_generate_keypair(&self) -> MlKemKeyPair { + MlKemKeyPair { + private: MlKemPrivateKey::new(Box::new([0; MlKemPrivateKey::SIZE])), + public: MlKemPublicKey::new(Box::new([0; MlKemPublicKey::SIZE])), + } + } + + fn mlkem_encapsulate(&self, _public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { + ( + MlKemCiphertext::new(Box::new([0; MlKemCiphertext::SIZE])), + SessionKey::from_data([0; SessionKey::SIZE]), + ) + } + + fn mlkem_decapsulate( + &self, + _private_key: &MlKemPrivateKey, + _ciphertext: &MlKemCiphertext, + ) -> SessionKey { + SessionKey::from_data([0; SessionKey::SIZE]) + } + } + impl QlPlatform for NoopPlatform { fn write_message(&self, _message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { Box::pin(async { Ok(()) }) @@ -656,33 +753,35 @@ mod tests { Box::pin(async {}) } - fn load_peer(&self) -> PlatformFuture<'_, Option> { + fn load_peer(&self) -> PlatformFuture<'_, Option> { Box::pin(async { None }) } - fn persist_peer(&self, _peer: Peer) {} - - fn clear_peer(&self) {} + fn persist_peer(&self, _peer: PeerBundle) {} fn handle_peer_status(&self, _peer: XID, _status: ql_fsm::PeerStatus) {} fn handle_inbound(&self, _event: InboundStream) {} } - fn new_driver_state() -> DriverState { + fn new_driver_state() -> (DriverState, QlFsm) { let (runtime_tx, _runtime_rx) = async_channel::unbounded(); - DriverState { - fsm: QlFsm::new(ql_fsm::QlFsmConfig::default(), new_identity(7), now()), - streams: HashMap::new(), - runtime_tx, - stream_send_buffer_bytes: 16, - max_concurrent_message_writes: 1, - } + ( + DriverState { + streams: HashMap::new(), + runtime_tx, + stream_send_buffer_bytes: 16, + max_concurrent_message_writes: 1, + peer_xid: None, + pending_fsm_events: VecDeque::new(), + }, + QlFsm::new(ql_fsm::QlFsmConfig::default(), new_identity(7), now()), + ) } #[test] fn handle_inbound_finished_reaps_closed_initiator_stream() { - let mut state = new_driver_state(); + let (mut state, _fsm) = new_driver_state(); let stream_id = StreamId(1); let (response_tx, _response_rx) = async_channel::unbounded(); @@ -701,7 +800,7 @@ mod tests { #[test] fn handle_closed_stream_reaps_when_both_halves_close() { - let mut state = new_driver_state(); + let (mut state, _fsm) = new_driver_state(); let stream_id = StreamId(2); let (request_tx, _request_rx) = async_channel::unbounded(); let (response_reader, _response_writer) = piper::pipe(1); @@ -717,8 +816,7 @@ mod tests { state.handle_closed_stream(StreamClose { stream_id, target: CloseTarget::Both, - code: CloseCode::CANCELLED, - payload: Vec::new(), + code: StreamCloseCode(0), }); assert!(!state.streams.contains_key(&stream_id)); @@ -726,7 +824,7 @@ mod tests { #[test] fn poll_stream_reaps_after_local_finish_when_inbound_is_closed() { - let mut state = new_driver_state(); + let (mut state, mut fsm) = new_driver_state(); let stream_id = StreamId(3); let (request_reader, request_writer) = piper::pipe(1); @@ -739,14 +837,14 @@ mod tests { }, ); - state.poll_stream(stream_id); + state.poll_stream(&mut fsm, stream_id); assert!(!state.streams.contains_key(&stream_id)); } #[test] fn local_close_command_reaps_when_other_half_is_already_closed() { - let mut state = new_driver_state(); + let (mut state, mut fsm) = new_driver_state(); let stream_id = StreamId(4); let (request_reader, _request_writer) = piper::pipe(1); let mut in_flight = Vec::new(); @@ -760,11 +858,11 @@ mod tests { ); state.drive_command( + &mut fsm, RuntimeCommand::CloseStream { stream_id, target: CloseTarget::Origin, - code: CloseCode::CANCELLED, - payload: Vec::new(), + code: StreamCloseCode(0), }, &NoopPlatform, &mut in_flight, diff --git a/ql-runtime/src/handle.rs b/ql-runtime/src/handle.rs index 40707ef7..5b14e373 100644 --- a/ql-runtime/src/handle.rs +++ b/ql-runtime/src/handle.rs @@ -4,8 +4,8 @@ use async_channel::{Receiver, Sender}; use futures_lite::{future::poll_fn, Stream}; use crate::{ - command::RuntimeCommand, CloseCode, CloseTarget, InboundEvent, OpenedStreamDelivery, Peer, - QlError, StreamId, + command::RuntimeCommand, CloseTarget, InboundEvent, OpenedStreamDelivery, PeerBundle, QlError, + StreamCloseCode, StreamId, }; #[derive(Clone)] @@ -114,7 +114,7 @@ impl ByteReader { } } - pub async fn close(mut self, code: CloseCode, payload: Vec) -> Result<(), QlError> { + pub async fn close(mut self, code: StreamCloseCode) -> Result<(), QlError> { if self.finished { return Ok(()); } @@ -124,7 +124,6 @@ impl ByteReader { stream_id: self.stream_id, target: self.target, code, - payload, }) .await .map_err(|_| QlError::Cancelled) @@ -139,8 +138,7 @@ impl Drop for ByteReader { let _ = self.tx.try_send(RuntimeCommand::CloseStream { stream_id: self.stream_id, target: self.target, - code: CloseCode::CANCELLED, - payload: Vec::new(), + code: StreamCloseCode(0), }); } } @@ -201,7 +199,7 @@ impl ByteWriter { self.poll_runtime() } - pub async fn close(mut self, code: CloseCode, payload: Vec) -> Result<(), QlError> { + pub async fn close(mut self, code: StreamCloseCode) -> Result<(), QlError> { if self.writer.take().is_none() { return Ok(()); } @@ -210,7 +208,6 @@ impl ByteWriter { stream_id: self.stream_id, target: self.target, code, - payload, }) .await .map_err(|_| QlError::Cancelled) @@ -225,35 +222,22 @@ impl Drop for ByteWriter { let _ = self.tx.try_send(RuntimeCommand::CloseStream { stream_id: self.stream_id, target: self.target, - code: CloseCode::CANCELLED, - payload: Vec::new(), + code: StreamCloseCode(0), }); } } impl RuntimeHandle { - pub fn bind_peer(&self, peer: Peer) { + pub fn bind_peer(&self, peer: PeerBundle) { self.send(RuntimeCommand::BindPeer { peer }) } - pub fn pair(&self) -> Result<(), QlError> { - self.tx - .send_blocking(RuntimeCommand::Pair) - .map_err(|_| QlError::Cancelled) - } - pub fn connect(&self) -> Result<(), QlError> { self.tx .send_blocking(RuntimeCommand::Connect) .map_err(|_| QlError::Cancelled) } - pub fn unpair(&self) -> Result<(), QlError> { - self.tx - .send_blocking(RuntimeCommand::Unpair) - .map_err(|_| QlError::Cancelled) - } - pub fn send_incoming(&self, bytes: Vec) { self.send(RuntimeCommand::Incoming(bytes)) } diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index fcc93cc3..dcbffef6 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -1,6 +1,9 @@ pub use handle::{ByteReader, ByteWriter, InboundStream, OutboundStream, RuntimeHandle}; -pub use ql_fsm::{Peer, PeerStatus, QlFsmConfig, QlFsmError, SessionWriteId}; -pub use ql_wire::{self as wire, CloseCode, CloseTarget, QlIdentity, StreamId, XID}; +pub use ql_fsm::{PeerStatus, QlFsmConfig, QlFsmError, SessionWriteId}; +pub use ql_wire::{ + self as wire, CloseTarget, PeerBundle, QlIdentity, SessionCloseCode, StreamCloseCode, StreamId, + XID, +}; pub(crate) mod command; pub(crate) mod driver; @@ -20,8 +23,8 @@ use self::platform::QlPlatform; pub enum QlError { #[error("invalid payload")] InvalidPayload, - #[error("invalid signature")] - InvalidSignature, + #[error("invalid state")] + InvalidState, #[error("expired")] Expired, #[error("decryption failed")] @@ -32,6 +35,8 @@ pub enum QlError { MissingStream, #[error("stream is not writable")] NotWritable, + #[error("invalid read")] + InvalidRead, #[error("session is closed")] SessionClosed, #[error("no peer bound")] @@ -43,8 +48,7 @@ pub enum QlError { #[error("stream closed {code:?}")] StreamClosed { target: CloseTarget, - code: CloseCode, - payload: Vec, + code: StreamCloseCode, }, #[error("cancelled")] Cancelled, @@ -54,12 +58,13 @@ impl From for QlError { fn from(value: QlFsmError) -> Self { match value { QlFsmError::InvalidPayload => Self::InvalidPayload, - QlFsmError::InvalidSignature => Self::InvalidSignature, + QlFsmError::InvalidState => Self::InvalidState, QlFsmError::Expired => Self::Expired, QlFsmError::DecryptFailed => Self::DecryptFailed, QlFsmError::InvalidXid => Self::InvalidXid, QlFsmError::MissingStream => Self::MissingStream, QlFsmError::NotWritable => Self::NotWritable, + QlFsmError::InvalidRead => Self::InvalidRead, QlFsmError::SessionClosed => Self::SessionClosed, QlFsmError::NoPeerBound => Self::NoPeerBound, QlFsmError::NoSession => Self::NoSession, diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs index bb690242..7bdeaebb 100644 --- a/ql-runtime/src/platform.rs +++ b/ql-runtime/src/platform.rs @@ -2,7 +2,7 @@ use std::{future::Future, pin::Pin, time::Duration}; use ql_wire::QlCrypto; -use crate::{Peer, PeerStatus, QlError, XID}; +use crate::{PeerBundle, PeerStatus, QlError, XID}; pub type PlatformFuture<'a, T> = Pin + 'a>>; @@ -10,9 +10,8 @@ pub trait QlPlatform: QlCrypto { fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>>; fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()>; - fn load_peer(&self) -> PlatformFuture<'_, Option>; - fn persist_peer(&self, peer: Peer); - fn clear_peer(&self); + fn load_peer(&self) -> PlatformFuture<'_, Option>; + fn persist_peer(&self, peer: PeerBundle); fn handle_peer_status(&self, peer: XID, status: PeerStatus); fn handle_inbound(&self, event: super::InboundStream); diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index 7db4cef6..8b5e5f16 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -30,11 +30,11 @@ async fn connect_round_trip_changes_peer_status() { } #[tokio::test(flavor = "current_thread")] -async fn opening_stream_auto_connects() { +async fn opening_stream_requires_connection() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); + let (platform_a, _outbound_a, _status_a) = TestPlatform::new(1); + let (platform_b, _outbound_b, _status_b, _inbound_b) = TestPlatform::new_with_inbound(2); let identity_a = new_identity(11); let identity_b = new_identity(73); @@ -44,33 +44,11 @@ async fn opening_stream_auto_connects() { tokio::task::spawn_local(async move { runtime_a.run().await }); tokio::task::spawn_local(async move { runtime_b.run().await }); - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - - let responder = tokio::task::spawn_local(async move { - let stream = inbound_b.recv().await.unwrap(); - let request = read_all(stream.request).await.unwrap(); - stream.response.finish().await.unwrap(); - request - }); - - let mut stream = handle_a.open_stream().await.unwrap(); - stream.request.write_all(b"auto-connect").await.unwrap(); - stream.request.finish().await.unwrap(); - assert_eq!(stream.response.next_chunk().await.unwrap(), None); - - assert_eq!( - tokio::time::timeout(Duration::from_secs(2), responder) - .await - .unwrap() - .unwrap(), - b"auto-connect".to_vec() - ); - - await_status(&status_a, identity_b.xid, PeerStage::Connected).await; - await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + assert!(matches!( + handle_a.open_stream().await, + Err(QlError::NoSession) + )); }) .await; } diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index dfcda3a3..47a56c32 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -11,14 +11,15 @@ use std::{ use async_channel::{Receiver, Sender}; use libcrux_aesgcm::AesGcm256Key; use ql_wire::{ - generate_ml_dsa_keypair, generate_ml_kem_keypair, EncryptedMessage, Nonce, QlCrypto, - QlIdentity, QlPayload, QlRecord, SessionKey, XID, + generate_identity, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, + PeerBundle, QlAead, QlHash, QlIdentity, QlKem, QlRandom, RecordHeader, RecordType, SessionKey, + WireParse, XID, }; use sha2::{Digest, Sha256}; use tokio::task::LocalSet; use crate::{ - new_runtime, platform::PlatformFuture, InboundStream, Peer, PeerStatus, QlError, QlFsmConfig, + new_runtime, platform::PlatformFuture, InboundStream, PeerStatus, QlError, QlFsmConfig, RuntimeConfig, RuntimeHandle, }; @@ -27,13 +28,11 @@ mod heartbeat; #[cfg(feature = "rpc")] mod rpc; mod stream; -mod unpair; #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum PeerStage { Disconnected, Initiator, - Responder, Connected, } @@ -76,31 +75,35 @@ impl DeterministicCrypto { } } -impl QlCrypto for DeterministicCrypto { +impl QlRandom for DeterministicCrypto { fn fill_random_bytes(&self, data: &mut [u8]) { let value = self.seed.wrapping_add(self.counter.get()); self.counter.set(self.counter.get().wrapping_add(1)); data.fill(value); } +} - fn hash(&self, parts: &[&[u8]]) -> [u8; 32] { +impl QlHash for DeterministicCrypto { + fn sha256(&self, parts: &[&[u8]]) -> [u8; 32] { let mut hasher = Sha256::new(); for part in parts { hasher.update(part); } hasher.finalize().into() } +} - fn encrypt_with_aead( +impl QlAead for DeterministicCrypto { + fn aes256_gcm_encrypt( &self, key: &SessionKey, nonce: &Nonce, aad: &[u8], buffer: &mut [u8], - ) -> [u8; EncryptedMessage::AUTH_SIZE] { + ) -> [u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE] { let key: AesGcm256Key = (*key.data()).into(); let plaintext = buffer.to_vec(); - let mut auth = [0u8; EncryptedMessage::AUTH_SIZE]; + let mut auth = [0u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE]; key.encrypt( buffer, (&mut auth).into(), @@ -112,13 +115,13 @@ impl QlCrypto for DeterministicCrypto { auth } - fn decrypt_with_aead( + fn aes256_gcm_decrypt( &self, key: &SessionKey, nonce: &Nonce, aad: &[u8], buffer: &mut [u8], - auth_tag: &[u8; EncryptedMessage::AUTH_SIZE], + auth_tag: &[u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE], ) -> bool { let key: AesGcm256Key = (*key.data()).into(); let ciphertext = buffer.to_vec(); @@ -127,6 +130,35 @@ impl QlCrypto for DeterministicCrypto { } } +impl QlKem for DeterministicCrypto { + fn mlkem_generate_keypair(&self) -> MlKemKeyPair { + let data = Box::new([self.seed; MlKemPublicKey::SIZE]); + MlKemKeyPair { + private: MlKemPrivateKey::new(Box::new([self.seed; MlKemPrivateKey::SIZE])), + public: MlKemPublicKey::new(data), + } + } + + fn mlkem_encapsulate(&self, public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { + let mut secret = [0u8; SessionKey::SIZE]; + secret.copy_from_slice(&public_key.as_bytes()[..SessionKey::SIZE]); + ( + MlKemCiphertext::new(Box::new([self.seed; MlKemCiphertext::SIZE])), + SessionKey::from_data(secret), + ) + } + + fn mlkem_decapsulate( + &self, + private_key: &MlKemPrivateKey, + _ciphertext: &MlKemCiphertext, + ) -> SessionKey { + let mut secret = [0u8; SessionKey::SIZE]; + secret.copy_from_slice(&private_key.as_bytes()[..SessionKey::SIZE]); + SessionKey::from_data(secret) + } +} + struct TestPlatform { outbound: Sender>, status: Sender, @@ -206,32 +238,36 @@ impl TestPlatform { } } -impl QlCrypto for TestPlatform { +impl QlRandom for TestPlatform { fn fill_random_bytes(&self, data: &mut [u8]) { let value = self .nonce_seed .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); data.fill(value); } +} - fn hash(&self, parts: &[&[u8]]) -> [u8; 32] { +impl QlHash for TestPlatform { + fn sha256(&self, parts: &[&[u8]]) -> [u8; 32] { let mut hasher = Sha256::new(); for part in parts { hasher.update(part); } hasher.finalize().into() } +} - fn encrypt_with_aead( +impl QlAead for TestPlatform { + fn aes256_gcm_encrypt( &self, key: &SessionKey, nonce: &Nonce, aad: &[u8], buffer: &mut [u8], - ) -> [u8; EncryptedMessage::AUTH_SIZE] { + ) -> [u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE] { let key: AesGcm256Key = (*key.data()).into(); let plaintext = buffer.to_vec(); - let mut auth = [0u8; EncryptedMessage::AUTH_SIZE]; + let mut auth = [0u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE]; key.encrypt( buffer, (&mut auth).into(), @@ -243,13 +279,13 @@ impl QlCrypto for TestPlatform { auth } - fn decrypt_with_aead( + fn aes256_gcm_decrypt( &self, key: &SessionKey, nonce: &Nonce, aad: &[u8], buffer: &mut [u8], - auth_tag: &[u8; EncryptedMessage::AUTH_SIZE], + auth_tag: &[u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE], ) -> bool { let key: AesGcm256Key = (*key.data()).into(); let ciphertext = buffer.to_vec(); @@ -258,6 +294,35 @@ impl QlCrypto for TestPlatform { } } +impl QlKem for TestPlatform { + fn mlkem_generate_keypair(&self) -> MlKemKeyPair { + let byte = self.nonce_seed; + MlKemKeyPair { + private: MlKemPrivateKey::new(Box::new([byte; MlKemPrivateKey::SIZE])), + public: MlKemPublicKey::new(Box::new([byte; MlKemPublicKey::SIZE])), + } + } + + fn mlkem_encapsulate(&self, public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { + let mut secret = [0u8; SessionKey::SIZE]; + secret.copy_from_slice(&public_key.as_bytes()[..SessionKey::SIZE]); + ( + MlKemCiphertext::new(Box::new([self.nonce_seed; MlKemCiphertext::SIZE])), + SessionKey::from_data(secret), + ) + } + + fn mlkem_decapsulate( + &self, + private_key: &MlKemPrivateKey, + _ciphertext: &MlKemCiphertext, + ) -> SessionKey { + let mut secret = [0u8; SessionKey::SIZE]; + secret.copy_from_slice(&private_key.as_bytes()[..SessionKey::SIZE]); + SessionKey::from_data(secret) + } +} + impl crate::platform::QlPlatform for TestPlatform { fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { let outbound = self.outbound.clone(); @@ -302,19 +367,16 @@ impl crate::platform::QlPlatform for TestPlatform { Box::pin(tokio::time::sleep(duration)) } - fn load_peer(&self) -> PlatformFuture<'_, Option> { + fn load_peer(&self) -> PlatformFuture<'_, Option> { Box::pin(async { None }) } - fn persist_peer(&self, _peer: Peer) {} - - fn clear_peer(&self) {} + fn persist_peer(&self, _peer: PeerBundle) {} fn handle_peer_status(&self, peer: XID, status: PeerStatus) { let stage = match status { PeerStatus::Disconnected => PeerStage::Disconnected, PeerStatus::Initiator => PeerStage::Initiator, - PeerStatus::Responder => PeerStage::Responder, PeerStatus::Connected => PeerStage::Connected, }; let _ = self.status.try_send(StatusEvent { peer, stage }); @@ -328,30 +390,14 @@ impl crate::platform::QlPlatform for TestPlatform { } fn is_encrypted_payload(bytes: &[u8]) -> bool { - QlRecord::decode(bytes) + RecordHeader::parse_prefix(bytes) .ok() - .is_some_and(|record| matches!(record.payload, QlPayload::Session(_))) + .is_some_and(|header| header.record_type == RecordType::Session) } pub(crate) fn new_identity(seed: u8) -> QlIdentity { let crypto = DeterministicCrypto::new(seed); - let (signing_private, signing_public) = generate_ml_dsa_keypair(&crypto); - let (encapsulation_private, encapsulation_public) = generate_ml_kem_keypair(&crypto); - QlIdentity::new( - XID([seed; XID::SIZE]), - signing_private, - signing_public, - encapsulation_private, - encapsulation_public, - ) -} - -fn peer_from_identity(identity: &QlIdentity) -> Peer { - Peer { - xid: identity.xid, - signing_key: identity.signing_public_key.clone(), - encapsulation_key: identity.encapsulation_public_key.clone(), - } + generate_identity(&crypto, XID([seed; XID::SIZE])) } fn register_peers( @@ -360,8 +406,8 @@ fn register_peers( id_a: &QlIdentity, id_b: &QlIdentity, ) { - handle_a.bind_peer(peer_from_identity(id_b)); - handle_b.bind_peer(peer_from_identity(id_a)); + handle_a.bind_peer(id_b.bundle()); + handle_b.bind_peer(id_a.bundle()); } fn spawn_forwarder(outbound: Receiver>, handle: RuntimeHandle) { @@ -458,7 +504,7 @@ fn default_runtime_config() -> RuntimeConfig { RuntimeConfig { fsm: QlFsmConfig { handshake_timeout: Duration::from_millis(300), - session_retransmit_timeout: Duration::from_millis(30), + session_record_retransmit_timeout: Duration::from_millis(30), session_keepalive_interval: Duration::ZERO, session_peer_timeout: Duration::ZERO, ..Default::default() diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index b750bf4a..78c22fa2 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -1,7 +1,7 @@ use std::time::Duration; use super::*; -use crate::{CloseCode, CloseTarget}; +use crate::{CloseTarget, StreamCloseCode}; #[tokio::test(flavor = "current_thread")] async fn open_stream_duplex_happy_path() { @@ -153,9 +153,8 @@ async fn dropping_responder_closes_initiator_response() { err, QlError::StreamClosed { target: CloseTarget::Return, - code: CloseCode::CANCELLED, - payload, - } if payload.is_empty() + code, + } if code == StreamCloseCode(0) )); tokio::time::timeout(Duration::from_secs(2), responder) @@ -296,7 +295,7 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { run_local_test(async { let config = RuntimeConfig { fsm: QlFsmConfig { - session_retransmit_timeout: Duration::from_millis(20), + session_record_retransmit_timeout: Duration::from_millis(20), ..default_runtime_config().fsm }, stream_send_buffer_bytes: 4, From fe1eaf1c9cec08114a37fc5a1db2991741788712 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 11:12:07 -0400 Subject: [PATCH 105/304] ql-runtime: piper for reader --- ql-runtime/src/command.rs | 3 + ql-runtime/src/{driver.rs => driver/mod.rs} | 662 ++++++-------------- ql-runtime/src/driver/state.rs | 202 ++++++ ql-runtime/src/driver/test.rs | 199 ++++++ ql-runtime/src/handle.rs | 288 --------- ql-runtime/src/handle/mod.rs | 87 +++ ql-runtime/src/handle/reader.rs | 127 ++++ ql-runtime/src/handle/writer.rs | 105 ++++ ql-runtime/src/lib.rs | 9 +- ql-runtime/src/tests/handshake.rs | 2 +- ql-runtime/src/tests/heartbeat.rs | 10 +- ql-runtime/src/tests/mod.rs | 18 +- ql-runtime/src/tests/stream.rs | 27 +- 13 files changed, 941 insertions(+), 798 deletions(-) rename ql-runtime/src/{driver.rs => driver/mod.rs} (58%) create mode 100644 ql-runtime/src/driver/state.rs create mode 100644 ql-runtime/src/driver/test.rs delete mode 100644 ql-runtime/src/handle.rs create mode 100644 ql-runtime/src/handle/mod.rs create mode 100644 ql-runtime/src/handle/reader.rs create mode 100644 ql-runtime/src/handle/writer.rs diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index 8a2eff0d..0c9442dc 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -12,6 +12,9 @@ pub(crate) enum RuntimeCommand { request_reader: piper::Reader, start: oneshot::Sender>, }, + PollInbound { + stream_id: StreamId, + }, PollStream { stream_id: StreamId, }, diff --git a/ql-runtime/src/driver.rs b/ql-runtime/src/driver/mod.rs similarity index 58% rename from ql-runtime/src/driver.rs rename to ql-runtime/src/driver/mod.rs index 56e6050a..da71e1b2 100644 --- a/ql-runtime/src/driver.rs +++ b/ql-runtime/src/driver/mod.rs @@ -1,3 +1,7 @@ +mod state; +#[cfg(test)] +mod test; + use std::{ collections::{HashMap, VecDeque}, future::Future, @@ -7,194 +11,129 @@ use std::{ use futures_lite::future::poll_fn; use ql_fsm::{FsmTime, QlFsm, QlFsmEvent, SessionWriteId}; -use ql_wire::{CloseTarget, StreamCloseCode, StreamId, XID}; +use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; +use self::state::*; use crate::{ command::RuntimeCommand, handle::{ByteReader, ByteWriter, InboundStream}, platform::{PlatformFuture, QlPlatform}, - InboundEvent, OpenedStreamDelivery, QlError, Runtime, + OpenedStreamDelivery, QlError, Runtime, }; -struct InFlightWrite<'a> { - session_write_id: Option, - future: PlatformFuture<'a, Result<(), QlError>>, -} - -enum DriverEvent { - Command(RuntimeCommand), - WriteCompleted { - index: usize, - result: Result<(), QlError>, - }, - TimerExpired, - CommandsClosed, -} - -enum OutboundIo { - Open { - reader: piper::Reader, - finish_queued: bool, - }, - Closed, -} - -impl OutboundIo { - fn new(reader: piper::Reader) -> Self { - Self::Open { - reader, - finish_queued: false, - } - } - - fn close(&mut self) { - *self = Self::Closed; - } - - fn open_mut(&mut self) -> Option<(&mut piper::Reader, &mut bool)> { - match self { - Self::Open { - reader, - finish_queued, - } => Some((reader, finish_queued)), - Self::Closed => None, - } - } -} - -enum InboundIo { - Open(async_channel::Sender), - Closed, -} - -impl InboundIo { - fn new(tx: async_channel::Sender) -> Self { - Self::Open(tx) - } +impl Runtime

{ + pub async fn run(self) { + let Runtime { + identity, + platform, + config, + rx, + tx, + } = self; - fn close(&mut self) { - if let Self::Open(tx) = self { - tx.close(); + let runtime_tx = tx.upgrade().expect("runtime tx"); + let mut fsm = QlFsm::new(config.fsm, identity, now()); + let mut peer_xid = None; + if let Some(peer) = platform.load_peer().await { + peer_xid = Some(peer.xid); + fsm.bind_peer(peer); } - *self = Self::Closed; - } - fn write_or_close(&mut self, bytes: Vec) -> bool { - let Self::Open(tx) = self else { - return true; + let mut state = DriverState { + streams: HashMap::new(), + runtime_tx, + stream_send_buffer_bytes: config.stream_send_buffer_bytes, + max_concurrent_message_writes: config.max_concurrent_message_writes, + peer_xid, + pending_fsm_events: VecDeque::new(), }; + let mut in_flight = Vec::new(); - if tx.try_send(InboundEvent::Data(bytes)).is_err() { - tx.close(); - *self = Self::Closed; - return true; - } + loop { + state.finish_step(&mut fsm, &platform, &mut in_flight); - false - } + if rx.is_closed() && in_flight.is_empty() { + break; + } - fn finish(&mut self) { - if let Self::Open(tx) = self { - let _ = tx.try_send(InboundEvent::Finished); - tx.close(); + match next_driver_event(&rx, &platform, fsm.next_deadline(), &mut in_flight).await { + DriverEvent::Command(command) => { + state.drive_command(&mut fsm, command, &platform, &mut in_flight) + } + DriverEvent::WriteCompleted { index, result } => { + let write = in_flight.swap_remove(index); + state.drive_write_completed( + &mut fsm, + write.session_write_id, + result, + &platform, + &mut in_flight, + ); + } + DriverEvent::TimerExpired => { + state.with_fsm_events(&mut fsm, &platform, |fsm, emit| { + fsm.on_timer(now(), emit) + }); + state.finish_step(&mut fsm, &platform, &mut in_flight); + } + DriverEvent::CommandsClosed => {} + } } - *self = Self::Closed; } +} - fn fail(&mut self, error: QlError) { - if let Self::Open(tx) = self { - let _ = tx.try_send(InboundEvent::Failed(error)); - tx.close(); - } - *self = Self::Closed; - } +struct InFlightWrite<'a> { + session_write_id: Option, + future: PlatformFuture<'a, Result<(), QlError>>, } -enum DriverStreamIo { - Initiator { - request: OutboundIo, - response: InboundIo, - }, - Responder { - request: InboundIo, - response: OutboundIo, +enum DriverEvent { + Command(RuntimeCommand), + WriteCompleted { + index: usize, + result: Result<(), QlError>, }, + TimerExpired, + CommandsClosed, } -impl DriverStreamIo { - fn new_initiator( - request: piper::Reader, - response: async_channel::Sender, - ) -> Self { - Self::Initiator { - request: OutboundIo::new(request), - response: InboundIo::new(response), - } - } - - fn new_responder( - request: async_channel::Sender, - response: piper::Reader, - ) -> Self { - Self::Responder { - request: InboundIo::new(request), - response: OutboundIo::new(response), - } - } - - fn outbound_mut(&mut self) -> &mut OutboundIo { - match self { - Self::Initiator { request, .. } => request, - Self::Responder { response, .. } => response, - } - } - - fn inbound_mut(&mut self) -> &mut InboundIo { - match self { - Self::Initiator { response, .. } => response, - Self::Responder { request, .. } => request, - } - } +async fn next_driver_event( + rx: &async_channel::Receiver, + platform: &P, + next_timer: Option, + in_flight: &mut [InFlightWrite<'_>], +) -> DriverEvent { + let mut recv_future = (!rx.is_closed()).then(|| Box::pin(rx.recv())); + let mut sleep_future = next_timer.map(|deadline| { + let timeout = deadline.saturating_duration_since(Instant::now()); + platform.sleep(timeout) + }); - fn inbound_target(&self) -> CloseTarget { - match self { - Self::Initiator { .. } => CloseTarget::Return, - Self::Responder { .. } => CloseTarget::Origin, + poll_fn(|cx| { + for (index, write) in in_flight.iter_mut().enumerate() { + if let Poll::Ready(result) = write.future.as_mut().poll(cx) { + return Poll::Ready(DriverEvent::WriteCompleted { index, result }); + } } - } - fn outbound_target(&self) -> CloseTarget { - match self { - Self::Initiator { .. } => CloseTarget::Origin, - Self::Responder { .. } => CloseTarget::Return, + if let Some(future) = sleep_future.as_mut() { + if let Poll::Ready(()) = future.as_mut().poll(cx) { + return Poll::Ready(DriverEvent::TimerExpired); + } } - } - fn fail_all(&mut self, error: QlError) { - match self { - Self::Initiator { - request, response, .. - } => { - request.close(); - response.fail(error); - } - Self::Responder { - request, response, .. - } => { - request.fail(error); - response.close(); + if let Some(future) = recv_future.as_mut() { + if let Poll::Ready(res) = future.as_mut().poll(cx) { + return Poll::Ready(match res { + Ok(command) => DriverEvent::Command(command), + Err(_) => DriverEvent::CommandsClosed, + }); } } - } -} -struct DriverState { - streams: HashMap, - runtime_tx: async_channel::Sender, - stream_send_buffer_bytes: usize, - max_concurrent_message_writes: usize, - peer_xid: Option, - pending_fsm_events: VecDeque, + Poll::Pending + }) + .await } impl DriverState { @@ -228,14 +167,26 @@ impl DriverState { start, } => match fsm.open_stream().map_err(QlError::from) { Ok(stream_id) => { - let (response_tx, response_rx) = async_channel::unbounded(); + let (response_reader, response_writer) = + piper::pipe(self.stream_send_buffer_bytes); + let (response_terminal_tx, response_terminal_rx) = oneshot::channel(); self.streams.insert( stream_id, - DriverStreamIo::new_initiator(request_reader, response_tx), + DriverStreamIo::new_initiator( + request_reader, + response_writer, + response_terminal_tx, + ), ); let _ = start.send(Ok(OpenedStreamDelivery { stream_id, - response: response_rx, + response: ByteReader::new( + stream_id, + CloseTarget::Return, + response_reader, + response_terminal_rx, + self.runtime_tx.clone(), + ), })); self.poll_stream(fsm, stream_id); self.finish_step(fsm, platform, in_flight); @@ -244,6 +195,10 @@ impl DriverState { let _ = start.send(Err(error)); } }, + RuntimeCommand::PollInbound { stream_id } => { + self.handle_inbound_readable(fsm, stream_id); + self.finish_step(fsm, platform, in_flight); + } RuntimeCommand::PollStream { stream_id } => { self.poll_stream(fsm, stream_id); self.finish_step(fsm, platform, in_flight); @@ -346,7 +301,7 @@ impl DriverState { self.poll_stream(fsm, stream_id); } QlFsmEvent::Finished(stream_id) => { - self.handle_inbound_finished(stream_id); + self.handle_inbound_finished(fsm, stream_id); } QlFsmEvent::Closed(frame) => { self.handle_closed_stream(frame); @@ -359,12 +314,13 @@ impl DriverState { } fn handle_opened_stream(&mut self, platform: &P, stream_id: StreamId) { - let (request_tx, request_rx) = async_channel::unbounded(); + let (request_reader, request_writer) = piper::pipe(self.stream_send_buffer_bytes); + let (request_terminal_tx, request_terminal_rx) = oneshot::channel(); let (response_reader, response_writer) = piper::pipe(self.stream_send_buffer_bytes); self.streams.insert( stream_id, - DriverStreamIo::new_responder(request_tx, response_reader), + DriverStreamIo::new_responder(request_writer, request_terminal_tx, response_reader), ); platform.handle_inbound(InboundStream { @@ -372,7 +328,8 @@ impl DriverState { request: ByteReader::new( stream_id, CloseTarget::Origin, - request_rx, + request_reader, + request_terminal_rx, self.runtime_tx.clone(), ), response: ByteWriter::new( @@ -386,45 +343,81 @@ impl DriverState { fn handle_inbound_readable(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { loop { - let Some(available) = fsm.stream_available_bytes(stream_id) else { + let Some(_) = fsm.stream_available_bytes(stream_id) else { return; }; - if available == 0 { - break; - } - - let bytes = { + let mut accepted = 0usize; + let mut blocked = false; + let mut peer_closed = false; + let target; + { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + target = stream.inbound_target(); let Some(chunks) = fsm.stream_read(stream_id) else { return; }; - let mut bytes = Vec::with_capacity(available); for chunk in chunks { - bytes.extend_from_slice(chunk); + if chunk.is_empty() { + continue; + } + match stream.inbound_mut().try_write(chunk) { + InboundWriteResult::Accepted(n) => { + accepted += n; + if n < chunk.len() { + blocked = true; + break; + } + } + InboundWriteResult::Full => { + blocked = true; + break; + } + InboundWriteResult::Closed => { + peer_closed = true; + break; + } + } } - bytes - }; - - if bytes.is_empty() { - break; } - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - let target = stream.inbound_target(); - if stream.inbound_mut().write_or_close(bytes.clone()) { + if accepted > 0 { + fsm.stream_read_commit(stream_id, accepted).unwrap(); + } + if peer_closed { let _ = fsm.close_stream(stream_id, target, StreamCloseCode(0)); self.try_reap_stream(stream_id); break; } - fsm.stream_read_commit(stream_id, bytes.len()).unwrap(); + if accepted == 0 || blocked { + break; + } } + + self.finish_inbound_if_ready(fsm, stream_id); } - fn handle_inbound_finished(&mut self, stream_id: StreamId) { + fn handle_inbound_finished(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { let Some(stream) = self.streams.get_mut(&stream_id) else { return; }; + stream.inbound_mut().queue_finish(); + self.finish_inbound_if_ready(fsm, stream_id); + } + + fn finish_inbound_if_ready(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { + if fsm.stream_available_bytes(stream_id).unwrap_or(0) != 0 { + return; + } + + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + if !stream.inbound_mut().finish_pending() { + return; + } + stream.inbound_mut().finish(); self.try_reap_stream(stream_id); } @@ -554,106 +547,6 @@ impl DriverState { } } -async fn next_driver_event( - rx: &async_channel::Receiver, - platform: &P, - next_timer: Option, - in_flight: &mut [InFlightWrite<'_>], -) -> DriverEvent { - let mut recv_future = (!rx.is_closed()).then(|| Box::pin(rx.recv())); - let mut sleep_future = next_timer.map(|deadline| { - let timeout = deadline.saturating_duration_since(Instant::now()); - platform.sleep(timeout) - }); - - poll_fn(|cx| { - for (index, write) in in_flight.iter_mut().enumerate() { - if let Poll::Ready(result) = write.future.as_mut().poll(cx) { - return Poll::Ready(DriverEvent::WriteCompleted { index, result }); - } - } - - if let Some(future) = sleep_future.as_mut() { - if let Poll::Ready(()) = future.as_mut().poll(cx) { - return Poll::Ready(DriverEvent::TimerExpired); - } - } - - if let Some(future) = recv_future.as_mut() { - if let Poll::Ready(res) = future.as_mut().poll(cx) { - return Poll::Ready(match res { - Ok(command) => DriverEvent::Command(command), - Err(_) => DriverEvent::CommandsClosed, - }); - } - } - - Poll::Pending - }) - .await -} - -impl Runtime

{ - pub async fn run(self) { - let Runtime { - identity, - platform, - config, - rx, - tx, - } = self; - - let runtime_tx = tx.upgrade().expect("runtime tx"); - let mut fsm = QlFsm::new(config.fsm, identity, now()); - let mut peer_xid = None; - if let Some(peer) = platform.load_peer().await { - peer_xid = Some(peer.xid); - fsm.bind_peer(peer); - } - - let mut state = DriverState { - streams: HashMap::new(), - runtime_tx, - stream_send_buffer_bytes: config.stream_send_buffer_bytes, - max_concurrent_message_writes: config.max_concurrent_message_writes, - peer_xid, - pending_fsm_events: VecDeque::new(), - }; - let mut in_flight = Vec::new(); - - loop { - state.finish_step(&mut fsm, &platform, &mut in_flight); - - if rx.is_closed() && in_flight.is_empty() { - break; - } - - match next_driver_event(&rx, &platform, fsm.next_deadline(), &mut in_flight).await { - DriverEvent::Command(command) => { - state.drive_command(&mut fsm, command, &platform, &mut in_flight) - } - DriverEvent::WriteCompleted { index, result } => { - let write = in_flight.swap_remove(index); - state.drive_write_completed( - &mut fsm, - write.session_write_id, - result, - &platform, - &mut in_flight, - ); - } - DriverEvent::TimerExpired => { - state.with_fsm_events(&mut fsm, &platform, |fsm, emit| { - fsm.on_timer(now(), emit) - }); - state.finish_step(&mut fsm, &platform, &mut in_flight); - } - DriverEvent::CommandsClosed => {} - } - } - } -} - fn now() -> FsmTime { FsmTime { instant: Instant::now(), @@ -672,202 +565,3 @@ fn with_noop_context(f: impl FnOnce(&mut Context<'_>) -> T) -> T { let mut cx = Context::from_waker(Waker::noop()); f(&mut cx) } - -#[cfg(test)] -mod tests { - use ql_wire::{ - MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, PeerBundle, QlAead, QlHash, - QlKem, QlRandom, SessionKey, StreamClose, XID, - }; - - use super::*; - use crate::tests::new_identity; - - struct NoopPlatform; - - impl QlRandom for NoopPlatform { - fn fill_random_bytes(&self, data: &mut [u8]) { - data.fill(0); - } - } - - impl QlHash for NoopPlatform { - fn sha256(&self, _parts: &[&[u8]]) -> [u8; 32] { - [0; 32] - } - } - - impl QlAead for NoopPlatform { - fn aes256_gcm_encrypt( - &self, - _key: &SessionKey, - _nonce: &ql_wire::Nonce, - _aad: &[u8], - _buffer: &mut [u8], - ) -> [u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE] { - [0; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE] - } - - fn aes256_gcm_decrypt( - &self, - _key: &SessionKey, - _nonce: &ql_wire::Nonce, - _aad: &[u8], - _buffer: &mut [u8], - _auth_tag: &[u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE], - ) -> bool { - false - } - } - - impl QlKem for NoopPlatform { - fn mlkem_generate_keypair(&self) -> MlKemKeyPair { - MlKemKeyPair { - private: MlKemPrivateKey::new(Box::new([0; MlKemPrivateKey::SIZE])), - public: MlKemPublicKey::new(Box::new([0; MlKemPublicKey::SIZE])), - } - } - - fn mlkem_encapsulate(&self, _public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { - ( - MlKemCiphertext::new(Box::new([0; MlKemCiphertext::SIZE])), - SessionKey::from_data([0; SessionKey::SIZE]), - ) - } - - fn mlkem_decapsulate( - &self, - _private_key: &MlKemPrivateKey, - _ciphertext: &MlKemCiphertext, - ) -> SessionKey { - SessionKey::from_data([0; SessionKey::SIZE]) - } - } - - impl QlPlatform for NoopPlatform { - fn write_message(&self, _message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { - Box::pin(async { Ok(()) }) - } - - fn sleep(&self, _duration: Duration) -> PlatformFuture<'_, ()> { - Box::pin(async {}) - } - - fn load_peer(&self) -> PlatformFuture<'_, Option> { - Box::pin(async { None }) - } - - fn persist_peer(&self, _peer: PeerBundle) {} - - fn handle_peer_status(&self, _peer: XID, _status: ql_fsm::PeerStatus) {} - - fn handle_inbound(&self, _event: InboundStream) {} - } - - fn new_driver_state() -> (DriverState, QlFsm) { - let (runtime_tx, _runtime_rx) = async_channel::unbounded(); - ( - DriverState { - streams: HashMap::new(), - runtime_tx, - stream_send_buffer_bytes: 16, - max_concurrent_message_writes: 1, - peer_xid: None, - pending_fsm_events: VecDeque::new(), - }, - QlFsm::new(ql_fsm::QlFsmConfig::default(), new_identity(7), now()), - ) - } - - #[test] - fn handle_inbound_finished_reaps_closed_initiator_stream() { - let (mut state, _fsm) = new_driver_state(); - let stream_id = StreamId(1); - let (response_tx, _response_rx) = async_channel::unbounded(); - - state.streams.insert( - stream_id, - DriverStreamIo::Initiator { - request: OutboundIo::Closed, - response: InboundIo::new(response_tx), - }, - ); - - state.handle_inbound_finished(stream_id); - - assert!(!state.streams.contains_key(&stream_id)); - } - - #[test] - fn handle_closed_stream_reaps_when_both_halves_close() { - let (mut state, _fsm) = new_driver_state(); - let stream_id = StreamId(2); - let (request_tx, _request_rx) = async_channel::unbounded(); - let (response_reader, _response_writer) = piper::pipe(1); - - state.streams.insert( - stream_id, - DriverStreamIo::Responder { - request: InboundIo::new(request_tx), - response: OutboundIo::new(response_reader), - }, - ); - - state.handle_closed_stream(StreamClose { - stream_id, - target: CloseTarget::Both, - code: StreamCloseCode(0), - }); - - assert!(!state.streams.contains_key(&stream_id)); - } - - #[test] - fn poll_stream_reaps_after_local_finish_when_inbound_is_closed() { - let (mut state, mut fsm) = new_driver_state(); - let stream_id = StreamId(3); - let (request_reader, request_writer) = piper::pipe(1); - - drop(request_writer); - state.streams.insert( - stream_id, - DriverStreamIo::Initiator { - request: OutboundIo::new(request_reader), - response: InboundIo::Closed, - }, - ); - - state.poll_stream(&mut fsm, stream_id); - - assert!(!state.streams.contains_key(&stream_id)); - } - - #[test] - fn local_close_command_reaps_when_other_half_is_already_closed() { - let (mut state, mut fsm) = new_driver_state(); - let stream_id = StreamId(4); - let (request_reader, _request_writer) = piper::pipe(1); - let mut in_flight = Vec::new(); - - state.streams.insert( - stream_id, - DriverStreamIo::Initiator { - request: OutboundIo::new(request_reader), - response: InboundIo::Closed, - }, - ); - - state.drive_command( - &mut fsm, - RuntimeCommand::CloseStream { - stream_id, - target: CloseTarget::Origin, - code: StreamCloseCode(0), - }, - &NoopPlatform, - &mut in_flight, - ); - - assert!(!state.streams.contains_key(&stream_id)); - } -} diff --git a/ql-runtime/src/driver/state.rs b/ql-runtime/src/driver/state.rs new file mode 100644 index 00000000..8b941da1 --- /dev/null +++ b/ql-runtime/src/driver/state.rs @@ -0,0 +1,202 @@ +use std::collections::{HashMap, VecDeque}; + +use ql_fsm::QlFsmEvent; +use ql_wire::{CloseTarget, StreamId, XID}; + +use crate::{command::RuntimeCommand, QlError}; + +pub struct DriverState { + pub streams: HashMap, + pub runtime_tx: async_channel::Sender, + pub stream_send_buffer_bytes: usize, + pub max_concurrent_message_writes: usize, + pub peer_xid: Option, + pub pending_fsm_events: VecDeque, +} + +pub enum DriverStreamIo { + Initiator { + request: OutboundIo, + response: InboundIo, + }, + Responder { + request: InboundIo, + response: OutboundIo, + }, +} + +impl DriverStreamIo { + pub fn new_initiator( + request: piper::Reader, + response: piper::Writer, + response_terminal: oneshot::Sender>, + ) -> Self { + Self::Initiator { + request: OutboundIo::new(request), + response: InboundIo::new(response, response_terminal), + } + } + + pub fn new_responder( + request: piper::Writer, + request_terminal: oneshot::Sender>, + response: piper::Reader, + ) -> Self { + Self::Responder { + request: InboundIo::new(request, request_terminal), + response: OutboundIo::new(response), + } + } + + pub fn outbound_mut(&mut self) -> &mut OutboundIo { + match self { + Self::Initiator { request, .. } => request, + Self::Responder { response, .. } => response, + } + } + + pub fn inbound_mut(&mut self) -> &mut InboundIo { + match self { + Self::Initiator { response, .. } => response, + Self::Responder { request, .. } => request, + } + } + + pub fn inbound_target(&self) -> CloseTarget { + match self { + Self::Initiator { .. } => CloseTarget::Return, + Self::Responder { .. } => CloseTarget::Origin, + } + } + + pub fn outbound_target(&self) -> CloseTarget { + match self { + Self::Initiator { .. } => CloseTarget::Origin, + Self::Responder { .. } => CloseTarget::Return, + } + } + + pub fn fail_all(&mut self, error: QlError) { + match self { + Self::Initiator { + request, response, .. + } => { + request.close(); + response.fail(error); + } + Self::Responder { + request, response, .. + } => { + request.fail(error); + response.close(); + } + } + } +} + +pub enum OutboundIo { + Open { + reader: piper::Reader, + finish_queued: bool, + }, + Closed, +} + +impl OutboundIo { + pub fn new(reader: piper::Reader) -> Self { + Self::Open { + reader, + finish_queued: false, + } + } + + pub fn close(&mut self) { + *self = Self::Closed; + } + + pub fn open_mut(&mut self) -> Option<(&mut piper::Reader, &mut bool)> { + match self { + Self::Open { + reader, + finish_queued, + } => Some((reader, finish_queued)), + Self::Closed => None, + } + } +} + +pub enum InboundIo { + Open { + writer: piper::Writer, + terminal: Option>>, + finish_pending: bool, + }, + Closed, +} + +pub enum InboundWriteResult { + Accepted(usize), + Full, + Closed, +} + +impl InboundIo { + pub fn new(writer: piper::Writer, terminal: oneshot::Sender>) -> Self { + Self::Open { + writer, + terminal: Some(terminal), + finish_pending: false, + } + } + + pub fn close(&mut self) { + *self = Self::Closed; + } + + pub fn try_write(&mut self, bytes: &[u8]) -> InboundWriteResult { + let Self::Open { writer, .. } = self else { + return InboundWriteResult::Closed; + }; + + let accepted = writer.try_fill(bytes); + if accepted > 0 { + return InboundWriteResult::Accepted(accepted); + } + if writer.is_closed() { + *self = Self::Closed; + return InboundWriteResult::Closed; + } + InboundWriteResult::Full + } + + pub fn finish(&mut self) { + if let Self::Open { terminal, .. } = self { + if let Some(terminal) = terminal.take() { + let _ = terminal.send(Ok(())); + } + } + *self = Self::Closed; + } + + pub fn fail(&mut self, error: QlError) { + if let Self::Open { terminal, .. } = self { + if let Some(terminal) = terminal.take() { + let _ = terminal.send(Err(error)); + } + } + *self = Self::Closed; + } + + pub fn queue_finish(&mut self) { + if let Self::Open { finish_pending, .. } = self { + *finish_pending = true; + } + } + + pub fn finish_pending(&self) -> bool { + match self { + Self::Open { finish_pending, .. } => *finish_pending, + Self::Closed => false, + } + } +} diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs new file mode 100644 index 00000000..bd8188f5 --- /dev/null +++ b/ql-runtime/src/driver/test.rs @@ -0,0 +1,199 @@ +use ql_wire::{ + MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, PeerBundle, QlAead, QlHash, + QlKem, QlRandom, SessionKey, StreamClose, XID, +}; + +use super::*; +use crate::tests::new_identity; + +struct NoopPlatform; + +impl QlRandom for NoopPlatform { + fn fill_random_bytes(&self, data: &mut [u8]) { + data.fill(0); + } +} + +impl QlHash for NoopPlatform { + fn sha256(&self, _parts: &[&[u8]]) -> [u8; 32] { + [0; 32] + } +} + +impl QlAead for NoopPlatform { + fn aes256_gcm_encrypt( + &self, + _key: &SessionKey, + _nonce: &ql_wire::Nonce, + _aad: &[u8], + _buffer: &mut [u8], + ) -> [u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE] { + [0; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE] + } + + fn aes256_gcm_decrypt( + &self, + _key: &SessionKey, + _nonce: &ql_wire::Nonce, + _aad: &[u8], + _buffer: &mut [u8], + _auth_tag: &[u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE], + ) -> bool { + false + } +} + +impl QlKem for NoopPlatform { + fn mlkem_generate_keypair(&self) -> MlKemKeyPair { + MlKemKeyPair { + private: MlKemPrivateKey::new(Box::new([0; MlKemPrivateKey::SIZE])), + public: MlKemPublicKey::new(Box::new([0; MlKemPublicKey::SIZE])), + } + } + + fn mlkem_encapsulate(&self, _public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { + ( + MlKemCiphertext::new(Box::new([0; MlKemCiphertext::SIZE])), + SessionKey::from_data([0; SessionKey::SIZE]), + ) + } + + fn mlkem_decapsulate( + &self, + _private_key: &MlKemPrivateKey, + _ciphertext: &MlKemCiphertext, + ) -> SessionKey { + SessionKey::from_data([0; SessionKey::SIZE]) + } +} + +impl QlPlatform for NoopPlatform { + fn write_message(&self, _message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { + Box::pin(async { Ok(()) }) + } + + fn sleep(&self, _duration: Duration) -> PlatformFuture<'_, ()> { + Box::pin(async {}) + } + + fn load_peer(&self) -> PlatformFuture<'_, Option> { + Box::pin(async { None }) + } + + fn persist_peer(&self, _peer: PeerBundle) {} + + fn handle_peer_status(&self, _peer: XID, _status: ql_fsm::PeerStatus) {} + + fn handle_inbound(&self, _event: InboundStream) {} +} + +fn new_driver_state() -> (DriverState, QlFsm) { + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + ( + DriverState { + streams: HashMap::new(), + runtime_tx, + stream_send_buffer_bytes: 16, + max_concurrent_message_writes: 1, + peer_xid: None, + pending_fsm_events: VecDeque::new(), + }, + QlFsm::new(ql_fsm::QlFsmConfig::default(), new_identity(7), now()), + ) +} + +fn new_inbound_io(capacity: usize) -> InboundIo { + let (_reader, writer) = piper::pipe(capacity); + let (terminal_tx, _terminal_rx) = oneshot::channel(); + InboundIo::new(writer, terminal_tx) +} + +#[test] +fn handle_inbound_finished_reaps_closed_initiator_stream() { + let (mut state, mut fsm) = new_driver_state(); + let stream_id = StreamId(1); + + state.streams.insert( + stream_id, + DriverStreamIo::Initiator { + request: OutboundIo::Closed, + response: new_inbound_io(1), + }, + ); + + state.handle_inbound_finished(&mut fsm, stream_id); + + assert!(!state.streams.contains_key(&stream_id)); +} + +#[test] +fn handle_closed_stream_reaps_when_both_halves_close() { + let (mut state, _fsm) = new_driver_state(); + let stream_id = StreamId(2); + let (response_reader, _response_writer) = piper::pipe(1); + + state.streams.insert( + stream_id, + DriverStreamIo::Responder { + request: new_inbound_io(1), + response: OutboundIo::new(response_reader), + }, + ); + + state.handle_closed_stream(StreamClose { + stream_id, + target: CloseTarget::Both, + code: StreamCloseCode(0), + }); + + assert!(!state.streams.contains_key(&stream_id)); +} + +#[test] +fn poll_stream_reaps_after_local_finish_when_inbound_is_closed() { + let (mut state, mut fsm) = new_driver_state(); + let stream_id = StreamId(3); + let (request_reader, request_writer) = piper::pipe(1); + + drop(request_writer); + state.streams.insert( + stream_id, + DriverStreamIo::Initiator { + request: OutboundIo::new(request_reader), + response: InboundIo::Closed, + }, + ); + + state.poll_stream(&mut fsm, stream_id); + + assert!(!state.streams.contains_key(&stream_id)); +} + +#[test] +fn local_close_command_reaps_when_other_half_is_already_closed() { + let (mut state, mut fsm) = new_driver_state(); + let stream_id = StreamId(4); + let (request_reader, _request_writer) = piper::pipe(1); + let mut in_flight = Vec::new(); + + state.streams.insert( + stream_id, + DriverStreamIo::Initiator { + request: OutboundIo::new(request_reader), + response: InboundIo::Closed, + }, + ); + + state.drive_command( + &mut fsm, + RuntimeCommand::CloseStream { + stream_id, + target: CloseTarget::Origin, + code: StreamCloseCode(0), + }, + &NoopPlatform, + &mut in_flight, + ); + + assert!(!state.streams.contains_key(&stream_id)); +} diff --git a/ql-runtime/src/handle.rs b/ql-runtime/src/handle.rs deleted file mode 100644 index 5b14e373..00000000 --- a/ql-runtime/src/handle.rs +++ /dev/null @@ -1,288 +0,0 @@ -use std::{pin::Pin, task::Poll}; - -use async_channel::{Receiver, Sender}; -use futures_lite::{future::poll_fn, Stream}; - -use crate::{ - command::RuntimeCommand, CloseTarget, InboundEvent, OpenedStreamDelivery, PeerBundle, QlError, - StreamCloseCode, StreamId, -}; - -#[derive(Clone)] -pub struct RuntimeHandle { - pub(crate) tx: Sender, - pub(crate) stream_send_buffer_bytes: usize, -} - -#[derive(Debug)] -pub struct OutboundStream { - pub stream_id: StreamId, - pub request: ByteWriter, - pub response: ByteReader, -} - -#[derive(Debug)] -pub struct InboundStream { - pub stream_id: StreamId, - pub request: ByteReader, - pub response: ByteWriter, -} - -pub struct ByteReader { - stream_id: StreamId, - target: CloseTarget, - rx: Receiver, - tx: Sender, - finished: bool, -} - -impl std::fmt::Debug for ByteReader { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("InboundByteStream") - .field("stream_id", &self.stream_id) - .field("target", &self.target) - .field("finished", &self.finished) - .finish_non_exhaustive() - } -} - -pub struct ByteWriter { - stream_id: StreamId, - target: CloseTarget, - writer: Option, - tx: Sender, -} - -impl std::fmt::Debug for ByteWriter { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("OutboundByteStream") - .field("stream_id", &self.stream_id) - .field("target", &self.target) - .field("closed", &self.writer.is_none()) - .finish_non_exhaustive() - } -} - -impl ByteReader { - pub(crate) fn new( - stream_id: StreamId, - target: CloseTarget, - rx: Receiver, - tx: Sender, - ) -> Self { - Self { - stream_id, - target, - rx, - tx, - finished: false, - } - } - - pub async fn next_chunk(&mut self) -> Result>, QlError> { - poll_fn(|cx| self.poll_next_chunk(cx)).await - } - - pub fn poll_next_chunk( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> Poll>, QlError>> { - if self.finished { - return Poll::Ready(Ok(None)); - } - - // `async_channel::Receiver` implements `Stream` and stores its listener state - // internally, so poll it directly rather than recreating a `recv()` future. - // SAFETY: `self.rx` is pinned for the duration of this call and is not moved - // before `poll_next` returns. - let mut rx = unsafe { Pin::new_unchecked(&mut self.rx) }; - match Stream::poll_next(rx.as_mut(), cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Some(InboundEvent::Data(bytes))) => Poll::Ready(Ok(Some(bytes))), - Poll::Ready(Some(InboundEvent::Finished)) => { - self.finished = true; - Poll::Ready(Ok(None)) - } - Poll::Ready(Some(InboundEvent::Failed(error))) => { - self.finished = true; - Poll::Ready(Err(error)) - } - Poll::Ready(None) => { - self.finished = true; - Poll::Ready(Err(QlError::Cancelled)) - } - } - } - - pub async fn close(mut self, code: StreamCloseCode) -> Result<(), QlError> { - if self.finished { - return Ok(()); - } - self.finished = true; - self.tx - .send(RuntimeCommand::CloseStream { - stream_id: self.stream_id, - target: self.target, - code, - }) - .await - .map_err(|_| QlError::Cancelled) - } -} - -impl Drop for ByteReader { - fn drop(&mut self) { - if self.finished { - return; - } - let _ = self.tx.try_send(RuntimeCommand::CloseStream { - stream_id: self.stream_id, - target: self.target, - code: StreamCloseCode(0), - }); - } -} - -impl ByteWriter { - pub(crate) fn new( - stream_id: StreamId, - target: CloseTarget, - writer: piper::Writer, - tx: Sender, - ) -> Self { - Self { - stream_id, - target, - writer: Some(writer), - tx, - } - } - - fn poll_runtime(&self) -> Result<(), QlError> { - self.tx - .try_send(RuntimeCommand::PollStream { - stream_id: self.stream_id, - }) - .map_err(|_| QlError::Cancelled) - } - - pub async fn write(&mut self, bytes: &[u8]) -> Result { - if bytes.is_empty() { - return Ok(0); - } - self.poll_runtime()?; - let writer = self.writer.as_mut().expect("stream not finished or closed"); - let written = poll_fn(|cx| writer.poll_fill_bytes(cx, bytes)).await; - if written == 0 { - self.writer.take(); - return Err(QlError::Cancelled); - } - self.poll_runtime()?; - Ok(written) - } - - pub async fn write_all(&mut self, mut bytes: &[u8]) -> Result<(), QlError> { - while !bytes.is_empty() { - let written = self.write(bytes).await?; - if written == 0 { - return Err(QlError::Cancelled); - } - bytes = &bytes[written..]; - } - Ok(()) - } - - pub async fn finish(mut self) -> Result<(), QlError> { - if self.writer.take().is_none() { - return Ok(()); - } - self.poll_runtime() - } - - pub async fn close(mut self, code: StreamCloseCode) -> Result<(), QlError> { - if self.writer.take().is_none() { - return Ok(()); - } - self.tx - .send(RuntimeCommand::CloseStream { - stream_id: self.stream_id, - target: self.target, - code, - }) - .await - .map_err(|_| QlError::Cancelled) - } -} - -impl Drop for ByteWriter { - fn drop(&mut self) { - if self.writer.take().is_none() { - return; - } - let _ = self.tx.try_send(RuntimeCommand::CloseStream { - stream_id: self.stream_id, - target: self.target, - code: StreamCloseCode(0), - }); - } -} - -impl RuntimeHandle { - pub fn bind_peer(&self, peer: PeerBundle) { - self.send(RuntimeCommand::BindPeer { peer }) - } - - pub fn connect(&self) -> Result<(), QlError> { - self.tx - .send_blocking(RuntimeCommand::Connect) - .map_err(|_| QlError::Cancelled) - } - - pub fn send_incoming(&self, bytes: Vec) { - self.send(RuntimeCommand::Incoming(bytes)) - } - - pub async fn open_stream(&self) -> Result { - let (request_reader, request_writer) = piper::pipe(self.stream_send_buffer_bytes); - let (start_tx, start_rx) = oneshot::channel(); - - self.tx - .send(RuntimeCommand::OpenStream { - request_reader, - start: start_tx, - }) - .await - .map_err(|_| QlError::Cancelled)?; - - let OpenedStreamDelivery { - stream_id, - response, - } = start_rx.await.unwrap_or(Err(QlError::Cancelled))?; - - Ok(OutboundStream { - stream_id, - request: ByteWriter::new( - stream_id, - CloseTarget::Origin, - request_writer, - self.tx.clone(), - ), - response: ByteReader::new(stream_id, CloseTarget::Return, response, self.tx.clone()), - }) - } - - #[cfg(feature = "rpc")] - pub fn rpc(&self) -> crate::rpc::RpcHandle { - crate::rpc::RpcHandle { - inner: self.clone(), - } - } -} - -impl RuntimeHandle { - #[inline] - #[track_caller] - fn send(&self, cmd: RuntimeCommand) { - self.tx.send_blocking(cmd).expect("runtime is alive") - } -} diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs new file mode 100644 index 00000000..2f847453 --- /dev/null +++ b/ql-runtime/src/handle/mod.rs @@ -0,0 +1,87 @@ +mod reader; +mod writer; + +use ql_wire::{CloseTarget, PeerBundle, StreamId}; + +pub use self::{reader::*, writer::*}; +use crate::{command::RuntimeCommand, OpenedStreamDelivery, QlError}; + +#[derive(Debug)] +pub struct OutboundStream { + pub stream_id: StreamId, + pub request: ByteWriter, + pub response: ByteReader, +} + +#[derive(Debug)] +pub struct InboundStream { + pub stream_id: StreamId, + pub request: ByteReader, + pub response: ByteWriter, +} + +#[derive(Clone)] +pub struct RuntimeHandle { + pub(crate) tx: async_channel::Sender, + pub(crate) stream_send_buffer_bytes: usize, +} + +impl RuntimeHandle { + pub fn bind_peer(&self, peer: PeerBundle) { + self.send(RuntimeCommand::BindPeer { peer }) + } + + pub fn connect(&self) -> Result<(), QlError> { + self.tx + .send_blocking(RuntimeCommand::Connect) + .map_err(|_| QlError::Cancelled) + } + + pub fn send_incoming(&self, bytes: Vec) { + self.send(RuntimeCommand::Incoming(bytes)) + } + + pub async fn open_stream(&self) -> Result { + let (request_reader, request_writer) = piper::pipe(self.stream_send_buffer_bytes); + let (start_tx, start_rx) = oneshot::channel(); + + self.tx + .send(RuntimeCommand::OpenStream { + request_reader, + start: start_tx, + }) + .await + .map_err(|_| QlError::Cancelled)?; + + let OpenedStreamDelivery { + stream_id, + response, + } = start_rx.await.unwrap_or(Err(QlError::Cancelled))?; + + Ok(OutboundStream { + stream_id, + request: ByteWriter::new( + stream_id, + CloseTarget::Origin, + request_writer, + self.tx.clone(), + ), + response, + }) + } + + #[cfg(feature = "rpc")] + pub fn rpc(&self) -> crate::rpc::RpcHandle { + crate::rpc::RpcHandle { + inner: self.clone(), + } + } +} + +impl RuntimeHandle { + #[inline] + #[track_caller] + fn send(&self, cmd: RuntimeCommand) { + self.tx.send_blocking(cmd).expect("runtime is alive") + } +} diff --git a/ql-runtime/src/handle/reader.rs b/ql-runtime/src/handle/reader.rs new file mode 100644 index 00000000..988e94b8 --- /dev/null +++ b/ql-runtime/src/handle/reader.rs @@ -0,0 +1,127 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; + +use crate::{command::RuntimeCommand, QlError}; + +pub struct ByteReader { + stream_id: StreamId, + target: CloseTarget, + reader: piper::Reader, + terminal: TerminalState, + tx: async_channel::Sender, +} + +enum TerminalState { + Armed(oneshot::Receiver>), + Terminal(Result<(), QlError>), + Delivered, +} + +impl std::fmt::Debug for ByteReader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("InboundByteStream") + .field("stream_id", &self.stream_id) + .field("target", &self.target) + .field( + "terminal", + &matches!(self.terminal, TerminalState::Delivered), + ) + .finish_non_exhaustive() + } +} + +impl ByteReader { + pub(crate) fn new( + stream_id: StreamId, + target: CloseTarget, + reader: piper::Reader, + terminal: oneshot::Receiver>, + tx: async_channel::Sender, + ) -> Self { + Self { + stream_id, + target, + reader, + terminal: TerminalState::Armed(terminal), + tx, + } + } + + pub fn poll_fill_buf(&mut self, cx: &mut Context<'_>) -> Poll, QlError>> { + if matches!(self.terminal, TerminalState::Delivered) { + return Poll::Ready(Ok(None)); + } + + if let Poll::Ready(true) = self.reader.poll(cx) { + return Poll::Ready(Ok(Some(self.reader.peek_buf()))); + } + + if let TerminalState::Armed(terminal) = &mut self.terminal { + let result = match Pin::new(terminal).poll(cx) { + Poll::Pending => None, + Poll::Ready(Ok(result)) => Some(result), + Poll::Ready(Err(_)) => Some(Err(QlError::Cancelled)), + }; + if let Some(result) = result { + self.terminal = TerminalState::Terminal(result); + } + } + + match &self.terminal { + TerminalState::Armed(_) => Poll::Pending, + TerminalState::Terminal(Ok(())) => { + self.terminal = TerminalState::Delivered; + Poll::Ready(Ok(None)) + } + TerminalState::Terminal(Err(error)) => { + let error = error.clone(); + self.terminal = TerminalState::Delivered; + Poll::Ready(Err(error)) + } + TerminalState::Delivered => Poll::Ready(Ok(None)), + } + } + + pub fn consume(&mut self, amt: usize) { + if amt == 0 { + return; + } + self.reader.consume(amt); + let _ = self.tx.try_send(RuntimeCommand::PollInbound { + stream_id: self.stream_id, + }); + } + + pub async fn close(mut self, code: StreamCloseCode) -> Result<(), QlError> { + if matches!(self.terminal, TerminalState::Delivered) { + return Ok(()); + } + self.terminal = TerminalState::Delivered; + self.tx + .send(RuntimeCommand::CloseStream { + stream_id: self.stream_id, + target: self.target, + code, + }) + .await + .map_err(|_| QlError::Cancelled) + } +} + +impl Drop for ByteReader { + fn drop(&mut self) { + if matches!(self.terminal, TerminalState::Delivered) { + return; + } + let _ = self.tx.try_send(RuntimeCommand::CloseStream { + stream_id: self.stream_id, + target: self.target, + code: StreamCloseCode(0), + }); + } +} diff --git a/ql-runtime/src/handle/writer.rs b/ql-runtime/src/handle/writer.rs new file mode 100644 index 00000000..82e1786b --- /dev/null +++ b/ql-runtime/src/handle/writer.rs @@ -0,0 +1,105 @@ +use futures_lite::future::poll_fn; +use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; + +use crate::{command::RuntimeCommand, QlError}; + +pub struct ByteWriter { + stream_id: StreamId, + target: CloseTarget, + writer: Option, + tx: async_channel::Sender, +} + +impl std::fmt::Debug for ByteWriter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OutboundByteStream") + .field("stream_id", &self.stream_id) + .field("target", &self.target) + .field("closed", &self.writer.is_none()) + .finish_non_exhaustive() + } +} + +impl ByteWriter { + pub(crate) fn new( + stream_id: StreamId, + target: CloseTarget, + writer: piper::Writer, + tx: async_channel::Sender, + ) -> Self { + Self { + stream_id, + target, + writer: Some(writer), + tx, + } + } + + fn poll_runtime(&self) -> Result<(), QlError> { + self.tx + .try_send(RuntimeCommand::PollStream { + stream_id: self.stream_id, + }) + .map_err(|_| QlError::Cancelled) + } + + pub async fn write(&mut self, bytes: &[u8]) -> Result { + if bytes.is_empty() { + return Ok(0); + } + self.poll_runtime()?; + let writer = self.writer.as_mut().expect("stream not finished or closed"); + let written = poll_fn(|cx| writer.poll_fill_bytes(cx, bytes)).await; + if written == 0 { + self.writer.take(); + return Err(QlError::Cancelled); + } + self.poll_runtime()?; + Ok(written) + } + + pub async fn write_all(&mut self, mut bytes: &[u8]) -> Result<(), QlError> { + while !bytes.is_empty() { + let written = self.write(bytes).await?; + if written == 0 { + return Err(QlError::Cancelled); + } + bytes = &bytes[written..]; + } + Ok(()) + } + + pub async fn finish(mut self) -> Result<(), QlError> { + if self.writer.take().is_none() { + return Ok(()); + } + self.poll_runtime() + } + + pub async fn close(mut self, code: StreamCloseCode) -> Result<(), QlError> { + if self.writer.take().is_none() { + return Ok(()); + } + self.tx + .send(RuntimeCommand::CloseStream { + stream_id: self.stream_id, + target: self.target, + code, + }) + .await + .map_err(|_| QlError::Cancelled) + } +} + +impl Drop for ByteWriter { + fn drop(&mut self) { + if self.writer.take().is_none() { + return; + } + let _ = self.tx.try_send(RuntimeCommand::CloseStream { + stream_id: self.stream_id, + target: self.target, + code: StreamCloseCode(0), + }); + } +} diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index dcbffef6..2fe7137b 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -97,16 +97,9 @@ impl RuntimeConfig { } } -#[derive(Debug)] -pub(crate) enum InboundEvent { - Data(Vec), - Finished, - Failed(crate::QlError), -} - pub(crate) struct OpenedStreamDelivery { pub stream_id: StreamId, - pub response: async_channel::Receiver, + pub response: crate::ByteReader, } pub struct Runtime

{ diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index 8b5e5f16..3d53830c 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -116,7 +116,7 @@ async fn rejected_session_write_is_reissued() { let mut stream = handle_a.open_stream().await.unwrap(); stream.request.write_all(b"retry").await.unwrap(); stream.request.finish().await.unwrap(); - assert_eq!(stream.response.next_chunk().await.unwrap(), None); + assert_eq!(next_chunk(&mut stream.response).await.unwrap(), None); assert_eq!( tokio::time::timeout(Duration::from_secs(2), responder) diff --git a/ql-runtime/src/tests/heartbeat.rs b/ql-runtime/src/tests/heartbeat.rs index 12231080..9859f207 100644 --- a/ql-runtime/src/tests/heartbeat.rs +++ b/ql-runtime/src/tests/heartbeat.rs @@ -55,10 +55,12 @@ async fn session_timeout_disconnects_and_fails_pending_open() { await_status(&status_a, identity_b.xid, PeerStage::Disconnected).await; - let result = - tokio::time::timeout(Duration::from_millis(300), pending.response.next_chunk()) - .await - .unwrap(); + let result = tokio::time::timeout( + Duration::from_millis(300), + next_chunk(&mut pending.response), + ) + .await + .unwrap(); assert!(matches!( result, Err(QlError::SessionClosed) | Err(QlError::Cancelled) diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 47a56c32..d45159c2 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -5,10 +5,12 @@ use std::{ atomic::{AtomicU8, AtomicUsize, Ordering}, Arc, }, + task::Poll, time::Duration, }; use async_channel::{Receiver, Sender}; +use futures_lite::future::poll_fn; use libcrux_aesgcm::AesGcm256Key; use ql_wire::{ generate_identity, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, @@ -494,12 +496,26 @@ async fn assert_no_status_for( async fn read_all(mut stream: crate::ByteReader) -> Result, QlError> { let mut data = Vec::new(); - while let Some(chunk) = stream.next_chunk().await? { + while let Some(chunk) = next_chunk(&mut stream).await? { data.extend_from_slice(&chunk); } Ok(data) } +async fn next_chunk(stream: &mut crate::ByteReader) -> Result>, QlError> { + poll_fn(|cx| match stream.poll_fill_buf(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(Some(buf))) => { + let (bytes, len) = (buf.to_vec(), buf.len()); + stream.consume(len); + Poll::Ready(Ok(Some(bytes))) + } + Poll::Ready(Ok(None)) => Poll::Ready(Ok(None)), + Poll::Ready(Err(error)) => Poll::Ready(Err(error)), + }) + .await +} + fn default_runtime_config() -> RuntimeConfig { RuntimeConfig { fsm: QlFsmConfig { diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index 78c22fa2..17cae955 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -33,24 +33,27 @@ async fn open_stream_duplex_happy_path() { let mut request = inbound.request; let mut response = inbound.response; - assert_eq!(request.next_chunk().await.unwrap(), Some(vec![1, 2])); + assert_eq!(next_chunk(&mut request).await.unwrap(), Some(vec![1, 2])); response.write_all(&[9]).await.unwrap(); - assert_eq!(request.next_chunk().await.unwrap(), Some(vec![3, 4])); + assert_eq!(next_chunk(&mut request).await.unwrap(), Some(vec![3, 4])); response.write_all(&[8, 7]).await.unwrap(); - assert_eq!(request.next_chunk().await.unwrap(), None); + assert_eq!(next_chunk(&mut request).await.unwrap(), None); response.finish().await.unwrap(); }); let mut stream = handle_a.open_stream().await.unwrap(); stream.request.write_all(&[1, 2]).await.unwrap(); - assert_eq!(stream.response.next_chunk().await.unwrap(), Some(vec![9])); + assert_eq!( + next_chunk(&mut stream.response).await.unwrap(), + Some(vec![9]) + ); stream.request.write_all(&[3, 4]).await.unwrap(); stream.request.finish().await.unwrap(); assert_eq!( - stream.response.next_chunk().await.unwrap(), + next_chunk(&mut stream.response).await.unwrap(), Some(vec![8, 7]) ); - assert_eq!(stream.response.next_chunk().await.unwrap(), None); + assert_eq!(next_chunk(&mut stream.response).await.unwrap(), None); tokio::time::timeout(Duration::from_secs(2), responder) .await @@ -100,7 +103,7 @@ async fn stream_backpressure_with_small_runtime_buffer() { let mut stream = handle_a.open_stream().await.unwrap(); stream.request.write_all(&payload).await.unwrap(); stream.request.finish().await.unwrap(); - assert_eq!(stream.response.next_chunk().await.unwrap(), None); + assert_eq!(next_chunk(&mut stream.response).await.unwrap(), None); let received = tokio::time::timeout(Duration::from_secs(2), done_rx.recv()) .await @@ -148,7 +151,7 @@ async fn dropping_responder_closes_initiator_response() { let mut stream = handle_a.open_stream().await.unwrap(); stream.request.finish().await.unwrap(); - let err = stream.response.next_chunk().await.unwrap_err(); + let err = next_chunk(&mut stream.response).await.unwrap_err(); assert!(matches!( err, QlError::StreamClosed { @@ -197,7 +200,7 @@ async fn dropping_inbound_reader_cancels_remote_writer() { let stream = inbound_b.recv().await.unwrap(); let mut request = stream.request; let mut response = stream.response; - assert_eq!(request.next_chunk().await.unwrap(), None); + assert_eq!(next_chunk(&mut request).await.unwrap(), None); response.write_all(&[1, 2, 3, 4]).await.unwrap(); go_rx.recv().await.unwrap(); let err = response.write_all(&[5; 64]).await.unwrap_err(); @@ -207,7 +210,7 @@ async fn dropping_inbound_reader_cancels_remote_writer() { let mut stream = handle_a.open_stream().await.unwrap(); stream.request.finish().await.unwrap(); assert_eq!( - stream.response.next_chunk().await.unwrap(), + next_chunk(&mut stream.response).await.unwrap(), Some(vec![1, 2, 3, 4]) ); drop(stream.response); @@ -265,7 +268,7 @@ async fn max_concurrent_message_writes_is_respected() { let mut stream = handle.open_stream().await.unwrap(); stream.request.write_all(&[i; 8]).await.unwrap(); stream.request.finish().await.unwrap(); - assert_eq!(stream.response.next_chunk().await.unwrap(), None); + assert_eq!(next_chunk(&mut stream.response).await.unwrap(), None); })); } @@ -339,7 +342,7 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { stream.request.finish().await.unwrap(); let mut received_response = Vec::new(); - while let Some(chunk) = stream.response.next_chunk().await.unwrap() { + while let Some(chunk) = next_chunk(&mut stream.response).await.unwrap() { received_response.extend_from_slice(&chunk); } assert_eq!(received_response, expected_response); From 091ad4d037ae831e2c87caccc9be7a8b0b983d70 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 11:26:10 -0400 Subject: [PATCH 106/304] ql-runtime: single QlStream struct --- ql-runtime/src/driver/mod.rs | 10 ++-- ql-runtime/src/driver/test.rs | 2 +- ql-runtime/src/handle/mod.rs | 27 ++++------ ql-runtime/src/lib.rs | 4 +- ql-runtime/src/platform.rs | 4 +- ql-runtime/src/tests/handshake.rs | 10 ++-- ql-runtime/src/tests/heartbeat.rs | 17 +++--- ql-runtime/src/tests/mod.rs | 10 ++-- ql-runtime/src/tests/stream.rs | 89 +++++++++++++++---------------- 9 files changed, 79 insertions(+), 94 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index da71e1b2..0291591d 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -16,7 +16,7 @@ use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; use self::state::*; use crate::{ command::RuntimeCommand, - handle::{ByteReader, ByteWriter, InboundStream}, + handle::{ByteReader, ByteWriter, QlStream}, platform::{PlatformFuture, QlPlatform}, OpenedStreamDelivery, QlError, Runtime, }; @@ -180,7 +180,7 @@ impl DriverState { ); let _ = start.send(Ok(OpenedStreamDelivery { stream_id, - response: ByteReader::new( + reader: ByteReader::new( stream_id, CloseTarget::Return, response_reader, @@ -323,16 +323,16 @@ impl DriverState { DriverStreamIo::new_responder(request_writer, request_terminal_tx, response_reader), ); - platform.handle_inbound(InboundStream { + platform.handle_inbound(QlStream { stream_id, - request: ByteReader::new( + reader: ByteReader::new( stream_id, CloseTarget::Origin, request_reader, request_terminal_rx, self.runtime_tx.clone(), ), - response: ByteWriter::new( + writer: ByteWriter::new( stream_id, CloseTarget::Return, response_writer, diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index bd8188f5..9e587f1c 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -84,7 +84,7 @@ impl QlPlatform for NoopPlatform { fn handle_peer_status(&self, _peer: XID, _status: ql_fsm::PeerStatus) {} - fn handle_inbound(&self, _event: InboundStream) {} + fn handle_inbound(&self, _event: QlStream) {} } fn new_driver_state() -> (DriverState, QlFsm) { diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index 2f847453..eaa08b29 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -7,17 +7,10 @@ pub use self::{reader::*, writer::*}; use crate::{command::RuntimeCommand, OpenedStreamDelivery, QlError}; #[derive(Debug)] -pub struct OutboundStream { +pub struct QlStream { pub stream_id: StreamId, - pub request: ByteWriter, - pub response: ByteReader, -} - -#[derive(Debug)] -pub struct InboundStream { - pub stream_id: StreamId, - pub request: ByteReader, - pub response: ByteWriter, + pub writer: ByteWriter, + pub reader: ByteReader, } #[derive(Clone)] @@ -41,7 +34,7 @@ impl RuntimeHandle { self.send(RuntimeCommand::Incoming(bytes)) } - pub async fn open_stream(&self) -> Result { + pub async fn open_stream(&self) -> Result { let (request_reader, request_writer) = piper::pipe(self.stream_send_buffer_bytes); let (start_tx, start_rx) = oneshot::channel(); @@ -53,20 +46,18 @@ impl RuntimeHandle { .await .map_err(|_| QlError::Cancelled)?; - let OpenedStreamDelivery { - stream_id, - response, - } = start_rx.await.unwrap_or(Err(QlError::Cancelled))?; + let OpenedStreamDelivery { stream_id, reader } = + start_rx.await.unwrap_or(Err(QlError::Cancelled))?; - Ok(OutboundStream { + Ok(QlStream { stream_id, - request: ByteWriter::new( + writer: ByteWriter::new( stream_id, CloseTarget::Origin, request_writer, self.tx.clone(), ), - response, + reader, }) } diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index 2fe7137b..2444c0ce 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -1,4 +1,4 @@ -pub use handle::{ByteReader, ByteWriter, InboundStream, OutboundStream, RuntimeHandle}; +pub use handle::{ByteReader, ByteWriter, QlStream, RuntimeHandle}; pub use ql_fsm::{PeerStatus, QlFsmConfig, QlFsmError, SessionWriteId}; pub use ql_wire::{ self as wire, CloseTarget, PeerBundle, QlIdentity, SessionCloseCode, StreamCloseCode, StreamId, @@ -99,7 +99,7 @@ impl RuntimeConfig { pub(crate) struct OpenedStreamDelivery { pub stream_id: StreamId, - pub response: crate::ByteReader, + pub reader: crate::ByteReader, } pub struct Runtime

{ diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs index 7bdeaebb..d232bcb1 100644 --- a/ql-runtime/src/platform.rs +++ b/ql-runtime/src/platform.rs @@ -2,7 +2,7 @@ use std::{future::Future, pin::Pin, time::Duration}; use ql_wire::QlCrypto; -use crate::{PeerBundle, PeerStatus, QlError, XID}; +use crate::{PeerBundle, PeerStatus, QlError, QlStream, XID}; pub type PlatformFuture<'a, T> = Pin + 'a>>; @@ -14,5 +14,5 @@ pub trait QlPlatform: QlCrypto { fn persist_peer(&self, peer: PeerBundle); fn handle_peer_status(&self, peer: XID, status: PeerStatus); - fn handle_inbound(&self, event: super::InboundStream); + fn handle_inbound(&self, event: QlStream); } diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index 3d53830c..69859eea 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -108,15 +108,15 @@ async fn rejected_session_write_is_reissued() { let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); - let request = read_all(stream.request).await.unwrap(); - stream.response.finish().await.unwrap(); + let request = read_all(stream.reader).await.unwrap(); + stream.writer.finish().await.unwrap(); request }); let mut stream = handle_a.open_stream().await.unwrap(); - stream.request.write_all(b"retry").await.unwrap(); - stream.request.finish().await.unwrap(); - assert_eq!(next_chunk(&mut stream.response).await.unwrap(), None); + stream.writer.write_all(b"retry").await.unwrap(); + stream.writer.finish().await.unwrap(); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); assert_eq!( tokio::time::timeout(Duration::from_secs(2), responder) diff --git a/ql-runtime/src/tests/heartbeat.rs b/ql-runtime/src/tests/heartbeat.rs index 9859f207..1aa6e058 100644 --- a/ql-runtime/src/tests/heartbeat.rs +++ b/ql-runtime/src/tests/heartbeat.rs @@ -43,24 +43,21 @@ async fn session_timeout_disconnects_and_fails_pending_open() { let responder_task = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); - let _ = read_all(stream.request).await; - let response = stream.response; - let _ = response.finish().await; + let _ = read_all(stream.reader).await; + let _ = stream.writer.finish().await; }); drop_flag.store(true, Ordering::Relaxed); let mut pending = handle_a.open_stream().await.unwrap(); - pending.request.finish().await.unwrap(); + pending.writer.finish().await.unwrap(); await_status(&status_a, identity_b.xid, PeerStage::Disconnected).await; - let result = tokio::time::timeout( - Duration::from_millis(300), - next_chunk(&mut pending.response), - ) - .await - .unwrap(); + let result = + tokio::time::timeout(Duration::from_millis(300), next_chunk(&mut pending.reader)) + .await + .unwrap(); assert!(matches!( result, Err(QlError::SessionClosed) | Err(QlError::Cancelled) diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index d45159c2..5bfa5af9 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -21,7 +21,7 @@ use sha2::{Digest, Sha256}; use tokio::task::LocalSet; use crate::{ - new_runtime, platform::PlatformFuture, InboundStream, PeerStatus, QlError, QlFsmConfig, + new_runtime, platform::PlatformFuture, PeerStatus, QlError, QlFsmConfig, QlStream, RuntimeConfig, RuntimeHandle, }; @@ -164,7 +164,7 @@ impl QlKem for DeterministicCrypto { struct TestPlatform { outbound: Sender>, status: Sender, - inbound: Option>, + inbound: Option>, nonce_seed: u8, nonce_counter: AtomicU8, encrypted_write_counter: AtomicUsize, @@ -184,7 +184,7 @@ impl TestPlatform { Self, Receiver>, Receiver, - Receiver, + Receiver, ) { let (inbound_tx, inbound_rx) = async_channel::unbounded(); let (platform, outbound_rx, status_rx) = @@ -215,7 +215,7 @@ impl TestPlatform { fn new_inner( seed: u8, - inbound: Option>, + inbound: Option>, fail_encrypted_write_at: Option, write_delay: Duration, write_stats: Option, @@ -384,7 +384,7 @@ impl crate::platform::QlPlatform for TestPlatform { let _ = self.status.try_send(StatusEvent { peer, stage }); } - fn handle_inbound(&self, event: InboundStream) { + fn handle_inbound(&self, event: QlStream) { if let Some(tx) = &self.inbound { let _ = tx.try_send(event); } diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index 17cae955..2d091454 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -30,30 +30,27 @@ async fn open_stream_duplex_happy_path() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); - let mut request = inbound.request; - let mut response = inbound.response; - - assert_eq!(next_chunk(&mut request).await.unwrap(), Some(vec![1, 2])); - response.write_all(&[9]).await.unwrap(); - assert_eq!(next_chunk(&mut request).await.unwrap(), Some(vec![3, 4])); - response.write_all(&[8, 7]).await.unwrap(); - assert_eq!(next_chunk(&mut request).await.unwrap(), None); - response.finish().await.unwrap(); + let mut writer = inbound.writer; + let mut reader = inbound.reader; + + assert_eq!(next_chunk(&mut reader).await.unwrap(), Some(vec![1, 2])); + writer.write_all(&[9]).await.unwrap(); + assert_eq!(next_chunk(&mut reader).await.unwrap(), Some(vec![3, 4])); + writer.write_all(&[8, 7]).await.unwrap(); + assert_eq!(next_chunk(&mut reader).await.unwrap(), None); + writer.finish().await.unwrap(); }); let mut stream = handle_a.open_stream().await.unwrap(); - stream.request.write_all(&[1, 2]).await.unwrap(); + stream.writer.write_all(&[1, 2]).await.unwrap(); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), Some(vec![9])); + stream.writer.write_all(&[3, 4]).await.unwrap(); + stream.writer.finish().await.unwrap(); assert_eq!( - next_chunk(&mut stream.response).await.unwrap(), - Some(vec![9]) - ); - stream.request.write_all(&[3, 4]).await.unwrap(); - stream.request.finish().await.unwrap(); - assert_eq!( - next_chunk(&mut stream.response).await.unwrap(), + next_chunk(&mut stream.reader).await.unwrap(), Some(vec![8, 7]) ); - assert_eq!(next_chunk(&mut stream.response).await.unwrap(), None); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); tokio::time::timeout(Duration::from_secs(2), responder) .await @@ -95,15 +92,15 @@ async fn stream_backpressure_with_small_runtime_buffer() { let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); - let request_data = read_all(stream.request).await.unwrap(); - stream.response.finish().await.unwrap(); + let request_data = read_all(stream.reader).await.unwrap(); + stream.writer.finish().await.unwrap(); done_tx.send(request_data).await.unwrap(); }); let mut stream = handle_a.open_stream().await.unwrap(); - stream.request.write_all(&payload).await.unwrap(); - stream.request.finish().await.unwrap(); - assert_eq!(next_chunk(&mut stream.response).await.unwrap(), None); + stream.writer.write_all(&payload).await.unwrap(); + stream.writer.finish().await.unwrap(); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); let received = tokio::time::timeout(Duration::from_secs(2), done_rx.recv()) .await @@ -145,13 +142,13 @@ async fn dropping_responder_closes_initiator_response() { let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); - drop(stream.response); + drop(stream.reader); }); let mut stream = handle_a.open_stream().await.unwrap(); - stream.request.finish().await.unwrap(); + stream.writer.finish().await.unwrap(); - let err = next_chunk(&mut stream.response).await.unwrap_err(); + let err = next_chunk(&mut stream.reader).await.unwrap_err(); assert!(matches!( err, QlError::StreamClosed { @@ -198,22 +195,22 @@ async fn dropping_inbound_reader_cancels_remote_writer() { let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); - let mut request = stream.request; - let mut response = stream.response; - assert_eq!(next_chunk(&mut request).await.unwrap(), None); - response.write_all(&[1, 2, 3, 4]).await.unwrap(); + let mut writer = stream.writer; + let mut reader = stream.reader; + assert_eq!(next_chunk(&mut reader).await.unwrap(), None); + writer.write_all(&[1, 2, 3, 4]).await.unwrap(); go_rx.recv().await.unwrap(); - let err = response.write_all(&[5; 64]).await.unwrap_err(); + let err = writer.write_all(&[5; 64]).await.unwrap_err(); assert!(matches!(err, QlError::Cancelled)); }); let mut stream = handle_a.open_stream().await.unwrap(); - stream.request.finish().await.unwrap(); + stream.writer.finish().await.unwrap(); assert_eq!( - next_chunk(&mut stream.response).await.unwrap(), + next_chunk(&mut stream.reader).await.unwrap(), Some(vec![1, 2, 3, 4]) ); - drop(stream.response); + drop(stream.reader); go_tx.send(()).await.unwrap(); tokio::time::timeout(Duration::from_secs(2), responder) @@ -256,8 +253,8 @@ async fn max_concurrent_message_writes_is_respected() { let responder = tokio::task::spawn_local(async move { for _ in 0..4 { let stream = inbound_b.recv().await.unwrap(); - let _ = read_all(stream.request).await; - let _ = stream.response.finish().await; + let _ = read_all(stream.reader).await; + let _ = stream.writer.finish().await; } }); @@ -266,9 +263,9 @@ async fn max_concurrent_message_writes_is_respected() { let handle = handle_a.clone(); tasks.push(tokio::task::spawn_local(async move { let mut stream = handle.open_stream().await.unwrap(); - stream.request.write_all(&[i; 8]).await.unwrap(); - stream.request.finish().await.unwrap(); - assert_eq!(next_chunk(&mut stream.response).await.unwrap(), None); + stream.writer.write_all(&[i; 8]).await.unwrap(); + stream.writer.finish().await.unwrap(); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); })); } @@ -330,19 +327,19 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); - let received_request = read_all(stream.request).await.unwrap(); - let mut response = stream.response; - response.write_all(&response_payload).await.unwrap(); - response.finish().await.unwrap(); + let received_request = read_all(stream.reader).await.unwrap(); + let mut writer = stream.writer; + writer.write_all(&response_payload).await.unwrap(); + writer.finish().await.unwrap(); received_request }); let mut stream = handle_a.open_stream().await.unwrap(); - stream.request.write_all(&request_payload).await.unwrap(); - stream.request.finish().await.unwrap(); + stream.writer.write_all(&request_payload).await.unwrap(); + stream.writer.finish().await.unwrap(); let mut received_response = Vec::new(); - while let Some(chunk) = next_chunk(&mut stream.response).await.unwrap() { + while let Some(chunk) = next_chunk(&mut stream.reader).await.unwrap() { received_response.extend_from_slice(&chunk); } assert_eq!(received_response, expected_response); From 024936c1a1e5c6b2f87a97cc1ae667ff94073abb Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 11:54:41 -0400 Subject: [PATCH 107/304] ql-rpc: update --- ql-rpc/src/error.rs | 8 +- ql-runtime/src/driver/mod.rs | 79 ++++++++--------- ql-runtime/src/driver/state.rs | 6 +- ql-runtime/src/driver/test.rs | 6 +- ql-runtime/src/handle/mod.rs | 6 +- ql-runtime/src/handle/reader.rs | 2 +- ql-runtime/src/handle/writer.rs | 2 +- ql-runtime/src/rpc/mod.rs | 98 ++++++--------------- ql-runtime/src/rpc/request_with_progress.rs | 17 ++-- ql-runtime/src/rpc/subscription.rs | 11 ++- ql-runtime/src/tests/heartbeat.rs | 2 +- ql-runtime/src/tests/mod.rs | 25 +++++- ql-runtime/src/tests/rpc.rs | 25 +++--- 13 files changed, 135 insertions(+), 152 deletions(-) diff --git a/ql-rpc/src/error.rs b/ql-rpc/src/error.rs index 95b4fff7..65fdcfac 100644 --- a/ql-rpc/src/error.rs +++ b/ql-rpc/src/error.rs @@ -1,4 +1,4 @@ -use ql_wire::CloseCode; +use ql_wire::StreamCloseCode; use crate::MethodId; @@ -24,15 +24,15 @@ pub enum RpcError { } impl RpcError { - pub const fn close_code(self) -> CloseCode { + pub const fn close_code(self) -> StreamCloseCode { match self { - Self::UnexpectedMethod { .. } => CloseCode::UNKNOWN_ROUTE, + Self::UnexpectedMethod { .. } => StreamCloseCode(404), Self::Truncated | Self::LengthOverflow | Self::InvalidVersion(_) | Self::UnexpectedFrameKind(_) | Self::MissingResponse - | Self::TrailingBytes => CloseCode::INVALID_HEAD, + | Self::TrailingBytes => StreamCloseCode(400), } } } diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 0291591d..474e8004 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -13,7 +13,7 @@ use futures_lite::future::poll_fn; use ql_fsm::{FsmTime, QlFsm, QlFsmEvent, SessionWriteId}; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; -use self::state::*; +use self::state::{DriverState, DriverStreamIo, InboundIo, InboundWriteResult, OutboundIo}; use crate::{ command::RuntimeCommand, handle::{ByteReader, ByteWriter, QlStream}, @@ -22,8 +22,9 @@ use crate::{ }; impl Runtime

{ + #[allow(clippy::future_not_send)] pub async fn run(self) { - let Runtime { + let Self { identity, platform, config, @@ -58,21 +59,21 @@ impl Runtime

{ match next_driver_event(&rx, &platform, fsm.next_deadline(), &mut in_flight).await { DriverEvent::Command(command) => { - state.drive_command(&mut fsm, command, &platform, &mut in_flight) + state.drive_command(&mut fsm, command, &platform, &mut in_flight); } - DriverEvent::WriteCompleted { index, result } => { + DriverEvent::WriteCompleted { index, success } => { let write = in_flight.swap_remove(index); state.drive_write_completed( &mut fsm, write.session_write_id, - result, + success, &platform, &mut in_flight, ); } DriverEvent::TimerExpired => { state.with_fsm_events(&mut fsm, &platform, |fsm, emit| { - fsm.on_timer(now(), emit) + fsm.on_timer(now(), emit); }); state.finish_step(&mut fsm, &platform, &mut in_flight); } @@ -89,14 +90,12 @@ struct InFlightWrite<'a> { enum DriverEvent { Command(RuntimeCommand), - WriteCompleted { - index: usize, - result: Result<(), QlError>, - }, + WriteCompleted { index: usize, success: bool }, TimerExpired, CommandsClosed, } +#[allow(clippy::future_not_send)] async fn next_driver_event( rx: &async_channel::Receiver, platform: &P, @@ -112,22 +111,24 @@ async fn next_driver_event( poll_fn(|cx| { for (index, write) in in_flight.iter_mut().enumerate() { if let Poll::Ready(result) = write.future.as_mut().poll(cx) { - return Poll::Ready(DriverEvent::WriteCompleted { index, result }); + return Poll::Ready(DriverEvent::WriteCompleted { + index, + success: result.is_ok(), + }); } } if let Some(future) = sleep_future.as_mut() { - if let Poll::Ready(()) = future.as_mut().poll(cx) { + if future.as_mut().poll(cx) == Poll::Ready(()) { return Poll::Ready(DriverEvent::TimerExpired); } } if let Some(future) = recv_future.as_mut() { if let Poll::Ready(res) = future.as_mut().poll(cx) { - return Poll::Ready(match res { - Ok(command) => DriverEvent::Command(command), - Err(_) => DriverEvent::CommandsClosed, - }); + return Poll::Ready( + res.map_or_else(|_| DriverEvent::CommandsClosed, DriverEvent::Command), + ); } } @@ -224,24 +225,25 @@ impl DriverState { } fn drive_write_completed<'a, P: QlPlatform>( - &mut self, + &self, fsm: &mut QlFsm, session_write_id: Option, - result: Result<(), QlError>, + success: bool, platform: &'a P, in_flight: &mut Vec>, ) { if let Some(write_id) = session_write_id { - match result { - Ok(()) => fsm.confirm_session_write(now(), write_id), - Err(_) => fsm.reject_session_write(write_id), + if success { + fsm.confirm_session_write(now(), write_id); + } else { + fsm.reject_session_write(write_id); } } self.finish_step(fsm, platform, in_flight); } fn finish_step<'a, P: QlPlatform>( - &mut self, + &self, fsm: &mut QlFsm, platform: &'a P, in_flight: &mut Vec>, @@ -304,12 +306,12 @@ impl DriverState { self.handle_inbound_finished(fsm, stream_id); } QlFsmEvent::Closed(frame) => { - self.handle_closed_stream(frame); + self.handle_closed_stream(&frame); } QlFsmEvent::WritableClosed(stream_id) => { self.handle_writable_closed(stream_id); } - QlFsmEvent::SessionClosed(_) => self.fail_all_streams(QlError::SessionClosed), + QlFsmEvent::SessionClosed(_) => self.fail_all_streams(&QlError::SessionClosed), } } @@ -398,7 +400,7 @@ impl DriverState { self.finish_inbound_if_ready(fsm, stream_id); } - fn handle_inbound_finished(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { + fn handle_inbound_finished(&mut self, fsm: &QlFsm, stream_id: StreamId) { let Some(stream) = self.streams.get_mut(&stream_id) else { return; }; @@ -406,7 +408,7 @@ impl DriverState { self.finish_inbound_if_ready(fsm, stream_id); } - fn finish_inbound_if_ready(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { + fn finish_inbound_if_ready(&mut self, fsm: &QlFsm, stream_id: StreamId) { if fsm.stream_available_bytes(stream_id).unwrap_or(0) != 0 { return; } @@ -422,18 +424,16 @@ impl DriverState { self.try_reap_stream(stream_id); } - fn handle_closed_stream(&mut self, frame: ql_wire::StreamClose) { + fn handle_closed_stream(&mut self, frame: &ql_wire::StreamClose) { let Some(stream) = self.streams.get_mut(&frame.stream_id) else { return; }; - let error = QlError::StreamClosed { - target: frame.target, - code: frame.code, - }; - if frame.target == CloseTarget::Both || frame.target == stream.inbound_target() { - stream.inbound_mut().fail(error.clone()); + stream.inbound_mut().fail(QlError::StreamClosed { + target: frame.target, + code: frame.code, + }); } if frame.target == CloseTarget::Both || frame.target == stream.outbound_target() { stream.outbound_mut().close(); @@ -449,15 +449,15 @@ impl DriverState { self.try_reap_stream(stream_id); } - fn fail_all_streams(&mut self, error: QlError) { + fn fail_all_streams(&mut self, error: &QlError) { for stream in self.streams.values_mut() { - stream.fail_all(error.clone()); + stream.fail_all(error); } self.streams.clear(); } fn fill_write_slots<'a, P: QlPlatform>( - &mut self, + &self, fsm: &mut QlFsm, platform: &'a P, in_flight: &mut Vec>, @@ -495,17 +495,14 @@ impl DriverState { } else { let bytes = reader.peek_buf(); if bytes.is_empty() { - if reader.is_closed() && reader.len() == 0 && !*finish_queued { + if reader.is_closed() && reader.is_empty() && !*finish_queued { *finish_queued = true; should_finish = true; } false } else { let len = bytes.len(); - let accepted = match fsm.write_stream(stream_id, bytes) { - Ok(accepted) => accepted, - Err(_) => 0, - }; + let accepted = fsm.write_stream(stream_id, bytes).unwrap_or_default(); if accepted > 0 { reader.consume(accepted); } diff --git a/ql-runtime/src/driver/state.rs b/ql-runtime/src/driver/state.rs index 8b941da1..9dd47b4d 100644 --- a/ql-runtime/src/driver/state.rs +++ b/ql-runtime/src/driver/state.rs @@ -76,18 +76,18 @@ impl DriverStreamIo { } } - pub fn fail_all(&mut self, error: QlError) { + pub fn fail_all(&mut self, error: &QlError) { match self { Self::Initiator { request, response, .. } => { request.close(); - response.fail(error); + response.fail(error.clone()); } Self::Responder { request, response, .. } => { - request.fail(error); + request.fail(error.clone()); response.close(); } } diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index 9e587f1c..09867f3e 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -110,7 +110,7 @@ fn new_inbound_io(capacity: usize) -> InboundIo { #[test] fn handle_inbound_finished_reaps_closed_initiator_stream() { - let (mut state, mut fsm) = new_driver_state(); + let (mut state, fsm) = new_driver_state(); let stream_id = StreamId(1); state.streams.insert( @@ -121,7 +121,7 @@ fn handle_inbound_finished_reaps_closed_initiator_stream() { }, ); - state.handle_inbound_finished(&mut fsm, stream_id); + state.handle_inbound_finished(&fsm, stream_id); assert!(!state.streams.contains_key(&stream_id)); } @@ -140,7 +140,7 @@ fn handle_closed_stream_reaps_when_both_halves_close() { }, ); - state.handle_closed_stream(StreamClose { + state.handle_closed_stream(&StreamClose { stream_id, target: CloseTarget::Both, code: StreamCloseCode(0), diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index eaa08b29..5dea054d 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -21,7 +21,7 @@ pub struct RuntimeHandle { impl RuntimeHandle { pub fn bind_peer(&self, peer: PeerBundle) { - self.send(RuntimeCommand::BindPeer { peer }) + self.send(RuntimeCommand::BindPeer { peer }); } pub fn connect(&self) -> Result<(), QlError> { @@ -31,7 +31,7 @@ impl RuntimeHandle { } pub fn send_incoming(&self, bytes: Vec) { - self.send(RuntimeCommand::Incoming(bytes)) + self.send(RuntimeCommand::Incoming(bytes)); } pub async fn open_stream(&self) -> Result { @@ -73,6 +73,6 @@ impl RuntimeHandle { #[inline] #[track_caller] fn send(&self, cmd: RuntimeCommand) { - self.tx.send_blocking(cmd).expect("runtime is alive") + self.tx.send_blocking(cmd).expect("runtime is alive"); } } diff --git a/ql-runtime/src/handle/reader.rs b/ql-runtime/src/handle/reader.rs index 988e94b8..a552d061 100644 --- a/ql-runtime/src/handle/reader.rs +++ b/ql-runtime/src/handle/reader.rs @@ -57,7 +57,7 @@ impl ByteReader { return Poll::Ready(Ok(None)); } - if let Poll::Ready(true) = self.reader.poll(cx) { + if self.reader.poll(cx) == Poll::Ready(true) { return Poll::Ready(Ok(Some(self.reader.peek_buf()))); } diff --git a/ql-runtime/src/handle/writer.rs b/ql-runtime/src/handle/writer.rs index 82e1786b..01b7c174 100644 --- a/ql-runtime/src/handle/writer.rs +++ b/ql-runtime/src/handle/writer.rs @@ -73,7 +73,7 @@ impl ByteWriter { if self.writer.take().is_none() { return Ok(()); } - self.poll_runtime() + std::future::ready(self.poll_runtime()).await } pub async fn close(mut self, code: StreamCloseCode) -> Result<(), QlError> { diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index d7d72b56..cd8ad89d 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -1,10 +1,10 @@ -use std::task::{Context, Poll}; - mod error; mod request_with_progress; mod subscription; -pub use error::*; +use std::task::Poll; + +use futures_lite::future::poll_fn; use ql_rpc::{ notification::{self, Notification}, request::{self, Request as RequestRpc}, @@ -12,21 +12,15 @@ use ql_rpc::{ subscription::{self as rpc_subscription, Subscription as SubscriptionRpc}, RpcError, }; -pub use request_with_progress::*; -pub use subscription::*; -use crate::{ByteReader, OutboundStream, QlError, RuntimeHandle}; +pub use self::{error::*, request_with_progress::*, subscription::*}; +use crate::{ByteReader, QlError, RuntimeHandle}; #[derive(Clone)] pub struct RpcHandle { pub(crate) inner: RuntimeHandle, } -pub(super) enum ChunkState { - Open(ByteReader), - Closed, -} - impl RpcHandle { pub async fn event(&self, event: &M::Event) -> Result<(), RpcCallError> where @@ -34,12 +28,8 @@ impl RpcHandle { { let mut payload = Vec::new(); notification::encode_event::(event, &mut payload).map_err(RpcCallError::Codec)?; - - let response = self - .start_request(payload) - .await - .map_err(RpcCallError::Runtime)?; - let response = read_all(response).await.map_err(RpcCallError::Runtime)?; + let response = self.start_request(payload).await?; + let response = read_all(response).await?; if response.is_empty() { Ok(()) } else { @@ -56,11 +46,8 @@ impl RpcHandle { { let mut payload = Vec::new(); request::encode_request::(request, &mut payload).map_err(RpcCallError::Codec)?; - let response = self - .start_request(payload) - .await - .map_err(RpcCallError::Runtime)?; - let response = read_all(response).await.map_err(RpcCallError::Runtime)?; + let response = self.start_request(payload).await?; + let response = read_all(response).await?; request::decode_response::(&response).map_err(RpcCallError::Codec) } @@ -74,13 +61,9 @@ impl RpcHandle { let mut payload = Vec::new(); rpc_subscription::encode_request::(request, &mut payload) .map_err(RpcCallError::Codec)?; - - let response = self - .start_request(payload) - .await - .map_err(RpcCallError::Runtime)?; + let response = self.start_request(payload).await?; Ok(Subscription { - chunks: ChunkState::new(response), + stream: response, reader: Some(rpc_subscription::ResponseReader::new()), }) } @@ -95,59 +78,36 @@ impl RpcHandle { let mut payload = Vec::new(); rpc_request_with_progress::encode_request::(request, &mut payload) .map_err(RpcCallError::Codec)?; - - let response = self - .start_request(payload) - .await - .map_err(RpcCallError::Runtime)?; + let response = self.start_request(payload).await?; Ok(ProgressCall { - chunks: ChunkState::new(response), + stream: response, reader: Some(rpc_request_with_progress::ResponseReader::new()), terminal: None, }) } async fn start_request(&self, payload: Vec) -> Result { - let OutboundStream { - mut request, - response, - .. - } = self.inner.open_stream().await?; - - request.write_all(&payload).await?; - request.finish().await?; - Ok(response) - } -} - -impl ChunkState { - fn new(reader: ByteReader) -> Self { - Self::Open(reader) - } - - fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll>, QlError>> { - match self { - Self::Open(reader) => match reader.poll_next_chunk(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(Some(bytes))) => Poll::Ready(Ok(Some(bytes))), - Poll::Ready(Ok(None)) => { - *self = Self::Closed; - Poll::Ready(Ok(None)) - } - Poll::Ready(Err(error)) => { - *self = Self::Closed; - Poll::Ready(Err(error)) - } - }, - Self::Closed => Poll::Ready(Ok(None)), - } + let mut stream = self.inner.open_stream().await?; + stream.writer.write_all(&payload).await?; + stream.writer.finish().await?; + Ok(stream.reader) } } async fn read_all(mut reader: ByteReader) -> Result, QlError> { let mut bytes = Vec::new(); - while let Some(chunk) = reader.next_chunk().await? { - bytes.extend_from_slice(&chunk); + while let Some(len) = poll_fn(|cx| match reader.poll_fill_buf(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(Some(chunk))) => { + bytes.extend_from_slice(chunk); + Poll::Ready(Ok(Some(chunk.len()))) + } + Poll::Ready(Ok(None)) => Poll::Ready(Ok(None)), + Poll::Ready(Err(error)) => Poll::Ready(Err(error)), + }) + .await? + { + reader.consume(len); } Ok(bytes) } diff --git a/ql-runtime/src/rpc/request_with_progress.rs b/ql-runtime/src/rpc/request_with_progress.rs index 89021929..6b0066d9 100644 --- a/ql-runtime/src/rpc/request_with_progress.rs +++ b/ql-runtime/src/rpc/request_with_progress.rs @@ -10,10 +10,11 @@ use ql_rpc::{ RpcError, }; -use super::{ChunkState, RpcCallError}; +use super::RpcCallError; +use crate::ByteReader; pub struct ProgressCall { - pub(super) chunks: ChunkState, + pub(super) stream: ByteReader, pub(super) reader: Option>, pub(super) terminal: Option>>, } @@ -62,10 +63,12 @@ where } } - match this.chunks.poll_next(cx) { + match this.stream.poll_fill_buf(cx) { Poll::Ready(Ok(Some(chunk))) => { + let len = chunk.len(); let reader = this.reader.take().expect("progress reader is present"); - this.reader = Some(reader.push(&chunk)); + this.reader = Some(reader.push(chunk)); + this.stream.consume(len); } Poll::Ready(Ok(None)) => { this.reader = None; @@ -114,10 +117,12 @@ where Err(error) => return Poll::Ready(Err(error.into())), } - match this.chunks.poll_next(cx) { + match this.stream.poll_fill_buf(cx) { Poll::Ready(Ok(Some(chunk))) => { + let len = chunk.len(); let reader = this.reader.take().expect("progress reader is present"); - this.reader = Some(reader.push(&chunk)); + this.reader = Some(reader.push(chunk)); + this.stream.consume(len); } Poll::Ready(Ok(None)) => { this.reader = None; diff --git a/ql-runtime/src/rpc/subscription.rs b/ql-runtime/src/rpc/subscription.rs index 2b2e51fa..831a2b92 100644 --- a/ql-runtime/src/rpc/subscription.rs +++ b/ql-runtime/src/rpc/subscription.rs @@ -9,10 +9,11 @@ use ql_rpc::{ RpcError, }; -use super::{ChunkState, RpcCallError}; +use super::RpcCallError; +use crate::ByteReader; pub struct Subscription { - pub(super) chunks: ChunkState, + pub(super) stream: ByteReader, pub(super) reader: Option>, } @@ -53,10 +54,12 @@ where Err(error) => return Poll::Ready(Some(Err(error.into()))), } - match this.chunks.poll_next(cx) { + match this.stream.poll_fill_buf(cx) { Poll::Ready(Ok(Some(chunk))) => { + let len = chunk.len(); let reader = this.reader.take().expect("subscription reader is present"); - this.reader = Some(reader.push(&chunk)); + this.reader = Some(reader.push(chunk)); + this.stream.consume(len); } Poll::Ready(Ok(None)) => { this.reader = None; diff --git a/ql-runtime/src/tests/heartbeat.rs b/ql-runtime/src/tests/heartbeat.rs index 1aa6e058..23642858 100644 --- a/ql-runtime/src/tests/heartbeat.rs +++ b/ql-runtime/src/tests/heartbeat.rs @@ -60,7 +60,7 @@ async fn session_timeout_disconnects_and_fails_pending_open() { .unwrap(); assert!(matches!( result, - Err(QlError::SessionClosed) | Err(QlError::Cancelled) + Err(QlError::SessionClosed | QlError::Cancelled) )); responder_task.abort(); diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 5bfa5af9..02d0bb96 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -342,11 +342,12 @@ impl crate::platform::QlPlatform for TestPlatform { tokio::time::sleep(write_delay).await; } - let mut should_fail = false; - if is_encrypted_payload(&message) { + let should_fail = if is_encrypted_payload(&message) { let count = self.encrypted_write_counter.fetch_add(1, Ordering::Relaxed) + 1; - should_fail = fail_encrypted_write_at == Some(count); - } + fail_encrypted_write_at == Some(count) + } else { + false + }; let result = if should_fail { Err(QlError::SendFailed) @@ -454,6 +455,7 @@ fn spawn_gated_forwarder( }); } +#[allow(clippy::future_not_send)] async fn run_local_test(future: F) where F: Future, @@ -528,3 +530,18 @@ fn default_runtime_config() -> RuntimeConfig { ..Default::default() } } + +// runtime is send, though the Runtime::run future itself is not +#[test] +fn runtime_is_send() { + let config = default_runtime_config(); + let identity_a = new_identity(11); + let (platform_a, _, _) = TestPlatform::new(1); + let (runtime_a, _handle) = new_runtime(identity_a.clone(), platform_a, config); + std::thread::spawn(move || { + tokio::runtime::Builder::new_current_thread() + .build() + .unwrap() + .block_on(runtime_a.run()) + }); +} diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 0fd48c1a..3f7be9ff 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -4,6 +4,7 @@ use bytes::Buf; use futures_lite::StreamExt; use super::*; + #[derive(Debug, Clone, PartialEq, Eq)] struct BytesValue(Vec); @@ -74,7 +75,7 @@ async fn rpc_request_round_trips() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); - let request = read_all(inbound.request).await.unwrap(); + let request = read_all(inbound.reader).await.unwrap(); let rpc_inbound = ql_rpc::parse_inbound(&request).unwrap(); assert_eq!( ql_rpc::request::decode_request::(rpc_inbound.body).unwrap(), @@ -84,9 +85,9 @@ async fn rpc_request_round_trips() { let mut encoded = Vec::new(); ql_rpc::request::encode_response::(&BytesValue(b"world".to_vec()), &mut encoded) .unwrap(); - let mut response = inbound.response; - response.write_all(&encoded).await.unwrap(); - response.finish().await.unwrap(); + let mut writer = inbound.writer; + writer.write_all(&encoded).await.unwrap(); + writer.finish().await.unwrap(); }); let rpc = handle_a.rpc(); @@ -130,7 +131,7 @@ async fn rpc_subscription_streams_events() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); - let request = read_all(inbound.request).await.unwrap(); + let request = read_all(inbound.reader).await.unwrap(); let rpc_inbound = ql_rpc::parse_inbound(&request).unwrap(); assert_eq!( ql_rpc::subscription::decode_request::(rpc_inbound.body).unwrap(), @@ -144,9 +145,9 @@ async fn rpc_subscription_streams_events() { .unwrap(); ql_rpc::subscription::encode_end(&mut encoded); - let mut response = inbound.response; - response.write_all(&encoded).await.unwrap(); - response.finish().await.unwrap(); + let mut writer = inbound.writer; + writer.write_all(&encoded).await.unwrap(); + writer.finish().await.unwrap(); }); let rpc = handle_a.rpc(); @@ -198,7 +199,7 @@ async fn rpc_request_with_progress_supports_progress_then_await() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); - let request = read_all(inbound.request).await.unwrap(); + let request = read_all(inbound.reader).await.unwrap(); let rpc_inbound = ql_rpc::parse_inbound(&request).unwrap(); assert_eq!( ql_rpc::request_with_progress::decode_request::(rpc_inbound.body) @@ -223,9 +224,9 @@ async fn rpc_request_with_progress_supports_progress_then_await() { ) .unwrap(); - let mut response = inbound.response; - response.write_all(&encoded).await.unwrap(); - response.finish().await.unwrap(); + let mut writer = inbound.writer; + writer.write_all(&encoded).await.unwrap(); + writer.finish().await.unwrap(); }); let rpc = handle_a.rpc(); From 23c575d2ff3c4305f4c27e6e4a4ae6018a4882f0 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 13:01:54 -0400 Subject: [PATCH 108/304] ql-runtime: remove reference leak --- ql-runtime/src/driver/mod.rs | 100 ++++++++++++++++++++------------- ql-runtime/src/driver/state.rs | 2 +- ql-runtime/src/driver/test.rs | 2 +- ql-runtime/src/tests/mod.rs | 27 ++++++++- 4 files changed, 88 insertions(+), 43 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 474e8004..e851b324 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -32,7 +32,6 @@ impl Runtime

{ tx, } = self; - let runtime_tx = tx.upgrade().expect("runtime tx"); let mut fsm = QlFsm::new(config.fsm, identity, now()); let mut peer_xid = None; if let Some(peer) = platform.load_peer().await { @@ -42,7 +41,7 @@ impl Runtime

{ let mut state = DriverState { streams: HashMap::new(), - runtime_tx, + runtime_tx: tx, stream_send_buffer_bytes: config.stream_send_buffer_bytes, max_concurrent_message_writes: config.max_concurrent_message_writes, peer_xid, @@ -166,36 +165,54 @@ impl DriverState { RuntimeCommand::OpenStream { request_reader, start, - } => match fsm.open_stream().map_err(QlError::from) { - Ok(stream_id) => { - let (response_reader, response_writer) = - piper::pipe(self.stream_send_buffer_bytes); - let (response_terminal_tx, response_terminal_rx) = oneshot::channel(); - self.streams.insert( - stream_id, - DriverStreamIo::new_initiator( - request_reader, - response_writer, - response_terminal_tx, - ), - ); - let _ = start.send(Ok(OpenedStreamDelivery { - stream_id, - reader: ByteReader::new( + } => { + let Some(runtime_tx) = self.runtime_tx.upgrade() else { + let _ = start.send(Err(QlError::Cancelled)); + return; + }; + + match fsm.open_stream().map_err(QlError::from) { + Ok(stream_id) => { + let (response_reader, response_writer) = + piper::pipe(self.stream_send_buffer_bytes); + let (response_terminal_tx, response_terminal_rx) = oneshot::channel(); + self.streams.insert( stream_id, - CloseTarget::Return, - response_reader, - response_terminal_rx, - self.runtime_tx.clone(), - ), - })); - self.poll_stream(fsm, stream_id); - self.finish_step(fsm, platform, in_flight); - } - Err(error) => { - let _ = start.send(Err(error)); + DriverStreamIo::new_initiator( + request_reader, + response_writer, + response_terminal_tx, + ), + ); + if start + .send(Ok(OpenedStreamDelivery { + stream_id, + reader: ByteReader::new( + stream_id, + CloseTarget::Return, + response_reader, + response_terminal_rx, + runtime_tx, + ), + })) + .is_err() + { + if let Some(stream) = self.streams.get_mut(&stream_id) { + stream.inbound_mut().close(); + stream.outbound_mut().close(); + } + let _ = + fsm.close_stream(stream_id, CloseTarget::Both, StreamCloseCode(0)); + return; + } + self.poll_stream(fsm, stream_id); + self.finish_step(fsm, platform, in_flight); + } + Err(error) => { + let _ = start.send(Err(error)); + } } - }, + } RuntimeCommand::PollInbound { stream_id } => { self.handle_inbound_readable(fsm, stream_id); self.finish_step(fsm, platform, in_flight); @@ -294,7 +311,7 @@ impl DriverState { } } QlFsmEvent::Opened(stream_id) => { - self.handle_opened_stream(platform, stream_id); + self.handle_opened_stream(fsm, platform, stream_id); } QlFsmEvent::Readable(stream_id) => { self.handle_inbound_readable(fsm, stream_id); @@ -315,7 +332,17 @@ impl DriverState { } } - fn handle_opened_stream(&mut self, platform: &P, stream_id: StreamId) { + fn handle_opened_stream( + &mut self, + fsm: &mut QlFsm, + platform: &P, + stream_id: StreamId, + ) { + let Some(runtime_tx) = self.runtime_tx.upgrade() else { + let _ = fsm.close_stream(stream_id, CloseTarget::Both, StreamCloseCode(0)); + return; + }; + let (request_reader, request_writer) = piper::pipe(self.stream_send_buffer_bytes); let (request_terminal_tx, request_terminal_rx) = oneshot::channel(); let (response_reader, response_writer) = piper::pipe(self.stream_send_buffer_bytes); @@ -332,14 +359,9 @@ impl DriverState { CloseTarget::Origin, request_reader, request_terminal_rx, - self.runtime_tx.clone(), - ), - writer: ByteWriter::new( - stream_id, - CloseTarget::Return, - response_writer, - self.runtime_tx.clone(), + runtime_tx.clone(), ), + writer: ByteWriter::new(stream_id, CloseTarget::Return, response_writer, runtime_tx), }); } diff --git a/ql-runtime/src/driver/state.rs b/ql-runtime/src/driver/state.rs index 9dd47b4d..3ce0c350 100644 --- a/ql-runtime/src/driver/state.rs +++ b/ql-runtime/src/driver/state.rs @@ -7,7 +7,7 @@ use crate::{command::RuntimeCommand, QlError}; pub struct DriverState { pub streams: HashMap, - pub runtime_tx: async_channel::Sender, + pub runtime_tx: async_channel::WeakSender, pub stream_send_buffer_bytes: usize, pub max_concurrent_message_writes: usize, pub peer_xid: Option, diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index 09867f3e..d8e45100 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -92,7 +92,7 @@ fn new_driver_state() -> (DriverState, QlFsm) { ( DriverState { streams: HashMap::new(), - runtime_tx, + runtime_tx: runtime_tx.downgrade(), stream_send_buffer_bytes: 16, max_concurrent_message_writes: 1, peer_xid: None, diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 02d0bb96..f7afa2d0 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -537,11 +537,34 @@ fn runtime_is_send() { let config = default_runtime_config(); let identity_a = new_identity(11); let (platform_a, _, _) = TestPlatform::new(1); - let (runtime_a, _handle) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_a, _handle) = new_runtime(identity_a, platform_a, config); std::thread::spawn(move || { tokio::runtime::Builder::new_current_thread() .build() .unwrap() - .block_on(runtime_a.run()) + .block_on(runtime_a.run()); }); } + +#[test] +fn runtime_exits_when_last_handle_drops() { + let config = default_runtime_config(); + let identity = new_identity(11); + let (platform, _, _) = TestPlatform::new(1); + let (runtime, handle) = new_runtime(identity, platform, config); + let (done_tx, done_rx) = oneshot::channel(); + + std::thread::spawn(move || { + tokio::runtime::Builder::new_current_thread() + .build() + .unwrap() + .block_on(runtime.run()); + done_tx.send(()).unwrap(); + }); + + drop(handle); + + done_rx + .recv_timeout(Duration::from_secs(1)) + .expect("runtime should stop once the last sender is dropped"); +} From b65d223fbecdababcf51e581da30571afd20e1b6 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 13:07:01 -0400 Subject: [PATCH 109/304] ql-runtime: handle constructor --- ql-runtime/src/handle/mod.rs | 14 ++++++++++++-- ql-runtime/src/lib.rs | 5 +---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index 5dea054d..2fcf33ef 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -15,8 +15,8 @@ pub struct QlStream { #[derive(Clone)] pub struct RuntimeHandle { - pub(crate) tx: async_channel::Sender, - pub(crate) stream_send_buffer_bytes: usize, + tx: async_channel::Sender, + stream_send_buffer_bytes: usize, } impl RuntimeHandle { @@ -70,6 +70,16 @@ impl RuntimeHandle { } impl RuntimeHandle { + pub(crate) fn new( + tx: async_channel::Sender, + stream_send_buffer_bytes: usize, + ) -> Self { + Self { + tx, + stream_send_buffer_bytes, + } + } + #[inline] #[track_caller] fn send(&self, cmd: RuntimeCommand) { diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index 2444c0ce..885b0978 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -128,9 +128,6 @@ where rx, tx: tx.downgrade(), }, - RuntimeHandle { - tx, - stream_send_buffer_bytes: config.stream_send_buffer_bytes, - }, + RuntimeHandle::new(tx, config.stream_send_buffer_bytes), ) } From 5d4bd07abf6368aeb250849e6838f6e8f1870cb6 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 13:10:22 -0400 Subject: [PATCH 110/304] ql-runtime: infallible connect --- ql-runtime/src/handle/mod.rs | 6 ++---- ql-runtime/src/tests/handshake.rs | 6 +++--- ql-runtime/src/tests/heartbeat.rs | 2 +- ql-runtime/src/tests/rpc.rs | 6 +++--- ql-runtime/src/tests/stream.rs | 12 ++++++------ 5 files changed, 15 insertions(+), 17 deletions(-) diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index 2fcf33ef..0d23cd52 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -24,10 +24,8 @@ impl RuntimeHandle { self.send(RuntimeCommand::BindPeer { peer }); } - pub fn connect(&self) -> Result<(), QlError> { - self.tx - .send_blocking(RuntimeCommand::Connect) - .map_err(|_| QlError::Cancelled) + pub fn connect(&self) { + self.send(RuntimeCommand::Connect) } pub fn send_incoming(&self, bytes: Vec) { diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index 69859eea..cf899853 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -21,7 +21,7 @@ async fn connect_round_trip_changes_peer_status() { spawn_forwarder(outbound_b, handle_a.clone()); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect().unwrap(); + handle_a.connect(); await_status(&status_a, identity_b.xid, PeerStage::Connected).await; await_status(&status_b, identity_a.xid, PeerStage::Connected).await; @@ -75,7 +75,7 @@ async fn handshake_timeout_disconnects() { tokio::task::spawn_local(async move { runtime_b.run().await }); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect().unwrap(); + handle_a.connect(); await_status(&status_a, identity_b.xid, PeerStage::Disconnected).await; }) @@ -101,7 +101,7 @@ async fn rejected_session_write_is_reissued() { spawn_forwarder(outbound_b, handle_a.clone()); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect().unwrap(); + handle_a.connect(); await_status(&status_a, identity_b.xid, PeerStage::Connected).await; await_status(&status_b, identity_a.xid, PeerStage::Connected).await; diff --git a/ql-runtime/src/tests/heartbeat.rs b/ql-runtime/src/tests/heartbeat.rs index 23642858..d57c6e8f 100644 --- a/ql-runtime/src/tests/heartbeat.rs +++ b/ql-runtime/src/tests/heartbeat.rs @@ -36,7 +36,7 @@ async fn session_timeout_disconnects_and_fails_pending_open() { spawn_gated_forwarder(outbound_b, handle_a.clone(), drop_flag.clone()); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect().unwrap(); + handle_a.connect(); await_status(&status_a, identity_b.xid, PeerStage::Connected).await; await_status(&status_b, identity_a.xid, PeerStage::Connected).await; diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 3f7be9ff..e5c52fdb 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -68,7 +68,7 @@ async fn rpc_request_round_trips() { spawn_forwarder(outbound_b, handle_a.clone()); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect().unwrap(); + handle_a.connect(); await_status(&status_a, identity_b.xid, PeerStage::Connected).await; await_status(&status_b, identity_a.xid, PeerStage::Connected).await; @@ -124,7 +124,7 @@ async fn rpc_subscription_streams_events() { spawn_forwarder(outbound_b, handle_a.clone()); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect().unwrap(); + handle_a.connect(); await_status(&status_a, identity_b.xid, PeerStage::Connected).await; await_status(&status_b, identity_a.xid, PeerStage::Connected).await; @@ -192,7 +192,7 @@ async fn rpc_request_with_progress_supports_progress_then_await() { spawn_forwarder(outbound_b, handle_a.clone()); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect().unwrap(); + handle_a.connect(); await_status(&status_a, identity_b.xid, PeerStage::Connected).await; await_status(&status_b, identity_a.xid, PeerStage::Connected).await; diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index 2d091454..420832a5 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -22,7 +22,7 @@ async fn open_stream_duplex_happy_path() { spawn_forwarder(outbound_b, handle_a.clone()); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect().unwrap(); + handle_a.connect(); await_status(&status_a, identity_b.xid, PeerStage::Connected).await; await_status(&status_b, identity_a.xid, PeerStage::Connected).await; @@ -85,7 +85,7 @@ async fn stream_backpressure_with_small_runtime_buffer() { spawn_forwarder(outbound_b, handle_a.clone()); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect().unwrap(); + handle_a.connect(); await_status(&status_a, identity_b.xid, PeerStage::Connected).await; await_status(&status_b, identity_a.xid, PeerStage::Connected).await; @@ -135,7 +135,7 @@ async fn dropping_responder_closes_initiator_response() { spawn_forwarder(outbound_b, handle_a.clone()); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect().unwrap(); + handle_a.connect(); await_status(&status_a, identity_b.xid, PeerStage::Connected).await; await_status(&status_b, identity_a.xid, PeerStage::Connected).await; @@ -188,7 +188,7 @@ async fn dropping_inbound_reader_cancels_remote_writer() { spawn_forwarder(outbound_b, handle_a.clone()); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect().unwrap(); + handle_a.connect(); await_status(&status_a, identity_b.xid, PeerStage::Connected).await; await_status(&status_b, identity_a.xid, PeerStage::Connected).await; @@ -245,7 +245,7 @@ async fn max_concurrent_message_writes_is_respected() { spawn_forwarder(outbound_b, handle_a.clone()); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect().unwrap(); + handle_a.connect(); await_status(&status_a, identity_b.xid, PeerStage::Connected).await; await_status(&status_b, identity_a.xid, PeerStage::Connected).await; @@ -320,7 +320,7 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { spawn_drop_every_nth_encrypted_forwarder(outbound_b, handle_a.clone(), 3); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect().unwrap(); + handle_a.connect(); await_status(&status_a, identity_b.xid, PeerStage::Connected).await; await_status(&status_b, identity_a.xid, PeerStage::Connected).await; From f016d883d24f1fe595d5d6c50e0d6d9dd3298ad0 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 13:25:08 -0400 Subject: [PATCH 111/304] ql-runtime: lib.rs cleanup --- ql-runtime/src/lib.rs | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index 885b0978..b54e0dac 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -1,9 +1,4 @@ -pub use handle::{ByteReader, ByteWriter, QlStream, RuntimeHandle}; -pub use ql_fsm::{PeerStatus, QlFsmConfig, QlFsmError, SessionWriteId}; -pub use ql_wire::{ - self as wire, CloseTarget, PeerBundle, QlIdentity, SessionCloseCode, StreamCloseCode, StreamId, - XID, -}; +pub use self::{handle::*, platform::*}; pub(crate) mod command; pub(crate) mod driver; @@ -15,10 +10,10 @@ pub mod rpc; #[cfg(test)] mod tests; +use ql_fsm::{QlFsmConfig, QlFsmError}; +use ql_wire::QlIdentity; use thiserror::Error; -use self::platform::QlPlatform; - #[derive(Debug, Clone, PartialEq, Eq, Error)] pub enum QlError { #[error("invalid payload")] @@ -47,8 +42,8 @@ pub enum QlError { SendFailed, #[error("stream closed {code:?}")] StreamClosed { - target: CloseTarget, - code: StreamCloseCode, + target: ql_wire::CloseTarget, + code: ql_wire::StreamCloseCode, }, #[error("cancelled")] Cancelled, @@ -98,8 +93,8 @@ impl RuntimeConfig { } pub(crate) struct OpenedStreamDelivery { - pub stream_id: StreamId, - pub reader: crate::ByteReader, + pub stream_id: ql_wire::StreamId, + pub reader: ByteReader, } pub struct Runtime

{ From 7d4b92d9ae91dd24d5dfae929cf3a1c1b96a9f3c Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 13:27:15 -0400 Subject: [PATCH 112/304] ql: remove thiserror dep --- Cargo.lock | 3 -- ql-fsm/Cargo.toml | 1 - ql-fsm/src/error.rs | 35 ++++++++++++++-------- ql-fsm/src/session/mod.rs | 20 +++++++++---- ql-rpc/Cargo.toml | 1 - ql-rpc/src/error.rs | 27 ++++++++++++----- ql-runtime/Cargo.toml | 1 - ql-runtime/src/error.rs | 63 +++++++++++++++++++++++++++++++++++++++ ql-runtime/src/lib.rs | 59 ++---------------------------------- 9 files changed, 122 insertions(+), 88 deletions(-) create mode 100644 ql-runtime/src/error.rs diff --git a/Cargo.lock b/Cargo.lock index 6c47fec1..46804b72 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2132,7 +2132,6 @@ dependencies = [ "libcrux-ml-kem", "ql-wire", "sha2", - "thiserror", ] [[package]] @@ -2141,7 +2140,6 @@ version = "0.1.0" dependencies = [ "bytes", "ql-wire", - "thiserror", ] [[package]] @@ -2158,7 +2156,6 @@ dependencies = [ "ql-rpc", "ql-wire", "sha2", - "thiserror", "tokio", ] diff --git a/ql-fsm/Cargo.toml b/ql-fsm/Cargo.toml index 89c68339..1b5319e9 100644 --- a/ql-fsm/Cargo.toml +++ b/ql-fsm/Cargo.toml @@ -8,7 +8,6 @@ license = "Proprietary" [dependencies] indexmap = "2" ql-wire = { path = "../ql-wire" } -thiserror = { version = "2" } [dev-dependencies] libcrux-aesgcm = "0.0.7" diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs index c9114e02..0d66677f 100644 --- a/ql-fsm/src/error.rs +++ b/ql-fsm/src/error.rs @@ -1,34 +1,43 @@ use ql_wire::WireError; -use thiserror::Error; use crate::session::StreamError; -#[derive(Debug, Clone, PartialEq, Eq, Error)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum QlFsmError { - #[error("invalid payload")] InvalidPayload, - #[error("invalid state")] InvalidState, - #[error("expired")] Expired, - #[error("decryption failed")] DecryptFailed, - #[error("invalid xid")] InvalidXid, - #[error("missing stream")] MissingStream, - #[error("stream is not writable")] NotWritable, - #[error("invalid read commit")] InvalidRead, - #[error("session is closed")] SessionClosed, - #[error("no peer bound")] NoPeerBound, - #[error("no active session")] NoSession, } +impl std::fmt::Display for QlFsmError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let message = match self { + Self::InvalidPayload => "invalid payload", + Self::InvalidState => "invalid state", + Self::Expired => "expired", + Self::DecryptFailed => "decryption failed", + Self::InvalidXid => "invalid xid", + Self::MissingStream => "missing stream", + Self::NotWritable => "stream is not writable", + Self::InvalidRead => "invalid read commit", + Self::SessionClosed => "session is closed", + Self::NoPeerBound => "no peer bound", + Self::NoSession => "no active session", + }; + f.write_str(message) + } +} + +impl std::error::Error for QlFsmError {} + impl From for QlFsmError { fn from(value: WireError) -> Self { match value { diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 59cfbb93..b9ad8eab 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -72,18 +72,28 @@ pub enum SessionState { Closed, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum StreamError { - #[error("missing stream")] MissingStream, - #[error("stream is not writable")] NotWritable, - #[error("invalid read commit")] InvalidRead, - #[error("session is closed")] SessionClosed, } +impl std::fmt::Display for StreamError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let message = match self { + Self::MissingStream => "missing stream", + Self::NotWritable => "stream is not writable", + Self::InvalidRead => "invalid read commit", + Self::SessionClosed => "session is closed", + }; + f.write_str(message) + } +} + +impl std::error::Error for StreamError {} + pub struct SessionFsm { config: SessionFsmConfig, state: SessionFsmState, diff --git a/ql-rpc/Cargo.toml b/ql-rpc/Cargo.toml index a7df20c9..8fe76b8c 100644 --- a/ql-rpc/Cargo.toml +++ b/ql-rpc/Cargo.toml @@ -8,4 +8,3 @@ license = "Proprietary" [dependencies] bytes = { version = "1" } ql-wire = { path = "../ql-wire" } -thiserror = { version = "2" } diff --git a/ql-rpc/src/error.rs b/ql-rpc/src/error.rs index 65fdcfac..8675c493 100644 --- a/ql-rpc/src/error.rs +++ b/ql-rpc/src/error.rs @@ -2,27 +2,38 @@ use ql_wire::StreamCloseCode; use crate::MethodId; -#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RpcError { - #[error("truncated rpc payload")] Truncated, - #[error("rpc payload length overflow")] LengthOverflow, - #[error("invalid rpc version {0}")] InvalidVersion(u8), - #[error("unexpected rpc method {actual:?}, expected {expected:?}")] UnexpectedMethod { expected: MethodId, actual: MethodId, }, - #[error("unexpected rpc frame kind {0}")] UnexpectedFrameKind(u8), - #[error("missing terminal rpc response")] MissingResponse, - #[error("trailing rpc bytes")] TrailingBytes, } +impl std::fmt::Display for RpcError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Truncated => f.write_str("truncated rpc payload"), + Self::LengthOverflow => f.write_str("rpc payload length overflow"), + Self::InvalidVersion(version) => write!(f, "invalid rpc version {version}"), + Self::UnexpectedMethod { expected, actual } => { + write!(f, "unexpected rpc method {actual:?}, expected {expected:?}") + } + Self::UnexpectedFrameKind(kind) => write!(f, "unexpected rpc frame kind {kind}"), + Self::MissingResponse => f.write_str("missing terminal rpc response"), + Self::TrailingBytes => f.write_str("trailing rpc bytes"), + } + } +} + +impl std::error::Error for RpcError {} + impl RpcError { pub const fn close_code(self) -> StreamCloseCode { match self { diff --git a/ql-runtime/Cargo.toml b/ql-runtime/Cargo.toml index be9e018e..a34bacc8 100644 --- a/ql-runtime/Cargo.toml +++ b/ql-runtime/Cargo.toml @@ -17,7 +17,6 @@ piper = { version = "0.2.4" } ql-fsm = { path = "../ql-fsm" } ql-rpc = { path = "../ql-rpc", optional = true } ql-wire = { path = "../ql-wire" } -thiserror = { version = "2" } [dev-dependencies] bytes = "1" diff --git a/ql-runtime/src/error.rs b/ql-runtime/src/error.rs new file mode 100644 index 00000000..b16c7e2e --- /dev/null +++ b/ql-runtime/src/error.rs @@ -0,0 +1,63 @@ +use ql_fsm::QlFsmError; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum QlError { + InvalidPayload, + InvalidState, + Expired, + DecryptFailed, + InvalidXid, + MissingStream, + NotWritable, + InvalidRead, + SessionClosed, + NoPeerBound, + NoSession, + SendFailed, + StreamClosed { + target: ql_wire::CloseTarget, + code: ql_wire::StreamCloseCode, + }, + Cancelled, +} + +impl std::fmt::Display for QlError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidPayload => f.write_str("invalid payload"), + Self::InvalidState => f.write_str("invalid state"), + Self::Expired => f.write_str("expired"), + Self::DecryptFailed => f.write_str("decryption failed"), + Self::InvalidXid => f.write_str("invalid xid"), + Self::MissingStream => f.write_str("missing stream"), + Self::NotWritable => f.write_str("stream is not writable"), + Self::InvalidRead => f.write_str("invalid read"), + Self::SessionClosed => f.write_str("session is closed"), + Self::NoPeerBound => f.write_str("no peer bound"), + Self::NoSession => f.write_str("no active session"), + Self::SendFailed => f.write_str("send failed"), + Self::StreamClosed { code, .. } => write!(f, "stream closed {code:?}"), + Self::Cancelled => f.write_str("cancelled"), + } + } +} + +impl std::error::Error for QlError {} + +impl From for QlError { + fn from(value: QlFsmError) -> Self { + match value { + QlFsmError::InvalidPayload => Self::InvalidPayload, + QlFsmError::InvalidState => Self::InvalidState, + QlFsmError::Expired => Self::Expired, + QlFsmError::DecryptFailed => Self::DecryptFailed, + QlFsmError::InvalidXid => Self::InvalidXid, + QlFsmError::MissingStream => Self::MissingStream, + QlFsmError::NotWritable => Self::NotWritable, + QlFsmError::InvalidRead => Self::InvalidRead, + QlFsmError::SessionClosed => Self::SessionClosed, + QlFsmError::NoPeerBound => Self::NoPeerBound, + QlFsmError::NoSession => Self::NoSession, + } + } +} diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index b54e0dac..e2cc4bdd 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -1,7 +1,8 @@ -pub use self::{handle::*, platform::*}; +pub use self::{error::QlError, handle::*, platform::*}; pub(crate) mod command; pub(crate) mod driver; +mod error; pub mod handle; pub mod platform; #[cfg(feature = "rpc")] @@ -10,62 +11,8 @@ pub mod rpc; #[cfg(test)] mod tests; -use ql_fsm::{QlFsmConfig, QlFsmError}; +use ql_fsm::QlFsmConfig; use ql_wire::QlIdentity; -use thiserror::Error; - -#[derive(Debug, Clone, PartialEq, Eq, Error)] -pub enum QlError { - #[error("invalid payload")] - InvalidPayload, - #[error("invalid state")] - InvalidState, - #[error("expired")] - Expired, - #[error("decryption failed")] - DecryptFailed, - #[error("invalid xid")] - InvalidXid, - #[error("missing stream")] - MissingStream, - #[error("stream is not writable")] - NotWritable, - #[error("invalid read")] - InvalidRead, - #[error("session is closed")] - SessionClosed, - #[error("no peer bound")] - NoPeerBound, - #[error("no active session")] - NoSession, - #[error("send failed")] - SendFailed, - #[error("stream closed {code:?}")] - StreamClosed { - target: ql_wire::CloseTarget, - code: ql_wire::StreamCloseCode, - }, - #[error("cancelled")] - Cancelled, -} - -impl From for QlError { - fn from(value: QlFsmError) -> Self { - match value { - QlFsmError::InvalidPayload => Self::InvalidPayload, - QlFsmError::InvalidState => Self::InvalidState, - QlFsmError::Expired => Self::Expired, - QlFsmError::DecryptFailed => Self::DecryptFailed, - QlFsmError::InvalidXid => Self::InvalidXid, - QlFsmError::MissingStream => Self::MissingStream, - QlFsmError::NotWritable => Self::NotWritable, - QlFsmError::InvalidRead => Self::InvalidRead, - QlFsmError::SessionClosed => Self::SessionClosed, - QlFsmError::NoPeerBound => Self::NoPeerBound, - QlFsmError::NoSession => Self::NoSession, - } - } -} #[derive(Debug, Clone, Copy)] pub struct RuntimeConfig { From 5d14c2fb68d98d65aa33c7e4bab8da34b4305b44 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 13:33:40 -0400 Subject: [PATCH 113/304] ql-runtime: cleanup --- ql-runtime/src/command.rs | 7 +++---- ql-runtime/src/platform.rs | 5 +++-- ql-runtime/src/tests/handshake.rs | 12 ++++++------ ql-runtime/src/tests/heartbeat.rs | 6 +++--- ql-runtime/src/tests/mod.rs | 31 ++++++++++--------------------- ql-runtime/src/tests/rpc.rs | 12 ++++++------ ql-runtime/src/tests/stream.rs | 27 ++++++++++++++------------- 7 files changed, 45 insertions(+), 55 deletions(-) diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index 0c9442dc..c807818f 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -1,7 +1,6 @@ -use crate::{ - wire::{CloseTarget, StreamCloseCode}, - OpenedStreamDelivery, PeerBundle, QlError, StreamId, -}; +use ql_wire::{CloseTarget, PeerBundle, StreamCloseCode, StreamId}; + +use crate::{OpenedStreamDelivery, QlError}; pub(crate) enum RuntimeCommand { BindPeer { diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs index d232bcb1..36ce6e30 100644 --- a/ql-runtime/src/platform.rs +++ b/ql-runtime/src/platform.rs @@ -1,8 +1,9 @@ use std::{future::Future, pin::Pin, time::Duration}; -use ql_wire::QlCrypto; +use ql_fsm::PeerStatus; +use ql_wire::{PeerBundle, QlCrypto, XID}; -use crate::{PeerBundle, PeerStatus, QlError, QlStream, XID}; +use crate::{QlError, QlStream}; pub type PlatformFuture<'a, T> = Pin + 'a>>; diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index cf899853..dca270fb 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -23,8 +23,8 @@ async fn connect_round_trip_changes_peer_status() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, identity_b.xid, PeerStage::Connected).await; - await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; + await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; }) .await; } @@ -77,7 +77,7 @@ async fn handshake_timeout_disconnects() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, identity_b.xid, PeerStage::Disconnected).await; + await_status(&status_a, identity_b.xid, PeerStatus::Disconnected).await; }) .await; } @@ -103,8 +103,8 @@ async fn rejected_session_write_is_reissued() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, identity_b.xid, PeerStage::Connected).await; - await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; + await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); @@ -129,7 +129,7 @@ async fn rejected_session_write_is_reissued() { assert_no_status_for( &status_a, identity_b.xid, - PeerStage::Disconnected, + PeerStatus::Disconnected, Duration::from_millis(150), ) .await; diff --git a/ql-runtime/src/tests/heartbeat.rs b/ql-runtime/src/tests/heartbeat.rs index d57c6e8f..a0f01eb2 100644 --- a/ql-runtime/src/tests/heartbeat.rs +++ b/ql-runtime/src/tests/heartbeat.rs @@ -38,8 +38,8 @@ async fn session_timeout_disconnects_and_fails_pending_open() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, identity_b.xid, PeerStage::Connected).await; - await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; + await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; let responder_task = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); @@ -52,7 +52,7 @@ async fn session_timeout_disconnects_and_fails_pending_open() { let mut pending = handle_a.open_stream().await.unwrap(); pending.writer.finish().await.unwrap(); - await_status(&status_a, identity_b.xid, PeerStage::Disconnected).await; + await_status(&status_a, identity_b.xid, PeerStatus::Disconnected).await; let result = tokio::time::timeout(Duration::from_millis(300), next_chunk(&mut pending.reader)) diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index f7afa2d0..8705a0fe 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -12,6 +12,7 @@ use std::{ use async_channel::{Receiver, Sender}; use futures_lite::future::poll_fn; use libcrux_aesgcm::AesGcm256Key; +use ql_fsm::PeerStatus; use ql_wire::{ generate_identity, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, PeerBundle, QlAead, QlHash, QlIdentity, QlKem, QlRandom, RecordHeader, RecordType, SessionKey, @@ -21,8 +22,8 @@ use sha2::{Digest, Sha256}; use tokio::task::LocalSet; use crate::{ - new_runtime, platform::PlatformFuture, PeerStatus, QlError, QlFsmConfig, QlStream, - RuntimeConfig, RuntimeHandle, + new_runtime, platform::PlatformFuture, QlError, QlFsmConfig, QlStream, RuntimeConfig, + RuntimeHandle, }; mod handshake; @@ -31,17 +32,10 @@ mod heartbeat; mod rpc; mod stream; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum PeerStage { - Disconnected, - Initiator, - Connected, -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] struct StatusEvent { peer: XID, - stage: PeerStage, + status: PeerStatus, } #[derive(Debug, Clone)] @@ -377,12 +371,7 @@ impl crate::platform::QlPlatform for TestPlatform { fn persist_peer(&self, _peer: PeerBundle) {} fn handle_peer_status(&self, peer: XID, status: PeerStatus) { - let stage = match status { - PeerStatus::Disconnected => PeerStage::Disconnected, - PeerStatus::Initiator => PeerStage::Initiator, - PeerStatus::Connected => PeerStage::Connected, - }; - let _ = self.status.try_send(StatusEvent { peer, stage }); + let _ = self.status.try_send(StatusEvent { peer, status }); } fn handle_inbound(&self, event: QlStream) { @@ -464,11 +453,11 @@ where local.run_until(future).await; } -async fn await_status(receiver: &Receiver, peer: XID, stage: PeerStage) { +async fn await_status(receiver: &Receiver, peer: XID, stage: PeerStatus) { tokio::time::timeout(Duration::from_secs(2), async { loop { if let Ok(event) = receiver.recv().await { - if event.peer == peer && event.stage == stage { + if event.peer == peer && event.status == stage { return; } } @@ -481,19 +470,19 @@ async fn await_status(receiver: &Receiver, peer: XID, stage: PeerSt async fn assert_no_status_for( receiver: &Receiver, peer: XID, - stage: PeerStage, + status: PeerStatus, window: Duration, ) { let res = tokio::time::timeout(window, async { loop { let event = receiver.recv().await.unwrap(); - if event.peer == peer && event.stage == stage { + if event.peer == peer && event.status == status { return; } } }) .await; - assert!(res.is_err(), "unexpected status event: {stage:?}"); + assert!(res.is_err(), "unexpected status event: {status:?}"); } async fn read_all(mut stream: crate::ByteReader) -> Result, QlError> { diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index e5c52fdb..d0d1ff1f 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -70,8 +70,8 @@ async fn rpc_request_round_trips() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, identity_b.xid, PeerStage::Connected).await; - await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; + await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); @@ -126,8 +126,8 @@ async fn rpc_subscription_streams_events() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, identity_b.xid, PeerStage::Connected).await; - await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; + await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); @@ -194,8 +194,8 @@ async fn rpc_request_with_progress_supports_progress_then_await() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, identity_b.xid, PeerStage::Connected).await; - await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; + await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index 420832a5..7a1337c9 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -1,7 +1,8 @@ use std::time::Duration; +use ql_wire::{CloseTarget, StreamCloseCode}; + use super::*; -use crate::{CloseTarget, StreamCloseCode}; #[tokio::test(flavor = "current_thread")] async fn open_stream_duplex_happy_path() { @@ -24,8 +25,8 @@ async fn open_stream_duplex_happy_path() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, identity_b.xid, PeerStage::Connected).await; - await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; + await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); @@ -87,8 +88,8 @@ async fn stream_backpressure_with_small_runtime_buffer() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, identity_b.xid, PeerStage::Connected).await; - await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; + await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); @@ -137,8 +138,8 @@ async fn dropping_responder_closes_initiator_response() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, identity_b.xid, PeerStage::Connected).await; - await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; + await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); @@ -190,8 +191,8 @@ async fn dropping_inbound_reader_cancels_remote_writer() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, identity_b.xid, PeerStage::Connected).await; - await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; + await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); @@ -247,8 +248,8 @@ async fn max_concurrent_message_writes_is_respected() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, identity_b.xid, PeerStage::Connected).await; - await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; + await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; let responder = tokio::task::spawn_local(async move { for _ in 0..4 { @@ -322,8 +323,8 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, identity_b.xid, PeerStage::Connected).await; - await_status(&status_b, identity_a.xid, PeerStage::Connected).await; + await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; + await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); From e40cac93fae0e5ffb559780129d4c5083f621511 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 15:09:44 -0400 Subject: [PATCH 114/304] ql: more reasonable buffer sizes --- ql-fsm/src/session/mod.rs | 6 +++--- ql-runtime/src/lib.rs | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index b9ad8eab..1297c75a 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -43,13 +43,13 @@ impl Default for SessionFsmConfig { fn default() -> Self { Self { local_parity: StreamParity::Even, - record_max_size: 16 * 1024, + record_max_size: 8 * 1024, ack_delay: Duration::from_millis(5), retransmit_timeout: Duration::from_millis(150), keepalive_interval: Duration::from_secs(10), peer_timeout: Duration::from_secs(30), - stream_send_buffer_size: 64 * 1024, - stream_receive_buffer_size: 64 * 1024, + stream_send_buffer_size: 16 * 1024, + stream_receive_buffer_size: 16 * 1024, initial_peer_stream_receive_window: 16 * 1024, } } diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index e2cc4bdd..3d8c0889 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -25,7 +25,7 @@ impl Default for RuntimeConfig { fn default() -> Self { Self { fsm: QlFsmConfig::default(), - stream_send_buffer_bytes: 64 * 1024, + stream_send_buffer_bytes: 16 * 1024, max_concurrent_message_writes: 4, } } From 71bafcc6a8c079076c0f0a75a78746291f77dd4c Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 15:14:42 -0400 Subject: [PATCH 115/304] ql-runtime: get rid of openedstreamdelivery struct --- ql-runtime/src/command.rs | 4 ++-- ql-runtime/src/driver/mod.rs | 23 +++++++++-------------- ql-runtime/src/handle/mod.rs | 18 +++++++----------- ql-runtime/src/lib.rs | 5 ----- 4 files changed, 18 insertions(+), 32 deletions(-) diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index c807818f..261cef1f 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -1,6 +1,6 @@ use ql_wire::{CloseTarget, PeerBundle, StreamCloseCode, StreamId}; -use crate::{OpenedStreamDelivery, QlError}; +use crate::{ByteReader, QlError}; pub(crate) enum RuntimeCommand { BindPeer { @@ -9,7 +9,7 @@ pub(crate) enum RuntimeCommand { Connect, OpenStream { request_reader: piper::Reader, - start: oneshot::Sender>, + start: oneshot::Sender>, }, PollInbound { stream_id: StreamId, diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index e851b324..b0c94d1c 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -18,7 +18,7 @@ use crate::{ command::RuntimeCommand, handle::{ByteReader, ByteWriter, QlStream}, platform::{PlatformFuture, QlPlatform}, - OpenedStreamDelivery, QlError, Runtime, + QlError, Runtime, }; impl Runtime

{ @@ -184,19 +184,14 @@ impl DriverState { response_terminal_tx, ), ); - if start - .send(Ok(OpenedStreamDelivery { - stream_id, - reader: ByteReader::new( - stream_id, - CloseTarget::Return, - response_reader, - response_terminal_rx, - runtime_tx, - ), - })) - .is_err() - { + let reader = ByteReader::new( + stream_id, + CloseTarget::Return, + response_reader, + response_terminal_rx, + runtime_tx, + ); + if start.send(Ok((stream_id, reader))).is_err() { if let Some(stream) = self.streams.get_mut(&stream_id) { stream.inbound_mut().close(); stream.outbound_mut().close(); diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index 0d23cd52..a54ba8cc 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -4,7 +4,7 @@ mod writer; use ql_wire::{CloseTarget, PeerBundle, StreamId}; pub use self::{reader::*, writer::*}; -use crate::{command::RuntimeCommand, OpenedStreamDelivery, QlError}; +use crate::{command::RuntimeCommand, QlError}; #[derive(Debug)] pub struct QlStream { @@ -36,16 +36,12 @@ impl RuntimeHandle { let (request_reader, request_writer) = piper::pipe(self.stream_send_buffer_bytes); let (start_tx, start_rx) = oneshot::channel(); - self.tx - .send(RuntimeCommand::OpenStream { - request_reader, - start: start_tx, - }) - .await - .map_err(|_| QlError::Cancelled)?; - - let OpenedStreamDelivery { stream_id, reader } = - start_rx.await.unwrap_or(Err(QlError::Cancelled))?; + self.send(RuntimeCommand::OpenStream { + request_reader, + start: start_tx, + }); + // runtime cannot be shutdown while we have a handle + let (stream_id, reader) = start_rx.await.unwrap()?; Ok(QlStream { stream_id, diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index 3d8c0889..a7180f41 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -39,11 +39,6 @@ impl RuntimeConfig { } } -pub(crate) struct OpenedStreamDelivery { - pub stream_id: ql_wire::StreamId, - pub reader: ByteReader, -} - pub struct Runtime

{ identity: QlIdentity, platform: P, From 96cb3b566aa88b69e2092506ab1d64d33eae058e Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 18:02:35 -0400 Subject: [PATCH 116/304] ql-runtime: use try_send instead of send_blocking --- ql-runtime/src/handle/mod.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index a54ba8cc..bb860042 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -40,6 +40,7 @@ impl RuntimeHandle { request_reader, start: start_tx, }); + // runtime cannot be shutdown while we have a handle let (stream_id, reader) = start_rx.await.unwrap()?; @@ -77,6 +78,6 @@ impl RuntimeHandle { #[inline] #[track_caller] fn send(&self, cmd: RuntimeCommand) { - self.tx.send_blocking(cmd).expect("runtime is alive"); + self.tx.try_send(cmd).expect("runtime is alive"); } } From a95774811a9a460f9ca41702118fe455a2e286ab Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 19:26:26 -0400 Subject: [PATCH 117/304] ql-fsm: proptest --- Cargo.lock | 129 +++- ql-fsm/Cargo.toml | 1 + ql-fsm/src/tests/mod.rs | 1 + ql-fsm/src/tests/proptest.rs | 1084 ++++++++++++++++++++++++++++++++++ 4 files changed, 1213 insertions(+), 2 deletions(-) create mode 100644 ql-fsm/src/tests/proptest.rs diff --git a/Cargo.lock b/Cargo.lock index 46804b72..91c105fc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -311,6 +311,21 @@ dependencies = [ "thiserror", ] +[[package]] +name = "bit-set" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" + [[package]] name = "bitcoin-io" version = "0.1.3" @@ -344,9 +359,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.9.3" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34efbcccd345379ca2868b2b2c9d3782e9cc58ba87bc7d79d5b53d9c9ae6f25d" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" [[package]] name = "blake2" @@ -853,6 +868,16 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys", +] + [[package]] name = "event-listener" version = "5.4.1" @@ -938,6 +963,12 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -1561,6 +1592,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + [[package]] name = "litemap" version = "0.8.0" @@ -2079,6 +2116,25 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proptest" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b45fcc2344c680f5025fe57779faef368840d0bd1f42f216291f0dc4ace4744" +dependencies = [ + "bit-set", + "bit-vec", + "bitflags", + "num-traits", + "rand 0.9.2", + "rand_chacha 0.9.0", + "rand_xorshift", + "regex-syntax", + "rusty-fork", + "tempfile", + "unarray", +] + [[package]] name = "provenance-mark" version = "0.16.0" @@ -2130,6 +2186,7 @@ dependencies = [ "indexmap", "libcrux-aesgcm", "libcrux-ml-kem", + "proptest", "ql-wire", "sha2", ] @@ -2177,6 +2234,12 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + [[package]] name = "quote" version = "1.0.40" @@ -2260,6 +2323,15 @@ dependencies = [ "getrandom 0.3.3", ] +[[package]] +name = "rand_xorshift" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a" +dependencies = [ + "rand_core 0.9.3", +] + [[package]] name = "rand_xoshiro" version = "0.6.0" @@ -2392,12 +2464,37 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + [[package]] name = "rustversion" version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" +[[package]] +name = "rusty-fork" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc6bf79ff24e648f6da1f8d1f011e9cac26491b619e6b9280f2b47f1774e6ee2" +dependencies = [ + "fnv", + "quick-error", + "tempfile", + "wait-timeout", +] + [[package]] name = "ryu" version = "1.0.20" @@ -2687,6 +2784,19 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "tempfile" +version = "3.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +dependencies = [ + "fastrand", + "getrandom 0.3.3", + "once_cell", + "rustix", + "windows-sys", +] + [[package]] name = "thiserror" version = "2.0.17" @@ -2794,6 +2904,12 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + [[package]] name = "unicode-ident" version = "1.0.18" @@ -2866,6 +2982,15 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "wait-timeout" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ac3b126d3914f9849036f826e054cbabdc8519970b8998ddaf3b5bd3c65f11" +dependencies = [ + "libc", +] + [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" diff --git a/ql-fsm/Cargo.toml b/ql-fsm/Cargo.toml index 1b5319e9..14b6c050 100644 --- a/ql-fsm/Cargo.toml +++ b/ql-fsm/Cargo.toml @@ -12,4 +12,5 @@ ql-wire = { path = "../ql-wire" } [dev-dependencies] libcrux-aesgcm = "0.0.7" libcrux-ml-kem = "0.0.7" +proptest = "1.6" sha2 = "0.10" diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index fc108c1c..d8e0d4ed 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -1,4 +1,5 @@ mod handshake; +mod proptest; mod session; use std::{ diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs new file mode 100644 index 00000000..c7671a89 --- /dev/null +++ b/ql-fsm/src/tests/proptest.rs @@ -0,0 +1,1084 @@ +use std::{ + collections::{BTreeMap, BTreeSet}, + time::Duration, +}; + +use ::proptest::{collection::vec, prelude::*, test_runner::TestCaseResult}; +use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; + +use super::*; +use crate::{state::LinkState, PeerStatus, QlFsmError, QlFsmEvent, SessionWriteId}; + +const SLOT_COUNT: usize = 4; + +#[derive(Clone, Copy, Debug)] +enum Side { + A, + B, +} + +#[derive(Clone, Debug)] +enum Action { + ConnectIkA, + ConnectIkB, + ConnectKkA, + ConnectKkB, + AdvanceMs(u8), + OnTimerA, + OnTimerB, + OnTimerBoth, + Pump, + TakeNextAToB, + TakeNextBToA, + ConfirmTakenAToB(usize), + ConfirmTakenBToA(usize), + RejectTakenAToB(usize), + RejectTakenBToA(usize), + CaptureNextAToB, + CaptureNextBToA, + DeliverNextAToB, + DeliverNextBToA, + DropNextAToB, + DropNextBToA, + DeliverQueuedAToB(usize), + DeliverQueuedBToA(usize), + DuplicateQueuedAToB(usize), + DuplicateQueuedBToA(usize), + DropQueuedAToB(usize), + DropQueuedBToA(usize), + OpenStreamA(usize), + OpenStreamB(usize), + WriteA { slot: usize, bytes: Vec }, + WriteB { slot: usize, bytes: Vec }, + FinishA(usize), + FinishB(usize), + CloseA(usize), + CloseB(usize), +} + +#[derive(Clone, Debug)] +struct TakenWrite { + record: Vec, + write_id: Option, +} + +#[derive(Default)] +struct SideEventState { + opened: BTreeSet, + finished: BTreeSet, + writable_closed: BTreeSet, + closed: BTreeSet, + peer_statuses: Vec, + last_peer_status: Option, + session_epoch: usize, + session_closed_epoch: Option, +} + +impl SideEventState { + fn note_peer_status(&mut self, status: PeerStatus) { + if status == PeerStatus::Connected && self.last_peer_status != Some(PeerStatus::Connected) { + self.session_epoch = self.session_epoch.saturating_add(1); + } + self.peer_statuses.push(status); + self.last_peer_status = Some(status); + } +} + +struct Runner { + harness: Harness, + slots_a: [Option; SLOT_COUNT], + slots_b: [Option; SLOT_COUNT], + taken_a_to_b: Vec, + taken_b_to_a: Vec, + pending_a_to_b: Vec>, + pending_b_to_a: Vec>, + receive_errors: Vec<(Side, QlFsmError)>, + events_a: SideEventState, + events_b: SideEventState, + known_streams: BTreeSet, + expected_at_a: BTreeMap>, + expected_at_b: BTreeMap>, + received_at_a: BTreeMap>, + received_at_b: BTreeMap>, + finished_by_a: BTreeSet, + finished_by_b: BTreeSet, + closed_by_a: BTreeSet, + closed_by_b: BTreeSet, +} + +impl Runner { + fn handshake() -> Self { + let config = QlFsmConfig { + handshake_timeout: Duration::from_millis(60), + session_record_ack_delay: Duration::from_millis(5), + session_record_retransmit_timeout: Duration::from_millis(15), + session_peer_timeout: Duration::from_millis(80), + ..QlFsmConfig::default() + }; + + Self { + harness: Harness::paired_known(config), + slots_a: [None; SLOT_COUNT], + slots_b: [None; SLOT_COUNT], + taken_a_to_b: Vec::new(), + taken_b_to_a: Vec::new(), + pending_a_to_b: Vec::new(), + pending_b_to_a: Vec::new(), + receive_errors: Vec::new(), + events_a: SideEventState::default(), + events_b: SideEventState::default(), + known_streams: BTreeSet::new(), + expected_at_a: BTreeMap::new(), + expected_at_b: BTreeMap::new(), + received_at_a: BTreeMap::new(), + received_at_b: BTreeMap::new(), + finished_by_a: BTreeSet::new(), + finished_by_b: BTreeSet::new(), + closed_by_a: BTreeSet::new(), + closed_by_b: BTreeSet::new(), + } + } + + fn connected() -> Self { + let config = QlFsmConfig { + session_record_ack_delay: Duration::from_millis(5), + session_record_retransmit_timeout: Duration::from_millis(15), + session_peer_timeout: Duration::from_secs(5), + ..QlFsmConfig::default() + }; + + Self { + harness: Harness::connected(config), + slots_a: [None; SLOT_COUNT], + slots_b: [None; SLOT_COUNT], + taken_a_to_b: Vec::new(), + taken_b_to_a: Vec::new(), + pending_a_to_b: Vec::new(), + pending_b_to_a: Vec::new(), + receive_errors: Vec::new(), + events_a: SideEventState { + last_peer_status: Some(PeerStatus::Connected), + session_epoch: 1, + ..SideEventState::default() + }, + events_b: SideEventState { + last_peer_status: Some(PeerStatus::Connected), + session_epoch: 1, + ..SideEventState::default() + }, + known_streams: BTreeSet::new(), + expected_at_a: BTreeMap::new(), + expected_at_b: BTreeMap::new(), + received_at_a: BTreeMap::new(), + received_at_b: BTreeMap::new(), + finished_by_a: BTreeSet::new(), + finished_by_b: BTreeSet::new(), + closed_by_a: BTreeSet::new(), + closed_by_b: BTreeSet::new(), + } + } + + fn run(&mut self, actions: &[Action]) -> TestCaseResult { + for action in actions { + self.apply(action); + self.observe_and_assert()?; + } + + self.cleanup()?; + self.observe_and_assert()?; + self.assert_terminal_semantics()?; + self.assert_quiesced() + } + + fn apply(&mut self, action: &Action) { + match action { + Action::ConnectIkA => { + let _ = self.harness.connect_ik_a(); + } + Action::ConnectIkB => { + let _ = self.harness.connect_ik_b(); + } + Action::ConnectKkA => { + let _ = self.harness.connect_kk_a(); + } + Action::ConnectKkB => { + let _ = self.harness.connect_kk_b(); + } + Action::AdvanceMs(ms) => { + self.harness + .advance(Duration::from_millis(u64::from(*ms) + 1)); + } + Action::OnTimerA => self.harness.on_timer_a(), + Action::OnTimerB => self.harness.on_timer_b(), + Action::OnTimerBoth => { + self.harness.on_timer_a(); + self.harness.on_timer_b(); + } + Action::Pump => self.capture_all_outbound(), + Action::TakeNextAToB => { + if let Some(write) = take_unconfirmed_outbound_a(&mut self.harness) { + self.taken_a_to_b.push(write); + } + } + Action::TakeNextBToA => { + if let Some(write) = take_unconfirmed_outbound_b(&mut self.harness) { + self.taken_b_to_a.push(write); + } + } + Action::ConfirmTakenAToB(index) => { + if let Some(write) = take_taken(&mut self.taken_a_to_b, *index) { + confirm_taken_a(&mut self.harness, &write); + self.pending_a_to_b.push(write.record); + } + } + Action::ConfirmTakenBToA(index) => { + if let Some(write) = take_taken(&mut self.taken_b_to_a, *index) { + confirm_taken_b(&mut self.harness, &write); + self.pending_b_to_a.push(write.record); + } + } + Action::RejectTakenAToB(index) => { + if let Some(write) = take_taken(&mut self.taken_a_to_b, *index) { + reject_taken_a(&mut self.harness, &write); + } + } + Action::RejectTakenBToA(index) => { + if let Some(write) = take_taken(&mut self.taken_b_to_a, *index) { + reject_taken_b(&mut self.harness, &write); + } + } + Action::CaptureNextAToB => { + if let Some(record) = take_confirmed_outbound_a(&mut self.harness) { + self.pending_a_to_b.push(record); + } + } + Action::CaptureNextBToA => { + if let Some(record) = take_confirmed_outbound_b(&mut self.harness) { + self.pending_b_to_a.push(record); + } + } + Action::DeliverNextAToB => { + if let Some(record) = take_confirmed_outbound_a(&mut self.harness) { + self.deliver_to_b(record); + } + } + Action::DeliverNextBToA => { + if let Some(record) = take_confirmed_outbound_b(&mut self.harness) { + self.deliver_to_a(record); + } + } + Action::DropNextAToB => { + let _ = take_confirmed_outbound_a(&mut self.harness); + } + Action::DropNextBToA => { + let _ = take_confirmed_outbound_b(&mut self.harness); + } + Action::DeliverQueuedAToB(index) => { + if let Some(record) = take_pending(&mut self.pending_a_to_b, *index) { + self.deliver_to_b(record); + } + } + Action::DeliverQueuedBToA(index) => { + if let Some(record) = take_pending(&mut self.pending_b_to_a, *index) { + self.deliver_to_a(record); + } + } + Action::DuplicateQueuedAToB(index) => { + if let Some(record) = peek_pending(&self.pending_a_to_b, *index) { + self.deliver_to_b(record); + } + } + Action::DuplicateQueuedBToA(index) => { + if let Some(record) = peek_pending(&self.pending_b_to_a, *index) { + self.deliver_to_a(record); + } + } + Action::DropQueuedAToB(index) => { + let _ = take_pending(&mut self.pending_a_to_b, *index); + } + Action::DropQueuedBToA(index) => { + let _ = take_pending(&mut self.pending_b_to_a, *index); + } + Action::OpenStreamA(slot) => { + if let Ok(stream_id) = self.harness.a.fsm.open_stream() { + self.slots_a[*slot] = Some(stream_id); + self.known_streams.insert(stream_id); + } + } + Action::OpenStreamB(slot) => { + if let Ok(stream_id) = self.harness.b.fsm.open_stream() { + self.slots_b[*slot] = Some(stream_id); + self.known_streams.insert(stream_id); + } + } + Action::WriteA { slot, bytes } => { + if let Some(stream_id) = self.slots_a[*slot] { + if let Ok(accepted) = self.harness.a.fsm.write_stream(stream_id, bytes) { + self.expected_at_b + .entry(stream_id) + .or_default() + .extend_from_slice(&bytes[..accepted]); + } + } + } + Action::WriteB { slot, bytes } => { + if let Some(stream_id) = self.slots_b[*slot] { + if let Ok(accepted) = self.harness.b.fsm.write_stream(stream_id, bytes) { + self.expected_at_a + .entry(stream_id) + .or_default() + .extend_from_slice(&bytes[..accepted]); + } + } + } + Action::FinishA(slot) => { + if let Some(stream_id) = self.slots_a[*slot] { + if self.harness.a.fsm.finish_stream(stream_id).is_ok() { + self.finished_by_a.insert(stream_id); + } + } + } + Action::FinishB(slot) => { + if let Some(stream_id) = self.slots_b[*slot] { + if self.harness.b.fsm.finish_stream(stream_id).is_ok() { + self.finished_by_b.insert(stream_id); + } + } + } + Action::CloseA(slot) => { + if let Some(stream_id) = self.slots_a[*slot] { + if self + .harness + .a + .fsm + .close_stream(stream_id, CloseTarget::Both, StreamCloseCode(0)) + .is_ok() + { + self.closed_by_a.insert(stream_id); + self.slots_a[*slot] = None; + } + } + } + Action::CloseB(slot) => { + if let Some(stream_id) = self.slots_b[*slot] { + if self + .harness + .b + .fsm + .close_stream(stream_id, CloseTarget::Both, StreamCloseCode(0)) + .is_ok() + { + self.closed_by_b.insert(stream_id); + self.slots_b[*slot] = None; + } + } + } + } + } + + fn observe_and_assert(&mut self) -> TestCaseResult { + self.drain_reads_a(); + self.drain_reads_b(); + let events_a = self.harness.drain_events_a(); + let events_b = self.harness.drain_events_b(); + self.process_events(Side::A, events_a)?; + self.process_events(Side::B, events_b)?; + self.assert_prefix_invariants()?; + self.assert_legal_link_state()?; + self.assert_receive_errors() + } + + fn cleanup(&mut self) -> TestCaseResult { + let tick = self + .harness + .a + .fsm + .config + .session_record_retransmit_timeout + .max(self.harness.a.fsm.config.session_record_ack_delay) + + Duration::from_millis(1); + + self.reject_all_taken(); + + for _ in 0..12 { + self.capture_all_outbound(); + self.flush_pending_in_order(); + self.capture_all_outbound(); + self.flush_pending_in_order(); + self.observe_and_assert()?; + self.harness.advance(tick); + self.harness.on_timer_a(); + self.harness.on_timer_b(); + self.capture_all_outbound(); + self.flush_pending_in_order(); + self.observe_and_assert()?; + self.reject_all_taken(); + } + + Ok(()) + } + + fn drain_reads_a(&mut self) { + for stream_id in self.known_streams.iter().copied().collect::>() { + let appended = drain_stream(&mut self.harness.a.fsm, stream_id); + if !appended.is_empty() { + self.received_at_a + .entry(stream_id) + .or_default() + .extend_from_slice(&appended); + } + } + } + + fn drain_reads_b(&mut self) { + for stream_id in self.known_streams.iter().copied().collect::>() { + let appended = drain_stream(&mut self.harness.b.fsm, stream_id); + if !appended.is_empty() { + self.received_at_b + .entry(stream_id) + .or_default() + .extend_from_slice(&appended); + } + } + } + + fn process_events(&mut self, side: Side, events: Vec) -> TestCaseResult { + for event in events { + match event { + QlFsmEvent::NewPeer => {} + QlFsmEvent::PeerStatusChanged(status) => { + self.events_mut(side).note_peer_status(status); + } + QlFsmEvent::Opened(stream_id) => { + prop_assert!( + self.known_streams.contains(&stream_id), + "side {side:?} emitted Opened for unknown stream {stream_id:?}" + ); + prop_assert!( + self.events_mut(side).opened.insert(stream_id), + "side {side:?} emitted duplicate Opened for {stream_id:?}" + ); + } + QlFsmEvent::Readable(stream_id) | QlFsmEvent::Writable(stream_id) => { + prop_assert!( + self.known_streams.contains(&stream_id), + "side {side:?} emitted readiness for unknown stream {stream_id:?}" + ); + } + QlFsmEvent::Finished(stream_id) => { + prop_assert!( + self.known_streams.contains(&stream_id), + "side {side:?} emitted Finished for unknown stream {stream_id:?}" + ); + prop_assert!( + self.events_mut(side).finished.insert(stream_id), + "side {side:?} emitted duplicate Finished for {stream_id:?}" + ); + prop_assert!( + !self.events(side).closed.contains(&stream_id), + "side {side:?} emitted Finished after Closed for {stream_id:?}" + ); + } + QlFsmEvent::Closed(frame) => { + prop_assert!( + self.known_streams.contains(&frame.stream_id), + "side {side:?} emitted Closed for unknown stream {:?}", + frame.stream_id + ); + prop_assert!( + self.events_mut(side).closed.insert(frame.stream_id), + "side {side:?} emitted duplicate Closed for {:?}", + frame.stream_id + ); + } + QlFsmEvent::WritableClosed(stream_id) => { + prop_assert!( + self.known_streams.contains(&stream_id), + "side {side:?} emitted WritableClosed for unknown stream {stream_id:?}" + ); + prop_assert!( + self.events_mut(side).writable_closed.insert(stream_id), + "side {side:?} emitted duplicate WritableClosed for {stream_id:?}" + ); + } + QlFsmEvent::SessionClosed(_) => { + let state = self.events_mut(side); + prop_assert!( + state.session_epoch > 0, + "side {side:?} emitted SessionClosed without a connected session" + ); + prop_assert!( + state.session_closed_epoch != Some(state.session_epoch), + "side {side:?} emitted duplicate SessionClosed in session epoch {}", + state.session_epoch + ); + state.session_closed_epoch = Some(state.session_epoch); + } + } + } + + Ok(()) + } + + fn assert_prefix_invariants(&self) -> TestCaseResult { + for (stream_id, received) in &self.received_at_a { + let expected = self + .expected_at_a + .get(stream_id) + .map(Vec::as_slice) + .unwrap_or(&[]); + prop_assert!( + expected.starts_with(received), + "side A observed non-prefix bytes on {stream_id:?}: received={received:?} expected={expected:?}" + ); + } + + for (stream_id, received) in &self.received_at_b { + let expected = self + .expected_at_b + .get(stream_id) + .map(Vec::as_slice) + .unwrap_or(&[]); + prop_assert!( + expected.starts_with(received), + "side B observed non-prefix bytes on {stream_id:?}: received={received:?} expected={expected:?}" + ); + } + + Ok(()) + } + + fn assert_legal_link_state(&self) -> TestCaseResult { + let a_connected = matches!(self.harness.a.fsm.state.link, LinkState::Connected(_)); + let b_connected = matches!(self.harness.b.fsm.state.link, LinkState::Connected(_)); + + prop_assert!( + !a_connected || self.harness.a.fsm.peer().is_some(), + "side A reached Connected without a bound peer" + ); + prop_assert!( + !b_connected || self.harness.b.fsm.peer().is_some(), + "side B reached Connected without a bound peer" + ); + + Ok(()) + } + + fn assert_receive_errors(&self) -> TestCaseResult { + for (side, error) in &self.receive_errors { + prop_assert!( + matches!( + error, + QlFsmError::NoSession + | QlFsmError::InvalidState + | QlFsmError::Expired + | QlFsmError::InvalidPayload + | QlFsmError::DecryptFailed + ), + "unexpected receive error on side {side:?}: {error:?}" + ); + } + + Ok(()) + } + + fn assert_terminal_semantics(&self) -> TestCaseResult { + for stream_id in &self.events_a.finished { + if self.inbound_aborted(Side::A, stream_id) { + continue; + } + let expected = self + .expected_at_a + .get(stream_id) + .map(Vec::as_slice) + .unwrap_or(&[]); + let received = self + .received_at_a + .get(stream_id) + .map(Vec::as_slice) + .unwrap_or(&[]); + prop_assert_eq!( + received, + expected, + "side A finished {:?} without receiving all expected bytes", + stream_id + ); + } + + for stream_id in &self.events_b.finished { + if self.inbound_aborted(Side::B, stream_id) { + continue; + } + let expected = self + .expected_at_b + .get(stream_id) + .map(Vec::as_slice) + .unwrap_or(&[]); + let received = self + .received_at_b + .get(stream_id) + .map(Vec::as_slice) + .unwrap_or(&[]); + prop_assert_eq!( + received, + expected, + "side B finished {:?} without receiving all expected bytes", + stream_id + ); + } + + let a_connected = matches!(self.harness.a.fsm.state.link, LinkState::Connected(_)); + let b_connected = matches!(self.harness.b.fsm.state.link, LinkState::Connected(_)); + + for stream_id in &self.finished_by_a { + prop_assert!( + self.events_b.finished.contains(stream_id) + || self.events_b.closed.contains(stream_id) + || !b_connected, + "side A finished {stream_id:?} but side B saw neither Finished nor Closed" + ); + } + + for stream_id in &self.finished_by_b { + prop_assert!( + self.events_a.finished.contains(stream_id) + || self.events_a.closed.contains(stream_id) + || !a_connected, + "side B finished {stream_id:?} but side A saw neither Finished nor Closed" + ); + } + + for stream_id in &self.closed_by_a { + prop_assert!( + self.events_b.closed.contains(stream_id) || !b_connected, + "side A closed {stream_id:?} but side B saw no Closed event" + ); + } + + for stream_id in &self.closed_by_b { + prop_assert!( + self.events_a.closed.contains(stream_id) || !a_connected, + "side B closed {stream_id:?} but side A saw no Closed event" + ); + } + + Ok(()) + } + + fn assert_no_stream_events(&self) -> TestCaseResult { + prop_assert!( + self.known_streams.is_empty() + && self.events_a.opened.is_empty() + && self.events_b.opened.is_empty() + && self.events_a.finished.is_empty() + && self.events_b.finished.is_empty() + && self.events_a.closed.is_empty() + && self.events_b.closed.is_empty() + && self.events_a.writable_closed.is_empty() + && self.events_b.writable_closed.is_empty(), + "handshake-only property observed stream activity" + ); + Ok(()) + } + + fn assert_no_taken_writes(&self) -> TestCaseResult { + prop_assert!( + self.taken_a_to_b.is_empty() && self.taken_b_to_a.is_empty(), + "cleanup left taken writes queued" + ); + Ok(()) + } + + fn assert_quiesced(&mut self) -> TestCaseResult { + self.reject_all_taken(); + + for _ in 0..8 { + self.capture_all_outbound(); + if self.pending_a_to_b.is_empty() && self.pending_b_to_a.is_empty() { + break; + } + self.flush_pending_in_order(); + self.observe_and_assert()?; + } + + self.capture_all_outbound(); + prop_assert!( + self.pending_a_to_b.is_empty() + && self.pending_b_to_a.is_empty() + && self.taken_a_to_b.is_empty() + && self.taken_b_to_a.is_empty(), + "cleanup did not quiesce: taken_a={} taken_b={} pending_a={} pending_b={}", + self.taken_a_to_b.len(), + self.taken_b_to_a.len(), + self.pending_a_to_b.len(), + self.pending_b_to_a.len() + ); + + Ok(()) + } + + fn capture_all_outbound(&mut self) { + while let Some(record) = take_confirmed_outbound_a(&mut self.harness) { + self.pending_a_to_b.push(record); + } + + while let Some(record) = take_confirmed_outbound_b(&mut self.harness) { + self.pending_b_to_a.push(record); + } + } + + fn flush_pending_in_order(&mut self) { + while let Some(record) = pop_front_pending(&mut self.pending_a_to_b) { + self.deliver_to_b(record); + } + + while let Some(record) = pop_front_pending(&mut self.pending_b_to_a) { + self.deliver_to_a(record); + } + } + + fn reject_all_taken(&mut self) { + while let Some(write) = self.taken_a_to_b.pop() { + reject_taken_a(&mut self.harness, &write); + } + + while let Some(write) = self.taken_b_to_a.pop() { + reject_taken_b(&mut self.harness, &write); + } + } + + fn deliver_to_a(&mut self, record: Vec) { + if let Err(error) = deliver_to_a(&mut self.harness, record) { + self.receive_errors.push((Side::A, error)); + } + } + + fn deliver_to_b(&mut self, record: Vec) { + if let Err(error) = deliver_to_b(&mut self.harness, record) { + self.receive_errors.push((Side::B, error)); + } + } + + fn events_mut(&mut self, side: Side) -> &mut SideEventState { + match side { + Side::A => &mut self.events_a, + Side::B => &mut self.events_b, + } + } + + fn events(&self, side: Side) -> &SideEventState { + match side { + Side::A => &self.events_a, + Side::B => &self.events_b, + } + } + + fn inbound_aborted(&self, side: Side, stream_id: &StreamId) -> bool { + self.events(side).closed.contains(stream_id) + || match side { + Side::A => self.closed_by_a.contains(stream_id), + Side::B => self.closed_by_b.contains(stream_id), + } + } +} + +fn take_unconfirmed_outbound_a(harness: &mut Harness) -> Option { + let time = harness.time(); + let Node { fsm, crypto, .. } = &mut harness.a; + let write = fsm.take_next_write(time, crypto)?; + Some(TakenWrite { + record: write.record, + write_id: write.session_write_id, + }) +} + +fn take_unconfirmed_outbound_b(harness: &mut Harness) -> Option { + let time = harness.time(); + let Node { fsm, crypto, .. } = &mut harness.b; + let write = fsm.take_next_write(time, crypto)?; + Some(TakenWrite { + record: write.record, + write_id: write.session_write_id, + }) +} + +fn take_confirmed_outbound_a(harness: &mut Harness) -> Option> { + let write = take_unconfirmed_outbound_a(harness)?; + confirm_taken_a(harness, &write); + Some(write.record) +} + +fn take_confirmed_outbound_b(harness: &mut Harness) -> Option> { + let write = take_unconfirmed_outbound_b(harness)?; + confirm_taken_b(harness, &write); + Some(write.record) +} + +fn confirm_taken_a(harness: &mut Harness, write: &TakenWrite) { + if let Some(write_id) = write.write_id { + harness.a.fsm.confirm_session_write(harness.time(), write_id); + } +} + +fn confirm_taken_b(harness: &mut Harness, write: &TakenWrite) { + if let Some(write_id) = write.write_id { + harness.b.fsm.confirm_session_write(harness.time(), write_id); + } +} + +fn reject_taken_a(harness: &mut Harness, write: &TakenWrite) { + if let Some(write_id) = write.write_id { + harness.a.fsm.reject_session_write(write_id); + } +} + +fn reject_taken_b(harness: &mut Harness, write: &TakenWrite) { + if let Some(write_id) = write.write_id { + harness.b.fsm.reject_session_write(write_id); + } +} + +fn deliver_to_a(harness: &mut Harness, record: Vec) -> Result<(), QlFsmError> { + let time = harness.time(); + let Node { + fsm, + crypto, + events, + } = &mut harness.a; + fsm.receive(time, record, crypto, |event| events.push_back(event)) +} + +fn deliver_to_b(harness: &mut Harness, record: Vec) -> Result<(), QlFsmError> { + let time = harness.time(); + let Node { + fsm, + crypto, + events, + } = &mut harness.b; + fsm.receive(time, record, crypto, |event| events.push_back(event)) +} + +fn take_pending(pending: &mut Vec>, index: usize) -> Option> { + if pending.is_empty() { + return None; + } + + Some(pending.remove(index % pending.len())) +} + +fn peek_pending(pending: &[Vec], index: usize) -> Option> { + if pending.is_empty() { + return None; + } + + Some(pending[index % pending.len()].clone()) +} + +fn pop_front_pending(pending: &mut Vec>) -> Option> { + if pending.is_empty() { + None + } else { + Some(pending.remove(0)) + } +} + +fn take_taken(taken: &mut Vec, index: usize) -> Option { + if taken.is_empty() { + return None; + } + + Some(taken.remove(index % taken.len())) +} + +fn drain_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { + let mut out = Vec::new(); + + loop { + let Some(chunks) = fsm.stream_read(stream_id) else { + break; + }; + + let mut read = 0usize; + for chunk in chunks { + out.extend_from_slice(chunk); + read += chunk.len(); + } + + if read == 0 { + break; + } + + fsm.stream_read_commit(stream_id, read).unwrap(); + } + + out +} + +fn handshake_action_strategy() -> impl Strategy { + let queue_index = 0usize..6; + prop_oneof![ + Just(Action::ConnectIkA), + Just(Action::ConnectIkB), + Just(Action::ConnectKkA), + Just(Action::ConnectKkB), + (0u8..40).prop_map(Action::AdvanceMs), + Just(Action::OnTimerA), + Just(Action::OnTimerB), + Just(Action::OnTimerBoth), + Just(Action::Pump), + Just(Action::TakeNextAToB), + Just(Action::TakeNextBToA), + queue_index.clone().prop_map(Action::ConfirmTakenAToB), + queue_index.clone().prop_map(Action::ConfirmTakenBToA), + queue_index.clone().prop_map(Action::RejectTakenAToB), + queue_index.clone().prop_map(Action::RejectTakenBToA), + Just(Action::CaptureNextAToB), + Just(Action::CaptureNextBToA), + Just(Action::DeliverNextAToB), + Just(Action::DeliverNextBToA), + Just(Action::DropNextAToB), + Just(Action::DropNextBToA), + queue_index.clone().prop_map(Action::DeliverQueuedAToB), + queue_index.clone().prop_map(Action::DeliverQueuedBToA), + queue_index.clone().prop_map(Action::DuplicateQueuedAToB), + queue_index.clone().prop_map(Action::DuplicateQueuedBToA), + queue_index.clone().prop_map(Action::DropQueuedAToB), + queue_index.prop_map(Action::DropQueuedBToA), + ] +} + +fn connected_action_strategy() -> impl Strategy { + let bytes = vec(any::(), 0..24); + let slot = 0usize..SLOT_COUNT; + let queue_index = 0usize..6; + prop_oneof![ + (0u8..30).prop_map(Action::AdvanceMs), + Just(Action::OnTimerA), + Just(Action::OnTimerB), + Just(Action::OnTimerBoth), + Just(Action::Pump), + Just(Action::TakeNextAToB), + Just(Action::TakeNextBToA), + queue_index.clone().prop_map(Action::ConfirmTakenAToB), + queue_index.clone().prop_map(Action::ConfirmTakenBToA), + queue_index.clone().prop_map(Action::RejectTakenAToB), + queue_index.clone().prop_map(Action::RejectTakenBToA), + Just(Action::CaptureNextAToB), + Just(Action::CaptureNextBToA), + Just(Action::DeliverNextAToB), + Just(Action::DeliverNextBToA), + Just(Action::DropNextAToB), + Just(Action::DropNextBToA), + queue_index.clone().prop_map(Action::DeliverQueuedAToB), + queue_index.clone().prop_map(Action::DeliverQueuedBToA), + queue_index.clone().prop_map(Action::DuplicateQueuedAToB), + queue_index.clone().prop_map(Action::DuplicateQueuedBToA), + queue_index.clone().prop_map(Action::DropQueuedAToB), + queue_index.clone().prop_map(Action::DropQueuedBToA), + slot.clone().prop_map(Action::OpenStreamA), + slot.clone().prop_map(Action::OpenStreamB), + (slot.clone(), bytes.clone()).prop_map(|(slot, bytes)| Action::WriteA { slot, bytes }), + (slot.clone(), bytes).prop_map(|(slot, bytes)| Action::WriteB { slot, bytes }), + slot.clone().prop_map(Action::FinishA), + slot.clone().prop_map(Action::FinishB), + slot.clone().prop_map(Action::CloseA), + slot.prop_map(Action::CloseB), + ] +} + +fn write_tracking_action_strategy() -> impl Strategy { + let bytes = vec(any::(), 0..16); + let slot = 0usize..SLOT_COUNT; + let queue_index = 0usize..6; + prop_oneof![ + slot.clone().prop_map(Action::OpenStreamA), + slot.clone().prop_map(Action::OpenStreamB), + (slot.clone(), bytes.clone()).prop_map(|(slot, bytes)| Action::WriteA { slot, bytes }), + (slot.clone(), bytes).prop_map(|(slot, bytes)| Action::WriteB { slot, bytes }), + Just(Action::TakeNextAToB), + Just(Action::TakeNextBToA), + queue_index.clone().prop_map(Action::ConfirmTakenAToB), + queue_index.clone().prop_map(Action::ConfirmTakenBToA), + queue_index.clone().prop_map(Action::RejectTakenAToB), + queue_index.clone().prop_map(Action::RejectTakenBToA), + queue_index.clone().prop_map(Action::DeliverQueuedAToB), + queue_index.clone().prop_map(Action::DeliverQueuedBToA), + queue_index.clone().prop_map(Action::DuplicateQueuedAToB), + queue_index.clone().prop_map(Action::DuplicateQueuedBToA), + queue_index.clone().prop_map(Action::DropQueuedAToB), + queue_index.clone().prop_map(Action::DropQueuedBToA), + Just(Action::Pump), + Just(Action::OnTimerA), + Just(Action::OnTimerB), + Just(Action::OnTimerBoth), + (0u8..20).prop_map(Action::AdvanceMs), + ] +} + +fn terminal_action_strategy() -> impl Strategy { + let bytes = vec(any::(), 0..16); + let slot = 0usize..SLOT_COUNT; + let queue_index = 0usize..6; + prop_oneof![ + slot.clone().prop_map(Action::OpenStreamA), + slot.clone().prop_map(Action::OpenStreamB), + (slot.clone(), bytes.clone()).prop_map(|(slot, bytes)| Action::WriteA { slot, bytes }), + (slot.clone(), bytes).prop_map(|(slot, bytes)| Action::WriteB { slot, bytes }), + slot.clone().prop_map(Action::FinishA), + slot.clone().prop_map(Action::FinishB), + slot.clone().prop_map(Action::CloseA), + slot.clone().prop_map(Action::CloseB), + Just(Action::TakeNextAToB), + Just(Action::TakeNextBToA), + queue_index.clone().prop_map(Action::ConfirmTakenAToB), + queue_index.clone().prop_map(Action::ConfirmTakenBToA), + queue_index.clone().prop_map(Action::RejectTakenAToB), + queue_index.clone().prop_map(Action::RejectTakenBToA), + queue_index.clone().prop_map(Action::DeliverQueuedAToB), + queue_index.clone().prop_map(Action::DeliverQueuedBToA), + queue_index.clone().prop_map(Action::DuplicateQueuedAToB), + queue_index.clone().prop_map(Action::DuplicateQueuedBToA), + queue_index.clone().prop_map(Action::DropQueuedAToB), + queue_index.clone().prop_map(Action::DropQueuedBToA), + Just(Action::Pump), + Just(Action::OnTimerA), + Just(Action::OnTimerB), + Just(Action::OnTimerBoth), + (0u8..20).prop_map(Action::AdvanceMs), + ] +} + +proptest! { + #![proptest_config(ProptestConfig { + cases: 24, + max_shrink_iters: 10_000, + .. ProptestConfig::default() + })] + + #[test] + fn randomized_handshake_actions_quiesce(actions in vec(handshake_action_strategy(), 1..64)) { + let mut runner = Runner::handshake(); + runner.run(&actions)?; + runner.assert_no_stream_events()?; + } + + #[test] + fn randomized_stream_actions_preserve_integrity(actions in vec(connected_action_strategy(), 1..80)) { + let mut runner = Runner::connected(); + runner.run(&actions)?; + } + + #[test] + fn randomized_write_tracking_actions_quiesce(actions in vec(write_tracking_action_strategy(), 1..80)) { + let mut runner = Runner::connected(); + runner.run(&actions)?; + runner.assert_no_taken_writes()?; + } + + #[test] + fn randomized_terminal_actions_preserve_terminal_semantics(actions in vec(terminal_action_strategy(), 1..80)) { + let mut runner = Runner::connected(); + runner.run(&actions)?; + runner.assert_terminal_semantics()?; + } +} From 2b91c0a4d3b1def3fe97700a9ebbd3d64de1ff07 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 19:37:19 -0400 Subject: [PATCH 118/304] ql-fsm: introduce remotestreamhistory --- ql-fsm/src/session/mod.rs | 115 ++++++++++++++------ ql-fsm/src/session/remote_stream_history.rs | 50 +++++++++ ql-fsm/src/session/state.rs | 5 +- ql-fsm/src/session/tests.rs | 85 +++++++++++++++ 4 files changed, 217 insertions(+), 38 deletions(-) create mode 100644 ql-fsm/src/session/remote_stream_history.rs diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 1297c75a..d2c31590 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod received_records; +pub(crate) mod remote_stream_history; pub(crate) mod state; pub(crate) mod stream_parity; pub(crate) mod stream_rx; @@ -19,6 +20,7 @@ use ql_wire::{ use self::{ received_records::{ReceiveOutcome, ReceivedRecords}, + remote_stream_history::RemoteStreamHistory, state::{AckState, InboundState, OutboundState, SessionFsmState, StreamRole, StreamState}, stream_parity::StreamParity, stream_rx::{StreamReadIter, StreamRxError}, @@ -122,6 +124,7 @@ impl SessionFsm { pending_control: Default::default(), streams: Default::default(), next_stream_index: 0, + remote_stream_history: RemoteStreamHistory::new(config.local_parity.remote()), }, } } @@ -662,18 +665,25 @@ impl SessionFsm { let stream = match self.state.streams.entry(stream_id) { Entry::Occupied(entry) => entry.into_mut(), Entry::Vacant(entry) => { - if !self.config.local_parity.remote().matches(stream_id) { - if self.local_stream_was_opened(stream_id) { - return Ok(()); + match classify_missing_stream( + self.config.local_parity, + self.state.next_stream_ordinal, + stream_id, + &mut self.state.remote_stream_history, + ) { + MissingStreamAction::Create => {} + MissingStreamAction::Ignore => return Ok(()), + MissingStreamAction::FailProtocol => { + self.fail_session( + SessionClose { + code: SessionCloseCode::PROTOCOL, + }, + emit, + ); + return Err(()); } - self.fail_session( - SessionClose { - code: SessionCloseCode::PROTOCOL, - }, - emit, - ); - return Err(()); } + emit(SessionEvent::Opened(stream_id)); entry.insert(StreamState::new( StreamRole::Responder, @@ -752,31 +762,38 @@ impl SessionFsm { frame: &StreamClose, emit: &mut impl FnMut(SessionEvent), ) -> Result<(), ()> { - let created = match self.state.streams.entry(frame.stream_id) { - Entry::Occupied(_) => false, + let mut created = false; + let stream = match self.state.streams.entry(frame.stream_id) { + Entry::Occupied(entry) => entry.into_mut(), Entry::Vacant(entry) => { - if !self.config.local_parity.remote().matches(frame.stream_id) { - if self.local_stream_was_opened(frame.stream_id) { - return Ok(()); + match classify_missing_stream( + self.config.local_parity, + self.state.next_stream_ordinal, + frame.stream_id, + &mut self.state.remote_stream_history, + ) { + MissingStreamAction::Create => {} + MissingStreamAction::Ignore => return Ok(()), + MissingStreamAction::FailProtocol => { + self.fail_session( + SessionClose { + code: SessionCloseCode::PROTOCOL, + }, + emit, + ); + return Err(()); } - self.fail_session( - SessionClose { - code: SessionCloseCode::PROTOCOL, - }, - emit, - ); - return Err(()); } + + created = true; entry.insert(StreamState::new( StreamRole::Responder, self.config.stream_receive_buffer_size, self.config.initial_peer_stream_receive_window, - )); - true + )) } }; - let stream = self.state.streams.get_mut(&frame.stream_id).unwrap(); if created { emit(SessionEvent::Opened(frame.stream_id)); } @@ -834,17 +851,6 @@ impl SessionFsm { matches!(target, CloseTarget::Both) || role.outbound_target() == target } - /// Returns true if this locally-opened stream id was already reaped, so stale peer frames for it can be ignored. - fn local_stream_was_opened(&self, stream_id: StreamId) -> bool { - self.config.local_parity.matches(stream_id) - && stream_id.0 - < self - .config - .local_parity - .make_stream_id(self.state.next_stream_ordinal) - .0 - } - fn stream_is_reapable(&self, stream_id: StreamId, stream: &StreamState) -> bool { let tracked_refs_stream = self.state.tracked_records.values().any(|record| { record.window_updates.iter().any(|(id, _)| *id == stream_id) @@ -943,6 +949,43 @@ fn schedule_ack(ack_state: &mut AckState, due_at: Instant) { }; } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum MissingStreamAction { + Create, + Ignore, + FailProtocol, +} + +fn classify_missing_stream( + local_parity: StreamParity, + next_stream_ordinal: u32, + stream_id: StreamId, + remote_stream_history: &mut RemoteStreamHistory, +) -> MissingStreamAction { + if !local_parity.remote().matches(stream_id) { + return if local_stream_was_opened(local_parity, next_stream_ordinal, stream_id) { + MissingStreamAction::Ignore + } else { + MissingStreamAction::FailProtocol + }; + } + + if remote_stream_history.observe(stream_id) { + MissingStreamAction::Ignore + } else { + MissingStreamAction::Create + } +} + +fn local_stream_was_opened( + local_parity: StreamParity, + next_stream_ordinal: u32, + stream_id: StreamId, +) -> bool { + local_parity.matches(stream_id) + && stream_id.0 < local_parity.make_stream_id(next_stream_ordinal).0 +} + fn restore_tracked_record( now: Instant, ack_state: &mut AckState, diff --git a/ql-fsm/src/session/remote_stream_history.rs b/ql-fsm/src/session/remote_stream_history.rs new file mode 100644 index 00000000..7d20c0cf --- /dev/null +++ b/ql-fsm/src/session/remote_stream_history.rs @@ -0,0 +1,50 @@ +use std::collections::BTreeSet; + +use ql_wire::StreamId; + +use super::stream_parity::StreamParity; + +#[derive(Debug)] +pub struct RemoteStreamHistory { + parity: StreamParity, + seen_prefix_end: u32, + seen_sparse: BTreeSet, +} + +impl RemoteStreamHistory { + pub fn new(parity: StreamParity) -> Self { + Self { + parity, + seen_prefix_end: 0, + seen_sparse: BTreeSet::new(), + } + } + + /// returns true when this remote stream id was already observed before + /// panics if stream_id is wrong stream parity + pub fn observe(&mut self, stream_id: StreamId) -> bool { + let ordinal = self + .stream_ordinal(stream_id) + .expect("remote stream history used with wrong stream parity"); + if ordinal < self.seen_prefix_end { + return true; + } + if ordinal > self.seen_prefix_end { + return !self.seen_sparse.insert(ordinal); + } + + self.seen_prefix_end = self.seen_prefix_end.saturating_add(1); + while self.seen_sparse.remove(&self.seen_prefix_end) { + self.seen_prefix_end = self.seen_prefix_end.saturating_add(1); + } + false + } + + fn stream_ordinal(&self, stream_id: StreamId) -> Option { + let delta = stream_id.0.checked_sub(self.parity.first_stream_id())?; + if delta % 2 != 0 { + return None; + } + Some(delta / 2) + } +} diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index 24e07c62..54f4d584 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -4,8 +4,8 @@ use indexmap::IndexMap; use ql_wire::{CloseTarget, RecordSeq, SessionClose, StreamClose, StreamId}; use super::{ - received_records::ReceivedRecords, stream_rx::StreamRx, stream_tx::StreamTx, - tracked::TrackedRecord, SessionState, + received_records::ReceivedRecords, remote_stream_history::RemoteStreamHistory, + stream_rx::StreamRx, stream_tx::StreamTx, tracked::TrackedRecord, SessionState, }; pub struct SessionFsmState { @@ -22,6 +22,7 @@ pub struct SessionFsmState { pub pending_control: PendingSessionControl, pub streams: IndexMap, pub next_stream_index: usize, + pub remote_stream_history: RemoteStreamHistory, } #[derive(Debug)] diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 703d4d17..b884001f 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -365,6 +365,91 @@ fn duplicate_stream_data_is_not_redelivered() { assert_eq!(read_stream_all(&mut fsm, stream_id), b"hi".to_vec()); } +#[test] +fn duplicate_remote_close_after_reap_is_ignored() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let close = StreamClose { + stream_id: StreamId(1), + target: CloseTarget::Both, + code: StreamCloseCode(9), + }; + let record = SessionRecord { + frames: vec![SessionFrame::StreamClose(close.clone())], + }; + + let first = receive_events(&mut fsm, now, RecordSeq(1), &record); + assert_eq!( + first, + vec![ + SessionEvent::Opened(close.stream_id), + SessionEvent::Closed(close.clone()), + SessionEvent::WritableClosed(close.stream_id), + ] + ); + + let second = receive_events(&mut fsm, now + Duration::from_millis(1), RecordSeq(2), &record); + assert!(second.is_empty()); +} + +#[test] +fn duplicate_finished_remote_data_after_reap_is_ignored() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let stream_id = StreamId(1); + let record = SessionRecord { + frames: vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: 0, + fin: true, + bytes: b"hello".to_vec(), + })], + }; + + let first = receive_events(&mut fsm, now, RecordSeq(1), &record); + assert_eq!( + first, + vec![ + SessionEvent::Opened(stream_id), + SessionEvent::Readable(stream_id), + SessionEvent::Finished(stream_id), + ] + ); + assert_eq!(read_stream_all(&mut fsm, stream_id), b"hello".to_vec()); + + let second = receive_events(&mut fsm, now + Duration::from_millis(1), RecordSeq(2), &record); + assert!(second.is_empty()); +} + +#[test] +fn out_of_order_remote_stream_first_observations_still_open_once_each() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let close3 = SessionRecord { + frames: vec![SessionFrame::StreamClose(StreamClose { + stream_id: StreamId(3), + target: CloseTarget::Both, + code: StreamCloseCode(1), + })], + }; + let close1 = SessionRecord { + frames: vec![SessionFrame::StreamClose(StreamClose { + stream_id: StreamId(1), + target: CloseTarget::Both, + code: StreamCloseCode(2), + })], + }; + + let first = receive_events(&mut fsm, now, RecordSeq(1), &close3); + assert!(first.contains(&SessionEvent::Opened(StreamId(3)))); + + let second = receive_events(&mut fsm, now + Duration::from_millis(1), RecordSeq(2), &close1); + assert!(second.contains(&SessionEvent::Opened(StreamId(1)))); + + let third = receive_events(&mut fsm, now + Duration::from_millis(2), RecordSeq(3), &close3); + assert!(third.is_empty()); +} + #[test] fn close_does_not_ack_rejected_record_seq() { let now = Instant::now(); From 8f7e9147f0baf05c521990bff0f337760ad193ea Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 20:46:51 -0400 Subject: [PATCH 119/304] ql: varint --- ql-fsm/src/implementation/core.rs | 10 +- ql-fsm/src/session/mod.rs | 53 ++++--- ql-fsm/src/session/received_records.rs | 36 ++--- ql-fsm/src/session/remote_stream_history.rs | 6 +- ql-fsm/src/session/stream_parity.rs | 6 +- ql-fsm/src/session/tests.rs | 148 +++++++++----------- ql-fsm/src/tests/session.rs | 8 +- ql-runtime/src/driver/test.rs | 8 +- ql-wire/src/codec.rs | 44 +++++- ql-wire/src/encrypted/ack.rs | 26 ++-- ql-wire/src/encrypted/builder.rs | 82 +++++++---- ql-wire/src/encrypted/mod.rs | 91 ++++++------ ql-wire/src/encrypted/stream_close.rs | 9 +- ql-wire/src/encrypted/stream_data.rs | 21 +-- ql-wire/src/encrypted/stream_window.rs | 16 ++- ql-wire/src/header.rs | 48 +++++-- ql-wire/src/lib.rs | 2 + ql-wire/src/record.rs | 4 +- ql-wire/src/tests.rs | 80 ++++++++--- ql-wire/src/varint.rs | 126 +++++++++++++++++ 20 files changed, 551 insertions(+), 273 deletions(-) create mode 100644 ql-wire/src/varint.rs diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index 40b70640..38a6e13d 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -1,8 +1,7 @@ use std::time::{Duration, Instant}; use ql_wire::{ - self as wire, CloseTarget, QlCrypto, SessionCloseCode, SessionHeader, StreamCloseCode, - StreamId, WireParse, + self as wire, CloseTarget, QlCrypto, SessionCloseCode, StreamCloseCode, StreamId, WireParse, }; use crate::{ @@ -95,13 +94,10 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option Self { config.record_max_size = config .record_max_size - .max(SessionRecordBuilder::WIRE_PREFIX_LEN); + .max(SessionRecordBuilder::MIN_CAPACITY); config.stream_send_buffer_size = config.stream_send_buffer_size.max(1); config.stream_receive_buffer_size = config.stream_receive_buffer_size.max(1); Self { @@ -116,7 +116,7 @@ impl SessionFsm { last_inbound_at: now, session_state: SessionState::Open, next_stream_ordinal: 0, - next_record_seq: RecordSeq(0), + next_record_seq: RecordSeq::from_u32(0), next_write_id: 0, tracked_records: Default::default(), received_records: ReceivedRecords::default(), @@ -409,15 +409,11 @@ impl SessionFsm { .min() } - pub fn take_next_write( - &mut self, - now: Instant, - ) -> Option<(Option, RecordSeq, SessionRecordBuilder)> { + pub fn take_next_write(&mut self, now: Instant) -> Option<(Option, SessionRecordBuilder)> { self.state.now = now; self.collect_timeouts(); let (builder, outbound) = self.build_next_record()?; - let seq = outbound.seq; let should_track = outbound.ping_included || !outbound.window_updates.is_empty() @@ -430,12 +426,12 @@ impl SessionFsm { write_id }); - Some((write_id, seq, builder)) + Some((write_id, builder)) } fn build_next_record(&mut self) -> Option<(SessionRecordBuilder, TrackedRecord)> { let seq = self.state.next_record_seq; - let mut builder = SessionRecordBuilder::new(self.config.record_max_size); + let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); let mut outbound = TrackedRecord { seq, frames: Vec::new(), @@ -475,7 +471,11 @@ impl SessionFsm { return None; } - self.state.next_record_seq = RecordSeq(self.state.next_record_seq.0.saturating_add(1)); + self.state.next_record_seq = seq + .into_inner() + .checked_add(1) + .and_then(|next| RecordSeq::from_u64(next).ok()) + .expect("record sequence overflow"); Some((builder, outbound)) } @@ -525,17 +525,17 @@ impl SessionFsm { } let frame = StreamWindow { stream_id, - maximum_offset: stream.recv_limit(), + maximum_offset: VarInt::from_u64(stream.recv_limit()).unwrap(), }; if !builder.push_stream_window(&frame) { break; } stream.pending_window = false; - stream.advertised_max_offset = frame.maximum_offset; + stream.advertised_max_offset = frame.maximum_offset.into_inner(); outbound .window_updates - .push((stream_id, frame.maximum_offset)); + .push((stream_id, frame.maximum_offset.into_inner())); } } @@ -544,8 +544,7 @@ impl SessionFsm { builder: &mut SessionRecordBuilder, outbound: &mut TrackedRecord, ) { - const OVERHEAD: usize = - 1 + std::mem::size_of::() + StreamData::>::MIN_WIRE_SIZE; + const OVERHEAD: usize = 1 + VarInt::MAX_SIZE + StreamData::>::MIN_WIRE_SIZE; let len = self.state.streams.len(); if len == 0 { @@ -568,9 +567,11 @@ impl SessionFsm { let Some(candidate) = stream.tx.next_range(max_payload, stream.peer_max_offset) else { continue; }; + let offset = + VarInt::from_u64(candidate.offset).expect("stream offsets must fit ql-wire varint"); let frame = StreamData { stream_id, - offset: candidate.offset, + offset, fin: candidate.fin, bytes: stream.tx.ranged_bytes(candidate), }; @@ -609,7 +610,7 @@ impl SessionFsm { let tracked_records = &mut self.state.tracked_records; let streams = &mut self.state.streams; for (_, record) in tracked_records.extract_if(.., |_, record| { - record.sent_at.is_some() && ack.contains(record.seq.0) + record.sent_at.is_some() && ack.contains(record.seq.into_inner()) }) { for frame in &record.frames { acknowledge_tracked_frame(streams, stream_send_buffer_size, frame, emit); @@ -693,11 +694,13 @@ impl SessionFsm { } }; + let frame_offset = frame.offset.into_inner(); match stream.inbound_state { InboundState::Open => {} InboundState::Discarding => return Ok(()), InboundState::Finished | InboundState::Closed(_) => { - if frame.offset + frame.bytes.len() as u64 <= stream.rx.start_offset() { + if frame_offset.saturating_add(frame.bytes.len() as u64) <= stream.rx.start_offset() + { return Ok(()); } self.fail_session( @@ -711,7 +714,7 @@ impl SessionFsm { } let was_readable = stream.readable_bytes() > 0; - let insert = stream.rx.insert(frame.offset, frame.fin, frame.bytes); + let insert = stream.rx.insert(frame_offset, frame.fin, frame.bytes); match insert { Ok(outcome) => { if !was_readable && outcome.newly_readable_bytes > 0 { @@ -749,8 +752,9 @@ impl SessionFsm { }; let was_full = stream.send_capacity(self.config.stream_send_buffer_size) == 0; - if frame.maximum_offset > stream.peer_max_offset { - stream.peer_max_offset = frame.maximum_offset; + let maximum_offset = frame.maximum_offset.into_inner(); + if maximum_offset > stream.peer_max_offset { + stream.peer_max_offset = maximum_offset; } if was_full && stream.send_capacity(self.config.stream_send_buffer_size) > 0 { emit(SessionEvent::Writable(frame.stream_id)); @@ -983,7 +987,10 @@ fn local_stream_was_opened( stream_id: StreamId, ) -> bool { local_parity.matches(stream_id) - && stream_id.0 < local_parity.make_stream_id(next_stream_ordinal).0 + && stream_id.into_inner() + < local_parity + .make_stream_id(next_stream_ordinal) + .into_inner() } fn restore_tracked_record( diff --git a/ql-fsm/src/session/received_records.rs b/ql-fsm/src/session/received_records.rs index 69cf0cad..fe0a58f5 100644 --- a/ql-fsm/src/session/received_records.rs +++ b/ql-fsm/src/session/received_records.rs @@ -18,7 +18,7 @@ impl ReceivedRecords { const TRACKED_WINDOW: u64 = Self::TRACKED_LEN - 1; pub fn insert(&mut self, seq: RecordSeq) -> ReceiveOutcome { - let seq = seq.0; + let seq = seq.into_inner(); if self.seen == 0 { self.base = seq; self.seen = 1; @@ -50,7 +50,7 @@ impl ReceivedRecords { pub fn ack(&self) -> Option { (self.seen != 0).then_some(RecordAck { - base_seq: RecordSeq(self.base), + base_seq: RecordSeq::from_u64(self.base).expect("tracked record seq must fit varint"), bits: self.seen, }) } @@ -75,22 +75,26 @@ mod tests { use super::{ReceiveOutcome, ReceivedRecords}; + fn seq(value: u64) -> RecordSeq { + RecordSeq::from_u64(value).unwrap() + } + #[test] fn inserts_pack_contiguous_bits() { let mut received = ReceivedRecords::default(); assert_eq!( - received.insert(RecordSeq(10)), + received.insert(seq(10)), ReceiveOutcome::New { out_of_order: false } ); assert_eq!( - received.insert(RecordSeq(12)), + received.insert(seq(12)), ReceiveOutcome::New { out_of_order: true } ); assert_eq!( - received.insert(RecordSeq(11)), + received.insert(seq(11)), ReceiveOutcome::New { out_of_order: true } ); @@ -98,7 +102,7 @@ mod tests { assert_eq!( ack, RecordAck { - base_seq: RecordSeq(10), + base_seq: seq(10), bits: 0b111, } ); @@ -109,22 +113,22 @@ mod tests { let mut received = ReceivedRecords::default(); assert_eq!( - received.insert(RecordSeq(0)), + received.insert(seq(0)), ReceiveOutcome::New { out_of_order: false } ); assert_eq!( - received.insert(RecordSeq(300)), + received.insert(seq(300)), ReceiveOutcome::New { out_of_order: true } ); - assert_eq!(received.insert(RecordSeq(0)), ReceiveOutcome::TooOld); + assert_eq!(received.insert(seq(0)), ReceiveOutcome::TooOld); let ack = received.ack().unwrap(); assert_eq!( ack, RecordAck { - base_seq: RecordSeq(237), + base_seq: seq(237), bits: 1u64 << 63, } ); @@ -135,12 +139,12 @@ mod tests { let mut received = ReceivedRecords::default(); assert_eq!( - received.insert(RecordSeq(7)), + received.insert(seq(7)), ReceiveOutcome::New { out_of_order: false } ); - assert_eq!(received.insert(RecordSeq(7)), ReceiveOutcome::Duplicate); + assert_eq!(received.insert(seq(7)), ReceiveOutcome::Duplicate); } #[test] @@ -148,17 +152,17 @@ mod tests { let mut received = ReceivedRecords::default(); assert_eq!( - received.insert(RecordSeq(10)), + received.insert(seq(10)), ReceiveOutcome::New { out_of_order: false } ); assert_eq!( - received.insert(RecordSeq(12)), + received.insert(seq(12)), ReceiveOutcome::New { out_of_order: true } ); assert_eq!( - received.insert(RecordSeq(70)), + received.insert(seq(70)), ReceiveOutcome::New { out_of_order: true } ); @@ -166,7 +170,7 @@ mod tests { assert_eq!( ack, RecordAck { - base_seq: RecordSeq(10), + base_seq: seq(10), bits: (1u64 << 0) | (1u64 << 2) | (1u64 << 60), } ); diff --git a/ql-fsm/src/session/remote_stream_history.rs b/ql-fsm/src/session/remote_stream_history.rs index 7d20c0cf..f851d0a5 100644 --- a/ql-fsm/src/session/remote_stream_history.rs +++ b/ql-fsm/src/session/remote_stream_history.rs @@ -41,10 +41,12 @@ impl RemoteStreamHistory { } fn stream_ordinal(&self, stream_id: StreamId) -> Option { - let delta = stream_id.0.checked_sub(self.parity.first_stream_id())?; + let delta = stream_id + .into_inner() + .checked_sub(u64::from(self.parity.first_stream_id()))?; if delta % 2 != 0 { return None; } - Some(delta / 2) + u32::try_from(delta / 2).ok() } } diff --git a/ql-fsm/src/session/stream_parity.rs b/ql-fsm/src/session/stream_parity.rs index 8b95ad51..1fb498e0 100644 --- a/ql-fsm/src/session/stream_parity.rs +++ b/ql-fsm/src/session/stream_parity.rs @@ -23,8 +23,8 @@ impl StreamParity { pub const fn matches(self, stream_id: StreamId) -> bool { match self { - Self::Even => stream_id.0 % 2 == 0, - Self::Odd => stream_id.0 % 2 == 1, + Self::Even => stream_id.into_inner() % 2 == 0, + Self::Odd => stream_id.into_inner() % 2 == 1, } } @@ -36,7 +36,7 @@ impl StreamParity { } pub fn make_stream_id(self, ordinal: u32) -> StreamId { - StreamId( + StreamId::from_u32( self.first_stream_id() .saturating_add(ordinal.saturating_mul(2)), ) diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index b884001f..239bf017 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -2,12 +2,24 @@ use std::time::{Duration, Instant}; use ql_wire::{ CloseTarget, RecordAck, RecordSeq, SessionFrame, SessionRecord, SessionRecordBuilder, - StreamClose, StreamCloseCode, StreamData, StreamId, XID, + StreamClose, StreamCloseCode, StreamData, StreamId, VarInt, XID, }; use super::{SessionEvent, SessionFsm, SessionFsmConfig}; use crate::session::stream_parity::StreamParity; +fn seq(value: u64) -> RecordSeq { + RecordSeq::from_u64(value).unwrap() +} + +fn stream_id(value: u64) -> StreamId { + StreamId::from_u64(value).unwrap() +} + +fn offset(value: u64) -> VarInt { + VarInt::from_u64(value).unwrap() +} + fn read_stream_all(fsm: &mut SessionFsm, stream_id: StreamId) -> Vec { let mut out = Vec::new(); loop { @@ -25,11 +37,14 @@ fn read_stream_all(fsm: &mut SessionFsm, stream_id: StreamId) -> Vec { } fn next_outbound(fsm: &mut SessionFsm, now: Instant) -> Option<(RecordSeq, SessionRecord)> { - let (write_id, seq, builder) = fsm.take_next_write(now)?; + let (write_id, builder) = fsm.take_next_write(now)?; if let Some(write_id) = write_id { fsm.confirm_write(now, write_id); } - Some((seq, SessionRecord::decode(builder.bytes()).unwrap())) + Some(( + builder.seq(), + SessionRecord::decode(builder.bytes()).unwrap(), + )) } fn receive_events( @@ -38,8 +53,7 @@ fn receive_events( seq: RecordSeq, record: &SessionRecord, ) -> Vec { - let mut builder = - SessionRecordBuilder::new(SessionRecordBuilder::WIRE_PREFIX_LEN + record.wire_size()); + let mut builder = SessionRecordBuilder::new(seq, usize::MAX); for frame in &record.frames { assert!(builder.push_frame(frame)); } @@ -62,8 +76,8 @@ fn outbound_record_seq_increments_monotonically() { assert_eq!(fsm.write_stream(stream_id, b"two").unwrap(), 3); let (second_seq, _) = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); - assert_eq!(first_seq, RecordSeq(0)); - assert_eq!(second_seq, RecordSeq(1)); + assert_eq!(first_seq, seq(0)); + assert_eq!(second_seq, seq(1)); } #[test] @@ -87,7 +101,7 @@ fn lost_record_on_one_stream_does_not_block_another_stream() { let now = Instant::now(); let mut fsm = SessionFsm::new( SessionFsmConfig { - record_max_size: 80 + SessionRecordBuilder::WIRE_PREFIX_LEN, + record_max_size: 80 + SessionRecordBuilder::MIN_CAPACITY, ..SessionFsmConfig::default() }, now, @@ -121,34 +135,6 @@ fn lost_record_on_one_stream_does_not_block_another_stream() { assert_eq!(stream_ids, vec![stream_id_b]); } -#[test] -fn fin_only_stream_data_fits_exact_record_limit() { - let now = Instant::now(); - let stream_data_overhead = - 1 + std::mem::size_of::() + StreamData::>::MIN_WIRE_SIZE; - let mut fsm = SessionFsm::new( - SessionFsmConfig { - record_max_size: SessionRecordBuilder::WIRE_PREFIX_LEN + stream_data_overhead, - ..SessionFsmConfig::default() - }, - now, - ); - let stream_id = fsm.open_stream().unwrap(); - - fsm.finish_stream(stream_id).unwrap(); - - let (_seq, record) = next_outbound(&mut fsm, now).unwrap(); - assert_eq!(record.frames.len(), 1); - match &record.frames[0] { - SessionFrame::StreamData(frame) => { - assert_eq!(frame.stream_id, stream_id); - assert!(frame.fin); - assert!(frame.bytes.is_empty()); - } - frame => panic!("expected stream data frame, got {frame:?}"), - } -} - #[test] fn ack_reopens_write_capacity() { let now = Instant::now(); @@ -162,14 +148,14 @@ fn ack_reopens_write_capacity() { let stream_id = fsm.open_stream().unwrap(); assert_eq!(fsm.write_stream(stream_id, b"abcd").unwrap(), 4); - let (seq, _record) = next_outbound(&mut fsm, now).unwrap(); + let (record_seq, _record) = next_outbound(&mut fsm, now).unwrap(); let mut events = Vec::new(); fsm.receive( now + Duration::from_millis(1), - RecordSeq(9), + seq(9), std::iter::once(Ok(SessionFrame::Ack(RecordAck { - base_seq: seq, + base_seq: record_seq, bits: 1u64, }))), |event| events.push(event), @@ -190,16 +176,16 @@ fn commit_stream_read_is_what_advances_stream_window() { }, now, ); - let stream_id = StreamId(1); + let stream_id = stream_id(1); let data = SessionRecord { frames: vec![SessionFrame::StreamData(StreamData { stream_id, - offset: 0, + offset: offset(0), fin: false, bytes: b"hi".to_vec(), })], }; - let events = receive_events(&mut fsm, now, RecordSeq(7), &data); + let events = receive_events(&mut fsm, now, seq(7), &data); assert_eq!( events, vec![ @@ -208,8 +194,7 @@ fn commit_stream_read_is_what_advances_stream_window() { ] ); - let (write_id, _first_seq, builder) = - fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); + let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); let first = SessionRecord::decode(builder.bytes()).unwrap(); assert!(write_id.is_none()); assert!(matches!(first.frames.as_slice(), [SessionFrame::Ack(_)])); @@ -240,19 +225,19 @@ fn pure_ack_only_records_are_fire_and_forget() { }; let retransmit_timeout = config.retransmit_timeout; let mut fsm = SessionFsm::new(config, now); - let stream_id = StreamId(1); + let stream_id = stream_id(1); let record = SessionRecord { frames: vec![SessionFrame::StreamData(StreamData { stream_id, - offset: 0, + offset: offset(0), fin: false, bytes: b"hi".to_vec(), })], }; - let _ = receive_events(&mut fsm, now, RecordSeq(7), &record); + let _ = receive_events(&mut fsm, now, seq(7), &record); - let (write_id, _seq, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); + let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); let ack = SessionRecord::decode(builder.bytes()).unwrap(); assert!(write_id.is_none()); assert!(matches!(ack.frames.as_slice(), [SessionFrame::Ack(_)])); @@ -267,17 +252,17 @@ fn pure_ack_only_records_are_fire_and_forget() { fn inbound_stream_data_emits_opened_and_readable() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_id = ql_wire::StreamId(1); + let stream_id = stream_id(1); let record = SessionRecord { frames: vec![SessionFrame::StreamData(ql_wire::StreamData { stream_id, - offset: 0, + offset: offset(0), fin: true, bytes: b"hello".to_vec(), })], }; - let events = receive_events(&mut fsm, now, RecordSeq(0), &record); + let events = receive_events(&mut fsm, now, seq(0), &record); assert_eq!( events, vec![ @@ -298,7 +283,7 @@ fn remote_stream_close_is_reliable_and_retried() { fsm.close_stream(stream_id, CloseTarget::Both, StreamCloseCode(0)) .unwrap(); - let (write_id, _seq, builder) = fsm.take_next_write(now).unwrap(); + let (write_id, builder) = fsm.take_next_write(now).unwrap(); fsm.confirm_write(now, write_id.expect("stream close should be tracked")); let first = SessionRecord::decode(builder.bytes()).unwrap(); assert!(matches!( @@ -337,30 +322,25 @@ fn stream_ids_follow_even_odd_xid_ordering() { .open_stream() .unwrap(); - assert_eq!(even_id.0 % 2, 0); - assert_eq!(odd_id.0 % 2, 1); + assert_eq!(even_id.into_inner() % 2, 0); + assert_eq!(odd_id.into_inner() % 2, 1); } #[test] fn duplicate_stream_data_is_not_redelivered() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_id = StreamId(1); + let stream_id = stream_id(1); let record = SessionRecord { frames: vec![SessionFrame::StreamData(StreamData { stream_id, - offset: 0, + offset: offset(0), fin: false, bytes: b"hi".to_vec(), })], }; - let _ = receive_events(&mut fsm, now, RecordSeq(1), &record); - let _ = receive_events( - &mut fsm, - now + Duration::from_millis(1), - RecordSeq(2), - &record, - ); + let _ = receive_events(&mut fsm, now, seq(1), &record); + let _ = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &record); assert_eq!(read_stream_all(&mut fsm, stream_id), b"hi".to_vec()); } @@ -370,7 +350,7 @@ fn duplicate_remote_close_after_reap_is_ignored() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let close = StreamClose { - stream_id: StreamId(1), + stream_id: stream_id(1), target: CloseTarget::Both, code: StreamCloseCode(9), }; @@ -378,7 +358,7 @@ fn duplicate_remote_close_after_reap_is_ignored() { frames: vec![SessionFrame::StreamClose(close.clone())], }; - let first = receive_events(&mut fsm, now, RecordSeq(1), &record); + let first = receive_events(&mut fsm, now, seq(1), &record); assert_eq!( first, vec![ @@ -388,7 +368,7 @@ fn duplicate_remote_close_after_reap_is_ignored() { ] ); - let second = receive_events(&mut fsm, now + Duration::from_millis(1), RecordSeq(2), &record); + let second = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &record); assert!(second.is_empty()); } @@ -396,17 +376,17 @@ fn duplicate_remote_close_after_reap_is_ignored() { fn duplicate_finished_remote_data_after_reap_is_ignored() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_id = StreamId(1); + let stream_id = stream_id(1); let record = SessionRecord { frames: vec![SessionFrame::StreamData(StreamData { stream_id, - offset: 0, + offset: offset(0), fin: true, bytes: b"hello".to_vec(), })], }; - let first = receive_events(&mut fsm, now, RecordSeq(1), &record); + let first = receive_events(&mut fsm, now, seq(1), &record); assert_eq!( first, vec![ @@ -417,7 +397,7 @@ fn duplicate_finished_remote_data_after_reap_is_ignored() { ); assert_eq!(read_stream_all(&mut fsm, stream_id), b"hello".to_vec()); - let second = receive_events(&mut fsm, now + Duration::from_millis(1), RecordSeq(2), &record); + let second = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &record); assert!(second.is_empty()); } @@ -427,26 +407,26 @@ fn out_of_order_remote_stream_first_observations_still_open_once_each() { let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let close3 = SessionRecord { frames: vec![SessionFrame::StreamClose(StreamClose { - stream_id: StreamId(3), + stream_id: stream_id(3), target: CloseTarget::Both, code: StreamCloseCode(1), })], }; let close1 = SessionRecord { frames: vec![SessionFrame::StreamClose(StreamClose { - stream_id: StreamId(1), + stream_id: stream_id(1), target: CloseTarget::Both, code: StreamCloseCode(2), })], }; - let first = receive_events(&mut fsm, now, RecordSeq(1), &close3); - assert!(first.contains(&SessionEvent::Opened(StreamId(3)))); + let first = receive_events(&mut fsm, now, seq(1), &close3); + assert!(first.contains(&SessionEvent::Opened(stream_id(3)))); - let second = receive_events(&mut fsm, now + Duration::from_millis(1), RecordSeq(2), &close1); - assert!(second.contains(&SessionEvent::Opened(StreamId(1)))); + let second = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &close1); + assert!(second.contains(&SessionEvent::Opened(stream_id(1)))); - let third = receive_events(&mut fsm, now + Duration::from_millis(2), RecordSeq(3), &close3); + let third = receive_events(&mut fsm, now + Duration::from_millis(2), seq(3), &close3); assert!(third.is_empty()); } @@ -463,13 +443,13 @@ fn close_does_not_ack_rejected_record_seq() { let invalid = SessionRecord { frames: vec![SessionFrame::StreamData(StreamData { - stream_id: StreamId(0), - offset: 0, + stream_id: stream_id(0), + offset: offset(0), fin: false, bytes: b"bad".to_vec(), })], }; - let events = receive_events(&mut fsm, now, RecordSeq(7), &invalid); + let events = receive_events(&mut fsm, now, seq(7), &invalid); assert_eq!( events, vec![SessionEvent::SessionClosed(ql_wire::SessionClose { @@ -483,7 +463,7 @@ fn close_does_not_ack_rejected_record_seq() { let events = receive_events( &mut fsm, now + Duration::from_millis(1), - RecordSeq(8), + seq(8), &valid_after_close, ); assert!(events.is_empty()); @@ -517,11 +497,11 @@ fn initial_peer_stream_receive_window_limits_first_send() { let events = receive_events( &mut fsm, now + Duration::from_millis(1), - RecordSeq(9), + seq(9), &SessionRecord { frames: vec![SessionFrame::StreamWindow(ql_wire::StreamWindow { stream_id, - maximum_offset: 5, + maximum_offset: offset(5), })], }, ); @@ -533,7 +513,7 @@ fn initial_peer_stream_receive_window_limits_first_send() { frame, SessionFrame::StreamData(frame) if frame.stream_id == stream_id - && frame.offset == 3 + && frame.offset == offset(3) && frame.bytes.as_slice() == b"lo" ) })); diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 6af5524e..2d7e6964 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -5,6 +5,10 @@ use ql_wire::{SessionClose, StreamId}; use super::*; use crate::{state::LinkState, PeerStatus, QlFsmError, QlFsmEvent}; +fn stream_id(value: u32) -> StreamId { + StreamId::from_u32(value) +} + fn read_stream_all(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { let mut out = Vec::new(); loop { @@ -147,7 +151,7 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { #[test] fn disconnected_stream_operations_fail_with_no_session() { let mut harness = Harness::paired_known(QlFsmConfig::default()); - let missing = StreamId(0); + let missing = stream_id(0); assert_eq!(harness.a.fsm.open_stream(), Err(QlFsmError::NoSession)); assert_eq!( @@ -176,7 +180,7 @@ fn disconnected_stream_operations_fail_with_no_session() { #[test] fn disconnected_stream_read_accessors_return_none() { let harness = Harness::paired_known(QlFsmConfig::default()); - let missing = StreamId(0); + let missing = stream_id(0); assert!(harness.a.fsm.stream_read(missing).is_none()); assert!(harness.a.fsm.stream_available_bytes(missing).is_none()); diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index d8e45100..dff10bac 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -111,7 +111,7 @@ fn new_inbound_io(capacity: usize) -> InboundIo { #[test] fn handle_inbound_finished_reaps_closed_initiator_stream() { let (mut state, fsm) = new_driver_state(); - let stream_id = StreamId(1); + let stream_id = StreamId(1u32.into()); state.streams.insert( stream_id, @@ -129,7 +129,7 @@ fn handle_inbound_finished_reaps_closed_initiator_stream() { #[test] fn handle_closed_stream_reaps_when_both_halves_close() { let (mut state, _fsm) = new_driver_state(); - let stream_id = StreamId(2); + let stream_id = StreamId(1u32.into()); let (response_reader, _response_writer) = piper::pipe(1); state.streams.insert( @@ -152,7 +152,7 @@ fn handle_closed_stream_reaps_when_both_halves_close() { #[test] fn poll_stream_reaps_after_local_finish_when_inbound_is_closed() { let (mut state, mut fsm) = new_driver_state(); - let stream_id = StreamId(3); + let stream_id = StreamId(1u32.into()); let (request_reader, request_writer) = piper::pipe(1); drop(request_writer); @@ -172,7 +172,7 @@ fn poll_stream_reaps_after_local_finish_when_inbound_is_closed() { #[test] fn local_close_command_reaps_when_other_half_is_already_closed() { let (mut state, mut fsm) = new_driver_state(); - let stream_id = StreamId(4); + let stream_id = StreamId(1u32.into()); let (request_reader, _request_writer) = piper::pipe(1); let mut in_flight = Vec::new(); diff --git a/ql-wire/src/codec.rs b/ql-wire/src/codec.rs index cb9af122..d2538aad 100644 --- a/ql-wire/src/codec.rs +++ b/ql-wire/src/codec.rs @@ -1,4 +1,4 @@ -use crate::{ByteSlice, WireError}; +use crate::{ByteSlice, VarInt, WireError}; pub fn write_u8(out: &mut [u8], value: u8) -> &mut [u8] { let (head, rest) = out.split_at_mut(1); @@ -24,6 +24,17 @@ pub fn write_u64(out: &mut [u8], value: u64) -> &mut [u8] { rest } +pub fn write_varint(out: &mut [u8], value: VarInt) -> &mut [u8] { + let x = value.into_inner(); + match value.size() { + 1 => write_u8(out, x as u8), + 2 => write_bytes(out, &(((0b01u16 << 14) | (x as u16)).to_be_bytes())), + 4 => write_bytes(out, &(((0b10u32 << 30) | (x as u32)).to_be_bytes())), + 8 => write_bytes(out, &(((0b11u64 << 62) | x).to_be_bytes())), + _ => unreachable!("malformed varint"), + } +} + pub fn write_bool(out: &mut [u8], value: bool) -> &mut [u8] { write_u8(out, u8::from(value)) } @@ -125,6 +136,37 @@ impl Reader { Ok(u64::from_le_bytes(self.take_array()?)) } + pub fn take_varint(&mut self) -> Result { + let first = self.take_u8()?; + let tag = first >> 6; + let first = first & 0b0011_1111; + let value = match tag { + 0b00 => u64::from(first), + 0b01 => { + let mut buf = [0; 2]; + buf[0] = first; + buf[1] = self.take_u8()?; + u64::from(u16::from_be_bytes(buf)) + } + 0b10 => { + let mut buf = [0; 4]; + buf[0] = first; + buf[1..].copy_from_slice(&self.take_array::<3>()?); + u64::from(u32::from_be_bytes(buf)) + } + 0b11 => { + let mut buf = [0; 8]; + buf[0] = first; + buf[1..].copy_from_slice(&self.take_array::<7>()?); + u64::from_be_bytes(buf) + } + _ => unreachable!(), + }; + + // SAFETY: the decoded value is guaranteed to fit in the 62-bit varint range. + Ok(unsafe { VarInt::from_u64_unchecked(value) }) + } + pub fn take_bool(&mut self) -> Result { match self.take_u8()? { 0 => Ok(false), diff --git a/ql-wire/src/encrypted/ack.rs b/ql-wire/src/encrypted/ack.rs index c7794332..02e934ff 100644 --- a/ql-wire/src/encrypted/ack.rs +++ b/ql-wire/src/encrypted/ack.rs @@ -8,14 +8,13 @@ pub struct RecordAck { impl RecordAck { pub const BITMAP_BITS: usize = u64::BITS as usize; - pub const WIRE_SIZE: usize = size_of::() + size_of::(); pub fn contains(&self, seq: u64) -> bool { - if seq < self.base_seq.0 { + if seq < self.base_seq.into_inner() { return false; } - let offset = seq - self.base_seq.0; + let offset = seq - self.base_seq.into_inner(); if offset >= Self::BITMAP_BITS as u64 { return false; } @@ -23,9 +22,13 @@ impl RecordAck { (self.bits & (1u64 << offset)) != 0 } + pub fn wire_size(&self) -> usize { + self.base_seq.encoded_len() + size_of::() + } + pub fn encode_into(&self, out: &mut [u8]) { - assert_eq!(out.len(), Self::WIRE_SIZE); - let out = codec::write_u64(out, self.base_seq.0); + assert!(out.len() >= self.wire_size()); + let out = codec::write_varint(out, self.base_seq.0); let _ = codec::write_u64(out, self.bits); } } @@ -33,7 +36,7 @@ impl RecordAck { impl codec::WireParse for RecordAck { fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { - base_seq: RecordSeq(reader.take_u64()?), + base_seq: RecordSeq(reader.take_varint()?), bits: reader.take_u64()?, }) } @@ -47,19 +50,19 @@ mod tests { #[test] fn encode_decode_round_trip() { let ack = RecordAck { - base_seq: RecordSeq(42), + base_seq: RecordSeq::from_u32(42), bits: (1u64 << 0) | (1u64 << 17) | (1u64 << 63), }; - let mut encoded = [0; RecordAck::WIRE_SIZE]; + let mut encoded = vec![0; ack.wire_size()]; ack.encode_into(&mut encoded); - assert_eq!(RecordAck::parse_bytes(&encoded[..]).unwrap(), ack); + assert_eq!(RecordAck::parse_bytes(encoded.as_slice()).unwrap(), ack); } #[test] fn contains_matches_bit_membership() { let ack = RecordAck { - base_seq: RecordSeq(100), + base_seq: RecordSeq::from_u32(100), bits: (1u64 << 0) | (1u64 << 5) | (1u64 << 63), }; @@ -77,8 +80,9 @@ mod tests { RecordAck::parse_bytes(&[][..]), Err(WireError::InvalidPayload) ); + let encoded = vec![0; RecordSeq::from_u32(0).encoded_len() + size_of::()]; assert_eq!( - RecordAck::parse_bytes(&[0; RecordAck::WIRE_SIZE - 1][..]), + RecordAck::parse_bytes(&encoded[..encoded.len() - 1]), Err(WireError::InvalidPayload) ); } diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index 615f83e4..ac5c8495 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -1,30 +1,50 @@ use super::{RecordAck, SessionClose, SessionFrame, StreamClose, StreamData, StreamWindow}; -use crate::{ByteChunks, Nonce, QlCrypto, RecordType, SessionHeader, SessionKey, QL_WIRE_VERSION}; +use crate::{ + codec, ByteChunks, ConnectionId, Nonce, QlCrypto, RecordSeq, RecordType, SessionHeader, + SessionKey, VarInt, QL_WIRE_VERSION, +}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct SessionRecordBuilder { + seq: RecordSeq, + prefix_len: usize, max_capacity: usize, bytes: Vec, } impl SessionRecordBuilder { - pub const WIRE_PREFIX_LEN: usize = - 1 + 1 + SessionHeader::WIRE_SIZE + crate::ENCRYPTED_MESSAGE_AUTH_SIZE; - - pub fn new(max_capacity: usize) -> Self { - assert!(max_capacity >= Self::WIRE_PREFIX_LEN); + pub const MIN_CAPACITY: usize = 1 + + 1 + + ConnectionId::SIZE + + RecordSeq::MAX_ENCODED_LEN + + crate::ENCRYPTED_MESSAGE_AUTH_SIZE; + + pub fn new(seq: RecordSeq, max_capacity: usize) -> Self { + let prefix_len = + 1 + 1 + ConnectionId::SIZE + seq.encoded_len() + crate::ENCRYPTED_MESSAGE_AUTH_SIZE; + assert!(max_capacity >= prefix_len); Self { + seq, + prefix_len, max_capacity, bytes: Vec::new(), } } + pub fn seq(&self) -> RecordSeq { + self.seq + } + + pub fn prefix_len(&self) -> usize { + self.prefix_len + } + pub fn max_capacity(&self) -> usize { self.max_capacity } pub fn len(&self) -> usize { - self.bytes.len().saturating_sub(Self::WIRE_PREFIX_LEN) + self.bytes.len().saturating_sub(self.prefix_len) } pub fn is_empty(&self) -> bool { @@ -33,11 +53,11 @@ impl SessionRecordBuilder { pub fn remaining_capacity(&self) -> usize { self.max_capacity - .saturating_sub(self.bytes.len().max(Self::WIRE_PREFIX_LEN)) + .saturating_sub(self.bytes.len().max(self.prefix_len)) } pub fn bytes(&self) -> &[u8] { - self.bytes.get(Self::WIRE_PREFIX_LEN..).unwrap_or_default() + self.bytes.get(self.prefix_len..).unwrap_or_default() } pub fn push_ping(&mut self) -> bool { @@ -45,13 +65,9 @@ impl SessionRecordBuilder { } pub fn push_ack(&mut self, ack: &RecordAck) -> bool { - self.push_frame_payload( - super::SessionFrameKind::Ack, - RecordAck::WIRE_SIZE, - |payload| { - ack.encode_into(payload); - }, - ) + self.push_frame_payload(super::SessionFrameKind::Ack, ack.wire_size(), |payload| { + ack.encode_into(payload); + }) } pub fn push_stream_data(&mut self, frame: &StreamData) -> bool { @@ -67,7 +83,7 @@ impl SessionRecordBuilder { pub fn push_stream_window(&mut self, frame: &StreamWindow) -> bool { self.push_frame_payload( super::SessionFrameKind::StreamWindow, - StreamWindow::WIRE_SIZE, + frame.wire_size(), |payload| { frame.encode_into(payload); }, @@ -77,7 +93,7 @@ impl SessionRecordBuilder { pub fn push_stream_close(&mut self, frame: &StreamClose) -> bool { self.push_frame_payload( super::SessionFrameKind::StreamClose, - StreamClose::WIRE_SIZE, + frame.wire_size(), |payload| { frame.encode_into(payload); }, @@ -108,24 +124,28 @@ impl SessionRecordBuilder { pub fn encrypt( mut self, crypto: &impl QlCrypto, - header: SessionHeader, + connection_id: ConnectionId, session_key: &SessionKey, ) -> Vec { self.ensure_prefix_capacity(0); + let header = SessionHeader { + connection_id, + seq: self.seq, + }; let aad = header.aad(); - let nonce = Nonce::from_counter(header.seq.0); + let nonce = Nonce::from_counter(self.seq.into_inner()); let auth = crypto.aes256_gcm_encrypt( session_key, &nonce, &aad, - &mut self.bytes[Self::WIRE_PREFIX_LEN..], + &mut self.bytes[self.prefix_len..], ); - let prefix = &mut self.bytes[..Self::WIRE_PREFIX_LEN]; + let prefix = &mut self.bytes[..self.prefix_len]; prefix[0] = QL_WIRE_VERSION; prefix[1] = RecordType::Session as u8; - header.encode_into(&mut prefix[2..2 + SessionHeader::WIRE_SIZE]); - prefix[2 + SessionHeader::WIRE_SIZE..].copy_from_slice(&auth); + let auth_out = header.encode_into(&mut prefix[2..]); + auth_out[..crate::ENCRYPTED_MESSAGE_AUTH_SIZE].copy_from_slice(&auth); self.bytes } @@ -162,10 +182,13 @@ impl SessionRecordBuilder { payload_wire_size: usize, encode_payload: impl FnOnce(&mut [u8]), ) -> bool { - self.push_wire_size(1 + super::SIZE_LEN + payload_wire_size, |out| { + let Ok(prefix_len) = VarInt::try_from(payload_wire_size) else { + return false; + }; + self.push_wire_size(1 + prefix_len.size() + payload_wire_size, |out| { out[0] = kind as u8; - super::push_variable_len(&mut out[1..=super::SIZE_LEN], payload_wire_size); - encode_payload(&mut out[1 + super::SIZE_LEN..]); + let payload = codec::write_varint(&mut out[1..], prefix_len); + encode_payload(payload); }) } @@ -175,9 +198,8 @@ impl SessionRecordBuilder { fn ensure_prefix_capacity(&mut self, additional_body_len: usize) { if self.bytes.is_empty() { - self.bytes - .reserve(Self::WIRE_PREFIX_LEN + additional_body_len); - self.bytes.resize(Self::WIRE_PREFIX_LEN, 0); + self.bytes.reserve(self.prefix_len + additional_body_len); + self.bytes.resize(self.prefix_len, 0); } } } diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 62c8dfec..abeb55c4 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -1,6 +1,6 @@ use crate::{ codec, encrypted_message::EncryptedMessage, ByteChunks, ByteSlice, Nonce, QlCrypto, - SessionHeader, SessionKey, WireError, WireParse, + SessionHeader, SessionKey, VarInt, VarIntBoundsExceeded, WireError, WireParse, }; mod ack; @@ -20,7 +20,27 @@ pub use stream_window::*; // todo: should use even/odd based on xid ordering #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] -pub struct StreamId(pub u32); +pub struct StreamId(pub VarInt); + +impl StreamId { + pub const MAX_ENCODED_LEN: usize = VarInt::MAX_SIZE; + + pub const fn from_u32(value: u32) -> Self { + Self(VarInt::from_u32(value)) + } + + pub fn from_u64(value: u64) -> Result { + Ok(Self(VarInt::from_u64(value)?)) + } + + pub const fn into_inner(self) -> u64 { + self.0.into_inner() + } + + pub const fn encoded_len(self) -> usize { + self.0.size() + } +} #[derive(Debug, Clone, PartialEq, Eq)] pub struct SessionRecord { @@ -40,8 +60,6 @@ pub enum SessionFrame { pub type SessionFrameVec = SessionFrame>; pub type StreamDataVec = StreamData>; -pub(crate) const SIZE_LEN: usize = size_of::(); - #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub(crate) enum SessionFrameKind { @@ -98,10 +116,15 @@ impl SessionFrame { pub fn wire_size(&self) -> usize { 1 + match self { Self::Ping => 0, - Self::Ack(_) => RecordAck::WIRE_SIZE, - Self::StreamData(frame) => SIZE_LEN + frame.wire_size(), - Self::StreamWindow(_) => StreamWindow::WIRE_SIZE, - Self::StreamClose(_) => StreamClose::WIRE_SIZE, + Self::Ack(frame) => frame.wire_size(), + Self::StreamData(frame) => { + VarInt::try_from(frame.wire_size()) + .unwrap_or(VarInt::MAX) + .size() + + frame.wire_size() + } + Self::StreamWindow(frame) => frame.wire_size(), + Self::StreamClose(frame) => frame.wire_size(), Self::Close(_) => SessionClose::WIRE_SIZE, } } @@ -149,7 +172,7 @@ pub fn decrypt_record>( session_key: &SessionKey, ) -> Result { let aad = header.aad(); - let nonce = Nonce::from_counter(header.seq.0); + let nonce = Nonce::from_counter(header.seq.into_inner()); encrypted.decrypt_in_place(crypto, session_key, &nonce, &aad) } @@ -158,52 +181,42 @@ fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrame<&[u8]>, &[u8]), WireEr match SessionFrameKind::try_from(kind)? { SessionFrameKind::Ping => Ok((SessionFrame::Ping, rest)), SessionFrameKind::Ack => { - let (frame, rest) = rest - .split_at_checked(RecordAck::WIRE_SIZE) - .ok_or(WireError::InvalidPayload)?; - Ok((SessionFrame::Ack(RecordAck::parse_bytes(frame)?), rest)) + let (frame, rest) = parse_inline_frame::(rest)?; + Ok((SessionFrame::Ack(frame), rest)) } SessionFrameKind::StreamData => { let (frame, rest) = split_variable_frame(rest)?; Ok((SessionFrame::StreamData(StreamData::parse(frame)?), rest)) } SessionFrameKind::StreamWindow => { - let (frame, rest) = rest - .split_at_checked(StreamWindow::WIRE_SIZE) - .ok_or(WireError::InvalidPayload)?; - Ok(( - SessionFrame::StreamWindow(StreamWindow::parse_bytes(frame)?), - rest, - )) + let (frame, rest) = parse_inline_frame::(rest)?; + Ok((SessionFrame::StreamWindow(frame), rest)) } SessionFrameKind::StreamClose => { - let (frame, rest) = rest - .split_at_checked(StreamClose::WIRE_SIZE) - .ok_or(WireError::InvalidPayload)?; - Ok(( - SessionFrame::StreamClose(StreamClose::parse_bytes(frame)?), - rest, - )) + let (frame, rest) = parse_inline_frame::(rest)?; + Ok((SessionFrame::StreamClose(frame), rest)) } SessionFrameKind::Close => { - let (frame, rest) = rest - .split_at_checked(SessionClose::WIRE_SIZE) - .ok_or(WireError::InvalidPayload)?; - Ok((SessionFrame::Close(SessionClose::parse_bytes(frame)?), rest)) + let (frame, rest) = parse_inline_frame::(rest)?; + Ok((SessionFrame::Close(frame), rest)) } } } -fn push_variable_len(out: &mut [u8], len: usize) { - let len = u16::try_from(len).expect("session frame exceeds u16"); - let _ = codec::write_u16(out, len); +fn parse_inline_frame(bytes: &[u8]) -> Result<(T, &[u8]), WireError> +where + T: for<'a> WireParse<&'a [u8]>, +{ + let mut reader = codec::Reader::new(bytes); + let frame = reader.parse::()?; + let consumed = bytes.len() - reader.remaining_len(); + Ok((frame, &bytes[consumed..])) } fn split_variable_frame(bytes: &[u8]) -> Result<(&[u8], &[u8]), WireError> { - if bytes.len() < SIZE_LEN { - return Err(WireError::InvalidPayload); - } - let len = u16::from_le_bytes([bytes[0], bytes[1]]) as usize; - let bytes = &bytes[SIZE_LEN..]; + let mut reader = codec::Reader::new(bytes); + let len = usize::try_from(reader.take_varint()?.into_inner()) + .map_err(|_| WireError::InvalidPayload)?; + let bytes = &bytes[bytes.len() - reader.remaining_len()..]; bytes.split_at_checked(len).ok_or(WireError::InvalidPayload) } diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index 742d9992..f8796228 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -14,11 +14,12 @@ pub struct StreamClose { } impl StreamClose { - pub const WIRE_SIZE: usize = - size_of::() + size_of::() + size_of::(); + pub fn wire_size(&self) -> usize { + self.stream_id.encoded_len() + size_of::() + size_of::() + } pub fn encode_into(&self, out: &mut [u8]) { - let out = codec::write_u32(out, self.stream_id.0); + let out = codec::write_varint(out, self.stream_id.0); let out = codec::write_u8(out, self.target.to_wire()); let _ = codec::write_u16(out, self.code.0); } @@ -27,7 +28,7 @@ impl StreamClose { impl codec::WireParse for StreamClose { fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { - stream_id: StreamId(reader.take_u32()?), + stream_id: StreamId(reader.take_varint()?), target: CloseTarget::try_from(reader.take_u8()?)?, code: StreamCloseCode(reader.take_u16()?), }) diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs index 33fbb7b8..0f66c4bb 100644 --- a/ql-wire/src/encrypted/stream_data.rs +++ b/ql-wire/src/encrypted/stream_data.rs @@ -1,25 +1,26 @@ use super::StreamId; -use crate::{codec, ByteChunks, ByteSlice, WireError}; +use crate::{codec, ByteChunks, ByteSlice, VarInt, WireError}; /// carries bytes for a stream and may finish that sending direction. #[derive(Debug, Clone, PartialEq, Eq)] pub struct StreamData { pub stream_id: StreamId, - pub offset: u64, + pub offset: VarInt, pub fin: bool, pub bytes: B, } impl StreamData { - pub const MIN_WIRE_SIZE: usize = size_of::() + size_of::() + size_of::(); + /// Conservative constant overhead for callers that still budget with a fixed header size. + pub const MIN_WIRE_SIZE: usize = StreamId::MAX_ENCODED_LEN + VarInt::MAX_SIZE + size_of::(); } impl StreamData { pub fn parse(bytes: B) -> Result { let mut reader = codec::Reader::new(bytes); Ok(Self { - stream_id: StreamId(reader.take_u32()?), - offset: reader.take_u64()?, + stream_id: StreamId(reader.take_varint()?), + offset: reader.take_varint()?, fin: reader.take_bool()?, bytes: reader.take_rest(), }) @@ -41,13 +42,17 @@ impl StreamData { } impl StreamData { + pub fn header_len(&self) -> usize { + self.stream_id.encoded_len() + self.offset.size() + size_of::() + } + pub fn wire_size(&self) -> usize { - Self::MIN_WIRE_SIZE + self.bytes.len() + self.header_len() + self.bytes.len() } pub fn encode_into(&self, out: &mut [u8]) { - let out = codec::write_u32(out, self.stream_id.0); - let out = codec::write_u64(out, self.offset); + let out = codec::write_varint(out, self.stream_id.0); + let out = codec::write_varint(out, self.offset); let mut out = codec::write_bool(out, self.fin); for chunk in self.bytes.chunks() { out = codec::write_bytes(out, chunk); diff --git a/ql-wire/src/encrypted/stream_window.rs b/ql-wire/src/encrypted/stream_window.rs index 224e9e43..91c8a15a 100644 --- a/ql-wire/src/encrypted/stream_window.rs +++ b/ql-wire/src/encrypted/stream_window.rs @@ -1,27 +1,29 @@ use super::StreamId; -use crate::{codec, ByteSlice, WireError}; +use crate::{codec, ByteSlice, VarInt, WireError}; /// advertises the highest byte offset the peer may send on a stream. #[derive(Debug, Clone, PartialEq, Eq)] pub struct StreamWindow { pub stream_id: StreamId, - pub maximum_offset: u64, + pub maximum_offset: VarInt, } impl StreamWindow { - pub const WIRE_SIZE: usize = size_of::() + size_of::(); + pub fn wire_size(&self) -> usize { + self.stream_id.encoded_len() + self.maximum_offset.size() + } pub fn encode_into(&self, out: &mut [u8]) { - let out = codec::write_u32(out, self.stream_id.0); - let _ = codec::write_u64(out, self.maximum_offset); + let out = codec::write_varint(out, self.stream_id.0); + let _ = codec::write_varint(out, self.maximum_offset); } } impl codec::WireParse for StreamWindow { fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { - stream_id: StreamId(reader.take_u32()?), - maximum_offset: reader.take_u64()?, + stream_id: StreamId(reader.take_varint()?), + maximum_offset: reader.take_varint()?, }) } } diff --git a/ql-wire/src/header.rs b/ql-wire/src/header.rs index 253f2e21..6db1e24b 100644 --- a/ql-wire/src/header.rs +++ b/ql-wire/src/header.rs @@ -1,4 +1,4 @@ -use crate::{codec, ByteSlice, QL_WIRE_VERSION}; +use crate::{codec, ByteSlice, VarInt, VarIntBoundsExceeded, QL_WIRE_VERSION}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct SessionHeader { @@ -8,7 +8,27 @@ pub struct SessionHeader { #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] -pub struct RecordSeq(pub u64); +pub struct RecordSeq(pub VarInt); + +impl RecordSeq { + pub const MAX_ENCODED_LEN: usize = VarInt::MAX_SIZE; + + pub const fn from_u32(value: u32) -> Self { + Self(VarInt::from_u32(value)) + } + + pub fn from_u64(value: u64) -> Result { + Ok(Self(VarInt::from_u64(value)?)) + } + + pub const fn into_inner(self) -> u64 { + self.0.into_inner() + } + + pub const fn encoded_len(self) -> usize { + self.0.size() + } +} #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[repr(transparent)] @@ -27,20 +47,24 @@ impl ConnectionId { } impl SessionHeader { - pub const WIRE_SIZE: usize = ConnectionId::SIZE + size_of::(); + pub const MAX_ENCODED_LEN: usize = ConnectionId::SIZE + RecordSeq::MAX_ENCODED_LEN; const AAD_DOMAIN: &[u8] = b"ql-wire:session-aad:v1"; const AAD_RECORD_KIND_SESSION: u8 = 1; - pub fn encode(&self) -> [u8; Self::WIRE_SIZE] { - let mut out = [0; Self::WIRE_SIZE]; - self.encode_into(&mut out); + pub fn encoded_len(&self) -> usize { + ConnectionId::SIZE + self.seq.encoded_len() + } + + pub fn encode(&self) -> Vec { + let mut out = vec![0; self.encoded_len()]; + let _ = self.encode_into(&mut out); out } - pub fn encode_into(&self, out: &mut [u8]) { - assert_eq!(out.len(), Self::WIRE_SIZE); + pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { + assert!(out.len() >= self.encoded_len()); let out = codec::write_bytes(out, self.connection_id.as_bytes()); - let _ = codec::write_u64(out, self.seq.0); + codec::write_varint(out, self.seq.0) } pub fn aad(&self) -> Vec { @@ -48,13 +72,13 @@ impl SessionHeader { + size_of::() + size_of::() + ConnectionId::SIZE - + size_of::(); + + self.seq.encoded_len(); let mut aad = vec![0; aad_len]; let out = codec::write_bytes(&mut aad, Self::AAD_DOMAIN); let out = codec::write_u8(out, QL_WIRE_VERSION); let out = codec::write_u8(out, Self::AAD_RECORD_KIND_SESSION); let out = codec::write_bytes(out, self.connection_id.as_bytes()); - let _ = codec::write_u64(out, self.seq.0); + let _ = codec::write_varint(out, self.seq.0); aad } } @@ -63,7 +87,7 @@ impl codec::WireParse for SessionHeader { fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { connection_id: ConnectionId::from_data(reader.take_array()?), - seq: RecordSeq(reader.take_u64()?), + seq: RecordSeq(reader.take_varint()?), }) } } diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs index 36568e24..4d6077e4 100644 --- a/ql-wire/src/lib.rs +++ b/ql-wire/src/lib.rs @@ -16,6 +16,7 @@ mod identity; mod nonce; mod pq; mod record; +mod varint; mod xid; pub use bytes::*; @@ -30,6 +31,7 @@ pub use identity::*; pub use nonce::*; pub use pq::*; pub use record::*; +pub use varint::*; pub use xid::*; pub const QL_WIRE_VERSION: u8 = 1; diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index a8d4124a..df169ace 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -139,13 +139,13 @@ impl> QlSessionRecord { pub fn encode(&self) -> Vec { let mut out = vec![ 0; - 2 + SessionHeader::WIRE_SIZE + 2 + self.header.encoded_len() + EncryptedMessage::<&[u8]>::HEADER_LEN + self.payload.ciphertext.as_ref().len() ]; let rest = codec::write_u8(&mut out, QL_WIRE_VERSION); let rest = codec::write_u8(rest, RecordType::Session as u8); - let rest = codec::write_bytes(rest, &self.header.encode()); + let rest = self.header.encode_into(rest); let _ = self.payload.encode_into(rest); out } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 6a87939b..46549e06 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -151,6 +151,18 @@ fn xid(byte: u8) -> XID { XID([byte; XID::SIZE]) } +fn varint(value: u64) -> VarInt { + VarInt::from_u64(value).unwrap() +} + +fn record_seq(value: u64) -> RecordSeq { + RecordSeq(varint(value)) +} + +fn stream_id(value: u64) -> StreamId { + StreamId(varint(value)) +} + fn handshake_meta(id: u32) -> HandshakeMeta { HandshakeMeta { handshake_id: HandshakeId(id), @@ -181,15 +193,18 @@ fn encrypt_record( session_key: &SessionKey, body: &SessionRecord, ) -> QlSessionRecord> { - let wire_size = body.wire_size() + SessionRecordBuilder::WIRE_PREFIX_LEN; - let mut builder = SessionRecordBuilder::new(wire_size); + let mut builder = SessionRecordBuilder::new(header.seq, usize::MAX); for frame in &body.frames { let pushed = builder.push_frame(frame); debug_assert!(pushed); } - QlSessionRecord::parse_bytes(builder.encrypt(crypto, header, session_key).as_slice()) - .unwrap() - .into_owned() + QlSessionRecord::parse_bytes( + builder + .encrypt(crypto, header.connection_id, session_key) + .as_slice(), + ) + .unwrap() + .into_owned() } #[test] @@ -634,13 +649,13 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { let crypto = TestCrypto::new(40); let header = SessionHeader { connection_id: ConnectionId::from_data([0x44; ConnectionId::SIZE]), - seq: RecordSeq(11), + seq: record_seq(11), }; let body = SessionRecord { frames: vec![ SessionFrame::Ping, SessionFrame::Ack(RecordAck { - base_seq: RecordSeq(12), + base_seq: record_seq(12), bits: (1u64 << 0) | (1u64 << 1) | (1u64 << 8) @@ -649,17 +664,17 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { | (1u64 << 11), }), SessionFrame::StreamWindow(StreamWindow { - stream_id: StreamId(9), - maximum_offset: 65_536, + stream_id: stream_id(9), + maximum_offset: varint(65_536), }), SessionFrame::StreamData(StreamData { - stream_id: StreamId(9), - offset: 1024, + stream_id: stream_id(9), + offset: varint(1024), bytes: b"hello".to_vec(), fin: true, }), SessionFrame::StreamClose(StreamClose { - stream_id: StreamId(9), + stream_id: stream_id(9), target: CloseTarget::Both, code: StreamCloseCode(0), }), @@ -700,7 +715,7 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { let wrong_seq_header = SessionHeader { connection_id: header.connection_id, - seq: RecordSeq(header.seq.0 + 1), + seq: record_seq(header.seq.into_inner() + 1), }; assert_eq!( encrypted::decrypt_record(&crypto, &wrong_seq_header, encrypted, &session_key), @@ -708,6 +723,35 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { ); } +#[test] +fn session_varint_fields_expand_at_expected_boundaries() { + let short_header = SessionHeader { + connection_id: ConnectionId::from_data([0x11; ConnectionId::SIZE]), + seq: record_seq(63), + }; + let long_header = SessionHeader { + connection_id: ConnectionId::from_data([0x11; ConnectionId::SIZE]), + seq: record_seq(64), + }; + + assert_eq!(short_header.encode().len(), ConnectionId::SIZE + 1); + assert_eq!(long_header.encode().len(), ConnectionId::SIZE + 2); + + let frame = StreamData { + stream_id: stream_id(64), + offset: varint(16_384), + fin: true, + bytes: b"abc".to_vec(), + }; + let mut encoded = vec![0; frame.wire_size()]; + frame.encode_into(&mut encoded); + + assert_eq!( + StreamData::parse(encoded.as_slice()).unwrap().into_owned(), + frame + ); +} + #[test] fn protocol_record_size_breakdown() { fn print_size(label: &str, size: usize) { @@ -763,7 +807,7 @@ fn protocol_record_size_breakdown() { &crypto, SessionHeader { connection_id: session.tx_connection_id, - seq: RecordSeq(1), + seq: record_seq(1), }, &session.tx_key, &SessionRecord { @@ -774,13 +818,13 @@ fn protocol_record_size_breakdown() { &crypto, SessionHeader { connection_id: session.tx_connection_id, - seq: RecordSeq(2), + seq: record_seq(2), }, &session.tx_key, &SessionRecord { frames: vec![SessionFrame::StreamData(StreamData { - stream_id: StreamId(1), - offset: 0, + stream_id: stream_id(1), + offset: varint(0), fin: false, bytes: Vec::new(), })], @@ -790,7 +834,7 @@ fn protocol_record_size_breakdown() { &crypto, SessionHeader { connection_id: session.tx_connection_id, - seq: RecordSeq(3), + seq: record_seq(3), }, &session.tx_key, &SessionRecord { diff --git a/ql-wire/src/varint.rs b/ql-wire/src/varint.rs new file mode 100644 index 00000000..0fc06d03 --- /dev/null +++ b/ql-wire/src/varint.rs @@ -0,0 +1,126 @@ +use core::fmt; + +/// An integer less than 2^62 encoded with QUIC variable-length integer rules. +#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct VarInt(pub(crate) u64); + +impl VarInt { + /// The largest representable value. + pub const MAX: Self = Self((1u64 << 62) - 1); + /// The largest encoded value length. + pub const MAX_SIZE: usize = 8; + pub const MIN_SIZE: usize = 1; + + /// Construct a `VarInt` infallibly from a `u32`. + pub const fn from_u32(x: u32) -> Self { + Self(x as u64) + } + + /// Construct a `VarInt` from a `u64`. + pub fn from_u64(x: u64) -> Result { + if x < (1u64 << 62) { + Ok(Self(x)) + } else { + Err(VarIntBoundsExceeded) + } + } + + /// Create a `VarInt` without checking the bounds. + /// + /// # Safety + /// + /// `x` must be less than 2^62. + pub const unsafe fn from_u64_unchecked(x: u64) -> Self { + Self(x) + } + + /// Extract the inner integer value. + pub const fn into_inner(self) -> u64 { + self.0 + } + + /// Return the number of bytes required to encode this value. + pub const fn size(self) -> usize { + let x = self.0; + if x < (1u64 << 6) { + 1 + } else if x < (1u64 << 14) { + 2 + } else if x < (1u64 << 30) { + 4 + } else { + 8 + } + } +} + +impl From for u64 { + fn from(value: VarInt) -> Self { + value.0 + } +} + +impl From for VarInt { + fn from(value: u8) -> Self { + Self(value.into()) + } +} + +impl From for VarInt { + fn from(value: u16) -> Self { + Self(value.into()) + } +} + +impl From for VarInt { + fn from(value: u32) -> Self { + Self(value.into()) + } +} + +impl TryFrom for VarInt { + type Error = VarIntBoundsExceeded; + + fn try_from(value: u64) -> Result { + Self::from_u64(value) + } +} + +impl TryFrom for VarInt { + type Error = VarIntBoundsExceeded; + + fn try_from(value: u128) -> Result { + Self::from_u64(value.try_into().map_err(|_| VarIntBoundsExceeded)?) + } +} + +impl TryFrom for VarInt { + type Error = VarIntBoundsExceeded; + + fn try_from(value: usize) -> Result { + Self::from_u64(value as u64) + } +} + +impl fmt::Debug for VarInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl fmt::Display for VarInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct VarIntBoundsExceeded; + +impl fmt::Display for VarIntBoundsExceeded { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("value too large for varint encoding") + } +} + +impl std::error::Error for VarIntBoundsExceeded {} From 0c7816b73d49723fda3216509d35a60899718e34 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 5 Apr 2026 23:38:18 -0400 Subject: [PATCH 120/304] ql-wire: use WireParse trait more --- ql-wire/src/codec.rs | 130 ++++++++++++---------- ql-wire/src/encrypted/ack.rs | 4 +- ql-wire/src/encrypted/close.rs | 22 ++-- ql-wire/src/encrypted/mod.rs | 8 +- ql-wire/src/encrypted/stream_close.rs | 18 ++- ql-wire/src/encrypted/stream_data.rs | 6 +- ql-wire/src/encrypted/stream_window.rs | 4 +- ql-wire/src/encrypted_message.rs | 2 +- ql-wire/src/handshake/ik.rs | 8 +- ql-wire/src/handshake/kk.rs | 6 +- ql-wire/src/handshake/meta.rs | 10 +- ql-wire/src/handshake/mod.rs | 18 ++- ql-wire/src/handshake/transport_params.rs | 2 +- ql-wire/src/header.rs | 20 +++- ql-wire/src/identity.rs | 8 +- ql-wire/src/pq.rs | 20 ++++ ql-wire/src/record.rs | 18 ++- ql-wire/src/xid.rs | 8 ++ 18 files changed, 211 insertions(+), 101 deletions(-) diff --git a/ql-wire/src/codec.rs b/ql-wire/src/codec.rs index d2538aad..b115f99d 100644 --- a/ql-wire/src/codec.rs +++ b/ql-wire/src/codec.rs @@ -64,80 +64,56 @@ pub trait WireParse: Sized { } } -pub struct Reader { - remaining: Option, -} - -impl Reader { - pub fn new(bytes: B) -> Self { - Self { - remaining: Some(bytes), - } - } - - pub fn is_empty(&self) -> bool { - self.remaining.as_ref().unwrap().is_empty() - } - - pub fn remaining_len(&self) -> usize { - self.remaining.as_ref().unwrap().len() - } - - pub fn take_bytes(&mut self, len: usize) -> Result { - let remaining = self.remaining.take().unwrap(); - match remaining.split_at(len) { - Ok((head, tail)) => { - self.remaining = Some(tail); - Ok(head) - } - Err(remaining) => { - self.remaining = Some(remaining); - Err(WireError::InvalidPayload) - } - } - } - - pub fn take_rest(mut self) -> B { - self.remaining.take().unwrap() - } - - pub fn take_array(&mut self) -> Result<[u8; N], WireError> { - let bytes = self.take_bytes(N)?; +impl WireParse for [u8; N] { + fn parse(reader: &mut Reader) -> Result { + let bytes = reader.take_bytes(N)?; let mut out = [0u8; N]; out.copy_from_slice(&bytes); Ok(out) } +} - pub fn take_boxed_array(&mut self) -> Result, WireError> { - let bytes = self.take_bytes(N)?; +impl WireParse for Box<[u8; N]> { + fn parse(reader: &mut Reader) -> Result { + let bytes = reader.take_bytes(N)?; let mut out = Box::<[u8; N]>::new_uninit(); let src = bytes.as_ptr(); let dst = out.as_mut_ptr().cast::(); - // SAFETY: `take_bytes(N)` guarantees the source has exactly `N` bytes + // SAFETY: `take_bytes(N)` guarantees the source has exactly `N` bytes. unsafe { std::ptr::copy_nonoverlapping(src, dst, N); Ok(out.assume_init()) } } +} - pub fn take_u8(&mut self) -> Result { - Ok(self.take_bytes(1)?[0]) +impl WireParse for u8 { + fn parse(reader: &mut Reader) -> Result { + Ok(reader.take_bytes(1)?[0]) } +} - pub fn take_u16(&mut self) -> Result { - Ok(u16::from_le_bytes(self.take_array()?)) +impl WireParse for u16 { + fn parse(reader: &mut Reader) -> Result { + Ok(u16::from_le_bytes(reader.parse()?)) } +} - pub fn take_u32(&mut self) -> Result { - Ok(u32::from_le_bytes(self.take_array()?)) +impl WireParse for u32 { + fn parse(reader: &mut Reader) -> Result { + Ok(u32::from_le_bytes(reader.parse()?)) } +} - pub fn take_u64(&mut self) -> Result { - Ok(u64::from_le_bytes(self.take_array()?)) +impl WireParse for u64 { + fn parse(reader: &mut Reader) -> Result { + Ok(u64::from_le_bytes(reader.parse()?)) } +} - pub fn take_varint(&mut self) -> Result { - let first = self.take_u8()?; +impl WireParse for VarInt { + fn parse(reader: &mut Reader) -> Result { + let first = reader.parse::()?; let tag = first >> 6; let first = first & 0b0011_1111; let value = match tag { @@ -145,19 +121,19 @@ impl Reader { 0b01 => { let mut buf = [0; 2]; buf[0] = first; - buf[1] = self.take_u8()?; + buf[1] = reader.parse()?; u64::from(u16::from_be_bytes(buf)) } 0b10 => { let mut buf = [0; 4]; buf[0] = first; - buf[1..].copy_from_slice(&self.take_array::<3>()?); + buf[1..].copy_from_slice(&reader.parse::<[u8; 3]>()?); u64::from(u32::from_be_bytes(buf)) } 0b11 => { let mut buf = [0; 8]; buf[0] = first; - buf[1..].copy_from_slice(&self.take_array::<7>()?); + buf[1..].copy_from_slice(&reader.parse::<[u8; 7]>()?); u64::from_be_bytes(buf) } _ => unreachable!(), @@ -166,14 +142,54 @@ impl Reader { // SAFETY: the decoded value is guaranteed to fit in the 62-bit varint range. Ok(unsafe { VarInt::from_u64_unchecked(value) }) } +} - pub fn take_bool(&mut self) -> Result { - match self.take_u8()? { +impl WireParse for bool { + fn parse(reader: &mut Reader) -> Result { + match reader.parse::()? { 0 => Ok(false), 1 => Ok(true), _ => Err(WireError::InvalidPayload), } } +} + +pub struct Reader { + remaining: Option, +} + +impl Reader { + pub fn new(bytes: B) -> Self { + Self { + remaining: Some(bytes), + } + } + + pub fn is_empty(&self) -> bool { + self.remaining.as_ref().unwrap().is_empty() + } + + pub fn remaining_len(&self) -> usize { + self.remaining.as_ref().unwrap().len() + } + + pub fn take_bytes(&mut self, len: usize) -> Result { + let remaining = self.remaining.take().unwrap(); + match remaining.split_at(len) { + Ok((head, tail)) => { + self.remaining = Some(tail); + Ok(head) + } + Err(remaining) => { + self.remaining = Some(remaining); + Err(WireError::InvalidPayload) + } + } + } + + pub fn take_rest(mut self) -> B { + self.remaining.take().unwrap() + } #[inline] pub fn parse(&mut self) -> Result diff --git a/ql-wire/src/encrypted/ack.rs b/ql-wire/src/encrypted/ack.rs index 02e934ff..07a6d995 100644 --- a/ql-wire/src/encrypted/ack.rs +++ b/ql-wire/src/encrypted/ack.rs @@ -36,8 +36,8 @@ impl RecordAck { impl codec::WireParse for RecordAck { fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { - base_seq: RecordSeq(reader.take_varint()?), - bits: reader.take_u64()?, + base_seq: reader.parse()?, + bits: reader.parse()?, }) } } diff --git a/ql-wire/src/encrypted/close.rs b/ql-wire/src/encrypted/close.rs index 51653643..85ae85d1 100644 --- a/ql-wire/src/encrypted/close.rs +++ b/ql-wire/src/encrypted/close.rs @@ -14,14 +14,6 @@ impl SessionClose { } } -impl codec::WireParse for SessionClose { - fn parse(reader: &mut Reader) -> Result { - Ok(Self { - code: SessionCloseCode(reader.take_u16()?), - }) - } -} - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[repr(transparent)] pub struct SessionCloseCode(pub u16); @@ -31,3 +23,17 @@ impl SessionCloseCode { pub const PROTOCOL: Self = Self(1); pub const TIMEOUT: Self = Self(2); } + +impl codec::WireParse for SessionCloseCode { + fn parse(reader: &mut Reader) -> Result { + Ok(Self(reader.parse()?)) + } +} + +impl codec::WireParse for SessionClose { + fn parse(reader: &mut Reader) -> Result { + Ok(Self { + code: reader.parse()?, + }) + } +} diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index abeb55c4..b8d8235f 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -42,6 +42,12 @@ impl StreamId { } } +impl codec::WireParse for StreamId { + fn parse(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.parse()?)) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct SessionRecord { pub frames: Vec, @@ -215,7 +221,7 @@ where fn split_variable_frame(bytes: &[u8]) -> Result<(&[u8], &[u8]), WireError> { let mut reader = codec::Reader::new(bytes); - let len = usize::try_from(reader.take_varint()?.into_inner()) + let len = usize::try_from(reader.parse::()?.into_inner()) .map_err(|_| WireError::InvalidPayload)?; let bytes = &bytes[bytes.len() - reader.remaining_len()..]; bytes.split_at_checked(len).ok_or(WireError::InvalidPayload) diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index f8796228..ef5ac08a 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -28,9 +28,9 @@ impl StreamClose { impl codec::WireParse for StreamClose { fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { - stream_id: StreamId(reader.take_varint()?), - target: CloseTarget::try_from(reader.take_u8()?)?, - code: StreamCloseCode(reader.take_u16()?), + stream_id: reader.parse()?, + target: reader.parse()?, + code: reader.parse()?, }) } } @@ -66,6 +66,18 @@ impl TryFrom for CloseTarget { } } +impl codec::WireParse for CloseTarget { + fn parse(reader: &mut codec::Reader) -> Result { + reader.parse::()?.try_into() + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[repr(transparent)] pub struct StreamCloseCode(pub u16); + +impl codec::WireParse for StreamCloseCode { + fn parse(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.parse()?)) + } +} diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs index 0f66c4bb..0b25d0f7 100644 --- a/ql-wire/src/encrypted/stream_data.rs +++ b/ql-wire/src/encrypted/stream_data.rs @@ -19,9 +19,9 @@ impl StreamData { pub fn parse(bytes: B) -> Result { let mut reader = codec::Reader::new(bytes); Ok(Self { - stream_id: StreamId(reader.take_varint()?), - offset: reader.take_varint()?, - fin: reader.take_bool()?, + stream_id: reader.parse()?, + offset: reader.parse()?, + fin: reader.parse()?, bytes: reader.take_rest(), }) } diff --git a/ql-wire/src/encrypted/stream_window.rs b/ql-wire/src/encrypted/stream_window.rs index 91c8a15a..070626a7 100644 --- a/ql-wire/src/encrypted/stream_window.rs +++ b/ql-wire/src/encrypted/stream_window.rs @@ -22,8 +22,8 @@ impl StreamWindow { impl codec::WireParse for StreamWindow { fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { - stream_id: StreamId(reader.take_varint()?), - maximum_offset: reader.take_varint()?, + stream_id: reader.parse()?, + maximum_offset: reader.parse()?, }) } } diff --git a/ql-wire/src/encrypted_message.rs b/ql-wire/src/encrypted_message.rs index 886b50b3..bc97cf7e 100644 --- a/ql-wire/src/encrypted_message.rs +++ b/ql-wire/src/encrypted_message.rs @@ -27,7 +27,7 @@ impl EncryptedMessage { pub fn parse(bytes: B) -> Result { let mut reader = codec::Reader::new(bytes); Ok(Self { - auth: reader.take_array()?, + auth: reader.parse()?, ciphertext: reader.take_rest(), }) } diff --git a/ql-wire/src/handshake/ik.rs b/ql-wire/src/handshake/ik.rs index 1d19069f..4d7ab980 100644 --- a/ql-wire/src/handshake/ik.rs +++ b/ql-wire/src/handshake/ik.rs @@ -44,9 +44,9 @@ impl codec::WireParse for Ik1 { header: reader.parse()?, meta: reader.parse()?, transport_params: reader.parse()?, - skem_ciphertext: MlKemCiphertext::new(reader.take_boxed_array()?), + skem_ciphertext: reader.parse()?, ephemeral: reader.parse()?, - static_bundle: EncryptedPeerBundle::new(reader.take_boxed_array()?), + static_bundle: reader.parse()?, }) } } @@ -82,8 +82,8 @@ impl codec::WireParse for Ik2 { header: reader.parse()?, meta: reader.parse()?, transport_params: reader.parse()?, - ekem_ciphertext: MlKemCiphertext::new(reader.take_boxed_array()?), - skem_ciphertext: EncryptedMlKemCiphertext::new(reader.take_boxed_array()?), + ekem_ciphertext: reader.parse()?, + skem_ciphertext: reader.parse()?, }) } } diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index 31ed5cda..534a0148 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -40,7 +40,7 @@ impl codec::WireParse for Kk1 { header: reader.parse()?, meta: reader.parse()?, transport_params: reader.parse()?, - skem_ciphertext: MlKemCiphertext::new(reader.take_boxed_array()?), + skem_ciphertext: reader.parse()?, ephemeral: reader.parse()?, }) } @@ -77,8 +77,8 @@ impl codec::WireParse for Kk2 { header: reader.parse()?, meta: reader.parse()?, transport_params: reader.parse()?, - ekem_ciphertext: MlKemCiphertext::new(reader.take_boxed_array()?), - skem_ciphertext: EncryptedMlKemCiphertext::new(reader.take_boxed_array()?), + ekem_ciphertext: reader.parse()?, + skem_ciphertext: reader.parse()?, }) } } diff --git a/ql-wire/src/handshake/meta.rs b/ql-wire/src/handshake/meta.rs index f74697ad..f26f3cf8 100644 --- a/ql-wire/src/handshake/meta.rs +++ b/ql-wire/src/handshake/meta.rs @@ -10,6 +10,12 @@ pub struct HandshakeMeta { pub valid_until: u64, } +impl codec::WireParse for HandshakeId { + fn parse(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.parse()?)) + } +} + impl HandshakeMeta { pub const WIRE_SIZE: usize = size_of::() + size_of::(); @@ -36,8 +42,8 @@ impl HandshakeMeta { impl codec::WireParse for HandshakeMeta { fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { - handshake_id: HandshakeId(reader.take_u32()?), - valid_until: reader.take_u64()?, + handshake_id: reader.parse()?, + valid_until: reader.parse()?, }) } } diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 56408933..37dd50c3 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -44,8 +44,8 @@ impl HandshakeHeader { impl codec::WireParse for HandshakeHeader { fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { - sender: XID(reader.take_array()?), - recipient: XID(reader.take_array()?), + sender: reader.parse()?, + recipient: reader.parse()?, }) } } @@ -66,7 +66,7 @@ impl EphemeralPublicKey { impl codec::WireParse for EphemeralPublicKey { fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { - mlkem_public_key: MlKemPublicKey::new(reader.take_boxed_array()?), + mlkem_public_key: reader.parse()?, }) } } @@ -86,6 +86,12 @@ impl EncryptedMlKemCiphertext { } } +impl codec::WireParse for EncryptedMlKemCiphertext { + fn parse(reader: &mut codec::Reader) -> Result { + Ok(Self::new(reader.parse()?)) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct EncryptedPeerBundle(Box<[u8; Self::WIRE_SIZE]>); @@ -101,6 +107,12 @@ impl EncryptedPeerBundle { } } +impl codec::WireParse for EncryptedPeerBundle { + fn parse(reader: &mut codec::Reader) -> Result { + Ok(Self::new(reader.parse()?)) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct FinalizedHandshake { pub tx_key: SessionKey, diff --git a/ql-wire/src/handshake/transport_params.rs b/ql-wire/src/handshake/transport_params.rs index 2acf9608..08e580d2 100644 --- a/ql-wire/src/handshake/transport_params.rs +++ b/ql-wire/src/handshake/transport_params.rs @@ -32,7 +32,7 @@ impl Default for TransportParams { impl codec::WireParse for TransportParams { fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { - initial_stream_receive_window: reader.take_u32()?, + initial_stream_receive_window: reader.parse()?, }) } } diff --git a/ql-wire/src/header.rs b/ql-wire/src/header.rs index 6db1e24b..760761e8 100644 --- a/ql-wire/src/header.rs +++ b/ql-wire/src/header.rs @@ -1,4 +1,4 @@ -use crate::{codec, ByteSlice, VarInt, VarIntBoundsExceeded, QL_WIRE_VERSION}; +use crate::{codec, ByteSlice, VarInt, VarIntBoundsExceeded, WireError, QL_WIRE_VERSION}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct SessionHeader { @@ -46,6 +46,18 @@ impl ConnectionId { } } +impl codec::WireParse for RecordSeq { + fn parse(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.parse()?)) + } +} + +impl codec::WireParse for ConnectionId { + fn parse(reader: &mut codec::Reader) -> Result { + Ok(Self::from_data(reader.parse()?)) + } +} + impl SessionHeader { pub const MAX_ENCODED_LEN: usize = ConnectionId::SIZE + RecordSeq::MAX_ENCODED_LEN; const AAD_DOMAIN: &[u8] = b"ql-wire:session-aad:v1"; @@ -84,10 +96,10 @@ impl SessionHeader { } impl codec::WireParse for SessionHeader { - fn parse(reader: &mut codec::Reader) -> Result { + fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { - connection_id: ConnectionId::from_data(reader.take_array()?), - seq: RecordSeq(reader.take_varint()?), + connection_id: reader.parse()?, + seq: reader.parse()?, }) } } diff --git a/ql-wire/src/identity.rs b/ql-wire/src/identity.rs index 8533a054..1328bf5a 100644 --- a/ql-wire/src/identity.rs +++ b/ql-wire/src/identity.rs @@ -32,10 +32,10 @@ impl PeerBundle { impl codec::WireParse for PeerBundle { fn parse(reader: &mut codec::Reader) -> Result { Ok(Self { - version: reader.take_u16()?, - xid: XID(reader.take_array()?), - capabilities: reader.take_u32()?, - mlkem_public_key: MlKemPublicKey::new(reader.take_boxed_array()?), + version: reader.parse()?, + xid: reader.parse()?, + capabilities: reader.parse()?, + mlkem_public_key: reader.parse()?, }) } } diff --git a/ql-wire/src/pq.rs b/ql-wire/src/pq.rs index 7000e406..ce87bfb6 100644 --- a/ql-wire/src/pq.rs +++ b/ql-wire/src/pq.rs @@ -1,3 +1,5 @@ +use crate::{codec, ByteSlice, WireError}; + pub const ML_KEM_SUITE_TAG: &[u8] = b"ml-kem-1024"; // ql-wire fixes the protocol to ML-KEM-1024 on the wire, but the host @@ -39,6 +41,12 @@ impl Drop for SessionKey { } } +impl codec::WireParse for SessionKey { + fn parse(reader: &mut codec::Reader) -> Result { + Ok(Self::from_data(reader.parse()?)) + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct MlKemPublicKey(Box<[u8; MlKemPublicKey::SIZE]>); @@ -60,6 +68,12 @@ impl Drop for MlKemPublicKey { } } +impl codec::WireParse for MlKemPublicKey { + fn parse(reader: &mut codec::Reader) -> Result { + Ok(Self::new(reader.parse()?)) + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct MlKemPrivateKey(Box<[u8; MlKemPrivateKey::SIZE]>); @@ -102,6 +116,12 @@ impl Drop for MlKemCiphertext { } } +impl codec::WireParse for MlKemCiphertext { + fn parse(reader: &mut codec::Reader) -> Result { + Ok(Self::new(reader.parse()?)) + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct MlKemKeyPair { pub private: MlKemPrivateKey, diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index df169ace..42afc52c 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -44,10 +44,16 @@ impl TryFrom for RecordType { } } +impl WireParse for RecordType { + fn parse(reader: &mut codec::Reader) -> Result { + reader.parse::()?.try_into() + } +} + impl WireParse for RecordHeader { fn parse(reader: &mut codec::Reader) -> Result { - let version = reader.take_u8()?; - let record_type = RecordType::try_from(reader.take_u8()?)?; + let version = reader.parse()?; + let record_type = reader.parse()?; Ok(Self { version, record_type, @@ -78,6 +84,12 @@ impl TryFrom for HandshakeKind { } } +impl WireParse for HandshakeKind { + fn parse(reader: &mut codec::Reader) -> Result { + reader.parse::()?.try_into() + } +} + impl QlHandshakeRecord { pub fn kind(&self) -> HandshakeKind { match self { @@ -125,7 +137,7 @@ impl WireParse for QlHandshakeRecord { if header.record_type != RecordType::Handshake { return Err(WireError::InvalidPayload); } - let kind = HandshakeKind::try_from(reader.take_u8()?)?; + let kind = reader.parse::()?; match kind { HandshakeKind::Ik1 => Ok(Self::Ik1(reader.parse()?)), HandshakeKind::Ik2 => Ok(Self::Ik2(reader.parse()?)), diff --git a/ql-wire/src/xid.rs b/ql-wire/src/xid.rs index 040b3127..60f1f06f 100644 --- a/ql-wire/src/xid.rs +++ b/ql-wire/src/xid.rs @@ -1,3 +1,5 @@ +use crate::{codec, ByteSlice, WireError}; + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[repr(transparent)] pub struct XID(pub [u8; Self::SIZE]); @@ -5,3 +7,9 @@ pub struct XID(pub [u8; Self::SIZE]); impl XID { pub const SIZE: usize = 16; } + +impl codec::WireParse for XID { + fn parse(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.parse()?)) + } +} From 069eec7f15088afd0e31bd71e202c704ea562d5e Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 6 Apr 2026 07:16:44 -0400 Subject: [PATCH 121/304] ql-wire: encode trait --- Cargo.lock | 1 + ql-wire/Cargo.toml | 3 + ql-wire/src/codec.rs | 217 ++++++++++++++-------- ql-wire/src/encrypted/ack.rs | 32 ++-- ql-wire/src/encrypted/builder.rs | 82 +++----- ql-wire/src/encrypted/close.rs | 38 ++-- ql-wire/src/encrypted/mod.rs | 109 ++++++++--- ql-wire/src/encrypted/stream_close.rs | 59 ++++-- ql-wire/src/encrypted/stream_data.rs | 29 +-- ql-wire/src/encrypted/stream_window.rs | 22 +-- ql-wire/src/encrypted_message.rs | 34 ++-- ql-wire/src/handshake/ik.rs | 78 ++++---- ql-wire/src/handshake/kk.rs | 74 ++++---- ql-wire/src/handshake/meta.rs | 40 ++-- ql-wire/src/handshake/mod.rs | 93 ++++++---- ql-wire/src/handshake/transport_params.rs | 20 +- ql-wire/src/header.rs | 88 +++++---- ql-wire/src/identity.rs | 33 ++-- ql-wire/src/pq.rs | 50 ++++- ql-wire/src/record.rs | 138 ++++++++------ ql-wire/src/tests.rs | 51 ++--- ql-wire/src/xid.rs | 18 +- 22 files changed, 791 insertions(+), 518 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 91c105fc..c9bc20e7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2220,6 +2220,7 @@ dependencies = [ name = "ql-wire" version = "0.1.0" dependencies = [ + "bytes", "libcrux-aesgcm", "libcrux-ml-kem", "sha2", diff --git a/ql-wire/Cargo.toml b/ql-wire/Cargo.toml index 42db6996..4e713826 100644 --- a/ql-wire/Cargo.toml +++ b/ql-wire/Cargo.toml @@ -5,6 +5,9 @@ edition = "2021" description = "Quantum Link wire format types and crypto helpers" license = "Proprietary" +[dependencies] +bytes = "1" + [dev-dependencies] libcrux-aesgcm = "0.0.7" libcrux-ml-kem = "0.0.7" diff --git a/ql-wire/src/codec.rs b/ql-wire/src/codec.rs index b115f99d..7062c53a 100644 --- a/ql-wire/src/codec.rs +++ b/ql-wire/src/codec.rs @@ -1,61 +1,31 @@ -use crate::{ByteSlice, VarInt, WireError}; - -pub fn write_u8(out: &mut [u8], value: u8) -> &mut [u8] { - let (head, rest) = out.split_at_mut(1); - head[0] = value; - rest -} +use ::bytes::BufMut; -pub fn write_u16(out: &mut [u8], value: u16) -> &mut [u8] { - let (head, rest) = out.split_at_mut(size_of::()); - head.copy_from_slice(&value.to_le_bytes()); - rest -} +use crate::{ByteSlice, VarInt, WireError}; -pub fn write_u32(out: &mut [u8], value: u32) -> &mut [u8] { - let (head, rest) = out.split_at_mut(size_of::()); - head.copy_from_slice(&value.to_le_bytes()); - rest -} +pub trait WireEncode { + fn encoded_len(&self) -> usize; -pub fn write_u64(out: &mut [u8], value: u64) -> &mut [u8] { - let (head, rest) = out.split_at_mut(size_of::()); - head.copy_from_slice(&value.to_le_bytes()); - rest -} + fn encode(&self, out: &mut W); -pub fn write_varint(out: &mut [u8], value: VarInt) -> &mut [u8] { - let x = value.into_inner(); - match value.size() { - 1 => write_u8(out, x as u8), - 2 => write_bytes(out, &(((0b01u16 << 14) | (x as u16)).to_be_bytes())), - 4 => write_bytes(out, &(((0b10u32 << 30) | (x as u32)).to_be_bytes())), - 8 => write_bytes(out, &(((0b11u64 << 62) | x).to_be_bytes())), - _ => unreachable!("malformed varint"), + fn encode_vec(&self) -> Vec { + let mut out = Vec::with_capacity(self.encoded_len()); + self.encode(&mut out); + debug_assert_eq!(out.len(), self.encoded_len()); + out } } -pub fn write_bool(out: &mut [u8], value: bool) -> &mut [u8] { - write_u8(out, u8::from(value)) -} +pub trait WireDecode: Sized { + fn decode(reader: &mut Reader) -> Result; -pub fn write_bytes<'a>(out: &'a mut [u8], bytes: &[u8]) -> &'a mut [u8] { - let (head, rest) = out.split_at_mut(bytes.len()); - head.copy_from_slice(bytes); - rest -} - -pub trait WireParse: Sized { - fn parse(reader: &mut Reader) -> Result; - - fn parse_prefix(bytes: B) -> Result { + fn decode_bytes(bytes: B) -> Result { let mut reader = Reader::new(bytes); - Self::parse(&mut reader) + Self::decode(&mut reader) } - fn parse_bytes(bytes: B) -> Result { + fn decode_exact(bytes: B) -> Result { let mut reader = Reader::new(bytes); - let value = Self::parse(&mut reader)?; + let value = Self::decode(&mut reader)?; if reader.is_empty() { Ok(value) } else { @@ -64,8 +34,8 @@ pub trait WireParse: Sized { } } -impl WireParse for [u8; N] { - fn parse(reader: &mut Reader) -> Result { +impl WireDecode for [u8; N] { + fn decode(reader: &mut Reader) -> Result { let bytes = reader.take_bytes(N)?; let mut out = [0u8; N]; out.copy_from_slice(&bytes); @@ -73,8 +43,18 @@ impl WireParse for [u8; N] { } } -impl WireParse for Box<[u8; N]> { - fn parse(reader: &mut Reader) -> Result { +impl WireEncode for [u8; N] { + fn encoded_len(&self) -> usize { + N + } + + fn encode(&self, out: &mut W) { + out.put_slice(self); + } +} + +impl WireDecode for Box<[u8; N]> { + fn decode(reader: &mut Reader) -> Result { let bytes = reader.take_bytes(N)?; let mut out = Box::<[u8; N]>::new_uninit(); let src = bytes.as_ptr(); @@ -87,33 +67,93 @@ impl WireParse for Box<[u8; N]> { } } -impl WireParse for u8 { - fn parse(reader: &mut Reader) -> Result { +impl WireEncode for Box<[u8; N]> { + fn encoded_len(&self) -> usize { + N + } + + fn encode(&self, out: &mut W) { + out.put_slice(self.as_ref()); + } +} + +impl WireEncode for [u8] { + fn encoded_len(&self) -> usize { + self.len() + } + + fn encode(&self, out: &mut W) { + out.put_slice(self); + } +} + +impl WireDecode for u8 { + fn decode(reader: &mut Reader) -> Result { Ok(reader.take_bytes(1)?[0]) } } -impl WireParse for u16 { - fn parse(reader: &mut Reader) -> Result { - Ok(u16::from_le_bytes(reader.parse()?)) +impl WireEncode for u8 { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u8(*self); } } -impl WireParse for u32 { - fn parse(reader: &mut Reader) -> Result { - Ok(u32::from_le_bytes(reader.parse()?)) +impl WireDecode for u16 { + fn decode(reader: &mut Reader) -> Result { + Ok(u16::from_be_bytes(reader.decode()?)) } } -impl WireParse for u64 { - fn parse(reader: &mut Reader) -> Result { - Ok(u64::from_le_bytes(reader.parse()?)) +impl WireEncode for u16 { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u16(*self); + } +} + +impl WireDecode for u32 { + fn decode(reader: &mut Reader) -> Result { + Ok(u32::from_be_bytes(reader.decode()?)) + } +} + +impl WireEncode for u32 { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u32(*self); + } +} + +impl WireDecode for u64 { + fn decode(reader: &mut Reader) -> Result { + Ok(u64::from_be_bytes(reader.decode()?)) + } +} + +impl WireEncode for u64 { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u64(*self); } } -impl WireParse for VarInt { - fn parse(reader: &mut Reader) -> Result { - let first = reader.parse::()?; +impl WireDecode for VarInt { + fn decode(reader: &mut Reader) -> Result { + let first = reader.decode::()?; let tag = first >> 6; let first = first & 0b0011_1111; let value = match tag { @@ -121,19 +161,19 @@ impl WireParse for VarInt { 0b01 => { let mut buf = [0; 2]; buf[0] = first; - buf[1] = reader.parse()?; + buf[1] = reader.decode()?; u64::from(u16::from_be_bytes(buf)) } 0b10 => { let mut buf = [0; 4]; buf[0] = first; - buf[1..].copy_from_slice(&reader.parse::<[u8; 3]>()?); + buf[1..].copy_from_slice(&reader.decode::<[u8; 3]>()?); u64::from(u32::from_be_bytes(buf)) } 0b11 => { let mut buf = [0; 8]; buf[0] = first; - buf[1..].copy_from_slice(&reader.parse::<[u8; 7]>()?); + buf[1..].copy_from_slice(&reader.decode::<[u8; 7]>()?); u64::from_be_bytes(buf) } _ => unreachable!(), @@ -144,9 +184,26 @@ impl WireParse for VarInt { } } -impl WireParse for bool { - fn parse(reader: &mut Reader) -> Result { - match reader.parse::()? { +impl WireEncode for VarInt { + fn encoded_len(&self) -> usize { + self.size() + } + + fn encode(&self, out: &mut W) { + let x = self.into_inner(); + match self.size() { + 1 => out.put_u8(x as u8), + 2 => out.put_u16((0b01 << 14) | x as u16), + 4 => out.put_u32((0b10 << 30) | x as u32), + 8 => out.put_u64((0b11 << 62) | x), + _ => unreachable!("malformed varint"), + } + } +} + +impl WireDecode for bool { + fn decode(reader: &mut Reader) -> Result { + match reader.decode::()? { 0 => Ok(false), 1 => Ok(true), _ => Err(WireError::InvalidPayload), @@ -154,6 +211,16 @@ impl WireParse for bool { } } +impl WireEncode for bool { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u8(u8::from(*self)); + } +} + pub struct Reader { remaining: Option, } @@ -187,15 +254,15 @@ impl Reader { } } - pub fn take_rest(mut self) -> B { - self.remaining.take().unwrap() + pub fn take_rest(&mut self) -> B { + self.take_bytes(self.remaining_len()).unwrap() } #[inline] - pub fn parse(&mut self) -> Result + pub fn decode(&mut self) -> Result where - T: WireParse, + T: WireDecode, { - T::parse(self) + T::decode(self) } } diff --git a/ql-wire/src/encrypted/ack.rs b/ql-wire/src/encrypted/ack.rs index 07a6d995..0e7600e1 100644 --- a/ql-wire/src/encrypted/ack.rs +++ b/ql-wire/src/encrypted/ack.rs @@ -1,4 +1,4 @@ -use crate::{codec, ByteSlice, RecordSeq, WireError}; +use crate::{codec, ByteSlice, RecordSeq, WireEncode, WireError}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct RecordAck { @@ -21,23 +21,24 @@ impl RecordAck { (self.bits & (1u64 << offset)) != 0 } +} - pub fn wire_size(&self) -> usize { +impl WireEncode for RecordAck { + fn encoded_len(&self) -> usize { self.base_seq.encoded_len() + size_of::() } - pub fn encode_into(&self, out: &mut [u8]) { - assert!(out.len() >= self.wire_size()); - let out = codec::write_varint(out, self.base_seq.0); - let _ = codec::write_u64(out, self.bits); + fn encode(&self, out: &mut W) { + self.base_seq.encode(out); + self.bits.encode(out); } } -impl codec::WireParse for RecordAck { - fn parse(reader: &mut codec::Reader) -> Result { +impl codec::WireDecode for RecordAck { + fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { - base_seq: reader.parse()?, - bits: reader.parse()?, + base_seq: reader.decode()?, + bits: reader.decode()?, }) } } @@ -45,7 +46,7 @@ impl codec::WireParse for RecordAck { #[cfg(test)] mod tests { use super::RecordAck; - use crate::{RecordSeq, WireError, WireParse}; + use crate::{RecordSeq, WireEncode, WireError, WireDecode}; #[test] fn encode_decode_round_trip() { @@ -53,10 +54,9 @@ mod tests { base_seq: RecordSeq::from_u32(42), bits: (1u64 << 0) | (1u64 << 17) | (1u64 << 63), }; - let mut encoded = vec![0; ack.wire_size()]; - ack.encode_into(&mut encoded); + let encoded = ack.encode_vec(); - assert_eq!(RecordAck::parse_bytes(encoded.as_slice()).unwrap(), ack); + assert_eq!(RecordAck::decode_exact(encoded.as_slice()).unwrap(), ack); } #[test] @@ -77,12 +77,12 @@ mod tests { #[test] fn decode_rejects_truncated_payload() { assert_eq!( - RecordAck::parse_bytes(&[][..]), + RecordAck::decode_exact(&[][..]), Err(WireError::InvalidPayload) ); let encoded = vec![0; RecordSeq::from_u32(0).encoded_len() + size_of::()]; assert_eq!( - RecordAck::parse_bytes(&encoded[..encoded.len() - 1]), + RecordAck::decode_exact(&encoded[..encoded.len() - 1]), Err(WireError::InvalidPayload) ); } diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index ac5c8495..bca27a31 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -1,7 +1,9 @@ +use ::bytes::BufMut; + use super::{RecordAck, SessionClose, SessionFrame, StreamClose, StreamData, StreamWindow}; use crate::{ - codec, ByteChunks, ConnectionId, Nonce, QlCrypto, RecordSeq, RecordType, SessionHeader, - SessionKey, VarInt, QL_WIRE_VERSION, + ByteChunks, ConnectionId, Nonce, QlCrypto, RecordSeq, RecordType, SessionHeader, SessionKey, + VarInt, WireEncode, QL_WIRE_VERSION, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -65,49 +67,23 @@ impl SessionRecordBuilder { } pub fn push_ack(&mut self, ack: &RecordAck) -> bool { - self.push_frame_payload(super::SessionFrameKind::Ack, ack.wire_size(), |payload| { - ack.encode_into(payload); - }) + self.push_frame_payload(super::SessionFrameKind::Ack, ack) } pub fn push_stream_data(&mut self, frame: &StreamData) -> bool { - self.push_len_prefixed_frame( - super::SessionFrameKind::StreamData, - frame.wire_size(), - |payload| { - frame.encode_into(payload); - }, - ) + self.push_len_prefixed_frame(super::SessionFrameKind::StreamData, frame) } pub fn push_stream_window(&mut self, frame: &StreamWindow) -> bool { - self.push_frame_payload( - super::SessionFrameKind::StreamWindow, - frame.wire_size(), - |payload| { - frame.encode_into(payload); - }, - ) + self.push_frame_payload(super::SessionFrameKind::StreamWindow, frame) } pub fn push_stream_close(&mut self, frame: &StreamClose) -> bool { - self.push_frame_payload( - super::SessionFrameKind::StreamClose, - frame.wire_size(), - |payload| { - frame.encode_into(payload); - }, - ) + self.push_frame_payload(super::SessionFrameKind::StreamClose, frame) } pub fn push_close(&mut self, close: &SessionClose) -> bool { - self.push_frame_payload( - super::SessionFrameKind::Close, - SessionClose::WIRE_SIZE, - |payload| { - close.encode_into(payload); - }, - ) + self.push_frame_payload(super::SessionFrameKind::Close, close) } pub fn push_frame(&mut self, frame: &SessionFrame) -> bool { @@ -141,54 +117,56 @@ impl SessionRecordBuilder { &mut self.bytes[self.prefix_len..], ); - let prefix = &mut self.bytes[..self.prefix_len]; + let mut prefix = &mut self.bytes[..self.prefix_len]; prefix[0] = QL_WIRE_VERSION; prefix[1] = RecordType::Session as u8; - let auth_out = header.encode_into(&mut prefix[2..]); - auth_out[..crate::ENCRYPTED_MESSAGE_AUTH_SIZE].copy_from_slice(&auth); + prefix = &mut prefix[2..]; + header.encode(&mut prefix); + auth.encode(&mut prefix); + debug_assert!(prefix.is_empty()); self.bytes } - fn push_wire_size(&mut self, wire_size: usize, encode: impl FnOnce(&mut [u8])) -> bool { + fn push_wire_size(&mut self, wire_size: usize, encode: impl FnOnce(&mut Vec)) -> bool { if !self.can_push_len(wire_size) { return false; } self.ensure_prefix_capacity(wire_size); let start = self.bytes.len(); - self.bytes.resize(start + wire_size, 0); - encode(&mut self.bytes[start..]); + encode(&mut self.bytes); + debug_assert_eq!(self.bytes.len(), start + wire_size); true } fn push_empty_frame(&mut self, kind: super::SessionFrameKind) -> bool { - self.push_wire_size(1, |out| out[0] = kind as u8) + self.push_wire_size(1, |out| out.put_u8(kind as u8)) } - fn push_frame_payload( + fn push_frame_payload( &mut self, kind: super::SessionFrameKind, - payload_wire_size: usize, - encode_payload: impl FnOnce(&mut [u8]), + payload: &T, ) -> bool { + let payload_wire_size = payload.encoded_len(); self.push_wire_size(1 + payload_wire_size, |out| { - out[0] = kind as u8; - encode_payload(&mut out[1..]); + out.put_u8(kind as u8); + payload.encode(out); }) } - fn push_len_prefixed_frame( + fn push_len_prefixed_frame( &mut self, kind: super::SessionFrameKind, - payload_wire_size: usize, - encode_payload: impl FnOnce(&mut [u8]), + payload: &T, ) -> bool { + let payload_wire_size = payload.encoded_len(); let Ok(prefix_len) = VarInt::try_from(payload_wire_size) else { return false; }; - self.push_wire_size(1 + prefix_len.size() + payload_wire_size, |out| { - out[0] = kind as u8; - let payload = codec::write_varint(&mut out[1..], prefix_len); - encode_payload(payload); + self.push_wire_size(1 + prefix_len.encoded_len() + payload_wire_size, |out| { + out.put_u8(kind as u8); + prefix_len.encode(out); + payload.encode(out); }) } diff --git a/ql-wire/src/encrypted/close.rs b/ql-wire/src/encrypted/close.rs index 85ae85d1..e0860d7a 100644 --- a/ql-wire/src/encrypted/close.rs +++ b/ql-wire/src/encrypted/close.rs @@ -1,4 +1,4 @@ -use crate::{codec, codec::Reader, ByteSlice, WireError}; +use crate::{codec, codec::Reader, ByteSlice, WireEncode, WireError}; /// closes the whole session immediately with a close code. #[derive(Debug, Clone, PartialEq, Eq)] @@ -8,10 +8,6 @@ pub struct SessionClose { impl SessionClose { pub const WIRE_SIZE: usize = size_of::(); - - pub fn encode_into(&self, out: &mut [u8]) { - let _ = codec::write_u16(out, self.code.0); - } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -24,16 +20,36 @@ impl SessionCloseCode { pub const TIMEOUT: Self = Self(2); } -impl codec::WireParse for SessionCloseCode { - fn parse(reader: &mut Reader) -> Result { - Ok(Self(reader.parse()?)) +impl WireEncode for SessionCloseCode { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for SessionCloseCode { + fn decode(reader: &mut Reader) -> Result { + Ok(Self(reader.decode()?)) } } -impl codec::WireParse for SessionClose { - fn parse(reader: &mut Reader) -> Result { +impl codec::WireDecode for SessionClose { + fn decode(reader: &mut Reader) -> Result { Ok(Self { - code: reader.parse()?, + code: reader.decode()?, }) } } + +impl WireEncode for SessionClose { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.code.encode(out); + } +} diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index b8d8235f..de42ec45 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -1,6 +1,6 @@ use crate::{ codec, encrypted_message::EncryptedMessage, ByteChunks, ByteSlice, Nonce, QlCrypto, - SessionHeader, SessionKey, VarInt, VarIntBoundsExceeded, WireError, WireParse, + SessionHeader, SessionKey, VarInt, VarIntBoundsExceeded, WireDecode, WireEncode, WireError, }; mod ack; @@ -36,15 +36,21 @@ impl StreamId { pub const fn into_inner(self) -> u64 { self.0.into_inner() } +} - pub const fn encoded_len(self) -> usize { +impl WireEncode for StreamId { + fn encoded_len(&self) -> usize { self.0.size() } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } } -impl codec::WireParse for StreamId { - fn parse(reader: &mut codec::Reader) -> Result { - Ok(Self(reader.parse()?)) +impl codec::WireDecode for StreamId { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) } } @@ -109,29 +115,17 @@ impl SessionRecord { .collect::, _>>()?; Ok(Self { frames }) } - - pub fn wire_size(&self) -> usize { - self.frames - .iter() - .map(SessionFrame::wire_size) - .sum::() - } } -impl SessionFrame { - pub fn wire_size(&self) -> usize { - 1 + match self { - Self::Ping => 0, - Self::Ack(frame) => frame.wire_size(), - Self::StreamData(frame) => { - VarInt::try_from(frame.wire_size()) - .unwrap_or(VarInt::MAX) - .size() - + frame.wire_size() - } - Self::StreamWindow(frame) => frame.wire_size(), - Self::StreamClose(frame) => frame.wire_size(), - Self::Close(_) => SessionClose::WIRE_SIZE, +impl SessionFrame { + fn kind(&self) -> SessionFrameKind { + match self { + Self::Ping => SessionFrameKind::Ping, + Self::Ack(_) => SessionFrameKind::Ack, + Self::StreamData(_) => SessionFrameKind::StreamData, + Self::StreamWindow(_) => SessionFrameKind::StreamWindow, + Self::StreamClose(_) => SessionFrameKind::StreamClose, + Self::Close(_) => SessionFrameKind::Close, } } } @@ -149,6 +143,58 @@ impl SessionFrame { } } +impl WireEncode for SessionFrame { + fn encoded_len(&self) -> usize { + 1 + match self { + Self::Ping => 0, + Self::Ack(frame) => frame.encoded_len(), + Self::StreamData(frame) => { + let payload_len = frame.encoded_len(); + VarInt::try_from(payload_len) + .unwrap_or(VarInt::MAX) + .encoded_len() + + payload_len + } + Self::StreamWindow(frame) => frame.encoded_len(), + Self::StreamClose(frame) => frame.encoded_len(), + Self::Close(frame) => frame.encoded_len(), + } + } + + fn encode(&self, out: &mut W) { + out.put_u8(self.kind() as u8); + match self { + Self::Ping => {} + Self::Ack(frame) => frame.encode(out), + Self::StreamData(frame) => { + let payload_len = frame.encoded_len(); + let payload_len = VarInt::try_from(payload_len) + .expect("stream data frame length must fit ql-wire varint"); + payload_len.encode(out); + frame.encode(out); + } + Self::StreamWindow(frame) => frame.encode(out), + Self::StreamClose(frame) => frame.encode(out), + Self::Close(frame) => frame.encode(out), + } + } +} + +impl WireEncode for SessionRecord { + fn encoded_len(&self) -> usize { + self.frames + .iter() + .map(WireEncode::encoded_len) + .sum::() + } + + fn encode(&self, out: &mut W) { + for frame in &self.frames { + frame.encode(out); + } + } +} + impl<'a> Iterator for SessionFrameIter<'a> { type Item = Result, WireError>; @@ -192,7 +238,10 @@ fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrame<&[u8]>, &[u8]), WireEr } SessionFrameKind::StreamData => { let (frame, rest) = split_variable_frame(rest)?; - Ok((SessionFrame::StreamData(StreamData::parse(frame)?), rest)) + Ok(( + SessionFrame::StreamData(StreamData::decode_exact(frame)?), + rest, + )) } SessionFrameKind::StreamWindow => { let (frame, rest) = parse_inline_frame::(rest)?; @@ -211,17 +260,17 @@ fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrame<&[u8]>, &[u8]), WireEr fn parse_inline_frame(bytes: &[u8]) -> Result<(T, &[u8]), WireError> where - T: for<'a> WireParse<&'a [u8]>, + T: for<'a> WireDecode<&'a [u8]>, { let mut reader = codec::Reader::new(bytes); - let frame = reader.parse::()?; + let frame = reader.decode::()?; let consumed = bytes.len() - reader.remaining_len(); Ok((frame, &bytes[consumed..])) } fn split_variable_frame(bytes: &[u8]) -> Result<(&[u8], &[u8]), WireError> { let mut reader = codec::Reader::new(bytes); - let len = usize::try_from(reader.parse::()?.into_inner()) + let len = usize::try_from(reader.decode::()?.into_inner()) .map_err(|_| WireError::InvalidPayload)?; let bytes = &bytes[bytes.len() - reader.remaining_len()..]; bytes.split_at_checked(len).ok_or(WireError::InvalidPayload) diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index ef5ac08a..5c4a06a3 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -1,5 +1,5 @@ use super::StreamId; -use crate::{codec, ByteSlice, WireError}; +use crate::{codec, ByteSlice, WireEncode, WireError}; /// aborts one or both lanes of a stream with a close code /// @@ -14,23 +14,26 @@ pub struct StreamClose { } impl StreamClose { - pub fn wire_size(&self) -> usize { - self.stream_id.encoded_len() + size_of::() + size_of::() +} + +impl WireEncode for StreamClose { + fn encoded_len(&self) -> usize { + self.stream_id.encoded_len() + self.target.encoded_len() + self.code.encoded_len() } - pub fn encode_into(&self, out: &mut [u8]) { - let out = codec::write_varint(out, self.stream_id.0); - let out = codec::write_u8(out, self.target.to_wire()); - let _ = codec::write_u16(out, self.code.0); + fn encode(&self, out: &mut W) { + self.stream_id.encode(out); + self.target.encode(out); + self.code.encode(out); } } -impl codec::WireParse for StreamClose { - fn parse(reader: &mut codec::Reader) -> Result { +impl codec::WireDecode for StreamClose { + fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { - stream_id: reader.parse()?, - target: reader.parse()?, - code: reader.parse()?, + stream_id: reader.decode()?, + target: reader.decode()?, + code: reader.decode()?, }) } } @@ -53,6 +56,16 @@ impl CloseTarget { } } +impl WireEncode for CloseTarget { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + self.to_wire().encode(out); + } +} + impl TryFrom for CloseTarget { type Error = WireError; @@ -66,9 +79,9 @@ impl TryFrom for CloseTarget { } } -impl codec::WireParse for CloseTarget { - fn parse(reader: &mut codec::Reader) -> Result { - reader.parse::()?.try_into() +impl codec::WireDecode for CloseTarget { + fn decode(reader: &mut codec::Reader) -> Result { + reader.decode::()?.try_into() } } @@ -76,8 +89,18 @@ impl codec::WireParse for CloseTarget { #[repr(transparent)] pub struct StreamCloseCode(pub u16); -impl codec::WireParse for StreamCloseCode { - fn parse(reader: &mut codec::Reader) -> Result { - Ok(Self(reader.parse()?)) +impl codec::WireDecode for StreamCloseCode { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} + +impl WireEncode for StreamCloseCode { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); } } diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs index 0b25d0f7..962ae600 100644 --- a/ql-wire/src/encrypted/stream_data.rs +++ b/ql-wire/src/encrypted/stream_data.rs @@ -1,5 +1,5 @@ use super::StreamId; -use crate::{codec, ByteChunks, ByteSlice, VarInt, WireError}; +use crate::{codec, ByteChunks, ByteSlice, VarInt, WireDecode, WireEncode, WireError}; /// carries bytes for a stream and may finish that sending direction. #[derive(Debug, Clone, PartialEq, Eq)] @@ -15,13 +15,12 @@ impl StreamData { pub const MIN_WIRE_SIZE: usize = StreamId::MAX_ENCODED_LEN + VarInt::MAX_SIZE + size_of::(); } -impl StreamData { - pub fn parse(bytes: B) -> Result { - let mut reader = codec::Reader::new(bytes); +impl WireDecode for StreamData { + fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { - stream_id: reader.parse()?, - offset: reader.parse()?, - fin: reader.parse()?, + stream_id: reader.decode()?, + offset: reader.decode()?, + fin: reader.decode()?, bytes: reader.take_rest(), }) } @@ -43,19 +42,21 @@ impl StreamData { impl StreamData { pub fn header_len(&self) -> usize { - self.stream_id.encoded_len() + self.offset.size() + size_of::() + self.stream_id.encoded_len() + self.offset.encoded_len() + size_of::() } +} - pub fn wire_size(&self) -> usize { +impl WireEncode for StreamData { + fn encoded_len(&self) -> usize { self.header_len() + self.bytes.len() } - pub fn encode_into(&self, out: &mut [u8]) { - let out = codec::write_varint(out, self.stream_id.0); - let out = codec::write_varint(out, self.offset); - let mut out = codec::write_bool(out, self.fin); + fn encode(&self, out: &mut W) { + self.stream_id.encode(out); + self.offset.encode(out); + self.fin.encode(out); for chunk in self.bytes.chunks() { - out = codec::write_bytes(out, chunk); + chunk.encode(out); } } } diff --git a/ql-wire/src/encrypted/stream_window.rs b/ql-wire/src/encrypted/stream_window.rs index 070626a7..6a2274f9 100644 --- a/ql-wire/src/encrypted/stream_window.rs +++ b/ql-wire/src/encrypted/stream_window.rs @@ -1,5 +1,5 @@ use super::StreamId; -use crate::{codec, ByteSlice, VarInt, WireError}; +use crate::{codec, ByteSlice, VarInt, WireEncode, WireError}; /// advertises the highest byte offset the peer may send on a stream. #[derive(Debug, Clone, PartialEq, Eq)] @@ -8,22 +8,22 @@ pub struct StreamWindow { pub maximum_offset: VarInt, } -impl StreamWindow { - pub fn wire_size(&self) -> usize { - self.stream_id.encoded_len() + self.maximum_offset.size() +impl WireEncode for StreamWindow { + fn encoded_len(&self) -> usize { + self.stream_id.encoded_len() + self.maximum_offset.encoded_len() } - pub fn encode_into(&self, out: &mut [u8]) { - let out = codec::write_varint(out, self.stream_id.0); - let _ = codec::write_varint(out, self.maximum_offset); + fn encode(&self, out: &mut W) { + self.stream_id.encode(out); + self.maximum_offset.encode(out); } } -impl codec::WireParse for StreamWindow { - fn parse(reader: &mut codec::Reader) -> Result { +impl codec::WireDecode for StreamWindow { + fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { - stream_id: reader.parse()?, - maximum_offset: reader.parse()?, + stream_id: reader.decode()?, + maximum_offset: reader.decode()?, }) } } diff --git a/ql-wire/src/encrypted_message.rs b/ql-wire/src/encrypted_message.rs index bc97cf7e..293b5773 100644 --- a/ql-wire/src/encrypted_message.rs +++ b/ql-wire/src/encrypted_message.rs @@ -1,5 +1,6 @@ use crate::{ - codec, ByteSlice, Nonce, QlCrypto, SessionKey, WireError, ENCRYPTED_MESSAGE_AUTH_SIZE, + codec, ByteSlice, Nonce, QlCrypto, SessionKey, WireEncode, WireError, WireDecode, + ENCRYPTED_MESSAGE_AUTH_SIZE, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -23,28 +24,16 @@ impl EncryptedMessage { } } -impl EncryptedMessage { - pub fn parse(bytes: B) -> Result { - let mut reader = codec::Reader::new(bytes); +impl WireDecode for EncryptedMessage { + fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { - auth: reader.parse()?, + auth: reader.decode()?, ciphertext: reader.take_rest(), }) } } impl> EncryptedMessage { - pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { - let out = codec::write_bytes(out, &self.auth); - codec::write_bytes(out, self.ciphertext.as_ref()) - } - - pub fn encode(&self) -> Vec { - let mut out = vec![0; Self::HEADER_LEN + self.ciphertext.as_ref().len()]; - let _ = self.encode_into(&mut out); - out - } - pub fn decrypt( &self, crypto: &impl QlCrypto, @@ -60,6 +49,17 @@ impl> EncryptedMessage { } } +impl> WireEncode for EncryptedMessage { + fn encoded_len(&self) -> usize { + Self::HEADER_LEN + self.ciphertext.as_ref().len() + } + + fn encode(&self, out: &mut W) { + self.auth.encode(out); + self.ciphertext.as_ref().encode(out); + } +} + impl> EncryptedMessage { pub fn decrypt_in_place( mut self, @@ -92,6 +92,6 @@ impl EncryptedMessage> { } pub fn decode(bytes: &[u8]) -> Result { - Ok(EncryptedMessage::parse(bytes)?.into_owned()) + Ok(EncryptedMessage::decode_exact(bytes)?.into_owned()) } } diff --git a/ql-wire/src/handshake/ik.rs b/ql-wire/src/handshake/ik.rs index 4d7ab980..10ba0843 100644 --- a/ql-wire/src/handshake/ik.rs +++ b/ql-wire/src/handshake/ik.rs @@ -7,7 +7,7 @@ use super::{ }; use crate::{ codec, ByteSlice, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, - QlIdentity, WireError, + QlIdentity, WireEncode, WireError, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -27,30 +27,36 @@ impl Ik1 { + MlKemCiphertext::SIZE + EphemeralPublicKey::WIRE_SIZE + EncryptedPeerBundle::WIRE_SIZE; - - pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { - let out = self.header.encode_into(out); - let out = self.meta.encode_into(out); - let out = self.transport_params.encode_into(out); - let out = codec::write_bytes(out, self.skem_ciphertext.as_bytes()); - let out = self.ephemeral.encode_into(out); - codec::write_bytes(out, self.static_bundle.as_bytes()) - } } -impl codec::WireParse for Ik1 { - fn parse(reader: &mut codec::Reader) -> Result { +impl codec::WireDecode for Ik1 { + fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { - header: reader.parse()?, - meta: reader.parse()?, - transport_params: reader.parse()?, - skem_ciphertext: reader.parse()?, - ephemeral: reader.parse()?, - static_bundle: reader.parse()?, + header: reader.decode()?, + meta: reader.decode()?, + transport_params: reader.decode()?, + skem_ciphertext: reader.decode()?, + ephemeral: reader.decode()?, + static_bundle: reader.decode()?, }) } } +impl WireEncode for Ik1 { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.transport_params.encode(out); + self.skem_ciphertext.encode(out); + self.ephemeral.encode(out); + self.static_bundle.encode(out); + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct Ik2 { pub header: HandshakeHeader, @@ -66,28 +72,34 @@ impl Ik2 { + TransportParams::WIRE_SIZE + MlKemCiphertext::SIZE + EncryptedMlKemCiphertext::WIRE_SIZE; - - pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { - let out = self.header.encode_into(out); - let out = self.meta.encode_into(out); - let out = self.transport_params.encode_into(out); - let out = codec::write_bytes(out, self.ekem_ciphertext.as_bytes()); - codec::write_bytes(out, self.skem_ciphertext.as_bytes()) - } } -impl codec::WireParse for Ik2 { - fn parse(reader: &mut codec::Reader) -> Result { +impl codec::WireDecode for Ik2 { + fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { - header: reader.parse()?, - meta: reader.parse()?, - transport_params: reader.parse()?, - ekem_ciphertext: reader.parse()?, - skem_ciphertext: reader.parse()?, + header: reader.decode()?, + meta: reader.decode()?, + transport_params: reader.decode()?, + ekem_ciphertext: reader.decode()?, + skem_ciphertext: reader.decode()?, }) } } +impl WireEncode for Ik2 { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.transport_params.encode(out); + self.ekem_ciphertext.encode(out); + self.skem_ciphertext.encode(out); + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum IkStep { Send1, diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index 534a0148..9cc17ba4 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -6,7 +6,7 @@ use super::{ }; use crate::{ codec, ByteSlice, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, - QlIdentity, WireError, + QlIdentity, WireEncode, WireError, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -24,28 +24,34 @@ impl Kk1 { + TransportParams::WIRE_SIZE + MlKemCiphertext::SIZE + EphemeralPublicKey::WIRE_SIZE; - - pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { - let out = self.header.encode_into(out); - let out = self.meta.encode_into(out); - let out = self.transport_params.encode_into(out); - let out = codec::write_bytes(out, self.skem_ciphertext.as_bytes()); - self.ephemeral.encode_into(out) - } } -impl codec::WireParse for Kk1 { - fn parse(reader: &mut codec::Reader) -> Result { +impl codec::WireDecode for Kk1 { + fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { - header: reader.parse()?, - meta: reader.parse()?, - transport_params: reader.parse()?, - skem_ciphertext: reader.parse()?, - ephemeral: reader.parse()?, + header: reader.decode()?, + meta: reader.decode()?, + transport_params: reader.decode()?, + skem_ciphertext: reader.decode()?, + ephemeral: reader.decode()?, }) } } +impl WireEncode for Kk1 { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.transport_params.encode(out); + self.skem_ciphertext.encode(out); + self.ephemeral.encode(out); + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct Kk2 { pub header: HandshakeHeader, @@ -61,28 +67,34 @@ impl Kk2 { + TransportParams::WIRE_SIZE + MlKemCiphertext::SIZE + EncryptedMlKemCiphertext::WIRE_SIZE; - - pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { - let out = self.header.encode_into(out); - let out = self.meta.encode_into(out); - let out = self.transport_params.encode_into(out); - let out = codec::write_bytes(out, self.ekem_ciphertext.as_bytes()); - codec::write_bytes(out, self.skem_ciphertext.as_bytes()) - } } -impl codec::WireParse for Kk2 { - fn parse(reader: &mut codec::Reader) -> Result { +impl codec::WireDecode for Kk2 { + fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { - header: reader.parse()?, - meta: reader.parse()?, - transport_params: reader.parse()?, - ekem_ciphertext: reader.parse()?, - skem_ciphertext: reader.parse()?, + header: reader.decode()?, + meta: reader.decode()?, + transport_params: reader.decode()?, + ekem_ciphertext: reader.decode()?, + skem_ciphertext: reader.decode()?, }) } } +impl WireEncode for Kk2 { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.transport_params.encode(out); + self.ekem_ciphertext.encode(out); + self.skem_ciphertext.encode(out); + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum KkStep { Send1, diff --git a/ql-wire/src/handshake/meta.rs b/ql-wire/src/handshake/meta.rs index f26f3cf8..52e3e870 100644 --- a/ql-wire/src/handshake/meta.rs +++ b/ql-wire/src/handshake/meta.rs @@ -1,4 +1,4 @@ -use crate::{codec, ByteSlice, WireError}; +use crate::{codec, ByteSlice, WireEncode, WireError}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] @@ -10,9 +10,19 @@ pub struct HandshakeMeta { pub valid_until: u64, } -impl codec::WireParse for HandshakeId { - fn parse(reader: &mut codec::Reader) -> Result { - Ok(Self(reader.parse()?)) +impl codec::WireDecode for HandshakeId { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} + +impl WireEncode for HandshakeId { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); } } @@ -26,24 +36,24 @@ impl HandshakeMeta { Ok(()) } } +} - pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { - let out = codec::write_u32(out, self.handshake_id.0); - codec::write_u64(out, self.valid_until) +impl WireEncode for HandshakeMeta { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE } - pub fn encode(&self) -> [u8; Self::WIRE_SIZE] { - let mut out = [0; Self::WIRE_SIZE]; - let _ = self.encode_into(&mut out); - out + fn encode(&self, out: &mut W) { + self.handshake_id.encode(out); + self.valid_until.encode(out); } } -impl codec::WireParse for HandshakeMeta { - fn parse(reader: &mut codec::Reader) -> Result { +impl codec::WireDecode for HandshakeMeta { + fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { - handshake_id: reader.parse()?, - valid_until: reader.parse()?, + handshake_id: reader.decode()?, + valid_until: reader.decode()?, }) } } diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 37dd50c3..13e61bf9 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -1,7 +1,7 @@ use crate::{ codec, ByteSlice, ConnectionId, HandshakeKind, MlKemCiphertext, MlKemKeyPair, MlKemPublicKey, - Nonce, PeerBundle, QlCrypto, SessionKey, WireError, WireParse, ENCRYPTED_MESSAGE_AUTH_SIZE, - XID, + Nonce, PeerBundle, QlCrypto, SessionKey, WireDecode, WireEncode, WireError, + ENCRYPTED_MESSAGE_AUTH_SIZE, XID, }; mod ik; @@ -28,24 +28,24 @@ pub struct HandshakeHeader { impl HandshakeHeader { pub const WIRE_SIZE: usize = XID::SIZE * 2; +} - pub fn encode(&self) -> [u8; Self::WIRE_SIZE] { - let mut out = [0; Self::WIRE_SIZE]; - let _ = self.encode_into(&mut out); - out +impl WireEncode for HandshakeHeader { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE } - pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { - let out = codec::write_bytes(out, &self.sender.0); - codec::write_bytes(out, &self.recipient.0) + fn encode(&self, out: &mut W) { + self.sender.encode(out); + self.recipient.encode(out); } } -impl codec::WireParse for HandshakeHeader { - fn parse(reader: &mut codec::Reader) -> Result { +impl codec::WireDecode for HandshakeHeader { + fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { - sender: reader.parse()?, - recipient: reader.parse()?, + sender: reader.decode()?, + recipient: reader.decode()?, }) } } @@ -57,16 +57,22 @@ pub struct EphemeralPublicKey { impl EphemeralPublicKey { pub const WIRE_SIZE: usize = MlKemPublicKey::SIZE; +} + +impl WireEncode for EphemeralPublicKey { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } - pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { - codec::write_bytes(out, self.mlkem_public_key.as_bytes()) + fn encode(&self, out: &mut W) { + self.mlkem_public_key.encode(out); } } -impl codec::WireParse for EphemeralPublicKey { - fn parse(reader: &mut codec::Reader) -> Result { +impl codec::WireDecode for EphemeralPublicKey { + fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { - mlkem_public_key: reader.parse()?, + mlkem_public_key: reader.decode()?, }) } } @@ -86,9 +92,19 @@ impl EncryptedMlKemCiphertext { } } -impl codec::WireParse for EncryptedMlKemCiphertext { - fn parse(reader: &mut codec::Reader) -> Result { - Ok(Self::new(reader.parse()?)) +impl WireEncode for EncryptedMlKemCiphertext { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.0.as_ref().encode(out); + } +} + +impl codec::WireDecode for EncryptedMlKemCiphertext { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self::new(reader.decode()?)) } } @@ -107,9 +123,19 @@ impl EncryptedPeerBundle { } } -impl codec::WireParse for EncryptedPeerBundle { - fn parse(reader: &mut codec::Reader) -> Result { - Ok(Self::new(reader.parse()?)) +impl WireEncode for EncryptedPeerBundle { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.0.as_ref().encode(out); + } +} + +impl codec::WireDecode for EncryptedPeerBundle { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self::new(reader.decode()?)) } } @@ -293,14 +319,14 @@ fn init_kk_symmetric( responder_bundle: &PeerBundle, ) -> SymmetricState { let mut symmetric = SymmetricState::new(crypto, PROTOCOL_KK); - symmetric.mix_hash(crypto, &initiator_bundle.encode()); - symmetric.mix_hash(crypto, &responder_bundle.encode()); + symmetric.mix_hash(crypto, &initiator_bundle.encode_vec()); + symmetric.mix_hash(crypto, &responder_bundle.encode_vec()); symmetric } fn init_ik_symmetric(crypto: &impl QlCrypto, responder_bundle: &PeerBundle) -> SymmetricState { let mut symmetric = SymmetricState::new(crypto, PROTOCOL_IK); - symmetric.mix_hash(crypto, &responder_bundle.encode()); + symmetric.mix_hash(crypto, &responder_bundle.encode_vec()); symmetric } @@ -326,14 +352,11 @@ fn mix_hash_routed_handshake( meta: &HandshakeMeta, transport_params: TransportParams, ) { - let encoded_header = header.encode(); - let encoded_meta = meta.encode(); - let encoded_transport_params = transport_params.encode(); symmetric.mix_hash(crypto, HANDSHAKE_PREAMBLE_DOMAIN); - symmetric.mix_hash(crypto, &encoded_header); + symmetric.mix_hash(crypto, &header.encode_vec()); symmetric.mix_hash(crypto, &[kind as u8]); - symmetric.mix_hash(crypto, &encoded_meta); - symmetric.mix_hash(crypto, &encoded_transport_params); + symmetric.mix_hash(crypto, &meta.encode_vec()); + symmetric.mix_hash(crypto, &transport_params.encode_vec()); } fn initialize_handshake_meta( @@ -365,7 +388,7 @@ fn encrypt_peer_bundle( symmetric: &mut SymmetricState, bundle: &PeerBundle, ) -> Result { - let ciphertext = symmetric.encrypt_and_hash(crypto, &bundle.encode())?; + let ciphertext = symmetric.encrypt_and_hash(crypto, &bundle.encode_vec())?; let out: Box<[u8; EncryptedPeerBundle::WIRE_SIZE]> = ciphertext.try_into().map_err(|_| WireError::InvalidState)?; Ok(EncryptedPeerBundle::new(out)) @@ -377,7 +400,7 @@ fn decrypt_peer_bundle( bundle: &EncryptedPeerBundle, ) -> Result { let plaintext = symmetric.decrypt_and_hash(crypto, bundle.as_bytes())?; - PeerBundle::parse_bytes(plaintext.as_slice()) + PeerBundle::decode_exact(plaintext.as_slice()) } fn encrypt_mlkem_ciphertext( diff --git a/ql-wire/src/handshake/transport_params.rs b/ql-wire/src/handshake/transport_params.rs index 08e580d2..bfd0d427 100644 --- a/ql-wire/src/handshake/transport_params.rs +++ b/ql-wire/src/handshake/transport_params.rs @@ -1,4 +1,4 @@ -use crate::{codec, ByteSlice, WireError}; +use crate::{codec, ByteSlice, WireEncode, WireError}; /// Session parameters advertised in the handshake #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -9,15 +9,15 @@ pub struct TransportParams { impl TransportParams { pub const WIRE_SIZE: usize = size_of::(); +} - pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { - codec::write_u32(out, self.initial_stream_receive_window) +impl WireEncode for TransportParams { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE } - pub fn encode(&self) -> [u8; Self::WIRE_SIZE] { - let mut out = [0; Self::WIRE_SIZE]; - let _ = self.encode_into(&mut out); - out + fn encode(&self, out: &mut W) { + self.initial_stream_receive_window.encode(out); } } @@ -29,10 +29,10 @@ impl Default for TransportParams { } } -impl codec::WireParse for TransportParams { - fn parse(reader: &mut codec::Reader) -> Result { +impl codec::WireDecode for TransportParams { + fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { - initial_stream_receive_window: reader.parse()?, + initial_stream_receive_window: reader.decode()?, }) } } diff --git a/ql-wire/src/header.rs b/ql-wire/src/header.rs index 760761e8..88764c0a 100644 --- a/ql-wire/src/header.rs +++ b/ql-wire/src/header.rs @@ -1,4 +1,8 @@ -use crate::{codec, ByteSlice, VarInt, VarIntBoundsExceeded, WireError, QL_WIRE_VERSION}; +use ::bytes::BufMut; + +use crate::{ + codec, ByteSlice, VarInt, VarIntBoundsExceeded, WireEncode, WireError, QL_WIRE_VERSION, +}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct SessionHeader { @@ -24,10 +28,6 @@ impl RecordSeq { pub const fn into_inner(self) -> u64 { self.0.into_inner() } - - pub const fn encoded_len(self) -> usize { - self.0.size() - } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -46,38 +46,42 @@ impl ConnectionId { } } -impl codec::WireParse for RecordSeq { - fn parse(reader: &mut codec::Reader) -> Result { - Ok(Self(reader.parse()?)) +impl codec::WireDecode for RecordSeq { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) } } -impl codec::WireParse for ConnectionId { - fn parse(reader: &mut codec::Reader) -> Result { - Ok(Self::from_data(reader.parse()?)) +impl WireEncode for RecordSeq { + fn encoded_len(&self) -> usize { + self.0.size() } -} -impl SessionHeader { - pub const MAX_ENCODED_LEN: usize = ConnectionId::SIZE + RecordSeq::MAX_ENCODED_LEN; - const AAD_DOMAIN: &[u8] = b"ql-wire:session-aad:v1"; - const AAD_RECORD_KIND_SESSION: u8 = 1; + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} - pub fn encoded_len(&self) -> usize { - ConnectionId::SIZE + self.seq.encoded_len() +impl codec::WireDecode for ConnectionId { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self::from_data(reader.decode()?)) } +} - pub fn encode(&self) -> Vec { - let mut out = vec![0; self.encoded_len()]; - let _ = self.encode_into(&mut out); - out +impl WireEncode for ConnectionId { + fn encoded_len(&self) -> usize { + Self::SIZE } - pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { - assert!(out.len() >= self.encoded_len()); - let out = codec::write_bytes(out, self.connection_id.as_bytes()); - codec::write_varint(out, self.seq.0) + fn encode(&self, out: &mut W) { + self.0.encode(out); } +} + +impl SessionHeader { + pub const MAX_ENCODED_LEN: usize = ConnectionId::SIZE + RecordSeq::MAX_ENCODED_LEN; + const AAD_DOMAIN: &[u8] = b"ql-wire:session-aad:v1"; + const AAD_RECORD_KIND_SESSION: u8 = 1; pub fn aad(&self) -> Vec { let aad_len = Self::AAD_DOMAIN.len() @@ -85,21 +89,33 @@ impl SessionHeader { + size_of::() + ConnectionId::SIZE + self.seq.encoded_len(); - let mut aad = vec![0; aad_len]; - let out = codec::write_bytes(&mut aad, Self::AAD_DOMAIN); - let out = codec::write_u8(out, QL_WIRE_VERSION); - let out = codec::write_u8(out, Self::AAD_RECORD_KIND_SESSION); - let out = codec::write_bytes(out, self.connection_id.as_bytes()); - let _ = codec::write_varint(out, self.seq.0); + let mut aad = Vec::with_capacity(aad_len); + aad.put_slice(Self::AAD_DOMAIN); + aad.put_u8(QL_WIRE_VERSION); + aad.put_u8(Self::AAD_RECORD_KIND_SESSION); + self.connection_id.encode(&mut aad); + self.seq.encode(&mut aad); + debug_assert_eq!(aad.len(), aad_len); aad } } -impl codec::WireParse for SessionHeader { - fn parse(reader: &mut codec::Reader) -> Result { +impl WireEncode for SessionHeader { + fn encoded_len(&self) -> usize { + ConnectionId::SIZE + self.seq.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.connection_id.encode(out); + self.seq.encode(out); + } +} + +impl codec::WireDecode for SessionHeader { + fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { - connection_id: reader.parse()?, - seq: reader.parse()?, + connection_id: reader.decode()?, + seq: reader.decode()?, }) } } diff --git a/ql-wire/src/identity.rs b/ql-wire/src/identity.rs index 1328bf5a..72602f70 100644 --- a/ql-wire/src/identity.rs +++ b/ql-wire/src/identity.rs @@ -1,5 +1,6 @@ use crate::{ - codec, ByteSlice, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, QlCrypto, WireError, XID, + codec, ByteSlice, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, QlCrypto, WireEncode, + WireError, XID, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -14,28 +15,28 @@ impl PeerBundle { pub const VERSION: u16 = 1; pub const WIRE_SIZE: usize = size_of::() + XID::SIZE + size_of::() + MlKemPublicKey::SIZE; +} - pub fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { - let out = codec::write_u16(out, self.version); - let out = codec::write_bytes(out, &self.xid.0); - let out = codec::write_u32(out, self.capabilities); - codec::write_bytes(out, self.mlkem_public_key.as_bytes()) +impl WireEncode for PeerBundle { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE } - pub fn encode(&self) -> Vec { - let mut out = vec![0; Self::WIRE_SIZE]; - let _ = self.encode_into(&mut out); - out + fn encode(&self, out: &mut W) { + self.version.encode(out); + self.xid.encode(out); + self.capabilities.encode(out); + self.mlkem_public_key.encode(out); } } -impl codec::WireParse for PeerBundle { - fn parse(reader: &mut codec::Reader) -> Result { +impl codec::WireDecode for PeerBundle { + fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { - version: reader.parse()?, - xid: reader.parse()?, - capabilities: reader.parse()?, - mlkem_public_key: reader.parse()?, + version: reader.decode()?, + xid: reader.decode()?, + capabilities: reader.decode()?, + mlkem_public_key: reader.decode()?, }) } } diff --git a/ql-wire/src/pq.rs b/ql-wire/src/pq.rs index ce87bfb6..327ef7c4 100644 --- a/ql-wire/src/pq.rs +++ b/ql-wire/src/pq.rs @@ -1,4 +1,4 @@ -use crate::{codec, ByteSlice, WireError}; +use crate::{codec, ByteSlice, WireEncode, WireError}; pub const ML_KEM_SUITE_TAG: &[u8] = b"ml-kem-1024"; @@ -41,9 +41,19 @@ impl Drop for SessionKey { } } -impl codec::WireParse for SessionKey { - fn parse(reader: &mut codec::Reader) -> Result { - Ok(Self::from_data(reader.parse()?)) +impl WireEncode for SessionKey { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for SessionKey { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self::from_data(reader.decode()?)) } } @@ -68,9 +78,19 @@ impl Drop for MlKemPublicKey { } } -impl codec::WireParse for MlKemPublicKey { - fn parse(reader: &mut codec::Reader) -> Result { - Ok(Self::new(reader.parse()?)) +impl codec::WireDecode for MlKemPublicKey { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self::new(reader.decode()?)) + } +} + +impl WireEncode for MlKemPublicKey { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.as_ref().encode(out); } } @@ -116,9 +136,19 @@ impl Drop for MlKemCiphertext { } } -impl codec::WireParse for MlKemCiphertext { - fn parse(reader: &mut codec::Reader) -> Result { - Ok(Self::new(reader.parse()?)) +impl codec::WireDecode for MlKemCiphertext { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self::new(reader.decode()?)) + } +} + +impl WireEncode for MlKemCiphertext { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.as_ref().encode(out); } } diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index 42afc52c..002e8e7f 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -2,7 +2,7 @@ use crate::{ codec, encrypted_message::EncryptedMessage, handshake::{Ik1, Ik2, Kk1, Kk2}, - ByteSlice, SessionHeader, WireError, WireParse, QL_WIRE_VERSION, + ByteSlice, SessionHeader, WireEncode, WireError, WireDecode, QL_WIRE_VERSION, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -44,16 +44,26 @@ impl TryFrom for RecordType { } } -impl WireParse for RecordType { - fn parse(reader: &mut codec::Reader) -> Result { - reader.parse::()?.try_into() +impl WireDecode for RecordType { + fn decode(reader: &mut codec::Reader) -> Result { + reader.decode::()?.try_into() } } -impl WireParse for RecordHeader { - fn parse(reader: &mut codec::Reader) -> Result { - let version = reader.parse()?; - let record_type = reader.parse()?; +impl WireEncode for RecordType { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u8(*self as u8); + } +} + +impl WireDecode for RecordHeader { + fn decode(reader: &mut codec::Reader) -> Result { + let version = reader.decode()?; + let record_type = reader.decode()?; Ok(Self { version, record_type, @@ -84,9 +94,19 @@ impl TryFrom for HandshakeKind { } } -impl WireParse for HandshakeKind { - fn parse(reader: &mut codec::Reader) -> Result { - reader.parse::()?.try_into() +impl WireDecode for HandshakeKind { + fn decode(reader: &mut codec::Reader) -> Result { + reader.decode::()?.try_into() + } +} + +impl WireEncode for HandshakeKind { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u8(*self as u8); } } @@ -99,67 +119,66 @@ impl QlHandshakeRecord { Self::Kk2(_) => HandshakeKind::Kk2, } } +} - fn wire_size(&self) -> usize { - match self { - Self::Ik1(_) => Ik1::WIRE_SIZE, - Self::Ik2(_) => Ik2::WIRE_SIZE, - Self::Kk1(_) => Kk1::WIRE_SIZE, - Self::Kk2(_) => Kk2::WIRE_SIZE, - } +impl WireEncode for QlHandshakeRecord { + fn encoded_len(&self) -> usize { + RecordType::Handshake.encoded_len() + + HandshakeKind::Ik1.encoded_len() + + size_of::() + + match self { + Self::Ik1(message) => message.encoded_len(), + Self::Ik2(message) => message.encoded_len(), + Self::Kk1(message) => message.encoded_len(), + Self::Kk2(message) => message.encoded_len(), + } } - fn encode_into<'a>(&self, out: &'a mut [u8]) -> &'a mut [u8] { + fn encode(&self, out: &mut W) { + out.put_u8(QL_WIRE_VERSION); + RecordType::Handshake.encode(out); + self.kind().encode(out); match self { - Self::Ik1(message) => message.encode_into(out), - Self::Ik2(message) => message.encode_into(out), - Self::Kk1(message) => message.encode_into(out), - Self::Kk2(message) => message.encode_into(out), + Self::Ik1(message) => message.encode(out), + Self::Ik2(message) => message.encode(out), + Self::Kk1(message) => message.encode(out), + Self::Kk2(message) => message.encode(out), } } - - pub fn encode(&self) -> Vec { - let mut out = vec![0; 3 + self.wire_size()]; - let rest = codec::write_u8(&mut out, QL_WIRE_VERSION); - let rest = codec::write_u8(rest, RecordType::Handshake as u8); - let rest = codec::write_u8(rest, self.kind() as u8); - let _ = self.encode_into(rest); - out - } } -impl WireParse for QlHandshakeRecord { - fn parse(reader: &mut codec::Reader) -> Result { - let header = reader.parse::()?; +impl WireDecode for QlHandshakeRecord { + fn decode(reader: &mut codec::Reader) -> Result { + let header = reader.decode::()?; if header.version != QL_WIRE_VERSION { return Err(WireError::InvalidPayload); } if header.record_type != RecordType::Handshake { return Err(WireError::InvalidPayload); } - let kind = reader.parse::()?; + let kind = reader.decode::()?; match kind { - HandshakeKind::Ik1 => Ok(Self::Ik1(reader.parse()?)), - HandshakeKind::Ik2 => Ok(Self::Ik2(reader.parse()?)), - HandshakeKind::Kk1 => Ok(Self::Kk1(reader.parse()?)), - HandshakeKind::Kk2 => Ok(Self::Kk2(reader.parse()?)), + HandshakeKind::Ik1 => Ok(Self::Ik1(reader.decode()?)), + HandshakeKind::Ik2 => Ok(Self::Ik2(reader.decode()?)), + HandshakeKind::Kk1 => Ok(Self::Kk1(reader.decode()?)), + HandshakeKind::Kk2 => Ok(Self::Kk2(reader.decode()?)), } } } -impl> QlSessionRecord { - pub fn encode(&self) -> Vec { - let mut out = vec![ - 0; - 2 + self.header.encoded_len() - + EncryptedMessage::<&[u8]>::HEADER_LEN - + self.payload.ciphertext.as_ref().len() - ]; - let rest = codec::write_u8(&mut out, QL_WIRE_VERSION); - let rest = codec::write_u8(rest, RecordType::Session as u8); - let rest = self.header.encode_into(rest); - let _ = self.payload.encode_into(rest); - out +impl> WireEncode for QlSessionRecord { + fn encoded_len(&self) -> usize { + size_of::() + + RecordType::Session.encoded_len() + + self.header.encoded_len() + + self.payload.encoded_len() + } + + fn encode(&self, out: &mut W) { + out.put_u8(QL_WIRE_VERSION); + RecordType::Session.encode(out); + self.header.encode(out); + self.payload.encode(out); } } @@ -172,17 +191,18 @@ impl QlSessionRecord { } } -impl WireParse for QlSessionRecord { - fn parse(reader: &mut codec::Reader) -> Result { - let header = reader.parse::()?; +impl WireDecode for QlSessionRecord { + fn decode(reader: &mut codec::Reader) -> Result { + let header = reader.decode::()?; if header.version != QL_WIRE_VERSION { return Err(WireError::InvalidPayload); } if header.record_type != RecordType::Session { return Err(WireError::InvalidPayload); } - let header = reader.parse::()?; - let payload = EncryptedMessage::parse(reader.take_bytes(reader.remaining_len())?)?; - Ok(Self { header, payload }) + Ok(Self { + header: reader.decode()?, + payload: reader.decode()?, + }) } } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 46549e06..1fdba3f2 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -198,7 +198,7 @@ fn encrypt_record( let pushed = builder.push_frame(frame); debug_assert!(pushed); } - QlSessionRecord::parse_bytes( + QlSessionRecord::decode_exact( builder .encrypt(crypto, header.connection_id, session_key) .as_slice(), @@ -213,8 +213,8 @@ fn peer_bundle_round_trip() { let identity = make_identity(&crypto, 7).with_capabilities(0x55aa_33cc); let bundle = identity.bundle(); - let encoded = bundle.encode(); - let decoded = PeerBundle::parse_bytes(encoded.as_slice()).unwrap(); + let encoded = bundle.encode_vec(); + let decoded = PeerBundle::decode_exact(encoded.as_slice()).unwrap(); assert_eq!(decoded, bundle); } @@ -231,16 +231,16 @@ fn handshake_record_round_trip_supports_ik_and_kk() { }, static_bundle: EncryptedPeerBundle::new(Box::new([13; EncryptedPeerBundle::WIRE_SIZE])), }); - let ik_encoded = ik.encode(); + let ik_encoded = ik.encode_vec(); assert_eq!( - RecordHeader::parse_prefix(ik_encoded.as_slice()).unwrap(), + RecordHeader::decode_bytes(ik_encoded.as_slice()).unwrap(), RecordHeader { version: QL_WIRE_VERSION, record_type: RecordType::Handshake, } ); assert_eq!( - QlHandshakeRecord::parse_bytes(ik_encoded.as_slice()).unwrap(), + QlHandshakeRecord::decode_exact(ik_encoded.as_slice()).unwrap(), ik ); @@ -253,16 +253,16 @@ fn handshake_record_round_trip_supports_ik_and_kk() { mlkem_public_key: MlKemPublicKey::new(Box::new([15; MlKemPublicKey::SIZE])), }, }); - let kk_encoded = kk.encode(); + let kk_encoded = kk.encode_vec(); assert_eq!( - RecordHeader::parse_prefix(kk_encoded.as_slice()).unwrap(), + RecordHeader::decode_bytes(kk_encoded.as_slice()).unwrap(), RecordHeader { version: QL_WIRE_VERSION, record_type: RecordType::Handshake, } ); assert_eq!( - QlHandshakeRecord::parse_bytes(kk_encoded.as_slice()).unwrap(), + QlHandshakeRecord::decode_exact(kk_encoded.as_slice()).unwrap(), kk ); } @@ -686,15 +686,15 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { let session_key = SessionKey::from_data([7; SessionKey::SIZE]); let record = encrypt_record(&crypto, header, &session_key, &body); - let bytes = record.encode(); + let bytes = record.encode_vec(); assert_eq!( - RecordHeader::parse_prefix(bytes.as_slice()).unwrap(), + RecordHeader::decode_bytes(bytes.as_slice()).unwrap(), RecordHeader { version: QL_WIRE_VERSION, record_type: RecordType::Session, } ); - let decoded = QlSessionRecord::parse_bytes(bytes.as_slice()) + let decoded = QlSessionRecord::decode_exact(bytes.as_slice()) .unwrap() .into_owned(); assert_eq!(decoded.header, header); @@ -734,8 +734,8 @@ fn session_varint_fields_expand_at_expected_boundaries() { seq: record_seq(64), }; - assert_eq!(short_header.encode().len(), ConnectionId::SIZE + 1); - assert_eq!(long_header.encode().len(), ConnectionId::SIZE + 2); + assert_eq!(short_header.encode_vec().len(), ConnectionId::SIZE + 1); + assert_eq!(long_header.encode_vec().len(), ConnectionId::SIZE + 2); let frame = StreamData { stream_id: stream_id(64), @@ -743,11 +743,12 @@ fn session_varint_fields_expand_at_expected_boundaries() { fin: true, bytes: b"abc".to_vec(), }; - let mut encoded = vec![0; frame.wire_size()]; - frame.encode_into(&mut encoded); + let encoded = frame.encode_vec(); assert_eq!( - StreamData::parse(encoded.as_slice()).unwrap().into_owned(), + StreamData::decode_exact(encoded.as_slice()) + .unwrap() + .into_owned(), frame ); } @@ -844,17 +845,17 @@ fn protocol_record_size_breakdown() { }, ); - print_size("ql-wire peer bundle", initiator.bundle().encode().len()); + print_size("ql-wire peer bundle", initiator.bundle().encode_vec().len()); print_size("ql-wire mlkem public key", MlKemPublicKey::SIZE); print_size("ql-wire mlkem ciphertext", MlKemCiphertext::SIZE); - print_size("ql-wire pq ik1", ik1.encode().len()); - print_size("ql-wire pq ik2", ik2.encode().len()); - print_size("ql-wire pq kk1", kk1.encode().len()); - print_size("ql-wire pq kk2", kk2.encode().len()); - print_size("ql-wire session ping", session_ping.encode().len()); + print_size("ql-wire pq ik1", ik1.encode_vec().len()); + print_size("ql-wire pq ik2", ik2.encode_vec().len()); + print_size("ql-wire pq kk1", kk1.encode_vec().len()); + print_size("ql-wire pq kk2", kk2.encode_vec().len()); + print_size("ql-wire session ping", session_ping.encode_vec().len()); print_size( "ql-wire session stream empty", - session_stream_empty.encode().len(), + session_stream_empty.encode_vec().len(), ); - print_size("ql-wire session close", session_close.encode().len()); + print_size("ql-wire session close", session_close.encode_vec().len()); } diff --git a/ql-wire/src/xid.rs b/ql-wire/src/xid.rs index 60f1f06f..f7500af6 100644 --- a/ql-wire/src/xid.rs +++ b/ql-wire/src/xid.rs @@ -1,4 +1,4 @@ -use crate::{codec, ByteSlice, WireError}; +use crate::{codec, ByteSlice, WireEncode, WireError}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[repr(transparent)] @@ -8,8 +8,18 @@ impl XID { pub const SIZE: usize = 16; } -impl codec::WireParse for XID { - fn parse(reader: &mut codec::Reader) -> Result { - Ok(Self(reader.parse()?)) +impl WireEncode for XID { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for XID { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) } } From 5952d1efb12dcd2c148d50b269b35316e8b0b379 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 6 Apr 2026 11:37:13 -0400 Subject: [PATCH 122/304] ql-fsm + ql-runtime update --- ql-fsm/src/implementation/core.rs | 11 ++++++----- ql-fsm/src/tests/handshake.rs | 6 +++--- ql-fsm/src/tests/mod.rs | 4 ++-- ql-runtime/src/tests/mod.rs | 4 ++-- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index 38a6e13d..46d19a53 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -1,7 +1,8 @@ use std::time::{Duration, Instant}; use ql_wire::{ - self as wire, CloseTarget, QlCrypto, SessionCloseCode, StreamCloseCode, StreamId, WireParse, + self as wire, CloseTarget, QlCrypto, SessionCloseCode, StreamCloseCode, StreamId, WireDecode, + WireEncode, }; use crate::{ @@ -21,14 +22,14 @@ pub fn receive( crypto: &impl QlCrypto, mut emit: impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { - let header = wire::RecordHeader::parse_prefix(bytes.as_slice())?; + let header = wire::RecordHeader::decode_bytes(bytes.as_slice())?; match header.record_type { wire::RecordType::Handshake => { - let record = wire::QlHandshakeRecord::parse_bytes(bytes.as_slice())?; + let record = wire::QlHandshakeRecord::decode_exact(bytes.as_slice())?; super::handle_handshake_record(fsm, crypto, &record, &mut emit) } wire::RecordType::Session => { - let record = wire::QlSessionRecord::parse_bytes(&mut bytes[..])?; + let record = wire::QlSessionRecord::decode_exact(&mut bytes[..])?; let state = fsm.state.link.connected_mut_or_err()?; if record.header.connection_id != state.transport.rx_connection_id { return Err(QlFsmError::InvalidPayload); @@ -88,7 +89,7 @@ pub fn next_deadline(fsm: &QlFsm) -> Option { pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option { if let Some(record) = fsm.state.handshake.take() { return Some(OutboundWrite { - record: record.encode(), + record: record.encode_vec(), session_write_id: None, }); } diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 766c5cf1..7ae59efb 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -1,6 +1,6 @@ use std::time::Duration; -use ql_wire::{QlHandshakeRecord, WireParse}; +use ql_wire::{QlHandshakeRecord, WireDecode}; use super::*; use crate::{state::LinkState, PeerStatus, QlFsmError, QlFsmEvent}; @@ -192,7 +192,7 @@ fn handshake_timeout_drops_single_ik_attempt_without_resend() { harness.connect_ik_a().unwrap(); harness.drain_events_a(); let first = harness.next_outbound_a().unwrap(); - let first = QlHandshakeRecord::parse_bytes(first.as_slice()).unwrap(); + let first = QlHandshakeRecord::decode_exact(first.as_slice()).unwrap(); assert!(matches!(first, ql_wire::QlHandshakeRecord::Ik1(_))); assert!(harness.next_outbound_a().is_none()); @@ -261,7 +261,7 @@ fn simultaneous_ik_and_kk_connect_prefers_ik() { } fn handshake_id(record: &[u8]) -> ql_wire::HandshakeId { - let record = QlHandshakeRecord::parse_bytes(record).unwrap(); + let record = QlHandshakeRecord::decode_exact(record).unwrap(); match record { ql_wire::QlHandshakeRecord::Ik1(message) => message.meta.handshake_id, ql_wire::QlHandshakeRecord::Ik2(message) => message.meta.handshake_id, diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index d8e0d4ed..253877d7 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -13,7 +13,7 @@ use libcrux_ml_kem::mlkem1024; use ql_wire::{ self, generate_identity, ConnectionId, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, QlAead, QlCrypto, QlHash, QlIdentity, QlKem, QlRandom, SessionKey, - TransportParams, WireParse, ENCRYPTED_MESSAGE_AUTH_SIZE, XID, + TransportParams, WireDecode, ENCRYPTED_MESSAGE_AUTH_SIZE, XID, }; use sha2::{Digest, Sha256}; @@ -443,7 +443,7 @@ fn decrypt_record( record: &[u8], session_key: &SessionKey, ) -> (ql_wire::SessionHeader, ql_wire::SessionRecord) { - let record = ql_wire::QlSessionRecord::parse_bytes(record) + let record = ql_wire::QlSessionRecord::decode_exact(record) .unwrap() .into_owned(); let plaintext = diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 8705a0fe..f463fb4b 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -16,7 +16,7 @@ use ql_fsm::PeerStatus; use ql_wire::{ generate_identity, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, PeerBundle, QlAead, QlHash, QlIdentity, QlKem, QlRandom, RecordHeader, RecordType, SessionKey, - WireParse, XID, + WireDecode, XID, }; use sha2::{Digest, Sha256}; use tokio::task::LocalSet; @@ -382,7 +382,7 @@ impl crate::platform::QlPlatform for TestPlatform { } fn is_encrypted_payload(bytes: &[u8]) -> bool { - RecordHeader::parse_prefix(bytes) + RecordHeader::decode_bytes(bytes) .ok() .is_some_and(|header| header.record_type == RecordType::Session) } From b2d6f0531bc046598292921050cf98b1ffbc727a Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 6 Apr 2026 12:29:56 -0400 Subject: [PATCH 123/304] ql-fsm: range_set --- ql-fsm/src/session/mod.rs | 10 +- ql-fsm/src/session/range_set.rs | 149 ++++++++++++++ ql-fsm/src/session/stream_rx.rs | 298 ++++++---------------------- ql-fsm/src/session/stream_tx.rs | 333 ++++++++++++++------------------ 4 files changed, 354 insertions(+), 436 deletions(-) create mode 100644 ql-fsm/src/session/range_set.rs diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index c6b94c74..6213ecd8 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod received_records; +pub(crate) mod range_set; pub(crate) mod remote_stream_history; pub(crate) mod state; pub(crate) mod stream_parity; @@ -564,7 +565,8 @@ impl SessionFsm { if matches!(stream.outbound_state, OutboundState::Closed) { continue; } - let Some(candidate) = stream.tx.next_range(max_payload, stream.peer_max_offset) else { + let Some(candidate) = stream.tx.poll_transmit(max_payload, stream.peer_max_offset) + else { continue; }; let offset = @@ -578,7 +580,6 @@ impl SessionFsm { let res = builder.push_stream_data(&frame); assert!(res, "builder has capacity"); - stream.tx.mark_in_flight(candidate); if candidate.fin { stream.outbound_state = OutboundState::Finished; } @@ -732,7 +733,6 @@ impl SessionFsm { | StreamRxError::InconsistentFinalOffset | StreamRxError::FinalOffsetBeforeBufferedData | StreamRxError::BeyondFinalOffset - | StreamRxError::TooManyMissingRanges | StreamRxError::OffsetOverflow, ) => { self.fail_session( @@ -1043,7 +1043,7 @@ fn restore_stream_data(streams: &mut IndexMap, frame: Tra if matches!(stream.outbound_state, OutboundState::Closed) { return; } - stream.tx.mark_lost(StreamTxRange { + stream.tx.retransmit(StreamTxRange { offset: frame.offset, len: frame.len, fin: frame.fin, @@ -1066,7 +1066,7 @@ fn acknowledge_tracked_frame( let stream_id = frame.stream_id; if let Some(stream) = streams.get_mut(&stream_id) { let was_full = stream.send_capacity(stream_send_buffer_size) == 0; - stream.tx.mark_acked(StreamTxRange { + stream.tx.ack(StreamTxRange { offset: frame.offset, len: frame.len, fin: frame.fin, diff --git a/ql-fsm/src/session/range_set.rs b/ql-fsm/src/session/range_set.rs new file mode 100644 index 00000000..ac39f23f --- /dev/null +++ b/ql-fsm/src/session/range_set.rs @@ -0,0 +1,149 @@ +use std::{ + cmp, + collections::BTreeMap, + ops::{ + Bound::{Excluded, Included}, + Range, + }, +}; + +/// A set of `u64` values optimized for long runs and random insert/delete. +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct RangeSet(BTreeMap); + +impl RangeSet { + pub fn new() -> Self { + Self::default() + } + + pub fn insert(&mut self, mut x: Range) -> bool { + if x.is_empty() { + return false; + } + + if let Some((start, end)) = self.before(x.start) { + if end >= x.end { + return false; + } else if end >= x.start { + self.0.remove(&start); + x.start = start; + } + } + + while let Some((next_start, next_end)) = self.after(x.start) { + if next_start > x.end { + break; + } + self.0.remove(&next_start); + x.end = cmp::max(next_end, x.end); + } + + self.0.insert(x.start, x.end); + true + } + + pub fn remove(&mut self, x: Range) -> bool { + if x.is_empty() { + return false; + } + + let before = match self.before(x.start) { + Some((start, end)) if end > x.start => { + self.0.remove(&start); + if start < x.start { + self.0.insert(start, x.start); + } + if end > x.end { + self.0.insert(x.end, end); + } + if end >= x.end { + return true; + } + true + } + Some(_) | None => false, + }; + + let mut after = false; + while let Some((start, end)) = self.after(x.start) { + if start >= x.end { + break; + } + after = true; + self.0.remove(&start); + if end > x.end { + self.0.insert(x.end, end); + break; + } + } + + before || after + } + + pub fn min(&self) -> Option { + self.0.first_key_value().map(|(&start, _)| start) + } + + pub fn iter(&self) -> Iter<'_> { + Iter(self.0.iter()) + } + + pub fn peek_min(&self) -> Option> { + let (&start, &end) = self.0.iter().next()?; + Some(start..end) + } + + pub fn pop_min(&mut self) -> Option> { + let result = self.peek_min()?; + self.0.remove(&result.start); + Some(result) + } + + /// find closest range to `x` that begins at or before it + fn before(&self, x: u64) -> Option<(u64, u64)> { + self.0 + .range((Included(0), Included(x))) + .next_back() + .map(|(&start, &end)| (start, end)) + } + + /// find the closest range to `x` that begins after it + fn after(&self, x: u64) -> Option<(u64, u64)> { + self.0 + .range((Excluded(x), Included(u64::MAX))) + .next() + .map(|(&start, &end)| (start, end)) + } +} + +pub struct Iter<'a>(std::collections::btree_map::Iter<'a, u64, u64>); + +impl Iterator for Iter<'_> { + type Item = Range; + + fn next(&mut self) -> Option { + self.0.next().map(|(&start, &end)| start..end) + } +} + +#[cfg(test)] +mod tests { + use super::RangeSet; + + #[test] + fn insert_merges_overlaps() { + let mut set = RangeSet::new(); + assert!(set.insert(10..20)); + assert!(set.insert(30..40)); + assert!(set.insert(15..35)); + assert_eq!(set.iter().collect::>(), vec![10..40]); + } + + #[test] + fn remove_splits_ranges() { + let mut set = RangeSet::new(); + set.insert(10..40); + assert!(set.remove(20..30)); + assert_eq!(set.iter().collect::>(), vec![10..20, 30..40]); + } +} diff --git a/ql-fsm/src/session/stream_rx.rs b/ql-fsm/src/session/stream_rx.rs index 112e08a1..78386229 100644 --- a/ql-fsm/src/session/stream_rx.rs +++ b/ql-fsm/src/session/stream_rx.rs @@ -1,21 +1,17 @@ use std::collections::VecDeque; +use super::range_set::RangeSet; + /// reassembles one stream direction from out-of-order byte ranges. #[derive(Debug, Clone, PartialEq, Eq)] -pub struct StreamRx { +pub struct StreamRx { start_offset: u64, bytes: VecDeque, - missing: MissingRanges, + missing: RangeSet, final_offset: Option, max_buffered: usize, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct MissingRange { - pub start: u64, - pub end: u64, -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct InsertOutcome { pub newly_readable_bytes: usize, @@ -29,7 +25,6 @@ pub enum StreamRxError { InconsistentFinalOffset, FinalOffsetBeforeBufferedData, BeyondFinalOffset, - TooManyMissingRanges, } #[derive(Debug, Clone)] @@ -38,7 +33,7 @@ pub struct StreamReadIter<'a> { back: Option<&'a [u8]>, } -impl StreamRx { +impl StreamRx { pub fn new(max_buffered: usize) -> Self { Self::with_start_offset(0, max_buffered) } @@ -47,7 +42,7 @@ impl StreamRx { Self { start_offset, bytes: VecDeque::new(), - missing: MissingRanges::new(), + missing: RangeSet::new(), final_offset: None, max_buffered, } @@ -70,7 +65,7 @@ impl StreamRx { return 0; } - match self.missing.first() { + match self.missing.peek_min() { Some(range) if range.start <= self.start_offset => 0, Some(range) => usize::try_from(range.start - self.start_offset) .expect("readable prefix exceeds usize"), @@ -141,11 +136,11 @@ impl StreamRx { } self.ensure_within_window(end)?; - self.ensure_buffered(end)?; + self.ensure_buffered(end); #[cfg(test)] self.assert_valid_overlap(effective_offset, effective_bytes); self.write_bytes(effective_offset, effective_bytes); - self.subtract_missing_range(effective_offset, end)?; + self.missing.remove(effective_offset..end); Ok(self.insert_outcome(was_complete, old_readable)) } @@ -194,49 +189,25 @@ impl StreamRx { Ok(()) } - fn ensure_buffered(&mut self, end: u64) -> Result<(), StreamRxError> { + fn ensure_buffered(&mut self, end: u64) { let buffered_end = self.buffered_end_offset(); if end <= buffered_end { - return Ok(()); + return; } let additional = usize::try_from(end - buffered_end).expect("buffer growth exceeds usize"); self.bytes.resize(self.bytes.len() + additional, 0); - self.push_missing_range(MissingRange { - start: buffered_end, - end, - }) - } - - fn push_missing_range(&mut self, range: MissingRange) -> Result<(), StreamRxError> { - if range.start >= range.end { - return Ok(()); - } - - if let Some(last) = self.missing.last_mut() { - if last.end >= range.start { - last.end = last.end.max(range.end); - return Ok(()); - } - } - - self.missing.push(range) + self.missing.insert(buffered_end..end); } #[cfg(test)] fn assert_valid_overlap(&self, offset: u64, bytes: &[u8]) { - let mut gap_index = self.first_gap_index_after(offset); - for (index, byte) in bytes.iter().copied().enumerate() { let absolute = offset + index as u64; - - while gap_index < self.missing.len() && self.missing[gap_index].end <= absolute { - gap_index += 1; - } - - let is_missing = gap_index < self.missing.len() - && self.missing[gap_index].start <= absolute - && absolute < self.missing[gap_index].end; + let is_missing = self + .missing + .iter() + .any(|range| range.start <= absolute && absolute < range.end); if is_missing { continue; } @@ -268,75 +239,6 @@ impl StreamRx { back[..bytes.len() - front_len].copy_from_slice(&bytes[front_len..]); } } - - fn subtract_missing_range(&mut self, start: u64, end: u64) -> Result<(), StreamRxError> { - let first = self.first_gap_index_after(start); - if first == self.missing.len() || self.missing[first].start >= end { - return Ok(()); - } - - let mut last_exclusive = first; - while last_exclusive < self.missing.len() && self.missing[last_exclusive].start < end { - last_exclusive += 1; - } - - let last = last_exclusive - 1; - let keep_left = self.missing[first].start < start; - let keep_right = self.missing[last].end > end; - - if first == last { - let original = self.missing[first]; - match (keep_left, keep_right) { - (true, true) => { - self.missing[first].end = start; - self.missing.insert( - first + 1, - MissingRange { - start: end, - end: original.end, - }, - )?; - } - (true, false) => { - self.missing[first].end = start; - } - (false, true) => { - self.missing[first].start = end; - } - (false, false) => { - self.missing.remove(first); - } - } - return Ok(()); - } - - match (keep_left, keep_right) { - (true, true) => { - self.missing[first].end = start; - self.missing[last].start = end; - self.missing.drain(first + 1..last); - } - (true, false) => { - self.missing[first].end = start; - self.missing.drain(first + 1..last_exclusive); - } - (false, true) => { - self.missing[last].start = end; - self.missing.drain(first..last); - } - (false, false) => { - self.missing.drain(first..last_exclusive); - } - } - - Ok(()) - } - - fn first_gap_index_after(&self, offset: u64) -> usize { - self.missing - .as_slice() - .partition_point(|range| range.end <= offset) - } } impl<'a> Iterator for StreamReadIter<'a> { @@ -359,108 +261,9 @@ impl<'a> Iterator for StreamReadIter<'a> { } } -#[derive(Debug, Clone, PartialEq, Eq)] -struct MissingRanges { - ranges: [MissingRange; N], - len: usize, -} - -impl MissingRanges { - fn new() -> Self { - Self { - ranges: [MissingRange { start: 0, end: 0 }; N], - len: 0, - } - } - - fn as_slice(&self) -> &[MissingRange] { - &self.ranges[..self.len] - } - - fn is_empty(&self) -> bool { - self.len == 0 - } - - fn len(&self) -> usize { - self.len - } - - fn first(&self) -> Option<&MissingRange> { - self.as_slice().first() - } - - fn last_mut(&mut self) -> Option<&mut MissingRange> { - if self.len == 0 { - None - } else { - Some(&mut self.ranges[self.len - 1]) - } - } - - fn push(&mut self, range: MissingRange) -> Result<(), StreamRxError> { - if self.len == N { - return Err(StreamRxError::TooManyMissingRanges); - } - self.ranges[self.len] = range; - self.len += 1; - Ok(()) - } - - fn insert(&mut self, index: usize, range: MissingRange) -> Result<(), StreamRxError> { - if self.len == N { - return Err(StreamRxError::TooManyMissingRanges); - } - for i in (index..self.len).rev() { - self.ranges[i + 1] = self.ranges[i]; - } - self.ranges[index] = range; - self.len += 1; - Ok(()) - } - - fn remove(&mut self, index: usize) -> MissingRange { - let removed = self.ranges[index]; - for i in index + 1..self.len { - self.ranges[i - 1] = self.ranges[i]; - } - self.len -= 1; - self.ranges[self.len] = MissingRange { start: 0, end: 0 }; - removed - } - - fn drain(&mut self, range: std::ops::Range) { - let count = range.end - range.start; - if count == 0 { - return; - } - - for i in range.end..self.len { - self.ranges[i - count] = self.ranges[i]; - } - for i in self.len - count..self.len { - self.ranges[i] = MissingRange { start: 0, end: 0 }; - } - self.len -= count; - } -} - -impl std::ops::Index for MissingRanges { - type Output = MissingRange; - - fn index(&self, index: usize) -> &Self::Output { - &self.as_slice()[index] - } -} - -impl std::ops::IndexMut for MissingRanges { - fn index_mut(&mut self, index: usize) -> &mut Self::Output { - &mut self.ranges[index] - } -} - #[cfg(test)] mod tests { - use super::{InsertOutcome, MissingRange, StreamRx, StreamRxError}; + use super::{InsertOutcome, StreamRx, StreamRxError}; pub fn copy_readable(rx: &StreamRx) -> Vec { let readable = rx.readable_len(); @@ -473,7 +276,7 @@ mod tests { #[test] fn contiguous_insert_becomes_readable_and_complete() { - let mut rx = StreamRx::<8>::new(64); + let mut rx = StreamRx::new(64); let outcome = rx.insert(0, true, b"hello").unwrap(); @@ -493,7 +296,7 @@ mod tests { #[test] fn out_of_order_insert_tracks_missing_ranges_until_gap_is_filled() { - let mut rx = StreamRx::<8>::new(64); + let mut rx = StreamRx::new(64); let first = rx.insert(5, true, b" world").unwrap(); assert_eq!( @@ -503,7 +306,7 @@ mod tests { became_complete: false, } ); - assert_eq!(rx.missing.as_slice(), &[MissingRange { start: 0, end: 5 }]); + assert_eq!(rx.missing.iter().collect::>(), vec![0..5]); assert_eq!(rx.readable_len(), 0); let second = rx.insert(0, false, b"hello").unwrap(); @@ -521,7 +324,7 @@ mod tests { #[test] fn duplicate_insert_is_ignored_if_bytes_match() { - let mut rx = StreamRx::<8>::new(64); + let mut rx = StreamRx::new(64); rx.insert(0, false, b"hello").unwrap(); let duplicate = rx.insert(0, false, b"hello").unwrap(); @@ -539,7 +342,7 @@ mod tests { #[test] #[should_panic(expected = "conflicting overlap at stream offset 3")] fn conflicting_overlap_panics_in_test_builds() { - let mut rx = StreamRx::<8>::new(64); + let mut rx = StreamRx::new(64); rx.insert(0, false, b"abcdef").unwrap(); rx.insert(3, false, b"xyz").unwrap(); @@ -547,7 +350,7 @@ mod tests { #[test] fn consume_advances_start_offset_and_trims_old_prefix() { - let mut rx = StreamRx::<8>::new(64); + let mut rx = StreamRx::new(64); rx.insert(0, false, b"abcd").unwrap(); rx.consume(2); @@ -567,32 +370,15 @@ mod tests { assert!(rx.is_complete()); } - #[test] - fn insert_rejects_when_missing_range_budget_is_exhausted() { - let mut rx = StreamRx::<2>::new(64); - - rx.insert(1, false, b"a").unwrap(); - rx.insert(3, false, b"b").unwrap(); - let error = rx.insert(5, false, b"c").unwrap_err(); - - assert_eq!(error, StreamRxError::TooManyMissingRanges); - } - #[test] fn insert_can_fill_multiple_gaps_without_rebuilding_state() { - let mut rx = StreamRx::<8>::new(64); + let mut rx = StreamRx::new(64); rx.insert(0, false, b"ab").unwrap(); rx.insert(4, false, b"ef").unwrap(); rx.insert(8, true, b"ij").unwrap(); - assert_eq!( - rx.missing.as_slice(), - &[ - MissingRange { start: 2, end: 4 }, - MissingRange { start: 6, end: 8 }, - ] - ); + assert_eq!(rx.missing.iter().collect::>(), vec![2..4, 6..8]); let outcome = rx.insert(2, false, b"cdefgh").unwrap(); @@ -608,4 +394,38 @@ mod tests { assert_eq!(copy_readable(&rx), b"abcdefghij"); assert!(rx.is_complete()); } + + #[test] + fn heavily_fragmented_inserts_stay_valid() { + let mut rx = StreamRx::new(64); + + rx.insert(1, false, b"b").unwrap(); + rx.insert(3, false, b"d").unwrap(); + rx.insert(5, false, b"f").unwrap(); + rx.insert(7, false, b"h").unwrap(); + rx.insert(9, true, b"j").unwrap(); + + assert_eq!( + rx.missing.iter().collect::>(), + vec![0..1, 2..3, 4..5, 6..7, 8..9] + ); + + let outcome = rx.insert(0, false, b"abcdefghi").unwrap(); + assert_eq!( + outcome, + InsertOutcome { + newly_readable_bytes: 10, + became_complete: true, + } + ); + assert_eq!(copy_readable(&rx), b"abcdefghij"); + assert!(rx.is_complete()); + } + + #[test] + fn out_of_window_insert_is_rejected() { + let mut rx = StreamRx::new(4); + let error = rx.insert(5, false, b"a").unwrap_err(); + assert_eq!(error, StreamRxError::OutOfWindow); + } } diff --git a/ql-fsm/src/session/stream_tx.rs b/ql-fsm/src/session/stream_tx.rs index 8b514d0b..67e34e7a 100644 --- a/ql-fsm/src/session/stream_tx.rs +++ b/ql-fsm/src/session/stream_tx.rs @@ -1,12 +1,16 @@ -use std::collections::VecDeque; +use std::{collections::VecDeque, ops::Range}; use ql_wire::RangedByteChunks; +use super::range_set::RangeSet; + #[derive(Debug, Clone, PartialEq, Eq)] pub struct StreamTx { bytes: VecDeque, base_offset: u64, - segments: VecDeque, + unsent: u64, + acked: RangeSet, + retransmits: RangeSet, final_offset: Option, } @@ -16,23 +20,10 @@ struct TrackedFinalOffset { state: SendState, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -struct SendSegment { - offset: u64, - len: usize, - state: SendState, -} - -impl SendSegment { - fn end_offset(&self) -> u64 { - self.offset + self.len as u64 - } -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum SendState { Unsent, - InFlight, + Sent, Lost, Acked, } @@ -49,7 +40,9 @@ impl StreamTx { Self { bytes: VecDeque::new(), base_offset: 0, - segments: VecDeque::new(), + unsent: 0, + acked: RangeSet::new(), + retransmits: RangeSet::new(), final_offset: None, } } @@ -63,7 +56,7 @@ impl StreamTx { } pub fn is_empty(&self) -> bool { - self.bytes.is_empty() && self.segments.is_empty() && self.final_offset.is_none() + self.bytes.is_empty() && self.final_offset.is_none() } pub fn append(&mut self, bytes: &[u8]) { @@ -71,20 +64,7 @@ impl StreamTx { return; } - let start = self.end_offset(); self.bytes.extend(bytes); - if let Some(last) = self.segments.back_mut() { - if last.state == SendState::Unsent && last.end_offset() == start { - last.len += bytes.len(); - return; - } - } - - self.segments.push_back(SendSegment { - offset: start, - len: bytes.len(), - state: SendState::Unsent, - }); } pub fn queue_fin(&mut self) { @@ -94,47 +74,50 @@ impl StreamTx { }); } - pub fn next_range(&self, max_payload: usize, peer_max_offset: u64) -> Option { - let mut unsent = None; - for segment in &self.segments { - if !matches!(segment.state, SendState::Lost | SendState::Unsent) { - continue; - } - - let credit_remaining = peer_max_offset.saturating_sub(segment.offset); - let credit_remaining = usize::try_from(credit_remaining).unwrap_or(usize::MAX); - let len = segment.len.min(max_payload).min(credit_remaining); - if len == 0 { - continue; - } - - let fin = self.final_offset.is_some_and(|final_offset| { - matches!(final_offset.state, SendState::Lost | SendState::Unsent) - && final_offset.offset == segment.offset + len as u64 - }); - let range = StreamTxRange { - offset: segment.offset, - len, - fin, - }; - - if segment.state == SendState::Lost { - return Some(range); - } - if unsent.is_none() { - unsent = Some(range); + pub fn poll_transmit( + &mut self, + max_payload: usize, + peer_max_offset: u64, + ) -> Option { + if let Some(range) = self.retransmits.peek_min() { + let end = range + .end + .min(range.start.saturating_add(max_payload as u64)) + .min(peer_max_offset); + if end > range.start { + let range = self.retransmits.pop_min().unwrap(); + if end < range.end { + self.retransmits.insert(end..range.end); + } + return Some(StreamTxRange { + offset: range.start, + len: usize::try_from(end - range.start).unwrap(), + fin: self.poll_fin(end), + }); } } - if let Some(range) = unsent { - return Some(range); + if self.unsent < self.end_offset() { + let end = self + .end_offset() + .min(self.unsent.saturating_add(max_payload as u64)) + .min(peer_max_offset); + if end > self.unsent { + let start = self.unsent; + self.unsent = end; + return Some(StreamTxRange { + offset: start, + len: usize::try_from(end - start).unwrap(), + fin: self.poll_fin(end), + }); + } } let final_offset = self.final_offset.filter(|final_offset| { matches!(final_offset.state, SendState::Lost | SendState::Unsent) && final_offset.offset <= peer_max_offset })?; - + self.final_offset.as_mut().unwrap().state = SendState::Sent; Some(StreamTxRange { offset: final_offset.offset, len: 0, @@ -151,121 +134,113 @@ impl StreamTx { } } - pub fn mark_in_flight(&mut self, range: StreamTxRange) { - self.set_segment_state(range.offset, range.len, SendState::InFlight); - if range.fin { - if let Some(final_offset) = self.final_offset.as_mut() { - if final_offset.state != SendState::Acked { - final_offset.state = SendState::InFlight; - } - } + pub fn retransmit(&mut self, range: StreamTxRange) { + if let Some(range) = self.clamp_sent_range(range.offset, range.len) { + Self::insert_not_acked(&self.acked, &mut self.retransmits, range); } - } - - pub fn mark_lost(&mut self, range: StreamTxRange) { - self.set_segment_state(range.offset, range.len, SendState::Lost); if range.fin { - if let Some(final_offset) = self.final_offset.as_mut() { - if final_offset.state != SendState::Acked { - final_offset.state = SendState::Lost; - } - } + self.mark_fin_lost(); } } - pub fn mark_acked(&mut self, range: StreamTxRange) { - self.set_segment_state(range.offset, range.len, SendState::Acked); + pub fn ack(&mut self, range: StreamTxRange) { + if let Some(range) = self.clamp_buffered_range(range.offset, range.len) { + self.acked.insert(range.clone()); + self.retransmits.remove(range); + self.trim_acked_prefix(); + } if range.fin { if let Some(final_offset) = self.final_offset.as_mut() { final_offset.state = SendState::Acked; } } - self.trim_acked_prefix(); + self.trim_acked_fin(); } pub fn clear(&mut self) { self.bytes.clear(); - self.segments.clear(); + self.unsent = self.base_offset; + self.acked = RangeSet::new(); + self.retransmits = RangeSet::new(); self.final_offset = None; } - fn set_segment_state(&mut self, offset: u64, len: usize, state: SendState) { + fn clamp_buffered_range(&self, offset: u64, len: usize) -> Option> { if len == 0 { - return; + return None; } - let end = offset + len as u64; - - let Some(index) = self - .segments - .iter() - .position(|segment| segment.offset <= offset && end <= segment.end_offset()) - else { - return; - }; + let start = offset.max(self.base_offset); + let end = offset.saturating_add(len as u64).min(self.end_offset()); + (start < end).then_some(start..end) + } - if self.segments[index].state == SendState::Acked && state != SendState::Acked { - return; + fn clamp_sent_range(&self, offset: u64, len: usize) -> Option> { + if len == 0 { + return None; } + let start = offset.max(self.base_offset); + let end = offset.saturating_add(len as u64).min(self.unsent); + (start < end).then_some(start..end) + } - let segment = self.segments.remove(index).unwrap(); - let mut insert_index = index; - - if segment.offset < offset { - self.segments.insert( - insert_index, - SendSegment { - offset: segment.offset, - len: usize::try_from(offset - segment.offset).unwrap(), - state: segment.state, - }, - ); - insert_index += 1; + fn insert_not_acked(acked_set: &RangeSet, target: &mut RangeSet, range: Range) { + let mut cursor = range.start; + for acked in acked_set.iter() { + if acked.end <= cursor { + continue; + } + if acked.start >= range.end { + break; + } + if cursor < acked.start { + target.insert(cursor..acked.start.min(range.end)); + } + cursor = cursor.max(acked.end); + if cursor >= range.end { + break; + } } - - self.segments - .insert(insert_index, SendSegment { offset, len, state }); - insert_index += 1; - - if end < segment.end_offset() { - self.segments.insert( - insert_index, - SendSegment { - offset: end, - len: usize::try_from(segment.end_offset() - end).unwrap(), - state: segment.state, - }, - ); + if cursor < range.end { + target.insert(cursor..range.end); } + } - self.merge_adjacent_segments(); + fn poll_fin(&mut self, offset: u64) -> bool { + let Some(final_offset) = self.final_offset.as_mut() else { + return false; + }; + if matches!(final_offset.state, SendState::Lost | SendState::Unsent) + && final_offset.offset == offset + { + final_offset.state = SendState::Sent; + true + } else { + false + } } - fn merge_adjacent_segments(&mut self) { - let mut index = 1; - while index < self.segments.len() { - let prev = self.segments[index - 1]; - let next = self.segments[index]; - if prev.state == next.state && prev.end_offset() == next.offset { - self.segments[index - 1].len += next.len; - self.segments.remove(index); - } else { - index += 1; + fn mark_fin_lost(&mut self) { + if let Some(final_offset) = self.final_offset.as_mut() { + if final_offset.state != SendState::Acked { + final_offset.state = SendState::Lost; } } } fn trim_acked_prefix(&mut self) { - while matches!( - self.segments.front(), - Some(segment) if segment.state == SendState::Acked - ) { - let len = self.segments.pop_front().unwrap().len; + while self.acked.min() == Some(self.base_offset) { + let prefix = self.acked.pop_min().unwrap(); + let len = usize::try_from(prefix.end - prefix.start).unwrap(); self.bytes.drain(..len); - self.base_offset = self.base_offset.saturating_add(len as u64); + self.base_offset = prefix.end; } + } + fn trim_acked_fin(&mut self) { if self.final_offset.is_some_and(|final_offset| { - final_offset.state == SendState::Acked && final_offset.offset == self.base_offset + final_offset.state == SendState::Acked + && final_offset.offset == self.base_offset + && self.bytes.is_empty() }) { self.final_offset = None; } @@ -283,7 +258,7 @@ mod tests { tx.append(b"de"); assert_eq!( - tx.next_range(8, u64::MAX), + tx.poll_transmit(8, u64::MAX), Some(StreamTxRange { offset: 0, len: 5, @@ -297,12 +272,11 @@ mod tests { let mut tx = StreamTx::new(); tx.append(b"abcdef"); - let first = tx.next_range(3, u64::MAX).unwrap(); - tx.mark_in_flight(first); - tx.mark_lost(first); + let first = tx.poll_transmit(3, u64::MAX).unwrap(); + tx.retransmit(first); assert_eq!( - tx.next_range(3, u64::MAX), + tx.poll_transmit(3, u64::MAX), Some(StreamTxRange { offset: 0, len: 3, @@ -316,12 +290,11 @@ mod tests { let mut tx = StreamTx::new(); tx.append(b"abcdef"); - let first = tx.next_range(3, u64::MAX).unwrap(); - tx.mark_in_flight(first); - tx.mark_acked(first); + let first = tx.poll_transmit(3, u64::MAX).unwrap(); + tx.ack(first); assert_eq!( - tx.next_range(3, u64::MAX), + tx.poll_transmit(3, u64::MAX), Some(StreamTxRange { offset: 3, len: 3, @@ -335,7 +308,7 @@ mod tests { let mut tx = StreamTx::new(); tx.queue_fin(); - let range = tx.next_range(16, u64::MAX).unwrap(); + let range = tx.poll_transmit(16, u64::MAX).unwrap(); assert_eq!( range, StreamTxRange { @@ -345,8 +318,7 @@ mod tests { } ); - tx.mark_in_flight(range); - tx.mark_acked(range); + tx.ack(range); assert!(tx.is_empty()); } @@ -355,21 +327,14 @@ mod tests { let mut tx = StreamTx::new(); tx.append(b"abcdefghijkl"); - let first = tx.next_range(4, u64::MAX).unwrap(); - tx.mark_in_flight(first); - let second = tx.next_range(4, u64::MAX).unwrap(); - tx.mark_in_flight(second); - let third = tx.next_range(4, u64::MAX).unwrap(); - tx.mark_in_flight(third); - - tx.mark_lost(StreamTxRange { - offset: 4, - len: 4, - fin: false, - }); + let _first = tx.poll_transmit(4, u64::MAX).unwrap(); + let second = tx.poll_transmit(4, u64::MAX).unwrap(); + let _third = tx.poll_transmit(4, u64::MAX).unwrap(); + + tx.retransmit(second); assert_eq!( - tx.next_range(4, u64::MAX), + tx.poll_transmit(4, u64::MAX), Some(StreamTxRange { offset: 4, len: 4, @@ -383,33 +348,17 @@ mod tests { let mut tx = StreamTx::new(); tx.append(b"abcdefghijklmnop"); - let first = tx.next_range(4, u64::MAX).unwrap(); - tx.mark_in_flight(first); - let second = tx.next_range(4, u64::MAX).unwrap(); - tx.mark_in_flight(second); - let third = tx.next_range(4, u64::MAX).unwrap(); - tx.mark_in_flight(third); - let fourth = tx.next_range(4, u64::MAX).unwrap(); - tx.mark_in_flight(fourth); - - tx.mark_acked(StreamTxRange { - offset: 4, - len: 4, - fin: false, - }); - tx.mark_lost(StreamTxRange { - offset: 4, - len: 4, - fin: false, - }); - tx.mark_lost(StreamTxRange { - offset: 8, - len: 4, - fin: false, - }); + let _first = tx.poll_transmit(4, u64::MAX).unwrap(); + let second = tx.poll_transmit(4, u64::MAX).unwrap(); + let third = tx.poll_transmit(4, u64::MAX).unwrap(); + let _fourth = tx.poll_transmit(4, u64::MAX).unwrap(); + + tx.ack(second); + tx.retransmit(second); + tx.retransmit(third); assert_eq!( - tx.next_range(4, u64::MAX), + tx.poll_transmit(4, u64::MAX), Some(StreamTxRange { offset: 8, len: 4, From d4a33361c61abb7d63e92347d31ce0c56d164b9b Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 6 Apr 2026 13:10:35 -0400 Subject: [PATCH 124/304] ql-fsm: stream_tx with byte chunks --- Cargo.lock | 1 + ql-fsm/Cargo.toml | 1 + ql-fsm/src/implementation/core.rs | 3 +- ql-fsm/src/lib.rs | 9 ++- ql-fsm/src/session/mod.rs | 12 ++- ql-fsm/src/session/range_set.rs | 4 + ql-fsm/src/session/stream_tx.rs | 125 +++++++++++++++++++++++++----- ql-fsm/src/session/tests.rs | 24 +++--- ql-fsm/src/tests/proptest.rs | 7 +- ql-fsm/src/tests/session.rs | 32 +++++--- ql-runtime/src/driver/mod.rs | 4 +- ql-wire/src/bytes.rs | 95 +---------------------- 12 files changed, 173 insertions(+), 144 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c9bc20e7..d719520e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2183,6 +2183,7 @@ dependencies = [ name = "ql-fsm" version = "0.1.0" dependencies = [ + "bytes", "indexmap", "libcrux-aesgcm", "libcrux-ml-kem", diff --git a/ql-fsm/Cargo.toml b/ql-fsm/Cargo.toml index 14b6c050..45ccca28 100644 --- a/ql-fsm/Cargo.toml +++ b/ql-fsm/Cargo.toml @@ -6,6 +6,7 @@ description = "Quantum Link synchronous finite state machine" license = "Proprietary" [dependencies] +bytes = "1" indexmap = "2" ql-wire = { path = "../ql-wire" } diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index 46d19a53..c6919f32 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -1,5 +1,6 @@ use std::time::{Duration, Instant}; +use bytes::Bytes; use ql_wire::{ self as wire, CloseTarget, QlCrypto, SessionCloseCode, StreamCloseCode, StreamId, WireDecode, WireEncode, @@ -140,7 +141,7 @@ pub fn open_stream(fsm: &mut QlFsm) -> Result { pub fn write_stream( fsm: &mut QlFsm, stream_id: StreamId, - bytes: &[u8], + bytes: &mut Bytes, ) -> Result { let state = fsm.state.link.connected_mut_or_err()?; Ok(state.session.write_stream(stream_id, bytes)?) diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 29bbe1b5..47015351 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -27,6 +27,7 @@ mod tests; use std::time::{Duration, Instant}; +pub use bytes::Bytes; pub use error::QlFsmError; use ql_wire::{ CloseTarget, PeerBundle, QlCrypto, QlIdentity, SessionClose, SessionCloseCode, StreamClose, @@ -253,8 +254,12 @@ impl QlFsm { implementation::open_stream(self) } - /// queues bytes for an open stream and returns the accepted count - pub fn write_stream(&mut self, stream_id: StreamId, bytes: &[u8]) -> Result { + /// queues owned bytes for an open stream and returns the accepted count + pub fn write_stream( + &mut self, + stream_id: StreamId, + bytes: &mut Bytes, + ) -> Result { implementation::write_stream(self, stream_id, bytes) } diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 6213ecd8..9170307d 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -1,5 +1,5 @@ -pub(crate) mod received_records; pub(crate) mod range_set; +pub(crate) mod received_records; pub(crate) mod remote_stream_history; pub(crate) mod state; pub(crate) mod stream_parity; @@ -12,6 +12,7 @@ mod tests; use std::time::{Duration, Instant}; +use bytes::Bytes; use indexmap::{map::Entry, IndexMap}; use ql_wire::{ CloseTarget, RecordAck, RecordSeq, SessionClose, SessionCloseCode, SessionFrame, @@ -151,8 +152,11 @@ impl SessionFsm { pub fn write_stream( &mut self, stream_id: StreamId, - bytes: &[u8], + bytes: &mut Bytes, ) -> Result { + // TODO: consider a `BytesSource` abstraction here so callers can provide + // different chunk sources while preserving partial-accept semantics and deferring any + // required copying until capacity is known self.ensure_session_open()?; let stream = self .state @@ -166,7 +170,9 @@ impl SessionFsm { let accepted = bytes .len() .min(stream.send_capacity(self.config.stream_send_buffer_size)); - stream.tx.append(&bytes[..accepted]); + if accepted > 0 { + stream.tx.append(bytes.split_to(accepted)); + } Ok(accepted) } diff --git a/ql-fsm/src/session/range_set.rs b/ql-fsm/src/session/range_set.rs index ac39f23f..4ef6bdbd 100644 --- a/ql-fsm/src/session/range_set.rs +++ b/ql-fsm/src/session/range_set.rs @@ -84,6 +84,10 @@ impl RangeSet { self.0.first_key_value().map(|(&start, _)| start) } + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + pub fn iter(&self) -> Iter<'_> { Iter(self.0.iter()) } diff --git a/ql-fsm/src/session/stream_tx.rs b/ql-fsm/src/session/stream_tx.rs index 67e34e7a..81ea6575 100644 --- a/ql-fsm/src/session/stream_tx.rs +++ b/ql-fsm/src/session/stream_tx.rs @@ -1,12 +1,14 @@ use std::{collections::VecDeque, ops::Range}; -use ql_wire::RangedByteChunks; +use bytes::{Buf, Bytes}; +use ql_wire::ByteChunks; use super::range_set::RangeSet; #[derive(Debug, Clone, PartialEq, Eq)] pub struct StreamTx { - bytes: VecDeque, + chunks: VecDeque, + buffered_len: usize, base_offset: u64, unsent: u64, acked: RangeSet, @@ -35,10 +37,74 @@ pub struct StreamTxRange { pub fin: bool, } +#[derive(Debug, Clone, Copy)] +pub struct StreamTxBytes<'a> { + inner: &'a VecDeque, + offset: usize, + len: usize, +} + +pub struct StreamTxBytesIter<'a> { + inner: std::collections::vec_deque::Iter<'a, Bytes>, + skip: usize, + remaining: usize, +} + +impl ByteChunks for StreamTxBytes<'_> { + type Chunks<'a> + = StreamTxBytesIter<'a> + where + Self: 'a; + + fn len(&self) -> usize { + self.inner + .iter() + .map(Bytes::len) + .sum::() + .saturating_sub(self.offset) + .min(self.len) + } + + fn chunks(&self) -> Self::Chunks<'_> { + StreamTxBytesIter { + inner: self.inner.iter(), + skip: self.offset, + remaining: self.len(), + } + } +} + +impl<'a> Iterator for StreamTxBytesIter<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option { + while self.remaining > 0 { + let chunk = self.inner.next()?; + if self.skip >= chunk.len() { + self.skip -= chunk.len(); + continue; + } + + let chunk = &chunk[self.skip..]; + self.skip = 0; + if chunk.is_empty() { + continue; + } + + let len = chunk.len().min(self.remaining); + self.remaining -= len; + return Some(&chunk[..len]); + } + + None + } +} + impl StreamTx { pub fn new() -> Self { Self { - bytes: VecDeque::new(), + chunks: VecDeque::new(), + buffered_len: 0, base_offset: 0, unsent: 0, acked: RangeSet::new(), @@ -48,23 +114,24 @@ impl StreamTx { } pub fn buffered_len(&self) -> usize { - self.bytes.len() + self.buffered_len } pub fn end_offset(&self) -> u64 { - self.base_offset + self.bytes.len() as u64 + self.base_offset + self.buffered_len as u64 } pub fn is_empty(&self) -> bool { - self.bytes.is_empty() && self.final_offset.is_none() + self.buffered_len == 0 && self.final_offset.is_none() } - pub fn append(&mut self, bytes: &[u8]) { + pub fn append(&mut self, bytes: Bytes) { if bytes.is_empty() { return; } - self.bytes.extend(bytes); + self.buffered_len += bytes.len(); + self.chunks.push_back(bytes); } pub fn queue_fin(&mut self) { @@ -125,10 +192,10 @@ impl StreamTx { }) } - pub fn ranged_bytes(&self, range: StreamTxRange) -> RangedByteChunks<&VecDeque> { + pub fn ranged_bytes(&self, range: StreamTxRange) -> StreamTxBytes<'_> { let offset = usize::try_from(range.offset - self.base_offset).unwrap(); - RangedByteChunks { - inner: &self.bytes, + StreamTxBytes { + inner: &self.chunks, offset, len: range.len, } @@ -158,7 +225,8 @@ impl StreamTx { } pub fn clear(&mut self) { - self.bytes.clear(); + self.chunks.clear(); + self.buffered_len = 0; self.unsent = self.base_offset; self.acked = RangeSet::new(); self.retransmits = RangeSet::new(); @@ -230,8 +298,21 @@ impl StreamTx { fn trim_acked_prefix(&mut self) { while self.acked.min() == Some(self.base_offset) { let prefix = self.acked.pop_min().unwrap(); - let len = usize::try_from(prefix.end - prefix.start).unwrap(); - self.bytes.drain(..len); + let mut to_advance = usize::try_from(prefix.end - prefix.start).unwrap(); + self.buffered_len -= to_advance; + while to_advance > 0 { + let front = self + .chunks + .front_mut() + .expect("expected buffered chunks for acked prefix"); + if front.len() <= to_advance { + to_advance -= front.len(); + self.chunks.pop_front(); + } else { + front.advance(to_advance); + to_advance = 0; + } + } self.base_offset = prefix.end; } } @@ -240,7 +321,7 @@ impl StreamTx { if self.final_offset.is_some_and(|final_offset| { final_offset.state == SendState::Acked && final_offset.offset == self.base_offset - && self.bytes.is_empty() + && self.buffered_len == 0 }) { self.final_offset = None; } @@ -249,13 +330,15 @@ impl StreamTx { #[cfg(test)] mod tests { + use bytes::Bytes; + use super::{StreamTx, StreamTxRange}; #[test] fn append_tracks_unsent_tail() { let mut tx = StreamTx::new(); - tx.append(b"abc"); - tx.append(b"de"); + tx.append(Bytes::from_static(b"abc")); + tx.append(Bytes::from_static(b"de")); assert_eq!( tx.poll_transmit(8, u64::MAX), @@ -270,7 +353,7 @@ mod tests { #[test] fn lost_range_is_selected_before_unsent_tail() { let mut tx = StreamTx::new(); - tx.append(b"abcdef"); + tx.append(Bytes::from_static(b"abcdef")); let first = tx.poll_transmit(3, u64::MAX).unwrap(); tx.retransmit(first); @@ -288,7 +371,7 @@ mod tests { #[test] fn acked_prefix_is_trimmed() { let mut tx = StreamTx::new(); - tx.append(b"abcdef"); + tx.append(Bytes::from_static(b"abcdef")); let first = tx.poll_transmit(3, u64::MAX).unwrap(); tx.ack(first); @@ -325,7 +408,7 @@ mod tests { #[test] fn subrange_updates_split_merged_in_flight_segments() { let mut tx = StreamTx::new(); - tx.append(b"abcdefghijkl"); + tx.append(Bytes::from_static(b"abcdefghijkl")); let _first = tx.poll_transmit(4, u64::MAX).unwrap(); let second = tx.poll_transmit(4, u64::MAX).unwrap(); @@ -346,7 +429,7 @@ mod tests { #[test] fn acked_subrange_is_not_reopened_by_stale_timeout() { let mut tx = StreamTx::new(); - tx.append(b"abcdefghijklmnop"); + tx.append(Bytes::from_static(b"abcdefghijklmnop")); let _first = tx.poll_transmit(4, u64::MAX).unwrap(); let second = tx.poll_transmit(4, u64::MAX).unwrap(); diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 239bf017..ccd9be3e 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -1,5 +1,6 @@ use std::time::{Duration, Instant}; +use bytes::Bytes; use ql_wire::{ CloseTarget, RecordAck, RecordSeq, SessionFrame, SessionRecord, SessionRecordBuilder, StreamClose, StreamCloseCode, StreamData, StreamId, VarInt, XID, @@ -20,6 +21,11 @@ fn offset(value: u64) -> VarInt { VarInt::from_u64(value).unwrap() } +fn write_stream_bytes(fsm: &mut SessionFsm, stream_id: StreamId, bytes: &[u8]) -> usize { + let mut bytes = Bytes::copy_from_slice(bytes); + fsm.write_stream(stream_id, &mut bytes).unwrap() +} + fn read_stream_all(fsm: &mut SessionFsm, stream_id: StreamId) -> Vec { let mut out = Vec::new(); loop { @@ -70,10 +76,10 @@ fn outbound_record_seq_increments_monotonically() { let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = fsm.open_stream().unwrap(); - assert_eq!(fsm.write_stream(stream_id, b"one").unwrap(), 3); + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"one"), 3); let (first_seq, _) = next_outbound(&mut fsm, now).unwrap(); - assert_eq!(fsm.write_stream(stream_id, b"two").unwrap(), 3); + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"two"), 3); let (second_seq, _) = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); assert_eq!(first_seq, seq(0)); @@ -86,7 +92,7 @@ fn retransmit_uses_new_record_seq() { let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = fsm.open_stream().unwrap(); - assert_eq!(fsm.write_stream(stream_id, b"retry").unwrap(), 5); + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"retry"), 5); let (first_seq, first) = next_outbound(&mut fsm, now).unwrap(); fsm.on_timer(now + Duration::from_millis(200), |_| {}); @@ -111,8 +117,8 @@ fn lost_record_on_one_stream_does_not_block_another_stream() { let payload_a = vec![b'a'; 40]; let payload_b = vec![b'b'; 40]; - assert_eq!(fsm.write_stream(stream_id_a, &payload_a).unwrap(), 40); - assert_eq!(fsm.write_stream(stream_id_b, &payload_b).unwrap(), 40); + assert_eq!(write_stream_bytes(&mut fsm, stream_id_a, &payload_a), 40); + assert_eq!(write_stream_bytes(&mut fsm, stream_id_b, &payload_b), 40); let (first_seq, first) = next_outbound(&mut fsm, now).unwrap(); let (second_seq, _second) = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); @@ -121,7 +127,7 @@ fn lost_record_on_one_stream_does_not_block_another_stream() { |frame| matches!(frame, SessionFrame::StreamData(frame) if frame.stream_id == stream_id_a) )); - assert_eq!(fsm.write_stream(stream_id_b, b"b-2").unwrap(), 3); + assert_eq!(write_stream_bytes(&mut fsm, stream_id_b, b"b-2"), 3); let (_third_seq, third) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); let stream_ids: Vec<_> = third @@ -147,7 +153,7 @@ fn ack_reopens_write_capacity() { ); let stream_id = fsm.open_stream().unwrap(); - assert_eq!(fsm.write_stream(stream_id, b"abcd").unwrap(), 4); + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"abcd"), 4); let (record_seq, _record) = next_outbound(&mut fsm, now).unwrap(); let mut events = Vec::new(); @@ -162,7 +168,7 @@ fn ack_reopens_write_capacity() { ); assert!(events.contains(&SessionEvent::Writable(stream_id))); - assert_eq!(fsm.write_stream(stream_id, b"z").unwrap(), 1); + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"z"), 1); } #[test] @@ -487,7 +493,7 @@ fn initial_peer_stream_receive_window_limits_first_send() { ); let stream_id = fsm.open_stream().unwrap(); - assert_eq!(fsm.write_stream(stream_id, b"hello").unwrap(), 5); + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"hello"), 5); let (_first_seq, first) = next_outbound(&mut fsm, now).unwrap(); assert!(matches!( first.frames.as_slice(), diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index c7671a89..55b1064b 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -3,6 +3,7 @@ use std::{ time::Duration, }; +use bytes::Bytes; use ::proptest::{collection::vec, prelude::*, test_runner::TestCaseResult}; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; @@ -313,7 +314,8 @@ impl Runner { } Action::WriteA { slot, bytes } => { if let Some(stream_id) = self.slots_a[*slot] { - if let Ok(accepted) = self.harness.a.fsm.write_stream(stream_id, bytes) { + let mut chunk = Bytes::from(bytes.clone()); + if let Ok(accepted) = self.harness.a.fsm.write_stream(stream_id, &mut chunk) { self.expected_at_b .entry(stream_id) .or_default() @@ -323,7 +325,8 @@ impl Runner { } Action::WriteB { slot, bytes } => { if let Some(stream_id) = self.slots_b[*slot] { - if let Ok(accepted) = self.harness.b.fsm.write_stream(stream_id, bytes) { + let mut chunk = Bytes::from(bytes.clone()); + if let Ok(accepted) = self.harness.b.fsm.write_stream(stream_id, &mut chunk) { self.expected_at_a .entry(stream_id) .or_default() diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 2d7e6964..f9c281d9 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -1,5 +1,6 @@ use std::time::Duration; +use bytes::Bytes; use ql_wire::{SessionClose, StreamId}; use super::*; @@ -9,6 +10,15 @@ fn stream_id(value: u32) -> StreamId { StreamId::from_u32(value) } +fn write_stream_bytes( + fsm: &mut QlFsm, + stream_id: StreamId, + bytes: &[u8], +) -> Result { + let mut bytes = Bytes::copy_from_slice(bytes); + fsm.write_stream(stream_id, &mut bytes) +} + fn read_stream_all(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { let mut out = Vec::new(); loop { @@ -30,7 +40,7 @@ fn connected_fsms_deliver_stream_data() { let mut harness = Harness::connected(QlFsmConfig::default()); let stream_id = harness.a.fsm.open_stream().unwrap(); - assert_eq!(harness.a.fsm.write_stream(stream_id, b"hello").unwrap(), 5); + assert_eq!(write_stream_bytes(&mut harness.a.fsm, stream_id, b"hello").unwrap(), 5); harness.a.fsm.finish_stream(stream_id).unwrap(); harness.pump(); @@ -56,7 +66,7 @@ fn session_retransmit_uses_new_record_seq() { let mut harness = Harness::connected(config); let stream_id = harness.a.fsm.open_stream().unwrap(); - assert_eq!(harness.a.fsm.write_stream(stream_id, b"retry").unwrap(), 5); + assert_eq!(write_stream_bytes(&mut harness.a.fsm, stream_id, b"retry").unwrap(), 5); let first = harness.next_outbound_a().unwrap(); let first_transport = harness.b.fsm.state.link.transport().unwrap().clone(); @@ -112,11 +122,11 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { ); assert_eq!( - harness.a.fsm.write_stream(stream_id_a, b"from-a").unwrap(), + write_stream_bytes(&mut harness.a.fsm, stream_id_a, b"from-a").unwrap(), 6 ); assert_eq!( - harness.b.fsm.write_stream(stream_id_b, b"from-b").unwrap(), + write_stream_bytes(&mut harness.b.fsm, stream_id_b, b"from-b").unwrap(), 6 ); @@ -155,7 +165,7 @@ fn disconnected_stream_operations_fail_with_no_session() { assert_eq!(harness.a.fsm.open_stream(), Err(QlFsmError::NoSession)); assert_eq!( - harness.a.fsm.write_stream(missing, b"queued"), + write_stream_bytes(&mut harness.a.fsm, missing, b"queued"), Err(QlFsmError::NoSession) ); assert_eq!( @@ -191,7 +201,7 @@ fn returned_session_write_is_reissued_with_new_record_seq() { let mut harness = Harness::connected(QlFsmConfig::default()); let stream_id = harness.a.fsm.open_stream().unwrap(); - assert_eq!(harness.a.fsm.write_stream(stream_id, b"retry").unwrap(), 5); + assert_eq!(write_stream_bytes(&mut harness.a.fsm, stream_id, b"retry").unwrap(), 5); let write = harness.next_write_a().unwrap(); let id = write.session_write_id.expect("expected session write"); @@ -231,7 +241,7 @@ fn unconfirmed_session_write_does_not_start_retransmit_timer() { let mut harness = Harness::connected(config); let stream_id = harness.a.fsm.open_stream().unwrap(); - assert_eq!(harness.a.fsm.write_stream(stream_id, b"retry").unwrap(), 5); + assert_eq!(write_stream_bytes(&mut harness.a.fsm, stream_id, b"retry").unwrap(), 5); let write = harness.next_write_a().unwrap(); let id = write.session_write_id.expect("expected session write"); @@ -264,8 +274,8 @@ fn ack_frame_releases_stream_capacity_and_emits_writable() { let mut harness = Harness::connected(config); let stream_id = harness.a.fsm.open_stream().unwrap(); - assert_eq!(harness.a.fsm.write_stream(stream_id, b"abcd").unwrap(), 4); - assert_eq!(harness.a.fsm.write_stream(stream_id, b"z").unwrap(), 0); + assert_eq!(write_stream_bytes(&mut harness.a.fsm, stream_id, b"abcd").unwrap(), 4); + assert_eq!(write_stream_bytes(&mut harness.a.fsm, stream_id, b"z").unwrap(), 0); let record = harness.next_outbound_a().unwrap(); harness.deliver_to_b(record); @@ -299,7 +309,7 @@ fn session_records_contain_ack_frames_after_delivery() { let mut harness = Harness::connected(config); let stream_id = harness.a.fsm.open_stream().unwrap(); - assert_eq!(harness.a.fsm.write_stream(stream_id, b"x").unwrap(), 1); + assert_eq!(write_stream_bytes(&mut harness.a.fsm, stream_id, b"x").unwrap(), 1); let data = harness.next_outbound_a().unwrap(); harness.deliver_to_b(data); @@ -335,7 +345,7 @@ fn first_stream_data_uses_negotiated_initial_peer_credit() { harness.deliver_to_a(ik2); let stream_id = harness.a.fsm.open_stream().unwrap(); - assert_eq!(harness.a.fsm.write_stream(stream_id, b"hello").unwrap(), 5); + assert_eq!(write_stream_bytes(&mut harness.a.fsm, stream_id, b"hello").unwrap(), 5); let data = harness.next_outbound_a().unwrap(); let session_key = harness.b.fsm.state.link.transport().unwrap().rx_key.clone(); diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index b0c94d1c..65bfe203 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -519,7 +519,9 @@ impl DriverState { false } else { let len = bytes.len(); - let accepted = fsm.write_stream(stream_id, bytes).unwrap_or_default(); + let mut bytes = ql_fsm::Bytes::copy_from_slice(bytes); + let accepted = + fsm.write_stream(stream_id, &mut bytes).unwrap_or_default(); if accepted > 0 { reader.consume(accepted); } diff --git a/ql-wire/src/bytes.rs b/ql-wire/src/bytes.rs index 1a1294f0..21a6c57e 100644 --- a/ql-wire/src/bytes.rs +++ b/ql-wire/src/bytes.rs @@ -145,72 +145,11 @@ impl ByteSlice for &mut [u8] { } } -#[derive(Debug, Clone, Copy)] -pub struct RangedByteChunks { - pub inner: T, - pub offset: usize, - pub len: usize, -} - -pub struct RangedByteChunksIter { - inner: I, - skip: usize, - remaining: usize, -} - -impl<'a, I> Iterator for RangedByteChunksIter -where - I: Iterator, -{ - type Item = &'a [u8]; - - fn next(&mut self) -> Option { - while self.remaining > 0 { - let chunk = self.inner.next()?; - if self.skip >= chunk.len() { - self.skip -= chunk.len(); - continue; - } - - let chunk = &chunk[self.skip..]; - self.skip = 0; - if chunk.is_empty() { - continue; - } - - let len = chunk.len().min(self.remaining); - self.remaining -= len; - return Some(&chunk[..len]); - } - - None - } -} - -impl ByteChunks for RangedByteChunks { - type Chunks<'a> - = RangedByteChunksIter> - where - Self: 'a; - - fn len(&self) -> usize { - self.inner.len().saturating_sub(self.offset).min(self.len) - } - - fn chunks(&self) -> Self::Chunks<'_> { - RangedByteChunksIter { - inner: self.inner.chunks(), - skip: self.offset, - remaining: self.len(), - } - } -} - #[cfg(test)] mod tests { use std::collections::VecDeque; - use super::{ByteChunks, ByteSlice, ByteSliceMut, RangedByteChunks}; + use super::{ByteChunks, ByteSlice, ByteSliceMut}; #[test] fn shared_slice_split_at() { @@ -262,36 +201,4 @@ mod tests { assert_eq!(chunks.concat(), b"cdefgh"); assert!(!chunks.is_empty()); } - - #[test] - fn ranged_byte_chunks_slice_middle() { - let bytes: &[u8] = b"abcdef"; - let ranged = RangedByteChunks { - inner: bytes, - offset: 2, - len: 3, - }; - - let chunks = ranged.chunks().collect::>(); - assert_eq!(ranged.len(), 3); - assert_eq!(chunks, vec![b"cde".as_slice()]); - } - - #[test] - fn ranged_byte_chunks_borrowed_vec_deque_middle() { - let mut bytes = VecDeque::with_capacity(8); - bytes.extend(b"abcd".iter().copied()); - bytes.drain(..2); - bytes.extend(b"efgh".iter().copied()); - - let ranged = RangedByteChunks { - inner: &bytes, - offset: 1, - len: 4, - }; - - let chunks = ranged.chunks().collect::>(); - assert_eq!(ranged.len(), 4); - assert_eq!(chunks.concat(), b"defg"); - } } From 943770e42168a0be95f42005d7cccb10dfd59c8f Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 6 Apr 2026 16:53:52 -0400 Subject: [PATCH 125/304] ql: use bytes for stream_rx --- ql-fsm/src/implementation/core.rs | 28 +-- ql-fsm/src/session/mod.rs | 8 +- ql-fsm/src/session/range_set.rs | 4 - ql-fsm/src/session/stream_rx.rs | 293 +++++++++++++++++------------- ql-fsm/src/session/tests.rs | 4 +- ql-wire/src/bytes.rs | 28 +++ ql-wire/src/encrypted/mod.rs | 86 ++++----- 7 files changed, 255 insertions(+), 196 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index c6919f32..8329b538 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -30,23 +30,29 @@ pub fn receive( super::handle_handshake_record(fsm, crypto, &record, &mut emit) } wire::RecordType::Session => { - let record = wire::QlSessionRecord::decode_exact(&mut bytes[..])?; let state = fsm.state.link.connected_mut_or_err()?; - if record.header.connection_id != state.transport.rx_connection_id { - return Err(QlFsmError::InvalidPayload); - } - let plaintext = wire::decrypt_record( - crypto, - &record.header, - record.payload, - &state.transport.rx_key, - )?; + let bytes_ptr = bytes.as_ptr() as usize; + let (seq, start, len) = { + let record = wire::QlSessionRecord::decode_exact(&mut bytes[..])?; + if record.header.connection_id != state.transport.rx_connection_id { + return Err(QlFsmError::InvalidPayload); + } + let plaintext = wire::decrypt_record( + crypto, + &record.header, + record.payload, + &state.transport.rx_key, + )?; + let start = plaintext.as_ptr() as usize - bytes_ptr; + (record.header.seq, start, plaintext.len()) + }; + let plaintext = Bytes::from(bytes).slice(start..start + len); let frames = wire::SessionRecord::parse(plaintext)?; let mut session_closed = false; state .session - .receive(fsm.state.now.instant, record.header.seq, frames, |event| { + .receive(fsm.state.now.instant, seq, frames, |event| { session_closed |= forward_session_event(event, &mut emit); }); diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 9170307d..e74fb5a3 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -252,14 +252,14 @@ impl SessionFsm { Ok(()) } - pub fn receive<'a, I>( + pub(crate) fn receive( &mut self, now: Instant, seq: RecordSeq, frames: I, mut emit: impl FnMut(SessionEvent), ) where - I: IntoIterator, WireError>>, + I: IntoIterator, WireError>>, { self.state.now = now; self.collect_timeouts(); @@ -298,7 +298,7 @@ impl SessionFsm { SessionFrame::Ping => {} SessionFrame::Ack(ack) => self.process_record_ack(&ack, &mut emit), SessionFrame::StreamData(frame) => { - if self.handle_stream_data(&frame, &mut emit).is_err() { + if self.handle_stream_data(frame, &mut emit).is_err() { return; } } @@ -666,7 +666,7 @@ impl SessionFsm { fn handle_stream_data( &mut self, - frame: &StreamData<&[u8]>, + frame: StreamData, emit: &mut impl FnMut(SessionEvent), ) -> Result<(), ()> { let stream_id = frame.stream_id; diff --git a/ql-fsm/src/session/range_set.rs b/ql-fsm/src/session/range_set.rs index 4ef6bdbd..ac39f23f 100644 --- a/ql-fsm/src/session/range_set.rs +++ b/ql-fsm/src/session/range_set.rs @@ -84,10 +84,6 @@ impl RangeSet { self.0.first_key_value().map(|(&start, _)| start) } - pub fn is_empty(&self) -> bool { - self.0.is_empty() - } - pub fn iter(&self) -> Iter<'_> { Iter(self.0.iter()) } diff --git a/ql-fsm/src/session/stream_rx.rs b/ql-fsm/src/session/stream_rx.rs index 78386229..ee58b593 100644 --- a/ql-fsm/src/session/stream_rx.rs +++ b/ql-fsm/src/session/stream_rx.rs @@ -1,13 +1,12 @@ -use std::collections::VecDeque; +use std::collections::{btree_map, BTreeMap}; -use super::range_set::RangeSet; +use bytes::{Buf, Bytes}; /// reassembles one stream direction from out-of-order byte ranges. #[derive(Debug, Clone, PartialEq, Eq)] pub struct StreamRx { start_offset: u64, - bytes: VecDeque, - missing: RangeSet, + chunks: BTreeMap, final_offset: Option, max_buffered: usize, } @@ -27,12 +26,6 @@ pub enum StreamRxError { BeyondFinalOffset, } -#[derive(Debug, Clone)] -pub struct StreamReadIter<'a> { - front: Option<&'a [u8]>, - back: Option<&'a [u8]>, -} - impl StreamRx { pub fn new(max_buffered: usize) -> Self { Self::with_start_offset(0, max_buffered) @@ -41,8 +34,7 @@ impl StreamRx { pub fn with_start_offset(start_offset: u64, max_buffered: usize) -> Self { Self { start_offset, - bytes: VecDeque::new(), - missing: RangeSet::new(), + chunks: BTreeMap::new(), final_offset: None, max_buffered, } @@ -53,7 +45,10 @@ impl StreamRx { } pub fn buffered_end_offset(&self) -> u64 { - self.start_offset + self.bytes.len() as u64 + self.chunks + .last_key_value() + .map(|(&offset, bytes)| offset + bytes.len() as u64) + .unwrap_or(self.start_offset) } pub fn max_buffered(&self) -> usize { @@ -61,51 +56,40 @@ impl StreamRx { } pub fn readable_len(&self) -> usize { - if self.bytes.is_empty() { - return 0; - } + let mut cursor = self.start_offset; + for (&offset, bytes) in self.chunks.range(self.start_offset..) { + if offset > cursor { + break; + } - match self.missing.peek_min() { - Some(range) if range.start <= self.start_offset => 0, - Some(range) => usize::try_from(range.start - self.start_offset) - .expect("readable prefix exceeds usize"), - None => self.bytes.len(), + let end = offset + bytes.len() as u64; + if end > cursor { + cursor = end; + } } + + usize::try_from(cursor - self.start_offset).expect("readable prefix exceeds usize") } pub fn bytes(&self) -> StreamReadIter<'_> { - let readable = self.readable_len(); - if readable == 0 { - return StreamReadIter { - front: None, - back: None, - }; - } - - let (front, back) = self.bytes.as_slices(); - if readable <= front.len() { - StreamReadIter { - front: Some(&front[..readable]), - back: None, - } - } else { - StreamReadIter { - front: Some(front), - back: Some(&back[..readable - front.len()]), - } + StreamReadIter { + inner: self.chunks.range(self.start_offset..), + cursor: self.start_offset, + remaining: self.readable_len(), } } pub fn is_complete(&self) -> bool { - matches!(self.final_offset, Some(final_offset) if final_offset == self.buffered_end_offset()) - && self.missing.is_empty() + matches!(self.final_offset, Some(final_offset) + if final_offset == self.buffered_end_offset() + && final_offset == self.start_offset + self.readable_len() as u64) } pub fn insert( &mut self, offset: u64, fin: bool, - bytes: &[u8], + mut bytes: Bytes, ) -> Result { let end = offset .checked_add(bytes.len() as u64) @@ -130,17 +114,16 @@ impl StreamRx { let effective_offset = offset.max(self.start_offset); let trim_front = usize::try_from(effective_offset - offset).expect("front trim exceeds usize"); - let effective_bytes = &bytes[trim_front..]; - if effective_bytes.is_empty() { + bytes.advance(trim_front); + if bytes.is_empty() { return Ok(self.insert_outcome(was_complete, old_readable)); } - self.ensure_within_window(end)?; - self.ensure_buffered(end); + let effective_end = effective_offset + bytes.len() as u64; + self.ensure_within_window(effective_end)?; #[cfg(test)] - self.assert_valid_overlap(effective_offset, effective_bytes); - self.write_bytes(effective_offset, effective_bytes); - self.missing.remove(effective_offset..end); + self.assert_valid_overlap(effective_offset, &bytes); + self.insert_chunk(effective_offset, bytes); Ok(self.insert_outcome(was_complete, old_readable)) } @@ -152,8 +135,22 @@ impl StreamRx { return; } - self.bytes.drain(..len); - self.start_offset = self.start_offset.saturating_add(len as u64); + let new_start = self.start_offset.saturating_add(len as u64); + while let Some((&offset, bytes)) = self.chunks.first_key_value() { + let end = offset + bytes.len() as u64; + if end <= new_start { + self.chunks.pop_first(); + continue; + } + if offset < new_start { + let (offset, mut bytes) = self.chunks.pop_first().unwrap(); + bytes.advance(usize::try_from(new_start - offset).expect("trim exceeds usize")); + self.chunks.insert(new_start, bytes); + } + break; + } + + self.start_offset = new_start; } fn insert_outcome(&self, was_complete: bool, old_readable: usize) -> InsertOutcome { @@ -189,72 +186,123 @@ impl StreamRx { Ok(()) } - fn ensure_buffered(&mut self, end: u64) { - let buffered_end = self.buffered_end_offset(); - if end <= buffered_end { + fn insert_chunk(&mut self, mut offset: u64, mut bytes: Bytes) { + if bytes.is_empty() { return; } - let additional = usize::try_from(end - buffered_end).expect("buffer growth exceeds usize"); - self.bytes.resize(self.bytes.len() + additional, 0); - self.missing.insert(buffered_end..end); - } + if let Some((&existing_offset, existing)) = self.chunks.range(..offset).next_back() { + let existing_end = existing_offset + existing.len() as u64; + if existing_end > offset { + let overlap = + usize::try_from((existing_end - offset).min(bytes.len() as u64)).unwrap(); + bytes.advance(overlap); + offset += overlap as u64; + } + } - #[cfg(test)] - fn assert_valid_overlap(&self, offset: u64, bytes: &[u8]) { - for (index, byte) in bytes.iter().copied().enumerate() { - let absolute = offset + index as u64; - let is_missing = self - .missing - .iter() - .any(|range| range.start <= absolute && absolute < range.end); - if is_missing { - continue; + if bytes.is_empty() { + return; + } + + let end = offset + bytes.len() as u64; + let overlapping = self + .chunks + .range(offset..end) + .map(|(&chunk_offset, _)| chunk_offset) + .collect::>(); + + for chunk_offset in overlapping { + let chunk_end = chunk_offset + self.chunks[&chunk_offset].len() as u64; + + if chunk_offset > offset { + let len = usize::try_from(chunk_offset - offset).expect("gap exceeds usize"); + self.chunks.insert(offset, bytes.slice(..len)); + bytes.advance(len); + offset = chunk_offset; } - let index = - usize::try_from(absolute - self.start_offset).expect("read index exceeds usize"); + let overlap = usize::try_from((chunk_end - offset).min(bytes.len() as u64)).unwrap(); + bytes.advance(overlap); + offset += overlap as u64; - assert_eq!( - self.bytes[index], byte, - "conflicting overlap at stream offset {absolute}" - ); + if bytes.is_empty() { + return; + } } + + self.chunks.insert(offset, bytes); } - fn write_bytes(&mut self, offset: u64, bytes: &[u8]) { - let start = usize::try_from(offset - self.start_offset).expect("write index exceeds usize"); - let (front, back) = self.bytes.as_mut_slices(); + #[cfg(test)] + fn assert_valid_overlap(&self, offset: u64, bytes: &Bytes) { + if let Some((&existing_offset, existing)) = self.chunks.range(..offset).next_back() { + self.assert_overlap_chunk(offset, bytes, existing_offset, existing); + } + let end = offset + bytes.len() as u64; + for (&existing_offset, existing) in self.chunks.range(offset..end) { + self.assert_overlap_chunk(offset, bytes, existing_offset, existing); + } + } - if start >= front.len() { - let start = start - front.len(); - back[start..start + bytes.len()].copy_from_slice(bytes); + #[cfg(test)] + fn assert_overlap_chunk( + &self, + offset: u64, + bytes: &Bytes, + existing_offset: u64, + existing: &Bytes, + ) { + let end = offset + bytes.len() as u64; + let existing_end = existing_offset + existing.len() as u64; + let overlap_start = offset.max(existing_offset); + let overlap_end = end.min(existing_end); + if overlap_start >= overlap_end { return; } - let front_len = (front.len() - start).min(bytes.len()); - front[start..start + front_len].copy_from_slice(&bytes[..front_len]); + let start = usize::try_from(overlap_start - offset).expect("overlap start exceeds usize"); + let existing_start = usize::try_from(overlap_start - existing_offset) + .expect("existing overlap start exceeds usize"); + let len = usize::try_from(overlap_end - overlap_start).expect("overlap exceeds usize"); - if front_len < bytes.len() { - back[..bytes.len() - front_len].copy_from_slice(&bytes[front_len..]); - } + assert_eq!( + &bytes[start..start + len], + &existing[existing_start..existing_start + len], + "conflicting overlap at stream offset {overlap_start}" + ); } } +#[derive(Debug, Clone)] +pub struct StreamReadIter<'a> { + inner: btree_map::Range<'a, u64, Bytes>, + cursor: u64, + remaining: usize, +} + impl<'a> Iterator for StreamReadIter<'a> { type Item = &'a [u8]; fn next(&mut self) -> Option { - if let Some(front) = self.front.take() { - if !front.is_empty() { - return Some(front); + while self.remaining > 0 { + let (&offset, bytes) = self.inner.next()?; + if offset > self.cursor { + self.remaining = 0; + return None; } - } - if let Some(back) = self.back.take() { - if !back.is_empty() { - return Some(back); + let skip = usize::try_from(self.cursor.saturating_sub(offset)) + .expect("read cursor exceeds usize"); + if skip >= bytes.len() { + continue; } + + let chunk = &bytes[skip..]; + let len = chunk.len().min(self.remaining); + self.remaining -= len; + self.cursor += len as u64; + return Some(&chunk[..len]); } None @@ -263,6 +311,8 @@ impl<'a> Iterator for StreamReadIter<'a> { #[cfg(test)] mod tests { + use bytes::Bytes; + use super::{InsertOutcome, StreamRx, StreamRxError}; pub fn copy_readable(rx: &StreamRx) -> Vec { @@ -274,11 +324,15 @@ mod tests { out } + fn bytes(bytes: &'static [u8]) -> Bytes { + Bytes::from_static(bytes) + } + #[test] fn contiguous_insert_becomes_readable_and_complete() { let mut rx = StreamRx::new(64); - let outcome = rx.insert(0, true, b"hello").unwrap(); + let outcome = rx.insert(0, true, bytes(b"hello")).unwrap(); assert_eq!( outcome, @@ -291,14 +345,13 @@ mod tests { assert_eq!(copy_readable(&rx), b"hello"); assert_eq!(rx.final_offset, Some(5)); assert!(rx.is_complete()); - assert!(rx.missing.is_empty()); } #[test] - fn out_of_order_insert_tracks_missing_ranges_until_gap_is_filled() { + fn out_of_order_insert_tracks_gap_until_prefix_is_filled() { let mut rx = StreamRx::new(64); - let first = rx.insert(5, true, b" world").unwrap(); + let first = rx.insert(5, true, bytes(b" world")).unwrap(); assert_eq!( first, InsertOutcome { @@ -306,10 +359,9 @@ mod tests { became_complete: false, } ); - assert_eq!(rx.missing.iter().collect::>(), vec![0..5]); assert_eq!(rx.readable_len(), 0); - let second = rx.insert(0, false, b"hello").unwrap(); + let second = rx.insert(0, false, bytes(b"hello")).unwrap(); assert_eq!( second, InsertOutcome { @@ -318,7 +370,6 @@ mod tests { } ); assert_eq!(copy_readable(&rx), b"hello world"); - assert!(rx.missing.is_empty()); assert!(rx.is_complete()); } @@ -326,8 +377,8 @@ mod tests { fn duplicate_insert_is_ignored_if_bytes_match() { let mut rx = StreamRx::new(64); - rx.insert(0, false, b"hello").unwrap(); - let duplicate = rx.insert(0, false, b"hello").unwrap(); + rx.insert(0, false, bytes(b"hello")).unwrap(); + let duplicate = rx.insert(0, false, bytes(b"hello")).unwrap(); assert_eq!( duplicate, @@ -344,20 +395,20 @@ mod tests { fn conflicting_overlap_panics_in_test_builds() { let mut rx = StreamRx::new(64); - rx.insert(0, false, b"abcdef").unwrap(); - rx.insert(3, false, b"xyz").unwrap(); + rx.insert(0, false, bytes(b"abcdef")).unwrap(); + rx.insert(3, false, bytes(b"xyz")).unwrap(); } #[test] fn consume_advances_start_offset_and_trims_old_prefix() { let mut rx = StreamRx::new(64); - rx.insert(0, false, b"abcd").unwrap(); + rx.insert(0, false, bytes(b"abcd")).unwrap(); rx.consume(2); assert_eq!(rx.start_offset(), 2); assert_eq!(copy_readable(&rx), b"cd"); - let outcome = rx.insert(1, true, b"bcde").unwrap(); + let outcome = rx.insert(1, true, bytes(b"bcde")).unwrap(); assert_eq!( outcome, InsertOutcome { @@ -374,13 +425,11 @@ mod tests { fn insert_can_fill_multiple_gaps_without_rebuilding_state() { let mut rx = StreamRx::new(64); - rx.insert(0, false, b"ab").unwrap(); - rx.insert(4, false, b"ef").unwrap(); - rx.insert(8, true, b"ij").unwrap(); - - assert_eq!(rx.missing.iter().collect::>(), vec![2..4, 6..8]); + rx.insert(0, false, bytes(b"ab")).unwrap(); + rx.insert(4, false, bytes(b"ef")).unwrap(); + rx.insert(8, true, bytes(b"ij")).unwrap(); - let outcome = rx.insert(2, false, b"cdefgh").unwrap(); + let outcome = rx.insert(2, false, bytes(b"cdefgh")).unwrap(); assert_eq!( outcome, @@ -389,7 +438,6 @@ mod tests { became_complete: true, } ); - assert!(rx.missing.is_empty()); assert_eq!(copy_readable(&rx), b"abcdefghij"); assert!(rx.is_complete()); @@ -399,18 +447,13 @@ mod tests { fn heavily_fragmented_inserts_stay_valid() { let mut rx = StreamRx::new(64); - rx.insert(1, false, b"b").unwrap(); - rx.insert(3, false, b"d").unwrap(); - rx.insert(5, false, b"f").unwrap(); - rx.insert(7, false, b"h").unwrap(); - rx.insert(9, true, b"j").unwrap(); - - assert_eq!( - rx.missing.iter().collect::>(), - vec![0..1, 2..3, 4..5, 6..7, 8..9] - ); + rx.insert(1, false, bytes(b"b")).unwrap(); + rx.insert(3, false, bytes(b"d")).unwrap(); + rx.insert(5, false, bytes(b"f")).unwrap(); + rx.insert(7, false, bytes(b"h")).unwrap(); + rx.insert(9, true, bytes(b"j")).unwrap(); - let outcome = rx.insert(0, false, b"abcdefghi").unwrap(); + let outcome = rx.insert(0, false, bytes(b"abcdefghi")).unwrap(); assert_eq!( outcome, InsertOutcome { @@ -425,7 +468,7 @@ mod tests { #[test] fn out_of_window_insert_is_rejected() { let mut rx = StreamRx::new(4); - let error = rx.insert(5, false, b"a").unwrap_err(); + let error = rx.insert(5, false, bytes(b"a")).unwrap_err(); assert_eq!(error, StreamRxError::OutOfWindow); } } diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index ccd9be3e..e06b4837 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -63,8 +63,8 @@ fn receive_events( for frame in &record.frames { assert!(builder.push_frame(frame)); } - let bytes = builder.bytes().to_vec(); - let frames = SessionRecord::parse(&bytes).unwrap(); + let bytes = Bytes::from(builder.bytes().to_vec()); + let frames = SessionRecord::parse(bytes).unwrap(); let mut events = Vec::new(); fsm.receive(now, seq, frames, |event| events.push(event)); events diff --git a/ql-wire/src/bytes.rs b/ql-wire/src/bytes.rs index 21a6c57e..7a6f4e25 100644 --- a/ql-wire/src/bytes.rs +++ b/ql-wire/src/bytes.rs @@ -4,6 +4,8 @@ use core::{ }; use std::collections::VecDeque; +use bytes::Bytes; + /// A mutable or immutable byte slice owner used by the wire parser. pub trait ByteSlice: Deref + Sized { /// Splits the current byte view at `mid`. @@ -107,6 +109,21 @@ impl ByteChunks for Vec { } } +impl ByteChunks for Bytes { + type Chunks<'a> + = Once<&'a [u8]> + where + Self: 'a; + + fn len(&self) -> usize { + Bytes::len(self) + } + + fn chunks(&self) -> Self::Chunks<'_> { + once(self.as_ref()) + } +} + impl ByteChunks for VecDeque { type Chunks<'a> = Chain, Once<&'a [u8]>> @@ -145,6 +162,17 @@ impl ByteSlice for &mut [u8] { } } +impl ByteSlice for Bytes { + #[inline] + fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { + if mid <= self.len() { + Ok((self.slice(..mid), self.slice(mid..))) + } else { + Err(self) + } + } +} + #[cfg(test)] mod tests { use std::collections::VecDeque; diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index de42ec45..e458986e 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -2,6 +2,7 @@ use crate::{ codec, encrypted_message::EncryptedMessage, ByteChunks, ByteSlice, Nonce, QlCrypto, SessionHeader, SessionKey, VarInt, VarIntBoundsExceeded, WireDecode, WireEncode, WireError, }; +use bytes::Bytes; mod ack; mod builder; @@ -71,6 +72,8 @@ pub enum SessionFrame { pub type SessionFrameVec = SessionFrame>; pub type StreamDataVec = StreamData>; +pub type SessionFrameBytes = SessionFrame; +pub type StreamDataBytes = StreamData; #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] @@ -83,8 +86,8 @@ pub(crate) enum SessionFrameKind { Close = 6, } -pub struct SessionFrameIter<'a> { - remaining: &'a [u8], +pub struct SessionFrameIter { + remaining: Option, } impl TryFrom for SessionFrameKind { @@ -103,9 +106,17 @@ impl TryFrom for SessionFrameKind { } } +impl codec::WireDecode for SessionFrameKind { + fn decode(reader: &mut codec::Reader) -> Result { + reader.decode::()?.try_into() + } +} + impl SessionRecord { - pub fn parse(bytes: &[u8]) -> Result, WireError> { - Ok(SessionFrameIter { remaining: bytes }) + pub fn parse(bytes: B) -> Result, WireError> { + Ok(SessionFrameIter { + remaining: Some(bytes), + }) } pub fn decode(bytes: &[u8]) -> Result { @@ -195,22 +206,22 @@ impl WireEncode for SessionRecord { } } -impl<'a> Iterator for SessionFrameIter<'a> { - type Item = Result, WireError>; +impl Iterator for SessionFrameIter { + type Item = Result, WireError>; fn next(&mut self) -> Option { - if self.remaining.is_empty() { + let remaining = self.remaining.take()?; + if remaining.is_empty() { return None; } - let parsed = parse_next_frame(self.remaining); + let parsed = parse_next_frame(remaining); match parsed { Ok((frame, rest)) => { - self.remaining = rest; + self.remaining = Some(rest); Some(Ok(frame)) } Err(error) => { - self.remaining = &[]; Some(Err(error)) } } @@ -228,50 +239,25 @@ pub fn decrypt_record>( encrypted.decrypt_in_place(crypto, session_key, &nonce, &aad) } -fn parse_next_frame(bytes: &[u8]) -> Result<(SessionFrame<&[u8]>, &[u8]), WireError> { - let (&kind, rest) = bytes.split_first().ok_or(WireError::InvalidPayload)?; - match SessionFrameKind::try_from(kind)? { - SessionFrameKind::Ping => Ok((SessionFrame::Ping, rest)), - SessionFrameKind::Ack => { - let (frame, rest) = parse_inline_frame::(rest)?; - Ok((SessionFrame::Ack(frame), rest)) - } +fn parse_next_frame(bytes: B) -> Result<(SessionFrame, B), WireError> { + let mut reader = codec::Reader::new(bytes); + let kind = reader.decode::()?; + let frame = match kind { + SessionFrameKind::Ping => SessionFrame::Ping, + SessionFrameKind::Ack => SessionFrame::Ack(reader.decode::()?), SessionFrameKind::StreamData => { - let (frame, rest) = split_variable_frame(rest)?; - Ok(( - SessionFrame::StreamData(StreamData::decode_exact(frame)?), - rest, - )) + let len = usize::try_from(reader.decode::()?.into_inner()) + .map_err(|_| WireError::InvalidPayload)?; + let frame = reader.take_bytes(len)?; + SessionFrame::StreamData(StreamData::decode_exact(frame)?) } SessionFrameKind::StreamWindow => { - let (frame, rest) = parse_inline_frame::(rest)?; - Ok((SessionFrame::StreamWindow(frame), rest)) + SessionFrame::StreamWindow(reader.decode::()?) } SessionFrameKind::StreamClose => { - let (frame, rest) = parse_inline_frame::(rest)?; - Ok((SessionFrame::StreamClose(frame), rest)) - } - SessionFrameKind::Close => { - let (frame, rest) = parse_inline_frame::(rest)?; - Ok((SessionFrame::Close(frame), rest)) + SessionFrame::StreamClose(reader.decode::()?) } - } -} - -fn parse_inline_frame(bytes: &[u8]) -> Result<(T, &[u8]), WireError> -where - T: for<'a> WireDecode<&'a [u8]>, -{ - let mut reader = codec::Reader::new(bytes); - let frame = reader.decode::()?; - let consumed = bytes.len() - reader.remaining_len(); - Ok((frame, &bytes[consumed..])) -} - -fn split_variable_frame(bytes: &[u8]) -> Result<(&[u8], &[u8]), WireError> { - let mut reader = codec::Reader::new(bytes); - let len = usize::try_from(reader.decode::()?.into_inner()) - .map_err(|_| WireError::InvalidPayload)?; - let bytes = &bytes[bytes.len() - reader.remaining_len()..]; - bytes.split_at_checked(len).ok_or(WireError::InvalidPayload) + SessionFrameKind::Close => SessionFrame::Close(reader.decode::()?), + }; + Ok((frame, reader.take_rest())) } From be21010028f18f2748423d7ae900b145ad302e1b Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 6 Apr 2026 18:25:48 -0400 Subject: [PATCH 126/304] ql: prevent duplicate decode and less pointer arithmetic --- ql-fsm/src/implementation/core.rs | 28 ++++--- ql-fsm/src/tests/handshake.rs | 6 +- ql-fsm/src/tests/mod.rs | 17 +++-- ql-wire/src/record.rs | 118 +++++++++++++++++------------- ql-wire/src/tests.rs | 33 ++++----- 5 files changed, 114 insertions(+), 88 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index 8329b538..33decd1d 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -3,7 +3,6 @@ use std::time::{Duration, Instant}; use bytes::Bytes; use ql_wire::{ self as wire, CloseTarget, QlCrypto, SessionCloseCode, StreamCloseCode, StreamId, WireDecode, - WireEncode, }; use crate::{ @@ -23,30 +22,36 @@ pub fn receive( crypto: &impl QlCrypto, mut emit: impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { - let header = wire::RecordHeader::decode_bytes(bytes.as_slice())?; + let mut reader = wire::Reader::new(bytes.as_mut_slice()); + let header = wire::RecordHeader::decode(&mut reader)?; + + if header.version != wire::QL_WIRE_VERSION { + return Err(QlFsmError::InvalidPayload); + } + match header.record_type { wire::RecordType::Handshake => { - let record = wire::QlHandshakeRecord::decode_exact(bytes.as_slice())?; + let record = wire::QlHandshakeRecord::decode(&mut reader)?; super::handle_handshake_record(fsm, crypto, &record, &mut emit) } wire::RecordType::Session => { let state = fsm.state.link.connected_mut_or_err()?; - let bytes_ptr = bytes.as_ptr() as usize; - let (seq, start, len) = { - let record = wire::QlSessionRecord::decode_exact(&mut bytes[..])?; + let (decrypt_len, seq) = { + let record = wire::QlSessionRecord::decode(&mut reader)?; if record.header.connection_id != state.transport.rx_connection_id { return Err(QlFsmError::InvalidPayload); } - let plaintext = wire::decrypt_record( + let payload = wire::decrypt_record( crypto, &record.header, record.payload, &state.transport.rx_key, )?; - let start = plaintext.as_ptr() as usize - bytes_ptr; - (record.header.seq, start, plaintext.len()) + (payload.len(), record.header.seq) }; - let plaintext = Bytes::from(bytes).slice(start..start + len); + + let len = bytes.len(); + let plaintext = Bytes::from(bytes).slice(len - decrypt_len..); let frames = wire::SessionRecord::parse(plaintext)?; let mut session_closed = false; @@ -95,8 +100,9 @@ pub fn next_deadline(fsm: &QlFsm) -> Option { pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option { if let Some(record) = fsm.state.handshake.take() { + let record = wire::encode_record_vec(ql_wire::RecordType::Handshake, &record); return Some(OutboundWrite { - record: record.encode_vec(), + record, session_write_id: None, }); } diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 7ae59efb..c8cbbbc0 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -1,6 +1,6 @@ use std::time::Duration; -use ql_wire::{QlHandshakeRecord, WireDecode}; +use ql_wire::QlHandshakeRecord; use super::*; use crate::{state::LinkState, PeerStatus, QlFsmError, QlFsmEvent}; @@ -192,7 +192,7 @@ fn handshake_timeout_drops_single_ik_attempt_without_resend() { harness.connect_ik_a().unwrap(); harness.drain_events_a(); let first = harness.next_outbound_a().unwrap(); - let first = QlHandshakeRecord::decode_exact(first.as_slice()).unwrap(); + let (_, first) = ql_wire::decode_record::(first.as_slice()).unwrap(); assert!(matches!(first, ql_wire::QlHandshakeRecord::Ik1(_))); assert!(harness.next_outbound_a().is_none()); @@ -261,7 +261,7 @@ fn simultaneous_ik_and_kk_connect_prefers_ik() { } fn handshake_id(record: &[u8]) -> ql_wire::HandshakeId { - let record = QlHandshakeRecord::decode_exact(record).unwrap(); + let (_, record) = ql_wire::decode_record(record).unwrap(); match record { ql_wire::QlHandshakeRecord::Ik1(message) => message.meta.handshake_id, ql_wire::QlHandshakeRecord::Ik2(message) => message.meta.handshake_id, diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 253877d7..40fbaf2a 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -13,7 +13,7 @@ use libcrux_ml_kem::mlkem1024; use ql_wire::{ self, generate_identity, ConnectionId, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, QlAead, QlCrypto, QlHash, QlIdentity, QlKem, QlRandom, SessionKey, - TransportParams, WireDecode, ENCRYPTED_MESSAGE_AUTH_SIZE, XID, + TransportParams, ENCRYPTED_MESSAGE_AUTH_SIZE, XID, }; use sha2::{Digest, Sha256}; @@ -443,12 +443,15 @@ fn decrypt_record( record: &[u8], session_key: &SessionKey, ) -> (ql_wire::SessionHeader, ql_wire::SessionRecord) { - let record = ql_wire::QlSessionRecord::decode_exact(record) - .unwrap() - .into_owned(); - let plaintext = - ql_wire::decrypt_record(crypto, &record.header, record.payload.clone(), session_key) - .unwrap(); + let (_header, record) = + ql_wire::decode_record::, _>(record).unwrap(); + let plaintext = ql_wire::decrypt_record( + crypto, + &record.header, + record.payload.into_owned(), + session_key, + ) + .unwrap(); ( record.header, ql_wire::SessionRecord::decode(&plaintext).unwrap(), diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index 002e8e7f..191d5e4b 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -2,28 +2,35 @@ use crate::{ codec, encrypted_message::EncryptedMessage, handshake::{Ik1, Ik2, Kk1, Kk2}, - ByteSlice, SessionHeader, WireEncode, WireError, WireDecode, QL_WIRE_VERSION, + ByteSlice, SessionHeader, WireDecode, WireEncode, WireError, QL_WIRE_VERSION, }; -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct QlSessionRecord { - pub header: SessionHeader, - pub payload: EncryptedMessage, +pub fn encode_record(out: &mut W, record_type: RecordType, body: &T) +where + W: bytes::BufMut + ?Sized, + T: WireEncode + ?Sized, +{ + RecordHeader { + version: QL_WIRE_VERSION, + record_type, + } + .encode(out); + body.encode(out); } -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum QlHandshakeRecord { - Ik1(Ik1), - Ik2(Ik2), - Kk1(Kk1), - Kk2(Kk2), +pub fn encode_record_vec(record_type: RecordType, body: &T) -> Vec { + let mut out = Vec::with_capacity(RecordHeader::WIRE_SIZE + body.encoded_len()); + encode_record(&mut out, record_type, body); + out } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(u8)] -pub enum RecordType { - Handshake = 1, - Session = 2, +pub fn decode_record(bytes: B) -> Result<(RecordHeader, T), WireError> +where + T: WireDecode, + B: ByteSlice, +{ + let mut reader = codec::Reader::new(bytes); + Ok((reader.decode()?, reader.decode()?)) } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -32,6 +39,37 @@ pub struct RecordHeader { pub record_type: RecordType, } +impl RecordHeader { + pub const WIRE_SIZE: usize = size_of::() + size_of::(); +} + +impl WireDecode for RecordHeader { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + version: reader.decode()?, + record_type: reader.decode()?, + }) + } +} + +impl WireEncode for RecordHeader { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + out.put_u8(self.version); + self.record_type.encode(out); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum RecordType { + Handshake = 1, + Session = 2, +} + impl TryFrom for RecordType { type Error = WireError; @@ -60,15 +98,12 @@ impl WireEncode for RecordType { } } -impl WireDecode for RecordHeader { - fn decode(reader: &mut codec::Reader) -> Result { - let version = reader.decode()?; - let record_type = reader.decode()?; - Ok(Self { - version, - record_type, - }) - } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum QlHandshakeRecord { + Ik1(Ik1), + Ik2(Ik2), + Kk1(Kk1), + Kk2(Kk2), } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -123,9 +158,7 @@ impl QlHandshakeRecord { impl WireEncode for QlHandshakeRecord { fn encoded_len(&self) -> usize { - RecordType::Handshake.encoded_len() - + HandshakeKind::Ik1.encoded_len() - + size_of::() + self.kind().encoded_len() + match self { Self::Ik1(message) => message.encoded_len(), Self::Ik2(message) => message.encoded_len(), @@ -135,8 +168,6 @@ impl WireEncode for QlHandshakeRecord { } fn encode(&self, out: &mut W) { - out.put_u8(QL_WIRE_VERSION); - RecordType::Handshake.encode(out); self.kind().encode(out); match self { Self::Ik1(message) => message.encode(out), @@ -149,13 +180,6 @@ impl WireEncode for QlHandshakeRecord { impl WireDecode for QlHandshakeRecord { fn decode(reader: &mut codec::Reader) -> Result { - let header = reader.decode::()?; - if header.version != QL_WIRE_VERSION { - return Err(WireError::InvalidPayload); - } - if header.record_type != RecordType::Handshake { - return Err(WireError::InvalidPayload); - } let kind = reader.decode::()?; match kind { HandshakeKind::Ik1 => Ok(Self::Ik1(reader.decode()?)), @@ -166,17 +190,18 @@ impl WireDecode for QlHandshakeRecord { } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QlSessionRecord { + pub header: SessionHeader, + pub payload: EncryptedMessage, +} + impl> WireEncode for QlSessionRecord { fn encoded_len(&self) -> usize { - size_of::() - + RecordType::Session.encoded_len() - + self.header.encoded_len() - + self.payload.encoded_len() + self.header.encoded_len() + self.payload.encoded_len() } fn encode(&self, out: &mut W) { - out.put_u8(QL_WIRE_VERSION); - RecordType::Session.encode(out); self.header.encode(out); self.payload.encode(out); } @@ -193,13 +218,6 @@ impl QlSessionRecord { impl WireDecode for QlSessionRecord { fn decode(reader: &mut codec::Reader) -> Result { - let header = reader.decode::()?; - if header.version != QL_WIRE_VERSION { - return Err(WireError::InvalidPayload); - } - if header.record_type != RecordType::Session { - return Err(WireError::InvalidPayload); - } Ok(Self { header: reader.decode()?, payload: reader.decode()?, diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 1fdba3f2..ad868b71 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -10,6 +10,15 @@ struct TestCrypto { counter: AtomicU64, } +fn decode_handshake_record(bytes: &[u8]) -> QlHandshakeRecord { + decode_record(bytes).unwrap().1 +} + +fn decode_session_record(bytes: &[u8]) -> QlSessionRecord> { + let (_, record) = decode_record::, _>(bytes).unwrap(); + record.into_owned() +} + impl TestCrypto { fn new(seed: u64) -> Self { Self { @@ -198,13 +207,11 @@ fn encrypt_record( let pushed = builder.push_frame(frame); debug_assert!(pushed); } - QlSessionRecord::decode_exact( + decode_session_record( builder .encrypt(crypto, header.connection_id, session_key) .as_slice(), ) - .unwrap() - .into_owned() } #[test] @@ -231,7 +238,7 @@ fn handshake_record_round_trip_supports_ik_and_kk() { }, static_bundle: EncryptedPeerBundle::new(Box::new([13; EncryptedPeerBundle::WIRE_SIZE])), }); - let ik_encoded = ik.encode_vec(); + let ik_encoded = encode_record_vec(RecordType::Handshake, &ik); assert_eq!( RecordHeader::decode_bytes(ik_encoded.as_slice()).unwrap(), RecordHeader { @@ -239,10 +246,7 @@ fn handshake_record_round_trip_supports_ik_and_kk() { record_type: RecordType::Handshake, } ); - assert_eq!( - QlHandshakeRecord::decode_exact(ik_encoded.as_slice()).unwrap(), - ik - ); + assert_eq!(decode_handshake_record(ik_encoded.as_slice()), ik); let kk = QlHandshakeRecord::Kk1(Kk1 { header: handshake_header(1, 2), @@ -253,7 +257,7 @@ fn handshake_record_round_trip_supports_ik_and_kk() { mlkem_public_key: MlKemPublicKey::new(Box::new([15; MlKemPublicKey::SIZE])), }, }); - let kk_encoded = kk.encode_vec(); + let kk_encoded = encode_record_vec(RecordType::Handshake, &kk); assert_eq!( RecordHeader::decode_bytes(kk_encoded.as_slice()).unwrap(), RecordHeader { @@ -261,10 +265,7 @@ fn handshake_record_round_trip_supports_ik_and_kk() { record_type: RecordType::Handshake, } ); - assert_eq!( - QlHandshakeRecord::decode_exact(kk_encoded.as_slice()).unwrap(), - kk - ); + assert_eq!(decode_handshake_record(kk_encoded.as_slice()), kk); } #[test] @@ -686,7 +687,7 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { let session_key = SessionKey::from_data([7; SessionKey::SIZE]); let record = encrypt_record(&crypto, header, &session_key, &body); - let bytes = record.encode_vec(); + let bytes = encode_record_vec(RecordType::Session, &record); assert_eq!( RecordHeader::decode_bytes(bytes.as_slice()).unwrap(), RecordHeader { @@ -694,9 +695,7 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { record_type: RecordType::Session, } ); - let decoded = QlSessionRecord::decode_exact(bytes.as_slice()) - .unwrap() - .into_owned(); + let decoded = decode_session_record(bytes.as_slice()); assert_eq!(decoded.header, header); let encrypted = decoded.payload; From 3e65fd886236777f42caae23a27448e75ee87657 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 6 Apr 2026 18:35:05 -0400 Subject: [PATCH 127/304] ql-fsm: use RangeSet for RemoteStreamHistory --- ql-fsm/src/session/remote_stream_history.rs | 49 ++++++++++++--------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/ql-fsm/src/session/remote_stream_history.rs b/ql-fsm/src/session/remote_stream_history.rs index f851d0a5..9b15ecce 100644 --- a/ql-fsm/src/session/remote_stream_history.rs +++ b/ql-fsm/src/session/remote_stream_history.rs @@ -1,22 +1,18 @@ -use std::collections::BTreeSet; - use ql_wire::StreamId; -use super::stream_parity::StreamParity; +use super::{range_set::RangeSet, stream_parity::StreamParity}; #[derive(Debug)] pub struct RemoteStreamHistory { parity: StreamParity, - seen_prefix_end: u32, - seen_sparse: BTreeSet, + seen: RangeSet, } impl RemoteStreamHistory { pub fn new(parity: StreamParity) -> Self { Self { parity, - seen_prefix_end: 0, - seen_sparse: BTreeSet::new(), + seen: RangeSet::new(), } } @@ -26,27 +22,38 @@ impl RemoteStreamHistory { let ordinal = self .stream_ordinal(stream_id) .expect("remote stream history used with wrong stream parity"); - if ordinal < self.seen_prefix_end { - return true; - } - if ordinal > self.seen_prefix_end { - return !self.seen_sparse.insert(ordinal); - } - - self.seen_prefix_end = self.seen_prefix_end.saturating_add(1); - while self.seen_sparse.remove(&self.seen_prefix_end) { - self.seen_prefix_end = self.seen_prefix_end.saturating_add(1); - } - false + !self.seen.insert(ordinal..ordinal + 1) } - fn stream_ordinal(&self, stream_id: StreamId) -> Option { + fn stream_ordinal(&self, stream_id: StreamId) -> Option { let delta = stream_id .into_inner() .checked_sub(u64::from(self.parity.first_stream_id()))?; if delta % 2 != 0 { return None; } - u32::try_from(delta / 2).ok() + Some(delta / 2) + } +} + +#[cfg(test)] +mod tests { + use super::RemoteStreamHistory; + use crate::session::stream_parity::StreamParity; + + #[test] + fn observe() { + let parity = StreamParity::Even; + let mut history = RemoteStreamHistory::new(parity); + + assert!(!history.observe(parity.make_stream_id(2))); + assert!(!history.observe(parity.make_stream_id(5))); + assert!(!history.observe(parity.make_stream_id(0))); + assert!(!history.observe(parity.make_stream_id(4))); + assert!(history.observe(parity.make_stream_id(2))); + assert!(!history.observe(parity.make_stream_id(1))); + assert!(history.observe(parity.make_stream_id(5))); + assert!(!history.observe(parity.make_stream_id(3))); + assert!(history.observe(parity.make_stream_id(0))); } } From ebd2d8a2e60a9d5a65da40a9c9cdb6c350da4e94 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 6 Apr 2026 18:53:33 -0400 Subject: [PATCH 128/304] ql: cleanup & fix clippy --- ql-fsm/src/session/remote_stream_history.rs | 3 +- ql-fsm/src/session/stream_rx.rs | 55 +---------------- ql-fsm/src/tests/proptest.rs | 65 +++++++++++---------- ql-fsm/src/tests/session.rs | 40 ++++++++++--- ql-runtime/src/driver/mod.rs | 3 +- ql-wire/src/bytes.rs | 2 +- ql-wire/src/codec.rs | 62 ++------------------ ql-wire/src/encrypted/ack.rs | 2 +- ql-wire/src/encrypted/mod.rs | 18 ++---- ql-wire/src/encrypted/stream_close.rs | 3 +- ql-wire/src/encrypted_message.rs | 2 +- ql-wire/src/varint.rs | 55 +++++++++++++++++ 12 files changed, 141 insertions(+), 169 deletions(-) diff --git a/ql-fsm/src/session/remote_stream_history.rs b/ql-fsm/src/session/remote_stream_history.rs index 9b15ecce..76c1e8bb 100644 --- a/ql-fsm/src/session/remote_stream_history.rs +++ b/ql-fsm/src/session/remote_stream_history.rs @@ -17,7 +17,8 @@ impl RemoteStreamHistory { } /// returns true when this remote stream id was already observed before - /// panics if stream_id is wrong stream parity + /// panics if `stream_id` is wrong stream parity + #[allow(clippy::range_plus_one)] pub fn observe(&mut self, stream_id: StreamId) -> bool { let ordinal = self .stream_ordinal(stream_id) diff --git a/ql-fsm/src/session/stream_rx.rs b/ql-fsm/src/session/stream_rx.rs index ee58b593..077f1204 100644 --- a/ql-fsm/src/session/stream_rx.rs +++ b/ql-fsm/src/session/stream_rx.rs @@ -47,8 +47,9 @@ impl StreamRx { pub fn buffered_end_offset(&self) -> u64 { self.chunks .last_key_value() - .map(|(&offset, bytes)| offset + bytes.len() as u64) - .unwrap_or(self.start_offset) + .map_or(self.start_offset, |(&offset, bytes)| { + offset + bytes.len() as u64 + }) } pub fn max_buffered(&self) -> usize { @@ -121,8 +122,6 @@ impl StreamRx { let effective_end = effective_offset + bytes.len() as u64; self.ensure_within_window(effective_end)?; - #[cfg(test)] - self.assert_valid_overlap(effective_offset, &bytes); self.insert_chunk(effective_offset, bytes); Ok(self.insert_outcome(was_complete, old_readable)) @@ -233,45 +232,6 @@ impl StreamRx { self.chunks.insert(offset, bytes); } - - #[cfg(test)] - fn assert_valid_overlap(&self, offset: u64, bytes: &Bytes) { - if let Some((&existing_offset, existing)) = self.chunks.range(..offset).next_back() { - self.assert_overlap_chunk(offset, bytes, existing_offset, existing); - } - let end = offset + bytes.len() as u64; - for (&existing_offset, existing) in self.chunks.range(offset..end) { - self.assert_overlap_chunk(offset, bytes, existing_offset, existing); - } - } - - #[cfg(test)] - fn assert_overlap_chunk( - &self, - offset: u64, - bytes: &Bytes, - existing_offset: u64, - existing: &Bytes, - ) { - let end = offset + bytes.len() as u64; - let existing_end = existing_offset + existing.len() as u64; - let overlap_start = offset.max(existing_offset); - let overlap_end = end.min(existing_end); - if overlap_start >= overlap_end { - return; - } - - let start = usize::try_from(overlap_start - offset).expect("overlap start exceeds usize"); - let existing_start = usize::try_from(overlap_start - existing_offset) - .expect("existing overlap start exceeds usize"); - let len = usize::try_from(overlap_end - overlap_start).expect("overlap exceeds usize"); - - assert_eq!( - &bytes[start..start + len], - &existing[existing_start..existing_start + len], - "conflicting overlap at stream offset {overlap_start}" - ); - } } #[derive(Debug, Clone)] @@ -390,15 +350,6 @@ mod tests { assert_eq!(copy_readable(&rx), b"hello"); } - #[test] - #[should_panic(expected = "conflicting overlap at stream offset 3")] - fn conflicting_overlap_panics_in_test_builds() { - let mut rx = StreamRx::new(64); - - rx.insert(0, false, bytes(b"abcdef")).unwrap(); - rx.insert(3, false, bytes(b"xyz")).unwrap(); - } - #[test] fn consume_advances_start_offset_and_trims_old_prefix() { let mut rx = StreamRx::new(64); diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index 55b1064b..92102bbc 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -3,8 +3,10 @@ use std::{ time::Duration, }; +extern crate proptest as proptest_crate; + use bytes::Bytes; -use ::proptest::{collection::vec, prelude::*, test_runner::TestCaseResult}; +use proptest_crate::{collection::vec, prelude::*, test_runner::TestCaseResult}; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; use super::*; @@ -191,6 +193,7 @@ impl Runner { self.assert_quiesced() } + #[allow(clippy::cognitive_complexity, clippy::too_many_lines)] fn apply(&mut self, action: &Action) { match action { Action::ConnectIkA => { @@ -314,7 +317,7 @@ impl Runner { } Action::WriteA { slot, bytes } => { if let Some(stream_id) = self.slots_a[*slot] { - let mut chunk = Bytes::from(bytes.clone()); + let mut chunk = Bytes::copy_from_slice(bytes); if let Ok(accepted) = self.harness.a.fsm.write_stream(stream_id, &mut chunk) { self.expected_at_b .entry(stream_id) @@ -325,7 +328,7 @@ impl Runner { } Action::WriteB { slot, bytes } => { if let Some(stream_id) = self.slots_b[*slot] { - let mut chunk = Bytes::from(bytes.clone()); + let mut chunk = Bytes::copy_from_slice(bytes); if let Ok(accepted) = self.harness.b.fsm.write_stream(stream_id, &mut chunk) { self.expected_at_a .entry(stream_id) @@ -422,7 +425,7 @@ impl Runner { } fn drain_reads_a(&mut self) { - for stream_id in self.known_streams.iter().copied().collect::>() { + for stream_id in self.known_streams.clone() { let appended = drain_stream(&mut self.harness.a.fsm, stream_id); if !appended.is_empty() { self.received_at_a @@ -434,7 +437,7 @@ impl Runner { } fn drain_reads_b(&mut self) { - for stream_id in self.known_streams.iter().copied().collect::>() { + for stream_id in self.known_streams.clone() { let appended = drain_stream(&mut self.harness.b.fsm, stream_id); if !appended.is_empty() { self.received_at_b @@ -528,8 +531,7 @@ impl Runner { let expected = self .expected_at_a .get(stream_id) - .map(Vec::as_slice) - .unwrap_or(&[]); + .map_or(&[][..], Vec::as_slice); prop_assert!( expected.starts_with(received), "side A observed non-prefix bytes on {stream_id:?}: received={received:?} expected={expected:?}" @@ -540,8 +542,7 @@ impl Runner { let expected = self .expected_at_b .get(stream_id) - .map(Vec::as_slice) - .unwrap_or(&[]); + .map_or(&[][..], Vec::as_slice); prop_assert!( expected.starts_with(received), "side B observed non-prefix bytes on {stream_id:?}: received={received:?} expected={expected:?}" @@ -587,19 +588,17 @@ impl Runner { fn assert_terminal_semantics(&self) -> TestCaseResult { for stream_id in &self.events_a.finished { - if self.inbound_aborted(Side::A, stream_id) { + if self.inbound_aborted(Side::A, *stream_id) { continue; } let expected = self .expected_at_a .get(stream_id) - .map(Vec::as_slice) - .unwrap_or(&[]); + .map_or(&[][..], Vec::as_slice); let received = self .received_at_a .get(stream_id) - .map(Vec::as_slice) - .unwrap_or(&[]); + .map_or(&[][..], Vec::as_slice); prop_assert_eq!( received, expected, @@ -609,19 +608,17 @@ impl Runner { } for stream_id in &self.events_b.finished { - if self.inbound_aborted(Side::B, stream_id) { + if self.inbound_aborted(Side::B, *stream_id) { continue; } let expected = self .expected_at_b .get(stream_id) - .map(Vec::as_slice) - .unwrap_or(&[]); + .map_or(&[][..], Vec::as_slice); let received = self .received_at_b .get(stream_id) - .map(Vec::as_slice) - .unwrap_or(&[]); + .map_or(&[][..], Vec::as_slice); prop_assert_eq!( received, expected, @@ -776,11 +773,11 @@ impl Runner { } } - fn inbound_aborted(&self, side: Side, stream_id: &StreamId) -> bool { - self.events(side).closed.contains(stream_id) + fn inbound_aborted(&self, side: Side, stream_id: StreamId) -> bool { + self.events(side).closed.contains(&stream_id) || match side { - Side::A => self.closed_by_a.contains(stream_id), - Side::B => self.closed_by_b.contains(stream_id), + Side::A => self.closed_by_a.contains(&stream_id), + Side::B => self.closed_by_b.contains(&stream_id), } } } @@ -819,13 +816,19 @@ fn take_confirmed_outbound_b(harness: &mut Harness) -> Option> { fn confirm_taken_a(harness: &mut Harness, write: &TakenWrite) { if let Some(write_id) = write.write_id { - harness.a.fsm.confirm_session_write(harness.time(), write_id); + harness + .a + .fsm + .confirm_session_write(harness.time(), write_id); } } fn confirm_taken_b(harness: &mut Harness, write: &TakenWrite) { if let Some(write_id) = write.write_id { - harness.b.fsm.confirm_session_write(harness.time(), write_id); + harness + .b + .fsm + .confirm_session_write(harness.time(), write_id); } } @@ -977,7 +980,7 @@ fn connected_action_strategy() -> impl Strategy { queue_index.clone().prop_map(Action::DuplicateQueuedAToB), queue_index.clone().prop_map(Action::DuplicateQueuedBToA), queue_index.clone().prop_map(Action::DropQueuedAToB), - queue_index.clone().prop_map(Action::DropQueuedBToA), + queue_index.prop_map(Action::DropQueuedBToA), slot.clone().prop_map(Action::OpenStreamA), slot.clone().prop_map(Action::OpenStreamB), (slot.clone(), bytes.clone()).prop_map(|(slot, bytes)| Action::WriteA { slot, bytes }), @@ -997,7 +1000,7 @@ fn write_tracking_action_strategy() -> impl Strategy { slot.clone().prop_map(Action::OpenStreamA), slot.clone().prop_map(Action::OpenStreamB), (slot.clone(), bytes.clone()).prop_map(|(slot, bytes)| Action::WriteA { slot, bytes }), - (slot.clone(), bytes).prop_map(|(slot, bytes)| Action::WriteB { slot, bytes }), + (slot, bytes).prop_map(|(slot, bytes)| Action::WriteB { slot, bytes }), Just(Action::TakeNextAToB), Just(Action::TakeNextBToA), queue_index.clone().prop_map(Action::ConfirmTakenAToB), @@ -1009,7 +1012,7 @@ fn write_tracking_action_strategy() -> impl Strategy { queue_index.clone().prop_map(Action::DuplicateQueuedAToB), queue_index.clone().prop_map(Action::DuplicateQueuedBToA), queue_index.clone().prop_map(Action::DropQueuedAToB), - queue_index.clone().prop_map(Action::DropQueuedBToA), + queue_index.prop_map(Action::DropQueuedBToA), Just(Action::Pump), Just(Action::OnTimerA), Just(Action::OnTimerB), @@ -1030,7 +1033,7 @@ fn terminal_action_strategy() -> impl Strategy { slot.clone().prop_map(Action::FinishA), slot.clone().prop_map(Action::FinishB), slot.clone().prop_map(Action::CloseA), - slot.clone().prop_map(Action::CloseB), + slot.prop_map(Action::CloseB), Just(Action::TakeNextAToB), Just(Action::TakeNextBToA), queue_index.clone().prop_map(Action::ConfirmTakenAToB), @@ -1042,7 +1045,7 @@ fn terminal_action_strategy() -> impl Strategy { queue_index.clone().prop_map(Action::DuplicateQueuedAToB), queue_index.clone().prop_map(Action::DuplicateQueuedBToA), queue_index.clone().prop_map(Action::DropQueuedAToB), - queue_index.clone().prop_map(Action::DropQueuedBToA), + queue_index.prop_map(Action::DropQueuedBToA), Just(Action::Pump), Just(Action::OnTimerA), Just(Action::OnTimerB), @@ -1051,7 +1054,7 @@ fn terminal_action_strategy() -> impl Strategy { ] } -proptest! { +proptest_crate::proptest! { #![proptest_config(ProptestConfig { cases: 24, max_shrink_iters: 10_000, diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index f9c281d9..15ebe956 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -40,7 +40,10 @@ fn connected_fsms_deliver_stream_data() { let mut harness = Harness::connected(QlFsmConfig::default()); let stream_id = harness.a.fsm.open_stream().unwrap(); - assert_eq!(write_stream_bytes(&mut harness.a.fsm, stream_id, b"hello").unwrap(), 5); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"hello").unwrap(), + 5 + ); harness.a.fsm.finish_stream(stream_id).unwrap(); harness.pump(); @@ -66,7 +69,10 @@ fn session_retransmit_uses_new_record_seq() { let mut harness = Harness::connected(config); let stream_id = harness.a.fsm.open_stream().unwrap(); - assert_eq!(write_stream_bytes(&mut harness.a.fsm, stream_id, b"retry").unwrap(), 5); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"retry").unwrap(), + 5 + ); let first = harness.next_outbound_a().unwrap(); let first_transport = harness.b.fsm.state.link.transport().unwrap().clone(); @@ -201,7 +207,10 @@ fn returned_session_write_is_reissued_with_new_record_seq() { let mut harness = Harness::connected(QlFsmConfig::default()); let stream_id = harness.a.fsm.open_stream().unwrap(); - assert_eq!(write_stream_bytes(&mut harness.a.fsm, stream_id, b"retry").unwrap(), 5); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"retry").unwrap(), + 5 + ); let write = harness.next_write_a().unwrap(); let id = write.session_write_id.expect("expected session write"); @@ -241,7 +250,10 @@ fn unconfirmed_session_write_does_not_start_retransmit_timer() { let mut harness = Harness::connected(config); let stream_id = harness.a.fsm.open_stream().unwrap(); - assert_eq!(write_stream_bytes(&mut harness.a.fsm, stream_id, b"retry").unwrap(), 5); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"retry").unwrap(), + 5 + ); let write = harness.next_write_a().unwrap(); let id = write.session_write_id.expect("expected session write"); @@ -274,8 +286,14 @@ fn ack_frame_releases_stream_capacity_and_emits_writable() { let mut harness = Harness::connected(config); let stream_id = harness.a.fsm.open_stream().unwrap(); - assert_eq!(write_stream_bytes(&mut harness.a.fsm, stream_id, b"abcd").unwrap(), 4); - assert_eq!(write_stream_bytes(&mut harness.a.fsm, stream_id, b"z").unwrap(), 0); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"abcd").unwrap(), + 4 + ); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"z").unwrap(), + 0 + ); let record = harness.next_outbound_a().unwrap(); harness.deliver_to_b(record); @@ -309,7 +327,10 @@ fn session_records_contain_ack_frames_after_delivery() { let mut harness = Harness::connected(config); let stream_id = harness.a.fsm.open_stream().unwrap(); - assert_eq!(write_stream_bytes(&mut harness.a.fsm, stream_id, b"x").unwrap(), 1); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"x").unwrap(), + 1 + ); let data = harness.next_outbound_a().unwrap(); harness.deliver_to_b(data); @@ -345,7 +366,10 @@ fn first_stream_data_uses_negotiated_initial_peer_credit() { harness.deliver_to_a(ik2); let stream_id = harness.a.fsm.open_stream().unwrap(); - assert_eq!(write_stream_bytes(&mut harness.a.fsm, stream_id, b"hello").unwrap(), 5); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"hello").unwrap(), + 5 + ); let data = harness.next_outbound_a().unwrap(); let session_key = harness.b.fsm.state.link.transport().unwrap().rx_key.clone(); diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 65bfe203..d8a25ae7 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -520,8 +520,7 @@ impl DriverState { } else { let len = bytes.len(); let mut bytes = ql_fsm::Bytes::copy_from_slice(bytes); - let accepted = - fsm.write_stream(stream_id, &mut bytes).unwrap_or_default(); + let accepted = fsm.write_stream(stream_id, &mut bytes).unwrap_or_default(); if accepted > 0 { reader.consume(accepted); } diff --git a/ql-wire/src/bytes.rs b/ql-wire/src/bytes.rs index 7a6f4e25..c8243e12 100644 --- a/ql-wire/src/bytes.rs +++ b/ql-wire/src/bytes.rs @@ -116,7 +116,7 @@ impl ByteChunks for Bytes { Self: 'a; fn len(&self) -> usize { - Bytes::len(self) + Self::len(self) } fn chunks(&self) -> Self::Chunks<'_> { diff --git a/ql-wire/src/codec.rs b/ql-wire/src/codec.rs index 7062c53a..27e06819 100644 --- a/ql-wire/src/codec.rs +++ b/ql-wire/src/codec.rs @@ -1,6 +1,6 @@ -use ::bytes::BufMut; +use bytes::BufMut; -use crate::{ByteSlice, VarInt, WireError}; +use crate::{ByteSlice, WireError}; pub trait WireEncode { fn encoded_len(&self) -> usize; @@ -56,7 +56,7 @@ impl WireEncode for [u8; N] { impl WireDecode for Box<[u8; N]> { fn decode(reader: &mut Reader) -> Result { let bytes = reader.take_bytes(N)?; - let mut out = Box::<[u8; N]>::new_uninit(); + let mut out = Self::new_uninit(); let src = bytes.as_ptr(); let dst = out.as_mut_ptr().cast::(); // SAFETY: `take_bytes(N)` guarantees the source has exactly `N` bytes. @@ -105,7 +105,7 @@ impl WireEncode for u8 { impl WireDecode for u16 { fn decode(reader: &mut Reader) -> Result { - Ok(u16::from_be_bytes(reader.decode()?)) + Ok(Self::from_be_bytes(reader.decode()?)) } } @@ -121,7 +121,7 @@ impl WireEncode for u16 { impl WireDecode for u32 { fn decode(reader: &mut Reader) -> Result { - Ok(u32::from_be_bytes(reader.decode()?)) + Ok(Self::from_be_bytes(reader.decode()?)) } } @@ -137,7 +137,7 @@ impl WireEncode for u32 { impl WireDecode for u64 { fn decode(reader: &mut Reader) -> Result { - Ok(u64::from_be_bytes(reader.decode()?)) + Ok(Self::from_be_bytes(reader.decode()?)) } } @@ -151,56 +151,6 @@ impl WireEncode for u64 { } } -impl WireDecode for VarInt { - fn decode(reader: &mut Reader) -> Result { - let first = reader.decode::()?; - let tag = first >> 6; - let first = first & 0b0011_1111; - let value = match tag { - 0b00 => u64::from(first), - 0b01 => { - let mut buf = [0; 2]; - buf[0] = first; - buf[1] = reader.decode()?; - u64::from(u16::from_be_bytes(buf)) - } - 0b10 => { - let mut buf = [0; 4]; - buf[0] = first; - buf[1..].copy_from_slice(&reader.decode::<[u8; 3]>()?); - u64::from(u32::from_be_bytes(buf)) - } - 0b11 => { - let mut buf = [0; 8]; - buf[0] = first; - buf[1..].copy_from_slice(&reader.decode::<[u8; 7]>()?); - u64::from_be_bytes(buf) - } - _ => unreachable!(), - }; - - // SAFETY: the decoded value is guaranteed to fit in the 62-bit varint range. - Ok(unsafe { VarInt::from_u64_unchecked(value) }) - } -} - -impl WireEncode for VarInt { - fn encoded_len(&self) -> usize { - self.size() - } - - fn encode(&self, out: &mut W) { - let x = self.into_inner(); - match self.size() { - 1 => out.put_u8(x as u8), - 2 => out.put_u16((0b01 << 14) | x as u16), - 4 => out.put_u32((0b10 << 30) | x as u32), - 8 => out.put_u64((0b11 << 62) | x), - _ => unreachable!("malformed varint"), - } - } -} - impl WireDecode for bool { fn decode(reader: &mut Reader) -> Result { match reader.decode::()? { diff --git a/ql-wire/src/encrypted/ack.rs b/ql-wire/src/encrypted/ack.rs index 0e7600e1..adceb0d7 100644 --- a/ql-wire/src/encrypted/ack.rs +++ b/ql-wire/src/encrypted/ack.rs @@ -46,7 +46,7 @@ impl codec::WireDecode for RecordAck { #[cfg(test)] mod tests { use super::RecordAck; - use crate::{RecordSeq, WireEncode, WireError, WireDecode}; + use crate::{RecordSeq, WireDecode, WireEncode, WireError}; #[test] fn encode_decode_round_trip() { diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index e458986e..b3f7f2db 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -2,7 +2,6 @@ use crate::{ codec, encrypted_message::EncryptedMessage, ByteChunks, ByteSlice, Nonce, QlCrypto, SessionHeader, SessionKey, VarInt, VarIntBoundsExceeded, WireDecode, WireEncode, WireError, }; -use bytes::Bytes; mod ack; mod builder; @@ -57,7 +56,7 @@ impl codec::WireDecode for StreamId { #[derive(Debug, Clone, PartialEq, Eq)] pub struct SessionRecord { - pub frames: Vec, + pub frames: Vec>>, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -70,11 +69,6 @@ pub enum SessionFrame { Close(SessionClose), } -pub type SessionFrameVec = SessionFrame>; -pub type StreamDataVec = StreamData>; -pub type SessionFrameBytes = SessionFrame; -pub type StreamDataBytes = StreamData; - #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub(crate) enum SessionFrameKind { @@ -142,7 +136,7 @@ impl SessionFrame { } impl SessionFrame { - pub fn into_owned(self) -> SessionFrameVec { + pub fn into_owned(self) -> SessionFrame> { match self { Self::Ping => SessionFrame::Ping, Self::Ack(frame) => SessionFrame::Ack(frame), @@ -221,9 +215,7 @@ impl Iterator for SessionFrameIter { self.remaining = Some(rest); Some(Ok(frame)) } - Err(error) => { - Some(Err(error)) - } + Err(error) => Some(Err(error)), } } } @@ -254,9 +246,7 @@ fn parse_next_frame(bytes: B) -> Result<(SessionFrame, B), Wire SessionFrameKind::StreamWindow => { SessionFrame::StreamWindow(reader.decode::()?) } - SessionFrameKind::StreamClose => { - SessionFrame::StreamClose(reader.decode::()?) - } + SessionFrameKind::StreamClose => SessionFrame::StreamClose(reader.decode::()?), SessionFrameKind::Close => SessionFrame::Close(reader.decode::()?), }; Ok((frame, reader.take_rest())) diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index 5c4a06a3..2fc03eb9 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -13,8 +13,7 @@ pub struct StreamClose { pub code: StreamCloseCode, } -impl StreamClose { -} +impl StreamClose {} impl WireEncode for StreamClose { fn encoded_len(&self) -> usize { diff --git a/ql-wire/src/encrypted_message.rs b/ql-wire/src/encrypted_message.rs index 293b5773..9e11d3d0 100644 --- a/ql-wire/src/encrypted_message.rs +++ b/ql-wire/src/encrypted_message.rs @@ -1,5 +1,5 @@ use crate::{ - codec, ByteSlice, Nonce, QlCrypto, SessionKey, WireEncode, WireError, WireDecode, + codec, ByteSlice, Nonce, QlCrypto, SessionKey, WireDecode, WireEncode, WireError, ENCRYPTED_MESSAGE_AUTH_SIZE, }; diff --git a/ql-wire/src/varint.rs b/ql-wire/src/varint.rs index 0fc06d03..7a39bd16 100644 --- a/ql-wire/src/varint.rs +++ b/ql-wire/src/varint.rs @@ -1,5 +1,9 @@ use core::fmt; +use bytes::BufMut; + +use crate::{ByteSlice, Reader, WireDecode, WireEncode, WireError}; + /// An integer less than 2^62 encoded with QUIC variable-length integer rules. #[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct VarInt(pub(crate) u64); @@ -124,3 +128,54 @@ impl fmt::Display for VarIntBoundsExceeded { } impl std::error::Error for VarIntBoundsExceeded {} + +impl WireDecode for VarInt { + fn decode(reader: &mut Reader) -> Result { + let first = reader.decode::()?; + let tag = first >> 6; + let first = first & 0b0011_1111; + let value = match tag { + 0b00 => u64::from(first), + 0b01 => { + let mut buf = [0; 2]; + buf[0] = first; + buf[1] = reader.decode()?; + u64::from(u16::from_be_bytes(buf)) + } + 0b10 => { + let mut buf = [0; 4]; + buf[0] = first; + buf[1..].copy_from_slice(&reader.take_bytes(3)?); + u64::from(u32::from_be_bytes(buf)) + } + 0b11 => { + let mut buf = [0; 8]; + buf[0] = first; + buf[1..].copy_from_slice(&reader.take_bytes(7)?); + u64::from_be_bytes(buf) + } + _ => unreachable!(), + }; + + // SAFETY: the decoded value is guaranteed to fit in the 62-bit varint range. + Ok(unsafe { Self::from_u64_unchecked(value) }) + } +} + +impl WireEncode for VarInt { + fn encoded_len(&self) -> usize { + self.size() + } + + #[allow(clippy::cast_possible_truncation)] + fn encode(&self, out: &mut W) { + let x = self.into_inner(); + match self.size() { + 1 => out.put_u8(x as u8), + 2 => out.put_u16((0b01 << 14) | x as u16), + 4 => out.put_u32((0b10 << 30) | x as u32), + 8 => out.put_u64((0b11 << 62) | x), + _ => unreachable!("malformed varint"), + } + } +} From 3d6e754a445497cb3a9625757c55f65c348e9bf6 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 6 Apr 2026 19:36:12 -0400 Subject: [PATCH 129/304] ql: remove SessionRecord struct + get rid of useless helpers on StreamId --- ql-fsm/src/implementation/core.rs | 2 +- ql-fsm/src/session/stream_parity.rs | 4 +- ql-fsm/src/session/tests.rs | 175 ++++++++++------------- ql-fsm/src/tests/mod.rs | 4 +- ql-fsm/src/tests/session.rs | 14 +- ql-wire/src/encrypted/builder.rs | 2 +- ql-wire/src/encrypted/mod.rs | 210 ++++++++++------------------ ql-wire/src/encrypted/stream_id.rs | 29 ++++ ql-wire/src/tests.rs | 94 ++++++------- 9 files changed, 234 insertions(+), 300 deletions(-) create mode 100644 ql-wire/src/encrypted/stream_id.rs diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index 33decd1d..b7347553 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -52,7 +52,7 @@ pub fn receive( let len = bytes.len(); let plaintext = Bytes::from(bytes).slice(len - decrypt_len..); - let frames = wire::SessionRecord::parse(plaintext)?; + let frames = wire::parse_session_frames(plaintext); let mut session_closed = false; state diff --git a/ql-fsm/src/session/stream_parity.rs b/ql-fsm/src/session/stream_parity.rs index 1fb498e0..87c9ef33 100644 --- a/ql-fsm/src/session/stream_parity.rs +++ b/ql-fsm/src/session/stream_parity.rs @@ -36,9 +36,9 @@ impl StreamParity { } pub fn make_stream_id(self, ordinal: u32) -> StreamId { - StreamId::from_u32( + StreamId(ql_wire::VarInt::from_u32( self.first_stream_id() .saturating_add(ordinal.saturating_mul(2)), - ) + )) } } diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index e06b4837..ed6636f6 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -2,8 +2,8 @@ use std::time::{Duration, Instant}; use bytes::Bytes; use ql_wire::{ - CloseTarget, RecordAck, RecordSeq, SessionFrame, SessionRecord, SessionRecordBuilder, - StreamClose, StreamCloseCode, StreamData, StreamId, VarInt, XID, + decode_session_frames, parse_session_frames, CloseTarget, RecordAck, RecordSeq, SessionFrame, + SessionRecordBuilder, StreamClose, StreamCloseCode, StreamData, StreamId, VarInt, XID, }; use super::{SessionEvent, SessionFsm, SessionFsmConfig}; @@ -14,7 +14,7 @@ fn seq(value: u64) -> RecordSeq { } fn stream_id(value: u64) -> StreamId { - StreamId::from_u64(value).unwrap() + StreamId(VarInt::from_u64(value).unwrap()) } fn offset(value: u64) -> VarInt { @@ -42,14 +42,17 @@ fn read_stream_all(fsm: &mut SessionFsm, stream_id: StreamId) -> Vec { out } -fn next_outbound(fsm: &mut SessionFsm, now: Instant) -> Option<(RecordSeq, SessionRecord)> { +fn next_outbound( + fsm: &mut SessionFsm, + now: Instant, +) -> Option<(RecordSeq, Vec>>)> { let (write_id, builder) = fsm.take_next_write(now)?; if let Some(write_id) = write_id { fsm.confirm_write(now, write_id); } Some(( builder.seq(), - SessionRecord::decode(builder.bytes()).unwrap(), + decode_session_frames(builder.bytes()).unwrap(), )) } @@ -57,14 +60,14 @@ fn receive_events( fsm: &mut SessionFsm, now: Instant, seq: RecordSeq, - record: &SessionRecord, + record: &[SessionFrame>], ) -> Vec { let mut builder = SessionRecordBuilder::new(seq, usize::MAX); - for frame in &record.frames { + for frame in record { assert!(builder.push_frame(frame)); } let bytes = Bytes::from(builder.bytes().to_vec()); - let frames = SessionRecord::parse(bytes).unwrap(); + let frames = parse_session_frames(bytes); let mut events = Vec::new(); fsm.receive(now, seq, frames, |event| events.push(event)); events @@ -99,7 +102,7 @@ fn retransmit_uses_new_record_seq() { let (retried_seq, retried) = next_outbound(&mut fsm, now + Duration::from_millis(200)).unwrap(); assert_ne!(first_seq, retried_seq); - assert_eq!(first.frames, retried.frames); + assert_eq!(first, retried); } #[test] @@ -123,7 +126,7 @@ fn lost_record_on_one_stream_does_not_block_another_stream() { let (first_seq, first) = next_outbound(&mut fsm, now).unwrap(); let (second_seq, _second) = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); assert_ne!(first_seq, second_seq); - assert!(first.frames.iter().any( + assert!(first.iter().any( |frame| matches!(frame, SessionFrame::StreamData(frame) if frame.stream_id == stream_id_a) )); @@ -131,7 +134,6 @@ fn lost_record_on_one_stream_does_not_block_another_stream() { let (_third_seq, third) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); let stream_ids: Vec<_> = third - .frames .iter() .filter_map(|frame| match frame { SessionFrame::StreamData(frame) => Some(frame.stream_id), @@ -183,14 +185,12 @@ fn commit_stream_read_is_what_advances_stream_window() { now, ); let stream_id = stream_id(1); - let data = SessionRecord { - frames: vec![SessionFrame::StreamData(StreamData { - stream_id, - offset: offset(0), - fin: false, - bytes: b"hi".to_vec(), - })], - }; + let data = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + fin: false, + bytes: b"hi".to_vec(), + })]; let events = receive_events(&mut fsm, now, seq(7), &data); assert_eq!( events, @@ -201,9 +201,9 @@ fn commit_stream_read_is_what_advances_stream_window() { ); let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); - let first = SessionRecord::decode(builder.bytes()).unwrap(); + let first = decode_session_frames(builder.bytes()).unwrap(); assert!(write_id.is_none()); - assert!(matches!(first.frames.as_slice(), [SessionFrame::Ack(_)])); + assert!(matches!(first.as_slice(), [SessionFrame::Ack(_)])); let read = fsm .stream_read(stream_id) @@ -217,7 +217,7 @@ fn commit_stream_read_is_what_advances_stream_window() { fsm.stream_read_commit(stream_id, 2).unwrap(); let (_second_seq, second) = next_outbound(&mut fsm, now + Duration::from_millis(3)).unwrap(); assert!(matches!( - second.frames.as_slice(), + second.as_slice(), [SessionFrame::StreamWindow(window)] if window.stream_id == stream_id )); } @@ -232,21 +232,19 @@ fn pure_ack_only_records_are_fire_and_forget() { let retransmit_timeout = config.retransmit_timeout; let mut fsm = SessionFsm::new(config, now); let stream_id = stream_id(1); - let record = SessionRecord { - frames: vec![SessionFrame::StreamData(StreamData { - stream_id, - offset: offset(0), - fin: false, - bytes: b"hi".to_vec(), - })], - }; + let record = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + fin: false, + bytes: b"hi".to_vec(), + })]; let _ = receive_events(&mut fsm, now, seq(7), &record); let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); - let ack = SessionRecord::decode(builder.bytes()).unwrap(); + let ack = decode_session_frames(builder.bytes()).unwrap(); assert!(write_id.is_none()); - assert!(matches!(ack.frames.as_slice(), [SessionFrame::Ack(_)])); + assert!(matches!(ack.as_slice(), [SessionFrame::Ack(_)])); fsm.on_timer(now + retransmit_timeout + Duration::from_millis(1), |_| {}); assert!(fsm @@ -259,14 +257,12 @@ fn inbound_stream_data_emits_opened_and_readable() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = stream_id(1); - let record = SessionRecord { - frames: vec![SessionFrame::StreamData(ql_wire::StreamData { - stream_id, - offset: offset(0), - fin: true, - bytes: b"hello".to_vec(), - })], - }; + let record = vec![SessionFrame::StreamData(ql_wire::StreamData { + stream_id, + offset: offset(0), + fin: true, + bytes: b"hello".to_vec(), + })]; let events = receive_events(&mut fsm, now, seq(0), &record); assert_eq!( @@ -291,16 +287,16 @@ fn remote_stream_close_is_reliable_and_retried() { let (write_id, builder) = fsm.take_next_write(now).unwrap(); fsm.confirm_write(now, write_id.expect("stream close should be tracked")); - let first = SessionRecord::decode(builder.bytes()).unwrap(); + let first = decode_session_frames(builder.bytes()).unwrap(); assert!(matches!( - first.frames.as_slice(), + first.as_slice(), [SessionFrame::StreamClose(StreamClose { stream_id: id, .. })] if *id == stream_id )); fsm.on_timer(now + Duration::from_millis(200), |_| {}); let (_retried_seq, retried) = next_outbound(&mut fsm, now + Duration::from_millis(200)).unwrap(); - assert_eq!(first.frames, retried.frames); + assert_eq!(first, retried); } #[test] @@ -337,14 +333,12 @@ fn duplicate_stream_data_is_not_redelivered() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = stream_id(1); - let record = SessionRecord { - frames: vec![SessionFrame::StreamData(StreamData { - stream_id, - offset: offset(0), - fin: false, - bytes: b"hi".to_vec(), - })], - }; + let record = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + fin: false, + bytes: b"hi".to_vec(), + })]; let _ = receive_events(&mut fsm, now, seq(1), &record); let _ = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &record); @@ -360,9 +354,7 @@ fn duplicate_remote_close_after_reap_is_ignored() { target: CloseTarget::Both, code: StreamCloseCode(9), }; - let record = SessionRecord { - frames: vec![SessionFrame::StreamClose(close.clone())], - }; + let record = vec![SessionFrame::StreamClose(close.clone())]; let first = receive_events(&mut fsm, now, seq(1), &record); assert_eq!( @@ -383,14 +375,12 @@ fn duplicate_finished_remote_data_after_reap_is_ignored() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); let stream_id = stream_id(1); - let record = SessionRecord { - frames: vec![SessionFrame::StreamData(StreamData { - stream_id, - offset: offset(0), - fin: true, - bytes: b"hello".to_vec(), - })], - }; + let record = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + fin: true, + bytes: b"hello".to_vec(), + })]; let first = receive_events(&mut fsm, now, seq(1), &record); assert_eq!( @@ -411,20 +401,16 @@ fn duplicate_finished_remote_data_after_reap_is_ignored() { fn out_of_order_remote_stream_first_observations_still_open_once_each() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let close3 = SessionRecord { - frames: vec![SessionFrame::StreamClose(StreamClose { - stream_id: stream_id(3), - target: CloseTarget::Both, - code: StreamCloseCode(1), - })], - }; - let close1 = SessionRecord { - frames: vec![SessionFrame::StreamClose(StreamClose { - stream_id: stream_id(1), - target: CloseTarget::Both, - code: StreamCloseCode(2), - })], - }; + let close3 = vec![SessionFrame::StreamClose(StreamClose { + stream_id: stream_id(3), + target: CloseTarget::Both, + code: StreamCloseCode(1), + })]; + let close1 = vec![SessionFrame::StreamClose(StreamClose { + stream_id: stream_id(1), + target: CloseTarget::Both, + code: StreamCloseCode(2), + })]; let first = receive_events(&mut fsm, now, seq(1), &close3); assert!(first.contains(&SessionEvent::Opened(stream_id(3)))); @@ -447,14 +433,12 @@ fn close_does_not_ack_rejected_record_seq() { now, ); - let invalid = SessionRecord { - frames: vec![SessionFrame::StreamData(StreamData { - stream_id: stream_id(0), - offset: offset(0), - fin: false, - bytes: b"bad".to_vec(), - })], - }; + let invalid = vec![SessionFrame::StreamData(StreamData { + stream_id: stream_id(0), + offset: offset(0), + fin: false, + bytes: b"bad".to_vec(), + })]; let events = receive_events(&mut fsm, now, seq(7), &invalid); assert_eq!( events, @@ -463,9 +447,7 @@ fn close_does_not_ack_rejected_record_seq() { })] ); - let valid_after_close = SessionRecord { - frames: vec![SessionFrame::Ping], - }; + let valid_after_close = vec![SessionFrame::Ping]; let events = receive_events( &mut fsm, now + Duration::from_millis(1), @@ -475,10 +457,7 @@ fn close_does_not_ack_rejected_record_seq() { assert!(events.is_empty()); let (_seq, outbound) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); - assert!(matches!( - outbound.frames.as_slice(), - [SessionFrame::Close(_)] - )); + assert!(matches!(outbound.as_slice(), [SessionFrame::Close(_)])); } #[test] @@ -496,7 +475,7 @@ fn initial_peer_stream_receive_window_limits_first_send() { assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"hello"), 5); let (_first_seq, first) = next_outbound(&mut fsm, now).unwrap(); assert!(matches!( - first.frames.as_slice(), + first.as_slice(), [SessionFrame::StreamData(frame)] if frame.stream_id == stream_id && frame.bytes.as_slice() == b"hel" )); @@ -504,17 +483,15 @@ fn initial_peer_stream_receive_window_limits_first_send() { &mut fsm, now + Duration::from_millis(1), seq(9), - &SessionRecord { - frames: vec![SessionFrame::StreamWindow(ql_wire::StreamWindow { - stream_id, - maximum_offset: offset(5), - })], - }, + &[SessionFrame::StreamWindow(ql_wire::StreamWindow { + stream_id, + maximum_offset: offset(5), + })], ); assert!(events.is_empty()); let (_second_seq, second) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); - assert!(second.frames.iter().any(|frame| { + assert!(second.iter().any(|frame| { matches!( frame, SessionFrame::StreamData(frame) diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 40fbaf2a..6eb55ba7 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -442,7 +442,7 @@ fn decrypt_record( crypto: &impl QlCrypto, record: &[u8], session_key: &SessionKey, -) -> (ql_wire::SessionHeader, ql_wire::SessionRecord) { +) -> (ql_wire::SessionHeader, Vec>>) { let (_header, record) = ql_wire::decode_record::, _>(record).unwrap(); let plaintext = ql_wire::decrypt_record( @@ -454,7 +454,7 @@ fn decrypt_record( .unwrap(); ( record.header, - ql_wire::SessionRecord::decode(&plaintext).unwrap(), + ql_wire::decode_session_frames(&plaintext).unwrap(), ) } diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 15ebe956..e221a78a 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -1,13 +1,13 @@ use std::time::Duration; use bytes::Bytes; -use ql_wire::{SessionClose, StreamId}; +use ql_wire::{SessionClose, StreamId, VarInt}; use super::*; use crate::{state::LinkState, PeerStatus, QlFsmError, QlFsmEvent}; fn stream_id(value: u32) -> StreamId { - StreamId::from_u32(value) + StreamId(VarInt::from_u32(value)) } fn write_stream_bytes( @@ -87,7 +87,7 @@ fn session_retransmit_uses_new_record_seq() { decrypt_record(&harness.b.crypto, &retried, &first_transport.rx_key); assert_ne!(retried_header.seq, first_header.seq); - assert_eq!(retried_record.frames, first_record.frames); + assert_eq!(retried_record, first_record); harness.deliver_to_b(retried); harness.advance(config.session_record_ack_delay); @@ -227,7 +227,7 @@ fn returned_session_write_is_reissued_with_new_record_seq() { assert_ne!(reissued_id, id); assert_ne!(reissued_header.seq, first_header.seq); - assert_eq!(reissued.frames, first.frames); + assert_eq!(reissued, first); harness.confirm_write_a(reissued_id); harness.deliver_to_b(record); @@ -274,7 +274,7 @@ fn unconfirmed_session_write_does_not_start_retransmit_timer() { let (retried_header, retried) = decrypt_record(&harness.b.crypto, &record, &session_key); assert_ne!(retried_header.seq, first_header.seq); - assert_eq!(retried.frames, first.frames); + assert_eq!(retried, first); } #[test] @@ -341,7 +341,7 @@ fn session_records_contain_ack_frames_after_delivery() { let session_key = harness.a.fsm.state.link.transport().unwrap().rx_key.clone(); let (_ack_header, ack_record) = decrypt_record(&harness.a.crypto, &ack, &session_key); assert!(matches!( - ack_record.frames.as_slice(), + ack_record.as_slice(), [ql_wire::SessionFrame::Ack(_)] )); } @@ -376,7 +376,7 @@ fn first_stream_data_uses_negotiated_initial_peer_credit() { let (_header, record) = decrypt_record(&harness.b.crypto, &data, &session_key); assert!(matches!( - record.frames.as_slice(), + record.as_slice(), [ql_wire::SessionFrame::StreamData(frame)] if frame.stream_id == stream_id && frame.bytes.as_slice() == b"hel" )); } diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index bca27a31..3d926ece 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -1,4 +1,4 @@ -use ::bytes::BufMut; +use bytes::BufMut; use super::{RecordAck, SessionClose, SessionFrame, StreamClose, StreamData, StreamWindow}; use crate::{ diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index b3f7f2db..dd9ccd3b 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -1,6 +1,6 @@ use crate::{ - codec, encrypted_message::EncryptedMessage, ByteChunks, ByteSlice, Nonce, QlCrypto, - SessionHeader, SessionKey, VarInt, VarIntBoundsExceeded, WireDecode, WireEncode, WireError, + codec, encrypted_message::EncryptedMessage, ByteChunks, ByteSlice, Nonce, QlCrypto, Reader, + SessionHeader, SessionKey, VarInt, WireDecode, WireEncode, WireError, }; mod ack; @@ -8,6 +8,7 @@ mod builder; mod close; mod stream_close; mod stream_data; +mod stream_id; mod stream_window; pub use ack::*; @@ -15,50 +16,9 @@ pub use builder::*; pub use close::*; pub use stream_close::*; pub use stream_data::*; +pub use stream_id::*; pub use stream_window::*; -// todo: should use even/odd based on xid ordering -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -#[repr(transparent)] -pub struct StreamId(pub VarInt); - -impl StreamId { - pub const MAX_ENCODED_LEN: usize = VarInt::MAX_SIZE; - - pub const fn from_u32(value: u32) -> Self { - Self(VarInt::from_u32(value)) - } - - pub fn from_u64(value: u64) -> Result { - Ok(Self(VarInt::from_u64(value)?)) - } - - pub const fn into_inner(self) -> u64 { - self.0.into_inner() - } -} - -impl WireEncode for StreamId { - fn encoded_len(&self) -> usize { - self.0.size() - } - - fn encode(&self, out: &mut W) { - self.0.encode(out); - } -} - -impl codec::WireDecode for StreamId { - fn decode(reader: &mut codec::Reader) -> Result { - Ok(Self(reader.decode()?)) - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct SessionRecord { - pub frames: Vec>>, -} - #[derive(Debug, Clone, PartialEq, Eq)] pub enum SessionFrame { Ping, @@ -69,56 +29,27 @@ pub enum SessionFrame { Close(SessionClose), } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(u8)] -pub(crate) enum SessionFrameKind { - Ping = 1, - Ack = 2, - StreamData = 3, - StreamWindow = 4, - StreamClose = 5, - Close = 6, -} - -pub struct SessionFrameIter { - remaining: Option, -} - -impl TryFrom for SessionFrameKind { - type Error = WireError; - - fn try_from(value: u8) -> Result { - match value { - 1 => Ok(Self::Ping), - 2 => Ok(Self::Ack), - 3 => Ok(Self::StreamData), - 4 => Ok(Self::StreamWindow), - 5 => Ok(Self::StreamClose), - 6 => Ok(Self::Close), - _ => Err(WireError::InvalidPayload), - } - } -} - -impl codec::WireDecode for SessionFrameKind { - fn decode(reader: &mut codec::Reader) -> Result { - reader.decode::()?.try_into() - } -} - -impl SessionRecord { - pub fn parse(bytes: B) -> Result, WireError> { - Ok(SessionFrameIter { - remaining: Some(bytes), - }) - } - - pub fn decode(bytes: &[u8]) -> Result { - let frames = Self::parse(bytes)?; - let frames = frames - .map(|frame| frame.map(SessionFrame::into_owned)) - .collect::, _>>()?; - Ok(Self { frames }) +impl WireDecode for SessionFrame { + fn decode(reader: &mut Reader) -> Result { + let kind = reader.decode::()?; + let frame = match kind { + SessionFrameKind::Ping => SessionFrame::Ping, + SessionFrameKind::Ack => SessionFrame::Ack(reader.decode::()?), + SessionFrameKind::StreamData => { + let len = usize::try_from(reader.decode::()?.into_inner()) + .map_err(|_| WireError::InvalidPayload)?; + let frame = reader.take_bytes(len)?; + SessionFrame::StreamData(StreamData::decode_exact(frame)?) + } + SessionFrameKind::StreamWindow => { + SessionFrame::StreamWindow(reader.decode::()?) + } + SessionFrameKind::StreamClose => { + SessionFrame::StreamClose(reader.decode::()?) + } + SessionFrameKind::Close => SessionFrame::Close(reader.decode::()?), + }; + Ok(frame) } } @@ -185,37 +116,63 @@ impl WireEncode for SessionFrame { } } -impl WireEncode for SessionRecord { - fn encoded_len(&self) -> usize { - self.frames - .iter() - .map(WireEncode::encoded_len) - .sum::() - } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub(crate) enum SessionFrameKind { + Ping = 1, + Ack = 2, + StreamData = 3, + StreamWindow = 4, + StreamClose = 5, + Close = 6, +} - fn encode(&self, out: &mut W) { - for frame in &self.frames { - frame.encode(out); +impl TryFrom for SessionFrameKind { + type Error = WireError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::Ping), + 2 => Ok(Self::Ack), + 3 => Ok(Self::StreamData), + 4 => Ok(Self::StreamWindow), + 5 => Ok(Self::StreamClose), + 6 => Ok(Self::Close), + _ => Err(WireError::InvalidPayload), } } } +impl codec::WireDecode for SessionFrameKind { + fn decode(reader: &mut codec::Reader) -> Result { + reader.decode::()?.try_into() + } +} + +pub fn parse_session_frames(bytes: B) -> SessionFrameIter { + SessionFrameIter { + reader: Reader::new(bytes), + } +} + +pub fn decode_session_frames(bytes: &[u8]) -> Result>>, WireError> { + parse_session_frames(bytes) + .map(|frame| frame.map(SessionFrame::into_owned)) + .collect() +} + +pub struct SessionFrameIter { + reader: Reader, +} + impl Iterator for SessionFrameIter { type Item = Result, WireError>; fn next(&mut self) -> Option { - let remaining = self.remaining.take()?; - if remaining.is_empty() { - return None; - } - - let parsed = parse_next_frame(remaining); - match parsed { - Ok((frame, rest)) => { - self.remaining = Some(rest); - Some(Ok(frame)) - } - Err(error) => Some(Err(error)), + if self.reader.is_empty() { + None + } else { + Some(self.reader.decode::>()) } } } @@ -230,24 +187,3 @@ pub fn decrypt_record>( let nonce = Nonce::from_counter(header.seq.into_inner()); encrypted.decrypt_in_place(crypto, session_key, &nonce, &aad) } - -fn parse_next_frame(bytes: B) -> Result<(SessionFrame, B), WireError> { - let mut reader = codec::Reader::new(bytes); - let kind = reader.decode::()?; - let frame = match kind { - SessionFrameKind::Ping => SessionFrame::Ping, - SessionFrameKind::Ack => SessionFrame::Ack(reader.decode::()?), - SessionFrameKind::StreamData => { - let len = usize::try_from(reader.decode::()?.into_inner()) - .map_err(|_| WireError::InvalidPayload)?; - let frame = reader.take_bytes(len)?; - SessionFrame::StreamData(StreamData::decode_exact(frame)?) - } - SessionFrameKind::StreamWindow => { - SessionFrame::StreamWindow(reader.decode::()?) - } - SessionFrameKind::StreamClose => SessionFrame::StreamClose(reader.decode::()?), - SessionFrameKind::Close => SessionFrame::Close(reader.decode::()?), - }; - Ok((frame, reader.take_rest())) -} diff --git a/ql-wire/src/encrypted/stream_id.rs b/ql-wire/src/encrypted/stream_id.rs new file mode 100644 index 00000000..fdbf564d --- /dev/null +++ b/ql-wire/src/encrypted/stream_id.rs @@ -0,0 +1,29 @@ +use crate::{ByteSlice, Reader, VarInt, WireDecode, WireEncode, WireError}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct StreamId(pub VarInt); + +impl StreamId { + pub const MAX_ENCODED_LEN: usize = VarInt::MAX_SIZE; + + pub const fn into_inner(self) -> u64 { + self.0.into_inner() + } +} + +impl WireEncode for StreamId { + fn encoded_len(&self) -> usize { + self.0.size() + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl WireDecode for StreamId { + fn decode(reader: &mut Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index ad868b71..5d1a3fd1 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -200,10 +200,10 @@ fn encrypt_record( crypto: &impl QlCrypto, header: SessionHeader, session_key: &SessionKey, - body: &SessionRecord, + body: &[SessionFrame>], ) -> QlSessionRecord> { let mut builder = SessionRecordBuilder::new(header.seq, usize::MAX); - for frame in &body.frames { + for frame in body { let pushed = builder.push_frame(frame); debug_assert!(pushed); } @@ -652,38 +652,36 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { connection_id: ConnectionId::from_data([0x44; ConnectionId::SIZE]), seq: record_seq(11), }; - let body = SessionRecord { - frames: vec![ - SessionFrame::Ping, - SessionFrame::Ack(RecordAck { - base_seq: record_seq(12), - bits: (1u64 << 0) - | (1u64 << 1) - | (1u64 << 8) - | (1u64 << 9) - | (1u64 << 10) - | (1u64 << 11), - }), - SessionFrame::StreamWindow(StreamWindow { - stream_id: stream_id(9), - maximum_offset: varint(65_536), - }), - SessionFrame::StreamData(StreamData { - stream_id: stream_id(9), - offset: varint(1024), - bytes: b"hello".to_vec(), - fin: true, - }), - SessionFrame::StreamClose(StreamClose { - stream_id: stream_id(9), - target: CloseTarget::Both, - code: StreamCloseCode(0), - }), - SessionFrame::Close(SessionClose { - code: SessionCloseCode::TIMEOUT, - }), - ], - }; + let body = vec![ + SessionFrame::Ping, + SessionFrame::Ack(RecordAck { + base_seq: record_seq(12), + bits: (1u64 << 0) + | (1u64 << 1) + | (1u64 << 8) + | (1u64 << 9) + | (1u64 << 10) + | (1u64 << 11), + }), + SessionFrame::StreamWindow(StreamWindow { + stream_id: stream_id(9), + maximum_offset: varint(65_536), + }), + SessionFrame::StreamData(StreamData { + stream_id: stream_id(9), + offset: varint(1024), + bytes: b"hello".to_vec(), + fin: true, + }), + SessionFrame::StreamClose(StreamClose { + stream_id: stream_id(9), + target: CloseTarget::Both, + code: StreamCloseCode(0), + }), + SessionFrame::Close(SessionClose { + code: SessionCloseCode::TIMEOUT, + }), + ]; let session_key = SessionKey::from_data([7; SessionKey::SIZE]); let record = encrypt_record(&crypto, header, &session_key, &body); @@ -701,7 +699,7 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { let decrypted = encrypted::decrypt_record(&crypto, &header, encrypted.clone(), &session_key).unwrap(); - assert_eq!(SessionRecord::decode(&decrypted).unwrap(), body); + assert_eq!(decode_session_frames(&decrypted).unwrap(), body); let wrong_header = SessionHeader { connection_id: ConnectionId::from_data([0x99; ConnectionId::SIZE]), @@ -810,9 +808,7 @@ fn protocol_record_size_breakdown() { seq: record_seq(1), }, &session.tx_key, - &SessionRecord { - frames: vec![SessionFrame::Ping], - }, + &[SessionFrame::Ping], ); let session_stream_empty = encrypt_record( &crypto, @@ -821,14 +817,12 @@ fn protocol_record_size_breakdown() { seq: record_seq(2), }, &session.tx_key, - &SessionRecord { - frames: vec![SessionFrame::StreamData(StreamData { - stream_id: stream_id(1), - offset: varint(0), - fin: false, - bytes: Vec::new(), - })], - }, + &[SessionFrame::StreamData(StreamData { + stream_id: stream_id(1), + offset: varint(0), + fin: false, + bytes: Vec::new(), + })], ); let session_close = encrypt_record( &crypto, @@ -837,11 +831,9 @@ fn protocol_record_size_breakdown() { seq: record_seq(3), }, &session.tx_key, - &SessionRecord { - frames: vec![SessionFrame::Close(SessionClose { - code: SessionCloseCode::PROTOCOL, - })], - }, + &[SessionFrame::Close(SessionClose { + code: SessionCloseCode::PROTOCOL, + })], ); print_size("ql-wire peer bundle", initiator.bundle().encode_vec().len()); From 12b8333608bb7cb12682a98528a80b6c040519e9 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 6 Apr 2026 20:39:56 -0400 Subject: [PATCH 130/304] ql-fsm: return bytes views from stream read --- ql-fsm/src/lib.rs | 2 +- ql-fsm/src/session/stream_rx.rs | 7 +++---- ql-runtime/src/driver/mod.rs | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 47015351..430e38df 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -263,7 +263,7 @@ impl QlFsm { implementation::write_stream(self, stream_id, bytes) } - /// returns the readable stream bytes as borrowed chunks without consuming them + /// returns the readable stream bytes as owned `Bytes` views without consuming them pub fn stream_read(&self, stream_id: StreamId) -> Option> { implementation::stream_read(self, stream_id) } diff --git a/ql-fsm/src/session/stream_rx.rs b/ql-fsm/src/session/stream_rx.rs index 077f1204..1367d9a2 100644 --- a/ql-fsm/src/session/stream_rx.rs +++ b/ql-fsm/src/session/stream_rx.rs @@ -242,7 +242,7 @@ pub struct StreamReadIter<'a> { } impl<'a> Iterator for StreamReadIter<'a> { - type Item = &'a [u8]; + type Item = Bytes; fn next(&mut self) -> Option { while self.remaining > 0 { @@ -258,11 +258,10 @@ impl<'a> Iterator for StreamReadIter<'a> { continue; } - let chunk = &bytes[skip..]; - let len = chunk.len().min(self.remaining); + let len = (bytes.len() - skip).min(self.remaining); self.remaining -= len; self.cursor += len as u64; - return Some(&chunk[..len]); + return Some(bytes.slice(skip..skip + len)); } None diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index d8a25ae7..48936b06 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -381,7 +381,7 @@ impl DriverState { if chunk.is_empty() { continue; } - match stream.inbound_mut().try_write(chunk) { + match stream.inbound_mut().try_write(&chunk) { InboundWriteResult::Accepted(n) => { accepted += n; if n < chunk.len() { From 75c126826ad81af9c70cb3aaf35cd1e5cdbd49b6 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 6 Apr 2026 22:42:57 -0400 Subject: [PATCH 131/304] ql-runtime: chunkslot --- Cargo.lock | 150 +++++- ql-fsm/src/implementation/core.rs | 7 + ql-fsm/src/lib.rs | 5 + ql-fsm/src/session/mod.rs | 8 + ql-fsm/src/session/stream_rx.rs | 2 +- ql-fsm/src/session/tests.rs | 4 +- ql-fsm/src/tests/proptest.rs | 2 +- ql-fsm/src/tests/session.rs | 2 +- ql-runtime/Cargo.toml | 11 +- ql-runtime/src/chunk_slot.rs | 501 ++++++++++++++++++++ ql-runtime/src/command.rs | 4 +- ql-runtime/src/driver/mod.rs | 93 ++-- ql-runtime/src/driver/state.rs | 156 +++--- ql-runtime/src/driver/test.rs | 36 +- ql-runtime/src/handle/mod.rs | 15 +- ql-runtime/src/handle/reader.rs | 43 +- ql-runtime/src/handle/writer.rs | 37 +- ql-runtime/src/lib.rs | 3 +- ql-runtime/src/rpc/mod.rs | 21 +- ql-runtime/src/rpc/request_with_progress.rs | 12 +- ql-runtime/src/rpc/subscription.rs | 6 +- ql-runtime/src/tests/handshake.rs | 3 +- ql-runtime/src/tests/mod.rs | 14 +- ql-runtime/src/tests/rpc.rs | 8 +- ql-runtime/src/tests/stream.rs | 36 +- 25 files changed, 879 insertions(+), 300 deletions(-) create mode 100644 ql-runtime/src/chunk_slot.rs diff --git a/Cargo.lock b/Cargo.lock index d719520e..4cdd8854 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -127,12 +127,6 @@ version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c59bdb34bc650a32731b31bd8f0829cc15d24a708ee31559e0bb34f2bc320cba" -[[package]] -name = "atomic-waker" -version = "1.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" - [[package]] name = "autocfg" version = "1.5.0" @@ -533,6 +527,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" dependencies = [ "crossbeam-utils", + "loom", ] [[package]] @@ -885,6 +880,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e13b66accf52311f30a0db42147dadea9850cb48cd070028831ae5f5d4b856ab" dependencies = [ "concurrent-queue", + "loom", "parking", "pin-project-lite", ] @@ -1097,6 +1093,21 @@ dependencies = [ "slab", ] +[[package]] +name = "generator" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52f04ae4152da20c76fe800fa48659201d5cf627c5149ca0b707b69d7eef6cf9" +dependencies = [ + "cc", + "cfg-if", + "libc", + "log", + "rustversion", + "windows-link", + "windows-result", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -1620,6 +1631,28 @@ version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + [[package]] name = "md-5" version = "0.10.6" @@ -1705,6 +1738,15 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -1852,6 +1894,9 @@ name = "parking" version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" +dependencies = [ + "loom", +] [[package]] name = "parking_lot_core" @@ -1968,17 +2013,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" -[[package]] -name = "piper" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c835479a4443ded371d6c535cbfd8d31ad92c5d23ae9770a61bc155e4992a3c1" -dependencies = [ - "atomic-waker", - "fastrand", - "futures-io", -] - [[package]] name = "pkcs1" version = "0.7.5" @@ -2206,10 +2240,11 @@ version = "0.1.0" dependencies = [ "async-channel", "bytes", + "event-listener", "futures-lite", "libcrux-aesgcm", + "loom", "oneshot", - "piper", "ql-fsm", "ql-rpc", "ql-wire", @@ -2512,6 +2547,12 @@ dependencies = [ "cipher", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.2.0" @@ -2623,6 +2664,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shlex" version = "1.3.0" @@ -2819,6 +2869,15 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + [[package]] name = "threadpool" version = "1.8.1" @@ -2900,6 +2959,55 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + [[package]] name = "typenum" version = "1.18.0" @@ -2978,6 +3086,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "version_check" version = "0.9.5" diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index b7347553..672ff222 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -159,6 +159,13 @@ pub fn write_stream( Ok(state.session.write_stream(stream_id, bytes)?) } +pub fn stream_write_capacity(fsm: &QlFsm, stream_id: StreamId) -> Option { + fsm.state + .link + .connected() + .and_then(|state| state.session.stream_write_capacity(stream_id)) +} + pub fn stream_read(fsm: &QlFsm, stream_id: StreamId) -> Option> { fsm.state .link diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 430e38df..c03ead39 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -263,6 +263,11 @@ impl QlFsm { implementation::write_stream(self, stream_id, bytes) } + /// returns how many bytes can currently be queued for an open stream + pub fn stream_write_capacity(&self, stream_id: StreamId) -> Option { + implementation::stream_write_capacity(self, stream_id) + } + /// returns the readable stream bytes as owned `Bytes` views without consuming them pub fn stream_read(&self, stream_id: StreamId) -> Option> { implementation::stream_read(self, stream_id) diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index e74fb5a3..c8e983a5 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -176,6 +176,14 @@ impl SessionFsm { Ok(accepted) } + pub fn stream_write_capacity(&self, stream_id: StreamId) -> Option { + let stream = self.state.streams.get(&stream_id)?; + if !stream.is_writable() { + return Some(0); + } + Some(stream.send_capacity(self.config.stream_send_buffer_size)) + } + pub fn finish_stream(&mut self, stream_id: StreamId) -> Result<(), StreamError> { self.ensure_session_open()?; let stream = self diff --git a/ql-fsm/src/session/stream_rx.rs b/ql-fsm/src/session/stream_rx.rs index 1367d9a2..58f7d37c 100644 --- a/ql-fsm/src/session/stream_rx.rs +++ b/ql-fsm/src/session/stream_rx.rs @@ -278,7 +278,7 @@ mod tests { let readable = rx.readable_len(); let mut out = Vec::with_capacity(readable); for chunk in rx.bytes() { - out.extend_from_slice(chunk); + out.extend_from_slice(&chunk); } out } diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index ed6636f6..dcf517fc 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -31,7 +31,7 @@ fn read_stream_all(fsm: &mut SessionFsm, stream_id: StreamId) -> Vec { loop { let mut read = 0; for chunk in fsm.stream_read(stream_id).unwrap() { - out.extend_from_slice(chunk); + out.extend_from_slice(&chunk); read += chunk.len(); } if read == 0 { @@ -208,7 +208,7 @@ fn commit_stream_read_is_what_advances_stream_window() { let read = fsm .stream_read(stream_id) .unwrap() - .map(<[u8]>::len) + .map(|chunk| chunk.len()) .sum::(); assert_eq!(read, 2); diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index 92102bbc..137c6e9a 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -906,7 +906,7 @@ fn drain_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { let mut read = 0usize; for chunk in chunks { - out.extend_from_slice(chunk); + out.extend_from_slice(&chunk); read += chunk.len(); } diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index e221a78a..a08b24e0 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -24,7 +24,7 @@ fn read_stream_all(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { loop { let mut read = 0; for chunk in fsm.stream_read(stream_id).unwrap() { - out.extend_from_slice(chunk); + out.extend_from_slice(&chunk); read += chunk.len(); } if read == 0 { diff --git a/ql-runtime/Cargo.toml b/ql-runtime/Cargo.toml index a34bacc8..116c54e7 100644 --- a/ql-runtime/Cargo.toml +++ b/ql-runtime/Cargo.toml @@ -11,15 +11,22 @@ rpc = ["dep:ql-rpc"] [dependencies] async-channel = { version = "2.5" } +bytes = "1" +event-listener = "5.4" futures-lite = { version = "2.5" } oneshot = { version = "0.1.11" } -piper = { version = "0.2.4" } ql-fsm = { path = "../ql-fsm" } ql-rpc = { path = "../ql-rpc", optional = true } ql-wire = { path = "../ql-wire" } [dev-dependencies] -bytes = "1" libcrux-aesgcm = "0.0.7" sha2 = "0.10" tokio = { version = "1.44", features = ["macros", "rt", "time", "sync"] } + +[target.'cfg(loom)'.dev-dependencies] +event-listener = { version = "5.4", features = ["loom"] } +loom = "0.7" + +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(loom)'] } diff --git a/ql-runtime/src/chunk_slot.rs b/ql-runtime/src/chunk_slot.rs new file mode 100644 index 00000000..8280d48b --- /dev/null +++ b/ql-runtime/src/chunk_slot.rs @@ -0,0 +1,501 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use event_listener::{Event, EventListener}; + +mod sync { + #[cfg(not(all(test, loom)))] + pub use std::sync::atomic::{AtomicU8, Ordering}; + #[cfg(not(all(test, loom)))] + pub use std::sync::{Arc, Mutex}; + + #[cfg(all(test, loom))] + pub use loom::sync::atomic::{AtomicU8, Ordering}; + #[cfg(all(test, loom))] + pub use loom::sync::{Arc, Mutex}; +} + +use sync::{Arc, AtomicU8, Mutex, Ordering}; + +const OCCUPIED: u8 = 1 << 0; +const TX_CLOSED: u8 = 1 << 1; +const RX_CLOSED: u8 = 1 << 2; + +pub fn new() -> (ChunkSlotRx, ChunkSlotTx) { + let inner = Arc::new(Inner { + chunk: Mutex::new(None), + state: AtomicU8::new(0), + changed: Event::new(), + }); + + ( + ChunkSlotRx { + inner: inner.clone(), + }, + ChunkSlotTx { inner }, + ) +} + +pub struct ChunkSlotRx { + inner: Arc, +} + +pub struct ChunkSlotTx { + inner: Arc, +} + +#[derive(Debug)] +pub struct SendClosed(pub Bytes); + +#[derive(Debug, PartialEq, Eq)] +pub enum TrySendError { + Closed(Bytes), + Full(Bytes), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RecvClosed; + +impl ChunkSlotRx { + pub fn try_recv(&self, max_len: usize) -> Result, RecvClosed> { + self.inner.try_recv(max_len) + } + + pub(crate) fn poll_recv( + &self, + max_len: usize, + listener: &mut Option, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + match self.try_recv(max_len) { + Ok(Some(bytes)) => return Poll::Ready(Ok(bytes)), + Err(closed) => return Poll::Ready(Err(closed)), + Ok(None) => {} + } + + if let Some(active_listener) = listener.as_mut() { + match Pin::new(active_listener).poll(cx) { + Poll::Ready(()) => *listener = None, + Poll::Pending => return Poll::Pending, + } + } else { + *listener = Some(self.inner.changed.listen()); + } + } + } + + pub fn recv(&self, max_len: usize) -> Recv<'_> { + Recv { + rx: self, + max_len, + listener: None, + } + } + + pub fn is_finished(&self) -> bool { + self.inner.snapshot(Ordering::Acquire).is_finished() + } + + pub fn is_empty(&self) -> bool { + !self.inner.snapshot(Ordering::Relaxed).is_occupied() + } + + pub fn close(self) { + self.inner.close_rx(); + } +} + +impl Drop for ChunkSlotRx { + fn drop(&mut self) { + self.inner.close_rx(); + } +} + +impl ChunkSlotTx { + pub fn try_send(&self, bytes: Bytes) -> Result<(), TrySendError> { + self.inner.try_send(bytes) + } + + pub fn send(&self, bytes: Bytes) -> Send<'_> { + Send { + tx: self, + bytes: Some(bytes), + listener: None, + } + } + + pub fn is_closed(&self) -> bool { + self.inner.snapshot(Ordering::Acquire).is_closed() + } + + pub fn close(self) { + self.inner.close_tx(); + } +} + +impl Drop for ChunkSlotTx { + fn drop(&mut self) { + self.inner.close_tx(); + } +} + +pub struct Recv<'a> { + rx: &'a ChunkSlotRx, + max_len: usize, + listener: Option, +} + +impl Future for Recv<'_> { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.rx.poll_recv(self.max_len, &mut self.listener, cx) + } +} + +pub struct Send<'a> { + tx: &'a ChunkSlotTx, + bytes: Option, + listener: Option, +} + +impl Future for Send<'_> { + type Output = Result<(), SendClosed>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + let bytes = self + .bytes + .take() + .expect("send future polled after completion"); + + match self.tx.try_send(bytes) { + Ok(()) => return Poll::Ready(Ok(())), + Err(TrySendError::Closed(bytes)) => return Poll::Ready(Err(SendClosed(bytes))), + Err(TrySendError::Full(bytes)) => self.bytes = Some(bytes), + } + + if let Some(listener) = self.listener.as_mut() { + match Pin::new(listener).poll(cx) { + Poll::Ready(()) => self.listener = None, + Poll::Pending => return Poll::Pending, + } + } else { + self.listener = Some(self.tx.inner.changed.listen()); + } + } + } +} + +struct Inner { + chunk: Mutex>, + state: AtomicU8, + changed: Event, +} + +#[derive(Clone, Copy)] +struct StateSnapshot(u8); + +impl StateSnapshot { + fn has_any(self, bits: u8) -> bool { + self.0 & bits != 0 + } + + fn is_occupied(self) -> bool { + self.has_any(OCCUPIED) + } + + fn is_closed(self) -> bool { + self.has_any(TX_CLOSED | RX_CLOSED) + } + + fn is_finished(self) -> bool { + self.has_any(TX_CLOSED) && !self.is_occupied() + } +} + +impl Inner { + fn snapshot(&self, ordering: Ordering) -> StateSnapshot { + StateSnapshot(self.state.load(ordering)) + } + + fn mark_occupied(&self) { + self.state.fetch_or(OCCUPIED, Ordering::Release); + } + + fn clear_occupied(&self) { + self.state.fetch_and(!OCCUPIED, Ordering::Release); + } + + fn close_rx(&self) { + if !StateSnapshot(self.state.fetch_or(RX_CLOSED, Ordering::Release)).has_any(RX_CLOSED) { + self.changed.notify(usize::MAX); + } + } + + fn close_tx(&self) { + if !StateSnapshot(self.state.fetch_or(TX_CLOSED, Ordering::Release)).has_any(TX_CLOSED) { + self.changed.notify(usize::MAX); + } + } + + fn try_recv(&self, max_len: usize) -> Result, RecvClosed> { + let snapshot = self.snapshot(Ordering::Acquire); + if max_len == 0 || !snapshot.is_occupied() { + return if snapshot.is_closed() { + Err(RecvClosed) + } else { + Ok(None) + }; + } + + let (bytes, became_empty) = { + let Ok(mut chunk) = self.chunk.try_lock() else { + return Ok(None); + }; + let Some(result) = take_chunk(&mut chunk, max_len) else { + return Ok(None); + }; + result + }; + + if became_empty { + self.clear_occupied(); + self.changed.notify(usize::MAX); + } + + Ok(Some(bytes)) + } + + fn try_send(&self, bytes: Bytes) -> Result<(), TrySendError> { + let snapshot = self.snapshot(Ordering::Acquire); + if snapshot.is_closed() { + return Err(TrySendError::Closed(bytes)); + } + if snapshot.is_occupied() { + return Err(TrySendError::Full(bytes)); + } + + let result = { + let Ok(mut chunk) = self.chunk.try_lock() else { + return Err(TrySendError::Full(bytes)); + }; + if self.snapshot(Ordering::Relaxed).is_closed() { + Err(TrySendError::Closed(bytes)) + } else if chunk.is_some() { + Err(TrySendError::Full(bytes)) + } else { + *chunk = Some(bytes); + Ok(()) + } + }; + + if result.is_ok() { + self.mark_occupied(); + self.changed.notify(usize::MAX); + } + + result + } +} + +fn take_chunk(chunk: &mut Option, max_len: usize) -> Option<(Bytes, bool)> { + let bytes = chunk.as_mut()?; + if bytes.len() <= max_len { + Some((chunk.take().unwrap(), true)) + } else { + Some((bytes.split_to(max_len), false)) + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use bytes::Bytes; + + use super::{new, TrySendError}; + + #[test] + fn try_send_and_take_round_trip() { + let (rx, tx) = new(); + + tx.try_send(Bytes::from_static(b"hello")).unwrap(); + assert_eq!(rx.try_recv(8), Ok(Some(Bytes::from_static(b"hello")))); + assert_eq!(rx.try_recv(8), Ok(None)); + } + + #[test] + fn read_splits_without_freeing_slot() { + let (rx, tx) = new(); + + tx.try_send(Bytes::from_static(b"hello")).unwrap(); + assert_eq!(rx.try_recv(2), Ok(Some(Bytes::from_static(b"he")))); + assert_eq!( + tx.try_send(Bytes::from_static(b"!")), + Err(TrySendError::Full(Bytes::from_static(b"!"))) + ); + assert_eq!(rx.try_recv(8), Ok(Some(Bytes::from_static(b"llo")))); + } + + #[test] + fn read_drains_slot_when_limit_covers_chunk() { + let (rx, tx) = new(); + + tx.try_send(Bytes::from_static(b"hello")).unwrap(); + assert_eq!(rx.try_recv(8), Ok(Some(Bytes::from_static(b"hello")))); + tx.try_send(Bytes::from_static(b"!")).unwrap(); + assert_eq!(rx.try_recv(8), Ok(Some(Bytes::from_static(b"!")))); + } + + #[tokio::test(flavor = "current_thread")] + async fn send_waits_until_slot_clears() { + let (rx, tx) = new(); + + tx.try_send(Bytes::from_static(b"a")).unwrap(); + + let sender = tokio::spawn(async move { + tx.send(Bytes::from_static(b"b")).await.unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(10)).await; + assert_eq!(rx.try_recv(8), Ok(Some(Bytes::from_static(b"a")))); + + tokio::time::timeout(Duration::from_secs(1), sender) + .await + .unwrap() + .unwrap(); + } + + #[tokio::test(flavor = "current_thread")] + async fn finish_yields_eof_after_buffered_chunk() { + let (rx, tx) = new(); + + tx.send(Bytes::from_static(b"abc")).await.unwrap(); + tx.close(); + + assert_eq!(rx.recv(8).await, Ok(Bytes::from_static(b"abc"))); + assert_eq!(rx.recv(8).await, Err(super::RecvClosed)); + assert!(rx.is_finished()); + } + + #[tokio::test(flavor = "current_thread")] + async fn closing_receiver_returns_unsent_bytes() { + let (rx, tx) = new(); + + rx.close(); + + let error = tx.send(Bytes::from_static(b"abc")).await.unwrap_err(); + assert_eq!(error.0, Bytes::from_static(b"abc")); + } +} + +#[cfg(all(test, loom))] +mod loom_tests { + use std::{ + future::Future, + pin::pin, + task::{Context, Poll, Waker}, + }; + + use bytes::Bytes; + use loom::{model, thread}; + + use super::{new, RecvClosed}; + + fn now_or_never(future: F) -> Option { + let waker = Waker::noop(); + let mut cx = Context::from_waker(waker); + let mut future = pin!(future); + match future.as_mut().poll(&mut cx) { + Poll::Ready(value) => Some(value), + Poll::Pending => None, + } + } + + fn check_model(f: impl Fn() + Sync + Send + 'static) { + let mut builder = model::Builder::new(); + builder.preemption_bound = Some(3); + builder.check(f); + } + + #[test] + fn try_recv_never_reports_closed_while_open() { + check_model(|| { + let (rx, tx) = new(); + + let sender = thread::spawn(move || { + let _ = tx.try_send(Bytes::from_static(b"abc")); + }); + + let receiver = thread::spawn(move || { + let result = rx.try_recv(1); + assert!( + !matches!(result, Err(RecvClosed)), + "open slot must not report RecvClosed" + ); + }); + + sender.join().unwrap(); + receiver.join().unwrap(); + }); + } + + #[test] + fn recv_observes_send_after_pending() { + check_model(|| { + let (rx, tx) = new(); + + assert!(now_or_never(rx.recv(8)).is_none()); + + let sender = thread::spawn(move || { + tx.try_send(Bytes::from_static(b"abc")).unwrap(); + }); + + sender.join().unwrap(); + + assert_eq!( + now_or_never(rx.recv(8)), + Some(Ok(Bytes::from_static(b"abc"))) + ); + }); + } + + #[test] + fn recv_observes_finish_as_closed() { + check_model(|| { + let (rx, tx) = new(); + + assert!(now_or_never(rx.recv(8)).is_none()); + + let finisher = thread::spawn(move || { + tx.close(); + }); + + finisher.join().unwrap(); + + assert_eq!(now_or_never(rx.recv(8)), Some(Err(RecvClosed))); + }); + } + + #[test] + fn partial_recv_preserves_remainder_and_finished_state() { + check_model(|| { + let (rx, tx) = new(); + + tx.try_send(Bytes::from_static(b"abcd")).unwrap(); + tx.close(); + + assert_eq!(rx.try_recv(2), Ok(Some(Bytes::from_static(b"ab")))); + assert!(!rx.is_finished()); + assert_eq!(rx.try_recv(8), Ok(Some(Bytes::from_static(b"cd")))); + assert_eq!(rx.try_recv(8), Err(RecvClosed)); + assert!(rx.is_finished()); + }); + } +} diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index 261cef1f..46c31122 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -1,6 +1,6 @@ use ql_wire::{CloseTarget, PeerBundle, StreamCloseCode, StreamId}; -use crate::{ByteReader, QlError}; +use crate::{chunk_slot::ChunkSlotRx, ByteReader, QlError}; pub(crate) enum RuntimeCommand { BindPeer { @@ -8,7 +8,7 @@ pub(crate) enum RuntimeCommand { }, Connect, OpenStream { - request_reader: piper::Reader, + request_reader: ChunkSlotRx, start: oneshot::Sender>, }, PollInbound { diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 48936b06..afb57cc5 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -5,7 +5,7 @@ mod test; use std::{ collections::{HashMap, VecDeque}, future::Future, - task::{Context, Poll, Waker}, + task::Poll, time::{Duration, Instant, SystemTime, UNIX_EPOCH}, }; @@ -13,8 +13,9 @@ use futures_lite::future::poll_fn; use ql_fsm::{FsmTime, QlFsm, QlFsmEvent, SessionWriteId}; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; -use self::state::{DriverState, DriverStreamIo, InboundIo, InboundWriteResult, OutboundIo}; +use self::state::{DriverState, DriverStreamIo, InboundWriteResult}; use crate::{ + chunk_slot, command::RuntimeCommand, handle::{ByteReader, ByteWriter, QlStream}, platform::{PlatformFuture, QlPlatform}, @@ -42,7 +43,6 @@ impl Runtime

{ let mut state = DriverState { streams: HashMap::new(), runtime_tx: tx, - stream_send_buffer_bytes: config.stream_send_buffer_bytes, max_concurrent_message_writes: config.max_concurrent_message_writes, peer_xid, pending_fsm_events: VecDeque::new(), @@ -173,8 +173,7 @@ impl DriverState { match fsm.open_stream().map_err(QlError::from) { Ok(stream_id) => { - let (response_reader, response_writer) = - piper::pipe(self.stream_send_buffer_bytes); + let (response_reader, response_writer) = chunk_slot::new(); let (response_terminal_tx, response_terminal_rx) = oneshot::channel(); self.streams.insert( stream_id, @@ -338,9 +337,9 @@ impl DriverState { return; }; - let (request_reader, request_writer) = piper::pipe(self.stream_send_buffer_bytes); + let (request_reader, request_writer) = chunk_slot::new(); let (request_terminal_tx, request_terminal_rx) = oneshot::channel(); - let (response_reader, response_writer) = piper::pipe(self.stream_send_buffer_bytes); + let (response_reader, response_writer) = chunk_slot::new(); self.streams.insert( stream_id, @@ -381,13 +380,9 @@ impl DriverState { if chunk.is_empty() { continue; } - match stream.inbound_mut().try_write(&chunk) { + match stream.inbound_mut().try_write(chunk) { InboundWriteResult::Accepted(n) => { accepted += n; - if n < chunk.len() { - blocked = true; - break; - } } InboundWriteResult::Full => { blocked = true; @@ -496,51 +491,35 @@ impl DriverState { } fn poll_stream(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { - loop { - let mut should_finish = false; - let progressed = { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - let Some((reader, finish_queued)) = stream.outbound_mut().open_mut() else { + let should_finish = { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + let Some(reader) = stream.outbound_mut().open_mut() else { + return; + }; + + if reader.is_finished() { + true + } else { + let Some(capacity) = fsm.stream_write_capacity(stream_id) else { return; }; - - let ready = with_noop_context(|cx| reader.poll(cx)); - if matches!(ready, Poll::Pending) { - false - } else { - let bytes = reader.peek_buf(); - if bytes.is_empty() { - if reader.is_closed() && reader.is_empty() && !*finish_queued { - *finish_queued = true; - should_finish = true; - } - false - } else { - let len = bytes.len(); - let mut bytes = ql_fsm::Bytes::copy_from_slice(bytes); - let accepted = fsm.write_stream(stream_id, &mut bytes).unwrap_or_default(); - if accepted > 0 { - reader.consume(accepted); - } - accepted > 0 && accepted == len + if capacity > 0 { + if let Ok(Some(mut bytes)) = reader.try_recv(capacity) { + let _ = fsm.write_stream(stream_id, &mut bytes); } } - }; - - if should_finish { - let _ = fsm.finish_stream(stream_id); - if let Some(stream) = self.streams.get_mut(&stream_id) { - stream.outbound_mut().close(); - } - self.try_reap_stream(stream_id); - break; + reader.is_finished() } + }; - if !progressed { - break; + if should_finish { + let _ = fsm.finish_stream(stream_id); + if let Some(stream) = self.streams.get_mut(&stream_id) { + stream.outbound_mut().close(); } + self.try_reap_stream(stream_id); } } @@ -548,14 +527,7 @@ impl DriverState { let should_reap = self .streams .get(&stream_id) - .is_some_and(|stream| match stream { - DriverStreamIo::Initiator { - request, response, .. - } => matches!(request, OutboundIo::Closed) && matches!(response, InboundIo::Closed), - DriverStreamIo::Responder { - request, response, .. - } => matches!(request, InboundIo::Closed) && matches!(response, OutboundIo::Closed), - }); + .is_some_and(DriverStreamIo::is_closed); if should_reap { self.streams.remove(&stream_id); } @@ -575,8 +547,3 @@ fn unix_now_secs() -> u64 { .unwrap_or(Duration::ZERO) .as_secs() } - -fn with_noop_context(f: impl FnOnce(&mut Context<'_>) -> T) -> T { - let mut cx = Context::from_waker(Waker::noop()); - f(&mut cx) -} diff --git a/ql-runtime/src/driver/state.rs b/ql-runtime/src/driver/state.rs index 3ce0c350..47ae8e2c 100644 --- a/ql-runtime/src/driver/state.rs +++ b/ql-runtime/src/driver/state.rs @@ -1,125 +1,121 @@ use std::collections::{HashMap, VecDeque}; +use bytes::Bytes; use ql_fsm::QlFsmEvent; use ql_wire::{CloseTarget, StreamId, XID}; -use crate::{command::RuntimeCommand, QlError}; +use crate::{ + chunk_slot::{ChunkSlotRx, ChunkSlotTx, TrySendError}, + command::RuntimeCommand, + QlError, +}; pub struct DriverState { pub streams: HashMap, pub runtime_tx: async_channel::WeakSender, - pub stream_send_buffer_bytes: usize, pub max_concurrent_message_writes: usize, pub peer_xid: Option, pub pending_fsm_events: VecDeque, } -pub enum DriverStreamIo { - Initiator { - request: OutboundIo, - response: InboundIo, - }, - Responder { - request: InboundIo, - response: OutboundIo, - }, +pub struct DriverStreamIo { + is_initiator: bool, + outbound: OutboundIo, + inbound: InboundIo, } impl DriverStreamIo { + #[cfg(test)] + pub fn new(is_initiator: bool, outbound: OutboundIo, inbound: InboundIo) -> Self { + Self { + is_initiator, + outbound, + inbound, + } + } + pub fn new_initiator( - request: piper::Reader, - response: piper::Writer, + request: ChunkSlotRx, + response: ChunkSlotTx, response_terminal: oneshot::Sender>, ) -> Self { - Self::Initiator { - request: OutboundIo::new(request), - response: InboundIo::new(response, response_terminal), + Self { + is_initiator: true, + outbound: OutboundIo::new(request), + inbound: InboundIo::new(response, response_terminal), } } pub fn new_responder( - request: piper::Writer, + request: ChunkSlotTx, request_terminal: oneshot::Sender>, - response: piper::Reader, + response: ChunkSlotRx, ) -> Self { - Self::Responder { - request: InboundIo::new(request, request_terminal), - response: OutboundIo::new(response), + Self { + is_initiator: false, + outbound: OutboundIo::new(response), + inbound: InboundIo::new(request, request_terminal), } } pub fn outbound_mut(&mut self) -> &mut OutboundIo { - match self { - Self::Initiator { request, .. } => request, - Self::Responder { response, .. } => response, - } + &mut self.outbound } pub fn inbound_mut(&mut self) -> &mut InboundIo { - match self { - Self::Initiator { response, .. } => response, - Self::Responder { request, .. } => request, - } + &mut self.inbound } pub fn inbound_target(&self) -> CloseTarget { - match self { - Self::Initiator { .. } => CloseTarget::Return, - Self::Responder { .. } => CloseTarget::Origin, + if self.is_initiator { + CloseTarget::Return + } else { + CloseTarget::Origin } } pub fn outbound_target(&self) -> CloseTarget { - match self { - Self::Initiator { .. } => CloseTarget::Origin, - Self::Responder { .. } => CloseTarget::Return, + if self.is_initiator { + CloseTarget::Origin + } else { + CloseTarget::Return } } pub fn fail_all(&mut self, error: &QlError) { - match self { - Self::Initiator { - request, response, .. - } => { - request.close(); - response.fail(error.clone()); - } - Self::Responder { - request, response, .. - } => { - request.fail(error.clone()); - response.close(); - } + if self.is_initiator { + self.outbound.close(); + self.inbound.fail(error.clone()); + } else { + self.inbound.fail(error.clone()); + self.outbound.close(); } } + + pub fn is_closed(&self) -> bool { + matches!(self.outbound, OutboundIo::Closed) && matches!(self.inbound, InboundIo::Closed) + } } pub enum OutboundIo { Open { - reader: piper::Reader, - finish_queued: bool, + reader: ChunkSlotRx, }, Closed, } impl OutboundIo { - pub fn new(reader: piper::Reader) -> Self { - Self::Open { - reader, - finish_queued: false, - } + pub fn new(reader: ChunkSlotRx) -> Self { + Self::Open { reader } } pub fn close(&mut self) { *self = Self::Closed; } - pub fn open_mut(&mut self) -> Option<(&mut piper::Reader, &mut bool)> { + pub fn open_mut(&mut self) -> Option<&mut ChunkSlotRx> { match self { - Self::Open { - reader, - finish_queued, - } => Some((reader, finish_queued)), + Self::Open { reader } => Some(reader), Self::Closed => None, } } @@ -127,7 +123,7 @@ impl OutboundIo { pub enum InboundIo { Open { - writer: piper::Writer, + writer: ChunkSlotTx, terminal: Option>>, finish_pending: bool, }, @@ -141,7 +137,7 @@ pub enum InboundWriteResult { } impl InboundIo { - pub fn new(writer: piper::Writer, terminal: oneshot::Sender>) -> Self { + pub fn new(writer: ChunkSlotTx, terminal: oneshot::Sender>) -> Self { Self::Open { writer, terminal: Some(terminal), @@ -153,38 +149,48 @@ impl InboundIo { *self = Self::Closed; } - pub fn try_write(&mut self, bytes: &[u8]) -> InboundWriteResult { + pub fn try_write(&mut self, bytes: Bytes) -> InboundWriteResult { let Self::Open { writer, .. } = self else { return InboundWriteResult::Closed; }; - let accepted = writer.try_fill(bytes); - if accepted > 0 { - return InboundWriteResult::Accepted(accepted); - } - if writer.is_closed() { - *self = Self::Closed; - return InboundWriteResult::Closed; + let len = bytes.len(); + match writer.try_send(bytes) { + Ok(()) => InboundWriteResult::Accepted(len), + Err(TrySendError::Full(_)) => InboundWriteResult::Full, + Err(TrySendError::Closed(_)) => { + *self = Self::Closed; + InboundWriteResult::Closed + } } - InboundWriteResult::Full } pub fn finish(&mut self) { - if let Self::Open { terminal, .. } = self { + if let Self::Open { + mut terminal, + writer, + .. + } = std::mem::replace(self, Self::Closed) + { + writer.close(); if let Some(terminal) = terminal.take() { let _ = terminal.send(Ok(())); } } - *self = Self::Closed; } pub fn fail(&mut self, error: QlError) { - if let Self::Open { terminal, .. } = self { + if let Self::Open { + mut terminal, + writer, + .. + } = std::mem::replace(self, Self::Closed) + { + writer.close(); if let Some(terminal) = terminal.take() { let _ = terminal.send(Err(error)); } } - *self = Self::Closed; } pub fn queue_finish(&mut self) { diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index dff10bac..326b702d 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -4,7 +4,11 @@ use ql_wire::{ }; use super::*; -use crate::tests::new_identity; +use crate::{ + chunk_slot, + driver::state::{InboundIo, OutboundIo}, + tests::new_identity, +}; struct NoopPlatform; @@ -93,7 +97,6 @@ fn new_driver_state() -> (DriverState, QlFsm) { DriverState { streams: HashMap::new(), runtime_tx: runtime_tx.downgrade(), - stream_send_buffer_bytes: 16, max_concurrent_message_writes: 1, peer_xid: None, pending_fsm_events: VecDeque::new(), @@ -103,7 +106,8 @@ fn new_driver_state() -> (DriverState, QlFsm) { } fn new_inbound_io(capacity: usize) -> InboundIo { - let (_reader, writer) = piper::pipe(capacity); + let _ = capacity; + let (_reader, writer) = chunk_slot::new(); let (terminal_tx, _terminal_rx) = oneshot::channel(); InboundIo::new(writer, terminal_tx) } @@ -115,10 +119,7 @@ fn handle_inbound_finished_reaps_closed_initiator_stream() { state.streams.insert( stream_id, - DriverStreamIo::Initiator { - request: OutboundIo::Closed, - response: new_inbound_io(1), - }, + DriverStreamIo::new(true, OutboundIo::Closed, new_inbound_io(1)), ); state.handle_inbound_finished(&fsm, stream_id); @@ -130,14 +131,11 @@ fn handle_inbound_finished_reaps_closed_initiator_stream() { fn handle_closed_stream_reaps_when_both_halves_close() { let (mut state, _fsm) = new_driver_state(); let stream_id = StreamId(1u32.into()); - let (response_reader, _response_writer) = piper::pipe(1); + let (response_reader, _response_writer) = chunk_slot::new(); state.streams.insert( stream_id, - DriverStreamIo::Responder { - request: new_inbound_io(1), - response: OutboundIo::new(response_reader), - }, + DriverStreamIo::new(false, OutboundIo::new(response_reader), new_inbound_io(1)), ); state.handle_closed_stream(&StreamClose { @@ -153,15 +151,12 @@ fn handle_closed_stream_reaps_when_both_halves_close() { fn poll_stream_reaps_after_local_finish_when_inbound_is_closed() { let (mut state, mut fsm) = new_driver_state(); let stream_id = StreamId(1u32.into()); - let (request_reader, request_writer) = piper::pipe(1); + let (request_reader, request_writer) = chunk_slot::new(); drop(request_writer); state.streams.insert( stream_id, - DriverStreamIo::Initiator { - request: OutboundIo::new(request_reader), - response: InboundIo::Closed, - }, + DriverStreamIo::new(true, OutboundIo::new(request_reader), InboundIo::Closed), ); state.poll_stream(&mut fsm, stream_id); @@ -173,15 +168,12 @@ fn poll_stream_reaps_after_local_finish_when_inbound_is_closed() { fn local_close_command_reaps_when_other_half_is_already_closed() { let (mut state, mut fsm) = new_driver_state(); let stream_id = StreamId(1u32.into()); - let (request_reader, _request_writer) = piper::pipe(1); + let (request_reader, _request_writer) = chunk_slot::new(); let mut in_flight = Vec::new(); state.streams.insert( stream_id, - DriverStreamIo::Initiator { - request: OutboundIo::new(request_reader), - response: InboundIo::Closed, - }, + DriverStreamIo::new(true, OutboundIo::new(request_reader), InboundIo::Closed), ); state.drive_command( diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index bb860042..acb594cf 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -4,7 +4,7 @@ mod writer; use ql_wire::{CloseTarget, PeerBundle, StreamId}; pub use self::{reader::*, writer::*}; -use crate::{command::RuntimeCommand, QlError}; +use crate::{chunk_slot, command::RuntimeCommand, QlError}; #[derive(Debug)] pub struct QlStream { @@ -16,7 +16,6 @@ pub struct QlStream { #[derive(Clone)] pub struct RuntimeHandle { tx: async_channel::Sender, - stream_send_buffer_bytes: usize, } impl RuntimeHandle { @@ -33,7 +32,7 @@ impl RuntimeHandle { } pub async fn open_stream(&self) -> Result { - let (request_reader, request_writer) = piper::pipe(self.stream_send_buffer_bytes); + let (request_reader, request_writer) = chunk_slot::new(); let (start_tx, start_rx) = oneshot::channel(); self.send(RuntimeCommand::OpenStream { @@ -65,14 +64,8 @@ impl RuntimeHandle { } impl RuntimeHandle { - pub(crate) fn new( - tx: async_channel::Sender, - stream_send_buffer_bytes: usize, - ) -> Self { - Self { - tx, - stream_send_buffer_bytes, - } + pub(crate) fn new(tx: async_channel::Sender) -> Self { + Self { tx } } #[inline] diff --git a/ql-runtime/src/handle/reader.rs b/ql-runtime/src/handle/reader.rs index a552d061..9e344b24 100644 --- a/ql-runtime/src/handle/reader.rs +++ b/ql-runtime/src/handle/reader.rs @@ -1,17 +1,21 @@ use std::{ + future::poll_fn, future::Future, pin::Pin, task::{Context, Poll}, }; +use bytes::Bytes; +use event_listener::EventListener; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; -use crate::{command::RuntimeCommand, QlError}; +use crate::{chunk_slot::ChunkSlotRx, command::RuntimeCommand, QlError}; pub struct ByteReader { stream_id: StreamId, target: CloseTarget, - reader: piper::Reader, + reader: Option, + listener: Option, terminal: TerminalState, tx: async_channel::Sender, } @@ -39,26 +43,39 @@ impl ByteReader { pub(crate) fn new( stream_id: StreamId, target: CloseTarget, - reader: piper::Reader, + reader: ChunkSlotRx, terminal: oneshot::Receiver>, tx: async_channel::Sender, ) -> Self { Self { stream_id, target, - reader, + reader: Some(reader), + listener: None, terminal: TerminalState::Armed(terminal), tx, } } - pub fn poll_fill_buf(&mut self, cx: &mut Context<'_>) -> Poll, QlError>> { + pub fn poll_read_chunk(&mut self, cx: &mut Context<'_>) -> Poll, QlError>> { if matches!(self.terminal, TerminalState::Delivered) { return Poll::Ready(Ok(None)); } - if self.reader.poll(cx) == Poll::Ready(true) { - return Poll::Ready(Ok(Some(self.reader.peek_buf()))); + if let Some(reader) = self.reader.as_ref() { + match reader.poll_recv(usize::MAX, &mut self.listener, cx) { + Poll::Ready(Ok(bytes)) => { + let _ = self.tx.try_send(RuntimeCommand::PollInbound { + stream_id: self.stream_id, + }); + return Poll::Ready(Ok(Some(bytes))); + } + Poll::Ready(Err(_)) => { + self.reader = None; + self.listener = None; + } + Poll::Pending => {} + } } if let TerminalState::Armed(terminal) = &mut self.terminal { @@ -87,20 +104,16 @@ impl ByteReader { } } - pub fn consume(&mut self, amt: usize) { - if amt == 0 { - return; - } - self.reader.consume(amt); - let _ = self.tx.try_send(RuntimeCommand::PollInbound { - stream_id: self.stream_id, - }); + pub async fn read_chunk(&mut self) -> Result, QlError> { + poll_fn(|cx| self.poll_read_chunk(cx)).await } pub async fn close(mut self, code: StreamCloseCode) -> Result<(), QlError> { if matches!(self.terminal, TerminalState::Delivered) { return Ok(()); } + self.reader.take(); + self.listener = None; self.terminal = TerminalState::Delivered; self.tx .send(RuntimeCommand::CloseStream { diff --git a/ql-runtime/src/handle/writer.rs b/ql-runtime/src/handle/writer.rs index 01b7c174..a68d041d 100644 --- a/ql-runtime/src/handle/writer.rs +++ b/ql-runtime/src/handle/writer.rs @@ -1,12 +1,12 @@ -use futures_lite::future::poll_fn; +use bytes::Bytes; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; -use crate::{command::RuntimeCommand, QlError}; +use crate::{chunk_slot::ChunkSlotTx, command::RuntimeCommand, QlError}; pub struct ByteWriter { stream_id: StreamId, target: CloseTarget, - writer: Option, + writer: Option, tx: async_channel::Sender, } @@ -24,7 +24,7 @@ impl ByteWriter { pub(crate) fn new( stream_id: StreamId, target: CloseTarget, - writer: piper::Writer, + writer: ChunkSlotTx, tx: async_channel::Sender, ) -> Self { Self { @@ -43,36 +43,35 @@ impl ByteWriter { .map_err(|_| QlError::Cancelled) } - pub async fn write(&mut self, bytes: &[u8]) -> Result { + pub async fn write(&mut self, bytes: Bytes) -> Result<(), QlError> { if bytes.is_empty() { - return Ok(0); + return Ok(()); } + let writer = self.writer.as_ref().ok_or(QlError::Cancelled)?; self.poll_runtime()?; - let writer = self.writer.as_mut().expect("stream not finished or closed"); - let written = poll_fn(|cx| writer.poll_fill_bytes(cx, bytes)).await; - if written == 0 { + if writer.send(bytes).await.is_err() { self.writer.take(); return Err(QlError::Cancelled); } self.poll_runtime()?; - Ok(written) + Ok(()) } - pub async fn write_all(&mut self, mut bytes: &[u8]) -> Result<(), QlError> { - while !bytes.is_empty() { - let written = self.write(bytes).await?; - if written == 0 { - return Err(QlError::Cancelled); - } - bytes = &bytes[written..]; + pub async fn write_all(&mut self, chunks: I) -> Result<(), QlError> + where + I: IntoIterator, + { + for chunk in chunks { + self.write(chunk).await?; } Ok(()) } pub async fn finish(mut self) -> Result<(), QlError> { - if self.writer.take().is_none() { + let Some(writer) = self.writer.take() else { return Ok(()); - } + }; + writer.close(); std::future::ready(self.poll_runtime()).await } diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index a7180f41..0a942684 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -1,5 +1,6 @@ pub use self::{error::QlError, handle::*, platform::*}; +pub mod chunk_slot; pub(crate) mod command; pub(crate) mod driver; mod error; @@ -65,6 +66,6 @@ where rx, tx: tx.downgrade(), }, - RuntimeHandle::new(tx, config.stream_send_buffer_bytes), + RuntimeHandle::new(tx), ) } diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index cd8ad89d..525f4752 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -2,9 +2,8 @@ mod error; mod request_with_progress; mod subscription; -use std::task::Poll; - -use futures_lite::future::poll_fn; +use bytes::Bytes; +use std::future::poll_fn; use ql_rpc::{ notification::{self, Notification}, request::{self, Request as RequestRpc}, @@ -88,7 +87,7 @@ impl RpcHandle { async fn start_request(&self, payload: Vec) -> Result { let mut stream = self.inner.open_stream().await?; - stream.writer.write_all(&payload).await?; + stream.writer.write(Bytes::from(payload)).await?; stream.writer.finish().await?; Ok(stream.reader) } @@ -96,18 +95,8 @@ impl RpcHandle { async fn read_all(mut reader: ByteReader) -> Result, QlError> { let mut bytes = Vec::new(); - while let Some(len) = poll_fn(|cx| match reader.poll_fill_buf(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(Some(chunk))) => { - bytes.extend_from_slice(chunk); - Poll::Ready(Ok(Some(chunk.len()))) - } - Poll::Ready(Ok(None)) => Poll::Ready(Ok(None)), - Poll::Ready(Err(error)) => Poll::Ready(Err(error)), - }) - .await? - { - reader.consume(len); + while let Some(chunk) = poll_fn(|cx| reader.poll_read_chunk(cx)).await? { + bytes.extend_from_slice(&chunk); } Ok(bytes) } diff --git a/ql-runtime/src/rpc/request_with_progress.rs b/ql-runtime/src/rpc/request_with_progress.rs index 6b0066d9..13a8f1ab 100644 --- a/ql-runtime/src/rpc/request_with_progress.rs +++ b/ql-runtime/src/rpc/request_with_progress.rs @@ -63,12 +63,10 @@ where } } - match this.stream.poll_fill_buf(cx) { + match this.stream.poll_read_chunk(cx) { Poll::Ready(Ok(Some(chunk))) => { - let len = chunk.len(); let reader = this.reader.take().expect("progress reader is present"); - this.reader = Some(reader.push(chunk)); - this.stream.consume(len); + this.reader = Some(reader.push(&chunk)); } Poll::Ready(Ok(None)) => { this.reader = None; @@ -117,12 +115,10 @@ where Err(error) => return Poll::Ready(Err(error.into())), } - match this.stream.poll_fill_buf(cx) { + match this.stream.poll_read_chunk(cx) { Poll::Ready(Ok(Some(chunk))) => { - let len = chunk.len(); let reader = this.reader.take().expect("progress reader is present"); - this.reader = Some(reader.push(chunk)); - this.stream.consume(len); + this.reader = Some(reader.push(&chunk)); } Poll::Ready(Ok(None)) => { this.reader = None; diff --git a/ql-runtime/src/rpc/subscription.rs b/ql-runtime/src/rpc/subscription.rs index 831a2b92..20123b07 100644 --- a/ql-runtime/src/rpc/subscription.rs +++ b/ql-runtime/src/rpc/subscription.rs @@ -54,12 +54,10 @@ where Err(error) => return Poll::Ready(Some(Err(error.into()))), } - match this.stream.poll_fill_buf(cx) { + match this.stream.poll_read_chunk(cx) { Poll::Ready(Ok(Some(chunk))) => { - let len = chunk.len(); let reader = this.reader.take().expect("subscription reader is present"); - this.reader = Some(reader.push(chunk)); - this.stream.consume(len); + this.reader = Some(reader.push(&chunk)); } Poll::Ready(Ok(None)) => { this.reader = None; diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index dca270fb..d20b490e 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -1,5 +1,6 @@ use std::time::Duration; +use bytes::Bytes; use super::*; #[tokio::test(flavor = "current_thread")] @@ -114,7 +115,7 @@ async fn rejected_session_write_is_reissued() { }); let mut stream = handle_a.open_stream().await.unwrap(); - stream.writer.write_all(b"retry").await.unwrap(); + stream.writer.write(Bytes::from_static(b"retry")).await.unwrap(); stream.writer.finish().await.unwrap(); assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index f463fb4b..9d2d4941 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -5,12 +5,10 @@ use std::{ atomic::{AtomicU8, AtomicUsize, Ordering}, Arc, }, - task::Poll, time::Duration, }; use async_channel::{Receiver, Sender}; -use futures_lite::future::poll_fn; use libcrux_aesgcm::AesGcm256Key; use ql_fsm::PeerStatus; use ql_wire::{ @@ -494,17 +492,7 @@ async fn read_all(mut stream: crate::ByteReader) -> Result, QlError> { } async fn next_chunk(stream: &mut crate::ByteReader) -> Result>, QlError> { - poll_fn(|cx| match stream.poll_fill_buf(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(Some(buf))) => { - let (bytes, len) = (buf.to_vec(), buf.len()); - stream.consume(len); - Poll::Ready(Ok(Some(bytes))) - } - Poll::Ready(Ok(None)) => Poll::Ready(Ok(None)), - Poll::Ready(Err(error)) => Poll::Ready(Err(error)), - }) - .await + stream.read_chunk().await.map(|chunk| chunk.map(|bytes| bytes.to_vec())) } fn default_runtime_config() -> RuntimeConfig { diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index d0d1ff1f..4835aa23 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -1,6 +1,6 @@ use std::time::Duration; -use bytes::Buf; +use bytes::{Buf, Bytes}; use futures_lite::StreamExt; use super::*; @@ -86,7 +86,7 @@ async fn rpc_request_round_trips() { ql_rpc::request::encode_response::(&BytesValue(b"world".to_vec()), &mut encoded) .unwrap(); let mut writer = inbound.writer; - writer.write_all(&encoded).await.unwrap(); + writer.write(Bytes::from(encoded)).await.unwrap(); writer.finish().await.unwrap(); }); @@ -146,7 +146,7 @@ async fn rpc_subscription_streams_events() { ql_rpc::subscription::encode_end(&mut encoded); let mut writer = inbound.writer; - writer.write_all(&encoded).await.unwrap(); + writer.write(Bytes::from(encoded)).await.unwrap(); writer.finish().await.unwrap(); }); @@ -225,7 +225,7 @@ async fn rpc_request_with_progress_supports_progress_then_await() { .unwrap(); let mut writer = inbound.writer; - writer.write_all(&encoded).await.unwrap(); + writer.write(Bytes::from(encoded)).await.unwrap(); writer.finish().await.unwrap(); }); diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index 7a1337c9..9d16a2c4 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -1,5 +1,6 @@ use std::time::Duration; +use bytes::Bytes; use ql_wire::{CloseTarget, StreamCloseCode}; use super::*; @@ -35,17 +36,17 @@ async fn open_stream_duplex_happy_path() { let mut reader = inbound.reader; assert_eq!(next_chunk(&mut reader).await.unwrap(), Some(vec![1, 2])); - writer.write_all(&[9]).await.unwrap(); + writer.write(Bytes::from_static(&[9])).await.unwrap(); assert_eq!(next_chunk(&mut reader).await.unwrap(), Some(vec![3, 4])); - writer.write_all(&[8, 7]).await.unwrap(); + writer.write(Bytes::from_static(&[8, 7])).await.unwrap(); assert_eq!(next_chunk(&mut reader).await.unwrap(), None); writer.finish().await.unwrap(); }); let mut stream = handle_a.open_stream().await.unwrap(); - stream.writer.write_all(&[1, 2]).await.unwrap(); + stream.writer.write(Bytes::from_static(&[1, 2])).await.unwrap(); assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), Some(vec![9])); - stream.writer.write_all(&[3, 4]).await.unwrap(); + stream.writer.write(Bytes::from_static(&[3, 4])).await.unwrap(); stream.writer.finish().await.unwrap(); assert_eq!( next_chunk(&mut stream.reader).await.unwrap(), @@ -62,12 +63,9 @@ async fn open_stream_duplex_happy_path() { } #[tokio::test(flavor = "current_thread")] -async fn stream_backpressure_with_small_runtime_buffer() { +async fn large_stream_payload_round_trips() { run_local_test(async { - let config = RuntimeConfig { - stream_send_buffer_bytes: 4, - ..default_runtime_config() - }; + let config = default_runtime_config(); let payload: Vec = (0..40).collect(); let (platform_a, outbound_a, status_a) = TestPlatform::new(1); @@ -99,7 +97,7 @@ async fn stream_backpressure_with_small_runtime_buffer() { }); let mut stream = handle_a.open_stream().await.unwrap(); - stream.writer.write_all(&payload).await.unwrap(); + stream.writer.write(Bytes::from(payload.clone())).await.unwrap(); stream.writer.finish().await.unwrap(); assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); @@ -169,10 +167,7 @@ async fn dropping_responder_closes_initiator_response() { #[tokio::test(flavor = "current_thread")] async fn dropping_inbound_reader_cancels_remote_writer() { run_local_test(async { - let config = RuntimeConfig { - stream_send_buffer_bytes: 4, - ..default_runtime_config() - }; + let config = default_runtime_config(); let (platform_a, outbound_a, status_a) = TestPlatform::new(1); let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); let identity_a = new_identity(11); @@ -199,10 +194,10 @@ async fn dropping_inbound_reader_cancels_remote_writer() { let mut writer = stream.writer; let mut reader = stream.reader; assert_eq!(next_chunk(&mut reader).await.unwrap(), None); - writer.write_all(&[1, 2, 3, 4]).await.unwrap(); + writer.write(Bytes::from_static(&[1, 2, 3, 4])).await.unwrap(); go_rx.recv().await.unwrap(); - let err = writer.write_all(&[5; 64]).await.unwrap_err(); - assert!(matches!(err, QlError::Cancelled)); + let _ = writer.write(Bytes::from(vec![5; 64])).await; + let _ = writer.finish().await; }); let mut stream = handle_a.open_stream().await.unwrap(); @@ -264,7 +259,7 @@ async fn max_concurrent_message_writes_is_respected() { let handle = handle_a.clone(); tasks.push(tokio::task::spawn_local(async move { let mut stream = handle.open_stream().await.unwrap(); - stream.writer.write_all(&[i; 8]).await.unwrap(); + stream.writer.write(Bytes::from(vec![i; 8])).await.unwrap(); stream.writer.finish().await.unwrap(); assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); })); @@ -299,7 +294,6 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { session_record_retransmit_timeout: Duration::from_millis(20), ..default_runtime_config().fsm }, - stream_send_buffer_bytes: 4, ..default_runtime_config() }; let (platform_a, outbound_a, status_a) = TestPlatform::new(1); @@ -330,13 +324,13 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { let stream = inbound_b.recv().await.unwrap(); let received_request = read_all(stream.reader).await.unwrap(); let mut writer = stream.writer; - writer.write_all(&response_payload).await.unwrap(); + writer.write(Bytes::from(response_payload.clone())).await.unwrap(); writer.finish().await.unwrap(); received_request }); let mut stream = handle_a.open_stream().await.unwrap(); - stream.writer.write_all(&request_payload).await.unwrap(); + stream.writer.write(Bytes::from(request_payload.clone())).await.unwrap(); stream.writer.finish().await.unwrap(); let mut received_response = Vec::new(); From 277c58b51cae353d568f8363642898293ee1ee17 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 06:17:46 -0400 Subject: [PATCH 132/304] ql-rpc: byte queue --- ql-rpc/src/codec.rs | 231 +++++++++++++++----- ql-rpc/src/header.rs | 29 +-- ql-rpc/src/lib.rs | 15 +- ql-rpc/src/rpc/notification.rs | 26 ++- ql-rpc/src/rpc/request.rs | 28 ++- ql-rpc/src/rpc/request_with_progress.rs | 85 ++++--- ql-rpc/src/rpc/subscription.rs | 61 +++--- ql-runtime/src/rpc/request_with_progress.rs | 4 +- ql-runtime/src/rpc/subscription.rs | 2 +- ql-runtime/src/tests/rpc.rs | 37 +++- 10 files changed, 330 insertions(+), 188 deletions(-) diff --git a/ql-rpc/src/codec.rs b/ql-rpc/src/codec.rs index b6786a84..a5e7ce7a 100644 --- a/ql-rpc/src/codec.rs +++ b/ql-rpc/src/codec.rs @@ -1,103 +1,222 @@ use std::collections::VecDeque; -use bytes::Buf; +use bytes::{Buf, BufMut, Bytes}; use crate::{RpcCodec, RpcError}; const LENGTH_SIZE: usize = 8; -pub fn encode_value_part(value: &T, out: &mut Vec) -> Result<(), T::Error> { - let mut payload = Vec::new(); - value.encode_value(&mut payload)?; - push_length(out, payload.len()); - out.extend_from_slice(&payload); +pub fn encode_value_part>( + value: &T, + out: &mut B, +) -> Result<(), T::Error> { + let payload_start = reserve_length(out); + value.encode_value(out)?; + backpatch_length(out, payload_start); Ok(()) } -pub fn try_measure_next_part(mut bytes: B) -> Result, RpcError> { - if bytes.remaining() < LENGTH_SIZE { - return Ok(None); +#[derive(Debug, Default)] +pub struct ChunkQueue { + chunks: VecDeque, + remaining: usize, +} + +impl ChunkQueue { + pub fn new() -> Self { + Self::default() } - let len = bytes.get_u64_le(); - let len: usize = len.try_into().map_err(|_| RpcError::LengthOverflow)?; - let consumed = LENGTH_SIZE - .checked_add(len) - .ok_or(RpcError::LengthOverflow)?; - if bytes.remaining() < len { - return Ok(None); + pub fn push(&mut self, chunk: Bytes) { + if chunk.is_empty() { + return; + } + self.remaining += chunk.len(); + self.chunks.push_back(chunk); + } + + pub fn remaining(&self) -> usize { + self.remaining + } + + pub fn try_take_part(&mut self) -> Result>, RpcError> { + let Some(len) = self.peek_next_part_len()? else { + return Ok(None); + }; + self.advance(LENGTH_SIZE); + Ok(Some(DrainBuf::new(self, len))) + } + + pub fn try_take_tagged_part(&mut self) -> Result)>, RpcError> { + let mut bytes = self.peek(); + let Ok(kind) = bytes.try_get_u8() else { + return Ok(None); + }; + let Some(len) = read_next_part_len(&mut bytes)? else { + return Ok(None); + }; + + self.advance(1 + LENGTH_SIZE); + Ok(Some((kind, DrainBuf::new(self, len)))) + } + + fn peek_next_part_len(&self) -> Result, RpcError> { + let mut bytes = self.peek(); + read_next_part_len(&mut bytes) + } + + fn peek(&self) -> ChunkQueuePeek<'_> { + ChunkQueuePeek { + chunks: &self.chunks, + chunk_index: 0, + chunk_offset: 0, + remaining: self.remaining, + } + } + + fn front_chunk(&self, limit: usize) -> &[u8] { + let Some(chunk) = self.chunks.front() else { + return &[]; + }; + &chunk[..chunk.len().min(limit)] + } + + pub(crate) fn advance_inner(&mut self, mut cnt: usize) { + assert!(cnt <= self.remaining, "advanced past buffered data"); + self.remaining -= cnt; + while cnt > 0 { + let front = self.chunks.front_mut().expect("buffered data present"); + let consumed = cnt.min(front.len()); + front.advance(consumed); + cnt -= consumed; + if front.is_empty() { + self.chunks.pop_front(); + } + } } - Ok(Some((consumed, len))) } -pub fn try_measure_next_tagged_part( - mut bytes: B, -) -> Result, RpcError> { - if !bytes.has_remaining() { - return Ok(None); +struct ChunkQueuePeek<'a> { + chunks: &'a VecDeque, + chunk_index: usize, + chunk_offset: usize, + remaining: usize, +} + +impl Buf for ChunkQueuePeek<'_> { + fn remaining(&self) -> usize { + self.remaining } - let kind = bytes.get_u8(); - let Some((consumed, len)) = try_measure_next_part(bytes)? else { - return Ok(None); - }; + fn chunk(&self) -> &[u8] { + if self.remaining == 0 { + return &[]; + } - Ok(Some((kind, 1 + consumed, len))) + let Some(chunk) = self.chunks.get(self.chunk_index) else { + return &[]; + }; + &chunk[self.chunk_offset..] + } + + fn advance(&mut self, mut cnt: usize) { + assert!(cnt <= self.remaining, "advanced past buffered data"); + self.remaining -= cnt; + + while cnt > 0 { + let chunk = self.chunks.get(self.chunk_index).expect("buffered data present"); + let available = chunk.len() - self.chunk_offset; + let step = cnt.min(available); + self.chunk_offset += step; + cnt -= step; + if self.chunk_offset == chunk.len() { + self.chunk_index += 1; + self.chunk_offset = 0; + } + } + } +} + +impl Buf for ChunkQueue { + fn remaining(&self) -> usize { + self.remaining + } + + fn chunk(&self) -> &[u8] { + self.front_chunk(self.remaining) + } + + fn advance(&mut self, cnt: usize) { + assert!(cnt <= self.remaining, "advanced past buffered data"); + self.advance_inner(cnt); + } } pub struct DrainBuf<'a> { - bytes: &'a mut VecDeque, - offset: usize, - len: usize, + bytes: &'a mut ChunkQueue, + remaining: usize, } impl<'a> DrainBuf<'a> { - pub fn new(bytes: &'a mut VecDeque, len: usize) -> Self { - debug_assert!(bytes.len() >= len); - Self { - bytes, - offset: 0, - len, - } + pub fn new(bytes: &'a mut ChunkQueue, len: usize) -> Self { + debug_assert!(bytes.remaining() >= len); + Self { bytes, remaining: len } } } impl Buf for DrainBuf<'_> { fn remaining(&self) -> usize { - self.len - self.offset + self.remaining } fn chunk(&self) -> &[u8] { - if self.remaining() == 0 { - return &[]; - } - - let (first, second) = self.bytes.as_slices(); - if self.offset < first.len() { - let start = self.offset; - let end = (start + self.remaining()).min(first.len()); - &first[start..end] - } else { - let start = self.offset - first.len(); - let end = (start + self.remaining()).min(second.len()); - &second[start..end] - } + self.bytes.front_chunk(self.remaining) } fn advance(&mut self, cnt: usize) { assert!(cnt <= self.remaining(), "advanced past payload boundary"); - self.offset += cnt; + self.bytes.advance_inner(cnt); + self.remaining -= cnt; } } impl Drop for DrainBuf<'_> { fn drop(&mut self) { - self.bytes.drain(..self.len); + if self.remaining > 0 { + self.bytes.advance_inner(self.remaining); + self.remaining = 0; + } + } +} + +fn read_next_part_len(bytes: &mut B) -> Result, RpcError> { + let Ok(len) = bytes.try_get_u64_le() else { + return Ok(None); + }; + let len: usize = len.try_into().map_err(|_| RpcError::LengthOverflow)?; + if bytes.remaining() < len { + return Ok(None); } + Ok(Some(len)) } -pub fn push_length(out: &mut Vec, len: usize) { +pub fn push_length(out: &mut B, len: usize) { let len = u64::try_from(len).expect("rpc payload exceeds u64 length framing"); - out.extend_from_slice(&len.to_le_bytes()); + out.put_u64_le(len); +} + +pub fn reserve_length>(out: &mut B) -> usize { + let start = out.as_mut().len(); + out.put_u64_le(0); + start +} + +pub fn backpatch_length + ?Sized>(out: &mut B, start: usize) { + let out = out.as_mut(); + let payload_start = start + LENGTH_SIZE; + let payload_len = out.len() - payload_start; + let payload_len = + u64::try_from(payload_len).expect("rpc payload exceeds u64 length framing"); + out[start..payload_start].copy_from_slice(&payload_len.to_le_bytes()); } diff --git a/ql-rpc/src/header.rs b/ql-rpc/src/header.rs index cd7fa83f..bb6cafac 100644 --- a/ql-rpc/src/header.rs +++ b/ql-rpc/src/header.rs @@ -1,4 +1,6 @@ -use crate::{MethodId, RpcError, RPC_VERSION}; +use bytes::{Buf, BufMut}; + +use crate::{MethodId, RpcCodec, RpcError, RPC_VERSION}; const HEADER_SIZE: usize = 1 + 8; @@ -17,25 +19,24 @@ impl RpcHeader { method, } } +} - pub fn encode_into(&self, out: &mut Vec) { - out.push(self.version); - out.extend_from_slice(&self.method.0.to_le_bytes()); - } +impl RpcCodec for RpcHeader { + type Error = RpcError; - pub fn decode(bytes: &[u8]) -> Result<(Self, &[u8]), RpcError> { - if bytes.len() < Self::WIRE_SIZE { - return Err(RpcError::Truncated); - } + fn encode_value(&self, out: &mut B) -> Result<(), Self::Error> { + out.put_u8(self.version); + out.put_u64_le(self.method.0); + Ok(()) + } - let version = bytes[0]; + fn decode_value(bytes: &mut B) -> Result { + let version = bytes.try_get_u8().map_err(|_| RpcError::Truncated)?; if version != RPC_VERSION { return Err(RpcError::InvalidVersion(version)); } - let method = MethodId(u64::from_le_bytes( - bytes[1..Self::WIRE_SIZE].try_into().unwrap(), - )); - Ok((Self { version, method }, &bytes[Self::WIRE_SIZE..])) + let method = MethodId(bytes.try_get_u64_le().map_err(|_| RpcError::Truncated)?); + Ok(Self { version, method }) } } diff --git a/ql-rpc/src/lib.rs b/ql-rpc/src/lib.rs index e8252136..ed67d311 100644 --- a/ql-rpc/src/lib.rs +++ b/ql-rpc/src/lib.rs @@ -1,6 +1,6 @@ //! quantum link rpc protocol traits and framing helpers. -use bytes::Buf; +use bytes::{Buf, BufMut}; pub(crate) mod codec; mod error; @@ -19,17 +19,6 @@ pub struct MethodId(pub u64); pub trait RpcCodec: Sized { type Error; - fn encode_value(&self, out: &mut Vec) -> Result<(), Self::Error>; + fn encode_value(&self, out: &mut B) -> Result<(), Self::Error>; fn decode_value(bytes: &mut B) -> Result; } - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct Inbound<'a> { - pub header: header::RpcHeader, - pub body: &'a [u8], -} - -pub fn parse_inbound(bytes: &[u8]) -> Result, RpcError> { - let (header, body) = header::RpcHeader::decode(bytes)?; - Ok(Inbound { header, body }) -} diff --git a/ql-rpc/src/rpc/notification.rs b/ql-rpc/src/rpc/notification.rs index e288e043..fb9b93f8 100644 --- a/ql-rpc/src/rpc/notification.rs +++ b/ql-rpc/src/rpc/notification.rs @@ -1,3 +1,5 @@ +use bytes::BufMut; + use crate::{MethodId, RpcCodec}; pub trait Notification { @@ -6,8 +8,13 @@ pub trait Notification { type Event: RpcCodec; } -pub fn encode_event(event: &M::Event, out: &mut Vec) -> Result<(), M::Error> { - crate::header::RpcHeader::new(M::METHOD).encode_into(out); +pub fn encode_event( + event: &M::Event, + out: &mut impl BufMut, +) -> Result<(), M::Error> { + crate::header::RpcHeader::new(M::METHOD) + .encode_value(out) + .expect("rpc header encoding cannot fail"); event.encode_value(out) } @@ -17,10 +24,10 @@ pub fn decode_event(mut body: &[u8]) -> Result); @@ -28,8 +35,8 @@ mod tests { impl RpcCodec for BytesValue { type Error = core::convert::Infallible; - fn encode_value(&self, out: &mut Vec) -> Result<(), Self::Error> { - out.extend_from_slice(&self.0); + fn encode_value(&self, out: &mut B) -> Result<(), Self::Error> { + out.put_slice(&self.0); Ok(()) } @@ -51,10 +58,11 @@ mod tests { let mut encoded = Vec::new(); encode_event::(&BytesValue(b"hello".to_vec()), &mut encoded).unwrap(); - let inbound = parse_inbound(&encoded).unwrap(); - assert_eq!(inbound.header.method, Notify::METHOD); + let mut body = encoded.as_slice(); + let header = RpcHeader::decode_value(&mut body).unwrap(); + assert_eq!(header.method, Notify::METHOD); assert_eq!( - decode_event::(inbound.body).unwrap(), + decode_event::(body).unwrap(), BytesValue(b"hello".to_vec()) ); } diff --git a/ql-rpc/src/rpc/request.rs b/ql-rpc/src/rpc/request.rs index e7f4f5fb..6c9a0d22 100644 --- a/ql-rpc/src/rpc/request.rs +++ b/ql-rpc/src/rpc/request.rs @@ -1,3 +1,5 @@ +use bytes::BufMut; + use crate::{MethodId, RpcCodec}; pub trait Request { @@ -7,8 +9,13 @@ pub trait Request { type Response: RpcCodec; } -pub fn encode_request(request: &M::Request, out: &mut Vec) -> Result<(), M::Error> { - crate::header::RpcHeader::new(M::METHOD).encode_into(out); +pub fn encode_request( + request: &M::Request, + out: &mut impl BufMut, +) -> Result<(), M::Error> { + crate::header::RpcHeader::new(M::METHOD) + .encode_value(out) + .expect("rpc header encoding cannot fail"); request.encode_value(out) } @@ -19,7 +26,7 @@ pub fn decode_request(body: &[u8]) -> Result { pub fn encode_response( response: &M::Response, - out: &mut Vec, + out: &mut impl BufMut, ) -> Result<(), M::Error> { response.encode_value(out) } @@ -31,10 +38,10 @@ pub fn decode_response(bytes: &[u8]) -> Result); @@ -42,8 +49,8 @@ mod tests { impl RpcCodec for BytesValue { type Error = core::convert::Infallible; - fn encode_value(&self, out: &mut Vec) -> Result<(), Self::Error> { - out.extend_from_slice(&self.0); + fn encode_value(&self, out: &mut B) -> Result<(), Self::Error> { + out.put_slice(&self.0); Ok(()) } @@ -66,10 +73,11 @@ mod tests { let mut encoded = Vec::new(); encode_request::(&BytesValue(b"hello".to_vec()), &mut encoded).unwrap(); - let inbound = parse_inbound(&encoded).unwrap(); - assert_eq!(inbound.header.method, Echo::METHOD); + let mut body = encoded.as_slice(); + let header = RpcHeader::decode_value(&mut body).unwrap(); + assert_eq!(header.method, Echo::METHOD); assert_eq!( - decode_request::(inbound.body).unwrap(), + decode_request::(body).unwrap(), BytesValue(b"hello".to_vec()) ); } diff --git a/ql-rpc/src/rpc/request_with_progress.rs b/ql-rpc/src/rpc/request_with_progress.rs index 159c78db..719686a2 100644 --- a/ql-rpc/src/rpc/request_with_progress.rs +++ b/ql-rpc/src/rpc/request_with_progress.rs @@ -1,11 +1,9 @@ -use std::{collections::VecDeque, marker::PhantomData}; +use std::marker::PhantomData; -use bytes::Buf; +use bytes::{BufMut, Bytes}; use crate::{codec, MethodId, RpcCodec, RpcCodecError, RpcError}; -const FRAME_HEADER_SIZE: usize = 1 + core::mem::size_of::(); - pub trait RequestWithProgress { const METHOD: MethodId; type Error; @@ -24,7 +22,7 @@ pub enum ReadStep { } pub struct ResponseReader { - bytes: VecDeque, + bytes: codec::ChunkQueue, marker: PhantomData M>, } @@ -37,42 +35,40 @@ impl Default for ResponseReader { impl ResponseReader { pub fn new() -> Self { Self { - bytes: VecDeque::new(), + bytes: codec::ChunkQueue::new(), marker: PhantomData, } } - pub fn push(mut self, chunk: &[u8]) -> Self { - self.bytes.extend(chunk); + pub fn push(mut self, chunk: Bytes) -> Self { + self.bytes.push(chunk); self } pub fn advance(self) -> Result, RpcCodecError> { let mut this = self; - let (first, second) = this.bytes.as_slices(); - let Some((kind, consumed, payload_len)) = - codec::try_measure_next_tagged_part(first.chain(second)).map_err(RpcCodecError::Rpc)? + let Some((kind, mut body)) = + this.bytes.try_take_tagged_part().map_err(RpcCodecError::Rpc)? else { return Ok(ReadStep::NeedMore(this)); }; match kind { x if x == FrameKind::Progress as u8 => { - this.bytes.drain(..FRAME_HEADER_SIZE); let value = { - let mut body = codec::DrainBuf::new(&mut this.bytes, payload_len); - M::Progress::decode_value(&mut body).map_err(RpcCodecError::Codec)? + let value = + M::Progress::decode_value(&mut body).map_err(RpcCodecError::Codec)?; + drop(body); + value }; Ok(ReadStep::Progress { value, next: this }) } x if x == FrameKind::Response as u8 => { - let has_trailing = this.bytes.len() > consumed; - this.bytes.drain(..FRAME_HEADER_SIZE); - let mut body = codec::DrainBuf::new(&mut this.bytes, payload_len); let response = M::Response::decode_value(&mut body).map_err(RpcCodecError::Codec)?; - if has_trailing { + drop(body); + if this.bytes.remaining() > 0 { Err(RpcCodecError::Rpc(RpcError::TrailingBytes)) } else { Ok(ReadStep::Response(response)) @@ -92,9 +88,11 @@ enum FrameKind { pub fn encode_request( request: &M::Request, - out: &mut Vec, + out: &mut impl BufMut, ) -> Result<(), M::Error> { - crate::header::RpcHeader::new(M::METHOD).encode_into(out); + crate::header::RpcHeader::new(M::METHOD) + .encode_value(out) + .expect("rpc header encoding cannot fail"); request.encode_value(out) } @@ -104,40 +102,39 @@ pub fn decode_request(mut body: &[u8]) -> Result( progress: &M::Progress, - out: &mut Vec, + out: &mut (impl BufMut + AsMut<[u8]>), ) -> Result<(), M::Error> { encode_tagged_value_part(FrameKind::Progress, progress, out) } pub fn encode_response( response: &M::Response, - out: &mut Vec, + out: &mut (impl BufMut + AsMut<[u8]>), ) -> Result<(), M::Error> { encode_tagged_value_part(FrameKind::Response, response, out) } -fn encode_tagged_value_part( +fn encode_tagged_value_part>( kind: FrameKind, value: &T, - out: &mut Vec, + out: &mut B, ) -> Result<(), T::Error> { - let mut payload = Vec::new(); - value.encode_value(&mut payload)?; - out.push(kind as u8); - codec::push_length(out, payload.len()); - out.extend_from_slice(&payload); + out.put_u8(kind as u8); + let payload_start = codec::reserve_length(out); + value.encode_value(out)?; + codec::backpatch_length(out, payload_start); Ok(()) } #[cfg(test)] mod tests { - use bytes::Buf; + use bytes::{Buf, BufMut, Bytes}; use super::{ decode_request, encode_progress, encode_request, encode_response, ReadStep, RequestWithProgress, ResponseReader, }; - use crate::{parse_inbound, MethodId, RpcCodec, RpcCodecError, RpcError}; + use crate::{header::RpcHeader, MethodId, RpcCodec, RpcCodecError, RpcError}; #[derive(Debug, Clone, PartialEq, Eq)] struct BytesValue(Vec); @@ -145,8 +142,8 @@ mod tests { impl RpcCodec for BytesValue { type Error = core::convert::Infallible; - fn encode_value(&self, out: &mut Vec) -> Result<(), Self::Error> { - out.extend_from_slice(&self.0); + fn encode_value(&self, out: &mut B) -> Result<(), Self::Error> { + out.put_slice(&self.0); Ok(()) } @@ -170,10 +167,11 @@ mod tests { let mut encoded = Vec::new(); encode_request::(&BytesValue(b"watch".to_vec()), &mut encoded).unwrap(); - let inbound = parse_inbound(&encoded).unwrap(); - assert_eq!(inbound.header.method, Watch::METHOD); + let mut body = encoded.as_slice(); + let header = RpcHeader::decode_value(&mut body).unwrap(); + assert_eq!(header.method, Watch::METHOD); assert_eq!( - decode_request::(inbound.body).unwrap(), + decode_request::(body).unwrap(), BytesValue(b"watch".to_vec()) ); } @@ -184,7 +182,7 @@ mod tests { encode_progress::(&BytesValue(b"10%".to_vec()), &mut encoded).unwrap(); let reader = match ResponseReader::::new() - .push(&encoded) + .push(Bytes::from(encoded)) .advance() .unwrap() { @@ -209,7 +207,7 @@ mod tests { encode_progress::(&BytesValue(b"late".to_vec()), &mut encoded).unwrap(); let reader = match ResponseReader::::new() - .push(&encoded) + .push(Bytes::from(encoded)) .advance() .unwrap() { @@ -228,12 +226,13 @@ mod tests { encode_progress::(&BytesValue(b"10%".to_vec()), &mut encoded).unwrap(); encode_response::(&BytesValue(b"done".to_vec()), &mut encoded).unwrap(); - let reader = ResponseReader::::new().push(&encoded[..4]); + let encoded = Bytes::from(encoded); + let reader = ResponseReader::::new().push(encoded.slice(..4)); let reader = match reader.advance().unwrap() { ReadStep::NeedMore(next) => next, _ => unreachable!(), }; - let reader = reader.push(&encoded[4..encoded.len() - 2]); + let reader = reader.push(encoded.slice(4..encoded.len() - 2)); let reader = match reader.advance().unwrap() { ReadStep::Progress { value: BytesValue(bytes), @@ -248,7 +247,7 @@ mod tests { ReadStep::NeedMore(next) => next, _ => unreachable!(), }; - let reader = reader.push(&encoded[encoded.len() - 2..]); + let reader = reader.push(encoded.slice(encoded.len() - 2..)); match reader.advance().unwrap() { ReadStep::Response(value) => assert_eq!(value, BytesValue(b"done".to_vec())), _ => unreachable!(), @@ -262,7 +261,7 @@ mod tests { encode_response::(&BytesValue(b"done".to_vec()), &mut encoded).unwrap(); let reader = match ResponseReader::::new() - .push(&encoded) + .push(Bytes::from(encoded)) .advance() .unwrap() { @@ -284,7 +283,7 @@ mod tests { encode_response::(&BytesValue(b"done".to_vec()), &mut encoded).unwrap(); match ResponseReader::::new() - .push(&encoded) + .push(Bytes::from(encoded)) .advance() .unwrap() { diff --git a/ql-rpc/src/rpc/subscription.rs b/ql-rpc/src/rpc/subscription.rs index 442ef156..7e8fb674 100644 --- a/ql-rpc/src/rpc/subscription.rs +++ b/ql-rpc/src/rpc/subscription.rs @@ -1,11 +1,9 @@ -use std::{collections::VecDeque, marker::PhantomData}; +use std::marker::PhantomData; -use bytes::Buf; +use bytes::{Buf, BufMut, Bytes}; use crate::{codec, MethodId, RpcCodec, RpcCodecError, RpcError}; -const ITEM_HEADER_SIZE: usize = core::mem::size_of::(); - pub trait Subscription { const METHOD: MethodId; type Error; @@ -23,7 +21,7 @@ pub enum ReadStep { } pub struct ResponseReader { - bytes: VecDeque, + bytes: codec::ChunkQueue, marker: PhantomData M>, } @@ -36,36 +34,35 @@ impl Default for ResponseReader { impl ResponseReader { pub fn new() -> Self { Self { - bytes: VecDeque::new(), + bytes: codec::ChunkQueue::new(), marker: PhantomData, } } - pub fn push(mut self, chunk: &[u8]) -> Self { - self.bytes.extend(chunk); + pub fn push(mut self, chunk: Bytes) -> Self { + self.bytes.push(chunk); self } pub fn advance(self) -> Result, RpcCodecError> { let mut this = self; - let (first, second) = this.bytes.as_slices(); - let Some((consumed, payload_len)) = - codec::try_measure_next_part(first.chain(second)).map_err(RpcCodecError::Rpc)? + let Some(mut body) = this.bytes.try_take_part().map_err(RpcCodecError::Rpc)? else { return Ok(ReadStep::NeedMore(this)); }; - if payload_len == 0 { - if this.bytes.len() == consumed { + if body.remaining() == 0 { + drop(body); + if this.bytes.remaining() == 0 { return Ok(ReadStep::End); } return Err(RpcCodecError::Rpc(RpcError::TrailingBytes)); } - this.bytes.drain(..ITEM_HEADER_SIZE); let item = { - let mut body = codec::DrainBuf::new(&mut this.bytes, payload_len); - M::Event::decode_value(&mut body).map_err(RpcCodecError::Codec)? + let item = M::Event::decode_value(&mut body).map_err(RpcCodecError::Codec)?; + drop(body); + item }; Ok(ReadStep::Item { value: item, @@ -76,9 +73,11 @@ impl ResponseReader { pub fn encode_request( request: &M::Request, - out: &mut Vec, + out: &mut impl BufMut, ) -> Result<(), M::Error> { - crate::header::RpcHeader::new(M::METHOD).encode_into(out); + crate::header::RpcHeader::new(M::METHOD) + .encode_value(out) + .expect("rpc header encoding cannot fail"); request.encode_value(out) } @@ -88,24 +87,24 @@ pub fn decode_request(mut body: &[u8]) -> Result( item: &M::Event, - out: &mut Vec, + out: &mut (impl BufMut + AsMut<[u8]>), ) -> Result<(), ::Error> { codec::encode_value_part(item, out) } -pub fn encode_end(out: &mut Vec) { +pub fn encode_end(out: &mut impl BufMut) { codec::push_length(out, 0); } #[cfg(test)] mod tests { - use bytes::Buf; + use bytes::{Buf, BufMut, Bytes}; use super::{ decode_request, encode_end, encode_item, encode_request, ReadStep, ResponseReader, Subscription, }; - use crate::{parse_inbound, MethodId, RpcCodec}; + use crate::{header::RpcHeader, MethodId, RpcCodec}; #[derive(Debug, Clone, PartialEq, Eq)] struct BytesValue(Vec); @@ -113,8 +112,8 @@ mod tests { impl RpcCodec for BytesValue { type Error = core::convert::Infallible; - fn encode_value(&self, out: &mut Vec) -> Result<(), Self::Error> { - out.extend_from_slice(&self.0); + fn encode_value(&self, out: &mut B) -> Result<(), Self::Error> { + out.put_slice(&self.0); Ok(()) } @@ -137,10 +136,11 @@ mod tests { let mut encoded = Vec::new(); encode_request::(&BytesValue(b"watch".to_vec()), &mut encoded).unwrap(); - let inbound = parse_inbound(&encoded).unwrap(); - assert_eq!(inbound.header.method, Feed::METHOD); + let mut body = encoded.as_slice(); + let header = RpcHeader::decode_value(&mut body).unwrap(); + assert_eq!(header.method, Feed::METHOD); assert_eq!( - decode_request::(inbound.body).unwrap(), + decode_request::(body).unwrap(), BytesValue(b"watch".to_vec()) ); } @@ -153,7 +153,7 @@ mod tests { encode_end(&mut encoded); let reader = match ResponseReader::::new() - .push(&encoded) + .push(Bytes::from(encoded)) .advance() .unwrap() { @@ -182,8 +182,9 @@ mod tests { encode_item::(&BytesValue(b"two".to_vec()), &mut encoded).unwrap(); encode_end(&mut encoded); + let all = Bytes::from(encoded); let reader = match ResponseReader::::new() - .push(&encoded[..5]) + .push(all.slice(..5)) .advance() .unwrap() { @@ -191,7 +192,7 @@ mod tests { _ => unreachable!(), }; - let reader = match reader.push(&encoded[5..]).advance().unwrap() { + let reader = match reader.push(all.slice(5..)).advance().unwrap() { ReadStep::Item { value, next } => { assert_eq!(value, BytesValue(b"one".to_vec())); next diff --git a/ql-runtime/src/rpc/request_with_progress.rs b/ql-runtime/src/rpc/request_with_progress.rs index 13a8f1ab..239e7f72 100644 --- a/ql-runtime/src/rpc/request_with_progress.rs +++ b/ql-runtime/src/rpc/request_with_progress.rs @@ -66,7 +66,7 @@ where match this.stream.poll_read_chunk(cx) { Poll::Ready(Ok(Some(chunk))) => { let reader = this.reader.take().expect("progress reader is present"); - this.reader = Some(reader.push(&chunk)); + this.reader = Some(reader.push(chunk)); } Poll::Ready(Ok(None)) => { this.reader = None; @@ -118,7 +118,7 @@ where match this.stream.poll_read_chunk(cx) { Poll::Ready(Ok(Some(chunk))) => { let reader = this.reader.take().expect("progress reader is present"); - this.reader = Some(reader.push(&chunk)); + this.reader = Some(reader.push(chunk)); } Poll::Ready(Ok(None)) => { this.reader = None; diff --git a/ql-runtime/src/rpc/subscription.rs b/ql-runtime/src/rpc/subscription.rs index 20123b07..5c0fde1b 100644 --- a/ql-runtime/src/rpc/subscription.rs +++ b/ql-runtime/src/rpc/subscription.rs @@ -57,7 +57,7 @@ where match this.stream.poll_read_chunk(cx) { Poll::Ready(Ok(Some(chunk))) => { let reader = this.reader.take().expect("subscription reader is present"); - this.reader = Some(reader.push(&chunk)); + this.reader = Some(reader.push(chunk)); } Poll::Ready(Ok(None)) => { this.reader = None; diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 4835aa23..a64c3bc0 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -1,6 +1,6 @@ use std::time::Duration; -use bytes::{Buf, Bytes}; +use bytes::{Buf, BufMut, Bytes}; use futures_lite::StreamExt; use super::*; @@ -11,8 +11,8 @@ struct BytesValue(Vec); impl ql_rpc::RpcCodec for BytesValue { type Error = core::convert::Infallible; - fn encode_value(&self, out: &mut Vec) -> Result<(), Self::Error> { - out.extend_from_slice(&self.0); + fn encode_value(&self, out: &mut B) -> Result<(), Self::Error> { + out.put_slice(&self.0); Ok(()) } @@ -76,9 +76,15 @@ async fn rpc_request_round_trips() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); let request = read_all(inbound.reader).await.unwrap(); - let rpc_inbound = ql_rpc::parse_inbound(&request).unwrap(); + let mut body = request.as_slice(); + let header = + ::decode_value(&mut body).unwrap(); assert_eq!( - ql_rpc::request::decode_request::(rpc_inbound.body).unwrap(), + header.method, + ::METHOD + ); + assert_eq!( + ql_rpc::request::decode_request::(body).unwrap(), BytesValue(b"hello".to_vec()) ); @@ -132,9 +138,15 @@ async fn rpc_subscription_streams_events() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); let request = read_all(inbound.reader).await.unwrap(); - let rpc_inbound = ql_rpc::parse_inbound(&request).unwrap(); + let mut body = request.as_slice(); + let header = + ::decode_value(&mut body).unwrap(); + assert_eq!( + header.method, + ::METHOD + ); assert_eq!( - ql_rpc::subscription::decode_request::(rpc_inbound.body).unwrap(), + ql_rpc::subscription::decode_request::(body).unwrap(), BytesValue(b"watch".to_vec()) ); @@ -200,10 +212,15 @@ async fn rpc_request_with_progress_supports_progress_then_await() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); let request = read_all(inbound.reader).await.unwrap(); - let rpc_inbound = ql_rpc::parse_inbound(&request).unwrap(); + let mut body = request.as_slice(); + let header = + ::decode_value(&mut body).unwrap(); + assert_eq!( + header.method, + ::METHOD + ); assert_eq!( - ql_rpc::request_with_progress::decode_request::(rpc_inbound.body) - .unwrap(), + ql_rpc::request_with_progress::decode_request::(body).unwrap(), BytesValue(b"logo".to_vec()) ); From 048b5c4e9caf85dcc312810413398ad2b7725650 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 07:47:27 -0400 Subject: [PATCH 133/304] ql: update design doc --- QL_V2.md | 82 ++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/QL_V2.md b/QL_V2.md index c8dd3ffd..8d91361c 100644 --- a/QL_V2.md +++ b/QL_V2.md @@ -1,11 +1,9 @@ -# QuantumLink V2 Design Document +# QuantumLink V2 QuantumLink V2 is a peer-to-peer protocol for authenticated encrypted sessions carrying multiplexed duplex byte streams. It operates on whole QL records. Packetization, fragmentation, batching, and reassembly belong to the transport adapter, not to QLv2 itself. -The handshake is the setup phase. It authenticates the remote peer, establishes a fresh session, and derives the keys used for steady-state traffic. - ## Design goals 1. [Ephemeral peer sessions](#handshake): short-lived keys for encryption 2. [Forward secrecy](#security-properties): losing a long-term private key does not reveal old session data @@ -46,7 +44,20 @@ QLv2 has two record types: Handshake records are large because they carry ML-KEM material. Session records are small and can carry multiple frames, including frames for different streams. -Handshake records are routed by peer identity. Session records are routed by `connection_id`. +Handshake records are routed by peer identity via visible `sender` and `recipient` XIDs. Session records are routed by `connection_id`. + +QLv2 uses QUIC-style variable-length integers for several steady-state fields. A varint is 1, 2, 4, or 8 bytes and can represent values in the range `0..2^62-1`. This keeps small values compact while allowing very large record and stream number spaces. + +Today, varints are used for: + +- session record `seq` +- `Ack.base_seq` +- `StreamData` frame length +- `StreamData.stream_id` +- `StreamData.offset` +- `StreamWindow.stream_id` +- `StreamWindow.maximum_offset` +- `StreamClose.stream_id` ### Handshake records @@ -59,7 +70,7 @@ Handshake records are routed by peer identity. Session records are routed by `co ### Session records -`session record size = 42 + sum(frame sizes)` +`session record size = 35..42 + sum(frame sizes)` There is no explicit AEAD nonce on the wire. The record `seq` is used to derive the nonce. @@ -68,9 +79,9 @@ There is no explicit AEAD nonce on the wire. The record `seq` is used to derive | version | 1 byte | protocol version | | record type | 1 byte | identifies a session record | | `connection_id` | 16 bytes | route the record to the current session | -| `seq` | 8 bytes | record identity for ack and retransmit | +| `seq` | 1..8 bytes | varint record identity for ack and retransmit | | AEAD auth tag | 16 bytes | authenticate the encrypted body | -| fixed overhead total | 42 bytes | overhead before any frames | +| fixed overhead total | 35..42 bytes | overhead before any frames | The visible session header is authenticated as AEAD AAD but is not encrypted. @@ -79,23 +90,23 @@ The visible session header is authenticated as AEAD AAD but is not encrypted. | Frame | Size | Purpose | | --- | ---: | --- | | `Ping` | 1 byte | keep the session alive when idle | -| `Ack` | 17 bytes | acknowledge received session records | -| `StreamWindow` | 13 bytes | extend per-stream send credit | -| `StreamClose` | 10 bytes | abort one stream lane or both lanes | +| `Ack` | `10..17` bytes | acknowledge received session records | +| `StreamWindow` | `3..17` bytes | extend per-stream send credit | +| `StreamClose` | `5..12` bytes | abort one stream lane or both lanes | | `Close` | 3 bytes | close the whole session | -| `StreamData` | `16 + payload_len` bytes | carry stream bytes and optional `fin` | +| `StreamData` | `4..26 + payload_len` bytes | carry stream bytes and optional `fin` | `StreamData` is the main steady-state frame: -`1 kind + 2 variable-length prefix + 4 stream_id + 8 offset + 1 fin + payload_len` +`1 kind + varint(frame_len) + varint(stream_id) + varint(offset) + 1 fin + payload_len` -Some useful minimum record sizes: +Some useful minimum sizes for single-frame records: | Record | Size | Meaning | | --- | ---: | --- | -| `Ping` only | 43 bytes | idle keepalive | -| `Close` only | 45 bytes | session shutdown | -| empty or fin-only `StreamData` | 58 bytes | open or finish a stream lane without payload bytes | +| `Ping` only | 36 bytes | idle keepalive | +| `Close` only | 38 bytes | session shutdown | +| empty `StreamData` | 40 bytes | open or finish a stream lane without payload bytes | ## Handshake @@ -114,7 +125,7 @@ The handshake does five things: 4. bind transport parameters into the transcript 5. produce a `handshake_hash` for the completed exchange -Today, first-contact identity exchange is still partly out of band. `IK` removes the need for the responder to know the initiator in advance, but the initiator still needs the responder bundle before it can start. A future pattern such as `XX` could remove that requirement. +First-contact identity exchange is still partly out of band. `IK` removes the need for the responder to know the initiator in advance, but the initiator still needs the responder bundle before it can start. Each handshake carries: @@ -122,11 +133,11 @@ Each handshake carries: - `valid_until`: expiration time for that attempt - transport parameters: today this is initial per-stream receive credit -Important behavior: +Handshake rules: - handshake start messages are replay-checked by `handshake_id` - expired handshake messages are rejected -- simultaneous starts are resolved deterministically +- simultaneous starts are resolved deterministically: `IK` beats `KK`; otherwise the initial ephemeral key breaks ties - handshake attempts time out and are dropped rather than being retransmitted in place Session establishment is slightly asymmetric: @@ -171,6 +182,31 @@ Retransmission works at the frame level: QLv2 does not resend the same logical record identity. +There is no explicit `Nack` frame. Loss is inferred either from timeout or from later selective `Ack` state that makes it clear a record was not accepted. + +Example: + +`seq = 10` + +| Frame | Contents | +| --- | --- | +| `StreamData` | `stream_id=4 offset=0 bytes="hello"` | + +The sender receives more bytes for that stream before `seq = 10` is acked: + +| Pending new frame | Contents | +| --- | --- | +| `StreamData` | `stream_id=4 offset=5 bytes=" world"` | + +If `seq = 10` is considered lost, its frame is restored and packed again with a new record sequence: + +`seq = 11` + +| Frame | Contents | +| --- | --- | +| `StreamData` | `stream_id=4 offset=0 bytes="hello"` | +| `StreamData` | `stream_id=4 offset=5 bytes=" world"` | + Receivers track a recent record window so they can: - reject duplicates - send selective acks with `base_seq + bitmap` @@ -187,11 +223,14 @@ A stream has two independent lanes: Important properties: - either peer can open a stream -- stream IDs are split by parity so both peers can open streams without collision +- stream IDs are split by parity derived from XID ordering, so both peers can open streams without collision +- stream IDs increase monotonically within each parity namespace and must not repeat within a session - ordering is preserved within a stream lane - different streams can make progress independently - record loss on one stream does not block unrelated streams +A stream opens implicitly on the first valid `StreamData` or `StreamClose` for that remote stream ID. There is no separate open frame. + `StreamData` carries: - `stream_id` @@ -209,8 +248,6 @@ During the handshake, each peer advertises an initial per-stream receive window. `StreamWindow` extends that credit by advertising a larger maximum offset. -Important detail: reading bytes is not what returns credit. Committing those reads is what returns credit and causes window updates to be sent. - In practice, a stream is writable only when both are true: - local send buffering has room @@ -243,4 +280,3 @@ The current handshake is ML-KEM-based and post-quantum focused. Session payloads are encrypted and authenticated. The session header stays visible so the receiver can route the record, but it is still authenticated as AEAD AAD. QLv2 also provides forward secrecy in the following sense: even if an attacker later obtains a peer's long-term ML-KEM private key, they still cannot decrypt messages from earlier completed sessions. - From 3264de2c0bf15114e9e542714570b659da204719 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 07:51:49 -0400 Subject: [PATCH 134/304] ql-fsm: add todo --- ql-fsm/src/session/stream_tx.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ql-fsm/src/session/stream_tx.rs b/ql-fsm/src/session/stream_tx.rs index 81ea6575..793fea3c 100644 --- a/ql-fsm/src/session/stream_tx.rs +++ b/ql-fsm/src/session/stream_tx.rs @@ -146,6 +146,9 @@ impl StreamTx { max_payload: usize, peer_max_offset: u64, ) -> Option { + // TODO: coalesce a lost range with contiguous unsent tail bytes when they fit in the same + // transmit budget. That would let a repacked record send one larger StreamData frame + // instead of retransmitting the lost prefix first and the new tail later. if let Some(range) = self.retransmits.peek_min() { let end = range .end From 45767b0077af3edd096b292d8023fe9ac2512232 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 08:03:51 -0400 Subject: [PATCH 135/304] ql: fmt --- ql-rpc/src/codec.rs | 14 ++++++---- ql-rpc/src/rpc/request_with_progress.rs | 6 +++-- ql-rpc/src/rpc/subscription.rs | 3 +-- ql-runtime/src/driver/state.rs | 4 +-- ql-runtime/src/handle/reader.rs | 8 +++--- ql-runtime/src/rpc/mod.rs | 3 ++- ql-runtime/src/tests/handshake.rs | 7 ++++- ql-runtime/src/tests/mod.rs | 5 +++- ql-runtime/src/tests/rpc.rs | 5 +--- ql-runtime/src/tests/stream.rs | 34 ++++++++++++++++++++----- 10 files changed, 61 insertions(+), 28 deletions(-) diff --git a/ql-rpc/src/codec.rs b/ql-rpc/src/codec.rs index a5e7ce7a..ae896561 100644 --- a/ql-rpc/src/codec.rs +++ b/ql-rpc/src/codec.rs @@ -94,7 +94,6 @@ impl ChunkQueue { } } } - } struct ChunkQueuePeek<'a> { @@ -125,7 +124,10 @@ impl Buf for ChunkQueuePeek<'_> { self.remaining -= cnt; while cnt > 0 { - let chunk = self.chunks.get(self.chunk_index).expect("buffered data present"); + let chunk = self + .chunks + .get(self.chunk_index) + .expect("buffered data present"); let available = chunk.len() - self.chunk_offset; let step = cnt.min(available); self.chunk_offset += step; @@ -161,7 +163,10 @@ pub struct DrainBuf<'a> { impl<'a> DrainBuf<'a> { pub fn new(bytes: &'a mut ChunkQueue, len: usize) -> Self { debug_assert!(bytes.remaining() >= len); - Self { bytes, remaining: len } + Self { + bytes, + remaining: len, + } } } @@ -216,7 +221,6 @@ pub fn backpatch_length + ?Sized>(out: &mut B, start: usize) { let out = out.as_mut(); let payload_start = start + LENGTH_SIZE; let payload_len = out.len() - payload_start; - let payload_len = - u64::try_from(payload_len).expect("rpc payload exceeds u64 length framing"); + let payload_len = u64::try_from(payload_len).expect("rpc payload exceeds u64 length framing"); out[start..payload_start].copy_from_slice(&payload_len.to_le_bytes()); } diff --git a/ql-rpc/src/rpc/request_with_progress.rs b/ql-rpc/src/rpc/request_with_progress.rs index 719686a2..55d576d7 100644 --- a/ql-rpc/src/rpc/request_with_progress.rs +++ b/ql-rpc/src/rpc/request_with_progress.rs @@ -48,8 +48,10 @@ impl ResponseReader { pub fn advance(self) -> Result, RpcCodecError> { let mut this = self; - let Some((kind, mut body)) = - this.bytes.try_take_tagged_part().map_err(RpcCodecError::Rpc)? + let Some((kind, mut body)) = this + .bytes + .try_take_tagged_part() + .map_err(RpcCodecError::Rpc)? else { return Ok(ReadStep::NeedMore(this)); }; diff --git a/ql-rpc/src/rpc/subscription.rs b/ql-rpc/src/rpc/subscription.rs index 7e8fb674..1cbbf4b8 100644 --- a/ql-rpc/src/rpc/subscription.rs +++ b/ql-rpc/src/rpc/subscription.rs @@ -46,8 +46,7 @@ impl ResponseReader { pub fn advance(self) -> Result, RpcCodecError> { let mut this = self; - let Some(mut body) = this.bytes.try_take_part().map_err(RpcCodecError::Rpc)? - else { + let Some(mut body) = this.bytes.try_take_part().map_err(RpcCodecError::Rpc)? else { return Ok(ReadStep::NeedMore(this)); }; diff --git a/ql-runtime/src/driver/state.rs b/ql-runtime/src/driver/state.rs index 47ae8e2c..a54e152c 100644 --- a/ql-runtime/src/driver/state.rs +++ b/ql-runtime/src/driver/state.rs @@ -98,9 +98,7 @@ impl DriverStreamIo { } pub enum OutboundIo { - Open { - reader: ChunkSlotRx, - }, + Open { reader: ChunkSlotRx }, Closed, } diff --git a/ql-runtime/src/handle/reader.rs b/ql-runtime/src/handle/reader.rs index 9e344b24..7c807cba 100644 --- a/ql-runtime/src/handle/reader.rs +++ b/ql-runtime/src/handle/reader.rs @@ -1,6 +1,5 @@ use std::{ - future::poll_fn, - future::Future, + future::{poll_fn, Future}, pin::Pin, task::{Context, Poll}, }; @@ -57,7 +56,10 @@ impl ByteReader { } } - pub fn poll_read_chunk(&mut self, cx: &mut Context<'_>) -> Poll, QlError>> { + pub fn poll_read_chunk( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, QlError>> { if matches!(self.terminal, TerminalState::Delivered) { return Poll::Ready(Ok(None)); } diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index 525f4752..d392a687 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -2,8 +2,9 @@ mod error; mod request_with_progress; mod subscription; -use bytes::Bytes; use std::future::poll_fn; + +use bytes::Bytes; use ql_rpc::{ notification::{self, Notification}, request::{self, Request as RequestRpc}, diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index d20b490e..e0021e42 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -1,6 +1,7 @@ use std::time::Duration; use bytes::Bytes; + use super::*; #[tokio::test(flavor = "current_thread")] @@ -115,7 +116,11 @@ async fn rejected_session_write_is_reissued() { }); let mut stream = handle_a.open_stream().await.unwrap(); - stream.writer.write(Bytes::from_static(b"retry")).await.unwrap(); + stream + .writer + .write(Bytes::from_static(b"retry")) + .await + .unwrap(); stream.writer.finish().await.unwrap(); assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 9d2d4941..06bcd55e 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -492,7 +492,10 @@ async fn read_all(mut stream: crate::ByteReader) -> Result, QlError> { } async fn next_chunk(stream: &mut crate::ByteReader) -> Result>, QlError> { - stream.read_chunk().await.map(|chunk| chunk.map(|bytes| bytes.to_vec())) + stream + .read_chunk() + .await + .map(|chunk| chunk.map(|bytes| bytes.to_vec())) } fn default_runtime_config() -> RuntimeConfig { diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index a64c3bc0..171aef2a 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -79,10 +79,7 @@ async fn rpc_request_round_trips() { let mut body = request.as_slice(); let header = ::decode_value(&mut body).unwrap(); - assert_eq!( - header.method, - ::METHOD - ); + assert_eq!(header.method, ::METHOD); assert_eq!( ql_rpc::request::decode_request::(body).unwrap(), BytesValue(b"hello".to_vec()) diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index 9d16a2c4..ee3366ed 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -44,9 +44,17 @@ async fn open_stream_duplex_happy_path() { }); let mut stream = handle_a.open_stream().await.unwrap(); - stream.writer.write(Bytes::from_static(&[1, 2])).await.unwrap(); + stream + .writer + .write(Bytes::from_static(&[1, 2])) + .await + .unwrap(); assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), Some(vec![9])); - stream.writer.write(Bytes::from_static(&[3, 4])).await.unwrap(); + stream + .writer + .write(Bytes::from_static(&[3, 4])) + .await + .unwrap(); stream.writer.finish().await.unwrap(); assert_eq!( next_chunk(&mut stream.reader).await.unwrap(), @@ -97,7 +105,11 @@ async fn large_stream_payload_round_trips() { }); let mut stream = handle_a.open_stream().await.unwrap(); - stream.writer.write(Bytes::from(payload.clone())).await.unwrap(); + stream + .writer + .write(Bytes::from(payload.clone())) + .await + .unwrap(); stream.writer.finish().await.unwrap(); assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); @@ -194,7 +206,10 @@ async fn dropping_inbound_reader_cancels_remote_writer() { let mut writer = stream.writer; let mut reader = stream.reader; assert_eq!(next_chunk(&mut reader).await.unwrap(), None); - writer.write(Bytes::from_static(&[1, 2, 3, 4])).await.unwrap(); + writer + .write(Bytes::from_static(&[1, 2, 3, 4])) + .await + .unwrap(); go_rx.recv().await.unwrap(); let _ = writer.write(Bytes::from(vec![5; 64])).await; let _ = writer.finish().await; @@ -324,13 +339,20 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { let stream = inbound_b.recv().await.unwrap(); let received_request = read_all(stream.reader).await.unwrap(); let mut writer = stream.writer; - writer.write(Bytes::from(response_payload.clone())).await.unwrap(); + writer + .write(Bytes::from(response_payload.clone())) + .await + .unwrap(); writer.finish().await.unwrap(); received_request }); let mut stream = handle_a.open_stream().await.unwrap(); - stream.writer.write(Bytes::from(request_payload.clone())).await.unwrap(); + stream + .writer + .write(Bytes::from(request_payload.clone())) + .await + .unwrap(); stream.writer.finish().await.unwrap(); let mut received_response = Vec::new(); From 94745d7577943ade018b05dcf1751cef549e3aa8 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 08:13:48 -0400 Subject: [PATCH 136/304] ql: clippy --- ql-fsm/src/session/stream_rx.rs | 2 +- ql-runtime/src/handle/mod.rs | 2 +- ql-wire/src/encrypted/mod.rs | 16 ++++++---------- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/ql-fsm/src/session/stream_rx.rs b/ql-fsm/src/session/stream_rx.rs index 58f7d37c..c5012018 100644 --- a/ql-fsm/src/session/stream_rx.rs +++ b/ql-fsm/src/session/stream_rx.rs @@ -241,7 +241,7 @@ pub struct StreamReadIter<'a> { remaining: usize, } -impl<'a> Iterator for StreamReadIter<'a> { +impl Iterator for StreamReadIter<'_> { type Item = Bytes; fn next(&mut self) -> Option { diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index acb594cf..466c8484 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -24,7 +24,7 @@ impl RuntimeHandle { } pub fn connect(&self) { - self.send(RuntimeCommand::Connect) + self.send(RuntimeCommand::Connect); } pub fn send_incoming(&self, bytes: Vec) { diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index dd9ccd3b..ec6d3bba 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -33,21 +33,17 @@ impl WireDecode for SessionFrame { fn decode(reader: &mut Reader) -> Result { let kind = reader.decode::()?; let frame = match kind { - SessionFrameKind::Ping => SessionFrame::Ping, - SessionFrameKind::Ack => SessionFrame::Ack(reader.decode::()?), + SessionFrameKind::Ping => Self::Ping, + SessionFrameKind::Ack => Self::Ack(reader.decode::()?), SessionFrameKind::StreamData => { let len = usize::try_from(reader.decode::()?.into_inner()) .map_err(|_| WireError::InvalidPayload)?; let frame = reader.take_bytes(len)?; - SessionFrame::StreamData(StreamData::decode_exact(frame)?) + Self::StreamData(StreamData::decode_exact(frame)?) } - SessionFrameKind::StreamWindow => { - SessionFrame::StreamWindow(reader.decode::()?) - } - SessionFrameKind::StreamClose => { - SessionFrame::StreamClose(reader.decode::()?) - } - SessionFrameKind::Close => SessionFrame::Close(reader.decode::()?), + SessionFrameKind::StreamWindow => Self::StreamWindow(reader.decode::()?), + SessionFrameKind::StreamClose => Self::StreamClose(reader.decode::()?), + SessionFrameKind::Close => Self::Close(reader.decode::()?), }; Ok(frame) } From 1d20851db05f5c050fe4cff9b8843bfd9273a1dc Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 08:30:31 -0400 Subject: [PATCH 137/304] ql-runtime: use RuntimeHandle internally --- ql-runtime/src/driver/mod.rs | 13 +++++-- ql-runtime/src/handle/mod.rs | 9 +---- ql-runtime/src/handle/reader.rs | 29 +++++++-------- ql-runtime/src/handle/writer.rs | 61 +++++++++++-------------------- ql-runtime/src/rpc/mod.rs | 2 +- ql-runtime/src/tests/handshake.rs | 4 +- ql-runtime/src/tests/heartbeat.rs | 4 +- ql-runtime/src/tests/rpc.rs | 6 +-- ql-runtime/src/tests/stream.rs | 22 +++++------ 9 files changed, 65 insertions(+), 85 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index afb57cc5..794a3a42 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -19,7 +19,7 @@ use crate::{ command::RuntimeCommand, handle::{ByteReader, ByteWriter, QlStream}, platform::{PlatformFuture, QlPlatform}, - QlError, Runtime, + QlError, Runtime, RuntimeHandle, }; impl Runtime

{ @@ -188,7 +188,7 @@ impl DriverState { CloseTarget::Return, response_reader, response_terminal_rx, - runtime_tx, + RuntimeHandle::new(runtime_tx), ); if start.send(Ok((stream_id, reader))).is_err() { if let Some(stream) = self.streams.get_mut(&stream_id) { @@ -353,9 +353,14 @@ impl DriverState { CloseTarget::Origin, request_reader, request_terminal_rx, - runtime_tx.clone(), + RuntimeHandle::new(runtime_tx.clone()), + ), + writer: ByteWriter::new( + stream_id, + CloseTarget::Return, + response_writer, + RuntimeHandle::new(runtime_tx), ), - writer: ByteWriter::new(stream_id, CloseTarget::Return, response_writer, runtime_tx), }); } diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index 466c8484..ea944197 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -45,12 +45,7 @@ impl RuntimeHandle { Ok(QlStream { stream_id, - writer: ByteWriter::new( - stream_id, - CloseTarget::Origin, - request_writer, - self.tx.clone(), - ), + writer: ByteWriter::new(stream_id, CloseTarget::Origin, request_writer, self.clone()), reader, }) } @@ -70,7 +65,7 @@ impl RuntimeHandle { #[inline] #[track_caller] - fn send(&self, cmd: RuntimeCommand) { + pub(crate) fn send(&self, cmd: RuntimeCommand) { self.tx.try_send(cmd).expect("runtime is alive"); } } diff --git a/ql-runtime/src/handle/reader.rs b/ql-runtime/src/handle/reader.rs index 7c807cba..6e62ff07 100644 --- a/ql-runtime/src/handle/reader.rs +++ b/ql-runtime/src/handle/reader.rs @@ -8,7 +8,7 @@ use bytes::Bytes; use event_listener::EventListener; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; -use crate::{chunk_slot::ChunkSlotRx, command::RuntimeCommand, QlError}; +use crate::{chunk_slot::ChunkSlotRx, command::RuntimeCommand, QlError, RuntimeHandle}; pub struct ByteReader { stream_id: StreamId, @@ -16,7 +16,7 @@ pub struct ByteReader { reader: Option, listener: Option, terminal: TerminalState, - tx: async_channel::Sender, + handle: RuntimeHandle, } enum TerminalState { @@ -44,7 +44,7 @@ impl ByteReader { target: CloseTarget, reader: ChunkSlotRx, terminal: oneshot::Receiver>, - tx: async_channel::Sender, + handle: RuntimeHandle, ) -> Self { Self { stream_id, @@ -52,7 +52,7 @@ impl ByteReader { reader: Some(reader), listener: None, terminal: TerminalState::Armed(terminal), - tx, + handle, } } @@ -67,7 +67,7 @@ impl ByteReader { if let Some(reader) = self.reader.as_ref() { match reader.poll_recv(usize::MAX, &mut self.listener, cx) { Poll::Ready(Ok(bytes)) => { - let _ = self.tx.try_send(RuntimeCommand::PollInbound { + self.handle.send(RuntimeCommand::PollInbound { stream_id: self.stream_id, }); return Poll::Ready(Ok(Some(bytes))); @@ -110,21 +110,18 @@ impl ByteReader { poll_fn(|cx| self.poll_read_chunk(cx)).await } - pub async fn close(mut self, code: StreamCloseCode) -> Result<(), QlError> { + pub fn close(mut self, code: StreamCloseCode) { if matches!(self.terminal, TerminalState::Delivered) { - return Ok(()); + return; } self.reader.take(); self.listener = None; self.terminal = TerminalState::Delivered; - self.tx - .send(RuntimeCommand::CloseStream { - stream_id: self.stream_id, - target: self.target, - code, - }) - .await - .map_err(|_| QlError::Cancelled) + self.handle.send(RuntimeCommand::CloseStream { + stream_id: self.stream_id, + target: self.target, + code, + }); } } @@ -133,7 +130,7 @@ impl Drop for ByteReader { if matches!(self.terminal, TerminalState::Delivered) { return; } - let _ = self.tx.try_send(RuntimeCommand::CloseStream { + self.handle.send(RuntimeCommand::CloseStream { stream_id: self.stream_id, target: self.target, code: StreamCloseCode(0), diff --git a/ql-runtime/src/handle/writer.rs b/ql-runtime/src/handle/writer.rs index a68d041d..0156a260 100644 --- a/ql-runtime/src/handle/writer.rs +++ b/ql-runtime/src/handle/writer.rs @@ -1,13 +1,13 @@ use bytes::Bytes; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; -use crate::{chunk_slot::ChunkSlotTx, command::RuntimeCommand, QlError}; +use crate::{chunk_slot::ChunkSlotTx, command::RuntimeCommand, QlError, RuntimeHandle}; pub struct ByteWriter { stream_id: StreamId, target: CloseTarget, writer: Option, - tx: async_channel::Sender, + handle: RuntimeHandle, } impl std::fmt::Debug for ByteWriter { @@ -25,22 +25,20 @@ impl ByteWriter { stream_id: StreamId, target: CloseTarget, writer: ChunkSlotTx, - tx: async_channel::Sender, + handle: RuntimeHandle, ) -> Self { Self { stream_id, target, writer: Some(writer), - tx, + handle, } } - fn poll_runtime(&self) -> Result<(), QlError> { - self.tx - .try_send(RuntimeCommand::PollStream { - stream_id: self.stream_id, - }) - .map_err(|_| QlError::Cancelled) + fn poll_runtime(&self) { + self.handle.send(RuntimeCommand::PollStream { + stream_id: self.stream_id, + }); } pub async fn write(&mut self, bytes: Bytes) -> Result<(), QlError> { @@ -48,57 +46,42 @@ impl ByteWriter { return Ok(()); } let writer = self.writer.as_ref().ok_or(QlError::Cancelled)?; - self.poll_runtime()?; if writer.send(bytes).await.is_err() { self.writer.take(); return Err(QlError::Cancelled); } - self.poll_runtime()?; - Ok(()) - } - - pub async fn write_all(&mut self, chunks: I) -> Result<(), QlError> - where - I: IntoIterator, - { - for chunk in chunks { - self.write(chunk).await?; - } + self.poll_runtime(); Ok(()) } - pub async fn finish(mut self) -> Result<(), QlError> { + pub fn finish(mut self) { let Some(writer) = self.writer.take() else { - return Ok(()); + return; }; writer.close(); - std::future::ready(self.poll_runtime()).await + self.poll_runtime(); } - pub async fn close(mut self, code: StreamCloseCode) -> Result<(), QlError> { - if self.writer.take().is_none() { - return Ok(()); - } - self.tx - .send(RuntimeCommand::CloseStream { - stream_id: self.stream_id, - target: self.target, - code, - }) - .await - .map_err(|_| QlError::Cancelled) + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); } } impl Drop for ByteWriter { fn drop(&mut self) { + self.close_inner(StreamCloseCode(0)); + } +} + +impl ByteWriter { + fn close_inner(&mut self, code: StreamCloseCode) { if self.writer.take().is_none() { return; } - let _ = self.tx.try_send(RuntimeCommand::CloseStream { + self.handle.send(RuntimeCommand::CloseStream { stream_id: self.stream_id, target: self.target, - code: StreamCloseCode(0), + code, }); } } diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index d392a687..35923db2 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -89,7 +89,7 @@ impl RpcHandle { async fn start_request(&self, payload: Vec) -> Result { let mut stream = self.inner.open_stream().await?; stream.writer.write(Bytes::from(payload)).await?; - stream.writer.finish().await?; + stream.writer.finish(); Ok(stream.reader) } } diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index e0021e42..648c85bf 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -111,7 +111,7 @@ async fn rejected_session_write_is_reissued() { let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); let request = read_all(stream.reader).await.unwrap(); - stream.writer.finish().await.unwrap(); + stream.writer.finish(); request }); @@ -121,7 +121,7 @@ async fn rejected_session_write_is_reissued() { .write(Bytes::from_static(b"retry")) .await .unwrap(); - stream.writer.finish().await.unwrap(); + stream.writer.finish(); assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); assert_eq!( diff --git a/ql-runtime/src/tests/heartbeat.rs b/ql-runtime/src/tests/heartbeat.rs index a0f01eb2..21a31cc4 100644 --- a/ql-runtime/src/tests/heartbeat.rs +++ b/ql-runtime/src/tests/heartbeat.rs @@ -44,13 +44,13 @@ async fn session_timeout_disconnects_and_fails_pending_open() { let responder_task = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); let _ = read_all(stream.reader).await; - let _ = stream.writer.finish().await; + stream.writer.finish(); }); drop_flag.store(true, Ordering::Relaxed); let mut pending = handle_a.open_stream().await.unwrap(); - pending.writer.finish().await.unwrap(); + pending.writer.finish(); await_status(&status_a, identity_b.xid, PeerStatus::Disconnected).await; diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 171aef2a..0912992f 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -90,7 +90,7 @@ async fn rpc_request_round_trips() { .unwrap(); let mut writer = inbound.writer; writer.write(Bytes::from(encoded)).await.unwrap(); - writer.finish().await.unwrap(); + writer.finish(); }); let rpc = handle_a.rpc(); @@ -156,7 +156,7 @@ async fn rpc_subscription_streams_events() { let mut writer = inbound.writer; writer.write(Bytes::from(encoded)).await.unwrap(); - writer.finish().await.unwrap(); + writer.finish(); }); let rpc = handle_a.rpc(); @@ -240,7 +240,7 @@ async fn rpc_request_with_progress_supports_progress_then_await() { let mut writer = inbound.writer; writer.write(Bytes::from(encoded)).await.unwrap(); - writer.finish().await.unwrap(); + writer.finish(); }); let rpc = handle_a.rpc(); diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index ee3366ed..ff61882b 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -40,7 +40,7 @@ async fn open_stream_duplex_happy_path() { assert_eq!(next_chunk(&mut reader).await.unwrap(), Some(vec![3, 4])); writer.write(Bytes::from_static(&[8, 7])).await.unwrap(); assert_eq!(next_chunk(&mut reader).await.unwrap(), None); - writer.finish().await.unwrap(); + writer.finish(); }); let mut stream = handle_a.open_stream().await.unwrap(); @@ -55,7 +55,7 @@ async fn open_stream_duplex_happy_path() { .write(Bytes::from_static(&[3, 4])) .await .unwrap(); - stream.writer.finish().await.unwrap(); + stream.writer.finish(); assert_eq!( next_chunk(&mut stream.reader).await.unwrap(), Some(vec![8, 7]) @@ -100,7 +100,7 @@ async fn large_stream_payload_round_trips() { let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); let request_data = read_all(stream.reader).await.unwrap(); - stream.writer.finish().await.unwrap(); + stream.writer.finish(); done_tx.send(request_data).await.unwrap(); }); @@ -110,7 +110,7 @@ async fn large_stream_payload_round_trips() { .write(Bytes::from(payload.clone())) .await .unwrap(); - stream.writer.finish().await.unwrap(); + stream.writer.finish(); assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); let received = tokio::time::timeout(Duration::from_secs(2), done_rx.recv()) @@ -157,7 +157,7 @@ async fn dropping_responder_closes_initiator_response() { }); let mut stream = handle_a.open_stream().await.unwrap(); - stream.writer.finish().await.unwrap(); + stream.writer.finish(); let err = next_chunk(&mut stream.reader).await.unwrap_err(); assert!(matches!( @@ -212,11 +212,11 @@ async fn dropping_inbound_reader_cancels_remote_writer() { .unwrap(); go_rx.recv().await.unwrap(); let _ = writer.write(Bytes::from(vec![5; 64])).await; - let _ = writer.finish().await; + writer.finish(); }); let mut stream = handle_a.open_stream().await.unwrap(); - stream.writer.finish().await.unwrap(); + stream.writer.finish(); assert_eq!( next_chunk(&mut stream.reader).await.unwrap(), Some(vec![1, 2, 3, 4]) @@ -265,7 +265,7 @@ async fn max_concurrent_message_writes_is_respected() { for _ in 0..4 { let stream = inbound_b.recv().await.unwrap(); let _ = read_all(stream.reader).await; - let _ = stream.writer.finish().await; + stream.writer.finish(); } }); @@ -275,7 +275,7 @@ async fn max_concurrent_message_writes_is_respected() { tasks.push(tokio::task::spawn_local(async move { let mut stream = handle.open_stream().await.unwrap(); stream.writer.write(Bytes::from(vec![i; 8])).await.unwrap(); - stream.writer.finish().await.unwrap(); + stream.writer.finish(); assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); })); } @@ -343,7 +343,7 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { .write(Bytes::from(response_payload.clone())) .await .unwrap(); - writer.finish().await.unwrap(); + writer.finish(); received_request }); @@ -353,7 +353,7 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { .write(Bytes::from(request_payload.clone())) .await .unwrap(); - stream.writer.finish().await.unwrap(); + stream.writer.finish(); let mut received_response = Vec::new(); while let Some(chunk) = next_chunk(&mut stream.reader).await.unwrap() { From f5c2787fe8d7fba2ff6d4c84d8e2412fb651ff0c Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 08:54:02 -0400 Subject: [PATCH 138/304] ql-runtime: surface read chunk with len --- ql-runtime/src/handle/reader.rs | 19 ++++++++-- ql-runtime/src/tests/mod.rs | 11 +++++- ql-runtime/src/tests/stream.rs | 67 +++++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 5 deletions(-) diff --git a/ql-runtime/src/handle/reader.rs b/ql-runtime/src/handle/reader.rs index 6e62ff07..d07596a3 100644 --- a/ql-runtime/src/handle/reader.rs +++ b/ql-runtime/src/handle/reader.rs @@ -56,8 +56,9 @@ impl ByteReader { } } - pub fn poll_read_chunk( + pub fn poll_read( &mut self, + max_len: usize, cx: &mut Context<'_>, ) -> Poll, QlError>> { if matches!(self.terminal, TerminalState::Delivered) { @@ -65,7 +66,7 @@ impl ByteReader { } if let Some(reader) = self.reader.as_ref() { - match reader.poll_recv(usize::MAX, &mut self.listener, cx) { + match reader.poll_recv(max_len, &mut self.listener, cx) { Poll::Ready(Ok(bytes)) => { self.handle.send(RuntimeCommand::PollInbound { stream_id: self.stream_id, @@ -106,8 +107,20 @@ impl ByteReader { } } + pub fn poll_read_chunk( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, QlError>> { + self.poll_read(usize::MAX, cx) + } + + /// Returns `Ok(None)` on clean EOF, `Ok(Some(_))` for data, and `Err(_)` for stream failure. + pub async fn read(&mut self, max_len: usize) -> Result, QlError> { + poll_fn(|cx| self.poll_read(max_len, cx)).await + } + pub async fn read_chunk(&mut self) -> Result, QlError> { - poll_fn(|cx| self.poll_read_chunk(cx)).await + self.read(usize::MAX).await } pub fn close(mut self, code: StreamCloseCode) { diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 06bcd55e..daa0dbf8 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -491,13 +491,20 @@ async fn read_all(mut stream: crate::ByteReader) -> Result, QlError> { Ok(data) } -async fn next_chunk(stream: &mut crate::ByteReader) -> Result>, QlError> { +async fn next_chunk_max( + stream: &mut crate::ByteReader, + max_len: usize, +) -> Result>, QlError> { stream - .read_chunk() + .read(max_len) .await .map(|chunk| chunk.map(|bytes| bytes.to_vec())) } +async fn next_chunk(stream: &mut crate::ByteReader) -> Result>, QlError> { + next_chunk_max(stream, usize::MAX).await +} + fn default_runtime_config() -> RuntimeConfig { RuntimeConfig { fsm: QlFsmConfig { diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index ff61882b..c1d00748 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -70,6 +70,73 @@ async fn open_stream_duplex_happy_path() { .await; } +#[tokio::test(flavor = "current_thread")] +async fn reader_exposes_bounded_chunk_reads() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); + let identity_a = new_identity(11); + let identity_b = new_identity(73); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect(); + + await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; + await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + let mut reader = inbound.reader; + + assert_eq!( + next_chunk_max(&mut reader, 2).await.unwrap(), + Some(vec![1, 2]) + ); + assert_eq!( + next_chunk_max(&mut reader, 2).await.unwrap(), + Some(vec![3, 4]) + ); + assert_eq!( + next_chunk_max(&mut reader, 2).await.unwrap(), + Some(vec![5, 6]) + ); + assert_eq!(next_chunk(&mut reader).await.unwrap(), None); + + inbound.writer.finish(); + }); + + let mut stream = handle_a.open_stream().await.unwrap(); + stream + .writer + .write(Bytes::from_static(&[1, 2, 3, 4])) + .await + .unwrap(); + stream + .writer + .write(Bytes::from_static(&[5, 6])) + .await + .unwrap(); + stream.writer.finish(); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + #[tokio::test(flavor = "current_thread")] async fn large_stream_payload_round_trips() { run_local_test(async { From 616bc80a31a47059c620303b8369b12fdbf49088 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 09:23:24 -0400 Subject: [PATCH 139/304] ql-runtime: streamerror --- ql-fsm/src/implementation/core.rs | 4 +- ql-fsm/src/lib.rs | 2 +- ql-fsm/src/session/mod.rs | 4 +- ql-fsm/src/session/tests.rs | 2 +- ql-fsm/src/tests/proptest.rs | 3 +- ql-runtime/src/command.rs | 3 +- ql-runtime/src/driver/mod.rs | 44 +++++++++++------- ql-runtime/src/driver/state.rs | 50 ++++++++++++++------- ql-runtime/src/driver/test.rs | 23 ++++++++-- ql-runtime/src/error.rs | 34 +++++++++++--- ql-runtime/src/handle/mod.rs | 10 ++++- ql-runtime/src/handle/reader.rs | 20 +++++---- ql-runtime/src/handle/writer.rs | 31 +++++++++++-- ql-runtime/src/lib.rs | 6 ++- ql-runtime/src/rpc/mod.rs | 5 ++- ql-runtime/src/rpc/request_with_progress.rs | 4 +- ql-runtime/src/rpc/subscription.rs | 2 +- ql-runtime/src/tests/heartbeat.rs | 6 +-- ql-runtime/src/tests/mod.rs | 10 ++--- ql-runtime/src/tests/stream.rs | 8 ++-- 20 files changed, 188 insertions(+), 83 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index 672ff222..6ebb8709 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -237,8 +237,8 @@ fn forward_session_event(event: SessionEvent, emit: &mut impl FnMut(QlFsmEvent)) emit(QlFsmEvent::Closed(frame)); false } - SessionEvent::WritableClosed(stream_id) => { - emit(QlFsmEvent::WritableClosed(stream_id)); + SessionEvent::WritableClosed(frame) => { + emit(QlFsmEvent::WritableClosed(frame)); false } SessionEvent::SessionClosed(close) => { diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index c03ead39..cbe60da2 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -78,7 +78,7 @@ pub enum QlFsmEvent { /// a stream was closed Closed(StreamClose), /// local writes on this stream are closed - WritableClosed(StreamId), + WritableClosed(StreamClose), /// the encrypted session was closed SessionClosed(SessionClose), } diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index c8e983a5..f814e8df 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -66,7 +66,7 @@ pub enum SessionEvent { Writable(StreamId), Finished(StreamId), Closed(StreamClose), - WritableClosed(StreamId), + WritableClosed(StreamClose), SessionClosed(SessionClose), } @@ -832,7 +832,7 @@ impl SessionFsm { stream.outbound_state = OutboundState::Closed; stream.tx.clear(); stream.pending_close = None; - emit(SessionEvent::WritableClosed(frame.stream_id)); + emit(SessionEvent::WritableClosed(frame.clone())); } self.try_reap_stream(frame.stream_id); Ok(()) diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index dcf517fc..ac45f737 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -362,7 +362,7 @@ fn duplicate_remote_close_after_reap_is_ignored() { vec![ SessionEvent::Opened(close.stream_id), SessionEvent::Closed(close.clone()), - SessionEvent::WritableClosed(close.stream_id), + SessionEvent::WritableClosed(close.clone()), ] ); diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index 137c6e9a..a63ad501 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -497,7 +497,8 @@ impl Runner { frame.stream_id ); } - QlFsmEvent::WritableClosed(stream_id) => { + QlFsmEvent::WritableClosed(frame) => { + let stream_id = frame.stream_id; prop_assert!( self.known_streams.contains(&stream_id), "side {side:?} emitted WritableClosed for unknown stream {stream_id:?}" diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index 46c31122..530b52fa 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -1,6 +1,6 @@ use ql_wire::{CloseTarget, PeerBundle, StreamCloseCode, StreamId}; -use crate::{chunk_slot::ChunkSlotRx, ByteReader, QlError}; +use crate::{chunk_slot::ChunkSlotRx, ByteReader, QlError, QlStreamError}; pub(crate) enum RuntimeCommand { BindPeer { @@ -9,6 +9,7 @@ pub(crate) enum RuntimeCommand { Connect, OpenStream { request_reader: ChunkSlotRx, + request_terminal: oneshot::Sender, start: oneshot::Sender>, }, PollInbound { diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 794a3a42..5ce9af7e 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -19,7 +19,7 @@ use crate::{ command::RuntimeCommand, handle::{ByteReader, ByteWriter, QlStream}, platform::{PlatformFuture, QlPlatform}, - QlError, Runtime, RuntimeHandle, + QlError, QlStreamError, Runtime, RuntimeHandle, }; impl Runtime

{ @@ -164,6 +164,7 @@ impl DriverState { } RuntimeCommand::OpenStream { request_reader, + request_terminal, start, } => { let Some(runtime_tx) = self.runtime_tx.upgrade() else { @@ -179,6 +180,7 @@ impl DriverState { stream_id, DriverStreamIo::new_initiator( request_reader, + request_terminal, response_writer, response_terminal_tx, ), @@ -319,10 +321,10 @@ impl DriverState { QlFsmEvent::Closed(frame) => { self.handle_closed_stream(&frame); } - QlFsmEvent::WritableClosed(stream_id) => { - self.handle_writable_closed(stream_id); + QlFsmEvent::WritableClosed(frame) => { + self.handle_writable_closed(&frame); } - QlFsmEvent::SessionClosed(_) => self.fail_all_streams(&QlError::SessionClosed), + QlFsmEvent::SessionClosed(_) => self.fail_all_streams(), } } @@ -340,10 +342,16 @@ impl DriverState { let (request_reader, request_writer) = chunk_slot::new(); let (request_terminal_tx, request_terminal_rx) = oneshot::channel(); let (response_reader, response_writer) = chunk_slot::new(); + let (response_terminal_tx, response_terminal_rx) = oneshot::channel(); self.streams.insert( stream_id, - DriverStreamIo::new_responder(request_writer, request_terminal_tx, response_reader), + DriverStreamIo::new_responder( + request_writer, + request_terminal_tx, + response_reader, + response_terminal_tx, + ), ); platform.handle_inbound(QlStream { @@ -359,6 +367,7 @@ impl DriverState { stream_id, CloseTarget::Return, response_writer, + response_terminal_rx, RuntimeHandle::new(runtime_tx), ), }); @@ -447,28 +456,31 @@ impl DriverState { }; if frame.target == CloseTarget::Both || frame.target == stream.inbound_target() { - stream.inbound_mut().fail(QlError::StreamClosed { - target: frame.target, - code: frame.code, - }); + stream + .inbound_mut() + .fail(QlStreamError::StreamClosed { code: frame.code }); } if frame.target == CloseTarget::Both || frame.target == stream.outbound_target() { - stream.outbound_mut().close(); + stream + .outbound_mut() + .fail(QlStreamError::StreamClosed { code: frame.code }); } self.try_reap_stream(frame.stream_id); } - fn handle_writable_closed(&mut self, stream_id: StreamId) { - let Some(stream) = self.streams.get_mut(&stream_id) else { + fn handle_writable_closed(&mut self, frame: &ql_wire::StreamClose) { + let Some(stream) = self.streams.get_mut(&frame.stream_id) else { return; }; - stream.outbound_mut().close(); - self.try_reap_stream(stream_id); + stream + .outbound_mut() + .fail(QlStreamError::StreamClosed { code: frame.code }); + self.try_reap_stream(frame.stream_id); } - fn fail_all_streams(&mut self, error: &QlError) { + fn fail_all_streams(&mut self) { for stream in self.streams.values_mut() { - stream.fail_all(error); + stream.fail_all(); } self.streams.clear(); } diff --git a/ql-runtime/src/driver/state.rs b/ql-runtime/src/driver/state.rs index a54e152c..be1369c8 100644 --- a/ql-runtime/src/driver/state.rs +++ b/ql-runtime/src/driver/state.rs @@ -7,7 +7,7 @@ use ql_wire::{CloseTarget, StreamId, XID}; use crate::{ chunk_slot::{ChunkSlotRx, ChunkSlotTx, TrySendError}, command::RuntimeCommand, - QlError, + QlStreamError, }; pub struct DriverState { @@ -36,24 +36,26 @@ impl DriverStreamIo { pub fn new_initiator( request: ChunkSlotRx, + request_terminal: oneshot::Sender, response: ChunkSlotTx, - response_terminal: oneshot::Sender>, + response_terminal: oneshot::Sender>, ) -> Self { Self { is_initiator: true, - outbound: OutboundIo::new(request), + outbound: OutboundIo::new(request, request_terminal), inbound: InboundIo::new(response, response_terminal), } } pub fn new_responder( request: ChunkSlotTx, - request_terminal: oneshot::Sender>, + request_terminal: oneshot::Sender>, response: ChunkSlotRx, + response_terminal: oneshot::Sender, ) -> Self { Self { is_initiator: false, - outbound: OutboundIo::new(response), + outbound: OutboundIo::new(response, response_terminal), inbound: InboundIo::new(request, request_terminal), } } @@ -82,13 +84,13 @@ impl DriverStreamIo { } } - pub fn fail_all(&mut self, error: &QlError) { + pub fn fail_all(&mut self) { if self.is_initiator { - self.outbound.close(); - self.inbound.fail(error.clone()); + self.outbound.fail(QlStreamError::SessionClosed); + self.inbound.fail(QlStreamError::SessionClosed); } else { - self.inbound.fail(error.clone()); - self.outbound.close(); + self.inbound.fail(QlStreamError::SessionClosed); + self.outbound.fail(QlStreamError::SessionClosed); } } @@ -98,22 +100,36 @@ impl DriverStreamIo { } pub enum OutboundIo { - Open { reader: ChunkSlotRx }, + Open { + reader: ChunkSlotRx, + terminal: Option>, + }, Closed, } impl OutboundIo { - pub fn new(reader: ChunkSlotRx) -> Self { - Self::Open { reader } + pub fn new(reader: ChunkSlotRx, terminal: oneshot::Sender) -> Self { + Self::Open { + reader, + terminal: Some(terminal), + } } pub fn close(&mut self) { *self = Self::Closed; } + pub fn fail(&mut self, error: QlStreamError) { + if let Self::Open { mut terminal, .. } = std::mem::replace(self, Self::Closed) { + if let Some(terminal) = terminal.take() { + let _ = terminal.send(error); + } + } + } + pub fn open_mut(&mut self) -> Option<&mut ChunkSlotRx> { match self { - Self::Open { reader } => Some(reader), + Self::Open { reader, .. } => Some(reader), Self::Closed => None, } } @@ -122,7 +138,7 @@ impl OutboundIo { pub enum InboundIo { Open { writer: ChunkSlotTx, - terminal: Option>>, + terminal: Option>>, finish_pending: bool, }, Closed, @@ -135,7 +151,7 @@ pub enum InboundWriteResult { } impl InboundIo { - pub fn new(writer: ChunkSlotTx, terminal: oneshot::Sender>) -> Self { + pub fn new(writer: ChunkSlotTx, terminal: oneshot::Sender>) -> Self { Self::Open { writer, terminal: Some(terminal), @@ -177,7 +193,7 @@ impl InboundIo { } } - pub fn fail(&mut self, error: QlError) { + pub fn fail(&mut self, error: QlStreamError) { if let Self::Open { mut terminal, writer, diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index 326b702d..b96d6f46 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -112,6 +112,12 @@ fn new_inbound_io(capacity: usize) -> InboundIo { InboundIo::new(writer, terminal_tx) } +fn new_outbound_io() -> OutboundIo { + let (reader, _writer) = chunk_slot::new(); + let (terminal_tx, _terminal_rx) = oneshot::channel(); + OutboundIo::new(reader, terminal_tx) +} + #[test] fn handle_inbound_finished_reaps_closed_initiator_stream() { let (mut state, fsm) = new_driver_state(); @@ -131,11 +137,10 @@ fn handle_inbound_finished_reaps_closed_initiator_stream() { fn handle_closed_stream_reaps_when_both_halves_close() { let (mut state, _fsm) = new_driver_state(); let stream_id = StreamId(1u32.into()); - let (response_reader, _response_writer) = chunk_slot::new(); state.streams.insert( stream_id, - DriverStreamIo::new(false, OutboundIo::new(response_reader), new_inbound_io(1)), + DriverStreamIo::new(false, new_outbound_io(), new_inbound_io(1)), ); state.handle_closed_stream(&StreamClose { @@ -152,11 +157,16 @@ fn poll_stream_reaps_after_local_finish_when_inbound_is_closed() { let (mut state, mut fsm) = new_driver_state(); let stream_id = StreamId(1u32.into()); let (request_reader, request_writer) = chunk_slot::new(); + let (request_terminal_tx, _request_terminal_rx) = oneshot::channel(); drop(request_writer); state.streams.insert( stream_id, - DriverStreamIo::new(true, OutboundIo::new(request_reader), InboundIo::Closed), + DriverStreamIo::new( + true, + OutboundIo::new(request_reader, request_terminal_tx), + InboundIo::Closed, + ), ); state.poll_stream(&mut fsm, stream_id); @@ -169,11 +179,16 @@ fn local_close_command_reaps_when_other_half_is_already_closed() { let (mut state, mut fsm) = new_driver_state(); let stream_id = StreamId(1u32.into()); let (request_reader, _request_writer) = chunk_slot::new(); + let (request_terminal_tx, _request_terminal_rx) = oneshot::channel(); let mut in_flight = Vec::new(); state.streams.insert( stream_id, - DriverStreamIo::new(true, OutboundIo::new(request_reader), InboundIo::Closed), + DriverStreamIo::new( + true, + OutboundIo::new(request_reader, request_terminal_tx), + InboundIo::Closed, + ), ); state.drive_command( diff --git a/ql-runtime/src/error.rs b/ql-runtime/src/error.rs index b16c7e2e..f04f7732 100644 --- a/ql-runtime/src/error.rs +++ b/ql-runtime/src/error.rs @@ -1,4 +1,5 @@ use ql_fsm::QlFsmError; +use ql_wire::StreamCloseCode; #[derive(Debug, Clone, PartialEq, Eq)] pub enum QlError { @@ -14,10 +15,7 @@ pub enum QlError { NoPeerBound, NoSession, SendFailed, - StreamClosed { - target: ql_wire::CloseTarget, - code: ql_wire::StreamCloseCode, - }, + StreamClosed { code: ql_wire::StreamCloseCode }, Cancelled, } @@ -36,7 +34,7 @@ impl std::fmt::Display for QlError { Self::NoPeerBound => f.write_str("no peer bound"), Self::NoSession => f.write_str("no active session"), Self::SendFailed => f.write_str("send failed"), - Self::StreamClosed { code, .. } => write!(f, "stream closed {code:?}"), + Self::StreamClosed { code } => write!(f, "stream closed {code:?}"), Self::Cancelled => f.write_str("cancelled"), } } @@ -44,6 +42,15 @@ impl std::fmt::Display for QlError { impl std::error::Error for QlError {} +impl From for QlError { + fn from(value: QlStreamError) -> Self { + match value { + QlStreamError::StreamClosed { code } => Self::StreamClosed { code }, + QlStreamError::SessionClosed => Self::SessionClosed, + } + } +} + impl From for QlError { fn from(value: QlFsmError) -> Self { match value { @@ -61,3 +68,20 @@ impl From for QlError { } } } + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum QlStreamError { + StreamClosed { code: StreamCloseCode }, + SessionClosed, +} + +impl std::fmt::Display for QlStreamError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::StreamClosed { code } => write!(f, "stream closed {code:?}"), + Self::SessionClosed => f.write_str("session is closed"), + } + } +} + +impl std::error::Error for QlStreamError {} diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index ea944197..a4f581cb 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -33,10 +33,12 @@ impl RuntimeHandle { pub async fn open_stream(&self) -> Result { let (request_reader, request_writer) = chunk_slot::new(); + let (request_terminal_tx, request_terminal_rx) = oneshot::channel(); let (start_tx, start_rx) = oneshot::channel(); self.send(RuntimeCommand::OpenStream { request_reader, + request_terminal: request_terminal_tx, start: start_tx, }); @@ -45,7 +47,13 @@ impl RuntimeHandle { Ok(QlStream { stream_id, - writer: ByteWriter::new(stream_id, CloseTarget::Origin, request_writer, self.clone()), + writer: ByteWriter::new( + stream_id, + CloseTarget::Origin, + request_writer, + request_terminal_rx, + self.clone(), + ), reader, }) } diff --git a/ql-runtime/src/handle/reader.rs b/ql-runtime/src/handle/reader.rs index d07596a3..cec47ae5 100644 --- a/ql-runtime/src/handle/reader.rs +++ b/ql-runtime/src/handle/reader.rs @@ -8,7 +8,7 @@ use bytes::Bytes; use event_listener::EventListener; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; -use crate::{chunk_slot::ChunkSlotRx, command::RuntimeCommand, QlError, RuntimeHandle}; +use crate::{chunk_slot::ChunkSlotRx, command::RuntimeCommand, QlStreamError, RuntimeHandle}; pub struct ByteReader { stream_id: StreamId, @@ -20,8 +20,8 @@ pub struct ByteReader { } enum TerminalState { - Armed(oneshot::Receiver>), - Terminal(Result<(), QlError>), + Armed(oneshot::Receiver>), + Terminal(Result<(), QlStreamError>), Delivered, } @@ -43,7 +43,7 @@ impl ByteReader { stream_id: StreamId, target: CloseTarget, reader: ChunkSlotRx, - terminal: oneshot::Receiver>, + terminal: oneshot::Receiver>, handle: RuntimeHandle, ) -> Self { Self { @@ -60,7 +60,7 @@ impl ByteReader { &mut self, max_len: usize, cx: &mut Context<'_>, - ) -> Poll, QlError>> { + ) -> Poll, QlStreamError>> { if matches!(self.terminal, TerminalState::Delivered) { return Poll::Ready(Ok(None)); } @@ -85,7 +85,9 @@ impl ByteReader { let result = match Pin::new(terminal).poll(cx) { Poll::Pending => None, Poll::Ready(Ok(result)) => Some(result), - Poll::Ready(Err(_)) => Some(Err(QlError::Cancelled)), + Poll::Ready(Err(_)) => { + panic!("byte reader terminal dropped before sending a terminal state") + } }; if let Some(result) = result { self.terminal = TerminalState::Terminal(result); @@ -110,16 +112,16 @@ impl ByteReader { pub fn poll_read_chunk( &mut self, cx: &mut Context<'_>, - ) -> Poll, QlError>> { + ) -> Poll, QlStreamError>> { self.poll_read(usize::MAX, cx) } /// Returns `Ok(None)` on clean EOF, `Ok(Some(_))` for data, and `Err(_)` for stream failure. - pub async fn read(&mut self, max_len: usize) -> Result, QlError> { + pub async fn read(&mut self, max_len: usize) -> Result, QlStreamError> { poll_fn(|cx| self.poll_read(max_len, cx)).await } - pub async fn read_chunk(&mut self) -> Result, QlError> { + pub async fn read_chunk(&mut self) -> Result, QlStreamError> { self.read(usize::MAX).await } diff --git a/ql-runtime/src/handle/writer.rs b/ql-runtime/src/handle/writer.rs index 0156a260..c0b8b333 100644 --- a/ql-runtime/src/handle/writer.rs +++ b/ql-runtime/src/handle/writer.rs @@ -1,15 +1,21 @@ use bytes::Bytes; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; -use crate::{chunk_slot::ChunkSlotTx, command::RuntimeCommand, QlError, RuntimeHandle}; +use crate::{chunk_slot::ChunkSlotTx, command::RuntimeCommand, QlStreamError, RuntimeHandle}; pub struct ByteWriter { stream_id: StreamId, target: CloseTarget, writer: Option, + terminal: WriteTerminalState, handle: RuntimeHandle, } +enum WriteTerminalState { + Armed(oneshot::Receiver), + Terminal(QlStreamError), +} + impl std::fmt::Debug for ByteWriter { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("OutboundByteStream") @@ -25,12 +31,14 @@ impl ByteWriter { stream_id: StreamId, target: CloseTarget, writer: ChunkSlotTx, + terminal: oneshot::Receiver, handle: RuntimeHandle, ) -> Self { Self { stream_id, target, writer: Some(writer), + terminal: WriteTerminalState::Armed(terminal), handle, } } @@ -41,14 +49,16 @@ impl ByteWriter { }); } - pub async fn write(&mut self, bytes: Bytes) -> Result<(), QlError> { + pub async fn write(&mut self, bytes: Bytes) -> Result<(), QlStreamError> { if bytes.is_empty() { return Ok(()); } - let writer = self.writer.as_ref().ok_or(QlError::Cancelled)?; + let Some(writer) = self.writer.as_ref() else { + return Err(self.terminal_error().await); + }; if writer.send(bytes).await.is_err() { self.writer.take(); - return Err(QlError::Cancelled); + return Err(self.terminal_error().await); } self.poll_runtime(); Ok(()) @@ -74,6 +84,19 @@ impl Drop for ByteWriter { } impl ByteWriter { + async fn terminal_error(&mut self) -> QlStreamError { + match &mut self.terminal { + WriteTerminalState::Terminal(error) => error.clone(), + WriteTerminalState::Armed(receiver) => { + let error = receiver + .await + .expect("byte writer terminal dropped before sending a terminal state"); + self.terminal = WriteTerminalState::Terminal(error.clone()); + error + } + } + } + fn close_inner(&mut self, code: StreamCloseCode) { if self.writer.take().is_none() { return; diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index 0a942684..bbd841da 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -1,4 +1,8 @@ -pub use self::{error::QlError, handle::*, platform::*}; +pub use self::{ + error::{QlError, QlStreamError}, + handle::*, + platform::*, +}; pub mod chunk_slot; pub(crate) mod command; diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index 35923db2..209a0492 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -96,7 +96,10 @@ impl RpcHandle { async fn read_all(mut reader: ByteReader) -> Result, QlError> { let mut bytes = Vec::new(); - while let Some(chunk) = poll_fn(|cx| reader.poll_read_chunk(cx)).await? { + while let Some(chunk) = poll_fn(|cx| reader.poll_read_chunk(cx)) + .await + .map_err(QlError::from)? + { bytes.extend_from_slice(&chunk); } Ok(bytes) diff --git a/ql-runtime/src/rpc/request_with_progress.rs b/ql-runtime/src/rpc/request_with_progress.rs index 239e7f72..ca768de9 100644 --- a/ql-runtime/src/rpc/request_with_progress.rs +++ b/ql-runtime/src/rpc/request_with_progress.rs @@ -75,7 +75,7 @@ where } Poll::Ready(Err(error)) => { this.reader = None; - this.terminal = Some(Err(RpcCallError::Runtime(error))); + this.terminal = Some(Err(RpcCallError::Runtime(error.into()))); return Poll::Ready(None); } Poll::Pending => return Poll::Pending, @@ -126,7 +126,7 @@ where } Poll::Ready(Err(error)) => { this.reader = None; - return Poll::Ready(Err(RpcCallError::Runtime(error))); + return Poll::Ready(Err(RpcCallError::Runtime(error.into()))); } Poll::Pending => return Poll::Pending, } diff --git a/ql-runtime/src/rpc/subscription.rs b/ql-runtime/src/rpc/subscription.rs index 5c0fde1b..13f1596b 100644 --- a/ql-runtime/src/rpc/subscription.rs +++ b/ql-runtime/src/rpc/subscription.rs @@ -65,7 +65,7 @@ where } Poll::Ready(Err(error)) => { this.reader = None; - return Poll::Ready(Some(Err(RpcCallError::Runtime(error)))); + return Poll::Ready(Some(Err(RpcCallError::Runtime(error.into())))); } Poll::Pending => return Poll::Pending, } diff --git a/ql-runtime/src/tests/heartbeat.rs b/ql-runtime/src/tests/heartbeat.rs index 21a31cc4..71de3b59 100644 --- a/ql-runtime/src/tests/heartbeat.rs +++ b/ql-runtime/src/tests/heartbeat.rs @@ -7,6 +7,7 @@ use std::{ }; use super::*; +use crate::QlStreamError; #[tokio::test(flavor = "current_thread")] async fn session_timeout_disconnects_and_fails_pending_open() { @@ -58,10 +59,7 @@ async fn session_timeout_disconnects_and_fails_pending_open() { tokio::time::timeout(Duration::from_millis(300), next_chunk(&mut pending.reader)) .await .unwrap(); - assert!(matches!( - result, - Err(QlError::SessionClosed | QlError::Cancelled) - )); + assert!(matches!(result, Err(QlStreamError::SessionClosed))); responder_task.abort(); }) diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index daa0dbf8..32e2b7b5 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -20,8 +20,8 @@ use sha2::{Digest, Sha256}; use tokio::task::LocalSet; use crate::{ - new_runtime, platform::PlatformFuture, QlError, QlFsmConfig, QlStream, RuntimeConfig, - RuntimeHandle, + new_runtime, platform::PlatformFuture, QlError, QlFsmConfig, QlStream, QlStreamError, + RuntimeConfig, RuntimeHandle, }; mod handshake; @@ -483,7 +483,7 @@ async fn assert_no_status_for( assert!(res.is_err(), "unexpected status event: {status:?}"); } -async fn read_all(mut stream: crate::ByteReader) -> Result, QlError> { +async fn read_all(mut stream: crate::ByteReader) -> Result, QlStreamError> { let mut data = Vec::new(); while let Some(chunk) = next_chunk(&mut stream).await? { data.extend_from_slice(&chunk); @@ -494,14 +494,14 @@ async fn read_all(mut stream: crate::ByteReader) -> Result, QlError> { async fn next_chunk_max( stream: &mut crate::ByteReader, max_len: usize, -) -> Result>, QlError> { +) -> Result>, crate::QlStreamError> { stream .read(max_len) .await .map(|chunk| chunk.map(|bytes| bytes.to_vec())) } -async fn next_chunk(stream: &mut crate::ByteReader) -> Result>, QlError> { +async fn next_chunk(stream: &mut crate::ByteReader) -> Result>, QlStreamError> { next_chunk_max(stream, usize::MAX).await } diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index c1d00748..ffb5fe53 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -1,9 +1,10 @@ use std::time::Duration; use bytes::Bytes; -use ql_wire::{CloseTarget, StreamCloseCode}; +use ql_wire::StreamCloseCode; use super::*; +use crate::QlStreamError; #[tokio::test(flavor = "current_thread")] async fn open_stream_duplex_happy_path() { @@ -229,10 +230,7 @@ async fn dropping_responder_closes_initiator_response() { let err = next_chunk(&mut stream.reader).await.unwrap_err(); assert!(matches!( err, - QlError::StreamClosed { - target: CloseTarget::Return, - code, - } if code == StreamCloseCode(0) + QlStreamError::StreamClosed { code } if code == StreamCloseCode(0) )); tokio::time::timeout(Duration::from_secs(2), responder) From c90417b71eb368f2342df73cdbbc95c480039f3d Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 09:34:09 -0400 Subject: [PATCH 140/304] ql-runtime: get rid of repeated slot writes --- ql-runtime/src/driver/mod.rs | 34 ++++------------------------------ 1 file changed, 4 insertions(+), 30 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 5ce9af7e..40a48c67 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -50,7 +50,7 @@ impl Runtime

{ let mut in_flight = Vec::new(); loop { - state.finish_step(&mut fsm, &platform, &mut in_flight); + while state.fill_write_slots(&mut fsm, &platform, &mut in_flight) {} if rx.is_closed() && in_flight.is_empty() { break; @@ -62,19 +62,12 @@ impl Runtime

{ } DriverEvent::WriteCompleted { index, success } => { let write = in_flight.swap_remove(index); - state.drive_write_completed( - &mut fsm, - write.session_write_id, - success, - &platform, - &mut in_flight, - ); + state.drive_write_completed(&mut fsm, write.session_write_id, success); } DriverEvent::TimerExpired => { state.with_fsm_events(&mut fsm, &platform, |fsm, emit| { fsm.on_timer(now(), emit); }); - state.finish_step(&mut fsm, &platform, &mut in_flight); } DriverEvent::CommandsClosed => {} } @@ -142,25 +135,22 @@ impl DriverState { fsm: &mut QlFsm, command: RuntimeCommand, platform: &'a P, - in_flight: &mut Vec>, + _in_flight: &mut Vec>, ) { match command { RuntimeCommand::BindPeer { peer } => { self.peer_xid = Some(peer.xid); fsm.bind_peer(peer); - self.finish_step(fsm, platform, in_flight); } RuntimeCommand::Connect => { let _ = self.with_fsm_events(fsm, platform, |fsm, emit| { fsm.connect_ik(now(), platform, emit) }); - self.finish_step(fsm, platform, in_flight); } RuntimeCommand::Incoming(bytes) => { let _ = self.with_fsm_events(fsm, platform, |fsm, emit| { fsm.receive(now(), bytes, platform, emit) }); - self.finish_step(fsm, platform, in_flight); } RuntimeCommand::OpenStream { request_reader, @@ -202,7 +192,6 @@ impl DriverState { return; } self.poll_stream(fsm, stream_id); - self.finish_step(fsm, platform, in_flight); } Err(error) => { let _ = start.send(Err(error)); @@ -211,11 +200,9 @@ impl DriverState { } RuntimeCommand::PollInbound { stream_id } => { self.handle_inbound_readable(fsm, stream_id); - self.finish_step(fsm, platform, in_flight); } RuntimeCommand::PollStream { stream_id } => { self.poll_stream(fsm, stream_id); - self.finish_step(fsm, platform, in_flight); } RuntimeCommand::CloseStream { stream_id, @@ -232,18 +219,15 @@ impl DriverState { } let _ = fsm.close_stream(stream_id, target, code); self.try_reap_stream(stream_id); - self.finish_step(fsm, platform, in_flight); } } } - fn drive_write_completed<'a, P: QlPlatform>( + fn drive_write_completed( &self, fsm: &mut QlFsm, session_write_id: Option, success: bool, - platform: &'a P, - in_flight: &mut Vec>, ) { if let Some(write_id) = session_write_id { if success { @@ -252,16 +236,6 @@ impl DriverState { fsm.reject_session_write(write_id); } } - self.finish_step(fsm, platform, in_flight); - } - - fn finish_step<'a, P: QlPlatform>( - &self, - fsm: &mut QlFsm, - platform: &'a P, - in_flight: &mut Vec>, - ) { - while self.fill_write_slots(fsm, platform, in_flight) {} } fn with_fsm_events( From 1f6210e98abc671b496aa18235dffe93657dbd3a Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 09:48:51 -0400 Subject: [PATCH 141/304] ql-runtime: clean up DriverStreamIo --- ql-runtime/src/driver/mod.rs | 52 ++++----- ql-runtime/src/driver/state.rs | 202 ++++++++++++--------------------- ql-runtime/src/driver/test.rs | 12 +- 3 files changed, 101 insertions(+), 165 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 40a48c67..0b31c99c 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -13,7 +13,7 @@ use futures_lite::future::poll_fn; use ql_fsm::{FsmTime, QlFsm, QlFsmEvent, SessionWriteId}; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; -use self::state::{DriverState, DriverStreamIo, InboundWriteResult}; +use self::state::{DriverState, DriverStreamIo, InboundIo, InboundWriteResult, OutboundIo}; use crate::{ chunk_slot, command::RuntimeCommand, @@ -168,11 +168,10 @@ impl DriverState { let (response_terminal_tx, response_terminal_rx) = oneshot::channel(); self.streams.insert( stream_id, - DriverStreamIo::new_initiator( - request_reader, - request_terminal, - response_writer, - response_terminal_tx, + DriverStreamIo::new( + true, + Some(OutboundIo::new(request_reader, request_terminal)), + Some(InboundIo::new(response_writer, response_terminal_tx)), ), ); let reader = ByteReader::new( @@ -184,8 +183,8 @@ impl DriverState { ); if start.send(Ok((stream_id, reader))).is_err() { if let Some(stream) = self.streams.get_mut(&stream_id) { - stream.inbound_mut().close(); - stream.outbound_mut().close(); + stream.inbound_close(); + stream.outbound_close(); } let _ = fsm.close_stream(stream_id, CloseTarget::Both, StreamCloseCode(0)); @@ -211,10 +210,10 @@ impl DriverState { } => { if let Some(stream) = self.streams.get_mut(&stream_id) { if target == CloseTarget::Both || target == stream.inbound_target() { - stream.inbound_mut().close(); + stream.inbound_close(); } if target == CloseTarget::Both || target == stream.outbound_target() { - stream.outbound_mut().close(); + stream.outbound_close(); } } let _ = fsm.close_stream(stream_id, target, code); @@ -320,11 +319,10 @@ impl DriverState { self.streams.insert( stream_id, - DriverStreamIo::new_responder( - request_writer, - request_terminal_tx, - response_reader, - response_terminal_tx, + DriverStreamIo::new( + false, + Some(OutboundIo::new(response_reader, response_terminal_tx)), + Some(InboundIo::new(request_writer, request_terminal_tx)), ), ); @@ -368,7 +366,7 @@ impl DriverState { if chunk.is_empty() { continue; } - match stream.inbound_mut().try_write(chunk) { + match stream.inbound_try_write(chunk) { InboundWriteResult::Accepted(n) => { accepted += n; } @@ -404,7 +402,7 @@ impl DriverState { let Some(stream) = self.streams.get_mut(&stream_id) else { return; }; - stream.inbound_mut().queue_finish(); + stream.inbound_queue_finish(); self.finish_inbound_if_ready(fsm, stream_id); } @@ -416,11 +414,11 @@ impl DriverState { let Some(stream) = self.streams.get_mut(&stream_id) else { return; }; - if !stream.inbound_mut().finish_pending() { + if !stream.inbound_finish_pending() { return; } - stream.inbound_mut().finish(); + stream.inbound_finish(); self.try_reap_stream(stream_id); } @@ -430,14 +428,10 @@ impl DriverState { }; if frame.target == CloseTarget::Both || frame.target == stream.inbound_target() { - stream - .inbound_mut() - .fail(QlStreamError::StreamClosed { code: frame.code }); + stream.inbound_fail(QlStreamError::StreamClosed { code: frame.code }); } if frame.target == CloseTarget::Both || frame.target == stream.outbound_target() { - stream - .outbound_mut() - .fail(QlStreamError::StreamClosed { code: frame.code }); + stream.outbound_fail(QlStreamError::StreamClosed { code: frame.code }); } self.try_reap_stream(frame.stream_id); } @@ -446,9 +440,7 @@ impl DriverState { let Some(stream) = self.streams.get_mut(&frame.stream_id) else { return; }; - stream - .outbound_mut() - .fail(QlStreamError::StreamClosed { code: frame.code }); + stream.outbound_fail(QlStreamError::StreamClosed { code: frame.code }); self.try_reap_stream(frame.stream_id); } @@ -486,7 +478,7 @@ impl DriverState { let Some(stream) = self.streams.get_mut(&stream_id) else { return; }; - let Some(reader) = stream.outbound_mut().open_mut() else { + let Some(reader) = stream.outbound_reader_mut() else { return; }; @@ -508,7 +500,7 @@ impl DriverState { if should_finish { let _ = fsm.finish_stream(stream_id); if let Some(stream) = self.streams.get_mut(&stream_id) { - stream.outbound_mut().close(); + stream.outbound_close(); } self.try_reap_stream(stream_id); } diff --git a/ql-runtime/src/driver/state.rs b/ql-runtime/src/driver/state.rs index be1369c8..4f89c210 100644 --- a/ql-runtime/src/driver/state.rs +++ b/ql-runtime/src/driver/state.rs @@ -20,13 +20,16 @@ pub struct DriverState { pub struct DriverStreamIo { is_initiator: bool, - outbound: OutboundIo, - inbound: InboundIo, + outbound: Option, + inbound: Option, } impl DriverStreamIo { - #[cfg(test)] - pub fn new(is_initiator: bool, outbound: OutboundIo, inbound: InboundIo) -> Self { + pub fn new( + is_initiator: bool, + outbound: Option, + inbound: Option, + ) -> Self { Self { is_initiator, outbound, @@ -34,40 +37,6 @@ impl DriverStreamIo { } } - pub fn new_initiator( - request: ChunkSlotRx, - request_terminal: oneshot::Sender, - response: ChunkSlotTx, - response_terminal: oneshot::Sender>, - ) -> Self { - Self { - is_initiator: true, - outbound: OutboundIo::new(request, request_terminal), - inbound: InboundIo::new(response, response_terminal), - } - } - - pub fn new_responder( - request: ChunkSlotTx, - request_terminal: oneshot::Sender>, - response: ChunkSlotRx, - response_terminal: oneshot::Sender, - ) -> Self { - Self { - is_initiator: false, - outbound: OutboundIo::new(response, response_terminal), - inbound: InboundIo::new(request, request_terminal), - } - } - - pub fn outbound_mut(&mut self) -> &mut OutboundIo { - &mut self.outbound - } - - pub fn inbound_mut(&mut self) -> &mut InboundIo { - &mut self.inbound - } - pub fn inbound_target(&self) -> CloseTarget { if self.is_initiator { CloseTarget::Return @@ -85,138 +54,113 @@ impl DriverStreamIo { } pub fn fail_all(&mut self) { - if self.is_initiator { - self.outbound.fail(QlStreamError::SessionClosed); - self.inbound.fail(QlStreamError::SessionClosed); - } else { - self.inbound.fail(QlStreamError::SessionClosed); - self.outbound.fail(QlStreamError::SessionClosed); - } + self.inbound_fail(QlStreamError::SessionClosed); + self.outbound_fail(QlStreamError::SessionClosed); } pub fn is_closed(&self) -> bool { - matches!(self.outbound, OutboundIo::Closed) && matches!(self.inbound, InboundIo::Closed) + self.outbound.is_none() && self.inbound.is_none() } -} -pub enum OutboundIo { - Open { - reader: ChunkSlotRx, - terminal: Option>, - }, - Closed, -} - -impl OutboundIo { - pub fn new(reader: ChunkSlotRx, terminal: oneshot::Sender) -> Self { - Self::Open { - reader, - terminal: Some(terminal), - } + pub fn outbound_close(&mut self) { + self.outbound = None; } - pub fn close(&mut self) { - *self = Self::Closed; - } - - pub fn fail(&mut self, error: QlStreamError) { - if let Self::Open { mut terminal, .. } = std::mem::replace(self, Self::Closed) { - if let Some(terminal) = terminal.take() { + pub fn outbound_fail(&mut self, error: QlStreamError) { + if let Some(mut outbound) = self.outbound.take() { + if let Some(terminal) = outbound.terminal.take() { let _ = terminal.send(error); } } } - pub fn open_mut(&mut self) -> Option<&mut ChunkSlotRx> { - match self { - Self::Open { reader, .. } => Some(reader), - Self::Closed => None, - } - } -} - -pub enum InboundIo { - Open { - writer: ChunkSlotTx, - terminal: Option>>, - finish_pending: bool, - }, - Closed, -} - -pub enum InboundWriteResult { - Accepted(usize), - Full, - Closed, -} - -impl InboundIo { - pub fn new(writer: ChunkSlotTx, terminal: oneshot::Sender>) -> Self { - Self::Open { - writer, - terminal: Some(terminal), - finish_pending: false, - } + pub fn outbound_reader_mut(&mut self) -> Option<&mut ChunkSlotRx> { + self.outbound.as_mut().map(|outbound| &mut outbound.reader) } - pub fn close(&mut self) { - *self = Self::Closed; + pub fn inbound_close(&mut self) { + self.inbound = None; } - pub fn try_write(&mut self, bytes: Bytes) -> InboundWriteResult { - let Self::Open { writer, .. } = self else { + pub fn inbound_try_write(&mut self, bytes: Bytes) -> InboundWriteResult { + let Some(inbound) = self.inbound.as_mut() else { return InboundWriteResult::Closed; }; let len = bytes.len(); - match writer.try_send(bytes) { + match inbound.writer.try_send(bytes) { Ok(()) => InboundWriteResult::Accepted(len), Err(TrySendError::Full(_)) => InboundWriteResult::Full, Err(TrySendError::Closed(_)) => { - *self = Self::Closed; + self.inbound = None; InboundWriteResult::Closed } } } - pub fn finish(&mut self) { - if let Self::Open { - mut terminal, - writer, - .. - } = std::mem::replace(self, Self::Closed) - { - writer.close(); - if let Some(terminal) = terminal.take() { + pub fn inbound_finish(&mut self) { + if let Some(mut inbound) = self.inbound.take() { + inbound.writer.close(); + if let Some(terminal) = inbound.terminal.take() { let _ = terminal.send(Ok(())); } } } - pub fn fail(&mut self, error: QlStreamError) { - if let Self::Open { - mut terminal, - writer, - .. - } = std::mem::replace(self, Self::Closed) - { - writer.close(); - if let Some(terminal) = terminal.take() { + pub fn inbound_fail(&mut self, error: QlStreamError) { + if let Some(mut inbound) = self.inbound.take() { + inbound.writer.close(); + if let Some(terminal) = inbound.terminal.take() { let _ = terminal.send(Err(error)); } } } - pub fn queue_finish(&mut self) { - if let Self::Open { finish_pending, .. } = self { - *finish_pending = true; + pub fn inbound_queue_finish(&mut self) { + if let Some(inbound) = self.inbound.as_mut() { + inbound.finish_pending = true; } } - pub fn finish_pending(&self) -> bool { - match self { - Self::Open { finish_pending, .. } => *finish_pending, - Self::Closed => false, + pub fn inbound_finish_pending(&self) -> bool { + self.inbound + .as_ref() + .is_some_and(|inbound| inbound.finish_pending) + } +} + +pub struct OutboundIo { + reader: ChunkSlotRx, + terminal: Option>, +} + +impl OutboundIo { + pub fn new(reader: ChunkSlotRx, terminal: oneshot::Sender) -> Self { + Self { + reader, + terminal: Some(terminal), + } + } +} + +pub struct InboundIo { + writer: ChunkSlotTx, + terminal: Option>>, + finish_pending: bool, +} + +pub enum InboundWriteResult { + Accepted(usize), + Full, + Closed, +} + +impl InboundIo { + pub fn new(writer: ChunkSlotTx, terminal: oneshot::Sender>) -> Self { + Self { + writer, + terminal: Some(terminal), + finish_pending: false, } } } diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index b96d6f46..005125ef 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -125,7 +125,7 @@ fn handle_inbound_finished_reaps_closed_initiator_stream() { state.streams.insert( stream_id, - DriverStreamIo::new(true, OutboundIo::Closed, new_inbound_io(1)), + DriverStreamIo::new(true, None, Some(new_inbound_io(1))), ); state.handle_inbound_finished(&fsm, stream_id); @@ -140,7 +140,7 @@ fn handle_closed_stream_reaps_when_both_halves_close() { state.streams.insert( stream_id, - DriverStreamIo::new(false, new_outbound_io(), new_inbound_io(1)), + DriverStreamIo::new(false, Some(new_outbound_io()), Some(new_inbound_io(1))), ); state.handle_closed_stream(&StreamClose { @@ -164,8 +164,8 @@ fn poll_stream_reaps_after_local_finish_when_inbound_is_closed() { stream_id, DriverStreamIo::new( true, - OutboundIo::new(request_reader, request_terminal_tx), - InboundIo::Closed, + Some(OutboundIo::new(request_reader, request_terminal_tx)), + None, ), ); @@ -186,8 +186,8 @@ fn local_close_command_reaps_when_other_half_is_already_closed() { stream_id, DriverStreamIo::new( true, - OutboundIo::new(request_reader, request_terminal_tx), - InboundIo::Closed, + Some(OutboundIo::new(request_reader, request_terminal_tx)), + None, ), ); From 2f09e6026461c283661ea749951764e56a83cec4 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 10:11:11 -0400 Subject: [PATCH 142/304] ql-runtime: efficient timer --- ql-runtime/src/driver/mod.rs | 38 ++++++++++++------------------- ql-runtime/src/driver/test.rs | 20 ++++++++++++---- ql-runtime/src/platform.rs | 16 +++++++++++-- ql-runtime/src/tests/mod.rs | 43 +++++++++++++++++++++++++++++++---- 4 files changed, 83 insertions(+), 34 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 0b31c99c..5670280a 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -18,7 +18,7 @@ use crate::{ chunk_slot, command::RuntimeCommand, handle::{ByteReader, ByteWriter, QlStream}, - platform::{PlatformFuture, QlPlatform}, + platform::{PlatformFuture, QlPlatform, QlTimer}, QlError, QlStreamError, Runtime, RuntimeHandle, }; @@ -48,6 +48,7 @@ impl Runtime

{ pending_fsm_events: VecDeque::new(), }; let mut in_flight = Vec::new(); + let mut timer = platform.timer(); loop { while state.fill_write_slots(&mut fsm, &platform, &mut in_flight) {} @@ -56,13 +57,15 @@ impl Runtime

{ break; } - match next_driver_event(&rx, &platform, fsm.next_deadline(), &mut in_flight).await { + timer.set_deadline(fsm.next_deadline()); + + match next_driver_event(&rx, &mut timer, &mut in_flight).await { DriverEvent::Command(command) => { - state.drive_command(&mut fsm, command, &platform, &mut in_flight); + state.drive_command(&mut fsm, command, &platform); } DriverEvent::WriteCompleted { index, success } => { let write = in_flight.swap_remove(index); - state.drive_write_completed(&mut fsm, write.session_write_id, success); + DriverState::drive_write_completed(&mut fsm, write.session_write_id, success); } DriverEvent::TimerExpired => { state.with_fsm_events(&mut fsm, &platform, |fsm, emit| { @@ -88,17 +91,12 @@ enum DriverEvent { } #[allow(clippy::future_not_send)] -async fn next_driver_event( +async fn next_driver_event( rx: &async_channel::Receiver, - platform: &P, - next_timer: Option, + timer: &mut T, in_flight: &mut [InFlightWrite<'_>], ) -> DriverEvent { let mut recv_future = (!rx.is_closed()).then(|| Box::pin(rx.recv())); - let mut sleep_future = next_timer.map(|deadline| { - let timeout = deadline.saturating_duration_since(Instant::now()); - platform.sleep(timeout) - }); poll_fn(|cx| { for (index, write) in in_flight.iter_mut().enumerate() { @@ -110,10 +108,8 @@ async fn next_driver_event( } } - if let Some(future) = sleep_future.as_mut() { - if future.as_mut().poll(cx) == Poll::Ready(()) { - return Poll::Ready(DriverEvent::TimerExpired); - } + if timer.poll_wait(cx) == Poll::Ready(()) { + return Poll::Ready(DriverEvent::TimerExpired); } if let Some(future) = recv_future.as_mut() { @@ -130,12 +126,11 @@ async fn next_driver_event( } impl DriverState { - fn drive_command<'a, P: QlPlatform>( + fn drive_command( &mut self, fsm: &mut QlFsm, command: RuntimeCommand, - platform: &'a P, - _in_flight: &mut Vec>, + platform: &P, ) { match command { RuntimeCommand::BindPeer { peer } => { @@ -222,12 +217,7 @@ impl DriverState { } } - fn drive_write_completed( - &self, - fsm: &mut QlFsm, - session_write_id: Option, - success: bool, - ) { + fn drive_write_completed(fsm: &mut QlFsm, session_write_id: Option, success: bool) { if let Some(write_id) = session_write_id { if success { fsm.confirm_session_write(now(), write_id); diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index 005125ef..9f3e8567 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -1,3 +1,5 @@ +use std::task::{Context, Poll}; + use ql_wire::{ MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, PeerBundle, QlAead, QlHash, QlKem, QlRandom, SessionKey, StreamClose, XID, @@ -12,6 +14,8 @@ use crate::{ struct NoopPlatform; +struct NoopTimer; + impl QlRandom for NoopPlatform { fn fill_random_bytes(&self, data: &mut [u8]) { data.fill(0); @@ -71,13 +75,23 @@ impl QlKem for NoopPlatform { } } +impl crate::platform::QlTimer for NoopTimer { + fn set_deadline(&mut self, _deadline: Option) {} + + fn poll_wait(&mut self, _cx: &mut Context<'_>) -> Poll<()> { + Poll::Pending + } +} + impl QlPlatform for NoopPlatform { + type Timer = NoopTimer; + fn write_message(&self, _message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { Box::pin(async { Ok(()) }) } - fn sleep(&self, _duration: Duration) -> PlatformFuture<'_, ()> { - Box::pin(async {}) + fn timer(&self) -> Self::Timer { + NoopTimer } fn load_peer(&self) -> PlatformFuture<'_, Option> { @@ -180,7 +194,6 @@ fn local_close_command_reaps_when_other_half_is_already_closed() { let stream_id = StreamId(1u32.into()); let (request_reader, _request_writer) = chunk_slot::new(); let (request_terminal_tx, _request_terminal_rx) = oneshot::channel(); - let mut in_flight = Vec::new(); state.streams.insert( stream_id, @@ -199,7 +212,6 @@ fn local_close_command_reaps_when_other_half_is_already_closed() { code: StreamCloseCode(0), }, &NoopPlatform, - &mut in_flight, ); assert!(!state.streams.contains_key(&stream_id)); diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs index 36ce6e30..f903358c 100644 --- a/ql-runtime/src/platform.rs +++ b/ql-runtime/src/platform.rs @@ -1,4 +1,9 @@ -use std::{future::Future, pin::Pin, time::Duration}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + time::Instant, +}; use ql_fsm::PeerStatus; use ql_wire::{PeerBundle, QlCrypto, XID}; @@ -7,9 +12,16 @@ use crate::{QlError, QlStream}; pub type PlatformFuture<'a, T> = Pin + 'a>>; +pub trait QlTimer { + fn set_deadline(&mut self, deadline: Option); + fn poll_wait(&mut self, cx: &mut Context<'_>) -> Poll<()>; +} + pub trait QlPlatform: QlCrypto { + type Timer: QlTimer; + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>>; - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()>; + fn timer(&self) -> Self::Timer; fn load_peer(&self) -> PlatformFuture<'_, Option>; fn persist_peer(&self, peer: PeerBundle); diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 32e2b7b5..2b5d4254 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -1,10 +1,12 @@ use std::{ cell::Cell, future::Future, + pin::Pin, sync::{ atomic::{AtomicU8, AtomicUsize, Ordering}, Arc, }, + task::{Context, Poll}, time::Duration, }; @@ -17,10 +19,12 @@ use ql_wire::{ WireDecode, XID, }; use sha2::{Digest, Sha256}; -use tokio::task::LocalSet; +use tokio::{task::LocalSet, time::Sleep}; use crate::{ - new_runtime, platform::PlatformFuture, QlError, QlFsmConfig, QlStream, QlStreamError, + new_runtime, + platform::{PlatformFuture, QlTimer}, + QlError, QlFsmConfig, QlStream, QlStreamError, RuntimeConfig, RuntimeHandle, }; @@ -232,6 +236,29 @@ impl TestPlatform { } } +struct TokioTimer { + sleep: Pin>, +} + +impl TokioTimer { + fn new() -> Self { + Self { + sleep: Box::pin(tokio::time::sleep_until(parked_deadline())), + } + } +} + +impl QlTimer for TokioTimer { + fn set_deadline(&mut self, deadline: Option) { + let deadline = deadline.map_or_else(parked_deadline, tokio::time::Instant::from_std); + self.sleep.as_mut().reset(deadline); + } + + fn poll_wait(&mut self, cx: &mut Context<'_>) -> Poll<()> { + self.sleep.as_mut().poll(cx) + } +} + impl QlRandom for TestPlatform { fn fill_random_bytes(&self, data: &mut [u8]) { let value = self @@ -318,6 +345,8 @@ impl QlKem for TestPlatform { } impl crate::platform::QlPlatform for TestPlatform { + type Timer = TokioTimer; + fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { let outbound = self.outbound.clone(); let write_delay = self.write_delay; @@ -358,8 +387,8 @@ impl crate::platform::QlPlatform for TestPlatform { }) } - fn sleep(&self, duration: Duration) -> PlatformFuture<'_, ()> { - Box::pin(tokio::time::sleep(duration)) + fn timer(&self) -> Self::Timer { + TokioTimer::new() } fn load_peer(&self) -> PlatformFuture<'_, Option> { @@ -379,6 +408,10 @@ impl crate::platform::QlPlatform for TestPlatform { } } +fn parked_deadline() -> tokio::time::Instant { + tokio::time::Instant::now() + Duration::from_secs(60 * 60 * 24 * 365 * 100) +} + fn is_encrypted_payload(bytes: &[u8]) -> bool { RecordHeader::decode_bytes(bytes) .ok() @@ -527,6 +560,7 @@ fn runtime_is_send() { let (runtime_a, _handle) = new_runtime(identity_a, platform_a, config); std::thread::spawn(move || { tokio::runtime::Builder::new_current_thread() + .enable_time() .build() .unwrap() .block_on(runtime_a.run()); @@ -543,6 +577,7 @@ fn runtime_exits_when_last_handle_drops() { std::thread::spawn(move || { tokio::runtime::Builder::new_current_thread() + .enable_time() .build() .unwrap() .block_on(runtime.run()); From bf3d2eec810d61e02ffef334f0d48dea369ee34e Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 10:20:52 -0400 Subject: [PATCH 143/304] ql-runtime: non-boxed write --- ql-runtime/src/driver/mod.rs | 23 ++++++++++++++--------- ql-runtime/src/driver/test.rs | 6 ++++-- ql-runtime/src/platform.rs | 5 ++++- ql-runtime/src/tests/mod.rs | 3 ++- 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 5670280a..235bbf03 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -5,6 +5,7 @@ mod test; use std::{ collections::{HashMap, VecDeque}, future::Future, + pin::Pin, task::Poll, time::{Duration, Instant, SystemTime, UNIX_EPOCH}, }; @@ -18,7 +19,7 @@ use crate::{ chunk_slot, command::RuntimeCommand, handle::{ByteReader, ByteWriter, QlStream}, - platform::{PlatformFuture, QlPlatform, QlTimer}, + platform::{QlPlatform, QlTimer}, QlError, QlStreamError, Runtime, RuntimeHandle, }; @@ -78,9 +79,9 @@ impl Runtime

{ } } -struct InFlightWrite<'a> { +struct InFlightWrite { session_write_id: Option, - future: PlatformFuture<'a, Result<(), QlError>>, + future: F, } enum DriverEvent { @@ -91,16 +92,20 @@ enum DriverEvent { } #[allow(clippy::future_not_send)] -async fn next_driver_event( +async fn next_driver_event( rx: &async_channel::Receiver, timer: &mut T, - in_flight: &mut [InFlightWrite<'_>], -) -> DriverEvent { + in_flight: &mut [InFlightWrite], +) -> DriverEvent +where + T: QlTimer, + F: Future> + Unpin, +{ let mut recv_future = (!rx.is_closed()).then(|| Box::pin(rx.recv())); poll_fn(|cx| { for (index, write) in in_flight.iter_mut().enumerate() { - if let Poll::Ready(result) = write.future.as_mut().poll(cx) { + if let Poll::Ready(result) = Pin::new(&mut write.future).poll(cx) { return Poll::Ready(DriverEvent::WriteCompleted { index, success: result.is_ok(), @@ -441,11 +446,11 @@ impl DriverState { self.streams.clear(); } - fn fill_write_slots<'a, P: QlPlatform>( + fn fill_write_slots<'a, P: QlPlatform + 'a>( &self, fsm: &mut QlFsm, platform: &'a P, - in_flight: &mut Vec>, + in_flight: &mut Vec>>, ) -> bool { let mut progressed = false; diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index 9f3e8567..62322d85 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -9,6 +9,7 @@ use super::*; use crate::{ chunk_slot, driver::state::{InboundIo, OutboundIo}, + platform::PlatformFuture, tests::new_identity, }; @@ -85,9 +86,10 @@ impl crate::platform::QlTimer for NoopTimer { impl QlPlatform for NoopPlatform { type Timer = NoopTimer; + type WriteMessageFut<'a> = std::future::Ready>; - fn write_message(&self, _message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { - Box::pin(async { Ok(()) }) + fn write_message(&self, _message: Vec) -> Self::WriteMessageFut<'_> { + std::future::ready(Ok(())) } fn timer(&self) -> Self::Timer { diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs index f903358c..9b91af70 100644 --- a/ql-runtime/src/platform.rs +++ b/ql-runtime/src/platform.rs @@ -19,8 +19,11 @@ pub trait QlTimer { pub trait QlPlatform: QlCrypto { type Timer: QlTimer; + type WriteMessageFut<'a>: Future> + Unpin + 'a + where + Self: 'a; - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>>; + fn write_message(&self, message: Vec) -> Self::WriteMessageFut<'_>; fn timer(&self) -> Self::Timer; fn load_peer(&self) -> PlatformFuture<'_, Option>; diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 2b5d4254..ae479610 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -346,8 +346,9 @@ impl QlKem for TestPlatform { impl crate::platform::QlPlatform for TestPlatform { type Timer = TokioTimer; + type WriteMessageFut<'a> = PlatformFuture<'a, Result<(), QlError>>; - fn write_message(&self, message: Vec) -> PlatformFuture<'_, Result<(), QlError>> { + fn write_message(&self, message: Vec) -> Self::WriteMessageFut<'_> { let outbound = self.outbound.clone(); let write_delay = self.write_delay; let fail_encrypted_write_at = self.fail_encrypted_write_at; From 1ca26e380f94414a05e53e5b68bb676ee0c50337 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 11:15:06 -0400 Subject: [PATCH 144/304] ql-runtime: remove loop --- ql-runtime/src/driver/mod.rs | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 235bbf03..b74bd6f3 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -52,7 +52,7 @@ impl Runtime

{ let mut timer = platform.timer(); loop { - while state.fill_write_slots(&mut fsm, &platform, &mut in_flight) {} + state.fill_write_slots(&mut fsm, &platform, &mut in_flight); if rx.is_closed() && in_flight.is_empty() { break; @@ -222,7 +222,11 @@ impl DriverState { } } - fn drive_write_completed(fsm: &mut QlFsm, session_write_id: Option, success: bool) { + fn drive_write_completed( + fsm: &mut QlFsm, + session_write_id: Option, + success: bool, + ) { if let Some(write_id) = session_write_id { if success { fsm.confirm_session_write(now(), write_id); @@ -451,21 +455,16 @@ impl DriverState { fsm: &mut QlFsm, platform: &'a P, in_flight: &mut Vec>>, - ) -> bool { - let mut progressed = false; - + ) { while in_flight.len() < self.max_concurrent_message_writes { let Some(write) = fsm.take_next_write(now(), platform) else { break; }; - progressed = true; in_flight.push(InFlightWrite { session_write_id: write.session_write_id, future: platform.write_message(write.record), }); } - - progressed } fn poll_stream(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { From 45fbade6eb137c02c3e6509e5bc2ddc7c0f27daa Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 11:21:36 -0400 Subject: [PATCH 145/304] ql: stream-writer to avoid double lookups --- ql-fsm/src/implementation/core.rs | 17 ++--------- ql-fsm/src/lib.rs | 17 +++-------- ql-fsm/src/session/mod.rs | 48 ++++++++++++++++--------------- ql-fsm/src/session/tests.rs | 3 +- ql-fsm/src/tests/proptest.rs | 6 ++-- ql-fsm/src/tests/session.rs | 3 +- ql-runtime/src/driver/mod.rs | 5 ++-- 7 files changed, 43 insertions(+), 56 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index 6ebb8709..bb796ee6 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -7,7 +7,7 @@ use ql_wire::{ use crate::{ session::SessionEvent, state::LinkState, OutboundWrite, QlFsm, QlFsmError, QlFsmEvent, - SessionWriteId, StreamReadIter, + SessionWriteId, StreamReadIter, StreamWriter, }; pub fn handle_bind_peer(fsm: &mut QlFsm, peer: ql_wire::PeerBundle) { @@ -150,20 +150,9 @@ pub fn open_stream(fsm: &mut QlFsm) -> Result { Ok(state.session.open_stream()?) } -pub fn write_stream( - fsm: &mut QlFsm, - stream_id: StreamId, - bytes: &mut Bytes, -) -> Result { +pub fn write_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result, QlFsmError> { let state = fsm.state.link.connected_mut_or_err()?; - Ok(state.session.write_stream(stream_id, bytes)?) -} - -pub fn stream_write_capacity(fsm: &QlFsm, stream_id: StreamId) -> Option { - fsm.state - .link - .connected() - .and_then(|state| state.session.stream_write_capacity(stream_id)) + Ok(state.session.write_stream(stream_id)?) } pub fn stream_read(fsm: &QlFsm, stream_id: StreamId) -> Option> { diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index cbe60da2..335be9af 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -33,7 +33,7 @@ use ql_wire::{ CloseTarget, PeerBundle, QlCrypto, QlIdentity, SessionClose, SessionCloseCode, StreamClose, StreamCloseCode, StreamId, }; -pub use session::stream_rx::StreamReadIter; +pub use session::{stream_rx::StreamReadIter, StreamWriter}; use crate::{ replay_cache::ReplayCache, @@ -254,18 +254,9 @@ impl QlFsm { implementation::open_stream(self) } - /// queues owned bytes for an open stream and returns the accepted count - pub fn write_stream( - &mut self, - stream_id: StreamId, - bytes: &mut Bytes, - ) -> Result { - implementation::write_stream(self, stream_id, bytes) - } - - /// returns how many bytes can currently be queued for an open stream - pub fn stream_write_capacity(&self, stream_id: StreamId) -> Option { - implementation::stream_write_capacity(self, stream_id) + /// returns a writer for an open stream + pub fn write_stream(&mut self, stream_id: StreamId) -> Result, QlFsmError> { + implementation::write_stream(self, stream_id) } /// returns the readable stream bytes as owned `Bytes` views without consuming them diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index f814e8df..afdb9554 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -98,6 +98,25 @@ impl std::fmt::Display for StreamError { impl std::error::Error for StreamError {} +pub struct StreamWriter<'a> { + stream: &'a mut StreamState, + send_buffer_size: usize, +} + +impl StreamWriter<'_> { + pub fn capacity(&self) -> usize { + self.stream.send_capacity(self.send_buffer_size) + } + + pub fn write(&mut self, bytes: &mut Bytes) -> usize { + let accepted = bytes.len().min(self.capacity()); + if accepted > 0 { + self.stream.tx.append(bytes.split_to(accepted)); + } + accepted + } +} + pub struct SessionFsm { config: SessionFsmConfig, state: SessionFsmState, @@ -149,15 +168,9 @@ impl SessionFsm { Ok(stream_id) } - pub fn write_stream( - &mut self, - stream_id: StreamId, - bytes: &mut Bytes, - ) -> Result { - // TODO: consider a `BytesSource` abstraction here so callers can provide - // different chunk sources while preserving partial-accept semantics and deferring any - // required copying until capacity is known + pub fn write_stream(&mut self, stream_id: StreamId) -> Result, StreamError> { self.ensure_session_open()?; + let send_buffer_size = self.config.stream_send_buffer_size; let stream = self .state .streams @@ -167,21 +180,10 @@ impl SessionFsm { return Err(StreamError::NotWritable); } - let accepted = bytes - .len() - .min(stream.send_capacity(self.config.stream_send_buffer_size)); - if accepted > 0 { - stream.tx.append(bytes.split_to(accepted)); - } - Ok(accepted) - } - - pub fn stream_write_capacity(&self, stream_id: StreamId) -> Option { - let stream = self.state.streams.get(&stream_id)?; - if !stream.is_writable() { - return Some(0); - } - Some(stream.send_capacity(self.config.stream_send_buffer_size)) + Ok(StreamWriter { + stream, + send_buffer_size, + }) } pub fn finish_stream(&mut self, stream_id: StreamId) -> Result<(), StreamError> { diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index ac45f737..15956504 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -23,7 +23,8 @@ fn offset(value: u64) -> VarInt { fn write_stream_bytes(fsm: &mut SessionFsm, stream_id: StreamId, bytes: &[u8]) -> usize { let mut bytes = Bytes::copy_from_slice(bytes); - fsm.write_stream(stream_id, &mut bytes).unwrap() + let mut writer = fsm.write_stream(stream_id).unwrap(); + writer.write(&mut bytes) } fn read_stream_all(fsm: &mut SessionFsm, stream_id: StreamId) -> Vec { diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index a63ad501..4aef024a 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -318,7 +318,8 @@ impl Runner { Action::WriteA { slot, bytes } => { if let Some(stream_id) = self.slots_a[*slot] { let mut chunk = Bytes::copy_from_slice(bytes); - if let Ok(accepted) = self.harness.a.fsm.write_stream(stream_id, &mut chunk) { + if let Ok(mut writer) = self.harness.a.fsm.write_stream(stream_id) { + let accepted = writer.write(&mut chunk); self.expected_at_b .entry(stream_id) .or_default() @@ -329,7 +330,8 @@ impl Runner { Action::WriteB { slot, bytes } => { if let Some(stream_id) = self.slots_b[*slot] { let mut chunk = Bytes::copy_from_slice(bytes); - if let Ok(accepted) = self.harness.b.fsm.write_stream(stream_id, &mut chunk) { + if let Ok(mut writer) = self.harness.b.fsm.write_stream(stream_id) { + let accepted = writer.write(&mut chunk); self.expected_at_a .entry(stream_id) .or_default() diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index a08b24e0..5698aa9a 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -16,7 +16,8 @@ fn write_stream_bytes( bytes: &[u8], ) -> Result { let mut bytes = Bytes::copy_from_slice(bytes); - fsm.write_stream(stream_id, &mut bytes) + let mut writer = fsm.write_stream(stream_id)?; + Ok(writer.write(&mut bytes)) } fn read_stream_all(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index b74bd6f3..6b3d70f7 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -479,12 +479,13 @@ impl DriverState { if reader.is_finished() { true } else { - let Some(capacity) = fsm.stream_write_capacity(stream_id) else { + let Ok(mut writer) = fsm.write_stream(stream_id) else { return; }; + let capacity = writer.capacity(); if capacity > 0 { if let Ok(Some(mut bytes)) = reader.try_recv(capacity) { - let _ = fsm.write_stream(stream_id, &mut bytes); + let _ = writer.write(&mut bytes); } } reader.is_finished() From 918e721a0127dceacd1924274e91669c461d221f Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 11:25:28 -0400 Subject: [PATCH 146/304] ql-runtime: write future bool output instead of result --- ql-runtime/src/driver/mod.rs | 9 +++------ ql-runtime/src/driver/test.rs | 4 ++-- ql-runtime/src/platform.rs | 4 ++-- ql-runtime/src/tests/mod.rs | 16 ++++++---------- 4 files changed, 13 insertions(+), 20 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 6b3d70f7..59f1df66 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -99,17 +99,14 @@ async fn next_driver_event( ) -> DriverEvent where T: QlTimer, - F: Future> + Unpin, + F: Future + Unpin, { let mut recv_future = (!rx.is_closed()).then(|| Box::pin(rx.recv())); poll_fn(|cx| { for (index, write) in in_flight.iter_mut().enumerate() { - if let Poll::Ready(result) = Pin::new(&mut write.future).poll(cx) { - return Poll::Ready(DriverEvent::WriteCompleted { - index, - success: result.is_ok(), - }); + if let Poll::Ready(success) = Pin::new(&mut write.future).poll(cx) { + return Poll::Ready(DriverEvent::WriteCompleted { index, success }); } } diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index 62322d85..a459aadb 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -86,10 +86,10 @@ impl crate::platform::QlTimer for NoopTimer { impl QlPlatform for NoopPlatform { type Timer = NoopTimer; - type WriteMessageFut<'a> = std::future::Ready>; + type WriteMessageFut<'a> = std::future::Ready; fn write_message(&self, _message: Vec) -> Self::WriteMessageFut<'_> { - std::future::ready(Ok(())) + std::future::ready(true) } fn timer(&self) -> Self::Timer { diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs index 9b91af70..411627c0 100644 --- a/ql-runtime/src/platform.rs +++ b/ql-runtime/src/platform.rs @@ -8,7 +8,7 @@ use std::{ use ql_fsm::PeerStatus; use ql_wire::{PeerBundle, QlCrypto, XID}; -use crate::{QlError, QlStream}; +use crate::QlStream; pub type PlatformFuture<'a, T> = Pin + 'a>>; @@ -19,7 +19,7 @@ pub trait QlTimer { pub trait QlPlatform: QlCrypto { type Timer: QlTimer; - type WriteMessageFut<'a>: Future> + Unpin + 'a + type WriteMessageFut<'a>: Future + Unpin + 'a where Self: 'a; diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index ae479610..e3ae9c55 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -24,8 +24,7 @@ use tokio::{task::LocalSet, time::Sleep}; use crate::{ new_runtime, platform::{PlatformFuture, QlTimer}, - QlError, QlFsmConfig, QlStream, QlStreamError, - RuntimeConfig, RuntimeHandle, + QlError, QlFsmConfig, QlStream, QlStreamError, RuntimeConfig, RuntimeHandle, }; mod handshake; @@ -346,7 +345,7 @@ impl QlKem for TestPlatform { impl crate::platform::QlPlatform for TestPlatform { type Timer = TokioTimer; - type WriteMessageFut<'a> = PlatformFuture<'a, Result<(), QlError>>; + type WriteMessageFut<'a> = PlatformFuture<'a, bool>; fn write_message(&self, message: Vec) -> Self::WriteMessageFut<'_> { let outbound = self.outbound.clone(); @@ -371,20 +370,17 @@ impl crate::platform::QlPlatform for TestPlatform { false }; - let result = if should_fail { - Err(QlError::SendFailed) + let success = if should_fail { + false } else { - outbound - .send(message) - .await - .map_err(|_| QlError::InvalidPayload) + outbound.send(message).await.is_ok() }; if let Some(stats) = write_stats.as_ref() { stats.active.fetch_sub(1, Ordering::Relaxed); } - result + success }) } From 85eb5743fa0e0a2fc3643f5e5148dccc6a2671a2 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 11:40:39 -0400 Subject: [PATCH 147/304] ql-fsm: cleanup --- ql-fsm/src/implementation/core.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index bb796ee6..3da9de71 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -156,10 +156,8 @@ pub fn write_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result Option> { - fsm.state - .link - .connected() - .and_then(|state| state.session.stream_read(stream_id)) + let state = fsm.state.link.connected()?; + state.session.stream_read(stream_id) } pub fn stream_read_commit( From 9e7731de47aaee59dd7b5049a56f2f06a63ca572 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 12:12:15 -0400 Subject: [PATCH 148/304] ql: more granular errors --- ql-fsm/src/error.rs | 69 +++++++++++++------ ql-fsm/src/implementation/core.rs | 34 +++++----- ql-fsm/src/lib.rs | 14 ++-- ql-fsm/src/session/mod.rs | 31 ++------- ql-fsm/src/session/tests.rs | 2 +- ql-fsm/src/state.rs | 6 +- ql-fsm/src/tests/session.rs | 16 ++--- ql-runtime/src/command.rs | 5 +- ql-runtime/src/driver/mod.rs | 6 +- ql-runtime/src/driver/state.rs | 4 +- ql-runtime/src/error.rs | 73 +-------------------- ql-runtime/src/handle/mod.rs | 5 +- ql-runtime/src/lib.rs | 3 +- ql-runtime/src/rpc/error.rs | 29 ++++++-- ql-runtime/src/rpc/mod.rs | 11 ++-- ql-runtime/src/rpc/request_with_progress.rs | 4 +- ql-runtime/src/rpc/subscription.rs | 2 +- ql-runtime/src/tests/handshake.rs | 2 +- ql-runtime/src/tests/heartbeat.rs | 2 +- ql-runtime/src/tests/mod.rs | 2 +- 20 files changed, 139 insertions(+), 181 deletions(-) diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs index 0d66677f..2e2e929d 100644 --- a/ql-fsm/src/error.rs +++ b/ql-fsm/src/error.rs @@ -1,6 +1,9 @@ -use ql_wire::WireError; +use std::{ + error::Error, + fmt::{Display, Formatter}, +}; -use crate::session::StreamError; +use ql_wire::WireError; #[derive(Debug, Clone, PartialEq, Eq)] pub enum QlFsmError { @@ -9,26 +12,18 @@ pub enum QlFsmError { Expired, DecryptFailed, InvalidXid, - MissingStream, - NotWritable, - InvalidRead, - SessionClosed, NoPeerBound, NoSession, } -impl std::fmt::Display for QlFsmError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Display for QlFsmError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let message = match self { Self::InvalidPayload => "invalid payload", Self::InvalidState => "invalid state", Self::Expired => "expired", Self::DecryptFailed => "decryption failed", Self::InvalidXid => "invalid xid", - Self::MissingStream => "missing stream", - Self::NotWritable => "stream is not writable", - Self::InvalidRead => "invalid read commit", - Self::SessionClosed => "session is closed", Self::NoPeerBound => "no peer bound", Self::NoSession => "no active session", }; @@ -49,13 +44,47 @@ impl From for QlFsmError { } } -impl From for QlFsmError { - fn from(value: StreamError) -> Self { - match value { - StreamError::MissingStream => Self::MissingStream, - StreamError::NotWritable => Self::NotWritable, - StreamError::InvalidRead => Self::InvalidRead, - StreamError::SessionClosed => Self::SessionClosed, - } +impl From for QlFsmError { + fn from(_: NoSessionError) -> Self { + Self::NoSession + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct NoSessionError; + +impl Display for NoSessionError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "no session") + } +} + +impl Error for NoSessionError {} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamError { + MissingStream, + NotWritable, + InvalidRead, + NoSession, +} + +impl Display for StreamError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let message = match self { + Self::MissingStream => "missing stream", + Self::NotWritable => "stream is not writable", + Self::InvalidRead => "invalid read commit", + Self::NoSession => "no session", + }; + f.write_str(message) + } +} + +impl Error for StreamError {} + +impl From for StreamError { + fn from(_: NoSessionError) -> Self { + Self::NoSession } } diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index 3da9de71..eaf82551 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -6,8 +6,8 @@ use ql_wire::{ }; use crate::{ - session::SessionEvent, state::LinkState, OutboundWrite, QlFsm, QlFsmError, QlFsmEvent, - SessionWriteId, StreamReadIter, StreamWriter, + session::SessionEvent, state::LinkState, NoSessionError, OutboundWrite, QlFsm, QlFsmError, + QlFsmEvent, SessionWriteId, StreamError, StreamReadIter, StreamWriter, }; pub fn handle_bind_peer(fsm: &mut QlFsm, peer: ql_wire::PeerBundle) { @@ -35,7 +35,11 @@ pub fn receive( super::handle_handshake_record(fsm, crypto, &record, &mut emit) } wire::RecordType::Session => { - let state = fsm.state.link.connected_mut_or_err()?; + let state = fsm + .state + .link + .connected_mut() + .ok_or(QlFsmError::NoSession)?; let (decrypt_len, seq) = { let record = wire::QlSessionRecord::decode(&mut reader)?; if record.header.connection_id != state.transport.rx_connection_id { @@ -145,14 +149,14 @@ pub fn kill_session(fsm: &mut QlFsm, _code: SessionCloseCode) { fsm.state.link = crate::state::LinkState::Idle; } -pub fn open_stream(fsm: &mut QlFsm) -> Result { +pub fn open_stream(fsm: &mut QlFsm) -> Result { let state = fsm.state.link.connected_mut_or_err()?; - Ok(state.session.open_stream()?) + state.session.open_stream() } -pub fn write_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result, QlFsmError> { +pub fn write_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result, StreamError> { let state = fsm.state.link.connected_mut_or_err()?; - Ok(state.session.write_stream(stream_id)?) + state.session.write_stream(stream_id) } pub fn stream_read(fsm: &QlFsm, stream_id: StreamId) -> Option> { @@ -164,9 +168,9 @@ pub fn stream_read_commit( fsm: &mut QlFsm, stream_id: StreamId, len: usize, -) -> Result<(), QlFsmError> { +) -> Result<(), StreamError> { let state = fsm.state.link.connected_mut_or_err()?; - Ok(state.session.stream_read_commit(stream_id, len)?) + state.session.stream_read_commit(stream_id, len) } pub fn stream_available_bytes(fsm: &QlFsm, stream_id: StreamId) -> Option { @@ -176,9 +180,9 @@ pub fn stream_available_bytes(fsm: &QlFsm, stream_id: StreamId) -> Option .and_then(|state| state.session.stream_available_bytes(stream_id)) } -pub fn finish_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result<(), QlFsmError> { +pub fn finish_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result<(), StreamError> { let state = fsm.state.link.connected_mut_or_err()?; - Ok(state.session.finish_stream(stream_id)?) + state.session.finish_stream(stream_id) } pub fn close_stream( @@ -186,14 +190,14 @@ pub fn close_stream( stream_id: StreamId, target: CloseTarget, code: StreamCloseCode, -) -> Result<(), QlFsmError> { +) -> Result<(), StreamError> { let state = fsm.state.link.connected_mut_or_err()?; - Ok(state.session.close_stream(stream_id, target, code)?) + state.session.close_stream(stream_id, target, code) } -pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), QlFsmError> { +pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), NoSessionError> { let state = fsm.state.link.connected_mut_or_err()?; - Ok(state.session.queue_ping()?) + state.session.queue_ping() } pub fn emit_peer_status(fsm: &QlFsm, emit: &mut impl FnMut(QlFsmEvent)) { diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 335be9af..05b3ef20 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -28,7 +28,7 @@ mod tests; use std::time::{Duration, Instant}; pub use bytes::Bytes; -pub use error::QlFsmError; +pub use error::*; use ql_wire::{ CloseTarget, PeerBundle, QlCrypto, QlIdentity, SessionClose, SessionCloseCode, StreamClose, StreamCloseCode, StreamId, @@ -250,12 +250,12 @@ impl QlFsm { } /// opens a new outgoing stream - pub fn open_stream(&mut self) -> Result { + pub fn open_stream(&mut self) -> Result { implementation::open_stream(self) } /// returns a writer for an open stream - pub fn write_stream(&mut self, stream_id: StreamId) -> Result, QlFsmError> { + pub fn write_stream(&mut self, stream_id: StreamId) -> Result, StreamError> { implementation::write_stream(self, stream_id) } @@ -269,7 +269,7 @@ impl QlFsm { &mut self, stream_id: StreamId, len: usize, - ) -> Result<(), QlFsmError> { + ) -> Result<(), StreamError> { implementation::stream_read_commit(self, stream_id, len) } @@ -279,7 +279,7 @@ impl QlFsm { } /// marks the local write side as finished - pub fn finish_stream(&mut self, stream_id: StreamId) -> Result<(), QlFsmError> { + pub fn finish_stream(&mut self, stream_id: StreamId) -> Result<(), StreamError> { implementation::finish_stream(self, stream_id) } @@ -289,12 +289,12 @@ impl QlFsm { stream_id: StreamId, target: CloseTarget, code: StreamCloseCode, - ) -> Result<(), QlFsmError> { + ) -> Result<(), StreamError> { implementation::close_stream(self, stream_id, target, code) } /// queues a ping on the active session - pub fn queue_ping(&mut self) -> Result<(), QlFsmError> { + pub fn queue_ping(&mut self) -> Result<(), NoSessionError> { implementation::queue_ping(self) } } diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index afdb9554..d1deb89c 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -29,6 +29,7 @@ use self::{ stream_tx::StreamTxRange, tracked::{TrackedFrame, TrackedRecord, TrackedStreamData}, }; +use crate::{NoSessionError, StreamError}; #[derive(Debug, Clone, Copy)] pub struct SessionFsmConfig { @@ -76,28 +77,6 @@ pub enum SessionState { Closed, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum StreamError { - MissingStream, - NotWritable, - InvalidRead, - SessionClosed, -} - -impl std::fmt::Display for StreamError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let message = match self { - Self::MissingStream => "missing stream", - Self::NotWritable => "stream is not writable", - Self::InvalidRead => "invalid read commit", - Self::SessionClosed => "session is closed", - }; - f.write_str(message) - } -} - -impl std::error::Error for StreamError {} - pub struct StreamWriter<'a> { stream: &'a mut StreamState, send_buffer_size: usize, @@ -150,7 +129,7 @@ impl SessionFsm { } } - pub fn open_stream(&mut self) -> Result { + pub fn open_stream(&mut self) -> Result { self.ensure_session_open()?; let stream_id = self .config @@ -256,7 +235,7 @@ impl SessionFsm { Some(stream.readable_bytes()) } - pub fn queue_ping(&mut self) -> Result<(), StreamError> { + pub fn queue_ping(&mut self) -> Result<(), NoSessionError> { self.ensure_session_open()?; self.state.pending_control.ping = true; Ok(()) @@ -613,9 +592,9 @@ impl SessionFsm { self.state.next_stream_index = next_index; } - fn ensure_session_open(&self) -> Result<(), StreamError> { + fn ensure_session_open(&self) -> Result<(), NoSessionError> { if self.state.session_state == SessionState::Closed { - Err(StreamError::SessionClosed) + Err(NoSessionError) } else { Ok(()) } diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 15956504..293150cb 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -363,7 +363,7 @@ fn duplicate_remote_close_after_reap_is_ignored() { vec![ SessionEvent::Opened(close.stream_id), SessionEvent::Closed(close.clone()), - SessionEvent::WritableClosed(close.clone()), + SessionEvent::WritableClosed(close), ] ); diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index 03ae1709..c4b4fe96 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -5,7 +5,7 @@ use ql_wire::{ QlHandshakeRecord, SessionKey, TransportParams, }; -use crate::{replay_cache::ReplayCache, session::SessionFsm, FsmTime, PeerStatus, QlFsmError}; +use crate::{replay_cache::ReplayCache, session::SessionFsm, FsmTime, NoSessionError, PeerStatus}; pub struct QlFsmState { pub replay_cache: ReplayCache, @@ -98,8 +98,8 @@ impl LinkState { } #[inline] - pub fn connected_mut_or_err(&mut self) -> Result<&mut ConnectedState, QlFsmError> { - self.connected_mut().ok_or(QlFsmError::NoSession) + pub fn connected_mut_or_err(&mut self) -> Result<&mut ConnectedState, NoSessionError> { + self.connected_mut().ok_or(NoSessionError) } pub fn handshake_deadline(&self) -> Option { diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 5698aa9a..c43924e5 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -4,7 +4,7 @@ use bytes::Bytes; use ql_wire::{SessionClose, StreamId, VarInt}; use super::*; -use crate::{state::LinkState, PeerStatus, QlFsmError, QlFsmEvent}; +use crate::{state::LinkState, NoSessionError, PeerStatus, QlFsmEvent, StreamError}; fn stream_id(value: u32) -> StreamId { StreamId(VarInt::from_u32(value)) @@ -14,7 +14,7 @@ fn write_stream_bytes( fsm: &mut QlFsm, stream_id: StreamId, bytes: &[u8], -) -> Result { +) -> Result { let mut bytes = Bytes::copy_from_slice(bytes); let mut writer = fsm.write_stream(stream_id)?; Ok(writer.write(&mut bytes)) @@ -170,14 +170,14 @@ fn disconnected_stream_operations_fail_with_no_session() { let mut harness = Harness::paired_known(QlFsmConfig::default()); let missing = stream_id(0); - assert_eq!(harness.a.fsm.open_stream(), Err(QlFsmError::NoSession)); + assert_eq!(harness.a.fsm.open_stream(), Err(NoSessionError)); assert_eq!( write_stream_bytes(&mut harness.a.fsm, missing, b"queued"), - Err(QlFsmError::NoSession) + Err(StreamError::NoSession) ); assert_eq!( harness.a.fsm.finish_stream(missing), - Err(QlFsmError::NoSession) + Err(StreamError::NoSession) ); assert_eq!( harness.a.fsm.close_stream( @@ -185,12 +185,12 @@ fn disconnected_stream_operations_fail_with_no_session() { ql_wire::CloseTarget::Both, ql_wire::StreamCloseCode(0) ), - Err(QlFsmError::NoSession) + Err(StreamError::NoSession) ); - assert_eq!(harness.a.fsm.queue_ping(), Err(QlFsmError::NoSession)); + assert_eq!(harness.a.fsm.queue_ping(), Err(NoSessionError)); assert_eq!( harness.a.fsm.stream_read_commit(missing, 1), - Err(QlFsmError::NoSession) + Err(StreamError::NoSession) ); } diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index 530b52fa..019610f8 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -1,6 +1,7 @@ +use ql_fsm::NoSessionError; use ql_wire::{CloseTarget, PeerBundle, StreamCloseCode, StreamId}; -use crate::{chunk_slot::ChunkSlotRx, ByteReader, QlError, QlStreamError}; +use crate::{chunk_slot::ChunkSlotRx, ByteReader, QlStreamError}; pub(crate) enum RuntimeCommand { BindPeer { @@ -10,7 +11,7 @@ pub(crate) enum RuntimeCommand { OpenStream { request_reader: ChunkSlotRx, request_terminal: oneshot::Sender, - start: oneshot::Sender>, + start: oneshot::Sender>, }, PollInbound { stream_id: StreamId, diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 59f1df66..24df30ff 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -20,7 +20,7 @@ use crate::{ command::RuntimeCommand, handle::{ByteReader, ByteWriter, QlStream}, platform::{QlPlatform, QlTimer}, - QlError, QlStreamError, Runtime, RuntimeHandle, + QlStreamError, Runtime, RuntimeHandle, }; impl Runtime

{ @@ -155,11 +155,11 @@ impl DriverState { start, } => { let Some(runtime_tx) = self.runtime_tx.upgrade() else { - let _ = start.send(Err(QlError::Cancelled)); + let _ = start.send(Err(ql_fsm::NoSessionError)); return; }; - match fsm.open_stream().map_err(QlError::from) { + match fsm.open_stream() { Ok(stream_id) => { let (response_reader, response_writer) = chunk_slot::new(); let (response_terminal_tx, response_terminal_rx) = oneshot::channel(); diff --git a/ql-runtime/src/driver/state.rs b/ql-runtime/src/driver/state.rs index 4f89c210..0a522874 100644 --- a/ql-runtime/src/driver/state.rs +++ b/ql-runtime/src/driver/state.rs @@ -54,8 +54,8 @@ impl DriverStreamIo { } pub fn fail_all(&mut self) { - self.inbound_fail(QlStreamError::SessionClosed); - self.outbound_fail(QlStreamError::SessionClosed); + self.inbound_fail(QlStreamError::NoSession); + self.outbound_fail(QlStreamError::NoSession); } pub fn is_closed(&self) -> bool { diff --git a/ql-runtime/src/error.rs b/ql-runtime/src/error.rs index f04f7732..5b74bcf8 100644 --- a/ql-runtime/src/error.rs +++ b/ql-runtime/src/error.rs @@ -1,85 +1,16 @@ -use ql_fsm::QlFsmError; use ql_wire::StreamCloseCode; -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum QlError { - InvalidPayload, - InvalidState, - Expired, - DecryptFailed, - InvalidXid, - MissingStream, - NotWritable, - InvalidRead, - SessionClosed, - NoPeerBound, - NoSession, - SendFailed, - StreamClosed { code: ql_wire::StreamCloseCode }, - Cancelled, -} - -impl std::fmt::Display for QlError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::InvalidPayload => f.write_str("invalid payload"), - Self::InvalidState => f.write_str("invalid state"), - Self::Expired => f.write_str("expired"), - Self::DecryptFailed => f.write_str("decryption failed"), - Self::InvalidXid => f.write_str("invalid xid"), - Self::MissingStream => f.write_str("missing stream"), - Self::NotWritable => f.write_str("stream is not writable"), - Self::InvalidRead => f.write_str("invalid read"), - Self::SessionClosed => f.write_str("session is closed"), - Self::NoPeerBound => f.write_str("no peer bound"), - Self::NoSession => f.write_str("no active session"), - Self::SendFailed => f.write_str("send failed"), - Self::StreamClosed { code } => write!(f, "stream closed {code:?}"), - Self::Cancelled => f.write_str("cancelled"), - } - } -} - -impl std::error::Error for QlError {} - -impl From for QlError { - fn from(value: QlStreamError) -> Self { - match value { - QlStreamError::StreamClosed { code } => Self::StreamClosed { code }, - QlStreamError::SessionClosed => Self::SessionClosed, - } - } -} - -impl From for QlError { - fn from(value: QlFsmError) -> Self { - match value { - QlFsmError::InvalidPayload => Self::InvalidPayload, - QlFsmError::InvalidState => Self::InvalidState, - QlFsmError::Expired => Self::Expired, - QlFsmError::DecryptFailed => Self::DecryptFailed, - QlFsmError::InvalidXid => Self::InvalidXid, - QlFsmError::MissingStream => Self::MissingStream, - QlFsmError::NotWritable => Self::NotWritable, - QlFsmError::InvalidRead => Self::InvalidRead, - QlFsmError::SessionClosed => Self::SessionClosed, - QlFsmError::NoPeerBound => Self::NoPeerBound, - QlFsmError::NoSession => Self::NoSession, - } - } -} - #[derive(Debug, Clone, PartialEq, Eq)] pub enum QlStreamError { StreamClosed { code: StreamCloseCode }, - SessionClosed, + NoSession, } impl std::fmt::Display for QlStreamError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::StreamClosed { code } => write!(f, "stream closed {code:?}"), - Self::SessionClosed => f.write_str("session is closed"), + Self::NoSession => f.write_str("no session"), } } } diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index a4f581cb..5d38b985 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -1,10 +1,11 @@ mod reader; mod writer; +use ql_fsm::NoSessionError; use ql_wire::{CloseTarget, PeerBundle, StreamId}; pub use self::{reader::*, writer::*}; -use crate::{chunk_slot, command::RuntimeCommand, QlError}; +use crate::{chunk_slot, command::RuntimeCommand}; #[derive(Debug)] pub struct QlStream { @@ -31,7 +32,7 @@ impl RuntimeHandle { self.send(RuntimeCommand::Incoming(bytes)); } - pub async fn open_stream(&self) -> Result { + pub async fn open_stream(&self) -> Result { let (request_reader, request_writer) = chunk_slot::new(); let (request_terminal_tx, request_terminal_rx) = oneshot::channel(); let (start_tx, start_rx) = oneshot::channel(); diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index bbd841da..f1fa7d85 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -1,8 +1,9 @@ pub use self::{ - error::{QlError, QlStreamError}, + error::QlStreamError, handle::*, platform::*, }; +pub use ql_fsm::NoSessionError; pub mod chunk_slot; pub(crate) mod command; diff --git a/ql-runtime/src/rpc/error.rs b/ql-runtime/src/rpc/error.rs index f81ad652..c82a30ee 100644 --- a/ql-runtime/src/rpc/error.rs +++ b/ql-runtime/src/rpc/error.rs @@ -1,15 +1,28 @@ -use crate::QlError; +use ql_fsm::NoSessionError; +use ql_wire::StreamCloseCode; + +use crate::QlStreamError; #[derive(Debug)] pub enum RpcCallError { - Runtime(QlError), + NoSession, + StreamClosed(StreamCloseCode), Rpc(ql_rpc::RpcError), Codec(E), } -impl From for RpcCallError { - fn from(error: QlError) -> Self { - Self::Runtime(error) +impl From for RpcCallError { + fn from(_: NoSessionError) -> Self { + Self::NoSession + } +} + +impl From for RpcCallError { + fn from(error: QlStreamError) -> Self { + match error { + QlStreamError::StreamClosed { code } => Self::StreamClosed(code), + QlStreamError::NoSession => Self::NoSession, + } } } @@ -34,7 +47,8 @@ where { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Runtime(error) => write!(f, "{error}"), + Self::NoSession => write!(f, "no session"), + Self::StreamClosed(code) => write!(f, "stream closed {code:?}"), Self::Rpc(error) => write!(f, "{error}"), Self::Codec(error) => write!(f, "{error}"), } @@ -47,9 +61,10 @@ where { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { - Self::Runtime(error) => Some(error), Self::Rpc(error) => Some(error), Self::Codec(error) => Some(error), + RpcCallError::NoSession => None, + RpcCallError::StreamClosed(_) => None, } } } diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index 209a0492..48b06d23 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -14,7 +14,7 @@ use ql_rpc::{ }; pub use self::{error::*, request_with_progress::*, subscription::*}; -use crate::{ByteReader, QlError, RuntimeHandle}; +use crate::{ByteReader, QlStreamError, RuntimeHandle}; #[derive(Clone)] pub struct RpcHandle { @@ -86,7 +86,7 @@ impl RpcHandle { }) } - async fn start_request(&self, payload: Vec) -> Result { + async fn start_request(&self, payload: Vec) -> Result> { let mut stream = self.inner.open_stream().await?; stream.writer.write(Bytes::from(payload)).await?; stream.writer.finish(); @@ -94,12 +94,9 @@ impl RpcHandle { } } -async fn read_all(mut reader: ByteReader) -> Result, QlError> { +async fn read_all(mut reader: ByteReader) -> Result, QlStreamError> { let mut bytes = Vec::new(); - while let Some(chunk) = poll_fn(|cx| reader.poll_read_chunk(cx)) - .await - .map_err(QlError::from)? - { + while let Some(chunk) = poll_fn(|cx| reader.poll_read_chunk(cx)).await? { bytes.extend_from_slice(&chunk); } Ok(bytes) diff --git a/ql-runtime/src/rpc/request_with_progress.rs b/ql-runtime/src/rpc/request_with_progress.rs index ca768de9..cda35908 100644 --- a/ql-runtime/src/rpc/request_with_progress.rs +++ b/ql-runtime/src/rpc/request_with_progress.rs @@ -75,7 +75,7 @@ where } Poll::Ready(Err(error)) => { this.reader = None; - this.terminal = Some(Err(RpcCallError::Runtime(error.into()))); + this.terminal = Some(Err(error.into())); return Poll::Ready(None); } Poll::Pending => return Poll::Pending, @@ -126,7 +126,7 @@ where } Poll::Ready(Err(error)) => { this.reader = None; - return Poll::Ready(Err(RpcCallError::Runtime(error.into()))); + return Poll::Ready(Err(error.into())); } Poll::Pending => return Poll::Pending, } diff --git a/ql-runtime/src/rpc/subscription.rs b/ql-runtime/src/rpc/subscription.rs index 13f1596b..10648172 100644 --- a/ql-runtime/src/rpc/subscription.rs +++ b/ql-runtime/src/rpc/subscription.rs @@ -65,7 +65,7 @@ where } Poll::Ready(Err(error)) => { this.reader = None; - return Poll::Ready(Some(Err(RpcCallError::Runtime(error.into())))); + return Poll::Ready(Some(Err(error.into()))); } Poll::Pending => return Poll::Pending, } diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index 648c85bf..ac250e52 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -49,7 +49,7 @@ async fn opening_stream_requires_connection() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); assert!(matches!( handle_a.open_stream().await, - Err(QlError::NoSession) + Err(NoSessionError) )); }) .await; diff --git a/ql-runtime/src/tests/heartbeat.rs b/ql-runtime/src/tests/heartbeat.rs index 71de3b59..412c9393 100644 --- a/ql-runtime/src/tests/heartbeat.rs +++ b/ql-runtime/src/tests/heartbeat.rs @@ -59,7 +59,7 @@ async fn session_timeout_disconnects_and_fails_pending_open() { tokio::time::timeout(Duration::from_millis(300), next_chunk(&mut pending.reader)) .await .unwrap(); - assert!(matches!(result, Err(QlStreamError::SessionClosed))); + assert!(matches!(result, Err(QlStreamError::NoSession))); responder_task.abort(); }) diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index e3ae9c55..61602f60 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -24,7 +24,7 @@ use tokio::{task::LocalSet, time::Sleep}; use crate::{ new_runtime, platform::{PlatformFuture, QlTimer}, - QlError, QlFsmConfig, QlStream, QlStreamError, RuntimeConfig, RuntimeHandle, + NoSessionError, QlFsmConfig, QlStream, QlStreamError, RuntimeConfig, RuntimeHandle, }; mod handshake; From 7e647a3560079e3163a642583615c5d96673e5fa Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 12:46:26 -0400 Subject: [PATCH 149/304] ql-runtime: remove driver xid --- ql-runtime/src/driver/mod.rs | 10 +--------- ql-runtime/src/driver/state.rs | 3 +-- ql-runtime/src/driver/test.rs | 1 - 3 files changed, 2 insertions(+), 12 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 24df30ff..87661ce8 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -35,9 +35,7 @@ impl Runtime

{ } = self; let mut fsm = QlFsm::new(config.fsm, identity, now()); - let mut peer_xid = None; if let Some(peer) = platform.load_peer().await { - peer_xid = Some(peer.xid); fsm.bind_peer(peer); } @@ -45,7 +43,6 @@ impl Runtime

{ streams: HashMap::new(), runtime_tx: tx, max_concurrent_message_writes: config.max_concurrent_message_writes, - peer_xid, pending_fsm_events: VecDeque::new(), }; let mut in_flight = Vec::new(); @@ -136,7 +133,6 @@ impl DriverState { ) { match command { RuntimeCommand::BindPeer { peer } => { - self.peer_xid = Some(peer.xid); fsm.bind_peer(peer); } RuntimeCommand::Connect => { @@ -263,15 +259,11 @@ impl DriverState { match event { QlFsmEvent::NewPeer => { if let Some(peer) = fsm.peer().cloned() { - self.peer_xid = Some(peer.xid); platform.persist_peer(peer); } } QlFsmEvent::PeerStatusChanged(status) => { - if self.peer_xid.is_none() { - self.peer_xid = fsm.peer().map(|peer| peer.xid); - } - if let Some(peer) = self.peer_xid { + if let Some(peer) = fsm.peer().map(|peer| peer.xid) { platform.handle_peer_status(peer, status); } } diff --git a/ql-runtime/src/driver/state.rs b/ql-runtime/src/driver/state.rs index 0a522874..0279d66f 100644 --- a/ql-runtime/src/driver/state.rs +++ b/ql-runtime/src/driver/state.rs @@ -2,7 +2,7 @@ use std::collections::{HashMap, VecDeque}; use bytes::Bytes; use ql_fsm::QlFsmEvent; -use ql_wire::{CloseTarget, StreamId, XID}; +use ql_wire::{CloseTarget, StreamId}; use crate::{ chunk_slot::{ChunkSlotRx, ChunkSlotTx, TrySendError}, @@ -14,7 +14,6 @@ pub struct DriverState { pub streams: HashMap, pub runtime_tx: async_channel::WeakSender, pub max_concurrent_message_writes: usize, - pub peer_xid: Option, pub pending_fsm_events: VecDeque, } diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index a459aadb..68198132 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -114,7 +114,6 @@ fn new_driver_state() -> (DriverState, QlFsm) { streams: HashMap::new(), runtime_tx: runtime_tx.downgrade(), max_concurrent_message_writes: 1, - peer_xid: None, pending_fsm_events: VecDeque::new(), }, QlFsm::new(ql_fsm::QlFsmConfig::default(), new_identity(7), now()), From d5929299b12d855f3c25e985954c07ca636ffe08 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 13:01:00 -0400 Subject: [PATCH 150/304] ql: finish stream on writer --- ql-fsm/src/implementation/core.rs | 5 --- ql-fsm/src/lib.rs | 5 --- ql-fsm/src/session/mod.rs | 20 +++-------- ql-fsm/src/tests/proptest.rs | 6 ++-- ql-fsm/src/tests/session.rs | 8 +++-- ql-runtime/src/driver/mod.rs | 58 +++++++++++++++++-------------- 6 files changed, 47 insertions(+), 55 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index eaf82551..f3d92fa9 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -180,11 +180,6 @@ pub fn stream_available_bytes(fsm: &QlFsm, stream_id: StreamId) -> Option .and_then(|state| state.session.stream_available_bytes(stream_id)) } -pub fn finish_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result<(), StreamError> { - let state = fsm.state.link.connected_mut_or_err()?; - state.session.finish_stream(stream_id) -} - pub fn close_stream( fsm: &mut QlFsm, stream_id: StreamId, diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 05b3ef20..e6de3fb8 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -278,11 +278,6 @@ impl QlFsm { implementation::stream_available_bytes(self, stream_id) } - /// marks the local write side as finished - pub fn finish_stream(&mut self, stream_id: StreamId) -> Result<(), StreamError> { - implementation::finish_stream(self, stream_id) - } - /// closes the origin lane, return lane, or both lanes of a stream pub fn close_stream( &mut self, diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index d1deb89c..7b9c9f6a 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -94,6 +94,11 @@ impl StreamWriter<'_> { } accepted } + + pub fn finish(self) { + self.stream.tx.queue_fin(); + self.stream.outbound_state = OutboundState::FinQueued; + } } pub struct SessionFsm { @@ -165,21 +170,6 @@ impl SessionFsm { }) } - pub fn finish_stream(&mut self, stream_id: StreamId) -> Result<(), StreamError> { - self.ensure_session_open()?; - let stream = self - .state - .streams - .get_mut(&stream_id) - .ok_or(StreamError::MissingStream)?; - if !stream.is_writable() { - return Err(StreamError::NotWritable); - } - stream.tx.queue_fin(); - stream.outbound_state = OutboundState::FinQueued; - Ok(()) - } - pub fn close_stream( &mut self, stream_id: StreamId, diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index 4aef024a..c6fd67f7 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -341,14 +341,16 @@ impl Runner { } Action::FinishA(slot) => { if let Some(stream_id) = self.slots_a[*slot] { - if self.harness.a.fsm.finish_stream(stream_id).is_ok() { + if let Ok(writer) = self.harness.a.fsm.write_stream(stream_id) { + writer.finish(); self.finished_by_a.insert(stream_id); } } } Action::FinishB(slot) => { if let Some(stream_id) = self.slots_b[*slot] { - if self.harness.b.fsm.finish_stream(stream_id).is_ok() { + if let Ok(writer) = self.harness.b.fsm.write_stream(stream_id) { + writer.finish(); self.finished_by_b.insert(stream_id); } } diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index c43924e5..8ccbeb50 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -45,7 +45,7 @@ fn connected_fsms_deliver_stream_data() { write_stream_bytes(&mut harness.a.fsm, stream_id, b"hello").unwrap(), 5 ); - harness.a.fsm.finish_stream(stream_id).unwrap(); + harness.a.fsm.write_stream(stream_id).unwrap().finish(); harness.pump(); @@ -176,7 +176,11 @@ fn disconnected_stream_operations_fail_with_no_session() { Err(StreamError::NoSession) ); assert_eq!( - harness.a.fsm.finish_stream(missing), + harness + .a + .fsm + .write_stream(missing) + .map(crate::session::StreamWriter::finish), Err(StreamError::NoSession) ); assert_eq!( diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 87661ce8..0f42d67e 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -3,7 +3,7 @@ mod state; mod test; use std::{ - collections::{HashMap, VecDeque}, + collections::{hash_map::Entry, HashMap, VecDeque}, future::Future, pin::Pin, task::Poll, @@ -457,36 +457,42 @@ impl DriverState { } fn poll_stream(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { - let should_finish = { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - let Some(reader) = stream.outbound_reader_mut() else { - return; - }; + let Entry::Occupied(mut entry) = self.streams.entry(stream_id) else { + return; + }; + let stream = entry.get_mut(); + let Some(reader) = stream.outbound_reader_mut() else { + return; + }; - if reader.is_finished() { - true - } else { - let Ok(mut writer) = fsm.write_stream(stream_id) else { - return; - }; - let capacity = writer.capacity(); - if capacity > 0 { - if let Ok(Some(mut bytes)) = reader.try_recv(capacity) { - let _ = writer.write(&mut bytes); - } - } - reader.is_finished() + if reader.is_finished() { + if let Ok(writer) = fsm.write_stream(stream_id) { + writer.finish(); + } + stream.outbound_close(); + if stream.is_closed() { + entry.remove(); } + return; + } + + let Ok(mut writer) = fsm.write_stream(stream_id) else { + return; }; - if should_finish { - let _ = fsm.finish_stream(stream_id); - if let Some(stream) = self.streams.get_mut(&stream_id) { - stream.outbound_close(); + let capacity = writer.capacity(); + if capacity > 0 { + if let Ok(Some(mut bytes)) = reader.try_recv(capacity) { + let _ = writer.write(&mut bytes); + } + } + + if reader.is_finished() { + writer.finish(); + stream.outbound_close(); + if stream.is_closed() { + entry.remove(); } - self.try_reap_stream(stream_id); } } From ac504941d3e080ec7d5581ccdaeb48ec58e23edb Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 13:25:08 -0400 Subject: [PATCH 151/304] ql-fsm: cleanup --- ql-fsm/src/session/mod.rs | 42 +++++------------------------ ql-fsm/src/session/state.rs | 8 +++++- ql-fsm/src/session/stream_writer.rs | 34 +++++++++++++++++++++++ 3 files changed, 48 insertions(+), 36 deletions(-) create mode 100644 ql-fsm/src/session/stream_writer.rs diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 7b9c9f6a..f3f7453c 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -5,6 +5,7 @@ pub(crate) mod state; pub(crate) mod stream_parity; pub(crate) mod stream_rx; pub(crate) mod stream_tx; +mod stream_writer; pub(crate) mod tracked; #[cfg(test)] @@ -20,10 +21,14 @@ use ql_wire::{ WireError, }; +pub use self::stream_writer::StreamWriter; use self::{ received_records::{ReceiveOutcome, ReceivedRecords}, remote_stream_history::RemoteStreamHistory, - state::{AckState, InboundState, OutboundState, SessionFsmState, StreamRole, StreamState}, + state::{ + AckState, InboundState, OutboundState, SessionFsmState, SessionState, StreamRole, + StreamState, + }, stream_parity::StreamParity, stream_rx::{StreamReadIter, StreamRxError}, stream_tx::StreamTxRange, @@ -71,36 +76,6 @@ pub enum SessionEvent { SessionClosed(SessionClose), } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum SessionState { - Open, - Closed, -} - -pub struct StreamWriter<'a> { - stream: &'a mut StreamState, - send_buffer_size: usize, -} - -impl StreamWriter<'_> { - pub fn capacity(&self) -> usize { - self.stream.send_capacity(self.send_buffer_size) - } - - pub fn write(&mut self, bytes: &mut Bytes) -> usize { - let accepted = bytes.len().min(self.capacity()); - if accepted > 0 { - self.stream.tx.append(bytes.split_to(accepted)); - } - accepted - } - - pub fn finish(self) { - self.stream.tx.queue_fin(); - self.stream.outbound_state = OutboundState::FinQueued; - } -} - pub struct SessionFsm { config: SessionFsmConfig, state: SessionFsmState, @@ -164,10 +139,7 @@ impl SessionFsm { return Err(StreamError::NotWritable); } - Ok(StreamWriter { - stream, - send_buffer_size, - }) + Ok(StreamWriter::new(stream, send_buffer_size)) } pub fn close_stream( diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index 54f4d584..c8f6dd6b 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -5,7 +5,7 @@ use ql_wire::{CloseTarget, RecordSeq, SessionClose, StreamClose, StreamId}; use super::{ received_records::ReceivedRecords, remote_stream_history::RemoteStreamHistory, - stream_rx::StreamRx, stream_tx::StreamTx, tracked::TrackedRecord, SessionState, + stream_rx::StreamRx, stream_tx::StreamTx, tracked::TrackedRecord, }; pub struct SessionFsmState { @@ -25,6 +25,12 @@ pub struct SessionFsmState { pub remote_stream_history: RemoteStreamHistory, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SessionState { + Open, + Closed, +} + #[derive(Debug)] pub struct StreamState { pub role: StreamRole, diff --git a/ql-fsm/src/session/stream_writer.rs b/ql-fsm/src/session/stream_writer.rs new file mode 100644 index 00000000..1a12e471 --- /dev/null +++ b/ql-fsm/src/session/stream_writer.rs @@ -0,0 +1,34 @@ +use bytes::Bytes; + +use super::state::{OutboundState, StreamState}; + +pub struct StreamWriter<'a> { + stream: &'a mut StreamState, + send_buffer_size: usize, +} + +impl<'a> StreamWriter<'a> { + pub(super) fn new(stream: &'a mut StreamState, send_buffer_size: usize) -> Self { + Self { + stream, + send_buffer_size, + } + } + + pub fn capacity(&self) -> usize { + self.stream.send_capacity(self.send_buffer_size) + } + + pub fn write(&mut self, bytes: &mut Bytes) -> usize { + let accepted = bytes.len().min(self.capacity()); + if accepted > 0 { + self.stream.tx.append(bytes.split_to(accepted)); + } + accepted + } + + pub fn finish(self) { + self.stream.tx.queue_fin(); + self.stream.outbound_state = OutboundState::FinQueued; + } +} From 2a9177256afb0007ee3c6d279c409bed66de01bb Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 7 Apr 2026 14:08:04 -0400 Subject: [PATCH 152/304] ql-fsm: stream ops --- ql-fsm/src/error.rs | 13 ++- ql-fsm/src/implementation/core.rs | 41 +------ ql-fsm/src/lib.rs | 44 ++------ ql-fsm/src/session/mod.rs | 109 +++++-------------- ql-fsm/src/session/state.rs | 6 +- ql-fsm/src/session/stream_ops.rs | 131 ++++++++++++++++++++++ ql-fsm/src/session/stream_writer.rs | 34 ------ ql-fsm/src/session/tests.rs | 52 ++++----- ql-fsm/src/tests/proptest.rs | 79 +++++++------- ql-fsm/src/tests/session.rs | 85 ++++++++++----- ql-runtime/src/driver/mod.rs | 163 ++++++++++++++-------------- ql-runtime/src/driver/test.rs | 4 +- 12 files changed, 390 insertions(+), 371 deletions(-) create mode 100644 ql-fsm/src/session/stream_ops.rs delete mode 100644 ql-fsm/src/session/stream_writer.rs diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs index 2e2e929d..b5706503 100644 --- a/ql-fsm/src/error.rs +++ b/ql-fsm/src/error.rs @@ -65,7 +65,6 @@ impl Error for NoSessionError {} pub enum StreamError { MissingStream, NotWritable, - InvalidRead, NoSession, } @@ -74,7 +73,6 @@ impl Display for StreamError { let message = match self { Self::MissingStream => "missing stream", Self::NotWritable => "stream is not writable", - Self::InvalidRead => "invalid read commit", Self::NoSession => "no session", }; f.write_str(message) @@ -88,3 +86,14 @@ impl From for StreamError { Self::NoSession } } + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct CommitReadError; + +impl Display for CommitReadError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "invalid read commit") + } +} + +impl Error for CommitReadError {} diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index f3d92fa9..5db7bd51 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -2,12 +2,12 @@ use std::time::{Duration, Instant}; use bytes::Bytes; use ql_wire::{ - self as wire, CloseTarget, QlCrypto, SessionCloseCode, StreamCloseCode, StreamId, WireDecode, + self as wire, QlCrypto, SessionCloseCode, StreamId, WireDecode, }; use crate::{ session::SessionEvent, state::LinkState, NoSessionError, OutboundWrite, QlFsm, QlFsmError, - QlFsmEvent, SessionWriteId, StreamError, StreamReadIter, StreamWriter, + QlFsmEvent, SessionWriteId, StreamError, StreamOps, }; pub fn handle_bind_peer(fsm: &mut QlFsm, peer: ql_wire::PeerBundle) { @@ -149,45 +149,14 @@ pub fn kill_session(fsm: &mut QlFsm, _code: SessionCloseCode) { fsm.state.link = crate::state::LinkState::Idle; } -pub fn open_stream(fsm: &mut QlFsm) -> Result { +pub fn open_stream(fsm: &mut QlFsm) -> Result, NoSessionError> { let state = fsm.state.link.connected_mut_or_err()?; state.session.open_stream() } -pub fn write_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result, StreamError> { +pub fn stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result, StreamError> { let state = fsm.state.link.connected_mut_or_err()?; - state.session.write_stream(stream_id) -} - -pub fn stream_read(fsm: &QlFsm, stream_id: StreamId) -> Option> { - let state = fsm.state.link.connected()?; - state.session.stream_read(stream_id) -} - -pub fn stream_read_commit( - fsm: &mut QlFsm, - stream_id: StreamId, - len: usize, -) -> Result<(), StreamError> { - let state = fsm.state.link.connected_mut_or_err()?; - state.session.stream_read_commit(stream_id, len) -} - -pub fn stream_available_bytes(fsm: &QlFsm, stream_id: StreamId) -> Option { - fsm.state - .link - .connected() - .and_then(|state| state.session.stream_available_bytes(stream_id)) -} - -pub fn close_stream( - fsm: &mut QlFsm, - stream_id: StreamId, - target: CloseTarget, - code: StreamCloseCode, -) -> Result<(), StreamError> { - let state = fsm.state.link.connected_mut_or_err()?; - state.session.close_stream(stream_id, target, code) + state.session.stream(stream_id) } pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), NoSessionError> { diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index e6de3fb8..cc7bf309 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -3,7 +3,7 @@ //! a caller drives `QlFsm` inside its own event loop //! //! inputs to that loop usually include -//! - app actions like `bind_peer`, `connect_ik`, `connect_kk`, `open_stream`, or `write_stream` +//! - app actions like `bind_peer`, `connect_ik`, `connect_kk`, `open_stream`, or `stream` //! - inbound transport bytes passed to `receive` //! - a deadline expiring, handled by calling `on_timer` //! - transport write results passed to `confirm_session_write` or `reject_session_write` @@ -30,10 +30,9 @@ use std::time::{Duration, Instant}; pub use bytes::Bytes; pub use error::*; use ql_wire::{ - CloseTarget, PeerBundle, QlCrypto, QlIdentity, SessionClose, SessionCloseCode, StreamClose, - StreamCloseCode, StreamId, + PeerBundle, QlCrypto, QlIdentity, SessionClose, SessionCloseCode, StreamClose, StreamId, }; -pub use session::{stream_rx::StreamReadIter, StreamWriter}; +pub use session::{stream_rx::StreamReadIter, StreamOps, StreamWriter}; use crate::{ replay_cache::ReplayCache, @@ -250,42 +249,13 @@ impl QlFsm { } /// opens a new outgoing stream - pub fn open_stream(&mut self) -> Result { + pub fn open_stream(&mut self) -> Result, NoSessionError> { implementation::open_stream(self) } - /// returns a writer for an open stream - pub fn write_stream(&mut self, stream_id: StreamId) -> Result, StreamError> { - implementation::write_stream(self, stream_id) - } - - /// returns the readable stream bytes as owned `Bytes` views without consuming them - pub fn stream_read(&self, stream_id: StreamId) -> Option> { - implementation::stream_read(self, stream_id) - } - - /// marks previously read bytes as consumed - pub fn stream_read_commit( - &mut self, - stream_id: StreamId, - len: usize, - ) -> Result<(), StreamError> { - implementation::stream_read_commit(self, stream_id, len) - } - - /// returns how many bytes can be read from a stream - pub fn stream_available_bytes(&self, stream_id: StreamId) -> Option { - implementation::stream_available_bytes(self, stream_id) - } - - /// closes the origin lane, return lane, or both lanes of a stream - pub fn close_stream( - &mut self, - stream_id: StreamId, - target: CloseTarget, - code: StreamCloseCode, - ) -> Result<(), StreamError> { - implementation::close_stream(self, stream_id, target, code) + /// returns a facade for an open stream + pub fn stream(&mut self, stream_id: StreamId) -> Result, StreamError> { + implementation::stream(self, stream_id) } /// queues a ping on the active session diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index f3f7453c..6467aba2 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -2,10 +2,10 @@ pub(crate) mod range_set; pub(crate) mod received_records; pub(crate) mod remote_stream_history; pub(crate) mod state; +mod stream_ops; pub(crate) mod stream_parity; pub(crate) mod stream_rx; pub(crate) mod stream_tx; -mod stream_writer; pub(crate) mod tracked; #[cfg(test)] @@ -17,11 +17,10 @@ use bytes::Bytes; use indexmap::{map::Entry, IndexMap}; use ql_wire::{ CloseTarget, RecordAck, RecordSeq, SessionClose, SessionCloseCode, SessionFrame, - SessionRecordBuilder, StreamClose, StreamCloseCode, StreamData, StreamId, StreamWindow, VarInt, - WireError, + SessionRecordBuilder, StreamClose, StreamData, StreamId, StreamWindow, VarInt, WireError, }; -pub use self::stream_writer::StreamWriter; +pub use self::stream_ops::*; use self::{ received_records::{ReceiveOutcome, ReceivedRecords}, remote_stream_history::RemoteStreamHistory, @@ -30,7 +29,7 @@ use self::{ StreamState, }, stream_parity::StreamParity, - stream_rx::{StreamReadIter, StreamRxError}, + stream_rx::StreamRxError, stream_tx::StreamTxRange, tracked::{TrackedFrame, TrackedRecord, TrackedStreamData}, }; @@ -109,7 +108,7 @@ impl SessionFsm { } } - pub fn open_stream(&mut self) -> Result { + pub fn open_stream(&mut self) -> Result, NoSessionError> { self.ensure_session_open()?; let stream_id = self .config @@ -124,77 +123,17 @@ impl SessionFsm { self.config.initial_peer_stream_receive_window, ), ); - Ok(stream_id) + let stream_index = self.state.streams.len() - 1; + Ok(StreamOps::new(self, stream_id, stream_index)) } - pub fn write_stream(&mut self, stream_id: StreamId) -> Result, StreamError> { + pub fn stream(&mut self, stream_id: StreamId) -> Result, StreamError> { self.ensure_session_open()?; - let send_buffer_size = self.config.stream_send_buffer_size; - let stream = self - .state - .streams - .get_mut(&stream_id) - .ok_or(StreamError::MissingStream)?; - if !stream.is_writable() { - return Err(StreamError::NotWritable); - } - - Ok(StreamWriter::new(stream, send_buffer_size)) - } - - pub fn close_stream( - &mut self, - stream_id: StreamId, - target: CloseTarget, - code: StreamCloseCode, - ) -> Result<(), StreamError> { - self.ensure_session_open()?; - { - let stream = self - .state - .streams - .get_mut(&stream_id) - .ok_or(StreamError::MissingStream)?; - Self::apply_local_close_to_stream(stream, target); - stream.pending_close = Some(StreamClose { - stream_id, - target, - code, - }); - } - self.try_reap_stream(stream_id); - Ok(()) - } - - pub fn stream_read(&self, stream_id: StreamId) -> Option> { - let stream = self.state.streams.get(&stream_id)?; - Some(stream.rx.bytes()) - } - - pub fn stream_read_commit( - &mut self, - stream_id: StreamId, - len: usize, - ) -> Result<(), StreamError> { - let stream = self - .state - .streams - .get_mut(&stream_id) - .ok_or(StreamError::MissingStream)?; - if len > stream.readable_bytes() { - return Err(StreamError::InvalidRead); - } - stream.rx.consume(len); - if stream.recv_limit() > stream.advertised_max_offset { - stream.pending_window = true; - } - self.try_reap_stream(stream_id); - Ok(()) - } + let Some(stream_index) = self.state.streams.get_index_of(&stream_id) else { + return Err(StreamError::MissingStream); + }; - pub fn stream_available_bytes(&self, stream_id: StreamId) -> Option { - let stream = self.state.streams.get(&stream_id)?; - Some(stream.readable_bytes()) + Ok(StreamOps::new(self, stream_id, stream_index)) } pub fn queue_ping(&mut self) -> Result<(), NoSessionError> { @@ -856,19 +795,25 @@ impl SessionFsm { } fn try_reap_stream(&mut self, stream_id: StreamId) { - let should_reap = self - .state - .streams - .get(&stream_id) - .is_some_and(|stream| self.stream_is_reapable(stream_id, stream)); - if !should_reap { + let Some(index) = self.state.streams.get_index_of(&stream_id) else { return; - } + }; + self.try_reap_stream_at(stream_id, index); + } - let Some(index) = self.state.streams.get_index_of(&stream_id) else { + fn try_reap_stream_at(&mut self, stream_id: StreamId, index: usize) { + let Some((indexed_stream_id, stream)) = self.state.streams.get_index(index) else { return; }; - self.state.streams.shift_remove(&stream_id); + debug_assert_eq!(*indexed_stream_id, stream_id); + if !self.stream_is_reapable(stream_id, stream) { + return; + } + self.reap_stream_at(index); + } + + fn reap_stream_at(&mut self, index: usize) { + self.state.streams.shift_remove_index(index); if self.state.streams.is_empty() { self.state.next_stream_index = 0; diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index c8f6dd6b..52b98134 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -68,12 +68,8 @@ impl StreamState { matches!(self.outbound_state, OutboundState::Open) } - pub fn buffered_send_bytes(&self) -> usize { - self.tx.buffered_len() - } - pub fn send_capacity(&self, send_buffer_size: usize) -> usize { - send_buffer_size.saturating_sub(self.buffered_send_bytes()) + send_buffer_size.saturating_sub(self.tx.buffered_len()) } pub fn readable_bytes(&self) -> usize { diff --git a/ql-fsm/src/session/stream_ops.rs b/ql-fsm/src/session/stream_ops.rs new file mode 100644 index 00000000..8e4f1ff7 --- /dev/null +++ b/ql-fsm/src/session/stream_ops.rs @@ -0,0 +1,131 @@ +use ql_wire::{CloseTarget, StreamClose, StreamCloseCode, StreamId}; + +use super::{state::StreamState, stream_rx::StreamReadIter, SessionFsm}; +use crate::CommitReadError; + +pub struct StreamOps<'a> { + session: &'a mut SessionFsm, + stream_id: StreamId, + stream_index: usize, + reap_on_drop: bool, +} + +impl<'a> StreamOps<'a> { + pub(super) fn new( + session: &'a mut SessionFsm, + stream_id: StreamId, + stream_index: usize, + ) -> Self { + Self { + session, + stream_id, + stream_index, + reap_on_drop: false, + } + } + + /// returns this stream's identifier + pub fn stream_id(&self) -> StreamId { + self.stream_id + } + + /// returns the readable stream bytes as owned `Bytes` views without consuming them + pub fn read(&self) -> StreamReadIter<'_> { + self.stream().rx.bytes() + } + + /// returns how many bytes can be read from the stream + pub fn readable_bytes(&self) -> usize { + self.stream().readable_bytes() + } + + /// marks previously read bytes as consumed + pub fn commit_read(&mut self, len: usize) -> Result<(), CommitReadError> { + let stream = self.stream_mut(); + if len > stream.readable_bytes() { + return Err(CommitReadError); + } + stream.rx.consume(len); + if stream.recv_limit() > stream.advertised_max_offset { + stream.pending_window = true; + } + self.reap_on_drop = true; + Ok(()) + } + + /// returns a writer if the local write side is still open + pub fn writer(&mut self) -> Option> { + let send_buffer_size = self.session.config.stream_send_buffer_size; + let stream = self.stream_mut(); + if !stream.is_writable() { + return None; + } + Some(StreamWriter::new(stream, send_buffer_size)) + } + + /// closes the origin lane, return lane, or both lanes of the stream + pub fn close(&mut self, target: CloseTarget, code: StreamCloseCode) { + let stream_id = self.stream_id; + let stream = self.stream_mut(); + SessionFsm::apply_local_close_to_stream(stream, target); + stream.pending_close = Some(StreamClose { + stream_id, + target, + code, + }); + self.reap_on_drop = true; + } + + fn stream(&self) -> &StreamState { + &self.session.state.streams[self.stream_index] + } + + fn stream_mut(&mut self) -> &mut StreamState { + &mut self.session.state.streams[self.stream_index] + } +} + +impl Drop for StreamOps<'_> { + fn drop(&mut self) { + if !self.reap_on_drop { + return; + } + + self.session + .try_reap_stream_at(self.stream_id, self.stream_index); + } +} + +pub struct StreamWriter<'a> { + stream: &'a mut StreamState, + send_buffer_size: usize, +} + +impl<'a> StreamWriter<'a> { + pub(super) fn new(stream: &'a mut StreamState, send_buffer_size: usize) -> Self { + Self { + stream, + send_buffer_size, + } + } + + /// returns how many bytes can still be buffered for local writes + pub fn capacity(&self) -> usize { + self.stream.send_capacity(self.send_buffer_size) + } + + /// appends as many bytes as possible and returns the accepted count + pub fn write(&mut self, bytes: &mut bytes::Bytes) -> usize { + let accepted = bytes.len().min(self.capacity()); + if accepted > 0 { + self.stream.tx.append(bytes.split_to(accepted)); + } + accepted + } + + /// marks the local write side as finished + pub fn finish(self) { + self.stream.tx.queue_fin(); + self.stream.outbound_state = super::state::OutboundState::FinQueued; + } +} diff --git a/ql-fsm/src/session/stream_writer.rs b/ql-fsm/src/session/stream_writer.rs deleted file mode 100644 index 1a12e471..00000000 --- a/ql-fsm/src/session/stream_writer.rs +++ /dev/null @@ -1,34 +0,0 @@ -use bytes::Bytes; - -use super::state::{OutboundState, StreamState}; - -pub struct StreamWriter<'a> { - stream: &'a mut StreamState, - send_buffer_size: usize, -} - -impl<'a> StreamWriter<'a> { - pub(super) fn new(stream: &'a mut StreamState, send_buffer_size: usize) -> Self { - Self { - stream, - send_buffer_size, - } - } - - pub fn capacity(&self) -> usize { - self.stream.send_capacity(self.send_buffer_size) - } - - pub fn write(&mut self, bytes: &mut Bytes) -> usize { - let accepted = bytes.len().min(self.capacity()); - if accepted > 0 { - self.stream.tx.append(bytes.split_to(accepted)); - } - accepted - } - - pub fn finish(self) { - self.stream.tx.queue_fin(); - self.stream.outbound_state = OutboundState::FinQueued; - } -} diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 293150cb..05538f14 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -21,25 +21,21 @@ fn offset(value: u64) -> VarInt { VarInt::from_u64(value).unwrap() } +fn open_stream_id(fsm: &mut SessionFsm) -> StreamId { + fsm.open_stream().unwrap().stream_id() +} + fn write_stream_bytes(fsm: &mut SessionFsm, stream_id: StreamId, bytes: &[u8]) -> usize { let mut bytes = Bytes::copy_from_slice(bytes); - let mut writer = fsm.write_stream(stream_id).unwrap(); + let mut stream = fsm.stream(stream_id).unwrap(); + let mut writer = stream.writer().unwrap(); writer.write(&mut bytes) } fn read_stream_all(fsm: &mut SessionFsm, stream_id: StreamId) -> Vec { - let mut out = Vec::new(); - loop { - let mut read = 0; - for chunk in fsm.stream_read(stream_id).unwrap() { - out.extend_from_slice(&chunk); - read += chunk.len(); - } - if read == 0 { - break; - } - fsm.stream_read_commit(stream_id, read).unwrap(); - } + let mut stream = fsm.stream(stream_id).unwrap(); + let out = stream.read().flatten().collect::>(); + stream.commit_read(out.len()).unwrap(); out } @@ -78,7 +74,7 @@ fn receive_events( fn outbound_record_seq_increments_monotonically() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_id = fsm.open_stream().unwrap(); + let stream_id = open_stream_id(&mut fsm); assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"one"), 3); let (first_seq, _) = next_outbound(&mut fsm, now).unwrap(); @@ -94,7 +90,7 @@ fn outbound_record_seq_increments_monotonically() { fn retransmit_uses_new_record_seq() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_id = fsm.open_stream().unwrap(); + let stream_id = open_stream_id(&mut fsm); assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"retry"), 5); let (first_seq, first) = next_outbound(&mut fsm, now).unwrap(); @@ -116,8 +112,8 @@ fn lost_record_on_one_stream_does_not_block_another_stream() { }, now, ); - let stream_id_a = fsm.open_stream().unwrap(); - let stream_id_b = fsm.open_stream().unwrap(); + let stream_id_a = open_stream_id(&mut fsm); + let stream_id_b = open_stream_id(&mut fsm); let payload_a = vec![b'a'; 40]; let payload_b = vec![b'b'; 40]; @@ -154,7 +150,7 @@ fn ack_reopens_write_capacity() { }, now, ); - let stream_id = fsm.open_stream().unwrap(); + let stream_id = open_stream_id(&mut fsm); assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"abcd"), 4); let (record_seq, _record) = next_outbound(&mut fsm, now).unwrap(); @@ -207,15 +203,16 @@ fn commit_stream_read_is_what_advances_stream_window() { assert!(matches!(first.as_slice(), [SessionFrame::Ack(_)])); let read = fsm - .stream_read(stream_id) + .stream(stream_id) .unwrap() + .read() .map(|chunk| chunk.len()) .sum::(); assert_eq!(read, 2); assert!(next_outbound(&mut fsm, now + Duration::from_millis(2)).is_none()); - fsm.stream_read_commit(stream_id, 2).unwrap(); + fsm.stream(stream_id).unwrap().commit_read(2).unwrap(); let (_second_seq, second) = next_outbound(&mut fsm, now + Duration::from_millis(3)).unwrap(); assert!(matches!( second.as_slice(), @@ -281,10 +278,11 @@ fn inbound_stream_data_emits_opened_and_readable() { fn remote_stream_close_is_reliable_and_retried() { let now = Instant::now(); let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); - let stream_id = fsm.open_stream().unwrap(); + let stream_id = open_stream_id(&mut fsm); - fsm.close_stream(stream_id, CloseTarget::Both, StreamCloseCode(0)) - .unwrap(); + fsm.stream(stream_id) + .unwrap() + .close(CloseTarget::Both, StreamCloseCode(0)); let (write_id, builder) = fsm.take_next_write(now).unwrap(); fsm.confirm_write(now, write_id.expect("stream close should be tracked")); @@ -314,7 +312,8 @@ fn stream_ids_follow_even_odd_xid_ordering() { now, ) .open_stream() - .unwrap(); + .unwrap() + .stream_id(); let odd_id = SessionFsm::new( SessionFsmConfig { local_parity: odd, @@ -323,7 +322,8 @@ fn stream_ids_follow_even_odd_xid_ordering() { now, ) .open_stream() - .unwrap(); + .unwrap() + .stream_id(); assert_eq!(even_id.into_inner() % 2, 0); assert_eq!(odd_id.into_inner() % 2, 1); @@ -471,7 +471,7 @@ fn initial_peer_stream_receive_window_limits_first_send() { }, now, ); - let stream_id = fsm.open_stream().unwrap(); + let stream_id = open_stream_id(&mut fsm); assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"hello"), 5); let (_first_seq, first) = next_outbound(&mut fsm, now).unwrap(); diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index c6fd67f7..841b0017 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -304,13 +304,15 @@ impl Runner { let _ = take_pending(&mut self.pending_b_to_a, *index); } Action::OpenStreamA(slot) => { - if let Ok(stream_id) = self.harness.a.fsm.open_stream() { + if let Ok(stream) = self.harness.a.fsm.open_stream() { + let stream_id = stream.stream_id(); self.slots_a[*slot] = Some(stream_id); self.known_streams.insert(stream_id); } } Action::OpenStreamB(slot) => { - if let Ok(stream_id) = self.harness.b.fsm.open_stream() { + if let Ok(stream) = self.harness.b.fsm.open_stream() { + let stream_id = stream.stream_id(); self.slots_b[*slot] = Some(stream_id); self.known_streams.insert(stream_id); } @@ -318,52 +320,55 @@ impl Runner { Action::WriteA { slot, bytes } => { if let Some(stream_id) = self.slots_a[*slot] { let mut chunk = Bytes::copy_from_slice(bytes); - if let Ok(mut writer) = self.harness.a.fsm.write_stream(stream_id) { - let accepted = writer.write(&mut chunk); - self.expected_at_b - .entry(stream_id) - .or_default() - .extend_from_slice(&bytes[..accepted]); + if let Ok(mut stream) = self.harness.a.fsm.stream(stream_id) { + if let Some(mut writer) = stream.writer() { + let accepted = writer.write(&mut chunk); + self.expected_at_b + .entry(stream_id) + .or_default() + .extend_from_slice(&bytes[..accepted]); + } } } } Action::WriteB { slot, bytes } => { if let Some(stream_id) = self.slots_b[*slot] { let mut chunk = Bytes::copy_from_slice(bytes); - if let Ok(mut writer) = self.harness.b.fsm.write_stream(stream_id) { - let accepted = writer.write(&mut chunk); - self.expected_at_a - .entry(stream_id) - .or_default() - .extend_from_slice(&bytes[..accepted]); + if let Ok(mut stream) = self.harness.b.fsm.stream(stream_id) { + if let Some(mut writer) = stream.writer() { + let accepted = writer.write(&mut chunk); + self.expected_at_a + .entry(stream_id) + .or_default() + .extend_from_slice(&bytes[..accepted]); + } } } } Action::FinishA(slot) => { if let Some(stream_id) = self.slots_a[*slot] { - if let Ok(writer) = self.harness.a.fsm.write_stream(stream_id) { - writer.finish(); - self.finished_by_a.insert(stream_id); + if let Ok(mut stream) = self.harness.a.fsm.stream(stream_id) { + if let Some(writer) = stream.writer() { + writer.finish(); + self.finished_by_a.insert(stream_id); + } } } } Action::FinishB(slot) => { if let Some(stream_id) = self.slots_b[*slot] { - if let Ok(writer) = self.harness.b.fsm.write_stream(stream_id) { - writer.finish(); - self.finished_by_b.insert(stream_id); + if let Ok(mut stream) = self.harness.b.fsm.stream(stream_id) { + if let Some(writer) = stream.writer() { + writer.finish(); + self.finished_by_b.insert(stream_id); + } } } } Action::CloseA(slot) => { if let Some(stream_id) = self.slots_a[*slot] { - if self - .harness - .a - .fsm - .close_stream(stream_id, CloseTarget::Both, StreamCloseCode(0)) - .is_ok() - { + if let Ok(mut stream) = self.harness.a.fsm.stream(stream_id) { + stream.close(CloseTarget::Both, StreamCloseCode(0)); self.closed_by_a.insert(stream_id); self.slots_a[*slot] = None; } @@ -371,13 +376,8 @@ impl Runner { } Action::CloseB(slot) => { if let Some(stream_id) = self.slots_b[*slot] { - if self - .harness - .b - .fsm - .close_stream(stream_id, CloseTarget::Both, StreamCloseCode(0)) - .is_ok() - { + if let Ok(mut stream) = self.harness.b.fsm.stream(stream_id) { + stream.close(CloseTarget::Both, StreamCloseCode(0)); self.closed_by_b.insert(stream_id); self.slots_b[*slot] = None; } @@ -903,14 +903,13 @@ fn take_taken(taken: &mut Vec, index: usize) -> Option { fn drain_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { let mut out = Vec::new(); + let Ok(mut stream) = fsm.stream(stream_id) else { + return out; + }; loop { - let Some(chunks) = fsm.stream_read(stream_id) else { - break; - }; - let mut read = 0usize; - for chunk in chunks { + for chunk in stream.read() { out.extend_from_slice(&chunk); read += chunk.len(); } @@ -919,7 +918,7 @@ fn drain_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { break; } - fsm.stream_read_commit(stream_id, read).unwrap(); + stream.commit_read(read).unwrap(); } out diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 8ccbeb50..c4259635 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -4,34 +4,46 @@ use bytes::Bytes; use ql_wire::{SessionClose, StreamId, VarInt}; use super::*; -use crate::{state::LinkState, NoSessionError, PeerStatus, QlFsmEvent, StreamError}; +use crate::{ + state::LinkState, CommitReadError, NoSessionError, PeerStatus, QlFsmEvent, StreamError, +}; fn stream_id(value: u32) -> StreamId { StreamId(VarInt::from_u32(value)) } +fn open_stream_id(fsm: &mut QlFsm) -> StreamId { + fsm.open_stream().unwrap().stream_id() +} + fn write_stream_bytes( fsm: &mut QlFsm, stream_id: StreamId, bytes: &[u8], ) -> Result { let mut bytes = Bytes::copy_from_slice(bytes); - let mut writer = fsm.write_stream(stream_id)?; + let mut stream = fsm.stream(stream_id)?; + let Some(mut writer) = stream.writer() else { + return Err(StreamError::NotWritable); + }; Ok(writer.write(&mut bytes)) } fn read_stream_all(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { let mut out = Vec::new(); + let Ok(mut stream) = fsm.stream(stream_id) else { + return out; + }; loop { let mut read = 0; - for chunk in fsm.stream_read(stream_id).unwrap() { + for chunk in stream.read() { out.extend_from_slice(&chunk); read += chunk.len(); } if read == 0 { break; } - fsm.stream_read_commit(stream_id, read).unwrap(); + stream.commit_read(read).unwrap(); } out } @@ -40,12 +52,12 @@ fn read_stream_all(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { fn connected_fsms_deliver_stream_data() { let mut harness = Harness::connected(QlFsmConfig::default()); - let stream_id = harness.a.fsm.open_stream().unwrap(); + let stream_id = open_stream_id(&mut harness.a.fsm); assert_eq!( write_stream_bytes(&mut harness.a.fsm, stream_id, b"hello").unwrap(), 5 ); - harness.a.fsm.write_stream(stream_id).unwrap().finish(); + harness.a.fsm.stream(stream_id).unwrap().writer().unwrap().finish(); harness.pump(); @@ -69,7 +81,7 @@ fn session_retransmit_uses_new_record_seq() { let config = QlFsmConfig::default(); let mut harness = Harness::connected(config); - let stream_id = harness.a.fsm.open_stream().unwrap(); + let stream_id = open_stream_id(&mut harness.a.fsm); assert_eq!( write_stream_bytes(&mut harness.a.fsm, stream_id, b"retry").unwrap(), 5 @@ -115,8 +127,8 @@ fn session_retransmit_uses_new_record_seq() { fn simultaneous_opens_use_even_and_odd_stream_ids() { let mut harness = Harness::connected(QlFsmConfig::default()); - let stream_id_a = harness.a.fsm.open_stream().unwrap(); - let stream_id_b = harness.b.fsm.open_stream().unwrap(); + let stream_id_a = open_stream_id(&mut harness.a.fsm); + let stream_id_b = open_stream_id(&mut harness.b.fsm); assert_ne!(stream_id_a, stream_id_b); assert!( @@ -170,7 +182,7 @@ fn disconnected_stream_operations_fail_with_no_session() { let mut harness = Harness::paired_known(QlFsmConfig::default()); let missing = stream_id(0); - assert_eq!(harness.a.fsm.open_stream(), Err(NoSessionError)); + assert!(matches!(harness.a.fsm.open_stream(), Err(NoSessionError))); assert_eq!( write_stream_bytes(&mut harness.a.fsm, missing, b"queued"), Err(StreamError::NoSession) @@ -179,39 +191,56 @@ fn disconnected_stream_operations_fail_with_no_session() { harness .a .fsm - .write_stream(missing) - .map(crate::session::StreamWriter::finish), + .stream(missing) + .map(|mut stream| stream.writer().unwrap().finish()), Err(StreamError::NoSession) ); assert_eq!( - harness.a.fsm.close_stream( - missing, - ql_wire::CloseTarget::Both, - ql_wire::StreamCloseCode(0) - ), + harness + .a + .fsm + .stream(missing) + .map(|mut stream| stream.close(ql_wire::CloseTarget::Both, ql_wire::StreamCloseCode(0))), Err(StreamError::NoSession) ); assert_eq!(harness.a.fsm.queue_ping(), Err(NoSessionError)); - assert_eq!( - harness.a.fsm.stream_read_commit(missing, 1), + assert!(matches!( + harness.a.fsm.stream(missing), Err(StreamError::NoSession) - ); + )); } #[test] fn disconnected_stream_read_accessors_return_none() { - let harness = Harness::paired_known(QlFsmConfig::default()); + let mut harness = Harness::paired_known(QlFsmConfig::default()); let missing = stream_id(0); - assert!(harness.a.fsm.stream_read(missing).is_none()); - assert!(harness.a.fsm.stream_available_bytes(missing).is_none()); + assert!(matches!( + harness.a.fsm.stream(missing), + Err(StreamError::NoSession) + )); +} + +#[test] +fn commit_read_rejects_lengths_past_readable_prefix() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + let stream_id = open_stream_id(&mut harness.a.fsm); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"hi").unwrap(), + 2 + ); + harness.pump(); + + let mut stream = harness.b.fsm.stream(stream_id).unwrap(); + assert_eq!(stream.commit_read(3), Err(CommitReadError)); } #[test] fn returned_session_write_is_reissued_with_new_record_seq() { let mut harness = Harness::connected(QlFsmConfig::default()); - let stream_id = harness.a.fsm.open_stream().unwrap(); + let stream_id = open_stream_id(&mut harness.a.fsm); assert_eq!( write_stream_bytes(&mut harness.a.fsm, stream_id, b"retry").unwrap(), 5 @@ -254,7 +283,7 @@ fn unconfirmed_session_write_does_not_start_retransmit_timer() { let config = QlFsmConfig::default(); let mut harness = Harness::connected(config); - let stream_id = harness.a.fsm.open_stream().unwrap(); + let stream_id = open_stream_id(&mut harness.a.fsm); assert_eq!( write_stream_bytes(&mut harness.a.fsm, stream_id, b"retry").unwrap(), 5 @@ -290,7 +319,7 @@ fn ack_frame_releases_stream_capacity_and_emits_writable() { }; let mut harness = Harness::connected(config); - let stream_id = harness.a.fsm.open_stream().unwrap(); + let stream_id = open_stream_id(&mut harness.a.fsm); assert_eq!( write_stream_bytes(&mut harness.a.fsm, stream_id, b"abcd").unwrap(), 4 @@ -331,7 +360,7 @@ fn session_records_contain_ack_frames_after_delivery() { let config = QlFsmConfig::default(); let mut harness = Harness::connected(config); - let stream_id = harness.a.fsm.open_stream().unwrap(); + let stream_id = open_stream_id(&mut harness.a.fsm); assert_eq!( write_stream_bytes(&mut harness.a.fsm, stream_id, b"x").unwrap(), 1 @@ -370,7 +399,7 @@ fn first_stream_data_uses_negotiated_initial_peer_credit() { let ik2 = harness.next_outbound_b().unwrap(); harness.deliver_to_a(ik2); - let stream_id = harness.a.fsm.open_stream().unwrap(); + let stream_id = open_stream_id(&mut harness.a.fsm); assert_eq!( write_stream_bytes(&mut harness.a.fsm, stream_id, b"hello").unwrap(), 5 diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 0f42d67e..7814bce0 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -155,40 +155,41 @@ impl DriverState { return; }; - match fsm.open_stream() { - Ok(stream_id) => { - let (response_reader, response_writer) = chunk_slot::new(); - let (response_terminal_tx, response_terminal_rx) = oneshot::channel(); - self.streams.insert( - stream_id, - DriverStreamIo::new( - true, - Some(OutboundIo::new(request_reader, request_terminal)), - Some(InboundIo::new(response_writer, response_terminal_tx)), - ), - ); - let reader = ByteReader::new( - stream_id, - CloseTarget::Return, - response_reader, - response_terminal_rx, - RuntimeHandle::new(runtime_tx), - ); - if start.send(Ok((stream_id, reader))).is_err() { - if let Some(stream) = self.streams.get_mut(&stream_id) { - stream.inbound_close(); - stream.outbound_close(); - } - let _ = - fsm.close_stream(stream_id, CloseTarget::Both, StreamCloseCode(0)); - return; - } - self.poll_stream(fsm, stream_id); - } + let mut stream_ops = match fsm.open_stream() { + Ok(stream_ops) => stream_ops, Err(error) => { let _ = start.send(Err(error)); + return; } + }; + let stream_id = stream_ops.stream_id(); + let (response_reader, response_writer) = chunk_slot::new(); + let (response_terminal_tx, response_terminal_rx) = oneshot::channel(); + self.streams.insert( + stream_id, + DriverStreamIo::new( + true, + Some(OutboundIo::new(request_reader, request_terminal)), + Some(InboundIo::new(response_writer, response_terminal_tx)), + ), + ); + let reader = ByteReader::new( + stream_id, + CloseTarget::Return, + response_reader, + response_terminal_rx, + RuntimeHandle::new(runtime_tx), + ); + if start.send(Ok((stream_id, reader))).is_err() { + if let Some(stream) = self.streams.get_mut(&stream_id) { + stream.inbound_close(); + stream.outbound_close(); + } + stream_ops.close(CloseTarget::Both, StreamCloseCode(0)); + return; } + drop(stream_ops); + self.poll_stream(fsm, stream_id); } RuntimeCommand::PollInbound { stream_id } => { self.handle_inbound_readable(fsm, stream_id); @@ -209,7 +210,9 @@ impl DriverState { stream.outbound_close(); } } - let _ = fsm.close_stream(stream_id, target, code); + if let Ok(mut stream) = fsm.stream(stream_id) { + stream.close(target, code); + } self.try_reap_stream(stream_id); } } @@ -296,7 +299,9 @@ impl DriverState { stream_id: StreamId, ) { let Some(runtime_tx) = self.runtime_tx.upgrade() else { - let _ = fsm.close_stream(stream_id, CloseTarget::Both, StreamCloseCode(0)); + if let Ok(mut stream) = fsm.stream(stream_id) { + stream.close(CloseTarget::Both, StreamCloseCode(0)); + } return; }; @@ -334,59 +339,52 @@ impl DriverState { } fn handle_inbound_readable(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { - loop { - let Some(_) = fsm.stream_available_bytes(stream_id) else { + let Ok(mut stream_ops) = fsm.stream(stream_id) else { + return; + }; + if stream_ops.readable_bytes() == 0 { + return; + } + let mut accepted = 0usize; + let mut peer_closed = false; + let target; + { + let Some(stream) = self.streams.get_mut(&stream_id) else { return; }; - let mut accepted = 0usize; - let mut blocked = false; - let mut peer_closed = false; - let target; - { - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - target = stream.inbound_target(); - let Some(chunks) = fsm.stream_read(stream_id) else { - return; - }; - for chunk in chunks { - if chunk.is_empty() { - continue; + target = stream.inbound_target(); + for chunk in stream_ops.read() { + if chunk.is_empty() { + continue; + } + match stream.inbound_try_write(chunk) { + InboundWriteResult::Accepted(n) => { + accepted += n; } - match stream.inbound_try_write(chunk) { - InboundWriteResult::Accepted(n) => { - accepted += n; - } - InboundWriteResult::Full => { - blocked = true; - break; - } - InboundWriteResult::Closed => { - peer_closed = true; - break; - } + InboundWriteResult::Full => { + break; + } + InboundWriteResult::Closed => { + peer_closed = true; + break; } } } + } - if accepted > 0 { - fsm.stream_read_commit(stream_id, accepted).unwrap(); - } - if peer_closed { - let _ = fsm.close_stream(stream_id, target, StreamCloseCode(0)); - self.try_reap_stream(stream_id); - break; - } - if accepted == 0 || blocked { - break; - } + if accepted > 0 { + stream_ops.commit_read(accepted).unwrap(); + } + if peer_closed { + stream_ops.close(target, StreamCloseCode(0)); + self.try_reap_stream(stream_id); } + drop(stream_ops); self.finish_inbound_if_ready(fsm, stream_id); } - fn handle_inbound_finished(&mut self, fsm: &QlFsm, stream_id: StreamId) { + fn handle_inbound_finished(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { let Some(stream) = self.streams.get_mut(&stream_id) else { return; }; @@ -394,9 +392,11 @@ impl DriverState { self.finish_inbound_if_ready(fsm, stream_id); } - fn finish_inbound_if_ready(&mut self, fsm: &QlFsm, stream_id: StreamId) { - if fsm.stream_available_bytes(stream_id).unwrap_or(0) != 0 { - return; + fn finish_inbound_if_ready(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { + if let Ok(stream_ops) = fsm.stream(stream_id) { + if stream_ops.readable_bytes() != 0 { + return; + } } let Some(stream) = self.streams.get_mut(&stream_id) else { @@ -466,8 +466,10 @@ impl DriverState { }; if reader.is_finished() { - if let Ok(writer) = fsm.write_stream(stream_id) { - writer.finish(); + if let Ok(mut stream_ops) = fsm.stream(stream_id) { + if let Some(writer) = stream_ops.writer() { + writer.finish(); + } } stream.outbound_close(); if stream.is_closed() { @@ -476,7 +478,10 @@ impl DriverState { return; } - let Ok(mut writer) = fsm.write_stream(stream_id) else { + let Ok(mut stream_ops) = fsm.stream(stream_id) else { + return; + }; + let Some(mut writer) = stream_ops.writer() else { return; }; diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index 68198132..610c090b 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -135,7 +135,7 @@ fn new_outbound_io() -> OutboundIo { #[test] fn handle_inbound_finished_reaps_closed_initiator_stream() { - let (mut state, fsm) = new_driver_state(); + let (mut state, mut fsm) = new_driver_state(); let stream_id = StreamId(1u32.into()); state.streams.insert( @@ -143,7 +143,7 @@ fn handle_inbound_finished_reaps_closed_initiator_stream() { DriverStreamIo::new(true, None, Some(new_inbound_io(1))), ); - state.handle_inbound_finished(&fsm, stream_id); + state.handle_inbound_finished(&mut fsm, stream_id); assert!(!state.streams.contains_key(&stream_id)); } From cb3e54738f8e0e3e9eb5397fa775219b57cb0674 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 8 Apr 2026 06:38:53 -0400 Subject: [PATCH 153/304] ql-fsm: private internal modules --- ql-fsm/src/implementation/handshake/mod.rs | 2 +- ql-fsm/src/lib.rs | 2 +- ql-fsm/src/session/mod.rs | 23 +++++++++++----------- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs index 233ce82a..3529591b 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -5,7 +5,7 @@ use ql_wire::{self as wire, EphemeralPublicKey, HandshakeMeta, QlCrypto, QlHands use super::emit_peer_status; use crate::{ - session::{stream_parity::StreamParity, SessionFsm, SessionFsmConfig}, + session::{SessionFsm, SessionFsmConfig, StreamParity}, state::{ConnectedState, LinkState, SessionTransport}, QlFsm, QlFsmError, QlFsmEvent, }; diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index cc7bf309..9a4548c1 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -32,7 +32,7 @@ pub use error::*; use ql_wire::{ PeerBundle, QlCrypto, QlIdentity, SessionClose, SessionCloseCode, StreamClose, StreamId, }; -pub use session::{stream_rx::StreamReadIter, StreamOps, StreamWriter}; +pub use session::{StreamOps, StreamReadIter, StreamWriter}; use crate::{ replay_cache::ReplayCache, diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 6467aba2..911e96bf 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -1,12 +1,14 @@ -pub(crate) mod range_set; -pub(crate) mod received_records; -pub(crate) mod remote_stream_history; -pub(crate) mod state; +pub use self::{stream_ops::*, stream_parity::*, stream_rx::*}; + +mod range_set; +mod received_records; +mod remote_stream_history; +mod state; mod stream_ops; -pub(crate) mod stream_parity; -pub(crate) mod stream_rx; -pub(crate) mod stream_tx; -pub(crate) mod tracked; +mod stream_parity; +mod stream_rx; +mod stream_tx; +mod tracked; #[cfg(test)] mod tests; @@ -20,7 +22,6 @@ use ql_wire::{ SessionRecordBuilder, StreamClose, StreamData, StreamId, StreamWindow, VarInt, WireError, }; -pub use self::stream_ops::*; use self::{ received_records::{ReceiveOutcome, ReceivedRecords}, remote_stream_history::RemoteStreamHistory, @@ -28,8 +29,6 @@ use self::{ AckState, InboundState, OutboundState, SessionFsmState, SessionState, StreamRole, StreamState, }, - stream_parity::StreamParity, - stream_rx::StreamRxError, stream_tx::StreamTxRange, tracked::{TrackedFrame, TrackedRecord, TrackedStreamData}, }; @@ -945,7 +944,7 @@ fn restore_stream_data(streams: &mut IndexMap, frame: Tra if matches!(stream.outbound_state, OutboundState::Closed) { return; } - stream.tx.retransmit(StreamTxRange { + stream.tx.retransmit(stream_tx::StreamTxRange { offset: frame.offset, len: frame.len, fin: frame.fin, From d73d721a0fb4f2c367603d8203568688660e7eb8 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 8 Apr 2026 09:39:16 -0400 Subject: [PATCH 154/304] ql: cleanup --- ql-fsm/src/tests/mod.rs | 2 +- ql-runtime/src/chunk_slot.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 6eb55ba7..12b696eb 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -18,7 +18,7 @@ use ql_wire::{ use sha2::{Digest, Sha256}; use crate::{ - session::{stream_parity::StreamParity, SessionFsm, SessionFsmConfig}, + session::{SessionFsm, SessionFsmConfig, StreamParity}, state::{ConnectedState, LinkState, SessionTransport}, FsmTime, OutboundWrite, QlFsm, QlFsmConfig, QlFsmError, QlFsmEvent, SessionWriteId, }; diff --git a/ql-runtime/src/chunk_slot.rs b/ql-runtime/src/chunk_slot.rs index 8280d48b..7d1ba990 100644 --- a/ql-runtime/src/chunk_slot.rs +++ b/ql-runtime/src/chunk_slot.rs @@ -65,7 +65,7 @@ impl ChunkSlotRx { self.inner.try_recv(max_len) } - pub(crate) fn poll_recv( + pub fn poll_recv( &self, max_len: usize, listener: &mut Option, From d9ddd82f21c0fc3f8b8f3b1128136a5f9b655e10 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 8 Apr 2026 10:22:33 -0400 Subject: [PATCH 155/304] ql: fmt --- ql-fsm/src/implementation/core.rs | 4 +--- ql-fsm/src/tests/session.rs | 9 ++++++++- ql-runtime/src/lib.rs | 7 ++----- ql-runtime/src/tests/handshake.rs | 5 +---- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index 5db7bd51..dd24d47b 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -1,9 +1,7 @@ use std::time::{Duration, Instant}; use bytes::Bytes; -use ql_wire::{ - self as wire, QlCrypto, SessionCloseCode, StreamId, WireDecode, -}; +use ql_wire::{self as wire, QlCrypto, SessionCloseCode, StreamId, WireDecode}; use crate::{ session::SessionEvent, state::LinkState, NoSessionError, OutboundWrite, QlFsm, QlFsmError, diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index c4259635..5368948d 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -57,7 +57,14 @@ fn connected_fsms_deliver_stream_data() { write_stream_bytes(&mut harness.a.fsm, stream_id, b"hello").unwrap(), 5 ); - harness.a.fsm.stream(stream_id).unwrap().writer().unwrap().finish(); + harness + .a + .fsm + .stream(stream_id) + .unwrap() + .writer() + .unwrap() + .finish(); harness.pump(); diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index f1fa7d85..04a58c4c 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -1,10 +1,7 @@ -pub use self::{ - error::QlStreamError, - handle::*, - platform::*, -}; pub use ql_fsm::NoSessionError; +pub use self::{error::QlStreamError, handle::*, platform::*}; + pub mod chunk_slot; pub(crate) mod command; pub(crate) mod driver; diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index ac250e52..b727186f 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -47,10 +47,7 @@ async fn opening_stream_requires_connection() { tokio::task::spawn_local(async move { runtime_b.run().await }); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - assert!(matches!( - handle_a.open_stream().await, - Err(NoSessionError) - )); + assert!(matches!(handle_a.open_stream().await, Err(NoSessionError))); }) .await; } From 9a7a28adba05be1c90a7c287e6ee472c5b244678 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 8 Apr 2026 10:24:31 -0400 Subject: [PATCH 156/304] ql-wire: xx handshake with pairing token --- ql-wire/src/handshake/mod.rs | 120 +++++++- ql-wire/src/handshake/xx.rs | 579 +++++++++++++++++++++++++++++++++++ ql-wire/src/record.rs | 30 +- ql-wire/src/tests.rs | 196 +++++++++++- 4 files changed, 921 insertions(+), 4 deletions(-) create mode 100644 ql-wire/src/handshake/xx.rs diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 13e61bf9..fbb1db1e 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -8,15 +8,18 @@ mod ik; mod kk; mod meta; mod transport_params; +mod xx; pub use ik::{Ik1, Ik2, IkHandshake}; pub use kk::{Kk1, Kk2, KkHandshake}; pub use meta::{HandshakeId, HandshakeMeta}; pub use transport_params::TransportParams; +pub use xx::{Xx1, Xx2, Xx3, Xx4, XxHandshake}; const SHA256_BLOCK_LEN: usize = 64; const PROTOCOL_IK: &[u8] = b"ql-wire:pq-ik:v1"; const PROTOCOL_KK: &[u8] = b"ql-wire:pq-kk:v1"; +const PROTOCOL_XX: &[u8] = b"ql-wire:pq-xx:v1"; const CONNECTION_ID_DOMAIN: &[u8] = b"ql-wire:conn-id:v1"; const HANDSHAKE_PREAMBLE_DOMAIN: &[u8] = b"ql-wire:handshake-preamble:v1"; @@ -50,6 +53,57 @@ impl codec::WireDecode for HandshakeHeader { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct PairingToken(pub [u8; Self::SIZE]); + +impl PairingToken { + pub const SIZE: usize = 16; +} + +impl WireEncode for PairingToken { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for PairingToken { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct XxHeader { + pub pairing_token: PairingToken, +} + +impl XxHeader { + pub const WIRE_SIZE: usize = PairingToken::SIZE; +} + +impl WireEncode for XxHeader { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.pairing_token.encode(out); + } +} + +impl codec::WireDecode for XxHeader { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + pairing_token: reader.decode()?, + }) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct EphemeralPublicKey { pub mlkem_public_key: MlKemPublicKey, @@ -330,6 +384,10 @@ fn init_ik_symmetric(crypto: &impl QlCrypto, responder_bundle: &PeerBundle) -> S symmetric } +fn init_xx_symmetric(crypto: &impl QlCrypto) -> SymmetricState { + SymmetricState::new(crypto, PROTOCOL_XX) +} + fn generate_ephemeral_keypair(crypto: &impl QlCrypto) -> EphemeralKeyPair { EphemeralKeyPair { mlkem: crypto.mlkem_generate_keypair(), @@ -351,9 +409,45 @@ fn mix_hash_routed_handshake( kind: HandshakeKind, meta: &HandshakeMeta, transport_params: TransportParams, +) { + mix_hash_handshake_preamble( + symmetric, + crypto, + &header.encode_vec(), + kind, + meta, + transport_params, + ); +} + +fn mix_hash_pairing_handshake( + symmetric: &mut SymmetricState, + crypto: &impl QlCrypto, + header: XxHeader, + kind: HandshakeKind, + meta: &HandshakeMeta, + transport_params: TransportParams, +) { + mix_hash_handshake_preamble( + symmetric, + crypto, + &header.encode_vec(), + kind, + meta, + transport_params, + ); +} + +fn mix_hash_handshake_preamble( + symmetric: &mut SymmetricState, + crypto: &impl QlCrypto, + header: &[u8], + kind: HandshakeKind, + meta: &HandshakeMeta, + transport_params: TransportParams, ) { symmetric.mix_hash(crypto, HANDSHAKE_PREAMBLE_DOMAIN); - symmetric.mix_hash(crypto, &header.encode_vec()); + symmetric.mix_hash(crypto, header); symmetric.mix_hash(crypto, &[kind as u8]); symmetric.mix_hash(crypto, &meta.encode_vec()); symmetric.mix_hash(crypto, &transport_params.encode_vec()); @@ -383,6 +477,30 @@ fn require_handshake_meta( } } +fn initialize_transport_params( + expected: &mut Option, + transport_params: TransportParams, +) -> Result<(), WireError> { + match expected { + Some(stored) if *stored != transport_params => Err(WireError::InvalidPayload), + Some(_) => Ok(()), + None => { + *expected = Some(transport_params); + Ok(()) + } + } +} + +fn require_transport_params( + expected: Option<&TransportParams>, + transport_params: TransportParams, +) -> Result<(), WireError> { + match expected { + Some(stored) if *stored == transport_params => Ok(()), + _ => Err(WireError::InvalidPayload), + } +} + fn encrypt_peer_bundle( crypto: &impl QlCrypto, symmetric: &mut SymmetricState, diff --git a/ql-wire/src/handshake/xx.rs b/ql-wire/src/handshake/xx.rs new file mode 100644 index 00000000..4e8d455b --- /dev/null +++ b/ql-wire/src/handshake/xx.rs @@ -0,0 +1,579 @@ +use super::{ + decrypt_mlkem_ciphertext, decrypt_peer_bundle, encrypt_mlkem_ciphertext, encrypt_peer_bundle, + finalize_handshake, generate_ephemeral_keypair, init_xx_symmetric, initialize_handshake_meta, + initialize_transport_params, mix_hash_ephemeral, mix_hash_pairing_handshake, + require_handshake_meta, require_transport_params, EncryptedMlKemCiphertext, + EncryptedPeerBundle, EphemeralKeyPair, EphemeralPublicKey, FinalizedHandshake, Role, + SymmetricState, TransportParams, XxHeader, +}; +use crate::{ + codec, ByteSlice, HandshakeKind, HandshakeMeta, MlKemCiphertext, PairingToken, PeerBundle, + QlCrypto, QlIdentity, WireEncode, WireError, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Xx1 { + pub header: XxHeader, + pub meta: HandshakeMeta, + pub transport_params: TransportParams, + pub ephemeral: EphemeralPublicKey, +} + +impl Xx1 { + pub const WIRE_SIZE: usize = XxHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + TransportParams::WIRE_SIZE + + EphemeralPublicKey::WIRE_SIZE; +} + +impl codec::WireDecode for Xx1 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + transport_params: reader.decode()?, + ephemeral: reader.decode()?, + }) + } +} + +impl WireEncode for Xx1 { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.transport_params.encode(out); + self.ephemeral.encode(out); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Xx2 { + pub header: XxHeader, + pub meta: HandshakeMeta, + pub transport_params: TransportParams, + pub ekem_ciphertext: MlKemCiphertext, + pub static_bundle: EncryptedPeerBundle, +} + +impl Xx2 { + pub const WIRE_SIZE: usize = XxHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + TransportParams::WIRE_SIZE + + MlKemCiphertext::SIZE + + EncryptedPeerBundle::WIRE_SIZE; +} + +impl codec::WireDecode for Xx2 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + transport_params: reader.decode()?, + ekem_ciphertext: reader.decode()?, + static_bundle: reader.decode()?, + }) + } +} + +impl WireEncode for Xx2 { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.transport_params.encode(out); + self.ekem_ciphertext.encode(out); + self.static_bundle.encode(out); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Xx3 { + pub header: XxHeader, + pub meta: HandshakeMeta, + pub transport_params: TransportParams, + pub skem_ciphertext: EncryptedMlKemCiphertext, + pub static_bundle: EncryptedPeerBundle, +} + +impl Xx3 { + pub const WIRE_SIZE: usize = XxHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + TransportParams::WIRE_SIZE + + EncryptedMlKemCiphertext::WIRE_SIZE + + EncryptedPeerBundle::WIRE_SIZE; +} + +impl codec::WireDecode for Xx3 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + transport_params: reader.decode()?, + skem_ciphertext: reader.decode()?, + static_bundle: reader.decode()?, + }) + } +} + +impl WireEncode for Xx3 { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.transport_params.encode(out); + self.skem_ciphertext.encode(out); + self.static_bundle.encode(out); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Xx4 { + pub header: XxHeader, + pub meta: HandshakeMeta, + pub transport_params: TransportParams, + pub skem_ciphertext: EncryptedMlKemCiphertext, +} + +impl Xx4 { + pub const WIRE_SIZE: usize = XxHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + TransportParams::WIRE_SIZE + + EncryptedMlKemCiphertext::WIRE_SIZE; +} + +impl codec::WireDecode for Xx4 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + transport_params: reader.decode()?, + skem_ciphertext: reader.decode()?, + }) + } +} + +impl WireEncode for Xx4 { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.transport_params.encode(out); + self.skem_ciphertext.encode(out); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum XxStep { + Send1, + Recv1, + Send2, + Recv2, + Send3, + Recv3, + Send4, + Recv4, + Done, +} + +#[derive(Debug, Clone)] +pub struct XxHandshake { + role: Role, + step: XxStep, + symmetric: SymmetricState, + local: QlIdentity, + pairing_token: PairingToken, + remote_bundle: Option, + local_ephemeral: Option, + remote_ephemeral: Option, + handshake_meta: Option, + local_transport_params: TransportParams, + remote_transport_params: Option, +} + +impl XxHandshake { + pub fn new_initiator( + crypto: &impl QlCrypto, + local: QlIdentity, + pairing_token: PairingToken, + local_transport_params: TransportParams, + ) -> Self { + Self { + role: Role::Initiator, + step: XxStep::Send1, + symmetric: init_xx_symmetric(crypto), + local, + pairing_token, + remote_bundle: None, + local_ephemeral: None, + remote_ephemeral: None, + handshake_meta: None, + local_transport_params, + remote_transport_params: None, + } + } + + pub fn new_responder( + crypto: &impl QlCrypto, + local: QlIdentity, + pairing_token: PairingToken, + local_transport_params: TransportParams, + ) -> Self { + Self { + role: Role::Responder, + step: XxStep::Recv1, + symmetric: init_xx_symmetric(crypto), + local, + pairing_token, + remote_bundle: None, + local_ephemeral: None, + remote_ephemeral: None, + handshake_meta: None, + local_transport_params, + remote_transport_params: None, + } + } + + pub fn is_finished(&self) -> bool { + self.step == XxStep::Done + } + + pub fn pairing_token(&self) -> PairingToken { + self.pairing_token + } + + pub fn remote_bundle(&self) -> Option<&PeerBundle> { + self.remote_bundle.as_ref() + } + + fn header(&self) -> XxHeader { + XxHeader { + pairing_token: self.pairing_token, + } + } + + fn ensure_inbound_header(&self, header: XxHeader) -> Result<(), WireError> { + if header == self.header() { + Ok(()) + } else { + Err(WireError::InvalidPayload) + } + } + + pub fn write_1( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != XxStep::Send1 { + return Err(WireError::InvalidState); + } + initialize_handshake_meta(&mut self.handshake_meta, meta)?; + let header = self.header(); + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx1, + &meta, + self.local_transport_params, + ); + + let local_ephemeral = generate_ephemeral_keypair(crypto); + let ephemeral = local_ephemeral.public(); + mix_hash_ephemeral(&mut self.symmetric, crypto, &ephemeral); + + self.local_ephemeral = Some(local_ephemeral); + self.step = XxStep::Recv2; + Ok(Xx1 { + header, + meta, + transport_params: self.local_transport_params, + ephemeral, + }) + } + + pub fn read_1( + &mut self, + crypto: &impl QlCrypto, + now_seconds: u64, + message: &Xx1, + ) -> Result<(), WireError> { + if self.step != XxStep::Recv1 { + return Err(WireError::InvalidState); + } + message.meta.ensure_not_expired(now_seconds)?; + initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; + self.ensure_inbound_header(message.header)?; + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Xx1, + &message.meta, + message.transport_params, + ); + mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); + + self.remote_ephemeral = Some(message.ephemeral.clone()); + initialize_transport_params(&mut self.remote_transport_params, message.transport_params)?; + self.step = XxStep::Send2; + Ok(()) + } + + pub fn write_2( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != XxStep::Send2 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), meta)?; + let header = self.header(); + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx2, + &meta, + self.local_transport_params, + ); + + let remote_ephemeral = self + .remote_ephemeral + .as_ref() + .ok_or(WireError::InvalidState)?; + let (ekem_ciphertext, ekem_secret) = + crypto.mlkem_encapsulate(&remote_ephemeral.mlkem_public_key); + self.symmetric.mix_hash(crypto, ekem_ciphertext.as_bytes()); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let static_bundle = encrypt_peer_bundle(crypto, &mut self.symmetric, &self.local.bundle())?; + + self.step = XxStep::Recv3; + Ok(Xx2 { + header, + meta, + transport_params: self.local_transport_params, + ekem_ciphertext, + static_bundle, + }) + } + + pub fn read_2( + &mut self, + crypto: &impl QlCrypto, + now_seconds: u64, + message: &Xx2, + ) -> Result<(), WireError> { + if self.step != XxStep::Recv2 { + return Err(WireError::InvalidState); + } + message.meta.ensure_not_expired(now_seconds)?; + require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; + self.ensure_inbound_header(message.header)?; + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Xx2, + &message.meta, + message.transport_params, + ); + + let local_ephemeral = self + .local_ephemeral + .as_ref() + .ok_or(WireError::InvalidState)?; + self.symmetric + .mix_hash(crypto, message.ekem_ciphertext.as_bytes()); + let ekem_secret = + crypto.mlkem_decapsulate(&local_ephemeral.mlkem.private, &message.ekem_ciphertext); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let remote_bundle = + decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; + self.remote_bundle = Some(remote_bundle); + initialize_transport_params(&mut self.remote_transport_params, message.transport_params)?; + self.step = XxStep::Send3; + Ok(()) + } + + pub fn write_3( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != XxStep::Send3 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), meta)?; + let header = self.header(); + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx3, + &meta, + self.local_transport_params, + ); + + let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); + let skem_ciphertext = + encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + let static_bundle = encrypt_peer_bundle(crypto, &mut self.symmetric, &self.local.bundle())?; + + self.step = XxStep::Recv4; + Ok(Xx3 { + header, + meta, + transport_params: self.local_transport_params, + skem_ciphertext, + static_bundle, + }) + } + + pub fn read_3( + &mut self, + crypto: &impl QlCrypto, + now_seconds: u64, + message: &Xx3, + ) -> Result<(), WireError> { + if self.step != XxStep::Recv3 { + return Err(WireError::InvalidState); + } + message.meta.ensure_not_expired(now_seconds)?; + require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; + self.ensure_inbound_header(message.header)?; + require_transport_params( + self.remote_transport_params.as_ref(), + message.transport_params, + )?; + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Xx3, + &message.meta, + message.transport_params, + ); + + let skem_ciphertext = + decrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &message.skem_ciphertext)?; + let skem_secret = crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + let remote_bundle = + decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; + self.remote_bundle = Some(remote_bundle); + self.step = XxStep::Send4; + Ok(()) + } + + pub fn write_4( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != XxStep::Send4 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), meta)?; + let header = self.header(); + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx4, + &meta, + self.local_transport_params, + ); + + let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); + let skem_ciphertext = + encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + self.step = XxStep::Done; + Ok(Xx4 { + header, + meta, + transport_params: self.local_transport_params, + skem_ciphertext, + }) + } + + pub fn read_4( + &mut self, + crypto: &impl QlCrypto, + now_seconds: u64, + message: &Xx4, + ) -> Result<(), WireError> { + if self.step != XxStep::Recv4 { + return Err(WireError::InvalidState); + } + message.meta.ensure_not_expired(now_seconds)?; + require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; + self.ensure_inbound_header(message.header)?; + require_transport_params( + self.remote_transport_params.as_ref(), + message.transport_params, + )?; + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Xx4, + &message.meta, + message.transport_params, + ); + + let skem_ciphertext = + decrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &message.skem_ciphertext)?; + let skem_secret = crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + self.step = XxStep::Done; + Ok(()) + } + + pub fn finalize(self, crypto: &impl QlCrypto) -> Result { + if !self.is_finished() { + return Err(WireError::InvalidState); + } + let remote_bundle = self.remote_bundle.ok_or(WireError::InvalidState)?; + let remote_transport_params = self + .remote_transport_params + .ok_or(WireError::InvalidState)?; + Ok(finalize_handshake( + crypto, + &self.symmetric, + self.role, + remote_bundle, + remote_transport_params, + )) + } +} diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index 191d5e4b..163a1bff 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -1,7 +1,7 @@ use crate::{ codec, encrypted_message::EncryptedMessage, - handshake::{Ik1, Ik2, Kk1, Kk2}, + handshake::{Ik1, Ik2, Kk1, Kk2, Xx1, Xx2, Xx3, Xx4}, ByteSlice, SessionHeader, WireDecode, WireEncode, WireError, QL_WIRE_VERSION, }; @@ -104,6 +104,10 @@ pub enum QlHandshakeRecord { Ik2(Ik2), Kk1(Kk1), Kk2(Kk2), + Xx1(Xx1), + Xx2(Xx2), + Xx3(Xx3), + Xx4(Xx4), } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -113,6 +117,10 @@ pub enum HandshakeKind { Ik2 = 2, Kk1 = 3, Kk2 = 4, + Xx1 = 5, + Xx2 = 6, + Xx3 = 7, + Xx4 = 8, } impl TryFrom for HandshakeKind { @@ -124,6 +132,10 @@ impl TryFrom for HandshakeKind { 2 => Ok(Self::Ik2), 3 => Ok(Self::Kk1), 4 => Ok(Self::Kk2), + 5 => Ok(Self::Xx1), + 6 => Ok(Self::Xx2), + 7 => Ok(Self::Xx3), + 8 => Ok(Self::Xx4), _ => Err(WireError::InvalidPayload), } } @@ -152,6 +164,10 @@ impl QlHandshakeRecord { Self::Ik2(_) => HandshakeKind::Ik2, Self::Kk1(_) => HandshakeKind::Kk1, Self::Kk2(_) => HandshakeKind::Kk2, + Self::Xx1(_) => HandshakeKind::Xx1, + Self::Xx2(_) => HandshakeKind::Xx2, + Self::Xx3(_) => HandshakeKind::Xx3, + Self::Xx4(_) => HandshakeKind::Xx4, } } } @@ -164,6 +180,10 @@ impl WireEncode for QlHandshakeRecord { Self::Ik2(message) => message.encoded_len(), Self::Kk1(message) => message.encoded_len(), Self::Kk2(message) => message.encoded_len(), + Self::Xx1(message) => message.encoded_len(), + Self::Xx2(message) => message.encoded_len(), + Self::Xx3(message) => message.encoded_len(), + Self::Xx4(message) => message.encoded_len(), } } @@ -174,6 +194,10 @@ impl WireEncode for QlHandshakeRecord { Self::Ik2(message) => message.encode(out), Self::Kk1(message) => message.encode(out), Self::Kk2(message) => message.encode(out), + Self::Xx1(message) => message.encode(out), + Self::Xx2(message) => message.encode(out), + Self::Xx3(message) => message.encode(out), + Self::Xx4(message) => message.encode(out), } } } @@ -186,6 +210,10 @@ impl WireDecode for QlHandshakeRecord { HandshakeKind::Ik2 => Ok(Self::Ik2(reader.decode()?)), HandshakeKind::Kk1 => Ok(Self::Kk1(reader.decode()?)), HandshakeKind::Kk2 => Ok(Self::Kk2(reader.decode()?)), + HandshakeKind::Xx1 => Ok(Self::Xx1(reader.decode()?)), + HandshakeKind::Xx2 => Ok(Self::Xx2(reader.decode()?)), + HandshakeKind::Xx3 => Ok(Self::Xx3(reader.decode()?)), + HandshakeKind::Xx4 => Ok(Self::Xx4(reader.decode()?)), } } } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 5d1a3fd1..57eca414 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -196,6 +196,16 @@ fn handshake_header(sender: u8, recipient: u8) -> HandshakeHeader { } } +fn pairing_token(byte: u8) -> PairingToken { + PairingToken([byte; PairingToken::SIZE]) +} + +fn xx_header(byte: u8) -> XxHeader { + XxHeader { + pairing_token: pairing_token(byte), + } +} + fn encrypt_record( crypto: &impl QlCrypto, header: SessionHeader, @@ -227,7 +237,7 @@ fn peer_bundle_round_trip() { } #[test] -fn handshake_record_round_trip_supports_ik_and_kk() { +fn handshake_record_round_trip_supports_ik_kk_and_xx() { let ik = QlHandshakeRecord::Ik1(Ik1 { header: handshake_header(1, 2), meta: handshake_meta(1), @@ -266,6 +276,24 @@ fn handshake_record_round_trip_supports_ik_and_kk() { } ); assert_eq!(decode_handshake_record(kk_encoded.as_slice()), kk); + + let xx = QlHandshakeRecord::Xx1(Xx1 { + header: xx_header(3), + meta: handshake_meta(3), + transport_params: handshake_transport_params(196_608), + ephemeral: EphemeralPublicKey { + mlkem_public_key: MlKemPublicKey::new(Box::new([17; MlKemPublicKey::SIZE])), + }, + }); + let xx_encoded = encode_record_vec(RecordType::Handshake, &xx); + assert_eq!( + RecordHeader::decode_bytes(xx_encoded.as_slice()).unwrap(), + RecordHeader { + version: QL_WIRE_VERSION, + record_type: RecordType::Handshake, + } + ); + assert_eq!(decode_handshake_record(xx_encoded.as_slice()), xx); } #[test] @@ -645,6 +673,135 @@ fn kk_handshake_rejects_tampered_transport_params() { ); } +#[test] +fn xx_handshake_rejects_tampered_pairing_token() { + let crypto = TestCrypto::new(32); + let initiator = make_identity(&crypto, 5); + let responder = make_identity(&crypto, 6); + let token = pairing_token(7); + + let mut initiator_state = + XxHandshake::new_initiator(&crypto, initiator, token, TransportParams::default()); + let mut responder_state = + XxHandshake::new_responder(&crypto, responder, token, TransportParams::default()); + + let mut m1 = initiator_state + .write_1(&crypto, handshake_meta(31)) + .unwrap(); + m1.header = xx_header(8); + + assert_eq!( + responder_state.read_1(&crypto, 0, &m1), + Err(WireError::InvalidPayload) + ); +} + +#[test] +fn xx_handshake_rejects_repeated_transport_param_change() { + let crypto = TestCrypto::new(33); + let initiator = make_identity(&crypto, 5); + let responder = make_identity(&crypto, 6); + let token = pairing_token(9); + + let mut initiator_state = XxHandshake::new_initiator( + &crypto, + initiator.clone(), + token, + handshake_transport_params(12_288), + ); + let mut responder_state = XxHandshake::new_responder( + &crypto, + responder, + token, + handshake_transport_params(24_576), + ); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(32)) + .unwrap(); + responder_state.read_1(&crypto, 0, &m1).unwrap(); + + let m2 = responder_state + .write_2(&crypto, handshake_meta(32)) + .unwrap(); + initiator_state.read_2(&crypto, 0, &m2).unwrap(); + + let mut m3 = initiator_state + .write_3(&crypto, handshake_meta(32)) + .unwrap(); + m3.transport_params.initial_stream_receive_window += 1; + + assert_eq!( + responder_state.read_3(&crypto, 0, &m3), + Err(WireError::InvalidPayload) + ); +} + +#[test] +fn xx_handshake_round_trip_derives_matching_transport_and_learns_remote() { + let crypto = TestCrypto::new(34); + let initiator = make_identity(&crypto, 7); + let responder = make_identity(&crypto, 8); + let token = pairing_token(10); + + let initiator_params = handshake_transport_params(28_672); + let responder_params = handshake_transport_params(57_344); + let mut initiator_state = + XxHandshake::new_initiator(&crypto, initiator.clone(), token, initiator_params); + let mut responder_state = + XxHandshake::new_responder(&crypto, responder.clone(), token, responder_params); + + assert_eq!(initiator_state.pairing_token(), token); + assert_eq!(responder_state.pairing_token(), token); + assert!(initiator_state.remote_bundle().is_none()); + assert!(responder_state.remote_bundle().is_none()); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(33)) + .unwrap(); + responder_state.read_1(&crypto, 0, &m1).unwrap(); + + let m2 = responder_state + .write_2(&crypto, handshake_meta(33)) + .unwrap(); + initiator_state.read_2(&crypto, 0, &m2).unwrap(); + assert_eq!(initiator_state.remote_bundle(), Some(&responder.bundle())); + assert!(responder_state.remote_bundle().is_none()); + + let m3 = initiator_state + .write_3(&crypto, handshake_meta(33)) + .unwrap(); + responder_state.read_3(&crypto, 0, &m3).unwrap(); + assert_eq!(responder_state.remote_bundle(), Some(&initiator.bundle())); + + let m4 = responder_state + .write_4(&crypto, handshake_meta(33)) + .unwrap(); + initiator_state.read_4(&crypto, 0, &m4).unwrap(); + + let initiator_final = initiator_state.finalize(&crypto).unwrap(); + let responder_final = responder_state.finalize(&crypto).unwrap(); + + assert_eq!( + initiator_final.handshake_hash, + responder_final.handshake_hash + ); + assert_eq!(initiator_final.tx_key, responder_final.rx_key); + assert_eq!(initiator_final.rx_key, responder_final.tx_key); + assert_eq!( + initiator_final.tx_connection_id, + responder_final.rx_connection_id + ); + assert_eq!( + initiator_final.rx_connection_id, + responder_final.tx_connection_id + ); + assert_eq!(initiator_final.remote_bundle, responder.bundle()); + assert_eq!(responder_final.remote_bundle, initiator.bundle()); + assert_eq!(initiator_final.remote_transport_params, responder_params); + assert_eq!(responder_final.remote_transport_params, initiator_params); +} + #[test] fn encrypted_session_record_round_trip_uses_connection_id_header() { let crypto = TestCrypto::new(40); @@ -786,7 +943,7 @@ fn protocol_record_size_breakdown() { ); let mut kk_responder = KkHandshake::new_responder( &crypto, - responder, + responder.clone(), initiator.bundle(), TransportParams::default(), ); @@ -800,6 +957,37 @@ fn protocol_record_size_breakdown() { let kk1 = QlHandshakeRecord::Kk1(kk1); let kk2 = QlHandshakeRecord::Kk2(kk2); + let token = pairing_token(0x42); + let mut xx_initiator = XxHandshake::new_initiator( + &crypto, + initiator.clone(), + token, + TransportParams::default(), + ); + let mut xx_responder = XxHandshake::new_responder( + &crypto, + responder.clone(), + token, + TransportParams::default(), + ); + + let xx1 = xx_initiator.write_1(&crypto, handshake_meta(301)).unwrap(); + xx_responder.read_1(&crypto, 0, &xx1).unwrap(); + + let xx2 = xx_responder.write_2(&crypto, handshake_meta(301)).unwrap(); + xx_initiator.read_2(&crypto, 0, &xx2).unwrap(); + + let xx3 = xx_initiator.write_3(&crypto, handshake_meta(301)).unwrap(); + xx_responder.read_3(&crypto, 0, &xx3).unwrap(); + + let xx4 = xx_responder.write_4(&crypto, handshake_meta(301)).unwrap(); + xx_initiator.read_4(&crypto, 0, &xx4).unwrap(); + + let xx1 = QlHandshakeRecord::Xx1(xx1); + let xx2 = QlHandshakeRecord::Xx2(xx2); + let xx3 = QlHandshakeRecord::Xx3(xx3); + let xx4 = QlHandshakeRecord::Xx4(xx4); + let session = ik_initiator.finalize(&crypto).unwrap(); let session_ping = encrypt_record( &crypto, @@ -843,6 +1031,10 @@ fn protocol_record_size_breakdown() { print_size("ql-wire pq ik2", ik2.encode_vec().len()); print_size("ql-wire pq kk1", kk1.encode_vec().len()); print_size("ql-wire pq kk2", kk2.encode_vec().len()); + print_size("ql-wire pq xx1", xx1.encode_vec().len()); + print_size("ql-wire pq xx2", xx2.encode_vec().len()); + print_size("ql-wire pq xx3", xx3.encode_vec().len()); + print_size("ql-wire pq xx4", xx4.encode_vec().len()); print_size("ql-wire session ping", session_ping.encode_vec().len()); print_size( "ql-wire session stream empty", From bef70595c218516c78f2b3f0f2205ff4220461fc Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 8 Apr 2026 10:44:14 -0400 Subject: [PATCH 157/304] ql-fsm: xx handshake --- ql-fsm/src/implementation/handshake/ik.rs | 7 +- ql-fsm/src/implementation/handshake/kk.rs | 6 +- ql-fsm/src/implementation/handshake/mod.rs | 47 ++++- ql-fsm/src/implementation/handshake/xx.rs | 222 +++++++++++++++++++++ ql-fsm/src/lib.rs | 58 +++++- ql-fsm/src/state.rs | 41 +++- ql-fsm/src/tests/handshake.rs | 135 +++++++++++++ ql-fsm/src/tests/mod.rs | 52 ++++- ql-fsm/src/tests/proptest.rs | 2 +- 9 files changed, 556 insertions(+), 14 deletions(-) create mode 100644 ql-fsm/src/implementation/handshake/xx.rs diff --git a/ql-fsm/src/implementation/handshake/ik.rs b/ql-fsm/src/implementation/handshake/ik.rs index b785da24..f7d9b0da 100644 --- a/ql-fsm/src/implementation/handshake/ik.rs +++ b/ql-fsm/src/implementation/handshake/ik.rs @@ -103,7 +103,12 @@ pub fn handle_ik2( pub fn should_ignore_inbound(fsm: &QlFsm, message: &Ik1) -> bool { match &fsm.state.link { - LinkState::Idle | LinkState::Connected(_) | LinkState::KkInitiator(_) => false, + LinkState::Idle + | LinkState::Connected(_) + | LinkState::KkInitiator(_) + | LinkState::XxInitiator(_) + | LinkState::XxResponder(_) + | LinkState::XxResponderPending(_) => false, LinkState::IkInitiator(state) => { if fsm.state.peer.as_ref().map(|peer| peer.xid) != Some(message.header.sender) { return false; diff --git a/ql-fsm/src/implementation/handshake/kk.rs b/ql-fsm/src/implementation/handshake/kk.rs index bf19dcb9..33e007fc 100644 --- a/ql-fsm/src/implementation/handshake/kk.rs +++ b/ql-fsm/src/implementation/handshake/kk.rs @@ -102,7 +102,11 @@ pub fn handle_kk2( pub fn should_ignore_inbound(fsm: &QlFsm, message: &Kk1) -> bool { match &fsm.state.link { - LinkState::Idle | LinkState::Connected(_) => false, + LinkState::Idle + | LinkState::Connected(_) + | LinkState::XxInitiator(_) + | LinkState::XxResponder(_) + | LinkState::XxResponderPending(_) => false, LinkState::IkInitiator(_) => true, LinkState::KkInitiator(state) => { if fsm.state.peer.as_ref().map(|peer| peer.xid) != Some(message.header.sender) { diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs index 3529591b..6177a1a2 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -1,7 +1,11 @@ mod ik; mod kk; +mod xx; -use ql_wire::{self as wire, EphemeralPublicKey, HandshakeMeta, QlCrypto, QlHandshakeRecord}; +use ql_wire::{ + self as wire, EphemeralPublicKey, HandshakeMeta, PairingToken, PeerBundle, QlCrypto, + QlHandshakeRecord, +}; use super::emit_peer_status; use crate::{ @@ -30,6 +34,16 @@ pub fn handle_connect_kk( kk::start_initiator(fsm, crypto, peer, &mut emit) } +pub fn handle_connect_xx( + fsm: &mut QlFsm, + token: PairingToken, + crypto: &impl QlCrypto, + mut emit: impl FnMut(QlFsmEvent), +) -> Result<(), QlFsmError> { + prepare_for_outbound_connect(fsm); + xx::start_initiator(fsm, crypto, token, &mut emit) +} + pub fn next_handshake_meta(fsm: &mut QlFsm) -> HandshakeMeta { let handshake_id = wire::HandshakeId(fsm.state.next_control_id); fsm.state.next_control_id = fsm.state.next_control_id.wrapping_add(1); @@ -47,6 +61,33 @@ pub fn enqueue_handshake(fsm: &mut QlFsm, record: QlHandshakeRecord) { fsm.state.handshake = Some(record); } +pub fn pending_xx_pairing(fsm: &QlFsm) -> Option<(PairingToken, &PeerBundle)> { + match &fsm.state.link { + crate::state::LinkState::XxResponderPending(state) => state + .handshake + .remote_bundle() + .map(|peer| (state.handshake.pairing_token(), peer)), + _ => None, + } +} + +pub fn handle_accept_pairing( + fsm: &mut QlFsm, + token: PairingToken, + crypto: &impl QlCrypto, + mut emit: impl FnMut(QlFsmEvent), +) -> Result<(), QlFsmError> { + xx::accept_pairing(fsm, crypto, token, &mut emit) +} + +pub fn handle_reject_pairing(fsm: &mut QlFsm, token: PairingToken) -> Result<(), QlFsmError> { + xx::reject_pairing(fsm, token) +} + +pub fn handle_disarm_pairing(fsm: &mut QlFsm) { + xx::disarm_pairing(fsm); +} + fn local_transport_params(fsm: &QlFsm) -> wire::TransportParams { wire::TransportParams { initial_stream_receive_window: fsm.config.session_stream_receive_buffer_size, @@ -75,6 +116,10 @@ pub fn handle_handshake_record( QlHandshakeRecord::Ik2(message) => ik::handle_ik2(fsm, crypto, message, emit), QlHandshakeRecord::Kk1(message) => kk::handle_kk1(fsm, crypto, message, emit), QlHandshakeRecord::Kk2(message) => kk::handle_kk2(fsm, crypto, message, emit), + QlHandshakeRecord::Xx1(message) => xx::handle_xx1(fsm, crypto, message, emit), + QlHandshakeRecord::Xx2(message) => xx::handle_xx2(fsm, crypto, message, emit), + QlHandshakeRecord::Xx3(message) => xx::handle_xx3(fsm, crypto, message, emit), + QlHandshakeRecord::Xx4(message) => xx::handle_xx4(fsm, crypto, message, emit), } } diff --git a/ql-fsm/src/implementation/handshake/xx.rs b/ql-fsm/src/implementation/handshake/xx.rs new file mode 100644 index 00000000..04df7e7e --- /dev/null +++ b/ql-fsm/src/implementation/handshake/xx.rs @@ -0,0 +1,222 @@ +use ql_wire::{self as wire, PairingToken, QlCrypto, QlHandshakeRecord, Xx1, Xx2, Xx3, Xx4}; + +use super::{ + emit_peer_status, enqueue_handshake, finish_handshake, is_replayed_handshake_start, + reset_connected_session_if_needed, +}; +use crate::{ + state::{ + LinkState, SessionTransport, XxInitiatorState, XxResponderPendingState, XxResponderState, + }, + QlFsm, QlFsmError, QlFsmEvent, +}; + +pub fn start_initiator( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + token: PairingToken, + emit: &mut impl FnMut(QlFsmEvent), +) -> Result<(), QlFsmError> { + let meta = super::next_handshake_meta(fsm); + let mut handshake = wire::XxHandshake::new_initiator( + crypto, + fsm.identity.clone(), + token, + super::local_transport_params(fsm), + ); + let message = handshake.write_1(crypto, meta)?; + + fsm.state.link = LinkState::XxInitiator(XxInitiatorState { + handshake_id: meta.handshake_id, + initial_ephemeral: message.ephemeral.clone(), + handshake, + deadline: fsm.state.now.instant + fsm.config.handshake_timeout, + }); + enqueue_handshake(fsm, QlHandshakeRecord::Xx1(message)); + emit_peer_status(fsm, emit); + Ok(()) +} + +pub fn handle_xx1( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Xx1, + _emit: &mut impl FnMut(QlFsmEvent), +) -> Result<(), QlFsmError> { + if should_ignore_inbound(fsm, message) { + return Ok(()); + } + if is_replayed_handshake_start(fsm, message.meta) { + return Ok(()); + } + if fsm.state.armed_pairing_token != Some(message.header.pairing_token) { + return Ok(()); + } + + reset_connected_session_if_needed(fsm); + + let mut handshake = wire::XxHandshake::new_responder( + crypto, + fsm.identity.clone(), + message.header.pairing_token, + super::local_transport_params(fsm), + ); + handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; + let outbound = handshake.write_2(crypto, message.meta)?; + fsm.state.link = LinkState::XxResponder(XxResponderState { + handshake, + handshake_meta: message.meta, + deadline: fsm.state.now.instant + fsm.config.handshake_timeout, + }); + fsm.state.handshake = None; + enqueue_handshake(fsm, QlHandshakeRecord::Xx2(outbound)); + Ok(()) +} + +pub fn handle_xx2( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Xx2, + _emit: &mut impl FnMut(QlFsmEvent), +) -> Result<(), QlFsmError> { + { + let LinkState::XxInitiator(state) = &mut fsm.state.link else { + return Ok(()); + }; + + if message.meta.handshake_id != state.handshake_id { + return Ok(()); + } + + state + .handshake + .read_2(crypto, fsm.state.now.unix_secs, message)?; + let outbound = state.handshake.write_3(crypto, message.meta)?; + fsm.state.handshake = None; + enqueue_handshake(fsm, QlHandshakeRecord::Xx3(outbound)); + } + + Ok(()) +} + +pub fn handle_xx3( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Xx3, + emit: &mut impl FnMut(QlFsmEvent), +) -> Result<(), QlFsmError> { + let LinkState::XxResponder(state) = &mut fsm.state.link else { + return Ok(()); + }; + + if message.meta.handshake_id != state.handshake_meta.handshake_id { + return Ok(()); + } + + state + .handshake + .read_3(crypto, fsm.state.now.unix_secs, message)?; + let deadline = state.deadline; + let handshake_meta = state.handshake_meta; + let LinkState::XxResponder(state) = fsm.state.link.take() else { + unreachable!("active XX responder was checked above"); + }; + fsm.state.link = LinkState::XxResponderPending(XxResponderPendingState { + handshake: state.handshake, + handshake_meta, + deadline, + }); + emit(QlFsmEvent::PairingPending); + Ok(()) +} + +pub fn handle_xx4( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Xx4, + emit: &mut impl FnMut(QlFsmEvent), +) -> Result<(), QlFsmError> { + { + let LinkState::XxInitiator(state) = &mut fsm.state.link else { + return Ok(()); + }; + + if message.meta.handshake_id != state.handshake_id { + return Ok(()); + } + + state + .handshake + .read_4(crypto, fsm.state.now.unix_secs, message)?; + } + + let LinkState::XxInitiator(state) = fsm.state.link.take() else { + unreachable!("active XX initiator was checked above"); + }; + let (transport, remote_bundle) = + SessionTransport::from_finalized(state.handshake.finalize(crypto)?); + finish_handshake(fsm, transport, remote_bundle, emit) +} + +pub fn accept_pairing( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + token: PairingToken, + emit: &mut impl FnMut(QlFsmEvent), +) -> Result<(), QlFsmError> { + { + let LinkState::XxResponderPending(state) = &mut fsm.state.link else { + return Err(QlFsmError::InvalidState); + }; + if state.handshake.pairing_token() != token { + return Err(QlFsmError::InvalidState); + } + let outbound = state.handshake.write_4(crypto, state.handshake_meta)?; + fsm.state.handshake = None; + enqueue_handshake(fsm, QlHandshakeRecord::Xx4(outbound)); + } + + let LinkState::XxResponderPending(state) = fsm.state.link.take() else { + unreachable!("pending XX responder was checked above"); + }; + let (transport, remote_bundle) = + SessionTransport::from_finalized(state.handshake.finalize(crypto)?); + finish_handshake(fsm, transport, remote_bundle, emit) +} + +pub fn reject_pairing(fsm: &mut QlFsm, token: PairingToken) -> Result<(), QlFsmError> { + let LinkState::XxResponderPending(state) = &fsm.state.link else { + return Err(QlFsmError::InvalidState); + }; + if state.handshake.pairing_token() != token { + return Err(QlFsmError::InvalidState); + } + + fsm.state.link = LinkState::Idle; + fsm.state.handshake = None; + Ok(()) +} + +pub fn disarm_pairing(fsm: &mut QlFsm) { + if matches!( + fsm.state.link, + LinkState::XxResponder(_) | LinkState::XxResponderPending(_) + ) { + fsm.state.link = LinkState::Idle; + fsm.state.handshake = None; + } +} + +pub fn should_ignore_inbound(fsm: &QlFsm, message: &Xx1) -> bool { + match &fsm.state.link { + LinkState::Idle | LinkState::Connected(_) => false, + LinkState::IkInitiator(_) | LinkState::KkInitiator(_) => true, + LinkState::XxResponder(_) | LinkState::XxResponderPending(_) => true, + LinkState::XxInitiator(state) => { + if state.handshake.pairing_token() != message.header.pairing_token { + return false; + } + super::local_start_wins(&state.initial_ephemeral, &message.ephemeral) + } + } +} diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 9a4548c1..390c579c 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -3,15 +3,16 @@ //! a caller drives `QlFsm` inside its own event loop //! //! inputs to that loop usually include -//! - app actions like `bind_peer`, `connect_ik`, `connect_kk`, `open_stream`, or `stream` +//! - app actions like `bind_peer`, `connect_ik`, `connect_kk`, `connect_xx`, `open_stream`, or +//! `stream` //! - inbound transport bytes passed to `receive` //! - a deadline expiring, handled by calling `on_timer` //! - transport write results passed to `confirm_session_write` or `reject_session_write` //! //! outputs from `QlFsm` are //! - outbound session and handshake records from `take_next_write` -//! - callback-driven `QlFsmEvent`s emitted during `connect_ik`, `connect_kk`, `receive`, and -//! `on_timer` +//! - callback-driven `QlFsmEvent`s emitted during `connect_ik`, `connect_kk`, `connect_xx`, +//! `receive`, and `on_timer` //! //! call `next_deadline` after handling current inputs and any emitted outputs //! use it to decide how long the outer loop can wait before `on_timer` must run @@ -30,7 +31,8 @@ use std::time::{Duration, Instant}; pub use bytes::Bytes; pub use error::*; use ql_wire::{ - PeerBundle, QlCrypto, QlIdentity, SessionClose, SessionCloseCode, StreamClose, StreamId, + PairingToken, PeerBundle, QlCrypto, QlIdentity, SessionClose, SessionCloseCode, StreamClose, + StreamId, }; pub use session::{StreamOps, StreamReadIter, StreamWriter}; @@ -64,6 +66,8 @@ pub enum PeerStatus { pub enum QlFsmEvent { /// a peer was learned during handshake completion NewPeer, + /// an inbound xx pairing is waiting for an accept or reject decision + PairingPending, /// the peer changed connection state PeerStatusChanged(PeerStatus), /// a stream was opened @@ -151,6 +155,7 @@ impl QlFsm { replay_cache: ReplayCache::default(), next_control_id: 1, peer: None, + armed_pairing_token: None, handshake: None, link: LinkState::Idle, now, @@ -168,6 +173,51 @@ impl QlFsm { self.state.peer.as_ref() } + /// arms acceptance of inbound xx pairings for a single token + pub fn arm_pairing(&mut self, token: PairingToken) { + self.state.armed_pairing_token = Some(token); + } + + /// disarms inbound xx pairing and rejects any in-flight inbound xx responder state + pub fn disarm_pairing(&mut self) { + self.state.armed_pairing_token = None; + implementation::handle_disarm_pairing(self); + } + + /// starts or replaces an outbound xx handshake using the supplied pairing token + pub fn connect_xx( + &mut self, + now: FsmTime, + token: PairingToken, + crypto: &impl QlCrypto, + emit: impl FnMut(QlFsmEvent), + ) -> Result<(), QlFsmError> { + self.state.now = now; + implementation::handle_connect_xx(self, token, crypto, emit) + } + + /// returns the pending inbound xx candidate token and peer, if any + pub fn pending_xx_pairing(&self) -> Option<(PairingToken, &PeerBundle)> { + implementation::pending_xx_pairing(self) + } + + /// accepts a pending inbound xx pairing for the matching token + pub fn accept_pairing( + &mut self, + now: FsmTime, + token: PairingToken, + crypto: &impl QlCrypto, + emit: impl FnMut(QlFsmEvent), + ) -> Result<(), QlFsmError> { + self.state.now = now; + implementation::handle_accept_pairing(self, token, crypto, emit) + } + + /// rejects a pending inbound xx pairing for the matching token + pub fn reject_pairing(&mut self, token: PairingToken) -> Result<(), QlFsmError> { + implementation::handle_reject_pairing(self, token) + } + /// starts or replaces an IK handshake with the currently bound peer pub fn connect_ik( &mut self, diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index c4b4fe96..71d9e6da 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -1,8 +1,8 @@ use std::time::Instant; use ql_wire::{ - ConnectionId, EphemeralPublicKey, HandshakeId, IkHandshake, KkHandshake, PeerBundle, - QlHandshakeRecord, SessionKey, TransportParams, + ConnectionId, EphemeralPublicKey, HandshakeId, HandshakeMeta, IkHandshake, KkHandshake, + PairingToken, PeerBundle, QlHandshakeRecord, SessionKey, TransportParams, XxHandshake, }; use crate::{replay_cache::ReplayCache, session::SessionFsm, FsmTime, NoSessionError, PeerStatus}; @@ -11,6 +11,7 @@ pub struct QlFsmState { pub replay_cache: ReplayCache, pub next_control_id: u32, pub peer: Option, + pub armed_pairing_token: Option, pub handshake: Option, pub link: LinkState, pub now: FsmTime, @@ -44,6 +45,9 @@ pub enum LinkState { Idle, IkInitiator(IkInitiatorState), KkInitiator(KkInitiatorState), + XxInitiator(XxInitiatorState), + XxResponder(XxResponderState), + XxResponderPending(XxResponderPendingState), Connected(ConnectedState), } @@ -68,6 +72,28 @@ pub struct KkInitiatorState { pub initial_ephemeral: EphemeralPublicKey, } +#[derive(Debug, Clone)] +pub struct XxInitiatorState { + pub handshake: XxHandshake, + pub handshake_id: HandshakeId, + pub deadline: Instant, + pub initial_ephemeral: EphemeralPublicKey, +} + +#[derive(Debug, Clone)] +pub struct XxResponderState { + pub handshake: XxHandshake, + pub handshake_meta: HandshakeMeta, + pub deadline: Instant, +} + +#[derive(Debug, Clone)] +pub struct XxResponderPendingState { + pub handshake: XxHandshake, + pub handshake_meta: HandshakeMeta, + pub deadline: Instant, +} + impl LinkState { pub fn take(&mut self) -> Self { std::mem::replace(self, Self::Idle) @@ -75,8 +101,12 @@ impl LinkState { pub fn status(&self) -> PeerStatus { match self { - Self::Idle => PeerStatus::Disconnected, - Self::IkInitiator(_) | Self::KkInitiator(_) => PeerStatus::Initiator, + Self::Idle | Self::XxResponder(_) | Self::XxResponderPending(_) => { + PeerStatus::Disconnected + } + Self::IkInitiator(_) | Self::KkInitiator(_) | Self::XxInitiator(_) => { + PeerStatus::Initiator + } Self::Connected(_) => PeerStatus::Connected, } } @@ -107,6 +137,9 @@ impl LinkState { Self::Idle | Self::Connected(_) => None, Self::IkInitiator(state) => Some(state.deadline), Self::KkInitiator(state) => Some(state.deadline), + Self::XxInitiator(state) => Some(state.deadline), + Self::XxResponder(state) => Some(state.deadline), + Self::XxResponderPending(state) => Some(state.deadline), } } diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index c8cbbbc0..65bd9c6f 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -27,6 +27,38 @@ fn kk_connect_round_trip_establishes_transport() { assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); } +#[test] +fn xx_connect_round_trip_establishes_transport_after_accept() { + let mut harness = Harness::paired(QlFsmConfig::default(), false, false); + let token = pairing_token(1); + + harness.b.fsm.arm_pairing(token); + harness.connect_xx_a(token).unwrap(); + + let xx1 = harness.next_outbound_a().unwrap(); + harness.deliver_to_b(xx1); + let xx2 = harness.next_outbound_b().unwrap(); + harness.deliver_to_a(xx2); + let xx3 = harness.next_outbound_a().unwrap(); + harness.deliver_to_b(xx3); + + assert_eq!(harness.take_event_b(), Some(QlFsmEvent::PairingPending)); + assert_eq!( + harness.b.fsm.pending_xx_pairing(), + Some((token, &harness.a.fsm.identity.bundle())) + ); + assert!(harness.next_outbound_b().is_none()); + + harness.accept_pairing_b(token).unwrap(); + let xx4 = harness.next_outbound_b().unwrap(); + harness.deliver_to_a(xx4); + + assert_eq!(harness.a.fsm.peer(), Some(&harness.b.fsm.identity.bundle())); + assert_eq!(harness.b.fsm.peer(), Some(&harness.a.fsm.identity.bundle())); + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + #[test] fn ik_connect_learns_remote_initial_stream_receive_window() { let mut harness = Harness::paired_known_with_configs( @@ -84,6 +116,11 @@ fn connect_methods_require_bound_peer() { fsm.connect_kk(time, &crypto, |_| {}), Err(QlFsmError::NoPeerBound) ); + + assert_eq!( + fsm.connect_xx(time, pairing_token(2), &crypto, |_| {}), + Ok(()) + ); } #[test] @@ -98,6 +135,100 @@ fn connect_ik_emits_initiator_status() { ); } +#[test] +fn inbound_xx1_ignored_when_pairing_token_not_armed() { + let mut harness = Harness::paired(QlFsmConfig::default(), false, false); + let token = pairing_token(3); + + harness.connect_xx_a(token).unwrap(); + let xx1 = harness.next_outbound_a().unwrap(); + harness.deliver_to_b(xx1); + + assert!(matches!(harness.b.fsm.state.link, LinkState::Idle)); + assert!(harness.drain_events_b().is_empty()); + assert!(harness.next_outbound_b().is_none()); +} + +#[test] +fn reject_pairing_drops_pending_xx_candidate() { + let mut harness = Harness::paired(QlFsmConfig::default(), false, false); + let token = pairing_token(4); + + harness.b.fsm.arm_pairing(token); + harness.connect_xx_a(token).unwrap(); + let xx1 = harness.next_outbound_a().unwrap(); + harness.deliver_to_b(xx1); + let xx2 = harness.next_outbound_b().unwrap(); + harness.deliver_to_a(xx2); + let xx3 = harness.next_outbound_a().unwrap(); + harness.deliver_to_b(xx3); + + assert_eq!(harness.take_event_b(), Some(QlFsmEvent::PairingPending)); + harness.reject_pairing_b(token).unwrap(); + + assert!(matches!(harness.b.fsm.state.link, LinkState::Idle)); + assert!(harness.next_outbound_b().is_none()); + assert!(harness.b.fsm.pending_xx_pairing().is_none()); +} + +#[test] +fn disarm_pairing_rejects_inflight_inbound_xx_responder() { + let mut harness = Harness::paired(QlFsmConfig::default(), false, false); + let token = pairing_token(5); + + harness.b.fsm.arm_pairing(token); + harness.connect_xx_a(token).unwrap(); + let xx1 = harness.next_outbound_a().unwrap(); + harness.deliver_to_b(xx1); + let xx2 = harness.next_outbound_b().unwrap(); + harness.deliver_to_a(xx2); + let xx3 = harness.next_outbound_a().unwrap(); + harness.deliver_to_b(xx3); + + assert_eq!(harness.take_event_b(), Some(QlFsmEvent::PairingPending)); + harness.b.fsm.disarm_pairing(); + + assert!(matches!(harness.b.fsm.state.link, LinkState::Idle)); + assert!(harness.b.fsm.pending_xx_pairing().is_none()); +} + +#[test] +fn simultaneous_xx_connect_converges() { + let mut harness = Harness::paired(QlFsmConfig::default(), false, false); + let token = pairing_token(6); + + harness.a.fsm.arm_pairing(token); + harness.b.fsm.arm_pairing(token); + harness.connect_xx_a(token).unwrap(); + harness.connect_xx_b(token).unwrap(); + + for _ in 0..2 { + if let Some(record) = harness.next_outbound_a() { + harness.deliver_to_b(record); + } + if let Some(record) = harness.next_outbound_b() { + harness.deliver_to_a(record); + } + } + + let event_a = harness.take_event_a(); + let event_b = harness.take_event_b(); + assert!( + matches!(event_a, Some(QlFsmEvent::PairingPending)) + || matches!(event_b, Some(QlFsmEvent::PairingPending)) + ); + if matches!(event_a, Some(QlFsmEvent::PairingPending)) { + harness.accept_pairing_a(token).unwrap(); + } + if matches!(event_b, Some(QlFsmEvent::PairingPending)) { + harness.accept_pairing_b(token).unwrap(); + } + harness.pump(); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + #[test] fn connect_ik_replaces_in_flight_attempt_and_ignores_stale_reply() { let mut harness = Harness::paired_known(QlFsmConfig::default()); @@ -267,5 +398,9 @@ fn handshake_id(record: &[u8]) -> ql_wire::HandshakeId { ql_wire::QlHandshakeRecord::Ik2(message) => message.meta.handshake_id, ql_wire::QlHandshakeRecord::Kk1(message) => message.meta.handshake_id, ql_wire::QlHandshakeRecord::Kk2(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Xx1(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Xx2(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Xx3(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Xx4(message) => message.meta.handshake_id, } } diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 12b696eb..06e7f717 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -12,8 +12,8 @@ use libcrux_aesgcm::AesGcm256Key; use libcrux_ml_kem::mlkem1024; use ql_wire::{ self, generate_identity, ConnectionId, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, - MlKemPublicKey, Nonce, QlAead, QlCrypto, QlHash, QlIdentity, QlKem, QlRandom, SessionKey, - TransportParams, ENCRYPTED_MESSAGE_AUTH_SIZE, XID, + MlKemPublicKey, Nonce, PairingToken, QlAead, QlCrypto, QlHash, QlIdentity, QlKem, QlRandom, + SessionKey, TransportParams, ENCRYPTED_MESSAGE_AUTH_SIZE, XID, }; use sha2::{Digest, Sha256}; @@ -320,6 +320,50 @@ impl Harness { fsm.connect_kk(time, crypto, |event| events.push_back(event)) } + fn connect_xx_a(&mut self, token: PairingToken) -> Result<(), QlFsmError> { + let time = self.time(); + let Node { + fsm, + crypto, + events, + } = &mut self.a; + fsm.connect_xx(time, token, crypto, |event| events.push_back(event)) + } + + fn connect_xx_b(&mut self, token: PairingToken) -> Result<(), QlFsmError> { + let time = self.time(); + let Node { + fsm, + crypto, + events, + } = &mut self.b; + fsm.connect_xx(time, token, crypto, |event| events.push_back(event)) + } + + fn accept_pairing_a(&mut self, token: PairingToken) -> Result<(), QlFsmError> { + let time = self.time(); + let Node { + fsm, + crypto, + events, + } = &mut self.a; + fsm.accept_pairing(time, token, crypto, |event| events.push_back(event)) + } + + fn accept_pairing_b(&mut self, token: PairingToken) -> Result<(), QlFsmError> { + let time = self.time(); + let Node { + fsm, + crypto, + events, + } = &mut self.b; + fsm.accept_pairing(time, token, crypto, |event| events.push_back(event)) + } + + fn reject_pairing_b(&mut self, token: PairingToken) -> Result<(), QlFsmError> { + self.b.fsm.reject_pairing(token) + } + fn deliver_to_a(&mut self, record: Vec) { let time = self.time(); let Node { @@ -406,6 +450,10 @@ fn test_identity(seed: u8) -> QlIdentity { generate_identity(&crypto, XID([seed; XID::SIZE])) } +fn pairing_token(byte: u8) -> PairingToken { + PairingToken([byte; PairingToken::SIZE]) +} + fn session_config(harness: &Harness, a: bool) -> SessionFsmConfig { let (local, peer, config) = if a { ( diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index 841b0017..8de5b39e 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -455,7 +455,7 @@ impl Runner { fn process_events(&mut self, side: Side, events: Vec) -> TestCaseResult { for event in events { match event { - QlFsmEvent::NewPeer => {} + QlFsmEvent::NewPeer | QlFsmEvent::PairingPending => {} QlFsmEvent::PeerStatusChanged(status) => { self.events_mut(side).note_peer_status(status); } From e24ab2dd2f273fdd5345170d31157b2c979bb5c9 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 8 Apr 2026 11:27:03 -0400 Subject: [PATCH 158/304] ql-runtime: xx handshake --- ql-runtime/src/command.rs | 9 +- ql-runtime/src/driver/mod.rs | 157 +++++++++++++++++++++++------- ql-runtime/src/driver/test.rs | 15 ++- ql-runtime/src/handle/mod.rs | 14 ++- ql-runtime/src/platform.rs | 10 +- ql-runtime/src/tests/handshake.rs | 70 +++++++++++++ ql-runtime/src/tests/mod.rs | 82 ++++++++++++++-- 7 files changed, 309 insertions(+), 48 deletions(-) diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index 019610f8..b419d8ae 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -1,5 +1,5 @@ use ql_fsm::NoSessionError; -use ql_wire::{CloseTarget, PeerBundle, StreamCloseCode, StreamId}; +use ql_wire::{CloseTarget, PairingToken, PeerBundle, StreamCloseCode, StreamId}; use crate::{chunk_slot::ChunkSlotRx, ByteReader, QlStreamError}; @@ -8,6 +8,13 @@ pub(crate) enum RuntimeCommand { peer: PeerBundle, }, Connect, + ArmPairing { + token: PairingToken, + }, + DisarmPairing, + StartPairing { + token: PairingToken, + }, OpenStream { request_reader: ChunkSlotRx, request_terminal: oneshot::Sender, diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 7814bce0..2ee8e09f 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -5,21 +5,21 @@ mod test; use std::{ collections::{hash_map::Entry, HashMap, VecDeque}, future::Future, - pin::Pin, + pin::{pin, Pin}, task::Poll, time::{Duration, Instant, SystemTime, UNIX_EPOCH}, }; use futures_lite::future::poll_fn; use ql_fsm::{FsmTime, QlFsm, QlFsmEvent, SessionWriteId}; -use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; +use ql_wire::{CloseTarget, PairingToken, StreamCloseCode, StreamId}; use self::state::{DriverState, DriverStreamIo, InboundIo, InboundWriteResult, OutboundIo}; use crate::{ chunk_slot, command::RuntimeCommand, handle::{ByteReader, ByteWriter, QlStream}, - platform::{QlPlatform, QlTimer}, + platform::{PlatformFuture, QlPlatform, QlTimer}, QlStreamError, Runtime, RuntimeHandle, }; @@ -45,32 +45,61 @@ impl Runtime

{ max_concurrent_message_writes: config.max_concurrent_message_writes, pending_fsm_events: VecDeque::new(), }; + let mut in_flight = Vec::new(); + let mut pairing_decision = None; let mut timer = platform.timer(); + let recv_future = rx.recv(); + let mut recv_future = pin!(recv_future); loop { state.fill_write_slots(&mut fsm, &platform, &mut in_flight); - - if rx.is_closed() && in_flight.is_empty() { - break; - } - + state.sync_pairing_decision_state(&fsm, &mut pairing_decision); timer.set_deadline(fsm.next_deadline()); - match next_driver_event(&rx, &mut timer, &mut in_flight).await { + match next_driver_event( + recv_future.as_mut(), + &mut timer, + &mut in_flight, + &mut pairing_decision, + ) + .await + { DriverEvent::Command(command) => { - state.drive_command(&mut fsm, command, &platform); + state.drive_command(&mut fsm, command, &platform, &mut pairing_decision); } DriverEvent::WriteCompleted { index, success } => { let write = in_flight.swap_remove(index); DriverState::drive_write_completed(&mut fsm, write.session_write_id, success); } + DriverEvent::PairingDecision { token, accept } => { + pairing_decision = None; + let _ = state.with_fsm_events( + &mut fsm, + &platform, + &mut pairing_decision, + |fsm, emit| { + if accept { + fsm.accept_pairing(now(), token, &platform, emit) + } else { + fsm.reject_pairing(token) + } + }, + ); + } DriverEvent::TimerExpired => { - state.with_fsm_events(&mut fsm, &platform, |fsm, emit| { - fsm.on_timer(now(), emit); - }); + state.with_fsm_events( + &mut fsm, + &platform, + &mut pairing_decision, + |fsm, emit| fsm.on_timer(now(), emit), + ); + } + DriverEvent::CommandsClosed => { + if in_flight.is_empty() && pairing_decision.is_none() { + break; + } } - DriverEvent::CommandsClosed => {} } } } @@ -81,25 +110,30 @@ struct InFlightWrite { future: F, } +struct InFlightPairingDecision<'a> { + token: PairingToken, + future: PlatformFuture<'a, bool>, +} + enum DriverEvent { Command(RuntimeCommand), WriteCompleted { index: usize, success: bool }, + PairingDecision { token: PairingToken, accept: bool }, TimerExpired, CommandsClosed, } #[allow(clippy::future_not_send)] async fn next_driver_event( - rx: &async_channel::Receiver, + mut recv_future: Pin<&mut async_channel::Recv<'_, RuntimeCommand>>, timer: &mut T, in_flight: &mut [InFlightWrite], + pairing_decision: &mut Option>, ) -> DriverEvent where T: QlTimer, F: Future + Unpin, { - let mut recv_future = (!rx.is_closed()).then(|| Box::pin(rx.recv())); - poll_fn(|cx| { for (index, write) in in_flight.iter_mut().enumerate() { if let Poll::Ready(success) = Pin::new(&mut write.future).poll(cx) { @@ -107,41 +141,57 @@ where } } - if timer.poll_wait(cx) == Poll::Ready(()) { - return Poll::Ready(DriverEvent::TimerExpired); + if let Some(decision) = pairing_decision.as_mut() { + if let Poll::Ready(accept) = Pin::new(&mut decision.future).poll(cx) { + return Poll::Ready(DriverEvent::PairingDecision { + token: decision.token, + accept, + }); + } } - if let Some(future) = recv_future.as_mut() { - if let Poll::Ready(res) = future.as_mut().poll(cx) { - return Poll::Ready( - res.map_or_else(|_| DriverEvent::CommandsClosed, DriverEvent::Command), - ); - } + if timer.poll_wait(cx) == Poll::Ready(()) { + return Poll::Ready(DriverEvent::TimerExpired); } - Poll::Pending + recv_future + .as_mut() + .poll(cx) + .map(|res| res.map_or_else(|_| DriverEvent::CommandsClosed, DriverEvent::Command)) }) .await } impl DriverState { - fn drive_command( + fn drive_command<'a, P: QlPlatform + 'a>( &mut self, fsm: &mut QlFsm, command: RuntimeCommand, - platform: &P, + platform: &'a P, + pairing_decision: &mut Option>, ) { match command { RuntimeCommand::BindPeer { peer } => { fsm.bind_peer(peer); } RuntimeCommand::Connect => { - let _ = self.with_fsm_events(fsm, platform, |fsm, emit| { + let _ = self.with_fsm_events(fsm, platform, pairing_decision, |fsm, emit| { fsm.connect_ik(now(), platform, emit) }); } + RuntimeCommand::ArmPairing { token } => { + fsm.arm_pairing(token); + } + RuntimeCommand::DisarmPairing => { + fsm.disarm_pairing(); + } + RuntimeCommand::StartPairing { token } => { + let _ = self.with_fsm_events(fsm, platform, pairing_decision, |fsm, emit| { + fsm.connect_xx(now(), token, platform, emit) + }); + } RuntimeCommand::Incoming(bytes) => { - let _ = self.with_fsm_events(fsm, platform, |fsm, emit| { + let _ = self.with_fsm_events(fsm, platform, pairing_decision, |fsm, emit| { fsm.receive(now(), bytes, platform, emit) }); } @@ -232,10 +282,11 @@ impl DriverState { } } - fn with_fsm_events( + fn with_fsm_events<'a, P: QlPlatform + 'a, T>( &mut self, fsm: &mut QlFsm, - platform: &P, + platform: &'a P, + pairing_decision: &mut Option>, run: impl FnOnce(&mut QlFsm, &mut dyn FnMut(QlFsmEvent)) -> T, ) -> T { let output = { @@ -243,20 +294,26 @@ impl DriverState { let mut emit = |event| pending.push_back(event); run(fsm, &mut emit) }; - self.process_pending_fsm_events(fsm, platform); + self.process_pending_fsm_events(fsm, platform, pairing_decision); output } - fn process_pending_fsm_events(&mut self, fsm: &mut QlFsm, platform: &P) { + fn process_pending_fsm_events<'a, P: QlPlatform + 'a>( + &mut self, + fsm: &mut QlFsm, + platform: &'a P, + pairing_decision: &mut Option>, + ) { while let Some(event) = self.pending_fsm_events.pop_front() { - self.process_fsm_event(fsm, platform, event); + self.process_fsm_event(fsm, platform, pairing_decision, event); } } - fn process_fsm_event( + fn process_fsm_event<'a, P: QlPlatform + 'a>( &mut self, fsm: &mut QlFsm, - platform: &P, + platform: &'a P, + pairing_decision: &mut Option>, event: QlFsmEvent, ) { match event { @@ -265,6 +322,17 @@ impl DriverState { platform.persist_peer(peer); } } + QlFsmEvent::PairingPending => { + if let Some((token, peer)) = fsm.pending_xx_pairing() { + let peer = peer.clone(); + *pairing_decision = Some(InFlightPairingDecision { + token, + future: Box::pin(async move { + platform.handle_pairing_request(token, peer).await + }), + }); + } + } QlFsmEvent::PeerStatusChanged(status) => { if let Some(peer) = fsm.peer().map(|peer| peer.xid) { platform.handle_peer_status(peer, status); @@ -439,6 +507,21 @@ impl DriverState { self.streams.clear(); } + fn sync_pairing_decision_state( + &self, + fsm: &QlFsm, + pairing_decision: &mut Option>, + ) { + if let Some(decision) = pairing_decision.as_ref() { + let is_current = fsm + .pending_xx_pairing() + .is_some_and(|(token, _)| token == decision.token); + if !is_current { + *pairing_decision = None; + } + } + } + fn fill_write_slots<'a, P: QlPlatform + 'a>( &self, fsm: &mut QlFsm, diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index 610c090b..df290a48 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -1,8 +1,8 @@ use std::task::{Context, Poll}; use ql_wire::{ - MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, PeerBundle, QlAead, QlHash, - QlKem, QlRandom, SessionKey, StreamClose, XID, + MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, PairingToken, PeerBundle, + QlAead, QlHash, QlKem, QlRandom, SessionKey, StreamClose, XID, }; use super::*; @@ -87,6 +87,7 @@ impl crate::platform::QlTimer for NoopTimer { impl QlPlatform for NoopPlatform { type Timer = NoopTimer; type WriteMessageFut<'a> = std::future::Ready; + type PairingDecisionFut<'a> = std::future::Ready; fn write_message(&self, _message: Vec) -> Self::WriteMessageFut<'_> { std::future::ready(true) @@ -104,6 +105,14 @@ impl QlPlatform for NoopPlatform { fn handle_peer_status(&self, _peer: XID, _status: ql_fsm::PeerStatus) {} + fn handle_pairing_request( + &self, + _token: PairingToken, + _peer: PeerBundle, + ) -> Self::PairingDecisionFut<'_> { + std::future::ready(false) + } + fn handle_inbound(&self, _event: QlStream) {} } @@ -195,6 +204,7 @@ fn local_close_command_reaps_when_other_half_is_already_closed() { let stream_id = StreamId(1u32.into()); let (request_reader, _request_writer) = chunk_slot::new(); let (request_terminal_tx, _request_terminal_rx) = oneshot::channel(); + let mut pairing_decision = None; state.streams.insert( stream_id, @@ -213,6 +223,7 @@ fn local_close_command_reaps_when_other_half_is_already_closed() { code: StreamCloseCode(0), }, &NoopPlatform, + &mut pairing_decision, ); assert!(!state.streams.contains_key(&stream_id)); diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index 5d38b985..9d339dad 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -2,7 +2,7 @@ mod reader; mod writer; use ql_fsm::NoSessionError; -use ql_wire::{CloseTarget, PeerBundle, StreamId}; +use ql_wire::{CloseTarget, PairingToken, PeerBundle, StreamId}; pub use self::{reader::*, writer::*}; use crate::{chunk_slot, command::RuntimeCommand}; @@ -28,6 +28,18 @@ impl RuntimeHandle { self.send(RuntimeCommand::Connect); } + pub fn arm_pairing(&self, token: PairingToken) { + self.send(RuntimeCommand::ArmPairing { token }); + } + + pub fn disarm_pairing(&self) { + self.send(RuntimeCommand::DisarmPairing); + } + + pub fn start_pairing(&self, token: PairingToken) { + self.send(RuntimeCommand::StartPairing { token }); + } + pub fn send_incoming(&self, bytes: Vec) { self.send(RuntimeCommand::Incoming(bytes)); } diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs index 411627c0..713efeb9 100644 --- a/ql-runtime/src/platform.rs +++ b/ql-runtime/src/platform.rs @@ -6,7 +6,7 @@ use std::{ }; use ql_fsm::PeerStatus; -use ql_wire::{PeerBundle, QlCrypto, XID}; +use ql_wire::{PairingToken, PeerBundle, QlCrypto, XID}; use crate::QlStream; @@ -20,6 +20,9 @@ pub trait QlTimer { pub trait QlPlatform: QlCrypto { type Timer: QlTimer; type WriteMessageFut<'a>: Future + Unpin + 'a + where + Self: 'a; + type PairingDecisionFut<'a>: Future + Unpin + 'a where Self: 'a; @@ -30,5 +33,10 @@ pub trait QlPlatform: QlCrypto { fn persist_peer(&self, peer: PeerBundle); fn handle_peer_status(&self, peer: XID, status: PeerStatus); + fn handle_pairing_request( + &self, + token: PairingToken, + peer: PeerBundle, + ) -> Self::PairingDecisionFut<'_>; fn handle_inbound(&self, event: QlStream); } diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index b727186f..c4916d94 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -139,3 +139,73 @@ async fn rejected_session_write_is_reissued() { }) .await; } + +#[tokio::test(flavor = "current_thread")] +async fn start_pairing_round_trip_uses_platform_decision_to_connect() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, status_b, pairing_b) = TestPlatform::new_with_pairing(2, true); + let identity_a = new_identity(11); + let identity_b = new_identity(73); + let token = pairing_token(7); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + handle_b.arm_pairing(token); + handle_a.start_pairing(token); + + let request = await_pairing_request(&pairing_b).await; + assert_eq!(request.token, token); + assert_eq!(request.peer, identity_a.bundle()); + + await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; + await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn start_pairing_rejects_when_platform_returns_false() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, status_a) = TestPlatform::new(1); + let (platform_b, outbound_b, _status_b, pairing_b) = + TestPlatform::new_with_pairing(2, false); + let identity_a = new_identity(11); + let identity_b = new_identity(73); + let token = pairing_token(8); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + + handle_b.arm_pairing(token); + handle_a.start_pairing(token); + + let request = await_pairing_request(&pairing_b).await; + assert_eq!(request.token, token); + assert_eq!(request.peer, identity_a.bundle()); + + assert_no_status_for( + &status_a, + identity_b.xid, + PeerStatus::Connected, + Duration::from_millis(150), + ) + .await; + }) + .await; +} diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 61602f60..f0883fc7 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -15,8 +15,8 @@ use libcrux_aesgcm::AesGcm256Key; use ql_fsm::PeerStatus; use ql_wire::{ generate_identity, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, - PeerBundle, QlAead, QlHash, QlIdentity, QlKem, QlRandom, RecordHeader, RecordType, SessionKey, - WireDecode, XID, + PairingToken, PeerBundle, QlAead, QlHash, QlIdentity, QlKem, QlRandom, RecordHeader, + RecordType, SessionKey, WireDecode, XID, }; use sha2::{Digest, Sha256}; use tokio::{task::LocalSet, time::Sleep}; @@ -39,6 +39,12 @@ struct StatusEvent { status: PeerStatus, } +#[derive(Debug, Clone, PartialEq, Eq)] +struct PairingRequest { + token: PairingToken, + peer: PeerBundle, +} + #[derive(Debug, Clone)] struct WriteStats { active: Arc, @@ -160,6 +166,8 @@ struct TestPlatform { outbound: Sender>, status: Sender, inbound: Option>, + pairing_requests: Option>, + pairing_accept: bool, nonce_seed: u8, nonce_counter: AtomicU8, encrypted_write_counter: AtomicUsize, @@ -170,7 +178,7 @@ struct TestPlatform { impl TestPlatform { fn new(seed: u8) -> (Self, Receiver>, Receiver) { - Self::new_inner(seed, None, None, Duration::ZERO, None) + Self::new_inner(seed, None, None, false, None, Duration::ZERO, None) } fn new_with_inbound( @@ -182,11 +190,40 @@ impl TestPlatform { Receiver, ) { let (inbound_tx, inbound_rx) = async_channel::unbounded(); - let (platform, outbound_rx, status_rx) = - Self::new_inner(seed, Some(inbound_tx), None, Duration::ZERO, None); + let (platform, outbound_rx, status_rx) = Self::new_inner( + seed, + Some(inbound_tx), + None, + false, + None, + Duration::ZERO, + None, + ); (platform, outbound_rx, status_rx, inbound_rx) } + fn new_with_pairing( + seed: u8, + pairing_accept: bool, + ) -> ( + Self, + Receiver>, + Receiver, + Receiver, + ) { + let (pairing_tx, pairing_rx) = async_channel::unbounded(); + let (platform, outbound_rx, status_rx) = Self::new_inner( + seed, + None, + Some(pairing_tx), + pairing_accept, + None, + Duration::ZERO, + None, + ); + (platform, outbound_rx, status_rx, pairing_rx) + } + fn new_with_session_write_failure( seed: u8, fail_encrypted_write_at: usize, @@ -194,6 +231,8 @@ impl TestPlatform { Self::new_inner( seed, None, + None, + false, Some(fail_encrypted_write_at), Duration::ZERO, None, @@ -205,12 +244,14 @@ impl TestPlatform { delay: Duration, write_stats: WriteStats, ) -> (Self, Receiver>, Receiver) { - Self::new_inner(seed, None, None, delay, Some(write_stats)) + Self::new_inner(seed, None, None, false, None, delay, Some(write_stats)) } fn new_inner( seed: u8, inbound: Option>, + pairing_requests: Option>, + pairing_accept: bool, fail_encrypted_write_at: Option, write_delay: Duration, write_stats: Option, @@ -222,6 +263,8 @@ impl TestPlatform { outbound, status, inbound, + pairing_requests, + pairing_accept, nonce_seed: seed, nonce_counter: AtomicU8::new(0), encrypted_write_counter: AtomicUsize::new(0), @@ -346,6 +389,7 @@ impl QlKem for TestPlatform { impl crate::platform::QlPlatform for TestPlatform { type Timer = TokioTimer; type WriteMessageFut<'a> = PlatformFuture<'a, bool>; + type PairingDecisionFut<'a> = PlatformFuture<'a, bool>; fn write_message(&self, message: Vec) -> Self::WriteMessageFut<'_> { let outbound = self.outbound.clone(); @@ -398,6 +442,21 @@ impl crate::platform::QlPlatform for TestPlatform { let _ = self.status.try_send(StatusEvent { peer, status }); } + fn handle_pairing_request( + &self, + token: PairingToken, + peer: PeerBundle, + ) -> Self::PairingDecisionFut<'_> { + let pairing_requests = self.pairing_requests.clone(); + let pairing_accept = self.pairing_accept; + Box::pin(async move { + if let Some(tx) = pairing_requests { + let _ = tx.send(PairingRequest { token, peer }).await; + } + pairing_accept + }) + } + fn handle_inbound(&self, event: QlStream) { if let Some(tx) = &self.inbound { let _ = tx.try_send(event); @@ -420,6 +479,10 @@ pub(crate) fn new_identity(seed: u8) -> QlIdentity { generate_identity(&crypto, XID([seed; XID::SIZE])) } +fn pairing_token(byte: u8) -> PairingToken { + PairingToken([byte; PairingToken::SIZE]) +} + fn register_peers( handle_a: &RuntimeHandle, handle_b: &RuntimeHandle, @@ -513,6 +576,13 @@ async fn assert_no_status_for( assert!(res.is_err(), "unexpected status event: {status:?}"); } +async fn await_pairing_request(receiver: &Receiver) -> PairingRequest { + tokio::time::timeout(Duration::from_secs(2), receiver.recv()) + .await + .unwrap() + .unwrap() +} + async fn read_all(mut stream: crate::ByteReader) -> Result, QlStreamError> { let mut data = Vec::new(); while let Some(chunk) = next_chunk(&mut stream).await? { From 55091bf9e2e8f091447bee6d5c7837930aa8962b Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 8 Apr 2026 12:56:18 -0400 Subject: [PATCH 159/304] ql-fsm: fix un-needed session close --- ql-fsm/src/session/mod.rs | 18 +++++++-- ql-fsm/src/session/stream_rx.rs | 4 ++ ql-fsm/src/session/tests.rs | 66 +++++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 4 deletions(-) diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 911e96bf..86c475f1 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -591,12 +591,22 @@ impl SessionFsm { }; let frame_offset = frame.offset.into_inner(); + let frame_end = frame_offset + .checked_add(frame.bytes.len() as u64) + .ok_or(())?; match stream.inbound_state { InboundState::Open => {} - InboundState::Discarding => return Ok(()), - InboundState::Finished | InboundState::Closed(_) => { - if frame_offset.saturating_add(frame.bytes.len() as u64) <= stream.rx.start_offset() - { + InboundState::Discarding | InboundState::Closed(_) => return Ok(()), + InboundState::Finished => { + // finished stream should always have a final offset + let Some(final_offset) = stream.rx.final_offset() else { + debug_assert!(false, "finished stream must retain final offset"); + return Ok(()); + }; + + // retransmitted data for an already-finished stream is fine as long as it stays + // within the finalized byte range and any repeated FIN lands on that same offset. + if (!frame.fin || frame_end == final_offset) && frame_end <= final_offset { return Ok(()); } self.fail_session( diff --git a/ql-fsm/src/session/stream_rx.rs b/ql-fsm/src/session/stream_rx.rs index c5012018..0f5a8eab 100644 --- a/ql-fsm/src/session/stream_rx.rs +++ b/ql-fsm/src/session/stream_rx.rs @@ -52,6 +52,10 @@ impl StreamRx { }) } + pub fn final_offset(&self) -> Option { + self.final_offset + } + pub fn max_buffered(&self) -> usize { self.max_buffered } diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 05538f14..e1549f5a 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -371,6 +371,45 @@ fn duplicate_remote_close_after_reap_is_ignored() { assert!(second.is_empty()); } +#[test] +fn late_remote_stream_data_after_close_is_ignored() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let stream_id = stream_id(1); + let close = vec![SessionFrame::StreamClose(StreamClose { + stream_id, + target: CloseTarget::Both, + code: StreamCloseCode(9), + })]; + let data = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + fin: false, + bytes: b"hello".to_vec(), + })]; + + let first = receive_events(&mut fsm, now, seq(1), &close); + assert_eq!( + first, + vec![ + SessionEvent::Opened(stream_id), + SessionEvent::Closed(StreamClose { + stream_id, + target: CloseTarget::Both, + code: StreamCloseCode(9), + }), + SessionEvent::WritableClosed(StreamClose { + stream_id, + target: CloseTarget::Both, + code: StreamCloseCode(9), + }), + ] + ); + + let second = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &data); + assert!(second.is_empty()); +} + #[test] fn duplicate_finished_remote_data_after_reap_is_ignored() { let now = Instant::now(); @@ -398,6 +437,33 @@ fn duplicate_finished_remote_data_after_reap_is_ignored() { assert!(second.is_empty()); } +#[test] +fn duplicate_finished_remote_data_before_read_is_ignored() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let stream_id = stream_id(1); + let record = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + fin: true, + bytes: b"hello".to_vec(), + })]; + + let first = receive_events(&mut fsm, now, seq(1), &record); + assert_eq!( + first, + vec![ + SessionEvent::Opened(stream_id), + SessionEvent::Readable(stream_id), + SessionEvent::Finished(stream_id), + ] + ); + + let second = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &record); + assert!(second.is_empty()); + assert_eq!(read_stream_all(&mut fsm, stream_id), b"hello".to_vec()); +} + #[test] fn out_of_order_remote_stream_first_observations_still_open_once_each() { let now = Instant::now(); From fbcf2ff08b349f046acbe446e4b87c31f696161e Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 8 Apr 2026 12:58:12 -0400 Subject: [PATCH 160/304] ql-wire: implement encode/decode on identity --- ql-wire/src/identity.rs | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/ql-wire/src/identity.rs b/ql-wire/src/identity.rs index 72602f70..1de12a01 100644 --- a/ql-wire/src/identity.rs +++ b/ql-wire/src/identity.rs @@ -50,6 +50,9 @@ pub struct QlIdentity { } impl QlIdentity { + pub const WIRE_SIZE: usize = + XID::SIZE + MlKemPrivateKey::SIZE + MlKemPublicKey::SIZE + size_of::(); + pub fn new( xid: XID, mlkem_private_key: MlKemPrivateKey, @@ -79,6 +82,30 @@ impl QlIdentity { } } +impl WireEncode for QlIdentity { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.xid.encode(out); + self.mlkem_private_key.as_bytes().encode(out); + self.mlkem_public_key.encode(out); + self.capabilities.encode(out); + } +} + +impl codec::WireDecode for QlIdentity { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + xid: reader.decode()?, + mlkem_private_key: MlKemPrivateKey::new(reader.decode()?), + mlkem_public_key: reader.decode()?, + capabilities: reader.decode()?, + }) + } +} + pub fn generate_identity(crypto: &impl QlCrypto, xid: XID) -> QlIdentity { let MlKemKeyPair { private: mlkem_private_key, From 4100224c94fa6354ea1413978b69c91d0ba061bc Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 8 Apr 2026 13:07:48 -0400 Subject: [PATCH 161/304] ql-wire: codec for Option --- ql-wire/src/codec.rs | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/ql-wire/src/codec.rs b/ql-wire/src/codec.rs index 27e06819..a1f276e6 100644 --- a/ql-wire/src/codec.rs +++ b/ql-wire/src/codec.rs @@ -171,6 +171,32 @@ impl WireEncode for bool { } } +impl WireEncode for Option { + fn encoded_len(&self) -> usize { + 1 + self.as_ref().map_or(0, |inner| inner.encoded_len()) + } + + fn encode(&self, out: &mut W) { + match self { + None => out.put_u8(0), + Some(inner) => { + out.put_u8(1); + inner.encode(out) + } + } + } +} + +impl> WireDecode for Option { + fn decode(reader: &mut Reader) -> Result { + match reader.decode::()? { + 0 => Ok(None), + 1 => Ok(Some(reader.decode::()?)), + _ => Err(WireError::InvalidPayload), + } + } +} + pub struct Reader { remaining: Option, } From a52f5823ecfa1451a9986c924bc11fe30518e98c Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 8 Apr 2026 17:00:32 -0400 Subject: [PATCH 162/304] ql-fsm: remove emit callback in favor of internal queue --- ql-fsm/src/implementation/core.rs | 56 ++++++----- ql-fsm/src/implementation/handshake/ik.rs | 11 +-- ql-fsm/src/implementation/handshake/kk.rs | 11 +-- ql-fsm/src/implementation/handshake/mod.rs | 48 ++++----- ql-fsm/src/implementation/handshake/xx.rs | 14 +-- ql-fsm/src/lib.rs | 49 +++++----- ql-fsm/src/tests/handshake.rs | 15 +-- ql-fsm/src/tests/mod.rs | 108 +++++++-------------- ql-fsm/src/tests/proptest.rs | 16 +-- ql-runtime/src/driver/mod.rs | 57 ++++------- ql-runtime/src/driver/state.rs | 4 +- ql-runtime/src/driver/test.rs | 1 - 12 files changed, 144 insertions(+), 246 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index dd24d47b..d17e4597 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -18,7 +18,6 @@ pub fn receive( fsm: &mut QlFsm, mut bytes: Vec, crypto: &impl QlCrypto, - mut emit: impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { let mut reader = wire::Reader::new(bytes.as_mut_slice()); let header = wire::RecordHeader::decode(&mut reader)?; @@ -30,9 +29,10 @@ pub fn receive( match header.record_type { wire::RecordType::Handshake => { let record = wire::QlHandshakeRecord::decode(&mut reader)?; - super::handle_handshake_record(fsm, crypto, &record, &mut emit) + super::handle_handshake_record(fsm, crypto, &record) } wire::RecordType::Session => { + let pending_events = &mut fsm.pending_events; let state = fsm .state .link @@ -60,30 +60,32 @@ pub fn receive( state .session .receive(fsm.state.now.instant, seq, frames, |event| { - session_closed |= forward_session_event(event, &mut emit); + session_closed |= forward_session_event(event, pending_events); }); if session_closed { - apply_session_closed(fsm, &mut emit); + apply_session_closed(fsm); } Ok(()) } } } -pub fn on_timer(fsm: &mut QlFsm, mut emit: impl FnMut(QlFsmEvent)) { - super::handle_timer(fsm, &mut emit); - let Some(state) = fsm.state.link.connected_mut() else { - return; - }; +pub fn on_timer(fsm: &mut QlFsm) { + super::handle_timer(fsm); let mut session_closed = false; - state.session.on_timer(fsm.state.now.instant, |event| { - session_closed |= forward_session_event(event, &mut emit); - }); + if let Some(state) = fsm.state.link.connected_mut() { + let pending_events = &mut fsm.pending_events; + state.session.on_timer(fsm.state.now.instant, |event| { + session_closed |= forward_session_event(event, pending_events); + }); + } else { + return; + } if session_closed { - apply_session_closed(fsm, &mut emit); + apply_session_closed(fsm); } } @@ -162,49 +164,53 @@ pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), NoSessionError> { state.session.queue_ping() } -pub fn emit_peer_status(fsm: &QlFsm, emit: &mut impl FnMut(QlFsmEvent)) { +pub fn emit_peer_status(fsm: &mut QlFsm) { if fsm.state.peer.is_some() { - emit(QlFsmEvent::PeerStatusChanged(fsm.state.link.status())); + fsm.pending_events + .push_back(QlFsmEvent::PeerStatusChanged(fsm.state.link.status())); } } -fn forward_session_event(event: SessionEvent, emit: &mut impl FnMut(QlFsmEvent)) -> bool { +fn forward_session_event( + event: SessionEvent, + pending_events: &mut std::collections::VecDeque, +) -> bool { match event { SessionEvent::Opened(stream_id) => { - emit(QlFsmEvent::Opened(stream_id)); + pending_events.push_back(QlFsmEvent::Opened(stream_id)); false } SessionEvent::Readable(stream_id) => { - emit(QlFsmEvent::Readable(stream_id)); + pending_events.push_back(QlFsmEvent::Readable(stream_id)); false } SessionEvent::Writable(stream_id) => { - emit(QlFsmEvent::Writable(stream_id)); + pending_events.push_back(QlFsmEvent::Writable(stream_id)); false } SessionEvent::Finished(stream_id) => { - emit(QlFsmEvent::Finished(stream_id)); + pending_events.push_back(QlFsmEvent::Finished(stream_id)); false } SessionEvent::Closed(frame) => { - emit(QlFsmEvent::Closed(frame)); + pending_events.push_back(QlFsmEvent::Closed(frame)); false } SessionEvent::WritableClosed(frame) => { - emit(QlFsmEvent::WritableClosed(frame)); + pending_events.push_back(QlFsmEvent::WritableClosed(frame)); false } SessionEvent::SessionClosed(close) => { - emit(QlFsmEvent::SessionClosed(close)); + pending_events.push_back(QlFsmEvent::SessionClosed(close)); true } } } -fn apply_session_closed(fsm: &mut QlFsm, emit: &mut impl FnMut(QlFsmEvent)) { +fn apply_session_closed(fsm: &mut QlFsm) { if matches!(fsm.state.link, crate::state::LinkState::Connected(_)) { fsm.state.link = crate::state::LinkState::Idle; - emit_peer_status(fsm, emit); + emit_peer_status(fsm); } } diff --git a/ql-fsm/src/implementation/handshake/ik.rs b/ql-fsm/src/implementation/handshake/ik.rs index f7d9b0da..ee9335e8 100644 --- a/ql-fsm/src/implementation/handshake/ik.rs +++ b/ql-fsm/src/implementation/handshake/ik.rs @@ -6,14 +6,13 @@ use super::{ }; use crate::{ state::{IkInitiatorState, LinkState, SessionTransport}, - QlFsm, QlFsmError, QlFsmEvent, + QlFsm, QlFsmError, }; pub fn start_initiator( fsm: &mut QlFsm, crypto: &impl QlCrypto, peer: PeerBundle, - emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { let meta = super::next_handshake_meta(fsm); let mut handshake = wire::IkHandshake::new_initiator( @@ -31,7 +30,7 @@ pub fn start_initiator( deadline: fsm.state.now.instant + fsm.config.handshake_timeout, }); enqueue_handshake(fsm, QlHandshakeRecord::Ik1(message)); - emit_peer_status(fsm, emit); + emit_peer_status(fsm); Ok(()) } @@ -39,7 +38,6 @@ pub fn handle_ik1( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Ik1, - emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { if should_ignore_inbound(fsm, message) { return Ok(()); @@ -67,7 +65,7 @@ pub fn handle_ik1( handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; let outbound = handshake.write_2(crypto, message.meta)?; let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); - finish_handshake(fsm, transport, remote_bundle, emit)?; + finish_handshake(fsm, transport, remote_bundle)?; fsm.state.handshake = None; enqueue_handshake(fsm, QlHandshakeRecord::Ik2(outbound)); Ok(()) @@ -77,7 +75,6 @@ pub fn handle_ik2( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Ik2, - emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { { let LinkState::IkInitiator(state) = &mut fsm.state.link else { @@ -98,7 +95,7 @@ pub fn handle_ik2( }; let (transport, remote_bundle) = SessionTransport::from_finalized(state.handshake.finalize(crypto)?); - finish_handshake(fsm, transport, remote_bundle, emit) + finish_handshake(fsm, transport, remote_bundle) } pub fn should_ignore_inbound(fsm: &QlFsm, message: &Ik1) -> bool { diff --git a/ql-fsm/src/implementation/handshake/kk.rs b/ql-fsm/src/implementation/handshake/kk.rs index 33e007fc..8f23d337 100644 --- a/ql-fsm/src/implementation/handshake/kk.rs +++ b/ql-fsm/src/implementation/handshake/kk.rs @@ -6,14 +6,13 @@ use super::{ }; use crate::{ state::{KkInitiatorState, LinkState, SessionTransport}, - QlFsm, QlFsmError, QlFsmEvent, + QlFsm, QlFsmError, }; pub fn start_initiator( fsm: &mut QlFsm, crypto: &impl QlCrypto, peer: PeerBundle, - emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { let meta = super::next_handshake_meta(fsm); let mut handshake = wire::KkHandshake::new_initiator( @@ -31,7 +30,7 @@ pub fn start_initiator( deadline: fsm.state.now.instant + fsm.config.handshake_timeout, }); enqueue_handshake(fsm, QlHandshakeRecord::Kk1(message)); - emit_peer_status(fsm, emit); + emit_peer_status(fsm); Ok(()) } @@ -39,7 +38,6 @@ pub fn handle_kk1( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Kk1, - emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { if should_ignore_inbound(fsm, message) { return Ok(()); @@ -66,7 +64,7 @@ pub fn handle_kk1( handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; let outbound = handshake.write_2(crypto, message.meta)?; let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); - finish_handshake(fsm, transport, remote_bundle, emit)?; + finish_handshake(fsm, transport, remote_bundle)?; fsm.state.handshake = None; enqueue_handshake(fsm, QlHandshakeRecord::Kk2(outbound)); Ok(()) @@ -76,7 +74,6 @@ pub fn handle_kk2( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Kk2, - emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { { let LinkState::KkInitiator(state) = &mut fsm.state.link else { @@ -97,7 +94,7 @@ pub fn handle_kk2( }; let (transport, remote_bundle) = SessionTransport::from_finalized(state.handshake.finalize(crypto)?); - finish_handshake(fsm, transport, remote_bundle, emit) + finish_handshake(fsm, transport, remote_bundle) } pub fn should_ignore_inbound(fsm: &QlFsm, message: &Kk1) -> bool { diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs index 6177a1a2..e0f22224 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -14,34 +14,25 @@ use crate::{ QlFsm, QlFsmError, QlFsmEvent, }; -pub fn handle_connect_ik( - fsm: &mut QlFsm, - crypto: &impl QlCrypto, - mut emit: impl FnMut(QlFsmEvent), -) -> Result<(), QlFsmError> { +pub fn handle_connect_ik(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { let peer = fsm.state.peer.clone().ok_or(QlFsmError::NoPeerBound)?; prepare_for_outbound_connect(fsm); - ik::start_initiator(fsm, crypto, peer, &mut emit) + ik::start_initiator(fsm, crypto, peer) } -pub fn handle_connect_kk( - fsm: &mut QlFsm, - crypto: &impl QlCrypto, - mut emit: impl FnMut(QlFsmEvent), -) -> Result<(), QlFsmError> { +pub fn handle_connect_kk(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { let peer = fsm.state.peer.clone().ok_or(QlFsmError::NoPeerBound)?; prepare_for_outbound_connect(fsm); - kk::start_initiator(fsm, crypto, peer, &mut emit) + kk::start_initiator(fsm, crypto, peer) } pub fn handle_connect_xx( fsm: &mut QlFsm, token: PairingToken, crypto: &impl QlCrypto, - mut emit: impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { prepare_for_outbound_connect(fsm); - xx::start_initiator(fsm, crypto, token, &mut emit) + xx::start_initiator(fsm, crypto, token) } pub fn next_handshake_meta(fsm: &mut QlFsm) -> HandshakeMeta { @@ -75,9 +66,8 @@ pub fn handle_accept_pairing( fsm: &mut QlFsm, token: PairingToken, crypto: &impl QlCrypto, - mut emit: impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { - xx::accept_pairing(fsm, crypto, token, &mut emit) + xx::accept_pairing(fsm, crypto, token) } pub fn handle_reject_pairing(fsm: &mut QlFsm, token: PairingToken) -> Result<(), QlFsmError> { @@ -109,21 +99,20 @@ pub fn handle_handshake_record( fsm: &mut QlFsm, crypto: &impl QlCrypto, record: &QlHandshakeRecord, - emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { match record { - QlHandshakeRecord::Ik1(message) => ik::handle_ik1(fsm, crypto, message, emit), - QlHandshakeRecord::Ik2(message) => ik::handle_ik2(fsm, crypto, message, emit), - QlHandshakeRecord::Kk1(message) => kk::handle_kk1(fsm, crypto, message, emit), - QlHandshakeRecord::Kk2(message) => kk::handle_kk2(fsm, crypto, message, emit), - QlHandshakeRecord::Xx1(message) => xx::handle_xx1(fsm, crypto, message, emit), - QlHandshakeRecord::Xx2(message) => xx::handle_xx2(fsm, crypto, message, emit), - QlHandshakeRecord::Xx3(message) => xx::handle_xx3(fsm, crypto, message, emit), - QlHandshakeRecord::Xx4(message) => xx::handle_xx4(fsm, crypto, message, emit), + QlHandshakeRecord::Ik1(message) => ik::handle_ik1(fsm, crypto, message), + QlHandshakeRecord::Ik2(message) => ik::handle_ik2(fsm, crypto, message), + QlHandshakeRecord::Kk1(message) => kk::handle_kk1(fsm, crypto, message), + QlHandshakeRecord::Kk2(message) => kk::handle_kk2(fsm, crypto, message), + QlHandshakeRecord::Xx1(message) => xx::handle_xx1(fsm, crypto, message), + QlHandshakeRecord::Xx2(message) => xx::handle_xx2(fsm, crypto, message), + QlHandshakeRecord::Xx3(message) => xx::handle_xx3(fsm, crypto, message), + QlHandshakeRecord::Xx4(message) => xx::handle_xx4(fsm, crypto, message), } } -pub fn handle_timer(fsm: &mut QlFsm, emit: &mut impl FnMut(QlFsmEvent)) { +pub fn handle_timer(fsm: &mut QlFsm) { let Some(deadline) = fsm.state.link.handshake_deadline() else { return; }; @@ -133,7 +122,7 @@ pub fn handle_timer(fsm: &mut QlFsm, emit: &mut impl FnMut(QlFsmEvent)) { fsm.state.link = LinkState::Idle; fsm.state.handshake = None; - emit_peer_status(fsm, emit); + emit_peer_status(fsm); } pub fn next_handshake_deadline(fsm: &QlFsm) -> Option { @@ -144,7 +133,6 @@ pub fn finish_handshake( fsm: &mut QlFsm, transport: SessionTransport, remote_bundle: wire::PeerBundle, - emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { let xid = remote_bundle.xid; if let Some(peer) = fsm.state.peer.as_ref() { @@ -153,7 +141,7 @@ pub fn finish_handshake( } } else { fsm.state.peer = Some(remote_bundle); - emit(QlFsmEvent::NewPeer); + fsm.pending_events.push_back(QlFsmEvent::NewPeer); } let config = &fsm.config; @@ -174,7 +162,7 @@ pub fn finish_handshake( fsm.state.now.instant, ); fsm.state.link = LinkState::Connected(ConnectedState { transport, session }); - emit_peer_status(fsm, emit); + emit_peer_status(fsm); Ok(()) } diff --git a/ql-fsm/src/implementation/handshake/xx.rs b/ql-fsm/src/implementation/handshake/xx.rs index 04df7e7e..3171f743 100644 --- a/ql-fsm/src/implementation/handshake/xx.rs +++ b/ql-fsm/src/implementation/handshake/xx.rs @@ -15,7 +15,6 @@ pub fn start_initiator( fsm: &mut QlFsm, crypto: &impl QlCrypto, token: PairingToken, - emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { let meta = super::next_handshake_meta(fsm); let mut handshake = wire::XxHandshake::new_initiator( @@ -33,7 +32,7 @@ pub fn start_initiator( deadline: fsm.state.now.instant + fsm.config.handshake_timeout, }); enqueue_handshake(fsm, QlHandshakeRecord::Xx1(message)); - emit_peer_status(fsm, emit); + emit_peer_status(fsm); Ok(()) } @@ -41,7 +40,6 @@ pub fn handle_xx1( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Xx1, - _emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { if should_ignore_inbound(fsm, message) { return Ok(()); @@ -77,7 +75,6 @@ pub fn handle_xx2( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Xx2, - _emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { { let LinkState::XxInitiator(state) = &mut fsm.state.link else { @@ -103,7 +100,6 @@ pub fn handle_xx3( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Xx3, - emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { let LinkState::XxResponder(state) = &mut fsm.state.link else { return Ok(()); @@ -126,7 +122,7 @@ pub fn handle_xx3( handshake_meta, deadline, }); - emit(QlFsmEvent::PairingPending); + fsm.pending_events.push_back(QlFsmEvent::PairingPending); Ok(()) } @@ -134,7 +130,6 @@ pub fn handle_xx4( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Xx4, - emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { { let LinkState::XxInitiator(state) = &mut fsm.state.link else { @@ -155,14 +150,13 @@ pub fn handle_xx4( }; let (transport, remote_bundle) = SessionTransport::from_finalized(state.handshake.finalize(crypto)?); - finish_handshake(fsm, transport, remote_bundle, emit) + finish_handshake(fsm, transport, remote_bundle) } pub fn accept_pairing( fsm: &mut QlFsm, crypto: &impl QlCrypto, token: PairingToken, - emit: &mut impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { { let LinkState::XxResponderPending(state) = &mut fsm.state.link else { @@ -181,7 +175,7 @@ pub fn accept_pairing( }; let (transport, remote_bundle) = SessionTransport::from_finalized(state.handshake.finalize(crypto)?); - finish_handshake(fsm, transport, remote_bundle, emit) + finish_handshake(fsm, transport, remote_bundle) } pub fn reject_pairing(fsm: &mut QlFsm, token: PairingToken) -> Result<(), QlFsmError> { diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 390c579c..fce0dd4f 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -11,10 +11,10 @@ //! //! outputs from `QlFsm` are //! - outbound session and handshake records from `take_next_write` -//! - callback-driven `QlFsmEvent`s emitted during `connect_ik`, `connect_kk`, `connect_xx`, -//! `receive`, and `on_timer` +//! - queued `QlFsmEvent`s returned by `poll_event` after `connect_ik`, `connect_kk`, +//! `connect_xx`, `accept_pairing`, `receive`, and `on_timer` //! -//! call `next_deadline` after handling current inputs and any emitted outputs +//! call `next_deadline` after handling current inputs and any queued outputs //! use it to decide how long the outer loop can wait before `on_timer` must run //! another input may arrive before that deadline, which is fine @@ -26,7 +26,10 @@ pub(crate) mod state; #[cfg(test)] mod tests; -use std::time::{Duration, Instant}; +use std::{ + collections::VecDeque, + time::{Duration, Instant}, +}; pub use bytes::Bytes; pub use error::*; @@ -143,6 +146,7 @@ pub struct QlFsm { /// local identity and private keys pub identity: QlIdentity, pub(crate) state: QlFsmState, + pending_events: VecDeque, } impl QlFsm { @@ -160,6 +164,7 @@ impl QlFsm { link: LinkState::Idle, now, }, + pending_events: VecDeque::new(), } } @@ -190,10 +195,9 @@ impl QlFsm { now: FsmTime, token: PairingToken, crypto: &impl QlCrypto, - emit: impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { self.state.now = now; - implementation::handle_connect_xx(self, token, crypto, emit) + implementation::handle_connect_xx(self, token, crypto) } /// returns the pending inbound xx candidate token and peer, if any @@ -207,10 +211,9 @@ impl QlFsm { now: FsmTime, token: PairingToken, crypto: &impl QlCrypto, - emit: impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { self.state.now = now; - implementation::handle_accept_pairing(self, token, crypto, emit) + implementation::handle_accept_pairing(self, token, crypto) } /// rejects a pending inbound xx pairing for the matching token @@ -219,25 +222,15 @@ impl QlFsm { } /// starts or replaces an IK handshake with the currently bound peer - pub fn connect_ik( - &mut self, - now: FsmTime, - crypto: &impl QlCrypto, - emit: impl FnMut(QlFsmEvent), - ) -> Result<(), QlFsmError> { + pub fn connect_ik(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { self.state.now = now; - implementation::handle_connect_ik(self, crypto, emit) + implementation::handle_connect_ik(self, crypto) } /// starts or replaces a KK handshake with the currently bound peer - pub fn connect_kk( - &mut self, - now: FsmTime, - crypto: &impl QlCrypto, - emit: impl FnMut(QlFsmEvent), - ) -> Result<(), QlFsmError> { + pub fn connect_kk(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { self.state.now = now; - implementation::handle_connect_kk(self, crypto, emit) + implementation::handle_connect_kk(self, crypto) } /// handles one inbound wire message @@ -246,16 +239,20 @@ impl QlFsm { now: FsmTime, bytes: Vec, crypto: &impl QlCrypto, - emit: impl FnMut(QlFsmEvent), ) -> Result<(), QlFsmError> { self.state.now = now; - implementation::receive(self, bytes, crypto, emit) + implementation::receive(self, bytes, crypto) } /// advances time-based state - pub fn on_timer(&mut self, now: FsmTime, emit: impl FnMut(QlFsmEvent)) { + pub fn on_timer(&mut self, now: FsmTime) { self.state.now = now; - implementation::on_timer(self, emit); + implementation::on_timer(self); + } + + /// returns the next queued event, if any + pub fn poll_event(&mut self) -> Option { + self.pending_events.pop_front() } /// returns the next timer deadline, if any diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 65bd9c6f..204d41d4 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -108,19 +108,10 @@ fn connect_methods_require_bound_peer() { let mut fsm = QlFsm::new(QlFsmConfig::default(), identity, time); let crypto = TestCrypto::new(9); - assert_eq!( - fsm.connect_ik(time, &crypto, |_| {}), - Err(QlFsmError::NoPeerBound) - ); - assert_eq!( - fsm.connect_kk(time, &crypto, |_| {}), - Err(QlFsmError::NoPeerBound) - ); + assert_eq!(fsm.connect_ik(time, &crypto), Err(QlFsmError::NoPeerBound)); + assert_eq!(fsm.connect_kk(time, &crypto), Err(QlFsmError::NoPeerBound)); - assert_eq!( - fsm.connect_xx(time, pairing_token(2), &crypto, |_| {}), - Ok(()) - ); + assert_eq!(fsm.connect_xx(time, pairing_token(2), &crypto), Ok(())); } #[test] diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 06e7f717..32d96a52 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -4,7 +4,6 @@ mod session; use std::{ cell::Cell, - collections::VecDeque, time::{Duration, Instant}, }; @@ -144,7 +143,6 @@ impl QlKem for TestCrypto { struct Node { fsm: QlFsm, crypto: TestCrypto, - events: VecDeque, } struct Harness { @@ -187,12 +185,10 @@ impl Harness { a: Node { fsm: QlFsm::new(config_a, identity_a.clone(), time), crypto: TestCrypto::new(1), - events: Default::default(), }, b: Node { fsm: QlFsm::new(config_b, identity_b.clone(), time), crypto: TestCrypto::new(2), - events: Default::default(), }, }; @@ -282,82 +278,50 @@ impl Harness { fn connect_ik_a(&mut self) -> Result<(), QlFsmError> { let time = self.time(); - let Node { - fsm, - crypto, - events, - } = &mut self.a; - fsm.connect_ik(time, crypto, |event| events.push_back(event)) + let Node { fsm, crypto } = &mut self.a; + fsm.connect_ik(time, crypto) } fn connect_ik_b(&mut self) -> Result<(), QlFsmError> { let time = self.time(); - let Node { - fsm, - crypto, - events, - } = &mut self.b; - fsm.connect_ik(time, crypto, |event| events.push_back(event)) + let Node { fsm, crypto } = &mut self.b; + fsm.connect_ik(time, crypto) } fn connect_kk_a(&mut self) -> Result<(), QlFsmError> { let time = self.time(); - let Node { - fsm, - crypto, - events, - } = &mut self.a; - fsm.connect_kk(time, crypto, |event| events.push_back(event)) + let Node { fsm, crypto } = &mut self.a; + fsm.connect_kk(time, crypto) } fn connect_kk_b(&mut self) -> Result<(), QlFsmError> { let time = self.time(); - let Node { - fsm, - crypto, - events, - } = &mut self.b; - fsm.connect_kk(time, crypto, |event| events.push_back(event)) + let Node { fsm, crypto } = &mut self.b; + fsm.connect_kk(time, crypto) } fn connect_xx_a(&mut self, token: PairingToken) -> Result<(), QlFsmError> { let time = self.time(); - let Node { - fsm, - crypto, - events, - } = &mut self.a; - fsm.connect_xx(time, token, crypto, |event| events.push_back(event)) + let Node { fsm, crypto } = &mut self.a; + fsm.connect_xx(time, token, crypto) } fn connect_xx_b(&mut self, token: PairingToken) -> Result<(), QlFsmError> { let time = self.time(); - let Node { - fsm, - crypto, - events, - } = &mut self.b; - fsm.connect_xx(time, token, crypto, |event| events.push_back(event)) + let Node { fsm, crypto } = &mut self.b; + fsm.connect_xx(time, token, crypto) } fn accept_pairing_a(&mut self, token: PairingToken) -> Result<(), QlFsmError> { let time = self.time(); - let Node { - fsm, - crypto, - events, - } = &mut self.a; - fsm.accept_pairing(time, token, crypto, |event| events.push_back(event)) + let Node { fsm, crypto } = &mut self.a; + fsm.accept_pairing(time, token, crypto) } fn accept_pairing_b(&mut self, token: PairingToken) -> Result<(), QlFsmError> { let time = self.time(); - let Node { - fsm, - crypto, - events, - } = &mut self.b; - fsm.accept_pairing(time, token, crypto, |event| events.push_back(event)) + let Node { fsm, crypto } = &mut self.b; + fsm.accept_pairing(time, token, crypto) } fn reject_pairing_b(&mut self, token: PairingToken) -> Result<(), QlFsmError> { @@ -366,24 +330,14 @@ impl Harness { fn deliver_to_a(&mut self, record: Vec) { let time = self.time(); - let Node { - fsm, - crypto, - events, - } = &mut self.a; - fsm.receive(time, record, crypto, |event| events.push_back(event)) - .unwrap(); + let Node { fsm, crypto } = &mut self.a; + fsm.receive(time, record, crypto).unwrap(); } fn deliver_to_b(&mut self, record: Vec) { let time = self.time(); - let Node { - fsm, - crypto, - events, - } = &mut self.b; - fsm.receive(time, record, crypto, |event| events.push_back(event)) - .unwrap(); + let Node { fsm, crypto } = &mut self.b; + fsm.receive(time, record, crypto).unwrap(); } fn confirm_write_a(&mut self, write_id: SessionWriteId) { @@ -396,30 +350,36 @@ impl Harness { fn on_timer_a(&mut self) { let time = self.time(); - let Node { fsm, events, .. } = &mut self.a; - fsm.on_timer(time, |event| events.push_back(event)); + self.a.fsm.on_timer(time); } fn on_timer_b(&mut self) { let time = self.time(); - let Node { fsm, events, .. } = &mut self.b; - fsm.on_timer(time, |event| events.push_back(event)); + self.b.fsm.on_timer(time); } fn take_event_a(&mut self) -> Option { - self.a.events.pop_front() + self.a.fsm.poll_event() } fn take_event_b(&mut self) -> Option { - self.b.events.pop_front() + self.b.fsm.poll_event() } fn drain_events_a(&mut self) -> Vec { - self.a.events.drain(..).collect() + let mut events = Vec::new(); + while let Some(event) = self.a.fsm.poll_event() { + events.push(event); + } + events } fn drain_events_b(&mut self) -> Vec { - self.b.events.drain(..).collect() + let mut events = Vec::new(); + while let Some(event) = self.b.fsm.poll_event() { + events.push(event); + } + events } fn pump(&mut self) { diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index 8de5b39e..79bf9364 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -851,22 +851,14 @@ fn reject_taken_b(harness: &mut Harness, write: &TakenWrite) { fn deliver_to_a(harness: &mut Harness, record: Vec) -> Result<(), QlFsmError> { let time = harness.time(); - let Node { - fsm, - crypto, - events, - } = &mut harness.a; - fsm.receive(time, record, crypto, |event| events.push_back(event)) + let Node { fsm, crypto } = &mut harness.a; + fsm.receive(time, record, crypto) } fn deliver_to_b(harness: &mut Harness, record: Vec) -> Result<(), QlFsmError> { let time = harness.time(); - let Node { - fsm, - crypto, - events, - } = &mut harness.b; - fsm.receive(time, record, crypto, |event| events.push_back(event)) + let Node { fsm, crypto } = &mut harness.b; + fsm.receive(time, record, crypto) } fn take_pending(pending: &mut Vec>, index: usize) -> Option> { diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 2ee8e09f..92ce9e79 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -3,7 +3,7 @@ mod state; mod test; use std::{ - collections::{hash_map::Entry, HashMap, VecDeque}, + collections::{hash_map::Entry, HashMap}, future::Future, pin::{pin, Pin}, task::Poll, @@ -43,7 +43,6 @@ impl Runtime

{ streams: HashMap::new(), runtime_tx: tx, max_concurrent_message_writes: config.max_concurrent_message_writes, - pending_fsm_events: VecDeque::new(), }; let mut in_flight = Vec::new(); @@ -74,26 +73,19 @@ impl Runtime

{ } DriverEvent::PairingDecision { token, accept } => { pairing_decision = None; - let _ = state.with_fsm_events( - &mut fsm, - &platform, - &mut pairing_decision, - |fsm, emit| { + let _ = + state.with_fsm_events(&mut fsm, &platform, &mut pairing_decision, |fsm| { if accept { - fsm.accept_pairing(now(), token, &platform, emit) + fsm.accept_pairing(now(), token, &platform) } else { fsm.reject_pairing(token) } - }, - ); + }); } DriverEvent::TimerExpired => { - state.with_fsm_events( - &mut fsm, - &platform, - &mut pairing_decision, - |fsm, emit| fsm.on_timer(now(), emit), - ); + state.with_fsm_events(&mut fsm, &platform, &mut pairing_decision, |fsm| { + fsm.on_timer(now()); + }); } DriverEvent::CommandsClosed => { if in_flight.is_empty() && pairing_decision.is_none() { @@ -175,8 +167,8 @@ impl DriverState { fsm.bind_peer(peer); } RuntimeCommand::Connect => { - let _ = self.with_fsm_events(fsm, platform, pairing_decision, |fsm, emit| { - fsm.connect_ik(now(), platform, emit) + let _ = self.with_fsm_events(fsm, platform, pairing_decision, |fsm| { + fsm.connect_ik(now(), platform) }); } RuntimeCommand::ArmPairing { token } => { @@ -186,13 +178,13 @@ impl DriverState { fsm.disarm_pairing(); } RuntimeCommand::StartPairing { token } => { - let _ = self.with_fsm_events(fsm, platform, pairing_decision, |fsm, emit| { - fsm.connect_xx(now(), token, platform, emit) + let _ = self.with_fsm_events(fsm, platform, pairing_decision, |fsm| { + fsm.connect_xx(now(), token, platform) }); } RuntimeCommand::Incoming(bytes) => { - let _ = self.with_fsm_events(fsm, platform, pairing_decision, |fsm, emit| { - fsm.receive(now(), bytes, platform, emit) + let _ = self.with_fsm_events(fsm, platform, pairing_decision, |fsm| { + fsm.receive(now(), bytes, platform) }); } RuntimeCommand::OpenStream { @@ -287,26 +279,13 @@ impl DriverState { fsm: &mut QlFsm, platform: &'a P, pairing_decision: &mut Option>, - run: impl FnOnce(&mut QlFsm, &mut dyn FnMut(QlFsmEvent)) -> T, + run: impl FnOnce(&mut QlFsm) -> T, ) -> T { - let output = { - let pending = &mut self.pending_fsm_events; - let mut emit = |event| pending.push_back(event); - run(fsm, &mut emit) - }; - self.process_pending_fsm_events(fsm, platform, pairing_decision); - output - } - - fn process_pending_fsm_events<'a, P: QlPlatform + 'a>( - &mut self, - fsm: &mut QlFsm, - platform: &'a P, - pairing_decision: &mut Option>, - ) { - while let Some(event) = self.pending_fsm_events.pop_front() { + let output = run(fsm); + while let Some(event) = fsm.poll_event() { self.process_fsm_event(fsm, platform, pairing_decision, event); } + output } fn process_fsm_event<'a, P: QlPlatform + 'a>( diff --git a/ql-runtime/src/driver/state.rs b/ql-runtime/src/driver/state.rs index 0279d66f..8e07368f 100644 --- a/ql-runtime/src/driver/state.rs +++ b/ql-runtime/src/driver/state.rs @@ -1,7 +1,6 @@ -use std::collections::{HashMap, VecDeque}; +use std::collections::HashMap; use bytes::Bytes; -use ql_fsm::QlFsmEvent; use ql_wire::{CloseTarget, StreamId}; use crate::{ @@ -14,7 +13,6 @@ pub struct DriverState { pub streams: HashMap, pub runtime_tx: async_channel::WeakSender, pub max_concurrent_message_writes: usize, - pub pending_fsm_events: VecDeque, } pub struct DriverStreamIo { diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index df290a48..93d19bd0 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -123,7 +123,6 @@ fn new_driver_state() -> (DriverState, QlFsm) { streams: HashMap::new(), runtime_tx: runtime_tx.downgrade(), max_concurrent_message_writes: 1, - pending_fsm_events: VecDeque::new(), }, QlFsm::new(ql_fsm::QlFsmConfig::default(), new_identity(7), now()), ) From 5730aacbb69e98192043a7bd0d8ea22301ca4174 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 8 Apr 2026 17:27:32 -0400 Subject: [PATCH 163/304] ql-runtime: use entry api to avoid double lookup --- ql-runtime/src/driver/mod.rs | 39 ++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 92ce9e79..9a9c044d 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -3,7 +3,10 @@ mod state; mod test; use std::{ - collections::{hash_map::Entry, HashMap}, + collections::{ + hash_map::{Entry, OccupiedEntry}, + HashMap, + }, future::Future, pin::{pin, Pin}, task::Poll, @@ -244,18 +247,19 @@ impl DriverState { target, code, } => { - if let Some(stream) = self.streams.get_mut(&stream_id) { + if let Entry::Occupied(mut entry) = self.streams.entry(stream_id) { + let stream = entry.get_mut(); if target == CloseTarget::Both || target == stream.inbound_target() { stream.inbound_close(); } if target == CloseTarget::Both || target == stream.outbound_target() { stream.outbound_close(); } + Self::try_reap_stream(entry); } if let Ok(mut stream) = fsm.stream(stream_id) { stream.close(target, code); } - self.try_reap_stream(stream_id); } } } @@ -424,7 +428,9 @@ impl DriverState { } if peer_closed { stream_ops.close(target, StreamCloseCode(0)); - self.try_reap_stream(stream_id); + if let Entry::Occupied(entry) = self.streams.entry(stream_id) { + Self::try_reap_stream(entry); + } } drop(stream_ops); @@ -446,21 +452,23 @@ impl DriverState { } } - let Some(stream) = self.streams.get_mut(&stream_id) else { + let Entry::Occupied(mut entry) = self.streams.entry(stream_id) else { return; }; + let stream = entry.get_mut(); if !stream.inbound_finish_pending() { return; } stream.inbound_finish(); - self.try_reap_stream(stream_id); + Self::try_reap_stream(entry); } fn handle_closed_stream(&mut self, frame: &ql_wire::StreamClose) { - let Some(stream) = self.streams.get_mut(&frame.stream_id) else { + let Entry::Occupied(mut entry) = self.streams.entry(frame.stream_id) else { return; }; + let stream = entry.get_mut(); if frame.target == CloseTarget::Both || frame.target == stream.inbound_target() { stream.inbound_fail(QlStreamError::StreamClosed { code: frame.code }); @@ -468,15 +476,16 @@ impl DriverState { if frame.target == CloseTarget::Both || frame.target == stream.outbound_target() { stream.outbound_fail(QlStreamError::StreamClosed { code: frame.code }); } - self.try_reap_stream(frame.stream_id); + Self::try_reap_stream(entry); } fn handle_writable_closed(&mut self, frame: &ql_wire::StreamClose) { - let Some(stream) = self.streams.get_mut(&frame.stream_id) else { + let Entry::Occupied(mut entry) = self.streams.entry(frame.stream_id) else { return; }; + let stream = entry.get_mut(); stream.outbound_fail(QlStreamError::StreamClosed { code: frame.code }); - self.try_reap_stream(frame.stream_id); + Self::try_reap_stream(entry); } fn fail_all_streams(&mut self) { @@ -563,13 +572,9 @@ impl DriverState { } } - fn try_reap_stream(&mut self, stream_id: StreamId) { - let should_reap = self - .streams - .get(&stream_id) - .is_some_and(DriverStreamIo::is_closed); - if should_reap { - self.streams.remove(&stream_id); + fn try_reap_stream(entry: OccupiedEntry<'_, StreamId, DriverStreamIo>) { + if entry.get().is_closed() { + entry.remove(); } } } From 599d4589fb8d99f2e98b5a6883797758311b0e9e Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 8 Apr 2026 17:48:21 -0400 Subject: [PATCH 164/304] ql-fsm: remove pairing approval --- ql-fsm/src/implementation/handshake/ik.rs | 3 +- ql-fsm/src/implementation/handshake/kk.rs | 3 +- ql-fsm/src/implementation/handshake/mod.rs | 25 +------- ql-fsm/src/implementation/handshake/xx.rs | 67 ++++------------------ ql-fsm/src/lib.rs | 25 +------- ql-fsm/src/state.rs | 13 +---- ql-fsm/src/tests/handshake.rs | 51 +--------------- ql-fsm/src/tests/mod.rs | 16 ------ ql-fsm/src/tests/proptest.rs | 2 +- 9 files changed, 20 insertions(+), 185 deletions(-) diff --git a/ql-fsm/src/implementation/handshake/ik.rs b/ql-fsm/src/implementation/handshake/ik.rs index ee9335e8..ac8b62e1 100644 --- a/ql-fsm/src/implementation/handshake/ik.rs +++ b/ql-fsm/src/implementation/handshake/ik.rs @@ -104,8 +104,7 @@ pub fn should_ignore_inbound(fsm: &QlFsm, message: &Ik1) -> bool { | LinkState::Connected(_) | LinkState::KkInitiator(_) | LinkState::XxInitiator(_) - | LinkState::XxResponder(_) - | LinkState::XxResponderPending(_) => false, + | LinkState::XxResponder(_) => false, LinkState::IkInitiator(state) => { if fsm.state.peer.as_ref().map(|peer| peer.xid) != Some(message.header.sender) { return false; diff --git a/ql-fsm/src/implementation/handshake/kk.rs b/ql-fsm/src/implementation/handshake/kk.rs index 8f23d337..38bafd5e 100644 --- a/ql-fsm/src/implementation/handshake/kk.rs +++ b/ql-fsm/src/implementation/handshake/kk.rs @@ -102,8 +102,7 @@ pub fn should_ignore_inbound(fsm: &QlFsm, message: &Kk1) -> bool { LinkState::Idle | LinkState::Connected(_) | LinkState::XxInitiator(_) - | LinkState::XxResponder(_) - | LinkState::XxResponderPending(_) => false, + | LinkState::XxResponder(_) => false, LinkState::IkInitiator(_) => true, LinkState::KkInitiator(state) => { if fsm.state.peer.as_ref().map(|peer| peer.xid) != Some(message.header.sender) { diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs index e0f22224..16d0856d 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -3,8 +3,7 @@ mod kk; mod xx; use ql_wire::{ - self as wire, EphemeralPublicKey, HandshakeMeta, PairingToken, PeerBundle, QlCrypto, - QlHandshakeRecord, + self as wire, EphemeralPublicKey, HandshakeMeta, PairingToken, QlCrypto, QlHandshakeRecord, }; use super::emit_peer_status; @@ -52,28 +51,6 @@ pub fn enqueue_handshake(fsm: &mut QlFsm, record: QlHandshakeRecord) { fsm.state.handshake = Some(record); } -pub fn pending_xx_pairing(fsm: &QlFsm) -> Option<(PairingToken, &PeerBundle)> { - match &fsm.state.link { - crate::state::LinkState::XxResponderPending(state) => state - .handshake - .remote_bundle() - .map(|peer| (state.handshake.pairing_token(), peer)), - _ => None, - } -} - -pub fn handle_accept_pairing( - fsm: &mut QlFsm, - token: PairingToken, - crypto: &impl QlCrypto, -) -> Result<(), QlFsmError> { - xx::accept_pairing(fsm, crypto, token) -} - -pub fn handle_reject_pairing(fsm: &mut QlFsm, token: PairingToken) -> Result<(), QlFsmError> { - xx::reject_pairing(fsm, token) -} - pub fn handle_disarm_pairing(fsm: &mut QlFsm) { xx::disarm_pairing(fsm); } diff --git a/ql-fsm/src/implementation/handshake/xx.rs b/ql-fsm/src/implementation/handshake/xx.rs index 3171f743..f346dae8 100644 --- a/ql-fsm/src/implementation/handshake/xx.rs +++ b/ql-fsm/src/implementation/handshake/xx.rs @@ -5,10 +5,8 @@ use super::{ reset_connected_session_if_needed, }; use crate::{ - state::{ - LinkState, SessionTransport, XxInitiatorState, XxResponderPendingState, XxResponderState, - }, - QlFsm, QlFsmError, QlFsmEvent, + state::{LinkState, SessionTransport, XxInitiatorState, XxResponderState}, + QlFsm, QlFsmError, }; pub fn start_initiator( @@ -112,18 +110,16 @@ pub fn handle_xx3( state .handshake .read_3(crypto, fsm.state.now.unix_secs, message)?; - let deadline = state.deadline; let handshake_meta = state.handshake_meta; - let LinkState::XxResponder(state) = fsm.state.link.take() else { + let LinkState::XxResponder(mut state) = fsm.state.link.take() else { unreachable!("active XX responder was checked above"); }; - fsm.state.link = LinkState::XxResponderPending(XxResponderPendingState { - handshake: state.handshake, - handshake_meta, - deadline, - }); - fsm.pending_events.push_back(QlFsmEvent::PairingPending); - Ok(()) + let outbound = state.handshake.write_4(crypto, handshake_meta)?; + fsm.state.handshake = None; + enqueue_handshake(fsm, QlHandshakeRecord::Xx4(outbound)); + let (transport, remote_bundle) = + SessionTransport::from_finalized(state.handshake.finalize(crypto)?); + finish_handshake(fsm, transport, remote_bundle) } pub fn handle_xx4( @@ -153,49 +149,8 @@ pub fn handle_xx4( finish_handshake(fsm, transport, remote_bundle) } -pub fn accept_pairing( - fsm: &mut QlFsm, - crypto: &impl QlCrypto, - token: PairingToken, -) -> Result<(), QlFsmError> { - { - let LinkState::XxResponderPending(state) = &mut fsm.state.link else { - return Err(QlFsmError::InvalidState); - }; - if state.handshake.pairing_token() != token { - return Err(QlFsmError::InvalidState); - } - let outbound = state.handshake.write_4(crypto, state.handshake_meta)?; - fsm.state.handshake = None; - enqueue_handshake(fsm, QlHandshakeRecord::Xx4(outbound)); - } - - let LinkState::XxResponderPending(state) = fsm.state.link.take() else { - unreachable!("pending XX responder was checked above"); - }; - let (transport, remote_bundle) = - SessionTransport::from_finalized(state.handshake.finalize(crypto)?); - finish_handshake(fsm, transport, remote_bundle) -} - -pub fn reject_pairing(fsm: &mut QlFsm, token: PairingToken) -> Result<(), QlFsmError> { - let LinkState::XxResponderPending(state) = &fsm.state.link else { - return Err(QlFsmError::InvalidState); - }; - if state.handshake.pairing_token() != token { - return Err(QlFsmError::InvalidState); - } - - fsm.state.link = LinkState::Idle; - fsm.state.handshake = None; - Ok(()) -} - pub fn disarm_pairing(fsm: &mut QlFsm) { - if matches!( - fsm.state.link, - LinkState::XxResponder(_) | LinkState::XxResponderPending(_) - ) { + if matches!(fsm.state.link, LinkState::XxResponder(_)) { fsm.state.link = LinkState::Idle; fsm.state.handshake = None; } @@ -205,7 +160,7 @@ pub fn should_ignore_inbound(fsm: &QlFsm, message: &Xx1) -> bool { match &fsm.state.link { LinkState::Idle | LinkState::Connected(_) => false, LinkState::IkInitiator(_) | LinkState::KkInitiator(_) => true, - LinkState::XxResponder(_) | LinkState::XxResponderPending(_) => true, + LinkState::XxResponder(_) => true, LinkState::XxInitiator(state) => { if state.handshake.pairing_token() != message.header.pairing_token { return false; diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index fce0dd4f..0d0cc884 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -12,7 +12,7 @@ //! outputs from `QlFsm` are //! - outbound session and handshake records from `take_next_write` //! - queued `QlFsmEvent`s returned by `poll_event` after `connect_ik`, `connect_kk`, -//! `connect_xx`, `accept_pairing`, `receive`, and `on_timer` +//! `connect_xx`, `receive`, and `on_timer` //! //! call `next_deadline` after handling current inputs and any queued outputs //! use it to decide how long the outer loop can wait before `on_timer` must run @@ -69,8 +69,6 @@ pub enum PeerStatus { pub enum QlFsmEvent { /// a peer was learned during handshake completion NewPeer, - /// an inbound xx pairing is waiting for an accept or reject decision - PairingPending, /// the peer changed connection state PeerStatusChanged(PeerStatus), /// a stream was opened @@ -200,27 +198,6 @@ impl QlFsm { implementation::handle_connect_xx(self, token, crypto) } - /// returns the pending inbound xx candidate token and peer, if any - pub fn pending_xx_pairing(&self) -> Option<(PairingToken, &PeerBundle)> { - implementation::pending_xx_pairing(self) - } - - /// accepts a pending inbound xx pairing for the matching token - pub fn accept_pairing( - &mut self, - now: FsmTime, - token: PairingToken, - crypto: &impl QlCrypto, - ) -> Result<(), QlFsmError> { - self.state.now = now; - implementation::handle_accept_pairing(self, token, crypto) - } - - /// rejects a pending inbound xx pairing for the matching token - pub fn reject_pairing(&mut self, token: PairingToken) -> Result<(), QlFsmError> { - implementation::handle_reject_pairing(self, token) - } - /// starts or replaces an IK handshake with the currently bound peer pub fn connect_ik(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { self.state.now = now; diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index 71d9e6da..79a8c5ee 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -47,7 +47,6 @@ pub enum LinkState { KkInitiator(KkInitiatorState), XxInitiator(XxInitiatorState), XxResponder(XxResponderState), - XxResponderPending(XxResponderPendingState), Connected(ConnectedState), } @@ -87,13 +86,6 @@ pub struct XxResponderState { pub deadline: Instant, } -#[derive(Debug, Clone)] -pub struct XxResponderPendingState { - pub handshake: XxHandshake, - pub handshake_meta: HandshakeMeta, - pub deadline: Instant, -} - impl LinkState { pub fn take(&mut self) -> Self { std::mem::replace(self, Self::Idle) @@ -101,9 +93,7 @@ impl LinkState { pub fn status(&self) -> PeerStatus { match self { - Self::Idle | Self::XxResponder(_) | Self::XxResponderPending(_) => { - PeerStatus::Disconnected - } + Self::Idle | Self::XxResponder(_) => PeerStatus::Disconnected, Self::IkInitiator(_) | Self::KkInitiator(_) | Self::XxInitiator(_) => { PeerStatus::Initiator } @@ -139,7 +129,6 @@ impl LinkState { Self::KkInitiator(state) => Some(state.deadline), Self::XxInitiator(state) => Some(state.deadline), Self::XxResponder(state) => Some(state.deadline), - Self::XxResponderPending(state) => Some(state.deadline), } } diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 204d41d4..19bfe7c4 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -28,7 +28,7 @@ fn kk_connect_round_trip_establishes_transport() { } #[test] -fn xx_connect_round_trip_establishes_transport_after_accept() { +fn xx_connect_round_trip_establishes_transport_when_armed() { let mut harness = Harness::paired(QlFsmConfig::default(), false, false); let token = pairing_token(1); @@ -42,14 +42,6 @@ fn xx_connect_round_trip_establishes_transport_after_accept() { let xx3 = harness.next_outbound_a().unwrap(); harness.deliver_to_b(xx3); - assert_eq!(harness.take_event_b(), Some(QlFsmEvent::PairingPending)); - assert_eq!( - harness.b.fsm.pending_xx_pairing(), - Some((token, &harness.a.fsm.identity.bundle())) - ); - assert!(harness.next_outbound_b().is_none()); - - harness.accept_pairing_b(token).unwrap(); let xx4 = harness.next_outbound_b().unwrap(); harness.deliver_to_a(xx4); @@ -140,28 +132,6 @@ fn inbound_xx1_ignored_when_pairing_token_not_armed() { assert!(harness.next_outbound_b().is_none()); } -#[test] -fn reject_pairing_drops_pending_xx_candidate() { - let mut harness = Harness::paired(QlFsmConfig::default(), false, false); - let token = pairing_token(4); - - harness.b.fsm.arm_pairing(token); - harness.connect_xx_a(token).unwrap(); - let xx1 = harness.next_outbound_a().unwrap(); - harness.deliver_to_b(xx1); - let xx2 = harness.next_outbound_b().unwrap(); - harness.deliver_to_a(xx2); - let xx3 = harness.next_outbound_a().unwrap(); - harness.deliver_to_b(xx3); - - assert_eq!(harness.take_event_b(), Some(QlFsmEvent::PairingPending)); - harness.reject_pairing_b(token).unwrap(); - - assert!(matches!(harness.b.fsm.state.link, LinkState::Idle)); - assert!(harness.next_outbound_b().is_none()); - assert!(harness.b.fsm.pending_xx_pairing().is_none()); -} - #[test] fn disarm_pairing_rejects_inflight_inbound_xx_responder() { let mut harness = Harness::paired(QlFsmConfig::default(), false, false); @@ -174,13 +144,11 @@ fn disarm_pairing_rejects_inflight_inbound_xx_responder() { let xx2 = harness.next_outbound_b().unwrap(); harness.deliver_to_a(xx2); let xx3 = harness.next_outbound_a().unwrap(); - harness.deliver_to_b(xx3); - - assert_eq!(harness.take_event_b(), Some(QlFsmEvent::PairingPending)); harness.b.fsm.disarm_pairing(); + harness.deliver_to_b(xx3); assert!(matches!(harness.b.fsm.state.link, LinkState::Idle)); - assert!(harness.b.fsm.pending_xx_pairing().is_none()); + assert!(harness.next_outbound_b().is_none()); } #[test] @@ -201,19 +169,6 @@ fn simultaneous_xx_connect_converges() { harness.deliver_to_a(record); } } - - let event_a = harness.take_event_a(); - let event_b = harness.take_event_b(); - assert!( - matches!(event_a, Some(QlFsmEvent::PairingPending)) - || matches!(event_b, Some(QlFsmEvent::PairingPending)) - ); - if matches!(event_a, Some(QlFsmEvent::PairingPending)) { - harness.accept_pairing_a(token).unwrap(); - } - if matches!(event_b, Some(QlFsmEvent::PairingPending)) { - harness.accept_pairing_b(token).unwrap(); - } harness.pump(); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 32d96a52..887e191e 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -312,22 +312,6 @@ impl Harness { fsm.connect_xx(time, token, crypto) } - fn accept_pairing_a(&mut self, token: PairingToken) -> Result<(), QlFsmError> { - let time = self.time(); - let Node { fsm, crypto } = &mut self.a; - fsm.accept_pairing(time, token, crypto) - } - - fn accept_pairing_b(&mut self, token: PairingToken) -> Result<(), QlFsmError> { - let time = self.time(); - let Node { fsm, crypto } = &mut self.b; - fsm.accept_pairing(time, token, crypto) - } - - fn reject_pairing_b(&mut self, token: PairingToken) -> Result<(), QlFsmError> { - self.b.fsm.reject_pairing(token) - } - fn deliver_to_a(&mut self, record: Vec) { let time = self.time(); let Node { fsm, crypto } = &mut self.a; diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index 79bf9364..e0206e4b 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -455,7 +455,7 @@ impl Runner { fn process_events(&mut self, side: Side, events: Vec) -> TestCaseResult { for event in events { match event { - QlFsmEvent::NewPeer | QlFsmEvent::PairingPending => {} + QlFsmEvent::NewPeer => {} QlFsmEvent::PeerStatusChanged(status) => { self.events_mut(side).note_peer_status(status); } From 0d835c2d7b8de02d2fd7a1d7ce5fb48a440b203b Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 8 Apr 2026 17:50:45 -0400 Subject: [PATCH 165/304] ql-runtime: remove pair approval --- ql-runtime/src/driver/mod.rs | 105 +++++------------------------- ql-runtime/src/driver/test.rs | 15 +---- ql-runtime/src/platform.rs | 10 +-- ql-runtime/src/tests/handshake.rs | 23 ++----- ql-runtime/src/tests/mod.rs | 74 ++------------------- 5 files changed, 31 insertions(+), 196 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 9a9c044d..ca86ab72 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -15,14 +15,14 @@ use std::{ use futures_lite::future::poll_fn; use ql_fsm::{FsmTime, QlFsm, QlFsmEvent, SessionWriteId}; -use ql_wire::{CloseTarget, PairingToken, StreamCloseCode, StreamId}; +use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; use self::state::{DriverState, DriverStreamIo, InboundIo, InboundWriteResult, OutboundIo}; use crate::{ chunk_slot, command::RuntimeCommand, handle::{ByteReader, ByteWriter, QlStream}, - platform::{PlatformFuture, QlPlatform, QlTimer}, + platform::{QlPlatform, QlTimer}, QlStreamError, Runtime, RuntimeHandle, }; @@ -49,49 +49,29 @@ impl Runtime

{ }; let mut in_flight = Vec::new(); - let mut pairing_decision = None; let mut timer = platform.timer(); let recv_future = rx.recv(); let mut recv_future = pin!(recv_future); loop { state.fill_write_slots(&mut fsm, &platform, &mut in_flight); - state.sync_pairing_decision_state(&fsm, &mut pairing_decision); timer.set_deadline(fsm.next_deadline()); - match next_driver_event( - recv_future.as_mut(), - &mut timer, - &mut in_flight, - &mut pairing_decision, - ) - .await - { + match next_driver_event(recv_future.as_mut(), &mut timer, &mut in_flight).await { DriverEvent::Command(command) => { - state.drive_command(&mut fsm, command, &platform, &mut pairing_decision); + state.drive_command(&mut fsm, command, &platform); } DriverEvent::WriteCompleted { index, success } => { let write = in_flight.swap_remove(index); DriverState::drive_write_completed(&mut fsm, write.session_write_id, success); } - DriverEvent::PairingDecision { token, accept } => { - pairing_decision = None; - let _ = - state.with_fsm_events(&mut fsm, &platform, &mut pairing_decision, |fsm| { - if accept { - fsm.accept_pairing(now(), token, &platform) - } else { - fsm.reject_pairing(token) - } - }); - } DriverEvent::TimerExpired => { - state.with_fsm_events(&mut fsm, &platform, &mut pairing_decision, |fsm| { + state.with_fsm_events(&mut fsm, &platform, |fsm| { fsm.on_timer(now()); }); } DriverEvent::CommandsClosed => { - if in_flight.is_empty() && pairing_decision.is_none() { + if in_flight.is_empty() { break; } } @@ -105,15 +85,9 @@ struct InFlightWrite { future: F, } -struct InFlightPairingDecision<'a> { - token: PairingToken, - future: PlatformFuture<'a, bool>, -} - enum DriverEvent { Command(RuntimeCommand), WriteCompleted { index: usize, success: bool }, - PairingDecision { token: PairingToken, accept: bool }, TimerExpired, CommandsClosed, } @@ -123,7 +97,6 @@ async fn next_driver_event( mut recv_future: Pin<&mut async_channel::Recv<'_, RuntimeCommand>>, timer: &mut T, in_flight: &mut [InFlightWrite], - pairing_decision: &mut Option>, ) -> DriverEvent where T: QlTimer, @@ -136,15 +109,6 @@ where } } - if let Some(decision) = pairing_decision.as_mut() { - if let Poll::Ready(accept) = Pin::new(&mut decision.future).poll(cx) { - return Poll::Ready(DriverEvent::PairingDecision { - token: decision.token, - accept, - }); - } - } - if timer.poll_wait(cx) == Poll::Ready(()) { return Poll::Ready(DriverEvent::TimerExpired); } @@ -158,21 +122,18 @@ where } impl DriverState { - fn drive_command<'a, P: QlPlatform + 'a>( + fn drive_command( &mut self, fsm: &mut QlFsm, command: RuntimeCommand, - platform: &'a P, - pairing_decision: &mut Option>, + platform: &P, ) { match command { RuntimeCommand::BindPeer { peer } => { fsm.bind_peer(peer); } RuntimeCommand::Connect => { - let _ = self.with_fsm_events(fsm, platform, pairing_decision, |fsm| { - fsm.connect_ik(now(), platform) - }); + let _ = self.with_fsm_events(fsm, platform, |fsm| fsm.connect_ik(now(), platform)); } RuntimeCommand::ArmPairing { token } => { fsm.arm_pairing(token); @@ -181,14 +142,12 @@ impl DriverState { fsm.disarm_pairing(); } RuntimeCommand::StartPairing { token } => { - let _ = self.with_fsm_events(fsm, platform, pairing_decision, |fsm| { - fsm.connect_xx(now(), token, platform) - }); + let _ = self + .with_fsm_events(fsm, platform, |fsm| fsm.connect_xx(now(), token, platform)); } RuntimeCommand::Incoming(bytes) => { - let _ = self.with_fsm_events(fsm, platform, pairing_decision, |fsm| { - fsm.receive(now(), bytes, platform) - }); + let _ = + self.with_fsm_events(fsm, platform, |fsm| fsm.receive(now(), bytes, platform)); } RuntimeCommand::OpenStream { request_reader, @@ -278,25 +237,23 @@ impl DriverState { } } - fn with_fsm_events<'a, P: QlPlatform + 'a, T>( + fn with_fsm_events( &mut self, fsm: &mut QlFsm, - platform: &'a P, - pairing_decision: &mut Option>, + platform: &P, run: impl FnOnce(&mut QlFsm) -> T, ) -> T { let output = run(fsm); while let Some(event) = fsm.poll_event() { - self.process_fsm_event(fsm, platform, pairing_decision, event); + self.process_fsm_event(fsm, platform, event); } output } - fn process_fsm_event<'a, P: QlPlatform + 'a>( + fn process_fsm_event( &mut self, fsm: &mut QlFsm, - platform: &'a P, - pairing_decision: &mut Option>, + platform: &P, event: QlFsmEvent, ) { match event { @@ -305,17 +262,6 @@ impl DriverState { platform.persist_peer(peer); } } - QlFsmEvent::PairingPending => { - if let Some((token, peer)) = fsm.pending_xx_pairing() { - let peer = peer.clone(); - *pairing_decision = Some(InFlightPairingDecision { - token, - future: Box::pin(async move { - platform.handle_pairing_request(token, peer).await - }), - }); - } - } QlFsmEvent::PeerStatusChanged(status) => { if let Some(peer) = fsm.peer().map(|peer| peer.xid) { platform.handle_peer_status(peer, status); @@ -495,21 +441,6 @@ impl DriverState { self.streams.clear(); } - fn sync_pairing_decision_state( - &self, - fsm: &QlFsm, - pairing_decision: &mut Option>, - ) { - if let Some(decision) = pairing_decision.as_ref() { - let is_current = fsm - .pending_xx_pairing() - .is_some_and(|(token, _)| token == decision.token); - if !is_current { - *pairing_decision = None; - } - } - } - fn fill_write_slots<'a, P: QlPlatform + 'a>( &self, fsm: &mut QlFsm, diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index 93d19bd0..05333db5 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -1,8 +1,8 @@ use std::task::{Context, Poll}; use ql_wire::{ - MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, PairingToken, PeerBundle, - QlAead, QlHash, QlKem, QlRandom, SessionKey, StreamClose, XID, + MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, PeerBundle, QlAead, QlHash, + QlKem, QlRandom, SessionKey, StreamClose, XID, }; use super::*; @@ -87,7 +87,6 @@ impl crate::platform::QlTimer for NoopTimer { impl QlPlatform for NoopPlatform { type Timer = NoopTimer; type WriteMessageFut<'a> = std::future::Ready; - type PairingDecisionFut<'a> = std::future::Ready; fn write_message(&self, _message: Vec) -> Self::WriteMessageFut<'_> { std::future::ready(true) @@ -105,14 +104,6 @@ impl QlPlatform for NoopPlatform { fn handle_peer_status(&self, _peer: XID, _status: ql_fsm::PeerStatus) {} - fn handle_pairing_request( - &self, - _token: PairingToken, - _peer: PeerBundle, - ) -> Self::PairingDecisionFut<'_> { - std::future::ready(false) - } - fn handle_inbound(&self, _event: QlStream) {} } @@ -203,7 +194,6 @@ fn local_close_command_reaps_when_other_half_is_already_closed() { let stream_id = StreamId(1u32.into()); let (request_reader, _request_writer) = chunk_slot::new(); let (request_terminal_tx, _request_terminal_rx) = oneshot::channel(); - let mut pairing_decision = None; state.streams.insert( stream_id, @@ -222,7 +212,6 @@ fn local_close_command_reaps_when_other_half_is_already_closed() { code: StreamCloseCode(0), }, &NoopPlatform, - &mut pairing_decision, ); assert!(!state.streams.contains_key(&stream_id)); diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs index 713efeb9..411627c0 100644 --- a/ql-runtime/src/platform.rs +++ b/ql-runtime/src/platform.rs @@ -6,7 +6,7 @@ use std::{ }; use ql_fsm::PeerStatus; -use ql_wire::{PairingToken, PeerBundle, QlCrypto, XID}; +use ql_wire::{PeerBundle, QlCrypto, XID}; use crate::QlStream; @@ -20,9 +20,6 @@ pub trait QlTimer { pub trait QlPlatform: QlCrypto { type Timer: QlTimer; type WriteMessageFut<'a>: Future + Unpin + 'a - where - Self: 'a; - type PairingDecisionFut<'a>: Future + Unpin + 'a where Self: 'a; @@ -33,10 +30,5 @@ pub trait QlPlatform: QlCrypto { fn persist_peer(&self, peer: PeerBundle); fn handle_peer_status(&self, peer: XID, status: PeerStatus); - fn handle_pairing_request( - &self, - token: PairingToken, - peer: PeerBundle, - ) -> Self::PairingDecisionFut<'_>; fn handle_inbound(&self, event: QlStream); } diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index c4916d94..9483b631 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -141,11 +141,11 @@ async fn rejected_session_write_is_reissued() { } #[tokio::test(flavor = "current_thread")] -async fn start_pairing_round_trip_uses_platform_decision_to_connect() { +async fn start_pairing_round_trip_connects_when_armed() { run_local_test(async { let config = default_runtime_config(); let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, pairing_b) = TestPlatform::new_with_pairing(2, true); + let (platform_b, outbound_b, status_b) = TestPlatform::new(2); let identity_a = new_identity(11); let identity_b = new_identity(73); let token = pairing_token(7); @@ -162,10 +162,6 @@ async fn start_pairing_round_trip_uses_platform_decision_to_connect() { handle_b.arm_pairing(token); handle_a.start_pairing(token); - let request = await_pairing_request(&pairing_b).await; - assert_eq!(request.token, token); - assert_eq!(request.peer, identity_a.bundle()); - await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; }) @@ -173,18 +169,16 @@ async fn start_pairing_round_trip_uses_platform_decision_to_connect() { } #[tokio::test(flavor = "current_thread")] -async fn start_pairing_rejects_when_platform_returns_false() { +async fn start_pairing_does_not_connect_when_unarmed() { run_local_test(async { let config = default_runtime_config(); let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, _status_b, pairing_b) = - TestPlatform::new_with_pairing(2, false); + let (platform_b, outbound_b, _status_b) = TestPlatform::new(2); let identity_a = new_identity(11); - let identity_b = new_identity(73); let token = pairing_token(8); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); - let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + let (runtime_b, handle_b) = new_runtime(new_identity(73), platform_b, config); tokio::task::spawn_local(async move { runtime_a.run().await }); tokio::task::spawn_local(async move { runtime_b.run().await }); @@ -192,16 +186,11 @@ async fn start_pairing_rejects_when_platform_returns_false() { spawn_forwarder(outbound_a, handle_b.clone()); spawn_forwarder(outbound_b, handle_a.clone()); - handle_b.arm_pairing(token); handle_a.start_pairing(token); - let request = await_pairing_request(&pairing_b).await; - assert_eq!(request.token, token); - assert_eq!(request.peer, identity_a.bundle()); - assert_no_status_for( &status_a, - identity_b.xid, + XID([73; XID::SIZE]), PeerStatus::Connected, Duration::from_millis(150), ) diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index f0883fc7..673831a6 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -39,12 +39,6 @@ struct StatusEvent { status: PeerStatus, } -#[derive(Debug, Clone, PartialEq, Eq)] -struct PairingRequest { - token: PairingToken, - peer: PeerBundle, -} - #[derive(Debug, Clone)] struct WriteStats { active: Arc, @@ -166,8 +160,6 @@ struct TestPlatform { outbound: Sender>, status: Sender, inbound: Option>, - pairing_requests: Option>, - pairing_accept: bool, nonce_seed: u8, nonce_counter: AtomicU8, encrypted_write_counter: AtomicUsize, @@ -178,7 +170,7 @@ struct TestPlatform { impl TestPlatform { fn new(seed: u8) -> (Self, Receiver>, Receiver) { - Self::new_inner(seed, None, None, false, None, Duration::ZERO, None) + Self::new_inner(seed, None, None, Duration::ZERO, None) } fn new_with_inbound( @@ -190,40 +182,11 @@ impl TestPlatform { Receiver, ) { let (inbound_tx, inbound_rx) = async_channel::unbounded(); - let (platform, outbound_rx, status_rx) = Self::new_inner( - seed, - Some(inbound_tx), - None, - false, - None, - Duration::ZERO, - None, - ); + let (platform, outbound_rx, status_rx) = + Self::new_inner(seed, Some(inbound_tx), None, Duration::ZERO, None); (platform, outbound_rx, status_rx, inbound_rx) } - fn new_with_pairing( - seed: u8, - pairing_accept: bool, - ) -> ( - Self, - Receiver>, - Receiver, - Receiver, - ) { - let (pairing_tx, pairing_rx) = async_channel::unbounded(); - let (platform, outbound_rx, status_rx) = Self::new_inner( - seed, - None, - Some(pairing_tx), - pairing_accept, - None, - Duration::ZERO, - None, - ); - (platform, outbound_rx, status_rx, pairing_rx) - } - fn new_with_session_write_failure( seed: u8, fail_encrypted_write_at: usize, @@ -231,8 +194,6 @@ impl TestPlatform { Self::new_inner( seed, None, - None, - false, Some(fail_encrypted_write_at), Duration::ZERO, None, @@ -244,14 +205,12 @@ impl TestPlatform { delay: Duration, write_stats: WriteStats, ) -> (Self, Receiver>, Receiver) { - Self::new_inner(seed, None, None, false, None, delay, Some(write_stats)) + Self::new_inner(seed, None, None, delay, Some(write_stats)) } fn new_inner( seed: u8, inbound: Option>, - pairing_requests: Option>, - pairing_accept: bool, fail_encrypted_write_at: Option, write_delay: Duration, write_stats: Option, @@ -263,8 +222,6 @@ impl TestPlatform { outbound, status, inbound, - pairing_requests, - pairing_accept, nonce_seed: seed, nonce_counter: AtomicU8::new(0), encrypted_write_counter: AtomicUsize::new(0), @@ -389,7 +346,6 @@ impl QlKem for TestPlatform { impl crate::platform::QlPlatform for TestPlatform { type Timer = TokioTimer; type WriteMessageFut<'a> = PlatformFuture<'a, bool>; - type PairingDecisionFut<'a> = PlatformFuture<'a, bool>; fn write_message(&self, message: Vec) -> Self::WriteMessageFut<'_> { let outbound = self.outbound.clone(); @@ -442,21 +398,6 @@ impl crate::platform::QlPlatform for TestPlatform { let _ = self.status.try_send(StatusEvent { peer, status }); } - fn handle_pairing_request( - &self, - token: PairingToken, - peer: PeerBundle, - ) -> Self::PairingDecisionFut<'_> { - let pairing_requests = self.pairing_requests.clone(); - let pairing_accept = self.pairing_accept; - Box::pin(async move { - if let Some(tx) = pairing_requests { - let _ = tx.send(PairingRequest { token, peer }).await; - } - pairing_accept - }) - } - fn handle_inbound(&self, event: QlStream) { if let Some(tx) = &self.inbound { let _ = tx.try_send(event); @@ -576,13 +517,6 @@ async fn assert_no_status_for( assert!(res.is_err(), "unexpected status event: {status:?}"); } -async fn await_pairing_request(receiver: &Receiver) -> PairingRequest { - tokio::time::timeout(Duration::from_secs(2), receiver.recv()) - .await - .unwrap() - .unwrap() -} - async fn read_all(mut stream: crate::ByteReader) -> Result, QlStreamError> { let mut data = Vec::new(); while let Some(chunk) = next_chunk(&mut stream).await? { From 7ee24fa002bea054b0011d550ba8aa16139fcf1a Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 8 Apr 2026 19:48:46 -0400 Subject: [PATCH 166/304] ql-wire: streamdata with header --- ql-wire/src/encrypted/mod.rs | 2 + ql-wire/src/encrypted/route_id.rs | 29 +++++++++ ql-wire/src/encrypted/stream_data.rs | 91 +++++++++++++++++++++++----- ql-wire/src/tests.rs | 3 + 4 files changed, 111 insertions(+), 14 deletions(-) create mode 100644 ql-wire/src/encrypted/route_id.rs diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index ec6d3bba..bb8b148f 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -6,6 +6,7 @@ use crate::{ mod ack; mod builder; mod close; +mod route_id; mod stream_close; mod stream_data; mod stream_id; @@ -14,6 +15,7 @@ mod stream_window; pub use ack::*; pub use builder::*; pub use close::*; +pub use route_id::*; pub use stream_close::*; pub use stream_data::*; pub use stream_id::*; diff --git a/ql-wire/src/encrypted/route_id.rs b/ql-wire/src/encrypted/route_id.rs new file mode 100644 index 00000000..2338e634 --- /dev/null +++ b/ql-wire/src/encrypted/route_id.rs @@ -0,0 +1,29 @@ +use crate::{ByteSlice, Reader, VarInt, WireDecode, WireEncode, WireError}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct RouteId(pub VarInt); + +impl RouteId { + pub const MAX_ENCODED_LEN: usize = VarInt::MAX_SIZE; + + pub const fn into_inner(self) -> u64 { + self.0.into_inner() + } +} + +impl WireEncode for RouteId { + fn encoded_len(&self) -> usize { + self.0.size() + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl WireDecode for RouteId { + fn decode(reader: &mut Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs index 962ae600..a0ab5011 100644 --- a/ql-wire/src/encrypted/stream_data.rs +++ b/ql-wire/src/encrypted/stream_data.rs @@ -1,4 +1,4 @@ -use super::StreamId; +use super::{RouteId, StreamId}; use crate::{codec, ByteChunks, ByteSlice, VarInt, WireDecode, WireEncode, WireError}; /// carries bytes for a stream and may finish that sending direction. @@ -6,21 +6,35 @@ use crate::{codec, ByteChunks, ByteSlice, VarInt, WireDecode, WireEncode, WireEr pub struct StreamData { pub stream_id: StreamId, pub offset: VarInt, + pub header: Option, pub fin: bool, pub bytes: B, } impl StreamData { - /// Conservative constant overhead for callers that still budget with a fixed header size. - pub const MIN_WIRE_SIZE: usize = StreamId::MAX_ENCODED_LEN + VarInt::MAX_SIZE + size_of::(); + pub const MIN_WIRE_SIZE: usize = StreamId::MAX_ENCODED_LEN + + VarInt::MAX_SIZE + + size_of::() + + StreamHeader::MAX_WIRE_SIZE; } impl WireDecode for StreamData { fn decode(reader: &mut codec::Reader) -> Result { + let stream_id = reader.decode()?; + let offset: VarInt = reader.decode()?; + let flags = reader.decode::()?; + let fin = (flags & flag::FIN) != 0; + let has_header = (flags & flag::HEADER) != 0; + Ok(Self { - stream_id: reader.decode()?, - offset: reader.decode()?, - fin: reader.decode()?, + stream_id, + offset, + header: if has_header { + Some(reader.decode()?) + } else { + None + }, + fin, bytes: reader.take_rest(), }) } @@ -34,29 +48,78 @@ impl StreamData { StreamData { stream_id: self.stream_id, offset: self.offset, + header: self.header, fin: self.fin, bytes: self.bytes.to_vec(), } } } -impl StreamData { - pub fn header_len(&self) -> usize { - self.stream_id.encoded_len() + self.offset.encoded_len() + size_of::() - } -} - impl WireEncode for StreamData { fn encoded_len(&self) -> usize { - self.header_len() + self.bytes.len() + self.stream_id.encoded_len() + + self.offset.encoded_len() + + size_of::() + + self + .header + .as_ref() + .map_or(0, |header| header.encoded_len()) + + self.bytes.len() } fn encode(&self, out: &mut W) { + debug_assert!( + self.offset.into_inner() == 0 || self.header.is_none(), + "stream header is only valid at offset 0" + ); + self.stream_id.encode(out); self.offset.encode(out); - self.fin.encode(out); + let mut flags = 0; + if self.fin { + flags |= flag::FIN; + } + if self.header.is_some() { + flags |= flag::HEADER; + } + flags.encode(out); + if let Some(header) = &self.header { + header.encode(out); + } for chunk in self.bytes.chunks() { chunk.encode(out); } } } + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct StreamHeader { + pub route_id: RouteId, +} + +impl StreamHeader { + pub const MAX_WIRE_SIZE: usize = RouteId::MAX_ENCODED_LEN; +} + +impl WireDecode for StreamHeader { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + route_id: reader.decode()?, + }) + } +} + +impl WireEncode for StreamHeader { + fn encoded_len(&self) -> usize { + self.route_id.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.route_id.encode(out); + } +} + +mod flag { + pub const FIN: u8 = 0x01; + pub const HEADER: u8 = 0x02; +} diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 57eca414..cacc77a8 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -827,6 +827,7 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { SessionFrame::StreamData(StreamData { stream_id: stream_id(9), offset: varint(1024), + header: None, bytes: b"hello".to_vec(), fin: true, }), @@ -894,6 +895,7 @@ fn session_varint_fields_expand_at_expected_boundaries() { let frame = StreamData { stream_id: stream_id(64), offset: varint(16_384), + header: None, fin: true, bytes: b"abc".to_vec(), }; @@ -1008,6 +1010,7 @@ fn protocol_record_size_breakdown() { &[SessionFrame::StreamData(StreamData { stream_id: stream_id(1), offset: varint(0), + header: None, fin: false, bytes: Vec::new(), })], From d60ec48163143ccccf7ece2a9007b575942353da Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 8 Apr 2026 20:16:37 -0400 Subject: [PATCH 167/304] ql: streamdata with header --- ql-fsm/src/implementation/core.rs | 10 +- ql-fsm/src/lib.rs | 10 +- ql-fsm/src/session/mod.rs | 246 +++++++++++++++--------------- ql-fsm/src/session/state.rs | 5 +- ql-fsm/src/session/tests.rs | 101 ++++++++++-- ql-fsm/src/tests/proptest.rs | 10 +- ql-fsm/src/tests/session.rs | 27 +++- ql-runtime/src/command.rs | 3 +- ql-runtime/src/driver/mod.rs | 9 +- ql-runtime/src/handle/mod.rs | 7 +- ql-runtime/src/tests/handshake.rs | 7 +- ql-runtime/src/tests/heartbeat.rs | 2 +- ql-runtime/src/tests/mod.rs | 6 +- ql-runtime/src/tests/stream.rs | 14 +- ql-runtime/src/tests/unpair.rs | 6 +- 15 files changed, 284 insertions(+), 179 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index d17e4597..314d4506 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -1,7 +1,7 @@ use std::time::{Duration, Instant}; use bytes::Bytes; -use ql_wire::{self as wire, QlCrypto, SessionCloseCode, StreamId, WireDecode}; +use ql_wire::{self as wire, QlCrypto, RouteId, SessionCloseCode, StreamId, WireDecode}; use crate::{ session::SessionEvent, state::LinkState, NoSessionError, OutboundWrite, QlFsm, QlFsmError, @@ -149,9 +149,9 @@ pub fn kill_session(fsm: &mut QlFsm, _code: SessionCloseCode) { fsm.state.link = crate::state::LinkState::Idle; } -pub fn open_stream(fsm: &mut QlFsm) -> Result, NoSessionError> { +pub fn open_stream(fsm: &mut QlFsm, route_id: RouteId) -> Result, NoSessionError> { let state = fsm.state.link.connected_mut_or_err()?; - state.session.open_stream() + state.session.open_stream(route_id) } pub fn stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result, StreamError> { @@ -176,8 +176,8 @@ fn forward_session_event( pending_events: &mut std::collections::VecDeque, ) -> bool { match event { - SessionEvent::Opened(stream_id) => { - pending_events.push_back(QlFsmEvent::Opened(stream_id)); + SessionEvent::Opened { stream_id, route_id } => { + pending_events.push_back(QlFsmEvent::Opened { stream_id, route_id }); false } SessionEvent::Readable(stream_id) => { diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 0d0cc884..21e9a8df 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -34,8 +34,8 @@ use std::{ pub use bytes::Bytes; pub use error::*; use ql_wire::{ - PairingToken, PeerBundle, QlCrypto, QlIdentity, SessionClose, SessionCloseCode, StreamClose, - StreamId, + PairingToken, PeerBundle, QlCrypto, QlIdentity, RouteId, SessionClose, SessionCloseCode, + StreamClose, StreamId, }; pub use session::{StreamOps, StreamReadIter, StreamWriter}; @@ -72,7 +72,7 @@ pub enum QlFsmEvent { /// the peer changed connection state PeerStatusChanged(PeerStatus), /// a stream was opened - Opened(StreamId), + Opened { stream_id: StreamId, route_id: RouteId }, /// a stream has bytes ready to read Readable(StreamId), /// a stream has room for more local writes @@ -273,8 +273,8 @@ impl QlFsm { } /// opens a new outgoing stream - pub fn open_stream(&mut self) -> Result, NoSessionError> { - implementation::open_stream(self) + pub fn open_stream(&mut self, route_id: RouteId) -> Result, NoSessionError> { + implementation::open_stream(self, route_id) } /// returns a facade for an open stream diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 86c475f1..26df092a 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -16,10 +16,11 @@ mod tests; use std::time::{Duration, Instant}; use bytes::Bytes; -use indexmap::{map::Entry, IndexMap}; +use indexmap::IndexMap; use ql_wire::{ - CloseTarget, RecordAck, RecordSeq, SessionClose, SessionCloseCode, SessionFrame, - SessionRecordBuilder, StreamClose, StreamData, StreamId, StreamWindow, VarInt, WireError, + CloseTarget, RecordAck, RecordSeq, RouteId, SessionClose, SessionCloseCode, SessionFrame, + SessionRecordBuilder, StreamClose, StreamData, StreamHeader, StreamId, StreamWindow, VarInt, + WireError, }; use self::{ @@ -65,7 +66,10 @@ impl Default for SessionFsmConfig { #[derive(Debug, Clone, PartialEq, Eq)] pub enum SessionEvent { - Opened(StreamId), + Opened { + stream_id: StreamId, + route_id: RouteId, + }, Readable(StreamId), Writable(StreamId), Finished(StreamId), @@ -107,7 +111,7 @@ impl SessionFsm { } } - pub fn open_stream(&mut self) -> Result, NoSessionError> { + pub fn open_stream(&mut self, route_id: RouteId) -> Result, NoSessionError> { self.ensure_session_open()?; let stream_id = self .config @@ -118,6 +122,7 @@ impl SessionFsm { stream_id, StreamState::new( StreamRole::Initiator, + Some(route_id), self.config.stream_receive_buffer_size, self.config.initial_peer_stream_receive_window, ), @@ -174,12 +179,7 @@ impl SessionFsm { for frame in frames { let Ok(frame) = frame else { - self.fail_session( - SessionClose { - code: SessionCloseCode::PROTOCOL, - }, - &mut emit, - ); + self.fail_session(SessionCloseCode::PROTOCOL, &mut emit); return; }; ack_eliciting |= !matches!(frame, SessionFrame::Ack(_)); @@ -188,12 +188,14 @@ impl SessionFsm { SessionFrame::Ack(ack) => self.process_record_ack(&ack, &mut emit), SessionFrame::StreamData(frame) => { if self.handle_stream_data(frame, &mut emit).is_err() { + self.fail_session(SessionCloseCode::PROTOCOL, &mut emit); return; } } SessionFrame::StreamWindow(frame) => self.handle_stream_window(&frame, &mut emit), SessionFrame::StreamClose(frame) => { if self.handle_stream_close(&frame, &mut emit).is_err() { + self.fail_session(SessionCloseCode::PROTOCOL, &mut emit); return; } } @@ -256,12 +258,7 @@ impl SessionFsm { if !self.config.peer_timeout.is_zero() && self.state.last_inbound_at + self.config.peer_timeout <= self.state.now { - self.fail_session( - SessionClose { - code: SessionCloseCode::TIMEOUT, - }, - &mut emit, - ); + self.fail_session(SessionCloseCode::TIMEOUT, &mut emit); return; } if self.state.session_state == SessionState::Open @@ -469,6 +466,11 @@ impl SessionFsm { let frame = StreamData { stream_id, offset, + header: if matches!(stream.role, StreamRole::Initiator) && candidate.offset == 0 { + stream.route_id.map(|route_id| StreamHeader { route_id }) + } else { + None + }, fin: candidate.fin, bytes: stream.tx.ranged_bytes(candidate), }; @@ -558,42 +560,39 @@ impl SessionFsm { frame: StreamData, emit: &mut impl FnMut(SessionEvent), ) -> Result<(), ()> { - let stream_id = frame.stream_id; - let stream = match self.state.streams.entry(stream_id) { - Entry::Occupied(entry) => entry.into_mut(), - Entry::Vacant(entry) => { - match classify_missing_stream( - self.config.local_parity, - self.state.next_stream_ordinal, - stream_id, - &mut self.state.remote_stream_history, - ) { - MissingStreamAction::Create => {} - MissingStreamAction::Ignore => return Ok(()), - MissingStreamAction::FailProtocol => { - self.fail_session( - SessionClose { - code: SessionCloseCode::PROTOCOL, - }, - emit, - ); - return Err(()); - } - } + let StreamData { + stream_id, + offset, + header, + fin, + bytes, + } = frame; + let stream = match self.state.streams.get_mut(&stream_id) { + Some(stream) => stream, + None => match self.create_remote_stream(stream_id)? { + Some(stream) => stream, + None => return Ok(()), + }, + }; + + let frame_offset = offset.into_inner(); + let Some(frame_end) = frame_offset.checked_add(bytes.len() as u64) else { + return Err(()); + }; + let readable_before = stream.readable_bytes(); + let was_finished = matches!(stream.inbound_state, InboundState::Finished); - emit(SessionEvent::Opened(stream_id)); - entry.insert(StreamState::new( - StreamRole::Responder, - self.config.stream_receive_buffer_size, - self.config.initial_peer_stream_receive_window, - )) + let opened_route = match (stream.role, stream.route_id, header, frame_offset) { + (StreamRole::Responder, None, Some(header), 0) => { + stream.route_id = Some(header.route_id); + Some(header.route_id) } + (StreamRole::Initiator, _, Some(_), _) + | (StreamRole::Responder, None, Some(_), _) + | (StreamRole::Responder, None, None, 0) => return Err(()), + _ => None, }; - let frame_offset = frame.offset.into_inner(); - let frame_end = frame_offset - .checked_add(frame.bytes.len() as u64) - .ok_or(())?; match stream.inbound_state { InboundState::Open => {} InboundState::Discarding | InboundState::Closed(_) => return Ok(()), @@ -607,48 +606,49 @@ impl SessionFsm { // retransmitted data for an already-finished stream is fine as long as it stays // within the finalized byte range and any repeated FIN lands on that same offset. if (!frame.fin || frame_end == final_offset) && frame_end <= final_offset { + if let Some(route_id) = opened_route { + emit(SessionEvent::Opened { + stream_id, + route_id, + }); + if readable_before > 0 { + emit(SessionEvent::Readable(stream_id)); + } + emit(SessionEvent::Finished(stream_id)); + } return Ok(()); } - self.fail_session( - SessionClose { - code: SessionCloseCode::PROTOCOL, - }, - emit, - ); + return Err(()); } } - let was_readable = stream.readable_bytes() > 0; - let insert = stream.rx.insert(frame_offset, frame.fin, frame.bytes); - match insert { - Ok(outcome) => { - if !was_readable && outcome.newly_readable_bytes > 0 { - emit(SessionEvent::Readable(stream_id)); - } - if outcome.became_complete { - stream.inbound_state = InboundState::Finished; - emit(SessionEvent::Finished(stream_id)); - } - self.try_reap_stream(stream_id); - Ok(()) - } - Err( - StreamRxError::OutOfWindow - | StreamRxError::InconsistentFinalOffset - | StreamRxError::FinalOffsetBeforeBufferedData - | StreamRxError::BeyondFinalOffset - | StreamRxError::OffsetOverflow, - ) => { - self.fail_session( - SessionClose { - code: SessionCloseCode::PROTOCOL, - }, - emit, - ); - Err(()) - } + let outcome = stream.rx.insert(frame_offset, fin, bytes).map_err(|_| ())?; + + if outcome.became_complete { + stream.inbound_state = InboundState::Finished; + } + + if let Some(route_id) = opened_route { + emit(SessionEvent::Opened { + stream_id, + route_id, + }); + } + + if stream.route_id.is_some() && readable_before == 0 && stream.readable_bytes() > 0 { + emit(SessionEvent::Readable(stream_id)); + } + + if stream.route_id.is_some() + && !was_finished + && matches!(stream.inbound_state, InboundState::Finished) + { + emit(SessionEvent::Finished(stream_id)); } + + self.try_reap_stream(stream_id); + Ok(()) } fn handle_stream_window(&mut self, frame: &StreamWindow, emit: &mut impl FnMut(SessionEvent)) { @@ -671,42 +671,15 @@ impl SessionFsm { frame: &StreamClose, emit: &mut impl FnMut(SessionEvent), ) -> Result<(), ()> { - let mut created = false; - let stream = match self.state.streams.entry(frame.stream_id) { - Entry::Occupied(entry) => entry.into_mut(), - Entry::Vacant(entry) => { - match classify_missing_stream( - self.config.local_parity, - self.state.next_stream_ordinal, - frame.stream_id, - &mut self.state.remote_stream_history, - ) { - MissingStreamAction::Create => {} - MissingStreamAction::Ignore => return Ok(()), - MissingStreamAction::FailProtocol => { - self.fail_session( - SessionClose { - code: SessionCloseCode::PROTOCOL, - }, - emit, - ); - return Err(()); - } - } - - created = true; - entry.insert(StreamState::new( - StreamRole::Responder, - self.config.stream_receive_buffer_size, - self.config.initial_peer_stream_receive_window, - )) - } + let stream_id = frame.stream_id; + let stream = match self.state.streams.get_mut(&stream_id) { + Some(stream) => stream, + None => match self.create_remote_stream(stream_id)? { + Some(stream) => stream, + None => return Ok(()), + }, }; - if created { - emit(SessionEvent::Opened(frame.stream_id)); - } - if Self::target_affects_inbound(stream.role, frame.target) && !matches!( stream.inbound_state, @@ -836,7 +809,7 @@ impl SessionFsm { } } - fn fail_session(&mut self, close: SessionClose, emit: &mut impl FnMut(SessionEvent)) { + fn fail_session(&mut self, code: SessionCloseCode, emit: &mut impl FnMut(SessionEvent)) { if self.state.session_state == SessionState::Closed { return; } @@ -844,15 +817,46 @@ impl SessionFsm { self.state.session_state = SessionState::Closed; self.state.tracked_records.clear(); self.state.pending_control = Default::default(); - self.state.pending_control.close = Some(close.clone()); + self.state.pending_control.close = Some(SessionClose { code }); self.clear_streams(); - emit(SessionEvent::SessionClosed(close)); + emit(SessionEvent::SessionClosed(SessionClose { code })); } fn clear_streams(&mut self) { self.state.next_stream_index = 0; self.state.streams.clear(); } + + fn create_remote_stream( + &mut self, + stream_id: StreamId, + ) -> Result, ()> { + match classify_missing_stream( + self.config.local_parity, + self.state.next_stream_ordinal, + stream_id, + &mut self.state.remote_stream_history, + ) { + MissingStreamAction::Create => {} + MissingStreamAction::Ignore => return Ok(None), + MissingStreamAction::FailProtocol => { + return Err(()); + } + } + + let stream = self + .state + .streams + .entry(stream_id) + .insert_entry(StreamState::new( + StreamRole::Responder, + None, + self.config.stream_receive_buffer_size, + self.config.initial_peer_stream_receive_window, + )); + + Ok(Some(stream.into_mut())) + } } fn schedule_ack(ack_state: &mut AckState, due_at: Instant) { diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index 52b98134..58063499 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -1,7 +1,7 @@ use std::time::Instant; use indexmap::IndexMap; -use ql_wire::{CloseTarget, RecordSeq, SessionClose, StreamClose, StreamId}; +use ql_wire::{CloseTarget, RecordSeq, RouteId, SessionClose, StreamClose, StreamId}; use super::{ received_records::ReceivedRecords, remote_stream_history::RemoteStreamHistory, @@ -34,6 +34,7 @@ pub enum SessionState { #[derive(Debug)] pub struct StreamState { pub role: StreamRole, + pub route_id: Option, pub rx: StreamRx, pub tx: StreamTx, pub pending_close: Option, @@ -47,12 +48,14 @@ pub struct StreamState { impl StreamState { pub fn new( role: StreamRole, + route_id: Option, receive_buffer_size: u32, initial_peer_stream_receive_window: u32, ) -> Self { let receive_buffer_size = receive_buffer_size as usize; Self { role, + route_id, tx: StreamTx::new(), pending_close: None, peer_max_offset: u64::from(initial_peer_stream_receive_window), diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index e1549f5a..8ef6e559 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -2,8 +2,9 @@ use std::time::{Duration, Instant}; use bytes::Bytes; use ql_wire::{ - decode_session_frames, parse_session_frames, CloseTarget, RecordAck, RecordSeq, SessionFrame, - SessionRecordBuilder, StreamClose, StreamCloseCode, StreamData, StreamId, VarInt, XID, + decode_session_frames, parse_session_frames, CloseTarget, RecordAck, RecordSeq, RouteId, + SessionFrame, SessionRecordBuilder, StreamClose, StreamCloseCode, StreamData, StreamHeader, + StreamId, VarInt, XID, }; use super::{SessionEvent, SessionFsm, SessionFsmConfig}; @@ -21,8 +22,25 @@ fn offset(value: u64) -> VarInt { VarInt::from_u64(value).unwrap() } +fn route_id(value: u64) -> RouteId { + RouteId(VarInt::from_u64(value).unwrap()) +} + +fn header(value: u64) -> Option { + Some(StreamHeader { + route_id: route_id(value), + }) +} + +fn opened(stream_id: StreamId) -> SessionEvent { + SessionEvent::Opened { + stream_id, + route_id: route_id(1), + } +} + fn open_stream_id(fsm: &mut SessionFsm) -> StreamId { - fsm.open_stream().unwrap().stream_id() + fsm.open_stream(route_id(1)).unwrap().stream_id() } fn write_stream_bytes(fsm: &mut SessionFsm, stream_id: StreamId, bytes: &[u8]) -> usize { @@ -185,16 +203,14 @@ fn commit_stream_read_is_what_advances_stream_window() { let data = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), + header: header(1), fin: false, bytes: b"hi".to_vec(), })]; let events = receive_events(&mut fsm, now, seq(7), &data); assert_eq!( events, - vec![ - SessionEvent::Opened(stream_id), - SessionEvent::Readable(stream_id) - ] + vec![opened(stream_id), SessionEvent::Readable(stream_id)] ); let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); @@ -233,6 +249,7 @@ fn pure_ack_only_records_are_fire_and_forget() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), + header: header(1), fin: false, bytes: b"hi".to_vec(), })]; @@ -258,6 +275,7 @@ fn inbound_stream_data_emits_opened_and_readable() { let record = vec![SessionFrame::StreamData(ql_wire::StreamData { stream_id, offset: offset(0), + header: header(1), fin: true, bytes: b"hello".to_vec(), })]; @@ -266,7 +284,7 @@ fn inbound_stream_data_emits_opened_and_readable() { assert_eq!( events, vec![ - SessionEvent::Opened(stream_id), + opened(stream_id), SessionEvent::Readable(stream_id), SessionEvent::Finished(stream_id) ] @@ -311,7 +329,7 @@ fn stream_ids_follow_even_odd_xid_ordering() { }, now, ) - .open_stream() + .open_stream(route_id(1)) .unwrap() .stream_id(); let odd_id = SessionFsm::new( @@ -321,7 +339,7 @@ fn stream_ids_follow_even_odd_xid_ordering() { }, now, ) - .open_stream() + .open_stream(route_id(1)) .unwrap() .stream_id(); @@ -337,6 +355,7 @@ fn duplicate_stream_data_is_not_redelivered() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), + header: header(1), fin: false, bytes: b"hi".to_vec(), })]; @@ -361,7 +380,6 @@ fn duplicate_remote_close_after_reap_is_ignored() { assert_eq!( first, vec![ - SessionEvent::Opened(close.stream_id), SessionEvent::Closed(close.clone()), SessionEvent::WritableClosed(close), ] @@ -384,6 +402,7 @@ fn late_remote_stream_data_after_close_is_ignored() { let data = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), + header: header(1), fin: false, bytes: b"hello".to_vec(), })]; @@ -392,7 +411,6 @@ fn late_remote_stream_data_after_close_is_ignored() { assert_eq!( first, vec![ - SessionEvent::Opened(stream_id), SessionEvent::Closed(StreamClose { stream_id, target: CloseTarget::Both, @@ -418,6 +436,7 @@ fn duplicate_finished_remote_data_after_reap_is_ignored() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), + header: header(1), fin: true, bytes: b"hello".to_vec(), })]; @@ -426,7 +445,7 @@ fn duplicate_finished_remote_data_after_reap_is_ignored() { assert_eq!( first, vec![ - SessionEvent::Opened(stream_id), + opened(stream_id), SessionEvent::Readable(stream_id), SessionEvent::Finished(stream_id), ] @@ -445,6 +464,7 @@ fn duplicate_finished_remote_data_before_read_is_ignored() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), + header: header(1), fin: true, bytes: b"hello".to_vec(), })]; @@ -453,7 +473,7 @@ fn duplicate_finished_remote_data_before_read_is_ignored() { assert_eq!( first, vec![ - SessionEvent::Opened(stream_id), + opened(stream_id), SessionEvent::Readable(stream_id), SessionEvent::Finished(stream_id), ] @@ -480,15 +500,63 @@ fn out_of_order_remote_stream_first_observations_still_open_once_each() { })]; let first = receive_events(&mut fsm, now, seq(1), &close3); - assert!(first.contains(&SessionEvent::Opened(stream_id(3)))); + assert_eq!( + first, + vec![ + SessionEvent::Closed(StreamClose { + stream_id: stream_id(3), + target: CloseTarget::Both, + code: StreamCloseCode(1), + }), + SessionEvent::WritableClosed(StreamClose { + stream_id: stream_id(3), + target: CloseTarget::Both, + code: StreamCloseCode(1), + }), + ] + ); let second = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &close1); - assert!(second.contains(&SessionEvent::Opened(stream_id(1)))); + assert_eq!( + second, + vec![ + SessionEvent::Closed(StreamClose { + stream_id: stream_id(1), + target: CloseTarget::Both, + code: StreamCloseCode(2), + }), + SessionEvent::WritableClosed(StreamClose { + stream_id: stream_id(1), + target: CloseTarget::Both, + code: StreamCloseCode(2), + }), + ] + ); let third = receive_events(&mut fsm, now + Duration::from_millis(2), seq(3), &close3); assert!(third.is_empty()); } +#[test] +fn invalid_remote_stream_close_closes_session() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + + let invalid = vec![SessionFrame::StreamClose(StreamClose { + stream_id: stream_id(0), + target: CloseTarget::Both, + code: StreamCloseCode(9), + })]; + let events = receive_events(&mut fsm, now, seq(1), &invalid); + + assert_eq!( + events, + vec![SessionEvent::SessionClosed(ql_wire::SessionClose { + code: ql_wire::SessionCloseCode::PROTOCOL, + })] + ); +} + #[test] fn close_does_not_ack_rejected_record_seq() { let now = Instant::now(); @@ -503,6 +571,7 @@ fn close_does_not_ack_rejected_record_seq() { let invalid = vec![SessionFrame::StreamData(StreamData { stream_id: stream_id(0), offset: offset(0), + header: header(1), fin: false, bytes: b"bad".to_vec(), })]; diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index e0206e4b..e5840af0 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -10,6 +10,10 @@ use proptest_crate::{collection::vec, prelude::*, test_runner::TestCaseResult}; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; use super::*; + +fn test_route_id() -> ql_wire::RouteId { + ql_wire::RouteId(ql_wire::VarInt::from_u32(1)) +} use crate::{state::LinkState, PeerStatus, QlFsmError, QlFsmEvent, SessionWriteId}; const SLOT_COUNT: usize = 4; @@ -304,14 +308,14 @@ impl Runner { let _ = take_pending(&mut self.pending_b_to_a, *index); } Action::OpenStreamA(slot) => { - if let Ok(stream) = self.harness.a.fsm.open_stream() { + if let Ok(stream) = self.harness.a.fsm.open_stream(test_route_id()) { let stream_id = stream.stream_id(); self.slots_a[*slot] = Some(stream_id); self.known_streams.insert(stream_id); } } Action::OpenStreamB(slot) => { - if let Ok(stream) = self.harness.b.fsm.open_stream() { + if let Ok(stream) = self.harness.b.fsm.open_stream(test_route_id()) { let stream_id = stream.stream_id(); self.slots_b[*slot] = Some(stream_id); self.known_streams.insert(stream_id); @@ -459,7 +463,7 @@ impl Runner { QlFsmEvent::PeerStatusChanged(status) => { self.events_mut(side).note_peer_status(status); } - QlFsmEvent::Opened(stream_id) => { + QlFsmEvent::Opened { stream_id, .. } => { prop_assert!( self.known_streams.contains(&stream_id), "side {side:?} emitted Opened for unknown stream {stream_id:?}" diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 5368948d..52ea9a97 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -1,7 +1,7 @@ use std::time::Duration; use bytes::Bytes; -use ql_wire::{SessionClose, StreamId, VarInt}; +use ql_wire::{RouteId, SessionClose, StreamId, VarInt}; use super::*; use crate::{ @@ -12,8 +12,19 @@ fn stream_id(value: u32) -> StreamId { StreamId(VarInt::from_u32(value)) } +fn route_id(value: u32) -> RouteId { + RouteId(VarInt::from_u32(value)) +} + +fn opened(stream_id: StreamId) -> QlFsmEvent { + QlFsmEvent::Opened { + stream_id, + route_id: route_id(1), + } +} + fn open_stream_id(fsm: &mut QlFsm) -> StreamId { - fsm.open_stream().unwrap().stream_id() + fsm.open_stream(route_id(1)).unwrap().stream_id() } fn write_stream_bytes( @@ -68,7 +79,7 @@ fn connected_fsms_deliver_stream_data() { harness.pump(); - assert_eq!(harness.take_event_b(), Some(QlFsmEvent::Opened(stream_id))); + assert_eq!(harness.take_event_b(), Some(opened(stream_id))); assert_eq!( harness.take_event_b(), Some(QlFsmEvent::Readable(stream_id)) @@ -115,7 +126,7 @@ fn session_retransmit_uses_new_record_seq() { harness.on_timer_b(); harness.pump(); - assert_eq!(harness.take_event_b(), Some(QlFsmEvent::Opened(stream_id))); + assert_eq!(harness.take_event_b(), Some(opened(stream_id))); assert_eq!( harness.take_event_b(), Some(QlFsmEvent::Readable(stream_id)) @@ -160,7 +171,7 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { assert_eq!( harness.take_event_a(), - Some(QlFsmEvent::Opened(stream_id_b)) + Some(opened(stream_id_b)) ); assert_eq!( harness.take_event_a(), @@ -172,7 +183,7 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { ); assert_eq!( harness.take_event_b(), - Some(QlFsmEvent::Opened(stream_id_a)) + Some(opened(stream_id_a)) ); assert_eq!( harness.take_event_b(), @@ -189,7 +200,7 @@ fn disconnected_stream_operations_fail_with_no_session() { let mut harness = Harness::paired_known(QlFsmConfig::default()); let missing = stream_id(0); - assert!(matches!(harness.a.fsm.open_stream(), Err(NoSessionError))); + assert!(matches!(harness.a.fsm.open_stream(route_id(1)), Err(NoSessionError))); assert_eq!( write_stream_bytes(&mut harness.a.fsm, missing, b"queued"), Err(StreamError::NoSession) @@ -274,7 +285,7 @@ fn returned_session_write_is_reissued_with_new_record_seq() { harness.deliver_to_b(record); harness.pump(); - assert_eq!(harness.take_event_b(), Some(QlFsmEvent::Opened(stream_id))); + assert_eq!(harness.take_event_b(), Some(opened(stream_id))); assert_eq!( harness.take_event_b(), Some(QlFsmEvent::Readable(stream_id)) diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index b419d8ae..752288d7 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -1,5 +1,5 @@ use ql_fsm::NoSessionError; -use ql_wire::{CloseTarget, PairingToken, PeerBundle, StreamCloseCode, StreamId}; +use ql_wire::{CloseTarget, PairingToken, PeerBundle, RouteId, StreamCloseCode, StreamId}; use crate::{chunk_slot::ChunkSlotRx, ByteReader, QlStreamError}; @@ -16,6 +16,7 @@ pub(crate) enum RuntimeCommand { token: PairingToken, }, OpenStream { + route_id: RouteId, request_reader: ChunkSlotRx, request_terminal: oneshot::Sender, start: oneshot::Sender>, diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index ca86ab72..2ecb1919 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -150,6 +150,7 @@ impl DriverState { self.with_fsm_events(fsm, platform, |fsm| fsm.receive(now(), bytes, platform)); } RuntimeCommand::OpenStream { + route_id, request_reader, request_terminal, start, @@ -159,7 +160,7 @@ impl DriverState { return; }; - let mut stream_ops = match fsm.open_stream() { + let mut stream_ops = match fsm.open_stream(route_id) { Ok(stream_ops) => stream_ops, Err(error) => { let _ = start.send(Err(error)); @@ -267,8 +268,8 @@ impl DriverState { platform.handle_peer_status(peer, status); } } - QlFsmEvent::Opened(stream_id) => { - self.handle_opened_stream(fsm, platform, stream_id); + QlFsmEvent::Opened { stream_id, route_id } => { + self.handle_opened_stream(fsm, platform, stream_id, route_id); } QlFsmEvent::Readable(stream_id) => { self.handle_inbound_readable(fsm, stream_id); @@ -294,6 +295,7 @@ impl DriverState { fsm: &mut QlFsm, platform: &P, stream_id: StreamId, + route_id: ql_wire::RouteId, ) { let Some(runtime_tx) = self.runtime_tx.upgrade() else { if let Ok(mut stream) = fsm.stream(stream_id) { @@ -318,6 +320,7 @@ impl DriverState { platform.handle_inbound(QlStream { stream_id, + route_id, reader: ByteReader::new( stream_id, CloseTarget::Origin, diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index 9d339dad..c6274192 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -2,7 +2,7 @@ mod reader; mod writer; use ql_fsm::NoSessionError; -use ql_wire::{CloseTarget, PairingToken, PeerBundle, StreamId}; +use ql_wire::{CloseTarget, PairingToken, PeerBundle, RouteId, StreamId}; pub use self::{reader::*, writer::*}; use crate::{chunk_slot, command::RuntimeCommand}; @@ -10,6 +10,7 @@ use crate::{chunk_slot, command::RuntimeCommand}; #[derive(Debug)] pub struct QlStream { pub stream_id: StreamId, + pub route_id: RouteId, pub writer: ByteWriter, pub reader: ByteReader, } @@ -44,12 +45,13 @@ impl RuntimeHandle { self.send(RuntimeCommand::Incoming(bytes)); } - pub async fn open_stream(&self) -> Result { + pub async fn open_stream(&self, route_id: RouteId) -> Result { let (request_reader, request_writer) = chunk_slot::new(); let (request_terminal_tx, request_terminal_rx) = oneshot::channel(); let (start_tx, start_rx) = oneshot::channel(); self.send(RuntimeCommand::OpenStream { + route_id, request_reader, request_terminal: request_terminal_tx, start: start_tx, @@ -60,6 +62,7 @@ impl RuntimeHandle { Ok(QlStream { stream_id, + route_id, writer: ByteWriter::new( stream_id, CloseTarget::Origin, diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index 9483b631..923d46c2 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -47,7 +47,10 @@ async fn opening_stream_requires_connection() { tokio::task::spawn_local(async move { runtime_b.run().await }); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - assert!(matches!(handle_a.open_stream().await, Err(NoSessionError))); + assert!(matches!( + handle_a.open_stream(test_route_id()).await, + Err(NoSessionError) + )); }) .await; } @@ -112,7 +115,7 @@ async fn rejected_session_write_is_reissued() { request }); - let mut stream = handle_a.open_stream().await.unwrap(); + let mut stream = handle_a.open_stream(test_route_id()).await.unwrap(); stream .writer .write(Bytes::from_static(b"retry")) diff --git a/ql-runtime/src/tests/heartbeat.rs b/ql-runtime/src/tests/heartbeat.rs index 412c9393..77412a2c 100644 --- a/ql-runtime/src/tests/heartbeat.rs +++ b/ql-runtime/src/tests/heartbeat.rs @@ -50,7 +50,7 @@ async fn session_timeout_disconnects_and_fails_pending_open() { drop_flag.store(true, Ordering::Relaxed); - let mut pending = handle_a.open_stream().await.unwrap(); + let mut pending = handle_a.open_stream(test_route_id()).await.unwrap(); pending.writer.finish(); await_status(&status_a, identity_b.xid, PeerStatus::Disconnected).await; diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 673831a6..cdae79ad 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -16,7 +16,7 @@ use ql_fsm::PeerStatus; use ql_wire::{ generate_identity, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, PairingToken, PeerBundle, QlAead, QlHash, QlIdentity, QlKem, QlRandom, RecordHeader, - RecordType, SessionKey, WireDecode, XID, + RecordType, RouteId, SessionKey, VarInt, WireDecode, XID, }; use sha2::{Digest, Sha256}; use tokio::{task::LocalSet, time::Sleep}; @@ -39,6 +39,10 @@ struct StatusEvent { status: PeerStatus, } +fn test_route_id() -> RouteId { + RouteId(VarInt::from_u32(1)) +} + #[derive(Debug, Clone)] struct WriteStats { active: Arc, diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index ffb5fe53..7e378781 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -44,7 +44,7 @@ async fn open_stream_duplex_happy_path() { writer.finish(); }); - let mut stream = handle_a.open_stream().await.unwrap(); + let mut stream = handle_a.open_stream(test_route_id()).await.unwrap(); stream .writer .write(Bytes::from_static(&[1, 2])) @@ -116,7 +116,7 @@ async fn reader_exposes_bounded_chunk_reads() { inbound.writer.finish(); }); - let mut stream = handle_a.open_stream().await.unwrap(); + let mut stream = handle_a.open_stream(test_route_id()).await.unwrap(); stream .writer .write(Bytes::from_static(&[1, 2, 3, 4])) @@ -172,7 +172,7 @@ async fn large_stream_payload_round_trips() { done_tx.send(request_data).await.unwrap(); }); - let mut stream = handle_a.open_stream().await.unwrap(); + let mut stream = handle_a.open_stream(test_route_id()).await.unwrap(); stream .writer .write(Bytes::from(payload.clone())) @@ -224,7 +224,7 @@ async fn dropping_responder_closes_initiator_response() { drop(stream.reader); }); - let mut stream = handle_a.open_stream().await.unwrap(); + let mut stream = handle_a.open_stream(test_route_id()).await.unwrap(); stream.writer.finish(); let err = next_chunk(&mut stream.reader).await.unwrap_err(); @@ -280,7 +280,7 @@ async fn dropping_inbound_reader_cancels_remote_writer() { writer.finish(); }); - let mut stream = handle_a.open_stream().await.unwrap(); + let mut stream = handle_a.open_stream(test_route_id()).await.unwrap(); stream.writer.finish(); assert_eq!( next_chunk(&mut stream.reader).await.unwrap(), @@ -338,7 +338,7 @@ async fn max_concurrent_message_writes_is_respected() { for i in 0..4u8 { let handle = handle_a.clone(); tasks.push(tokio::task::spawn_local(async move { - let mut stream = handle.open_stream().await.unwrap(); + let mut stream = handle.open_stream(test_route_id()).await.unwrap(); stream.writer.write(Bytes::from(vec![i; 8])).await.unwrap(); stream.writer.finish(); assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); @@ -412,7 +412,7 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { received_request }); - let mut stream = handle_a.open_stream().await.unwrap(); + let mut stream = handle_a.open_stream(test_route_id()).await.unwrap(); stream .writer .write(Bytes::from(request_payload.clone())) diff --git a/ql-runtime/src/tests/unpair.rs b/ql-runtime/src/tests/unpair.rs index 600ee3cd..133c156c 100644 --- a/ql-runtime/src/tests/unpair.rs +++ b/ql-runtime/src/tests/unpair.rs @@ -32,12 +32,12 @@ async fn unpair_clears_remote_peer_and_aborts_active_stream() { assert!(matches!(second, Ok(None) | Err(QlError::Cancelled))); }); - let mut stream = handle_a.open_stream().await.unwrap(); + let mut stream = handle_a.open_stream(test_route_id()).await.unwrap(); stream.request.write_all(&[1, 2, 3, 4]).await.unwrap(); handle_a.unpair().unwrap(); assert!(matches!( - handle_a.open_stream().await, + handle_a.open_stream(test_route_id()).await, Err(QlError::NoPeerBound) )); @@ -48,7 +48,7 @@ async fn unpair_clears_remote_peer_and_aborts_active_stream() { let open_err_b = tokio::time::timeout(std::time::Duration::from_secs(2), async { loop { - match handle_b.open_stream().await { + match handle_b.open_stream(test_route_id()).await { Err(QlError::NoPeerBound) => return, _ => tokio::time::sleep(std::time::Duration::from_millis(10)).await, } From 5fe64c5140704d2c8beef29051c83580fdfd55c5 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 8 Apr 2026 21:19:11 -0400 Subject: [PATCH 168/304] ql-rpc: remove methodid --- Cargo.lock | 1 - ql-rpc/Cargo.toml | 1 - ql-rpc/src/codec.rs | 2 +- ql-rpc/src/error.rs | 27 ---------------- ql-rpc/src/header.rs | 42 ------------------------- ql-rpc/src/lib.rs | 5 +-- ql-rpc/src/rpc/notification.rs | 13 ++------ ql-rpc/src/rpc/request.rs | 13 ++------ ql-rpc/src/rpc/request_with_progress.rs | 13 ++------ ql-rpc/src/rpc/subscription.rs | 13 ++------ ql-runtime/src/rpc/mod.rs | 18 +++++++---- ql-runtime/src/tests/rpc.rs | 33 ++++++++++--------- 12 files changed, 42 insertions(+), 139 deletions(-) delete mode 100644 ql-rpc/src/header.rs diff --git a/Cargo.lock b/Cargo.lock index 4cdd8854..ed587aa8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2231,7 +2231,6 @@ name = "ql-rpc" version = "0.1.0" dependencies = [ "bytes", - "ql-wire", ] [[package]] diff --git a/ql-rpc/Cargo.toml b/ql-rpc/Cargo.toml index 8fe76b8c..897ee9ef 100644 --- a/ql-rpc/Cargo.toml +++ b/ql-rpc/Cargo.toml @@ -7,4 +7,3 @@ license = "Proprietary" [dependencies] bytes = { version = "1" } -ql-wire = { path = "../ql-wire" } diff --git a/ql-rpc/src/codec.rs b/ql-rpc/src/codec.rs index ae896561..f0caece3 100644 --- a/ql-rpc/src/codec.rs +++ b/ql-rpc/src/codec.rs @@ -81,7 +81,7 @@ impl ChunkQueue { &chunk[..chunk.len().min(limit)] } - pub(crate) fn advance_inner(&mut self, mut cnt: usize) { + fn advance_inner(&mut self, mut cnt: usize) { assert!(cnt <= self.remaining, "advanced past buffered data"); self.remaining -= cnt; while cnt > 0 { diff --git a/ql-rpc/src/error.rs b/ql-rpc/src/error.rs index 8675c493..cc236898 100644 --- a/ql-rpc/src/error.rs +++ b/ql-rpc/src/error.rs @@ -1,16 +1,7 @@ -use ql_wire::StreamCloseCode; - -use crate::MethodId; - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RpcError { Truncated, LengthOverflow, - InvalidVersion(u8), - UnexpectedMethod { - expected: MethodId, - actual: MethodId, - }, UnexpectedFrameKind(u8), MissingResponse, TrailingBytes, @@ -21,10 +12,6 @@ impl std::fmt::Display for RpcError { match self { Self::Truncated => f.write_str("truncated rpc payload"), Self::LengthOverflow => f.write_str("rpc payload length overflow"), - Self::InvalidVersion(version) => write!(f, "invalid rpc version {version}"), - Self::UnexpectedMethod { expected, actual } => { - write!(f, "unexpected rpc method {actual:?}, expected {expected:?}") - } Self::UnexpectedFrameKind(kind) => write!(f, "unexpected rpc frame kind {kind}"), Self::MissingResponse => f.write_str("missing terminal rpc response"), Self::TrailingBytes => f.write_str("trailing rpc bytes"), @@ -34,20 +21,6 @@ impl std::fmt::Display for RpcError { impl std::error::Error for RpcError {} -impl RpcError { - pub const fn close_code(self) -> StreamCloseCode { - match self { - Self::UnexpectedMethod { .. } => StreamCloseCode(404), - Self::Truncated - | Self::LengthOverflow - | Self::InvalidVersion(_) - | Self::UnexpectedFrameKind(_) - | Self::MissingResponse - | Self::TrailingBytes => StreamCloseCode(400), - } - } -} - #[derive(Debug, Clone, PartialEq, Eq)] pub enum RpcCodecError { Rpc(RpcError), diff --git a/ql-rpc/src/header.rs b/ql-rpc/src/header.rs deleted file mode 100644 index bb6cafac..00000000 --- a/ql-rpc/src/header.rs +++ /dev/null @@ -1,42 +0,0 @@ -use bytes::{Buf, BufMut}; - -use crate::{MethodId, RpcCodec, RpcError, RPC_VERSION}; - -const HEADER_SIZE: usize = 1 + 8; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct RpcHeader { - pub version: u8, - pub method: MethodId, -} - -impl RpcHeader { - pub const WIRE_SIZE: usize = HEADER_SIZE; - - pub const fn new(method: MethodId) -> Self { - Self { - version: RPC_VERSION, - method, - } - } -} - -impl RpcCodec for RpcHeader { - type Error = RpcError; - - fn encode_value(&self, out: &mut B) -> Result<(), Self::Error> { - out.put_u8(self.version); - out.put_u64_le(self.method.0); - Ok(()) - } - - fn decode_value(bytes: &mut B) -> Result { - let version = bytes.try_get_u8().map_err(|_| RpcError::Truncated)?; - if version != RPC_VERSION { - return Err(RpcError::InvalidVersion(version)); - } - - let method = MethodId(bytes.try_get_u64_le().map_err(|_| RpcError::Truncated)?); - Ok(Self { version, method }) - } -} diff --git a/ql-rpc/src/lib.rs b/ql-rpc/src/lib.rs index ed67d311..994e473d 100644 --- a/ql-rpc/src/lib.rs +++ b/ql-rpc/src/lib.rs @@ -4,17 +4,14 @@ use bytes::{Buf, BufMut}; pub(crate) mod codec; mod error; -pub mod header; pub mod rpc; pub use error::*; pub use rpc::*; -pub const RPC_VERSION: u8 = 1; - #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] -pub struct MethodId(pub u64); +pub struct MethodId(pub u32); pub trait RpcCodec: Sized { type Error; diff --git a/ql-rpc/src/rpc/notification.rs b/ql-rpc/src/rpc/notification.rs index fb9b93f8..ae452d3d 100644 --- a/ql-rpc/src/rpc/notification.rs +++ b/ql-rpc/src/rpc/notification.rs @@ -12,9 +12,6 @@ pub fn encode_event( event: &M::Event, out: &mut impl BufMut, ) -> Result<(), M::Error> { - crate::header::RpcHeader::new(M::METHOD) - .encode_value(out) - .expect("rpc header encoding cannot fail"); event.encode_value(out) } @@ -27,7 +24,7 @@ mod tests { use bytes::{Buf, BufMut}; use super::{decode_event, encode_event, Notification}; - use crate::{header::RpcHeader, MethodId, RpcCodec}; + use crate::{MethodId, RpcCodec}; #[derive(Debug, Clone, PartialEq, Eq)] struct BytesValue(Vec); @@ -54,15 +51,11 @@ mod tests { } #[test] - fn event_round_trip_preserves_header_and_payload() { + fn event_round_trip_preserves_payload() { let mut encoded = Vec::new(); encode_event::(&BytesValue(b"hello".to_vec()), &mut encoded).unwrap(); - - let mut body = encoded.as_slice(); - let header = RpcHeader::decode_value(&mut body).unwrap(); - assert_eq!(header.method, Notify::METHOD); assert_eq!( - decode_event::(body).unwrap(), + decode_event::(&encoded).unwrap(), BytesValue(b"hello".to_vec()) ); } diff --git a/ql-rpc/src/rpc/request.rs b/ql-rpc/src/rpc/request.rs index 6c9a0d22..0483dd2d 100644 --- a/ql-rpc/src/rpc/request.rs +++ b/ql-rpc/src/rpc/request.rs @@ -13,9 +13,6 @@ pub fn encode_request( request: &M::Request, out: &mut impl BufMut, ) -> Result<(), M::Error> { - crate::header::RpcHeader::new(M::METHOD) - .encode_value(out) - .expect("rpc header encoding cannot fail"); request.encode_value(out) } @@ -41,7 +38,7 @@ mod tests { use bytes::{Buf, BufMut}; use super::*; - use crate::{header::RpcHeader, MethodId, RpcCodec}; + use crate::{MethodId, RpcCodec}; #[derive(Debug, Clone, PartialEq, Eq)] struct BytesValue(Vec); @@ -69,15 +66,11 @@ mod tests { } #[test] - fn request_round_trip_preserves_header_and_payload() { + fn request_round_trip_preserves_payload() { let mut encoded = Vec::new(); encode_request::(&BytesValue(b"hello".to_vec()), &mut encoded).unwrap(); - - let mut body = encoded.as_slice(); - let header = RpcHeader::decode_value(&mut body).unwrap(); - assert_eq!(header.method, Echo::METHOD); assert_eq!( - decode_request::(body).unwrap(), + decode_request::(&encoded).unwrap(), BytesValue(b"hello".to_vec()) ); } diff --git a/ql-rpc/src/rpc/request_with_progress.rs b/ql-rpc/src/rpc/request_with_progress.rs index 55d576d7..03d39fd6 100644 --- a/ql-rpc/src/rpc/request_with_progress.rs +++ b/ql-rpc/src/rpc/request_with_progress.rs @@ -92,9 +92,6 @@ pub fn encode_request( request: &M::Request, out: &mut impl BufMut, ) -> Result<(), M::Error> { - crate::header::RpcHeader::new(M::METHOD) - .encode_value(out) - .expect("rpc header encoding cannot fail"); request.encode_value(out) } @@ -136,7 +133,7 @@ mod tests { decode_request, encode_progress, encode_request, encode_response, ReadStep, RequestWithProgress, ResponseReader, }; - use crate::{header::RpcHeader, MethodId, RpcCodec, RpcCodecError, RpcError}; + use crate::{MethodId, RpcCodec, RpcCodecError, RpcError}; #[derive(Debug, Clone, PartialEq, Eq)] struct BytesValue(Vec); @@ -165,15 +162,11 @@ mod tests { } #[test] - fn request_round_trip_preserves_header_and_payload() { + fn request_round_trip_preserves_payload() { let mut encoded = Vec::new(); encode_request::(&BytesValue(b"watch".to_vec()), &mut encoded).unwrap(); - - let mut body = encoded.as_slice(); - let header = RpcHeader::decode_value(&mut body).unwrap(); - assert_eq!(header.method, Watch::METHOD); assert_eq!( - decode_request::(body).unwrap(), + decode_request::(&encoded).unwrap(), BytesValue(b"watch".to_vec()) ); } diff --git a/ql-rpc/src/rpc/subscription.rs b/ql-rpc/src/rpc/subscription.rs index 1cbbf4b8..78f398e4 100644 --- a/ql-rpc/src/rpc/subscription.rs +++ b/ql-rpc/src/rpc/subscription.rs @@ -74,9 +74,6 @@ pub fn encode_request( request: &M::Request, out: &mut impl BufMut, ) -> Result<(), M::Error> { - crate::header::RpcHeader::new(M::METHOD) - .encode_value(out) - .expect("rpc header encoding cannot fail"); request.encode_value(out) } @@ -103,7 +100,7 @@ mod tests { decode_request, encode_end, encode_item, encode_request, ReadStep, ResponseReader, Subscription, }; - use crate::{header::RpcHeader, MethodId, RpcCodec}; + use crate::{MethodId, RpcCodec}; #[derive(Debug, Clone, PartialEq, Eq)] struct BytesValue(Vec); @@ -131,15 +128,11 @@ mod tests { } #[test] - fn request_round_trip_preserves_header_and_payload() { + fn request_round_trip_preserves_payload() { let mut encoded = Vec::new(); encode_request::(&BytesValue(b"watch".to_vec()), &mut encoded).unwrap(); - - let mut body = encoded.as_slice(); - let header = RpcHeader::decode_value(&mut body).unwrap(); - assert_eq!(header.method, Feed::METHOD); assert_eq!( - decode_request::(body).unwrap(), + decode_request::(&encoded).unwrap(), BytesValue(b"watch".to_vec()) ); } diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index 48b06d23..aa684fd0 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -12,6 +12,7 @@ use ql_rpc::{ subscription::{self as rpc_subscription, Subscription as SubscriptionRpc}, RpcError, }; +use ql_wire::{RouteId, VarInt}; pub use self::{error::*, request_with_progress::*, subscription::*}; use crate::{ByteReader, QlStreamError, RuntimeHandle}; @@ -28,7 +29,7 @@ impl RpcHandle { { let mut payload = Vec::new(); notification::encode_event::(event, &mut payload).map_err(RpcCallError::Codec)?; - let response = self.start_request(payload).await?; + let response = self.start_request(M::METHOD, payload).await?; let response = read_all(response).await?; if response.is_empty() { Ok(()) @@ -46,7 +47,7 @@ impl RpcHandle { { let mut payload = Vec::new(); request::encode_request::(request, &mut payload).map_err(RpcCallError::Codec)?; - let response = self.start_request(payload).await?; + let response = self.start_request(M::METHOD, payload).await?; let response = read_all(response).await?; request::decode_response::(&response).map_err(RpcCallError::Codec) } @@ -61,7 +62,7 @@ impl RpcHandle { let mut payload = Vec::new(); rpc_subscription::encode_request::(request, &mut payload) .map_err(RpcCallError::Codec)?; - let response = self.start_request(payload).await?; + let response = self.start_request(M::METHOD, payload).await?; Ok(Subscription { stream: response, reader: Some(rpc_subscription::ResponseReader::new()), @@ -78,7 +79,7 @@ impl RpcHandle { let mut payload = Vec::new(); rpc_request_with_progress::encode_request::(request, &mut payload) .map_err(RpcCallError::Codec)?; - let response = self.start_request(payload).await?; + let response = self.start_request(M::METHOD, payload).await?; Ok(ProgressCall { stream: response, reader: Some(rpc_request_with_progress::ResponseReader::new()), @@ -86,8 +87,13 @@ impl RpcHandle { }) } - async fn start_request(&self, payload: Vec) -> Result> { - let mut stream = self.inner.open_stream().await?; + async fn start_request( + &self, + method: ql_rpc::MethodId, + payload: Vec, + ) -> Result> { + let route_id = RouteId(VarInt::from_u32(method.0)); + let mut stream = self.inner.open_stream(route_id).await?; stream.writer.write(Bytes::from(payload)).await?; stream.writer.finish(); Ok(stream.reader) diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 0912992f..983b77f2 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -2,6 +2,7 @@ use std::time::Duration; use bytes::{Buf, BufMut, Bytes}; use futures_lite::StreamExt; +use ql_wire::RouteId; use super::*; @@ -76,12 +77,12 @@ async fn rpc_request_round_trips() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); let request = read_all(inbound.reader).await.unwrap(); - let mut body = request.as_slice(); - let header = - ::decode_value(&mut body).unwrap(); - assert_eq!(header.method, ::METHOD); assert_eq!( - ql_rpc::request::decode_request::(body).unwrap(), + inbound.route_id, + route_id(::METHOD) + ); + assert_eq!( + ql_rpc::request::decode_request::(&request).unwrap(), BytesValue(b"hello".to_vec()) ); @@ -135,15 +136,12 @@ async fn rpc_subscription_streams_events() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); let request = read_all(inbound.reader).await.unwrap(); - let mut body = request.as_slice(); - let header = - ::decode_value(&mut body).unwrap(); assert_eq!( - header.method, - ::METHOD + inbound.route_id, + route_id(::METHOD) ); assert_eq!( - ql_rpc::subscription::decode_request::(body).unwrap(), + ql_rpc::subscription::decode_request::(&request).unwrap(), BytesValue(b"watch".to_vec()) ); @@ -209,15 +207,12 @@ async fn rpc_request_with_progress_supports_progress_then_await() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); let request = read_all(inbound.reader).await.unwrap(); - let mut body = request.as_slice(); - let header = - ::decode_value(&mut body).unwrap(); assert_eq!( - header.method, - ::METHOD + inbound.route_id, + route_id(::METHOD) ); assert_eq!( - ql_rpc::request_with_progress::decode_request::(body).unwrap(), + ql_rpc::request_with_progress::decode_request::(&request).unwrap(), BytesValue(b"logo".to_vec()) ); @@ -261,3 +256,7 @@ async fn rpc_request_with_progress_supports_progress_then_await() { }) .await; } + +fn route_id(method: ql_rpc::MethodId) -> RouteId { + RouteId(ql_wire::VarInt::from_u32(method.0)) +} From 4546240913195887408c2594d100119b96beedae Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 8 Apr 2026 21:21:29 -0400 Subject: [PATCH 169/304] fmt + clippy --- ql-fsm/src/implementation/core.rs | 10 ++++++++-- ql-fsm/src/implementation/handshake/xx.rs | 5 +++-- ql-fsm/src/lib.rs | 5 ++++- ql-fsm/src/tests/session.rs | 15 ++++++--------- ql-runtime/src/driver/mod.rs | 5 ++++- ql-wire/src/codec.rs | 4 ++-- ql-wire/src/encrypted/stream_data.rs | 5 +---- 7 files changed, 28 insertions(+), 21 deletions(-) diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index 314d4506..f39f56c5 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -176,8 +176,14 @@ fn forward_session_event( pending_events: &mut std::collections::VecDeque, ) -> bool { match event { - SessionEvent::Opened { stream_id, route_id } => { - pending_events.push_back(QlFsmEvent::Opened { stream_id, route_id }); + SessionEvent::Opened { + stream_id, + route_id, + } => { + pending_events.push_back(QlFsmEvent::Opened { + stream_id, + route_id, + }); false } SessionEvent::Readable(stream_id) => { diff --git a/ql-fsm/src/implementation/handshake/xx.rs b/ql-fsm/src/implementation/handshake/xx.rs index f346dae8..254d10b7 100644 --- a/ql-fsm/src/implementation/handshake/xx.rs +++ b/ql-fsm/src/implementation/handshake/xx.rs @@ -159,8 +159,9 @@ pub fn disarm_pairing(fsm: &mut QlFsm) { pub fn should_ignore_inbound(fsm: &QlFsm, message: &Xx1) -> bool { match &fsm.state.link { LinkState::Idle | LinkState::Connected(_) => false, - LinkState::IkInitiator(_) | LinkState::KkInitiator(_) => true, - LinkState::XxResponder(_) => true, + LinkState::IkInitiator(_) | LinkState::KkInitiator(_) | LinkState::XxResponder(_) => { + true + } LinkState::XxInitiator(state) => { if state.handshake.pairing_token() != message.header.pairing_token { return false; diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 21e9a8df..6cb3021f 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -72,7 +72,10 @@ pub enum QlFsmEvent { /// the peer changed connection state PeerStatusChanged(PeerStatus), /// a stream was opened - Opened { stream_id: StreamId, route_id: RouteId }, + Opened { + stream_id: StreamId, + route_id: RouteId, + }, /// a stream has bytes ready to read Readable(StreamId), /// a stream has room for more local writes diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 52ea9a97..71ca159a 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -169,10 +169,7 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { harness.pump(); - assert_eq!( - harness.take_event_a(), - Some(opened(stream_id_b)) - ); + assert_eq!(harness.take_event_a(), Some(opened(stream_id_b))); assert_eq!( harness.take_event_a(), Some(QlFsmEvent::Readable(stream_id_b)) @@ -181,10 +178,7 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { read_stream_all(&mut harness.a.fsm, stream_id_b), b"from-b".to_vec() ); - assert_eq!( - harness.take_event_b(), - Some(opened(stream_id_a)) - ); + assert_eq!(harness.take_event_b(), Some(opened(stream_id_a))); assert_eq!( harness.take_event_b(), Some(QlFsmEvent::Readable(stream_id_a)) @@ -200,7 +194,10 @@ fn disconnected_stream_operations_fail_with_no_session() { let mut harness = Harness::paired_known(QlFsmConfig::default()); let missing = stream_id(0); - assert!(matches!(harness.a.fsm.open_stream(route_id(1)), Err(NoSessionError))); + assert!(matches!( + harness.a.fsm.open_stream(route_id(1)), + Err(NoSessionError) + )); assert_eq!( write_stream_bytes(&mut harness.a.fsm, missing, b"queued"), Err(StreamError::NoSession) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 2ecb1919..30b09fc4 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -268,7 +268,10 @@ impl DriverState { platform.handle_peer_status(peer, status); } } - QlFsmEvent::Opened { stream_id, route_id } => { + QlFsmEvent::Opened { + stream_id, + route_id, + } => { self.handle_opened_stream(fsm, platform, stream_id, route_id); } QlFsmEvent::Readable(stream_id) => { diff --git a/ql-wire/src/codec.rs b/ql-wire/src/codec.rs index a1f276e6..c2e4ba3c 100644 --- a/ql-wire/src/codec.rs +++ b/ql-wire/src/codec.rs @@ -173,7 +173,7 @@ impl WireEncode for bool { impl WireEncode for Option { fn encoded_len(&self) -> usize { - 1 + self.as_ref().map_or(0, |inner| inner.encoded_len()) + 1 + self.as_ref().map_or(0, WireEncode::encoded_len) } fn encode(&self, out: &mut W) { @@ -181,7 +181,7 @@ impl WireEncode for Option { None => out.put_u8(0), Some(inner) => { out.put_u8(1); - inner.encode(out) + inner.encode(out); } } } diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs index a0ab5011..bc825247 100644 --- a/ql-wire/src/encrypted/stream_data.rs +++ b/ql-wire/src/encrypted/stream_data.rs @@ -60,10 +60,7 @@ impl WireEncode for StreamData { self.stream_id.encoded_len() + self.offset.encoded_len() + size_of::() - + self - .header - .as_ref() - .map_or(0, |header| header.encoded_len()) + + self.header.as_ref().map_or(0, WireEncode::encoded_len) + self.bytes.len() } From 2096ebf33c9a9d8664f8a03632c45bbb9e542d18 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 8 Apr 2026 21:58:11 -0400 Subject: [PATCH 170/304] ql: docs --- ql-fsm/src/lib.rs | 8 ++++---- ql-runtime/src/handle/mod.rs | 7 +++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 6cb3021f..4f09640d 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -169,7 +169,7 @@ impl QlFsm { } } - /// binds or replaces the remote peer + /// binds the remote peer pub fn bind_peer(&mut self, peer: PeerBundle) { implementation::handle_bind_peer(self, peer); } @@ -190,7 +190,7 @@ impl QlFsm { implementation::handle_disarm_pairing(self); } - /// starts or replaces an outbound xx handshake using the supplied pairing token + /// starts an outbound xx handshake using the supplied pairing token pub fn connect_xx( &mut self, now: FsmTime, @@ -201,13 +201,13 @@ impl QlFsm { implementation::handle_connect_xx(self, token, crypto) } - /// starts or replaces an IK handshake with the currently bound peer + /// starts an IK handshake with the currently bound peer pub fn connect_ik(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { self.state.now = now; implementation::handle_connect_ik(self, crypto) } - /// starts or replaces a KK handshake with the currently bound peer + /// starts a KK handshake with the currently bound peer pub fn connect_kk(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { self.state.now = now; implementation::handle_connect_kk(self, crypto) diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index c6274192..0063b169 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -21,30 +21,37 @@ pub struct RuntimeHandle { } impl RuntimeHandle { + /// binds the remote peer pub fn bind_peer(&self, peer: PeerBundle) { self.send(RuntimeCommand::BindPeer { peer }); } + /// starts an IK handshake with the bound peer pub fn connect(&self) { self.send(RuntimeCommand::Connect); } + /// arms acceptance of inbound xx pairings for a single token pub fn arm_pairing(&self, token: PairingToken) { self.send(RuntimeCommand::ArmPairing { token }); } + /// disarms inbound xx pairing pub fn disarm_pairing(&self) { self.send(RuntimeCommand::DisarmPairing); } + /// starts an outbound xx handshake using the supplied pairing token pub fn start_pairing(&self, token: PairingToken) { self.send(RuntimeCommand::StartPairing { token }); } + /// hands inbound transport bytes to the runtime pub fn send_incoming(&self, bytes: Vec) { self.send(RuntimeCommand::Incoming(bytes)); } + /// opens a new stream on the active encrypted session pub async fn open_stream(&self, route_id: RouteId) -> Result { let (request_reader, request_writer) = chunk_slot::new(); let (request_terminal_tx, request_terminal_rx) = oneshot::channel(); From 1d8b0a2285db4e555fd7970972f185c93a4a02a2 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 9 Apr 2026 07:06:45 -0400 Subject: [PATCH 171/304] ql-fsm: granular errors --- ql-fsm/src/error.rs | 23 ++++++++++++----- ql-fsm/src/implementation/core.rs | 12 ++++----- ql-fsm/src/implementation/handshake/ik.rs | 19 +++++--------- ql-fsm/src/implementation/handshake/kk.rs | 19 +++++--------- ql-fsm/src/implementation/handshake/mod.rs | 30 ++++++++++------------ ql-fsm/src/implementation/handshake/xx.rs | 23 ++++++----------- ql-fsm/src/lib.rs | 15 ++++------- ql-fsm/src/tests/handshake.rs | 18 ++++++------- ql-fsm/src/tests/mod.rs | 18 ++++++------- ql-fsm/src/tests/proptest.rs | 18 ++++++------- ql-runtime/src/driver/mod.rs | 3 +-- 11 files changed, 91 insertions(+), 107 deletions(-) diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs index b5706503..8470cd1f 100644 --- a/ql-fsm/src/error.rs +++ b/ql-fsm/src/error.rs @@ -6,17 +6,16 @@ use std::{ use ql_wire::WireError; #[derive(Debug, Clone, PartialEq, Eq)] -pub enum QlFsmError { +pub enum ReceiveError { InvalidPayload, InvalidState, Expired, DecryptFailed, InvalidXid, - NoPeerBound, NoSession, } -impl Display for QlFsmError { +impl Display for ReceiveError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let message = match self { Self::InvalidPayload => "invalid payload", @@ -24,16 +23,15 @@ impl Display for QlFsmError { Self::Expired => "expired", Self::DecryptFailed => "decryption failed", Self::InvalidXid => "invalid xid", - Self::NoPeerBound => "no peer bound", Self::NoSession => "no active session", }; f.write_str(message) } } -impl std::error::Error for QlFsmError {} +impl std::error::Error for ReceiveError {} -impl From for QlFsmError { +impl From for ReceiveError { fn from(value: WireError) -> Self { match value { WireError::InvalidPayload => Self::InvalidPayload, @@ -44,12 +42,23 @@ impl From for QlFsmError { } } -impl From for QlFsmError { +impl From for ReceiveError { fn from(_: NoSessionError) -> Self { Self::NoSession } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct NoPeerError; + +impl Display for NoPeerError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str("no peer bound") + } +} + +impl Error for NoPeerError {} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct NoSessionError; diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/implementation/core.rs index f39f56c5..22276e9e 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/implementation/core.rs @@ -4,8 +4,8 @@ use bytes::Bytes; use ql_wire::{self as wire, QlCrypto, RouteId, SessionCloseCode, StreamId, WireDecode}; use crate::{ - session::SessionEvent, state::LinkState, NoSessionError, OutboundWrite, QlFsm, QlFsmError, - QlFsmEvent, SessionWriteId, StreamError, StreamOps, + session::SessionEvent, state::LinkState, NoSessionError, OutboundWrite, QlFsm, QlFsmEvent, + ReceiveError, SessionWriteId, StreamError, StreamOps, }; pub fn handle_bind_peer(fsm: &mut QlFsm, peer: ql_wire::PeerBundle) { @@ -18,12 +18,12 @@ pub fn receive( fsm: &mut QlFsm, mut bytes: Vec, crypto: &impl QlCrypto, -) -> Result<(), QlFsmError> { +) -> Result<(), ReceiveError> { let mut reader = wire::Reader::new(bytes.as_mut_slice()); let header = wire::RecordHeader::decode(&mut reader)?; if header.version != wire::QL_WIRE_VERSION { - return Err(QlFsmError::InvalidPayload); + return Err(ReceiveError::InvalidPayload); } match header.record_type { @@ -37,11 +37,11 @@ pub fn receive( .state .link .connected_mut() - .ok_or(QlFsmError::NoSession)?; + .ok_or(ReceiveError::NoSession)?; let (decrypt_len, seq) = { let record = wire::QlSessionRecord::decode(&mut reader)?; if record.header.connection_id != state.transport.rx_connection_id { - return Err(QlFsmError::InvalidPayload); + return Err(ReceiveError::InvalidPayload); } let payload = wire::decrypt_record( crypto, diff --git a/ql-fsm/src/implementation/handshake/ik.rs b/ql-fsm/src/implementation/handshake/ik.rs index ac8b62e1..06292816 100644 --- a/ql-fsm/src/implementation/handshake/ik.rs +++ b/ql-fsm/src/implementation/handshake/ik.rs @@ -6,14 +6,10 @@ use super::{ }; use crate::{ state::{IkInitiatorState, LinkState, SessionTransport}, - QlFsm, QlFsmError, + QlFsm, ReceiveError, }; -pub fn start_initiator( - fsm: &mut QlFsm, - crypto: &impl QlCrypto, - peer: PeerBundle, -) -> Result<(), QlFsmError> { +pub fn start_initiator(fsm: &mut QlFsm, crypto: &impl QlCrypto, peer: PeerBundle) { let meta = super::next_handshake_meta(fsm); let mut handshake = wire::IkHandshake::new_initiator( crypto, @@ -21,7 +17,7 @@ pub fn start_initiator( peer, super::local_transport_params(fsm), ); - let message = handshake.write_1(crypto, meta)?; + let message = handshake.write_1(crypto, meta).unwrap(); fsm.state.link = LinkState::IkInitiator(IkInitiatorState { handshake_id: meta.handshake_id, @@ -31,14 +27,13 @@ pub fn start_initiator( }); enqueue_handshake(fsm, QlHandshakeRecord::Ik1(message)); emit_peer_status(fsm); - Ok(()) } pub fn handle_ik1( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Ik1, -) -> Result<(), QlFsmError> { +) -> Result<(), ReceiveError> { if should_ignore_inbound(fsm, message) { return Ok(()); } @@ -46,11 +41,11 @@ pub fn handle_ik1( return Ok(()); } if message.header.recipient != fsm.identity.xid { - return Err(QlFsmError::InvalidXid); + return Err(ReceiveError::InvalidXid); } if let Some(peer) = fsm.state.peer.as_ref() { if message.header.sender != peer.xid { - return Err(QlFsmError::InvalidXid); + return Err(ReceiveError::InvalidXid); } } @@ -75,7 +70,7 @@ pub fn handle_ik2( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Ik2, -) -> Result<(), QlFsmError> { +) -> Result<(), ReceiveError> { { let LinkState::IkInitiator(state) = &mut fsm.state.link else { return Ok(()); diff --git a/ql-fsm/src/implementation/handshake/kk.rs b/ql-fsm/src/implementation/handshake/kk.rs index 38bafd5e..b46f612a 100644 --- a/ql-fsm/src/implementation/handshake/kk.rs +++ b/ql-fsm/src/implementation/handshake/kk.rs @@ -6,14 +6,10 @@ use super::{ }; use crate::{ state::{KkInitiatorState, LinkState, SessionTransport}, - QlFsm, QlFsmError, + QlFsm, ReceiveError, }; -pub fn start_initiator( - fsm: &mut QlFsm, - crypto: &impl QlCrypto, - peer: PeerBundle, -) -> Result<(), QlFsmError> { +pub fn start_initiator(fsm: &mut QlFsm, crypto: &impl QlCrypto, peer: PeerBundle) { let meta = super::next_handshake_meta(fsm); let mut handshake = wire::KkHandshake::new_initiator( crypto, @@ -21,7 +17,7 @@ pub fn start_initiator( peer, super::local_transport_params(fsm), ); - let message = handshake.write_1(crypto, meta)?; + let message = handshake.write_1(crypto, meta).unwrap(); fsm.state.link = LinkState::KkInitiator(KkInitiatorState { handshake_id: meta.handshake_id, @@ -31,14 +27,13 @@ pub fn start_initiator( }); enqueue_handshake(fsm, QlHandshakeRecord::Kk1(message)); emit_peer_status(fsm); - Ok(()) } pub fn handle_kk1( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Kk1, -) -> Result<(), QlFsmError> { +) -> Result<(), ReceiveError> { if should_ignore_inbound(fsm, message) { return Ok(()); } @@ -47,10 +42,10 @@ pub fn handle_kk1( } let Some(peer) = fsm.state.peer.clone() else { - return Err(QlFsmError::InvalidPayload); + return Err(ReceiveError::InvalidPayload); }; if message.header.recipient != fsm.identity.xid || message.header.sender != peer.xid { - return Err(QlFsmError::InvalidXid); + return Err(ReceiveError::InvalidXid); } reset_connected_session_if_needed(fsm); @@ -74,7 +69,7 @@ pub fn handle_kk2( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Kk2, -) -> Result<(), QlFsmError> { +) -> Result<(), ReceiveError> { { let LinkState::KkInitiator(state) = &mut fsm.state.link else { return Ok(()); diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/implementation/handshake/mod.rs index 16d0856d..01908d80 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/implementation/handshake/mod.rs @@ -10,28 +10,26 @@ use super::emit_peer_status; use crate::{ session::{SessionFsm, SessionFsmConfig, StreamParity}, state::{ConnectedState, LinkState, SessionTransport}, - QlFsm, QlFsmError, QlFsmEvent, + NoPeerError, QlFsm, QlFsmEvent, ReceiveError, }; -pub fn handle_connect_ik(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { - let peer = fsm.state.peer.clone().ok_or(QlFsmError::NoPeerBound)?; +pub fn handle_connect_ik(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { + let peer = fsm.state.peer.clone().ok_or(NoPeerError)?; prepare_for_outbound_connect(fsm); - ik::start_initiator(fsm, crypto, peer) + ik::start_initiator(fsm, crypto, peer); + Ok(()) } -pub fn handle_connect_kk(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { - let peer = fsm.state.peer.clone().ok_or(QlFsmError::NoPeerBound)?; +pub fn handle_connect_kk(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { + let peer = fsm.state.peer.clone().ok_or(NoPeerError)?; prepare_for_outbound_connect(fsm); - kk::start_initiator(fsm, crypto, peer) + kk::start_initiator(fsm, crypto, peer); + Ok(()) } -pub fn handle_connect_xx( - fsm: &mut QlFsm, - token: PairingToken, - crypto: &impl QlCrypto, -) -> Result<(), QlFsmError> { +pub fn handle_connect_xx(fsm: &mut QlFsm, token: PairingToken, crypto: &impl QlCrypto) { prepare_for_outbound_connect(fsm); - xx::start_initiator(fsm, crypto, token) + xx::start_initiator(fsm, crypto, token); } pub fn next_handshake_meta(fsm: &mut QlFsm) -> HandshakeMeta { @@ -76,7 +74,7 @@ pub fn handle_handshake_record( fsm: &mut QlFsm, crypto: &impl QlCrypto, record: &QlHandshakeRecord, -) -> Result<(), QlFsmError> { +) -> Result<(), ReceiveError> { match record { QlHandshakeRecord::Ik1(message) => ik::handle_ik1(fsm, crypto, message), QlHandshakeRecord::Ik2(message) => ik::handle_ik2(fsm, crypto, message), @@ -110,11 +108,11 @@ pub fn finish_handshake( fsm: &mut QlFsm, transport: SessionTransport, remote_bundle: wire::PeerBundle, -) -> Result<(), QlFsmError> { +) -> Result<(), ReceiveError> { let xid = remote_bundle.xid; if let Some(peer) = fsm.state.peer.as_ref() { if peer != &remote_bundle { - return Err(QlFsmError::InvalidPayload); + return Err(ReceiveError::InvalidPayload); } } else { fsm.state.peer = Some(remote_bundle); diff --git a/ql-fsm/src/implementation/handshake/xx.rs b/ql-fsm/src/implementation/handshake/xx.rs index 254d10b7..5e1de43f 100644 --- a/ql-fsm/src/implementation/handshake/xx.rs +++ b/ql-fsm/src/implementation/handshake/xx.rs @@ -6,14 +6,10 @@ use super::{ }; use crate::{ state::{LinkState, SessionTransport, XxInitiatorState, XxResponderState}, - QlFsm, QlFsmError, + QlFsm, ReceiveError, }; -pub fn start_initiator( - fsm: &mut QlFsm, - crypto: &impl QlCrypto, - token: PairingToken, -) -> Result<(), QlFsmError> { +pub fn start_initiator(fsm: &mut QlFsm, crypto: &impl QlCrypto, token: PairingToken) { let meta = super::next_handshake_meta(fsm); let mut handshake = wire::XxHandshake::new_initiator( crypto, @@ -21,7 +17,7 @@ pub fn start_initiator( token, super::local_transport_params(fsm), ); - let message = handshake.write_1(crypto, meta)?; + let message = handshake.write_1(crypto, meta).unwrap(); fsm.state.link = LinkState::XxInitiator(XxInitiatorState { handshake_id: meta.handshake_id, @@ -31,14 +27,13 @@ pub fn start_initiator( }); enqueue_handshake(fsm, QlHandshakeRecord::Xx1(message)); emit_peer_status(fsm); - Ok(()) } pub fn handle_xx1( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Xx1, -) -> Result<(), QlFsmError> { +) -> Result<(), ReceiveError> { if should_ignore_inbound(fsm, message) { return Ok(()); } @@ -73,7 +68,7 @@ pub fn handle_xx2( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Xx2, -) -> Result<(), QlFsmError> { +) -> Result<(), ReceiveError> { { let LinkState::XxInitiator(state) = &mut fsm.state.link else { return Ok(()); @@ -98,7 +93,7 @@ pub fn handle_xx3( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Xx3, -) -> Result<(), QlFsmError> { +) -> Result<(), ReceiveError> { let LinkState::XxResponder(state) = &mut fsm.state.link else { return Ok(()); }; @@ -126,7 +121,7 @@ pub fn handle_xx4( fsm: &mut QlFsm, crypto: &impl QlCrypto, message: &Xx4, -) -> Result<(), QlFsmError> { +) -> Result<(), ReceiveError> { { let LinkState::XxInitiator(state) = &mut fsm.state.link else { return Ok(()); @@ -159,9 +154,7 @@ pub fn disarm_pairing(fsm: &mut QlFsm) { pub fn should_ignore_inbound(fsm: &QlFsm, message: &Xx1) -> bool { match &fsm.state.link { LinkState::Idle | LinkState::Connected(_) => false, - LinkState::IkInitiator(_) | LinkState::KkInitiator(_) | LinkState::XxResponder(_) => { - true - } + LinkState::IkInitiator(_) | LinkState::KkInitiator(_) | LinkState::XxResponder(_) => true, LinkState::XxInitiator(state) => { if state.handshake.pairing_token() != message.header.pairing_token { return false; diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 4f09640d..cfc7e80d 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -191,24 +191,19 @@ impl QlFsm { } /// starts an outbound xx handshake using the supplied pairing token - pub fn connect_xx( - &mut self, - now: FsmTime, - token: PairingToken, - crypto: &impl QlCrypto, - ) -> Result<(), QlFsmError> { + pub fn connect_xx(&mut self, now: FsmTime, token: PairingToken, crypto: &impl QlCrypto) { self.state.now = now; - implementation::handle_connect_xx(self, token, crypto) + implementation::handle_connect_xx(self, token, crypto); } /// starts an IK handshake with the currently bound peer - pub fn connect_ik(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { + pub fn connect_ik(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { self.state.now = now; implementation::handle_connect_ik(self, crypto) } /// starts a KK handshake with the currently bound peer - pub fn connect_kk(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), QlFsmError> { + pub fn connect_kk(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { self.state.now = now; implementation::handle_connect_kk(self, crypto) } @@ -219,7 +214,7 @@ impl QlFsm { now: FsmTime, bytes: Vec, crypto: &impl QlCrypto, - ) -> Result<(), QlFsmError> { + ) -> Result<(), ReceiveError> { self.state.now = now; implementation::receive(self, bytes, crypto) } diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 19bfe7c4..cea913be 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -3,7 +3,7 @@ use std::time::Duration; use ql_wire::QlHandshakeRecord; use super::*; -use crate::{state::LinkState, PeerStatus, QlFsmError, QlFsmEvent}; +use crate::{state::LinkState, NoPeerError, PeerStatus, QlFsmEvent}; #[test] fn ik_connect_round_trip_establishes_transport() { @@ -33,7 +33,7 @@ fn xx_connect_round_trip_establishes_transport_when_armed() { let token = pairing_token(1); harness.b.fsm.arm_pairing(token); - harness.connect_xx_a(token).unwrap(); + harness.connect_xx_a(token); let xx1 = harness.next_outbound_a().unwrap(); harness.deliver_to_b(xx1); @@ -100,10 +100,10 @@ fn connect_methods_require_bound_peer() { let mut fsm = QlFsm::new(QlFsmConfig::default(), identity, time); let crypto = TestCrypto::new(9); - assert_eq!(fsm.connect_ik(time, &crypto), Err(QlFsmError::NoPeerBound)); - assert_eq!(fsm.connect_kk(time, &crypto), Err(QlFsmError::NoPeerBound)); + assert_eq!(fsm.connect_ik(time, &crypto), Err(NoPeerError)); + assert_eq!(fsm.connect_kk(time, &crypto), Err(NoPeerError)); - assert_eq!(fsm.connect_xx(time, pairing_token(2), &crypto), Ok(())); + fsm.connect_xx(time, pairing_token(2), &crypto); } #[test] @@ -123,7 +123,7 @@ fn inbound_xx1_ignored_when_pairing_token_not_armed() { let mut harness = Harness::paired(QlFsmConfig::default(), false, false); let token = pairing_token(3); - harness.connect_xx_a(token).unwrap(); + harness.connect_xx_a(token); let xx1 = harness.next_outbound_a().unwrap(); harness.deliver_to_b(xx1); @@ -138,7 +138,7 @@ fn disarm_pairing_rejects_inflight_inbound_xx_responder() { let token = pairing_token(5); harness.b.fsm.arm_pairing(token); - harness.connect_xx_a(token).unwrap(); + harness.connect_xx_a(token); let xx1 = harness.next_outbound_a().unwrap(); harness.deliver_to_b(xx1); let xx2 = harness.next_outbound_b().unwrap(); @@ -158,8 +158,8 @@ fn simultaneous_xx_connect_converges() { harness.a.fsm.arm_pairing(token); harness.b.fsm.arm_pairing(token); - harness.connect_xx_a(token).unwrap(); - harness.connect_xx_b(token).unwrap(); + harness.connect_xx_a(token); + harness.connect_xx_b(token); for _ in 0..2 { if let Some(record) = harness.next_outbound_a() { diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 887e191e..bf7022d6 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -19,7 +19,7 @@ use sha2::{Digest, Sha256}; use crate::{ session::{SessionFsm, SessionFsmConfig, StreamParity}, state::{ConnectedState, LinkState, SessionTransport}, - FsmTime, OutboundWrite, QlFsm, QlFsmConfig, QlFsmError, QlFsmEvent, SessionWriteId, + FsmTime, NoPeerError, OutboundWrite, QlFsm, QlFsmConfig, QlFsmEvent, SessionWriteId, }; #[derive(Clone)] @@ -276,40 +276,40 @@ impl Harness { self.a.fsm.take_next_write(self.time(), &self.a.crypto) } - fn connect_ik_a(&mut self) -> Result<(), QlFsmError> { + fn connect_ik_a(&mut self) -> Result<(), NoPeerError> { let time = self.time(); let Node { fsm, crypto } = &mut self.a; fsm.connect_ik(time, crypto) } - fn connect_ik_b(&mut self) -> Result<(), QlFsmError> { + fn connect_ik_b(&mut self) -> Result<(), NoPeerError> { let time = self.time(); let Node { fsm, crypto } = &mut self.b; fsm.connect_ik(time, crypto) } - fn connect_kk_a(&mut self) -> Result<(), QlFsmError> { + fn connect_kk_a(&mut self) -> Result<(), NoPeerError> { let time = self.time(); let Node { fsm, crypto } = &mut self.a; fsm.connect_kk(time, crypto) } - fn connect_kk_b(&mut self) -> Result<(), QlFsmError> { + fn connect_kk_b(&mut self) -> Result<(), NoPeerError> { let time = self.time(); let Node { fsm, crypto } = &mut self.b; fsm.connect_kk(time, crypto) } - fn connect_xx_a(&mut self, token: PairingToken) -> Result<(), QlFsmError> { + fn connect_xx_a(&mut self, token: PairingToken) { let time = self.time(); let Node { fsm, crypto } = &mut self.a; - fsm.connect_xx(time, token, crypto) + fsm.connect_xx(time, token, crypto); } - fn connect_xx_b(&mut self, token: PairingToken) -> Result<(), QlFsmError> { + fn connect_xx_b(&mut self, token: PairingToken) { let time = self.time(); let Node { fsm, crypto } = &mut self.b; - fsm.connect_xx(time, token, crypto) + fsm.connect_xx(time, token, crypto); } fn deliver_to_a(&mut self, record: Vec) { diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index e5840af0..047bf3f2 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -14,7 +14,7 @@ use super::*; fn test_route_id() -> ql_wire::RouteId { ql_wire::RouteId(ql_wire::VarInt::from_u32(1)) } -use crate::{state::LinkState, PeerStatus, QlFsmError, QlFsmEvent, SessionWriteId}; +use crate::{state::LinkState, PeerStatus, QlFsmEvent, ReceiveError, SessionWriteId}; const SLOT_COUNT: usize = 4; @@ -99,7 +99,7 @@ struct Runner { taken_b_to_a: Vec, pending_a_to_b: Vec>, pending_b_to_a: Vec>, - receive_errors: Vec<(Side, QlFsmError)>, + receive_errors: Vec<(Side, ReceiveError)>, events_a: SideEventState, events_b: SideEventState, known_streams: BTreeSet, @@ -582,11 +582,11 @@ impl Runner { prop_assert!( matches!( error, - QlFsmError::NoSession - | QlFsmError::InvalidState - | QlFsmError::Expired - | QlFsmError::InvalidPayload - | QlFsmError::DecryptFailed + ReceiveError::NoSession + | ReceiveError::InvalidState + | ReceiveError::Expired + | ReceiveError::InvalidPayload + | ReceiveError::DecryptFailed ), "unexpected receive error on side {side:?}: {error:?}" ); @@ -853,13 +853,13 @@ fn reject_taken_b(harness: &mut Harness, write: &TakenWrite) { } } -fn deliver_to_a(harness: &mut Harness, record: Vec) -> Result<(), QlFsmError> { +fn deliver_to_a(harness: &mut Harness, record: Vec) -> Result<(), ReceiveError> { let time = harness.time(); let Node { fsm, crypto } = &mut harness.a; fsm.receive(time, record, crypto) } -fn deliver_to_b(harness: &mut Harness, record: Vec) -> Result<(), QlFsmError> { +fn deliver_to_b(harness: &mut Harness, record: Vec) -> Result<(), ReceiveError> { let time = harness.time(); let Node { fsm, crypto } = &mut harness.b; fsm.receive(time, record, crypto) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 30b09fc4..d337b733 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -142,8 +142,7 @@ impl DriverState { fsm.disarm_pairing(); } RuntimeCommand::StartPairing { token } => { - let _ = self - .with_fsm_events(fsm, platform, |fsm| fsm.connect_xx(now(), token, platform)); + self.with_fsm_events(fsm, platform, |fsm| fsm.connect_xx(now(), token, platform)); } RuntimeCommand::Incoming(bytes) => { let _ = From fef6d76cc0dedde2eebdb537ebf3fb9354896911 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 9 Apr 2026 07:13:32 -0400 Subject: [PATCH 172/304] ql-runtime: platform handle_recv_error --- ql-runtime/src/command.rs | 2 +- ql-runtime/src/driver/mod.rs | 9 ++++++--- ql-runtime/src/handle/mod.rs | 4 ++-- ql-runtime/src/platform.rs | 3 ++- ql-runtime/src/tests/mod.rs | 6 +++--- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index 752288d7..deccec52 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -32,5 +32,5 @@ pub(crate) enum RuntimeCommand { target: CloseTarget, code: StreamCloseCode, }, - Incoming(Vec), + Receive(Vec), } diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index d337b733..2c13ff38 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -144,9 +144,12 @@ impl DriverState { RuntimeCommand::StartPairing { token } => { self.with_fsm_events(fsm, platform, |fsm| fsm.connect_xx(now(), token, platform)); } - RuntimeCommand::Incoming(bytes) => { - let _ = - self.with_fsm_events(fsm, platform, |fsm| fsm.receive(now(), bytes, platform)); + RuntimeCommand::Receive(bytes) => { + if let Err(e) = + self.with_fsm_events(fsm, platform, |fsm| fsm.receive(now(), bytes, platform)) + { + platform.handle_recv_error(e); + } } RuntimeCommand::OpenStream { route_id, diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index 0063b169..20c5581a 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -47,8 +47,8 @@ impl RuntimeHandle { } /// hands inbound transport bytes to the runtime - pub fn send_incoming(&self, bytes: Vec) { - self.send(RuntimeCommand::Incoming(bytes)); + pub fn receive(&self, bytes: Vec) { + self.send(RuntimeCommand::Receive(bytes)); } /// opens a new stream on the active encrypted session diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs index 411627c0..886504d9 100644 --- a/ql-runtime/src/platform.rs +++ b/ql-runtime/src/platform.rs @@ -5,7 +5,7 @@ use std::{ time::Instant, }; -use ql_fsm::PeerStatus; +use ql_fsm::{PeerStatus, ReceiveError}; use ql_wire::{PeerBundle, QlCrypto, XID}; use crate::QlStream; @@ -31,4 +31,5 @@ pub trait QlPlatform: QlCrypto { fn handle_peer_status(&self, peer: XID, status: PeerStatus); fn handle_inbound(&self, event: QlStream); + fn handle_recv_error(&self, _error: ReceiveError) {} } diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index cdae79ad..70913eb1 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -441,7 +441,7 @@ fn register_peers( fn spawn_forwarder(outbound: Receiver>, handle: RuntimeHandle) { tokio::task::spawn_local(async move { while let Ok(bytes) = outbound.recv().await { - handle.send_incoming(bytes); + handle.receive(bytes); } }); } @@ -460,7 +460,7 @@ fn spawn_drop_every_nth_encrypted_forwarder( continue; } } - handle.send_incoming(bytes); + handle.receive(bytes); } }); } @@ -475,7 +475,7 @@ fn spawn_gated_forwarder( if drop_flag.load(Ordering::Relaxed) { continue; } - handle.send_incoming(bytes); + handle.receive(bytes); } }); } From eecb61a947fa7973eb7c606f54a97a8772b0d3cc Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 9 Apr 2026 07:28:35 -0400 Subject: [PATCH 173/304] ql-fsm: refactor modules --- ql-fsm/src/{implementation/core.rs => fsm.rs} | 20 +++++------ .../src/{implementation => }/handshake/ik.rs | 0 .../src/{implementation => }/handshake/kk.rs | 0 .../src/{implementation => }/handshake/mod.rs | 7 ++-- .../src/{implementation => }/handshake/xx.rs | 0 ql-fsm/src/implementation/mod.rs | 6 ---- ql-fsm/src/lib.rs | 33 ++++++++++--------- 7 files changed, 29 insertions(+), 37 deletions(-) rename ql-fsm/src/{implementation/core.rs => fsm.rs} (92%) rename ql-fsm/src/{implementation => }/handshake/ik.rs (100%) rename ql-fsm/src/{implementation => }/handshake/kk.rs (100%) rename ql-fsm/src/{implementation => }/handshake/mod.rs (96%) rename ql-fsm/src/{implementation => }/handshake/xx.rs (100%) delete mode 100644 ql-fsm/src/implementation/mod.rs diff --git a/ql-fsm/src/implementation/core.rs b/ql-fsm/src/fsm.rs similarity index 92% rename from ql-fsm/src/implementation/core.rs rename to ql-fsm/src/fsm.rs index 22276e9e..f43ccb0a 100644 --- a/ql-fsm/src/implementation/core.rs +++ b/ql-fsm/src/fsm.rs @@ -4,8 +4,8 @@ use bytes::Bytes; use ql_wire::{self as wire, QlCrypto, RouteId, SessionCloseCode, StreamId, WireDecode}; use crate::{ - session::SessionEvent, state::LinkState, NoSessionError, OutboundWrite, QlFsm, QlFsmEvent, - ReceiveError, SessionWriteId, StreamError, StreamOps, + handshake, session::SessionEvent, state::LinkState, NoSessionError, OutboundWrite, QlFsm, + QlFsmEvent, ReceiveError, SessionWriteId, StreamError, StreamOps, }; pub fn handle_bind_peer(fsm: &mut QlFsm, peer: ql_wire::PeerBundle) { @@ -29,10 +29,9 @@ pub fn receive( match header.record_type { wire::RecordType::Handshake => { let record = wire::QlHandshakeRecord::decode(&mut reader)?; - super::handle_handshake_record(fsm, crypto, &record) + handshake::handle_handshake_record(fsm, crypto, &record) } wire::RecordType::Session => { - let pending_events = &mut fsm.pending_events; let state = fsm .state .link @@ -57,11 +56,12 @@ pub fn receive( let frames = wire::parse_session_frames(plaintext); let mut session_closed = false; - state - .session - .receive(fsm.state.now.instant, seq, frames, |event| { + state.session.receive(fsm.state.now.instant, seq, frames, { + let pending_events = &mut fsm.pending_events; + |event| { session_closed |= forward_session_event(event, pending_events); - }); + } + }); if session_closed { apply_session_closed(fsm); @@ -72,7 +72,7 @@ pub fn receive( } pub fn on_timer(fsm: &mut QlFsm) { - super::handle_timer(fsm); + handshake::handle_timer(fsm); let mut session_closed = false; if let Some(state) = fsm.state.link.connected_mut() { @@ -91,7 +91,7 @@ pub fn on_timer(fsm: &mut QlFsm) { pub fn next_deadline(fsm: &QlFsm) -> Option { [ - super::next_handshake_deadline(fsm), + handshake::next_handshake_deadline(fsm), fsm.state .link .connected() diff --git a/ql-fsm/src/implementation/handshake/ik.rs b/ql-fsm/src/handshake/ik.rs similarity index 100% rename from ql-fsm/src/implementation/handshake/ik.rs rename to ql-fsm/src/handshake/ik.rs diff --git a/ql-fsm/src/implementation/handshake/kk.rs b/ql-fsm/src/handshake/kk.rs similarity index 100% rename from ql-fsm/src/implementation/handshake/kk.rs rename to ql-fsm/src/handshake/kk.rs diff --git a/ql-fsm/src/implementation/handshake/mod.rs b/ql-fsm/src/handshake/mod.rs similarity index 96% rename from ql-fsm/src/implementation/handshake/mod.rs rename to ql-fsm/src/handshake/mod.rs index 01908d80..0e528922 100644 --- a/ql-fsm/src/implementation/handshake/mod.rs +++ b/ql-fsm/src/handshake/mod.rs @@ -6,8 +6,8 @@ use ql_wire::{ self as wire, EphemeralPublicKey, HandshakeMeta, PairingToken, QlCrypto, QlHandshakeRecord, }; -use super::emit_peer_status; use crate::{ + fsm::{deadline_after_secs, emit_peer_status}, session::{SessionFsm, SessionFsmConfig, StreamParity}, state::{ConnectedState, LinkState, SessionTransport}, NoPeerError, QlFsm, QlFsmEvent, ReceiveError, @@ -37,10 +37,7 @@ pub fn next_handshake_meta(fsm: &mut QlFsm) -> HandshakeMeta { fsm.state.next_control_id = fsm.state.next_control_id.wrapping_add(1); HandshakeMeta { handshake_id, - valid_until: super::deadline_after_secs( - fsm.state.now.unix_secs, - fsm.config.handshake_timeout, - ), + valid_until: deadline_after_secs(fsm.state.now.unix_secs, fsm.config.handshake_timeout), } } diff --git a/ql-fsm/src/implementation/handshake/xx.rs b/ql-fsm/src/handshake/xx.rs similarity index 100% rename from ql-fsm/src/implementation/handshake/xx.rs rename to ql-fsm/src/handshake/xx.rs diff --git a/ql-fsm/src/implementation/mod.rs b/ql-fsm/src/implementation/mod.rs deleted file mode 100644 index 64b0b3d3..00000000 --- a/ql-fsm/src/implementation/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -mod core; -mod handshake; - -pub use core::*; - -pub use handshake::*; diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index cfc7e80d..ef0e4369 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -19,7 +19,8 @@ //! another input may arrive before that deadline, which is fine mod error; -pub(crate) mod implementation; +mod fsm; +mod handshake; pub(crate) mod replay_cache; mod session; pub(crate) mod state; @@ -171,7 +172,7 @@ impl QlFsm { /// binds the remote peer pub fn bind_peer(&mut self, peer: PeerBundle) { - implementation::handle_bind_peer(self, peer); + fsm::handle_bind_peer(self, peer); } /// returns the currently bound peer, if any @@ -187,25 +188,25 @@ impl QlFsm { /// disarms inbound xx pairing and rejects any in-flight inbound xx responder state pub fn disarm_pairing(&mut self) { self.state.armed_pairing_token = None; - implementation::handle_disarm_pairing(self); + handshake::handle_disarm_pairing(self); } /// starts an outbound xx handshake using the supplied pairing token pub fn connect_xx(&mut self, now: FsmTime, token: PairingToken, crypto: &impl QlCrypto) { self.state.now = now; - implementation::handle_connect_xx(self, token, crypto); + handshake::handle_connect_xx(self, token, crypto); } /// starts an IK handshake with the currently bound peer pub fn connect_ik(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { self.state.now = now; - implementation::handle_connect_ik(self, crypto) + handshake::handle_connect_ik(self, crypto) } /// starts a KK handshake with the currently bound peer pub fn connect_kk(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { self.state.now = now; - implementation::handle_connect_kk(self, crypto) + handshake::handle_connect_kk(self, crypto) } /// handles one inbound wire message @@ -216,13 +217,13 @@ impl QlFsm { crypto: &impl QlCrypto, ) -> Result<(), ReceiveError> { self.state.now = now; - implementation::receive(self, bytes, crypto) + fsm::receive(self, bytes, crypto) } /// advances time-based state pub fn on_timer(&mut self, now: FsmTime) { self.state.now = now; - implementation::on_timer(self); + fsm::on_timer(self); } /// returns the next queued event, if any @@ -232,7 +233,7 @@ impl QlFsm { /// returns the next timer deadline, if any pub fn next_deadline(&self) -> Option { - implementation::next_deadline(self) + fsm::next_deadline(self) } /// returns the next outbound record @@ -247,7 +248,7 @@ impl QlFsm { crypto: &impl QlCrypto, ) -> Option { self.state.now = now; - implementation::take_next_write(self, crypto) + fsm::take_next_write(self, crypto) } /// marks a `SessionWriteId` from `take_next_write` as handed to the transport @@ -255,33 +256,33 @@ impl QlFsm { /// call this at most once for each returned `SessionWriteId` pub fn confirm_session_write(&mut self, now: FsmTime, write_id: SessionWriteId) { self.state.now = now; - implementation::confirm_session_write(self, write_id); + fsm::confirm_session_write(self, write_id); } /// reports that a `SessionWriteId` from `take_next_write` was not accepted /// /// call this at most once for each returned `SessionWriteId` pub fn reject_session_write(&mut self, write_id: SessionWriteId) { - implementation::reject_session_write(self, write_id); + fsm::reject_session_write(self, write_id); } /// closes the current encrypted session locally pub fn kill_session(&mut self, code: SessionCloseCode) { - implementation::kill_session(self, code); + fsm::kill_session(self, code); } /// opens a new outgoing stream pub fn open_stream(&mut self, route_id: RouteId) -> Result, NoSessionError> { - implementation::open_stream(self, route_id) + fsm::open_stream(self, route_id) } /// returns a facade for an open stream pub fn stream(&mut self, stream_id: StreamId) -> Result, StreamError> { - implementation::stream(self, stream_id) + fsm::stream(self, stream_id) } /// queues a ping on the active session pub fn queue_ping(&mut self) -> Result<(), NoSessionError> { - implementation::queue_ping(self) + fsm::queue_ping(self) } } From c91b30ce5c7908fc2711a6a292ba4a7ad27fc615 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 9 Apr 2026 07:39:01 -0400 Subject: [PATCH 174/304] ql: unified test impls --- Cargo.lock | 6 +- ql-fsm/Cargo.toml | 4 +- ql-fsm/src/tests/handshake.rs | 9 +- ql-fsm/src/tests/mod.rs | 170 +--------------------- ql-runtime/Cargo.toml | 3 +- ql-runtime/src/driver/test.rs | 79 ++--------- ql-runtime/src/lib.rs | 9 -- ql-runtime/src/tests/handshake.rs | 45 +++--- ql-runtime/src/tests/heartbeat.rs | 7 +- ql-runtime/src/tests/mod.rs | 204 ++++---------------------- ql-runtime/src/tests/rpc.rs | 21 ++- ql-runtime/src/tests/stream.rs | 49 +++---- ql-runtime/src/tests/unpair.rs | 7 +- ql-wire/Cargo.toml | 13 ++ ql-wire/src/lib.rs | 4 + ql-wire/src/testing.rs | 182 ++++++++++++++++++++++++ ql-wire/src/tests.rs | 229 +++++------------------------- 17 files changed, 337 insertions(+), 704 deletions(-) create mode 100644 ql-wire/src/testing.rs diff --git a/Cargo.lock b/Cargo.lock index ed587aa8..21cee875 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2219,11 +2219,8 @@ version = "0.1.0" dependencies = [ "bytes", "indexmap", - "libcrux-aesgcm", - "libcrux-ml-kem", "proptest", "ql-wire", - "sha2", ] [[package]] @@ -2241,13 +2238,11 @@ dependencies = [ "bytes", "event-listener", "futures-lite", - "libcrux-aesgcm", "loom", "oneshot", "ql-fsm", "ql-rpc", "ql-wire", - "sha2", "tokio", ] @@ -2256,6 +2251,7 @@ name = "ql-wire" version = "0.1.0" dependencies = [ "bytes", + "getrandom 0.2.16", "libcrux-aesgcm", "libcrux-ml-kem", "sha2", diff --git a/ql-fsm/Cargo.toml b/ql-fsm/Cargo.toml index 45ccca28..47da2e1b 100644 --- a/ql-fsm/Cargo.toml +++ b/ql-fsm/Cargo.toml @@ -11,7 +11,5 @@ indexmap = "2" ql-wire = { path = "../ql-wire" } [dev-dependencies] -libcrux-aesgcm = "0.0.7" -libcrux-ml-kem = "0.0.7" proptest = "1.6" -sha2 = "0.10" +ql-wire = { path = "../ql-wire", features = ["test-utils"] } diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index cea913be..19be4761 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -96,9 +96,9 @@ fn ik_connect_learns_remote_initial_stream_receive_window() { #[test] fn connect_methods_require_bound_peer() { let time = Harness::paired_known(QlFsmConfig::default()).time(); - let identity = test_identity(55); + let identity = test_identity(&SoftwareCrypto); let mut fsm = QlFsm::new(QlFsmConfig::default(), identity, time); - let crypto = TestCrypto::new(9); + let crypto = SoftwareCrypto; assert_eq!(fsm.connect_ik(time, &crypto), Err(NoPeerError)); assert_eq!(fsm.connect_kk(time, &crypto), Err(NoPeerError)); @@ -307,7 +307,10 @@ fn bind_peer_clears_queued_handshake_output() { harness.connect_ik_a().unwrap(); harness.drain_events_a(); - harness.a.fsm.bind_peer(test_identity(99).bundle()); + harness + .a + .fsm + .bind_peer(test_identity(&SoftwareCrypto).bundle()); assert!(harness.drain_events_a().is_empty()); assert!(harness.next_outbound_a().is_none()); diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index bf7022d6..174be582 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -2,19 +2,12 @@ mod handshake; mod proptest; mod session; -use std::{ - cell::Cell, - time::{Duration, Instant}, -}; +use std::time::{Duration, Instant}; -use libcrux_aesgcm::AesGcm256Key; -use libcrux_ml_kem::mlkem1024; use ql_wire::{ - self, generate_identity, ConnectionId, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, - MlKemPublicKey, Nonce, PairingToken, QlAead, QlCrypto, QlHash, QlIdentity, QlKem, QlRandom, - SessionKey, TransportParams, ENCRYPTED_MESSAGE_AUTH_SIZE, XID, + self, test_identities, test_identity, ConnectionId, PairingToken, QlCrypto, SessionKey, + SoftwareCrypto, TransportParams, }; -use sha2::{Digest, Sha256}; use crate::{ session::{SessionFsm, SessionFsmConfig, StreamParity}, @@ -22,123 +15,7 @@ use crate::{ FsmTime, NoPeerError, OutboundWrite, QlFsm, QlFsmConfig, QlFsmEvent, SessionWriteId, }; -#[derive(Clone)] -struct TestCrypto { - seed: u8, - counter: Cell, -} - -impl TestCrypto { - fn new(seed: u8) -> Self { - Self { - seed, - counter: Cell::new(0), - } - } - - fn next_block(&self) -> [u8; 32] { - let counter = self.counter.get(); - self.counter.set(counter.wrapping_add(1)); - sha256_parts(&[b"ql-fsm:test-rng:v1", &[self.seed], &counter.to_le_bytes()]) - } - - fn random_array(&self) -> [u8; L] { - let mut out = [0u8; L]; - self.fill_random_bytes(&mut out); - out - } -} - -impl QlRandom for TestCrypto { - fn fill_random_bytes(&self, out: &mut [u8]) { - fill_expanded(self, &[b"ql-fsm:test-fill:v1"], out); - } -} - -impl QlHash for TestCrypto { - fn sha256(&self, parts: &[&[u8]]) -> [u8; 32] { - sha256_parts(parts) - } -} - -impl QlAead for TestCrypto { - fn aes256_gcm_encrypt( - &self, - key: &SessionKey, - nonce: &Nonce, - aad: &[u8], - buffer: &mut [u8], - ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE] { - let key: AesGcm256Key = (*key.data()).into(); - let plaintext = buffer.to_vec(); - let mut auth = [0u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; - key.encrypt( - buffer, - (&mut auth).into(), - (&nonce.0).into(), - aad, - &plaintext, - ) - .unwrap(); - auth - } - - fn aes256_gcm_decrypt( - &self, - key: &SessionKey, - nonce: &Nonce, - aad: &[u8], - buffer: &mut [u8], - auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], - ) -> bool { - let key: AesGcm256Key = (*key.data()).into(); - let ciphertext = buffer.to_vec(); - key.decrypt(buffer, (&nonce.0).into(), aad, &ciphertext, auth_tag.into()) - .is_ok() - } -} - -impl QlKem for TestCrypto { - fn mlkem_generate_keypair(&self) -> MlKemKeyPair { - let key_pair = mlkem1024::generate_key_pair(self.random_array()); - let mut public = [0u8; MlKemPublicKey::SIZE]; - public.copy_from_slice(key_pair.pk()); - let mut private = [0u8; MlKemPrivateKey::SIZE]; - private.copy_from_slice(key_pair.sk()); - - MlKemKeyPair { - private: MlKemPrivateKey::new(Box::new(private)), - public: MlKemPublicKey::new(Box::new(public)), - } - } - - fn mlkem_encapsulate(&self, public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { - let public_key = public_key.as_bytes().into(); - let (ciphertext_value, shared_value) = - mlkem1024::encapsulate(&public_key, self.random_array()); - let mut ciphertext = [0u8; MlKemCiphertext::SIZE]; - ciphertext.copy_from_slice(ciphertext_value.as_slice()); - let mut shared = [0u8; SessionKey::SIZE]; - shared.copy_from_slice(shared_value.as_slice()); - ( - MlKemCiphertext::new(Box::new(ciphertext)), - SessionKey::from_data(shared), - ) - } - - fn mlkem_decapsulate( - &self, - private_key: &MlKemPrivateKey, - ciphertext: &MlKemCiphertext, - ) -> SessionKey { - let private_key = private_key.as_bytes().into(); - let ciphertext = ciphertext.as_bytes().into(); - let shared = mlkem1024::decapsulate(&private_key, &ciphertext); - let mut out = [0u8; SessionKey::SIZE]; - out.copy_from_slice(shared.as_slice()); - SessionKey::from_data(out) - } -} +type TestCrypto = SoftwareCrypto; struct Node { fsm: QlFsm, @@ -171,8 +48,7 @@ impl Harness { know_a: bool, know_b: bool, ) -> Self { - let identity_a = test_identity(11); - let identity_b = test_identity(73); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let now = Instant::now(); let time = FsmTime { instant: now, @@ -184,11 +60,11 @@ impl Harness { unix_secs: time.unix_secs, a: Node { fsm: QlFsm::new(config_a, identity_a.clone(), time), - crypto: TestCrypto::new(1), + crypto: SoftwareCrypto, }, b: Node { fsm: QlFsm::new(config_b, identity_b.clone(), time), - crypto: TestCrypto::new(2), + crypto: SoftwareCrypto, }, }; @@ -389,11 +265,6 @@ impl Harness { } } -fn test_identity(seed: u8) -> QlIdentity { - let crypto = TestCrypto::new(seed); - generate_identity(&crypto, XID([seed; XID::SIZE])) -} - fn pairing_token(byte: u8) -> PairingToken { PairingToken([byte; PairingToken::SIZE]) } @@ -449,30 +320,3 @@ fn decrypt_record( ql_wire::decode_session_frames(&plaintext).unwrap(), ) } - -fn sha256_parts(parts: &[&[u8]]) -> [u8; 32] { - let mut hasher = Sha256::new(); - for part in parts { - hasher.update(part); - } - hasher.finalize().into() -} - -fn fill_expanded(crypto: &TestCrypto, parts: &[&[u8]], out: &mut [u8]) { - let mut written = 0usize; - let mut counter = 0u64; - while written < out.len() { - let random = crypto.next_block(); - let counter_bytes = counter.to_le_bytes(); - let mut inputs = Vec::with_capacity(parts.len() + 3); - inputs.push(b"ql-fsm:test-expand:v1".as_slice()); - inputs.push(&random); - inputs.push(&counter_bytes); - inputs.extend_from_slice(parts); - let block = sha256_parts(&inputs); - let take = (out.len() - written).min(block.len()); - out[written..written + take].copy_from_slice(&block[..take]); - written += take; - counter = counter.wrapping_add(1); - } -} diff --git a/ql-runtime/Cargo.toml b/ql-runtime/Cargo.toml index 116c54e7..208e5325 100644 --- a/ql-runtime/Cargo.toml +++ b/ql-runtime/Cargo.toml @@ -20,8 +20,7 @@ ql-rpc = { path = "../ql-rpc", optional = true } ql-wire = { path = "../ql-wire" } [dev-dependencies] -libcrux-aesgcm = "0.0.7" -sha2 = "0.10" +ql-wire = { path = "../ql-wire", features = ["test-utils"] } tokio = { version = "1.44", features = ["macros", "rt", "time", "sync"] } [target.'cfg(loom)'.dev-dependencies] diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index 05333db5..644091c7 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -1,80 +1,15 @@ use std::task::{Context, Poll}; -use ql_wire::{ - MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, PeerBundle, QlAead, QlHash, - QlKem, QlRandom, SessionKey, StreamClose, XID, -}; +use ql_wire::{test_identity, NoopCrypto, PeerBundle, SoftwareCrypto, StreamClose, XID}; use super::*; use crate::{ chunk_slot, driver::state::{InboundIo, OutboundIo}, platform::PlatformFuture, - tests::new_identity, }; -struct NoopPlatform; - -struct NoopTimer; - -impl QlRandom for NoopPlatform { - fn fill_random_bytes(&self, data: &mut [u8]) { - data.fill(0); - } -} - -impl QlHash for NoopPlatform { - fn sha256(&self, _parts: &[&[u8]]) -> [u8; 32] { - [0; 32] - } -} - -impl QlAead for NoopPlatform { - fn aes256_gcm_encrypt( - &self, - _key: &SessionKey, - _nonce: &ql_wire::Nonce, - _aad: &[u8], - _buffer: &mut [u8], - ) -> [u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE] { - [0; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE] - } - - fn aes256_gcm_decrypt( - &self, - _key: &SessionKey, - _nonce: &ql_wire::Nonce, - _aad: &[u8], - _buffer: &mut [u8], - _auth_tag: &[u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE], - ) -> bool { - false - } -} - -impl QlKem for NoopPlatform { - fn mlkem_generate_keypair(&self) -> MlKemKeyPair { - MlKemKeyPair { - private: MlKemPrivateKey::new(Box::new([0; MlKemPrivateKey::SIZE])), - public: MlKemPublicKey::new(Box::new([0; MlKemPublicKey::SIZE])), - } - } - - fn mlkem_encapsulate(&self, _public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { - ( - MlKemCiphertext::new(Box::new([0; MlKemCiphertext::SIZE])), - SessionKey::from_data([0; SessionKey::SIZE]), - ) - } - - fn mlkem_decapsulate( - &self, - _private_key: &MlKemPrivateKey, - _ciphertext: &MlKemCiphertext, - ) -> SessionKey { - SessionKey::from_data([0; SessionKey::SIZE]) - } -} +pub struct NoopTimer; impl crate::platform::QlTimer for NoopTimer { fn set_deadline(&mut self, _deadline: Option) {} @@ -84,7 +19,7 @@ impl crate::platform::QlTimer for NoopTimer { } } -impl QlPlatform for NoopPlatform { +impl QlPlatform for NoopCrypto { type Timer = NoopTimer; type WriteMessageFut<'a> = std::future::Ready; @@ -115,7 +50,11 @@ fn new_driver_state() -> (DriverState, QlFsm) { runtime_tx: runtime_tx.downgrade(), max_concurrent_message_writes: 1, }, - QlFsm::new(ql_fsm::QlFsmConfig::default(), new_identity(7), now()), + QlFsm::new( + ql_fsm::QlFsmConfig::default(), + test_identity(&SoftwareCrypto), + now(), + ), ) } @@ -211,7 +150,7 @@ fn local_close_command_reaps_when_other_half_is_already_closed() { target: CloseTarget::Origin, code: StreamCloseCode(0), }, - &NoopPlatform, + &NoopCrypto, ); assert!(!state.streams.contains_key(&stream_id)); diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index 04a58c4c..d423adb0 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -34,14 +34,6 @@ impl Default for RuntimeConfig { } } -impl RuntimeConfig { - pub(crate) fn normalized(mut self) -> Self { - self.stream_send_buffer_bytes = self.stream_send_buffer_bytes.max(1); - self.max_concurrent_message_writes = self.max_concurrent_message_writes.max(1); - self - } -} - pub struct Runtime

{ identity: QlIdentity, platform: P, @@ -58,7 +50,6 @@ pub fn new_runtime

( where P: QlPlatform, { - let config = config.normalized(); let (tx, rx) = async_channel::unbounded(); ( Runtime { diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index 923d46c2..fa2ab8b5 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -8,10 +8,9 @@ use super::*; async fn connect_round_trip_changes_peer_status() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let identity_a = new_identity(11); - let identity_b = new_identity(73); + let (platform_a, outbound_a, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, status_b) = TestPlatform::new(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); @@ -35,10 +34,9 @@ async fn connect_round_trip_changes_peer_status() { async fn opening_stream_requires_connection() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, _outbound_a, _status_a) = TestPlatform::new(1); - let (platform_b, _outbound_b, _status_b, _inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(11); - let identity_b = new_identity(73); + let (platform_a, _outbound_a, _status_a) = TestPlatform::new(); + let (platform_b, _outbound_b, _status_b, _inbound_b) = TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); @@ -65,10 +63,9 @@ async fn handshake_timeout_disconnects() { }, ..default_runtime_config() }; - let (platform_a, _outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, _outbound_b, _status_b) = TestPlatform::new(2); - let identity_a = new_identity(11); - let identity_b = new_identity(73); + let (platform_a, _outbound_a, status_a) = TestPlatform::new(); + let (platform_b, _outbound_b, _status_b) = TestPlatform::new(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); @@ -88,10 +85,9 @@ async fn handshake_timeout_disconnects() { async fn rejected_session_write_is_reissued() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new_with_session_write_failure(1, 1); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(11); - let identity_b = new_identity(73); + let (platform_a, outbound_a, status_a) = TestPlatform::new_with_session_write_failure(1); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); @@ -147,10 +143,9 @@ async fn rejected_session_write_is_reissued() { async fn start_pairing_round_trip_connects_when_armed() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b) = TestPlatform::new(2); - let identity_a = new_identity(11); - let identity_b = new_identity(73); + let (platform_a, outbound_a, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, status_b) = TestPlatform::new(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let token = pairing_token(7); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); @@ -175,13 +170,13 @@ async fn start_pairing_round_trip_connects_when_armed() { async fn start_pairing_does_not_connect_when_unarmed() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, _status_b) = TestPlatform::new(2); - let identity_a = new_identity(11); + let (platform_a, outbound_a, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, _status_b) = TestPlatform::new(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let token = pairing_token(8); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); - let (runtime_b, handle_b) = new_runtime(new_identity(73), platform_b, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); tokio::task::spawn_local(async move { runtime_a.run().await }); tokio::task::spawn_local(async move { runtime_b.run().await }); @@ -193,7 +188,7 @@ async fn start_pairing_does_not_connect_when_unarmed() { assert_no_status_for( &status_a, - XID([73; XID::SIZE]), + identity_b.xid, PeerStatus::Connected, Duration::from_millis(150), ) diff --git a/ql-runtime/src/tests/heartbeat.rs b/ql-runtime/src/tests/heartbeat.rs index 77412a2c..2fd10383 100644 --- a/ql-runtime/src/tests/heartbeat.rs +++ b/ql-runtime/src/tests/heartbeat.rs @@ -21,10 +21,9 @@ async fn session_timeout_disconnects_and_fails_pending_open() { ..default_runtime_config() }; let config_b = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(2); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(1); - let identity_a = new_identity(11); - let identity_b = new_identity(73); + let (platform_a, outbound_a, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config_a); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config_b); diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 70913eb1..587b0bb1 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -1,9 +1,8 @@ use std::{ - cell::Cell, future::Future, pin::Pin, sync::{ - atomic::{AtomicU8, AtomicUsize, Ordering}, + atomic::{AtomicUsize, Ordering}, Arc, }, task::{Context, Poll}, @@ -11,14 +10,12 @@ use std::{ }; use async_channel::{Receiver, Sender}; -use libcrux_aesgcm::AesGcm256Key; use ql_fsm::PeerStatus; use ql_wire::{ - generate_identity, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, - PairingToken, PeerBundle, QlAead, QlHash, QlIdentity, QlKem, QlRandom, RecordHeader, - RecordType, RouteId, SessionKey, VarInt, WireDecode, XID, + test_identities, test_identity, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, + Nonce, PairingToken, PeerBundle, QlAead, QlHash, QlIdentity, QlKem, QlRandom, RecordHeader, + RecordType, RouteId, SessionKey, SoftwareCrypto, VarInt, WireDecode, XID, }; -use sha2::{Digest, Sha256}; use tokio::{task::LocalSet, time::Sleep}; use crate::{ @@ -62,110 +59,11 @@ impl WriteStats { } } -struct DeterministicCrypto { - seed: u8, - counter: Cell, -} - -impl DeterministicCrypto { - fn new(seed: u8) -> Self { - Self { - seed, - counter: Cell::new(0), - } - } -} - -impl QlRandom for DeterministicCrypto { - fn fill_random_bytes(&self, data: &mut [u8]) { - let value = self.seed.wrapping_add(self.counter.get()); - self.counter.set(self.counter.get().wrapping_add(1)); - data.fill(value); - } -} - -impl QlHash for DeterministicCrypto { - fn sha256(&self, parts: &[&[u8]]) -> [u8; 32] { - let mut hasher = Sha256::new(); - for part in parts { - hasher.update(part); - } - hasher.finalize().into() - } -} - -impl QlAead for DeterministicCrypto { - fn aes256_gcm_encrypt( - &self, - key: &SessionKey, - nonce: &Nonce, - aad: &[u8], - buffer: &mut [u8], - ) -> [u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE] { - let key: AesGcm256Key = (*key.data()).into(); - let plaintext = buffer.to_vec(); - let mut auth = [0u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE]; - key.encrypt( - buffer, - (&mut auth).into(), - (&nonce.0).into(), - aad, - &plaintext, - ) - .unwrap(); - auth - } - - fn aes256_gcm_decrypt( - &self, - key: &SessionKey, - nonce: &Nonce, - aad: &[u8], - buffer: &mut [u8], - auth_tag: &[u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE], - ) -> bool { - let key: AesGcm256Key = (*key.data()).into(); - let ciphertext = buffer.to_vec(); - key.decrypt(buffer, (&nonce.0).into(), aad, &ciphertext, auth_tag.into()) - .is_ok() - } -} - -impl QlKem for DeterministicCrypto { - fn mlkem_generate_keypair(&self) -> MlKemKeyPair { - let data = Box::new([self.seed; MlKemPublicKey::SIZE]); - MlKemKeyPair { - private: MlKemPrivateKey::new(Box::new([self.seed; MlKemPrivateKey::SIZE])), - public: MlKemPublicKey::new(data), - } - } - - fn mlkem_encapsulate(&self, public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { - let mut secret = [0u8; SessionKey::SIZE]; - secret.copy_from_slice(&public_key.as_bytes()[..SessionKey::SIZE]); - ( - MlKemCiphertext::new(Box::new([self.seed; MlKemCiphertext::SIZE])), - SessionKey::from_data(secret), - ) - } - - fn mlkem_decapsulate( - &self, - private_key: &MlKemPrivateKey, - _ciphertext: &MlKemCiphertext, - ) -> SessionKey { - let mut secret = [0u8; SessionKey::SIZE]; - secret.copy_from_slice(&private_key.as_bytes()[..SessionKey::SIZE]); - SessionKey::from_data(secret) - } -} - struct TestPlatform { outbound: Sender>, status: Sender, inbound: Option>, - nonce_seed: u8, - nonce_counter: AtomicU8, + crypto: SoftwareCrypto, encrypted_write_counter: AtomicUsize, fail_encrypted_write_at: Option, write_delay: Duration, @@ -173,13 +71,11 @@ struct TestPlatform { } impl TestPlatform { - fn new(seed: u8) -> (Self, Receiver>, Receiver) { - Self::new_inner(seed, None, None, Duration::ZERO, None) + fn new() -> (Self, Receiver>, Receiver) { + Self::new_inner(None, None, Duration::ZERO, None) } - fn new_with_inbound( - seed: u8, - ) -> ( + fn new_with_inbound() -> ( Self, Receiver>, Receiver, @@ -187,33 +83,24 @@ impl TestPlatform { ) { let (inbound_tx, inbound_rx) = async_channel::unbounded(); let (platform, outbound_rx, status_rx) = - Self::new_inner(seed, Some(inbound_tx), None, Duration::ZERO, None); + Self::new_inner(Some(inbound_tx), None, Duration::ZERO, None); (platform, outbound_rx, status_rx, inbound_rx) } fn new_with_session_write_failure( - seed: u8, fail_encrypted_write_at: usize, ) -> (Self, Receiver>, Receiver) { - Self::new_inner( - seed, - None, - Some(fail_encrypted_write_at), - Duration::ZERO, - None, - ) + Self::new_inner(None, Some(fail_encrypted_write_at), Duration::ZERO, None) } fn new_with_delayed_writes( - seed: u8, delay: Duration, write_stats: WriteStats, ) -> (Self, Receiver>, Receiver) { - Self::new_inner(seed, None, None, delay, Some(write_stats)) + Self::new_inner(None, None, delay, Some(write_stats)) } fn new_inner( - seed: u8, inbound: Option>, fail_encrypted_write_at: Option, write_delay: Duration, @@ -226,8 +113,7 @@ impl TestPlatform { outbound, status, inbound, - nonce_seed: seed, - nonce_counter: AtomicU8::new(0), + crypto: SoftwareCrypto, encrypted_write_counter: AtomicUsize::new(0), fail_encrypted_write_at, write_delay, @@ -264,20 +150,13 @@ impl QlTimer for TokioTimer { impl QlRandom for TestPlatform { fn fill_random_bytes(&self, data: &mut [u8]) { - let value = self - .nonce_seed - .wrapping_add(self.nonce_counter.fetch_add(1, Ordering::Relaxed)); - data.fill(value); + self.crypto.fill_random_bytes(data); } } impl QlHash for TestPlatform { fn sha256(&self, parts: &[&[u8]]) -> [u8; 32] { - let mut hasher = Sha256::new(); - for part in parts { - hasher.update(part); - } - hasher.finalize().into() + self.crypto.sha256(parts) } } @@ -289,18 +168,7 @@ impl QlAead for TestPlatform { aad: &[u8], buffer: &mut [u8], ) -> [u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE] { - let key: AesGcm256Key = (*key.data()).into(); - let plaintext = buffer.to_vec(); - let mut auth = [0u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE]; - key.encrypt( - buffer, - (&mut auth).into(), - (&nonce.0).into(), - aad, - &plaintext, - ) - .unwrap(); - auth + self.crypto.aes256_gcm_encrypt(key, nonce, aad, buffer) } fn aes256_gcm_decrypt( @@ -311,39 +179,22 @@ impl QlAead for TestPlatform { buffer: &mut [u8], auth_tag: &[u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE], ) -> bool { - let key: AesGcm256Key = (*key.data()).into(); - let ciphertext = buffer.to_vec(); - key.decrypt(buffer, (&nonce.0).into(), aad, &ciphertext, auth_tag.into()) - .is_ok() + self.crypto + .aes256_gcm_decrypt(key, nonce, aad, buffer, auth_tag) } } impl QlKem for TestPlatform { fn mlkem_generate_keypair(&self) -> MlKemKeyPair { - let byte = self.nonce_seed; - MlKemKeyPair { - private: MlKemPrivateKey::new(Box::new([byte; MlKemPrivateKey::SIZE])), - public: MlKemPublicKey::new(Box::new([byte; MlKemPublicKey::SIZE])), - } + self.crypto.mlkem_generate_keypair() } fn mlkem_encapsulate(&self, public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { - let mut secret = [0u8; SessionKey::SIZE]; - secret.copy_from_slice(&public_key.as_bytes()[..SessionKey::SIZE]); - ( - MlKemCiphertext::new(Box::new([self.nonce_seed; MlKemCiphertext::SIZE])), - SessionKey::from_data(secret), - ) + self.crypto.mlkem_encapsulate(public_key) } - fn mlkem_decapsulate( - &self, - private_key: &MlKemPrivateKey, - _ciphertext: &MlKemCiphertext, - ) -> SessionKey { - let mut secret = [0u8; SessionKey::SIZE]; - secret.copy_from_slice(&private_key.as_bytes()[..SessionKey::SIZE]); - SessionKey::from_data(secret) + fn mlkem_decapsulate(&self, pk: &MlKemPrivateKey, cipher: &MlKemCiphertext) -> SessionKey { + self.crypto.mlkem_decapsulate(pk, cipher) } } @@ -419,11 +270,6 @@ fn is_encrypted_payload(bytes: &[u8]) -> bool { .is_some_and(|header| header.record_type == RecordType::Session) } -pub(crate) fn new_identity(seed: u8) -> QlIdentity { - let crypto = DeterministicCrypto::new(seed); - generate_identity(&crypto, XID([seed; XID::SIZE])) -} - fn pairing_token(byte: u8) -> PairingToken { PairingToken([byte; PairingToken::SIZE]) } @@ -560,8 +406,8 @@ fn default_runtime_config() -> RuntimeConfig { #[test] fn runtime_is_send() { let config = default_runtime_config(); - let identity_a = new_identity(11); - let (platform_a, _, _) = TestPlatform::new(1); + let identity_a = test_identity(&SoftwareCrypto); + let (platform_a, _, _) = TestPlatform::new(); let (runtime_a, _handle) = new_runtime(identity_a, platform_a, config); std::thread::spawn(move || { tokio::runtime::Builder::new_current_thread() @@ -575,8 +421,8 @@ fn runtime_is_send() { #[test] fn runtime_exits_when_last_handle_drops() { let config = default_runtime_config(); - let identity = new_identity(11); - let (platform, _, _) = TestPlatform::new(1); + let identity = test_identity(&SoftwareCrypto); + let (platform, _, _) = TestPlatform::new(); let (runtime, handle) = new_runtime(identity, platform, config); let (done_tx, done_rx) = oneshot::channel(); diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 983b77f2..93dc82d2 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -54,10 +54,9 @@ impl ql_rpc::request_with_progress::RequestWithProgress for Download { async fn rpc_request_round_trips() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(11); - let identity_b = new_identity(73); + let (platform_a, outbound_a, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); @@ -113,10 +112,9 @@ async fn rpc_request_round_trips() { async fn rpc_subscription_streams_events() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(11); - let identity_b = new_identity(73); + let (platform_a, outbound_a, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); @@ -184,10 +182,9 @@ async fn rpc_subscription_streams_events() { async fn rpc_request_with_progress_supports_progress_then_await() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(11); - let identity_b = new_identity(73); + let (platform_a, outbound_a, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index 7e378781..cce7d65f 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -10,10 +10,9 @@ use crate::QlStreamError; async fn open_stream_duplex_happy_path() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(11); - let identity_b = new_identity(73); + let (platform_a, outbound_a, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); @@ -75,10 +74,9 @@ async fn open_stream_duplex_happy_path() { async fn reader_exposes_bounded_chunk_reads() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(11); - let identity_b = new_identity(73); + let (platform_a, outbound_a, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); @@ -144,10 +142,9 @@ async fn large_stream_payload_round_trips() { let config = default_runtime_config(); let payload: Vec = (0..40).collect(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(11); - let identity_b = new_identity(73); + let (platform_a, outbound_a, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (done_tx, done_rx) = async_channel::bounded(1); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); @@ -199,10 +196,9 @@ async fn large_stream_payload_round_trips() { async fn dropping_responder_closes_initiator_response() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(11); - let identity_b = new_identity(73); + let (platform_a, outbound_a, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); @@ -245,10 +241,9 @@ async fn dropping_responder_closes_initiator_response() { async fn dropping_inbound_reader_cancels_remote_writer() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(11); - let identity_b = new_identity(73); + let (platform_a, outbound_a, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (go_tx, go_rx) = async_channel::bounded(1); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); @@ -306,10 +301,9 @@ async fn max_concurrent_message_writes_is_respected() { ..default_runtime_config() }; let (platform_a, outbound_a, status_a) = - TestPlatform::new_with_delayed_writes(1, Duration::from_millis(40), stats.clone()); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(11); - let identity_b = new_identity(73); + TestPlatform::new_with_delayed_writes(Duration::from_millis(40), stats.clone()); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); @@ -376,10 +370,9 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { }, ..default_runtime_config() }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(11); - let identity_b = new_identity(73); + let (platform_a, outbound_a, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let request_payload: Vec = (0..32).collect(); let response_payload: Vec = (100..132).collect(); diff --git a/ql-runtime/src/tests/unpair.rs b/ql-runtime/src/tests/unpair.rs index 133c156c..93c78177 100644 --- a/ql-runtime/src/tests/unpair.rs +++ b/ql-runtime/src/tests/unpair.rs @@ -4,10 +4,9 @@ use super::*; async fn unpair_clears_remote_peer_and_aborts_active_stream() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(1); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(2); - let identity_a = new_identity(11); - let identity_b = new_identity(73); + let (platform_a, outbound_a, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); diff --git a/ql-wire/Cargo.toml b/ql-wire/Cargo.toml index 4e713826..5fccc95e 100644 --- a/ql-wire/Cargo.toml +++ b/ql-wire/Cargo.toml @@ -5,10 +5,23 @@ edition = "2021" description = "Quantum Link wire format types and crypto helpers" license = "Proprietary" +[features] +test-utils = [ + "dep:getrandom", + "dep:libcrux-aesgcm", + "dep:libcrux-ml-kem", + "dep:sha2", +] + [dependencies] bytes = "1" +getrandom = { workspace = true, optional = true } +libcrux-aesgcm = { version = "0.0.7", optional = true } +libcrux-ml-kem = { version = "0.0.7", optional = true } +sha2 = { version = "0.10", optional = true } [dev-dependencies] +getrandom = { workspace = true } libcrux-aesgcm = "0.0.7" libcrux-ml-kem = "0.0.7" sha2 = "0.10" diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs index 4d6077e4..63b5e633 100644 --- a/ql-wire/src/lib.rs +++ b/ql-wire/src/lib.rs @@ -16,6 +16,8 @@ mod identity; mod nonce; mod pq; mod record; +#[cfg(any(feature = "test-utils", test))] +mod testing; mod varint; mod xid; @@ -31,6 +33,8 @@ pub use identity::*; pub use nonce::*; pub use pq::*; pub use record::*; +#[cfg(any(feature = "test-utils", test))] +pub use testing::*; pub use varint::*; pub use xid::*; diff --git a/ql-wire/src/testing.rs b/ql-wire/src/testing.rs new file mode 100644 index 00000000..83b4fbde --- /dev/null +++ b/ql-wire/src/testing.rs @@ -0,0 +1,182 @@ +use libcrux_aesgcm::AesGcm256Key; +use libcrux_ml_kem::mlkem1024; +use sha2::{Digest, Sha256}; + +use crate::{ + MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, QlAead, QlCrypto, + QlHash, QlIdentity, QlKem, QlRandom, SessionKey, ENCRYPTED_MESSAGE_AUTH_SIZE, XID, +}; + +#[derive(Debug, Default, Clone, Copy)] +pub struct SoftwareCrypto; + +#[derive(Debug, Default, Clone, Copy)] +pub struct NoopCrypto; + +pub fn test_identity(crypto: &impl QlCrypto) -> QlIdentity { + crate::generate_identity(crypto, XID(random_array(crypto))) +} + +pub fn test_identities(crypto: &impl QlCrypto) -> (QlIdentity, QlIdentity) { + (test_identity(crypto), test_identity(crypto)) +} + +impl QlRandom for SoftwareCrypto { + fn fill_random_bytes(&self, out: &mut [u8]) { + getrandom::getrandom(out).unwrap(); + } +} + +impl QlHash for SoftwareCrypto { + fn sha256(&self, parts: &[&[u8]]) -> [u8; 32] { + let mut hasher = Sha256::new(); + for part in parts { + hasher.update(part); + } + hasher.finalize().into() + } +} + +impl QlAead for SoftwareCrypto { + fn aes256_gcm_encrypt( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE] { + let key: AesGcm256Key = (*key.data()).into(); + let plaintext = buffer.to_vec(); + let mut auth = [0u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; + key.encrypt( + buffer, + (&mut auth).into(), + (&nonce.0).into(), + aad, + &plaintext, + ) + .unwrap(); + auth + } + + fn aes256_gcm_decrypt( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], + ) -> bool { + let key: AesGcm256Key = (*key.data()).into(); + let ciphertext = buffer.to_vec(); + key.decrypt(buffer, (&nonce.0).into(), aad, &ciphertext, auth_tag.into()) + .is_ok() + } +} + +impl QlKem for SoftwareCrypto { + fn mlkem_generate_keypair(&self) -> MlKemKeyPair { + let key_pair = mlkem1024::generate_key_pair(random_array(self)); + let mut public = [0u8; MlKemPublicKey::SIZE]; + public.copy_from_slice(key_pair.pk()); + let mut private = [0u8; MlKemPrivateKey::SIZE]; + private.copy_from_slice(key_pair.sk()); + + MlKemKeyPair { + private: MlKemPrivateKey::new(Box::new(private)), + public: MlKemPublicKey::new(Box::new(public)), + } + } + + fn mlkem_encapsulate(&self, public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { + let public_key = public_key.as_bytes().into(); + let (ciphertext_value, shared_value) = + mlkem1024::encapsulate(&public_key, random_array(self)); + let mut ciphertext = [0u8; MlKemCiphertext::SIZE]; + ciphertext.copy_from_slice(ciphertext_value.as_slice()); + let mut shared = [0u8; SessionKey::SIZE]; + shared.copy_from_slice(shared_value.as_slice()); + ( + MlKemCiphertext::new(Box::new(ciphertext)), + SessionKey::from_data(shared), + ) + } + + fn mlkem_decapsulate( + &self, + private_key: &MlKemPrivateKey, + ciphertext: &MlKemCiphertext, + ) -> SessionKey { + let private_key = private_key.as_bytes().into(); + let ciphertext = ciphertext.as_bytes().into(); + let shared = mlkem1024::decapsulate(&private_key, &ciphertext); + let mut out = [0u8; SessionKey::SIZE]; + out.copy_from_slice(shared.as_slice()); + SessionKey::from_data(out) + } +} + +impl QlRandom for NoopCrypto { + fn fill_random_bytes(&self, out: &mut [u8]) { + out.fill(0); + } +} + +impl QlHash for NoopCrypto { + fn sha256(&self, _parts: &[&[u8]]) -> [u8; 32] { + [0; 32] + } +} + +impl QlAead for NoopCrypto { + fn aes256_gcm_encrypt( + &self, + _key: &SessionKey, + _nonce: &Nonce, + _aad: &[u8], + _buffer: &mut [u8], + ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE] { + [0; ENCRYPTED_MESSAGE_AUTH_SIZE] + } + + fn aes256_gcm_decrypt( + &self, + _key: &SessionKey, + _nonce: &Nonce, + _aad: &[u8], + _buffer: &mut [u8], + _auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], + ) -> bool { + false + } +} + +impl QlKem for NoopCrypto { + fn mlkem_generate_keypair(&self) -> MlKemKeyPair { + MlKemKeyPair { + private: MlKemPrivateKey::new(Box::new([0; MlKemPrivateKey::SIZE])), + public: MlKemPublicKey::new(Box::new([0; MlKemPublicKey::SIZE])), + } + } + + fn mlkem_encapsulate(&self, _public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { + ( + MlKemCiphertext::new(Box::new([0; MlKemCiphertext::SIZE])), + SessionKey::from_data([0; SessionKey::SIZE]), + ) + } + + fn mlkem_decapsulate( + &self, + _private_key: &MlKemPrivateKey, + _ciphertext: &MlKemCiphertext, + ) -> SessionKey { + SessionKey::from_data([0; SessionKey::SIZE]) + } +} + +fn random_array(crypto: &impl QlRandom) -> [u8; L] { + let mut out = [0u8; L]; + crypto.fill_random_bytes(&mut out); + out +} diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index cacc77a8..713f3189 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -1,15 +1,5 @@ -use std::sync::atomic::{AtomicU64, Ordering}; - -use libcrux_aesgcm::AesGcm256Key; -use libcrux_ml_kem::mlkem1024; -use sha2::{Digest, Sha256}; - use super::*; -struct TestCrypto { - counter: AtomicU64, -} - fn decode_handshake_record(bytes: &[u8]) -> QlHandshakeRecord { decode_record(bytes).unwrap().1 } @@ -19,143 +9,6 @@ fn decode_session_record(bytes: &[u8]) -> QlSessionRecord> { record.into_owned() } -impl TestCrypto { - fn new(seed: u64) -> Self { - Self { - counter: AtomicU64::new(seed), - } - } - - fn next_block(&self) -> [u8; 32] { - let value = self.counter.fetch_add(1, Ordering::Relaxed).to_le_bytes(); - sha256_parts(&[b"ql-wire:test-rng:v1", &value]) - } - - fn random_array(&self) -> [u8; L] { - let mut out = [0u8; L]; - self.fill_random_bytes(&mut out); - out - } -} - -impl QlRandom for TestCrypto { - fn fill_random_bytes(&self, out: &mut [u8]) { - fill_expanded(self, &[b"ql-wire:test-fill:v1"], out); - } -} - -impl QlHash for TestCrypto { - fn sha256(&self, parts: &[&[u8]]) -> [u8; 32] { - sha256_parts(parts) - } -} - -impl QlAead for TestCrypto { - fn aes256_gcm_encrypt( - &self, - key: &SessionKey, - nonce: &Nonce, - aad: &[u8], - buffer: &mut [u8], - ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE] { - let key: AesGcm256Key = (*key.data()).into(); - let plaintext = buffer.to_vec(); - let mut auth = [0u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; - key.encrypt( - buffer, - (&mut auth).into(), - (&nonce.0).into(), - aad, - &plaintext, - ) - .unwrap(); - auth - } - - fn aes256_gcm_decrypt( - &self, - key: &SessionKey, - nonce: &Nonce, - aad: &[u8], - buffer: &mut [u8], - auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], - ) -> bool { - let key: AesGcm256Key = (*key.data()).into(); - let ciphertext = buffer.to_vec(); - key.decrypt(buffer, (&nonce.0).into(), aad, &ciphertext, auth_tag.into()) - .is_ok() - } -} - -impl QlKem for TestCrypto { - fn mlkem_generate_keypair(&self) -> MlKemKeyPair { - let key_pair = mlkem1024::generate_key_pair(self.random_array()); - let mut public = [0u8; MlKemPublicKey::SIZE]; - public.copy_from_slice(key_pair.pk()); - let mut private = [0u8; MlKemPrivateKey::SIZE]; - private.copy_from_slice(key_pair.sk()); - - MlKemKeyPair { - private: MlKemPrivateKey::new(Box::new(private)), - public: MlKemPublicKey::new(Box::new(public)), - } - } - - fn mlkem_encapsulate(&self, public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { - let public_key = public_key.as_bytes().into(); - let (ciphertext_value, shared_value) = - mlkem1024::encapsulate(&public_key, self.random_array()); - let mut ciphertext = [0u8; MlKemCiphertext::SIZE]; - ciphertext.copy_from_slice(ciphertext_value.as_slice()); - let mut shared = [0u8; SessionKey::SIZE]; - shared.copy_from_slice(shared_value.as_slice()); - ( - MlKemCiphertext::new(Box::new(ciphertext)), - SessionKey::from_data(shared), - ) - } - - fn mlkem_decapsulate( - &self, - private_key: &MlKemPrivateKey, - ciphertext: &MlKemCiphertext, - ) -> SessionKey { - let private_key = private_key.as_bytes().into(); - let ciphertext = ciphertext.as_bytes().into(); - let shared = mlkem1024::decapsulate(&private_key, &ciphertext); - let mut out = [0u8; SessionKey::SIZE]; - out.copy_from_slice(shared.as_slice()); - SessionKey::from_data(out) - } -} - -fn sha256_parts(parts: &[&[u8]]) -> [u8; 32] { - let mut hasher = Sha256::new(); - for part in parts { - hasher.update(part); - } - hasher.finalize().into() -} - -fn fill_expanded(crypto: &TestCrypto, parts: &[&[u8]], out: &mut [u8]) { - let mut written = 0usize; - let mut counter = 0u64; - while written < out.len() { - let random = crypto.next_block(); - let counter_bytes = counter.to_le_bytes(); - let mut inputs = Vec::with_capacity(parts.len() + 3); - inputs.push(b"ql-wire:test-expand:v1".as_slice()); - inputs.push(&random); - inputs.push(&counter_bytes); - inputs.extend_from_slice(parts); - let block = sha256_parts(&inputs); - let take = (out.len() - written).min(block.len()); - out[written..written + take].copy_from_slice(&block[..take]); - written += take; - counter = counter.wrapping_add(1); - } -} - fn xid(byte: u8) -> XID { XID([byte; XID::SIZE]) } @@ -185,10 +38,6 @@ fn handshake_transport_params(window: u32) -> TransportParams { } } -fn make_identity(crypto: &impl QlCrypto, byte: u8) -> QlIdentity { - generate_identity(crypto, xid(byte)) -} - fn handshake_header(sender: u8, recipient: u8) -> HandshakeHeader { HandshakeHeader { sender: xid(sender), @@ -226,8 +75,8 @@ fn encrypt_record( #[test] fn peer_bundle_round_trip() { - let crypto = TestCrypto::new(1); - let identity = make_identity(&crypto, 7).with_capabilities(0x55aa_33cc); + let crypto = SoftwareCrypto; + let identity = test_identity(&crypto).with_capabilities(0x55aa_33cc); let bundle = identity.bundle(); let encoded = bundle.encode_vec(); @@ -298,9 +147,8 @@ fn handshake_record_round_trip_supports_ik_kk_and_xx() { #[test] fn ik_handshake_rejects_tampered_handshake_meta() { - let crypto = TestCrypto::new(9); - let initiator = make_identity(&crypto, 1); - let responder = make_identity(&crypto, 2); + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); let mut initiator_state = IkHandshake::new_initiator( &crypto, @@ -329,9 +177,8 @@ fn ik_handshake_rejects_tampered_handshake_meta() { #[test] fn kk_handshake_rejects_tampered_handshake_header() { - let crypto = TestCrypto::new(10); - let initiator = make_identity(&crypto, 1); - let responder = make_identity(&crypto, 2); + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); let mut initiator_state = KkHandshake::new_initiator( &crypto, @@ -364,9 +211,8 @@ fn kk_handshake_rejects_tampered_handshake_header() { #[test] fn ik_handshake_rejects_tampered_transport_params() { - let crypto = TestCrypto::new(101); - let initiator = make_identity(&crypto, 1); - let responder = make_identity(&crypto, 2); + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); let mut initiator_state = IkHandshake::new_initiator( &crypto, @@ -395,9 +241,8 @@ fn ik_handshake_rejects_tampered_transport_params() { #[test] fn ik_handshake_rejects_tampered_handshake_header() { - let crypto = TestCrypto::new(11); - let initiator = make_identity(&crypto, 1); - let responder = make_identity(&crypto, 2); + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); let mut initiator_state = IkHandshake::new_initiator( &crypto, @@ -421,10 +266,9 @@ fn ik_handshake_rejects_tampered_handshake_header() { #[test] fn ik_handshake_rejects_bound_remote_bundle_mismatch() { - let crypto = TestCrypto::new(12); - let initiator = make_identity(&crypto, 1); - let bogus = make_identity(&crypto, 1); - let responder = make_identity(&crypto, 2); + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + let bogus = test_identity(&crypto); let mut initiator_state = IkHandshake::new_initiator( &crypto, @@ -451,9 +295,8 @@ fn ik_handshake_rejects_bound_remote_bundle_mismatch() { #[test] fn ik_handshake_rejects_expired_message() { - let crypto = TestCrypto::new(13); - let initiator = make_identity(&crypto, 1); - let responder = make_identity(&crypto, 2); + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); let mut initiator_state = IkHandshake::new_initiator( &crypto, @@ -482,9 +325,8 @@ fn ik_handshake_rejects_expired_message() { #[test] fn ik_handshake_round_trip_derives_matching_transport_and_learns_remote() { - let crypto = TestCrypto::new(20); - let initiator = make_identity(&crypto, 3); - let responder = make_identity(&crypto, 4); + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); let initiator_params = handshake_transport_params(4096); let responder_params = handshake_transport_params(8192); @@ -532,9 +374,8 @@ fn ik_handshake_round_trip_derives_matching_transport_and_learns_remote() { #[test] fn ik_handshake_round_trip_derives_matching_transport_with_bound_responder() { - let crypto = TestCrypto::new(21); - let initiator = make_identity(&crypto, 3); - let responder = make_identity(&crypto, 4); + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); let initiator_params = handshake_transport_params(16_384); let responder_params = handshake_transport_params(32_768); @@ -586,9 +427,8 @@ fn ik_handshake_round_trip_derives_matching_transport_with_bound_responder() { #[test] fn kk_handshake_round_trip_derives_matching_transport() { - let crypto = TestCrypto::new(30); - let initiator = make_identity(&crypto, 3); - let responder = make_identity(&crypto, 4); + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); let initiator_params = handshake_transport_params(24_576); let responder_params = handshake_transport_params(49_152); @@ -640,9 +480,8 @@ fn kk_handshake_round_trip_derives_matching_transport() { #[test] fn kk_handshake_rejects_tampered_transport_params() { - let crypto = TestCrypto::new(31); - let initiator = make_identity(&crypto, 3); - let responder = make_identity(&crypto, 4); + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); let mut initiator_state = KkHandshake::new_initiator( &crypto, @@ -675,9 +514,8 @@ fn kk_handshake_rejects_tampered_transport_params() { #[test] fn xx_handshake_rejects_tampered_pairing_token() { - let crypto = TestCrypto::new(32); - let initiator = make_identity(&crypto, 5); - let responder = make_identity(&crypto, 6); + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); let token = pairing_token(7); let mut initiator_state = @@ -698,9 +536,8 @@ fn xx_handshake_rejects_tampered_pairing_token() { #[test] fn xx_handshake_rejects_repeated_transport_param_change() { - let crypto = TestCrypto::new(33); - let initiator = make_identity(&crypto, 5); - let responder = make_identity(&crypto, 6); + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); let token = pairing_token(9); let mut initiator_state = XxHandshake::new_initiator( @@ -739,9 +576,8 @@ fn xx_handshake_rejects_repeated_transport_param_change() { #[test] fn xx_handshake_round_trip_derives_matching_transport_and_learns_remote() { - let crypto = TestCrypto::new(34); - let initiator = make_identity(&crypto, 7); - let responder = make_identity(&crypto, 8); + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); let token = pairing_token(10); let initiator_params = handshake_transport_params(28_672); @@ -804,7 +640,7 @@ fn xx_handshake_round_trip_derives_matching_transport_and_learns_remote() { #[test] fn encrypted_session_record_round_trip_uses_connection_id_header() { - let crypto = TestCrypto::new(40); + let crypto = SoftwareCrypto; let header = SessionHeader { connection_id: ConnectionId::from_data([0x44; ConnectionId::SIZE]), seq: record_seq(11), @@ -915,9 +751,8 @@ fn protocol_record_size_breakdown() { println!("{label:<32}: {size} bytes"); } - let crypto = TestCrypto::new(50); - let initiator = make_identity(&crypto, 1); - let responder = make_identity(&crypto, 2); + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); let mut ik_initiator = IkHandshake::new_initiator( &crypto, From e33e609c90e2183e16664efca9781437504229de Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 9 Apr 2026 19:01:54 -0400 Subject: [PATCH 175/304] ql-fsm: better session close --- ql-fsm/src/fsm.rs | 49 +++++------- ql-fsm/src/lib.rs | 12 ++- ql-fsm/src/session/mod.rs | 143 ++++++++++++++++++---------------- ql-fsm/src/session/state.rs | 17 ++-- ql-fsm/src/session/tracked.rs | 3 +- ql-fsm/src/tests/session.rs | 42 +++++++--- ql-runtime/src/driver/mod.rs | 4 +- 7 files changed, 153 insertions(+), 117 deletions(-) diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs index f43ccb0a..ee111f52 100644 --- a/ql-fsm/src/fsm.rs +++ b/ql-fsm/src/fsm.rs @@ -55,15 +55,14 @@ pub fn receive( let plaintext = Bytes::from(bytes).slice(len - decrypt_len..); let frames = wire::parse_session_frames(plaintext); - let mut session_closed = false; state.session.receive(fsm.state.now.instant, seq, frames, { let pending_events = &mut fsm.pending_events; |event| { - session_closed |= forward_session_event(event, pending_events); + forward_session_event(event, pending_events); } }); - if session_closed { + if state.session.is_closed() { apply_session_closed(fsm); } Ok(()) @@ -74,17 +73,16 @@ pub fn receive( pub fn on_timer(fsm: &mut QlFsm) { handshake::handle_timer(fsm); - let mut session_closed = false; - if let Some(state) = fsm.state.link.connected_mut() { - let pending_events = &mut fsm.pending_events; - state.session.on_timer(fsm.state.now.instant, |event| { - session_closed |= forward_session_event(event, pending_events); - }); - } else { + let Some(state) = fsm.state.link.connected_mut() else { return; - } + }; - if session_closed { + let pending_events = &mut fsm.pending_events; + state.session.on_timer(fsm.state.now.instant, |event| { + forward_session_event(event, pending_events); + }); + + if state.session.is_closed() { apply_session_closed(fsm); } } @@ -118,6 +116,9 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option Result, NoSessionError> { @@ -174,7 +174,7 @@ pub fn emit_peer_status(fsm: &mut QlFsm) { fn forward_session_event( event: SessionEvent, pending_events: &mut std::collections::VecDeque, -) -> bool { +) { match event { SessionEvent::Opened { stream_id, @@ -184,31 +184,24 @@ fn forward_session_event( stream_id, route_id, }); - false } SessionEvent::Readable(stream_id) => { pending_events.push_back(QlFsmEvent::Readable(stream_id)); - false } SessionEvent::Writable(stream_id) => { pending_events.push_back(QlFsmEvent::Writable(stream_id)); - false } SessionEvent::Finished(stream_id) => { pending_events.push_back(QlFsmEvent::Finished(stream_id)); - false } SessionEvent::Closed(frame) => { pending_events.push_back(QlFsmEvent::Closed(frame)); - false } SessionEvent::WritableClosed(frame) => { pending_events.push_back(QlFsmEvent::WritableClosed(frame)); - false } SessionEvent::SessionClosed(close) => { pending_events.push_back(QlFsmEvent::SessionClosed(close)); - true } } } diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index ef0e4369..ac14bbc3 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -88,6 +88,10 @@ pub enum QlFsmEvent { /// local writes on this stream are closed WritableClosed(StreamClose), /// the encrypted session was closed + /// + /// session close is abortive and best-effort. the session ends immediately + /// one final write remains: a record containing only `SessionFrame::Close` + /// the FSM does not wait for an ack for that record SessionClosed(SessionClose), } @@ -267,8 +271,12 @@ impl QlFsm { } /// closes the current encrypted session locally - pub fn kill_session(&mut self, code: SessionCloseCode) { - fsm::kill_session(self, code); + /// + /// This transition is abortive and best-effort. It ends normal session use immediately and + /// may emit one final outbound close record, but it does not wait for the peer to acknowledge + /// that close. + pub fn close_session(&mut self, code: SessionCloseCode) { + fsm::close_session(self, code); } /// opens a new outgoing stream diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 26df092a..162fe37c 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -103,7 +103,7 @@ impl SessionFsm { tracked_records: Default::default(), received_records: ReceivedRecords::default(), ack_state: AckState::Idle, - pending_control: Default::default(), + pending_ping: false, streams: Default::default(), next_stream_index: 0, remote_stream_history: RemoteStreamHistory::new(config.local_parity.remote()), @@ -142,10 +142,27 @@ impl SessionFsm { pub fn queue_ping(&mut self) -> Result<(), NoSessionError> { self.ensure_session_open()?; - self.state.pending_control.ping = true; + self.state.pending_ping = true; Ok(()) } + pub(crate) fn close(&mut self, code: SessionCloseCode, mut emit: impl FnMut(SessionEvent)) { + if self.state.session_state != SessionState::Open { + return; + } + + let close = SessionClose { code }; + self.state.session_state = SessionState::Closing(close.clone()); + self.state.tracked_records.clear(); + self.state.ack_state = AckState::Idle; + self.clear_streams(); + emit(SessionEvent::SessionClosed(close)); + } + + pub(crate) fn is_closed(&self) -> bool { + self.state.session_state == SessionState::Closed + } + pub(crate) fn receive( &mut self, now: Instant, @@ -156,14 +173,15 @@ impl SessionFsm { I: IntoIterator, WireError>>, { self.state.now = now; - self.collect_timeouts(); self.state.last_activity_at = self.state.now; self.state.last_inbound_at = self.state.now; - if self.state.session_state == SessionState::Closed { + if self.state.session_state != SessionState::Open { return; } + self.collect_timeouts(); + let mut received_records = self.state.received_records.clone(); let out_of_order = match received_records.insert(seq) { ReceiveOutcome::TooOld => return, @@ -179,7 +197,7 @@ impl SessionFsm { for frame in frames { let Ok(frame) = frame else { - self.fail_session(SessionCloseCode::PROTOCOL, &mut emit); + self.close(SessionCloseCode::PROTOCOL, &mut emit); return; }; ack_eliciting |= !matches!(frame, SessionFrame::Ack(_)); @@ -188,19 +206,19 @@ impl SessionFsm { SessionFrame::Ack(ack) => self.process_record_ack(&ack, &mut emit), SessionFrame::StreamData(frame) => { if self.handle_stream_data(frame, &mut emit).is_err() { - self.fail_session(SessionCloseCode::PROTOCOL, &mut emit); + self.close(SessionCloseCode::PROTOCOL, &mut emit); return; } } SessionFrame::StreamWindow(frame) => self.handle_stream_window(&frame, &mut emit), SessionFrame::StreamClose(frame) => { if self.handle_stream_close(&frame, &mut emit).is_err() { - self.fail_session(SessionCloseCode::PROTOCOL, &mut emit); + self.close(SessionCloseCode::PROTOCOL, &mut emit); return; } } SessionFrame::Close(close) => { - self.handle_session_close(close, &mut emit); + self.close(close.code, &mut emit); handled_close = true; break; } @@ -221,6 +239,9 @@ impl SessionFsm { pub fn confirm_write(&mut self, now: Instant, write_id: u64) { self.state.now = now; + if !self.state.session_state.is_open() { + return; + } let Some(record) = self.state.tracked_records.get_mut(&write_id) else { return; }; @@ -232,6 +253,9 @@ impl SessionFsm { } pub fn reject_write(&mut self, write_id: u64) { + if !self.state.session_state.is_open() { + return; + } if self .state .tracked_records @@ -246,7 +270,7 @@ impl SessionFsm { restore_tracked_record( self.state.now, &mut self.state.ack_state, - &mut self.state.pending_control, + &mut self.state.pending_ping, &mut self.state.streams, record, ); @@ -254,22 +278,28 @@ impl SessionFsm { pub fn on_timer(&mut self, now: Instant, mut emit: impl FnMut(SessionEvent)) { self.state.now = now; + if !self.state.session_state.is_open() { + return; + } self.collect_timeouts(); if !self.config.peer_timeout.is_zero() && self.state.last_inbound_at + self.config.peer_timeout <= self.state.now { - self.fail_session(SessionCloseCode::TIMEOUT, &mut emit); + self.close(SessionCloseCode::TIMEOUT, &mut emit); return; } if self.state.session_state == SessionState::Open && !self.config.keepalive_interval.is_zero() && self.state.last_activity_at + self.config.keepalive_interval <= self.state.now { - self.state.pending_control.ping = true; + self.state.pending_ping = true; } } pub fn next_deadline(&self) -> Option { + if !self.state.session_state.is_open() { + return None; + } let ack_deadline = match self.state.ack_state { AckState::Idle => None, AckState::Dirty { due_at } => Some(due_at), @@ -286,7 +316,7 @@ impl SessionFsm { .min(); let keepalive_deadline = (self.state.session_state == SessionState::Open && !self.config.keepalive_interval.is_zero() - && !self.state.pending_control.ping) + && !self.state.pending_ping) .then_some(self.state.last_activity_at + self.config.keepalive_interval); let peer_timeout_deadline = (self.state.session_state == SessionState::Open && !self.config.peer_timeout.is_zero()) @@ -304,6 +334,20 @@ impl SessionFsm { pub fn take_next_write(&mut self, now: Instant) -> Option<(Option, SessionRecordBuilder)> { self.state.now = now; + match &self.state.session_state { + SessionState::Closing(close) => { + let seq = self.state.next_record_seq; + next_seq(&mut self.state.next_record_seq); + let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); + assert!(builder.push_close(&close), "builder has capacity"); + self.state.session_state = SessionState::Closed; + return Some((None, builder)); + } + SessionState::Closed => { + return None; + } + SessionState::Open => {} + } self.collect_timeouts(); let (builder, outbound) = self.build_next_record()?; @@ -334,18 +378,10 @@ impl SessionFsm { sent_at: None, }; - if let Some(close) = self.state.pending_control.close.take() { - if builder.push_close(&close) { - outbound.frames.push(TrackedFrame::Close(close.clone())); - } else { - self.state.pending_control.close = Some(close); - } - } - self.push_next_pending_stream_close(&mut builder, &mut outbound); - if self.state.pending_control.ping && builder.push_ping() { - self.state.pending_control.ping = false; + if self.state.pending_ping && builder.push_ping() { + self.state.pending_ping = false; outbound.ping_included = true; } @@ -364,11 +400,7 @@ impl SessionFsm { return None; } - self.state.next_record_seq = seq - .into_inner() - .checked_add(1) - .and_then(|next| RecordSeq::from_u64(next).ok()) - .expect("record sequence overflow"); + next_seq(&mut self.state.next_record_seq); Some((builder, outbound)) } @@ -495,7 +527,7 @@ impl SessionFsm { } fn ensure_session_open(&self) -> Result<(), NoSessionError> { - if self.state.session_state == SessionState::Closed { + if self.state.session_state != SessionState::Open { Err(NoSessionError) } else { Ok(()) @@ -548,7 +580,7 @@ impl SessionFsm { restore_tracked_record( self.state.now, &mut self.state.ack_state, - &mut self.state.pending_control, + &mut self.state.pending_ping, &mut self.state.streams, record, ); @@ -702,18 +734,6 @@ impl SessionFsm { Ok(()) } - fn handle_session_close(&mut self, close: SessionClose, emit: &mut impl FnMut(SessionEvent)) { - if self.state.session_state == SessionState::Closed { - return; - } - - self.state.session_state = SessionState::Closed; - self.state.tracked_records.clear(); - self.clear_streams(); - self.state.pending_control = Default::default(); - emit(SessionEvent::SessionClosed(close)); - } - fn apply_local_close_to_stream(stream: &mut StreamState, target: CloseTarget) { if Self::target_affects_inbound(stream.role, target) { stream.inbound_state = InboundState::Discarding; @@ -739,7 +759,6 @@ impl SessionFsm { || record.frames.iter().any(|frame| match frame { TrackedFrame::StreamData(frame) => frame.stream_id == stream_id, TrackedFrame::StreamClose(frame) => frame.stream_id == stream_id, - TrackedFrame::Close(_) => false, }) }); if tracked_refs_stream { @@ -809,19 +828,6 @@ impl SessionFsm { } } - fn fail_session(&mut self, code: SessionCloseCode, emit: &mut impl FnMut(SessionEvent)) { - if self.state.session_state == SessionState::Closed { - return; - } - - self.state.session_state = SessionState::Closed; - self.state.tracked_records.clear(); - self.state.pending_control = Default::default(); - self.state.pending_control.close = Some(SessionClose { code }); - self.clear_streams(); - emit(SessionEvent::SessionClosed(SessionClose { code })); - } - fn clear_streams(&mut self) { self.state.next_stream_index = 0; self.state.streams.clear(); @@ -911,7 +917,7 @@ fn local_stream_was_opened( fn restore_tracked_record( now: Instant, ack_state: &mut AckState, - pending_control: &mut state::PendingSessionControl, + pending_ping: &mut bool, streams: &mut IndexMap, record: TrackedRecord, ) { @@ -919,7 +925,7 @@ fn restore_tracked_record( schedule_ack(ack_state, now); } if record.ping_included { - pending_control.ping = true; + *pending_ping = true; } for (stream_id, maximum_offset) in record.window_updates { if let Some(stream) = streams.get_mut(&stream_id) { @@ -929,19 +935,12 @@ fn restore_tracked_record( } } for frame in record.frames { - requeue_tracked_frame(pending_control, streams, frame); + requeue_tracked_frame(streams, frame); } } -fn requeue_tracked_frame( - pending_control: &mut state::PendingSessionControl, - streams: &mut IndexMap, - frame: TrackedFrame, -) { +fn requeue_tracked_frame(streams: &mut IndexMap, frame: TrackedFrame) { match frame { - TrackedFrame::Close(close) => { - pending_control.close = Some(close); - } TrackedFrame::StreamClose(close) => restore_stream_close(streams, close), TrackedFrame::StreamData(frame) => restore_stream_data(streams, frame), } @@ -976,7 +975,7 @@ fn acknowledge_tracked_frame( emit: &mut impl FnMut(SessionEvent), ) { match frame { - TrackedFrame::Close(_) | TrackedFrame::StreamClose(_) => {} + TrackedFrame::StreamClose(_) => {} TrackedFrame::StreamData(frame) => { let stream_id = frame.stream_id; if let Some(stream) = streams.get_mut(&stream_id) { @@ -993,3 +992,13 @@ fn acknowledge_tracked_frame( } } } + +#[inline] +#[track_caller] +fn next_seq(seq: &mut RecordSeq) { + *seq = seq + .into_inner() + .checked_add(1) + .and_then(|next| RecordSeq::from_u64(next).ok()) + .expect("record sequence overflow"); +} diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index 58063499..237750cb 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -19,18 +19,25 @@ pub struct SessionFsmState { pub tracked_records: IndexMap, pub received_records: ReceivedRecords, pub ack_state: AckState, - pub pending_control: PendingSessionControl, + pub pending_ping: bool, pub streams: IndexMap, pub next_stream_index: usize, pub remote_stream_history: RemoteStreamHistory, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum SessionState { Open, + Closing(SessionClose), Closed, } +impl SessionState { + pub fn is_open(&self) -> bool { + self == &Self::Open + } +} + #[derive(Debug)] pub struct StreamState { pub role: StreamRole, @@ -128,12 +135,6 @@ pub enum InboundState { Discarding, } -#[derive(Debug, Clone, Default)] -pub struct PendingSessionControl { - pub ping: bool, - pub close: Option, -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum AckState { // ack state is not dirty diff --git a/ql-fsm/src/session/tracked.rs b/ql-fsm/src/session/tracked.rs index 1c7bd798..fa97a77b 100644 --- a/ql-fsm/src/session/tracked.rs +++ b/ql-fsm/src/session/tracked.rs @@ -2,7 +2,7 @@ use std::time::Instant; -use ql_wire::{RecordSeq, SessionClose, StreamClose, StreamId}; +use ql_wire::{RecordSeq, StreamClose, StreamId}; #[derive(Debug, Clone)] pub struct TrackedRecord { @@ -18,7 +18,6 @@ pub struct TrackedRecord { pub enum TrackedFrame { StreamData(TrackedStreamData), StreamClose(StreamClose), - Close(SessionClose), } #[derive(Debug, Clone, Copy, PartialEq, Eq)] diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 71ca159a..a210af4e 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -358,16 +358,34 @@ fn ack_frame_releases_stream_capacity_and_emits_writable() { } #[test] -fn kill_session_disconnects_locally() { +fn close_session_disconnects_locally() { let mut harness = Harness::connected(QlFsmConfig::default()); harness .a .fsm - .kill_session(ql_wire::SessionCloseCode::CANCELLED); + .close_session(ql_wire::SessionCloseCode::CANCELLED); + + assert!(matches!(harness.take_event_a(), Some(QlFsmEvent::SessionClosed(SessionClose { + code: ql_wire::SessionCloseCode::CANCELLED, + })))); + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!( + harness.a.fsm.open_stream(route_id(1)), + Err(NoSessionError) + )); + assert_eq!(harness.a.fsm.queue_ping(), Err(NoSessionError)); + + let close = harness.next_outbound_a().unwrap(); + let session_key = harness.b.fsm.state.link.transport().unwrap().rx_key.clone(); + let (_header, record) = decrypt_record(&harness.b.crypto, &close, &session_key); + assert!(matches!(record.as_slice(), [ql_wire::SessionFrame::Close(_)])); assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); - assert!(harness.drain_events_a().is_empty()); + assert_eq!( + harness.take_event_a(), + Some(QlFsmEvent::PeerStatusChanged(PeerStatus::Disconnected)) + ); } #[test] @@ -443,11 +461,17 @@ fn session_timeout_emits_close_before_disconnect() { assert_eq!( harness.drain_events_a(), - vec![ - QlFsmEvent::SessionClosed(SessionClose { - code: ql_wire::SessionCloseCode::TIMEOUT, - }), - QlFsmEvent::PeerStatusChanged(PeerStatus::Disconnected), - ] + vec![QlFsmEvent::SessionClosed(SessionClose { + code: ql_wire::SessionCloseCode::TIMEOUT, + })] + ); + + let close = harness.next_outbound_a().unwrap(); + let session_key = harness.b.fsm.state.link.transport().unwrap().rx_key.clone(); + let (_header, record) = decrypt_record(&harness.b.crypto, &close, &session_key); + assert!(matches!(record.as_slice(), [ql_wire::SessionFrame::Close(_)])); + assert_eq!( + harness.take_event_a(), + Some(QlFsmEvent::PeerStatusChanged(PeerStatus::Disconnected)) ); } diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 2c13ff38..928e9815 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -63,7 +63,9 @@ impl Runtime

{ } DriverEvent::WriteCompleted { index, success } => { let write = in_flight.swap_remove(index); - DriverState::drive_write_completed(&mut fsm, write.session_write_id, success); + state.with_fsm_events(&mut fsm, &platform, |fsm| { + DriverState::drive_write_completed(fsm, write.session_write_id, success) + }) } DriverEvent::TimerExpired => { state.with_fsm_events(&mut fsm, &platform, |fsm| { From 02b91810f5a48ab2cf549e2ffa05afeb05801960 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 9 Apr 2026 19:20:06 -0400 Subject: [PATCH 176/304] ql-fsm: test compaction --- ql-fsm/src/tests/handshake.rs | 140 ++--- ql-fsm/src/tests/mod.rs | 153 +++--- ql-fsm/src/tests/proptest.rs | 943 ++++++++++++++-------------------- ql-fsm/src/tests/session.rs | 153 +++--- 4 files changed, 598 insertions(+), 791 deletions(-) diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 19be4761..1112012d 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -9,7 +9,7 @@ use crate::{state::LinkState, NoPeerError, PeerStatus, QlFsmEvent}; fn ik_connect_round_trip_establishes_transport() { let mut harness = Harness::paired_known(QlFsmConfig::default()); - harness.connect_ik_a().unwrap(); + harness.connect_ik(Side::A).unwrap(); harness.pump(); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); @@ -20,7 +20,7 @@ fn ik_connect_round_trip_establishes_transport() { fn kk_connect_round_trip_establishes_transport() { let mut harness = Harness::paired_known(QlFsmConfig::default()); - harness.connect_kk_a().unwrap(); + harness.connect_kk(Side::A).unwrap(); harness.pump(); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); @@ -33,17 +33,17 @@ fn xx_connect_round_trip_establishes_transport_when_armed() { let token = pairing_token(1); harness.b.fsm.arm_pairing(token); - harness.connect_xx_a(token); + harness.connect_xx(Side::A, token); - let xx1 = harness.next_outbound_a().unwrap(); - harness.deliver_to_b(xx1); - let xx2 = harness.next_outbound_b().unwrap(); - harness.deliver_to_a(xx2); - let xx3 = harness.next_outbound_a().unwrap(); - harness.deliver_to_b(xx3); + let xx1 = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, xx1); + let xx2 = harness.next_outbound(Side::B).unwrap(); + harness.deliver(Side::A, xx2); + let xx3 = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, xx3); - let xx4 = harness.next_outbound_b().unwrap(); - harness.deliver_to_a(xx4); + let xx4 = harness.next_outbound(Side::B).unwrap(); + harness.deliver(Side::A, xx4); assert_eq!(harness.a.fsm.peer(), Some(&harness.b.fsm.identity.bundle())); assert_eq!(harness.b.fsm.peer(), Some(&harness.a.fsm.identity.bundle())); @@ -64,7 +64,7 @@ fn ik_connect_learns_remote_initial_stream_receive_window() { }, ); - harness.connect_ik_a().unwrap(); + harness.connect_ik(Side::A).unwrap(); harness.pump(); assert_eq!( @@ -110,10 +110,10 @@ fn connect_methods_require_bound_peer() { fn connect_ik_emits_initiator_status() { let mut harness = Harness::paired_known(QlFsmConfig::default()); - harness.connect_ik_a().unwrap(); + harness.connect_ik(Side::A).unwrap(); assert_eq!( - harness.drain_events_a(), + harness.drain_events(Side::A), vec![QlFsmEvent::PeerStatusChanged(PeerStatus::Initiator)] ); } @@ -123,13 +123,13 @@ fn inbound_xx1_ignored_when_pairing_token_not_armed() { let mut harness = Harness::paired(QlFsmConfig::default(), false, false); let token = pairing_token(3); - harness.connect_xx_a(token); - let xx1 = harness.next_outbound_a().unwrap(); - harness.deliver_to_b(xx1); + harness.connect_xx(Side::A, token); + let xx1 = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, xx1); assert!(matches!(harness.b.fsm.state.link, LinkState::Idle)); - assert!(harness.drain_events_b().is_empty()); - assert!(harness.next_outbound_b().is_none()); + assert!(harness.drain_events(Side::B).is_empty()); + assert!(harness.next_outbound(Side::B).is_none()); } #[test] @@ -138,17 +138,17 @@ fn disarm_pairing_rejects_inflight_inbound_xx_responder() { let token = pairing_token(5); harness.b.fsm.arm_pairing(token); - harness.connect_xx_a(token); - let xx1 = harness.next_outbound_a().unwrap(); - harness.deliver_to_b(xx1); - let xx2 = harness.next_outbound_b().unwrap(); - harness.deliver_to_a(xx2); - let xx3 = harness.next_outbound_a().unwrap(); + harness.connect_xx(Side::A, token); + let xx1 = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, xx1); + let xx2 = harness.next_outbound(Side::B).unwrap(); + harness.deliver(Side::A, xx2); + let xx3 = harness.next_outbound(Side::A).unwrap(); harness.b.fsm.disarm_pairing(); - harness.deliver_to_b(xx3); + harness.deliver(Side::B, xx3); assert!(matches!(harness.b.fsm.state.link, LinkState::Idle)); - assert!(harness.next_outbound_b().is_none()); + assert!(harness.next_outbound(Side::B).is_none()); } #[test] @@ -158,15 +158,15 @@ fn simultaneous_xx_connect_converges() { harness.a.fsm.arm_pairing(token); harness.b.fsm.arm_pairing(token); - harness.connect_xx_a(token); - harness.connect_xx_b(token); + harness.connect_xx(Side::A, token); + harness.connect_xx(Side::B, token); for _ in 0..2 { - if let Some(record) = harness.next_outbound_a() { - harness.deliver_to_b(record); + if let Some(record) = harness.next_outbound(Side::A) { + harness.deliver(Side::B, record); } - if let Some(record) = harness.next_outbound_b() { - harness.deliver_to_a(record); + if let Some(record) = harness.next_outbound(Side::B) { + harness.deliver(Side::A, record); } } harness.pump(); @@ -179,28 +179,28 @@ fn simultaneous_xx_connect_converges() { fn connect_ik_replaces_in_flight_attempt_and_ignores_stale_reply() { let mut harness = Harness::paired_known(QlFsmConfig::default()); - harness.connect_ik_a().unwrap(); - harness.drain_events_a(); - let first = harness.next_outbound_a().unwrap(); + harness.connect_ik(Side::A).unwrap(); + harness.drain_events(Side::A); + let first = harness.next_outbound(Side::A).unwrap(); let first_id = handshake_id(&first); - harness.connect_ik_a().unwrap(); - let second = harness.next_outbound_a().unwrap(); + harness.connect_ik(Side::A).unwrap(); + let second = harness.next_outbound(Side::A).unwrap(); let second_id = handshake_id(&second); assert_ne!(first_id, second_id); - harness.deliver_to_b(first); - let stale_reply = harness.next_outbound_b().unwrap(); + harness.deliver(Side::B, first); + let stale_reply = harness.next_outbound(Side::B).unwrap(); assert_eq!(handshake_id(&stale_reply), first_id); - harness.deliver_to_a(stale_reply); + harness.deliver(Side::A, stale_reply); assert!(matches!( harness.a.fsm.state.link, LinkState::IkInitiator(_) )); - harness.deliver_to_b(second); + harness.deliver(Side::B, second); harness.pump(); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); @@ -211,27 +211,27 @@ fn connect_ik_replaces_in_flight_attempt_and_ignores_stale_reply() { fn connect_kk_replaces_in_flight_attempt_and_ignores_stale_reply() { let mut harness = Harness::paired_known(QlFsmConfig::default()); - harness.connect_kk_a().unwrap(); - let first = harness.next_outbound_a().unwrap(); + harness.connect_kk(Side::A).unwrap(); + let first = harness.next_outbound(Side::A).unwrap(); let first_id = handshake_id(&first); - harness.connect_kk_a().unwrap(); - let second = harness.next_outbound_a().unwrap(); + harness.connect_kk(Side::A).unwrap(); + let second = harness.next_outbound(Side::A).unwrap(); let second_id = handshake_id(&second); assert_ne!(first_id, second_id); - harness.deliver_to_b(first); - let stale_reply = harness.next_outbound_b().unwrap(); + harness.deliver(Side::B, first); + let stale_reply = harness.next_outbound(Side::B).unwrap(); assert_eq!(handshake_id(&stale_reply), first_id); - harness.deliver_to_a(stale_reply); + harness.deliver(Side::A, stale_reply); assert!(matches!( harness.a.fsm.state.link, LinkState::KkInitiator(_) )); - harness.deliver_to_b(second); + harness.deliver(Side::B, second); harness.pump(); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); @@ -242,13 +242,13 @@ fn connect_kk_replaces_in_flight_attempt_and_ignores_stale_reply() { fn inbound_ik1_auto_binds_unbound_responder() { let mut harness = Harness::paired(QlFsmConfig::default(), true, false); - harness.connect_ik_a().unwrap(); + harness.connect_ik(Side::A).unwrap(); harness.pump(); let expected_peer = harness.a.fsm.identity.bundle(); assert_eq!(harness.b.fsm.peer(), Some(&expected_peer)); assert_eq!( - harness.drain_events_b(), + harness.drain_events(Side::B), vec![ QlFsmEvent::NewPeer, QlFsmEvent::PeerStatusChanged(PeerStatus::Connected), @@ -266,22 +266,22 @@ fn handshake_timeout_drops_single_ik_attempt_without_resend() { }; let mut harness = Harness::paired_known(config); - harness.connect_ik_a().unwrap(); - harness.drain_events_a(); - let first = harness.next_outbound_a().unwrap(); + harness.connect_ik(Side::A).unwrap(); + harness.drain_events(Side::A); + let first = harness.next_outbound(Side::A).unwrap(); let (_, first) = ql_wire::decode_record::(first.as_slice()).unwrap(); assert!(matches!(first, ql_wire::QlHandshakeRecord::Ik1(_))); - assert!(harness.next_outbound_a().is_none()); + assert!(harness.next_outbound(Side::A).is_none()); harness.advance(config.handshake_timeout); - harness.on_timer_a(); + harness.on_timer(Side::A); assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); assert_eq!( - harness.take_event_a(), + harness.take_event(Side::A), Some(QlFsmEvent::PeerStatusChanged(PeerStatus::Disconnected)) ); - assert!(harness.next_outbound_a().is_none()); + assert!(harness.next_outbound(Side::A).is_none()); } #[test] @@ -292,36 +292,36 @@ fn handshake_timeout_clears_queued_kk_output() { }; let mut harness = Harness::paired_known(config); - harness.connect_kk_a().unwrap(); + harness.connect_kk(Side::A).unwrap(); harness.advance(config.handshake_timeout); - harness.on_timer_a(); + harness.on_timer(Side::A); assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); - assert!(harness.next_outbound_a().is_none()); + assert!(harness.next_outbound(Side::A).is_none()); } #[test] fn bind_peer_clears_queued_handshake_output() { let mut harness = Harness::paired_known(QlFsmConfig::default()); - harness.connect_ik_a().unwrap(); - harness.drain_events_a(); + harness.connect_ik(Side::A).unwrap(); + harness.drain_events(Side::A); harness .a .fsm .bind_peer(test_identity(&SoftwareCrypto).bundle()); - assert!(harness.drain_events_a().is_empty()); - assert!(harness.next_outbound_a().is_none()); + assert!(harness.drain_events(Side::A).is_empty()); + assert!(harness.next_outbound(Side::A).is_none()); } #[test] fn simultaneous_ik_connect_converges() { let mut harness = Harness::paired_known(QlFsmConfig::default()); - harness.connect_ik_a().unwrap(); - harness.connect_ik_b().unwrap(); + harness.connect_ik(Side::A).unwrap(); + harness.connect_ik(Side::B).unwrap(); harness.pump(); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); @@ -332,8 +332,8 @@ fn simultaneous_ik_connect_converges() { fn simultaneous_ik_and_kk_connect_prefers_ik() { let mut harness = Harness::paired_known(QlFsmConfig::default()); - harness.connect_ik_a().unwrap(); - harness.connect_kk_b().unwrap(); + harness.connect_ik(Side::A).unwrap(); + harness.connect_kk(Side::B).unwrap(); harness.pump(); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 174be582..f854a6c6 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -17,6 +17,21 @@ use crate::{ type TestCrypto = SoftwareCrypto; +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum Side { + A, + B, +} + +impl Side { + fn idx(self) -> usize { + match self { + Side::A => 0, + Side::B => 1, + } + } +} + struct Node { fsm: QlFsm, crypto: TestCrypto, @@ -29,6 +44,13 @@ struct Harness { b: Node, } +struct DecodedSessionWrite { + record: Vec, + write_id: Option, + header: ql_wire::SessionHeader, + frames: Vec>>, +} + impl Harness { fn paired_known(config: QlFsmConfig) -> Self { Self::paired_with_configs(config, config, true, true) @@ -132,111 +154,108 @@ impl Harness { self.unix_secs = self.unix_secs.saturating_add(duration.as_secs()); } - fn next_outbound_a(&mut self) -> Option> { - let write = self.a.fsm.take_next_write(self.time(), &self.a.crypto)?; - if let Some(id) = write.session_write_id { - self.a.fsm.confirm_session_write(self.time(), id); + fn node(&self, side: Side) -> &Node { + match side { + Side::A => &self.a, + Side::B => &self.b, } - Some(write.record) } - fn next_outbound_b(&mut self) -> Option> { - let write = self.b.fsm.take_next_write(self.time(), &self.b.crypto)?; - if let Some(id) = write.session_write_id { - self.b.fsm.confirm_session_write(self.time(), id); + fn node_mut(&mut self, side: Side) -> &mut Node { + match side { + Side::A => &mut self.a, + Side::B => &mut self.b, } - Some(write.record) } - fn next_write_a(&mut self) -> Option { - self.a.fsm.take_next_write(self.time(), &self.a.crypto) + fn next_outbound(&mut self, side: Side) -> Option> { + let write = self.next_write(side)?; + if let Some(id) = write.session_write_id { + self.confirm_write(side, id); + } + Some(write.record) } - fn connect_ik_a(&mut self) -> Result<(), NoPeerError> { + fn next_write(&mut self, side: Side) -> Option { let time = self.time(); - let Node { fsm, crypto } = &mut self.a; - fsm.connect_ik(time, crypto) + let Node { fsm, crypto } = self.node_mut(side); + fsm.take_next_write(time, crypto) } - fn connect_ik_b(&mut self) -> Result<(), NoPeerError> { - let time = self.time(); - let Node { fsm, crypto } = &mut self.b; - fsm.connect_ik(time, crypto) + fn next_decoded_outbound(&mut self, side: Side) -> Option { + let write = self.next_write(side)?; + if let Some(id) = write.session_write_id { + self.confirm_write(side, id); + } + Some(self.decode_session_write(write, side)) } - fn connect_kk_a(&mut self) -> Result<(), NoPeerError> { - let time = self.time(); - let Node { fsm, crypto } = &mut self.a; - fsm.connect_kk(time, crypto) + fn next_decoded_write(&mut self, side: Side) -> Option { + let write = self.next_write(side)?; + Some(self.decode_session_write(write, side)) } - fn connect_kk_b(&mut self) -> Result<(), NoPeerError> { + fn connect_ik(&mut self, side: Side) -> Result<(), NoPeerError> { let time = self.time(); - let Node { fsm, crypto } = &mut self.b; - fsm.connect_kk(time, crypto) + let Node { fsm, crypto } = self.node_mut(side); + fsm.connect_ik(time, crypto) } - fn connect_xx_a(&mut self, token: PairingToken) { + fn connect_kk(&mut self, side: Side) -> Result<(), NoPeerError> { let time = self.time(); - let Node { fsm, crypto } = &mut self.a; - fsm.connect_xx(time, token, crypto); + let Node { fsm, crypto } = self.node_mut(side); + fsm.connect_kk(time, crypto) } - fn connect_xx_b(&mut self, token: PairingToken) { + fn connect_xx(&mut self, side: Side, token: PairingToken) { let time = self.time(); - let Node { fsm, crypto } = &mut self.b; + let Node { fsm, crypto } = self.node_mut(side); fsm.connect_xx(time, token, crypto); } - fn deliver_to_a(&mut self, record: Vec) { + fn deliver(&mut self, side: Side, record: Vec) { let time = self.time(); - let Node { fsm, crypto } = &mut self.a; + let Node { fsm, crypto } = self.node_mut(side); fsm.receive(time, record, crypto).unwrap(); } - fn deliver_to_b(&mut self, record: Vec) { + fn confirm_write(&mut self, side: Side, write_id: SessionWriteId) { let time = self.time(); - let Node { fsm, crypto } = &mut self.b; - fsm.receive(time, record, crypto).unwrap(); + self.node_mut(side).fsm.confirm_session_write(time, write_id); } - fn confirm_write_a(&mut self, write_id: SessionWriteId) { - self.a.fsm.confirm_session_write(self.time(), write_id); + fn reject_write(&mut self, side: Side, write_id: SessionWriteId) { + self.node_mut(side).fsm.reject_session_write(write_id); } - fn return_write_a(&mut self, write_id: SessionWriteId) { - self.a.fsm.reject_session_write(write_id); - } - - fn on_timer_a(&mut self) { - let time = self.time(); - self.a.fsm.on_timer(time); + fn decode_session_write(&self, write: OutboundWrite, side: Side) -> DecodedSessionWrite { + let peer = self.node(match side { + Side::A => Side::B, + Side::B => Side::A, + }); + let crypto = &peer.crypto; + let session_key = &peer.fsm.state.link.transport().unwrap().rx_key; + let (header, frames) = decrypt_record(crypto, &write.record, session_key); + DecodedSessionWrite { + record: write.record, + write_id: write.session_write_id, + header, + frames, + } } - fn on_timer_b(&mut self) { + fn on_timer(&mut self, side: Side) { let time = self.time(); - self.b.fsm.on_timer(time); - } - - fn take_event_a(&mut self) -> Option { - self.a.fsm.poll_event() - } - - fn take_event_b(&mut self) -> Option { - self.b.fsm.poll_event() + self.node_mut(side).fsm.on_timer(time); } - fn drain_events_a(&mut self) -> Vec { - let mut events = Vec::new(); - while let Some(event) = self.a.fsm.poll_event() { - events.push(event); - } - events + fn take_event(&mut self, side: Side) -> Option { + self.node_mut(side).fsm.poll_event() } - fn drain_events_b(&mut self) -> Vec { + fn drain_events(&mut self, side: Side) -> Vec { let mut events = Vec::new(); - while let Some(event) = self.b.fsm.poll_event() { + while let Some(event) = self.take_event(side) { events.push(event); } events @@ -246,14 +265,14 @@ impl Harness { for _ in 0..128 { let mut progressed = false; - while let Some(record) = self.next_outbound_a() { + while let Some(record) = self.next_outbound(Side::A) { progressed = true; - self.deliver_to_b(record); + self.deliver(Side::B, record); } - while let Some(record) = self.next_outbound_b() { + while let Some(record) = self.next_outbound(Side::B) { progressed = true; - self.deliver_to_a(record); + self.deliver(Side::A, record); } if !progressed { diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index 047bf3f2..04ce4bf4 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -18,49 +18,93 @@ use crate::{state::LinkState, PeerStatus, QlFsmEvent, ReceiveError, SessionWrite const SLOT_COUNT: usize = 4; -#[derive(Clone, Copy, Debug)] -enum Side { - A, - B, -} - #[derive(Clone, Debug)] enum Action { - ConnectIkA, - ConnectIkB, - ConnectKkA, - ConnectKkB, + ConnectIk(Side), + ConnectKk(Side), AdvanceMs(u8), - OnTimerA, - OnTimerB, + OnTimer(Side), OnTimerBoth, Pump, - TakeNextAToB, - TakeNextBToA, - ConfirmTakenAToB(usize), - ConfirmTakenBToA(usize), - RejectTakenAToB(usize), - RejectTakenBToA(usize), - CaptureNextAToB, - CaptureNextBToA, - DeliverNextAToB, - DeliverNextBToA, - DropNextAToB, - DropNextBToA, - DeliverQueuedAToB(usize), - DeliverQueuedBToA(usize), - DuplicateQueuedAToB(usize), - DuplicateQueuedBToA(usize), - DropQueuedAToB(usize), - DropQueuedBToA(usize), - OpenStreamA(usize), - OpenStreamB(usize), - WriteA { slot: usize, bytes: Vec }, - WriteB { slot: usize, bytes: Vec }, - FinishA(usize), - FinishB(usize), - CloseA(usize), - CloseB(usize), + TakeNext(Side), + ConfirmTaken { + side: Side, + index: usize, + }, + RejectTaken { + side: Side, + index: usize, + }, + CaptureNext(Side), + DeliverNext(Side), + DropNext(Side), + DeliverQueued { + side: Side, + index: usize, + }, + DuplicateQueued { + side: Side, + index: usize, + }, + DropQueued { + side: Side, + index: usize, + }, + OpenStream { + side: Side, + slot: usize, + }, + Write { + side: Side, + slot: usize, + bytes: Vec, + }, + Finish { + side: Side, + slot: usize, + }, + Close { + side: Side, + slot: usize, + }, +} + +impl Action { + fn confirm_taken(side: Side, index: usize) -> Self { + Self::ConfirmTaken { side, index } + } + + fn reject_taken(side: Side, index: usize) -> Self { + Self::RejectTaken { side, index } + } + + fn deliver_queued(side: Side, index: usize) -> Self { + Self::DeliverQueued { side, index } + } + + fn duplicate_queued(side: Side, index: usize) -> Self { + Self::DuplicateQueued { side, index } + } + + fn drop_queued(side: Side, index: usize) -> Self { + Self::DropQueued { side, index } + } + + fn open_stream(side: Side, slot: usize) -> Self { + Self::OpenStream { side, slot } + } + + fn write(side: Side, slot: usize, bytes: Vec) -> Self { + Self::Write { side, slot, bytes } + } + + fn finish(side: Side, slot: usize) -> Self { + Self::Finish { side, slot } + } + + fn close(side: Side, slot: usize) -> Self { + Self::Close { side, slot } + } } #[derive(Clone, Debug)] @@ -93,24 +137,16 @@ impl SideEventState { struct Runner { harness: Harness, - slots_a: [Option; SLOT_COUNT], - slots_b: [Option; SLOT_COUNT], - taken_a_to_b: Vec, - taken_b_to_a: Vec, - pending_a_to_b: Vec>, - pending_b_to_a: Vec>, + slots: [[Option; SLOT_COUNT]; 2], + taken: [Vec; 2], + pending: [Vec>; 2], receive_errors: Vec<(Side, ReceiveError)>, - events_a: SideEventState, - events_b: SideEventState, + events: [SideEventState; 2], known_streams: BTreeSet, - expected_at_a: BTreeMap>, - expected_at_b: BTreeMap>, - received_at_a: BTreeMap>, - received_at_b: BTreeMap>, - finished_by_a: BTreeSet, - finished_by_b: BTreeSet, - closed_by_a: BTreeSet, - closed_by_b: BTreeSet, + expected: [BTreeMap>; 2], + received: [BTreeMap>; 2], + finished_by: [BTreeSet; 2], + closed_by: [BTreeSet; 2], } impl Runner { @@ -125,24 +161,16 @@ impl Runner { Self { harness: Harness::paired_known(config), - slots_a: [None; SLOT_COUNT], - slots_b: [None; SLOT_COUNT], - taken_a_to_b: Vec::new(), - taken_b_to_a: Vec::new(), - pending_a_to_b: Vec::new(), - pending_b_to_a: Vec::new(), + slots: [[None; SLOT_COUNT]; 2], + taken: [Vec::new(), Vec::new()], + pending: [Vec::new(), Vec::new()], receive_errors: Vec::new(), - events_a: SideEventState::default(), - events_b: SideEventState::default(), + events: [SideEventState::default(), SideEventState::default()], known_streams: BTreeSet::new(), - expected_at_a: BTreeMap::new(), - expected_at_b: BTreeMap::new(), - received_at_a: BTreeMap::new(), - received_at_b: BTreeMap::new(), - finished_by_a: BTreeSet::new(), - finished_by_b: BTreeSet::new(), - closed_by_a: BTreeSet::new(), - closed_by_b: BTreeSet::new(), + expected: [BTreeMap::new(), BTreeMap::new()], + received: [BTreeMap::new(), BTreeMap::new()], + finished_by: [BTreeSet::new(), BTreeSet::new()], + closed_by: [BTreeSet::new(), BTreeSet::new()], } } @@ -153,35 +181,24 @@ impl Runner { session_peer_timeout: Duration::from_secs(5), ..QlFsmConfig::default() }; + let connected_events = || SideEventState { + last_peer_status: Some(PeerStatus::Connected), + session_epoch: 1, + ..SideEventState::default() + }; Self { harness: Harness::connected(config), - slots_a: [None; SLOT_COUNT], - slots_b: [None; SLOT_COUNT], - taken_a_to_b: Vec::new(), - taken_b_to_a: Vec::new(), - pending_a_to_b: Vec::new(), - pending_b_to_a: Vec::new(), + slots: [[None; SLOT_COUNT]; 2], + taken: [Vec::new(), Vec::new()], + pending: [Vec::new(), Vec::new()], receive_errors: Vec::new(), - events_a: SideEventState { - last_peer_status: Some(PeerStatus::Connected), - session_epoch: 1, - ..SideEventState::default() - }, - events_b: SideEventState { - last_peer_status: Some(PeerStatus::Connected), - session_epoch: 1, - ..SideEventState::default() - }, + events: [connected_events(), connected_events()], known_streams: BTreeSet::new(), - expected_at_a: BTreeMap::new(), - expected_at_b: BTreeMap::new(), - received_at_a: BTreeMap::new(), - received_at_b: BTreeMap::new(), - finished_by_a: BTreeSet::new(), - finished_by_b: BTreeSet::new(), - closed_by_a: BTreeSet::new(), - closed_by_b: BTreeSet::new(), + expected: [BTreeMap::new(), BTreeMap::new()], + received: [BTreeMap::new(), BTreeMap::new()], + finished_by: [BTreeSet::new(), BTreeSet::new()], + closed_by: [BTreeSet::new(), BTreeSet::new()], } } @@ -200,190 +217,131 @@ impl Runner { #[allow(clippy::cognitive_complexity, clippy::too_many_lines)] fn apply(&mut self, action: &Action) { match action { - Action::ConnectIkA => { - let _ = self.harness.connect_ik_a(); - } - Action::ConnectIkB => { - let _ = self.harness.connect_ik_b(); + Action::ConnectIk(side) => { + let _ = self.harness.connect_ik(*side); } - Action::ConnectKkA => { - let _ = self.harness.connect_kk_a(); - } - Action::ConnectKkB => { - let _ = self.harness.connect_kk_b(); + Action::ConnectKk(side) => { + let _ = self.harness.connect_kk(*side); } Action::AdvanceMs(ms) => { self.harness .advance(Duration::from_millis(u64::from(*ms) + 1)); } - Action::OnTimerA => self.harness.on_timer_a(), - Action::OnTimerB => self.harness.on_timer_b(), + Action::OnTimer(side) => self.harness.on_timer(*side), Action::OnTimerBoth => { - self.harness.on_timer_a(); - self.harness.on_timer_b(); + self.harness.on_timer(Side::A); + self.harness.on_timer(Side::B); } Action::Pump => self.capture_all_outbound(), - Action::TakeNextAToB => { - if let Some(write) = take_unconfirmed_outbound_a(&mut self.harness) { - self.taken_a_to_b.push(write); - } - } - Action::TakeNextBToA => { - if let Some(write) = take_unconfirmed_outbound_b(&mut self.harness) { - self.taken_b_to_a.push(write); - } - } - Action::ConfirmTakenAToB(index) => { - if let Some(write) = take_taken(&mut self.taken_a_to_b, *index) { - confirm_taken_a(&mut self.harness, &write); - self.pending_a_to_b.push(write.record); - } - } - Action::ConfirmTakenBToA(index) => { - if let Some(write) = take_taken(&mut self.taken_b_to_a, *index) { - confirm_taken_b(&mut self.harness, &write); - self.pending_b_to_a.push(write.record); + Action::TakeNext(side) => { + if let Some(write) = take_unconfirmed_outbound(&mut self.harness, *side) { + self.taken[side.idx()].push(write); } } - Action::RejectTakenAToB(index) => { - if let Some(write) = take_taken(&mut self.taken_a_to_b, *index) { - reject_taken_a(&mut self.harness, &write); + Action::ConfirmTaken { side, index } => { + if let Some(write) = take_taken(&mut self.taken[side.idx()], *index) { + confirm_taken(&mut self.harness, *side, &write); + self.pending[side.idx()].push(write.record); } } - Action::RejectTakenBToA(index) => { - if let Some(write) = take_taken(&mut self.taken_b_to_a, *index) { - reject_taken_b(&mut self.harness, &write); + Action::RejectTaken { side, index } => { + if let Some(write) = take_taken(&mut self.taken[side.idx()], *index) { + reject_taken(&mut self.harness, *side, &write); } } - Action::CaptureNextAToB => { - if let Some(record) = take_confirmed_outbound_a(&mut self.harness) { - self.pending_a_to_b.push(record); - } - } - Action::CaptureNextBToA => { - if let Some(record) = take_confirmed_outbound_b(&mut self.harness) { - self.pending_b_to_a.push(record); - } - } - Action::DeliverNextAToB => { - if let Some(record) = take_confirmed_outbound_a(&mut self.harness) { - self.deliver_to_b(record); - } - } - Action::DeliverNextBToA => { - if let Some(record) = take_confirmed_outbound_b(&mut self.harness) { - self.deliver_to_a(record); - } - } - Action::DropNextAToB => { - let _ = take_confirmed_outbound_a(&mut self.harness); - } - Action::DropNextBToA => { - let _ = take_confirmed_outbound_b(&mut self.harness); - } - Action::DeliverQueuedAToB(index) => { - if let Some(record) = take_pending(&mut self.pending_a_to_b, *index) { - self.deliver_to_b(record); + Action::CaptureNext(side) => { + if let Some(record) = take_confirmed_outbound(&mut self.harness, *side) { + self.pending[side.idx()].push(record); } } - Action::DeliverQueuedBToA(index) => { - if let Some(record) = take_pending(&mut self.pending_b_to_a, *index) { - self.deliver_to_a(record); + Action::DeliverNext(side) => { + if let Some(record) = take_confirmed_outbound(&mut self.harness, *side) { + self.deliver_to(opposite(*side), record); } } - Action::DuplicateQueuedAToB(index) => { - if let Some(record) = peek_pending(&self.pending_a_to_b, *index) { - self.deliver_to_b(record); - } + Action::DropNext(side) => { + let _ = take_confirmed_outbound(&mut self.harness, *side); } - Action::DuplicateQueuedBToA(index) => { - if let Some(record) = peek_pending(&self.pending_b_to_a, *index) { - self.deliver_to_a(record); + Action::DeliverQueued { side, index } => { + if let Some(record) = take_pending(&mut self.pending[side.idx()], *index) { + self.deliver_to(opposite(*side), record); } } - Action::DropQueuedAToB(index) => { - let _ = take_pending(&mut self.pending_a_to_b, *index); - } - Action::DropQueuedBToA(index) => { - let _ = take_pending(&mut self.pending_b_to_a, *index); - } - Action::OpenStreamA(slot) => { - if let Ok(stream) = self.harness.a.fsm.open_stream(test_route_id()) { - let stream_id = stream.stream_id(); - self.slots_a[*slot] = Some(stream_id); - self.known_streams.insert(stream_id); + Action::DuplicateQueued { side, index } => { + if let Some(record) = peek_pending(&self.pending[side.idx()], *index) { + self.deliver_to(opposite(*side), record); } } - Action::OpenStreamB(slot) => { - if let Ok(stream) = self.harness.b.fsm.open_stream(test_route_id()) { - let stream_id = stream.stream_id(); - self.slots_b[*slot] = Some(stream_id); + Action::DropQueued { side, index } => { + let _ = take_pending(&mut self.pending[side.idx()], *index); + } + Action::OpenStream { side, slot } => { + let stream_id = self + .harness + .node_mut(*side) + .fsm + .open_stream(test_route_id()) + .ok() + .map(|stream| stream.stream_id()); + if let Some(stream_id) = stream_id { + self.slots[side.idx()][*slot] = Some(stream_id); self.known_streams.insert(stream_id); } } - Action::WriteA { slot, bytes } => { - if let Some(stream_id) = self.slots_a[*slot] { - let mut chunk = Bytes::copy_from_slice(bytes); - if let Ok(mut stream) = self.harness.a.fsm.stream(stream_id) { - if let Some(mut writer) = stream.writer() { - let accepted = writer.write(&mut chunk); - self.expected_at_b - .entry(stream_id) - .or_default() - .extend_from_slice(&bytes[..accepted]); - } - } - } - } - Action::WriteB { slot, bytes } => { - if let Some(stream_id) = self.slots_b[*slot] { + Action::Write { side, slot, bytes } => { + if let Some(stream_id) = self.slots[side.idx()][*slot] { let mut chunk = Bytes::copy_from_slice(bytes); - if let Ok(mut stream) = self.harness.b.fsm.stream(stream_id) { + let accepted = if let Ok(mut stream) = + self.harness.node_mut(*side).fsm.stream(stream_id) + { if let Some(mut writer) = stream.writer() { - let accepted = writer.write(&mut chunk); - self.expected_at_a - .entry(stream_id) - .or_default() - .extend_from_slice(&bytes[..accepted]); - } - } - } - } - Action::FinishA(slot) => { - if let Some(stream_id) = self.slots_a[*slot] { - if let Ok(mut stream) = self.harness.a.fsm.stream(stream_id) { - if let Some(writer) = stream.writer() { - writer.finish(); - self.finished_by_a.insert(stream_id); + writer.write(&mut chunk) + } else { + 0 } + } else { + 0 + }; + if accepted != 0 { + self.expected[opposite(*side).idx()] + .entry(stream_id) + .or_default() + .extend_from_slice(&bytes[..accepted]); } } } - Action::FinishB(slot) => { - if let Some(stream_id) = self.slots_b[*slot] { - if let Ok(mut stream) = self.harness.b.fsm.stream(stream_id) { + Action::Finish { side, slot } => { + if let Some(stream_id) = self.slots[side.idx()][*slot] { + let finished = if let Ok(mut stream) = + self.harness.node_mut(*side).fsm.stream(stream_id) + { if let Some(writer) = stream.writer() { writer.finish(); - self.finished_by_b.insert(stream_id); + true + } else { + false } + } else { + false + }; + if finished { + self.finished_by[side.idx()].insert(stream_id); } } } - Action::CloseA(slot) => { - if let Some(stream_id) = self.slots_a[*slot] { - if let Ok(mut stream) = self.harness.a.fsm.stream(stream_id) { + Action::Close { side, slot } => { + if let Some(stream_id) = self.slots[side.idx()][*slot] { + let closed = if let Ok(mut stream) = + self.harness.node_mut(*side).fsm.stream(stream_id) + { stream.close(CloseTarget::Both, StreamCloseCode(0)); - self.closed_by_a.insert(stream_id); - self.slots_a[*slot] = None; - } - } - } - Action::CloseB(slot) => { - if let Some(stream_id) = self.slots_b[*slot] { - if let Ok(mut stream) = self.harness.b.fsm.stream(stream_id) { - stream.close(CloseTarget::Both, StreamCloseCode(0)); - self.closed_by_b.insert(stream_id); - self.slots_b[*slot] = None; + true + } else { + false + }; + if closed { + self.closed_by[side.idx()].insert(stream_id); + self.slots[side.idx()][*slot] = None; } } } @@ -391,10 +349,10 @@ impl Runner { } fn observe_and_assert(&mut self) -> TestCaseResult { - self.drain_reads_a(); - self.drain_reads_b(); - let events_a = self.harness.drain_events_a(); - let events_b = self.harness.drain_events_b(); + self.drain_reads(Side::A); + self.drain_reads(Side::B); + let events_a = self.harness.drain_events(Side::A); + let events_b = self.harness.drain_events(Side::B); self.process_events(Side::A, events_a)?; self.process_events(Side::B, events_b)?; self.assert_prefix_invariants()?; @@ -421,8 +379,8 @@ impl Runner { self.flush_pending_in_order(); self.observe_and_assert()?; self.harness.advance(tick); - self.harness.on_timer_a(); - self.harness.on_timer_b(); + self.harness.on_timer(Side::A); + self.harness.on_timer(Side::B); self.capture_all_outbound(); self.flush_pending_in_order(); self.observe_and_assert()?; @@ -432,23 +390,11 @@ impl Runner { Ok(()) } - fn drain_reads_a(&mut self) { - for stream_id in self.known_streams.clone() { - let appended = drain_stream(&mut self.harness.a.fsm, stream_id); - if !appended.is_empty() { - self.received_at_a - .entry(stream_id) - .or_default() - .extend_from_slice(&appended); - } - } - } - - fn drain_reads_b(&mut self) { + fn drain_reads(&mut self, side: Side) { for stream_id in self.known_streams.clone() { - let appended = drain_stream(&mut self.harness.b.fsm, stream_id); + let appended = drain_stream(&mut self.harness.node_mut(side).fsm, stream_id); if !appended.is_empty() { - self.received_at_b + self.received[side.idx()] .entry(stream_id) .or_default() .extend_from_slice(&appended); @@ -461,7 +407,7 @@ impl Runner { match event { QlFsmEvent::NewPeer => {} QlFsmEvent::PeerStatusChanged(status) => { - self.events_mut(side).note_peer_status(status); + self.events[side.idx()].note_peer_status(status); } QlFsmEvent::Opened { stream_id, .. } => { prop_assert!( @@ -469,7 +415,7 @@ impl Runner { "side {side:?} emitted Opened for unknown stream {stream_id:?}" ); prop_assert!( - self.events_mut(side).opened.insert(stream_id), + self.events[side.idx()].opened.insert(stream_id), "side {side:?} emitted duplicate Opened for {stream_id:?}" ); } @@ -485,11 +431,11 @@ impl Runner { "side {side:?} emitted Finished for unknown stream {stream_id:?}" ); prop_assert!( - self.events_mut(side).finished.insert(stream_id), + self.events[side.idx()].finished.insert(stream_id), "side {side:?} emitted duplicate Finished for {stream_id:?}" ); prop_assert!( - !self.events(side).closed.contains(&stream_id), + !self.events[side.idx()].closed.contains(&stream_id), "side {side:?} emitted Finished after Closed for {stream_id:?}" ); } @@ -500,7 +446,7 @@ impl Runner { frame.stream_id ); prop_assert!( - self.events_mut(side).closed.insert(frame.stream_id), + self.events[side.idx()].closed.insert(frame.stream_id), "side {side:?} emitted duplicate Closed for {:?}", frame.stream_id ); @@ -512,12 +458,12 @@ impl Runner { "side {side:?} emitted WritableClosed for unknown stream {stream_id:?}" ); prop_assert!( - self.events_mut(side).writable_closed.insert(stream_id), + self.events[side.idx()].writable_closed.insert(stream_id), "side {side:?} emitted duplicate WritableClosed for {stream_id:?}" ); } QlFsmEvent::SessionClosed(_) => { - let state = self.events_mut(side); + let state = &mut self.events[side.idx()]; prop_assert!( state.session_epoch > 0, "side {side:?} emitted SessionClosed without a connected session" @@ -536,26 +482,16 @@ impl Runner { } fn assert_prefix_invariants(&self) -> TestCaseResult { - for (stream_id, received) in &self.received_at_a { - let expected = self - .expected_at_a - .get(stream_id) - .map_or(&[][..], Vec::as_slice); - prop_assert!( - expected.starts_with(received), - "side A observed non-prefix bytes on {stream_id:?}: received={received:?} expected={expected:?}" - ); - } - - for (stream_id, received) in &self.received_at_b { - let expected = self - .expected_at_b - .get(stream_id) - .map_or(&[][..], Vec::as_slice); - prop_assert!( - expected.starts_with(received), - "side B observed non-prefix bytes on {stream_id:?}: received={received:?} expected={expected:?}" - ); + for side in [Side::A, Side::B] { + for (stream_id, received) in &self.received[side.idx()] { + let expected = self.expected[side.idx()] + .get(stream_id) + .map_or(&[][..], Vec::as_slice); + prop_assert!( + expected.starts_with(received), + "side {side:?} observed non-prefix bytes on {stream_id:?}: received={received:?} expected={expected:?}" + ); + } } Ok(()) @@ -596,79 +532,48 @@ impl Runner { } fn assert_terminal_semantics(&self) -> TestCaseResult { - for stream_id in &self.events_a.finished { - if self.inbound_aborted(Side::A, *stream_id) { - continue; - } - let expected = self - .expected_at_a - .get(stream_id) - .map_or(&[][..], Vec::as_slice); - let received = self - .received_at_a - .get(stream_id) - .map_or(&[][..], Vec::as_slice); - prop_assert_eq!( - received, - expected, - "side A finished {:?} without receiving all expected bytes", - stream_id - ); - } - - for stream_id in &self.events_b.finished { - if self.inbound_aborted(Side::B, *stream_id) { - continue; - } - let expected = self - .expected_at_b - .get(stream_id) - .map_or(&[][..], Vec::as_slice); - let received = self - .received_at_b - .get(stream_id) - .map_or(&[][..], Vec::as_slice); - prop_assert_eq!( - received, - expected, - "side B finished {:?} without receiving all expected bytes", - stream_id - ); - } - let a_connected = matches!(self.harness.a.fsm.state.link, LinkState::Connected(_)); let b_connected = matches!(self.harness.b.fsm.state.link, LinkState::Connected(_)); + let connected = [a_connected, b_connected]; - for stream_id in &self.finished_by_a { - prop_assert!( - self.events_b.finished.contains(stream_id) - || self.events_b.closed.contains(stream_id) - || !b_connected, - "side A finished {stream_id:?} but side B saw neither Finished nor Closed" - ); - } - - for stream_id in &self.finished_by_b { - prop_assert!( - self.events_a.finished.contains(stream_id) - || self.events_a.closed.contains(stream_id) - || !a_connected, - "side B finished {stream_id:?} but side A saw neither Finished nor Closed" - ); - } - - for stream_id in &self.closed_by_a { - prop_assert!( - self.events_b.closed.contains(stream_id) || !b_connected, - "side A closed {stream_id:?} but side B saw no Closed event" - ); - } - - for stream_id in &self.closed_by_b { - prop_assert!( - self.events_a.closed.contains(stream_id) || !a_connected, - "side B closed {stream_id:?} but side A saw no Closed event" - ); + for side in [Side::A, Side::B] { + for stream_id in &self.events[side.idx()].finished { + if self.inbound_aborted(side, *stream_id) { + continue; + } + let expected = self.expected[side.idx()] + .get(stream_id) + .map_or(&[][..], Vec::as_slice); + let received = self.received[side.idx()] + .get(stream_id) + .map_or(&[][..], Vec::as_slice); + prop_assert_eq!( + received, + expected, + "side {:?} finished {:?} without receiving all expected bytes", + side, + stream_id + ); + } + + for stream_id in &self.finished_by[side.idx()] { + prop_assert!( + self.events[opposite(side).idx()].finished.contains(stream_id) + || self.events[opposite(side).idx()].closed.contains(stream_id) + || !connected[opposite(side).idx()], + "side {side:?} finished {stream_id:?} but side {:?} saw neither Finished nor Closed", + opposite(side) + ); + } + + for stream_id in &self.closed_by[side.idx()] { + prop_assert!( + self.events[opposite(side).idx()].closed.contains(stream_id) + || !connected[opposite(side).idx()], + "side {side:?} closed {stream_id:?} but side {:?} saw no Closed event", + opposite(side) + ); + } } Ok(()) @@ -677,14 +582,12 @@ impl Runner { fn assert_no_stream_events(&self) -> TestCaseResult { prop_assert!( self.known_streams.is_empty() - && self.events_a.opened.is_empty() - && self.events_b.opened.is_empty() - && self.events_a.finished.is_empty() - && self.events_b.finished.is_empty() - && self.events_a.closed.is_empty() - && self.events_b.closed.is_empty() - && self.events_a.writable_closed.is_empty() - && self.events_b.writable_closed.is_empty(), + && self.events.iter().all(|events| { + events.opened.is_empty() + && events.finished.is_empty() + && events.closed.is_empty() + && events.writable_closed.is_empty() + }), "handshake-only property observed stream activity" ); Ok(()) @@ -692,7 +595,7 @@ impl Runner { fn assert_no_taken_writes(&self) -> TestCaseResult { prop_assert!( - self.taken_a_to_b.is_empty() && self.taken_b_to_a.is_empty(), + self.taken.iter().all(Vec::is_empty), "cleanup left taken writes queued" ); Ok(()) @@ -703,7 +606,7 @@ impl Runner { for _ in 0..8 { self.capture_all_outbound(); - if self.pending_a_to_b.is_empty() && self.pending_b_to_a.is_empty() { + if self.pending.iter().all(Vec::is_empty) { break; } self.flush_pending_in_order(); @@ -712,156 +615,82 @@ impl Runner { self.capture_all_outbound(); prop_assert!( - self.pending_a_to_b.is_empty() - && self.pending_b_to_a.is_empty() - && self.taken_a_to_b.is_empty() - && self.taken_b_to_a.is_empty(), + self.pending.iter().all(Vec::is_empty) && self.taken.iter().all(Vec::is_empty), "cleanup did not quiesce: taken_a={} taken_b={} pending_a={} pending_b={}", - self.taken_a_to_b.len(), - self.taken_b_to_a.len(), - self.pending_a_to_b.len(), - self.pending_b_to_a.len() + self.taken[Side::A.idx()].len(), + self.taken[Side::B.idx()].len(), + self.pending[Side::A.idx()].len(), + self.pending[Side::B.idx()].len() ); Ok(()) } fn capture_all_outbound(&mut self) { - while let Some(record) = take_confirmed_outbound_a(&mut self.harness) { - self.pending_a_to_b.push(record); - } - - while let Some(record) = take_confirmed_outbound_b(&mut self.harness) { - self.pending_b_to_a.push(record); + for side in [Side::A, Side::B] { + while let Some(record) = take_confirmed_outbound(&mut self.harness, side) { + self.pending[side.idx()].push(record); + } } } fn flush_pending_in_order(&mut self) { - while let Some(record) = pop_front_pending(&mut self.pending_a_to_b) { - self.deliver_to_b(record); - } - - while let Some(record) = pop_front_pending(&mut self.pending_b_to_a) { - self.deliver_to_a(record); + for side in [Side::A, Side::B] { + while let Some(record) = pop_front_pending(&mut self.pending[side.idx()]) { + self.deliver_to(opposite(side), record); + } } } fn reject_all_taken(&mut self) { - while let Some(write) = self.taken_a_to_b.pop() { - reject_taken_a(&mut self.harness, &write); - } - - while let Some(write) = self.taken_b_to_a.pop() { - reject_taken_b(&mut self.harness, &write); - } - } - - fn deliver_to_a(&mut self, record: Vec) { - if let Err(error) = deliver_to_a(&mut self.harness, record) { - self.receive_errors.push((Side::A, error)); - } - } - - fn deliver_to_b(&mut self, record: Vec) { - if let Err(error) = deliver_to_b(&mut self.harness, record) { - self.receive_errors.push((Side::B, error)); - } - } - - fn events_mut(&mut self, side: Side) -> &mut SideEventState { - match side { - Side::A => &mut self.events_a, - Side::B => &mut self.events_b, + for side in [Side::A, Side::B] { + while let Some(write) = self.taken[side.idx()].pop() { + reject_taken(&mut self.harness, side, &write); + } } } - fn events(&self, side: Side) -> &SideEventState { - match side { - Side::A => &self.events_a, - Side::B => &self.events_b, + fn deliver_to(&mut self, side: Side, record: Vec) { + if let Err(error) = deliver_to(&mut self.harness, side, record) { + self.receive_errors.push((side, error)); } } fn inbound_aborted(&self, side: Side, stream_id: StreamId) -> bool { - self.events(side).closed.contains(&stream_id) - || match side { - Side::A => self.closed_by_a.contains(&stream_id), - Side::B => self.closed_by_b.contains(&stream_id), - } + self.events[side.idx()].closed.contains(&stream_id) + || self.closed_by[side.idx()].contains(&stream_id) } } -fn take_unconfirmed_outbound_a(harness: &mut Harness) -> Option { - let time = harness.time(); - let Node { fsm, crypto, .. } = &mut harness.a; - let write = fsm.take_next_write(time, crypto)?; - Some(TakenWrite { - record: write.record, - write_id: write.session_write_id, - }) -} - -fn take_unconfirmed_outbound_b(harness: &mut Harness) -> Option { - let time = harness.time(); - let Node { fsm, crypto, .. } = &mut harness.b; - let write = fsm.take_next_write(time, crypto)?; +fn take_unconfirmed_outbound(harness: &mut Harness, side: Side) -> Option { + let write = harness.next_write(side)?; Some(TakenWrite { record: write.record, write_id: write.session_write_id, }) } -fn take_confirmed_outbound_a(harness: &mut Harness) -> Option> { - let write = take_unconfirmed_outbound_a(harness)?; - confirm_taken_a(harness, &write); +fn take_confirmed_outbound(harness: &mut Harness, side: Side) -> Option> { + let write = take_unconfirmed_outbound(harness, side)?; + confirm_taken(harness, side, &write); Some(write.record) } -fn take_confirmed_outbound_b(harness: &mut Harness) -> Option> { - let write = take_unconfirmed_outbound_b(harness)?; - confirm_taken_b(harness, &write); - Some(write.record) -} - -fn confirm_taken_a(harness: &mut Harness, write: &TakenWrite) { - if let Some(write_id) = write.write_id { - harness - .a - .fsm - .confirm_session_write(harness.time(), write_id); - } -} - -fn confirm_taken_b(harness: &mut Harness, write: &TakenWrite) { +fn confirm_taken(harness: &mut Harness, side: Side, write: &TakenWrite) { if let Some(write_id) = write.write_id { - harness - .b - .fsm - .confirm_session_write(harness.time(), write_id); + harness.confirm_write(side, write_id); } } -fn reject_taken_a(harness: &mut Harness, write: &TakenWrite) { +fn reject_taken(harness: &mut Harness, side: Side, write: &TakenWrite) { if let Some(write_id) = write.write_id { - harness.a.fsm.reject_session_write(write_id); + harness.reject_write(side, write_id); } } -fn reject_taken_b(harness: &mut Harness, write: &TakenWrite) { - if let Some(write_id) = write.write_id { - harness.b.fsm.reject_session_write(write_id); - } -} - -fn deliver_to_a(harness: &mut Harness, record: Vec) -> Result<(), ReceiveError> { - let time = harness.time(); - let Node { fsm, crypto } = &mut harness.a; - fsm.receive(time, record, crypto) -} - -fn deliver_to_b(harness: &mut Harness, record: Vec) -> Result<(), ReceiveError> { +fn deliver_to(harness: &mut Harness, side: Side, record: Vec) -> Result<(), ReceiveError> { let time = harness.time(); - let Node { fsm, crypto } = &mut harness.b; + let Node { fsm, crypto } = harness.node_mut(side); fsm.receive(time, record, crypto) } @@ -920,36 +749,54 @@ fn drain_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { out } +fn opposite(side: Side) -> Side { + match side { + Side::A => Side::B, + Side::B => Side::A, + } +} + +fn side_strategy() -> impl Strategy { + prop_oneof![Just(Side::A), Just(Side::B)] +} + +fn side_action(f: fn(Side) -> Action) -> impl Strategy { + side_strategy().prop_map(f) +} + +fn side_usize_action( + values: impl Strategy, + f: fn(Side, usize) -> Action, +) -> impl Strategy { + (side_strategy(), values).prop_map(move |(side, value)| f(side, value)) +} + +fn side_usize_vec_action( + values: impl Strategy, + bytes: impl Strategy>, + f: fn(Side, usize, Vec) -> Action, +) -> impl Strategy { + (side_strategy(), values, bytes).prop_map(move |(side, value, bytes)| f(side, value, bytes)) +} + fn handshake_action_strategy() -> impl Strategy { let queue_index = 0usize..6; prop_oneof![ - Just(Action::ConnectIkA), - Just(Action::ConnectIkB), - Just(Action::ConnectKkA), - Just(Action::ConnectKkB), + side_action(Action::ConnectIk), + side_action(Action::ConnectKk), (0u8..40).prop_map(Action::AdvanceMs), - Just(Action::OnTimerA), - Just(Action::OnTimerB), + side_action(Action::OnTimer), Just(Action::OnTimerBoth), Just(Action::Pump), - Just(Action::TakeNextAToB), - Just(Action::TakeNextBToA), - queue_index.clone().prop_map(Action::ConfirmTakenAToB), - queue_index.clone().prop_map(Action::ConfirmTakenBToA), - queue_index.clone().prop_map(Action::RejectTakenAToB), - queue_index.clone().prop_map(Action::RejectTakenBToA), - Just(Action::CaptureNextAToB), - Just(Action::CaptureNextBToA), - Just(Action::DeliverNextAToB), - Just(Action::DeliverNextBToA), - Just(Action::DropNextAToB), - Just(Action::DropNextBToA), - queue_index.clone().prop_map(Action::DeliverQueuedAToB), - queue_index.clone().prop_map(Action::DeliverQueuedBToA), - queue_index.clone().prop_map(Action::DuplicateQueuedAToB), - queue_index.clone().prop_map(Action::DuplicateQueuedBToA), - queue_index.clone().prop_map(Action::DropQueuedAToB), - queue_index.prop_map(Action::DropQueuedBToA), + side_action(Action::TakeNext), + side_usize_action(queue_index.clone(), Action::confirm_taken), + side_usize_action(queue_index.clone(), Action::reject_taken), + side_action(Action::CaptureNext), + side_action(Action::DeliverNext), + side_action(Action::DropNext), + side_usize_action(queue_index.clone(), Action::deliver_queued), + side_usize_action(queue_index.clone(), Action::duplicate_queued), + side_usize_action(queue_index, Action::drop_queued), ] } @@ -959,36 +806,22 @@ fn connected_action_strategy() -> impl Strategy { let queue_index = 0usize..6; prop_oneof![ (0u8..30).prop_map(Action::AdvanceMs), - Just(Action::OnTimerA), - Just(Action::OnTimerB), + side_action(Action::OnTimer), Just(Action::OnTimerBoth), Just(Action::Pump), - Just(Action::TakeNextAToB), - Just(Action::TakeNextBToA), - queue_index.clone().prop_map(Action::ConfirmTakenAToB), - queue_index.clone().prop_map(Action::ConfirmTakenBToA), - queue_index.clone().prop_map(Action::RejectTakenAToB), - queue_index.clone().prop_map(Action::RejectTakenBToA), - Just(Action::CaptureNextAToB), - Just(Action::CaptureNextBToA), - Just(Action::DeliverNextAToB), - Just(Action::DeliverNextBToA), - Just(Action::DropNextAToB), - Just(Action::DropNextBToA), - queue_index.clone().prop_map(Action::DeliverQueuedAToB), - queue_index.clone().prop_map(Action::DeliverQueuedBToA), - queue_index.clone().prop_map(Action::DuplicateQueuedAToB), - queue_index.clone().prop_map(Action::DuplicateQueuedBToA), - queue_index.clone().prop_map(Action::DropQueuedAToB), - queue_index.prop_map(Action::DropQueuedBToA), - slot.clone().prop_map(Action::OpenStreamA), - slot.clone().prop_map(Action::OpenStreamB), - (slot.clone(), bytes.clone()).prop_map(|(slot, bytes)| Action::WriteA { slot, bytes }), - (slot.clone(), bytes).prop_map(|(slot, bytes)| Action::WriteB { slot, bytes }), - slot.clone().prop_map(Action::FinishA), - slot.clone().prop_map(Action::FinishB), - slot.clone().prop_map(Action::CloseA), - slot.prop_map(Action::CloseB), + side_action(Action::TakeNext), + side_usize_action(queue_index.clone(), Action::confirm_taken), + side_usize_action(queue_index.clone(), Action::reject_taken), + side_action(Action::CaptureNext), + side_action(Action::DeliverNext), + side_action(Action::DropNext), + side_usize_action(queue_index.clone(), Action::deliver_queued), + side_usize_action(queue_index.clone(), Action::duplicate_queued), + side_usize_action(queue_index.clone(), Action::drop_queued), + side_usize_action(slot.clone(), Action::open_stream), + side_usize_vec_action(slot.clone(), bytes.clone(), Action::write), + side_usize_action(slot.clone(), Action::finish), + side_usize_action(slot, Action::close), ] } @@ -997,25 +830,16 @@ fn write_tracking_action_strategy() -> impl Strategy { let slot = 0usize..SLOT_COUNT; let queue_index = 0usize..6; prop_oneof![ - slot.clone().prop_map(Action::OpenStreamA), - slot.clone().prop_map(Action::OpenStreamB), - (slot.clone(), bytes.clone()).prop_map(|(slot, bytes)| Action::WriteA { slot, bytes }), - (slot, bytes).prop_map(|(slot, bytes)| Action::WriteB { slot, bytes }), - Just(Action::TakeNextAToB), - Just(Action::TakeNextBToA), - queue_index.clone().prop_map(Action::ConfirmTakenAToB), - queue_index.clone().prop_map(Action::ConfirmTakenBToA), - queue_index.clone().prop_map(Action::RejectTakenAToB), - queue_index.clone().prop_map(Action::RejectTakenBToA), - queue_index.clone().prop_map(Action::DeliverQueuedAToB), - queue_index.clone().prop_map(Action::DeliverQueuedBToA), - queue_index.clone().prop_map(Action::DuplicateQueuedAToB), - queue_index.clone().prop_map(Action::DuplicateQueuedBToA), - queue_index.clone().prop_map(Action::DropQueuedAToB), - queue_index.prop_map(Action::DropQueuedBToA), + side_usize_action(slot.clone(), Action::open_stream), + side_usize_vec_action(slot, bytes, Action::write), + side_action(Action::TakeNext), + side_usize_action(queue_index.clone(), Action::confirm_taken), + side_usize_action(queue_index.clone(), Action::reject_taken), + side_usize_action(queue_index.clone(), Action::deliver_queued), + side_usize_action(queue_index.clone(), Action::duplicate_queued), + side_usize_action(queue_index, Action::drop_queued), Just(Action::Pump), - Just(Action::OnTimerA), - Just(Action::OnTimerB), + side_action(Action::OnTimer), Just(Action::OnTimerBoth), (0u8..20).prop_map(Action::AdvanceMs), ] @@ -1026,29 +850,18 @@ fn terminal_action_strategy() -> impl Strategy { let slot = 0usize..SLOT_COUNT; let queue_index = 0usize..6; prop_oneof![ - slot.clone().prop_map(Action::OpenStreamA), - slot.clone().prop_map(Action::OpenStreamB), - (slot.clone(), bytes.clone()).prop_map(|(slot, bytes)| Action::WriteA { slot, bytes }), - (slot.clone(), bytes).prop_map(|(slot, bytes)| Action::WriteB { slot, bytes }), - slot.clone().prop_map(Action::FinishA), - slot.clone().prop_map(Action::FinishB), - slot.clone().prop_map(Action::CloseA), - slot.prop_map(Action::CloseB), - Just(Action::TakeNextAToB), - Just(Action::TakeNextBToA), - queue_index.clone().prop_map(Action::ConfirmTakenAToB), - queue_index.clone().prop_map(Action::ConfirmTakenBToA), - queue_index.clone().prop_map(Action::RejectTakenAToB), - queue_index.clone().prop_map(Action::RejectTakenBToA), - queue_index.clone().prop_map(Action::DeliverQueuedAToB), - queue_index.clone().prop_map(Action::DeliverQueuedBToA), - queue_index.clone().prop_map(Action::DuplicateQueuedAToB), - queue_index.clone().prop_map(Action::DuplicateQueuedBToA), - queue_index.clone().prop_map(Action::DropQueuedAToB), - queue_index.prop_map(Action::DropQueuedBToA), + side_usize_action(slot.clone(), Action::open_stream), + side_usize_vec_action(slot.clone(), bytes.clone(), Action::write), + side_usize_action(slot.clone(), Action::finish), + side_usize_action(slot, Action::close), + side_action(Action::TakeNext), + side_usize_action(queue_index.clone(), Action::confirm_taken), + side_usize_action(queue_index.clone(), Action::reject_taken), + side_usize_action(queue_index.clone(), Action::deliver_queued), + side_usize_action(queue_index.clone(), Action::duplicate_queued), + side_usize_action(queue_index, Action::drop_queued), Just(Action::Pump), - Just(Action::OnTimerA), - Just(Action::OnTimerB), + side_action(Action::OnTimer), Just(Action::OnTimerBoth), (0u8..20).prop_map(Action::AdvanceMs), ] diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index a210af4e..2c4fbb0e 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -79,9 +79,9 @@ fn connected_fsms_deliver_stream_data() { harness.pump(); - assert_eq!(harness.take_event_b(), Some(opened(stream_id))); + assert_eq!(harness.take_event(Side::B), Some(opened(stream_id))); assert_eq!( - harness.take_event_b(), + harness.take_event(Side::B), Some(QlFsmEvent::Readable(stream_id)) ); assert_eq!( @@ -89,7 +89,7 @@ fn connected_fsms_deliver_stream_data() { b"hello".to_vec() ); assert_eq!( - harness.take_event_b(), + harness.take_event(Side::B), Some(QlFsmEvent::Finished(stream_id)) ); } @@ -105,30 +105,25 @@ fn session_retransmit_uses_new_record_seq() { 5 ); - let first = harness.next_outbound_a().unwrap(); - let first_transport = harness.b.fsm.state.link.transport().unwrap().clone(); - let (first_header, first_record) = - decrypt_record(&harness.b.crypto, &first, &first_transport.rx_key); + let first = harness.next_decoded_outbound(Side::A).unwrap(); harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); - harness.on_timer_a(); + harness.on_timer(Side::A); - let retried = harness.next_outbound_a().unwrap(); - let (retried_header, retried_record) = - decrypt_record(&harness.b.crypto, &retried, &first_transport.rx_key); + let retried = harness.next_decoded_outbound(Side::A).unwrap(); - assert_ne!(retried_header.seq, first_header.seq); - assert_eq!(retried_record, first_record); + assert_ne!(retried.header.seq, first.header.seq); + assert_eq!(retried.frames, first.frames); - harness.deliver_to_b(retried); + harness.deliver(Side::B, retried.record); harness.advance(config.session_record_ack_delay); - harness.on_timer_a(); - harness.on_timer_b(); + harness.on_timer(Side::A); + harness.on_timer(Side::B); harness.pump(); - assert_eq!(harness.take_event_b(), Some(opened(stream_id))); + assert_eq!(harness.take_event(Side::B), Some(opened(stream_id))); assert_eq!( - harness.take_event_b(), + harness.take_event(Side::B), Some(QlFsmEvent::Readable(stream_id)) ); assert_eq!( @@ -137,8 +132,8 @@ fn session_retransmit_uses_new_record_seq() { ); harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); - harness.on_timer_a(); - assert!(harness.next_outbound_a().is_none()); + harness.on_timer(Side::A); + assert!(harness.next_outbound(Side::A).is_none()); } #[test] @@ -169,18 +164,18 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { harness.pump(); - assert_eq!(harness.take_event_a(), Some(opened(stream_id_b))); + assert_eq!(harness.take_event(Side::A), Some(opened(stream_id_b))); assert_eq!( - harness.take_event_a(), + harness.take_event(Side::A), Some(QlFsmEvent::Readable(stream_id_b)) ); assert_eq!( read_stream_all(&mut harness.a.fsm, stream_id_b), b"from-b".to_vec() ); - assert_eq!(harness.take_event_b(), Some(opened(stream_id_a))); + assert_eq!(harness.take_event(Side::B), Some(opened(stream_id_a))); assert_eq!( - harness.take_event_b(), + harness.take_event(Side::B), Some(QlFsmEvent::Readable(stream_id_a)) ); assert_eq!( @@ -261,30 +256,25 @@ fn returned_session_write_is_reissued_with_new_record_seq() { 5 ); - let write = harness.next_write_a().unwrap(); - let id = write.session_write_id.expect("expected session write"); - let record = write.record; - let session_key = harness.b.fsm.state.link.transport().unwrap().rx_key.clone(); - let (first_header, first) = decrypt_record(&harness.b.crypto, &record, &session_key); + let first = harness.next_decoded_write(Side::A).unwrap(); + let id = first.write_id.expect("expected session write"); - harness.return_write_a(id); + harness.reject_write(Side::A, id); - let write = harness.next_write_a().unwrap(); - let reissued_id = write.session_write_id.expect("expected reissued write"); - let record = write.record; - let (reissued_header, reissued) = decrypt_record(&harness.b.crypto, &record, &session_key); + let reissued = harness.next_decoded_write(Side::A).unwrap(); + let reissued_id = reissued.write_id.expect("expected reissued write"); assert_ne!(reissued_id, id); - assert_ne!(reissued_header.seq, first_header.seq); - assert_eq!(reissued, first); + assert_ne!(reissued.header.seq, first.header.seq); + assert_eq!(reissued.frames, first.frames); - harness.confirm_write_a(reissued_id); - harness.deliver_to_b(record); + harness.confirm_write(Side::A, reissued_id); + harness.deliver(Side::B, reissued.record); harness.pump(); - assert_eq!(harness.take_event_b(), Some(opened(stream_id))); + assert_eq!(harness.take_event(Side::B), Some(opened(stream_id))); assert_eq!( - harness.take_event_b(), + harness.take_event(Side::B), Some(QlFsmEvent::Readable(stream_id)) ); assert_eq!( @@ -304,26 +294,21 @@ fn unconfirmed_session_write_does_not_start_retransmit_timer() { 5 ); - let write = harness.next_write_a().unwrap(); - let id = write.session_write_id.expect("expected session write"); - let record = write.record; - let session_key = harness.b.fsm.state.link.transport().unwrap().rx_key.clone(); - let (first_header, first) = decrypt_record(&harness.b.crypto, &record, &session_key); + let first = harness.next_decoded_write(Side::A).unwrap(); + let id = first.write_id.expect("expected session write"); harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); - harness.on_timer_a(); - assert!(harness.next_write_a().is_none()); + harness.on_timer(Side::A); + assert!(harness.next_write(Side::A).is_none()); - harness.confirm_write_a(id); + harness.confirm_write(Side::A, id); harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); - harness.on_timer_a(); + harness.on_timer(Side::A); - let write = harness.next_write_a().unwrap(); - let record = write.record; - let (retried_header, retried) = decrypt_record(&harness.b.crypto, &record, &session_key); + let retried = harness.next_decoded_write(Side::A).unwrap(); - assert_ne!(retried_header.seq, first_header.seq); - assert_eq!(retried, first); + assert_ne!(retried.header.seq, first.header.seq); + assert_eq!(retried.frames, first.frames); } #[test] @@ -344,15 +329,15 @@ fn ack_frame_releases_stream_capacity_and_emits_writable() { 0 ); - let record = harness.next_outbound_a().unwrap(); - harness.deliver_to_b(record); + let record = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, record); harness.advance(config.session_record_ack_delay); - harness.on_timer_a(); - harness.on_timer_b(); + harness.on_timer(Side::A); + harness.on_timer(Side::B); harness.pump(); assert_eq!( - harness.take_event_a(), + harness.take_event(Side::A), Some(QlFsmEvent::Writable(stream_id)) ); } @@ -366,7 +351,7 @@ fn close_session_disconnects_locally() { .fsm .close_session(ql_wire::SessionCloseCode::CANCELLED); - assert!(matches!(harness.take_event_a(), Some(QlFsmEvent::SessionClosed(SessionClose { + assert!(matches!(harness.take_event(Side::A), Some(QlFsmEvent::SessionClosed(SessionClose { code: ql_wire::SessionCloseCode::CANCELLED, })))); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); @@ -376,14 +361,12 @@ fn close_session_disconnects_locally() { )); assert_eq!(harness.a.fsm.queue_ping(), Err(NoSessionError)); - let close = harness.next_outbound_a().unwrap(); - let session_key = harness.b.fsm.state.link.transport().unwrap().rx_key.clone(); - let (_header, record) = decrypt_record(&harness.b.crypto, &close, &session_key); - assert!(matches!(record.as_slice(), [ql_wire::SessionFrame::Close(_)])); + let close = harness.next_decoded_outbound(Side::A).unwrap(); + assert!(matches!(close.frames.as_slice(), [ql_wire::SessionFrame::Close(_)])); assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); assert_eq!( - harness.take_event_a(), + harness.take_event(Side::A), Some(QlFsmEvent::PeerStatusChanged(PeerStatus::Disconnected)) ); } @@ -399,16 +382,14 @@ fn session_records_contain_ack_frames_after_delivery() { 1 ); - let data = harness.next_outbound_a().unwrap(); - harness.deliver_to_b(data); + let data = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, data); harness.advance(config.session_record_ack_delay); - harness.on_timer_b(); + harness.on_timer(Side::B); - let ack = harness.next_outbound_b().unwrap(); - let session_key = harness.a.fsm.state.link.transport().unwrap().rx_key.clone(); - let (_ack_header, ack_record) = decrypt_record(&harness.a.crypto, &ack, &session_key); + let ack = harness.next_decoded_outbound(Side::B).unwrap(); assert!(matches!( - ack_record.as_slice(), + ack.frames.as_slice(), [ql_wire::SessionFrame::Ack(_)] )); } @@ -426,11 +407,11 @@ fn first_stream_data_uses_negotiated_initial_peer_credit() { }, ); - harness.connect_ik_a().unwrap(); - let ik1 = harness.next_outbound_a().unwrap(); - harness.deliver_to_b(ik1); - let ik2 = harness.next_outbound_b().unwrap(); - harness.deliver_to_a(ik2); + harness.connect_ik(Side::A).unwrap(); + let ik1 = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, ik1); + let ik2 = harness.next_outbound(Side::B).unwrap(); + harness.deliver(Side::A, ik2); let stream_id = open_stream_id(&mut harness.a.fsm); assert_eq!( @@ -438,12 +419,8 @@ fn first_stream_data_uses_negotiated_initial_peer_credit() { 5 ); - let data = harness.next_outbound_a().unwrap(); - let session_key = harness.b.fsm.state.link.transport().unwrap().rx_key.clone(); - let (_header, record) = decrypt_record(&harness.b.crypto, &data, &session_key); - assert!(matches!( - record.as_slice(), + harness.next_decoded_outbound(Side::A).unwrap().frames.as_slice(), [ql_wire::SessionFrame::StreamData(frame)] if frame.stream_id == stream_id && frame.bytes.as_slice() == b"hel" )); } @@ -457,21 +434,19 @@ fn session_timeout_emits_close_before_disconnect() { let mut harness = Harness::connected(config); harness.advance(config.session_peer_timeout); - harness.on_timer_a(); + harness.on_timer(Side::A); assert_eq!( - harness.drain_events_a(), + harness.drain_events(Side::A), vec![QlFsmEvent::SessionClosed(SessionClose { code: ql_wire::SessionCloseCode::TIMEOUT, })] ); - let close = harness.next_outbound_a().unwrap(); - let session_key = harness.b.fsm.state.link.transport().unwrap().rx_key.clone(); - let (_header, record) = decrypt_record(&harness.b.crypto, &close, &session_key); - assert!(matches!(record.as_slice(), [ql_wire::SessionFrame::Close(_)])); + let close = harness.next_decoded_outbound(Side::A).unwrap(); + assert!(matches!(close.frames.as_slice(), [ql_wire::SessionFrame::Close(_)])); assert_eq!( - harness.take_event_a(), + harness.take_event(Side::A), Some(QlFsmEvent::PeerStatusChanged(PeerStatus::Disconnected)) ); } From 3c0ce5ae2c8664ec382ff4d4aadf93c782c05a69 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 9 Apr 2026 20:03:27 -0400 Subject: [PATCH 177/304] ql-runtime: test compaction --- ql-runtime/src/tests/handshake.rs | 36 +------- ql-runtime/src/tests/mod.rs | 95 +++++++++++++++++++ ql-runtime/src/tests/rpc.rs | 72 +++------------ ql-runtime/src/tests/stream.rs | 146 +++++++++--------------------- 4 files changed, 156 insertions(+), 193 deletions(-) diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index fa2ab8b5..20629a99 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -7,25 +7,8 @@ use super::*; #[tokio::test(flavor = "current_thread")] async fn connect_round_trip_changes_peer_status() { run_local_test(async { - let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(); - let (platform_b, outbound_b, status_b) = TestPlatform::new(); - let (identity_a, identity_b) = test_identities(&SoftwareCrypto); - - let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); - let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect(); - - await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; - await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; + let pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; }) .await; } @@ -33,20 +16,9 @@ async fn connect_round_trip_changes_peer_status() { #[tokio::test(flavor = "current_thread")] async fn opening_stream_requires_connection() { run_local_test(async { - let config = default_runtime_config(); - let (platform_a, _outbound_a, _status_a) = TestPlatform::new(); - let (platform_b, _outbound_b, _status_b, _inbound_b) = TestPlatform::new_with_inbound(); - let (identity_a, identity_b) = test_identities(&SoftwareCrypto); - - let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); - let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + let pair = TestPair::new(default_runtime_config()); assert!(matches!( - handle_a.open_stream(test_route_id()).await, + pair.side(Side::A).handle.open_stream(test_route_id()).await, Err(NoSessionError) )); }) diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 587b0bb1..fc82bcda 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -36,6 +36,21 @@ struct StatusEvent { status: PeerStatus, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Side { + A, + B, +} + +impl Side { + fn opposite(self) -> Self { + match self { + Side::A => Side::B, + Side::B => Side::A, + } + } +} + fn test_route_id() -> RouteId { RouteId(VarInt::from_u32(1)) } @@ -125,6 +140,86 @@ impl TestPlatform { } } +struct TestSide { + handle: RuntimeHandle, + status: Receiver, + peer: XID, + inbound: Receiver, +} + +struct TestPair { + a: TestSide, + b: TestSide, +} + +impl TestPair { + fn new(config: RuntimeConfig) -> Self { + let (platform_a, outbound_a, status_a, inbound_a) = TestPlatform::new_with_inbound(); + let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config.clone()); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, handle_b.clone()); + spawn_forwarder(outbound_b, handle_a.clone()); + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + + Self { + a: TestSide { + handle: handle_a, + status: status_a, + peer: identity_a.xid, + inbound: inbound_a, + }, + b: TestSide { + handle: handle_b, + status: status_b, + peer: identity_b.xid, + inbound: inbound_b, + }, + } + } + + fn side(&self, side: Side) -> &TestSide { + match side { + Side::A => &self.a, + Side::B => &self.b, + } + } + + fn side_mut(&mut self, side: Side) -> &mut TestSide { + match side { + Side::A => &mut self.a, + Side::B => &mut self.b, + } + } + + async fn connect_and_wait(&self, initiator: Side) { + self.side(initiator).handle.connect(); + await_status( + &self.side(initiator).status, + self.side(initiator.opposite()).peer, + PeerStatus::Connected, + ) + .await; + await_status( + &self.side(initiator.opposite()).status, + self.side(initiator).peer, + PeerStatus::Connected, + ) + .await; + } + + fn take_inbound(&mut self, side: Side) -> Receiver { + let replacement = async_channel::unbounded().1; + std::mem::replace(&mut self.side_mut(side).inbound, replacement) + } +} + struct TokioTimer { sleep: Pin>, } diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 93dc82d2..d6b8952a 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -53,25 +53,9 @@ impl ql_rpc::request_with_progress::RequestWithProgress for Download { #[tokio::test(flavor = "current_thread")] async fn rpc_request_round_trips() { run_local_test(async { - let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); - let (identity_a, identity_b) = test_identities(&SoftwareCrypto); - - let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); - let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect(); - - await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; - await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); @@ -93,7 +77,7 @@ async fn rpc_request_round_trips() { writer.finish(); }); - let rpc = handle_a.rpc(); + let rpc = pair.handle(Side::A).rpc(); let response = rpc .request::(&BytesValue(b"hello".to_vec())) .await @@ -111,25 +95,9 @@ async fn rpc_request_round_trips() { #[tokio::test(flavor = "current_thread")] async fn rpc_subscription_streams_events() { run_local_test(async { - let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); - let (identity_a, identity_b) = test_identities(&SoftwareCrypto); - - let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); - let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect(); - - await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; - await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); @@ -155,7 +123,7 @@ async fn rpc_subscription_streams_events() { writer.finish(); }); - let rpc = handle_a.rpc(); + let rpc = pair.handle(Side::A).rpc(); let mut subscription = rpc .subscribe::(&BytesValue(b"watch".to_vec())) .await @@ -181,25 +149,9 @@ async fn rpc_subscription_streams_events() { #[tokio::test(flavor = "current_thread")] async fn rpc_request_with_progress_supports_progress_then_await() { run_local_test(async { - let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); - let (identity_a, identity_b) = test_identities(&SoftwareCrypto); - - let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); - let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect(); - - await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; - await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); @@ -235,7 +187,7 @@ async fn rpc_request_with_progress_supports_progress_then_await() { writer.finish(); }); - let rpc = handle_a.rpc(); + let rpc = pair.handle(Side::A).rpc(); let mut download = rpc .request_with_progress::(&BytesValue(b"logo".to_vec())) .await diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index cce7d65f..4c114f2e 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -9,25 +9,9 @@ use crate::QlStreamError; #[tokio::test(flavor = "current_thread")] async fn open_stream_duplex_happy_path() { run_local_test(async { - let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); - let (identity_a, identity_b) = test_identities(&SoftwareCrypto); - - let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); - let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect(); - - await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; - await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); @@ -43,7 +27,12 @@ async fn open_stream_duplex_happy_path() { writer.finish(); }); - let mut stream = handle_a.open_stream(test_route_id()).await.unwrap(); + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); stream .writer .write(Bytes::from_static(&[1, 2])) @@ -73,25 +62,9 @@ async fn open_stream_duplex_happy_path() { #[tokio::test(flavor = "current_thread")] async fn reader_exposes_bounded_chunk_reads() { run_local_test(async { - let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); - let (identity_a, identity_b) = test_identities(&SoftwareCrypto); - - let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); - let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect(); - - await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; - await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); @@ -114,7 +87,12 @@ async fn reader_exposes_bounded_chunk_reads() { inbound.writer.finish(); }); - let mut stream = handle_a.open_stream(test_route_id()).await.unwrap(); + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); stream .writer .write(Bytes::from_static(&[1, 2, 3, 4])) @@ -139,28 +117,11 @@ async fn reader_exposes_bounded_chunk_reads() { #[tokio::test(flavor = "current_thread")] async fn large_stream_payload_round_trips() { run_local_test(async { - let config = default_runtime_config(); let payload: Vec = (0..40).collect(); - - let (platform_a, outbound_a, status_a) = TestPlatform::new(); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); - let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + let mut pair = TestPair::new(default_runtime_config()); let (done_tx, done_rx) = async_channel::bounded(1); - - let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); - let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect(); - - await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; - await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); @@ -169,7 +130,12 @@ async fn large_stream_payload_round_trips() { done_tx.send(request_data).await.unwrap(); }); - let mut stream = handle_a.open_stream(test_route_id()).await.unwrap(); + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); stream .writer .write(Bytes::from(payload.clone())) @@ -195,32 +161,21 @@ async fn large_stream_payload_round_trips() { #[tokio::test(flavor = "current_thread")] async fn dropping_responder_closes_initiator_response() { run_local_test(async { - let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); - let (identity_a, identity_b) = test_identities(&SoftwareCrypto); - - let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); - let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect(); - - await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; - await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); drop(stream.reader); }); - let mut stream = handle_a.open_stream(test_route_id()).await.unwrap(); + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); stream.writer.finish(); let err = next_chunk(&mut stream.reader).await.unwrap_err(); @@ -240,26 +195,10 @@ async fn dropping_responder_closes_initiator_response() { #[tokio::test(flavor = "current_thread")] async fn dropping_inbound_reader_cancels_remote_writer() { run_local_test(async { - let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); - let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + let mut pair = TestPair::new(default_runtime_config()); + let inbound_b = pair.take_inbound(Side::B); let (go_tx, go_rx) = async_channel::bounded(1); - - let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); - let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); - - register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect(); - - await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; - await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; + pair.connect_and_wait(Side::A).await; let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); @@ -275,7 +214,12 @@ async fn dropping_inbound_reader_cancels_remote_writer() { writer.finish(); }); - let mut stream = handle_a.open_stream(test_route_id()).await.unwrap(); + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); stream.writer.finish(); assert_eq!( next_chunk(&mut stream.reader).await.unwrap(), From 19c446a15dcf80b990f1a3c0df3a24d5cf98ebc4 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 9 Apr 2026 21:59:21 -0400 Subject: [PATCH 178/304] ql: refactoring internals --- ql-fsm/src/fsm.rs | 28 +++++++++--------- ql-fsm/src/handshake/mod.rs | 8 ++--- ql-fsm/src/lib.rs | 22 +++++++------- ql-fsm/src/session/mod.rs | 51 ++++++++++++++++--------------- ql-fsm/src/session/state.rs | 8 ++--- ql-fsm/src/session/tests.rs | 56 +++++++++++++++++------------------ ql-fsm/src/tests/handshake.rs | 10 +++---- ql-fsm/src/tests/mod.rs | 22 +++++++------- ql-fsm/src/tests/proptest.rs | 22 +++++++------- ql-fsm/src/tests/session.rs | 28 +++++++++--------- ql-runtime/src/driver/mod.rs | 46 +++++++++++----------------- 11 files changed, 145 insertions(+), 156 deletions(-) diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs index ee111f52..ec999d03 100644 --- a/ql-fsm/src/fsm.rs +++ b/ql-fsm/src/fsm.rs @@ -4,8 +4,8 @@ use bytes::Bytes; use ql_wire::{self as wire, QlCrypto, RouteId, SessionCloseCode, StreamId, WireDecode}; use crate::{ - handshake, session::SessionEvent, state::LinkState, NoSessionError, OutboundWrite, QlFsm, - QlFsmEvent, ReceiveError, SessionWriteId, StreamError, StreamOps, + handshake, session::SessionEvent, state::LinkState, Event, NoSessionError, OutboundWrite, + QlFsm, ReceiveError, StreamError, StreamOps, WriteId, }; pub fn handle_bind_peer(fsm: &mut QlFsm, peer: ql_wire::PeerBundle) { @@ -121,11 +121,11 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option Result<(), NoSessionError> { pub fn emit_peer_status(fsm: &mut QlFsm) { if fsm.state.peer.is_some() { fsm.pending_events - .push_back(QlFsmEvent::PeerStatusChanged(fsm.state.link.status())); + .push_back(Event::PeerStatusChanged(fsm.state.link.status())); } } fn forward_session_event( event: SessionEvent, - pending_events: &mut std::collections::VecDeque, + pending_events: &mut std::collections::VecDeque, ) { match event { SessionEvent::Opened { stream_id, route_id, } => { - pending_events.push_back(QlFsmEvent::Opened { + pending_events.push_back(Event::Opened { stream_id, route_id, }); } SessionEvent::Readable(stream_id) => { - pending_events.push_back(QlFsmEvent::Readable(stream_id)); + pending_events.push_back(Event::Readable(stream_id)); } SessionEvent::Writable(stream_id) => { - pending_events.push_back(QlFsmEvent::Writable(stream_id)); + pending_events.push_back(Event::Writable(stream_id)); } SessionEvent::Finished(stream_id) => { - pending_events.push_back(QlFsmEvent::Finished(stream_id)); + pending_events.push_back(Event::Finished(stream_id)); } SessionEvent::Closed(frame) => { - pending_events.push_back(QlFsmEvent::Closed(frame)); + pending_events.push_back(Event::Closed(frame)); } SessionEvent::WritableClosed(frame) => { - pending_events.push_back(QlFsmEvent::WritableClosed(frame)); + pending_events.push_back(Event::WritableClosed(frame)); } SessionEvent::SessionClosed(close) => { - pending_events.push_back(QlFsmEvent::SessionClosed(close)); + pending_events.push_back(Event::SessionClosed(close)); } } } diff --git a/ql-fsm/src/handshake/mod.rs b/ql-fsm/src/handshake/mod.rs index 0e528922..8146be08 100644 --- a/ql-fsm/src/handshake/mod.rs +++ b/ql-fsm/src/handshake/mod.rs @@ -8,9 +8,9 @@ use ql_wire::{ use crate::{ fsm::{deadline_after_secs, emit_peer_status}, - session::{SessionFsm, SessionFsmConfig, StreamParity}, + session::{SessionConfig, SessionFsm, StreamParity}, state::{ConnectedState, LinkState, SessionTransport}, - NoPeerError, QlFsm, QlFsmEvent, ReceiveError, + Event, NoPeerError, QlFsm, ReceiveError, }; pub fn handle_connect_ik(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { @@ -113,12 +113,12 @@ pub fn finish_handshake( } } else { fsm.state.peer = Some(remote_bundle); - fsm.pending_events.push_back(QlFsmEvent::NewPeer); + fsm.pending_events.push_back(Event::NewPeer); } let config = &fsm.config; let session = SessionFsm::new( - SessionFsmConfig { + SessionConfig { local_parity: StreamParity::for_local(fsm.identity.xid, xid), record_max_size: config.session_record_max_size, ack_delay: config.session_record_ack_delay, diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index ac14bbc3..2cb32683 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -67,7 +67,7 @@ pub enum PeerStatus { /// events emitted by `QlFsm` #[derive(Debug, Clone, PartialEq, Eq)] -pub enum QlFsmEvent { +pub enum Event { /// a peer was learned during handshake completion NewPeer, /// the peer changed connection state @@ -97,7 +97,7 @@ pub enum QlFsmEvent { /// handle for a session write returned by `QlFsm::take_next_write` #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct SessionWriteId(pub(crate) u64); +pub struct WriteId(pub(crate) u64); /// outbound record produced by `QlFsm` #[derive(Debug, Clone, PartialEq, Eq)] @@ -105,7 +105,7 @@ pub struct OutboundWrite { /// wire bytes to hand to the transport pub record: Vec, /// write handle that must be confirmed or rejected - pub session_write_id: Option, + pub session_write_id: Option, } /// timing and buffering knobs for `QlFsm` @@ -131,7 +131,7 @@ pub struct QlFsmConfig { impl Default for QlFsmConfig { fn default() -> Self { - let s = session::SessionFsmConfig::default(); + let s = session::SessionConfig::default(); Self { handshake_timeout: Duration::from_secs(5), session_record_ack_delay: s.ack_delay, @@ -148,11 +148,11 @@ impl Default for QlFsmConfig { /// synchronous driver for peer binding, handshake, and encrypted streams pub struct QlFsm { /// active configuration - pub config: QlFsmConfig, + config: QlFsmConfig, /// local identity and private keys - pub identity: QlIdentity, - pub(crate) state: QlFsmState, - pending_events: VecDeque, + identity: QlIdentity, + state: QlFsmState, + pending_events: VecDeque, } impl QlFsm { @@ -231,7 +231,7 @@ impl QlFsm { } /// returns the next queued event, if any - pub fn poll_event(&mut self) -> Option { + pub fn poll_event(&mut self) -> Option { self.pending_events.pop_front() } @@ -258,7 +258,7 @@ impl QlFsm { /// marks a `SessionWriteId` from `take_next_write` as handed to the transport /// /// call this at most once for each returned `SessionWriteId` - pub fn confirm_session_write(&mut self, now: FsmTime, write_id: SessionWriteId) { + pub fn confirm_session_write(&mut self, now: FsmTime, write_id: WriteId) { self.state.now = now; fsm::confirm_session_write(self, write_id); } @@ -266,7 +266,7 @@ impl QlFsm { /// reports that a `SessionWriteId` from `take_next_write` was not accepted /// /// call this at most once for each returned `SessionWriteId` - pub fn reject_session_write(&mut self, write_id: SessionWriteId) { + pub fn reject_session_write(&mut self, write_id: WriteId) { fsm::reject_session_write(self, write_id); } diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 162fe37c..243b0d51 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -27,8 +27,7 @@ use self::{ received_records::{ReceiveOutcome, ReceivedRecords}, remote_stream_history::RemoteStreamHistory, state::{ - AckState, InboundState, OutboundState, SessionFsmState, SessionState, StreamRole, - StreamState, + AckState, InboundState, OutboundState, SessionPhase, SessionState, StreamRole, StreamState, }, stream_tx::StreamTxRange, tracked::{TrackedFrame, TrackedRecord, TrackedStreamData}, @@ -36,7 +35,7 @@ use self::{ use crate::{NoSessionError, StreamError}; #[derive(Debug, Clone, Copy)] -pub struct SessionFsmConfig { +pub struct SessionConfig { pub local_parity: StreamParity, pub record_max_size: usize, pub ack_delay: Duration, @@ -48,7 +47,7 @@ pub struct SessionFsmConfig { pub initial_peer_stream_receive_window: u32, } -impl Default for SessionFsmConfig { +impl Default for SessionConfig { fn default() -> Self { Self { local_parity: StreamParity::Even, @@ -79,12 +78,12 @@ pub enum SessionEvent { } pub struct SessionFsm { - config: SessionFsmConfig, - state: SessionFsmState, + config: SessionConfig, + state: SessionState, } impl SessionFsm { - pub fn new(mut config: SessionFsmConfig, now: Instant) -> Self { + pub fn new(mut config: SessionConfig, now: Instant) -> Self { config.record_max_size = config .record_max_size .max(SessionRecordBuilder::MIN_CAPACITY); @@ -92,11 +91,11 @@ impl SessionFsm { config.stream_receive_buffer_size = config.stream_receive_buffer_size.max(1); Self { config, - state: SessionFsmState { + state: SessionState { now, last_activity_at: now, last_inbound_at: now, - session_state: SessionState::Open, + phase: SessionPhase::Open, next_stream_ordinal: 0, next_record_seq: RecordSeq::from_u32(0), next_write_id: 0, @@ -147,12 +146,12 @@ impl SessionFsm { } pub(crate) fn close(&mut self, code: SessionCloseCode, mut emit: impl FnMut(SessionEvent)) { - if self.state.session_state != SessionState::Open { + if self.state.phase != SessionPhase::Open { return; } let close = SessionClose { code }; - self.state.session_state = SessionState::Closing(close.clone()); + self.state.phase = SessionPhase::Closing(close.clone()); self.state.tracked_records.clear(); self.state.ack_state = AckState::Idle; self.clear_streams(); @@ -160,7 +159,7 @@ impl SessionFsm { } pub(crate) fn is_closed(&self) -> bool { - self.state.session_state == SessionState::Closed + self.state.phase == SessionPhase::Closed } pub(crate) fn receive( @@ -176,7 +175,7 @@ impl SessionFsm { self.state.last_activity_at = self.state.now; self.state.last_inbound_at = self.state.now; - if self.state.session_state != SessionState::Open { + if self.state.phase != SessionPhase::Open { return; } @@ -239,7 +238,7 @@ impl SessionFsm { pub fn confirm_write(&mut self, now: Instant, write_id: u64) { self.state.now = now; - if !self.state.session_state.is_open() { + if !self.state.phase.is_open() { return; } let Some(record) = self.state.tracked_records.get_mut(&write_id) else { @@ -253,7 +252,7 @@ impl SessionFsm { } pub fn reject_write(&mut self, write_id: u64) { - if !self.state.session_state.is_open() { + if !self.state.phase.is_open() { return; } if self @@ -278,7 +277,7 @@ impl SessionFsm { pub fn on_timer(&mut self, now: Instant, mut emit: impl FnMut(SessionEvent)) { self.state.now = now; - if !self.state.session_state.is_open() { + if !self.state.phase.is_open() { return; } self.collect_timeouts(); @@ -288,7 +287,7 @@ impl SessionFsm { self.close(SessionCloseCode::TIMEOUT, &mut emit); return; } - if self.state.session_state == SessionState::Open + if self.state.phase == SessionPhase::Open && !self.config.keepalive_interval.is_zero() && self.state.last_activity_at + self.config.keepalive_interval <= self.state.now { @@ -297,7 +296,7 @@ impl SessionFsm { } pub fn next_deadline(&self) -> Option { - if !self.state.session_state.is_open() { + if !self.state.phase.is_open() { return None; } let ack_deadline = match self.state.ack_state { @@ -314,11 +313,11 @@ impl SessionFsm { .map(|sent_at| sent_at + self.config.retransmit_timeout) }) .min(); - let keepalive_deadline = (self.state.session_state == SessionState::Open + let keepalive_deadline = (self.state.phase == SessionPhase::Open && !self.config.keepalive_interval.is_zero() && !self.state.pending_ping) .then_some(self.state.last_activity_at + self.config.keepalive_interval); - let peer_timeout_deadline = (self.state.session_state == SessionState::Open + let peer_timeout_deadline = (self.state.phase == SessionPhase::Open && !self.config.peer_timeout.is_zero()) .then_some(self.state.last_inbound_at + self.config.peer_timeout); [ @@ -334,19 +333,19 @@ impl SessionFsm { pub fn take_next_write(&mut self, now: Instant) -> Option<(Option, SessionRecordBuilder)> { self.state.now = now; - match &self.state.session_state { - SessionState::Closing(close) => { + match &self.state.phase { + SessionPhase::Closing(close) => { let seq = self.state.next_record_seq; next_seq(&mut self.state.next_record_seq); let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); assert!(builder.push_close(&close), "builder has capacity"); - self.state.session_state = SessionState::Closed; + self.state.phase = SessionPhase::Closed; return Some((None, builder)); } - SessionState::Closed => { + SessionPhase::Closed => { return None; } - SessionState::Open => {} + SessionPhase::Open => {} } self.collect_timeouts(); @@ -527,7 +526,7 @@ impl SessionFsm { } fn ensure_session_open(&self) -> Result<(), NoSessionError> { - if self.state.session_state != SessionState::Open { + if self.state.phase != SessionPhase::Open { Err(NoSessionError) } else { Ok(()) diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index 237750cb..65a3950f 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -8,11 +8,11 @@ use super::{ stream_rx::StreamRx, stream_tx::StreamTx, tracked::TrackedRecord, }; -pub struct SessionFsmState { +pub struct SessionState { pub now: Instant, pub last_activity_at: Instant, pub last_inbound_at: Instant, - pub session_state: SessionState, + pub phase: SessionPhase, pub next_stream_ordinal: u32, pub next_record_seq: RecordSeq, pub next_write_id: u64, @@ -26,13 +26,13 @@ pub struct SessionFsmState { } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum SessionState { +pub enum SessionPhase { Open, Closing(SessionClose), Closed, } -impl SessionState { +impl SessionPhase { pub fn is_open(&self) -> bool { self == &Self::Open } diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 8ef6e559..2665696c 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -7,7 +7,7 @@ use ql_wire::{ StreamId, VarInt, XID, }; -use super::{SessionEvent, SessionFsm, SessionFsmConfig}; +use super::{SessionEvent, SessionFsm, SessionConfig}; use crate::session::stream_parity::StreamParity; fn seq(value: u64) -> RecordSeq { @@ -91,7 +91,7 @@ fn receive_events( #[test] fn outbound_record_seq_increments_monotonically() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); let stream_id = open_stream_id(&mut fsm); assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"one"), 3); @@ -107,7 +107,7 @@ fn outbound_record_seq_increments_monotonically() { #[test] fn retransmit_uses_new_record_seq() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); let stream_id = open_stream_id(&mut fsm); assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"retry"), 5); @@ -124,9 +124,9 @@ fn retransmit_uses_new_record_seq() { fn lost_record_on_one_stream_does_not_block_another_stream() { let now = Instant::now(); let mut fsm = SessionFsm::new( - SessionFsmConfig { + SessionConfig { record_max_size: 80 + SessionRecordBuilder::MIN_CAPACITY, - ..SessionFsmConfig::default() + ..SessionConfig::default() }, now, ); @@ -162,9 +162,9 @@ fn lost_record_on_one_stream_does_not_block_another_stream() { fn ack_reopens_write_capacity() { let now = Instant::now(); let mut fsm = SessionFsm::new( - SessionFsmConfig { + SessionConfig { stream_send_buffer_size: 4, - ..SessionFsmConfig::default() + ..SessionConfig::default() }, now, ); @@ -192,10 +192,10 @@ fn ack_reopens_write_capacity() { fn commit_stream_read_is_what_advances_stream_window() { let now = Instant::now(); let mut fsm = SessionFsm::new( - SessionFsmConfig { + SessionConfig { local_parity: StreamParity::Even, ack_delay: Duration::ZERO, - ..SessionFsmConfig::default() + ..SessionConfig::default() }, now, ); @@ -239,9 +239,9 @@ fn commit_stream_read_is_what_advances_stream_window() { #[test] fn pure_ack_only_records_are_fire_and_forget() { let now = Instant::now(); - let config = SessionFsmConfig { + let config = SessionConfig { ack_delay: Duration::ZERO, - ..SessionFsmConfig::default() + ..SessionConfig::default() }; let retransmit_timeout = config.retransmit_timeout; let mut fsm = SessionFsm::new(config, now); @@ -270,7 +270,7 @@ fn pure_ack_only_records_are_fire_and_forget() { #[test] fn inbound_stream_data_emits_opened_and_readable() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); let stream_id = stream_id(1); let record = vec![SessionFrame::StreamData(ql_wire::StreamData { stream_id, @@ -295,7 +295,7 @@ fn inbound_stream_data_emits_opened_and_readable() { #[test] fn remote_stream_close_is_reliable_and_retried() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); let stream_id = open_stream_id(&mut fsm); fsm.stream(stream_id) @@ -323,9 +323,9 @@ fn stream_ids_follow_even_odd_xid_ordering() { let odd = StreamParity::for_local(XID([2; XID::SIZE]), XID([1; XID::SIZE])); let even_id = SessionFsm::new( - SessionFsmConfig { + SessionConfig { local_parity: even, - ..SessionFsmConfig::default() + ..SessionConfig::default() }, now, ) @@ -333,9 +333,9 @@ fn stream_ids_follow_even_odd_xid_ordering() { .unwrap() .stream_id(); let odd_id = SessionFsm::new( - SessionFsmConfig { + SessionConfig { local_parity: odd, - ..SessionFsmConfig::default() + ..SessionConfig::default() }, now, ) @@ -350,7 +350,7 @@ fn stream_ids_follow_even_odd_xid_ordering() { #[test] fn duplicate_stream_data_is_not_redelivered() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); let stream_id = stream_id(1); let record = vec![SessionFrame::StreamData(StreamData { stream_id, @@ -368,7 +368,7 @@ fn duplicate_stream_data_is_not_redelivered() { #[test] fn duplicate_remote_close_after_reap_is_ignored() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); let close = StreamClose { stream_id: stream_id(1), target: CloseTarget::Both, @@ -392,7 +392,7 @@ fn duplicate_remote_close_after_reap_is_ignored() { #[test] fn late_remote_stream_data_after_close_is_ignored() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); let stream_id = stream_id(1); let close = vec![SessionFrame::StreamClose(StreamClose { stream_id, @@ -431,7 +431,7 @@ fn late_remote_stream_data_after_close_is_ignored() { #[test] fn duplicate_finished_remote_data_after_reap_is_ignored() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); let stream_id = stream_id(1); let record = vec![SessionFrame::StreamData(StreamData { stream_id, @@ -459,7 +459,7 @@ fn duplicate_finished_remote_data_after_reap_is_ignored() { #[test] fn duplicate_finished_remote_data_before_read_is_ignored() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); let stream_id = stream_id(1); let record = vec![SessionFrame::StreamData(StreamData { stream_id, @@ -487,7 +487,7 @@ fn duplicate_finished_remote_data_before_read_is_ignored() { #[test] fn out_of_order_remote_stream_first_observations_still_open_once_each() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); let close3 = vec![SessionFrame::StreamClose(StreamClose { stream_id: stream_id(3), target: CloseTarget::Both, @@ -540,7 +540,7 @@ fn out_of_order_remote_stream_first_observations_still_open_once_each() { #[test] fn invalid_remote_stream_close_closes_session() { let now = Instant::now(); - let mut fsm = SessionFsm::new(SessionFsmConfig::default(), now); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); let invalid = vec![SessionFrame::StreamClose(StreamClose { stream_id: stream_id(0), @@ -561,9 +561,9 @@ fn invalid_remote_stream_close_closes_session() { fn close_does_not_ack_rejected_record_seq() { let now = Instant::now(); let mut fsm = SessionFsm::new( - SessionFsmConfig { + SessionConfig { ack_delay: Duration::ZERO, - ..SessionFsmConfig::default() + ..SessionConfig::default() }, now, ); @@ -600,9 +600,9 @@ fn close_does_not_ack_rejected_record_seq() { fn initial_peer_stream_receive_window_limits_first_send() { let now = Instant::now(); let mut fsm = SessionFsm::new( - SessionFsmConfig { + SessionConfig { initial_peer_stream_receive_window: 3, - ..SessionFsmConfig::default() + ..SessionConfig::default() }, now, ); diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 1112012d..edf02a03 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -3,7 +3,7 @@ use std::time::Duration; use ql_wire::QlHandshakeRecord; use super::*; -use crate::{state::LinkState, NoPeerError, PeerStatus, QlFsmEvent}; +use crate::{state::LinkState, NoPeerError, PeerStatus, Event}; #[test] fn ik_connect_round_trip_establishes_transport() { @@ -114,7 +114,7 @@ fn connect_ik_emits_initiator_status() { assert_eq!( harness.drain_events(Side::A), - vec![QlFsmEvent::PeerStatusChanged(PeerStatus::Initiator)] + vec![Event::PeerStatusChanged(PeerStatus::Initiator)] ); } @@ -250,8 +250,8 @@ fn inbound_ik1_auto_binds_unbound_responder() { assert_eq!( harness.drain_events(Side::B), vec![ - QlFsmEvent::NewPeer, - QlFsmEvent::PeerStatusChanged(PeerStatus::Connected), + Event::NewPeer, + Event::PeerStatusChanged(PeerStatus::Connected), ] ); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); @@ -279,7 +279,7 @@ fn handshake_timeout_drops_single_ik_attempt_without_resend() { assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); assert_eq!( harness.take_event(Side::A), - Some(QlFsmEvent::PeerStatusChanged(PeerStatus::Disconnected)) + Some(Event::PeerStatusChanged(PeerStatus::Disconnected)) ); assert!(harness.next_outbound(Side::A).is_none()); } diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index f854a6c6..c2daffff 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -10,9 +10,9 @@ use ql_wire::{ }; use crate::{ - session::{SessionFsm, SessionFsmConfig, StreamParity}, + session::{SessionConfig, SessionFsm, StreamParity}, state::{ConnectedState, LinkState, SessionTransport}, - FsmTime, NoPeerError, OutboundWrite, QlFsm, QlFsmConfig, QlFsmEvent, SessionWriteId, + Event, FsmTime, NoPeerError, OutboundWrite, QlFsm, QlFsmConfig, WriteId, }; type TestCrypto = SoftwareCrypto; @@ -46,7 +46,7 @@ struct Harness { struct DecodedSessionWrite { record: Vec, - write_id: Option, + write_id: Option, header: ql_wire::SessionHeader, frames: Vec>>, } @@ -219,12 +219,14 @@ impl Harness { fsm.receive(time, record, crypto).unwrap(); } - fn confirm_write(&mut self, side: Side, write_id: SessionWriteId) { + fn confirm_write(&mut self, side: Side, write_id: WriteId) { let time = self.time(); - self.node_mut(side).fsm.confirm_session_write(time, write_id); + self.node_mut(side) + .fsm + .confirm_session_write(time, write_id); } - fn reject_write(&mut self, side: Side, write_id: SessionWriteId) { + fn reject_write(&mut self, side: Side, write_id: WriteId) { self.node_mut(side).fsm.reject_session_write(write_id); } @@ -249,11 +251,11 @@ impl Harness { self.node_mut(side).fsm.on_timer(time); } - fn take_event(&mut self, side: Side) -> Option { + fn take_event(&mut self, side: Side) -> Option { self.node_mut(side).fsm.poll_event() } - fn drain_events(&mut self, side: Side) -> Vec { + fn drain_events(&mut self, side: Side) -> Vec { let mut events = Vec::new(); while let Some(event) = self.take_event(side) { events.push(event); @@ -288,7 +290,7 @@ fn pairing_token(byte: u8) -> PairingToken { PairingToken([byte; PairingToken::SIZE]) } -fn session_config(harness: &Harness, a: bool) -> SessionFsmConfig { +fn session_config(harness: &Harness, a: bool) -> SessionConfig { let (local, peer, config) = if a { ( harness.a.fsm.identity.xid, @@ -303,7 +305,7 @@ fn session_config(harness: &Harness, a: bool) -> SessionFsmConfig { ) }; - SessionFsmConfig { + SessionConfig { local_parity: StreamParity::for_local(local, peer), record_max_size: config.session_record_max_size, ack_delay: config.session_record_ack_delay, diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index 04ce4bf4..fa997a9c 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -14,7 +14,7 @@ use super::*; fn test_route_id() -> ql_wire::RouteId { ql_wire::RouteId(ql_wire::VarInt::from_u32(1)) } -use crate::{state::LinkState, PeerStatus, QlFsmEvent, ReceiveError, SessionWriteId}; +use crate::{state::LinkState, Event, PeerStatus, ReceiveError, WriteId}; const SLOT_COUNT: usize = 4; @@ -110,7 +110,7 @@ impl Action { #[derive(Clone, Debug)] struct TakenWrite { record: Vec, - write_id: Option, + write_id: Option, } #[derive(Default)] @@ -402,14 +402,14 @@ impl Runner { } } - fn process_events(&mut self, side: Side, events: Vec) -> TestCaseResult { + fn process_events(&mut self, side: Side, events: Vec) -> TestCaseResult { for event in events { match event { - QlFsmEvent::NewPeer => {} - QlFsmEvent::PeerStatusChanged(status) => { + Event::NewPeer => {} + Event::PeerStatusChanged(status) => { self.events[side.idx()].note_peer_status(status); } - QlFsmEvent::Opened { stream_id, .. } => { + Event::Opened { stream_id, .. } => { prop_assert!( self.known_streams.contains(&stream_id), "side {side:?} emitted Opened for unknown stream {stream_id:?}" @@ -419,13 +419,13 @@ impl Runner { "side {side:?} emitted duplicate Opened for {stream_id:?}" ); } - QlFsmEvent::Readable(stream_id) | QlFsmEvent::Writable(stream_id) => { + Event::Readable(stream_id) | Event::Writable(stream_id) => { prop_assert!( self.known_streams.contains(&stream_id), "side {side:?} emitted readiness for unknown stream {stream_id:?}" ); } - QlFsmEvent::Finished(stream_id) => { + Event::Finished(stream_id) => { prop_assert!( self.known_streams.contains(&stream_id), "side {side:?} emitted Finished for unknown stream {stream_id:?}" @@ -439,7 +439,7 @@ impl Runner { "side {side:?} emitted Finished after Closed for {stream_id:?}" ); } - QlFsmEvent::Closed(frame) => { + Event::Closed(frame) => { prop_assert!( self.known_streams.contains(&frame.stream_id), "side {side:?} emitted Closed for unknown stream {:?}", @@ -451,7 +451,7 @@ impl Runner { frame.stream_id ); } - QlFsmEvent::WritableClosed(frame) => { + Event::WritableClosed(frame) => { let stream_id = frame.stream_id; prop_assert!( self.known_streams.contains(&stream_id), @@ -462,7 +462,7 @@ impl Runner { "side {side:?} emitted duplicate WritableClosed for {stream_id:?}" ); } - QlFsmEvent::SessionClosed(_) => { + Event::SessionClosed(_) => { let state = &mut self.events[side.idx()]; prop_assert!( state.session_epoch > 0, diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 2c4fbb0e..229f3fe3 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -5,7 +5,7 @@ use ql_wire::{RouteId, SessionClose, StreamId, VarInt}; use super::*; use crate::{ - state::LinkState, CommitReadError, NoSessionError, PeerStatus, QlFsmEvent, StreamError, + state::LinkState, CommitReadError, NoSessionError, PeerStatus, Event, StreamError, }; fn stream_id(value: u32) -> StreamId { @@ -16,8 +16,8 @@ fn route_id(value: u32) -> RouteId { RouteId(VarInt::from_u32(value)) } -fn opened(stream_id: StreamId) -> QlFsmEvent { - QlFsmEvent::Opened { +fn opened(stream_id: StreamId) -> Event { + Event::Opened { stream_id, route_id: route_id(1), } @@ -82,7 +82,7 @@ fn connected_fsms_deliver_stream_data() { assert_eq!(harness.take_event(Side::B), Some(opened(stream_id))); assert_eq!( harness.take_event(Side::B), - Some(QlFsmEvent::Readable(stream_id)) + Some(Event::Readable(stream_id)) ); assert_eq!( read_stream_all(&mut harness.b.fsm, stream_id), @@ -90,7 +90,7 @@ fn connected_fsms_deliver_stream_data() { ); assert_eq!( harness.take_event(Side::B), - Some(QlFsmEvent::Finished(stream_id)) + Some(Event::Finished(stream_id)) ); } @@ -124,7 +124,7 @@ fn session_retransmit_uses_new_record_seq() { assert_eq!(harness.take_event(Side::B), Some(opened(stream_id))); assert_eq!( harness.take_event(Side::B), - Some(QlFsmEvent::Readable(stream_id)) + Some(Event::Readable(stream_id)) ); assert_eq!( read_stream_all(&mut harness.b.fsm, stream_id), @@ -167,7 +167,7 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { assert_eq!(harness.take_event(Side::A), Some(opened(stream_id_b))); assert_eq!( harness.take_event(Side::A), - Some(QlFsmEvent::Readable(stream_id_b)) + Some(Event::Readable(stream_id_b)) ); assert_eq!( read_stream_all(&mut harness.a.fsm, stream_id_b), @@ -176,7 +176,7 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { assert_eq!(harness.take_event(Side::B), Some(opened(stream_id_a))); assert_eq!( harness.take_event(Side::B), - Some(QlFsmEvent::Readable(stream_id_a)) + Some(Event::Readable(stream_id_a)) ); assert_eq!( read_stream_all(&mut harness.b.fsm, stream_id_a), @@ -275,7 +275,7 @@ fn returned_session_write_is_reissued_with_new_record_seq() { assert_eq!(harness.take_event(Side::B), Some(opened(stream_id))); assert_eq!( harness.take_event(Side::B), - Some(QlFsmEvent::Readable(stream_id)) + Some(Event::Readable(stream_id)) ); assert_eq!( read_stream_all(&mut harness.b.fsm, stream_id), @@ -338,7 +338,7 @@ fn ack_frame_releases_stream_capacity_and_emits_writable() { assert_eq!( harness.take_event(Side::A), - Some(QlFsmEvent::Writable(stream_id)) + Some(Event::Writable(stream_id)) ); } @@ -351,7 +351,7 @@ fn close_session_disconnects_locally() { .fsm .close_session(ql_wire::SessionCloseCode::CANCELLED); - assert!(matches!(harness.take_event(Side::A), Some(QlFsmEvent::SessionClosed(SessionClose { + assert!(matches!(harness.take_event(Side::A), Some(Event::SessionClosed(SessionClose { code: ql_wire::SessionCloseCode::CANCELLED, })))); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); @@ -367,7 +367,7 @@ fn close_session_disconnects_locally() { assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); assert_eq!( harness.take_event(Side::A), - Some(QlFsmEvent::PeerStatusChanged(PeerStatus::Disconnected)) + Some(Event::PeerStatusChanged(PeerStatus::Disconnected)) ); } @@ -438,7 +438,7 @@ fn session_timeout_emits_close_before_disconnect() { assert_eq!( harness.drain_events(Side::A), - vec![QlFsmEvent::SessionClosed(SessionClose { + vec![Event::SessionClosed(SessionClose { code: ql_wire::SessionCloseCode::TIMEOUT, })] ); @@ -447,6 +447,6 @@ fn session_timeout_emits_close_before_disconnect() { assert!(matches!(close.frames.as_slice(), [ql_wire::SessionFrame::Close(_)])); assert_eq!( harness.take_event(Side::A), - Some(QlFsmEvent::PeerStatusChanged(PeerStatus::Disconnected)) + Some(Event::PeerStatusChanged(PeerStatus::Disconnected)) ); } diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 928e9815..f6137ce5 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -14,7 +14,7 @@ use std::{ }; use futures_lite::future::poll_fn; -use ql_fsm::{FsmTime, QlFsm, QlFsmEvent, SessionWriteId}; +use ql_fsm::{Event, FsmTime, QlFsm, WriteId}; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; use self::state::{DriverState, DriverStreamIo, InboundIo, InboundWriteResult, OutboundIo}; @@ -83,7 +83,7 @@ impl Runtime

{ } struct InFlightWrite { - session_write_id: Option, + session_write_id: Option, future: F, } @@ -228,11 +228,7 @@ impl DriverState { } } - fn drive_write_completed( - fsm: &mut QlFsm, - session_write_id: Option, - success: bool, - ) { + fn drive_write_completed(fsm: &mut QlFsm, session_write_id: Option, success: bool) { if let Some(write_id) = session_write_id { if success { fsm.confirm_session_write(now(), write_id); @@ -255,45 +251,44 @@ impl DriverState { output } - fn process_fsm_event( - &mut self, - fsm: &mut QlFsm, - platform: &P, - event: QlFsmEvent, - ) { + fn process_fsm_event(&mut self, fsm: &mut QlFsm, platform: &P, event: Event) { match event { - QlFsmEvent::NewPeer => { + Event::NewPeer => { if let Some(peer) = fsm.peer().cloned() { platform.persist_peer(peer); } } - QlFsmEvent::PeerStatusChanged(status) => { + Event::PeerStatusChanged(status) => { if let Some(peer) = fsm.peer().map(|peer| peer.xid) { platform.handle_peer_status(peer, status); } } - QlFsmEvent::Opened { + Event::Opened { stream_id, route_id, } => { self.handle_opened_stream(fsm, platform, stream_id, route_id); } - QlFsmEvent::Readable(stream_id) => { + Event::Readable(stream_id) => { self.handle_inbound_readable(fsm, stream_id); } - QlFsmEvent::Writable(stream_id) => { + Event::Writable(stream_id) => { self.poll_stream(fsm, stream_id); } - QlFsmEvent::Finished(stream_id) => { + Event::Finished(stream_id) => { self.handle_inbound_finished(fsm, stream_id); } - QlFsmEvent::Closed(frame) => { + Event::Closed(frame) => { self.handle_closed_stream(&frame); } - QlFsmEvent::WritableClosed(frame) => { + Event::WritableClosed(frame) => { self.handle_writable_closed(&frame); } - QlFsmEvent::SessionClosed(_) => self.fail_all_streams(), + Event::SessionClosed(_) => { + for (_, mut stream) in self.streams.drain() { + stream.fail_all(); + } + } } } @@ -444,13 +439,6 @@ impl DriverState { Self::try_reap_stream(entry); } - fn fail_all_streams(&mut self) { - for stream in self.streams.values_mut() { - stream.fail_all(); - } - self.streams.clear(); - } - fn fill_write_slots<'a, P: QlPlatform + 'a>( &self, fsm: &mut QlFsm, From 5ab6e19a68c0e7c64b7132bd4882a857e938586a Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 9 Apr 2026 22:10:26 -0400 Subject: [PATCH 179/304] ql-fsm: clean up time state --- ql-fsm/src/fsm.rs | 2 +- ql-fsm/src/lib.rs | 3 ++- ql-fsm/src/session/mod.rs | 47 +++++++++++++++--------------------- ql-fsm/src/session/state.rs | 1 - ql-fsm/src/tests/mod.rs | 3 ++- ql-runtime/src/driver/mod.rs | 2 +- 6 files changed, 25 insertions(+), 33 deletions(-) diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs index ec999d03..0912d745 100644 --- a/ql-fsm/src/fsm.rs +++ b/ql-fsm/src/fsm.rs @@ -135,7 +135,7 @@ pub fn confirm_session_write(fsm: &mut QlFsm, write_id: WriteId) { pub fn reject_session_write(fsm: &mut QlFsm, write_id: WriteId) { if let Some(state) = fsm.state.link.connected_mut() { - state.session.reject_write(write_id.0); + state.session.reject_write(fsm.state.now.instant, write_id.0); } } diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 2cb32683..9401379f 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -266,7 +266,8 @@ impl QlFsm { /// reports that a `SessionWriteId` from `take_next_write` was not accepted /// /// call this at most once for each returned `SessionWriteId` - pub fn reject_session_write(&mut self, write_id: WriteId) { + pub fn reject_session_write(&mut self, now: FsmTime, write_id: WriteId) { + self.state.now = now; fsm::reject_session_write(self, write_id); } diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 243b0d51..c17083d5 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -92,7 +92,6 @@ impl SessionFsm { Self { config, state: SessionState { - now, last_activity_at: now, last_inbound_at: now, phase: SessionPhase::Open, @@ -171,21 +170,20 @@ impl SessionFsm { ) where I: IntoIterator, WireError>>, { - self.state.now = now; - self.state.last_activity_at = self.state.now; - self.state.last_inbound_at = self.state.now; + self.state.last_activity_at = now; + self.state.last_inbound_at = now; if self.state.phase != SessionPhase::Open { return; } - self.collect_timeouts(); + self.collect_timeouts(now); let mut received_records = self.state.received_records.clone(); let out_of_order = match received_records.insert(seq) { ReceiveOutcome::TooOld => return, ReceiveOutcome::Duplicate => { - self.schedule_ack(true); + self.schedule_ack(now, true); return; } ReceiveOutcome::New { out_of_order } => out_of_order, @@ -232,12 +230,11 @@ impl SessionFsm { } if ack_eliciting { - self.schedule_ack(out_of_order); + self.schedule_ack(now, out_of_order); } } pub fn confirm_write(&mut self, now: Instant, write_id: u64) { - self.state.now = now; if !self.state.phase.is_open() { return; } @@ -251,7 +248,7 @@ impl SessionFsm { record.sent_at = Some(now); } - pub fn reject_write(&mut self, write_id: u64) { + pub fn reject_write(&mut self, now: Instant, write_id: u64) { if !self.state.phase.is_open() { return; } @@ -267,7 +264,7 @@ impl SessionFsm { return; }; restore_tracked_record( - self.state.now, + now, &mut self.state.ack_state, &mut self.state.pending_ping, &mut self.state.streams, @@ -276,20 +273,19 @@ impl SessionFsm { } pub fn on_timer(&mut self, now: Instant, mut emit: impl FnMut(SessionEvent)) { - self.state.now = now; if !self.state.phase.is_open() { return; } - self.collect_timeouts(); + self.collect_timeouts(now); if !self.config.peer_timeout.is_zero() - && self.state.last_inbound_at + self.config.peer_timeout <= self.state.now + && self.state.last_inbound_at + self.config.peer_timeout <= now { self.close(SessionCloseCode::TIMEOUT, &mut emit); return; } if self.state.phase == SessionPhase::Open && !self.config.keepalive_interval.is_zero() - && self.state.last_activity_at + self.config.keepalive_interval <= self.state.now + && self.state.last_activity_at + self.config.keepalive_interval <= now { self.state.pending_ping = true; } @@ -332,7 +328,6 @@ impl SessionFsm { } pub fn take_next_write(&mut self, now: Instant) -> Option<(Option, SessionRecordBuilder)> { - self.state.now = now; match &self.state.phase { SessionPhase::Closing(close) => { let seq = self.state.next_record_seq; @@ -347,9 +342,9 @@ impl SessionFsm { } SessionPhase::Open => {} } - self.collect_timeouts(); + self.collect_timeouts(now); - let (builder, outbound) = self.build_next_record()?; + let (builder, outbound) = self.build_next_record(now)?; let should_track = outbound.ping_included || !outbound.window_updates.is_empty() @@ -365,7 +360,7 @@ impl SessionFsm { Some((write_id, builder)) } - fn build_next_record(&mut self) -> Option<(SessionRecordBuilder, TrackedRecord)> { + fn build_next_record(&mut self, now: Instant) -> Option<(SessionRecordBuilder, TrackedRecord)> { let seq = self.state.next_record_seq; let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); let mut outbound = TrackedRecord { @@ -389,7 +384,7 @@ impl SessionFsm { self.push_next_stream_data(&mut builder, &mut outbound); if let Some((ack, due_at)) = self.pending_ack() { - if (!builder.is_empty() || due_at <= self.state.now) && builder.push_ack(&ack) { + if (!builder.is_empty() || due_at <= now) && builder.push_ack(&ack) { outbound.ack_included = true; self.state.ack_state = AckState::Idle; } @@ -549,14 +544,10 @@ impl SessionFsm { self.reap_reapable_streams(); } - fn schedule_ack(&mut self, immediate: bool) { + fn schedule_ack(&mut self, now: Instant, immediate: bool) { schedule_ack( &mut self.state.ack_state, - if immediate { - self.state.now - } else { - self.state.now + self.config.ack_delay - }, + if immediate { now } else { now + self.config.ack_delay }, ); } @@ -569,15 +560,15 @@ impl SessionFsm { } } - fn collect_timeouts(&mut self) { + fn collect_timeouts(&mut self, now: Instant) { let retransmit_timeout = self.config.retransmit_timeout; for (_, record) in self.state.tracked_records.extract_if(.., |_, record| { record .sent_at - .is_some_and(|sent_at| sent_at + retransmit_timeout <= self.state.now) + .is_some_and(|sent_at| sent_at + retransmit_timeout <= now) }) { restore_tracked_record( - self.state.now, + now, &mut self.state.ack_state, &mut self.state.pending_ping, &mut self.state.streams, diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index 65a3950f..21443f77 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -9,7 +9,6 @@ use super::{ }; pub struct SessionState { - pub now: Instant, pub last_activity_at: Instant, pub last_inbound_at: Instant, pub phase: SessionPhase, diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index c2daffff..1cde6750 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -227,7 +227,8 @@ impl Harness { } fn reject_write(&mut self, side: Side, write_id: WriteId) { - self.node_mut(side).fsm.reject_session_write(write_id); + let time = self.time(); + self.node_mut(side).fsm.reject_session_write(time, write_id); } fn decode_session_write(&self, write: OutboundWrite, side: Side) -> DecodedSessionWrite { diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index f6137ce5..92dcffc9 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -233,7 +233,7 @@ impl DriverState { if success { fsm.confirm_session_write(now(), write_id); } else { - fsm.reject_session_write(write_id); + fsm.reject_session_write(now(), write_id); } } } From 01428a875fac09b9d72402e7160f4936bce84635 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 9 Apr 2026 22:19:52 -0400 Subject: [PATCH 180/304] ql: clean up complete_write --- ql-fsm/src/fsm.rs | 14 +++------ ql-fsm/src/lib.rs | 33 +++++++------------ ql-fsm/src/session/mod.rs | 61 +++++++++++++++++------------------- ql-fsm/src/session/tests.rs | 10 ++++-- ql-fsm/src/tests/mod.rs | 14 ++++----- ql-fsm/src/tests/proptest.rs | 2 +- ql-runtime/src/driver/mod.rs | 8 ++--- 7 files changed, 62 insertions(+), 80 deletions(-) diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs index 0912d745..7f77af69 100644 --- a/ql-fsm/src/fsm.rs +++ b/ql-fsm/src/fsm.rs @@ -105,7 +105,7 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option Option, - /// write handle that must be confirmed or rejected - pub session_write_id: Option, + /// write handle that must be completed exactly once + pub write_id: Option, } /// timing and buffering knobs for `QlFsm` @@ -224,17 +224,17 @@ impl QlFsm { fsm::receive(self, bytes, crypto) } + /// returns the next queued event, if any + pub fn poll_event(&mut self) -> Option { + self.pending_events.pop_front() + } + /// advances time-based state pub fn on_timer(&mut self, now: FsmTime) { self.state.now = now; fsm::on_timer(self); } - /// returns the next queued event, if any - pub fn poll_event(&mut self) -> Option { - self.pending_events.pop_front() - } - /// returns the next timer deadline, if any pub fn next_deadline(&self) -> Option { fsm::next_deadline(self) @@ -242,8 +242,7 @@ impl QlFsm { /// returns the next outbound record /// - /// if `session_write_id` is `Some`, call exactly one of - /// `confirm_session_write` or `reject_session_write` + /// if `write_id` is `Some`, call `complete_write` exactly once /// /// if it is `None`, the record is fire-and-forget pub fn take_next_write( @@ -255,20 +254,12 @@ impl QlFsm { fsm::take_next_write(self, crypto) } - /// marks a `SessionWriteId` from `take_next_write` as handed to the transport - /// - /// call this at most once for each returned `SessionWriteId` - pub fn confirm_session_write(&mut self, now: FsmTime, write_id: WriteId) { - self.state.now = now; - fsm::confirm_session_write(self, write_id); - } - - /// reports that a `SessionWriteId` from `take_next_write` was not accepted + /// completes a `SessionWriteId` from `take_next_write` with the transport outcome /// /// call this at most once for each returned `SessionWriteId` - pub fn reject_session_write(&mut self, now: FsmTime, write_id: WriteId) { + pub fn complete_write(&mut self, now: FsmTime, write_id: WriteId, success: bool) { self.state.now = now; - fsm::reject_session_write(self, write_id); + fsm::complete_write(self, write_id, success); } /// closes the current encrypted session locally diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index c17083d5..2d617ba1 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -234,42 +234,39 @@ impl SessionFsm { } } - pub fn confirm_write(&mut self, now: Instant, write_id: u64) { + pub fn complete_write(&mut self, now: Instant, write_id: u64, success: bool) { if !self.state.phase.is_open() { return; } - let Some(record) = self.state.tracked_records.get_mut(&write_id) else { - return; - }; - if record.sent_at.is_some() { - return; - } - self.state.last_activity_at = now; - record.sent_at = Some(now); - } - - pub fn reject_write(&mut self, now: Instant, write_id: u64) { - if !self.state.phase.is_open() { - return; - } - if self - .state - .tracked_records - .get(&write_id) - .is_some_and(|record| record.sent_at.is_some()) - { - return; + if success { + let Some(record) = self.state.tracked_records.get_mut(&write_id) else { + return; + }; + if record.sent_at.is_some() { + return; + } + self.state.last_activity_at = now; + record.sent_at = Some(now); + } else { + if self + .state + .tracked_records + .get(&write_id) + .is_some_and(|record| record.sent_at.is_some()) + { + return; + } + let Some(record) = self.state.tracked_records.shift_remove(&write_id) else { + return; + }; + restore_tracked_record( + now, + &mut self.state.ack_state, + &mut self.state.pending_ping, + &mut self.state.streams, + record, + ); } - let Some(record) = self.state.tracked_records.shift_remove(&write_id) else { - return; - }; - restore_tracked_record( - now, - &mut self.state.ack_state, - &mut self.state.pending_ping, - &mut self.state.streams, - record, - ); } pub fn on_timer(&mut self, now: Instant, mut emit: impl FnMut(SessionEvent)) { diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 2665696c..ff208654 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -7,7 +7,7 @@ use ql_wire::{ StreamId, VarInt, XID, }; -use super::{SessionEvent, SessionFsm, SessionConfig}; +use super::{SessionConfig, SessionEvent, SessionFsm}; use crate::session::stream_parity::StreamParity; fn seq(value: u64) -> RecordSeq { @@ -63,7 +63,7 @@ fn next_outbound( ) -> Option<(RecordSeq, Vec>>)> { let (write_id, builder) = fsm.take_next_write(now)?; if let Some(write_id) = write_id { - fsm.confirm_write(now, write_id); + fsm.complete_write(now, write_id, true); } Some(( builder.seq(), @@ -303,7 +303,11 @@ fn remote_stream_close_is_reliable_and_retried() { .close(CloseTarget::Both, StreamCloseCode(0)); let (write_id, builder) = fsm.take_next_write(now).unwrap(); - fsm.confirm_write(now, write_id.expect("stream close should be tracked")); + fsm.complete_write( + now, + write_id.expect("stream close should be tracked"), + true, + ); let first = decode_session_frames(builder.bytes()).unwrap(); assert!(matches!( first.as_slice(), diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 1cde6750..3ce2c1f7 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -170,7 +170,7 @@ impl Harness { fn next_outbound(&mut self, side: Side) -> Option> { let write = self.next_write(side)?; - if let Some(id) = write.session_write_id { + if let Some(id) = write.write_id { self.confirm_write(side, id); } Some(write.record) @@ -184,7 +184,7 @@ impl Harness { fn next_decoded_outbound(&mut self, side: Side) -> Option { let write = self.next_write(side)?; - if let Some(id) = write.session_write_id { + if let Some(id) = write.write_id { self.confirm_write(side, id); } Some(self.decode_session_write(write, side)) @@ -221,14 +221,14 @@ impl Harness { fn confirm_write(&mut self, side: Side, write_id: WriteId) { let time = self.time(); - self.node_mut(side) - .fsm - .confirm_session_write(time, write_id); + self.node_mut(side).fsm.complete_write(time, write_id, true); } fn reject_write(&mut self, side: Side, write_id: WriteId) { let time = self.time(); - self.node_mut(side).fsm.reject_session_write(time, write_id); + self.node_mut(side) + .fsm + .complete_write(time, write_id, false); } fn decode_session_write(&self, write: OutboundWrite, side: Side) -> DecodedSessionWrite { @@ -241,7 +241,7 @@ impl Harness { let (header, frames) = decrypt_record(crypto, &write.record, session_key); DecodedSessionWrite { record: write.record, - write_id: write.session_write_id, + write_id: write.write_id, header, frames, } diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index fa997a9c..fc6d702b 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -666,7 +666,7 @@ fn take_unconfirmed_outbound(harness: &mut Harness, side: Side) -> Option, success: bool) { if let Some(write_id) = session_write_id { - if success { - fsm.confirm_session_write(now(), write_id); - } else { - fsm.reject_session_write(now(), write_id); - } + fsm.complete_write(now(), write_id, success); } } @@ -450,7 +446,7 @@ impl DriverState { break; }; in_flight.push(InFlightWrite { - session_write_id: write.session_write_id, + session_write_id: write.write_id, future: platform.write_message(write.record), }); } From 71bdf12be714d1833e054dd9e816a7f5c26ae433 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 9 Apr 2026 22:54:34 -0400 Subject: [PATCH 181/304] ql-fsm: dirty_timer state --- ql-fsm/src/fsm.rs | 48 +++++++++++++++++++++++++++++++++--- ql-fsm/src/handshake/mod.rs | 16 +++++++----- ql-fsm/src/lib.rs | 16 ++++++------ ql-fsm/src/session/mod.rs | 34 ++++++++++++++++++++++--- ql-fsm/src/state.rs | 7 ++++++ ql-fsm/src/tests/mod.rs | 7 +++++- ql-fsm/src/tests/proptest.rs | 1 + 7 files changed, 106 insertions(+), 23 deletions(-) diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs index 7f77af69..ee4770c4 100644 --- a/ql-fsm/src/fsm.rs +++ b/ql-fsm/src/fsm.rs @@ -4,14 +4,38 @@ use bytes::Bytes; use ql_wire::{self as wire, QlCrypto, RouteId, SessionCloseCode, StreamId, WireDecode}; use crate::{ - handshake, session::SessionEvent, state::LinkState, Event, NoSessionError, OutboundWrite, - QlFsm, ReceiveError, StreamError, StreamOps, WriteId, + handshake, session::SessionEvent, state::LinkState, Event, NoPeerError, NoSessionError, + OutboundWrite, QlFsm, ReceiveError, StreamError, StreamOps, WriteId, }; pub fn handle_bind_peer(fsm: &mut QlFsm, peer: ql_wire::PeerBundle) { fsm.state.handshake = None; fsm.state.link = LinkState::Idle; fsm.state.peer = Some(peer); + fsm.state.mark_timer_dirty(); +} + +pub fn handle_disarm_pairing(fsm: &mut QlFsm) { + fsm.state.armed_pairing_token = None; + handshake::handle_disarm_pairing(fsm); + fsm.state.mark_timer_dirty(); +} + +pub fn handle_connect_xx(fsm: &mut QlFsm, token: ql_wire::PairingToken, crypto: &impl QlCrypto) { + handshake::handle_connect_xx(fsm, token, crypto); + fsm.state.mark_timer_dirty(); +} + +pub fn handle_connect_ik(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { + handshake::handle_connect_ik(fsm, crypto).inspect(|_| { + fsm.state.mark_timer_dirty(); + }) +} + +pub fn handle_connect_kk(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { + handshake::handle_connect_kk(fsm, crypto).inspect(|_| { + fsm.state.mark_timer_dirty(); + }) } pub fn receive( @@ -29,7 +53,9 @@ pub fn receive( match header.record_type { wire::RecordType::Handshake => { let record = wire::QlHandshakeRecord::decode(&mut reader)?; - handshake::handle_handshake_record(fsm, crypto, &record) + handshake::handle_handshake_record(fsm, crypto, &record).inspect(|_| { + fsm.state.mark_timer_dirty(); + }) } wire::RecordType::Session => { let state = fsm @@ -65,13 +91,16 @@ pub fn receive( if state.session.is_closed() { apply_session_closed(fsm); } + fsm.state.mark_timer_dirty(); Ok(()) } } } pub fn on_timer(fsm: &mut QlFsm) { - handshake::handle_timer(fsm); + if handshake::handle_timer(fsm) { + fsm.state.mark_timer_dirty(); + } let Some(state) = fsm.state.link.connected_mut() else { return; @@ -158,6 +187,16 @@ pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), NoSessionError> { state.session.queue_ping() } +pub fn poll_event(fsm: &mut QlFsm) -> Option { + fsm.pending_events.pop_front().or_else(|| { + let mut timer_dirty = std::mem::take(&mut fsm.state.timer_dirty); + if let Some(state) = fsm.state.link.connected_mut() { + timer_dirty |= state.session.take_timer_dirty(); + } + timer_dirty.then_some(Event::TimerDirty) + }) +} + pub fn emit_peer_status(fsm: &mut QlFsm) { if fsm.state.peer.is_some() { fsm.pending_events @@ -203,6 +242,7 @@ fn forward_session_event( fn apply_session_closed(fsm: &mut QlFsm) { if matches!(fsm.state.link, crate::state::LinkState::Connected(_)) { fsm.state.link = crate::state::LinkState::Idle; + fsm.state.mark_timer_dirty(); emit_peer_status(fsm); } } diff --git a/ql-fsm/src/handshake/mod.rs b/ql-fsm/src/handshake/mod.rs index 8146be08..f0c5125c 100644 --- a/ql-fsm/src/handshake/mod.rs +++ b/ql-fsm/src/handshake/mod.rs @@ -84,17 +84,21 @@ pub fn handle_handshake_record( } } -pub fn handle_timer(fsm: &mut QlFsm) { - let Some(deadline) = fsm.state.link.handshake_deadline() else { - return; - }; - if deadline > fsm.state.now.instant { - return; +pub fn handle_timer(fsm: &mut QlFsm) -> bool { + let expired = fsm + .state + .link + .handshake_deadline() + .is_some_and(|d| d <= fsm.state.now.instant); + + if !expired { + return false; } fsm.state.link = LinkState::Idle; fsm.state.handshake = None; emit_peer_status(fsm); + true } pub fn next_handshake_deadline(fsm: &QlFsm) -> Option { diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index f2eb563c..3be7c5ac 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -68,6 +68,8 @@ pub enum PeerStatus { /// events emitted by `QlFsm` #[derive(Debug, Clone, PartialEq, Eq)] pub enum Event { + /// timer-related state changed; recompute the next deadline + TimerDirty, /// a peer was learned during handshake completion NewPeer, /// the peer changed connection state @@ -147,9 +149,7 @@ impl Default for QlFsmConfig { /// synchronous driver for peer binding, handshake, and encrypted streams pub struct QlFsm { - /// active configuration config: QlFsmConfig, - /// local identity and private keys identity: QlIdentity, state: QlFsmState, pending_events: VecDeque, @@ -169,6 +169,7 @@ impl QlFsm { handshake: None, link: LinkState::Idle, now, + timer_dirty: false, }, pending_events: VecDeque::new(), } @@ -191,26 +192,25 @@ impl QlFsm { /// disarms inbound xx pairing and rejects any in-flight inbound xx responder state pub fn disarm_pairing(&mut self) { - self.state.armed_pairing_token = None; - handshake::handle_disarm_pairing(self); + fsm::handle_disarm_pairing(self); } /// starts an outbound xx handshake using the supplied pairing token pub fn connect_xx(&mut self, now: FsmTime, token: PairingToken, crypto: &impl QlCrypto) { self.state.now = now; - handshake::handle_connect_xx(self, token, crypto); + fsm::handle_connect_xx(self, token, crypto); } /// starts an IK handshake with the currently bound peer pub fn connect_ik(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { self.state.now = now; - handshake::handle_connect_ik(self, crypto) + fsm::handle_connect_ik(self, crypto) } /// starts a KK handshake with the currently bound peer pub fn connect_kk(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { self.state.now = now; - handshake::handle_connect_kk(self, crypto) + fsm::handle_connect_kk(self, crypto) } /// handles one inbound wire message @@ -226,7 +226,7 @@ impl QlFsm { /// returns the next queued event, if any pub fn poll_event(&mut self) -> Option { - self.pending_events.pop_front() + fsm::poll_event(self) } /// advances time-based state diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 2d617ba1..c00fcddd 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -80,6 +80,7 @@ pub enum SessionEvent { pub struct SessionFsm { config: SessionConfig, state: SessionState, + timer_dirty: bool, } impl SessionFsm { @@ -106,6 +107,7 @@ impl SessionFsm { next_stream_index: 0, remote_stream_history: RemoteStreamHistory::new(config.local_parity.remote()), }, + timer_dirty: false, } } @@ -140,7 +142,10 @@ impl SessionFsm { pub fn queue_ping(&mut self) -> Result<(), NoSessionError> { self.ensure_session_open()?; - self.state.pending_ping = true; + if !self.state.pending_ping { + self.state.pending_ping = true; + self.timer_dirty = true; + } Ok(()) } @@ -154,6 +159,7 @@ impl SessionFsm { self.state.tracked_records.clear(); self.state.ack_state = AckState::Idle; self.clear_streams(); + self.timer_dirty = true; emit(SessionEvent::SessionClosed(close)); } @@ -177,6 +183,7 @@ impl SessionFsm { return; } + self.timer_dirty = true; self.collect_timeouts(now); let mut received_records = self.state.received_records.clone(); @@ -238,6 +245,7 @@ impl SessionFsm { if !self.state.phase.is_open() { return; } + self.timer_dirty = true; if success { let Some(record) = self.state.tracked_records.get_mut(&write_id) else { return; @@ -273,6 +281,7 @@ impl SessionFsm { if !self.state.phase.is_open() { return; } + self.timer_dirty = true; self.collect_timeouts(now); if !self.config.peer_timeout.is_zero() && self.state.last_inbound_at + self.config.peer_timeout <= now @@ -332,6 +341,7 @@ impl SessionFsm { let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); assert!(builder.push_close(&close), "builder has capacity"); self.state.phase = SessionPhase::Closed; + self.timer_dirty = true; return Some((None, builder)); } SessionPhase::Closed => { @@ -339,13 +349,14 @@ impl SessionFsm { } SessionPhase::Open => {} } - self.collect_timeouts(now); + let timeouts_changed = self.collect_timeouts(now); let (builder, outbound) = self.build_next_record(now)?; let should_track = outbound.ping_included || !outbound.window_updates.is_empty() || !outbound.frames.is_empty(); + let timer_changed = timeouts_changed || outbound.ack_included || outbound.ping_included; let write_id = should_track.then(|| { let write_id = self.state.next_write_id; @@ -354,9 +365,17 @@ impl SessionFsm { write_id }); + if timer_changed { + self.timer_dirty = true; + } + Some((write_id, builder)) } + pub fn take_timer_dirty(&mut self) -> bool { + std::mem::take(&mut self.timer_dirty) + } + fn build_next_record(&mut self, now: Instant) -> Option<(SessionRecordBuilder, TrackedRecord)> { let seq = self.state.next_record_seq; let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); @@ -544,7 +563,11 @@ impl SessionFsm { fn schedule_ack(&mut self, now: Instant, immediate: bool) { schedule_ack( &mut self.state.ack_state, - if immediate { now } else { now + self.config.ack_delay }, + if immediate { + now + } else { + now + self.config.ack_delay + }, ); } @@ -557,13 +580,15 @@ impl SessionFsm { } } - fn collect_timeouts(&mut self, now: Instant) { + fn collect_timeouts(&mut self, now: Instant) -> bool { let retransmit_timeout = self.config.retransmit_timeout; + let mut changed = false; for (_, record) in self.state.tracked_records.extract_if(.., |_, record| { record .sent_at .is_some_and(|sent_at| sent_at + retransmit_timeout <= now) }) { + changed = true; restore_tracked_record( now, &mut self.state.ack_state, @@ -572,6 +597,7 @@ impl SessionFsm { record, ); } + changed } fn handle_stream_data( diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index 79a8c5ee..db9711b9 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -15,6 +15,13 @@ pub struct QlFsmState { pub handshake: Option, pub link: LinkState, pub now: FsmTime, + pub timer_dirty: bool, +} + +impl QlFsmState { + pub fn mark_timer_dirty(&mut self) { + self.timer_dirty = true; + } } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 3ce2c1f7..93e4a3a7 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -253,7 +253,12 @@ impl Harness { } fn take_event(&mut self, side: Side) -> Option { - self.node_mut(side).fsm.poll_event() + loop { + match self.node_mut(side).fsm.poll_event() { + Some(Event::TimerDirty) => {} + event => return event, + } + } } fn drain_events(&mut self, side: Side) -> Vec { diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index fc6d702b..cb41664d 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -405,6 +405,7 @@ impl Runner { fn process_events(&mut self, side: Side, events: Vec) -> TestCaseResult { for event in events { match event { + Event::TimerDirty => {} Event::NewPeer => {} Event::PeerStatusChanged(status) => { self.events[side.idx()].note_peer_status(status); From 4d0392ccb388a929f7b07a40ada804008b34fe31 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 10 Apr 2026 05:10:45 -0400 Subject: [PATCH 182/304] ql-runtime: driver simplification --- ql-runtime/src/driver/mod.rs | 119 ++++++++++++++++++----------------- 1 file changed, 62 insertions(+), 57 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 7e0ce09d..17f46985 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -54,8 +54,7 @@ impl Runtime

{ let mut recv_future = pin!(recv_future); loop { - state.fill_write_slots(&mut fsm, &platform, &mut in_flight); - timer.set_deadline(fsm.next_deadline()); + state.drain_events(&mut fsm, &platform, &mut timer, &mut in_flight); match next_driver_event(recv_future.as_mut(), &mut timer, &mut in_flight).await { DriverEvent::Command(command) => { @@ -63,14 +62,10 @@ impl Runtime

{ } DriverEvent::WriteCompleted { index, success } => { let write = in_flight.swap_remove(index); - state.with_fsm_events(&mut fsm, &platform, |fsm| { - DriverState::drive_write_completed(fsm, write.session_write_id, success) - }) + DriverState::drive_write_completed(&mut fsm, write.session_write_id, success); } DriverEvent::TimerExpired => { - state.with_fsm_events(&mut fsm, &platform, |fsm| { - fsm.on_timer(now()); - }); + fsm.on_timer(now()); } DriverEvent::CommandsClosed => { if in_flight.is_empty() { @@ -118,7 +113,7 @@ where recv_future .as_mut() .poll(cx) - .map(|res| res.map_or_else(|_| DriverEvent::CommandsClosed, DriverEvent::Command)) + .map(|res| res.map_or(DriverEvent::CommandsClosed, DriverEvent::Command)) }) .await } @@ -135,7 +130,7 @@ impl DriverState { fsm.bind_peer(peer); } RuntimeCommand::Connect => { - let _ = self.with_fsm_events(fsm, platform, |fsm| fsm.connect_ik(now(), platform)); + let _ = fsm.connect_ik(now(), platform); } RuntimeCommand::ArmPairing { token } => { fsm.arm_pairing(token); @@ -144,12 +139,10 @@ impl DriverState { fsm.disarm_pairing(); } RuntimeCommand::StartPairing { token } => { - self.with_fsm_events(fsm, platform, |fsm| fsm.connect_xx(now(), token, platform)); + fsm.connect_xx(now(), token, platform); } RuntimeCommand::Receive(bytes) => { - if let Err(e) = - self.with_fsm_events(fsm, platform, |fsm| fsm.receive(now(), bytes, platform)) - { + if let Err(e) = fsm.receive(now(), bytes, platform) { platform.handle_recv_error(e); } } @@ -195,6 +188,7 @@ impl DriverState { stream.outbound_close(); } stream_ops.close(CloseTarget::Both, StreamCloseCode(0)); + drop(stream_ops); return; } drop(stream_ops); @@ -234,58 +228,66 @@ impl DriverState { } } - fn with_fsm_events( + fn drain_events<'a, P: QlPlatform + 'a>( &mut self, fsm: &mut QlFsm, - platform: &P, - run: impl FnOnce(&mut QlFsm) -> T, - ) -> T { - let output = run(fsm); - while let Some(event) = fsm.poll_event() { - self.process_fsm_event(fsm, platform, event); + platform: &'a P, + timer: &mut impl QlTimer, + in_flight: &mut Vec>>, + ) { + let mut timer_dirty = self.drain_fsm_events(fsm, platform); + if self.fill_write_slots(fsm, platform, in_flight) { + timer_dirty |= self.drain_fsm_events(fsm, platform); + } + if timer_dirty { + timer.set_deadline(fsm.next_deadline()); } - output } - fn process_fsm_event(&mut self, fsm: &mut QlFsm, platform: &P, event: Event) { - match event { - Event::NewPeer => { - if let Some(peer) = fsm.peer().cloned() { - platform.persist_peer(peer); + fn drain_fsm_events(&mut self, fsm: &mut QlFsm, platform: &P) -> bool { + let mut timer_dirty = false; + while let Some(event) = fsm.poll_event() { + match event { + Event::TimerDirty => timer_dirty = true, + Event::NewPeer => { + if let Some(peer) = fsm.peer().cloned() { + platform.persist_peer(peer); + } } - } - Event::PeerStatusChanged(status) => { - if let Some(peer) = fsm.peer().map(|peer| peer.xid) { - platform.handle_peer_status(peer, status); + Event::PeerStatusChanged(status) => { + if let Some(peer) = fsm.peer().map(|peer| peer.xid) { + platform.handle_peer_status(peer, status); + } } - } - Event::Opened { - stream_id, - route_id, - } => { - self.handle_opened_stream(fsm, platform, stream_id, route_id); - } - Event::Readable(stream_id) => { - self.handle_inbound_readable(fsm, stream_id); - } - Event::Writable(stream_id) => { - self.poll_stream(fsm, stream_id); - } - Event::Finished(stream_id) => { - self.handle_inbound_finished(fsm, stream_id); - } - Event::Closed(frame) => { - self.handle_closed_stream(&frame); - } - Event::WritableClosed(frame) => { - self.handle_writable_closed(&frame); - } - Event::SessionClosed(_) => { - for (_, mut stream) in self.streams.drain() { - stream.fail_all(); + Event::Opened { + stream_id, + route_id, + } => { + self.handle_opened_stream(fsm, platform, stream_id, route_id); + } + Event::Readable(stream_id) => { + self.handle_inbound_readable(fsm, stream_id); + } + Event::Writable(stream_id) => { + self.poll_stream(fsm, stream_id); + } + Event::Finished(stream_id) => { + self.handle_inbound_finished(fsm, stream_id); + } + Event::Closed(frame) => { + self.handle_closed_stream(&frame); + } + Event::WritableClosed(frame) => { + self.handle_writable_closed(&frame); + } + Event::SessionClosed(_) => { + for (_, mut stream) in self.streams.drain() { + stream.fail_all(); + } } } } + timer_dirty } fn handle_opened_stream( @@ -440,16 +442,19 @@ impl DriverState { fsm: &mut QlFsm, platform: &'a P, in_flight: &mut Vec>>, - ) { + ) -> bool { + let mut filled = false; while in_flight.len() < self.max_concurrent_message_writes { let Some(write) = fsm.take_next_write(now(), platform) else { break; }; + filled = true; in_flight.push(InFlightWrite { session_write_id: write.write_id, future: platform.write_message(write.record), }); } + filled } fn poll_stream(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { From d75fad39c2a530c33f71dc6da20ea5212235fbac Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 10 Apr 2026 05:45:45 -0400 Subject: [PATCH 183/304] ql-fsm: dirty bit cleanup --- ql-fsm/src/fsm.rs | 123 ++++++++++++++++++------------------ ql-fsm/src/handshake/mod.rs | 2 +- ql-fsm/src/lib.rs | 4 +- ql-fsm/src/session/mod.rs | 38 ++++++----- ql-fsm/src/session/tests.rs | 16 +++-- 5 files changed, 98 insertions(+), 85 deletions(-) diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs index ee4770c4..e93f215c 100644 --- a/ql-fsm/src/fsm.rs +++ b/ql-fsm/src/fsm.rs @@ -1,4 +1,7 @@ -use std::time::{Duration, Instant}; +use std::{ + collections::VecDeque, + time::{Duration, Instant}, +}; use bytes::Bytes; use ql_wire::{self as wire, QlCrypto, RouteId, SessionCloseCode, StreamId, WireDecode}; @@ -58,21 +61,18 @@ pub fn receive( }) } wire::RecordType::Session => { - let state = fsm - .state - .link - .connected_mut() - .ok_or(ReceiveError::NoSession)?; + let QlFsm { state, events, .. } = fsm; + let conn = state.link.connected_mut().ok_or(ReceiveError::NoSession)?; let (decrypt_len, seq) = { let record = wire::QlSessionRecord::decode(&mut reader)?; - if record.header.connection_id != state.transport.rx_connection_id { + if record.header.connection_id != conn.transport.rx_connection_id { return Err(ReceiveError::InvalidPayload); } let payload = wire::decrypt_record( crypto, &record.header, record.payload, - &state.transport.rx_key, + &conn.transport.rx_key, )?; (payload.len(), record.header.seq) }; @@ -81,14 +81,12 @@ pub fn receive( let plaintext = Bytes::from(bytes).slice(len - decrypt_len..); let frames = wire::parse_session_frames(plaintext); - state.session.receive(fsm.state.now.instant, seq, frames, { - let pending_events = &mut fsm.pending_events; - |event| { - forward_session_event(event, pending_events); - } - }); + conn.session + .receive(state.now.instant, seq, frames, |event| { + forward_session_event(event, events); + }); - if state.session.is_closed() { + if conn.session.is_closed() { apply_session_closed(fsm); } fsm.state.mark_timer_dirty(); @@ -102,16 +100,16 @@ pub fn on_timer(fsm: &mut QlFsm) { fsm.state.mark_timer_dirty(); } - let Some(state) = fsm.state.link.connected_mut() else { + let QlFsm { state, events, .. } = fsm; + let Some(conn) = state.link.connected_mut() else { return; }; - let pending_events = &mut fsm.pending_events; - state.session.on_timer(fsm.state.now.instant, |event| { - forward_session_event(event, pending_events); + conn.session.on_timer(state.now.instant, |event| { + forward_session_event(event, events); }); - if state.session.is_closed() { + if conn.session.is_closed() { apply_session_closed(fsm); } } @@ -138,14 +136,18 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option Option Result, NoSessionError> { - let state = fsm.state.link.connected_mut_or_err()?; - state.session.open_stream(route_id) + let conn = fsm.state.link.connected_mut_or_err()?; + conn.session.open_stream(route_id) } pub fn stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result, StreamError> { - let state = fsm.state.link.connected_mut_or_err()?; - state.session.stream(stream_id) + let conn = fsm.state.link.connected_mut_or_err()?; + conn.session.stream(stream_id) } pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), NoSessionError> { - let state = fsm.state.link.connected_mut_or_err()?; - state.session.queue_ping() + let QlFsm { state, events, .. } = fsm; + let conn = state.link.connected_mut_or_err()?; + conn.session + .queue_ping(|event| forward_session_event(event, events))?; + Ok(()) } pub fn poll_event(fsm: &mut QlFsm) -> Option { - fsm.pending_events.pop_front().or_else(|| { - let mut timer_dirty = std::mem::take(&mut fsm.state.timer_dirty); - if let Some(state) = fsm.state.link.connected_mut() { - timer_dirty |= state.session.take_timer_dirty(); - } - timer_dirty.then_some(Event::TimerDirty) - }) + fsm.events + .pop_front() + .or_else(|| std::mem::take(&mut fsm.state.timer_dirty).then_some(Event::TimerDirty)) } pub fn emit_peer_status(fsm: &mut QlFsm) { if fsm.state.peer.is_some() { - fsm.pending_events + fsm.events .push_back(Event::PeerStatusChanged(fsm.state.link.status())); } } -fn forward_session_event( - event: SessionEvent, - pending_events: &mut std::collections::VecDeque, -) { +fn forward_session_event(event: SessionEvent, events: &mut VecDeque) { match event { + SessionEvent::TimerDirty => { + events.push_back(Event::TimerDirty); + } SessionEvent::Opened { stream_id, route_id, } => { - pending_events.push_back(Event::Opened { + events.push_back(Event::Opened { stream_id, route_id, }); } SessionEvent::Readable(stream_id) => { - pending_events.push_back(Event::Readable(stream_id)); + events.push_back(Event::Readable(stream_id)); } SessionEvent::Writable(stream_id) => { - pending_events.push_back(Event::Writable(stream_id)); + events.push_back(Event::Writable(stream_id)); } SessionEvent::Finished(stream_id) => { - pending_events.push_back(Event::Finished(stream_id)); + events.push_back(Event::Finished(stream_id)); } SessionEvent::Closed(frame) => { - pending_events.push_back(Event::Closed(frame)); + events.push_back(Event::Closed(frame)); } SessionEvent::WritableClosed(frame) => { - pending_events.push_back(Event::WritableClosed(frame)); + events.push_back(Event::WritableClosed(frame)); } SessionEvent::SessionClosed(close) => { - pending_events.push_back(Event::SessionClosed(close)); + events.push_back(Event::SessionClosed(close)); } } } fn apply_session_closed(fsm: &mut QlFsm) { - if matches!(fsm.state.link, crate::state::LinkState::Connected(_)) { - fsm.state.link = crate::state::LinkState::Idle; + if matches!(fsm.state.link, LinkState::Connected(_)) { + fsm.state.link = LinkState::Idle; fsm.state.mark_timer_dirty(); emit_peer_status(fsm); } diff --git a/ql-fsm/src/handshake/mod.rs b/ql-fsm/src/handshake/mod.rs index f0c5125c..8f33f67b 100644 --- a/ql-fsm/src/handshake/mod.rs +++ b/ql-fsm/src/handshake/mod.rs @@ -117,7 +117,7 @@ pub fn finish_handshake( } } else { fsm.state.peer = Some(remote_bundle); - fsm.pending_events.push_back(Event::NewPeer); + fsm.events.push_back(Event::NewPeer); } let config = &fsm.config; diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 3be7c5ac..72e42746 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -152,7 +152,7 @@ pub struct QlFsm { config: QlFsmConfig, identity: QlIdentity, state: QlFsmState, - pending_events: VecDeque, + events: VecDeque, } impl QlFsm { @@ -171,7 +171,7 @@ impl QlFsm { now, timer_dirty: false, }, - pending_events: VecDeque::new(), + events: VecDeque::new(), } } diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index c00fcddd..670f3f63 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -65,6 +65,7 @@ impl Default for SessionConfig { #[derive(Debug, Clone, PartialEq, Eq)] pub enum SessionEvent { + TimerDirty, Opened { stream_id: StreamId, route_id: RouteId, @@ -80,7 +81,6 @@ pub enum SessionEvent { pub struct SessionFsm { config: SessionConfig, state: SessionState, - timer_dirty: bool, } impl SessionFsm { @@ -107,7 +107,6 @@ impl SessionFsm { next_stream_index: 0, remote_stream_history: RemoteStreamHistory::new(config.local_parity.remote()), }, - timer_dirty: false, } } @@ -140,11 +139,11 @@ impl SessionFsm { Ok(StreamOps::new(self, stream_id, stream_index)) } - pub fn queue_ping(&mut self) -> Result<(), NoSessionError> { + pub fn queue_ping(&mut self, mut emit: impl FnMut(SessionEvent)) -> Result<(), NoSessionError> { self.ensure_session_open()?; if !self.state.pending_ping { self.state.pending_ping = true; - self.timer_dirty = true; + emit(SessionEvent::TimerDirty); } Ok(()) } @@ -159,7 +158,7 @@ impl SessionFsm { self.state.tracked_records.clear(); self.state.ack_state = AckState::Idle; self.clear_streams(); - self.timer_dirty = true; + emit(SessionEvent::TimerDirty); emit(SessionEvent::SessionClosed(close)); } @@ -183,7 +182,7 @@ impl SessionFsm { return; } - self.timer_dirty = true; + emit(SessionEvent::TimerDirty); self.collect_timeouts(now); let mut received_records = self.state.received_records.clone(); @@ -241,11 +240,16 @@ impl SessionFsm { } } - pub fn complete_write(&mut self, now: Instant, write_id: u64, success: bool) { + pub fn complete_write( + &mut self, + now: Instant, + write_id: u64, + success: bool, + mut emit: impl FnMut(SessionEvent), + ) { if !self.state.phase.is_open() { return; } - self.timer_dirty = true; if success { let Some(record) = self.state.tracked_records.get_mut(&write_id) else { return; @@ -255,6 +259,7 @@ impl SessionFsm { } self.state.last_activity_at = now; record.sent_at = Some(now); + emit(SessionEvent::TimerDirty); } else { if self .state @@ -274,6 +279,7 @@ impl SessionFsm { &mut self.state.streams, record, ); + emit(SessionEvent::TimerDirty); } } @@ -281,7 +287,7 @@ impl SessionFsm { if !self.state.phase.is_open() { return; } - self.timer_dirty = true; + emit(SessionEvent::TimerDirty); self.collect_timeouts(now); if !self.config.peer_timeout.is_zero() && self.state.last_inbound_at + self.config.peer_timeout <= now @@ -333,7 +339,11 @@ impl SessionFsm { .min() } - pub fn take_next_write(&mut self, now: Instant) -> Option<(Option, SessionRecordBuilder)> { + pub fn take_next_write( + &mut self, + now: Instant, + mut emit: impl FnMut(SessionEvent), + ) -> Option<(Option, SessionRecordBuilder)> { match &self.state.phase { SessionPhase::Closing(close) => { let seq = self.state.next_record_seq; @@ -341,7 +351,7 @@ impl SessionFsm { let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); assert!(builder.push_close(&close), "builder has capacity"); self.state.phase = SessionPhase::Closed; - self.timer_dirty = true; + emit(SessionEvent::TimerDirty); return Some((None, builder)); } SessionPhase::Closed => { @@ -366,16 +376,12 @@ impl SessionFsm { }); if timer_changed { - self.timer_dirty = true; + emit(SessionEvent::TimerDirty); } Some((write_id, builder)) } - pub fn take_timer_dirty(&mut self) -> bool { - std::mem::take(&mut self.timer_dirty) - } - fn build_next_record(&mut self, now: Instant) -> Option<(SessionRecordBuilder, TrackedRecord)> { let seq = self.state.next_record_seq; let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index ff208654..15e2e364 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -61,9 +61,9 @@ fn next_outbound( fsm: &mut SessionFsm, now: Instant, ) -> Option<(RecordSeq, Vec>>)> { - let (write_id, builder) = fsm.take_next_write(now)?; + let (write_id, builder) = fsm.take_next_write(now, |_| {})?; if let Some(write_id) = write_id { - fsm.complete_write(now, write_id, true); + fsm.complete_write(now, write_id, true, |_| {}); } Some(( builder.seq(), @@ -86,6 +86,9 @@ fn receive_events( let mut events = Vec::new(); fsm.receive(now, seq, frames, |event| events.push(event)); events + .into_iter() + .filter(|event| !matches!(event, SessionEvent::TimerDirty)) + .collect() } #[test] @@ -213,7 +216,7 @@ fn commit_stream_read_is_what_advances_stream_window() { vec![opened(stream_id), SessionEvent::Readable(stream_id)] ); - let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); + let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1), |_| {}).unwrap(); let first = decode_session_frames(builder.bytes()).unwrap(); assert!(write_id.is_none()); assert!(matches!(first.as_slice(), [SessionFrame::Ack(_)])); @@ -256,14 +259,14 @@ fn pure_ack_only_records_are_fire_and_forget() { let _ = receive_events(&mut fsm, now, seq(7), &record); - let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); + let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1), |_| {}).unwrap(); let ack = decode_session_frames(builder.bytes()).unwrap(); assert!(write_id.is_none()); assert!(matches!(ack.as_slice(), [SessionFrame::Ack(_)])); fsm.on_timer(now + retransmit_timeout + Duration::from_millis(1), |_| {}); assert!(fsm - .take_next_write(now + retransmit_timeout + Duration::from_millis(1)) + .take_next_write(now + retransmit_timeout + Duration::from_millis(1), |_| {}) .is_none()); } @@ -302,11 +305,12 @@ fn remote_stream_close_is_reliable_and_retried() { .unwrap() .close(CloseTarget::Both, StreamCloseCode(0)); - let (write_id, builder) = fsm.take_next_write(now).unwrap(); + let (write_id, builder) = fsm.take_next_write(now, |_| {}).unwrap(); fsm.complete_write( now, write_id.expect("stream close should be tracked"), true, + |_| {}, ); let first = decode_session_frames(builder.bytes()).unwrap(); assert!(matches!( From fb2b7617ab6fe781babd2109038787299ce47ad2 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 10 Apr 2026 06:17:28 -0400 Subject: [PATCH 184/304] ql-runtime: DriverEvent -> DriverStep --- ql-runtime/src/driver/mod.rs | 51 ++++++++++++++++++------------------ 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 17f46985..81c29975 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -9,10 +9,11 @@ use std::{ }, future::Future, pin::{pin, Pin}, - task::Poll, + task::{Context, Poll}, time::{Duration, Instant, SystemTime, UNIX_EPOCH}, }; +use async_channel::Recv; use futures_lite::future::poll_fn; use ql_fsm::{Event, FsmTime, QlFsm, WriteId}; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; @@ -56,18 +57,21 @@ impl Runtime

{ loop { state.drain_events(&mut fsm, &platform, &mut timer, &mut in_flight); - match next_driver_event(recv_future.as_mut(), &mut timer, &mut in_flight).await { - DriverEvent::Command(command) => { + let step = + poll_fn(|cx| next_step(cx, recv_future.as_mut(), &mut timer, &mut in_flight)).await; + + match step { + DriverStep::Command(command) => { state.drive_command(&mut fsm, command, &platform); } - DriverEvent::WriteCompleted { index, success } => { + DriverStep::WriteCompleted { index, success } => { let write = in_flight.swap_remove(index); DriverState::drive_write_completed(&mut fsm, write.session_write_id, success); } - DriverEvent::TimerExpired => { + DriverStep::TimerExpired => { fsm.on_timer(now()); } - DriverEvent::CommandsClosed => { + DriverStep::CommandsClosed => { if in_flight.is_empty() { break; } @@ -82,40 +86,37 @@ struct InFlightWrite { future: F, } -enum DriverEvent { +enum DriverStep { Command(RuntimeCommand), WriteCompleted { index: usize, success: bool }, TimerExpired, CommandsClosed, } -#[allow(clippy::future_not_send)] -async fn next_driver_event( - mut recv_future: Pin<&mut async_channel::Recv<'_, RuntimeCommand>>, +fn next_step( + cx: &mut Context<'_>, + mut recv_future: Pin<&mut Recv<'_, RuntimeCommand>>, timer: &mut T, in_flight: &mut [InFlightWrite], -) -> DriverEvent +) -> Poll where T: QlTimer, F: Future + Unpin, { - poll_fn(|cx| { - for (index, write) in in_flight.iter_mut().enumerate() { - if let Poll::Ready(success) = Pin::new(&mut write.future).poll(cx) { - return Poll::Ready(DriverEvent::WriteCompleted { index, success }); - } + for (index, write) in in_flight.iter_mut().enumerate() { + if let Poll::Ready(success) = Pin::new(&mut write.future).poll(cx) { + return Poll::Ready(DriverStep::WriteCompleted { index, success }); } + } - if timer.poll_wait(cx) == Poll::Ready(()) { - return Poll::Ready(DriverEvent::TimerExpired); - } + if timer.poll_wait(cx) == Poll::Ready(()) { + return Poll::Ready(DriverStep::TimerExpired); + } - recv_future - .as_mut() - .poll(cx) - .map(|res| res.map_or(DriverEvent::CommandsClosed, DriverEvent::Command)) - }) - .await + recv_future + .as_mut() + .poll(cx) + .map(|res| res.map_or(DriverStep::CommandsClosed, DriverStep::Command)) } impl DriverState { From 3b3f8c2bb98fd933c2df5013512ec5dbcbd62782 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 10 Apr 2026 06:54:34 -0400 Subject: [PATCH 185/304] ql: revert dirty timer bit --- ql-fsm/src/fsm.rs | 47 +++++++----------------------- ql-fsm/src/handshake/mod.rs | 16 ++++------- ql-fsm/src/lib.rs | 3 -- ql-fsm/src/session/mod.rs | 54 ++++++++--------------------------- ql-fsm/src/session/tests.rs | 22 +++++--------- ql-fsm/src/state.rs | 7 ----- ql-fsm/src/tests/handshake.rs | 2 +- ql-fsm/src/tests/mod.rs | 7 +---- ql-fsm/src/tests/proptest.rs | 1 - ql-fsm/src/tests/session.rs | 23 +++++++++------ ql-runtime/src/driver/mod.rs | 33 ++++++--------------- 11 files changed, 62 insertions(+), 153 deletions(-) diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs index e93f215c..60b1074a 100644 --- a/ql-fsm/src/fsm.rs +++ b/ql-fsm/src/fsm.rs @@ -15,30 +15,23 @@ pub fn handle_bind_peer(fsm: &mut QlFsm, peer: ql_wire::PeerBundle) { fsm.state.handshake = None; fsm.state.link = LinkState::Idle; fsm.state.peer = Some(peer); - fsm.state.mark_timer_dirty(); } pub fn handle_disarm_pairing(fsm: &mut QlFsm) { fsm.state.armed_pairing_token = None; handshake::handle_disarm_pairing(fsm); - fsm.state.mark_timer_dirty(); } pub fn handle_connect_xx(fsm: &mut QlFsm, token: ql_wire::PairingToken, crypto: &impl QlCrypto) { handshake::handle_connect_xx(fsm, token, crypto); - fsm.state.mark_timer_dirty(); } pub fn handle_connect_ik(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { - handshake::handle_connect_ik(fsm, crypto).inspect(|_| { - fsm.state.mark_timer_dirty(); - }) + handshake::handle_connect_ik(fsm, crypto) } pub fn handle_connect_kk(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { - handshake::handle_connect_kk(fsm, crypto).inspect(|_| { - fsm.state.mark_timer_dirty(); - }) + handshake::handle_connect_kk(fsm, crypto) } pub fn receive( @@ -56,9 +49,7 @@ pub fn receive( match header.record_type { wire::RecordType::Handshake => { let record = wire::QlHandshakeRecord::decode(&mut reader)?; - handshake::handle_handshake_record(fsm, crypto, &record).inspect(|_| { - fsm.state.mark_timer_dirty(); - }) + handshake::handle_handshake_record(fsm, crypto, &record) } wire::RecordType::Session => { let QlFsm { state, events, .. } = fsm; @@ -89,16 +80,13 @@ pub fn receive( if conn.session.is_closed() { apply_session_closed(fsm); } - fsm.state.mark_timer_dirty(); Ok(()) } } } pub fn on_timer(fsm: &mut QlFsm) { - if handshake::handle_timer(fsm) { - fsm.state.mark_timer_dirty(); - } + handshake::handle_timer(fsm); let QlFsm { state, events, .. } = fsm; let Some(conn) = state.link.connected_mut() else { @@ -136,12 +124,10 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option Option Result, Str } pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), NoSessionError> { - let QlFsm { state, events, .. } = fsm; - let conn = state.link.connected_mut_or_err()?; - conn.session - .queue_ping(|event| forward_session_event(event, events))?; - Ok(()) + let conn = fsm.state.link.connected_mut_or_err()?; + conn.session.queue_ping() } pub fn poll_event(fsm: &mut QlFsm) -> Option { - fsm.events - .pop_front() - .or_else(|| std::mem::take(&mut fsm.state.timer_dirty).then_some(Event::TimerDirty)) + fsm.events.pop_front() } pub fn emit_peer_status(fsm: &mut QlFsm) { @@ -209,9 +188,6 @@ pub fn emit_peer_status(fsm: &mut QlFsm) { fn forward_session_event(event: SessionEvent, events: &mut VecDeque) { match event { - SessionEvent::TimerDirty => { - events.push_back(Event::TimerDirty); - } SessionEvent::Opened { stream_id, route_id, @@ -245,7 +221,6 @@ fn forward_session_event(event: SessionEvent, events: &mut VecDeque) { fn apply_session_closed(fsm: &mut QlFsm) { if matches!(fsm.state.link, LinkState::Connected(_)) { fsm.state.link = LinkState::Idle; - fsm.state.mark_timer_dirty(); emit_peer_status(fsm); } } diff --git a/ql-fsm/src/handshake/mod.rs b/ql-fsm/src/handshake/mod.rs index 8f33f67b..a4b05e92 100644 --- a/ql-fsm/src/handshake/mod.rs +++ b/ql-fsm/src/handshake/mod.rs @@ -84,21 +84,17 @@ pub fn handle_handshake_record( } } -pub fn handle_timer(fsm: &mut QlFsm) -> bool { - let expired = fsm - .state - .link - .handshake_deadline() - .is_some_and(|d| d <= fsm.state.now.instant); - - if !expired { - return false; +pub fn handle_timer(fsm: &mut QlFsm) { + let Some(deadline) = fsm.state.link.handshake_deadline() else { + return; + }; + if deadline > fsm.state.now.instant { + return; } fsm.state.link = LinkState::Idle; fsm.state.handshake = None; emit_peer_status(fsm); - true } pub fn next_handshake_deadline(fsm: &QlFsm) -> Option { diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 72e42746..58ad12c7 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -68,8 +68,6 @@ pub enum PeerStatus { /// events emitted by `QlFsm` #[derive(Debug, Clone, PartialEq, Eq)] pub enum Event { - /// timer-related state changed; recompute the next deadline - TimerDirty, /// a peer was learned during handshake completion NewPeer, /// the peer changed connection state @@ -169,7 +167,6 @@ impl QlFsm { handshake: None, link: LinkState::Idle, now, - timer_dirty: false, }, events: VecDeque::new(), } diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 670f3f63..8c18cc46 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -65,7 +65,6 @@ impl Default for SessionConfig { #[derive(Debug, Clone, PartialEq, Eq)] pub enum SessionEvent { - TimerDirty, Opened { stream_id: StreamId, route_id: RouteId, @@ -139,12 +138,9 @@ impl SessionFsm { Ok(StreamOps::new(self, stream_id, stream_index)) } - pub fn queue_ping(&mut self, mut emit: impl FnMut(SessionEvent)) -> Result<(), NoSessionError> { + pub fn queue_ping(&mut self) -> Result<(), NoSessionError> { self.ensure_session_open()?; - if !self.state.pending_ping { - self.state.pending_ping = true; - emit(SessionEvent::TimerDirty); - } + self.state.pending_ping = true; Ok(()) } @@ -158,7 +154,6 @@ impl SessionFsm { self.state.tracked_records.clear(); self.state.ack_state = AckState::Idle; self.clear_streams(); - emit(SessionEvent::TimerDirty); emit(SessionEvent::SessionClosed(close)); } @@ -182,7 +177,6 @@ impl SessionFsm { return; } - emit(SessionEvent::TimerDirty); self.collect_timeouts(now); let mut received_records = self.state.received_records.clone(); @@ -240,13 +234,7 @@ impl SessionFsm { } } - pub fn complete_write( - &mut self, - now: Instant, - write_id: u64, - success: bool, - mut emit: impl FnMut(SessionEvent), - ) { + pub fn complete_write(&mut self, now: Instant, write_id: u64, success: bool) { if !self.state.phase.is_open() { return; } @@ -259,7 +247,6 @@ impl SessionFsm { } self.state.last_activity_at = now; record.sent_at = Some(now); - emit(SessionEvent::TimerDirty); } else { if self .state @@ -279,7 +266,6 @@ impl SessionFsm { &mut self.state.streams, record, ); - emit(SessionEvent::TimerDirty); } } @@ -287,7 +273,6 @@ impl SessionFsm { if !self.state.phase.is_open() { return; } - emit(SessionEvent::TimerDirty); self.collect_timeouts(now); if !self.config.peer_timeout.is_zero() && self.state.last_inbound_at + self.config.peer_timeout <= now @@ -321,13 +306,12 @@ impl SessionFsm { .map(|sent_at| sent_at + self.config.retransmit_timeout) }) .min(); - let keepalive_deadline = (self.state.phase == SessionPhase::Open - && !self.config.keepalive_interval.is_zero() - && !self.state.pending_ping) - .then_some(self.state.last_activity_at + self.config.keepalive_interval); - let peer_timeout_deadline = (self.state.phase == SessionPhase::Open - && !self.config.peer_timeout.is_zero()) - .then_some(self.state.last_inbound_at + self.config.peer_timeout); + let is_open = self.state.phase.is_open(); + let keepalive_deadline = + (is_open && !self.config.keepalive_interval.is_zero() && !self.state.pending_ping) + .then_some(self.state.last_activity_at + self.config.keepalive_interval); + let peer_timeout_deadline = (is_open && !self.config.peer_timeout.is_zero()) + .then_some(self.state.last_inbound_at + self.config.peer_timeout); [ ack_deadline, retransmit_deadline, @@ -339,11 +323,7 @@ impl SessionFsm { .min() } - pub fn take_next_write( - &mut self, - now: Instant, - mut emit: impl FnMut(SessionEvent), - ) -> Option<(Option, SessionRecordBuilder)> { + pub fn take_next_write(&mut self, now: Instant) -> Option<(Option, SessionRecordBuilder)> { match &self.state.phase { SessionPhase::Closing(close) => { let seq = self.state.next_record_seq; @@ -351,7 +331,6 @@ impl SessionFsm { let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); assert!(builder.push_close(&close), "builder has capacity"); self.state.phase = SessionPhase::Closed; - emit(SessionEvent::TimerDirty); return Some((None, builder)); } SessionPhase::Closed => { @@ -359,15 +338,13 @@ impl SessionFsm { } SessionPhase::Open => {} } - let timeouts_changed = self.collect_timeouts(now); + self.collect_timeouts(now); let (builder, outbound) = self.build_next_record(now)?; let should_track = outbound.ping_included || !outbound.window_updates.is_empty() || !outbound.frames.is_empty(); - let timer_changed = timeouts_changed || outbound.ack_included || outbound.ping_included; - let write_id = should_track.then(|| { let write_id = self.state.next_write_id; self.state.next_write_id = self.state.next_write_id.wrapping_add(1); @@ -375,10 +352,6 @@ impl SessionFsm { write_id }); - if timer_changed { - emit(SessionEvent::TimerDirty); - } - Some((write_id, builder)) } @@ -586,15 +559,13 @@ impl SessionFsm { } } - fn collect_timeouts(&mut self, now: Instant) -> bool { + fn collect_timeouts(&mut self, now: Instant) { let retransmit_timeout = self.config.retransmit_timeout; - let mut changed = false; for (_, record) in self.state.tracked_records.extract_if(.., |_, record| { record .sent_at .is_some_and(|sent_at| sent_at + retransmit_timeout <= now) }) { - changed = true; restore_tracked_record( now, &mut self.state.ack_state, @@ -603,7 +574,6 @@ impl SessionFsm { record, ); } - changed } fn handle_stream_data( diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 15e2e364..017bd91d 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -61,9 +61,9 @@ fn next_outbound( fsm: &mut SessionFsm, now: Instant, ) -> Option<(RecordSeq, Vec>>)> { - let (write_id, builder) = fsm.take_next_write(now, |_| {})?; + let (write_id, builder) = fsm.take_next_write(now)?; if let Some(write_id) = write_id { - fsm.complete_write(now, write_id, true, |_| {}); + fsm.complete_write(now, write_id, true); } Some(( builder.seq(), @@ -86,9 +86,6 @@ fn receive_events( let mut events = Vec::new(); fsm.receive(now, seq, frames, |event| events.push(event)); events - .into_iter() - .filter(|event| !matches!(event, SessionEvent::TimerDirty)) - .collect() } #[test] @@ -216,7 +213,7 @@ fn commit_stream_read_is_what_advances_stream_window() { vec![opened(stream_id), SessionEvent::Readable(stream_id)] ); - let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1), |_| {}).unwrap(); + let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); let first = decode_session_frames(builder.bytes()).unwrap(); assert!(write_id.is_none()); assert!(matches!(first.as_slice(), [SessionFrame::Ack(_)])); @@ -259,14 +256,14 @@ fn pure_ack_only_records_are_fire_and_forget() { let _ = receive_events(&mut fsm, now, seq(7), &record); - let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1), |_| {}).unwrap(); + let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); let ack = decode_session_frames(builder.bytes()).unwrap(); assert!(write_id.is_none()); assert!(matches!(ack.as_slice(), [SessionFrame::Ack(_)])); fsm.on_timer(now + retransmit_timeout + Duration::from_millis(1), |_| {}); assert!(fsm - .take_next_write(now + retransmit_timeout + Duration::from_millis(1), |_| {}) + .take_next_write(now + retransmit_timeout + Duration::from_millis(1)) .is_none()); } @@ -305,13 +302,8 @@ fn remote_stream_close_is_reliable_and_retried() { .unwrap() .close(CloseTarget::Both, StreamCloseCode(0)); - let (write_id, builder) = fsm.take_next_write(now, |_| {}).unwrap(); - fsm.complete_write( - now, - write_id.expect("stream close should be tracked"), - true, - |_| {}, - ); + let (write_id, builder) = fsm.take_next_write(now).unwrap(); + fsm.complete_write(now, write_id.expect("stream close should be tracked"), true); let first = decode_session_frames(builder.bytes()).unwrap(); assert!(matches!( first.as_slice(), diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index db9711b9..79a8c5ee 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -15,13 +15,6 @@ pub struct QlFsmState { pub handshake: Option, pub link: LinkState, pub now: FsmTime, - pub timer_dirty: bool, -} - -impl QlFsmState { - pub fn mark_timer_dirty(&mut self) { - self.timer_dirty = true; - } } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index edf02a03..374558fe 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -3,7 +3,7 @@ use std::time::Duration; use ql_wire::QlHandshakeRecord; use super::*; -use crate::{state::LinkState, NoPeerError, PeerStatus, Event}; +use crate::{state::LinkState, Event, NoPeerError, PeerStatus}; #[test] fn ik_connect_round_trip_establishes_transport() { diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 93e4a3a7..3ce2c1f7 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -253,12 +253,7 @@ impl Harness { } fn take_event(&mut self, side: Side) -> Option { - loop { - match self.node_mut(side).fsm.poll_event() { - Some(Event::TimerDirty) => {} - event => return event, - } - } + self.node_mut(side).fsm.poll_event() } fn drain_events(&mut self, side: Side) -> Vec { diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index cb41664d..fc6d702b 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -405,7 +405,6 @@ impl Runner { fn process_events(&mut self, side: Side, events: Vec) -> TestCaseResult { for event in events { match event { - Event::TimerDirty => {} Event::NewPeer => {} Event::PeerStatusChanged(status) => { self.events[side.idx()].note_peer_status(status); diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 229f3fe3..36faee20 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -4,9 +4,7 @@ use bytes::Bytes; use ql_wire::{RouteId, SessionClose, StreamId, VarInt}; use super::*; -use crate::{ - state::LinkState, CommitReadError, NoSessionError, PeerStatus, Event, StreamError, -}; +use crate::{state::LinkState, CommitReadError, Event, NoSessionError, PeerStatus, StreamError}; fn stream_id(value: u32) -> StreamId { StreamId(VarInt::from_u32(value)) @@ -351,9 +349,12 @@ fn close_session_disconnects_locally() { .fsm .close_session(ql_wire::SessionCloseCode::CANCELLED); - assert!(matches!(harness.take_event(Side::A), Some(Event::SessionClosed(SessionClose { - code: ql_wire::SessionCloseCode::CANCELLED, - })))); + assert!(matches!( + harness.take_event(Side::A), + Some(Event::SessionClosed(SessionClose { + code: ql_wire::SessionCloseCode::CANCELLED, + })) + )); assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); assert!(matches!( harness.a.fsm.open_stream(route_id(1)), @@ -362,7 +363,10 @@ fn close_session_disconnects_locally() { assert_eq!(harness.a.fsm.queue_ping(), Err(NoSessionError)); let close = harness.next_decoded_outbound(Side::A).unwrap(); - assert!(matches!(close.frames.as_slice(), [ql_wire::SessionFrame::Close(_)])); + assert!(matches!( + close.frames.as_slice(), + [ql_wire::SessionFrame::Close(_)] + )); assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); assert_eq!( @@ -444,7 +448,10 @@ fn session_timeout_emits_close_before_disconnect() { ); let close = harness.next_decoded_outbound(Side::A).unwrap(); - assert!(matches!(close.frames.as_slice(), [ql_wire::SessionFrame::Close(_)])); + assert!(matches!( + close.frames.as_slice(), + [ql_wire::SessionFrame::Close(_)] + )); assert_eq!( harness.take_event(Side::A), Some(Event::PeerStatusChanged(PeerStatus::Disconnected)) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 81c29975..a070dedd 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -55,7 +55,11 @@ impl Runtime

{ let mut recv_future = pin!(recv_future); loop { - state.drain_events(&mut fsm, &platform, &mut timer, &mut in_flight); + state.drain_fsm_events(&mut fsm, &platform); + if state.fill_write_slots(&mut fsm, &platform, &mut in_flight) { + state.drain_fsm_events(&mut fsm, &platform); + } + timer.set_deadline(fsm.next_deadline()); let step = poll_fn(|cx| next_step(cx, recv_future.as_mut(), &mut timer, &mut in_flight)).await; @@ -71,7 +75,7 @@ impl Runtime

{ DriverStep::TimerExpired => { fsm.on_timer(now()); } - DriverStep::CommandsClosed => { + DriverStep::Closed => { if in_flight.is_empty() { break; } @@ -90,7 +94,7 @@ enum DriverStep { Command(RuntimeCommand), WriteCompleted { index: usize, success: bool }, TimerExpired, - CommandsClosed, + Closed, } fn next_step( @@ -116,7 +120,7 @@ where recv_future .as_mut() .poll(cx) - .map(|res| res.map_or(DriverStep::CommandsClosed, DriverStep::Command)) + .map(|res| res.map_or(DriverStep::Closed, DriverStep::Command)) } impl DriverState { @@ -229,27 +233,9 @@ impl DriverState { } } - fn drain_events<'a, P: QlPlatform + 'a>( - &mut self, - fsm: &mut QlFsm, - platform: &'a P, - timer: &mut impl QlTimer, - in_flight: &mut Vec>>, - ) { - let mut timer_dirty = self.drain_fsm_events(fsm, platform); - if self.fill_write_slots(fsm, platform, in_flight) { - timer_dirty |= self.drain_fsm_events(fsm, platform); - } - if timer_dirty { - timer.set_deadline(fsm.next_deadline()); - } - } - - fn drain_fsm_events(&mut self, fsm: &mut QlFsm, platform: &P) -> bool { - let mut timer_dirty = false; + fn drain_fsm_events(&mut self, fsm: &mut QlFsm, platform: &P) { while let Some(event) = fsm.poll_event() { match event { - Event::TimerDirty => timer_dirty = true, Event::NewPeer => { if let Some(peer) = fsm.peer().cloned() { platform.persist_peer(peer); @@ -288,7 +274,6 @@ impl DriverState { } } } - timer_dirty } fn handle_opened_stream( From f67e2b8bf833a8265579f23737fa16487338c1c5 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 10 Apr 2026 09:18:26 -0400 Subject: [PATCH 186/304] ql-fsm: granular recv errors --- ql-fsm/src/error.rs | 4 ++++ ql-fsm/src/handshake/xx.rs | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs index 8470cd1f..b0681ca4 100644 --- a/ql-fsm/src/error.rs +++ b/ql-fsm/src/error.rs @@ -13,6 +13,8 @@ pub enum ReceiveError { DecryptFailed, InvalidXid, NoSession, + InvalidPairingToken, + Replay, } impl Display for ReceiveError { @@ -24,6 +26,8 @@ impl Display for ReceiveError { Self::DecryptFailed => "decryption failed", Self::InvalidXid => "invalid xid", Self::NoSession => "no active session", + Self::InvalidPairingToken => "invalid pairing token", + Self::Replay => "replay", }; f.write_str(message) } diff --git a/ql-fsm/src/handshake/xx.rs b/ql-fsm/src/handshake/xx.rs index 5e1de43f..329e6090 100644 --- a/ql-fsm/src/handshake/xx.rs +++ b/ql-fsm/src/handshake/xx.rs @@ -38,10 +38,10 @@ pub fn handle_xx1( return Ok(()); } if is_replayed_handshake_start(fsm, message.meta) { - return Ok(()); + return Err(ReceiveError::Replay); } if fsm.state.armed_pairing_token != Some(message.header.pairing_token) { - return Ok(()); + return Err(ReceiveError::InvalidPairingToken); } reset_connected_session_if_needed(fsm); From 9d5a9fa03fe8ddb0914812eeff2a028ad2c4ad02 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 10 Apr 2026 09:58:44 -0400 Subject: [PATCH 187/304] ql-fsm: pairing_token api --- ql-fsm/src/lib.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 58ad12c7..799d13df 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -187,6 +187,10 @@ impl QlFsm { self.state.armed_pairing_token = Some(token); } + pub fn pairing_token(&self) -> Option<&PairingToken> { + self.state.armed_pairing_token.as_ref() + } + /// disarms inbound xx pairing and rejects any in-flight inbound xx responder state pub fn disarm_pairing(&mut self) { fsm::handle_disarm_pairing(self); From e6c35386a5737e7bbe36ba9137f8457a9896dbce Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 10 Apr 2026 10:08:04 -0400 Subject: [PATCH 188/304] surface pairing token errors --- ql-fsm/src/error.rs | 35 ++++++++++++++++++++++------------- ql-fsm/src/handshake/xx.rs | 11 +++++++++-- ql-fsm/src/tests/handshake.rs | 29 ++++++++++++++++++++++++++--- ql-wire/src/handshake/mod.rs | 10 ++++++++++ 4 files changed, 67 insertions(+), 18 deletions(-) diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs index b0681ca4..88563ea8 100644 --- a/ql-fsm/src/error.rs +++ b/ql-fsm/src/error.rs @@ -3,7 +3,7 @@ use std::{ fmt::{Display, Formatter}, }; -use ql_wire::WireError; +use ql_wire::{PairingToken, WireError}; #[derive(Debug, Clone, PartialEq, Eq)] pub enum ReceiveError { @@ -13,23 +13,32 @@ pub enum ReceiveError { DecryptFailed, InvalidXid, NoSession, - InvalidPairingToken, + NotPairingMode, + InvalidPairingToken { + expected: PairingToken, + actual: PairingToken, + }, Replay, } impl Display for ReceiveError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let message = match self { - Self::InvalidPayload => "invalid payload", - Self::InvalidState => "invalid state", - Self::Expired => "expired", - Self::DecryptFailed => "decryption failed", - Self::InvalidXid => "invalid xid", - Self::NoSession => "no active session", - Self::InvalidPairingToken => "invalid pairing token", - Self::Replay => "replay", - }; - f.write_str(message) + match self { + Self::InvalidPayload => f.write_str("invalid payload"), + Self::InvalidState => f.write_str("invalid state"), + Self::Expired => f.write_str("expired"), + Self::DecryptFailed => f.write_str("decryption failed"), + Self::InvalidXid => f.write_str("invalid xid"), + Self::NoSession => f.write_str("no active session"), + Self::NotPairingMode => f.write_str("not in pairing mode"), + Self::InvalidPairingToken { expected, actual } => { + write!( + f, + "invalid pairing token: expected {expected}, actual {actual}" + ) + } + Self::Replay => f.write_str("replay"), + } } } diff --git a/ql-fsm/src/handshake/xx.rs b/ql-fsm/src/handshake/xx.rs index 329e6090..b95879a3 100644 --- a/ql-fsm/src/handshake/xx.rs +++ b/ql-fsm/src/handshake/xx.rs @@ -40,8 +40,15 @@ pub fn handle_xx1( if is_replayed_handshake_start(fsm, message.meta) { return Err(ReceiveError::Replay); } - if fsm.state.armed_pairing_token != Some(message.header.pairing_token) { - return Err(ReceiveError::InvalidPairingToken); + match fsm.state.armed_pairing_token { + Some(expected) if expected != message.header.pairing_token => { + return Err(ReceiveError::InvalidPairingToken { + expected, + actual: message.header.pairing_token, + }); + } + Some(_) => {} + None => return Err(ReceiveError::NotPairingMode), } reset_connected_session_if_needed(fsm); diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 374558fe..d492757c 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -3,7 +3,7 @@ use std::time::Duration; use ql_wire::QlHandshakeRecord; use super::*; -use crate::{state::LinkState, Event, NoPeerError, PeerStatus}; +use crate::{state::LinkState, Event, NoPeerError, PeerStatus, ReceiveError}; #[test] fn ik_connect_round_trip_establishes_transport() { @@ -119,19 +119,42 @@ fn connect_ik_emits_initiator_status() { } #[test] -fn inbound_xx1_ignored_when_pairing_token_not_armed() { +fn inbound_xx1_rejects_when_not_in_pairing_mode() { let mut harness = Harness::paired(QlFsmConfig::default(), false, false); let token = pairing_token(3); harness.connect_xx(Side::A, token); let xx1 = harness.next_outbound(Side::A).unwrap(); - harness.deliver(Side::B, xx1); + let time = harness.time(); + let Node { fsm, crypto } = &mut harness.b; + let err = fsm.receive(time, xx1, crypto); + assert_eq!(err, Err(ReceiveError::NotPairingMode)); assert!(matches!(harness.b.fsm.state.link, LinkState::Idle)); assert!(harness.drain_events(Side::B).is_empty()); assert!(harness.next_outbound(Side::B).is_none()); } +#[test] +fn inbound_xx1_rejects_mismatched_pairing_token_with_expected_and_actual() { + let mut harness = Harness::paired(QlFsmConfig::default(), false, false); + let expected = pairing_token(4); + let actual = pairing_token(7); + + harness.b.fsm.arm_pairing(expected); + harness.connect_xx(Side::A, actual); + let xx1 = harness.next_outbound(Side::A).unwrap(); + + let time = harness.time(); + let Node { fsm, crypto } = &mut harness.b; + let err = fsm.receive(time, xx1, crypto); + + assert_eq!( + err, + Err(ReceiveError::InvalidPairingToken { expected, actual }) + ); +} + #[test] fn disarm_pairing_rejects_inflight_inbound_xx_responder() { let mut harness = Harness::paired(QlFsmConfig::default(), false, false); diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index fbb1db1e..9311617e 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -61,6 +61,15 @@ impl PairingToken { pub const SIZE: usize = 16; } +impl std::fmt::Display for PairingToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for byte in self.0 { + write!(f, "{byte:02x}")?; + } + Ok(()) + } +} + impl WireEncode for PairingToken { fn encoded_len(&self) -> usize { Self::SIZE @@ -77,6 +86,7 @@ impl codec::WireDecode for PairingToken { } } +// TODO: this should not be exposed #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct XxHeader { pub pairing_token: PairingToken, From fa510803311d51d9b1e71c6ce3d1228962a40cfb Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 10 Apr 2026 10:35:25 -0400 Subject: [PATCH 189/304] disable expiration --- ql-wire/src/handshake/meta.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/ql-wire/src/handshake/meta.rs b/ql-wire/src/handshake/meta.rs index 52e3e870..4987b03c 100644 --- a/ql-wire/src/handshake/meta.rs +++ b/ql-wire/src/handshake/meta.rs @@ -29,12 +29,9 @@ impl WireEncode for HandshakeId { impl HandshakeMeta { pub const WIRE_SIZE: usize = size_of::() + size_of::(); - pub fn ensure_not_expired(&self, now_seconds: u64) -> Result<(), WireError> { - if now_seconds > self.valid_until { - Err(WireError::Expired) - } else { - Ok(()) - } + // TODO: re-think expiration + pub fn ensure_not_expired(&self, _now_seconds: u64) -> Result<(), WireError> { + Ok(()) } } From 478a0bbb7c2cad1d83109cf930572c568995d3f2 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 10 Apr 2026 13:07:09 -0400 Subject: [PATCH 190/304] ql-runtime: add open stream lane test --- ql-runtime/src/tests/stream.rs | 41 ++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index 4c114f2e..85f63c8c 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -236,6 +236,47 @@ async fn dropping_inbound_reader_cancels_remote_writer() { .await; } +#[tokio::test(flavor = "current_thread")] +async fn closing_initiator_reader_preserves_initiator_writer() { + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let (done_tx, done_rx) = async_channel::bounded(1); + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let request = read_all(stream.reader).await.unwrap(); + done_tx.send(request).await.unwrap(); + }); + + let stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + let mut writer = stream.writer; + stream.reader.close(StreamCloseCode(0)); + + writer.write(Bytes::from_static(&[1, 2])).await.unwrap(); + writer.write(Bytes::from_static(&[3, 4])).await.unwrap(); + writer.finish(); + + let request = tokio::time::timeout(Duration::from_secs(2), done_rx.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(request, vec![1, 2, 3, 4]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + #[tokio::test(flavor = "current_thread")] async fn max_concurrent_message_writes_is_respected() { run_local_test(async { From bb3aa34a90bc889eb82f0da2c35ff68bba1323f2 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 10 Apr 2026 15:35:50 -0400 Subject: [PATCH 191/304] ql-runtime: fix event rpc --- ql-runtime/src/rpc/mod.rs | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index aa684fd0..cee8908d 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -10,7 +10,6 @@ use ql_rpc::{ request::{self, Request as RequestRpc}, request_with_progress::{self as rpc_request_with_progress, RequestWithProgress}, subscription::{self as rpc_subscription, Subscription as SubscriptionRpc}, - RpcError, }; use ql_wire::{RouteId, VarInt}; @@ -29,13 +28,11 @@ impl RpcHandle { { let mut payload = Vec::new(); notification::encode_event::(event, &mut payload).map_err(RpcCallError::Codec)?; - let response = self.start_request(M::METHOD, payload).await?; - let response = read_all(response).await?; - if response.is_empty() { - Ok(()) - } else { - Err(RpcCallError::Rpc(RpcError::TrailingBytes)) - } + let route_id = RouteId(VarInt::from_u32(M::METHOD.0)); + let mut stream = self.inner.open_stream(route_id).await?; + stream.reader.close(ql_wire::StreamCloseCode(0)); + stream.writer.write(Bytes::from(payload)).await?; + Ok(()) } pub async fn request( From 8717b35347183de4f33ed2409a881fe08eabc99a Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 10 Apr 2026 16:11:45 -0400 Subject: [PATCH 192/304] ql-rpc: incremental decoding --- ql-rpc/src/codec.rs | 108 ++++++++++++++++++++++- ql-rpc/src/lib.rs | 1 + ql-rpc/src/rpc/notification.rs | 55 ++---------- ql-rpc/src/rpc/request.rs | 78 +++-------------- ql-rpc/src/rpc/request_with_progress.rs | 112 +++--------------------- ql-rpc/src/rpc/subscription.rs | 66 ++------------ ql-runtime/src/rpc/mod.rs | 27 ++++-- ql-runtime/src/tests/mod.rs | 4 + ql-runtime/src/tests/rpc.rs | 41 +++++---- 9 files changed, 190 insertions(+), 302 deletions(-) diff --git a/ql-rpc/src/codec.rs b/ql-rpc/src/codec.rs index f0caece3..94b63875 100644 --- a/ql-rpc/src/codec.rs +++ b/ql-rpc/src/codec.rs @@ -1,8 +1,8 @@ -use std::collections::VecDeque; +use std::{collections::VecDeque, marker::PhantomData}; use bytes::{Buf, BufMut, Bytes}; -use crate::{RpcCodec, RpcError}; +use crate::{RpcCodec, RpcCodecError, RpcError}; const LENGTH_SIZE: usize = 8; @@ -16,6 +16,50 @@ pub fn encode_value_part>( Ok(()) } +pub enum ReadValueStep { + NeedMore(ValueReader), + Value(T), +} + +pub struct ValueReader { + bytes: ChunkQueue, + marker: PhantomData T>, +} + +impl Default for ValueReader { + fn default() -> Self { + Self::new() + } +} + +impl ValueReader { + pub fn new() -> Self { + Self { + bytes: ChunkQueue::new(), + marker: PhantomData, + } + } + + pub fn push(mut self, chunk: Bytes) -> Self { + self.bytes.push(chunk); + self + } + + pub fn advance(self) -> Result, RpcCodecError> { + let mut this = self; + let Some(mut body) = this.bytes.try_take_part().map_err(RpcCodecError::Rpc)? else { + return Ok(ReadValueStep::NeedMore(this)); + }; + + let value = T::decode_value(&mut body).map_err(RpcCodecError::Codec)?; + drop(body); + if this.bytes.remaining() > 0 { + return Err(RpcCodecError::Rpc(RpcError::TrailingBytes)); + } + Ok(ReadValueStep::Value(value)) + } +} + #[derive(Debug, Default)] pub struct ChunkQueue { chunks: VecDeque, @@ -224,3 +268,63 @@ pub fn backpatch_length + ?Sized>(out: &mut B, start: usize) { let payload_len = u64::try_from(payload_len).expect("rpc payload exceeds u64 length framing"); out[start..payload_start].copy_from_slice(&payload_len.to_le_bytes()); } + +#[cfg(test)] +mod tests { + use bytes::{Buf, BufMut, Bytes}; + + use super::{encode_value_part, ReadValueStep, ValueReader}; + use crate::RpcCodec; + + #[derive(Debug, Clone, PartialEq, Eq)] + struct BytesValue(Vec); + + impl RpcCodec for BytesValue { + type Error = core::convert::Infallible; + + fn encode_value(&self, out: &mut B) -> Result<(), Self::Error> { + out.put_slice(&self.0); + Ok(()) + } + + fn decode_value(bytes: &mut B) -> Result { + Ok(Self(bytes.copy_to_bytes(bytes.remaining()).to_vec())) + } + } + + #[test] + fn value_reader_round_trips_framed_values() { + let mut encoded = Vec::new(); + encode_value_part(&BytesValue(b"hello".to_vec()), &mut encoded).unwrap(); + + match ValueReader::::new() + .push(Bytes::from(encoded)) + .advance() + .unwrap() + { + ReadValueStep::Value(value) => assert_eq!(value, BytesValue(b"hello".to_vec())), + _ => unreachable!(), + } + } + + #[test] + fn value_reader_waits_for_complete_frame() { + let mut encoded = Vec::new(); + encode_value_part(&BytesValue(b"hello".to_vec()), &mut encoded).unwrap(); + let encoded = Bytes::from(encoded); + + let reader = match ValueReader::::new() + .push(encoded.slice(..4)) + .advance() + .unwrap() + { + ReadValueStep::NeedMore(next) => next, + _ => unreachable!(), + }; + + match reader.push(encoded.slice(4..)).advance().unwrap() { + ReadValueStep::Value(value) => assert_eq!(value, BytesValue(b"hello".to_vec())), + _ => unreachable!(), + } + } +} diff --git a/ql-rpc/src/lib.rs b/ql-rpc/src/lib.rs index 994e473d..ef27f843 100644 --- a/ql-rpc/src/lib.rs +++ b/ql-rpc/src/lib.rs @@ -6,6 +6,7 @@ pub(crate) mod codec; mod error; pub mod rpc; +pub use codec::{ReadValueStep, ValueReader}; pub use error::*; pub use rpc::*; diff --git a/ql-rpc/src/rpc/notification.rs b/ql-rpc/src/rpc/notification.rs index ae452d3d..0e09a8fc 100644 --- a/ql-rpc/src/rpc/notification.rs +++ b/ql-rpc/src/rpc/notification.rs @@ -1,6 +1,6 @@ use bytes::BufMut; -use crate::{MethodId, RpcCodec}; +use crate::{codec, MethodId, ReadValueStep, RpcCodec, ValueReader}; pub trait Notification { const METHOD: MethodId; @@ -8,55 +8,12 @@ pub trait Notification { type Event: RpcCodec; } +pub type EventReader = ValueReader<::Event>; +pub type EventReadStep = ReadValueStep<::Event>; + pub fn encode_event( event: &M::Event, - out: &mut impl BufMut, + out: &mut (impl BufMut + AsMut<[u8]>), ) -> Result<(), M::Error> { - event.encode_value(out) -} - -pub fn decode_event(mut body: &[u8]) -> Result { - M::Event::decode_value(&mut body) -} - -#[cfg(test)] -mod tests { - use bytes::{Buf, BufMut}; - - use super::{decode_event, encode_event, Notification}; - use crate::{MethodId, RpcCodec}; - - #[derive(Debug, Clone, PartialEq, Eq)] - struct BytesValue(Vec); - - impl RpcCodec for BytesValue { - type Error = core::convert::Infallible; - - fn encode_value(&self, out: &mut B) -> Result<(), Self::Error> { - out.put_slice(&self.0); - Ok(()) - } - - fn decode_value(bytes: &mut B) -> Result { - Ok(Self(bytes.copy_to_bytes(bytes.remaining()).to_vec())) - } - } - - struct Notify; - - impl Notification for Notify { - const METHOD: MethodId = MethodId(13); - type Error = core::convert::Infallible; - type Event = BytesValue; - } - - #[test] - fn event_round_trip_preserves_payload() { - let mut encoded = Vec::new(); - encode_event::(&BytesValue(b"hello".to_vec()), &mut encoded).unwrap(); - assert_eq!( - decode_event::(&encoded).unwrap(), - BytesValue(b"hello".to_vec()) - ); - } + codec::encode_value_part(event, out) } diff --git a/ql-rpc/src/rpc/request.rs b/ql-rpc/src/rpc/request.rs index 0483dd2d..aae44664 100644 --- a/ql-rpc/src/rpc/request.rs +++ b/ql-rpc/src/rpc/request.rs @@ -1,6 +1,6 @@ use bytes::BufMut; -use crate::{MethodId, RpcCodec}; +use crate::{codec, MethodId, ReadValueStep, RpcCodec, ValueReader}; pub trait Request { const METHOD: MethodId; @@ -9,79 +9,21 @@ pub trait Request { type Response: RpcCodec; } +pub type RequestReader = ValueReader<::Request>; +pub type RequestReadStep = ReadValueStep<::Request>; +pub type ResponseReader = ValueReader<::Response>; +pub type ResponseReadStep = ReadValueStep<::Response>; + pub fn encode_request( request: &M::Request, - out: &mut impl BufMut, + out: &mut (impl BufMut + AsMut<[u8]>), ) -> Result<(), M::Error> { - request.encode_value(out) -} - -pub fn decode_request(body: &[u8]) -> Result { - let mut body = body; - M::Request::decode_value(&mut body) + codec::encode_value_part(request, out) } pub fn encode_response( response: &M::Response, - out: &mut impl BufMut, + out: &mut (impl BufMut + AsMut<[u8]>), ) -> Result<(), M::Error> { - response.encode_value(out) -} - -pub fn decode_response(bytes: &[u8]) -> Result { - let mut bytes = bytes; - M::Response::decode_value(&mut bytes) -} - -#[cfg(test)] -mod tests { - use bytes::{Buf, BufMut}; - - use super::*; - use crate::{MethodId, RpcCodec}; - - #[derive(Debug, Clone, PartialEq, Eq)] - struct BytesValue(Vec); - - impl RpcCodec for BytesValue { - type Error = core::convert::Infallible; - - fn encode_value(&self, out: &mut B) -> Result<(), Self::Error> { - out.put_slice(&self.0); - Ok(()) - } - - fn decode_value(bytes: &mut B) -> Result { - Ok(Self(bytes.copy_to_bytes(bytes.remaining()).to_vec())) - } - } - - struct Echo; - - impl Request for Echo { - const METHOD: MethodId = MethodId(7); - type Error = core::convert::Infallible; - type Request = BytesValue; - type Response = BytesValue; - } - - #[test] - fn request_round_trip_preserves_payload() { - let mut encoded = Vec::new(); - encode_request::(&BytesValue(b"hello".to_vec()), &mut encoded).unwrap(); - assert_eq!( - decode_request::(&encoded).unwrap(), - BytesValue(b"hello".to_vec()) - ); - } - - #[test] - fn response_round_trip_preserves_payload() { - let mut encoded = Vec::new(); - encode_response::(&BytesValue(b"done".to_vec()), &mut encoded).unwrap(); - assert_eq!( - decode_response::(&encoded).unwrap(), - BytesValue(b"done".to_vec()) - ); - } + codec::encode_value_part(response, out) } diff --git a/ql-rpc/src/rpc/request_with_progress.rs b/ql-rpc/src/rpc/request_with_progress.rs index 03d39fd6..8631f37d 100644 --- a/ql-rpc/src/rpc/request_with_progress.rs +++ b/ql-rpc/src/rpc/request_with_progress.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use bytes::{BufMut, Bytes}; -use crate::{codec, MethodId, RpcCodec, RpcCodecError, RpcError}; +use crate::{codec, MethodId, ReadValueStep, RpcCodec, RpcCodecError, RpcError, ValueReader}; pub trait RequestWithProgress { const METHOD: MethodId; @@ -12,6 +12,9 @@ pub trait RequestWithProgress { type Response: RpcCodec; } +pub type RequestReader = ValueReader<::Request>; +pub type RequestReadStep = ReadValueStep<::Request>; + pub enum ReadStep { NeedMore(ResponseReader), Progress { @@ -90,13 +93,9 @@ enum FrameKind { pub fn encode_request( request: &M::Request, - out: &mut impl BufMut, + out: &mut (impl BufMut + AsMut<[u8]>), ) -> Result<(), M::Error> { - request.encode_value(out) -} - -pub fn decode_request(mut body: &[u8]) -> Result { - M::Request::decode_value(&mut body) + codec::encode_value_part(request, out) } pub fn encode_progress( @@ -129,11 +128,8 @@ fn encode_tagged_value_part>( mod tests { use bytes::{Buf, BufMut, Bytes}; - use super::{ - decode_request, encode_progress, encode_request, encode_response, ReadStep, - RequestWithProgress, ResponseReader, - }; - use crate::{MethodId, RpcCodec, RpcCodecError, RpcError}; + use super::{encode_progress, encode_response, ReadStep, RequestWithProgress, ResponseReader}; + use crate::{MethodId, RpcCodec}; #[derive(Debug, Clone, PartialEq, Eq)] struct BytesValue(Vec); @@ -162,95 +158,7 @@ mod tests { } #[test] - fn request_round_trip_preserves_payload() { - let mut encoded = Vec::new(); - encode_request::(&BytesValue(b"watch".to_vec()), &mut encoded).unwrap(); - assert_eq!( - decode_request::(&encoded).unwrap(), - BytesValue(b"watch".to_vec()) - ); - } - - #[test] - fn response_with_progress_requires_terminal_response() { - let mut encoded = Vec::new(); - encode_progress::(&BytesValue(b"10%".to_vec()), &mut encoded).unwrap(); - - let reader = match ResponseReader::::new() - .push(Bytes::from(encoded)) - .advance() - .unwrap() - { - ReadStep::Progress { value, next } => { - assert_eq!(value, BytesValue(b"10%".to_vec())); - next - } - _ => unreachable!(), - }; - let reader = match reader.advance().unwrap() { - ReadStep::NeedMore(next) => next, - _ => unreachable!(), - }; - let _ = reader; - } - - #[test] - fn response_with_progress_rejects_bytes_after_response() { - let mut encoded = Vec::new(); - encode_progress::(&BytesValue(b"10%".to_vec()), &mut encoded).unwrap(); - encode_response::(&BytesValue(b"done".to_vec()), &mut encoded).unwrap(); - encode_progress::(&BytesValue(b"late".to_vec()), &mut encoded).unwrap(); - - let reader = match ResponseReader::::new() - .push(Bytes::from(encoded)) - .advance() - .unwrap() - { - ReadStep::Progress { next, .. } => next, - _ => unreachable!(), - }; - match reader.advance() { - Err(RpcCodecError::Rpc(RpcError::TrailingBytes)) => {} - _ => unreachable!(), - } - } - - #[test] - fn response_reader_emits_typed_events() { - let mut encoded = Vec::new(); - encode_progress::(&BytesValue(b"10%".to_vec()), &mut encoded).unwrap(); - encode_response::(&BytesValue(b"done".to_vec()), &mut encoded).unwrap(); - - let encoded = Bytes::from(encoded); - let reader = ResponseReader::::new().push(encoded.slice(..4)); - let reader = match reader.advance().unwrap() { - ReadStep::NeedMore(next) => next, - _ => unreachable!(), - }; - let reader = reader.push(encoded.slice(4..encoded.len() - 2)); - let reader = match reader.advance().unwrap() { - ReadStep::Progress { - value: BytesValue(bytes), - next, - } => { - assert_eq!(bytes, b"10%".to_vec()); - next - } - _ => unreachable!(), - }; - let reader = match reader.advance().unwrap() { - ReadStep::NeedMore(next) => next, - _ => unreachable!(), - }; - let reader = reader.push(encoded.slice(encoded.len() - 2..)); - match reader.advance().unwrap() { - ReadStep::Response(value) => assert_eq!(value, BytesValue(b"done".to_vec())), - _ => unreachable!(), - } - } - - #[test] - fn response_progress_then_response_round_trips() { + fn response_reader_emits_progress_then_response() { let mut encoded = Vec::new(); encode_progress::(&BytesValue(b"10%".to_vec()), &mut encoded).unwrap(); encode_response::(&BytesValue(b"done".to_vec()), &mut encoded).unwrap(); @@ -273,7 +181,7 @@ mod tests { } #[test] - fn response_can_be_encoded_without_progress() { + fn response_reader_handles_response_only() { let mut encoded = Vec::new(); encode_response::(&BytesValue(b"done".to_vec()), &mut encoded).unwrap(); diff --git a/ql-rpc/src/rpc/subscription.rs b/ql-rpc/src/rpc/subscription.rs index 78f398e4..2a9a0155 100644 --- a/ql-rpc/src/rpc/subscription.rs +++ b/ql-rpc/src/rpc/subscription.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use bytes::{Buf, BufMut, Bytes}; -use crate::{codec, MethodId, RpcCodec, RpcCodecError, RpcError}; +use crate::{codec, MethodId, ReadValueStep, RpcCodec, RpcCodecError, RpcError, ValueReader}; pub trait Subscription { const METHOD: MethodId; @@ -11,6 +11,9 @@ pub trait Subscription { type Event: RpcCodec; } +pub type RequestReader = ValueReader<::Request>; +pub type RequestReadStep = ReadValueStep<::Request>; + pub enum ReadStep { NeedMore(ResponseReader), Item { @@ -72,13 +75,9 @@ impl ResponseReader { pub fn encode_request( request: &M::Request, - out: &mut impl BufMut, + out: &mut (impl BufMut + AsMut<[u8]>), ) -> Result<(), M::Error> { - request.encode_value(out) -} - -pub fn decode_request(mut body: &[u8]) -> Result { - M::Request::decode_value(&mut body) + codec::encode_value_part(request, out) } pub fn encode_item( @@ -96,10 +95,7 @@ pub fn encode_end(out: &mut impl BufMut) { mod tests { use bytes::{Buf, BufMut, Bytes}; - use super::{ - decode_request, encode_end, encode_item, encode_request, ReadStep, ResponseReader, - Subscription, - }; + use super::{encode_end, encode_item, ReadStep, ResponseReader, Subscription}; use crate::{MethodId, RpcCodec}; #[derive(Debug, Clone, PartialEq, Eq)] @@ -128,17 +124,7 @@ mod tests { } #[test] - fn request_round_trip_preserves_payload() { - let mut encoded = Vec::new(); - encode_request::(&BytesValue(b"watch".to_vec()), &mut encoded).unwrap(); - assert_eq!( - decode_request::(&encoded).unwrap(), - BytesValue(b"watch".to_vec()) - ); - } - - #[test] - fn decode_item_stream_reads_all_items() { + fn response_reader_streams_items_until_end() { let mut encoded = Vec::new(); encode_item::(&BytesValue(b"one".to_vec()), &mut encoded).unwrap(); encode_item::(&BytesValue(b"two".to_vec()), &mut encoded).unwrap(); @@ -166,40 +152,4 @@ mod tests { assert!(matches!(reader.advance().unwrap(), ReadStep::End)); } - - #[test] - fn response_reader_emits_items_as_chunks_arrive() { - let mut encoded = Vec::new(); - encode_item::(&BytesValue(b"one".to_vec()), &mut encoded).unwrap(); - encode_item::(&BytesValue(b"two".to_vec()), &mut encoded).unwrap(); - encode_end(&mut encoded); - - let all = Bytes::from(encoded); - let reader = match ResponseReader::::new() - .push(all.slice(..5)) - .advance() - .unwrap() - { - ReadStep::NeedMore(next) => next, - _ => unreachable!(), - }; - - let reader = match reader.push(all.slice(5..)).advance().unwrap() { - ReadStep::Item { value, next } => { - assert_eq!(value, BytesValue(b"one".to_vec())); - next - } - _ => unreachable!(), - }; - - let reader = match reader.advance().unwrap() { - ReadStep::Item { value, next } => { - assert_eq!(value, BytesValue(b"two".to_vec())); - next - } - _ => unreachable!(), - }; - - assert!(matches!(reader.advance().unwrap(), ReadStep::End)); - } } diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index cee8908d..e55a5c71 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -10,11 +10,12 @@ use ql_rpc::{ request::{self, Request as RequestRpc}, request_with_progress::{self as rpc_request_with_progress, RequestWithProgress}, subscription::{self as rpc_subscription, Subscription as SubscriptionRpc}, + ReadValueStep, RpcCodec, RpcError, ValueReader, }; use ql_wire::{RouteId, VarInt}; pub use self::{error::*, request_with_progress::*, subscription::*}; -use crate::{ByteReader, QlStreamError, RuntimeHandle}; +use crate::{ByteReader, RuntimeHandle}; #[derive(Clone)] pub struct RpcHandle { @@ -45,8 +46,7 @@ impl RpcHandle { let mut payload = Vec::new(); request::encode_request::(request, &mut payload).map_err(RpcCallError::Codec)?; let response = self.start_request(M::METHOD, payload).await?; - let response = read_all(response).await?; - request::decode_response::(&response).map_err(RpcCallError::Codec) + read_value::(response).await } pub async fn subscribe( @@ -97,10 +97,21 @@ impl RpcHandle { } } -async fn read_all(mut reader: ByteReader) -> Result, QlStreamError> { - let mut bytes = Vec::new(); - while let Some(chunk) = poll_fn(|cx| reader.poll_read_chunk(cx)).await? { - bytes.extend_from_slice(&chunk); +async fn read_value(mut reader: ByteReader) -> Result> +where + T: RpcCodec, +{ + let mut value_reader = ValueReader::::new(); + + loop { + match value_reader.advance().map_err(RpcCallError::from)? { + ReadValueStep::Value(value) => return Ok(value), + ReadValueStep::NeedMore(next) => value_reader = next, + } + + match poll_fn(|cx| reader.poll_read_chunk(cx)).await? { + Some(chunk) => value_reader = value_reader.push(chunk), + None => return Err(RpcError::Truncated.into()), + } } - Ok(bytes) } diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index fc82bcda..3e8799a7 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -214,6 +214,10 @@ impl TestPair { .await; } + fn handle(&self, side: Side) -> RuntimeHandle { + self.side(side).handle.clone() + } + fn take_inbound(&mut self, side: Side) -> Receiver { let replacement = async_channel::unbounded().1; std::mem::replace(&mut self.side_mut(side).inbound, replacement) diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index d6b8952a..b6e1a059 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -59,15 +59,12 @@ async fn rpc_request_round_trips() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); - let request = read_all(inbound.reader).await.unwrap(); + let request: BytesValue = read_rpc_value(inbound.reader).await; assert_eq!( inbound.route_id, route_id(::METHOD) ); - assert_eq!( - ql_rpc::request::decode_request::(&request).unwrap(), - BytesValue(b"hello".to_vec()) - ); + assert_eq!(request, BytesValue(b"hello".to_vec())); let mut encoded = Vec::new(); ql_rpc::request::encode_response::(&BytesValue(b"world".to_vec()), &mut encoded) @@ -101,15 +98,12 @@ async fn rpc_subscription_streams_events() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); - let request = read_all(inbound.reader).await.unwrap(); + let request: BytesValue = read_rpc_value(inbound.reader).await; assert_eq!( inbound.route_id, route_id(::METHOD) ); - assert_eq!( - ql_rpc::subscription::decode_request::(&request).unwrap(), - BytesValue(b"watch".to_vec()) - ); + assert_eq!(request, BytesValue(b"watch".to_vec())); let mut encoded = Vec::new(); ql_rpc::subscription::encode_item::(&BytesValue(b"one".to_vec()), &mut encoded) @@ -155,15 +149,12 @@ async fn rpc_request_with_progress_supports_progress_then_await() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); - let request = read_all(inbound.reader).await.unwrap(); + let request: BytesValue = read_rpc_value(inbound.reader).await; assert_eq!( inbound.route_id, route_id(::METHOD) ); - assert_eq!( - ql_rpc::request_with_progress::decode_request::(&request).unwrap(), - BytesValue(b"logo".to_vec()) - ); + assert_eq!(request, BytesValue(b"logo".to_vec())); let mut encoded = Vec::new(); ql_rpc::request_with_progress::encode_progress::( @@ -209,3 +200,23 @@ async fn rpc_request_with_progress_supports_progress_then_await() { fn route_id(method: ql_rpc::MethodId) -> RouteId { RouteId(ql_wire::VarInt::from_u32(method.0)) } + +async fn read_rpc_value(mut reader: crate::ByteReader) -> T +where + T: ql_rpc::RpcCodec, + T::Error: std::fmt::Debug, +{ + let mut value_reader = ql_rpc::ValueReader::::new(); + + loop { + match value_reader.advance().unwrap() { + ql_rpc::ReadValueStep::Value(value) => return value, + ql_rpc::ReadValueStep::NeedMore(next) => value_reader = next, + } + + match reader.read_chunk().await.unwrap() { + Some(chunk) => value_reader = value_reader.push(chunk), + None => panic!("truncated rpc value"), + } + } +} From eb9fcfbb6ec1838246918e99a791f9338001eff2 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 11 Apr 2026 09:39:48 -0400 Subject: [PATCH 193/304] ql-rpc: renaming --- ql-rpc/src/codec.rs | 20 +++++------ ql-rpc/src/error.rs | 26 +++++++------- ql-rpc/src/rpc/request_with_progress.rs | 19 ++++------ ql-rpc/src/rpc/subscription.rs | 10 +++--- ql-runtime/src/rpc/error.rs | 40 ++++++++++----------- ql-runtime/src/rpc/mod.rs | 30 +++++++--------- ql-runtime/src/rpc/request_with_progress.rs | 12 +++---- ql-runtime/src/rpc/subscription.rs | 10 +++--- 8 files changed, 79 insertions(+), 88 deletions(-) diff --git a/ql-rpc/src/codec.rs b/ql-rpc/src/codec.rs index 94b63875..56306e44 100644 --- a/ql-rpc/src/codec.rs +++ b/ql-rpc/src/codec.rs @@ -2,7 +2,7 @@ use std::{collections::VecDeque, marker::PhantomData}; use bytes::{Buf, BufMut, Bytes}; -use crate::{RpcCodec, RpcCodecError, RpcError}; +use crate::{CodecError, Error, RpcCodec}; const LENGTH_SIZE: usize = 8; @@ -45,16 +45,16 @@ impl ValueReader { self } - pub fn advance(self) -> Result, RpcCodecError> { + pub fn advance(self) -> Result, CodecError> { let mut this = self; - let Some(mut body) = this.bytes.try_take_part().map_err(RpcCodecError::Rpc)? else { + let Some(mut body) = this.bytes.try_take_part().map_err(CodecError::Rpc)? else { return Ok(ReadValueStep::NeedMore(this)); }; - let value = T::decode_value(&mut body).map_err(RpcCodecError::Codec)?; + let value = T::decode_value(&mut body).map_err(CodecError::Codec)?; drop(body); if this.bytes.remaining() > 0 { - return Err(RpcCodecError::Rpc(RpcError::TrailingBytes)); + return Err(CodecError::Rpc(Error::TrailingBytes)); } Ok(ReadValueStep::Value(value)) } @@ -83,7 +83,7 @@ impl ChunkQueue { self.remaining } - pub fn try_take_part(&mut self) -> Result>, RpcError> { + pub fn try_take_part(&mut self) -> Result>, Error> { let Some(len) = self.peek_next_part_len()? else { return Ok(None); }; @@ -91,7 +91,7 @@ impl ChunkQueue { Ok(Some(DrainBuf::new(self, len))) } - pub fn try_take_tagged_part(&mut self) -> Result)>, RpcError> { + pub fn try_take_tagged_part(&mut self) -> Result)>, Error> { let mut bytes = self.peek(); let Ok(kind) = bytes.try_get_u8() else { return Ok(None); @@ -104,7 +104,7 @@ impl ChunkQueue { Ok(Some((kind, DrainBuf::new(self, len)))) } - fn peek_next_part_len(&self) -> Result, RpcError> { + fn peek_next_part_len(&self) -> Result, Error> { let mut bytes = self.peek(); read_next_part_len(&mut bytes) } @@ -239,11 +239,11 @@ impl Drop for DrainBuf<'_> { } } -fn read_next_part_len(bytes: &mut B) -> Result, RpcError> { +fn read_next_part_len(bytes: &mut B) -> Result, Error> { let Ok(len) = bytes.try_get_u64_le() else { return Ok(None); }; - let len: usize = len.try_into().map_err(|_| RpcError::LengthOverflow)?; + let len: usize = len.try_into().map_err(|_| Error::LengthOverflow)?; if bytes.remaining() < len { return Ok(None); } diff --git a/ql-rpc/src/error.rs b/ql-rpc/src/error.rs index cc236898..c4d7d6d9 100644 --- a/ql-rpc/src/error.rs +++ b/ql-rpc/src/error.rs @@ -1,5 +1,5 @@ #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum RpcError { +pub enum Error { Truncated, LengthOverflow, UnexpectedFrameKind(u8), @@ -7,7 +7,7 @@ pub enum RpcError { TrailingBytes, } -impl std::fmt::Display for RpcError { +impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Truncated => f.write_str("truncated rpc payload"), @@ -19,22 +19,22 @@ impl std::fmt::Display for RpcError { } } -impl std::error::Error for RpcError {} +impl std::error::Error for Error {} #[derive(Debug, Clone, PartialEq, Eq)] -pub enum RpcCodecError { - Rpc(RpcError), +pub enum CodecError { + Rpc(Error), Codec(E), } -impl std::error::Error for RpcCodecError +impl std::error::Error for CodecError where E: std::error::Error + 'static, { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { - RpcCodecError::Rpc(e) => Some(e), - RpcCodecError::Codec(e) => Some(e), + CodecError::Rpc(e) => Some(e), + CodecError::Codec(e) => Some(e), } } @@ -43,20 +43,20 @@ where } } -impl std::fmt::Display for RpcCodecError +impl std::fmt::Display for CodecError where E: std::fmt::Display, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - RpcCodecError::Rpc(e) => write!(f, "{e}"), - RpcCodecError::Codec(e) => write!(f, "{e}"), + CodecError::Rpc(e) => write!(f, "{e}"), + CodecError::Codec(e) => write!(f, "{e}"), } } } -impl From for RpcCodecError { - fn from(error: RpcError) -> Self { +impl From for CodecError { + fn from(error: Error) -> Self { Self::Rpc(error) } } diff --git a/ql-rpc/src/rpc/request_with_progress.rs b/ql-rpc/src/rpc/request_with_progress.rs index 8631f37d..9617c349 100644 --- a/ql-rpc/src/rpc/request_with_progress.rs +++ b/ql-rpc/src/rpc/request_with_progress.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use bytes::{BufMut, Bytes}; -use crate::{codec, MethodId, ReadValueStep, RpcCodec, RpcCodecError, RpcError, ValueReader}; +use crate::{codec, CodecError, Error, MethodId, ReadValueStep, RpcCodec, ValueReader}; pub trait RequestWithProgress { const METHOD: MethodId; @@ -48,13 +48,10 @@ impl ResponseReader { self } - pub fn advance(self) -> Result, RpcCodecError> { + pub fn advance(self) -> Result, CodecError> { let mut this = self; - let Some((kind, mut body)) = this - .bytes - .try_take_tagged_part() - .map_err(RpcCodecError::Rpc)? + let Some((kind, mut body)) = this.bytes.try_take_tagged_part().map_err(CodecError::Rpc)? else { return Ok(ReadStep::NeedMore(this)); }; @@ -62,24 +59,22 @@ impl ResponseReader { match kind { x if x == FrameKind::Progress as u8 => { let value = { - let value = - M::Progress::decode_value(&mut body).map_err(RpcCodecError::Codec)?; + let value = M::Progress::decode_value(&mut body).map_err(CodecError::Codec)?; drop(body); value }; Ok(ReadStep::Progress { value, next: this }) } x if x == FrameKind::Response as u8 => { - let response = - M::Response::decode_value(&mut body).map_err(RpcCodecError::Codec)?; + let response = M::Response::decode_value(&mut body).map_err(CodecError::Codec)?; drop(body); if this.bytes.remaining() > 0 { - Err(RpcCodecError::Rpc(RpcError::TrailingBytes)) + Err(CodecError::Rpc(Error::TrailingBytes)) } else { Ok(ReadStep::Response(response)) } } - other => Err(RpcCodecError::Rpc(RpcError::UnexpectedFrameKind(other))), + other => Err(CodecError::Rpc(Error::UnexpectedFrameKind(other))), } } } diff --git a/ql-rpc/src/rpc/subscription.rs b/ql-rpc/src/rpc/subscription.rs index 2a9a0155..e87439f8 100644 --- a/ql-rpc/src/rpc/subscription.rs +++ b/ql-rpc/src/rpc/subscription.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use bytes::{Buf, BufMut, Bytes}; -use crate::{codec, MethodId, ReadValueStep, RpcCodec, RpcCodecError, RpcError, ValueReader}; +use crate::{codec, CodecError, Error, MethodId, ReadValueStep, RpcCodec, ValueReader}; pub trait Subscription { const METHOD: MethodId; @@ -47,9 +47,9 @@ impl ResponseReader { self } - pub fn advance(self) -> Result, RpcCodecError> { + pub fn advance(self) -> Result, CodecError> { let mut this = self; - let Some(mut body) = this.bytes.try_take_part().map_err(RpcCodecError::Rpc)? else { + let Some(mut body) = this.bytes.try_take_part().map_err(CodecError::Rpc)? else { return Ok(ReadStep::NeedMore(this)); }; @@ -58,11 +58,11 @@ impl ResponseReader { if this.bytes.remaining() == 0 { return Ok(ReadStep::End); } - return Err(RpcCodecError::Rpc(RpcError::TrailingBytes)); + return Err(CodecError::Rpc(Error::TrailingBytes)); } let item = { - let item = M::Event::decode_value(&mut body).map_err(RpcCodecError::Codec)?; + let item = M::Event::decode_value(&mut body).map_err(CodecError::Codec)?; drop(body); item }; diff --git a/ql-runtime/src/rpc/error.rs b/ql-runtime/src/rpc/error.rs index c82a30ee..9c30ad07 100644 --- a/ql-runtime/src/rpc/error.rs +++ b/ql-runtime/src/rpc/error.rs @@ -4,67 +4,67 @@ use ql_wire::StreamCloseCode; use crate::QlStreamError; #[derive(Debug)] -pub enum RpcCallError { +pub enum RpcError { NoSession, - StreamClosed(StreamCloseCode), - Rpc(ql_rpc::RpcError), + Closed(StreamCloseCode), + Protocol(ql_rpc::Error), Codec(E), } -impl From for RpcCallError { +impl From for RpcError { fn from(_: NoSessionError) -> Self { Self::NoSession } } -impl From for RpcCallError { +impl From for RpcError { fn from(error: QlStreamError) -> Self { match error { - QlStreamError::StreamClosed { code } => Self::StreamClosed(code), + QlStreamError::StreamClosed { code } => Self::Closed(code), QlStreamError::NoSession => Self::NoSession, } } } -impl From for RpcCallError { - fn from(error: ql_rpc::RpcError) -> Self { - Self::Rpc(error) +impl From for RpcError { + fn from(error: ql_rpc::Error) -> Self { + Self::Protocol(error) } } -impl From> for RpcCallError { - fn from(error: ql_rpc::RpcCodecError) -> Self { +impl From> for RpcError { + fn from(error: ql_rpc::CodecError) -> Self { match error { - ql_rpc::RpcCodecError::Rpc(error) => Self::Rpc(error), - ql_rpc::RpcCodecError::Codec(error) => Self::Codec(error), + ql_rpc::CodecError::Rpc(error) => Self::Protocol(error), + ql_rpc::CodecError::Codec(error) => Self::Codec(error), } } } -impl std::fmt::Display for RpcCallError +impl std::fmt::Display for RpcError where E: std::fmt::Display, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::NoSession => write!(f, "no session"), - Self::StreamClosed(code) => write!(f, "stream closed {code:?}"), - Self::Rpc(error) => write!(f, "{error}"), + Self::Closed(code) => write!(f, "stream closed {code:?}"), + Self::Protocol(error) => write!(f, "{error}"), Self::Codec(error) => write!(f, "{error}"), } } } -impl std::error::Error for RpcCallError +impl std::error::Error for RpcError where E: std::error::Error + 'static, { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { - Self::Rpc(error) => Some(error), + Self::Protocol(error) => Some(error), Self::Codec(error) => Some(error), - RpcCallError::NoSession => None, - RpcCallError::StreamClosed(_) => None, + RpcError::NoSession => None, + RpcError::Closed(_) => None, } } } diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index e55a5c71..f2cf57cc 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -10,7 +10,7 @@ use ql_rpc::{ request::{self, Request as RequestRpc}, request_with_progress::{self as rpc_request_with_progress, RequestWithProgress}, subscription::{self as rpc_subscription, Subscription as SubscriptionRpc}, - ReadValueStep, RpcCodec, RpcError, ValueReader, + Error, ReadValueStep, RpcCodec, ValueReader, }; use ql_wire::{RouteId, VarInt}; @@ -23,12 +23,12 @@ pub struct RpcHandle { } impl RpcHandle { - pub async fn event(&self, event: &M::Event) -> Result<(), RpcCallError> + pub async fn event(&self, event: &M::Event) -> Result<(), RpcError> where M: Notification, { let mut payload = Vec::new(); - notification::encode_event::(event, &mut payload).map_err(RpcCallError::Codec)?; + notification::encode_event::(event, &mut payload).map_err(RpcError::Codec)?; let route_id = RouteId(VarInt::from_u32(M::METHOD.0)); let mut stream = self.inner.open_stream(route_id).await?; stream.reader.close(ql_wire::StreamCloseCode(0)); @@ -36,15 +36,12 @@ impl RpcHandle { Ok(()) } - pub async fn request( - &self, - request: &M::Request, - ) -> Result> + pub async fn request(&self, request: &M::Request) -> Result> where M: RequestRpc, { let mut payload = Vec::new(); - request::encode_request::(request, &mut payload).map_err(RpcCallError::Codec)?; + request::encode_request::(request, &mut payload).map_err(RpcError::Codec)?; let response = self.start_request(M::METHOD, payload).await?; read_value::(response).await } @@ -52,13 +49,12 @@ impl RpcHandle { pub async fn subscribe( &self, request: &M::Request, - ) -> Result, RpcCallError> + ) -> Result, RpcError> where M: SubscriptionRpc, { let mut payload = Vec::new(); - rpc_subscription::encode_request::(request, &mut payload) - .map_err(RpcCallError::Codec)?; + rpc_subscription::encode_request::(request, &mut payload).map_err(RpcError::Codec)?; let response = self.start_request(M::METHOD, payload).await?; Ok(Subscription { stream: response, @@ -69,13 +65,13 @@ impl RpcHandle { pub async fn request_with_progress( &self, request: &M::Request, - ) -> Result, RpcCallError> + ) -> Result, RpcError> where M: RequestWithProgress, { let mut payload = Vec::new(); rpc_request_with_progress::encode_request::(request, &mut payload) - .map_err(RpcCallError::Codec)?; + .map_err(RpcError::Codec)?; let response = self.start_request(M::METHOD, payload).await?; Ok(ProgressCall { stream: response, @@ -88,7 +84,7 @@ impl RpcHandle { &self, method: ql_rpc::MethodId, payload: Vec, - ) -> Result> { + ) -> Result> { let route_id = RouteId(VarInt::from_u32(method.0)); let mut stream = self.inner.open_stream(route_id).await?; stream.writer.write(Bytes::from(payload)).await?; @@ -97,21 +93,21 @@ impl RpcHandle { } } -async fn read_value(mut reader: ByteReader) -> Result> +async fn read_value(mut reader: ByteReader) -> Result> where T: RpcCodec, { let mut value_reader = ValueReader::::new(); loop { - match value_reader.advance().map_err(RpcCallError::from)? { + match value_reader.advance().map_err(RpcError::from)? { ReadValueStep::Value(value) => return Ok(value), ReadValueStep::NeedMore(next) => value_reader = next, } match poll_fn(|cx| reader.poll_read_chunk(cx)).await? { Some(chunk) => value_reader = value_reader.push(chunk), - None => return Err(RpcError::Truncated.into()), + None => return Err(Error::Truncated.into()), } } } diff --git a/ql-runtime/src/rpc/request_with_progress.rs b/ql-runtime/src/rpc/request_with_progress.rs index cda35908..fe2e78e7 100644 --- a/ql-runtime/src/rpc/request_with_progress.rs +++ b/ql-runtime/src/rpc/request_with_progress.rs @@ -7,16 +7,16 @@ use std::{ use futures_lite::{future::poll_fn, Stream}; use ql_rpc::{ request_with_progress::{ReadStep, RequestWithProgress}, - RpcError, + Error, }; -use super::RpcCallError; +use super::RpcError; use crate::ByteReader; pub struct ProgressCall { pub(super) stream: ByteReader, pub(super) reader: Option>, - pub(super) terminal: Option>>, + pub(super) terminal: Option>>, } impl Unpin for ProgressCall where M: RequestWithProgress {} @@ -70,7 +70,7 @@ where } Poll::Ready(Ok(None)) => { this.reader = None; - this.terminal = Some(Err(RpcError::MissingResponse.into())); + this.terminal = Some(Err(Error::MissingResponse.into())); return Poll::Ready(None); } Poll::Ready(Err(error)) => { @@ -88,7 +88,7 @@ impl Future for ProgressCall where M: RequestWithProgress, { - type Output = Result>; + type Output = Result>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); @@ -122,7 +122,7 @@ where } Poll::Ready(Ok(None)) => { this.reader = None; - return Poll::Ready(Err(RpcError::MissingResponse.into())); + return Poll::Ready(Err(Error::MissingResponse.into())); } Poll::Ready(Err(error)) => { this.reader = None; diff --git a/ql-runtime/src/rpc/subscription.rs b/ql-runtime/src/rpc/subscription.rs index 10648172..40f00537 100644 --- a/ql-runtime/src/rpc/subscription.rs +++ b/ql-runtime/src/rpc/subscription.rs @@ -6,10 +6,10 @@ use std::{ use futures_lite::{future::poll_fn, Stream}; use ql_rpc::{ subscription::{ReadStep, Subscription as SubscriptionRpc}, - RpcError, + Error, }; -use super::RpcCallError; +use super::RpcError; use crate::ByteReader; pub struct Subscription { @@ -23,7 +23,7 @@ impl Subscription where M: SubscriptionRpc, { - pub async fn next_event(&mut self) -> Option>> { + pub async fn next_event(&mut self) -> Option>> { poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await } } @@ -32,7 +32,7 @@ impl Stream for Subscription where M: SubscriptionRpc, { - type Item = Result>; + type Item = Result>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); @@ -61,7 +61,7 @@ where } Poll::Ready(Ok(None)) => { this.reader = None; - return Poll::Ready(Some(Err(RpcError::Truncated.into()))); + return Poll::Ready(Some(Err(Error::Truncated.into()))); } Poll::Ready(Err(error)) => { this.reader = None; From 9f833cca2fbd1886d13affb80adbf6e7a11680d2 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 11 Apr 2026 11:02:17 -0400 Subject: [PATCH 194/304] ql: use err method --- ql-fsm/src/fsm.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs index 60b1074a..9c0bb978 100644 --- a/ql-fsm/src/fsm.rs +++ b/ql-fsm/src/fsm.rs @@ -53,7 +53,7 @@ pub fn receive( } wire::RecordType::Session => { let QlFsm { state, events, .. } = fsm; - let conn = state.link.connected_mut().ok_or(ReceiveError::NoSession)?; + let conn = state.link.connected_mut_or_err()?; let (decrypt_len, seq) = { let record = wire::QlSessionRecord::decode(&mut reader)?; if record.header.connection_id != conn.transport.rx_connection_id { From 3e26dc95619a53b5c5c71d1f54c432974efe05f1 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 11 Apr 2026 11:10:41 -0400 Subject: [PATCH 195/304] ql: bufview --- ql-fsm/src/session/stream_tx.rs | 86 +++++++---- ql-wire/src/bytes.rs | 205 ++++++++++----------------- ql-wire/src/encrypted/builder.rs | 6 +- ql-wire/src/encrypted/mod.rs | 4 +- ql-wire/src/encrypted/stream_data.rs | 16 ++- 5 files changed, 149 insertions(+), 168 deletions(-) diff --git a/ql-fsm/src/session/stream_tx.rs b/ql-fsm/src/session/stream_tx.rs index 793fea3c..d697e4f8 100644 --- a/ql-fsm/src/session/stream_tx.rs +++ b/ql-fsm/src/session/stream_tx.rs @@ -1,7 +1,7 @@ use std::{collections::VecDeque, ops::Range}; use bytes::{Buf, Bytes}; -use ql_wire::ByteChunks; +use ql_wire::BufView; use super::range_set::RangeSet; @@ -44,42 +44,39 @@ pub struct StreamTxBytes<'a> { len: usize, } -pub struct StreamTxBytesIter<'a> { +pub struct StreamTxBuf<'a> { inner: std::collections::vec_deque::Iter<'a, Bytes>, skip: usize, remaining: usize, + current: &'a [u8], } -impl ByteChunks for StreamTxBytes<'_> { - type Chunks<'a> - = StreamTxBytesIter<'a> +impl BufView for StreamTxBytes<'_> { + type Buf<'a> + = StreamTxBuf<'a> where Self: 'a; - fn len(&self) -> usize { - self.inner - .iter() - .map(Bytes::len) - .sum::() - .saturating_sub(self.offset) - .min(self.len) - } - - fn chunks(&self) -> Self::Chunks<'_> { - StreamTxBytesIter { + fn buf(&self) -> Self::Buf<'_> { + let mut buf = StreamTxBuf { inner: self.inner.iter(), skip: self.offset, - remaining: self.len(), - } + remaining: self.len, + current: &[], + }; + buf.refill(); + buf } } -impl<'a> Iterator for StreamTxBytesIter<'a> { - type Item = &'a [u8]; +impl<'a> StreamTxBuf<'a> { + fn refill(&mut self) { + if self.remaining == 0 { + self.current = &[]; + return; + } - fn next(&mut self) -> Option { - while self.remaining > 0 { - let chunk = self.inner.next()?; + while let Some(chunk) = self.inner.next() { if self.skip >= chunk.len() { self.skip -= chunk.len(); continue; @@ -92,11 +89,45 @@ impl<'a> Iterator for StreamTxBytesIter<'a> { } let len = chunk.len().min(self.remaining); - self.remaining -= len; - return Some(&chunk[..len]); + self.current = &chunk[..len]; + return; } - None + self.current = &[]; + } +} + +impl Buf for StreamTxBuf<'_> { + fn remaining(&self) -> usize { + self.remaining + } + + fn chunk(&self) -> &[u8] { + self.current + } + + fn advance(&mut self, cnt: usize) { + let remaining = self.remaining; + assert!( + cnt <= remaining, + "cannot advance past remaining bytes: {cnt} > {remaining}", + ); + + self.remaining -= cnt; + let mut cnt = cnt; + while cnt > 0 { + if cnt < self.current.len() { + self.current = &self.current[cnt..]; + return; + } + + cnt -= self.current.len(); + self.refill(); + } + + if self.remaining == 0 { + self.current = &[]; + } } } @@ -197,10 +228,11 @@ impl StreamTx { pub fn ranged_bytes(&self, range: StreamTxRange) -> StreamTxBytes<'_> { let offset = usize::try_from(range.offset - self.base_offset).unwrap(); + let len = range.len.min(self.buffered_len.saturating_sub(offset)); StreamTxBytes { inner: &self.chunks, offset, - len: range.len, + len, } } diff --git a/ql-wire/src/bytes.rs b/ql-wire/src/bytes.rs index c8243e12..9fecf5ea 100644 --- a/ql-wire/src/bytes.rs +++ b/ql-wire/src/bytes.rs @@ -1,10 +1,6 @@ -use core::{ - iter::{once, Chain, Once}, - ops::{Deref, DerefMut}, -}; -use std::collections::VecDeque; +use core::ops::{Deref, DerefMut}; -use bytes::Bytes; +use bytes::{Buf, Bytes}; /// A mutable or immutable byte slice owner used by the wire parser. pub trait ByteSlice: Deref + Sized { @@ -17,167 +13,125 @@ pub trait ByteSlice: Deref + Sized { /// A mutable reference to bytes. pub trait ByteSliceMut: ByteSlice + DerefMut {} -/// A byte container that can be encoded from one or more chunks. -pub trait ByteChunks { - type Chunks<'a>: Iterator - where - Self: 'a; - - fn len(&self) -> usize; - - fn chunks(&self) -> Self::Chunks<'_>; +impl ByteSliceMut for B where B: ByteSlice + DerefMut {} - fn is_empty(&self) -> bool { - self.len() == 0 +impl ByteSlice for &[u8] { + #[inline] + fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { + if mid <= self.len() { + Ok(<[u8]>::split_at(self, mid)) + } else { + Err(self) + } } } -impl ByteSliceMut for B where B: ByteSlice + DerefMut {} - -impl ByteChunks for &T { - type Chunks<'a> - = T::Chunks<'a> - where - Self: 'a; - - fn len(&self) -> usize { - (*self).len() +impl ByteSlice for &mut [u8] { + #[inline] + fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { + if mid <= self.len() { + Ok(<[u8]>::split_at_mut(self, mid)) + } else { + Err(self) + } } +} - fn chunks(&self) -> Self::Chunks<'_> { - (*self).chunks() +impl ByteSlice for Bytes { + #[inline] + fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { + if mid <= self.len() { + Ok((self.slice(..mid), self.slice(mid..))) + } else { + Err(self) + } } } -impl ByteChunks for &mut T { - type Chunks<'a> - = T::Chunks<'a> +/// A byte container that can expose a replayable [`Buf`] view for encoding. +pub trait BufView { + type Buf<'a>: Buf where Self: 'a; - fn len(&self) -> usize { - (**self).len() - } + fn buf(&self) -> Self::Buf<'_>; - fn chunks(&self) -> Self::Chunks<'_> { - (**self).chunks() + fn is_empty(&self) -> bool { + self.buf().remaining() == 0 } } -impl ByteChunks for [u8] { - type Chunks<'a> - = Once<&'a [u8]> +impl BufView for &T { + type Buf<'a> + = T::Buf<'a> where Self: 'a; - fn len(&self) -> usize { - <[u8]>::len(self) - } - - fn chunks(&self) -> Self::Chunks<'_> { - once(self) + fn buf(&self) -> Self::Buf<'_> { + (*self).buf() } } -impl ByteChunks for [u8; N] { - type Chunks<'a> - = Once<&'a [u8]> +impl BufView for &mut T { + type Buf<'a> + = T::Buf<'a> where Self: 'a; - fn len(&self) -> usize { - N - } - - fn chunks(&self) -> Self::Chunks<'_> { - once(self.as_slice()) + fn buf(&self) -> Self::Buf<'_> { + (**self).buf() } } -impl ByteChunks for Vec { - type Chunks<'a> - = Once<&'a [u8]> +impl BufView for [u8] { + type Buf<'a> + = &'a [u8] where Self: 'a; - fn len(&self) -> usize { - Self::len(self) - } - - fn chunks(&self) -> Self::Chunks<'_> { - once(self.as_slice()) + fn buf(&self) -> Self::Buf<'_> { + self } } -impl ByteChunks for Bytes { - type Chunks<'a> - = Once<&'a [u8]> +impl BufView for [u8; N] { + type Buf<'a> + = &'a [u8] where Self: 'a; - fn len(&self) -> usize { - Self::len(self) - } - - fn chunks(&self) -> Self::Chunks<'_> { - once(self.as_ref()) + fn buf(&self) -> Self::Buf<'_> { + self.as_slice() } } -impl ByteChunks for VecDeque { - type Chunks<'a> - = Chain, Once<&'a [u8]>> +impl BufView for Vec { + type Buf<'a> + = &'a [u8] where Self: 'a; - fn len(&self) -> usize { - Self::len(self) - } - - fn chunks(&self) -> Self::Chunks<'_> { - let (first, second) = self.as_slices(); - once(first).chain(once(second)) + fn buf(&self) -> Self::Buf<'_> { + self.as_slice() } } -impl ByteSlice for &[u8] { - #[inline] - fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { - if mid <= self.len() { - Ok(<[u8]>::split_at(self, mid)) - } else { - Err(self) - } - } -} - -impl ByteSlice for &mut [u8] { - #[inline] - fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { - if mid <= self.len() { - Ok(<[u8]>::split_at_mut(self, mid)) - } else { - Err(self) - } - } -} +impl BufView for Bytes { + type Buf<'a> + = &'a [u8] + where + Self: 'a; -impl ByteSlice for Bytes { - #[inline] - fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { - if mid <= self.len() { - Ok((self.slice(..mid), self.slice(mid..))) - } else { - Err(self) - } + fn buf(&self) -> Self::Buf<'_> { + self.as_ref() } } #[cfg(test)] mod tests { - use std::collections::VecDeque; + use bytes::Buf; - use super::{ByteChunks, ByteSlice, ByteSliceMut}; + use super::{BufView, ByteSlice, ByteSliceMut}; #[test] fn shared_slice_split_at() { @@ -210,23 +164,12 @@ mod tests { } #[test] - fn slice_byte_chunks_are_contiguous() { + fn slice_buf_view_is_contiguous() { let bytes: &[u8] = b"abcdef"; - let chunks = ByteChunks::chunks(&bytes).collect::>(); - assert_eq!(bytes.len(), 6); - assert_eq!(chunks, vec![b"abcdef".as_slice()]); - } - - #[test] - fn vec_deque_byte_chunks_preserve_split_storage() { - let mut bytes = VecDeque::with_capacity(8); - bytes.extend(b"abcd".iter().copied()); - bytes.drain(..2); - bytes.extend(b"efgh".iter().copied()); - - let chunks = ByteChunks::chunks(&bytes).collect::>(); - assert_eq!(bytes.len(), 6); - assert_eq!(chunks.concat(), b"cdefgh"); - assert!(!chunks.is_empty()); + let mut buf = bytes.buf(); + assert_eq!(buf.remaining(), 6); + assert_eq!(buf.chunk(), b"abcdef"); + buf.advance(6); + assert!(!buf.has_remaining()); } } diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index 3d926ece..65ad053a 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -2,7 +2,7 @@ use bytes::BufMut; use super::{RecordAck, SessionClose, SessionFrame, StreamClose, StreamData, StreamWindow}; use crate::{ - ByteChunks, ConnectionId, Nonce, QlCrypto, RecordSeq, RecordType, SessionHeader, SessionKey, + BufView, ConnectionId, Nonce, QlCrypto, RecordSeq, RecordType, SessionHeader, SessionKey, VarInt, WireEncode, QL_WIRE_VERSION, }; @@ -70,7 +70,7 @@ impl SessionRecordBuilder { self.push_frame_payload(super::SessionFrameKind::Ack, ack) } - pub fn push_stream_data(&mut self, frame: &StreamData) -> bool { + pub fn push_stream_data(&mut self, frame: &StreamData) -> bool { self.push_len_prefixed_frame(super::SessionFrameKind::StreamData, frame) } @@ -86,7 +86,7 @@ impl SessionRecordBuilder { self.push_frame_payload(super::SessionFrameKind::Close, close) } - pub fn push_frame(&mut self, frame: &SessionFrame) -> bool { + pub fn push_frame(&mut self, frame: &SessionFrame) -> bool { match frame { SessionFrame::Ping => self.push_ping(), SessionFrame::Ack(frame) => self.push_ack(frame), diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index bb8b148f..a25e854d 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -1,5 +1,5 @@ use crate::{ - codec, encrypted_message::EncryptedMessage, ByteChunks, ByteSlice, Nonce, QlCrypto, Reader, + codec, encrypted_message::EncryptedMessage, BufView, ByteSlice, Nonce, QlCrypto, Reader, SessionHeader, SessionKey, VarInt, WireDecode, WireEncode, WireError, }; @@ -77,7 +77,7 @@ impl SessionFrame { } } -impl WireEncode for SessionFrame { +impl WireEncode for SessionFrame { fn encoded_len(&self) -> usize { 1 + match self { Self::Ping => 0, diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs index bc825247..2ffc480c 100644 --- a/ql-wire/src/encrypted/stream_data.rs +++ b/ql-wire/src/encrypted/stream_data.rs @@ -1,5 +1,7 @@ +use bytes::Buf; + use super::{RouteId, StreamId}; -use crate::{codec, ByteChunks, ByteSlice, VarInt, WireDecode, WireEncode, WireError}; +use crate::{codec, BufView, ByteSlice, VarInt, WireDecode, WireEncode, WireError}; /// carries bytes for a stream and may finish that sending direction. #[derive(Debug, Clone, PartialEq, Eq)] @@ -55,13 +57,14 @@ impl StreamData { } } -impl WireEncode for StreamData { +impl WireEncode for StreamData { fn encoded_len(&self) -> usize { + let bytes = self.bytes.buf(); self.stream_id.encoded_len() + self.offset.encoded_len() + size_of::() + self.header.as_ref().map_or(0, WireEncode::encoded_len) - + self.bytes.len() + + bytes.remaining() } fn encode(&self, out: &mut W) { @@ -83,8 +86,11 @@ impl WireEncode for StreamData { if let Some(header) = &self.header { header.encode(out); } - for chunk in self.bytes.chunks() { - chunk.encode(out); + let mut bytes = self.bytes.buf(); + while bytes.has_remaining() { + let chunk = bytes.chunk(); + out.put_slice(chunk); + bytes.advance(chunk.len()); } } } From effb4448d6c9155d2841fedb68aeb2041ef92215 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 11 Apr 2026 11:49:16 -0400 Subject: [PATCH 196/304] ql: steam close code constants --- ql-fsm/src/session/tests.rs | 14 +++++++------- ql-fsm/src/tests/proptest.rs | 2 +- ql-fsm/src/tests/session.rs | 7 ++++++- ql-runtime/src/driver/mod.rs | 6 +++--- ql-runtime/src/driver/test.rs | 4 ++-- ql-runtime/src/handle/reader.rs | 2 +- ql-runtime/src/handle/writer.rs | 2 +- ql-runtime/src/rpc/mod.rs | 2 +- ql-runtime/src/tests/stream.rs | 4 ++-- ql-wire/src/encrypted/stream_close.rs | 13 +++++++++++++ ql-wire/src/tests.rs | 2 +- 11 files changed, 38 insertions(+), 20 deletions(-) diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 017bd91d..ccb7da0d 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -300,7 +300,7 @@ fn remote_stream_close_is_reliable_and_retried() { fsm.stream(stream_id) .unwrap() - .close(CloseTarget::Both, StreamCloseCode(0)); + .close(CloseTarget::Both, StreamCloseCode::CANCELLED); let (write_id, builder) = fsm.take_next_write(now).unwrap(); fsm.complete_write(now, write_id.expect("stream close should be tracked"), true); @@ -491,12 +491,12 @@ fn out_of_order_remote_stream_first_observations_still_open_once_each() { let close3 = vec![SessionFrame::StreamClose(StreamClose { stream_id: stream_id(3), target: CloseTarget::Both, - code: StreamCloseCode(1), + code: StreamCloseCode::REFUSED, })]; let close1 = vec![SessionFrame::StreamClose(StreamClose { stream_id: stream_id(1), target: CloseTarget::Both, - code: StreamCloseCode(2), + code: StreamCloseCode::TIMEOUT, })]; let first = receive_events(&mut fsm, now, seq(1), &close3); @@ -506,12 +506,12 @@ fn out_of_order_remote_stream_first_observations_still_open_once_each() { SessionEvent::Closed(StreamClose { stream_id: stream_id(3), target: CloseTarget::Both, - code: StreamCloseCode(1), + code: StreamCloseCode::REFUSED, }), SessionEvent::WritableClosed(StreamClose { stream_id: stream_id(3), target: CloseTarget::Both, - code: StreamCloseCode(1), + code: StreamCloseCode::REFUSED, }), ] ); @@ -523,12 +523,12 @@ fn out_of_order_remote_stream_first_observations_still_open_once_each() { SessionEvent::Closed(StreamClose { stream_id: stream_id(1), target: CloseTarget::Both, - code: StreamCloseCode(2), + code: StreamCloseCode::TIMEOUT, }), SessionEvent::WritableClosed(StreamClose { stream_id: stream_id(1), target: CloseTarget::Both, - code: StreamCloseCode(2), + code: StreamCloseCode::TIMEOUT, }), ] ); diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index fc6d702b..dbffebb5 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -334,7 +334,7 @@ impl Runner { let closed = if let Ok(mut stream) = self.harness.node_mut(*side).fsm.stream(stream_id) { - stream.close(CloseTarget::Both, StreamCloseCode(0)); + stream.close(CloseTarget::Both, StreamCloseCode::CANCELLED); true } else { false diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 36faee20..8fd9b9e6 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -208,7 +208,12 @@ fn disconnected_stream_operations_fail_with_no_session() { .a .fsm .stream(missing) - .map(|mut stream| stream.close(ql_wire::CloseTarget::Both, ql_wire::StreamCloseCode(0))), + .map(|mut stream| { + stream.close( + ql_wire::CloseTarget::Both, + ql_wire::StreamCloseCode::CANCELLED, + ) + }), Err(StreamError::NoSession) ); assert_eq!(harness.a.fsm.queue_ping(), Err(NoSessionError)); diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index a070dedd..a46059d3 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -192,7 +192,7 @@ impl DriverState { stream.inbound_close(); stream.outbound_close(); } - stream_ops.close(CloseTarget::Both, StreamCloseCode(0)); + stream_ops.close(CloseTarget::Both, StreamCloseCode::CANCELLED); drop(stream_ops); return; } @@ -285,7 +285,7 @@ impl DriverState { ) { let Some(runtime_tx) = self.runtime_tx.upgrade() else { if let Ok(mut stream) = fsm.stream(stream_id) { - stream.close(CloseTarget::Both, StreamCloseCode(0)); + stream.close(CloseTarget::Both, StreamCloseCode::CANCELLED); } return; }; @@ -362,7 +362,7 @@ impl DriverState { stream_ops.commit_read(accepted).unwrap(); } if peer_closed { - stream_ops.close(target, StreamCloseCode(0)); + stream_ops.close(target, StreamCloseCode::CANCELLED); if let Entry::Occupied(entry) = self.streams.entry(stream_id) { Self::try_reap_stream(entry); } diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index 644091c7..9d4eacea 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -99,7 +99,7 @@ fn handle_closed_stream_reaps_when_both_halves_close() { state.handle_closed_stream(&StreamClose { stream_id, target: CloseTarget::Both, - code: StreamCloseCode(0), + code: StreamCloseCode::CANCELLED, }); assert!(!state.streams.contains_key(&stream_id)); @@ -148,7 +148,7 @@ fn local_close_command_reaps_when_other_half_is_already_closed() { RuntimeCommand::CloseStream { stream_id, target: CloseTarget::Origin, - code: StreamCloseCode(0), + code: StreamCloseCode::CANCELLED, }, &NoopCrypto, ); diff --git a/ql-runtime/src/handle/reader.rs b/ql-runtime/src/handle/reader.rs index cec47ae5..504fdb4e 100644 --- a/ql-runtime/src/handle/reader.rs +++ b/ql-runtime/src/handle/reader.rs @@ -148,7 +148,7 @@ impl Drop for ByteReader { self.handle.send(RuntimeCommand::CloseStream { stream_id: self.stream_id, target: self.target, - code: StreamCloseCode(0), + code: StreamCloseCode::CANCELLED, }); } } diff --git a/ql-runtime/src/handle/writer.rs b/ql-runtime/src/handle/writer.rs index c0b8b333..ff1375a5 100644 --- a/ql-runtime/src/handle/writer.rs +++ b/ql-runtime/src/handle/writer.rs @@ -79,7 +79,7 @@ impl ByteWriter { impl Drop for ByteWriter { fn drop(&mut self) { - self.close_inner(StreamCloseCode(0)); + self.close_inner(StreamCloseCode::CANCELLED); } } diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index f2cf57cc..51b9c58c 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -31,7 +31,7 @@ impl RpcHandle { notification::encode_event::(event, &mut payload).map_err(RpcError::Codec)?; let route_id = RouteId(VarInt::from_u32(M::METHOD.0)); let mut stream = self.inner.open_stream(route_id).await?; - stream.reader.close(ql_wire::StreamCloseCode(0)); + stream.reader.close(ql_wire::StreamCloseCode::CANCELLED); stream.writer.write(Bytes::from(payload)).await?; Ok(()) } diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index 85f63c8c..a6f960d3 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -181,7 +181,7 @@ async fn dropping_responder_closes_initiator_response() { let err = next_chunk(&mut stream.reader).await.unwrap_err(); assert!(matches!( err, - QlStreamError::StreamClosed { code } if code == StreamCloseCode(0) + QlStreamError::StreamClosed { code } if code == StreamCloseCode::CANCELLED )); tokio::time::timeout(Duration::from_secs(2), responder) @@ -257,7 +257,7 @@ async fn closing_initiator_reader_preserves_initiator_writer() { .await .unwrap(); let mut writer = stream.writer; - stream.reader.close(StreamCloseCode(0)); + stream.reader.close(StreamCloseCode::CANCELLED); writer.write(Bytes::from_static(&[1, 2])).await.unwrap(); writer.write(Bytes::from_static(&[3, 4])).await.unwrap(); diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index 2fc03eb9..66721b01 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -88,6 +88,19 @@ impl codec::WireDecode for CloseTarget { #[repr(transparent)] pub struct StreamCloseCode(pub u16); +impl StreamCloseCode { + /// the stream was aborted intentionally before graceful completion + pub const CANCELLED: Self = Self(0); + /// the peer declined to service the stream + pub const REFUSED: Self = Self(1); + /// the stream was aborted because progress took too long + pub const TIMEOUT: Self = Self(2); + /// the stream exceeded a size, quota, or buffering limit + pub const LIMIT: Self = Self(3); + /// the stream route was not recognized by the peer + pub const UNKNOWN_ROUTE: Self = Self(4); +} + impl codec::WireDecode for StreamCloseCode { fn decode(reader: &mut codec::Reader) -> Result { Ok(Self(reader.decode()?)) diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 713f3189..52184a0e 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -670,7 +670,7 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { SessionFrame::StreamClose(StreamClose { stream_id: stream_id(9), target: CloseTarget::Both, - code: StreamCloseCode(0), + code: StreamCloseCode::CANCELLED, }), SessionFrame::Close(SessionClose { code: SessionCloseCode::TIMEOUT, From 2a02eff4c3ad384ba5c25fec1416b7efb7846ed8 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sun, 12 Apr 2026 09:44:31 -0400 Subject: [PATCH 197/304] ql-rpc: router --- ql-fsm/src/session/tests.rs | 17 +-- ql-fsm/src/tests/proptest.rs | 2 +- ql-fsm/src/tests/session.rs | 18 ++- ql-rpc/src/codec.rs | 24 ++-- ql-rpc/src/lib.rs | 41 +++++-- ql-rpc/src/router/builder.rs | 62 +++++++++++ ql-rpc/src/router/config.rs | 12 ++ ql-rpc/src/router/mod.rs | 45 ++++++++ ql-rpc/src/router/request.rs | 110 ++++++++++++++++++ ql-rpc/src/router/stream.rs | 59 ++++++++++ ql-rpc/src/rpc/notification.rs | 12 +- ql-rpc/src/rpc/request.rs | 19 +--- ql-rpc/src/rpc/request_with_progress.rs | 31 +++--- ql-rpc/src/rpc/subscription.rs | 25 ++--- ql-runtime/src/chunk_slot.rs | 56 ++++++---- ql-runtime/src/handle/writer.rs | 71 +++++++++--- ql-runtime/src/rpc/adapter.rs | 73 ++++++++++++ ql-runtime/src/rpc/error.rs | 5 +- ql-runtime/src/rpc/mod.rs | 28 +++-- ql-runtime/src/tests/mod.rs | 4 +- ql-runtime/src/tests/rpc.rs | 141 +++++++++++++++++++----- ql-wire/src/encrypted/route_id.rs | 22 +++- ql-wire/src/encrypted/stream_close.rs | 8 -- 23 files changed, 696 insertions(+), 189 deletions(-) create mode 100644 ql-rpc/src/router/builder.rs create mode 100644 ql-rpc/src/router/config.rs create mode 100644 ql-rpc/src/router/mod.rs create mode 100644 ql-rpc/src/router/request.rs create mode 100644 ql-rpc/src/router/stream.rs create mode 100644 ql-runtime/src/rpc/adapter.rs diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index ccb7da0d..ad700276 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -23,9 +23,12 @@ fn offset(value: u64) -> VarInt { } fn route_id(value: u64) -> RouteId { - RouteId(VarInt::from_u64(value).unwrap()) + RouteId::from_u64(value).unwrap() } +const REFUSED: StreamCloseCode = StreamCloseCode(1); +const TIMEOUT: StreamCloseCode = StreamCloseCode(2); + fn header(value: u64) -> Option { Some(StreamHeader { route_id: route_id(value), @@ -491,12 +494,12 @@ fn out_of_order_remote_stream_first_observations_still_open_once_each() { let close3 = vec![SessionFrame::StreamClose(StreamClose { stream_id: stream_id(3), target: CloseTarget::Both, - code: StreamCloseCode::REFUSED, + code: REFUSED, })]; let close1 = vec![SessionFrame::StreamClose(StreamClose { stream_id: stream_id(1), target: CloseTarget::Both, - code: StreamCloseCode::TIMEOUT, + code: TIMEOUT, })]; let first = receive_events(&mut fsm, now, seq(1), &close3); @@ -506,12 +509,12 @@ fn out_of_order_remote_stream_first_observations_still_open_once_each() { SessionEvent::Closed(StreamClose { stream_id: stream_id(3), target: CloseTarget::Both, - code: StreamCloseCode::REFUSED, + code: REFUSED, }), SessionEvent::WritableClosed(StreamClose { stream_id: stream_id(3), target: CloseTarget::Both, - code: StreamCloseCode::REFUSED, + code: REFUSED, }), ] ); @@ -523,12 +526,12 @@ fn out_of_order_remote_stream_first_observations_still_open_once_each() { SessionEvent::Closed(StreamClose { stream_id: stream_id(1), target: CloseTarget::Both, - code: StreamCloseCode::TIMEOUT, + code: TIMEOUT, }), SessionEvent::WritableClosed(StreamClose { stream_id: stream_id(1), target: CloseTarget::Both, - code: StreamCloseCode::TIMEOUT, + code: TIMEOUT, }), ] ); diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index dbffebb5..c06002d9 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -12,7 +12,7 @@ use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; use super::*; fn test_route_id() -> ql_wire::RouteId { - ql_wire::RouteId(ql_wire::VarInt::from_u32(1)) + ql_wire::RouteId::from_u32(1) } use crate::{state::LinkState, Event, PeerStatus, ReceiveError, WriteId}; diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 8fd9b9e6..793409b2 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -11,7 +11,7 @@ fn stream_id(value: u32) -> StreamId { } fn route_id(value: u32) -> RouteId { - RouteId(VarInt::from_u32(value)) + RouteId::from_u32(value) } fn opened(stream_id: StreamId) -> Event { @@ -204,16 +204,12 @@ fn disconnected_stream_operations_fail_with_no_session() { Err(StreamError::NoSession) ); assert_eq!( - harness - .a - .fsm - .stream(missing) - .map(|mut stream| { - stream.close( - ql_wire::CloseTarget::Both, - ql_wire::StreamCloseCode::CANCELLED, - ) - }), + harness.a.fsm.stream(missing).map(|mut stream| { + stream.close( + ql_wire::CloseTarget::Both, + ql_wire::StreamCloseCode::CANCELLED, + ) + }), Err(StreamError::NoSession) ); assert_eq!(harness.a.fsm.queue_ping(), Err(NoSessionError)); diff --git a/ql-rpc/src/codec.rs b/ql-rpc/src/codec.rs index 56306e44..dab369c6 100644 --- a/ql-rpc/src/codec.rs +++ b/ql-rpc/src/codec.rs @@ -2,18 +2,21 @@ use std::{collections::VecDeque, marker::PhantomData}; use bytes::{Buf, BufMut, Bytes}; -use crate::{CodecError, Error, RpcCodec}; +use crate::{CodecError, Error}; + +pub trait RpcCodec: Sized { + type Error; + + fn encode_value(&self, out: &mut B); + fn decode_value(bytes: &mut B) -> Result; +} const LENGTH_SIZE: usize = 8; -pub fn encode_value_part>( - value: &T, - out: &mut B, -) -> Result<(), T::Error> { +pub fn encode_value_part>(value: &T, out: &mut B) { let payload_start = reserve_length(out); - value.encode_value(out)?; + value.encode_value(out); backpatch_length(out, payload_start); - Ok(()) } pub enum ReadValueStep { @@ -282,9 +285,8 @@ mod tests { impl RpcCodec for BytesValue { type Error = core::convert::Infallible; - fn encode_value(&self, out: &mut B) -> Result<(), Self::Error> { + fn encode_value(&self, out: &mut B) { out.put_slice(&self.0); - Ok(()) } fn decode_value(bytes: &mut B) -> Result { @@ -295,7 +297,7 @@ mod tests { #[test] fn value_reader_round_trips_framed_values() { let mut encoded = Vec::new(); - encode_value_part(&BytesValue(b"hello".to_vec()), &mut encoded).unwrap(); + encode_value_part(&BytesValue(b"hello".to_vec()), &mut encoded); match ValueReader::::new() .push(Bytes::from(encoded)) @@ -310,7 +312,7 @@ mod tests { #[test] fn value_reader_waits_for_complete_frame() { let mut encoded = Vec::new(); - encode_value_part(&BytesValue(b"hello".to_vec()), &mut encoded).unwrap(); + encode_value_part(&BytesValue(b"hello".to_vec()), &mut encoded); let encoded = Bytes::from(encoded); let reader = match ValueReader::::new() diff --git a/ql-rpc/src/lib.rs b/ql-rpc/src/lib.rs index ef27f843..e5277e08 100644 --- a/ql-rpc/src/lib.rs +++ b/ql-rpc/src/lib.rs @@ -1,22 +1,47 @@ //! quantum link rpc protocol traits and framing helpers. -use bytes::{Buf, BufMut}; - pub(crate) mod codec; mod error; +mod router; pub mod rpc; -pub use codec::{ReadValueStep, ValueReader}; +pub use codec::{ReadValueStep, RpcCodec, ValueReader}; pub use error::*; +pub use router::*; pub use rpc::*; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] -pub struct MethodId(pub u32); +pub struct RouteId(pub u32); + +impl RouteId { + pub const fn from_u32(value: u32) -> Self { + Self(value) + } + + pub const fn into_inner(self) -> u32 { + self.0 + } +} + +impl From for RouteId { + fn from(value: u32) -> Self { + Self::from_u32(value) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct StreamCloseCode(pub u16); -pub trait RpcCodec: Sized { - type Error; +impl StreamCloseCode { + pub const CANCELLED: Self = Self(0); + pub const REFUSED: Self = Self(1); + pub const TIMEOUT: Self = Self(2); + pub const LIMIT: Self = Self(3); + pub const UNKNOWN_ROUTE: Self = Self(4); - fn encode_value(&self, out: &mut B) -> Result<(), Self::Error>; - fn decode_value(bytes: &mut B) -> Result; + pub const fn into_inner(self) -> u16 { + self.0 + } } diff --git a/ql-rpc/src/router/builder.rs b/ql-rpc/src/router/builder.rs new file mode 100644 index 00000000..6a9b8ba0 --- /dev/null +++ b/ql-rpc/src/router/builder.rs @@ -0,0 +1,62 @@ +use std::collections::HashMap; + +use super::{ + request::{handle_request, RequestHandler}, + Router, RouterConfig, RpcStream, +}; +use crate::{request::Request as RequestRpc, router::RouteFn, RouteId}; + +pub struct RouterBuilder { + config: RouterConfig, + routes: HashMap>, +} + +impl Default for RouterBuilder { + fn default() -> Self { + Self::new() + } +} + +impl RouterBuilder { + pub fn new() -> Self { + Self { + config: RouterConfig::default(), + routes: std::collections::HashMap::new(), + } + } + + pub fn config(mut self, config: RouterConfig) -> Self { + self.config = config; + self + } + + pub fn max_request_bytes(mut self, max_request_bytes: usize) -> Self { + self.config.max_request_bytes = max_request_bytes; + self + } + + pub fn request(self) -> Self + where + M: RequestRpc, + S: RequestHandler, + St: RpcStream + 'static, + { + self.add_route(M::METHOD, handle_request::) + } + + pub fn build(mut self, state: S) -> Router { + self.routes.shrink_to_fit(); + Router { + config: self.config, + state, + routes: self.routes, + } + } + + fn add_route(mut self, route_id: crate::RouteId, route: super::RouteFn) -> Self { + if self.routes.insert(route_id, route).is_some() { + panic!("duplicate rpc route {}", route_id.into_inner()); + } + self + } +} diff --git a/ql-rpc/src/router/config.rs b/ql-rpc/src/router/config.rs new file mode 100644 index 00000000..d6fb048f --- /dev/null +++ b/ql-rpc/src/router/config.rs @@ -0,0 +1,12 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RouterConfig { + pub max_request_bytes: usize, +} + +impl Default for RouterConfig { + fn default() -> Self { + Self { + max_request_bytes: usize::MAX, + } + } +} diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs new file mode 100644 index 00000000..9dbec217 --- /dev/null +++ b/ql-rpc/src/router/mod.rs @@ -0,0 +1,45 @@ +use std::{collections::HashMap, future::Future, pin::Pin}; + +use crate::{RouteId, StreamCloseCode}; + +mod builder; +mod config; +mod request; +mod stream; + +pub use self::{ + builder::RouterBuilder, + config::RouterConfig, + request::RequestHandler, + stream::{RpcRead, RpcStream, RpcWrite}, +}; + +type RouteFuture<'a> = Pin + 'a>>; +type RouteFn = for<'a> fn(&'a S, RouterConfig, St) -> RouteFuture<'a>; + +pub struct Router { + config: RouterConfig, + state: S, + routes: HashMap>, +} + +impl Router +where + St: RpcStream, +{ + pub fn builder() -> RouterBuilder { + RouterBuilder::::new() + } + + pub async fn handle(&self, stream: St) { + let Some(route_id) = stream.route_id() else { + stream::close_stream(stream, StreamCloseCode::UNKNOWN_ROUTE); + return; + }; + let Some(route) = self.routes.get(&route_id).copied() else { + stream::close_stream(stream, StreamCloseCode::UNKNOWN_ROUTE); + return; + }; + route(&self.state, self.config, stream).await; + } +} diff --git a/ql-rpc/src/router/request.rs b/ql-rpc/src/router/request.rs new file mode 100644 index 00000000..07889442 --- /dev/null +++ b/ql-rpc/src/router/request.rs @@ -0,0 +1,110 @@ +use std::future::Future; + +use bytes::Bytes; + +use super::{ + stream::{read_bytes, write_bytes, RpcRead, RpcStream, RpcWrite}, + RouteFuture, RouterConfig, +}; +use crate::{ + request::{self, Request as RequestRpc}, + ReadValueStep, RpcCodec, StreamCloseCode, ValueReader, +}; + +pub trait RequestHandler +where + M: RequestRpc, +{ + type Future<'a>: Future> + 'a + where + Self: 'a; + + fn handle<'a>(&'a self, request: M::Request) -> Self::Future<'a>; +} + +pub(super) fn handle_request( + state: &S, + config: RouterConfig, + stream: St, +) -> RouteFuture<'_> +where + M: RequestRpc, + S: RequestHandler, + St: RpcStream + 'static, +{ + Box::pin(async move { + let (mut reader, mut writer) = stream.split(); + + let request = match read_value_and_eof::(&mut reader, config).await { + Ok(request) => request, + Err(code) => { + reader.close(code); + writer.close(code); + return; + } + }; + + let response = match state.handle(request).await { + Ok(response) => response, + Err(code) => { + writer.close(code); + return; + } + }; + + let mut encoded = Vec::new(); + request::encode_response::(&response, &mut encoded); + + if write_bytes(&mut writer, Bytes::from(encoded)) + .await + .is_err() + { + return; + } + writer.finish(); + }) +} + +async fn read_value_and_eof( + reader: &mut R, + config: RouterConfig, +) -> Result +where + T: RpcCodec, + R: RpcRead, +{ + let mut value_reader = ValueReader::::new(); + let mut total_read = 0usize; + + let value = loop { + match value_reader.advance() { + Ok(ReadValueStep::Value(value)) => break value, + Ok(ReadValueStep::NeedMore(next)) => value_reader = next, + Err(crate::CodecError::Rpc(_error)) => return Err(StreamCloseCode::REFUSED), + Err(crate::CodecError::Codec(_error)) => return Err(StreamCloseCode::REFUSED), + } + + let remaining = config.max_request_bytes.saturating_sub(total_read); + if remaining == 0 { + return Err(StreamCloseCode::LIMIT); + } + + match read_bytes(reader, remaining).await { + Ok(Some(chunk)) => { + total_read += chunk.len(); + value_reader = value_reader.push(chunk); + } + Ok(None) => return Err(StreamCloseCode::REFUSED), + Err(code) => return Err(code), + } + }; + + let remaining = config.max_request_bytes.saturating_sub(total_read); + let probe = remaining.max(1); + match read_bytes(reader, probe).await { + Ok(None) => Ok(value), + Ok(Some(_)) if remaining == 0 => Err(StreamCloseCode::LIMIT), + Ok(Some(_)) => Err(StreamCloseCode::REFUSED), + Err(code) => Err(code), + } +} diff --git a/ql-rpc/src/router/stream.rs b/ql-rpc/src/router/stream.rs new file mode 100644 index 00000000..dcaabbf8 --- /dev/null +++ b/ql-rpc/src/router/stream.rs @@ -0,0 +1,59 @@ +use std::{ + future::poll_fn, + task::{Context, Poll}, +}; + +use bytes::Bytes; + +use crate::{RouteId, StreamCloseCode}; + +pub trait RpcStream { + type Reader: RpcRead; + type Writer: RpcWrite; + + fn route_id(&self) -> Option; + fn split(self) -> (Self::Reader, Self::Writer); +} + +pub trait RpcRead { + fn poll_read( + &mut self, + max_len: usize, + cx: &mut Context<'_>, + ) -> Poll, StreamCloseCode>>; + fn close(self, code: StreamCloseCode); +} + +pub trait RpcWrite { + fn poll_write( + &mut self, + bytes: &mut Bytes, + cx: &mut Context<'_>, + ) -> Poll>; + fn finish(self); + fn close(self, code: StreamCloseCode); +} + +pub async fn read_bytes(reader: &mut R, max_len: usize) -> Result, StreamCloseCode> +where + R: RpcRead, +{ + poll_fn(|cx| reader.poll_read(max_len, cx)).await +} + +pub async fn write_bytes(writer: &mut W, bytes: Bytes) -> Result<(), StreamCloseCode> +where + W: RpcWrite, +{ + let mut bytes = bytes; + poll_fn(|cx| writer.poll_write(&mut bytes, cx)).await +} + +pub fn close_stream(stream: St, code: StreamCloseCode) +where + St: RpcStream, +{ + let (reader, writer) = stream.split(); + reader.close(code); + writer.close(code); +} diff --git a/ql-rpc/src/rpc/notification.rs b/ql-rpc/src/rpc/notification.rs index 0e09a8fc..8bc3e069 100644 --- a/ql-rpc/src/rpc/notification.rs +++ b/ql-rpc/src/rpc/notification.rs @@ -1,19 +1,13 @@ use bytes::BufMut; -use crate::{codec, MethodId, ReadValueStep, RpcCodec, ValueReader}; +use crate::{codec, RouteId, RpcCodec}; pub trait Notification { - const METHOD: MethodId; + const METHOD: RouteId; type Error; type Event: RpcCodec; } -pub type EventReader = ValueReader<::Event>; -pub type EventReadStep = ReadValueStep<::Event>; - -pub fn encode_event( - event: &M::Event, - out: &mut (impl BufMut + AsMut<[u8]>), -) -> Result<(), M::Error> { +pub fn encode_event(event: &M::Event, out: &mut (impl BufMut + AsMut<[u8]>)) { codec::encode_value_part(event, out) } diff --git a/ql-rpc/src/rpc/request.rs b/ql-rpc/src/rpc/request.rs index aae44664..82a4f180 100644 --- a/ql-rpc/src/rpc/request.rs +++ b/ql-rpc/src/rpc/request.rs @@ -1,29 +1,18 @@ use bytes::BufMut; -use crate::{codec, MethodId, ReadValueStep, RpcCodec, ValueReader}; +use crate::{codec, RouteId, RpcCodec}; pub trait Request { - const METHOD: MethodId; + const METHOD: RouteId; type Error; type Request: RpcCodec; type Response: RpcCodec; } -pub type RequestReader = ValueReader<::Request>; -pub type RequestReadStep = ReadValueStep<::Request>; -pub type ResponseReader = ValueReader<::Response>; -pub type ResponseReadStep = ReadValueStep<::Response>; - -pub fn encode_request( - request: &M::Request, - out: &mut (impl BufMut + AsMut<[u8]>), -) -> Result<(), M::Error> { +pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { codec::encode_value_part(request, out) } -pub fn encode_response( - response: &M::Response, - out: &mut (impl BufMut + AsMut<[u8]>), -) -> Result<(), M::Error> { +pub fn encode_response(response: &M::Response, out: &mut (impl BufMut + AsMut<[u8]>)) { codec::encode_value_part(response, out) } diff --git a/ql-rpc/src/rpc/request_with_progress.rs b/ql-rpc/src/rpc/request_with_progress.rs index 9617c349..159271a4 100644 --- a/ql-rpc/src/rpc/request_with_progress.rs +++ b/ql-rpc/src/rpc/request_with_progress.rs @@ -2,19 +2,16 @@ use std::marker::PhantomData; use bytes::{BufMut, Bytes}; -use crate::{codec, CodecError, Error, MethodId, ReadValueStep, RpcCodec, ValueReader}; +use crate::{codec, CodecError, Error, RouteId, RpcCodec}; pub trait RequestWithProgress { - const METHOD: MethodId; + const METHOD: RouteId; type Error; type Request: RpcCodec; type Progress: RpcCodec; type Response: RpcCodec; } -pub type RequestReader = ValueReader<::Request>; -pub type RequestReadStep = ReadValueStep<::Request>; - pub enum ReadStep { NeedMore(ResponseReader), Progress { @@ -89,21 +86,21 @@ enum FrameKind { pub fn encode_request( request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>), -) -> Result<(), M::Error> { +) { codec::encode_value_part(request, out) } pub fn encode_progress( progress: &M::Progress, out: &mut (impl BufMut + AsMut<[u8]>), -) -> Result<(), M::Error> { +) { encode_tagged_value_part(FrameKind::Progress, progress, out) } pub fn encode_response( response: &M::Response, out: &mut (impl BufMut + AsMut<[u8]>), -) -> Result<(), M::Error> { +) { encode_tagged_value_part(FrameKind::Response, response, out) } @@ -111,12 +108,11 @@ fn encode_tagged_value_part>( kind: FrameKind, value: &T, out: &mut B, -) -> Result<(), T::Error> { +) { out.put_u8(kind as u8); let payload_start = codec::reserve_length(out); - value.encode_value(out)?; + value.encode_value(out); codec::backpatch_length(out, payload_start); - Ok(()) } #[cfg(test)] @@ -124,7 +120,7 @@ mod tests { use bytes::{Buf, BufMut, Bytes}; use super::{encode_progress, encode_response, ReadStep, RequestWithProgress, ResponseReader}; - use crate::{MethodId, RpcCodec}; + use crate::{RouteId, RpcCodec}; #[derive(Debug, Clone, PartialEq, Eq)] struct BytesValue(Vec); @@ -132,9 +128,8 @@ mod tests { impl RpcCodec for BytesValue { type Error = core::convert::Infallible; - fn encode_value(&self, out: &mut B) -> Result<(), Self::Error> { + fn encode_value(&self, out: &mut B) { out.put_slice(&self.0); - Ok(()) } fn decode_value(bytes: &mut B) -> Result { @@ -145,7 +140,7 @@ mod tests { struct Watch; impl RequestWithProgress for Watch { - const METHOD: MethodId = MethodId(11); + const METHOD: RouteId = RouteId::from_u32(11); type Error = core::convert::Infallible; type Request = BytesValue; type Progress = BytesValue; @@ -155,8 +150,8 @@ mod tests { #[test] fn response_reader_emits_progress_then_response() { let mut encoded = Vec::new(); - encode_progress::(&BytesValue(b"10%".to_vec()), &mut encoded).unwrap(); - encode_response::(&BytesValue(b"done".to_vec()), &mut encoded).unwrap(); + encode_progress::(&BytesValue(b"10%".to_vec()), &mut encoded); + encode_response::(&BytesValue(b"done".to_vec()), &mut encoded); let reader = match ResponseReader::::new() .push(Bytes::from(encoded)) @@ -178,7 +173,7 @@ mod tests { #[test] fn response_reader_handles_response_only() { let mut encoded = Vec::new(); - encode_response::(&BytesValue(b"done".to_vec()), &mut encoded).unwrap(); + encode_response::(&BytesValue(b"done".to_vec()), &mut encoded); match ResponseReader::::new() .push(Bytes::from(encoded)) diff --git a/ql-rpc/src/rpc/subscription.rs b/ql-rpc/src/rpc/subscription.rs index e87439f8..460012cf 100644 --- a/ql-rpc/src/rpc/subscription.rs +++ b/ql-rpc/src/rpc/subscription.rs @@ -2,18 +2,15 @@ use std::marker::PhantomData; use bytes::{Buf, BufMut, Bytes}; -use crate::{codec, CodecError, Error, MethodId, ReadValueStep, RpcCodec, ValueReader}; +use crate::{codec, CodecError, Error, RouteId, RpcCodec}; pub trait Subscription { - const METHOD: MethodId; + const METHOD: RouteId; type Error; type Request: RpcCodec; type Event: RpcCodec; } -pub type RequestReader = ValueReader<::Request>; -pub type RequestReadStep = ReadValueStep<::Request>; - pub enum ReadStep { NeedMore(ResponseReader), Item { @@ -76,14 +73,11 @@ impl ResponseReader { pub fn encode_request( request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>), -) -> Result<(), M::Error> { +) { codec::encode_value_part(request, out) } -pub fn encode_item( - item: &M::Event, - out: &mut (impl BufMut + AsMut<[u8]>), -) -> Result<(), ::Error> { +pub fn encode_item(item: &M::Event, out: &mut (impl BufMut + AsMut<[u8]>)) { codec::encode_value_part(item, out) } @@ -96,7 +90,7 @@ mod tests { use bytes::{Buf, BufMut, Bytes}; use super::{encode_end, encode_item, ReadStep, ResponseReader, Subscription}; - use crate::{MethodId, RpcCodec}; + use crate::{RouteId, RpcCodec}; #[derive(Debug, Clone, PartialEq, Eq)] struct BytesValue(Vec); @@ -104,9 +98,8 @@ mod tests { impl RpcCodec for BytesValue { type Error = core::convert::Infallible; - fn encode_value(&self, out: &mut B) -> Result<(), Self::Error> { + fn encode_value(&self, out: &mut B) { out.put_slice(&self.0); - Ok(()) } fn decode_value(bytes: &mut B) -> Result { @@ -117,7 +110,7 @@ mod tests { struct Feed; impl Subscription for Feed { - const METHOD: MethodId = MethodId(17); + const METHOD: RouteId = RouteId::from_u32(17); type Error = core::convert::Infallible; type Request = BytesValue; type Event = BytesValue; @@ -126,8 +119,8 @@ mod tests { #[test] fn response_reader_streams_items_until_end() { let mut encoded = Vec::new(); - encode_item::(&BytesValue(b"one".to_vec()), &mut encoded).unwrap(); - encode_item::(&BytesValue(b"two".to_vec()), &mut encoded).unwrap(); + encode_item::(&BytesValue(b"one".to_vec()), &mut encoded); + encode_item::(&BytesValue(b"two".to_vec()), &mut encoded); encode_end(&mut encoded); let reader = match ResponseReader::::new() diff --git a/ql-runtime/src/chunk_slot.rs b/ql-runtime/src/chunk_slot.rs index 7d1ba990..d536bc33 100644 --- a/ql-runtime/src/chunk_slot.rs +++ b/ql-runtime/src/chunk_slot.rs @@ -121,10 +121,39 @@ impl ChunkSlotTx { self.inner.try_send(bytes) } + pub fn poll_send( + &self, + bytes: &mut Bytes, + listener: &mut Option, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + let chunk = std::mem::take(bytes); + + match self.try_send(chunk) { + Ok(()) => return Poll::Ready(Ok(())), + Err(TrySendError::Closed(chunk)) => { + *bytes = chunk.clone(); + return Poll::Ready(Err(SendClosed(chunk))); + } + Err(TrySendError::Full(chunk)) => *bytes = chunk, + } + + if let Some(active_listener) = listener.as_mut() { + match Pin::new(active_listener).poll(cx) { + Poll::Ready(()) => *listener = None, + Poll::Pending => return Poll::Pending, + } + } else { + *listener = Some(self.inner.changed.listen()); + } + } + } + pub fn send(&self, bytes: Bytes) -> Send<'_> { Send { tx: self, - bytes: Some(bytes), + bytes, listener: None, } } @@ -160,7 +189,7 @@ impl Future for Recv<'_> { pub struct Send<'a> { tx: &'a ChunkSlotTx, - bytes: Option, + bytes: Bytes, listener: Option, } @@ -168,27 +197,8 @@ impl Future for Send<'_> { type Output = Result<(), SendClosed>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - let bytes = self - .bytes - .take() - .expect("send future polled after completion"); - - match self.tx.try_send(bytes) { - Ok(()) => return Poll::Ready(Ok(())), - Err(TrySendError::Closed(bytes)) => return Poll::Ready(Err(SendClosed(bytes))), - Err(TrySendError::Full(bytes)) => self.bytes = Some(bytes), - } - - if let Some(listener) = self.listener.as_mut() { - match Pin::new(listener).poll(cx) { - Poll::Ready(()) => self.listener = None, - Poll::Pending => return Poll::Pending, - } - } else { - self.listener = Some(self.tx.inner.changed.listen()); - } - } + let this = self.as_mut().get_mut(); + this.tx.poll_send(&mut this.bytes, &mut this.listener, cx) } } diff --git a/ql-runtime/src/handle/writer.rs b/ql-runtime/src/handle/writer.rs index ff1375a5..7ec6ac9d 100644 --- a/ql-runtime/src/handle/writer.rs +++ b/ql-runtime/src/handle/writer.rs @@ -1,12 +1,24 @@ +use std::{ + future::{poll_fn, Future}, + pin::Pin, + task::{Context, Poll}, +}; + use bytes::Bytes; +use event_listener::EventListener; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; -use crate::{chunk_slot::ChunkSlotTx, command::RuntimeCommand, QlStreamError, RuntimeHandle}; +use crate::{ + chunk_slot::{ChunkSlotTx, SendClosed}, + command::RuntimeCommand, + QlStreamError, RuntimeHandle, +}; pub struct ByteWriter { stream_id: StreamId, target: CloseTarget, writer: Option, + listener: Option, terminal: WriteTerminalState, handle: RuntimeHandle, } @@ -38,6 +50,7 @@ impl ByteWriter { stream_id, target, writer: Some(writer), + listener: None, terminal: WriteTerminalState::Armed(terminal), handle, } @@ -49,19 +62,37 @@ impl ByteWriter { }); } - pub async fn write(&mut self, bytes: Bytes) -> Result<(), QlStreamError> { + pub fn poll_write( + &mut self, + bytes: &mut Bytes, + cx: &mut Context<'_>, + ) -> Poll> { if bytes.is_empty() { - return Ok(()); + return Poll::Ready(Ok(())); } + let Some(writer) = self.writer.as_ref() else { - return Err(self.terminal_error().await); + return self.poll_terminal_error(cx).map(Err); }; - if writer.send(bytes).await.is_err() { - self.writer.take(); - return Err(self.terminal_error().await); + + match writer.poll_send(bytes, &mut self.listener, cx) { + Poll::Ready(Ok(())) => { + self.listener = None; + self.poll_runtime(); + Poll::Ready(Ok(())) + } + Poll::Ready(Err(SendClosed(_bytes))) => { + self.writer.take(); + self.listener = None; + self.poll_terminal_error(cx).map(Err) + } + Poll::Pending => Poll::Pending, } - self.poll_runtime(); - Ok(()) + } + + pub async fn write(&mut self, bytes: Bytes) -> Result<(), QlStreamError> { + let mut bytes = bytes; + poll_fn(|cx| self.poll_write(&mut bytes, cx)).await } pub fn finish(mut self) { @@ -84,16 +115,19 @@ impl Drop for ByteWriter { } impl ByteWriter { - async fn terminal_error(&mut self) -> QlStreamError { + fn poll_terminal_error(&mut self, cx: &mut Context<'_>) -> Poll { match &mut self.terminal { - WriteTerminalState::Terminal(error) => error.clone(), - WriteTerminalState::Armed(receiver) => { - let error = receiver - .await - .expect("byte writer terminal dropped before sending a terminal state"); - self.terminal = WriteTerminalState::Terminal(error.clone()); - error - } + WriteTerminalState::Terminal(error) => Poll::Ready(error.clone()), + WriteTerminalState::Armed(receiver) => match Pin::new(receiver).poll(cx) { + Poll::Ready(Ok(error)) => { + self.terminal = WriteTerminalState::Terminal(error.clone()); + Poll::Ready(error) + } + Poll::Ready(Err(_)) => { + panic!("byte writer terminal dropped before sending a terminal state") + } + Poll::Pending => Poll::Pending, + }, } } @@ -101,6 +135,7 @@ impl ByteWriter { if self.writer.take().is_none() { return; } + self.listener = None; self.handle.send(RuntimeCommand::CloseStream { stream_id: self.stream_id, target: self.target, diff --git a/ql-runtime/src/rpc/adapter.rs b/ql-runtime/src/rpc/adapter.rs new file mode 100644 index 00000000..2b3235b8 --- /dev/null +++ b/ql-runtime/src/rpc/adapter.rs @@ -0,0 +1,73 @@ +use std::task::{Context, Poll}; + +use bytes::Bytes; +pub use ql_rpc::{RequestHandler, RouteId, RouterConfig, StreamCloseCode}; +use ql_rpc::{RpcRead, RpcStream, RpcWrite}; +use ql_wire::{RouteId as WireRouteId, StreamCloseCode as WireStreamCloseCode}; + +use crate::{ByteReader, ByteWriter, QlStream, QlStreamError}; + +pub type Router = ql_rpc::Router; +pub type RouterBuilder = ql_rpc::RouterBuilder; + +impl RpcStream for QlStream { + type Reader = ByteReader; + type Writer = ByteWriter; + + fn route_id(&self) -> Option { + let route_id = u32::try_from(self.route_id.into_inner()).ok()?; + Some(RouteId::from_u32(route_id)) + } + + fn split(self) -> (Self::Reader, Self::Writer) { + (self.reader, self.writer) + } +} + +impl RpcRead for ByteReader { + fn poll_read( + &mut self, + max_len: usize, + cx: &mut Context<'_>, + ) -> Poll, StreamCloseCode>> { + ByteReader::poll_read(self, max_len, cx).map(|result| result.map_err(from_stream_error)) + } + + fn close(self, code: StreamCloseCode) { + ByteReader::close(self, to_wire_close_code(code)); + } +} + +impl RpcWrite for ByteWriter { + fn poll_write( + &mut self, + bytes: &mut Bytes, + cx: &mut Context<'_>, + ) -> Poll> { + ByteWriter::poll_write(self, bytes, cx).map(|result| result.map_err(from_stream_error)) + } + + fn finish(self) { + ByteWriter::finish(self); + } + + fn close(self, code: StreamCloseCode) { + ByteWriter::close(self, to_wire_close_code(code)); + } +} + +pub(super) fn to_wire_route_id(route_id: RouteId) -> WireRouteId { + WireRouteId::from_u32(route_id.into_inner()) +} + +pub(super) fn to_wire_close_code(code: StreamCloseCode) -> WireStreamCloseCode { + WireStreamCloseCode(code.into_inner()) +} + +fn from_stream_error(error: QlStreamError) -> StreamCloseCode { + let code = match error { + QlStreamError::StreamClosed { code } => code, + QlStreamError::NoSession => WireStreamCloseCode::CANCELLED, + }; + StreamCloseCode(code.0) +} diff --git a/ql-runtime/src/rpc/error.rs b/ql-runtime/src/rpc/error.rs index 9c30ad07..1e3e03e9 100644 --- a/ql-runtime/src/rpc/error.rs +++ b/ql-runtime/src/rpc/error.rs @@ -1,12 +1,11 @@ use ql_fsm::NoSessionError; -use ql_wire::StreamCloseCode; use crate::QlStreamError; #[derive(Debug)] pub enum RpcError { NoSession, - Closed(StreamCloseCode), + Closed(ql_rpc::StreamCloseCode), Protocol(ql_rpc::Error), Codec(E), } @@ -20,7 +19,7 @@ impl From for RpcError { impl From for RpcError { fn from(error: QlStreamError) -> Self { match error { - QlStreamError::StreamClosed { code } => Self::Closed(code), + QlStreamError::StreamClosed { code } => Self::Closed(ql_rpc::StreamCloseCode(code.0)), QlStreamError::NoSession => Self::NoSession, } } diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index 51b9c58c..576eb682 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -1,3 +1,4 @@ +mod adapter; mod error; mod request_with_progress; mod subscription; @@ -12,9 +13,8 @@ use ql_rpc::{ subscription::{self as rpc_subscription, Subscription as SubscriptionRpc}, Error, ReadValueStep, RpcCodec, ValueReader, }; -use ql_wire::{RouteId, VarInt}; -pub use self::{error::*, request_with_progress::*, subscription::*}; +pub use self::{adapter::*, error::*, request_with_progress::*, subscription::*}; use crate::{ByteReader, RuntimeHandle}; #[derive(Clone)] @@ -28,11 +28,14 @@ impl RpcHandle { M: Notification, { let mut payload = Vec::new(); - notification::encode_event::(event, &mut payload).map_err(RpcError::Codec)?; - let route_id = RouteId(VarInt::from_u32(M::METHOD.0)); - let mut stream = self.inner.open_stream(route_id).await?; + notification::encode_event::(event, &mut payload); + let mut stream = self + .inner + .open_stream(adapter::to_wire_route_id(M::METHOD)) + .await?; stream.reader.close(ql_wire::StreamCloseCode::CANCELLED); stream.writer.write(Bytes::from(payload)).await?; + stream.writer.finish(); Ok(()) } @@ -41,7 +44,7 @@ impl RpcHandle { M: RequestRpc, { let mut payload = Vec::new(); - request::encode_request::(request, &mut payload).map_err(RpcError::Codec)?; + request::encode_request::(request, &mut payload); let response = self.start_request(M::METHOD, payload).await?; read_value::(response).await } @@ -54,7 +57,7 @@ impl RpcHandle { M: SubscriptionRpc, { let mut payload = Vec::new(); - rpc_subscription::encode_request::(request, &mut payload).map_err(RpcError::Codec)?; + rpc_subscription::encode_request::(request, &mut payload); let response = self.start_request(M::METHOD, payload).await?; Ok(Subscription { stream: response, @@ -70,8 +73,7 @@ impl RpcHandle { M: RequestWithProgress, { let mut payload = Vec::new(); - rpc_request_with_progress::encode_request::(request, &mut payload) - .map_err(RpcError::Codec)?; + rpc_request_with_progress::encode_request::(request, &mut payload); let response = self.start_request(M::METHOD, payload).await?; Ok(ProgressCall { stream: response, @@ -82,11 +84,13 @@ impl RpcHandle { async fn start_request( &self, - method: ql_rpc::MethodId, + route_id: ql_rpc::RouteId, payload: Vec, ) -> Result> { - let route_id = RouteId(VarInt::from_u32(method.0)); - let mut stream = self.inner.open_stream(route_id).await?; + let mut stream = self + .inner + .open_stream(adapter::to_wire_route_id(route_id)) + .await?; stream.writer.write(Bytes::from(payload)).await?; stream.writer.finish(); Ok(stream.reader) diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 3e8799a7..d23fd445 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -14,7 +14,7 @@ use ql_fsm::PeerStatus; use ql_wire::{ test_identities, test_identity, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, PairingToken, PeerBundle, QlAead, QlHash, QlIdentity, QlKem, QlRandom, RecordHeader, - RecordType, RouteId, SessionKey, SoftwareCrypto, VarInt, WireDecode, XID, + RecordType, RouteId, SessionKey, SoftwareCrypto, WireDecode, XID, }; use tokio::{task::LocalSet, time::Sleep}; @@ -52,7 +52,7 @@ impl Side { } fn test_route_id() -> RouteId { - RouteId(VarInt::from_u32(1)) + RouteId::from_u32(1) } #[derive(Debug, Clone)] diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index b6e1a059..3161975f 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -1,8 +1,9 @@ -use std::time::Duration; +use std::{cell::RefCell, future::Ready, rc::Rc, time::Duration}; use bytes::{Buf, BufMut, Bytes}; use futures_lite::StreamExt; -use ql_wire::RouteId; +use ql_rpc::{RouteId, StreamCloseCode}; +use ql_wire::RouteId as WireRouteId; use super::*; @@ -12,9 +13,8 @@ struct BytesValue(Vec); impl ql_rpc::RpcCodec for BytesValue { type Error = core::convert::Infallible; - fn encode_value(&self, out: &mut B) -> Result<(), Self::Error> { + fn encode_value(&self, out: &mut B) { out.put_slice(&self.0); - Ok(()) } fn decode_value(bytes: &mut B) -> Result { @@ -25,7 +25,7 @@ impl ql_rpc::RpcCodec for BytesValue { struct Echo; impl ql_rpc::request::Request for Echo { - const METHOD: ql_rpc::MethodId = ql_rpc::MethodId(51); + const METHOD: RouteId = RouteId::from_u32(51); type Error = core::convert::Infallible; type Request = BytesValue; type Response = BytesValue; @@ -34,7 +34,7 @@ impl ql_rpc::request::Request for Echo { struct Feed; impl ql_rpc::subscription::Subscription for Feed { - const METHOD: ql_rpc::MethodId = ql_rpc::MethodId(52); + const METHOD: RouteId = RouteId::from_u32(52); type Error = core::convert::Infallible; type Request = BytesValue; type Event = BytesValue; @@ -43,7 +43,7 @@ impl ql_rpc::subscription::Subscription for Feed { struct Download; impl ql_rpc::request_with_progress::RequestWithProgress for Download { - const METHOD: ql_rpc::MethodId = ql_rpc::MethodId(53); + const METHOD: RouteId = RouteId::from_u32(53); type Error = core::convert::Infallible; type Request = BytesValue; type Progress = BytesValue; @@ -62,13 +62,12 @@ async fn rpc_request_round_trips() { let request: BytesValue = read_rpc_value(inbound.reader).await; assert_eq!( inbound.route_id, - route_id(::METHOD) + to_wire_route_id(::METHOD) ); assert_eq!(request, BytesValue(b"hello".to_vec())); let mut encoded = Vec::new(); - ql_rpc::request::encode_response::(&BytesValue(b"world".to_vec()), &mut encoded) - .unwrap(); + ql_rpc::request::encode_response::(&BytesValue(b"world".to_vec()), &mut encoded); let mut writer = inbound.writer; writer.write(Bytes::from(encoded)).await.unwrap(); writer.finish(); @@ -89,6 +88,99 @@ async fn rpc_request_round_trips() { .await; } +#[tokio::test(flavor = "current_thread")] +async fn rpc_router_handles_request() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>>, + } + + impl crate::rpc::RequestHandler for RouterState { + type Future<'a> + = Ready> + where + Self: 'a; + + fn handle<'a>(&'a self, request: BytesValue) -> Self::Future<'a> { + self.seen.borrow_mut().push(request.0); + std::future::ready(Ok(BytesValue(b"world".to_vec()))) + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let seen = Rc::new(RefCell::new(Vec::new())); + let router = crate::rpc::Router::builder() + .request::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + router.handle(inbound).await; + }); + + let rpc = pair.handle(Side::A).rpc(); + let response = rpc + .request::(&BytesValue(b"hello".to_vec())) + .await + .unwrap(); + assert_eq!(response, BytesValue(b"world".to_vec())); + assert_eq!(&*seen.borrow(), &[b"hello".to_vec()]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_router_enforces_max_request_bytes() { + struct LimitedState; + + impl crate::rpc::RequestHandler for LimitedState { + type Future<'a> + = Ready> + where + Self: 'a; + + fn handle<'a>(&'a self, request: BytesValue) -> Self::Future<'a> { + std::future::ready(Ok(request)) + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let router = crate::rpc::Router::builder() + .max_request_bytes(4) + .request::() + .build(LimitedState); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + router.handle(inbound).await; + }); + + let rpc = pair.handle(Side::A).rpc(); + let response = rpc.request::(&BytesValue(b"hello".to_vec())).await; + assert!(matches!( + response, + Err(crate::rpc::RpcError::Closed(code)) if code == StreamCloseCode::LIMIT + )); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + #[tokio::test(flavor = "current_thread")] async fn rpc_subscription_streams_events() { run_local_test(async { @@ -101,15 +193,13 @@ async fn rpc_subscription_streams_events() { let request: BytesValue = read_rpc_value(inbound.reader).await; assert_eq!( inbound.route_id, - route_id(::METHOD) + to_wire_route_id(::METHOD) ); assert_eq!(request, BytesValue(b"watch".to_vec())); let mut encoded = Vec::new(); - ql_rpc::subscription::encode_item::(&BytesValue(b"one".to_vec()), &mut encoded) - .unwrap(); - ql_rpc::subscription::encode_item::(&BytesValue(b"two".to_vec()), &mut encoded) - .unwrap(); + ql_rpc::subscription::encode_item::(&BytesValue(b"one".to_vec()), &mut encoded); + ql_rpc::subscription::encode_item::(&BytesValue(b"two".to_vec()), &mut encoded); ql_rpc::subscription::encode_end(&mut encoded); let mut writer = inbound.writer; @@ -152,7 +242,9 @@ async fn rpc_request_with_progress_supports_progress_then_await() { let request: BytesValue = read_rpc_value(inbound.reader).await; assert_eq!( inbound.route_id, - route_id(::METHOD) + to_wire_route_id( + ::METHOD + ) ); assert_eq!(request, BytesValue(b"logo".to_vec())); @@ -160,18 +252,15 @@ async fn rpc_request_with_progress_supports_progress_then_await() { ql_rpc::request_with_progress::encode_progress::( &BytesValue(b"10".to_vec()), &mut encoded, - ) - .unwrap(); + ); ql_rpc::request_with_progress::encode_progress::( &BytesValue(b"90".to_vec()), &mut encoded, - ) - .unwrap(); + ); ql_rpc::request_with_progress::encode_response::( &BytesValue(b"done".to_vec()), &mut encoded, - ) - .unwrap(); + ); let mut writer = inbound.writer; writer.write(Bytes::from(encoded)).await.unwrap(); @@ -197,10 +286,6 @@ async fn rpc_request_with_progress_supports_progress_then_await() { .await; } -fn route_id(method: ql_rpc::MethodId) -> RouteId { - RouteId(ql_wire::VarInt::from_u32(method.0)) -} - async fn read_rpc_value(mut reader: crate::ByteReader) -> T where T: ql_rpc::RpcCodec, @@ -220,3 +305,7 @@ where } } } + +fn to_wire_route_id(route_id: RouteId) -> WireRouteId { + WireRouteId::from_u32(route_id.into_inner()) +} diff --git a/ql-wire/src/encrypted/route_id.rs b/ql-wire/src/encrypted/route_id.rs index 2338e634..f7b51999 100644 --- a/ql-wire/src/encrypted/route_id.rs +++ b/ql-wire/src/encrypted/route_id.rs @@ -1,4 +1,4 @@ -use crate::{ByteSlice, Reader, VarInt, WireDecode, WireEncode, WireError}; +use crate::{ByteSlice, Reader, VarInt, VarIntBoundsExceeded, WireDecode, WireEncode, WireError}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] @@ -7,6 +7,14 @@ pub struct RouteId(pub VarInt); impl RouteId { pub const MAX_ENCODED_LEN: usize = VarInt::MAX_SIZE; + pub const fn from_u32(value: u32) -> Self { + Self(VarInt::from_u32(value)) + } + + pub fn from_u64(value: u64) -> Result { + Ok(Self(VarInt::from_u64(value)?)) + } + pub const fn into_inner(self) -> u64 { self.0.into_inner() } @@ -27,3 +35,15 @@ impl WireDecode for RouteId { Ok(Self(reader.decode()?)) } } + +impl From for RouteId { + fn from(value: VarInt) -> Self { + Self(value) + } +} + +impl From for RouteId { + fn from(value: u32) -> Self { + Self::from_u32(value) + } +} diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index 66721b01..2885eaa0 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -91,14 +91,6 @@ pub struct StreamCloseCode(pub u16); impl StreamCloseCode { /// the stream was aborted intentionally before graceful completion pub const CANCELLED: Self = Self(0); - /// the peer declined to service the stream - pub const REFUSED: Self = Self(1); - /// the stream was aborted because progress took too long - pub const TIMEOUT: Self = Self(2); - /// the stream exceeded a size, quota, or buffering limit - pub const LIMIT: Self = Self(3); - /// the stream route was not recognized by the peer - pub const UNKNOWN_ROUTE: Self = Self(4); } impl codec::WireDecode for StreamCloseCode { From 5852e0d49478df14ab95279b929beffb30724b9d Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 13 Apr 2026 06:20:45 -0400 Subject: [PATCH 198/304] ql-rpc: return fut from router --- ql-rpc/src/router/mod.rs | 18 ++++++++++-------- ql-runtime/src/tests/rpc.rs | 8 ++++++-- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs index 9dbec217..d7c6f696 100644 --- a/ql-rpc/src/router/mod.rs +++ b/ql-rpc/src/router/mod.rs @@ -31,15 +31,17 @@ where RouterBuilder::::new() } - pub async fn handle(&self, stream: St) { - let Some(route_id) = stream.route_id() else { + pub fn handle(&self, stream: St) -> Option<(RouteId, RouteFuture<'_>)> { + let route_id = stream.route_id()?; + let Some(route) = stream + .route_id() + .and_then(|route_id| self.routes.get(&route_id)) + .copied() + else { stream::close_stream(stream, StreamCloseCode::UNKNOWN_ROUTE); - return; + return None; }; - let Some(route) = self.routes.get(&route_id).copied() else { - stream::close_stream(stream, StreamCloseCode::UNKNOWN_ROUTE); - return; - }; - route(&self.state, self.config, stream).await; + let fut = route(&self.state, self.config, stream); + Some((route_id, fut)) } } diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 3161975f..2bb787b5 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -118,7 +118,9 @@ async fn rpc_router_handles_request() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); - router.handle(inbound).await; + if let Some((_, fut)) = router.handle(inbound) { + fut.await + } }); let rpc = pair.handle(Side::A).rpc(); @@ -163,7 +165,9 @@ async fn rpc_router_enforces_max_request_bytes() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); - router.handle(inbound).await; + if let Some((_, fut)) = router.handle(inbound) { + fut.await + } }); let rpc = pair.handle(Side::A).rpc(); From a829d9bfe931c663777bff4792876b32a15c7fae Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 13 Apr 2026 07:16:04 -0400 Subject: [PATCH 199/304] send router --- ql-rpc/src/router/builder.rs | 52 ++++++++---- ql-rpc/src/router/mod.rs | 35 ++++---- ql-rpc/src/router/mode.rs | 25 ++++++ ql-rpc/src/router/request.rs | 150 ++++++++++++++++++++++++---------- ql-runtime/src/rpc/adapter.rs | 6 +- ql-runtime/src/tests/rpc.rs | 104 +++++++++++++++++++---- 6 files changed, 276 insertions(+), 96 deletions(-) create mode 100644 ql-rpc/src/router/mode.rs diff --git a/ql-rpc/src/router/builder.rs b/ql-rpc/src/router/builder.rs index 6a9b8ba0..ababa61c 100644 --- a/ql-rpc/src/router/builder.rs +++ b/ql-rpc/src/router/builder.rs @@ -1,23 +1,32 @@ use std::collections::HashMap; use super::{ - request::{handle_request, RequestHandler}, - Router, RouterConfig, RpcStream, + request::{RequestHandler, RequestRouteMode}, + LocalMode, RouteMode, Router, RouterConfig, RpcStream, }; use crate::{request::Request as RequestRpc, router::RouteFn, RouteId}; -pub struct RouterBuilder { +pub struct RouterBuilder +where + Mode: RouteMode, +{ config: RouterConfig, - routes: HashMap>, + routes: HashMap>, } -impl Default for RouterBuilder { +impl Default for RouterBuilder +where + Mode: RouteMode, +{ fn default() -> Self { Self::new() } } -impl RouterBuilder { +impl RouterBuilder +where + Mode: RouteMode, +{ pub fn new() -> Self { Self { config: RouterConfig::default(), @@ -35,16 +44,7 @@ impl RouterBuilder { self } - pub fn request(self) -> Self - where - M: RequestRpc, - S: RequestHandler, - St: RpcStream + 'static, - { - self.add_route(M::METHOD, handle_request::) - } - - pub fn build(mut self, state: S) -> Router { + pub fn build(mut self, state: S) -> Router { self.routes.shrink_to_fit(); Router { config: self.config, @@ -53,10 +53,28 @@ impl RouterBuilder { } } - fn add_route(mut self, route_id: crate::RouteId, route: super::RouteFn) -> Self { + fn add_route(mut self, route_id: crate::RouteId, route: super::RouteFn) -> Self { if self.routes.insert(route_id, route).is_some() { panic!("duplicate rpc route {}", route_id.into_inner()); } self } } + +impl RouterBuilder +where + Mode: RouteMode, +{ + pub fn request(self) -> Self + where + M: RequestRpc + 'static, + S: RequestHandler + 'static, + St: RpcStream + 'static, + Mode: RequestRouteMode, + { + self.add_route( + M::METHOD, + >::handle_request, + ) + } +} diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs index d7c6f696..61fc9eb9 100644 --- a/ql-rpc/src/router/mod.rs +++ b/ql-rpc/src/router/mod.rs @@ -1,47 +1,46 @@ -use std::{collections::HashMap, future::Future, pin::Pin}; +use std::collections::HashMap; use crate::{RouteId, StreamCloseCode}; mod builder; mod config; +mod mode; mod request; mod stream; pub use self::{ builder::RouterBuilder, config::RouterConfig, - request::RequestHandler, + mode::*, + request::{RequestHandler, Response}, stream::{RpcRead, RpcStream, RpcWrite}, }; -type RouteFuture<'a> = Pin + 'a>>; -type RouteFn = for<'a> fn(&'a S, RouterConfig, St) -> RouteFuture<'a>; - -pub struct Router { +pub struct Router +where + Mode: RouteMode, +{ config: RouterConfig, state: S, - routes: HashMap>, + routes: HashMap>, } -impl Router +impl Router where + S: Clone + 'static, St: RpcStream, + Mode: RouteMode, { - pub fn builder() -> RouterBuilder { - RouterBuilder::::new() + pub fn builder() -> RouterBuilder { + RouterBuilder::::new() } - pub fn handle(&self, stream: St) -> Option<(RouteId, RouteFuture<'_>)> { + pub fn handle(&self, stream: St) -> Option<(RouteId, Mode::RouteFuture)> { let route_id = stream.route_id()?; - let Some(route) = stream - .route_id() - .and_then(|route_id| self.routes.get(&route_id)) - .copied() - else { + let Some(route) = self.routes.get(&route_id).copied() else { stream::close_stream(stream, StreamCloseCode::UNKNOWN_ROUTE); return None; }; - let fut = route(&self.state, self.config, stream); - Some((route_id, fut)) + Some((route_id, route(self.state.clone(), self.config, stream))) } } diff --git a/ql-rpc/src/router/mode.rs b/ql-rpc/src/router/mode.rs new file mode 100644 index 00000000..b0a27b1b --- /dev/null +++ b/ql-rpc/src/router/mode.rs @@ -0,0 +1,25 @@ +use std::{future::Future, pin::Pin}; + +use crate::RouterConfig; + +pub trait RouteMode { + type RouteFuture: Future + 'static; +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct LocalMode; + +#[derive(Debug, Clone, Copy, Default)] +pub struct SendMode; + +pub type RouteFn = fn(S, RouterConfig, St) -> ::RouteFuture; +pub type LocalFuture = Pin + 'static>>; +pub type SendFuture = Pin + Send + 'static>>; + +impl RouteMode for LocalMode { + type RouteFuture = LocalFuture; +} + +impl RouteMode for SendMode { + type RouteFuture = SendFuture; +} diff --git a/ql-rpc/src/router/request.rs b/ql-rpc/src/router/request.rs index 07889442..077e9676 100644 --- a/ql-rpc/src/router/request.rs +++ b/ql-rpc/src/router/request.rs @@ -1,68 +1,134 @@ -use std::future::Future; +use std::marker::PhantomData; use bytes::Bytes; use super::{ stream::{read_bytes, write_bytes, RpcRead, RpcStream, RpcWrite}, - RouteFuture, RouterConfig, + LocalMode, RouteMode, RouterConfig, SendMode, }; use crate::{ - request::{self, Request as RequestRpc}, - ReadValueStep, RpcCodec, StreamCloseCode, ValueReader, + codec, request::Request as RequestRpc, ReadValueStep, RpcCodec, StreamCloseCode, ValueReader, }; -pub trait RequestHandler +pub trait RequestHandler where M: RequestRpc, + St: RpcStream, { - type Future<'a>: Future> + 'a - where - Self: 'a; + fn handle(self, message: M::Request, responder: Response); +} - fn handle<'a>(&'a self, request: M::Request) -> Self::Future<'a>; +pub struct Response +where + W: RpcWrite, +{ + writer: Option, + marker: PhantomData T>, } -pub(super) fn handle_request( - state: &S, - config: RouterConfig, - stream: St, -) -> RouteFuture<'_> +impl Response where - M: RequestRpc, - S: RequestHandler, + T: RpcCodec, + W: RpcWrite, +{ + fn new(writer: W) -> Self { + Self { + writer: Some(writer), + marker: PhantomData, + } + } + + pub async fn respond(mut self, response: T) -> Result<(), StreamCloseCode> { + let mut writer = self.writer.take().expect("response writer exists"); + let mut encoded = Vec::new(); + codec::encode_value_part(&response, &mut encoded); + if let Err(code) = write_bytes(&mut writer, Bytes::from(encoded)).await { + writer.close(code); + return Err(code); + } + writer.finish(); + Ok(()) + } + + pub fn close(mut self, code: StreamCloseCode) { + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for Response +where + W: RpcWrite, +{ + fn drop(&mut self) { + if let Some(writer) = self.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } +} + +#[doc(hidden)] +pub trait RequestRouteMode: RouteMode +where + M: RequestRpc + 'static, + S: RequestHandler + 'static, St: RpcStream + 'static, { - Box::pin(async move { - let (mut reader, mut writer) = stream.split(); - - let request = match read_value_and_eof::(&mut reader, config).await { - Ok(request) => request, - Err(code) => { - reader.close(code); - writer.close(code); - return; - } - }; + fn handle_request(state: S, config: RouterConfig, stream: St) -> Self::RouteFuture; +} - let response = match state.handle(request).await { - Ok(response) => response, - Err(code) => { - writer.close(code); - return; - } - }; +impl RequestRouteMode for LocalMode +where + M: RequestRpc + 'static, + S: RequestHandler + 'static, + St: RpcStream + 'static, +{ + fn handle_request(state: S, config: RouterConfig, stream: St) -> Self::RouteFuture { + let (reader, writer) = stream.split(); + Box::pin(handle_request_inner::( + state, config, reader, writer, + )) + } +} - let mut encoded = Vec::new(); - request::encode_response::(&response, &mut encoded); +impl RequestRouteMode for SendMode +where + M: RequestRpc + 'static, + M::Request: Send + 'static, + S: RequestHandler + Send + 'static, + St: RpcStream + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, +{ + fn handle_request(state: S, config: RouterConfig, stream: St) -> Self::RouteFuture { + let (reader, writer) = stream.split(); + Box::pin(handle_request_inner::( + state, config, reader, writer, + )) + } +} - if write_bytes(&mut writer, Bytes::from(encoded)) - .await - .is_err() - { +async fn handle_request_inner( + state: S, + config: RouterConfig, + mut reader: St::Reader, + writer: St::Writer, +) where + M: RequestRpc + 'static, + S: RequestHandler + 'static, + St: RpcStream + 'static, +{ + let request = match read_value_and_eof::(&mut reader, config).await { + Ok(request) => request, + Err(code) => { + reader.close(code); + writer.close(code); return; } - writer.finish(); - }) + }; + + state.handle(request, Response::new(writer)); } async fn read_value_and_eof( diff --git a/ql-runtime/src/rpc/adapter.rs b/ql-runtime/src/rpc/adapter.rs index 2b3235b8..a391f7d2 100644 --- a/ql-runtime/src/rpc/adapter.rs +++ b/ql-runtime/src/rpc/adapter.rs @@ -1,7 +1,9 @@ use std::task::{Context, Poll}; use bytes::Bytes; -pub use ql_rpc::{RequestHandler, RouteId, RouterConfig, StreamCloseCode}; +pub use ql_rpc::{ + LocalMode, RequestHandler, Response, RouteId, RouterConfig, SendMode, StreamCloseCode, +}; use ql_rpc::{RpcRead, RpcStream, RpcWrite}; use ql_wire::{RouteId as WireRouteId, StreamCloseCode as WireStreamCloseCode}; @@ -9,6 +11,8 @@ use crate::{ByteReader, ByteWriter, QlStream, QlStreamError}; pub type Router = ql_rpc::Router; pub type RouterBuilder = ql_rpc::RouterBuilder; +pub type SendRouter = ql_rpc::Router; +pub type SendRouterBuilder = ql_rpc::RouterBuilder; impl RpcStream for QlStream { type Reader = ByteReader; diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 2bb787b5..36292e4e 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -1,4 +1,9 @@ -use std::{cell::RefCell, future::Ready, rc::Rc, time::Duration}; +use std::{ + cell::RefCell, + rc::Rc, + sync::{Arc, Mutex}, + time::Duration, +}; use bytes::{Buf, BufMut, Bytes}; use futures_lite::StreamExt; @@ -6,6 +11,7 @@ use ql_rpc::{RouteId, StreamCloseCode}; use ql_wire::RouteId as WireRouteId; use super::*; +use crate::ByteWriter; #[derive(Debug, Clone, PartialEq, Eq)] struct BytesValue(Vec); @@ -50,6 +56,10 @@ impl ql_rpc::request_with_progress::RequestWithProgress for Download { type Response = BytesValue; } +fn assert_send(value: T) -> T { + value +} + #[tokio::test(flavor = "current_thread")] async fn rpc_request_round_trips() { run_local_test(async { @@ -95,15 +105,17 @@ async fn rpc_router_handles_request() { seen: Rc>>>, } - impl crate::rpc::RequestHandler for RouterState { - type Future<'a> - = Ready> - where - Self: 'a; - - fn handle<'a>(&'a self, request: BytesValue) -> Self::Future<'a> { - self.seen.borrow_mut().push(request.0); - std::future::ready(Ok(BytesValue(b"world".to_vec()))) + impl crate::rpc::RequestHandler for RouterState { + fn handle( + self, + request: BytesValue, + response: crate::rpc::Response, + ) { + let seen = self.seen.clone(); + tokio::task::spawn_local(async move { + seen.borrow_mut().push(request.0); + let _ = response.respond(BytesValue(b"world".to_vec())).await; + }); } } @@ -139,18 +151,74 @@ async fn rpc_router_handles_request() { .await; } +#[tokio::test(flavor = "current_thread")] +async fn rpc_send_router_handles_request() { + #[derive(Clone)] + struct RouterState { + seen: Arc>>>, + } + + impl crate::rpc::RequestHandler for RouterState { + fn handle( + self, + request: BytesValue, + response: crate::rpc::Response, + ) { + let seen = self.seen.clone(); + tokio::task::spawn(async move { + seen.lock().unwrap().push(request.0); + let _ = response.respond(BytesValue(b"world".to_vec())).await; + }); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let seen = Arc::new(Mutex::new(Vec::new())); + let router = crate::rpc::SendRouter::builder() + .request::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + let fut = assert_send(fut); + fut.await + } + }); + + let rpc = pair.handle(Side::A).rpc(); + let response = rpc + .request::(&BytesValue(b"hello".to_vec())) + .await + .unwrap(); + assert_eq!(response, BytesValue(b"world".to_vec())); + assert_eq!(&*seen.lock().unwrap(), &[b"hello".to_vec()]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + #[tokio::test(flavor = "current_thread")] async fn rpc_router_enforces_max_request_bytes() { + #[derive(Clone)] struct LimitedState; - impl crate::rpc::RequestHandler for LimitedState { - type Future<'a> - = Ready> - where - Self: 'a; - - fn handle<'a>(&'a self, request: BytesValue) -> Self::Future<'a> { - std::future::ready(Ok(request)) + impl crate::rpc::RequestHandler for LimitedState { + fn handle( + self, + request: BytesValue, + response: crate::rpc::Response, + ) { + tokio::task::spawn_local(async move { + let _ = response.respond(request).await; + }); } } From 6784f127a38deda3dc1bfd949286d91467696dc2 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 13 Apr 2026 08:22:33 -0400 Subject: [PATCH 200/304] ql-rpc: default codec --- ql-rpc/src/codec.rs | 38 ++++++++++++++++++- ql-runtime/src/tests/rpc.rs | 75 +++++++++++++++---------------------- 2 files changed, 68 insertions(+), 45 deletions(-) diff --git a/ql-rpc/src/codec.rs b/ql-rpc/src/codec.rs index dab369c6..fa4e8b25 100644 --- a/ql-rpc/src/codec.rs +++ b/ql-rpc/src/codec.rs @@ -1,4 +1,4 @@ -use std::{collections::VecDeque, marker::PhantomData}; +use std::{collections::VecDeque, convert::Infallible, marker::PhantomData, str::Utf8Error}; use bytes::{Buf, BufMut, Bytes}; @@ -11,6 +11,42 @@ pub trait RpcCodec: Sized { fn decode_value(bytes: &mut B) -> Result; } +impl RpcCodec for String { + type Error = Utf8Error; + + fn encode_value(&self, out: &mut B) { + out.put_slice(self.as_bytes()); + } + + fn decode_value(bytes: &mut B) -> Result { + let len = bytes.remaining(); + if bytes.chunk().len() == len { + let s = std::str::from_utf8(bytes.chunk())?.to_owned(); + bytes.advance(len); + Ok(s) + } else { + let mut buf = vec![0; len]; + bytes.copy_to_slice(&mut buf); + String::from_utf8(buf).map_err(|err| err.utf8_error()) + } + } +} + +impl RpcCodec for Vec { + type Error = Infallible; + + fn encode_value(&self, out: &mut B) { + out.put_slice(self.as_slice()); + } + + fn decode_value(bytes: &mut B) -> Result { + let len = bytes.remaining(); + let mut buf = vec![0; len]; + bytes.copy_to_slice(&mut buf); + Ok(buf) + } +} + const LENGTH_SIZE: usize = 8; pub fn encode_value_part>(value: &T, out: &mut B) { diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 36292e4e..afb2bd3b 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -1,17 +1,18 @@ use std::{ cell::RefCell, rc::Rc, + str::Utf8Error, sync::{Arc, Mutex}, time::Duration, }; use bytes::{Buf, BufMut, Bytes}; use futures_lite::StreamExt; -use ql_rpc::{RouteId, StreamCloseCode}; +use ql_rpc::{Response, RouteId, StreamCloseCode}; use ql_wire::RouteId as WireRouteId; use super::*; -use crate::ByteWriter; +use crate::{rpc::Router, ByteWriter}; #[derive(Debug, Clone, PartialEq, Eq)] struct BytesValue(Vec); @@ -32,9 +33,11 @@ struct Echo; impl ql_rpc::request::Request for Echo { const METHOD: RouteId = RouteId::from_u32(51); - type Error = core::convert::Infallible; - type Request = BytesValue; - type Response = BytesValue; + + type Error = Utf8Error; + + type Request = String; + type Response = String; } struct Feed; @@ -77,18 +80,15 @@ async fn rpc_request_round_trips() { assert_eq!(request, BytesValue(b"hello".to_vec())); let mut encoded = Vec::new(); - ql_rpc::request::encode_response::(&BytesValue(b"world".to_vec()), &mut encoded); + ql_rpc::request::encode_response::(&"world".into(), &mut encoded); let mut writer = inbound.writer; writer.write(Bytes::from(encoded)).await.unwrap(); writer.finish(); }); let rpc = pair.handle(Side::A).rpc(); - let response = rpc - .request::(&BytesValue(b"hello".to_vec())) - .await - .unwrap(); - assert_eq!(response, BytesValue(b"world".to_vec())); + let response = rpc.request::(&"hello".into()).await.unwrap(); + assert_eq!(response, "world"); tokio::time::timeout(Duration::from_secs(2), responder) .await @@ -102,19 +102,15 @@ async fn rpc_request_round_trips() { async fn rpc_router_handles_request() { #[derive(Clone)] struct RouterState { - seen: Rc>>>, + seen: Rc>>, } - impl crate::rpc::RequestHandler for RouterState { - fn handle( - self, - request: BytesValue, - response: crate::rpc::Response, - ) { + impl crate::rpc::RequestHandler for RouterState { + fn handle(self, request: String, response: Response) { let seen = self.seen.clone(); tokio::task::spawn_local(async move { - seen.borrow_mut().push(request.0); - let _ = response.respond(BytesValue(b"world".to_vec())).await; + seen.borrow_mut().push(request); + let _ = response.respond("world".into()).await; }); } } @@ -124,7 +120,8 @@ async fn rpc_router_handles_request() { pair.connect_and_wait(Side::A).await; let inbound_b = pair.take_inbound(Side::B); let seen = Rc::new(RefCell::new(Vec::new())); - let router = crate::rpc::Router::builder() + + let router = Router::builder() .request::() .build(RouterState { seen: seen.clone() }); @@ -136,12 +133,9 @@ async fn rpc_router_handles_request() { }); let rpc = pair.handle(Side::A).rpc(); - let response = rpc - .request::(&BytesValue(b"hello".to_vec())) - .await - .unwrap(); - assert_eq!(response, BytesValue(b"world".to_vec())); - assert_eq!(&*seen.borrow(), &[b"hello".to_vec()]); + let response = rpc.request::(&"hello".into()).await.unwrap(); + assert_eq!(response, "world"); + assert_eq!(&*seen.borrow(), &["hello".to_string()]); tokio::time::timeout(Duration::from_secs(2), responder) .await @@ -155,19 +149,15 @@ async fn rpc_router_handles_request() { async fn rpc_send_router_handles_request() { #[derive(Clone)] struct RouterState { - seen: Arc>>>, + seen: Arc>>, } impl crate::rpc::RequestHandler for RouterState { - fn handle( - self, - request: BytesValue, - response: crate::rpc::Response, - ) { + fn handle(self, request: String, response: crate::rpc::Response) { let seen = self.seen.clone(); tokio::task::spawn(async move { - seen.lock().unwrap().push(request.0); - let _ = response.respond(BytesValue(b"world".to_vec())).await; + seen.lock().unwrap().push(request); + let _ = response.respond("world".into()).await; }); } } @@ -190,12 +180,9 @@ async fn rpc_send_router_handles_request() { }); let rpc = pair.handle(Side::A).rpc(); - let response = rpc - .request::(&BytesValue(b"hello".to_vec())) - .await - .unwrap(); - assert_eq!(response, BytesValue(b"world".to_vec())); - assert_eq!(&*seen.lock().unwrap(), &[b"hello".to_vec()]); + let response = rpc.request::(&"hello".into()).await.unwrap(); + assert_eq!(response, "world"); + assert_eq!(&*seen.lock().unwrap(), &["hello".to_string()]); tokio::time::timeout(Duration::from_secs(2), responder) .await @@ -213,8 +200,8 @@ async fn rpc_router_enforces_max_request_bytes() { impl crate::rpc::RequestHandler for LimitedState { fn handle( self, - request: BytesValue, - response: crate::rpc::Response, + request: String, + response: crate::rpc::Response, ) { tokio::task::spawn_local(async move { let _ = response.respond(request).await; @@ -239,7 +226,7 @@ async fn rpc_router_enforces_max_request_bytes() { }); let rpc = pair.handle(Side::A).rpc(); - let response = rpc.request::(&BytesValue(b"hello".to_vec())).await; + let response = rpc.request::(&"hello".to_string()).await; assert!(matches!( response, Err(crate::rpc::RpcError::Closed(code)) if code == StreamCloseCode::LIMIT From 17f893561f13e88c43d1365cb34d3838bb035d39 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 13 Apr 2026 08:37:17 -0400 Subject: [PATCH 201/304] ql: rename constant to ROUTE --- ql-rpc/src/router/builder.rs | 2 +- ql-rpc/src/rpc/notification.rs | 2 +- ql-rpc/src/rpc/request.rs | 2 +- ql-rpc/src/rpc/request_with_progress.rs | 4 ++-- ql-rpc/src/rpc/subscription.rs | 4 ++-- ql-runtime/src/rpc/mod.rs | 8 ++++---- ql-runtime/src/tests/rpc.rs | 12 ++++++------ 7 files changed, 17 insertions(+), 17 deletions(-) diff --git a/ql-rpc/src/router/builder.rs b/ql-rpc/src/router/builder.rs index ababa61c..8206010e 100644 --- a/ql-rpc/src/router/builder.rs +++ b/ql-rpc/src/router/builder.rs @@ -73,7 +73,7 @@ where Mode: RequestRouteMode, { self.add_route( - M::METHOD, + M::ROUTE, >::handle_request, ) } diff --git a/ql-rpc/src/rpc/notification.rs b/ql-rpc/src/rpc/notification.rs index 8bc3e069..7db5656d 100644 --- a/ql-rpc/src/rpc/notification.rs +++ b/ql-rpc/src/rpc/notification.rs @@ -3,7 +3,7 @@ use bytes::BufMut; use crate::{codec, RouteId, RpcCodec}; pub trait Notification { - const METHOD: RouteId; + const ROUTE: RouteId; type Error; type Event: RpcCodec; } diff --git a/ql-rpc/src/rpc/request.rs b/ql-rpc/src/rpc/request.rs index 82a4f180..8190aa3e 100644 --- a/ql-rpc/src/rpc/request.rs +++ b/ql-rpc/src/rpc/request.rs @@ -3,7 +3,7 @@ use bytes::BufMut; use crate::{codec, RouteId, RpcCodec}; pub trait Request { - const METHOD: RouteId; + const ROUTE: RouteId; type Error; type Request: RpcCodec; type Response: RpcCodec; diff --git a/ql-rpc/src/rpc/request_with_progress.rs b/ql-rpc/src/rpc/request_with_progress.rs index 159271a4..c46b97c0 100644 --- a/ql-rpc/src/rpc/request_with_progress.rs +++ b/ql-rpc/src/rpc/request_with_progress.rs @@ -5,7 +5,7 @@ use bytes::{BufMut, Bytes}; use crate::{codec, CodecError, Error, RouteId, RpcCodec}; pub trait RequestWithProgress { - const METHOD: RouteId; + const ROUTE: RouteId; type Error; type Request: RpcCodec; type Progress: RpcCodec; @@ -140,7 +140,7 @@ mod tests { struct Watch; impl RequestWithProgress for Watch { - const METHOD: RouteId = RouteId::from_u32(11); + const ROUTE: RouteId = RouteId::from_u32(11); type Error = core::convert::Infallible; type Request = BytesValue; type Progress = BytesValue; diff --git a/ql-rpc/src/rpc/subscription.rs b/ql-rpc/src/rpc/subscription.rs index 460012cf..d957ce9b 100644 --- a/ql-rpc/src/rpc/subscription.rs +++ b/ql-rpc/src/rpc/subscription.rs @@ -5,7 +5,7 @@ use bytes::{Buf, BufMut, Bytes}; use crate::{codec, CodecError, Error, RouteId, RpcCodec}; pub trait Subscription { - const METHOD: RouteId; + const ROUTE: RouteId; type Error; type Request: RpcCodec; type Event: RpcCodec; @@ -110,7 +110,7 @@ mod tests { struct Feed; impl Subscription for Feed { - const METHOD: RouteId = RouteId::from_u32(17); + const ROUTE: RouteId = RouteId::from_u32(17); type Error = core::convert::Infallible; type Request = BytesValue; type Event = BytesValue; diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index 576eb682..5c16d18c 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -31,7 +31,7 @@ impl RpcHandle { notification::encode_event::(event, &mut payload); let mut stream = self .inner - .open_stream(adapter::to_wire_route_id(M::METHOD)) + .open_stream(adapter::to_wire_route_id(M::ROUTE)) .await?; stream.reader.close(ql_wire::StreamCloseCode::CANCELLED); stream.writer.write(Bytes::from(payload)).await?; @@ -45,7 +45,7 @@ impl RpcHandle { { let mut payload = Vec::new(); request::encode_request::(request, &mut payload); - let response = self.start_request(M::METHOD, payload).await?; + let response = self.start_request(M::ROUTE, payload).await?; read_value::(response).await } @@ -58,7 +58,7 @@ impl RpcHandle { { let mut payload = Vec::new(); rpc_subscription::encode_request::(request, &mut payload); - let response = self.start_request(M::METHOD, payload).await?; + let response = self.start_request(M::ROUTE, payload).await?; Ok(Subscription { stream: response, reader: Some(rpc_subscription::ResponseReader::new()), @@ -74,7 +74,7 @@ impl RpcHandle { { let mut payload = Vec::new(); rpc_request_with_progress::encode_request::(request, &mut payload); - let response = self.start_request(M::METHOD, payload).await?; + let response = self.start_request(M::ROUTE, payload).await?; Ok(ProgressCall { stream: response, reader: Some(rpc_request_with_progress::ResponseReader::new()), diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index afb2bd3b..5a257807 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -32,7 +32,7 @@ impl ql_rpc::RpcCodec for BytesValue { struct Echo; impl ql_rpc::request::Request for Echo { - const METHOD: RouteId = RouteId::from_u32(51); + const ROUTE: RouteId = RouteId::from_u32(51); type Error = Utf8Error; @@ -43,7 +43,7 @@ impl ql_rpc::request::Request for Echo { struct Feed; impl ql_rpc::subscription::Subscription for Feed { - const METHOD: RouteId = RouteId::from_u32(52); + const ROUTE: RouteId = RouteId::from_u32(52); type Error = core::convert::Infallible; type Request = BytesValue; type Event = BytesValue; @@ -52,7 +52,7 @@ impl ql_rpc::subscription::Subscription for Feed { struct Download; impl ql_rpc::request_with_progress::RequestWithProgress for Download { - const METHOD: RouteId = RouteId::from_u32(53); + const ROUTE: RouteId = RouteId::from_u32(53); type Error = core::convert::Infallible; type Request = BytesValue; type Progress = BytesValue; @@ -75,7 +75,7 @@ async fn rpc_request_round_trips() { let request: BytesValue = read_rpc_value(inbound.reader).await; assert_eq!( inbound.route_id, - to_wire_route_id(::METHOD) + to_wire_route_id(::ROUTE) ); assert_eq!(request, BytesValue(b"hello".to_vec())); @@ -252,7 +252,7 @@ async fn rpc_subscription_streams_events() { let request: BytesValue = read_rpc_value(inbound.reader).await; assert_eq!( inbound.route_id, - to_wire_route_id(::METHOD) + to_wire_route_id(::ROUTE) ); assert_eq!(request, BytesValue(b"watch".to_vec())); @@ -302,7 +302,7 @@ async fn rpc_request_with_progress_supports_progress_then_await() { assert_eq!( inbound.route_id, to_wire_route_id( - ::METHOD + ::ROUTE ) ); assert_eq!(request, BytesValue(b"logo".to_vec())); From 72bd045e694a2a6ff33b17af1066986fbe2ac292 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 13 Apr 2026 10:01:17 -0400 Subject: [PATCH 202/304] ql: unsafe sync impl --- ql-runtime/src/handle/reader.rs | 5 +++++ ql-runtime/src/handle/writer.rs | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/ql-runtime/src/handle/reader.rs b/ql-runtime/src/handle/reader.rs index 504fdb4e..0b4b709f 100644 --- a/ql-runtime/src/handle/reader.rs +++ b/ql-runtime/src/handle/reader.rs @@ -25,6 +25,11 @@ enum TerminalState { Delivered, } +// Safety: `ByteReader` contains a `oneshot::Receiver`, which is `!Sync`, but that receiver is +// fully encapsulated. No safe API accesses it through `&self`; all access requires `&mut self` +// or ownership, so shared references cannot race the receiver state across threads. +unsafe impl Sync for ByteReader {} + impl std::fmt::Debug for ByteReader { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("InboundByteStream") diff --git a/ql-runtime/src/handle/writer.rs b/ql-runtime/src/handle/writer.rs index 7ec6ac9d..0331a2ed 100644 --- a/ql-runtime/src/handle/writer.rs +++ b/ql-runtime/src/handle/writer.rs @@ -28,6 +28,11 @@ enum WriteTerminalState { Terminal(QlStreamError), } +// Safety: `ByteWriter` contains a `oneshot::Receiver`, which is `!Sync`, but that receiver is +// fully encapsulated. No safe API accesses it through `&self`; all access requires `&mut self` +// or ownership, so shared references cannot race the receiver state across threads. +unsafe impl Sync for ByteWriter {} + impl std::fmt::Debug for ByteWriter { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("OutboundByteStream") From 2d2f343f18a64d977f77ab23e78aea6a2c622afe Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 13 Apr 2026 12:34:40 -0400 Subject: [PATCH 203/304] ql: subscription --- ql-rpc/src/codec.rs | 5 -- ql-rpc/src/router/builder.rs | 19 +++- ql-rpc/src/router/mod.rs | 2 + ql-rpc/src/router/request.rs | 2 +- ql-rpc/src/router/subscription.rs | 137 +++++++++++++++++++++++++++++ ql-rpc/src/rpc/subscription.rs | 72 +++++++++++---- ql-runtime/src/rpc/adapter.rs | 1 + ql-runtime/src/rpc/subscription.rs | 6 +- ql-runtime/src/tests/rpc.rs | 67 +++++++++++++- 9 files changed, 282 insertions(+), 29 deletions(-) create mode 100644 ql-rpc/src/router/subscription.rs diff --git a/ql-rpc/src/codec.rs b/ql-rpc/src/codec.rs index fa4e8b25..65b359f5 100644 --- a/ql-rpc/src/codec.rs +++ b/ql-rpc/src/codec.rs @@ -289,11 +289,6 @@ fn read_next_part_len(bytes: &mut B) -> Result, Error> { Ok(Some(len)) } -pub fn push_length(out: &mut B, len: usize) { - let len = u64::try_from(len).expect("rpc payload exceeds u64 length framing"); - out.put_u64_le(len); -} - pub fn reserve_length>(out: &mut B) -> usize { let start = out.as_mut().len(); out.put_u64_le(0); diff --git a/ql-rpc/src/router/builder.rs b/ql-rpc/src/router/builder.rs index 8206010e..c88fffb8 100644 --- a/ql-rpc/src/router/builder.rs +++ b/ql-rpc/src/router/builder.rs @@ -2,9 +2,13 @@ use std::collections::HashMap; use super::{ request::{RequestHandler, RequestRouteMode}, + subscription::{SubscriptionHandler, SubscriptionRouteMode}, LocalMode, RouteMode, Router, RouterConfig, RpcStream, }; -use crate::{request::Request as RequestRpc, router::RouteFn, RouteId}; +use crate::{ + request::Request as RequestRpc, router::RouteFn, subscription::Subscription as SubscriptionRpc, + RouteId, +}; pub struct RouterBuilder where @@ -77,4 +81,17 @@ where >::handle_request, ) } + + pub fn subscription(self) -> Self + where + M: SubscriptionRpc + 'static, + S: SubscriptionHandler + 'static, + St: RpcStream + 'static, + Mode: SubscriptionRouteMode, + { + self.add_route( + M::ROUTE, + >::handle_subscription, + ) + } } diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs index 61fc9eb9..bccf5ccb 100644 --- a/ql-rpc/src/router/mod.rs +++ b/ql-rpc/src/router/mod.rs @@ -7,6 +7,7 @@ mod config; mod mode; mod request; mod stream; +mod subscription; pub use self::{ builder::RouterBuilder, @@ -14,6 +15,7 @@ pub use self::{ mode::*, request::{RequestHandler, Response}, stream::{RpcRead, RpcStream, RpcWrite}, + subscription::{SubscriptionHandler, SubscriptionResponder}, }; pub struct Router diff --git a/ql-rpc/src/router/request.rs b/ql-rpc/src/router/request.rs index 077e9676..66882edf 100644 --- a/ql-rpc/src/router/request.rs +++ b/ql-rpc/src/router/request.rs @@ -131,7 +131,7 @@ async fn handle_request_inner( state.handle(request, Response::new(writer)); } -async fn read_value_and_eof( +pub(super) async fn read_value_and_eof( reader: &mut R, config: RouterConfig, ) -> Result diff --git a/ql-rpc/src/router/subscription.rs b/ql-rpc/src/router/subscription.rs new file mode 100644 index 00000000..3bd5f3a6 --- /dev/null +++ b/ql-rpc/src/router/subscription.rs @@ -0,0 +1,137 @@ +use std::marker::PhantomData; + +use bytes::Bytes; + +use super::{ + request::read_value_and_eof, + stream::{write_bytes, RpcRead, RpcStream, RpcWrite}, + LocalMode, RouteMode, RouterConfig, SendMode, +}; +use crate::{codec, subscription::Subscription as SubscriptionRpc, RpcCodec, StreamCloseCode}; + +pub trait SubscriptionHandler +where + M: SubscriptionRpc, + St: RpcStream, +{ + fn handle(self, message: M::Request, responder: SubscriptionResponder); +} + +pub struct SubscriptionResponder +where + W: RpcWrite, +{ + writer: Option, + marker: PhantomData T>, +} + +impl SubscriptionResponder +where + T: RpcCodec, + W: RpcWrite, +{ + fn new(writer: W) -> Self { + Self { + writer: Some(writer), + marker: PhantomData, + } + } + + pub async fn send(&mut self, event: T) -> Result<(), StreamCloseCode> { + let writer = self.writer.as_mut().expect("subscription writer exists"); + let mut encoded = Vec::new(); + codec::encode_value_part(&event, &mut encoded); + if let Err(code) = write_bytes(writer, Bytes::from(encoded)).await { + let writer = self.writer.take().expect("subscription writer exists"); + writer.close(code); + return Err(code); + } + Ok(()) + } + + pub async fn finish(mut self) -> Result<(), StreamCloseCode> { + let writer = self.writer.take().expect("subscription writer exists"); + writer.finish(); + Ok(()) + } + + pub fn close(mut self, code: StreamCloseCode) { + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for SubscriptionResponder +where + W: RpcWrite, +{ + fn drop(&mut self) { + if let Some(writer) = self.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } +} + +#[doc(hidden)] +pub trait SubscriptionRouteMode: RouteMode +where + M: SubscriptionRpc + 'static, + S: SubscriptionHandler + 'static, + St: RpcStream + 'static, +{ + fn handle_subscription(state: S, config: RouterConfig, stream: St) -> Self::RouteFuture; +} + +impl SubscriptionRouteMode for LocalMode +where + M: SubscriptionRpc + 'static, + S: SubscriptionHandler + 'static, + St: RpcStream + 'static, +{ + fn handle_subscription(state: S, config: RouterConfig, stream: St) -> Self::RouteFuture { + let (reader, writer) = stream.split(); + Box::pin(handle_subscription_inner::( + state, config, reader, writer, + )) + } +} + +impl SubscriptionRouteMode for SendMode +where + M: SubscriptionRpc + 'static, + M::Request: Send + 'static, + S: SubscriptionHandler + Send + 'static, + St: RpcStream + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, +{ + fn handle_subscription(state: S, config: RouterConfig, stream: St) -> Self::RouteFuture { + let (reader, writer) = stream.split(); + Box::pin(handle_subscription_inner::( + state, config, reader, writer, + )) + } +} + +async fn handle_subscription_inner( + state: S, + config: RouterConfig, + mut reader: St::Reader, + writer: St::Writer, +) where + M: SubscriptionRpc + 'static, + S: SubscriptionHandler + 'static, + St: RpcStream + 'static, +{ + let request = match read_value_and_eof::(&mut reader, config).await { + Ok(request) => request, + Err(code) => { + reader.close(code); + writer.close(code); + return; + } + }; + + state.handle(request, SubscriptionResponder::new(writer)); +} diff --git a/ql-rpc/src/rpc/subscription.rs b/ql-rpc/src/rpc/subscription.rs index d957ce9b..7b0c1e67 100644 --- a/ql-rpc/src/rpc/subscription.rs +++ b/ql-rpc/src/rpc/subscription.rs @@ -1,8 +1,8 @@ use std::marker::PhantomData; -use bytes::{Buf, BufMut, Bytes}; +use bytes::{BufMut, Bytes}; -use crate::{codec, CodecError, Error, RouteId, RpcCodec}; +use crate::{codec, CodecError, RouteId, RpcCodec}; pub trait Subscription { const ROUTE: RouteId; @@ -17,7 +17,6 @@ pub enum ReadStep { value: M::Event, next: ResponseReader, }, - End, } pub struct ResponseReader { @@ -44,20 +43,16 @@ impl ResponseReader { self } + pub fn is_empty(&self) -> bool { + self.bytes.remaining() == 0 + } + pub fn advance(self) -> Result, CodecError> { let mut this = self; let Some(mut body) = this.bytes.try_take_part().map_err(CodecError::Rpc)? else { return Ok(ReadStep::NeedMore(this)); }; - if body.remaining() == 0 { - drop(body); - if this.bytes.remaining() == 0 { - return Ok(ReadStep::End); - } - return Err(CodecError::Rpc(Error::TrailingBytes)); - } - let item = { let item = M::Event::decode_value(&mut body).map_err(CodecError::Codec)?; drop(body); @@ -81,15 +76,11 @@ pub fn encode_item(item: &M::Event, out: &mut (impl BufMut + As codec::encode_value_part(item, out) } -pub fn encode_end(out: &mut impl BufMut) { - codec::push_length(out, 0); -} - #[cfg(test)] mod tests { use bytes::{Buf, BufMut, Bytes}; - use super::{encode_end, encode_item, ReadStep, ResponseReader, Subscription}; + use super::{encode_item, ReadStep, ResponseReader, Subscription}; use crate::{RouteId, RpcCodec}; #[derive(Debug, Clone, PartialEq, Eq)] @@ -121,7 +112,6 @@ mod tests { let mut encoded = Vec::new(); encode_item::(&BytesValue(b"one".to_vec()), &mut encoded); encode_item::(&BytesValue(b"two".to_vec()), &mut encoded); - encode_end(&mut encoded); let reader = match ResponseReader::::new() .push(Bytes::from(encoded)) @@ -143,6 +133,52 @@ mod tests { _ => unreachable!(), }; - assert!(matches!(reader.advance().unwrap(), ReadStep::End)); + match reader.advance().unwrap() { + ReadStep::NeedMore(next) => assert!(next.is_empty()), + _ => unreachable!(), + } + } + + #[test] + fn response_reader_waits_for_transport_eof_when_no_end_frame_is_present() { + let mut encoded = Vec::new(); + encode_item::(&BytesValue(b"one".to_vec()), &mut encoded); + + let reader = match ResponseReader::::new() + .push(Bytes::from(encoded)) + .advance() + .unwrap() + { + ReadStep::Item { value, next } => { + assert_eq!(value, BytesValue(b"one".to_vec())); + next + } + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + ReadStep::NeedMore(next) => assert!(next.is_empty()), + _ => unreachable!(), + } + } + + #[test] + fn response_reader_allows_empty_event_payloads() { + let mut encoded = Vec::new(); + encode_item::(&BytesValue(Vec::new()), &mut encoded); + + match ResponseReader::::new() + .push(Bytes::from(encoded)) + .advance() + .unwrap() + { + ReadStep::Item { value, next } => { + assert_eq!(value, BytesValue(Vec::new())); + assert!( + matches!(next.advance().unwrap(), ReadStep::NeedMore(reader) if reader.is_empty()) + ); + } + _ => unreachable!(), + } } } diff --git a/ql-runtime/src/rpc/adapter.rs b/ql-runtime/src/rpc/adapter.rs index a391f7d2..780708a0 100644 --- a/ql-runtime/src/rpc/adapter.rs +++ b/ql-runtime/src/rpc/adapter.rs @@ -3,6 +3,7 @@ use std::task::{Context, Poll}; use bytes::Bytes; pub use ql_rpc::{ LocalMode, RequestHandler, Response, RouteId, RouterConfig, SendMode, StreamCloseCode, + SubscriptionHandler, SubscriptionResponder, }; use ql_rpc::{RpcRead, RpcStream, RpcWrite}; use ql_wire::{RouteId as WireRouteId, StreamCloseCode as WireStreamCloseCode}; diff --git a/ql-runtime/src/rpc/subscription.rs b/ql-runtime/src/rpc/subscription.rs index 40f00537..dc74afa0 100644 --- a/ql-runtime/src/rpc/subscription.rs +++ b/ql-runtime/src/rpc/subscription.rs @@ -47,7 +47,6 @@ where this.reader = Some(next); return Poll::Ready(Some(Ok(value))); } - Ok(ReadStep::End) => return Poll::Ready(None), Ok(ReadStep::NeedMore(next)) => { this.reader = Some(next); } @@ -60,7 +59,10 @@ where this.reader = Some(reader.push(chunk)); } Poll::Ready(Ok(None)) => { - this.reader = None; + let reader = this.reader.take().expect("subscription reader is present"); + if reader.is_empty() { + return Poll::Ready(None); + } return Poll::Ready(Some(Err(Error::Truncated.into()))); } Poll::Ready(Err(error)) => { diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 5a257807..2a0c279f 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -8,7 +8,7 @@ use std::{ use bytes::{Buf, BufMut, Bytes}; use futures_lite::StreamExt; -use ql_rpc::{Response, RouteId, StreamCloseCode}; +use ql_rpc::{Response, RouteId, StreamCloseCode, SubscriptionResponder}; use ql_wire::RouteId as WireRouteId; use super::*; @@ -145,6 +145,70 @@ async fn rpc_router_handles_request() { .await; } +#[tokio::test(flavor = "current_thread")] +async fn rpc_router_handles_subscription() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>, + } + + impl crate::rpc::SubscriptionHandler for RouterState { + fn handle( + self, + request: BytesValue, + mut response: SubscriptionResponder, + ) { + let seen = self.seen.clone(); + tokio::task::spawn_local(async move { + seen.borrow_mut().push(request); + let _ = response.send(BytesValue(b"one".to_vec())).await; + let _ = response.send(BytesValue(b"two".to_vec())).await; + let _ = response.finish().await; + }); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + + let seen = Rc::new(RefCell::new(Vec::new())); + let router = crate::rpc::Router::builder() + .subscription::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await; + } + }); + + let rpc = pair.handle(Side::A).rpc(); + let mut subscription = rpc + .subscribe::(&BytesValue(b"watch".to_vec())) + .await + .unwrap(); + assert_eq!( + subscription.next().await.unwrap().unwrap(), + BytesValue(b"one".to_vec()) + ); + assert_eq!( + subscription.next().await.unwrap().unwrap(), + BytesValue(b"two".to_vec()) + ); + assert!(subscription.next().await.is_none()); + assert_eq!(seen.borrow().as_slice(), &[BytesValue(b"watch".to_vec())]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + #[tokio::test(flavor = "current_thread")] async fn rpc_send_router_handles_request() { #[derive(Clone)] @@ -259,7 +323,6 @@ async fn rpc_subscription_streams_events() { let mut encoded = Vec::new(); ql_rpc::subscription::encode_item::(&BytesValue(b"one".to_vec()), &mut encoded); ql_rpc::subscription::encode_item::(&BytesValue(b"two".to_vec()), &mut encoded); - ql_rpc::subscription::encode_end(&mut encoded); let mut writer = inbound.writer; writer.write(Bytes::from(encoded)).await.unwrap(); From a96cb5d54a8d27b34955a6b0fbdd18ec7550aede Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 13 Apr 2026 12:48:03 -0400 Subject: [PATCH 204/304] ql: codec for bytes --- ql-rpc/src/codec.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ql-rpc/src/codec.rs b/ql-rpc/src/codec.rs index 65b359f5..3a8cb8e9 100644 --- a/ql-rpc/src/codec.rs +++ b/ql-rpc/src/codec.rs @@ -47,6 +47,18 @@ impl RpcCodec for Vec { } } +impl RpcCodec for Bytes { + type Error = Infallible; + + fn encode_value(&self, out: &mut B) { + out.put_slice(self.as_ref()); + } + + fn decode_value(bytes: &mut B) -> Result { + Ok(bytes.copy_to_bytes(bytes.remaining())) + } +} + const LENGTH_SIZE: usize = 8; pub fn encode_value_part>(value: &T, out: &mut B) { From ae90c4a943deedcd60df04624622c082eeabb888 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 13 Apr 2026 12:51:19 -0400 Subject: [PATCH 205/304] ql: test cleanup --- ql-rpc/src/codec.rs | 32 ++------ ql-rpc/src/rpc/request_with_progress.rs | 41 ++++------ ql-rpc/src/rpc/subscription.rs | 43 ++++------- ql-runtime/src/tests/rpc.rs | 99 +++++++++---------------- 4 files changed, 68 insertions(+), 147 deletions(-) diff --git a/ql-rpc/src/codec.rs b/ql-rpc/src/codec.rs index 3a8cb8e9..87375c4b 100644 --- a/ql-rpc/src/codec.rs +++ b/ql-rpc/src/codec.rs @@ -98,7 +98,7 @@ impl ValueReader { pub fn advance(self) -> Result, CodecError> { let mut this = self; - let Some(mut body) = this.bytes.try_take_part().map_err(CodecError::Rpc)? else { + let Some(mut body) = this.bytes.try_take_part()? else { return Ok(ReadValueStep::NeedMore(this)); }; @@ -317,37 +317,21 @@ pub fn backpatch_length + ?Sized>(out: &mut B, start: usize) { #[cfg(test)] mod tests { - use bytes::{Buf, BufMut, Bytes}; + use bytes::Bytes; use super::{encode_value_part, ReadValueStep, ValueReader}; - use crate::RpcCodec; - - #[derive(Debug, Clone, PartialEq, Eq)] - struct BytesValue(Vec); - - impl RpcCodec for BytesValue { - type Error = core::convert::Infallible; - - fn encode_value(&self, out: &mut B) { - out.put_slice(&self.0); - } - - fn decode_value(bytes: &mut B) -> Result { - Ok(Self(bytes.copy_to_bytes(bytes.remaining()).to_vec())) - } - } #[test] fn value_reader_round_trips_framed_values() { let mut encoded = Vec::new(); - encode_value_part(&BytesValue(b"hello".to_vec()), &mut encoded); + encode_value_part(&b"hello".to_vec(), &mut encoded); - match ValueReader::::new() + match ValueReader::>::new() .push(Bytes::from(encoded)) .advance() .unwrap() { - ReadValueStep::Value(value) => assert_eq!(value, BytesValue(b"hello".to_vec())), + ReadValueStep::Value(value) => assert_eq!(value, b"hello".to_vec()), _ => unreachable!(), } } @@ -355,10 +339,10 @@ mod tests { #[test] fn value_reader_waits_for_complete_frame() { let mut encoded = Vec::new(); - encode_value_part(&BytesValue(b"hello".to_vec()), &mut encoded); + encode_value_part(&b"hello".to_vec(), &mut encoded); let encoded = Bytes::from(encoded); - let reader = match ValueReader::::new() + let reader = match ValueReader::>::new() .push(encoded.slice(..4)) .advance() .unwrap() @@ -368,7 +352,7 @@ mod tests { }; match reader.push(encoded.slice(4..)).advance().unwrap() { - ReadValueStep::Value(value) => assert_eq!(value, BytesValue(b"hello".to_vec())), + ReadValueStep::Value(value) => assert_eq!(value, b"hello".to_vec()), _ => unreachable!(), } } diff --git a/ql-rpc/src/rpc/request_with_progress.rs b/ql-rpc/src/rpc/request_with_progress.rs index c46b97c0..c1119840 100644 --- a/ql-rpc/src/rpc/request_with_progress.rs +++ b/ql-rpc/src/rpc/request_with_progress.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use bytes::{BufMut, Bytes}; -use crate::{codec, CodecError, Error, RouteId, RpcCodec}; +use crate::{CodecError, Error, RouteId, RpcCodec, codec}; pub trait RequestWithProgress { const ROUTE: RouteId; @@ -117,41 +117,26 @@ fn encode_tagged_value_part>( #[cfg(test)] mod tests { - use bytes::{Buf, BufMut, Bytes}; + use bytes::Bytes; - use super::{encode_progress, encode_response, ReadStep, RequestWithProgress, ResponseReader}; - use crate::{RouteId, RpcCodec}; - - #[derive(Debug, Clone, PartialEq, Eq)] - struct BytesValue(Vec); - - impl RpcCodec for BytesValue { - type Error = core::convert::Infallible; - - fn encode_value(&self, out: &mut B) { - out.put_slice(&self.0); - } - - fn decode_value(bytes: &mut B) -> Result { - Ok(Self(bytes.copy_to_bytes(bytes.remaining()).to_vec())) - } - } + use super::{ReadStep, RequestWithProgress, ResponseReader, encode_progress, encode_response}; + use crate::RouteId; struct Watch; impl RequestWithProgress for Watch { const ROUTE: RouteId = RouteId::from_u32(11); type Error = core::convert::Infallible; - type Request = BytesValue; - type Progress = BytesValue; - type Response = BytesValue; + type Request = Vec; + type Progress = Vec; + type Response = Vec; } #[test] fn response_reader_emits_progress_then_response() { let mut encoded = Vec::new(); - encode_progress::(&BytesValue(b"10%".to_vec()), &mut encoded); - encode_response::(&BytesValue(b"done".to_vec()), &mut encoded); + encode_progress::(&b"10%".to_vec(), &mut encoded); + encode_response::(&b"done".to_vec(), &mut encoded); let reader = match ResponseReader::::new() .push(Bytes::from(encoded)) @@ -159,13 +144,13 @@ mod tests { .unwrap() { ReadStep::Progress { value, next } => { - assert_eq!(value, BytesValue(b"10%".to_vec())); + assert_eq!(value, b"10%".to_vec()); next } _ => unreachable!(), }; match reader.advance().unwrap() { - ReadStep::Response(value) => assert_eq!(value, BytesValue(b"done".to_vec())), + ReadStep::Response(value) => assert_eq!(value, b"done".to_vec()), _ => unreachable!(), } } @@ -173,14 +158,14 @@ mod tests { #[test] fn response_reader_handles_response_only() { let mut encoded = Vec::new(); - encode_response::(&BytesValue(b"done".to_vec()), &mut encoded); + encode_response::(&b"done".to_vec(), &mut encoded); match ResponseReader::::new() .push(Bytes::from(encoded)) .advance() .unwrap() { - ReadStep::Response(value) => assert_eq!(value, BytesValue(b"done".to_vec())), + ReadStep::Response(value) => assert_eq!(value, b"done".to_vec()), _ => unreachable!(), } } diff --git a/ql-rpc/src/rpc/subscription.rs b/ql-rpc/src/rpc/subscription.rs index 7b0c1e67..bf96d45f 100644 --- a/ql-rpc/src/rpc/subscription.rs +++ b/ql-rpc/src/rpc/subscription.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use bytes::{BufMut, Bytes}; -use crate::{codec, CodecError, RouteId, RpcCodec}; +use crate::{CodecError, RouteId, RpcCodec, codec}; pub trait Subscription { const ROUTE: RouteId; @@ -78,40 +78,25 @@ pub fn encode_item(item: &M::Event, out: &mut (impl BufMut + As #[cfg(test)] mod tests { - use bytes::{Buf, BufMut, Bytes}; + use bytes::Bytes; - use super::{encode_item, ReadStep, ResponseReader, Subscription}; - use crate::{RouteId, RpcCodec}; - - #[derive(Debug, Clone, PartialEq, Eq)] - struct BytesValue(Vec); - - impl RpcCodec for BytesValue { - type Error = core::convert::Infallible; - - fn encode_value(&self, out: &mut B) { - out.put_slice(&self.0); - } - - fn decode_value(bytes: &mut B) -> Result { - Ok(Self(bytes.copy_to_bytes(bytes.remaining()).to_vec())) - } - } + use super::{ReadStep, ResponseReader, Subscription, encode_item}; + use crate::RouteId; struct Feed; impl Subscription for Feed { const ROUTE: RouteId = RouteId::from_u32(17); type Error = core::convert::Infallible; - type Request = BytesValue; - type Event = BytesValue; + type Request = Vec; + type Event = Vec; } #[test] fn response_reader_streams_items_until_end() { let mut encoded = Vec::new(); - encode_item::(&BytesValue(b"one".to_vec()), &mut encoded); - encode_item::(&BytesValue(b"two".to_vec()), &mut encoded); + encode_item::(&b"one".to_vec(), &mut encoded); + encode_item::(&b"two".to_vec(), &mut encoded); let reader = match ResponseReader::::new() .push(Bytes::from(encoded)) @@ -119,7 +104,7 @@ mod tests { .unwrap() { ReadStep::Item { value, next } => { - assert_eq!(value, BytesValue(b"one".to_vec())); + assert_eq!(value, b"one".to_vec()); next } _ => unreachable!(), @@ -127,7 +112,7 @@ mod tests { let reader = match reader.advance().unwrap() { ReadStep::Item { value, next } => { - assert_eq!(value, BytesValue(b"two".to_vec())); + assert_eq!(value, b"two".to_vec()); next } _ => unreachable!(), @@ -142,7 +127,7 @@ mod tests { #[test] fn response_reader_waits_for_transport_eof_when_no_end_frame_is_present() { let mut encoded = Vec::new(); - encode_item::(&BytesValue(b"one".to_vec()), &mut encoded); + encode_item::(&b"one".to_vec(), &mut encoded); let reader = match ResponseReader::::new() .push(Bytes::from(encoded)) @@ -150,7 +135,7 @@ mod tests { .unwrap() { ReadStep::Item { value, next } => { - assert_eq!(value, BytesValue(b"one".to_vec())); + assert_eq!(value, b"one".to_vec()); next } _ => unreachable!(), @@ -165,7 +150,7 @@ mod tests { #[test] fn response_reader_allows_empty_event_payloads() { let mut encoded = Vec::new(); - encode_item::(&BytesValue(Vec::new()), &mut encoded); + encode_item::(&Vec::new(), &mut encoded); match ResponseReader::::new() .push(Bytes::from(encoded)) @@ -173,7 +158,7 @@ mod tests { .unwrap() { ReadStep::Item { value, next } => { - assert_eq!(value, BytesValue(Vec::new())); + assert_eq!(value, Vec::::new()); assert!( matches!(next.advance().unwrap(), ReadStep::NeedMore(reader) if reader.is_empty()) ); diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 2a0c279f..6c9e7321 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -6,7 +6,7 @@ use std::{ time::Duration, }; -use bytes::{Buf, BufMut, Bytes}; +use bytes::Bytes; use futures_lite::StreamExt; use ql_rpc::{Response, RouteId, StreamCloseCode, SubscriptionResponder}; use ql_wire::RouteId as WireRouteId; @@ -14,21 +14,6 @@ use ql_wire::RouteId as WireRouteId; use super::*; use crate::{rpc::Router, ByteWriter}; -#[derive(Debug, Clone, PartialEq, Eq)] -struct BytesValue(Vec); - -impl ql_rpc::RpcCodec for BytesValue { - type Error = core::convert::Infallible; - - fn encode_value(&self, out: &mut B) { - out.put_slice(&self.0); - } - - fn decode_value(bytes: &mut B) -> Result { - Ok(Self(bytes.copy_to_bytes(bytes.remaining()).to_vec())) - } -} - struct Echo; impl ql_rpc::request::Request for Echo { @@ -45,8 +30,8 @@ struct Feed; impl ql_rpc::subscription::Subscription for Feed { const ROUTE: RouteId = RouteId::from_u32(52); type Error = core::convert::Infallible; - type Request = BytesValue; - type Event = BytesValue; + type Request = Vec; + type Event = Vec; } struct Download; @@ -54,9 +39,9 @@ struct Download; impl ql_rpc::request_with_progress::RequestWithProgress for Download { const ROUTE: RouteId = RouteId::from_u32(53); type Error = core::convert::Infallible; - type Request = BytesValue; - type Progress = BytesValue; - type Response = BytesValue; + type Request = Vec; + type Progress = Vec; + type Response = Vec; } fn assert_send(value: T) -> T { @@ -72,12 +57,12 @@ async fn rpc_request_round_trips() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); - let request: BytesValue = read_rpc_value(inbound.reader).await; + let request: Vec = read_rpc_value(inbound.reader).await; assert_eq!( inbound.route_id, to_wire_route_id(::ROUTE) ); - assert_eq!(request, BytesValue(b"hello".to_vec())); + assert_eq!(request, b"hello".to_vec()); let mut encoded = Vec::new(); ql_rpc::request::encode_response::(&"world".into(), &mut encoded); @@ -149,20 +134,20 @@ async fn rpc_router_handles_request() { async fn rpc_router_handles_subscription() { #[derive(Clone)] struct RouterState { - seen: Rc>>, + seen: Rc>>>, } impl crate::rpc::SubscriptionHandler for RouterState { fn handle( self, - request: BytesValue, - mut response: SubscriptionResponder, + request: Vec, + mut response: SubscriptionResponder, ByteWriter>, ) { let seen = self.seen.clone(); tokio::task::spawn_local(async move { seen.borrow_mut().push(request); - let _ = response.send(BytesValue(b"one".to_vec())).await; - let _ = response.send(BytesValue(b"two".to_vec())).await; + let _ = response.send(b"one".to_vec()).await; + let _ = response.send(b"two".to_vec()).await; let _ = response.finish().await; }); } @@ -186,20 +171,11 @@ async fn rpc_router_handles_subscription() { }); let rpc = pair.handle(Side::A).rpc(); - let mut subscription = rpc - .subscribe::(&BytesValue(b"watch".to_vec())) - .await - .unwrap(); - assert_eq!( - subscription.next().await.unwrap().unwrap(), - BytesValue(b"one".to_vec()) - ); - assert_eq!( - subscription.next().await.unwrap().unwrap(), - BytesValue(b"two".to_vec()) - ); + let mut subscription = rpc.subscribe::(&b"watch".to_vec()).await.unwrap(); + assert_eq!(subscription.next().await.unwrap().unwrap(), b"one".to_vec()); + assert_eq!(subscription.next().await.unwrap().unwrap(), b"two".to_vec()); assert!(subscription.next().await.is_none()); - assert_eq!(seen.borrow().as_slice(), &[BytesValue(b"watch".to_vec())]); + assert_eq!(seen.borrow().as_slice(), &[b"watch".to_vec()]); tokio::time::timeout(Duration::from_secs(2), responder) .await @@ -313,16 +289,16 @@ async fn rpc_subscription_streams_events() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); - let request: BytesValue = read_rpc_value(inbound.reader).await; + let request: Vec = read_rpc_value(inbound.reader).await; assert_eq!( inbound.route_id, to_wire_route_id(::ROUTE) ); - assert_eq!(request, BytesValue(b"watch".to_vec())); + assert_eq!(request, b"watch".to_vec()); let mut encoded = Vec::new(); - ql_rpc::subscription::encode_item::(&BytesValue(b"one".to_vec()), &mut encoded); - ql_rpc::subscription::encode_item::(&BytesValue(b"two".to_vec()), &mut encoded); + ql_rpc::subscription::encode_item::(&b"one".to_vec(), &mut encoded); + ql_rpc::subscription::encode_item::(&b"two".to_vec(), &mut encoded); let mut writer = inbound.writer; writer.write(Bytes::from(encoded)).await.unwrap(); @@ -330,18 +306,9 @@ async fn rpc_subscription_streams_events() { }); let rpc = pair.handle(Side::A).rpc(); - let mut subscription = rpc - .subscribe::(&BytesValue(b"watch".to_vec())) - .await - .unwrap(); - assert_eq!( - subscription.next().await.unwrap().unwrap(), - BytesValue(b"one".to_vec()) - ); - assert_eq!( - subscription.next().await.unwrap().unwrap(), - BytesValue(b"two".to_vec()) - ); + let mut subscription = rpc.subscribe::(&b"watch".to_vec()).await.unwrap(); + assert_eq!(subscription.next().await.unwrap().unwrap(), b"one".to_vec()); + assert_eq!(subscription.next().await.unwrap().unwrap(), b"two".to_vec()); assert!(subscription.next().await.is_none()); tokio::time::timeout(Duration::from_secs(2), responder) @@ -361,26 +328,26 @@ async fn rpc_request_with_progress_supports_progress_then_await() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); - let request: BytesValue = read_rpc_value(inbound.reader).await; + let request: Vec = read_rpc_value(inbound.reader).await; assert_eq!( inbound.route_id, to_wire_route_id( ::ROUTE ) ); - assert_eq!(request, BytesValue(b"logo".to_vec())); + assert_eq!(request, b"logo".to_vec()); let mut encoded = Vec::new(); ql_rpc::request_with_progress::encode_progress::( - &BytesValue(b"10".to_vec()), + &b"10".to_vec(), &mut encoded, ); ql_rpc::request_with_progress::encode_progress::( - &BytesValue(b"90".to_vec()), + &b"90".to_vec(), &mut encoded, ); ql_rpc::request_with_progress::encode_response::( - &BytesValue(b"done".to_vec()), + &b"done".to_vec(), &mut encoded, ); @@ -391,14 +358,14 @@ async fn rpc_request_with_progress_supports_progress_then_await() { let rpc = pair.handle(Side::A).rpc(); let mut download = rpc - .request_with_progress::(&BytesValue(b"logo".to_vec())) + .request_with_progress::(&b"logo".to_vec()) .await .unwrap(); - assert_eq!(download.progress().await, Some(BytesValue(b"10".to_vec()))); - assert_eq!(download.progress().await, Some(BytesValue(b"90".to_vec()))); + assert_eq!(download.progress().await, Some(b"10".to_vec())); + assert_eq!(download.progress().await, Some(b"90".to_vec())); assert_eq!(download.progress().await, None); - assert_eq!(download.await.unwrap(), BytesValue(b"done".to_vec())); + assert_eq!(download.await.unwrap(), b"done".to_vec()); tokio::time::timeout(Duration::from_secs(2), responder) .await From 1b176d9af36cfc858b57291736f561ee7a9f67eb Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 13 Apr 2026 14:40:59 -0400 Subject: [PATCH 206/304] ql: finish non async --- ql-rpc/src/router/subscription.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ql-rpc/src/router/subscription.rs b/ql-rpc/src/router/subscription.rs index 3bd5f3a6..5badbbea 100644 --- a/ql-rpc/src/router/subscription.rs +++ b/ql-rpc/src/router/subscription.rs @@ -49,7 +49,7 @@ where Ok(()) } - pub async fn finish(mut self) -> Result<(), StreamCloseCode> { + pub fn finish(mut self) -> Result<(), StreamCloseCode> { let writer = self.writer.take().expect("subscription writer exists"); writer.finish(); Ok(()) From 15ddb2b8a4f588d09f62825243f3ac63789cbbe5 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 14 Apr 2026 10:53:08 -0400 Subject: [PATCH 207/304] ql-wire: add ack test --- ql-wire/src/tests.rs | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 52184a0e..022518ea 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -835,13 +835,25 @@ fn protocol_record_size_breakdown() { &session.tx_key, &[SessionFrame::Ping], ); - let session_stream_empty = encrypt_record( + let session_ack = encrypt_record( &crypto, SessionHeader { connection_id: session.tx_connection_id, seq: record_seq(2), }, &session.tx_key, + &[SessionFrame::Ack(RecordAck { + base_seq: record_seq(1), + bits: (1u64 << 0) | (1u64 << 1) | (1u64 << 5), + })], + ); + let session_stream_empty = encrypt_record( + &crypto, + SessionHeader { + connection_id: session.tx_connection_id, + seq: record_seq(3), + }, + &session.tx_key, &[SessionFrame::StreamData(StreamData { stream_id: stream_id(1), offset: varint(0), @@ -854,7 +866,7 @@ fn protocol_record_size_breakdown() { &crypto, SessionHeader { connection_id: session.tx_connection_id, - seq: record_seq(3), + seq: record_seq(4), }, &session.tx_key, &[SessionFrame::Close(SessionClose { @@ -874,6 +886,7 @@ fn protocol_record_size_breakdown() { print_size("ql-wire pq xx3", xx3.encode_vec().len()); print_size("ql-wire pq xx4", xx4.encode_vec().len()); print_size("ql-wire session ping", session_ping.encode_vec().len()); + print_size("ql-wire session ack", session_ack.encode_vec().len()); print_size( "ql-wire session stream empty", session_stream_empty.encode_vec().len(), From 0f970bf1e038a65246d2ad2a474dc736258657b5 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 14 Apr 2026 10:58:05 -0400 Subject: [PATCH 208/304] ql-wire: reader clone --- ql-wire/src/codec.rs | 1 + ql-wire/src/encrypted/mod.rs | 1 + 2 files changed, 2 insertions(+) diff --git a/ql-wire/src/codec.rs b/ql-wire/src/codec.rs index c2e4ba3c..0245ef6d 100644 --- a/ql-wire/src/codec.rs +++ b/ql-wire/src/codec.rs @@ -197,6 +197,7 @@ impl> WireDecode for Option { } } +#[derive(Clone)] pub struct Reader { remaining: Option, } diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index a25e854d..d5be4f5d 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -159,6 +159,7 @@ pub fn decode_session_frames(bytes: &[u8]) -> Result>>, .collect() } +#[derive(Clone)] pub struct SessionFrameIter { reader: Reader, } From 238ca43a55614e31f709ed9840a772ad1d45a4ed Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 14 Apr 2026 13:23:24 -0400 Subject: [PATCH 209/304] ql-wire: ack range --- ql-wire/src/encrypted/ack.rs | 354 ++++++++++++++++++++++++++++++++--- ql-wire/src/tests.rs | 25 ++- 2 files changed, 335 insertions(+), 44 deletions(-) diff --git a/ql-wire/src/encrypted/ack.rs b/ql-wire/src/encrypted/ack.rs index adceb0d7..c3e90c27 100644 --- a/ql-wire/src/encrypted/ack.rs +++ b/ql-wire/src/encrypted/ack.rs @@ -1,86 +1,378 @@ -use crate::{codec, ByteSlice, RecordSeq, WireEncode, WireError}; +use std::{fmt, ops::RangeInclusive}; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +use crate::{codec, ByteSlice, RecordSeq, VarInt, WireEncode, WireError}; + +#[derive(Debug, Clone, PartialEq, Eq)] pub struct RecordAck { - pub base_seq: RecordSeq, - pub bits: u64, + largest_acked: RecordSeq, + first_range_len: VarInt, + blocks: Box<[RecordAckBlock]>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RecordAckBlock { + gap: VarInt, + range_len: VarInt, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RecordAckRangeError { + Empty, + InvertedRange, + NotCanonical, } impl RecordAck { - pub const BITMAP_BITS: usize = u64::BITS as usize; + /// Build a record ACK from canonical ranges ordered from highest to lowest sequence number. + /// + /// Ranges must be: + /// - non-empty + /// - individually valid (`start <= end`) + /// - strictly descending + /// - separated by at least one missing sequence number + pub fn from_ranges(ranges: I) -> Result + where + I: IntoIterator>, + { + let mut ranges = ranges.into_iter(); + let Some(first_range) = ranges.next() else { + return Err(RecordAckRangeError::Empty); + }; + + let first_start = first_range.start().into_inner(); + let first_end = first_range.end().into_inner(); + if first_start > first_end { + return Err(RecordAckRangeError::InvertedRange); + } + + let mut prev_start = first_start; + let mut prev_end = first_end; + let mut blocks = Vec::new(); + + for range in ranges { + let start = range.start().into_inner(); + let end = range.end().into_inner(); + if start > end { + return Err(RecordAckRangeError::InvertedRange); + } + if end >= prev_end || end.saturating_add(1) >= prev_start { + return Err(RecordAckRangeError::NotCanonical); + } + + let gap = prev_start + .checked_sub(end) + .and_then(|delta| delta.checked_sub(2)) + .expect("canonical ack ranges stay separated by at least one sequence"); + blocks.push(RecordAckBlock { + gap: VarInt::from_u64(gap).expect("record ack gap must fit varint"), + range_len: VarInt::from_u64(end - start) + .expect("record ack range length must fit varint"), + }); + prev_start = start; + prev_end = end; + } + + Ok(Self { + largest_acked: RecordSeq::from_u64(first_end) + .expect("record ack range upper bound must fit record sequence"), + first_range_len: VarInt::from_u64(first_end - first_start) + .expect("record ack first range length must fit varint"), + blocks: blocks.into_boxed_slice(), + }) + } + + pub fn largest_acked(&self) -> RecordSeq { + self.largest_acked + } + + pub fn first_range_len(&self) -> VarInt { + self.first_range_len + } + + pub fn blocks(&self) -> &[RecordAckBlock] { + &self.blocks + } + + pub fn range_count(&self) -> usize { + 1 + self.blocks.len() + } + + pub fn ranges(&self) -> RecordAckRangeIter<'_> { + RecordAckRangeIter { + largest_acked: self.largest_acked.into_inner(), + first_range_len: Some(self.first_range_len), + previous_start: None, + blocks: self.blocks.iter(), + } + } pub fn contains(&self, seq: u64) -> bool { - if seq < self.base_seq.into_inner() { + let Ok(seq) = RecordSeq::from_u64(seq) else { return false; + }; + self.ranges().any(|range| range.contains(&seq)) + } + + fn validate(&self) -> Result<(), WireError> { + let mut previous_start = self + .largest_acked + .into_inner() + .checked_sub(self.first_range_len.into_inner()) + .ok_or(WireError::InvalidPayload)?; + + for block in self.blocks.iter() { + let end = previous_start + .checked_sub( + block + .gap + .into_inner() + .checked_add(2) + .ok_or(WireError::InvalidPayload)?, + ) + .ok_or(WireError::InvalidPayload)?; + previous_start = end + .checked_sub(block.range_len.into_inner()) + .ok_or(WireError::InvalidPayload)?; } - let offset = seq - self.base_seq.into_inner(); - if offset >= Self::BITMAP_BITS as u64 { - return false; + Ok(()) + } +} + +impl RecordAckBlock { + pub fn gap(&self) -> VarInt { + self.gap + } + + pub fn range_len(&self) -> VarInt { + self.range_len + } +} + +impl fmt::Display for RecordAckRangeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Empty => f.write_str("record ack requires at least one acknowledged range"), + Self::InvertedRange => { + f.write_str("record ack range start must be less than or equal to end") + } + Self::NotCanonical => f.write_str( + "record ack ranges must be passed in descending, disjoint order with a gap between adjacent ranges", + ), } + } +} + +impl std::error::Error for RecordAckRangeError {} - (self.bits & (1u64 << offset)) != 0 +pub struct RecordAckRangeIter<'a> { + largest_acked: u64, + first_range_len: Option, + previous_start: Option, + blocks: std::slice::Iter<'a, RecordAckBlock>, +} + +impl Iterator for RecordAckRangeIter<'_> { + type Item = RangeInclusive; + + fn next(&mut self) -> Option { + if let Some(first_range_len) = self.first_range_len.take() { + let end = self.largest_acked; + let start = end - first_range_len.into_inner(); + self.previous_start = Some(start); + return Some(RecordSeq::from_u64(start).unwrap()..=RecordSeq::from_u64(end).unwrap()); + } + + let block = self.blocks.next()?; + let previous_start = self + .previous_start + .expect("first ack range is always yielded"); + let end = previous_start - block.gap.into_inner() - 2; + let start = end - block.range_len.into_inner(); + self.previous_start = Some(start); + Some(RecordSeq::from_u64(start).unwrap()..=RecordSeq::from_u64(end).unwrap()) } } impl WireEncode for RecordAck { fn encoded_len(&self) -> usize { - self.base_seq.encoded_len() + size_of::() + self.largest_acked.encoded_len() + + VarInt::try_from(self.blocks.len()).unwrap().encoded_len() + + self.first_range_len.encoded_len() + + self + .blocks + .iter() + .map(|block| block.gap.encoded_len() + block.range_len.encoded_len()) + .sum::() } fn encode(&self, out: &mut W) { - self.base_seq.encode(out); - self.bits.encode(out); + self.largest_acked.encode(out); + VarInt::try_from(self.blocks.len()).unwrap().encode(out); + self.first_range_len.encode(out); + for block in self.blocks.iter() { + block.gap.encode(out); + block.range_len.encode(out); + } } } impl codec::WireDecode for RecordAck { fn decode(reader: &mut codec::Reader) -> Result { - Ok(Self { - base_seq: reader.decode()?, - bits: reader.decode()?, - }) + let largest_acked = reader.decode()?; + let block_count = usize::try_from(reader.decode::()?.into_inner()) + .map_err(|_| WireError::InvalidPayload)?; + let first_range_len = reader.decode::()?; + let mut blocks = Vec::with_capacity(block_count); + for _ in 0..block_count { + blocks.push(RecordAckBlock { + gap: reader.decode::()?, + range_len: reader.decode::()?, + }); + } + + let ack = Self { + largest_acked, + first_range_len, + blocks: blocks.into_boxed_slice(), + }; + ack.validate()?; + Ok(ack) } } #[cfg(test)] mod tests { - use super::RecordAck; - use crate::{RecordSeq, WireDecode, WireEncode, WireError}; + use std::ops::RangeInclusive; + + use super::{RecordAck, RecordAckBlock, RecordAckRangeError}; + use crate::{RecordSeq, VarInt, WireDecode, WireEncode, WireError}; + + fn seq(value: u64) -> RecordSeq { + RecordSeq::from_u64(value).unwrap() + } + + fn ack_range(start: u64, end: u64) -> RangeInclusive { + seq(start)..=seq(end) + } + + fn varint(value: u64) -> VarInt { + VarInt::from_u64(value).unwrap() + } #[test] fn encode_decode_round_trip() { - let ack = RecordAck { - base_seq: RecordSeq::from_u32(42), - bits: (1u64 << 0) | (1u64 << 17) | (1u64 << 63), - }; + let ack = + RecordAck::from_ranges([ack_range(95, 100), ack_range(90, 92), ack_range(80, 80)]) + .unwrap(); let encoded = ack.encode_vec(); assert_eq!(RecordAck::decode_exact(encoded.as_slice()).unwrap(), ack); } #[test] - fn contains_matches_bit_membership() { - let ack = RecordAck { - base_seq: RecordSeq::from_u32(100), - bits: (1u64 << 0) | (1u64 << 5) | (1u64 << 63), - }; + fn wire_fields_match_gap_encoding() { + let ack = + RecordAck::from_ranges([ack_range(95, 100), ack_range(90, 92), ack_range(80, 80)]) + .unwrap(); + + assert_eq!(ack.largest_acked(), seq(100)); + assert_eq!(ack.first_range_len(), varint(5)); + assert_eq!( + ack.blocks(), + &[ + RecordAckBlock { + gap: varint(1), + range_len: varint(2), + }, + RecordAckBlock { + gap: varint(8), + range_len: varint(0), + } + ] + ); + } + + #[test] + fn rejects_unsorted_ranges() { + assert_eq!( + RecordAck::from_ranges([ack_range(90, 92), ack_range(95, 100)]), + Err(RecordAckRangeError::NotCanonical) + ); + } + + #[test] + fn rejects_touching_ranges() { + assert_eq!( + RecordAck::from_ranges([ack_range(10, 12), ack_range(7, 9)]), + Err(RecordAckRangeError::NotCanonical) + ); + } + + #[test] + fn rejects_overlapping_ranges() { + assert_eq!( + RecordAck::from_ranges([ack_range(10, 12), ack_range(8, 11)]), + Err(RecordAckRangeError::NotCanonical) + ); + } + + #[test] + fn contains_matches_range_membership() { + let ack = RecordAck::from_ranges([ + ack_range(150, 163), + ack_range(105, 110), + ack_range(100, 100), + ]) + .unwrap(); assert!(ack.contains(100)); - assert!(ack.contains(105)); + assert!(ack.contains(107)); assert!(ack.contains(163)); assert!(!ack.contains(99)); - assert!(!ack.contains(101)); + assert!(!ack.contains(104)); assert!(!ack.contains(164)); } + #[test] + fn empty_ack_is_rejected() { + assert_eq!(RecordAck::from_ranges([]), Err(RecordAckRangeError::Empty)); + } + + #[test] + fn inverted_range_is_rejected() { + assert_eq!( + RecordAck::from_ranges([ack_range(5, 4)]), + Err(RecordAckRangeError::InvertedRange) + ); + } + + #[test] + fn decode_rejects_underflowing_ack_blocks() { + let encoded = vec![ + 42, // largest_acked + 1, // block_count + 0, // first_range_len + 41, // gap: implies a missing run larger than largest_acked + 0, // range_len + ]; + + assert_eq!( + RecordAck::decode_exact(encoded.as_slice()), + Err(WireError::InvalidPayload) + ); + } + #[test] fn decode_rejects_truncated_payload() { assert_eq!( RecordAck::decode_exact(&[][..]), Err(WireError::InvalidPayload) ); - let encoded = vec![0; RecordSeq::from_u32(0).encoded_len() + size_of::()]; + + let encoded = RecordAck::from_ranges([ack_range(42, 42)]) + .unwrap() + .encode_vec(); assert_eq!( RecordAck::decode_exact(&encoded[..encoded.len() - 1]), Err(WireError::InvalidPayload) diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 022518ea..94379580 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -1,3 +1,5 @@ +use std::ops::RangeInclusive; + use super::*; fn decode_handshake_record(bytes: &[u8]) -> QlHandshakeRecord { @@ -21,6 +23,10 @@ fn record_seq(value: u64) -> RecordSeq { RecordSeq(varint(value)) } +fn record_ack_range(start: u64, end: u64) -> RangeInclusive { + record_seq(start)..=record_seq(end) +} + fn stream_id(value: u64) -> StreamId { StreamId(varint(value)) } @@ -647,15 +653,9 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { }; let body = vec![ SessionFrame::Ping, - SessionFrame::Ack(RecordAck { - base_seq: record_seq(12), - bits: (1u64 << 0) - | (1u64 << 1) - | (1u64 << 8) - | (1u64 << 9) - | (1u64 << 10) - | (1u64 << 11), - }), + SessionFrame::Ack( + RecordAck::from_ranges([record_ack_range(20, 23), record_ack_range(12, 13)]).unwrap(), + ), SessionFrame::StreamWindow(StreamWindow { stream_id: stream_id(9), maximum_offset: varint(65_536), @@ -842,10 +842,9 @@ fn protocol_record_size_breakdown() { seq: record_seq(2), }, &session.tx_key, - &[SessionFrame::Ack(RecordAck { - base_seq: record_seq(1), - bits: (1u64 << 0) | (1u64 << 1) | (1u64 << 5), - })], + &[SessionFrame::Ack( + RecordAck::from_ranges([record_ack_range(6, 6), record_ack_range(1, 2)]).unwrap(), + )], ); let session_stream_empty = encrypt_record( &crypto, From a5b485ddef30fc3194eec7d0f08548d1e9d27f19 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 14 Apr 2026 13:27:20 -0400 Subject: [PATCH 210/304] ql-wire: len inside of streamdata frame --- ql-wire/src/encrypted/builder.rs | 20 ++------------------ ql-wire/src/encrypted/mod.rs | 25 ++++--------------------- ql-wire/src/encrypted/stream_data.rs | 23 +++++++++++++++-------- 3 files changed, 21 insertions(+), 47 deletions(-) diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index 65ad053a..711d923b 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -3,7 +3,7 @@ use bytes::BufMut; use super::{RecordAck, SessionClose, SessionFrame, StreamClose, StreamData, StreamWindow}; use crate::{ BufView, ConnectionId, Nonce, QlCrypto, RecordSeq, RecordType, SessionHeader, SessionKey, - VarInt, WireEncode, QL_WIRE_VERSION, + WireEncode, QL_WIRE_VERSION, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -71,7 +71,7 @@ impl SessionRecordBuilder { } pub fn push_stream_data(&mut self, frame: &StreamData) -> bool { - self.push_len_prefixed_frame(super::SessionFrameKind::StreamData, frame) + self.push_frame_payload(super::SessionFrameKind::StreamData, frame) } pub fn push_stream_window(&mut self, frame: &StreamWindow) -> bool { @@ -154,22 +154,6 @@ impl SessionRecordBuilder { }) } - fn push_len_prefixed_frame( - &mut self, - kind: super::SessionFrameKind, - payload: &T, - ) -> bool { - let payload_wire_size = payload.encoded_len(); - let Ok(prefix_len) = VarInt::try_from(payload_wire_size) else { - return false; - }; - self.push_wire_size(1 + prefix_len.encoded_len() + payload_wire_size, |out| { - out.put_u8(kind as u8); - prefix_len.encode(out); - payload.encode(out); - }) - } - fn can_push_len(&self, len: usize) -> bool { len <= self.remaining_capacity() } diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index d5be4f5d..0c0e6338 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -1,6 +1,6 @@ use crate::{ codec, encrypted_message::EncryptedMessage, BufView, ByteSlice, Nonce, QlCrypto, Reader, - SessionHeader, SessionKey, VarInt, WireDecode, WireEncode, WireError, + SessionHeader, SessionKey, WireDecode, WireEncode, WireError, }; mod ack; @@ -37,12 +37,7 @@ impl WireDecode for SessionFrame { let frame = match kind { SessionFrameKind::Ping => Self::Ping, SessionFrameKind::Ack => Self::Ack(reader.decode::()?), - SessionFrameKind::StreamData => { - let len = usize::try_from(reader.decode::()?.into_inner()) - .map_err(|_| WireError::InvalidPayload)?; - let frame = reader.take_bytes(len)?; - Self::StreamData(StreamData::decode_exact(frame)?) - } + SessionFrameKind::StreamData => Self::StreamData(reader.decode::>()?), SessionFrameKind::StreamWindow => Self::StreamWindow(reader.decode::()?), SessionFrameKind::StreamClose => Self::StreamClose(reader.decode::()?), SessionFrameKind::Close => Self::Close(reader.decode::()?), @@ -82,13 +77,7 @@ impl WireEncode for SessionFrame { 1 + match self { Self::Ping => 0, Self::Ack(frame) => frame.encoded_len(), - Self::StreamData(frame) => { - let payload_len = frame.encoded_len(); - VarInt::try_from(payload_len) - .unwrap_or(VarInt::MAX) - .encoded_len() - + payload_len - } + Self::StreamData(frame) => frame.encoded_len(), Self::StreamWindow(frame) => frame.encoded_len(), Self::StreamClose(frame) => frame.encoded_len(), Self::Close(frame) => frame.encoded_len(), @@ -100,13 +89,7 @@ impl WireEncode for SessionFrame { match self { Self::Ping => {} Self::Ack(frame) => frame.encode(out), - Self::StreamData(frame) => { - let payload_len = frame.encoded_len(); - let payload_len = VarInt::try_from(payload_len) - .expect("stream data frame length must fit ql-wire varint"); - payload_len.encode(out); - frame.encode(out); - } + Self::StreamData(frame) => frame.encode(out), Self::StreamWindow(frame) => frame.encode(out), Self::StreamClose(frame) => frame.encode(out), Self::Close(frame) => frame.encode(out), diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs index 2ffc480c..9174fe5a 100644 --- a/ql-wire/src/encrypted/stream_data.rs +++ b/ql-wire/src/encrypted/stream_data.rs @@ -17,7 +17,8 @@ impl StreamData { pub const MIN_WIRE_SIZE: usize = StreamId::MAX_ENCODED_LEN + VarInt::MAX_SIZE + size_of::() - + StreamHeader::MAX_WIRE_SIZE; + + StreamHeader::MAX_WIRE_SIZE + + VarInt::MAX_SIZE; } impl WireDecode for StreamData { @@ -27,17 +28,20 @@ impl WireDecode for StreamData { let flags = reader.decode::()?; let fin = (flags & flag::FIN) != 0; let has_header = (flags & flag::HEADER) != 0; + let header = if has_header { + Some(reader.decode()?) + } else { + None + }; + let bytes_len = usize::try_from(reader.decode::()?.into_inner()) + .map_err(|_| WireError::InvalidPayload)?; Ok(Self { stream_id, offset, - header: if has_header { - Some(reader.decode()?) - } else { - None - }, + header, fin, - bytes: reader.take_rest(), + bytes: reader.take_bytes(bytes_len)?, }) } } @@ -60,11 +64,13 @@ impl StreamData { impl WireEncode for StreamData { fn encoded_len(&self) -> usize { let bytes = self.bytes.buf(); + let bytes_len = bytes.remaining(); self.stream_id.encoded_len() + self.offset.encoded_len() + size_of::() + self.header.as_ref().map_or(0, WireEncode::encoded_len) - + bytes.remaining() + + VarInt::try_from(bytes_len).unwrap().encoded_len() + + bytes_len } fn encode(&self, out: &mut W) { @@ -87,6 +93,7 @@ impl WireEncode for StreamData { header.encode(out); } let mut bytes = self.bytes.buf(); + VarInt::try_from(bytes.remaining()).unwrap().encode(out); while bytes.has_remaining() { let chunk = bytes.chunk(); out.put_slice(chunk); From abe6b9116d19dd6edf7fbd658c4122c3e602872f Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 14 Apr 2026 14:41:37 -0400 Subject: [PATCH 211/304] port ql-fsm --- ql-fsm/src/handshake/mod.rs | 2 + ql-fsm/src/lib.rs | 6 + ql-fsm/src/session/mod.rs | 115 ++++---- ql-fsm/src/session/range_set.rs | 72 +++++ ql-fsm/src/session/received_records.rs | 357 ++++++++++++++++--------- ql-fsm/src/session/state.rs | 13 +- ql-fsm/src/session/tests.rs | 104 ++++++- ql-fsm/src/session/tracked.rs | 4 +- ql-fsm/src/tests/mod.rs | 2 + 9 files changed, 471 insertions(+), 204 deletions(-) diff --git a/ql-fsm/src/handshake/mod.rs b/ql-fsm/src/handshake/mod.rs index a4b05e92..983c633f 100644 --- a/ql-fsm/src/handshake/mod.rs +++ b/ql-fsm/src/handshake/mod.rs @@ -127,6 +127,8 @@ pub fn finish_handshake( peer_timeout: config.session_peer_timeout, stream_send_buffer_size: config.session_stream_send_buffer_size, stream_receive_buffer_size: config.session_stream_receive_buffer_size, + accepted_record_window: config.session_accepted_record_window, + pending_ack_range_limit: config.session_pending_ack_range_limit, initial_peer_stream_receive_window: transport .remote_transport_params .initial_stream_receive_window, diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 799d13df..9dcd2c65 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -127,6 +127,10 @@ pub struct QlFsmConfig { pub session_stream_send_buffer_size: usize, /// maximum bytes buffered locally for one stream receive side pub session_stream_receive_buffer_size: u32, + /// how many accepted record sequence numbers to retain for duplicate detection + pub session_accepted_record_window: u64, + /// maximum disjoint pending ACK ranges to retain before dropping the oldest low ranges + pub session_pending_ack_range_limit: usize, } impl Default for QlFsmConfig { @@ -141,6 +145,8 @@ impl Default for QlFsmConfig { session_record_max_size: s.record_max_size, session_stream_send_buffer_size: s.stream_send_buffer_size, session_stream_receive_buffer_size: s.stream_receive_buffer_size, + session_accepted_record_window: s.accepted_record_window, + session_pending_ack_range_limit: s.pending_ack_range_limit, } } } diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 8c18cc46..88a2f564 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -24,11 +24,9 @@ use ql_wire::{ }; use self::{ - received_records::{ReceiveOutcome, ReceivedRecords}, + received_records::{PendingAck, ReceiveOutcome, RecordRxState}, remote_stream_history::RemoteStreamHistory, - state::{ - AckState, InboundState, OutboundState, SessionPhase, SessionState, StreamRole, StreamState, - }, + state::{InboundState, OutboundState, SessionPhase, SessionState, StreamRole, StreamState}, stream_tx::StreamTxRange, tracked::{TrackedFrame, TrackedRecord, TrackedStreamData}, }; @@ -45,6 +43,8 @@ pub struct SessionConfig { pub stream_send_buffer_size: usize, pub stream_receive_buffer_size: u32, pub initial_peer_stream_receive_window: u32, + pub accepted_record_window: u64, + pub pending_ack_range_limit: usize, } impl Default for SessionConfig { @@ -59,6 +59,8 @@ impl Default for SessionConfig { stream_send_buffer_size: 16 * 1024, stream_receive_buffer_size: 16 * 1024, initial_peer_stream_receive_window: 16 * 1024, + accepted_record_window: 4096, + pending_ack_range_limit: 64, } } } @@ -89,6 +91,8 @@ impl SessionFsm { .max(SessionRecordBuilder::MIN_CAPACITY); config.stream_send_buffer_size = config.stream_send_buffer_size.max(1); config.stream_receive_buffer_size = config.stream_receive_buffer_size.max(1); + config.accepted_record_window = config.accepted_record_window.max(1); + config.pending_ack_range_limit = config.pending_ack_range_limit.max(1); Self { config, state: SessionState { @@ -99,8 +103,10 @@ impl SessionFsm { next_record_seq: RecordSeq::from_u32(0), next_write_id: 0, tracked_records: Default::default(), - received_records: ReceivedRecords::default(), - ack_state: AckState::Idle, + record_rx: RecordRxState::new( + config.accepted_record_window, + config.pending_ack_range_limit, + ), pending_ping: false, streams: Default::default(), next_stream_index: 0, @@ -152,7 +158,7 @@ impl SessionFsm { let close = SessionClose { code }; self.state.phase = SessionPhase::Closing(close.clone()); self.state.tracked_records.clear(); - self.state.ack_state = AckState::Idle; + self.state.record_rx.clear_ack_state(); self.clear_streams(); emit(SessionEvent::SessionClosed(close)); } @@ -179,14 +185,13 @@ impl SessionFsm { self.collect_timeouts(now); - let mut received_records = self.state.received_records.clone(); - let out_of_order = match received_records.insert(seq) { + match self.state.record_rx.insert(seq) { ReceiveOutcome::TooOld => return, ReceiveOutcome::Duplicate => { self.schedule_ack(now, true); return; } - ReceiveOutcome::New { out_of_order } => out_of_order, + ReceiveOutcome::New => {} }; let mut ack_eliciting = false; @@ -222,15 +227,12 @@ impl SessionFsm { } } - // commit after processing - self.state.received_records = received_records; - if handled_close { return; } if ack_eliciting { - self.schedule_ack(now, out_of_order); + self.schedule_ack(now, false); } } @@ -261,7 +263,7 @@ impl SessionFsm { }; restore_tracked_record( now, - &mut self.state.ack_state, + &mut self.state.record_rx, &mut self.state.pending_ping, &mut self.state.streams, record, @@ -292,10 +294,7 @@ impl SessionFsm { if !self.state.phase.is_open() { return None; } - let ack_deadline = match self.state.ack_state { - AckState::Idle => None, - AckState::Dirty { due_at } => Some(due_at), - }; + let ack_deadline = self.state.record_rx.ack_deadline(); let retransmit_deadline = self .state .tracked_records @@ -361,7 +360,7 @@ impl SessionFsm { let mut outbound = TrackedRecord { seq, frames: Vec::new(), - ack_included: false, + ack: None, ping_included: false, window_updates: Vec::new(), sent_at: None, @@ -378,10 +377,12 @@ impl SessionFsm { self.push_next_stream_data(&mut builder, &mut outbound); - if let Some((ack, due_at)) = self.pending_ack() { - if (!builder.is_empty() || due_at <= now) && builder.push_ack(&ack) { - outbound.ack_included = true; - self.state.ack_state = AckState::Idle; + if let Some(pending_ack) = self.pending_ack(builder.remaining_capacity()) { + if !builder.is_empty() || pending_ack.due_at <= now { + if builder.push_ack(&pending_ack.ack) { + self.state.record_rx.on_ack_emitted(&pending_ack); + outbound.ack = Some(pending_ack.ack); + } } } @@ -458,7 +459,7 @@ impl SessionFsm { builder: &mut SessionRecordBuilder, outbound: &mut TrackedRecord, ) { - const OVERHEAD: usize = 1 + VarInt::MAX_SIZE + StreamData::>::MIN_WIRE_SIZE; + const OVERHEAD: usize = 1 + StreamData::>::MIN_WIRE_SIZE; let len = self.state.streams.len(); if len == 0 { @@ -525,38 +526,39 @@ impl SessionFsm { fn process_record_ack(&mut self, ack: &RecordAck, emit: &mut impl FnMut(SessionEvent)) { let stream_send_buffer_size = self.config.stream_send_buffer_size; - { - let tracked_records = &mut self.state.tracked_records; - let streams = &mut self.state.streams; - for (_, record) in tracked_records.extract_if(.., |_, record| { + let acked_records = self + .state + .tracked_records + .extract_if(.., |_, record| { record.sent_at.is_some() && ack.contains(record.seq.into_inner()) - }) { - for frame in &record.frames { - acknowledge_tracked_frame(streams, stream_send_buffer_size, frame, emit); - } + }) + .map(|(_, record)| record) + .collect::>(); + + for record in acked_records { + for frame in &record.frames { + acknowledge_tracked_frame( + &mut self.state.streams, + stream_send_buffer_size, + frame, + emit, + ); } } self.reap_reapable_streams(); } fn schedule_ack(&mut self, now: Instant, immediate: bool) { - schedule_ack( - &mut self.state.ack_state, - if immediate { - now - } else { - now + self.config.ack_delay - }, - ); + self.state.record_rx.schedule_ack(if immediate { + now + } else { + now + self.config.ack_delay + }); } - fn pending_ack(&self) -> Option<(RecordAck, Instant)> { - match self.state.ack_state { - AckState::Idle => None, - AckState::Dirty { due_at } => { - self.state.received_records.ack().map(|ack| (ack, due_at)) - } - } + fn pending_ack(&self, remaining_capacity: usize) -> Option { + let max_ack_wire_size = remaining_capacity.checked_sub(1)?; + self.state.record_rx.pending_ack(max_ack_wire_size) } fn collect_timeouts(&mut self, now: Instant) { @@ -568,7 +570,7 @@ impl SessionFsm { }) { restore_tracked_record( now, - &mut self.state.ack_state, + &mut self.state.record_rx, &mut self.state.pending_ping, &mut self.state.streams, record, @@ -854,15 +856,6 @@ impl SessionFsm { } } -fn schedule_ack(ack_state: &mut AckState, due_at: Instant) { - *ack_state = match *ack_state { - AckState::Dirty { due_at: old } => AckState::Dirty { - due_at: due_at.min(old), - }, - AckState::Idle => AckState::Dirty { due_at }, - }; -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum MissingStreamAction { Create, @@ -905,13 +898,13 @@ fn local_stream_was_opened( fn restore_tracked_record( now: Instant, - ack_state: &mut AckState, + record_rx: &mut RecordRxState, pending_ping: &mut bool, streams: &mut IndexMap, record: TrackedRecord, ) { - if record.ack_included { - schedule_ack(ack_state, now); + if let Some(ack) = &record.ack { + record_rx.restore_acked_ranges(ack, now); } if record.ping_included { *pending_ping = true; diff --git a/ql-fsm/src/session/range_set.rs b/ql-fsm/src/session/range_set.rs index ac39f23f..53d66269 100644 --- a/ql-fsm/src/session/range_set.rs +++ b/ql-fsm/src/session/range_set.rs @@ -84,10 +84,28 @@ impl RangeSet { self.0.first_key_value().map(|(&start, _)| start) } + pub fn max(&self) -> Option { + self.0 + .last_key_value() + .map(|(_, &end)| end.checked_sub(1).unwrap()) + } + + pub fn contains(&self, x: u64) -> bool { + self.before(x).is_some_and(|(_, end)| end > x) + } + + pub fn range_count(&self) -> usize { + self.0.len() + } + pub fn iter(&self) -> Iter<'_> { Iter(self.0.iter()) } + pub fn iter_rev(&self) -> RevIter<'_> { + RevIter(self.0.iter().rev()) + } + pub fn peek_min(&self) -> Option> { let (&start, &end) = self.0.iter().next()?; Some(start..end) @@ -99,6 +117,19 @@ impl RangeSet { Some(result) } + #[cfg(test)] + pub fn peek_max(&self) -> Option> { + let (&start, &end) = self.0.iter().next_back()?; + Some(start..end) + } + + #[cfg(test)] + pub fn pop_max(&mut self) -> Option> { + let result = self.peek_max()?; + self.0.remove(&result.start); + Some(result) + } + /// find closest range to `x` that begins at or before it fn before(&self, x: u64) -> Option<(u64, u64)> { self.0 @@ -126,6 +157,16 @@ impl Iterator for Iter<'_> { } } +pub struct RevIter<'a>(std::iter::Rev>); + +impl Iterator for RevIter<'_> { + type Item = Range; + + fn next(&mut self) -> Option { + self.0.next().map(|(&start, &end)| start..end) + } +} + #[cfg(test)] mod tests { use super::RangeSet; @@ -146,4 +187,35 @@ mod tests { assert!(set.remove(20..30)); assert_eq!(set.iter().collect::>(), vec![10..20, 30..40]); } + + #[test] + fn reverse_iteration_visits_highest_range_first() { + let mut set = RangeSet::new(); + set.insert(10..20); + set.insert(30..40); + set.insert(50..60); + + assert_eq!( + set.iter_rev().collect::>(), + vec![50..60, 30..40, 10..20] + ); + assert_eq!(set.peek_max(), Some(50..60)); + assert_eq!(set.pop_max(), Some(50..60)); + assert_eq!(set.iter().collect::>(), vec![10..20, 30..40]); + } + + #[test] + fn contains_and_max_reflect_current_membership() { + let mut set = RangeSet::new(); + set.insert(10..20); + set.insert(30..31); + + assert!(!set.contains(9)); + assert!(set.contains(10)); + assert!(set.contains(19)); + assert!(!set.contains(20)); + assert_eq!(set.min(), Some(10)); + assert_eq!(set.max(), Some(30)); + assert_eq!(set.range_count(), 2); + } } diff --git a/ql-fsm/src/session/received_records.rs b/ql-fsm/src/session/received_records.rs index fe0a58f5..98f23e7c 100644 --- a/ql-fsm/src/session/received_records.rs +++ b/ql-fsm/src/session/received_records.rs @@ -1,178 +1,283 @@ -use ql_wire::{RecordAck, RecordSeq}; +use std::{ops::RangeInclusive, time::Instant}; -#[derive(Debug, Clone, Default)] -pub struct ReceivedRecords { - seen: u64, - base: u64, +use ql_wire::{RecordAck, RecordSeq, WireEncode}; + +use super::range_set::RangeSet; + +#[derive(Debug, Clone)] +pub struct RecordRxState { + accepted_records: RangeSet, + pending_ack_ranges: RangeSet, + ack_state: AckState, + accepted_record_window: u64, + pending_ack_range_limit: usize, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PendingAck { + pub ack: RecordAck, + pub due_at: Instant, + pub includes_all_pending: bool, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ReceiveOutcome { - New { out_of_order: bool }, + New, Duplicate, TooOld, } -impl ReceivedRecords { - const TRACKED_LEN: u64 = RecordAck::BITMAP_BITS as u64; - const TRACKED_WINDOW: u64 = Self::TRACKED_LEN - 1; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum AckState { + Idle, + Dirty { due_at: Instant }, +} - pub fn insert(&mut self, seq: RecordSeq) -> ReceiveOutcome { - let seq = seq.into_inner(); - if self.seen == 0 { - self.base = seq; - self.seen = 1; - return ReceiveOutcome::New { - out_of_order: false, - }; +impl RecordRxState { + pub fn new(accepted_record_window: u64, pending_ack_range_limit: usize) -> Self { + Self { + accepted_records: RangeSet::new(), + pending_ack_ranges: RangeSet::new(), + ack_state: AckState::Idle, + accepted_record_window: accepted_record_window.max(1), + pending_ack_range_limit: pending_ack_range_limit.max(1), } + } - if seq < self.base { + pub fn insert(&mut self, seq: RecordSeq) -> ReceiveOutcome { + let seq = seq.into_inner(); + let largest_accepted = self.accepted_records.max(); + if largest_accepted.is_some_and(|largest| seq < self.accepted_cutoff(largest)) { return ReceiveOutcome::TooOld; } - - let base = self.base.max(seq.saturating_sub(Self::TRACKED_WINDOW)); - let seen = self.rebased_seen(base); - let next_seen = seen | (1u64 << (seq - base)); - if next_seen == seen { + if self.accepted_records.contains(seq) { + self.pending_ack_ranges.insert(singleton_range(seq)); + self.trim_pending_ack_ranges(); return ReceiveOutcome::Duplicate; } - let out_of_order = seq - != self - .base - .saturating_add(u64::from(u64::BITS - 1 - self.seen.leading_zeros())) - .saturating_add(1); - self.base = base; - self.seen = next_seen; - ReceiveOutcome::New { out_of_order } + self.accepted_records.insert(singleton_range(seq)); + self.trim_accepted_records(); + + self.pending_ack_ranges.insert(singleton_range(seq)); + self.trim_pending_ack_ranges(); + + ReceiveOutcome::New } - pub fn ack(&self) -> Option { - (self.seen != 0).then_some(RecordAck { - base_seq: RecordSeq::from_u64(self.base).expect("tracked record seq must fit varint"), - bits: self.seen, + #[cfg(test)] + pub fn contains(&self, seq: RecordSeq) -> bool { + self.accepted_records.contains(seq.into_inner()) + } + + #[cfg(test)] + pub fn largest_accepted(&self) -> Option { + self.accepted_records + .max() + .map(|largest| RecordSeq::from_u64(largest).unwrap()) + } + + pub fn ack_deadline(&self) -> Option { + match self.ack_state { + AckState::Idle => None, + AckState::Dirty { due_at } => Some(due_at), + } + } + + pub fn schedule_ack(&mut self, due_at: Instant) { + self.ack_state = match self.ack_state { + AckState::Dirty { due_at: old } => AckState::Dirty { + due_at: due_at.min(old), + }, + AckState::Idle => AckState::Dirty { due_at }, + }; + } + + pub fn pending_ack(&self, max_wire_size: usize) -> Option { + let due_at = self.ack_deadline()?; + if max_wire_size == 0 || self.pending_ack_ranges.range_count() == 0 { + return None; + } + + let total_range_count = self.pending_ack_ranges.range_count(); + let mut included_range_count = 0usize; + let mut ranges = Vec::new(); + let mut ack = None; + + for range in self.pending_ack_ranges.iter_rev() { + ranges.push(to_ack_range(range)); + let candidate = RecordAck::from_ranges(ranges.iter().cloned()).unwrap(); + if candidate.encoded_len() > max_wire_size { + ranges.pop(); + break; + } + + included_range_count += 1; + ack = Some(candidate); + } + + ack.map(|ack| PendingAck { + ack, + due_at, + includes_all_pending: included_range_count == total_range_count, }) } - fn rebased_seen(&self, new_base: u64) -> u64 { - if new_base <= self.base { - return self.seen; + pub fn on_ack_emitted(&mut self, pending_ack: &PendingAck) { + self.retire_acked_ranges(&pending_ack.ack); + if pending_ack.includes_all_pending || self.pending_ack_ranges.range_count() == 0 { + self.ack_state = AckState::Idle; + } + } + + pub fn retire_acked_ranges(&mut self, ack: &RecordAck) { + for range in ack.ranges() { + self.pending_ack_ranges.remove(from_ack_range(range)); + } + if self.pending_ack_ranges.range_count() == 0 { + self.ack_state = AckState::Idle; + } + } + + pub fn clear_ack_state(&mut self) { + self.ack_state = AckState::Idle; + } + + pub fn restore_acked_ranges(&mut self, ack: &RecordAck, due_at: Instant) { + for range in ack.ranges() { + self.pending_ack_ranges.insert(from_ack_range(range)); } + self.trim_pending_ack_ranges(); + self.schedule_ack(due_at); + } + + fn accepted_cutoff(&self, largest_accepted: u64) -> u64 { + largest_accepted + .saturating_add(1) + .saturating_sub(self.accepted_record_window) + } + + fn trim_accepted_records(&mut self) { + let Some(largest_accepted) = self.accepted_records.max() else { + return; + }; + let cutoff = self.accepted_cutoff(largest_accepted); + self.accepted_records.remove(0..cutoff); + } - let shift = new_base - self.base; - if shift >= Self::TRACKED_LEN { - 0 - } else { - self.seen >> shift + fn trim_pending_ack_ranges(&mut self) { + while self.pending_ack_ranges.range_count() > self.pending_ack_range_limit { + self.pending_ack_ranges.pop_min(); } } } +fn singleton_range(seq: u64) -> std::ops::Range { + seq..seq.checked_add(1).unwrap() +} + +fn to_ack_range(range: std::ops::Range) -> RangeInclusive { + let end = range.end.checked_sub(1).unwrap(); + RecordSeq::from_u64(range.start).unwrap()..=RecordSeq::from_u64(end).unwrap() +} + +fn from_ack_range(range: RangeInclusive) -> std::ops::Range { + let start = range.start().into_inner(); + let end = range.end().into_inner().checked_add(1).unwrap(); + start..end +} + #[cfg(test)] mod tests { - use ql_wire::{RecordAck, RecordSeq}; + use std::time::{Duration, Instant}; + + use ql_wire::RecordSeq; - use super::{ReceiveOutcome, ReceivedRecords}; + use super::{PendingAck, ReceiveOutcome, RecordRxState}; fn seq(value: u64) -> RecordSeq { RecordSeq::from_u64(value).unwrap() } + fn ack_ranges(pending_ack: PendingAck) -> Vec<(u64, u64)> { + pending_ack + .ack + .ranges() + .map(|range| (range.start().into_inner(), range.end().into_inner())) + .collect() + } + #[test] - fn inserts_pack_contiguous_bits() { - let mut received = ReceivedRecords::default(); + fn contiguous_records_emit_one_ack_range() { + let now = Instant::now(); + let mut record_rx = RecordRxState::new(128, 8); - assert_eq!( - received.insert(seq(10)), - ReceiveOutcome::New { - out_of_order: false - } - ); - assert_eq!( - received.insert(seq(12)), - ReceiveOutcome::New { out_of_order: true } - ); - assert_eq!( - received.insert(seq(11)), - ReceiveOutcome::New { out_of_order: true } - ); - - let ack = received.ack().unwrap(); - assert_eq!( - ack, - RecordAck { - base_seq: seq(10), - bits: 0b111, - } - ); + assert_eq!(record_rx.insert(seq(10)), ReceiveOutcome::New); + assert_eq!(record_rx.insert(seq(11)), ReceiveOutcome::New); + assert_eq!(record_rx.insert(seq(12)), ReceiveOutcome::New); + + record_rx.schedule_ack(now); + let pending_ack = record_rx.pending_ack(usize::MAX).unwrap(); + assert_eq!(ack_ranges(pending_ack), vec![(10, 12)]); } #[test] - fn old_records_fall_out_of_fixed_window() { - let mut received = ReceivedRecords::default(); + fn sparse_records_emit_descending_ack_ranges() { + let now = Instant::now(); + let mut record_rx = RecordRxState::new(128, 8); - assert_eq!( - received.insert(seq(0)), - ReceiveOutcome::New { - out_of_order: false - } - ); - assert_eq!( - received.insert(seq(300)), - ReceiveOutcome::New { out_of_order: true } - ); - assert_eq!(received.insert(seq(0)), ReceiveOutcome::TooOld); - - let ack = received.ack().unwrap(); - assert_eq!( - ack, - RecordAck { - base_seq: seq(237), - bits: 1u64 << 63, - } - ); + assert_eq!(record_rx.insert(seq(10)), ReceiveOutcome::New); + assert_eq!(record_rx.insert(seq(15)), ReceiveOutcome::New); + assert_eq!(record_rx.insert(seq(16)), ReceiveOutcome::New); + assert_eq!(record_rx.insert(seq(12)), ReceiveOutcome::New); + + record_rx.schedule_ack(now + Duration::from_millis(5)); + let pending_ack = record_rx.pending_ack(usize::MAX).unwrap(); + assert_eq!(ack_ranges(pending_ack), vec![(15, 16), (12, 12), (10, 10)]); } #[test] - fn duplicate_in_window_is_rejected() { - let mut received = ReceivedRecords::default(); + fn accepted_record_window_evicts_old_sequences() { + let mut record_rx = RecordRxState::new(4, 8); - assert_eq!( - received.insert(seq(7)), - ReceiveOutcome::New { - out_of_order: false - } - ); - assert_eq!(received.insert(seq(7)), ReceiveOutcome::Duplicate); + assert_eq!(record_rx.insert(seq(10)), ReceiveOutcome::New); + assert_eq!(record_rx.insert(seq(15)), ReceiveOutcome::New); + + assert_eq!(record_rx.insert(seq(10)), ReceiveOutcome::TooOld); + assert!(!record_rx.contains(seq(10))); + assert_eq!(record_rx.largest_accepted(), Some(seq(15))); } #[test] - fn sliding_window_preserves_relative_bits() { - let mut received = ReceivedRecords::default(); + fn pending_ack_range_limit_drops_oldest_low_ranges() { + let now = Instant::now(); + let mut record_rx = RecordRxState::new(128, 2); - assert_eq!( - received.insert(seq(10)), - ReceiveOutcome::New { - out_of_order: false - } - ); - assert_eq!( - received.insert(seq(12)), - ReceiveOutcome::New { out_of_order: true } - ); - assert_eq!( - received.insert(seq(70)), - ReceiveOutcome::New { out_of_order: true } - ); - - let ack = received.ack().unwrap(); - assert_eq!( - ack, - RecordAck { - base_seq: seq(10), - bits: (1u64 << 0) | (1u64 << 2) | (1u64 << 60), - } - ); + assert_eq!(record_rx.insert(seq(1)), ReceiveOutcome::New); + assert_eq!(record_rx.insert(seq(3)), ReceiveOutcome::New); + assert_eq!(record_rx.insert(seq(5)), ReceiveOutcome::New); + + record_rx.schedule_ack(now); + let pending_ack = record_rx.pending_ack(usize::MAX).unwrap(); + assert_eq!(ack_ranges(pending_ack), vec![(5, 5), (3, 3)]); + } + + #[test] + fn retire_acked_ranges_removes_only_exact_snapshot() { + let now = Instant::now(); + let mut record_rx = RecordRxState::new(128, 8); + + assert_eq!(record_rx.insert(seq(1)), ReceiveOutcome::New); + assert_eq!(record_rx.insert(seq(3)), ReceiveOutcome::New); + assert_eq!(record_rx.insert(seq(5)), ReceiveOutcome::New); + record_rx.schedule_ack(now); + + let first_ack = record_rx.pending_ack(4).unwrap(); + assert_eq!(ack_ranges(first_ack.clone()), vec![(5, 5)]); + record_rx.on_ack_emitted(&first_ack); + record_rx.retire_acked_ranges(&first_ack.ack); + + let second_ack = record_rx.pending_ack(usize::MAX).unwrap(); + assert_eq!(ack_ranges(second_ack), vec![(3, 3), (1, 1)]); } } diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index 21443f77..f9fba199 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -4,7 +4,7 @@ use indexmap::IndexMap; use ql_wire::{CloseTarget, RecordSeq, RouteId, SessionClose, StreamClose, StreamId}; use super::{ - received_records::ReceivedRecords, remote_stream_history::RemoteStreamHistory, + received_records::RecordRxState, remote_stream_history::RemoteStreamHistory, stream_rx::StreamRx, stream_tx::StreamTx, tracked::TrackedRecord, }; @@ -16,8 +16,7 @@ pub struct SessionState { pub next_record_seq: RecordSeq, pub next_write_id: u64, pub tracked_records: IndexMap, - pub received_records: ReceivedRecords, - pub ack_state: AckState, + pub record_rx: RecordRxState, pub pending_ping: bool, pub streams: IndexMap, pub next_stream_index: usize, @@ -133,11 +132,3 @@ pub enum InboundState { Closed(StreamClose), Discarding, } - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum AckState { - // ack state is not dirty - Idle, - // ack is dirty. we can wait to piggy back on an outgoing record until this time - Dirty { due_at: Instant }, -} diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index ad700276..e0ab577d 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -26,6 +26,10 @@ fn route_id(value: u64) -> RouteId { RouteId::from_u64(value).unwrap() } +fn record_ack(seq: RecordSeq) -> RecordAck { + RecordAck::from_ranges([seq..=seq]).unwrap() +} + const REFUSED: StreamCloseCode = StreamCloseCode(1); const TIMEOUT: StreamCloseCode = StreamCloseCode(2); @@ -74,6 +78,22 @@ fn next_outbound( )) } +fn drain_outbound( + fsm: &mut SessionFsm, + now: Instant, + limit: usize, +) -> Vec<(RecordSeq, Vec>>)> { + let mut records = Vec::new(); + for _ in 0..limit { + let Some(record) = next_outbound(fsm, now) else { + return records; + }; + records.push(record); + } + + panic!("session did not quiesce within outbound limit"); +} + fn receive_events( fsm: &mut SessionFsm, now: Instant, @@ -180,10 +200,7 @@ fn ack_reopens_write_capacity() { fsm.receive( now + Duration::from_millis(1), seq(9), - std::iter::once(Ok(SessionFrame::Ack(RecordAck { - base_seq: record_seq, - bits: 1u64, - }))), + std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), |event| events.push(event), ); @@ -270,6 +287,85 @@ fn pure_ack_only_records_are_fire_and_forget() { .is_none()); } +#[test] +fn sparse_out_of_order_ack_ranges_page_and_quiesce() { + let now = Instant::now(); + let sender_config = SessionConfig { + local_parity: StreamParity::Even, + record_max_size: SessionRecordBuilder::MIN_CAPACITY + 40, + ack_delay: Duration::from_millis(5), + retransmit_timeout: Duration::from_millis(25), + stream_send_buffer_size: 8 * 1024, + initial_peer_stream_receive_window: 8 * 1024, + ..SessionConfig::default() + }; + let receiver_config = SessionConfig { + local_parity: StreamParity::Odd, + record_max_size: SessionRecordBuilder::MIN_CAPACITY + 10, + ack_delay: Duration::from_millis(1), + retransmit_timeout: Duration::from_millis(25), + pending_ack_range_limit: 512, + initial_peer_stream_receive_window: 8 * 1024, + ..SessionConfig::default() + }; + let mut sender = SessionFsm::new(sender_config, now); + let mut receiver = SessionFsm::new(receiver_config, now); + + let stream_id = open_stream_id(&mut sender); + let payload = vec![b'x'; 2048]; + assert_eq!( + write_stream_bytes(&mut sender, stream_id, &payload), + payload.len() + ); + + let originals = drain_outbound(&mut sender, now, 4096); + assert!(originals.len() >= 64); + + for (seq, record) in originals + .iter() + .filter(|(seq, _)| seq.into_inner() % 2 == 1) + { + let _ = receive_events(&mut receiver, now, *seq, record); + } + + let first_ack_time = now + receiver_config.ack_delay; + let first_acks = drain_outbound(&mut receiver, first_ack_time, originals.len()); + assert!(first_acks.len() > 1); + assert!(first_acks + .iter() + .all(|(_, frames)| matches!(frames.as_slice(), [SessionFrame::Ack(_)]))); + + for (seq, record) in &first_acks { + let _ = receive_events(&mut sender, first_ack_time, *seq, record); + } + + let retransmit_time = now + sender_config.retransmit_timeout + Duration::from_millis(1); + sender.on_timer(retransmit_time, |_| {}); + let retransmits = drain_outbound(&mut sender, retransmit_time, originals.len()); + assert!(!retransmits.is_empty()); + + for (seq, record) in &retransmits { + let _ = receive_events(&mut receiver, retransmit_time, *seq, record); + } + + let second_ack_time = retransmit_time + receiver_config.ack_delay; + let second_acks = drain_outbound(&mut receiver, second_ack_time, retransmits.len() + 16); + assert!(!second_acks.is_empty()); + assert!(second_acks + .iter() + .all(|(_, frames)| matches!(frames.as_slice(), [SessionFrame::Ack(_)]))); + + for (seq, record) in &second_acks { + let _ = receive_events(&mut sender, second_ack_time, *seq, record); + } + + let final_now = second_ack_time + sender_config.retransmit_timeout + Duration::from_millis(1); + sender.on_timer(final_now, |_| {}); + receiver.on_timer(final_now, |_| {}); + assert!(next_outbound(&mut sender, final_now).is_none()); + assert!(next_outbound(&mut receiver, final_now).is_none()); +} + #[test] fn inbound_stream_data_emits_opened_and_readable() { let now = Instant::now(); diff --git a/ql-fsm/src/session/tracked.rs b/ql-fsm/src/session/tracked.rs index fa97a77b..84317951 100644 --- a/ql-fsm/src/session/tracked.rs +++ b/ql-fsm/src/session/tracked.rs @@ -2,13 +2,13 @@ use std::time::Instant; -use ql_wire::{RecordSeq, StreamClose, StreamId}; +use ql_wire::{RecordAck, RecordSeq, StreamClose, StreamId}; #[derive(Debug, Clone)] pub struct TrackedRecord { pub seq: RecordSeq, pub frames: Vec, - pub ack_included: bool, + pub ack: Option, pub ping_included: bool, pub window_updates: Vec<(StreamId, u64)>, pub sent_at: Option, diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 3ce2c1f7..bba4b599 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -315,6 +315,8 @@ fn session_config(harness: &Harness, a: bool) -> SessionConfig { peer_timeout: config.session_peer_timeout, stream_send_buffer_size: config.session_stream_send_buffer_size, stream_receive_buffer_size: config.session_stream_receive_buffer_size, + accepted_record_window: config.session_accepted_record_window, + pending_ack_range_limit: config.session_pending_ack_range_limit, initial_peer_stream_receive_window: if a { harness.b.fsm.config.session_stream_receive_buffer_size } else { From b554034d81f140d12da10734a5e4c8a473499681 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 14 Apr 2026 15:04:27 -0400 Subject: [PATCH 212/304] recordack builder --- ql-fsm/src/session/received_records.rs | 39 ++-- ql-wire/src/encrypted/ack.rs | 264 ++++++++++++++++--------- 2 files changed, 179 insertions(+), 124 deletions(-) diff --git a/ql-fsm/src/session/received_records.rs b/ql-fsm/src/session/received_records.rs index 98f23e7c..b36f99a6 100644 --- a/ql-fsm/src/session/received_records.rs +++ b/ql-fsm/src/session/received_records.rs @@ -1,6 +1,6 @@ use std::{ops::RangeInclusive, time::Instant}; -use ql_wire::{RecordAck, RecordSeq, WireEncode}; +use ql_wire::{RecordAck, RecordAckBuilder, RecordSeq}; use super::range_set::RangeSet; @@ -65,18 +65,6 @@ impl RecordRxState { ReceiveOutcome::New } - #[cfg(test)] - pub fn contains(&self, seq: RecordSeq) -> bool { - self.accepted_records.contains(seq.into_inner()) - } - - #[cfg(test)] - pub fn largest_accepted(&self) -> Option { - self.accepted_records - .max() - .map(|largest| RecordSeq::from_u64(largest).unwrap()) - } - pub fn ack_deadline(&self) -> Option { match self.ack_state { AckState::Idle => None, @@ -100,26 +88,23 @@ impl RecordRxState { } let total_range_count = self.pending_ack_ranges.range_count(); - let mut included_range_count = 0usize; - let mut ranges = Vec::new(); - let mut ack = None; + let mut ack = RecordAckBuilder::new(); + let mut selected_range_count = 0usize; for range in self.pending_ack_ranges.iter_rev() { - ranges.push(to_ack_range(range)); - let candidate = RecordAck::from_ranges(ranges.iter().cloned()).unwrap(); - if candidate.encoded_len() > max_wire_size { - ranges.pop(); + let pushed = ack + .try_push_range(to_ack_range(range), max_wire_size) + .unwrap(); + if !pushed { break; } - - included_range_count += 1; - ack = Some(candidate); + selected_range_count += 1; } - ack.map(|ack| PendingAck { - ack, + (selected_range_count != 0).then(|| PendingAck { + ack: ack.build().unwrap(), due_at, - includes_all_pending: included_range_count == total_range_count, + includes_all_pending: total_range_count == selected_range_count, }) } @@ -244,8 +229,6 @@ mod tests { assert_eq!(record_rx.insert(seq(15)), ReceiveOutcome::New); assert_eq!(record_rx.insert(seq(10)), ReceiveOutcome::TooOld); - assert!(!record_rx.contains(seq(10))); - assert_eq!(record_rx.largest_accepted(), Some(seq(15))); } #[test] diff --git a/ql-wire/src/encrypted/ack.rs b/ql-wire/src/encrypted/ack.rs index c3e90c27..6869a6e8 100644 --- a/ql-wire/src/encrypted/ack.rs +++ b/ql-wire/src/encrypted/ack.rs @@ -11,8 +11,8 @@ pub struct RecordAck { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct RecordAckBlock { - gap: VarInt, - range_len: VarInt, + pub gap: VarInt, + pub range_len: VarInt, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -34,67 +34,14 @@ impl RecordAck { where I: IntoIterator>, { - let mut ranges = ranges.into_iter(); - let Some(first_range) = ranges.next() else { - return Err(RecordAckRangeError::Empty); - }; - - let first_start = first_range.start().into_inner(); - let first_end = first_range.end().into_inner(); - if first_start > first_end { - return Err(RecordAckRangeError::InvertedRange); - } - - let mut prev_start = first_start; - let mut prev_end = first_end; - let mut blocks = Vec::new(); - + let mut builder = RecordAckBuilder::new(); for range in ranges { - let start = range.start().into_inner(); - let end = range.end().into_inner(); - if start > end { - return Err(RecordAckRangeError::InvertedRange); - } - if end >= prev_end || end.saturating_add(1) >= prev_start { - return Err(RecordAckRangeError::NotCanonical); + let pushed = builder.try_push_range(range, usize::MAX)?; + if !pushed { + unreachable!("record ack should fit inside usize::MAX"); } - - let gap = prev_start - .checked_sub(end) - .and_then(|delta| delta.checked_sub(2)) - .expect("canonical ack ranges stay separated by at least one sequence"); - blocks.push(RecordAckBlock { - gap: VarInt::from_u64(gap).expect("record ack gap must fit varint"), - range_len: VarInt::from_u64(end - start) - .expect("record ack range length must fit varint"), - }); - prev_start = start; - prev_end = end; } - - Ok(Self { - largest_acked: RecordSeq::from_u64(first_end) - .expect("record ack range upper bound must fit record sequence"), - first_range_len: VarInt::from_u64(first_end - first_start) - .expect("record ack first range length must fit varint"), - blocks: blocks.into_boxed_slice(), - }) - } - - pub fn largest_acked(&self) -> RecordSeq { - self.largest_acked - } - - pub fn first_range_len(&self) -> VarInt { - self.first_range_len - } - - pub fn blocks(&self) -> &[RecordAckBlock] { - &self.blocks - } - - pub fn range_count(&self) -> usize { - 1 + self.blocks.len() + builder.build() } pub fn ranges(&self) -> RecordAckRangeIter<'_> { @@ -113,39 +60,14 @@ impl RecordAck { self.ranges().any(|range| range.contains(&seq)) } - fn validate(&self) -> Result<(), WireError> { - let mut previous_start = self - .largest_acked - .into_inner() - .checked_sub(self.first_range_len.into_inner()) - .ok_or(WireError::InvalidPayload)?; - - for block in self.blocks.iter() { - let end = previous_start - .checked_sub( - block - .gap - .into_inner() - .checked_add(2) - .ok_or(WireError::InvalidPayload)?, - ) - .ok_or(WireError::InvalidPayload)?; - previous_start = end - .checked_sub(block.range_len.into_inner()) - .ok_or(WireError::InvalidPayload)?; - } - - Ok(()) + fn block_count_len(block_count: usize) -> usize { + VarInt::try_from(block_count).unwrap().encoded_len() } } impl RecordAckBlock { - pub fn gap(&self) -> VarInt { - self.gap - } - - pub fn range_len(&self) -> VarInt { - self.range_len + fn encoded_len(&self) -> usize { + self.gap.encoded_len() + self.range_len.encoded_len() } } @@ -187,6 +109,7 @@ impl Iterator for RecordAckRangeIter<'_> { let previous_start = self .previous_start .expect("first ack range is always yielded"); + // gap is encoded as missing_count - 1, so decoding steps back by gap + 2. let end = previous_start - block.gap.into_inner() - 2; let start = end - block.range_len.into_inner(); self.previous_start = Some(start); @@ -197,12 +120,12 @@ impl Iterator for RecordAckRangeIter<'_> { impl WireEncode for RecordAck { fn encoded_len(&self) -> usize { self.largest_acked.encoded_len() - + VarInt::try_from(self.blocks.len()).unwrap().encoded_len() + + Self::block_count_len(self.blocks.len()) + self.first_range_len.encoded_len() + self .blocks .iter() - .map(|block| block.gap.encoded_len() + block.range_len.encoded_len()) + .map(RecordAckBlock::encoded_len) .sum::() } @@ -236,16 +159,119 @@ impl codec::WireDecode for RecordAck { first_range_len, blocks: blocks.into_boxed_slice(), }; - ack.validate()?; + + // validate + { + let mut previous_start = ack + .largest_acked + .into_inner() + .checked_sub(ack.first_range_len.into_inner()) + .ok_or(WireError::InvalidPayload)?; + + for block in ack.blocks.iter() { + let end = previous_start + .checked_sub( + block + .gap + .into_inner() + .checked_add(2) + .ok_or(WireError::InvalidPayload)?, + ) + .ok_or(WireError::InvalidPayload)?; + previous_start = end + .checked_sub(block.range_len.into_inner()) + .ok_or(WireError::InvalidPayload)?; + } + } Ok(ack) } } +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct RecordAckBuilder { + largest_acked: Option, + first_range_len: Option, + blocks: Vec, + previous_start: Option, + wire_len: usize, +} + +impl RecordAckBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn try_push_range( + &mut self, + range: RangeInclusive, + max_wire_size: usize, + ) -> Result { + let start = range.start().into_inner(); + let end = range.end().into_inner(); + if start > end { + return Err(RecordAckRangeError::InvertedRange); + } + + let range_len = VarInt::from_u64(end - start).unwrap(); + if let Some(previous_start) = self.previous_start { + if end.saturating_add(1) >= previous_start { + return Err(RecordAckRangeError::NotCanonical); + } + + let gap = previous_start + .checked_sub(end) + .and_then(|delta| delta.checked_sub(2)) + .expect("canonical ack ranges stay separated by at least one sequence"); + let block = RecordAckBlock { + gap: VarInt::from_u64(gap).unwrap(), + range_len, + }; + let current_block_count_len = RecordAck::block_count_len(self.blocks.len()); + let next_block_count_len = RecordAck::block_count_len(self.blocks.len() + 1); + let next_wire_len = self.wire_len + + (next_block_count_len - current_block_count_len) + + block.encoded_len(); + if next_wire_len > max_wire_size { + return Ok(false); + } + + self.previous_start = Some(start); + self.wire_len = next_wire_len; + self.blocks.push(block); + return Ok(true); + } + + let largest_acked = RecordSeq::from_u64(end).unwrap(); + let wire_len = + largest_acked.encoded_len() + RecordAck::block_count_len(0) + range_len.encoded_len(); + if wire_len > max_wire_size { + return Ok(false); + } + + self.largest_acked = Some(largest_acked); + self.first_range_len = Some(range_len); + self.previous_start = Some(start); + self.wire_len = wire_len; + Ok(true) + } + + pub fn build(self) -> Result { + let Some(largest_acked) = self.largest_acked else { + return Err(RecordAckRangeError::Empty); + }; + + Ok(RecordAck { + largest_acked, + first_range_len: self.first_range_len.unwrap(), + blocks: self.blocks.into_boxed_slice(), + }) + } +} #[cfg(test)] mod tests { use std::ops::RangeInclusive; - use super::{RecordAck, RecordAckBlock, RecordAckRangeError}; + use super::{RecordAck, RecordAckBlock, RecordAckBuilder, RecordAckRangeError}; use crate::{RecordSeq, VarInt, WireDecode, WireEncode, WireError}; fn seq(value: u64) -> RecordSeq { @@ -276,10 +302,10 @@ mod tests { RecordAck::from_ranges([ack_range(95, 100), ack_range(90, 92), ack_range(80, 80)]) .unwrap(); - assert_eq!(ack.largest_acked(), seq(100)); - assert_eq!(ack.first_range_len(), varint(5)); + assert_eq!(ack.largest_acked, seq(100)); + assert_eq!(ack.first_range_len, varint(5)); assert_eq!( - ack.blocks(), + ack.blocks.as_ref(), &[ RecordAckBlock { gap: varint(1), @@ -293,6 +319,52 @@ mod tests { ); } + #[test] + fn builder_matches_from_ranges() { + let mut builder = RecordAckBuilder::new(); + assert!(builder + .try_push_range(ack_range(95, 100), usize::MAX) + .unwrap()); + assert!(builder + .try_push_range(ack_range(90, 92), usize::MAX) + .unwrap()); + assert!(builder + .try_push_range(ack_range(80, 80), usize::MAX) + .unwrap()); + + assert_eq!( + builder.build().unwrap(), + RecordAck::from_ranges([ack_range(95, 100), ack_range(90, 92), ack_range(80, 80)]) + .unwrap() + ); + } + + #[test] + fn builder_stops_when_budget_is_exhausted() { + let first_only = RecordAck::from_ranges([ack_range(95, 100)]).unwrap(); + let mut builder = RecordAckBuilder::new(); + + assert!(builder + .try_push_range(ack_range(95, 100), first_only.encoded_len()) + .unwrap()); + assert!(!builder + .try_push_range(ack_range(90, 92), first_only.encoded_len()) + .unwrap()); + assert_eq!(builder.build().unwrap(), first_only); + } + + #[test] + fn builder_rejects_non_canonical_ranges() { + let mut builder = RecordAckBuilder::new(); + assert!(builder + .try_push_range(ack_range(95, 100), usize::MAX) + .unwrap()); + assert_eq!( + builder.try_push_range(ack_range(90, 95), usize::MAX), + Err(RecordAckRangeError::NotCanonical) + ); + } + #[test] fn rejects_unsorted_ranges() { assert_eq!( From a0da822cf2e13a0cc4f654e4bb0df6c1f75bc371 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 14 Apr 2026 21:51:06 -0400 Subject: [PATCH 213/304] ack tracker --- .../{received_records.rs => ack_tracker.rs} | 100 +++++++++--------- ql-fsm/src/session/mod.rs | 26 ++--- ql-fsm/src/session/state.rs | 6 +- 3 files changed, 66 insertions(+), 66 deletions(-) rename ql-fsm/src/session/{received_records.rs => ack_tracker.rs} (65%) diff --git a/ql-fsm/src/session/received_records.rs b/ql-fsm/src/session/ack_tracker.rs similarity index 65% rename from ql-fsm/src/session/received_records.rs rename to ql-fsm/src/session/ack_tracker.rs index b36f99a6..240095fc 100644 --- a/ql-fsm/src/session/received_records.rs +++ b/ql-fsm/src/session/ack_tracker.rs @@ -5,9 +5,9 @@ use ql_wire::{RecordAck, RecordAckBuilder, RecordSeq}; use super::range_set::RangeSet; #[derive(Debug, Clone)] -pub struct RecordRxState { +pub struct AckTracker { accepted_records: RangeSet, - pending_ack_ranges: RangeSet, + pending_ack: RangeSet, ack_state: AckState, accepted_record_window: u64, pending_ack_range_limit: usize, @@ -33,11 +33,11 @@ enum AckState { Dirty { due_at: Instant }, } -impl RecordRxState { +impl AckTracker { pub fn new(accepted_record_window: u64, pending_ack_range_limit: usize) -> Self { Self { accepted_records: RangeSet::new(), - pending_ack_ranges: RangeSet::new(), + pending_ack: RangeSet::new(), ack_state: AckState::Idle, accepted_record_window: accepted_record_window.max(1), pending_ack_range_limit: pending_ack_range_limit.max(1), @@ -51,15 +51,15 @@ impl RecordRxState { return ReceiveOutcome::TooOld; } if self.accepted_records.contains(seq) { - self.pending_ack_ranges.insert(singleton_range(seq)); + self.pending_ack.insert(single_range(seq)); self.trim_pending_ack_ranges(); return ReceiveOutcome::Duplicate; } - self.accepted_records.insert(singleton_range(seq)); + self.accepted_records.insert(single_range(seq)); self.trim_accepted_records(); - self.pending_ack_ranges.insert(singleton_range(seq)); + self.pending_ack.insert(single_range(seq)); self.trim_pending_ack_ranges(); ReceiveOutcome::New @@ -83,15 +83,15 @@ impl RecordRxState { pub fn pending_ack(&self, max_wire_size: usize) -> Option { let due_at = self.ack_deadline()?; - if max_wire_size == 0 || self.pending_ack_ranges.range_count() == 0 { + if max_wire_size == 0 || self.pending_ack.range_count() == 0 { return None; } - let total_range_count = self.pending_ack_ranges.range_count(); + let total_range_count = self.pending_ack.range_count(); let mut ack = RecordAckBuilder::new(); let mut selected_range_count = 0usize; - for range in self.pending_ack_ranges.iter_rev() { + for range in self.pending_ack.iter_rev() { let pushed = ack .try_push_range(to_ack_range(range), max_wire_size) .unwrap(); @@ -110,16 +110,16 @@ impl RecordRxState { pub fn on_ack_emitted(&mut self, pending_ack: &PendingAck) { self.retire_acked_ranges(&pending_ack.ack); - if pending_ack.includes_all_pending || self.pending_ack_ranges.range_count() == 0 { + if pending_ack.includes_all_pending || self.pending_ack.range_count() == 0 { self.ack_state = AckState::Idle; } } pub fn retire_acked_ranges(&mut self, ack: &RecordAck) { for range in ack.ranges() { - self.pending_ack_ranges.remove(from_ack_range(range)); + self.pending_ack.remove(from_ack_range(range)); } - if self.pending_ack_ranges.range_count() == 0 { + if self.pending_ack.range_count() == 0 { self.ack_state = AckState::Idle; } } @@ -130,7 +130,7 @@ impl RecordRxState { pub fn restore_acked_ranges(&mut self, ack: &RecordAck, due_at: Instant) { for range in ack.ranges() { - self.pending_ack_ranges.insert(from_ack_range(range)); + self.pending_ack.insert(from_ack_range(range)); } self.trim_pending_ack_ranges(); self.schedule_ack(due_at); @@ -151,13 +151,13 @@ impl RecordRxState { } fn trim_pending_ack_ranges(&mut self) { - while self.pending_ack_ranges.range_count() > self.pending_ack_range_limit { - self.pending_ack_ranges.pop_min(); + while self.pending_ack.range_count() > self.pending_ack_range_limit { + self.pending_ack.pop_min(); } } } -fn singleton_range(seq: u64) -> std::ops::Range { +fn single_range(seq: u64) -> std::ops::Range { seq..seq.checked_add(1).unwrap() } @@ -178,7 +178,7 @@ mod tests { use ql_wire::RecordSeq; - use super::{PendingAck, ReceiveOutcome, RecordRxState}; + use super::{AckTracker, PendingAck, ReceiveOutcome}; fn seq(value: u64) -> RecordSeq { RecordSeq::from_u64(value).unwrap() @@ -195,72 +195,72 @@ mod tests { #[test] fn contiguous_records_emit_one_ack_range() { let now = Instant::now(); - let mut record_rx = RecordRxState::new(128, 8); + let mut ack_tracker = AckTracker::new(128, 8); - assert_eq!(record_rx.insert(seq(10)), ReceiveOutcome::New); - assert_eq!(record_rx.insert(seq(11)), ReceiveOutcome::New); - assert_eq!(record_rx.insert(seq(12)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(10)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(11)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(12)), ReceiveOutcome::New); - record_rx.schedule_ack(now); - let pending_ack = record_rx.pending_ack(usize::MAX).unwrap(); + ack_tracker.schedule_ack(now); + let pending_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); assert_eq!(ack_ranges(pending_ack), vec![(10, 12)]); } #[test] fn sparse_records_emit_descending_ack_ranges() { let now = Instant::now(); - let mut record_rx = RecordRxState::new(128, 8); + let mut ack_tracker = AckTracker::new(128, 8); - assert_eq!(record_rx.insert(seq(10)), ReceiveOutcome::New); - assert_eq!(record_rx.insert(seq(15)), ReceiveOutcome::New); - assert_eq!(record_rx.insert(seq(16)), ReceiveOutcome::New); - assert_eq!(record_rx.insert(seq(12)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(10)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(15)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(16)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(12)), ReceiveOutcome::New); - record_rx.schedule_ack(now + Duration::from_millis(5)); - let pending_ack = record_rx.pending_ack(usize::MAX).unwrap(); + ack_tracker.schedule_ack(now + Duration::from_millis(5)); + let pending_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); assert_eq!(ack_ranges(pending_ack), vec![(15, 16), (12, 12), (10, 10)]); } #[test] fn accepted_record_window_evicts_old_sequences() { - let mut record_rx = RecordRxState::new(4, 8); + let mut ack_tracker = AckTracker::new(4, 8); - assert_eq!(record_rx.insert(seq(10)), ReceiveOutcome::New); - assert_eq!(record_rx.insert(seq(15)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(10)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(15)), ReceiveOutcome::New); - assert_eq!(record_rx.insert(seq(10)), ReceiveOutcome::TooOld); + assert_eq!(ack_tracker.insert(seq(10)), ReceiveOutcome::TooOld); } #[test] fn pending_ack_range_limit_drops_oldest_low_ranges() { let now = Instant::now(); - let mut record_rx = RecordRxState::new(128, 2); + let mut ack_tracker = AckTracker::new(128, 2); - assert_eq!(record_rx.insert(seq(1)), ReceiveOutcome::New); - assert_eq!(record_rx.insert(seq(3)), ReceiveOutcome::New); - assert_eq!(record_rx.insert(seq(5)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(1)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(3)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(5)), ReceiveOutcome::New); - record_rx.schedule_ack(now); - let pending_ack = record_rx.pending_ack(usize::MAX).unwrap(); + ack_tracker.schedule_ack(now); + let pending_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); assert_eq!(ack_ranges(pending_ack), vec![(5, 5), (3, 3)]); } #[test] fn retire_acked_ranges_removes_only_exact_snapshot() { let now = Instant::now(); - let mut record_rx = RecordRxState::new(128, 8); + let mut ack_tracker = AckTracker::new(128, 8); - assert_eq!(record_rx.insert(seq(1)), ReceiveOutcome::New); - assert_eq!(record_rx.insert(seq(3)), ReceiveOutcome::New); - assert_eq!(record_rx.insert(seq(5)), ReceiveOutcome::New); - record_rx.schedule_ack(now); + assert_eq!(ack_tracker.insert(seq(1)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(3)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(5)), ReceiveOutcome::New); + ack_tracker.schedule_ack(now); - let first_ack = record_rx.pending_ack(4).unwrap(); + let first_ack = ack_tracker.pending_ack(4).unwrap(); assert_eq!(ack_ranges(first_ack.clone()), vec![(5, 5)]); - record_rx.on_ack_emitted(&first_ack); - record_rx.retire_acked_ranges(&first_ack.ack); + ack_tracker.on_ack_emitted(&first_ack); + ack_tracker.retire_acked_ranges(&first_ack.ack); - let second_ack = record_rx.pending_ack(usize::MAX).unwrap(); + let second_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); assert_eq!(ack_ranges(second_ack), vec![(3, 3), (1, 1)]); } } diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 88a2f564..ad87f532 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -1,7 +1,7 @@ pub use self::{stream_ops::*, stream_parity::*, stream_rx::*}; +mod ack_tracker; mod range_set; -mod received_records; mod remote_stream_history; mod state; mod stream_ops; @@ -24,7 +24,7 @@ use ql_wire::{ }; use self::{ - received_records::{PendingAck, ReceiveOutcome, RecordRxState}, + ack_tracker::{AckTracker, PendingAck, ReceiveOutcome}, remote_stream_history::RemoteStreamHistory, state::{InboundState, OutboundState, SessionPhase, SessionState, StreamRole, StreamState}, stream_tx::StreamTxRange, @@ -103,7 +103,7 @@ impl SessionFsm { next_record_seq: RecordSeq::from_u32(0), next_write_id: 0, tracked_records: Default::default(), - record_rx: RecordRxState::new( + ack_tracker: AckTracker::new( config.accepted_record_window, config.pending_ack_range_limit, ), @@ -158,7 +158,7 @@ impl SessionFsm { let close = SessionClose { code }; self.state.phase = SessionPhase::Closing(close.clone()); self.state.tracked_records.clear(); - self.state.record_rx.clear_ack_state(); + self.state.ack_tracker.clear_ack_state(); self.clear_streams(); emit(SessionEvent::SessionClosed(close)); } @@ -185,7 +185,7 @@ impl SessionFsm { self.collect_timeouts(now); - match self.state.record_rx.insert(seq) { + match self.state.ack_tracker.insert(seq) { ReceiveOutcome::TooOld => return, ReceiveOutcome::Duplicate => { self.schedule_ack(now, true); @@ -263,7 +263,7 @@ impl SessionFsm { }; restore_tracked_record( now, - &mut self.state.record_rx, + &mut self.state.ack_tracker, &mut self.state.pending_ping, &mut self.state.streams, record, @@ -294,7 +294,7 @@ impl SessionFsm { if !self.state.phase.is_open() { return None; } - let ack_deadline = self.state.record_rx.ack_deadline(); + let ack_deadline = self.state.ack_tracker.ack_deadline(); let retransmit_deadline = self .state .tracked_records @@ -380,7 +380,7 @@ impl SessionFsm { if let Some(pending_ack) = self.pending_ack(builder.remaining_capacity()) { if !builder.is_empty() || pending_ack.due_at <= now { if builder.push_ack(&pending_ack.ack) { - self.state.record_rx.on_ack_emitted(&pending_ack); + self.state.ack_tracker.on_ack_emitted(&pending_ack); outbound.ack = Some(pending_ack.ack); } } @@ -549,7 +549,7 @@ impl SessionFsm { } fn schedule_ack(&mut self, now: Instant, immediate: bool) { - self.state.record_rx.schedule_ack(if immediate { + self.state.ack_tracker.schedule_ack(if immediate { now } else { now + self.config.ack_delay @@ -558,7 +558,7 @@ impl SessionFsm { fn pending_ack(&self, remaining_capacity: usize) -> Option { let max_ack_wire_size = remaining_capacity.checked_sub(1)?; - self.state.record_rx.pending_ack(max_ack_wire_size) + self.state.ack_tracker.pending_ack(max_ack_wire_size) } fn collect_timeouts(&mut self, now: Instant) { @@ -570,7 +570,7 @@ impl SessionFsm { }) { restore_tracked_record( now, - &mut self.state.record_rx, + &mut self.state.ack_tracker, &mut self.state.pending_ping, &mut self.state.streams, record, @@ -898,13 +898,13 @@ fn local_stream_was_opened( fn restore_tracked_record( now: Instant, - record_rx: &mut RecordRxState, + ack_tracker: &mut AckTracker, pending_ping: &mut bool, streams: &mut IndexMap, record: TrackedRecord, ) { if let Some(ack) = &record.ack { - record_rx.restore_acked_ranges(ack, now); + ack_tracker.restore_acked_ranges(ack, now); } if record.ping_included { *pending_ping = true; diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index f9fba199..c9e7a8ca 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -4,8 +4,8 @@ use indexmap::IndexMap; use ql_wire::{CloseTarget, RecordSeq, RouteId, SessionClose, StreamClose, StreamId}; use super::{ - received_records::RecordRxState, remote_stream_history::RemoteStreamHistory, - stream_rx::StreamRx, stream_tx::StreamTx, tracked::TrackedRecord, + ack_tracker::AckTracker, remote_stream_history::RemoteStreamHistory, stream_rx::StreamRx, + stream_tx::StreamTx, tracked::TrackedRecord, }; pub struct SessionState { @@ -16,7 +16,7 @@ pub struct SessionState { pub next_record_seq: RecordSeq, pub next_write_id: u64, pub tracked_records: IndexMap, - pub record_rx: RecordRxState, + pub ack_tracker: AckTracker, pub pending_ping: bool, pub streams: IndexMap, pub next_stream_index: usize, From 5143d9cee57c250338cadd3cf4b8235b5d0e0d15 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 14 Apr 2026 22:04:46 -0400 Subject: [PATCH 214/304] tests --- ql-fsm/src/session/tests.rs | 158 +++++++++++++++++------------------ ql-fsm/src/tests/proptest.rs | 69 +++++++++++++++ 2 files changed, 148 insertions(+), 79 deletions(-) diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index e0ab577d..88741280 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -287,85 +287,6 @@ fn pure_ack_only_records_are_fire_and_forget() { .is_none()); } -#[test] -fn sparse_out_of_order_ack_ranges_page_and_quiesce() { - let now = Instant::now(); - let sender_config = SessionConfig { - local_parity: StreamParity::Even, - record_max_size: SessionRecordBuilder::MIN_CAPACITY + 40, - ack_delay: Duration::from_millis(5), - retransmit_timeout: Duration::from_millis(25), - stream_send_buffer_size: 8 * 1024, - initial_peer_stream_receive_window: 8 * 1024, - ..SessionConfig::default() - }; - let receiver_config = SessionConfig { - local_parity: StreamParity::Odd, - record_max_size: SessionRecordBuilder::MIN_CAPACITY + 10, - ack_delay: Duration::from_millis(1), - retransmit_timeout: Duration::from_millis(25), - pending_ack_range_limit: 512, - initial_peer_stream_receive_window: 8 * 1024, - ..SessionConfig::default() - }; - let mut sender = SessionFsm::new(sender_config, now); - let mut receiver = SessionFsm::new(receiver_config, now); - - let stream_id = open_stream_id(&mut sender); - let payload = vec![b'x'; 2048]; - assert_eq!( - write_stream_bytes(&mut sender, stream_id, &payload), - payload.len() - ); - - let originals = drain_outbound(&mut sender, now, 4096); - assert!(originals.len() >= 64); - - for (seq, record) in originals - .iter() - .filter(|(seq, _)| seq.into_inner() % 2 == 1) - { - let _ = receive_events(&mut receiver, now, *seq, record); - } - - let first_ack_time = now + receiver_config.ack_delay; - let first_acks = drain_outbound(&mut receiver, first_ack_time, originals.len()); - assert!(first_acks.len() > 1); - assert!(first_acks - .iter() - .all(|(_, frames)| matches!(frames.as_slice(), [SessionFrame::Ack(_)]))); - - for (seq, record) in &first_acks { - let _ = receive_events(&mut sender, first_ack_time, *seq, record); - } - - let retransmit_time = now + sender_config.retransmit_timeout + Duration::from_millis(1); - sender.on_timer(retransmit_time, |_| {}); - let retransmits = drain_outbound(&mut sender, retransmit_time, originals.len()); - assert!(!retransmits.is_empty()); - - for (seq, record) in &retransmits { - let _ = receive_events(&mut receiver, retransmit_time, *seq, record); - } - - let second_ack_time = retransmit_time + receiver_config.ack_delay; - let second_acks = drain_outbound(&mut receiver, second_ack_time, retransmits.len() + 16); - assert!(!second_acks.is_empty()); - assert!(second_acks - .iter() - .all(|(_, frames)| matches!(frames.as_slice(), [SessionFrame::Ack(_)]))); - - for (seq, record) in &second_acks { - let _ = receive_events(&mut sender, second_ack_time, *seq, record); - } - - let final_now = second_ack_time + sender_config.retransmit_timeout + Duration::from_millis(1); - sender.on_timer(final_now, |_| {}); - receiver.on_timer(final_now, |_| {}); - assert!(next_outbound(&mut sender, final_now).is_none()); - assert!(next_outbound(&mut receiver, final_now).is_none()); -} - #[test] fn inbound_stream_data_emits_opened_and_readable() { let now = Instant::now(); @@ -736,3 +657,82 @@ fn initial_peer_stream_receive_window_limits_first_send() { ) })); } + +#[test] +fn sparse_out_of_order_ack_ranges_page_and_quiesce() { + let now = Instant::now(); + let sender_config = SessionConfig { + local_parity: StreamParity::Even, + record_max_size: SessionRecordBuilder::MIN_CAPACITY + 40, + ack_delay: Duration::from_millis(5), + retransmit_timeout: Duration::from_millis(25), + stream_send_buffer_size: 8 * 1024, + initial_peer_stream_receive_window: 8 * 1024, + ..SessionConfig::default() + }; + let receiver_config = SessionConfig { + local_parity: StreamParity::Odd, + record_max_size: SessionRecordBuilder::MIN_CAPACITY + 10, + ack_delay: Duration::from_millis(1), + retransmit_timeout: Duration::from_millis(25), + pending_ack_range_limit: 512, + initial_peer_stream_receive_window: 8 * 1024, + ..SessionConfig::default() + }; + let mut sender = SessionFsm::new(sender_config, now); + let mut receiver = SessionFsm::new(receiver_config, now); + + let stream_id = open_stream_id(&mut sender); + let payload = vec![b'x'; 2048]; + assert_eq!( + write_stream_bytes(&mut sender, stream_id, &payload), + payload.len() + ); + + let originals = drain_outbound(&mut sender, now, 4096); + assert!(originals.len() >= 64); + + for (seq, record) in originals + .iter() + .filter(|(seq, _)| seq.into_inner() % 2 == 1) + { + let _ = receive_events(&mut receiver, now, *seq, record); + } + + let first_ack_time = now + receiver_config.ack_delay; + let first_acks = drain_outbound(&mut receiver, first_ack_time, originals.len()); + assert!(first_acks.len() > 1); + assert!(first_acks + .iter() + .all(|(_, frames)| matches!(frames.as_slice(), [SessionFrame::Ack(_)]))); + + for (seq, record) in &first_acks { + let _ = receive_events(&mut sender, first_ack_time, *seq, record); + } + + let retransmit_time = now + sender_config.retransmit_timeout + Duration::from_millis(1); + sender.on_timer(retransmit_time, |_| {}); + let retransmits = drain_outbound(&mut sender, retransmit_time, originals.len()); + assert!(!retransmits.is_empty()); + + for (seq, record) in &retransmits { + let _ = receive_events(&mut receiver, retransmit_time, *seq, record); + } + + let second_ack_time = retransmit_time + receiver_config.ack_delay; + let second_acks = drain_outbound(&mut receiver, second_ack_time, retransmits.len() + 16); + assert!(!second_acks.is_empty()); + assert!(second_acks + .iter() + .all(|(_, frames)| matches!(frames.as_slice(), [SessionFrame::Ack(_)]))); + + for (seq, record) in &second_acks { + let _ = receive_events(&mut sender, second_ack_time, *seq, record); + } + + let final_now = second_ack_time + sender_config.retransmit_timeout + Duration::from_millis(1); + sender.on_timer(final_now, |_| {}); + receiver.on_timer(final_now, |_| {}); + assert!(next_outbound(&mut sender, final_now).is_none()); + assert!(next_outbound(&mut receiver, final_now).is_none()); +} diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index c06002d9..a9e756c7 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -181,6 +181,10 @@ impl Runner { session_peer_timeout: Duration::from_secs(5), ..QlFsmConfig::default() }; + Self::connected_with_config(config) + } + + fn connected_with_config(config: QlFsmConfig) -> Self { let connected_events = || SideEventState { last_peer_status: Some(PeerStatus::Connected), session_epoch: 1, @@ -579,6 +583,23 @@ impl Runner { Ok(()) } + fn assert_expected_delivered(&self, side: Side) -> TestCaseResult { + for (stream_id, expected) in &self.expected[side.idx()] { + let received = self.received[side.idx()] + .get(stream_id) + .map_or(&[][..], Vec::as_slice); + prop_assert_eq!( + received, + expected, + "side {:?} did not receive full payload for {:?}", + side, + stream_id + ); + } + + Ok(()) + } + fn assert_no_stream_events(&self) -> TestCaseResult { prop_assert!( self.known_streams.is_empty() @@ -845,6 +866,19 @@ fn write_tracking_action_strategy() -> impl Strategy { ] } +fn packet_loss_recovery_action_strategy() -> impl Strategy { + let queue_index = 0usize..16; + prop_oneof![ + (0u8..20).prop_map(Action::AdvanceMs), + side_action(Action::OnTimer), + Just(Action::OnTimerBoth), + Just(Action::Pump), + side_usize_action(queue_index.clone(), Action::deliver_queued), + side_usize_action(queue_index.clone(), Action::duplicate_queued), + side_usize_action(queue_index, Action::drop_queued), + ] +} + fn terminal_action_strategy() -> impl Strategy { let bytes = vec(any::(), 0..16); let slot = 0usize..SLOT_COUNT; @@ -894,6 +928,41 @@ proptest_crate::proptest! { runner.assert_no_taken_writes()?; } + #[test] + fn randomized_session_packet_loss_recovers( + payload in vec(any::(), 512..2048), + actions in vec(packet_loss_recovery_action_strategy(), 1..96), + ) { + let config = QlFsmConfig { + session_record_ack_delay: Duration::from_millis(1), + session_record_retransmit_timeout: Duration::from_millis(10), + session_record_max_size: 96, + session_pending_ack_range_limit: 512, + ..QlFsmConfig::default() + }; + let mut runner = Runner::connected_with_config(config); + + runner.apply(&Action::open_stream(Side::A, 0)); + runner.observe_and_assert()?; + + runner.apply(&Action::write(Side::A, 0, payload)); + runner.observe_and_assert()?; + + runner.apply(&Action::finish(Side::A, 0)); + runner.observe_and_assert()?; + + for action in &actions { + runner.apply(action); + runner.observe_and_assert()?; + } + + runner.cleanup()?; + runner.observe_and_assert()?; + runner.assert_expected_delivered(Side::B)?; + runner.assert_terminal_semantics()?; + runner.assert_quiesced()?; + } + #[test] fn randomized_terminal_actions_preserve_terminal_semantics(actions in vec(terminal_action_strategy(), 1..80)) { let mut runner = Runner::connected(); From c4cf333476742336e9e2205f7367673626b67bae Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 14 Apr 2026 23:57:26 -0400 Subject: [PATCH 215/304] poll fairness --- ql-runtime/src/driver/mod.rs | 58 +++++++++++++++++++++++++++--------- 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index a46059d3..ce43c56b 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -14,7 +14,7 @@ use std::{ }; use async_channel::Recv; -use futures_lite::future::poll_fn; +use futures_lite::future::{poll_fn, yield_now}; use ql_fsm::{Event, FsmTime, QlFsm, WriteId}; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; @@ -53,6 +53,7 @@ impl Runtime

{ let mut timer = platform.timer(); let recv_future = rx.recv(); let mut recv_future = pin!(recv_future); + let mut poll_cursor = 0usize; loop { state.drain_fsm_events(&mut fsm, &platform); @@ -61,8 +62,17 @@ impl Runtime

{ } timer.set_deadline(fsm.next_deadline()); - let step = - poll_fn(|cx| next_step(cx, recv_future.as_mut(), &mut timer, &mut in_flight)).await; + let step = poll_fn(|cx| { + next_step( + cx, + recv_future.as_mut(), + &mut timer, + &mut in_flight, + poll_cursor, + ) + }) + .await; + poll_cursor = (poll_cursor + 1) % STEP_COUNT; match step { DriverStep::Command(command) => { @@ -71,6 +81,7 @@ impl Runtime

{ DriverStep::WriteCompleted { index, success } => { let write = in_flight.swap_remove(index); DriverState::drive_write_completed(&mut fsm, write.session_write_id, success); + yield_now().await; } DriverStep::TimerExpired => { fsm.on_timer(now()); @@ -97,30 +108,49 @@ enum DriverStep { Closed, } +const STEP_COUNT: usize = 3; + fn next_step( cx: &mut Context<'_>, mut recv_future: Pin<&mut Recv<'_, RuntimeCommand>>, timer: &mut T, in_flight: &mut [InFlightWrite], + start: usize, ) -> Poll where T: QlTimer, F: Future + Unpin, { - for (index, write) in in_flight.iter_mut().enumerate() { - if let Poll::Ready(success) = Pin::new(&mut write.future).poll(cx) { - return Poll::Ready(DriverStep::WriteCompleted { index, success }); + for offset in 0..STEP_COUNT { + let step = (start + offset) % STEP_COUNT; + let poll = match step { + 0 => recv_future + .as_mut() + .poll(cx) + .map(|res| res.map_or(DriverStep::Closed, DriverStep::Command)), + 1 => { + for (index, write) in in_flight.iter_mut().enumerate() { + if let Poll::Ready(success) = Pin::new(&mut write.future).poll(cx) { + return Poll::Ready(DriverStep::WriteCompleted { index, success }); + } + } + Poll::Pending + } + 2 => { + if timer.poll_wait(cx) == Poll::Ready(()) { + Poll::Ready(DriverStep::TimerExpired) + } else { + Poll::Pending + } + } + _ => unreachable!(), + }; + if poll.is_ready() { + return poll; } } - if timer.poll_wait(cx) == Poll::Ready(()) { - return Poll::Ready(DriverStep::TimerExpired); - } - - recv_future - .as_mut() - .poll(cx) - .map(|res| res.map_or(DriverStep::Closed, DriverStep::Command)) + Poll::Pending } impl DriverState { From 78181c7faa75e082f4621d77b22d07882e37c611 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 00:02:14 -0400 Subject: [PATCH 216/304] stress tests --- ql-runtime/src/tests/mod.rs | 164 +++++++++++++++++++++--- ql-runtime/src/tests/stream.rs | 228 ++++++++++++++++++++++++++++++--- 2 files changed, 356 insertions(+), 36 deletions(-) diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index d23fd445..237e6302 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -3,7 +3,7 @@ use std::{ pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, - Arc, + Arc, Mutex, }, task::{Context, Poll}, time::Duration, @@ -152,11 +152,63 @@ struct TestPair { b: TestSide, } +#[derive(Debug, Clone, Copy, Default)] +struct LinkBehavior { + base_delay: Duration, + drop_encrypted_every: Option, + duplicate_encrypted_every: Option, + delay_encrypted_every: Option<(usize, Duration)>, +} + +#[derive(Clone, Default)] +struct LinkController { + behavior: Arc>, +} + +impl LinkController { + fn new(behavior: LinkBehavior) -> Self { + Self { + behavior: Arc::new(Mutex::new(behavior)), + } + } + + fn load(&self) -> LinkBehavior { + *self.behavior.lock().unwrap() + } + + fn store(&self, behavior: LinkBehavior) { + *self.behavior.lock().unwrap() = behavior; + } +} + +#[derive(Clone)] +struct ControlledLinks { + a_to_b: LinkController, + b_to_a: LinkController, +} + impl TestPair { fn new(config: RuntimeConfig) -> Self { + Self::new_with_links(config, LinkBehavior::default(), LinkBehavior::default()) + } + + fn new_with_links(config: RuntimeConfig, a_to_b: LinkBehavior, b_to_a: LinkBehavior) -> Self { + let (pair, _links) = Self::new_with_controlled_links(config, a_to_b, b_to_a); + pair + } + + fn new_with_controlled_links( + config: RuntimeConfig, + a_to_b: LinkBehavior, + b_to_a: LinkBehavior, + ) -> (Self, ControlledLinks) { let (platform_a, outbound_a, status_a, inbound_a) = TestPlatform::new_with_inbound(); let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + let links = ControlledLinks { + a_to_b: LinkController::new(a_to_b), + b_to_a: LinkController::new(b_to_a), + }; let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config.clone()); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); @@ -164,24 +216,27 @@ impl TestPair { tokio::task::spawn_local(async move { runtime_a.run().await }); tokio::task::spawn_local(async move { runtime_b.run().await }); - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); + spawn_simulated_forwarder(outbound_a, handle_b.clone(), links.a_to_b.clone()); + spawn_simulated_forwarder(outbound_b, handle_a.clone(), links.b_to_a.clone()); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - Self { - a: TestSide { - handle: handle_a, - status: status_a, - peer: identity_a.xid, - inbound: inbound_a, - }, - b: TestSide { - handle: handle_b, - status: status_b, - peer: identity_b.xid, - inbound: inbound_b, + ( + Self { + a: TestSide { + handle: handle_a, + status: status_a, + peer: identity_a.xid, + inbound: inbound_a, + }, + b: TestSide { + handle: handle_b, + status: status_b, + peer: identity_b.xid, + inbound: inbound_b, + }, }, - } + links, + ) } fn side(&self, side: Side) -> &TestSide { @@ -214,10 +269,6 @@ impl TestPair { .await; } - fn handle(&self, side: Side) -> RuntimeHandle { - self.side(side).handle.clone() - } - fn take_inbound(&mut self, side: Side) -> Receiver { let replacement = async_channel::unbounded().1; std::mem::replace(&mut self.side_mut(side).inbound, replacement) @@ -384,9 +435,70 @@ fn register_peers( } fn spawn_forwarder(outbound: Receiver>, handle: RuntimeHandle) { + spawn_simulated_forwarder( + outbound, + handle, + LinkController::new(LinkBehavior::default()), + ); +} + +fn spawn_simulated_forwarder( + outbound: Receiver>, + handle: RuntimeHandle, + controller: LinkController, +) { tokio::task::spawn_local(async move { + let mut encrypted_count = 0usize; while let Ok(bytes) = outbound.recv().await { - handle.receive(bytes); + let behavior = controller.load(); + let encrypted = is_encrypted_payload(&bytes); + let ordinal = if encrypted { + encrypted_count = encrypted_count.saturating_add(1); + Some(encrypted_count) + } else { + None + }; + + if ordinal.is_some_and(|count| { + behavior + .drop_encrypted_every + .is_some_and(|nth| nth != 0 && count % nth == 0) + }) { + continue; + } + + let mut delay = behavior.base_delay; + if let Some(count) = ordinal { + if let Some((nth, extra_delay)) = behavior.delay_encrypted_every { + if nth != 0 && count % nth == 0 { + delay += extra_delay; + } + } + } + + let primary = bytes.clone(); + let primary_handle = handle.clone(); + tokio::task::spawn_local(async move { + if !delay.is_zero() { + tokio::time::sleep(delay).await; + } + primary_handle.receive(primary); + }); + + if ordinal.is_some_and(|count| { + behavior + .duplicate_encrypted_every + .is_some_and(|nth| nth != 0 && count % nth == 0) + }) { + let duplicate_handle = handle.clone(); + tokio::task::spawn_local(async move { + let duplicate_delay = delay + Duration::from_millis(1); + if !duplicate_delay.is_zero() { + tokio::time::sleep(duplicate_delay).await; + } + duplicate_handle.receive(bytes); + }); + } } }); } @@ -434,6 +546,16 @@ where local.run_until(future).await; } +#[allow(clippy::future_not_send)] +async fn run_local_test_timeout(duration: Duration, future: F) +where + F: Future, +{ + tokio::time::timeout(duration, run_local_test(future)) + .await + .unwrap_or_else(|_| panic!("local runtime test exceeded {:?}", duration)); +} + async fn await_status(receiver: &Receiver, peer: XID, stage: PeerStatus) { tokio::time::timeout(Duration::from_secs(2), async { loop { diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index a6f960d3..cc0e0be4 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -60,7 +60,7 @@ async fn open_stream_duplex_happy_path() { } #[tokio::test(flavor = "current_thread")] -async fn reader_exposes_bounded_chunk_reads() { +async fn reader_respects_max_len() { run_local_test(async { let mut pair = TestPair::new(default_runtime_config()); pair.connect_and_wait(Side::A).await; @@ -70,18 +70,12 @@ async fn reader_exposes_bounded_chunk_reads() { let inbound = inbound_b.recv().await.unwrap(); let mut reader = inbound.reader; - assert_eq!( - next_chunk_max(&mut reader, 2).await.unwrap(), - Some(vec![1, 2]) - ); + assert_eq!(next_chunk_max(&mut reader, 2).await.unwrap(), Some(vec![1, 2])); assert_eq!( next_chunk_max(&mut reader, 2).await.unwrap(), Some(vec![3, 4]) ); - assert_eq!( - next_chunk_max(&mut reader, 2).await.unwrap(), - Some(vec![5, 6]) - ); + assert_eq!(next_chunk_max(&mut reader, 2).await.unwrap(), Some(vec![5, 6])); assert_eq!(next_chunk(&mut reader).await.unwrap(), None); inbound.writer.finish(); @@ -95,12 +89,7 @@ async fn reader_exposes_bounded_chunk_reads() { .unwrap(); stream .writer - .write(Bytes::from_static(&[1, 2, 3, 4])) - .await - .unwrap(); - stream - .writer - .write(Bytes::from_static(&[5, 6])) + .write(Bytes::from_static(&[1, 2, 3, 4, 5, 6])) .await .unwrap(); stream.writer.finish(); @@ -412,3 +401,212 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { }) .await; } + +#[tokio::test(flavor = "current_thread")] +async fn multi_megabyte_stream_survives_asymmetric_loss_and_delay() { + run_local_test_timeout(Duration::from_secs(5), async { + let payload_len = 2 * 1024 * 1024; + let chunk_len = 16 * 1024; + let payload: Vec = (0..payload_len) + .map(|i| u8::try_from(i % 251).unwrap()) + .collect(); + let expected = payload.clone(); + let config = RuntimeConfig { + fsm: QlFsmConfig { + session_record_max_size: 16 * 1024, + session_record_ack_delay: Duration::from_millis(2), + session_record_retransmit_timeout: Duration::from_millis(25), + session_stream_send_buffer_size: 4 * 1024 * 1024, + session_stream_receive_buffer_size: 4 * 1024 * 1024, + session_accepted_record_window: 16 * 1024, + session_pending_ack_range_limit: 4 * 1024, + ..default_runtime_config().fsm + }, + stream_send_buffer_bytes: 4 * 1024 * 1024, + ..default_runtime_config() + }; + let (mut pair, links) = TestPair::new_with_controlled_links( + config, + LinkBehavior { + base_delay: Duration::from_millis(1), + drop_encrypted_every: Some(41), + delay_encrypted_every: Some((13, Duration::from_millis(12))), + ..LinkBehavior::default() + }, + LinkBehavior { + base_delay: Duration::from_millis(1), + ..LinkBehavior::default() + }, + ); + pair.connect_and_wait(Side::A).await; + links.b_to_a.store(LinkBehavior { + base_delay: Duration::from_millis(3), + drop_encrypted_every: Some(7), + duplicate_encrypted_every: Some(19), + delay_encrypted_every: Some((3, Duration::from_millis(25))), + }); + let inbound_b = pair.take_inbound(Side::B); + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + eprintln!("responder accepted inbound stream"); + let mut reader = stream.reader; + let mut received = Vec::new(); + while let Some(chunk) = next_chunk(&mut reader).await.unwrap() { + if received.len() >= 36 * chunk_len { + eprintln!("responder received chunk of {} bytes", chunk.len()); + } + received.extend_from_slice(&chunk); + if received.len() % (256 * 1024) == 0 { + eprintln!("responder received {} bytes", received.len()); + } + } + stream.writer.finish(); + received + }); + + let recovery_links = links.clone(); + let recovery = tokio::task::spawn_local(async move { + tokio::time::sleep(Duration::from_millis(300)).await; + eprintln!("restoring reverse path"); + recovery_links.b_to_a.store(LinkBehavior { + base_delay: Duration::from_millis(1), + delay_encrypted_every: Some((17, Duration::from_millis(8))), + ..LinkBehavior::default() + }); + }); + + let writer = tokio::task::spawn_local(async move { + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + for (index, chunk) in payload.chunks(chunk_len).enumerate() { + if index + 1 >= 40 { + eprintln!("writer attempting chunk {}", index + 1); + } + stream + .writer + .write(Bytes::copy_from_slice(chunk)) + .await + .unwrap(); + if index + 1 >= 40 { + eprintln!("writer queued chunk {}", index + 1); + } + if index % 16 == 15 { + eprintln!("writer queued {} chunks", index + 1); + } + } + eprintln!("writer finished queueing"); + stream.writer.finish(); + eprintln!("writer waiting for eof"); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); + eprintln!("writer observed eof"); + }); + + tokio::time::timeout(Duration::from_secs(30), writer) + .await + .unwrap() + .unwrap(); + tokio::time::timeout(Duration::from_secs(2), recovery) + .await + .unwrap() + .unwrap(); + let received = tokio::time::timeout(Duration::from_secs(30), responder) + .await + .unwrap() + .unwrap(); + assert_eq!(received, expected); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn reproducer_writer_stalls_after_reverse_path_impairment() { + run_local_test_timeout(Duration::from_secs(5), async { + let payload_len = 2 * 1024 * 1024; + let chunk_len = 16 * 1024; + let payload: Vec = (0..payload_len) + .map(|i| u8::try_from(i % 251).unwrap()) + .collect(); + let config = RuntimeConfig { + fsm: QlFsmConfig { + session_record_max_size: 16 * 1024, + session_record_ack_delay: Duration::from_millis(2), + session_record_retransmit_timeout: Duration::from_millis(25), + session_stream_send_buffer_size: 4 * 1024 * 1024, + session_stream_receive_buffer_size: 4 * 1024 * 1024, + session_accepted_record_window: 16 * 1024, + session_pending_ack_range_limit: 4 * 1024, + ..default_runtime_config().fsm + }, + stream_send_buffer_bytes: 4 * 1024 * 1024, + ..default_runtime_config() + }; + let (mut pair, links) = TestPair::new_with_controlled_links( + config, + LinkBehavior { + base_delay: Duration::from_millis(1), + drop_encrypted_every: Some(41), + delay_encrypted_every: Some((13, Duration::from_millis(12))), + ..LinkBehavior::default() + }, + LinkBehavior { + base_delay: Duration::from_millis(1), + ..LinkBehavior::default() + }, + ); + pair.connect_and_wait(Side::A).await; + links.b_to_a.store(LinkBehavior { + base_delay: Duration::from_millis(3), + drop_encrypted_every: Some(7), + duplicate_encrypted_every: Some(19), + delay_encrypted_every: Some((3, Duration::from_millis(25))), + }); + let inbound_b = pair.take_inbound(Side::B); + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let mut reader = stream.reader; + let mut received = Vec::new(); + while let Some(chunk) = next_chunk(&mut reader).await.unwrap() { + received.extend_from_slice(&chunk); + } + }); + + let recovery_links = links.clone(); + let recovery = tokio::task::spawn_local(async move { + tokio::time::sleep(Duration::from_millis(300)).await; + recovery_links.b_to_a.store(LinkBehavior { + base_delay: Duration::from_millis(1), + delay_encrypted_every: Some((17, Duration::from_millis(8))), + ..LinkBehavior::default() + }); + }); + + let writer = tokio::task::spawn_local(async move { + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + for chunk in payload.chunks(chunk_len) { + stream + .writer + .write(Bytes::copy_from_slice(chunk)) + .await + .unwrap(); + } + stream.writer.finish(); + let _ = next_chunk(&mut stream.reader).await; + }); + + let _ = tokio::time::timeout(Duration::from_secs(15), writer).await; + recovery.abort(); + responder.abort(); + }) + .await; +} From 85edf0dce02c8adb805b34bcf006e561d6e8a6b6 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 00:08:22 -0400 Subject: [PATCH 217/304] rpc cleanup tests --- ql-runtime/src/tests/mod.rs | 4 ++++ ql-runtime/src/tests/rpc.rs | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 237e6302..158f6387 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -246,6 +246,10 @@ impl TestPair { } } + fn handle(&self, side: Side) -> &RuntimeHandle { + &self.side(side).handle + } + fn side_mut(&mut self, side: Side) -> &mut TestSide { match side { Side::A => &mut self.a, diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 6c9e7321..19e97b91 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -148,7 +148,7 @@ async fn rpc_router_handles_subscription() { seen.borrow_mut().push(request); let _ = response.send(b"one".to_vec()).await; let _ = response.send(b"two".to_vec()).await; - let _ = response.finish().await; + let _ = response.finish(); }); } } From 541ebff18d7ebc7f86c644a2d9d0f66c4495bd73 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 09:17:53 -0400 Subject: [PATCH 218/304] ql-runtime: debug logs for close --- ql-runtime/Cargo.toml | 1 + ql-runtime/src/driver/mod.rs | 32 ++++++++++++++++++++++++++++++ ql-runtime/src/handle/reader.rs | 35 +++++++++++++++++++++++++++++++++ ql-runtime/src/handle/writer.rs | 28 ++++++++++++++++++++++++++ 4 files changed, 96 insertions(+) diff --git a/ql-runtime/Cargo.toml b/ql-runtime/Cargo.toml index 208e5325..58009478 100644 --- a/ql-runtime/Cargo.toml +++ b/ql-runtime/Cargo.toml @@ -14,6 +14,7 @@ async-channel = { version = "2.5" } bytes = "1" event-listener = "5.4" futures-lite = { version = "2.5" } +log = "0.4" oneshot = { version = "0.1.11" } ql-fsm = { path = "../ql-fsm" } ql-rpc = { path = "../ql-rpc", optional = true } diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index ce43c56b..9c1274a1 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -15,6 +15,7 @@ use std::{ use async_channel::Recv; use futures_lite::future::{poll_fn, yield_now}; +use log::debug; use ql_fsm::{Event, FsmTime, QlFsm, WriteId}; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; @@ -240,6 +241,12 @@ impl DriverState { target, code, } => { + debug!( + "runtime close stream command: stream_id={:?} target={:?} code={:?}", + stream_id, + target, + code + ); if let Entry::Occupied(mut entry) = self.streams.entry(stream_id) { let stream = entry.get_mut(); if target == CloseTarget::Both || target == stream.inbound_target() { @@ -403,6 +410,7 @@ impl DriverState { } fn handle_inbound_finished(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { + debug!("runtime inbound finished event: stream_id={:?}", stream_id); let Some(stream) = self.streams.get_mut(&stream_id) else { return; }; @@ -425,11 +433,21 @@ impl DriverState { return; } + debug!( + "runtime delivering clean inbound finish: stream_id={:?}", + stream_id + ); stream.inbound_finish(); Self::try_reap_stream(entry); } fn handle_closed_stream(&mut self, frame: &ql_wire::StreamClose) { + debug!( + "runtime inbound close frame: stream_id={:?} target={:?} code={:?}", + frame.stream_id, + frame.target, + frame.code + ); let Entry::Occupied(mut entry) = self.streams.entry(frame.stream_id) else { return; }; @@ -445,6 +463,12 @@ impl DriverState { } fn handle_writable_closed(&mut self, frame: &ql_wire::StreamClose) { + debug!( + "runtime writable close frame: stream_id={:?} target={:?} code={:?}", + frame.stream_id, + frame.target, + frame.code + ); let Entry::Occupied(mut entry) = self.streams.entry(frame.stream_id) else { return; }; @@ -483,6 +507,10 @@ impl DriverState { }; if reader.is_finished() { + debug!( + "runtime observed outbound reader finished before write: stream_id={:?}", + stream_id + ); if let Ok(mut stream_ops) = fsm.stream(stream_id) { if let Some(writer) = stream_ops.writer() { writer.finish(); @@ -510,6 +538,10 @@ impl DriverState { } if reader.is_finished() { + debug!( + "runtime observed outbound reader finished after write: stream_id={:?}", + stream_id + ); writer.finish(); stream.outbound_close(); if stream.is_closed() { diff --git a/ql-runtime/src/handle/reader.rs b/ql-runtime/src/handle/reader.rs index 0b4b709f..636097d4 100644 --- a/ql-runtime/src/handle/reader.rs +++ b/ql-runtime/src/handle/reader.rs @@ -6,6 +6,7 @@ use std::{ use bytes::Bytes; use event_listener::EventListener; +use log::{debug, trace}; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; use crate::{chunk_slot::ChunkSlotRx, command::RuntimeCommand, QlStreamError, RuntimeHandle}; @@ -73,12 +74,23 @@ impl ByteReader { if let Some(reader) = self.reader.as_ref() { match reader.poll_recv(max_len, &mut self.listener, cx) { Poll::Ready(Ok(bytes)) => { + trace!( + "byte reader received chunk: stream_id={:?} target={:?} len={}", + self.stream_id, + self.target, + bytes.len() + ); self.handle.send(RuntimeCommand::PollInbound { stream_id: self.stream_id, }); return Poll::Ready(Ok(Some(bytes))); } Poll::Ready(Err(_)) => { + debug!( + "byte reader channel closed: stream_id={:?} target={:?}", + self.stream_id, + self.target + ); self.reader = None; self.listener = None; } @@ -102,11 +114,22 @@ impl ByteReader { match &self.terminal { TerminalState::Armed(_) => Poll::Pending, TerminalState::Terminal(Ok(())) => { + debug!( + "byte reader delivered clean eof: stream_id={:?} target={:?}", + self.stream_id, + self.target + ); self.terminal = TerminalState::Delivered; Poll::Ready(Ok(None)) } TerminalState::Terminal(Err(error)) => { let error = error.clone(); + debug!( + "byte reader delivered terminal error: stream_id={:?} target={:?} error={:?}", + self.stream_id, + self.target, + error + ); self.terminal = TerminalState::Delivered; Poll::Ready(Err(error)) } @@ -134,6 +157,12 @@ impl ByteReader { if matches!(self.terminal, TerminalState::Delivered) { return; } + debug!( + "byte reader explicit close: stream_id={:?} target={:?} code={:?}", + self.stream_id, + self.target, + code + ); self.reader.take(); self.listener = None; self.terminal = TerminalState::Delivered; @@ -150,6 +179,12 @@ impl Drop for ByteReader { if matches!(self.terminal, TerminalState::Delivered) { return; } + debug!( + "byte reader drop close: stream_id={:?} target={:?} code={:?}", + self.stream_id, + self.target, + StreamCloseCode::CANCELLED + ); self.handle.send(RuntimeCommand::CloseStream { stream_id: self.stream_id, target: self.target, diff --git a/ql-runtime/src/handle/writer.rs b/ql-runtime/src/handle/writer.rs index 0331a2ed..1060f7c9 100644 --- a/ql-runtime/src/handle/writer.rs +++ b/ql-runtime/src/handle/writer.rs @@ -6,6 +6,7 @@ use std::{ use bytes::Bytes; use event_listener::EventListener; +use log::{debug, trace}; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; use crate::{ @@ -82,11 +83,21 @@ impl ByteWriter { match writer.poll_send(bytes, &mut self.listener, cx) { Poll::Ready(Ok(())) => { + trace!( + "byte writer accepted chunk: stream_id={:?} target={:?}", + self.stream_id, + self.target + ); self.listener = None; self.poll_runtime(); Poll::Ready(Ok(())) } Poll::Ready(Err(SendClosed(_bytes))) => { + debug!( + "byte writer send closed: stream_id={:?} target={:?}", + self.stream_id, + self.target + ); self.writer.take(); self.listener = None; self.poll_terminal_error(cx).map(Err) @@ -104,6 +115,11 @@ impl ByteWriter { let Some(writer) = self.writer.take() else { return; }; + debug!( + "byte writer finish: stream_id={:?} target={:?}", + self.stream_id, + self.target + ); writer.close(); self.poll_runtime(); } @@ -115,6 +131,12 @@ impl ByteWriter { impl Drop for ByteWriter { fn drop(&mut self) { + debug!( + "byte writer drop close requested: stream_id={:?} target={:?} code={:?}", + self.stream_id, + self.target, + StreamCloseCode::CANCELLED + ); self.close_inner(StreamCloseCode::CANCELLED); } } @@ -140,6 +162,12 @@ impl ByteWriter { if self.writer.take().is_none() { return; } + debug!( + "byte writer close: stream_id={:?} target={:?} code={:?}", + self.stream_id, + self.target, + code + ); self.listener = None; self.handle.send(RuntimeCommand::CloseStream { stream_id: self.stream_id, From 9d63c167943e1d78c64d957e197275e0bccfee36 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 10:11:49 -0400 Subject: [PATCH 219/304] ql-runtime more logs --- ql-runtime/src/driver/mod.rs | 5 +++++ ql-runtime/src/handle/writer.rs | 6 ------ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 9c1274a1..0922f188 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -388,6 +388,11 @@ impl DriverState { break; } InboundWriteResult::Closed => { + debug!( + "runtime inbound consumer closed; sending CANCELLED: stream_id={:?} target={:?}", + stream_id, + target + ); peer_closed = true; break; } diff --git a/ql-runtime/src/handle/writer.rs b/ql-runtime/src/handle/writer.rs index 1060f7c9..b00811b9 100644 --- a/ql-runtime/src/handle/writer.rs +++ b/ql-runtime/src/handle/writer.rs @@ -131,12 +131,6 @@ impl ByteWriter { impl Drop for ByteWriter { fn drop(&mut self) { - debug!( - "byte writer drop close requested: stream_id={:?} target={:?} code={:?}", - self.stream_id, - self.target, - StreamCloseCode::CANCELLED - ); self.close_inner(StreamCloseCode::CANCELLED); } } From 7e3884205749f208b6ca7d9d088c1e1845aaf5ef Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 10:16:08 -0400 Subject: [PATCH 220/304] ql-rpc: add internal error variant --- ql-rpc/src/lib.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ql-rpc/src/lib.rs b/ql-rpc/src/lib.rs index e5277e08..7d968277 100644 --- a/ql-rpc/src/lib.rs +++ b/ql-rpc/src/lib.rs @@ -36,10 +36,11 @@ pub struct StreamCloseCode(pub u16); impl StreamCloseCode { pub const CANCELLED: Self = Self(0); - pub const REFUSED: Self = Self(1); - pub const TIMEOUT: Self = Self(2); - pub const LIMIT: Self = Self(3); - pub const UNKNOWN_ROUTE: Self = Self(4); + pub const INTERNAL: Self = Self(1); + pub const REFUSED: Self = Self(2); + pub const TIMEOUT: Self = Self(3); + pub const LIMIT: Self = Self(4); + pub const UNKNOWN_ROUTE: Self = Self(5); pub const fn into_inner(self) -> u16 { self.0 From 54122375e7820e148c9e6c62fab61d3a14f85822 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 10:55:43 -0400 Subject: [PATCH 221/304] add docs --- ql-rpc/src/lib.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ql-rpc/src/lib.rs b/ql-rpc/src/lib.rs index 7d968277..b10d96d6 100644 --- a/ql-rpc/src/lib.rs +++ b/ql-rpc/src/lib.rs @@ -35,11 +35,17 @@ impl From for RouteId { pub struct StreamCloseCode(pub u16); impl StreamCloseCode { + /// operation was cancelled pub const CANCELLED: Self = Self(0); + /// local internal error pub const INTERNAL: Self = Self(1); + /// request was refused pub const REFUSED: Self = Self(2); + /// operation timed out pub const TIMEOUT: Self = Self(3); + /// configured limit was exceeded pub const LIMIT: Self = Self(4); + /// route identifier was unknown pub const UNKNOWN_ROUTE: Self = Self(5); pub const fn into_inner(self) -> u16 { From 910c6e2f253347ac2b3b8cb17a8bb30420595a67 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 11:24:10 -0400 Subject: [PATCH 222/304] ql-rpc: custom stream errors --- ql-rpc/src/router/mod.rs | 2 +- ql-rpc/src/router/request.rs | 39 +++++++++++++++++-------------- ql-rpc/src/router/stream.rs | 27 ++++++++++++++++----- ql-rpc/src/router/subscription.rs | 24 ++++++++++--------- ql-runtime/src/rpc/adapter.rs | 36 +++++++++++++++++++--------- 5 files changed, 81 insertions(+), 47 deletions(-) diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs index bccf5ccb..88d23d51 100644 --- a/ql-rpc/src/router/mod.rs +++ b/ql-rpc/src/router/mod.rs @@ -14,7 +14,7 @@ pub use self::{ config::RouterConfig, mode::*, request::{RequestHandler, Response}, - stream::{RpcRead, RpcStream, RpcWrite}, + stream::{RpcRead, RpcStream, RpcWrite, StreamError}, subscription::{SubscriptionHandler, SubscriptionResponder}, }; diff --git a/ql-rpc/src/router/request.rs b/ql-rpc/src/router/request.rs index 66882edf..55fc55c6 100644 --- a/ql-rpc/src/router/request.rs +++ b/ql-rpc/src/router/request.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use bytes::Bytes; use super::{ - stream::{read_bytes, write_bytes, RpcRead, RpcStream, RpcWrite}, + stream::{read_bytes, write_bytes, RpcRead, RpcStream, RpcWrite, StreamError}, LocalMode, RouteMode, RouterConfig, SendMode, }; use crate::{ @@ -16,6 +16,8 @@ where St: RpcStream, { fn handle(self, message: M::Request, responder: Response); + + fn handle_transport_error(&self, _error: &St::Error) {} } pub struct Response @@ -38,14 +40,11 @@ where } } - pub async fn respond(mut self, response: T) -> Result<(), StreamCloseCode> { + pub async fn respond(mut self, response: T) -> Result<(), W::Error> { let mut writer = self.writer.take().expect("response writer exists"); let mut encoded = Vec::new(); codec::encode_value_part(&response, &mut encoded); - if let Err(code) = write_bytes(&mut writer, Bytes::from(encoded)).await { - writer.close(code); - return Err(code); - } + write_bytes(&mut writer, Bytes::from(encoded)).await?; writer.finish(); Ok(()) } @@ -121,9 +120,13 @@ async fn handle_request_inner( { let request = match read_value_and_eof::(&mut reader, config).await { Ok(request) => request, - Err(code) => { - reader.close(code); - writer.close(code); + Err(error) => { + let code = error.close_code(); + state.handle_transport_error(&error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } return; } }; @@ -134,7 +137,7 @@ async fn handle_request_inner( pub(super) async fn read_value_and_eof( reader: &mut R, config: RouterConfig, -) -> Result +) -> Result where T: RpcCodec, R: RpcRead, @@ -146,13 +149,13 @@ where match value_reader.advance() { Ok(ReadValueStep::Value(value)) => break value, Ok(ReadValueStep::NeedMore(next)) => value_reader = next, - Err(crate::CodecError::Rpc(_error)) => return Err(StreamCloseCode::REFUSED), - Err(crate::CodecError::Codec(_error)) => return Err(StreamCloseCode::REFUSED), + Err(crate::CodecError::Rpc(_error)) => return Err(StreamCloseCode::REFUSED.into()), + Err(crate::CodecError::Codec(_error)) => return Err(StreamCloseCode::REFUSED.into()), } let remaining = config.max_request_bytes.saturating_sub(total_read); if remaining == 0 { - return Err(StreamCloseCode::LIMIT); + return Err(StreamCloseCode::LIMIT.into()); } match read_bytes(reader, remaining).await { @@ -160,8 +163,8 @@ where total_read += chunk.len(); value_reader = value_reader.push(chunk); } - Ok(None) => return Err(StreamCloseCode::REFUSED), - Err(code) => return Err(code), + Ok(None) => return Err(StreamCloseCode::REFUSED.into()), + Err(error) => return Err(error), } }; @@ -169,8 +172,8 @@ where let probe = remaining.max(1); match read_bytes(reader, probe).await { Ok(None) => Ok(value), - Ok(Some(_)) if remaining == 0 => Err(StreamCloseCode::LIMIT), - Ok(Some(_)) => Err(StreamCloseCode::REFUSED), - Err(code) => Err(code), + Ok(Some(_)) if remaining == 0 => Err(StreamCloseCode::LIMIT.into()), + Ok(Some(_)) => Err(StreamCloseCode::REFUSED.into()), + Err(error) => Err(error), } } diff --git a/ql-rpc/src/router/stream.rs b/ql-rpc/src/router/stream.rs index dcaabbf8..7b531374 100644 --- a/ql-rpc/src/router/stream.rs +++ b/ql-rpc/src/router/stream.rs @@ -8,40 +8,55 @@ use bytes::Bytes; use crate::{RouteId, StreamCloseCode}; pub trait RpcStream { - type Reader: RpcRead; - type Writer: RpcWrite; + type Error: StreamError; + type Reader: RpcRead; + type Writer: RpcWrite; fn route_id(&self) -> Option; fn split(self) -> (Self::Reader, Self::Writer); } pub trait RpcRead { + type Error: StreamError; + fn poll_read( &mut self, max_len: usize, cx: &mut Context<'_>, - ) -> Poll, StreamCloseCode>>; + ) -> Poll, Self::Error>>; fn close(self, code: StreamCloseCode); } pub trait RpcWrite { + type Error: StreamError; + fn poll_write( &mut self, bytes: &mut Bytes, cx: &mut Context<'_>, - ) -> Poll>; + ) -> Poll>; fn finish(self); fn close(self, code: StreamCloseCode); } -pub async fn read_bytes(reader: &mut R, max_len: usize) -> Result, StreamCloseCode> +pub trait StreamError: From { + fn close_code(&self) -> Option; +} + +impl StreamError for StreamCloseCode { + fn close_code(&self) -> Option { + Some(*self) + } +} + +pub async fn read_bytes(reader: &mut R, max_len: usize) -> Result, R::Error> where R: RpcRead, { poll_fn(|cx| reader.poll_read(max_len, cx)).await } -pub async fn write_bytes(writer: &mut W, bytes: Bytes) -> Result<(), StreamCloseCode> +pub async fn write_bytes(writer: &mut W, bytes: Bytes) -> Result<(), W::Error> where W: RpcWrite, { diff --git a/ql-rpc/src/router/subscription.rs b/ql-rpc/src/router/subscription.rs index 5badbbea..bb137b6b 100644 --- a/ql-rpc/src/router/subscription.rs +++ b/ql-rpc/src/router/subscription.rs @@ -4,7 +4,7 @@ use bytes::Bytes; use super::{ request::read_value_and_eof, - stream::{write_bytes, RpcRead, RpcStream, RpcWrite}, + stream::{write_bytes, RpcRead, RpcStream, RpcWrite, StreamError}, LocalMode, RouteMode, RouterConfig, SendMode, }; use crate::{codec, subscription::Subscription as SubscriptionRpc, RpcCodec, StreamCloseCode}; @@ -15,6 +15,8 @@ where St: RpcStream, { fn handle(self, message: M::Request, responder: SubscriptionResponder); + + fn handle_transport_error(&self, _error: &St::Error) {} } pub struct SubscriptionResponder @@ -37,19 +39,15 @@ where } } - pub async fn send(&mut self, event: T) -> Result<(), StreamCloseCode> { + pub async fn send(&mut self, event: T) -> Result<(), W::Error> { let writer = self.writer.as_mut().expect("subscription writer exists"); let mut encoded = Vec::new(); codec::encode_value_part(&event, &mut encoded); - if let Err(code) = write_bytes(writer, Bytes::from(encoded)).await { - let writer = self.writer.take().expect("subscription writer exists"); - writer.close(code); - return Err(code); - } + write_bytes(writer, Bytes::from(encoded)).await?; Ok(()) } - pub fn finish(mut self) -> Result<(), StreamCloseCode> { + pub fn finish(mut self) -> Result<(), W::Error> { let writer = self.writer.take().expect("subscription writer exists"); writer.finish(); Ok(()) @@ -126,9 +124,13 @@ async fn handle_subscription_inner( { let request = match read_value_and_eof::(&mut reader, config).await { Ok(request) => request, - Err(code) => { - reader.close(code); - writer.close(code); + Err(error) => { + let code = error.close_code(); + state.handle_transport_error(&error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } return; } }; diff --git a/ql-runtime/src/rpc/adapter.rs b/ql-runtime/src/rpc/adapter.rs index 780708a0..729f1077 100644 --- a/ql-runtime/src/rpc/adapter.rs +++ b/ql-runtime/src/rpc/adapter.rs @@ -5,7 +5,7 @@ pub use ql_rpc::{ LocalMode, RequestHandler, Response, RouteId, RouterConfig, SendMode, StreamCloseCode, SubscriptionHandler, SubscriptionResponder, }; -use ql_rpc::{RpcRead, RpcStream, RpcWrite}; +use ql_rpc::{RpcRead, RpcStream, RpcWrite, StreamError}; use ql_wire::{RouteId as WireRouteId, StreamCloseCode as WireStreamCloseCode}; use crate::{ByteReader, ByteWriter, QlStream, QlStreamError}; @@ -16,6 +16,7 @@ pub type SendRouter = ql_rpc::Router; pub type SendRouterBuilder = ql_rpc::RouterBuilder; impl RpcStream for QlStream { + type Error = QlStreamError; type Reader = ByteReader; type Writer = ByteWriter; @@ -30,12 +31,14 @@ impl RpcStream for QlStream { } impl RpcRead for ByteReader { + type Error = QlStreamError; + fn poll_read( &mut self, max_len: usize, cx: &mut Context<'_>, - ) -> Poll, StreamCloseCode>> { - ByteReader::poll_read(self, max_len, cx).map(|result| result.map_err(from_stream_error)) + ) -> Poll, QlStreamError>> { + ByteReader::poll_read(self, max_len, cx) } fn close(self, code: StreamCloseCode) { @@ -44,12 +47,14 @@ impl RpcRead for ByteReader { } impl RpcWrite for ByteWriter { + type Error = QlStreamError; + fn poll_write( &mut self, bytes: &mut Bytes, cx: &mut Context<'_>, - ) -> Poll> { - ByteWriter::poll_write(self, bytes, cx).map(|result| result.map_err(from_stream_error)) + ) -> Poll> { + ByteWriter::poll_write(self, bytes, cx) } fn finish(self) { @@ -69,10 +74,19 @@ pub(super) fn to_wire_close_code(code: StreamCloseCode) -> WireStreamCloseCode { WireStreamCloseCode(code.into_inner()) } -fn from_stream_error(error: QlStreamError) -> StreamCloseCode { - let code = match error { - QlStreamError::StreamClosed { code } => code, - QlStreamError::NoSession => WireStreamCloseCode::CANCELLED, - }; - StreamCloseCode(code.0) +impl From for QlStreamError { + fn from(code: StreamCloseCode) -> Self { + Self::StreamClosed { + code: WireStreamCloseCode(code.into_inner()), + } + } +} + +impl StreamError for QlStreamError { + fn close_code(&self) -> Option { + match self { + QlStreamError::StreamClosed { code } => Some(StreamCloseCode(code.0)), + QlStreamError::NoSession => None, + } + } } From 0e89ae7e92d7b93291a53da4f688b1bf51cbb7d8 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 12:04:36 -0400 Subject: [PATCH 223/304] fix cargo lock --- Cargo.lock | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.lock b/Cargo.lock index 21cee875..fd232744 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2238,6 +2238,7 @@ dependencies = [ "bytes", "event-listener", "futures-lite", + "log", "loom", "oneshot", "ql-fsm", From 75bcf00a2f448aa3ef03819820b09afcd8e1c706 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 12:46:58 -0400 Subject: [PATCH 224/304] ql-fsm: outboundfinished --- ql-fsm/src/fsm.rs | 3 +++ ql-fsm/src/lib.rs | 2 ++ ql-fsm/src/session/mod.rs | 5 +++++ ql-fsm/src/session/stream_tx.rs | 5 +++++ ql-fsm/src/session/tests.rs | 37 +++++++++++++++++++++++++++++++++ ql-fsm/src/tests/proptest.rs | 12 +++++++++++ ql-fsm/src/tests/session.rs | 7 +++++++ 7 files changed, 71 insertions(+) diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs index 9c0bb978..5c2874e6 100644 --- a/ql-fsm/src/fsm.rs +++ b/ql-fsm/src/fsm.rs @@ -206,6 +206,9 @@ fn forward_session_event(event: SessionEvent, events: &mut VecDeque) { SessionEvent::Finished(stream_id) => { events.push_back(Event::Finished(stream_id)); } + SessionEvent::OutboundFinished(stream_id) => { + events.push_back(Event::OutboundFinished(stream_id)); + } SessionEvent::Closed(frame) => { events.push_back(Event::Closed(frame)); } diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 9dcd2c65..5a0e9eaa 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -83,6 +83,8 @@ pub enum Event { Writable(StreamId), /// the peer finished writing this stream Finished(StreamId), + /// our local FIN was acknowledged by the peer at the session layer + OutboundFinished(StreamId), /// a stream was closed Closed(StreamClose), /// local writes on this stream are closed diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index ad87f532..6e8d46e3 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -74,6 +74,7 @@ pub enum SessionEvent { Readable(StreamId), Writable(StreamId), Finished(StreamId), + OutboundFinished(StreamId), Closed(StreamClose), WritableClosed(StreamClose), SessionClosed(SessionClose), @@ -962,6 +963,7 @@ fn acknowledge_tracked_frame( let stream_id = frame.stream_id; if let Some(stream) = streams.get_mut(&stream_id) { let was_full = stream.send_capacity(stream_send_buffer_size) == 0; + let had_unacked_fin = frame.fin && stream.tx.has_unacked_fin(); stream.tx.ack(StreamTxRange { offset: frame.offset, len: frame.len, @@ -970,6 +972,9 @@ fn acknowledge_tracked_frame( if was_full && stream.send_capacity(stream_send_buffer_size) > 0 { emit(SessionEvent::Writable(stream_id)); } + if had_unacked_fin && !stream.tx.has_unacked_fin() { + emit(SessionEvent::OutboundFinished(stream_id)); + } } } } diff --git a/ql-fsm/src/session/stream_tx.rs b/ql-fsm/src/session/stream_tx.rs index d697e4f8..eb12ef26 100644 --- a/ql-fsm/src/session/stream_tx.rs +++ b/ql-fsm/src/session/stream_tx.rs @@ -172,6 +172,11 @@ impl StreamTx { }); } + pub fn has_unacked_fin(&self) -> bool { + self.final_offset + .is_some_and(|final_offset| final_offset.state != SendState::Acked) + } + pub fn poll_transmit( &mut self, max_payload: usize, diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 88741280..3168399e 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -208,6 +208,43 @@ fn ack_reopens_write_capacity() { assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"z"), 1); } +#[test] +fn ack_of_fin_emits_outbound_finished_once() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = open_stream_id(&mut fsm); + + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"done"), 4); + fsm.stream(stream_id).unwrap().writer().unwrap().finish(); + + let (record_seq, record) = next_outbound(&mut fsm, now).unwrap(); + assert!(matches!( + record.as_slice(), + [SessionFrame::StreamData(StreamData { + stream_id: id, + fin: true, + .. + })] if *id == stream_id + )); + + let mut events = Vec::new(); + fsm.receive( + now + Duration::from_millis(1), + seq(9), + std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), + |event| events.push(event), + ); + assert_eq!(events, vec![SessionEvent::OutboundFinished(stream_id)]); + + fsm.receive( + now + Duration::from_millis(2), + seq(10), + std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), + |event| events.push(event), + ); + assert_eq!(events, vec![SessionEvent::OutboundFinished(stream_id)]); +} + #[test] fn commit_stream_read_is_what_advances_stream_window() { let now = Instant::now(); diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index a9e756c7..6bd196ef 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -117,6 +117,7 @@ struct TakenWrite { struct SideEventState { opened: BTreeSet, finished: BTreeSet, + outbound_finished: BTreeSet, writable_closed: BTreeSet, closed: BTreeSet, peer_statuses: Vec, @@ -443,6 +444,16 @@ impl Runner { "side {side:?} emitted Finished after Closed for {stream_id:?}" ); } + Event::OutboundFinished(stream_id) => { + prop_assert!( + self.known_streams.contains(&stream_id), + "side {side:?} emitted OutboundFinished for unknown stream {stream_id:?}" + ); + prop_assert!( + self.events[side.idx()].outbound_finished.insert(stream_id), + "side {side:?} emitted duplicate OutboundFinished for {stream_id:?}" + ); + } Event::Closed(frame) => { prop_assert!( self.known_streams.contains(&frame.stream_id), @@ -606,6 +617,7 @@ impl Runner { && self.events.iter().all(|events| { events.opened.is_empty() && events.finished.is_empty() + && events.outbound_finished.is_empty() && events.closed.is_empty() && events.writable_closed.is_empty() }), diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 793409b2..cc4b79e1 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -90,6 +90,13 @@ fn connected_fsms_deliver_stream_data() { harness.take_event(Side::B), Some(Event::Finished(stream_id)) ); + harness.advance(QlFsmConfig::default().session_record_ack_delay); + harness.on_timer(Side::B); + harness.pump(); + assert_eq!( + harness.take_event(Side::A), + Some(Event::OutboundFinished(stream_id)) + ); } #[test] From 13795a41e7f8c3f701bcb8c91bf18fcee46dd008 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 13:06:35 -0400 Subject: [PATCH 225/304] ql-runtime: await finish --- ql-rpc/src/router/request.rs | 4 +- ql-rpc/src/router/stream.rs | 9 ++- ql-rpc/src/router/subscription.rs | 9 ++- ql-runtime/src/command.rs | 2 +- ql-runtime/src/driver/mod.rs | 23 +++++++- ql-runtime/src/driver/state.rs | 37 +++++++++++-- ql-runtime/src/driver/test.rs | 6 +- ql-runtime/src/handle/mod.rs | 4 ++ ql-runtime/src/handle/reader.rs | 20 +++---- ql-runtime/src/handle/writer.rs | 91 +++++++++++++++++-------------- ql-runtime/src/rpc/adapter.rs | 4 +- ql-runtime/src/rpc/mod.rs | 4 +- ql-runtime/src/tests/handshake.rs | 4 +- ql-runtime/src/tests/heartbeat.rs | 6 +- ql-runtime/src/tests/rpc.rs | 8 +-- ql-runtime/src/tests/stream.rs | 43 +++++++++------ 16 files changed, 172 insertions(+), 102 deletions(-) diff --git a/ql-rpc/src/router/request.rs b/ql-rpc/src/router/request.rs index 55fc55c6..e22d4d56 100644 --- a/ql-rpc/src/router/request.rs +++ b/ql-rpc/src/router/request.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use bytes::Bytes; use super::{ - stream::{read_bytes, write_bytes, RpcRead, RpcStream, RpcWrite, StreamError}, + stream::{finish_bytes, read_bytes, write_bytes, RpcRead, RpcStream, RpcWrite, StreamError}, LocalMode, RouteMode, RouterConfig, SendMode, }; use crate::{ @@ -45,7 +45,7 @@ where let mut encoded = Vec::new(); codec::encode_value_part(&response, &mut encoded); write_bytes(&mut writer, Bytes::from(encoded)).await?; - writer.finish(); + finish_bytes(&mut writer).await?; Ok(()) } diff --git a/ql-rpc/src/router/stream.rs b/ql-rpc/src/router/stream.rs index 7b531374..5755a26b 100644 --- a/ql-rpc/src/router/stream.rs +++ b/ql-rpc/src/router/stream.rs @@ -35,7 +35,7 @@ pub trait RpcWrite { bytes: &mut Bytes, cx: &mut Context<'_>, ) -> Poll>; - fn finish(self); + fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll>; fn close(self, code: StreamCloseCode); } @@ -64,6 +64,13 @@ where poll_fn(|cx| writer.poll_write(&mut bytes, cx)).await } +pub async fn finish_bytes(writer: &mut W) -> Result<(), W::Error> +where + W: RpcWrite, +{ + poll_fn(|cx| writer.poll_finish(cx)).await +} + pub fn close_stream(stream: St, code: StreamCloseCode) where St: RpcStream, diff --git a/ql-rpc/src/router/subscription.rs b/ql-rpc/src/router/subscription.rs index bb137b6b..14171211 100644 --- a/ql-rpc/src/router/subscription.rs +++ b/ql-rpc/src/router/subscription.rs @@ -4,7 +4,7 @@ use bytes::Bytes; use super::{ request::read_value_and_eof, - stream::{write_bytes, RpcRead, RpcStream, RpcWrite, StreamError}, + stream::{finish_bytes, write_bytes, RpcRead, RpcStream, RpcWrite, StreamError}, LocalMode, RouteMode, RouterConfig, SendMode, }; use crate::{codec, subscription::Subscription as SubscriptionRpc, RpcCodec, StreamCloseCode}; @@ -47,10 +47,9 @@ where Ok(()) } - pub fn finish(mut self) -> Result<(), W::Error> { - let writer = self.writer.take().expect("subscription writer exists"); - writer.finish(); - Ok(()) + pub async fn finish(mut self) -> Result<(), W::Error> { + let mut writer = self.writer.take().expect("subscription writer exists"); + finish_bytes(&mut writer).await } pub fn close(mut self, code: StreamCloseCode) { diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index deccec52..71b03879 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -18,7 +18,7 @@ pub(crate) enum RuntimeCommand { OpenStream { route_id: RouteId, request_reader: ChunkSlotRx, - request_terminal: oneshot::Sender, + request_terminal: oneshot::Sender>, start: oneshot::Sender>, }, PollInbound { diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 0922f188..cb35a538 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -298,6 +298,9 @@ impl DriverState { Event::Finished(stream_id) => { self.handle_inbound_finished(fsm, stream_id); } + Event::OutboundFinished(stream_id) => { + self.handle_outbound_finished(stream_id); + } Event::Closed(frame) => { self.handle_closed_stream(&frame); } @@ -482,6 +485,22 @@ impl DriverState { Self::try_reap_stream(entry); } + fn handle_outbound_finished(&mut self, stream_id: StreamId) { + debug!( + "runtime outbound finish acknowledged: stream_id={:?}", + stream_id + ); + let Entry::Occupied(mut entry) = self.streams.entry(stream_id) else { + return; + }; + let stream = entry.get_mut(); + if !stream.outbound_finish_pending() { + return; + } + stream.outbound_finish(); + Self::try_reap_stream(entry); + } + fn fill_write_slots<'a, P: QlPlatform + 'a>( &self, fsm: &mut QlFsm, @@ -521,7 +540,7 @@ impl DriverState { writer.finish(); } } - stream.outbound_close(); + stream.outbound_queue_finish(); if stream.is_closed() { entry.remove(); } @@ -548,7 +567,7 @@ impl DriverState { stream_id ); writer.finish(); - stream.outbound_close(); + stream.outbound_queue_finish(); if stream.is_closed() { entry.remove(); } diff --git a/ql-runtime/src/driver/state.rs b/ql-runtime/src/driver/state.rs index 8e07368f..9d4776da 100644 --- a/ql-runtime/src/driver/state.rs +++ b/ql-runtime/src/driver/state.rs @@ -63,16 +63,39 @@ impl DriverStreamIo { self.outbound = None; } + pub fn outbound_finish(&mut self) { + if let Some(mut outbound) = self.outbound.take() { + if let Some(terminal) = outbound.terminal.take() { + let _ = terminal.send(Ok(())); + } + } + } + pub fn outbound_fail(&mut self, error: QlStreamError) { if let Some(mut outbound) = self.outbound.take() { if let Some(terminal) = outbound.terminal.take() { - let _ = terminal.send(error); + let _ = terminal.send(Err(error)); } } } pub fn outbound_reader_mut(&mut self) -> Option<&mut ChunkSlotRx> { - self.outbound.as_mut().map(|outbound| &mut outbound.reader) + self.outbound + .as_mut() + .and_then(|outbound| outbound.reader.as_mut()) + } + + pub fn outbound_queue_finish(&mut self) { + if let Some(outbound) = self.outbound.as_mut() { + outbound.reader = None; + outbound.finish_pending = true; + } + } + + pub fn outbound_finish_pending(&self) -> bool { + self.outbound + .as_ref() + .is_some_and(|outbound| outbound.finish_pending) } pub fn inbound_close(&mut self) { @@ -127,15 +150,17 @@ impl DriverStreamIo { } pub struct OutboundIo { - reader: ChunkSlotRx, - terminal: Option>, + reader: Option, + terminal: Option>>, + finish_pending: bool, } impl OutboundIo { - pub fn new(reader: ChunkSlotRx, terminal: oneshot::Sender) -> Self { + pub fn new(reader: ChunkSlotRx, terminal: oneshot::Sender>) -> Self { Self { - reader, + reader: Some(reader), terminal: Some(terminal), + finish_pending: false, } } } diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index 9d4eacea..eefeaa48 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -106,7 +106,7 @@ fn handle_closed_stream_reaps_when_both_halves_close() { } #[test] -fn poll_stream_reaps_after_local_finish_when_inbound_is_closed() { +fn poll_stream_keeps_outbound_pending_after_local_finish_when_inbound_is_closed() { let (mut state, mut fsm) = new_driver_state(); let stream_id = StreamId(1u32.into()); let (request_reader, request_writer) = chunk_slot::new(); @@ -124,7 +124,9 @@ fn poll_stream_reaps_after_local_finish_when_inbound_is_closed() { state.poll_stream(&mut fsm, stream_id); - assert!(!state.streams.contains_key(&stream_id)); + let stream = state.streams.get(&stream_id).unwrap(); + assert!(stream.outbound_finish_pending()); + assert!(!stream.is_closed()); } #[test] diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index 20c5581a..49c7e820 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -99,4 +99,8 @@ impl RuntimeHandle { pub(crate) fn send(&self, cmd: RuntimeCommand) { self.tx.try_send(cmd).expect("runtime is alive"); } + + pub(crate) fn try_send(&self, cmd: RuntimeCommand) -> bool { + self.tx.try_send(cmd).is_ok() + } } diff --git a/ql-runtime/src/handle/reader.rs b/ql-runtime/src/handle/reader.rs index 636097d4..997e45b1 100644 --- a/ql-runtime/src/handle/reader.rs +++ b/ql-runtime/src/handle/reader.rs @@ -80,7 +80,7 @@ impl ByteReader { self.target, bytes.len() ); - self.handle.send(RuntimeCommand::PollInbound { + self.handle.try_send(RuntimeCommand::PollInbound { stream_id: self.stream_id, }); return Poll::Ready(Ok(Some(bytes))); @@ -88,8 +88,7 @@ impl ByteReader { Poll::Ready(Err(_)) => { debug!( "byte reader channel closed: stream_id={:?} target={:?}", - self.stream_id, - self.target + self.stream_id, self.target ); self.reader = None; self.listener = None; @@ -116,8 +115,7 @@ impl ByteReader { TerminalState::Terminal(Ok(())) => { debug!( "byte reader delivered clean eof: stream_id={:?} target={:?}", - self.stream_id, - self.target + self.stream_id, self.target ); self.terminal = TerminalState::Delivered; Poll::Ready(Ok(None)) @@ -126,9 +124,7 @@ impl ByteReader { let error = error.clone(); debug!( "byte reader delivered terminal error: stream_id={:?} target={:?} error={:?}", - self.stream_id, - self.target, - error + self.stream_id, self.target, error ); self.terminal = TerminalState::Delivered; Poll::Ready(Err(error)) @@ -159,14 +155,12 @@ impl ByteReader { } debug!( "byte reader explicit close: stream_id={:?} target={:?} code={:?}", - self.stream_id, - self.target, - code + self.stream_id, self.target, code ); self.reader.take(); self.listener = None; self.terminal = TerminalState::Delivered; - self.handle.send(RuntimeCommand::CloseStream { + self.handle.try_send(RuntimeCommand::CloseStream { stream_id: self.stream_id, target: self.target, code, @@ -185,7 +179,7 @@ impl Drop for ByteReader { self.target, StreamCloseCode::CANCELLED ); - self.handle.send(RuntimeCommand::CloseStream { + self.handle.try_send(RuntimeCommand::CloseStream { stream_id: self.stream_id, target: self.target, code: StreamCloseCode::CANCELLED, diff --git a/ql-runtime/src/handle/writer.rs b/ql-runtime/src/handle/writer.rs index b00811b9..dc5c4b3c 100644 --- a/ql-runtime/src/handle/writer.rs +++ b/ql-runtime/src/handle/writer.rs @@ -25,8 +25,8 @@ pub struct ByteWriter { } enum WriteTerminalState { - Armed(oneshot::Receiver), - Terminal(QlStreamError), + Armed(oneshot::Receiver>), + Terminal(Result<(), QlStreamError>), } // Safety: `ByteWriter` contains a `oneshot::Receiver`, which is `!Sync`, but that receiver is @@ -45,29 +45,6 @@ impl std::fmt::Debug for ByteWriter { } impl ByteWriter { - pub(crate) fn new( - stream_id: StreamId, - target: CloseTarget, - writer: ChunkSlotTx, - terminal: oneshot::Receiver, - handle: RuntimeHandle, - ) -> Self { - Self { - stream_id, - target, - writer: Some(writer), - listener: None, - terminal: WriteTerminalState::Armed(terminal), - handle, - } - } - - fn poll_runtime(&self) { - self.handle.send(RuntimeCommand::PollStream { - stream_id: self.stream_id, - }); - } - pub fn poll_write( &mut self, bytes: &mut Bytes, @@ -78,7 +55,7 @@ impl ByteWriter { } let Some(writer) = self.writer.as_ref() else { - return self.poll_terminal_error(cx).map(Err); + return self.poll_terminal(cx); }; match writer.poll_send(bytes, &mut self.listener, cx) { @@ -95,12 +72,11 @@ impl ByteWriter { Poll::Ready(Err(SendClosed(_bytes))) => { debug!( "byte writer send closed: stream_id={:?} target={:?}", - self.stream_id, - self.target + self.stream_id, self.target ); self.writer.take(); self.listener = None; - self.poll_terminal_error(cx).map(Err) + self.poll_terminal(cx) } Poll::Pending => Poll::Pending, } @@ -111,19 +87,31 @@ impl ByteWriter { poll_fn(|cx| self.poll_write(&mut bytes, cx)).await } - pub fn finish(mut self) { + pub fn queue_finish(&mut self) { let Some(writer) = self.writer.take() else { return; }; debug!( "byte writer finish: stream_id={:?} target={:?}", - self.stream_id, - self.target + self.stream_id, self.target ); writer.close(); + self.listener = None; self.poll_runtime(); } + pub async fn finish(mut self) -> Result<(), QlStreamError> { + self.queue_finish(); + poll_fn(|cx| self.poll_terminal(cx)).await + } + + pub fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.writer.is_some() { + self.queue_finish(); + } + self.poll_terminal(cx) + } + pub fn close(mut self, code: StreamCloseCode) { self.close_inner(code); } @@ -136,13 +124,36 @@ impl Drop for ByteWriter { } impl ByteWriter { - fn poll_terminal_error(&mut self, cx: &mut Context<'_>) -> Poll { + pub(crate) fn new( + stream_id: StreamId, + target: CloseTarget, + writer: ChunkSlotTx, + terminal: oneshot::Receiver>, + handle: RuntimeHandle, + ) -> Self { + Self { + stream_id, + target, + writer: Some(writer), + listener: None, + terminal: WriteTerminalState::Armed(terminal), + handle, + } + } + + fn poll_runtime(&self) { + self.handle.try_send(RuntimeCommand::PollStream { + stream_id: self.stream_id, + }); + } + + fn poll_terminal(&mut self, cx: &mut Context<'_>) -> Poll> { match &mut self.terminal { - WriteTerminalState::Terminal(error) => Poll::Ready(error.clone()), + WriteTerminalState::Terminal(result) => Poll::Ready(result.clone()), WriteTerminalState::Armed(receiver) => match Pin::new(receiver).poll(cx) { - Poll::Ready(Ok(error)) => { - self.terminal = WriteTerminalState::Terminal(error.clone()); - Poll::Ready(error) + Poll::Ready(Ok(result)) => { + self.terminal = WriteTerminalState::Terminal(result.clone()); + Poll::Ready(result) } Poll::Ready(Err(_)) => { panic!("byte writer terminal dropped before sending a terminal state") @@ -158,12 +169,10 @@ impl ByteWriter { } debug!( "byte writer close: stream_id={:?} target={:?} code={:?}", - self.stream_id, - self.target, - code + self.stream_id, self.target, code ); self.listener = None; - self.handle.send(RuntimeCommand::CloseStream { + self.handle.try_send(RuntimeCommand::CloseStream { stream_id: self.stream_id, target: self.target, code, diff --git a/ql-runtime/src/rpc/adapter.rs b/ql-runtime/src/rpc/adapter.rs index 729f1077..7449eed2 100644 --- a/ql-runtime/src/rpc/adapter.rs +++ b/ql-runtime/src/rpc/adapter.rs @@ -57,8 +57,8 @@ impl RpcWrite for ByteWriter { ByteWriter::poll_write(self, bytes, cx) } - fn finish(self) { - ByteWriter::finish(self); + fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll> { + ByteWriter::poll_finish(self, cx) } fn close(self, code: StreamCloseCode) { diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index 5c16d18c..98c1a125 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -35,7 +35,7 @@ impl RpcHandle { .await?; stream.reader.close(ql_wire::StreamCloseCode::CANCELLED); stream.writer.write(Bytes::from(payload)).await?; - stream.writer.finish(); + stream.writer.finish().await?; Ok(()) } @@ -92,7 +92,7 @@ impl RpcHandle { .open_stream(adapter::to_wire_route_id(route_id)) .await?; stream.writer.write(Bytes::from(payload)).await?; - stream.writer.finish(); + stream.writer.finish().await?; Ok(stream.reader) } } diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index 20629a99..c785f048 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -79,7 +79,7 @@ async fn rejected_session_write_is_reissued() { let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); let request = read_all(stream.reader).await.unwrap(); - stream.writer.finish(); + stream.writer.finish().await.unwrap(); request }); @@ -89,7 +89,7 @@ async fn rejected_session_write_is_reissued() { .write(Bytes::from_static(b"retry")) .await .unwrap(); - stream.writer.finish(); + stream.writer.finish().await.unwrap(); assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); assert_eq!( diff --git a/ql-runtime/src/tests/heartbeat.rs b/ql-runtime/src/tests/heartbeat.rs index 2fd10383..f9c7f2be 100644 --- a/ql-runtime/src/tests/heartbeat.rs +++ b/ql-runtime/src/tests/heartbeat.rs @@ -44,13 +44,15 @@ async fn session_timeout_disconnects_and_fails_pending_open() { let responder_task = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); let _ = read_all(stream.reader).await; - stream.writer.finish(); + let err = stream.writer.finish().await.unwrap_err(); + assert!(matches!(err, QlStreamError::NoSession)); }); drop_flag.store(true, Ordering::Relaxed); let mut pending = handle_a.open_stream(test_route_id()).await.unwrap(); - pending.writer.finish(); + let err = pending.writer.finish().await.unwrap_err(); + assert!(matches!(err, QlStreamError::NoSession)); await_status(&status_a, identity_b.xid, PeerStatus::Disconnected).await; diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 19e97b91..383d9ed4 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -68,7 +68,7 @@ async fn rpc_request_round_trips() { ql_rpc::request::encode_response::(&"world".into(), &mut encoded); let mut writer = inbound.writer; writer.write(Bytes::from(encoded)).await.unwrap(); - writer.finish(); + writer.finish().await.unwrap(); }); let rpc = pair.handle(Side::A).rpc(); @@ -148,7 +148,7 @@ async fn rpc_router_handles_subscription() { seen.borrow_mut().push(request); let _ = response.send(b"one".to_vec()).await; let _ = response.send(b"two".to_vec()).await; - let _ = response.finish(); + let _ = response.finish().await; }); } } @@ -302,7 +302,7 @@ async fn rpc_subscription_streams_events() { let mut writer = inbound.writer; writer.write(Bytes::from(encoded)).await.unwrap(); - writer.finish(); + writer.finish().await.unwrap(); }); let rpc = pair.handle(Side::A).rpc(); @@ -353,7 +353,7 @@ async fn rpc_request_with_progress_supports_progress_then_await() { let mut writer = inbound.writer; writer.write(Bytes::from(encoded)).await.unwrap(); - writer.finish(); + writer.finish().await.unwrap(); }); let rpc = pair.handle(Side::A).rpc(); diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index cc0e0be4..811ec32f 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -24,7 +24,7 @@ async fn open_stream_duplex_happy_path() { assert_eq!(next_chunk(&mut reader).await.unwrap(), Some(vec![3, 4])); writer.write(Bytes::from_static(&[8, 7])).await.unwrap(); assert_eq!(next_chunk(&mut reader).await.unwrap(), None); - writer.finish(); + writer.finish().await.unwrap(); }); let mut stream = pair @@ -44,7 +44,7 @@ async fn open_stream_duplex_happy_path() { .write(Bytes::from_static(&[3, 4])) .await .unwrap(); - stream.writer.finish(); + stream.writer.finish().await.unwrap(); assert_eq!( next_chunk(&mut stream.reader).await.unwrap(), Some(vec![8, 7]) @@ -78,7 +78,7 @@ async fn reader_respects_max_len() { assert_eq!(next_chunk_max(&mut reader, 2).await.unwrap(), Some(vec![5, 6])); assert_eq!(next_chunk(&mut reader).await.unwrap(), None); - inbound.writer.finish(); + inbound.writer.finish().await.unwrap(); }); let mut stream = pair @@ -92,7 +92,7 @@ async fn reader_respects_max_len() { .write(Bytes::from_static(&[1, 2, 3, 4, 5, 6])) .await .unwrap(); - stream.writer.finish(); + stream.writer.finish().await.unwrap(); assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); tokio::time::timeout(Duration::from_secs(2), responder) @@ -115,7 +115,7 @@ async fn large_stream_payload_round_trips() { let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); let request_data = read_all(stream.reader).await.unwrap(); - stream.writer.finish(); + stream.writer.finish().await.unwrap(); done_tx.send(request_data).await.unwrap(); }); @@ -130,7 +130,7 @@ async fn large_stream_payload_round_trips() { .write(Bytes::from(payload.clone())) .await .unwrap(); - stream.writer.finish(); + stream.writer.finish().await.unwrap(); assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); let received = tokio::time::timeout(Duration::from_secs(2), done_rx.recv()) @@ -165,7 +165,11 @@ async fn dropping_responder_closes_initiator_response() { .open_stream(test_route_id()) .await .unwrap(); - stream.writer.finish(); + let err = stream.writer.finish().await.unwrap_err(); + assert!(matches!( + err, + QlStreamError::StreamClosed { code } if code == StreamCloseCode::CANCELLED + )); let err = next_chunk(&mut stream.reader).await.unwrap_err(); assert!(matches!( @@ -200,7 +204,11 @@ async fn dropping_inbound_reader_cancels_remote_writer() { .unwrap(); go_rx.recv().await.unwrap(); let _ = writer.write(Bytes::from(vec![5; 64])).await; - writer.finish(); + let err = writer.finish().await.unwrap_err(); + assert!(matches!( + err, + QlStreamError::StreamClosed { code } if code == StreamCloseCode::CANCELLED + )); }); let mut stream = pair @@ -209,7 +217,7 @@ async fn dropping_inbound_reader_cancels_remote_writer() { .open_stream(test_route_id()) .await .unwrap(); - stream.writer.finish(); + stream.writer.finish().await.unwrap(); assert_eq!( next_chunk(&mut stream.reader).await.unwrap(), Some(vec![1, 2, 3, 4]) @@ -250,7 +258,7 @@ async fn closing_initiator_reader_preserves_initiator_writer() { writer.write(Bytes::from_static(&[1, 2])).await.unwrap(); writer.write(Bytes::from_static(&[3, 4])).await.unwrap(); - writer.finish(); + writer.finish().await.unwrap(); let request = tokio::time::timeout(Duration::from_secs(2), done_rx.recv()) .await @@ -298,7 +306,8 @@ async fn max_concurrent_message_writes_is_respected() { for _ in 0..4 { let stream = inbound_b.recv().await.unwrap(); let _ = read_all(stream.reader).await; - stream.writer.finish(); + let mut writer = stream.writer; + writer.queue_finish(); } }); @@ -308,7 +317,7 @@ async fn max_concurrent_message_writes_is_respected() { tasks.push(tokio::task::spawn_local(async move { let mut stream = handle.open_stream(test_route_id()).await.unwrap(); stream.writer.write(Bytes::from(vec![i; 8])).await.unwrap(); - stream.writer.finish(); + stream.writer.finish().await.unwrap(); assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); })); } @@ -375,7 +384,7 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { .write(Bytes::from(response_payload.clone())) .await .unwrap(); - writer.finish(); + writer.finish().await.unwrap(); received_request }); @@ -385,7 +394,7 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { .write(Bytes::from(request_payload.clone())) .await .unwrap(); - stream.writer.finish(); + stream.writer.finish().await.unwrap(); let mut received_response = Vec::new(); while let Some(chunk) = next_chunk(&mut stream.reader).await.unwrap() { @@ -461,7 +470,7 @@ async fn multi_megabyte_stream_survives_asymmetric_loss_and_delay() { eprintln!("responder received {} bytes", received.len()); } } - stream.writer.finish(); + stream.writer.finish().await.unwrap(); received }); @@ -500,7 +509,7 @@ async fn multi_megabyte_stream_survives_asymmetric_loss_and_delay() { } } eprintln!("writer finished queueing"); - stream.writer.finish(); + stream.writer.finish().await.unwrap(); eprintln!("writer waiting for eof"); assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); eprintln!("writer observed eof"); @@ -600,7 +609,7 @@ async fn reproducer_writer_stalls_after_reverse_path_impairment() { .await .unwrap(); } - stream.writer.finish(); + stream.writer.queue_finish(); let _ = next_chunk(&mut stream.reader).await; }); From 63b224f43b30acb877b9a0bdea3752d403b9392b Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 14:28:29 -0400 Subject: [PATCH 226/304] rpc stream docs --- ql-rpc/src/router/stream.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ql-rpc/src/router/stream.rs b/ql-rpc/src/router/stream.rs index 5755a26b..f6174efd 100644 --- a/ql-rpc/src/router/stream.rs +++ b/ql-rpc/src/router/stream.rs @@ -19,23 +19,31 @@ pub trait RpcStream { pub trait RpcRead { type Error: StreamError; + /// reads inbound bytes until eof or error fn poll_read( &mut self, max_len: usize, cx: &mut Context<'_>, ) -> Poll, Self::Error>>; + + /// aborts the read side fn close(self, code: StreamCloseCode); } pub trait RpcWrite { type Error: StreamError; + /// writes outbound bytes before finish or close fn poll_write( &mut self, bytes: &mut Bytes, cx: &mut Context<'_>, ) -> Poll>; + + /// completes the write side and must be polled until ready without further write or close calls fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll>; + + /// aborts the write side before finish fn close(self, code: StreamCloseCode); } From 8edfbcf232864c7e12a6a171c6d74024d14e0ffd Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 14:31:36 -0400 Subject: [PATCH 227/304] ql-rpc: refactor --- ql-rpc/src/lib.rs | 2 ++ ql-rpc/src/router/mod.rs | 6 +++--- ql-rpc/src/router/request.rs | 2 +- ql-rpc/src/router/subscription.rs | 6 ++++-- ql-rpc/src/{router => }/stream.rs | 0 5 files changed, 10 insertions(+), 6 deletions(-) rename ql-rpc/src/{router => }/stream.rs (100%) diff --git a/ql-rpc/src/lib.rs b/ql-rpc/src/lib.rs index b10d96d6..ded00d2e 100644 --- a/ql-rpc/src/lib.rs +++ b/ql-rpc/src/lib.rs @@ -4,11 +4,13 @@ pub(crate) mod codec; mod error; mod router; pub mod rpc; +mod stream; pub use codec::{ReadValueStep, RpcCodec, ValueReader}; pub use error::*; pub use router::*; pub use rpc::*; +pub use stream::*; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs index 88d23d51..feb30fe6 100644 --- a/ql-rpc/src/router/mod.rs +++ b/ql-rpc/src/router/mod.rs @@ -6,7 +6,6 @@ mod builder; mod config; mod mode; mod request; -mod stream; mod subscription; pub use self::{ @@ -14,10 +13,11 @@ pub use self::{ config::RouterConfig, mode::*, request::{RequestHandler, Response}, - stream::{RpcRead, RpcStream, RpcWrite, StreamError}, subscription::{SubscriptionHandler, SubscriptionResponder}, }; +use crate::{close_stream, RpcStream}; + pub struct Router where Mode: RouteMode, @@ -40,7 +40,7 @@ where pub fn handle(&self, stream: St) -> Option<(RouteId, Mode::RouteFuture)> { let route_id = stream.route_id()?; let Some(route) = self.routes.get(&route_id).copied() else { - stream::close_stream(stream, StreamCloseCode::UNKNOWN_ROUTE); + close_stream(stream, StreamCloseCode::UNKNOWN_ROUTE); return None; }; Some((route_id, route(self.state.clone(), self.config, stream))) diff --git a/ql-rpc/src/router/request.rs b/ql-rpc/src/router/request.rs index e22d4d56..fb8af33a 100644 --- a/ql-rpc/src/router/request.rs +++ b/ql-rpc/src/router/request.rs @@ -3,10 +3,10 @@ use std::marker::PhantomData; use bytes::Bytes; use super::{ - stream::{finish_bytes, read_bytes, write_bytes, RpcRead, RpcStream, RpcWrite, StreamError}, LocalMode, RouteMode, RouterConfig, SendMode, }; use crate::{ + finish_bytes, read_bytes, write_bytes, RpcRead, RpcStream, RpcWrite, StreamError, codec, request::Request as RequestRpc, ReadValueStep, RpcCodec, StreamCloseCode, ValueReader, }; diff --git a/ql-rpc/src/router/subscription.rs b/ql-rpc/src/router/subscription.rs index 14171211..f479fc1a 100644 --- a/ql-rpc/src/router/subscription.rs +++ b/ql-rpc/src/router/subscription.rs @@ -4,10 +4,12 @@ use bytes::Bytes; use super::{ request::read_value_and_eof, - stream::{finish_bytes, write_bytes, RpcRead, RpcStream, RpcWrite, StreamError}, LocalMode, RouteMode, RouterConfig, SendMode, }; -use crate::{codec, subscription::Subscription as SubscriptionRpc, RpcCodec, StreamCloseCode}; +use crate::{ + codec, finish_bytes, subscription::Subscription as SubscriptionRpc, write_bytes, RpcCodec, + RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, +}; pub trait SubscriptionHandler where diff --git a/ql-rpc/src/router/stream.rs b/ql-rpc/src/stream.rs similarity index 100% rename from ql-rpc/src/router/stream.rs rename to ql-rpc/src/stream.rs From 563e5b2576b60c41d3a6eca8f47729b43edd1c07 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 14:57:07 -0400 Subject: [PATCH 228/304] ql-rpc: router spawner --- ql-rpc/src/router/builder.rs | 103 +++++++++++++++++++----------- ql-rpc/src/router/mod.rs | 22 ++++--- ql-rpc/src/router/mode.rs | 52 +++++++++++---- ql-rpc/src/router/request.rs | 47 +------------- ql-rpc/src/router/subscription.rs | 45 +------------ ql-runtime/src/rpc/adapter.rs | 7 +- ql-runtime/src/tests/rpc.rs | 18 ++++-- 7 files changed, 138 insertions(+), 156 deletions(-) diff --git a/ql-rpc/src/router/builder.rs b/ql-rpc/src/router/builder.rs index c88fffb8..52e27926 100644 --- a/ql-rpc/src/router/builder.rs +++ b/ql-rpc/src/router/builder.rs @@ -1,39 +1,32 @@ use std::collections::HashMap; use super::{ - request::{RequestHandler, RequestRouteMode}, - subscription::{SubscriptionHandler, SubscriptionRouteMode}, - LocalMode, RouteMode, Router, RouterConfig, RpcStream, + request::{handle_request_inner, RequestHandler}, + subscription::{handle_subscription_inner, SubscriptionHandler}, + LocalSpawn, LocalSpawner, RouteFn, Router, RouterConfig, RpcStream, SendSpawn, SendSpawner, + Spawner, }; use crate::{ - request::Request as RequestRpc, router::RouteFn, subscription::Subscription as SubscriptionRpc, - RouteId, + request::Request as RequestRpc, subscription::Subscription as SubscriptionRpc, RouteId, }; -pub struct RouterBuilder +pub struct RouterBuilder where - Mode: RouteMode, + Sp: Spawner, { config: RouterConfig, - routes: HashMap>, + spawner: Sp, + routes: HashMap>, } -impl Default for RouterBuilder +impl RouterBuilder where - Mode: RouteMode, + Sp: Spawner, { - fn default() -> Self { - Self::new() - } -} - -impl RouterBuilder -where - Mode: RouteMode, -{ - pub fn new() -> Self { + pub fn new(spawner: Sp) -> Self { Self { config: RouterConfig::default(), + spawner, routes: std::collections::HashMap::new(), } } @@ -48,16 +41,17 @@ where self } - pub fn build(mut self, state: S) -> Router { + pub fn build(mut self, state: S) -> Router { self.routes.shrink_to_fit(); Router { config: self.config, state, + spawner: self.spawner, routes: self.routes, } } - fn add_route(mut self, route_id: crate::RouteId, route: super::RouteFn) -> Self { + fn add_route(mut self, route_id: crate::RouteId, route: RouteFn) -> Self { if self.routes.insert(route_id, route).is_some() { panic!("duplicate rpc route {}", route_id.into_inner()); } @@ -65,33 +59,70 @@ where } } -impl RouterBuilder +impl RouterBuilder where - Mode: RouteMode, + St: RpcStream + 'static, { pub fn request(self) -> Self where M: RequestRpc + 'static, S: RequestHandler + 'static, - St: RpcStream + 'static, - Mode: RequestRouteMode, { - self.add_route( - M::ROUTE, - >::handle_request, - ) + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_request_inner::( + state, config, reader, writer, + )) + }) } pub fn subscription(self) -> Self where M: SubscriptionRpc + 'static, S: SubscriptionHandler + 'static, - St: RpcStream + 'static, - Mode: SubscriptionRouteMode, { - self.add_route( - M::ROUTE, - >::handle_subscription, - ) + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_subscription_inner::( + state, config, reader, writer, + )) + }) + } +} + +impl RouterBuilder +where + St: RpcStream + 'static, +{ + pub fn request(self) -> Self + where + M: RequestRpc + 'static, + M::Request: Send + 'static, + S: RequestHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_request_inner::( + state, config, reader, writer, + )) + }) + } + + pub fn subscription(self) -> Self + where + M: SubscriptionRpc + 'static, + M::Request: Send + 'static, + S: SubscriptionHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_subscription_inner::( + state, config, reader, writer, + )) + }) } } diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs index feb30fe6..709ef5b9 100644 --- a/ql-rpc/src/router/mod.rs +++ b/ql-rpc/src/router/mod.rs @@ -18,31 +18,35 @@ pub use self::{ use crate::{close_stream, RpcStream}; -pub struct Router +pub struct Router where - Mode: RouteMode, + Sp: Spawner, { config: RouterConfig, state: S, - routes: HashMap>, + spawner: Sp, + routes: HashMap>, } -impl Router +impl Router where S: Clone + 'static, St: RpcStream, - Mode: RouteMode, + Sp: Spawner, { - pub fn builder() -> RouterBuilder { - RouterBuilder::::new() + pub fn builder(spawner: Sp) -> RouterBuilder { + RouterBuilder::::new(spawner) } - pub fn handle(&self, stream: St) -> Option<(RouteId, Mode::RouteFuture)> { + pub fn handle(&self, stream: St) -> Option<(RouteId, Sp::Handle)> { let route_id = stream.route_id()?; let Some(route) = self.routes.get(&route_id).copied() else { close_stream(stream, StreamCloseCode::UNKNOWN_ROUTE); return None; }; - Some((route_id, route(self.state.clone(), self.config, stream))) + Some(( + route_id, + route(&self.spawner, self.state.clone(), self.config, stream), + )) } } diff --git a/ql-rpc/src/router/mode.rs b/ql-rpc/src/router/mode.rs index b0a27b1b..5d22d706 100644 --- a/ql-rpc/src/router/mode.rs +++ b/ql-rpc/src/router/mode.rs @@ -2,24 +2,52 @@ use std::{future::Future, pin::Pin}; use crate::RouterConfig; -pub trait RouteMode { - type RouteFuture: Future + 'static; +pub type RouteFn = fn(&Sp, S, RouterConfig, St) -> ::Handle; + +pub trait Spawner { + type Handle: Future + 'static; } -#[derive(Debug, Clone, Copy, Default)] -pub struct LocalMode; +pub trait LocalSpawner: Spawner { + fn spawn(&self, fut: F) -> Self::Handle + where + F: Future + 'static; +} + +pub trait SendSpawner: Spawner { + fn spawn(&self, fut: F) -> Self::Handle + where + F: Future + Send + 'static; +} #[derive(Debug, Clone, Copy, Default)] -pub struct SendMode; +pub struct LocalSpawn; -pub type RouteFn = fn(S, RouterConfig, St) -> ::RouteFuture; -pub type LocalFuture = Pin + 'static>>; -pub type SendFuture = Pin + Send + 'static>>; +impl Spawner for LocalSpawn { + type Handle = Pin + 'static>>; +} + +impl LocalSpawner for LocalSpawn { + fn spawn(&self, fut: F) -> Self::Handle + where + F: Future + 'static, + { + Box::pin(fut) + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct SendSpawn; -impl RouteMode for LocalMode { - type RouteFuture = LocalFuture; +impl Spawner for SendSpawn { + type Handle = Pin + Send + 'static>>; } -impl RouteMode for SendMode { - type RouteFuture = SendFuture; +impl SendSpawner for SendSpawn { + fn spawn(&self, fut: F) -> Self::Handle + where + F: Future + Send + 'static, + { + Box::pin(fut) + } } diff --git a/ql-rpc/src/router/request.rs b/ql-rpc/src/router/request.rs index fb8af33a..a395ac7f 100644 --- a/ql-rpc/src/router/request.rs +++ b/ql-rpc/src/router/request.rs @@ -2,9 +2,7 @@ use std::marker::PhantomData; use bytes::Bytes; -use super::{ - LocalMode, RouteMode, RouterConfig, SendMode, -}; +use super::RouterConfig; use crate::{ finish_bytes, read_bytes, write_bytes, RpcRead, RpcStream, RpcWrite, StreamError, codec, request::Request as RequestRpc, ReadValueStep, RpcCodec, StreamCloseCode, ValueReader, @@ -67,48 +65,7 @@ where } } -#[doc(hidden)] -pub trait RequestRouteMode: RouteMode -where - M: RequestRpc + 'static, - S: RequestHandler + 'static, - St: RpcStream + 'static, -{ - fn handle_request(state: S, config: RouterConfig, stream: St) -> Self::RouteFuture; -} - -impl RequestRouteMode for LocalMode -where - M: RequestRpc + 'static, - S: RequestHandler + 'static, - St: RpcStream + 'static, -{ - fn handle_request(state: S, config: RouterConfig, stream: St) -> Self::RouteFuture { - let (reader, writer) = stream.split(); - Box::pin(handle_request_inner::( - state, config, reader, writer, - )) - } -} - -impl RequestRouteMode for SendMode -where - M: RequestRpc + 'static, - M::Request: Send + 'static, - S: RequestHandler + Send + 'static, - St: RpcStream + 'static, - St::Reader: Send + 'static, - St::Writer: Send + 'static, -{ - fn handle_request(state: S, config: RouterConfig, stream: St) -> Self::RouteFuture { - let (reader, writer) = stream.split(); - Box::pin(handle_request_inner::( - state, config, reader, writer, - )) - } -} - -async fn handle_request_inner( +pub(super) async fn handle_request_inner( state: S, config: RouterConfig, mut reader: St::Reader, diff --git a/ql-rpc/src/router/subscription.rs b/ql-rpc/src/router/subscription.rs index f479fc1a..dee51c59 100644 --- a/ql-rpc/src/router/subscription.rs +++ b/ql-rpc/src/router/subscription.rs @@ -4,7 +4,7 @@ use bytes::Bytes; use super::{ request::read_value_and_eof, - LocalMode, RouteMode, RouterConfig, SendMode, + RouterConfig, }; use crate::{ codec, finish_bytes, subscription::Subscription as SubscriptionRpc, write_bytes, RpcCodec, @@ -72,48 +72,7 @@ where } } -#[doc(hidden)] -pub trait SubscriptionRouteMode: RouteMode -where - M: SubscriptionRpc + 'static, - S: SubscriptionHandler + 'static, - St: RpcStream + 'static, -{ - fn handle_subscription(state: S, config: RouterConfig, stream: St) -> Self::RouteFuture; -} - -impl SubscriptionRouteMode for LocalMode -where - M: SubscriptionRpc + 'static, - S: SubscriptionHandler + 'static, - St: RpcStream + 'static, -{ - fn handle_subscription(state: S, config: RouterConfig, stream: St) -> Self::RouteFuture { - let (reader, writer) = stream.split(); - Box::pin(handle_subscription_inner::( - state, config, reader, writer, - )) - } -} - -impl SubscriptionRouteMode for SendMode -where - M: SubscriptionRpc + 'static, - M::Request: Send + 'static, - S: SubscriptionHandler + Send + 'static, - St: RpcStream + 'static, - St::Reader: Send + 'static, - St::Writer: Send + 'static, -{ - fn handle_subscription(state: S, config: RouterConfig, stream: St) -> Self::RouteFuture { - let (reader, writer) = stream.split(); - Box::pin(handle_subscription_inner::( - state, config, reader, writer, - )) - } -} - -async fn handle_subscription_inner( +pub(super) async fn handle_subscription_inner( state: S, config: RouterConfig, mut reader: St::Reader, diff --git a/ql-runtime/src/rpc/adapter.rs b/ql-runtime/src/rpc/adapter.rs index 7449eed2..45976ec1 100644 --- a/ql-runtime/src/rpc/adapter.rs +++ b/ql-runtime/src/rpc/adapter.rs @@ -2,7 +2,7 @@ use std::task::{Context, Poll}; use bytes::Bytes; pub use ql_rpc::{ - LocalMode, RequestHandler, Response, RouteId, RouterConfig, SendMode, StreamCloseCode, + LocalSpawn, RequestHandler, Response, RouteId, RouterConfig, SendSpawn, StreamCloseCode, SubscriptionHandler, SubscriptionResponder, }; use ql_rpc::{RpcRead, RpcStream, RpcWrite, StreamError}; @@ -10,11 +10,6 @@ use ql_wire::{RouteId as WireRouteId, StreamCloseCode as WireStreamCloseCode}; use crate::{ByteReader, ByteWriter, QlStream, QlStreamError}; -pub type Router = ql_rpc::Router; -pub type RouterBuilder = ql_rpc::RouterBuilder; -pub type SendRouter = ql_rpc::Router; -pub type SendRouterBuilder = ql_rpc::RouterBuilder; - impl RpcStream for QlStream { type Error = QlStreamError; type Reader = ByteReader; diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 383d9ed4..400e9f94 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -12,7 +12,7 @@ use ql_rpc::{Response, RouteId, StreamCloseCode, SubscriptionResponder}; use ql_wire::RouteId as WireRouteId; use super::*; -use crate::{rpc::Router, ByteWriter}; +use crate::{ByteWriter, QlStream}; struct Echo; @@ -106,7 +106,9 @@ async fn rpc_router_handles_request() { let inbound_b = pair.take_inbound(Side::B); let seen = Rc::new(RefCell::new(Vec::new())); - let router = Router::builder() + let router = ql_rpc::Router::<_, QlStream, crate::rpc::LocalSpawn>::builder( + crate::rpc::LocalSpawn, + ) .request::() .build(RouterState { seen: seen.clone() }); @@ -159,7 +161,9 @@ async fn rpc_router_handles_subscription() { let inbound_b = pair.take_inbound(Side::B); let seen = Rc::new(RefCell::new(Vec::new())); - let router = crate::rpc::Router::builder() + let router = ql_rpc::Router::<_, QlStream, crate::rpc::LocalSpawn>::builder( + crate::rpc::LocalSpawn, + ) .subscription::() .build(RouterState { seen: seen.clone() }); @@ -207,7 +211,9 @@ async fn rpc_send_router_handles_request() { pair.connect_and_wait(Side::A).await; let inbound_b = pair.take_inbound(Side::B); let seen = Arc::new(Mutex::new(Vec::new())); - let router = crate::rpc::SendRouter::builder() + let router = ql_rpc::Router::<_, QlStream, crate::rpc::SendSpawn>::builder( + crate::rpc::SendSpawn, + ) .request::() .build(RouterState { seen: seen.clone() }); @@ -253,7 +259,9 @@ async fn rpc_router_enforces_max_request_bytes() { let mut pair = TestPair::new(default_runtime_config()); pair.connect_and_wait(Side::A).await; let inbound_b = pair.take_inbound(Side::B); - let router = crate::rpc::Router::builder() + let router = ql_rpc::Router::<_, QlStream, crate::rpc::LocalSpawn>::builder( + crate::rpc::LocalSpawn, + ) .max_request_bytes(4) .request::() .build(LimitedState); From b80f9f1a46e9680383ccb529ce868ed30af93589 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 15:39:47 -0400 Subject: [PATCH 229/304] fix clippy --- ql-fsm/src/session/mod.rs | 20 ++++++++-------- ql-fsm/src/session/stream_tx.rs | 4 ++-- ql-rpc/src/router/mod.rs | 1 - ql-rpc/src/router/request.rs | 4 ++-- ql-rpc/src/router/subscription.rs | 5 +--- ql-rpc/src/rpc/request_with_progress.rs | 4 ++-- ql-rpc/src/rpc/subscription.rs | 4 ++-- ql-runtime/src/driver/mod.rs | 31 ++++++------------------- ql-runtime/src/tests/mod.rs | 12 ++++------ ql-runtime/src/tests/stream.rs | 16 ++++++++----- ql-wire/src/encrypted/ack.rs | 4 ++-- 11 files changed, 42 insertions(+), 63 deletions(-) diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 6e8d46e3..ef9d43a9 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -193,7 +193,7 @@ impl SessionFsm { return; } ReceiveOutcome::New => {} - }; + } let mut ack_eliciting = false; let mut handled_close = false; @@ -329,7 +329,7 @@ impl SessionFsm { let seq = self.state.next_record_seq; next_seq(&mut self.state.next_record_seq); let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); - assert!(builder.push_close(&close), "builder has capacity"); + assert!(builder.push_close(close), "builder has capacity"); self.state.phase = SessionPhase::Closed; return Some((None, builder)); } @@ -379,11 +379,11 @@ impl SessionFsm { self.push_next_stream_data(&mut builder, &mut outbound); if let Some(pending_ack) = self.pending_ack(builder.remaining_capacity()) { - if !builder.is_empty() || pending_ack.due_at <= now { - if builder.push_ack(&pending_ack.ack) { - self.state.ack_tracker.on_ack_emitted(&pending_ack); - outbound.ack = Some(pending_ack.ack); - } + if (!builder.is_empty() || pending_ack.due_at <= now) + && builder.push_ack(&pending_ack.ack) + { + self.state.ack_tracker.on_ack_emitted(&pending_ack); + outbound.ack = Some(pending_ack.ack); } } @@ -518,10 +518,10 @@ impl SessionFsm { } fn ensure_session_open(&self) -> Result<(), NoSessionError> { - if self.state.phase != SessionPhase::Open { - Err(NoSessionError) - } else { + if self.state.phase == SessionPhase::Open { Ok(()) + } else { + Err(NoSessionError) } } diff --git a/ql-fsm/src/session/stream_tx.rs b/ql-fsm/src/session/stream_tx.rs index eb12ef26..e4d2d3b6 100644 --- a/ql-fsm/src/session/stream_tx.rs +++ b/ql-fsm/src/session/stream_tx.rs @@ -69,14 +69,14 @@ impl BufView for StreamTxBytes<'_> { } } -impl<'a> StreamTxBuf<'a> { +impl StreamTxBuf<'_> { fn refill(&mut self) { if self.remaining == 0 { self.current = &[]; return; } - while let Some(chunk) = self.inner.next() { + for chunk in self.inner.by_ref() { if self.skip >= chunk.len() { self.skip -= chunk.len(); continue; diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs index 709ef5b9..4e98ef0e 100644 --- a/ql-rpc/src/router/mod.rs +++ b/ql-rpc/src/router/mod.rs @@ -15,7 +15,6 @@ pub use self::{ request::{RequestHandler, Response}, subscription::{SubscriptionHandler, SubscriptionResponder}, }; - use crate::{close_stream, RpcStream}; pub struct Router diff --git a/ql-rpc/src/router/request.rs b/ql-rpc/src/router/request.rs index a395ac7f..452d7880 100644 --- a/ql-rpc/src/router/request.rs +++ b/ql-rpc/src/router/request.rs @@ -4,8 +4,8 @@ use bytes::Bytes; use super::RouterConfig; use crate::{ - finish_bytes, read_bytes, write_bytes, RpcRead, RpcStream, RpcWrite, StreamError, - codec, request::Request as RequestRpc, ReadValueStep, RpcCodec, StreamCloseCode, ValueReader, + codec, finish_bytes, read_bytes, request::Request as RequestRpc, write_bytes, ReadValueStep, + RpcCodec, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, ValueReader, }; pub trait RequestHandler diff --git a/ql-rpc/src/router/subscription.rs b/ql-rpc/src/router/subscription.rs index dee51c59..9216680b 100644 --- a/ql-rpc/src/router/subscription.rs +++ b/ql-rpc/src/router/subscription.rs @@ -2,10 +2,7 @@ use std::marker::PhantomData; use bytes::Bytes; -use super::{ - request::read_value_and_eof, - RouterConfig, -}; +use super::{request::read_value_and_eof, RouterConfig}; use crate::{ codec, finish_bytes, subscription::Subscription as SubscriptionRpc, write_bytes, RpcCodec, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, diff --git a/ql-rpc/src/rpc/request_with_progress.rs b/ql-rpc/src/rpc/request_with_progress.rs index c1119840..e24ddfd2 100644 --- a/ql-rpc/src/rpc/request_with_progress.rs +++ b/ql-rpc/src/rpc/request_with_progress.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use bytes::{BufMut, Bytes}; -use crate::{CodecError, Error, RouteId, RpcCodec, codec}; +use crate::{codec, CodecError, Error, RouteId, RpcCodec}; pub trait RequestWithProgress { const ROUTE: RouteId; @@ -119,7 +119,7 @@ fn encode_tagged_value_part>( mod tests { use bytes::Bytes; - use super::{ReadStep, RequestWithProgress, ResponseReader, encode_progress, encode_response}; + use super::{encode_progress, encode_response, ReadStep, RequestWithProgress, ResponseReader}; use crate::RouteId; struct Watch; diff --git a/ql-rpc/src/rpc/subscription.rs b/ql-rpc/src/rpc/subscription.rs index bf96d45f..70e45280 100644 --- a/ql-rpc/src/rpc/subscription.rs +++ b/ql-rpc/src/rpc/subscription.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use bytes::{BufMut, Bytes}; -use crate::{CodecError, RouteId, RpcCodec, codec}; +use crate::{codec, CodecError, RouteId, RpcCodec}; pub trait Subscription { const ROUTE: RouteId; @@ -80,7 +80,7 @@ pub fn encode_item(item: &M::Event, out: &mut (impl BufMut + As mod tests { use bytes::Bytes; - use super::{ReadStep, ResponseReader, Subscription, encode_item}; + use super::{encode_item, ReadStep, ResponseReader, Subscription}; use crate::RouteId; struct Feed; diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index cb35a538..4d4f910e 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -242,10 +242,7 @@ impl DriverState { code, } => { debug!( - "runtime close stream command: stream_id={:?} target={:?} code={:?}", - stream_id, - target, - code + "runtime close stream command: stream_id={stream_id:?} target={target:?} code={code:?}" ); if let Entry::Occupied(mut entry) = self.streams.entry(stream_id) { let stream = entry.get_mut(); @@ -392,9 +389,7 @@ impl DriverState { } InboundWriteResult::Closed => { debug!( - "runtime inbound consumer closed; sending CANCELLED: stream_id={:?} target={:?}", - stream_id, - target + "runtime inbound consumer closed; sending CANCELLED: stream_id={stream_id:?} target={target:?}" ); peer_closed = true; break; @@ -418,7 +413,7 @@ impl DriverState { } fn handle_inbound_finished(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { - debug!("runtime inbound finished event: stream_id={:?}", stream_id); + debug!("runtime inbound finished event: stream_id={stream_id:?}"); let Some(stream) = self.streams.get_mut(&stream_id) else { return; }; @@ -441,10 +436,7 @@ impl DriverState { return; } - debug!( - "runtime delivering clean inbound finish: stream_id={:?}", - stream_id - ); + debug!("runtime delivering clean inbound finish: stream_id={stream_id:?}"); stream.inbound_finish(); Self::try_reap_stream(entry); } @@ -486,10 +478,7 @@ impl DriverState { } fn handle_outbound_finished(&mut self, stream_id: StreamId) { - debug!( - "runtime outbound finish acknowledged: stream_id={:?}", - stream_id - ); + debug!("runtime outbound finish acknowledged: stream_id={stream_id:?}"); let Entry::Occupied(mut entry) = self.streams.entry(stream_id) else { return; }; @@ -531,10 +520,7 @@ impl DriverState { }; if reader.is_finished() { - debug!( - "runtime observed outbound reader finished before write: stream_id={:?}", - stream_id - ); + debug!("runtime observed outbound reader finished before write: stream_id={stream_id:?}"); if let Ok(mut stream_ops) = fsm.stream(stream_id) { if let Some(writer) = stream_ops.writer() { writer.finish(); @@ -562,10 +548,7 @@ impl DriverState { } if reader.is_finished() { - debug!( - "runtime observed outbound reader finished after write: stream_id={:?}", - stream_id - ); + debug!("runtime observed outbound reader finished after write: stream_id={stream_id:?}"); writer.finish(); stream.outbound_queue_finish(); if stream.is_closed() { diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 158f6387..8c802d68 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -45,8 +45,8 @@ enum Side { impl Side { fn opposite(self) -> Self { match self { - Side::A => Side::B, - Side::B => Side::A, + Self::A => Self::B, + Self::B => Self::A, } } } @@ -210,7 +210,7 @@ impl TestPair { b_to_a: LinkController::new(b_to_a), }; - let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config.clone()); + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); tokio::task::spawn_local(async move { runtime_a.run().await }); @@ -246,10 +246,6 @@ impl TestPair { } } - fn handle(&self, side: Side) -> &RuntimeHandle { - &self.side(side).handle - } - fn side_mut(&mut self, side: Side) -> &mut TestSide { match side { Side::A => &mut self.a, @@ -557,7 +553,7 @@ where { tokio::time::timeout(duration, run_local_test(future)) .await - .unwrap_or_else(|_| panic!("local runtime test exceeded {:?}", duration)); + .unwrap_or_else(|_| panic!("local runtime test exceeded {duration:?}")); } async fn await_status(receiver: &Receiver, peer: XID, stage: PeerStatus) { diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index 811ec32f..b3ea0779 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -70,12 +70,18 @@ async fn reader_respects_max_len() { let inbound = inbound_b.recv().await.unwrap(); let mut reader = inbound.reader; - assert_eq!(next_chunk_max(&mut reader, 2).await.unwrap(), Some(vec![1, 2])); + assert_eq!( + next_chunk_max(&mut reader, 2).await.unwrap(), + Some(vec![1, 2]) + ); assert_eq!( next_chunk_max(&mut reader, 2).await.unwrap(), Some(vec![3, 4]) ); - assert_eq!(next_chunk_max(&mut reader, 2).await.unwrap(), Some(vec![5, 6])); + assert_eq!( + next_chunk_max(&mut reader, 2).await.unwrap(), + Some(vec![5, 6]) + ); assert_eq!(next_chunk(&mut reader).await.unwrap(), None); inbound.writer.finish().await.unwrap(); @@ -411,6 +417,7 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { .await; } +#[allow(clippy::too_many_lines)] #[tokio::test(flavor = "current_thread")] async fn multi_megabyte_stream_survives_asymmetric_loss_and_delay() { run_local_test_timeout(Duration::from_secs(5), async { @@ -579,10 +586,7 @@ async fn reproducer_writer_stalls_after_reverse_path_impairment() { let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); let mut reader = stream.reader; - let mut received = Vec::new(); - while let Some(chunk) = next_chunk(&mut reader).await.unwrap() { - received.extend_from_slice(&chunk); - } + while let Some(_) = next_chunk(&mut reader).await.unwrap() {} }); let recovery_links = links.clone(); diff --git a/ql-wire/src/encrypted/ack.rs b/ql-wire/src/encrypted/ack.rs index 6869a6e8..2eb34b33 100644 --- a/ql-wire/src/encrypted/ack.rs +++ b/ql-wire/src/encrypted/ack.rs @@ -133,7 +133,7 @@ impl WireEncode for RecordAck { self.largest_acked.encode(out); VarInt::try_from(self.blocks.len()).unwrap().encode(out); self.first_range_len.encode(out); - for block in self.blocks.iter() { + for block in &self.blocks { block.gap.encode(out); block.range_len.encode(out); } @@ -168,7 +168,7 @@ impl codec::WireDecode for RecordAck { .checked_sub(ack.first_range_len.into_inner()) .ok_or(WireError::InvalidPayload)?; - for block in ack.blocks.iter() { + for block in &ack.blocks { let end = previous_start .checked_sub( block From 7e8c0f16d2303b62de68d07d70ecfb2665ead302 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 15:53:36 -0400 Subject: [PATCH 230/304] log feature --- ql-runtime/Cargo.toml | 3 +- ql-runtime/src/driver/mod.rs | 24 ++++++++------ ql-runtime/src/handle/reader.rs | 29 +++++++++------- ql-runtime/src/handle/writer.rs | 21 +++++++----- ql-runtime/src/lib.rs | 1 + ql-runtime/src/log.rs | 59 +++++++++++++++++++++++++++++++++ ql-runtime/src/tests/rpc.rs | 52 ++++++++++++++--------------- 7 files changed, 129 insertions(+), 60 deletions(-) create mode 100644 ql-runtime/src/log.rs diff --git a/ql-runtime/Cargo.toml b/ql-runtime/Cargo.toml index 58009478..50f7db19 100644 --- a/ql-runtime/Cargo.toml +++ b/ql-runtime/Cargo.toml @@ -7,6 +7,7 @@ license = "Proprietary" [features] default = [] +log = ["dep:log"] rpc = ["dep:ql-rpc"] [dependencies] @@ -14,7 +15,7 @@ async-channel = { version = "2.5" } bytes = "1" event-listener = "5.4" futures-lite = { version = "2.5" } -log = "0.4" +log = { version = "0.4", optional = true } oneshot = { version = "0.1.11" } ql-fsm = { path = "../ql-fsm" } ql-rpc = { path = "../ql-rpc", optional = true } diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 4d4f910e..5ecf89e9 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -15,7 +15,6 @@ use std::{ use async_channel::Recv; use futures_lite::future::{poll_fn, yield_now}; -use log::debug; use ql_fsm::{Event, FsmTime, QlFsm, WriteId}; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; @@ -24,6 +23,7 @@ use crate::{ chunk_slot, command::RuntimeCommand, handle::{ByteReader, ByteWriter, QlStream}, + log, platform::{QlPlatform, QlTimer}, QlStreamError, Runtime, RuntimeHandle, }; @@ -241,7 +241,7 @@ impl DriverState { target, code, } => { - debug!( + log::debug!( "runtime close stream command: stream_id={stream_id:?} target={target:?} code={code:?}" ); if let Entry::Occupied(mut entry) = self.streams.entry(stream_id) { @@ -388,7 +388,7 @@ impl DriverState { break; } InboundWriteResult::Closed => { - debug!( + log::debug!( "runtime inbound consumer closed; sending CANCELLED: stream_id={stream_id:?} target={target:?}" ); peer_closed = true; @@ -413,7 +413,7 @@ impl DriverState { } fn handle_inbound_finished(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { - debug!("runtime inbound finished event: stream_id={stream_id:?}"); + log::debug!("runtime inbound finished event: stream_id={stream_id:?}"); let Some(stream) = self.streams.get_mut(&stream_id) else { return; }; @@ -436,13 +436,13 @@ impl DriverState { return; } - debug!("runtime delivering clean inbound finish: stream_id={stream_id:?}"); + log::debug!("runtime delivering clean inbound finish: stream_id={stream_id:?}"); stream.inbound_finish(); Self::try_reap_stream(entry); } fn handle_closed_stream(&mut self, frame: &ql_wire::StreamClose) { - debug!( + log::debug!( "runtime inbound close frame: stream_id={:?} target={:?} code={:?}", frame.stream_id, frame.target, @@ -463,7 +463,7 @@ impl DriverState { } fn handle_writable_closed(&mut self, frame: &ql_wire::StreamClose) { - debug!( + log::debug!( "runtime writable close frame: stream_id={:?} target={:?} code={:?}", frame.stream_id, frame.target, @@ -478,7 +478,7 @@ impl DriverState { } fn handle_outbound_finished(&mut self, stream_id: StreamId) { - debug!("runtime outbound finish acknowledged: stream_id={stream_id:?}"); + log::debug!("runtime outbound finish acknowledged: stream_id={stream_id:?}"); let Entry::Occupied(mut entry) = self.streams.entry(stream_id) else { return; }; @@ -520,7 +520,9 @@ impl DriverState { }; if reader.is_finished() { - debug!("runtime observed outbound reader finished before write: stream_id={stream_id:?}"); + log::debug!( + "runtime observed outbound reader finished before write: stream_id={stream_id:?}" + ); if let Ok(mut stream_ops) = fsm.stream(stream_id) { if let Some(writer) = stream_ops.writer() { writer.finish(); @@ -548,7 +550,9 @@ impl DriverState { } if reader.is_finished() { - debug!("runtime observed outbound reader finished after write: stream_id={stream_id:?}"); + log::debug!( + "runtime observed outbound reader finished after write: stream_id={stream_id:?}" + ); writer.finish(); stream.outbound_queue_finish(); if stream.is_closed() { diff --git a/ql-runtime/src/handle/reader.rs b/ql-runtime/src/handle/reader.rs index 997e45b1..e2b176ec 100644 --- a/ql-runtime/src/handle/reader.rs +++ b/ql-runtime/src/handle/reader.rs @@ -6,10 +6,9 @@ use std::{ use bytes::Bytes; use event_listener::EventListener; -use log::{debug, trace}; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; -use crate::{chunk_slot::ChunkSlotRx, command::RuntimeCommand, QlStreamError, RuntimeHandle}; +use crate::{chunk_slot::ChunkSlotRx, command::RuntimeCommand, log, QlStreamError, RuntimeHandle}; pub struct ByteReader { stream_id: StreamId, @@ -74,7 +73,7 @@ impl ByteReader { if let Some(reader) = self.reader.as_ref() { match reader.poll_recv(max_len, &mut self.listener, cx) { Poll::Ready(Ok(bytes)) => { - trace!( + log::trace!( "byte reader received chunk: stream_id={:?} target={:?} len={}", self.stream_id, self.target, @@ -86,9 +85,10 @@ impl ByteReader { return Poll::Ready(Ok(Some(bytes))); } Poll::Ready(Err(_)) => { - debug!( + log::debug!( "byte reader channel closed: stream_id={:?} target={:?}", - self.stream_id, self.target + self.stream_id, + self.target ); self.reader = None; self.listener = None; @@ -113,18 +113,21 @@ impl ByteReader { match &self.terminal { TerminalState::Armed(_) => Poll::Pending, TerminalState::Terminal(Ok(())) => { - debug!( + log::debug!( "byte reader delivered clean eof: stream_id={:?} target={:?}", - self.stream_id, self.target + self.stream_id, + self.target ); self.terminal = TerminalState::Delivered; Poll::Ready(Ok(None)) } TerminalState::Terminal(Err(error)) => { let error = error.clone(); - debug!( + log::debug!( "byte reader delivered terminal error: stream_id={:?} target={:?} error={:?}", - self.stream_id, self.target, error + self.stream_id, + self.target, + error ); self.terminal = TerminalState::Delivered; Poll::Ready(Err(error)) @@ -153,9 +156,11 @@ impl ByteReader { if matches!(self.terminal, TerminalState::Delivered) { return; } - debug!( + log::debug!( "byte reader explicit close: stream_id={:?} target={:?} code={:?}", - self.stream_id, self.target, code + self.stream_id, + self.target, + code ); self.reader.take(); self.listener = None; @@ -173,7 +178,7 @@ impl Drop for ByteReader { if matches!(self.terminal, TerminalState::Delivered) { return; } - debug!( + log::debug!( "byte reader drop close: stream_id={:?} target={:?} code={:?}", self.stream_id, self.target, diff --git a/ql-runtime/src/handle/writer.rs b/ql-runtime/src/handle/writer.rs index dc5c4b3c..19858443 100644 --- a/ql-runtime/src/handle/writer.rs +++ b/ql-runtime/src/handle/writer.rs @@ -6,13 +6,12 @@ use std::{ use bytes::Bytes; use event_listener::EventListener; -use log::{debug, trace}; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; use crate::{ chunk_slot::{ChunkSlotTx, SendClosed}, command::RuntimeCommand, - QlStreamError, RuntimeHandle, + log, QlStreamError, RuntimeHandle, }; pub struct ByteWriter { @@ -60,7 +59,7 @@ impl ByteWriter { match writer.poll_send(bytes, &mut self.listener, cx) { Poll::Ready(Ok(())) => { - trace!( + log::trace!( "byte writer accepted chunk: stream_id={:?} target={:?}", self.stream_id, self.target @@ -70,9 +69,10 @@ impl ByteWriter { Poll::Ready(Ok(())) } Poll::Ready(Err(SendClosed(_bytes))) => { - debug!( + log::debug!( "byte writer send closed: stream_id={:?} target={:?}", - self.stream_id, self.target + self.stream_id, + self.target ); self.writer.take(); self.listener = None; @@ -91,9 +91,10 @@ impl ByteWriter { let Some(writer) = self.writer.take() else { return; }; - debug!( + log::debug!( "byte writer finish: stream_id={:?} target={:?}", - self.stream_id, self.target + self.stream_id, + self.target ); writer.close(); self.listener = None; @@ -167,9 +168,11 @@ impl ByteWriter { if self.writer.take().is_none() { return; } - debug!( + log::debug!( "byte writer close: stream_id={:?} target={:?} code={:?}", - self.stream_id, self.target, code + self.stream_id, + self.target, + code ); self.listener = None; self.handle.try_send(RuntimeCommand::CloseStream { diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index d423adb0..ea965d64 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -7,6 +7,7 @@ pub(crate) mod command; pub(crate) mod driver; mod error; pub mod handle; +pub mod log; pub mod platform; #[cfg(feature = "rpc")] pub mod rpc; diff --git a/ql-runtime/src/log.rs b/ql-runtime/src/log.rs new file mode 100644 index 00000000..943ff26a --- /dev/null +++ b/ql-runtime/src/log.rs @@ -0,0 +1,59 @@ +#![allow(unused_imports, unused_macros)] + +#[cfg(feature = "log")] +macro_rules! with_log { + ($($tt:tt)*) => { + $($tt)* + }; +} + +#[cfg(not(feature = "log"))] +macro_rules! with_log { + ($($tt:tt)*) => {}; +} + +macro_rules! log { + ($level:ident, $($arg:tt)*) => { + $crate::log::with_log! { + ::log::log!(::log::Level::$level, $($arg)*) + } + }; +} + +macro_rules! trace { + ($($arg:tt)*) => { + $crate::log::log!(Trace, $($arg)*) + }; +} + +macro_rules! debug { + ($($arg:tt)*) => { + $crate::log::log!(Debug, $($arg)*) + }; +} + +macro_rules! info { + ($($arg:tt)*) => { + $crate::log::log!(Info, $($arg)*) + }; +} + +macro_rules! warn_ { + ($($arg:tt)*) => { + $crate::log::log!(Warn, $($arg)*) + }; +} + +macro_rules! error { + ($($arg:tt)*) => { + $crate::log::log!(Error, $($arg)*) + }; +} + +pub(crate) use debug; +pub(crate) use error; +pub(crate) use info; +pub(crate) use log; +pub(crate) use trace; +pub(crate) use warn_ as warn; +pub(crate) use with_log; diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 400e9f94..6482a4ef 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -71,7 +71,7 @@ async fn rpc_request_round_trips() { writer.finish().await.unwrap(); }); - let rpc = pair.handle(Side::A).rpc(); + let rpc = pair.side_mut(Side::A).handle.rpc(); let response = rpc.request::(&"hello".into()).await.unwrap(); assert_eq!(response, "world"); @@ -106,11 +106,10 @@ async fn rpc_router_handles_request() { let inbound_b = pair.take_inbound(Side::B); let seen = Rc::new(RefCell::new(Vec::new())); - let router = ql_rpc::Router::<_, QlStream, crate::rpc::LocalSpawn>::builder( - crate::rpc::LocalSpawn, - ) - .request::() - .build(RouterState { seen: seen.clone() }); + let router = + ql_rpc::Router::<_, QlStream, crate::rpc::LocalSpawn>::builder(crate::rpc::LocalSpawn) + .request::() + .build(RouterState { seen: seen.clone() }); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); @@ -119,7 +118,7 @@ async fn rpc_router_handles_request() { } }); - let rpc = pair.handle(Side::A).rpc(); + let rpc = pair.side_mut(Side::A).handle.rpc(); let response = rpc.request::(&"hello".into()).await.unwrap(); assert_eq!(response, "world"); assert_eq!(&*seen.borrow(), &["hello".to_string()]); @@ -161,11 +160,10 @@ async fn rpc_router_handles_subscription() { let inbound_b = pair.take_inbound(Side::B); let seen = Rc::new(RefCell::new(Vec::new())); - let router = ql_rpc::Router::<_, QlStream, crate::rpc::LocalSpawn>::builder( - crate::rpc::LocalSpawn, - ) - .subscription::() - .build(RouterState { seen: seen.clone() }); + let router = + ql_rpc::Router::<_, QlStream, crate::rpc::LocalSpawn>::builder(crate::rpc::LocalSpawn) + .subscription::() + .build(RouterState { seen: seen.clone() }); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); @@ -174,7 +172,7 @@ async fn rpc_router_handles_subscription() { } }); - let rpc = pair.handle(Side::A).rpc(); + let rpc = pair.side_mut(Side::A).handle.rpc(); let mut subscription = rpc.subscribe::(&b"watch".to_vec()).await.unwrap(); assert_eq!(subscription.next().await.unwrap().unwrap(), b"one".to_vec()); assert_eq!(subscription.next().await.unwrap().unwrap(), b"two".to_vec()); @@ -211,11 +209,10 @@ async fn rpc_send_router_handles_request() { pair.connect_and_wait(Side::A).await; let inbound_b = pair.take_inbound(Side::B); let seen = Arc::new(Mutex::new(Vec::new())); - let router = ql_rpc::Router::<_, QlStream, crate::rpc::SendSpawn>::builder( - crate::rpc::SendSpawn, - ) - .request::() - .build(RouterState { seen: seen.clone() }); + let router = + ql_rpc::Router::<_, QlStream, crate::rpc::SendSpawn>::builder(crate::rpc::SendSpawn) + .request::() + .build(RouterState { seen: seen.clone() }); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); @@ -225,7 +222,7 @@ async fn rpc_send_router_handles_request() { } }); - let rpc = pair.handle(Side::A).rpc(); + let rpc = pair.side_mut(Side::A).handle.rpc(); let response = rpc.request::(&"hello".into()).await.unwrap(); assert_eq!(response, "world"); assert_eq!(&*seen.lock().unwrap(), &["hello".to_string()]); @@ -259,12 +256,11 @@ async fn rpc_router_enforces_max_request_bytes() { let mut pair = TestPair::new(default_runtime_config()); pair.connect_and_wait(Side::A).await; let inbound_b = pair.take_inbound(Side::B); - let router = ql_rpc::Router::<_, QlStream, crate::rpc::LocalSpawn>::builder( - crate::rpc::LocalSpawn, - ) - .max_request_bytes(4) - .request::() - .build(LimitedState); + let router = + ql_rpc::Router::<_, QlStream, crate::rpc::LocalSpawn>::builder(crate::rpc::LocalSpawn) + .max_request_bytes(4) + .request::() + .build(LimitedState); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); @@ -273,7 +269,7 @@ async fn rpc_router_enforces_max_request_bytes() { } }); - let rpc = pair.handle(Side::A).rpc(); + let rpc = pair.side_mut(Side::A).handle.rpc(); let response = rpc.request::(&"hello".to_string()).await; assert!(matches!( response, @@ -313,7 +309,7 @@ async fn rpc_subscription_streams_events() { writer.finish().await.unwrap(); }); - let rpc = pair.handle(Side::A).rpc(); + let rpc = pair.side_mut(Side::A).handle.rpc(); let mut subscription = rpc.subscribe::(&b"watch".to_vec()).await.unwrap(); assert_eq!(subscription.next().await.unwrap().unwrap(), b"one".to_vec()); assert_eq!(subscription.next().await.unwrap().unwrap(), b"two".to_vec()); @@ -364,7 +360,7 @@ async fn rpc_request_with_progress_supports_progress_then_await() { writer.finish().await.unwrap(); }); - let rpc = pair.handle(Side::A).rpc(); + let rpc = pair.side_mut(Side::A).handle.rpc(); let mut download = rpc .request_with_progress::(&b"logo".to_vec()) .await From 44331e075c848bcc24803aab74b57b49309efeb7 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 17:33:35 -0400 Subject: [PATCH 231/304] ql-wire: impl display --- ql-runtime/src/command.rs | 19 ++++++++++++++++++- ql-wire/src/encrypted/route_id.rs | 6 ++++++ ql-wire/src/encrypted/stream_close.rs | 6 ++++++ ql-wire/src/encrypted/stream_id.rs | 6 ++++++ 4 files changed, 36 insertions(+), 1 deletion(-) diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index 71b03879..95784845 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -3,7 +3,7 @@ use ql_wire::{CloseTarget, PairingToken, PeerBundle, RouteId, StreamCloseCode, S use crate::{chunk_slot::ChunkSlotRx, ByteReader, QlStreamError}; -pub(crate) enum RuntimeCommand { +pub enum RuntimeCommand { BindPeer { peer: PeerBundle, }, @@ -34,3 +34,20 @@ pub(crate) enum RuntimeCommand { }, Receive(Vec), } + +impl RuntimeCommand { + pub fn kind(&self) -> &'static str { + match self { + Self::BindPeer { .. } => "BindPeer", + Self::Connect => "Connect", + Self::ArmPairing { .. } => "ArmPairing", + Self::DisarmPairing => "DisarmPairing", + Self::StartPairing { .. } => "StartPairing", + Self::OpenStream { .. } => "OpenStream", + Self::PollInbound { .. } => "PollInbound", + Self::PollStream { .. } => "PollStream", + Self::CloseStream { .. } => "CloseStream", + Self::Receive(_) => "Receive", + } + } +} diff --git a/ql-wire/src/encrypted/route_id.rs b/ql-wire/src/encrypted/route_id.rs index f7b51999..6b91a521 100644 --- a/ql-wire/src/encrypted/route_id.rs +++ b/ql-wire/src/encrypted/route_id.rs @@ -47,3 +47,9 @@ impl From for RouteId { Self::from_u32(value) } } + +impl std::fmt::Display for RouteId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs index 2885eaa0..20ddb879 100644 --- a/ql-wire/src/encrypted/stream_close.rs +++ b/ql-wire/src/encrypted/stream_close.rs @@ -108,3 +108,9 @@ impl WireEncode for StreamCloseCode { self.0.encode(out); } } + +impl std::fmt::Display for StreamCloseCode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} diff --git a/ql-wire/src/encrypted/stream_id.rs b/ql-wire/src/encrypted/stream_id.rs index fdbf564d..07002259 100644 --- a/ql-wire/src/encrypted/stream_id.rs +++ b/ql-wire/src/encrypted/stream_id.rs @@ -27,3 +27,9 @@ impl WireDecode for StreamId { Ok(Self(reader.decode()?)) } } + +impl std::fmt::Display for StreamId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} From 3eb9c6600d49658160e6f8797d997f67bc09b542 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 17:49:27 -0400 Subject: [PATCH 232/304] ql-runtime: add more logs --- ql-runtime/src/driver/mod.rs | 99 ++++++++++++++++++++++++++++-------- 1 file changed, 77 insertions(+), 22 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 5ecf89e9..d5e5cd2a 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -77,28 +77,39 @@ impl Runtime

{ match step { DriverStep::Command(command) => { + log::trace!("processing command: kind={}", command.kind()); state.drive_command(&mut fsm, command, &platform); } DriverStep::WriteCompleted { index, success } => { let write = in_flight.swap_remove(index); - DriverState::drive_write_completed(&mut fsm, write.session_write_id, success); + let write_id = write.write_id; + log::trace!( + "write completed: success={success} index={index} write_id={write_id:?}", + ); + DriverState::drive_write_completed(&mut fsm, write_id, success); yield_now().await; } DriverStep::TimerExpired => { + log::trace!("timer expired"); fsm.on_timer(now()); } DriverStep::Closed => { + log::debug!( + "command channel closed: in_flight_writes={}", + in_flight.len() + ); if in_flight.is_empty() { break; } } } } + log::info!("runtime stopped"); } } struct InFlightWrite { - session_write_id: Option, + write_id: Option, future: F, } @@ -163,22 +174,31 @@ impl DriverState { ) { match command { RuntimeCommand::BindPeer { peer } => { + log::info!("binding peer"); fsm.bind_peer(peer); } RuntimeCommand::Connect => { - let _ = fsm.connect_ik(now(), platform); + log::info!("starting IK connect"); + if fsm.connect_ik(now(), platform).is_err() { + log::warn!("IK connect ignored: no bound peer"); + } } RuntimeCommand::ArmPairing { token } => { + log::info!("arming inbound pairing"); fsm.arm_pairing(token); } RuntimeCommand::DisarmPairing => { + log::info!("disarming inbound pairing"); fsm.disarm_pairing(); } RuntimeCommand::StartPairing { token } => { + log::info!(" starting XX pairing"); fsm.connect_xx(now(), token, platform); } RuntimeCommand::Receive(bytes) => { + log::trace!("received transport frame: len={}", bytes.len()); if let Err(e) = fsm.receive(now(), bytes, platform) { + log::info!("receive rejected frame: error={e:?}"); platform.handle_recv_error(e); } } @@ -188,7 +208,9 @@ impl DriverState { request_terminal, start, } => { + log::info!("open stream requested: route_id={route_id}"); let Some(runtime_tx) = self.runtime_tx.upgrade() else { + log::warn!("open stream aborted: runtime channel unavailable"); let _ = start.send(Err(ql_fsm::NoSessionError)); return; }; @@ -196,11 +218,13 @@ impl DriverState { let mut stream_ops = match fsm.open_stream(route_id) { Ok(stream_ops) => stream_ops, Err(error) => { + log::warn!("open stream failed: route_id={route_id}"); let _ = start.send(Err(error)); return; } }; let stream_id = stream_ops.stream_id(); + log::info!("open stream allocated: route_id={route_id} stream_id={stream_id}"); let (response_reader, response_writer) = chunk_slot::new(); let (response_terminal_tx, response_terminal_rx) = oneshot::channel(); self.streams.insert( @@ -219,6 +243,7 @@ impl DriverState { RuntimeHandle::new(runtime_tx), ); if start.send(Ok((stream_id, reader))).is_err() { + log::warn!("open stream cancelled before delivery: stream_id={stream_id}"); if let Some(stream) = self.streams.get_mut(&stream_id) { stream.inbound_close(); stream.outbound_close(); @@ -231,9 +256,11 @@ impl DriverState { self.poll_stream(fsm, stream_id); } RuntimeCommand::PollInbound { stream_id } => { + log::trace!("poll inbound requested: stream_id={stream_id}"); self.handle_inbound_readable(fsm, stream_id); } RuntimeCommand::PollStream { stream_id } => { + log::trace!("poll stream requested: stream_id={stream_id}"); self.poll_stream(fsm, stream_id); } RuntimeCommand::CloseStream { @@ -242,7 +269,7 @@ impl DriverState { code, } => { log::debug!( - "runtime close stream command: stream_id={stream_id:?} target={target:?} code={code:?}" + "close stream command: stream_id={stream_id} target={target:?} code={code:?}" ); if let Entry::Occupied(mut entry) = self.streams.entry(stream_id) { let stream = entry.get_mut(); @@ -269,13 +296,16 @@ impl DriverState { fn drain_fsm_events(&mut self, fsm: &mut QlFsm, platform: &P) { while let Some(event) = fsm.poll_event() { + log::trace!("polled FSM event: event={event:?}"); match event { Event::NewPeer => { + log::info!("new ql peer"); if let Some(peer) = fsm.peer().cloned() { platform.persist_peer(peer); } } Event::PeerStatusChanged(status) => { + log::info!("peer status changed: status={status:?}"); if let Some(peer) = fsm.peer().map(|peer| peer.xid) { platform.handle_peer_status(peer, status); } @@ -284,18 +314,23 @@ impl DriverState { stream_id, route_id, } => { + log::info!("inbound stream opened: stream_id={stream_id} route_id={route_id}"); self.handle_opened_stream(fsm, platform, stream_id, route_id); } Event::Readable(stream_id) => { + log::trace!("stream readable: stream_id={stream_id}"); self.handle_inbound_readable(fsm, stream_id); } Event::Writable(stream_id) => { + log::trace!("stream writable: stream_id={stream_id}"); self.poll_stream(fsm, stream_id); } Event::Finished(stream_id) => { + log::info!("peer finished stream writes: stream_id={stream_id}"); self.handle_inbound_finished(fsm, stream_id); } Event::OutboundFinished(stream_id) => { + log::info!("outbound finish acknowledged: stream_id={stream_id}"); self.handle_outbound_finished(stream_id); } Event::Closed(frame) => { @@ -304,7 +339,8 @@ impl DriverState { Event::WritableClosed(frame) => { self.handle_writable_closed(&frame); } - Event::SessionClosed(_) => { + Event::SessionClosed(_close) => { + log::info!("session closed: frame={_close:?}"); for (_, mut stream) in self.streams.drain() { stream.fail_all(); } @@ -321,6 +357,9 @@ impl DriverState { route_id: ql_wire::RouteId, ) { let Some(runtime_tx) = self.runtime_tx.upgrade() else { + log::warn!( + "dropping inbound stream because handle channel is unavailable: stream_id={stream_id}" + ); if let Ok(mut stream) = fsm.stream(stream_id) { stream.close(CloseTarget::Both, StreamCloseCode::CANCELLED); } @@ -341,6 +380,9 @@ impl DriverState { ), ); + log::info!( + "delivering inbound stream to platform: stream_id={stream_id} route_id={route_id}" + ); platform.handle_inbound(QlStream { stream_id, route_id, @@ -363,11 +405,14 @@ impl DriverState { fn handle_inbound_readable(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { let Ok(mut stream_ops) = fsm.stream(stream_id) else { + log::info!("inbound readable for unknown stream: stream_id={stream_id}"); return; }; - if stream_ops.readable_bytes() == 0 { + let readable = stream_ops.readable_bytes(); + if readable == 0 { return; } + log::trace!("draining inbound bytes: stream_id={stream_id} readable={readable}"); let mut accepted = 0usize; let mut peer_closed = false; let target; @@ -385,11 +430,14 @@ impl DriverState { accepted += n; } InboundWriteResult::Full => { + log::debug!( + "inbound backpressure: stream_id={stream_id} accepted={accepted}" + ); break; } InboundWriteResult::Closed => { - log::debug!( - "runtime inbound consumer closed; sending CANCELLED: stream_id={stream_id:?} target={target:?}" + log::warn!( + "inbound consumer closed; sending CANCELLED: stream_id={stream_id} target={target:?}" ); peer_closed = true; break; @@ -399,6 +447,7 @@ impl DriverState { } if accepted > 0 { + log::trace!("committed inbound bytes: stream_id={stream_id:?} accepted={accepted}"); stream_ops.commit_read(accepted).unwrap(); } if peer_closed { @@ -413,7 +462,7 @@ impl DriverState { } fn handle_inbound_finished(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { - log::debug!("runtime inbound finished event: stream_id={stream_id:?}"); + log::info!("inbound finished event: stream_id={stream_id}"); let Some(stream) = self.streams.get_mut(&stream_id) else { return; }; @@ -436,14 +485,14 @@ impl DriverState { return; } - log::debug!("runtime delivering clean inbound finish: stream_id={stream_id:?}"); + log::info!("delivering clean inbound finish: stream_id={stream_id}"); stream.inbound_finish(); Self::try_reap_stream(entry); } fn handle_closed_stream(&mut self, frame: &ql_wire::StreamClose) { - log::debug!( - "runtime inbound close frame: stream_id={:?} target={:?} code={:?}", + log::info!( + "inbound close frame: stream_id={} target={:?} code={}", frame.stream_id, frame.target, frame.code @@ -463,8 +512,8 @@ impl DriverState { } fn handle_writable_closed(&mut self, frame: &ql_wire::StreamClose) { - log::debug!( - "runtime writable close frame: stream_id={:?} target={:?} code={:?}", + log::info!( + "writable close frame: stream_id={} target={:?} code={}", frame.stream_id, frame.target, frame.code @@ -478,7 +527,7 @@ impl DriverState { } fn handle_outbound_finished(&mut self, stream_id: StreamId) { - log::debug!("runtime outbound finish acknowledged: stream_id={stream_id:?}"); + log::info!("outbound finish acknowledged: stream_id={stream_id}"); let Entry::Occupied(mut entry) = self.streams.entry(stream_id) else { return; }; @@ -502,8 +551,13 @@ impl DriverState { break; }; filled = true; + log::trace!( + "queueing transport write: bytes={} write_id={:?}", + write.record.len(), + write.write_id + ); in_flight.push(InFlightWrite { - session_write_id: write.write_id, + write_id: write.write_id, future: platform.write_message(write.record), }); } @@ -516,13 +570,12 @@ impl DriverState { }; let stream = entry.get_mut(); let Some(reader) = stream.outbound_reader_mut() else { + log::trace!("poll stream skipped without outbound reader: stream_id={stream_id}"); return; }; if reader.is_finished() { - log::debug!( - "runtime observed outbound reader finished before write: stream_id={stream_id:?}" - ); + log::info!("observed outbound reader finished before write: stream_id={stream_id}"); if let Ok(mut stream_ops) = fsm.stream(stream_id) { if let Some(writer) = stream_ops.writer() { writer.finish(); @@ -539,20 +592,22 @@ impl DriverState { return; }; let Some(mut writer) = stream_ops.writer() else { + log::trace!("poll stream skipped without session writer: stream_id={stream_id}"); return; }; let capacity = writer.capacity(); + log::trace!("stream write capacity: stream_id={stream_id} capacity={capacity}"); if capacity > 0 { if let Ok(Some(mut bytes)) = reader.try_recv(capacity) { + let _len = bytes.len(); + log::trace!("writing stream bytes: stream_id={stream_id} len={_len}"); let _ = writer.write(&mut bytes); } } if reader.is_finished() { - log::debug!( - "runtime observed outbound reader finished after write: stream_id={stream_id:?}" - ); + log::info!("observed outbound reader finished after write: stream_id={stream_id}"); writer.finish(); stream.outbound_queue_finish(); if stream.is_closed() { From 2aef3a9e31403fb8dd55203adc704e0003083c2f Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 18:35:11 -0400 Subject: [PATCH 233/304] better chunk slot --- Cargo.lock | 1 + ql-runtime/Cargo.toml | 1 + ql-runtime/src/chunk_slot.rs | 294 ++++++++++++-------------------- ql-runtime/src/driver/mod.rs | 10 +- ql-runtime/src/handle/reader.rs | 18 +- ql-runtime/src/handle/writer.rs | 14 +- 6 files changed, 137 insertions(+), 201 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fd232744..dea9f37f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2236,6 +2236,7 @@ version = "0.1.0" dependencies = [ "async-channel", "bytes", + "concurrent-queue", "event-listener", "futures-lite", "log", diff --git a/ql-runtime/Cargo.toml b/ql-runtime/Cargo.toml index 50f7db19..077faf10 100644 --- a/ql-runtime/Cargo.toml +++ b/ql-runtime/Cargo.toml @@ -13,6 +13,7 @@ rpc = ["dep:ql-rpc"] [dependencies] async-channel = { version = "2.5" } bytes = "1" +concurrent-queue = { version = "2.5" } event-listener = "5.4" futures-lite = { version = "2.5" } log = { version = "0.4", optional = true } diff --git a/ql-runtime/src/chunk_slot.rs b/ql-runtime/src/chunk_slot.rs index d536bc33..d446859d 100644 --- a/ql-runtime/src/chunk_slot.rs +++ b/ql-runtime/src/chunk_slot.rs @@ -5,47 +5,48 @@ use std::{ }; use bytes::Bytes; +use concurrent_queue::{ConcurrentQueue, PopError, PushError}; use event_listener::{Event, EventListener}; mod sync { #[cfg(not(all(test, loom)))] - pub use std::sync::atomic::{AtomicU8, Ordering}; - #[cfg(not(all(test, loom)))] - pub use std::sync::{Arc, Mutex}; + pub use std::sync::Arc; #[cfg(all(test, loom))] - pub use loom::sync::atomic::{AtomicU8, Ordering}; - #[cfg(all(test, loom))] - pub use loom::sync::{Arc, Mutex}; + pub use loom::sync::Arc; } -use sync::{Arc, AtomicU8, Mutex, Ordering}; - -const OCCUPIED: u8 = 1 << 0; -const TX_CLOSED: u8 = 1 << 1; -const RX_CLOSED: u8 = 1 << 2; +use sync::*; +/// creates a single-chunk handoff pair +/// receiver-side partial reads keep the remainder locally pub fn new() -> (ChunkSlotRx, ChunkSlotTx) { - let inner = Arc::new(Inner { - chunk: Mutex::new(None), - state: AtomicU8::new(0), + let shared = Arc::new(Shared { + queue: ConcurrentQueue::bounded(1), changed: Event::new(), }); ( ChunkSlotRx { - inner: inner.clone(), + shared: Arc::clone(&shared), + pending: Bytes::new(), }, - ChunkSlotTx { inner }, + ChunkSlotTx { shared }, ) } pub struct ChunkSlotRx { - inner: Arc, + shared: Arc, + pending: Bytes, } pub struct ChunkSlotTx { - inner: Arc, + shared: Arc, +} + +struct Shared { + queue: ConcurrentQueue, + changed: Event, } #[derive(Debug)] @@ -61,21 +62,47 @@ pub enum TrySendError { pub struct RecvClosed; impl ChunkSlotRx { - pub fn try_recv(&self, max_len: usize) -> Result, RecvClosed> { - self.inner.try_recv(max_len) + pub fn try_recv(&mut self, max_len: usize) -> Result { + if !self.pending.is_empty() { + let pending = &mut self.pending; + let bytes = if pending.len() <= max_len { + std::mem::take(pending) + } else { + pending.split_to(max_len) + }; + return Ok(bytes); + } + + match self.shared.queue.pop() { + Ok(mut bytes) => { + self.shared.changed.notify(usize::MAX); + let pending = &mut self.pending; + + let bytes = if bytes.len() <= max_len { + bytes + } else { + let head = bytes.split_to(max_len); + *pending = bytes; + head + }; + Ok(bytes) + } + Err(PopError::Empty) => Ok(Bytes::new()), + Err(PopError::Closed) => Err(RecvClosed), + } } pub fn poll_recv( - &self, + &mut self, max_len: usize, listener: &mut Option, cx: &mut Context<'_>, ) -> Poll> { loop { match self.try_recv(max_len) { - Ok(Some(bytes)) => return Poll::Ready(Ok(bytes)), + Ok(bytes) if !bytes.is_empty() => return Poll::Ready(Ok(bytes)), Err(closed) => return Poll::Ready(Err(closed)), - Ok(None) => {} + Ok(_) => {} } if let Some(active_listener) = listener.as_mut() { @@ -84,12 +111,12 @@ impl ChunkSlotRx { Poll::Pending => return Poll::Pending, } } else { - *listener = Some(self.inner.changed.listen()); + *listener = Some(self.shared.changed.listen()); } } } - pub fn recv(&self, max_len: usize) -> Recv<'_> { + pub fn recv(&mut self, max_len: usize) -> Recv<'_> { Recv { rx: self, max_len, @@ -98,27 +125,38 @@ impl ChunkSlotRx { } pub fn is_finished(&self) -> bool { - self.inner.snapshot(Ordering::Acquire).is_finished() + self.pending.is_empty() && self.shared.queue.is_closed() && self.shared.queue.is_empty() } pub fn is_empty(&self) -> bool { - !self.inner.snapshot(Ordering::Relaxed).is_occupied() + self.pending.is_empty() && self.shared.queue.is_empty() } pub fn close(self) { - self.inner.close_rx(); + if self.shared.queue.close() { + self.shared.changed.notify(usize::MAX); + } } } impl Drop for ChunkSlotRx { fn drop(&mut self) { - self.inner.close_rx(); + if self.shared.queue.close() { + self.shared.changed.notify(usize::MAX); + } } } impl ChunkSlotTx { pub fn try_send(&self, bytes: Bytes) -> Result<(), TrySendError> { - self.inner.try_send(bytes) + match self.shared.queue.push(bytes) { + Ok(()) => { + self.shared.changed.notify(usize::MAX); + Ok(()) + } + Err(PushError::Full(bytes)) => Err(TrySendError::Full(bytes)), + Err(PushError::Closed(bytes)) => Err(TrySendError::Closed(bytes)), + } } pub fn poll_send( @@ -129,7 +167,6 @@ impl ChunkSlotTx { ) -> Poll> { loop { let chunk = std::mem::take(bytes); - match self.try_send(chunk) { Ok(()) => return Poll::Ready(Ok(())), Err(TrySendError::Closed(chunk)) => { @@ -145,7 +182,7 @@ impl ChunkSlotTx { Poll::Pending => return Poll::Pending, } } else { - *listener = Some(self.inner.changed.listen()); + *listener = Some(self.shared.changed.listen()); } } } @@ -159,22 +196,26 @@ impl ChunkSlotTx { } pub fn is_closed(&self) -> bool { - self.inner.snapshot(Ordering::Acquire).is_closed() + self.shared.queue.is_closed() } pub fn close(self) { - self.inner.close_tx(); + if self.shared.queue.close() { + self.shared.changed.notify(usize::MAX); + } } } impl Drop for ChunkSlotTx { fn drop(&mut self) { - self.inner.close_tx(); + if self.shared.queue.close() { + self.shared.changed.notify(usize::MAX); + } } } pub struct Recv<'a> { - rx: &'a ChunkSlotRx, + rx: &'a mut ChunkSlotRx, max_len: usize, listener: Option, } @@ -183,7 +224,8 @@ impl Future for Recv<'_> { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.rx.poll_recv(self.max_len, &mut self.listener, cx) + let this = self.as_mut().get_mut(); + this.rx.poll_recv(this.max_len, &mut this.listener, cx) } } @@ -202,170 +244,47 @@ impl Future for Send<'_> { } } -struct Inner { - chunk: Mutex>, - state: AtomicU8, - changed: Event, -} - -#[derive(Clone, Copy)] -struct StateSnapshot(u8); - -impl StateSnapshot { - fn has_any(self, bits: u8) -> bool { - self.0 & bits != 0 - } - - fn is_occupied(self) -> bool { - self.has_any(OCCUPIED) - } - - fn is_closed(self) -> bool { - self.has_any(TX_CLOSED | RX_CLOSED) - } - - fn is_finished(self) -> bool { - self.has_any(TX_CLOSED) && !self.is_occupied() - } -} - -impl Inner { - fn snapshot(&self, ordering: Ordering) -> StateSnapshot { - StateSnapshot(self.state.load(ordering)) - } - - fn mark_occupied(&self) { - self.state.fetch_or(OCCUPIED, Ordering::Release); - } - - fn clear_occupied(&self) { - self.state.fetch_and(!OCCUPIED, Ordering::Release); - } - - fn close_rx(&self) { - if !StateSnapshot(self.state.fetch_or(RX_CLOSED, Ordering::Release)).has_any(RX_CLOSED) { - self.changed.notify(usize::MAX); - } - } - - fn close_tx(&self) { - if !StateSnapshot(self.state.fetch_or(TX_CLOSED, Ordering::Release)).has_any(TX_CLOSED) { - self.changed.notify(usize::MAX); - } - } - - fn try_recv(&self, max_len: usize) -> Result, RecvClosed> { - let snapshot = self.snapshot(Ordering::Acquire); - if max_len == 0 || !snapshot.is_occupied() { - return if snapshot.is_closed() { - Err(RecvClosed) - } else { - Ok(None) - }; - } - - let (bytes, became_empty) = { - let Ok(mut chunk) = self.chunk.try_lock() else { - return Ok(None); - }; - let Some(result) = take_chunk(&mut chunk, max_len) else { - return Ok(None); - }; - result - }; - - if became_empty { - self.clear_occupied(); - self.changed.notify(usize::MAX); - } - - Ok(Some(bytes)) - } - - fn try_send(&self, bytes: Bytes) -> Result<(), TrySendError> { - let snapshot = self.snapshot(Ordering::Acquire); - if snapshot.is_closed() { - return Err(TrySendError::Closed(bytes)); - } - if snapshot.is_occupied() { - return Err(TrySendError::Full(bytes)); - } - - let result = { - let Ok(mut chunk) = self.chunk.try_lock() else { - return Err(TrySendError::Full(bytes)); - }; - if self.snapshot(Ordering::Relaxed).is_closed() { - Err(TrySendError::Closed(bytes)) - } else if chunk.is_some() { - Err(TrySendError::Full(bytes)) - } else { - *chunk = Some(bytes); - Ok(()) - } - }; - - if result.is_ok() { - self.mark_occupied(); - self.changed.notify(usize::MAX); - } - - result - } -} - -fn take_chunk(chunk: &mut Option, max_len: usize) -> Option<(Bytes, bool)> { - let bytes = chunk.as_mut()?; - if bytes.len() <= max_len { - Some((chunk.take().unwrap(), true)) - } else { - Some((bytes.split_to(max_len), false)) - } -} - #[cfg(test)] mod tests { use std::time::Duration; use bytes::Bytes; - use super::{new, TrySendError}; + use super::{new, RecvClosed}; #[test] fn try_send_and_take_round_trip() { - let (rx, tx) = new(); + let (mut rx, tx) = new(); tx.try_send(Bytes::from_static(b"hello")).unwrap(); - assert_eq!(rx.try_recv(8), Ok(Some(Bytes::from_static(b"hello")))); - assert_eq!(rx.try_recv(8), Ok(None)); + assert_eq!(rx.try_recv(8), Ok(Bytes::from_static(b"hello"))); + assert_eq!(rx.try_recv(8), Ok(Bytes::new())); } #[test] - fn read_splits_without_freeing_slot() { - let (rx, tx) = new(); + fn read_splits_moves_remainder_to_receiver() { + let (mut rx, tx) = new(); tx.try_send(Bytes::from_static(b"hello")).unwrap(); - assert_eq!(rx.try_recv(2), Ok(Some(Bytes::from_static(b"he")))); - assert_eq!( - tx.try_send(Bytes::from_static(b"!")), - Err(TrySendError::Full(Bytes::from_static(b"!"))) - ); - assert_eq!(rx.try_recv(8), Ok(Some(Bytes::from_static(b"llo")))); + assert_eq!(rx.try_recv(2), Ok(Bytes::from_static(b"he"))); + tx.try_send(Bytes::from_static(b"!")).unwrap(); + assert_eq!(rx.try_recv(8), Ok(Bytes::from_static(b"llo"))); + assert_eq!(rx.try_recv(8), Ok(Bytes::from_static(b"!"))); } #[test] fn read_drains_slot_when_limit_covers_chunk() { - let (rx, tx) = new(); + let (mut rx, tx) = new(); tx.try_send(Bytes::from_static(b"hello")).unwrap(); - assert_eq!(rx.try_recv(8), Ok(Some(Bytes::from_static(b"hello")))); + assert_eq!(rx.try_recv(8), Ok(Bytes::from_static(b"hello"))); tx.try_send(Bytes::from_static(b"!")).unwrap(); - assert_eq!(rx.try_recv(8), Ok(Some(Bytes::from_static(b"!")))); + assert_eq!(rx.try_recv(8), Ok(Bytes::from_static(b"!"))); } #[tokio::test(flavor = "current_thread")] async fn send_waits_until_slot_clears() { - let (rx, tx) = new(); + let (mut rx, tx) = new(); tx.try_send(Bytes::from_static(b"a")).unwrap(); @@ -374,7 +293,7 @@ mod tests { }); tokio::time::sleep(Duration::from_millis(10)).await; - assert_eq!(rx.try_recv(8), Ok(Some(Bytes::from_static(b"a")))); + assert_eq!(rx.try_recv(8), Ok(Bytes::from_static(b"a"))); tokio::time::timeout(Duration::from_secs(1), sender) .await @@ -384,13 +303,13 @@ mod tests { #[tokio::test(flavor = "current_thread")] async fn finish_yields_eof_after_buffered_chunk() { - let (rx, tx) = new(); + let (mut rx, tx) = new(); tx.send(Bytes::from_static(b"abc")).await.unwrap(); tx.close(); assert_eq!(rx.recv(8).await, Ok(Bytes::from_static(b"abc"))); - assert_eq!(rx.recv(8).await, Err(super::RecvClosed)); + assert_eq!(rx.recv(8).await, Err(RecvClosed)); assert!(rx.is_finished()); } @@ -403,6 +322,15 @@ mod tests { let error = tx.send(Bytes::from_static(b"abc")).await.unwrap_err(); assert_eq!(error.0, Bytes::from_static(b"abc")); } + + #[test] + fn zero_length_recv_does_not_consume_buffered_chunk() { + let (mut rx, tx) = new(); + + tx.try_send(Bytes::from_static(b"hello")).unwrap(); + assert_eq!(rx.try_recv(0), Ok(Bytes::new())); + assert_eq!(rx.try_recv(8), Ok(Bytes::from_static(b"hello"))); + } } #[cfg(all(test, loom))] @@ -437,7 +365,7 @@ mod loom_tests { #[test] fn try_recv_never_reports_closed_while_open() { check_model(|| { - let (rx, tx) = new(); + let (mut rx, tx) = new(); let sender = thread::spawn(move || { let _ = tx.try_send(Bytes::from_static(b"abc")); @@ -459,7 +387,7 @@ mod loom_tests { #[test] fn recv_observes_send_after_pending() { check_model(|| { - let (rx, tx) = new(); + let (mut rx, tx) = new(); assert!(now_or_never(rx.recv(8)).is_none()); @@ -479,7 +407,7 @@ mod loom_tests { #[test] fn recv_observes_finish_as_closed() { check_model(|| { - let (rx, tx) = new(); + let (mut rx, tx) = new(); assert!(now_or_never(rx.recv(8)).is_none()); @@ -496,14 +424,14 @@ mod loom_tests { #[test] fn partial_recv_preserves_remainder_and_finished_state() { check_model(|| { - let (rx, tx) = new(); + let (mut rx, tx) = new(); tx.try_send(Bytes::from_static(b"abcd")).unwrap(); tx.close(); - assert_eq!(rx.try_recv(2), Ok(Some(Bytes::from_static(b"ab")))); + assert_eq!(rx.try_recv(2), Ok(Bytes::from_static(b"ab"))); assert!(!rx.is_finished()); - assert_eq!(rx.try_recv(8), Ok(Some(Bytes::from_static(b"cd")))); + assert_eq!(rx.try_recv(8), Ok(Bytes::from_static(b"cd"))); assert_eq!(rx.try_recv(8), Err(RecvClosed)); assert!(rx.is_finished()); }); diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index d5e5cd2a..c4c48b1d 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -599,10 +599,12 @@ impl DriverState { let capacity = writer.capacity(); log::trace!("stream write capacity: stream_id={stream_id} capacity={capacity}"); if capacity > 0 { - if let Ok(Some(mut bytes)) = reader.try_recv(capacity) { - let _len = bytes.len(); - log::trace!("writing stream bytes: stream_id={stream_id} len={_len}"); - let _ = writer.write(&mut bytes); + if let Ok(mut bytes) = reader.try_recv(capacity) { + if !bytes.is_empty() { + let _len = bytes.len(); + log::trace!("writing stream bytes: stream_id={stream_id} len={_len}"); + let _ = writer.write(&mut bytes); + } } } diff --git a/ql-runtime/src/handle/reader.rs b/ql-runtime/src/handle/reader.rs index e2b176ec..e0952588 100644 --- a/ql-runtime/src/handle/reader.rs +++ b/ql-runtime/src/handle/reader.rs @@ -8,13 +8,17 @@ use bytes::Bytes; use event_listener::EventListener; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; -use crate::{chunk_slot::ChunkSlotRx, command::RuntimeCommand, log, QlStreamError, RuntimeHandle}; +use crate::{ + chunk_slot::ChunkSlotRx, + command::RuntimeCommand, + log, QlStreamError, RuntimeHandle, +}; pub struct ByteReader { stream_id: StreamId, target: CloseTarget, reader: Option, - listener: Option, + wait: Option, terminal: TerminalState, handle: RuntimeHandle, } @@ -55,7 +59,7 @@ impl ByteReader { stream_id, target, reader: Some(reader), - listener: None, + wait: None, terminal: TerminalState::Armed(terminal), handle, } @@ -70,8 +74,8 @@ impl ByteReader { return Poll::Ready(Ok(None)); } - if let Some(reader) = self.reader.as_ref() { - match reader.poll_recv(max_len, &mut self.listener, cx) { + if let Some(reader) = self.reader.as_mut() { + match reader.poll_recv(max_len, &mut self.wait, cx) { Poll::Ready(Ok(bytes)) => { log::trace!( "byte reader received chunk: stream_id={:?} target={:?} len={}", @@ -91,7 +95,7 @@ impl ByteReader { self.target ); self.reader = None; - self.listener = None; + self.wait = None; } Poll::Pending => {} } @@ -163,7 +167,7 @@ impl ByteReader { code ); self.reader.take(); - self.listener = None; + self.wait = None; self.terminal = TerminalState::Delivered; self.handle.try_send(RuntimeCommand::CloseStream { stream_id: self.stream_id, diff --git a/ql-runtime/src/handle/writer.rs b/ql-runtime/src/handle/writer.rs index 19858443..66addedc 100644 --- a/ql-runtime/src/handle/writer.rs +++ b/ql-runtime/src/handle/writer.rs @@ -18,7 +18,7 @@ pub struct ByteWriter { stream_id: StreamId, target: CloseTarget, writer: Option, - listener: Option, + wait: Option, terminal: WriteTerminalState, handle: RuntimeHandle, } @@ -57,14 +57,14 @@ impl ByteWriter { return self.poll_terminal(cx); }; - match writer.poll_send(bytes, &mut self.listener, cx) { + match writer.poll_send(bytes, &mut self.wait, cx) { Poll::Ready(Ok(())) => { log::trace!( "byte writer accepted chunk: stream_id={:?} target={:?}", self.stream_id, self.target ); - self.listener = None; + self.wait = None; self.poll_runtime(); Poll::Ready(Ok(())) } @@ -75,7 +75,7 @@ impl ByteWriter { self.target ); self.writer.take(); - self.listener = None; + self.wait = None; self.poll_terminal(cx) } Poll::Pending => Poll::Pending, @@ -97,7 +97,7 @@ impl ByteWriter { self.target ); writer.close(); - self.listener = None; + self.wait = None; self.poll_runtime(); } @@ -136,7 +136,7 @@ impl ByteWriter { stream_id, target, writer: Some(writer), - listener: None, + wait: None, terminal: WriteTerminalState::Armed(terminal), handle, } @@ -174,7 +174,7 @@ impl ByteWriter { self.target, code ); - self.listener = None; + self.wait = None; self.handle.try_send(RuntimeCommand::CloseStream { stream_id: self.stream_id, target: self.target, From 332cfd7f852fa785c872884a712266a04aff6117 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 19:15:30 -0400 Subject: [PATCH 234/304] ByteReader/ByteWriter -> StreamReader/StreamWriter --- ql-runtime/src/command.rs | 4 ++-- ql-runtime/src/driver/mod.rs | 8 ++++---- ql-runtime/src/handle/mod.rs | 6 +++--- ql-runtime/src/handle/reader.rs | 16 ++++++---------- ql-runtime/src/handle/writer.rs | 12 ++++++------ ql-runtime/src/rpc/adapter.rs | 20 ++++++++++---------- ql-runtime/src/rpc/mod.rs | 6 +++--- ql-runtime/src/rpc/request_with_progress.rs | 4 ++-- ql-runtime/src/rpc/subscription.rs | 4 ++-- ql-runtime/src/tests/mod.rs | 6 +++--- ql-runtime/src/tests/rpc.rs | 12 ++++++------ 11 files changed, 47 insertions(+), 51 deletions(-) diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index 95784845..c7ea9489 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -1,7 +1,7 @@ use ql_fsm::NoSessionError; use ql_wire::{CloseTarget, PairingToken, PeerBundle, RouteId, StreamCloseCode, StreamId}; -use crate::{chunk_slot::ChunkSlotRx, ByteReader, QlStreamError}; +use crate::{chunk_slot::ChunkSlotRx, StreamReader, QlStreamError}; pub enum RuntimeCommand { BindPeer { @@ -19,7 +19,7 @@ pub enum RuntimeCommand { route_id: RouteId, request_reader: ChunkSlotRx, request_terminal: oneshot::Sender>, - start: oneshot::Sender>, + start: oneshot::Sender>, }, PollInbound { stream_id: StreamId, diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index c4c48b1d..014a9fe1 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -22,7 +22,7 @@ use self::state::{DriverState, DriverStreamIo, InboundIo, InboundWriteResult, Ou use crate::{ chunk_slot, command::RuntimeCommand, - handle::{ByteReader, ByteWriter, QlStream}, + handle::{QlStream, StreamReader, StreamWriter}, log, platform::{QlPlatform, QlTimer}, QlStreamError, Runtime, RuntimeHandle, @@ -235,7 +235,7 @@ impl DriverState { Some(InboundIo::new(response_writer, response_terminal_tx)), ), ); - let reader = ByteReader::new( + let reader = StreamReader::new( stream_id, CloseTarget::Return, response_reader, @@ -386,14 +386,14 @@ impl DriverState { platform.handle_inbound(QlStream { stream_id, route_id, - reader: ByteReader::new( + reader: StreamReader::new( stream_id, CloseTarget::Origin, request_reader, request_terminal_rx, RuntimeHandle::new(runtime_tx.clone()), ), - writer: ByteWriter::new( + writer: StreamWriter::new( stream_id, CloseTarget::Return, response_writer, diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index 49c7e820..113b5596 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -11,8 +11,8 @@ use crate::{chunk_slot, command::RuntimeCommand}; pub struct QlStream { pub stream_id: StreamId, pub route_id: RouteId, - pub writer: ByteWriter, - pub reader: ByteReader, + pub writer: StreamWriter, + pub reader: StreamReader, } #[derive(Clone)] @@ -70,7 +70,7 @@ impl RuntimeHandle { Ok(QlStream { stream_id, route_id, - writer: ByteWriter::new( + writer: StreamWriter::new( stream_id, CloseTarget::Origin, request_writer, diff --git a/ql-runtime/src/handle/reader.rs b/ql-runtime/src/handle/reader.rs index e0952588..467bc3d6 100644 --- a/ql-runtime/src/handle/reader.rs +++ b/ql-runtime/src/handle/reader.rs @@ -8,13 +8,9 @@ use bytes::Bytes; use event_listener::EventListener; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; -use crate::{ - chunk_slot::ChunkSlotRx, - command::RuntimeCommand, - log, QlStreamError, RuntimeHandle, -}; +use crate::{chunk_slot::ChunkSlotRx, command::RuntimeCommand, log, QlStreamError, RuntimeHandle}; -pub struct ByteReader { +pub struct StreamReader { stream_id: StreamId, target: CloseTarget, reader: Option, @@ -32,9 +28,9 @@ enum TerminalState { // Safety: `ByteReader` contains a `oneshot::Receiver`, which is `!Sync`, but that receiver is // fully encapsulated. No safe API accesses it through `&self`; all access requires `&mut self` // or ownership, so shared references cannot race the receiver state across threads. -unsafe impl Sync for ByteReader {} +unsafe impl Sync for StreamReader {} -impl std::fmt::Debug for ByteReader { +impl std::fmt::Debug for StreamReader { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("InboundByteStream") .field("stream_id", &self.stream_id) @@ -47,7 +43,7 @@ impl std::fmt::Debug for ByteReader { } } -impl ByteReader { +impl StreamReader { pub(crate) fn new( stream_id: StreamId, target: CloseTarget, @@ -177,7 +173,7 @@ impl ByteReader { } } -impl Drop for ByteReader { +impl Drop for StreamReader { fn drop(&mut self) { if matches!(self.terminal, TerminalState::Delivered) { return; diff --git a/ql-runtime/src/handle/writer.rs b/ql-runtime/src/handle/writer.rs index 66addedc..2d2700e6 100644 --- a/ql-runtime/src/handle/writer.rs +++ b/ql-runtime/src/handle/writer.rs @@ -14,7 +14,7 @@ use crate::{ log, QlStreamError, RuntimeHandle, }; -pub struct ByteWriter { +pub struct StreamWriter { stream_id: StreamId, target: CloseTarget, writer: Option, @@ -31,9 +31,9 @@ enum WriteTerminalState { // Safety: `ByteWriter` contains a `oneshot::Receiver`, which is `!Sync`, but that receiver is // fully encapsulated. No safe API accesses it through `&self`; all access requires `&mut self` // or ownership, so shared references cannot race the receiver state across threads. -unsafe impl Sync for ByteWriter {} +unsafe impl Sync for StreamWriter {} -impl std::fmt::Debug for ByteWriter { +impl std::fmt::Debug for StreamWriter { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("OutboundByteStream") .field("stream_id", &self.stream_id) @@ -43,7 +43,7 @@ impl std::fmt::Debug for ByteWriter { } } -impl ByteWriter { +impl StreamWriter { pub fn poll_write( &mut self, bytes: &mut Bytes, @@ -118,13 +118,13 @@ impl ByteWriter { } } -impl Drop for ByteWriter { +impl Drop for StreamWriter { fn drop(&mut self) { self.close_inner(StreamCloseCode::CANCELLED); } } -impl ByteWriter { +impl StreamWriter { pub(crate) fn new( stream_id: StreamId, target: CloseTarget, diff --git a/ql-runtime/src/rpc/adapter.rs b/ql-runtime/src/rpc/adapter.rs index 45976ec1..b7f3c0da 100644 --- a/ql-runtime/src/rpc/adapter.rs +++ b/ql-runtime/src/rpc/adapter.rs @@ -8,12 +8,12 @@ pub use ql_rpc::{ use ql_rpc::{RpcRead, RpcStream, RpcWrite, StreamError}; use ql_wire::{RouteId as WireRouteId, StreamCloseCode as WireStreamCloseCode}; -use crate::{ByteReader, ByteWriter, QlStream, QlStreamError}; +use crate::{QlStream, QlStreamError, StreamReader, StreamWriter}; impl RpcStream for QlStream { type Error = QlStreamError; - type Reader = ByteReader; - type Writer = ByteWriter; + type Reader = StreamReader; + type Writer = StreamWriter; fn route_id(&self) -> Option { let route_id = u32::try_from(self.route_id.into_inner()).ok()?; @@ -25,7 +25,7 @@ impl RpcStream for QlStream { } } -impl RpcRead for ByteReader { +impl RpcRead for StreamReader { type Error = QlStreamError; fn poll_read( @@ -33,15 +33,15 @@ impl RpcRead for ByteReader { max_len: usize, cx: &mut Context<'_>, ) -> Poll, QlStreamError>> { - ByteReader::poll_read(self, max_len, cx) + StreamReader::poll_read(self, max_len, cx) } fn close(self, code: StreamCloseCode) { - ByteReader::close(self, to_wire_close_code(code)); + StreamReader::close(self, to_wire_close_code(code)); } } -impl RpcWrite for ByteWriter { +impl RpcWrite for StreamWriter { type Error = QlStreamError; fn poll_write( @@ -49,15 +49,15 @@ impl RpcWrite for ByteWriter { bytes: &mut Bytes, cx: &mut Context<'_>, ) -> Poll> { - ByteWriter::poll_write(self, bytes, cx) + StreamWriter::poll_write(self, bytes, cx) } fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll> { - ByteWriter::poll_finish(self, cx) + StreamWriter::poll_finish(self, cx) } fn close(self, code: StreamCloseCode) { - ByteWriter::close(self, to_wire_close_code(code)); + StreamWriter::close(self, to_wire_close_code(code)); } } diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index 98c1a125..81ee69a0 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -15,7 +15,7 @@ use ql_rpc::{ }; pub use self::{adapter::*, error::*, request_with_progress::*, subscription::*}; -use crate::{ByteReader, RuntimeHandle}; +use crate::{StreamReader, RuntimeHandle}; #[derive(Clone)] pub struct RpcHandle { @@ -86,7 +86,7 @@ impl RpcHandle { &self, route_id: ql_rpc::RouteId, payload: Vec, - ) -> Result> { + ) -> Result> { let mut stream = self .inner .open_stream(adapter::to_wire_route_id(route_id)) @@ -97,7 +97,7 @@ impl RpcHandle { } } -async fn read_value(mut reader: ByteReader) -> Result> +async fn read_value(mut reader: StreamReader) -> Result> where T: RpcCodec, { diff --git a/ql-runtime/src/rpc/request_with_progress.rs b/ql-runtime/src/rpc/request_with_progress.rs index fe2e78e7..94944af2 100644 --- a/ql-runtime/src/rpc/request_with_progress.rs +++ b/ql-runtime/src/rpc/request_with_progress.rs @@ -11,10 +11,10 @@ use ql_rpc::{ }; use super::RpcError; -use crate::ByteReader; +use crate::StreamReader; pub struct ProgressCall { - pub(super) stream: ByteReader, + pub(super) stream: StreamReader, pub(super) reader: Option>, pub(super) terminal: Option>>, } diff --git a/ql-runtime/src/rpc/subscription.rs b/ql-runtime/src/rpc/subscription.rs index dc74afa0..00792b21 100644 --- a/ql-runtime/src/rpc/subscription.rs +++ b/ql-runtime/src/rpc/subscription.rs @@ -10,10 +10,10 @@ use ql_rpc::{ }; use super::RpcError; -use crate::ByteReader; +use crate::StreamReader; pub struct Subscription { - pub(super) stream: ByteReader, + pub(super) stream: StreamReader, pub(super) reader: Option>, } diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 8c802d68..75d91c31 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -588,7 +588,7 @@ async fn assert_no_status_for( assert!(res.is_err(), "unexpected status event: {status:?}"); } -async fn read_all(mut stream: crate::ByteReader) -> Result, QlStreamError> { +async fn read_all(mut stream: crate::StreamReader) -> Result, QlStreamError> { let mut data = Vec::new(); while let Some(chunk) = next_chunk(&mut stream).await? { data.extend_from_slice(&chunk); @@ -597,7 +597,7 @@ async fn read_all(mut stream: crate::ByteReader) -> Result, QlStreamErro } async fn next_chunk_max( - stream: &mut crate::ByteReader, + stream: &mut crate::StreamReader, max_len: usize, ) -> Result>, crate::QlStreamError> { stream @@ -606,7 +606,7 @@ async fn next_chunk_max( .map(|chunk| chunk.map(|bytes| bytes.to_vec())) } -async fn next_chunk(stream: &mut crate::ByteReader) -> Result>, QlStreamError> { +async fn next_chunk(stream: &mut crate::StreamReader) -> Result>, QlStreamError> { next_chunk_max(stream, usize::MAX).await } diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 6482a4ef..f518f7f2 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -12,7 +12,7 @@ use ql_rpc::{Response, RouteId, StreamCloseCode, SubscriptionResponder}; use ql_wire::RouteId as WireRouteId; use super::*; -use crate::{ByteWriter, QlStream}; +use crate::{QlStream, StreamWriter}; struct Echo; @@ -91,7 +91,7 @@ async fn rpc_router_handles_request() { } impl crate::rpc::RequestHandler for RouterState { - fn handle(self, request: String, response: Response) { + fn handle(self, request: String, response: Response) { let seen = self.seen.clone(); tokio::task::spawn_local(async move { seen.borrow_mut().push(request); @@ -142,7 +142,7 @@ async fn rpc_router_handles_subscription() { fn handle( self, request: Vec, - mut response: SubscriptionResponder, ByteWriter>, + mut response: SubscriptionResponder, StreamWriter>, ) { let seen = self.seen.clone(); tokio::task::spawn_local(async move { @@ -195,7 +195,7 @@ async fn rpc_send_router_handles_request() { } impl crate::rpc::RequestHandler for RouterState { - fn handle(self, request: String, response: crate::rpc::Response) { + fn handle(self, request: String, response: crate::rpc::Response) { let seen = self.seen.clone(); tokio::task::spawn(async move { seen.lock().unwrap().push(request); @@ -244,7 +244,7 @@ async fn rpc_router_enforces_max_request_bytes() { fn handle( self, request: String, - response: crate::rpc::Response, + response: crate::rpc::Response, ) { tokio::task::spawn_local(async move { let _ = response.respond(request).await; @@ -379,7 +379,7 @@ async fn rpc_request_with_progress_supports_progress_then_await() { .await; } -async fn read_rpc_value(mut reader: crate::ByteReader) -> T +async fn read_rpc_value(mut reader: crate::StreamReader) -> T where T: ql_rpc::RpcCodec, T::Error: std::fmt::Debug, From 73d0895970f29e9383346232b607070023fdaa3a Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 19:20:11 -0400 Subject: [PATCH 235/304] ql-runtime: remove load_peer hook --- ql-runtime/src/driver/mod.rs | 3 --- ql-runtime/src/driver/test.rs | 5 ----- ql-runtime/src/platform.rs | 1 - ql-runtime/src/tests/mod.rs | 4 ---- 4 files changed, 13 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 014a9fe1..e60ac357 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -40,9 +40,6 @@ impl Runtime

{ } = self; let mut fsm = QlFsm::new(config.fsm, identity, now()); - if let Some(peer) = platform.load_peer().await { - fsm.bind_peer(peer); - } let mut state = DriverState { streams: HashMap::new(), diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index eefeaa48..b272e56f 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -6,7 +6,6 @@ use super::*; use crate::{ chunk_slot, driver::state::{InboundIo, OutboundIo}, - platform::PlatformFuture, }; pub struct NoopTimer; @@ -31,10 +30,6 @@ impl QlPlatform for NoopCrypto { NoopTimer } - fn load_peer(&self) -> PlatformFuture<'_, Option> { - Box::pin(async { None }) - } - fn persist_peer(&self, _peer: PeerBundle) {} fn handle_peer_status(&self, _peer: XID, _status: ql_fsm::PeerStatus) {} diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs index 886504d9..e5946534 100644 --- a/ql-runtime/src/platform.rs +++ b/ql-runtime/src/platform.rs @@ -26,7 +26,6 @@ pub trait QlPlatform: QlCrypto { fn write_message(&self, message: Vec) -> Self::WriteMessageFut<'_>; fn timer(&self) -> Self::Timer; - fn load_peer(&self) -> PlatformFuture<'_, Option>; fn persist_peer(&self, peer: PeerBundle); fn handle_peer_status(&self, peer: XID, status: PeerStatus); diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 75d91c31..8371503a 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -393,10 +393,6 @@ impl crate::platform::QlPlatform for TestPlatform { TokioTimer::new() } - fn load_peer(&self) -> PlatformFuture<'_, Option> { - Box::pin(async { None }) - } - fn persist_peer(&self, _peer: PeerBundle) {} fn handle_peer_status(&self, peer: XID, status: PeerStatus) { From 0d9670ca1369c21f6de5730cd7375e2b14281c4c Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 15 Apr 2026 20:07:55 -0400 Subject: [PATCH 236/304] ql-runtime: tests --- ql-runtime/src/tests/mod.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 8371503a..8cc7a449 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -538,8 +538,7 @@ async fn run_local_test(future: F) where F: Future, { - let local = LocalSet::new(); - local.run_until(future).await; + run_local_test_timeout(Duration::from_secs(5), future).await } #[allow(clippy::future_not_send)] @@ -547,7 +546,9 @@ async fn run_local_test_timeout(duration: Duration, future: F) where F: Future, { - tokio::time::timeout(duration, run_local_test(future)) + let local = LocalSet::new(); + let future = local.run_until(future); + tokio::time::timeout(duration, future) .await .unwrap_or_else(|_| panic!("local runtime test exceeded {duration:?}")); } From f7265743e715321069229b174cce157c86ceae68 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 07:48:03 -0400 Subject: [PATCH 237/304] ql-runtime: poll stream write loop --- ql-runtime/src/driver/mod.rs | 25 +++++++++++++-------- ql-runtime/src/tests/stream.rs | 41 ++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 9 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index e60ac357..6067614f 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -593,16 +593,23 @@ impl DriverState { return; }; - let capacity = writer.capacity(); - log::trace!("stream write capacity: stream_id={stream_id} capacity={capacity}"); - if capacity > 0 { - if let Ok(mut bytes) = reader.try_recv(capacity) { - if !bytes.is_empty() { - let _len = bytes.len(); - log::trace!("writing stream bytes: stream_id={stream_id} len={_len}"); - let _ = writer.write(&mut bytes); - } + loop { + let capacity = writer.capacity(); + log::trace!("stream write capacity: stream_id={stream_id} capacity={capacity}"); + if capacity == 0 { + break; } + + let Ok(mut bytes) = reader.try_recv(capacity) else { + break; + }; + if bytes.is_empty() { + break; + } + + let _len = bytes.len(); + log::trace!("writing stream bytes: stream_id={stream_id} len={_len}"); + let _ = writer.write(&mut bytes); } if reader.is_finished() { diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index b3ea0779..f464a90c 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -623,3 +623,44 @@ async fn reproducer_writer_stalls_after_reverse_path_impairment() { }) .await; } + +#[tokio::test(flavor = "current_thread")] +async fn responder_drains_multiple_local_chunks_per_writable_wake() { + run_local_test(async { + let chunk_len = 4104usize; + let chunk_count = 5usize; + let expected = vec![0x5a; chunk_len * chunk_count]; + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + let _ = read_all(inbound.reader).await.unwrap(); + + let mut writer = inbound.writer; + for _ in 0..chunk_count { + writer.write(Bytes::from(vec![0x5a; chunk_len])).await.unwrap(); + } + writer.finish().await.unwrap(); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + stream.writer.write(Bytes::from_static(b"request")).await.unwrap(); + stream.writer.finish().await.unwrap(); + + let received = read_all(stream.reader).await.unwrap(); + assert_eq!(received, expected); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} From 620216e62a4d2de783c902363704331792b379d8 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 08:49:15 -0400 Subject: [PATCH 238/304] ql-fsm: stream finished event --- ql-fsm/src/fsm.rs | 118 ++++++++++++++------------ ql-fsm/src/lib.rs | 40 ++++++++- ql-fsm/src/session/mod.rs | 97 +++++++++++++-------- ql-fsm/src/session/stream_ops.rs | 38 ++++++--- ql-fsm/src/session/tests.rs | 141 ++++++++++++++++++++----------- 5 files changed, 288 insertions(+), 146 deletions(-) diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs index 5c2874e6..9114a19b 100644 --- a/ql-fsm/src/fsm.rs +++ b/ql-fsm/src/fsm.rs @@ -7,10 +7,53 @@ use bytes::Bytes; use ql_wire::{self as wire, QlCrypto, RouteId, SessionCloseCode, StreamId, WireDecode}; use crate::{ - handshake, session::SessionEvent, state::LinkState, Event, NoPeerError, NoSessionError, - OutboundWrite, QlFsm, ReceiveError, StreamError, StreamOps, WriteId, + handshake, + session::{EventSink, SessionEvent}, + state::LinkState, + Event, NoPeerError, NoSessionError, OutboundWrite, QlFsm, ReceiveError, StreamError, WriteId, }; +pub struct FsmEventEmitter<'a> { + events: &'a mut VecDeque, +} + +impl EventSink for FsmEventEmitter<'_> { + fn emit(&mut self, event: SessionEvent) { + match event { + SessionEvent::Opened { + stream_id, + route_id, + } => { + self.events.push_back(Event::Opened { + stream_id, + route_id, + }); + } + SessionEvent::Readable(stream_id) => { + self.events.push_back(Event::Readable(stream_id)); + } + SessionEvent::Writable(stream_id) => { + self.events.push_back(Event::Writable(stream_id)); + } + SessionEvent::Finished(stream_id) => { + self.events.push_back(Event::Finished(stream_id)); + } + SessionEvent::OutboundFinished(stream_id) => { + self.events.push_back(Event::OutboundFinished(stream_id)); + } + SessionEvent::Closed(frame) => { + self.events.push_back(Event::Closed(frame)); + } + SessionEvent::WritableClosed(frame) => { + self.events.push_back(Event::WritableClosed(frame)); + } + SessionEvent::SessionClosed(close) => { + self.events.push_back(Event::SessionClosed(close)); + } + } + } +} + pub fn handle_bind_peer(fsm: &mut QlFsm, peer: ql_wire::PeerBundle) { fsm.state.handshake = None; fsm.state.link = LinkState::Idle; @@ -72,10 +115,9 @@ pub fn receive( let plaintext = Bytes::from(bytes).slice(len - decrypt_len..); let frames = wire::parse_session_frames(plaintext); + let mut emit = FsmEventEmitter { events }; conn.session - .receive(state.now.instant, seq, frames, |event| { - forward_session_event(event, events); - }); + .receive(state.now.instant, seq, frames, &mut emit); if conn.session.is_closed() { apply_session_closed(fsm); @@ -93,9 +135,8 @@ pub fn on_timer(fsm: &mut QlFsm) { return; }; - conn.session.on_timer(state.now.instant, |event| { - forward_session_event(event, events); - }); + let mut emit = FsmEventEmitter { events }; + conn.session.on_timer(state.now.instant, &mut emit); if conn.session.is_closed() { apply_session_closed(fsm); @@ -155,19 +196,27 @@ pub fn close_session(fsm: &mut QlFsm, code: SessionCloseCode) { let Some(conn) = state.link.connected_mut() else { return; }; - conn.session.close(code, |event| { - forward_session_event(event, events); - }); + let mut emit = FsmEventEmitter { events }; + conn.session.close(code, &mut emit); } -pub fn open_stream(fsm: &mut QlFsm, route_id: RouteId) -> Result, NoSessionError> { - let conn = fsm.state.link.connected_mut_or_err()?; - conn.session.open_stream(route_id) +pub fn open_stream( + fsm: &mut QlFsm, + route_id: RouteId, +) -> Result, NoSessionError> { + let QlFsm { state, events, .. } = fsm; + let conn = state.link.connected_mut_or_err()?; + let inner = conn + .session + .open_stream(route_id, FsmEventEmitter { events })?; + Ok(crate::StreamOps { inner }) } -pub fn stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result, StreamError> { - let conn = fsm.state.link.connected_mut_or_err()?; - conn.session.stream(stream_id) +pub fn stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result, StreamError> { + let QlFsm { state, events, .. } = fsm; + let conn = state.link.connected_mut_or_err()?; + let inner = conn.session.stream(stream_id, FsmEventEmitter { events })?; + Ok(crate::StreamOps { inner }) } pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), NoSessionError> { @@ -186,41 +235,6 @@ pub fn emit_peer_status(fsm: &mut QlFsm) { } } -fn forward_session_event(event: SessionEvent, events: &mut VecDeque) { - match event { - SessionEvent::Opened { - stream_id, - route_id, - } => { - events.push_back(Event::Opened { - stream_id, - route_id, - }); - } - SessionEvent::Readable(stream_id) => { - events.push_back(Event::Readable(stream_id)); - } - SessionEvent::Writable(stream_id) => { - events.push_back(Event::Writable(stream_id)); - } - SessionEvent::Finished(stream_id) => { - events.push_back(Event::Finished(stream_id)); - } - SessionEvent::OutboundFinished(stream_id) => { - events.push_back(Event::OutboundFinished(stream_id)); - } - SessionEvent::Closed(frame) => { - events.push_back(Event::Closed(frame)); - } - SessionEvent::WritableClosed(frame) => { - events.push_back(Event::WritableClosed(frame)); - } - SessionEvent::SessionClosed(close) => { - events.push_back(Event::SessionClosed(close)); - } - } -} - fn apply_session_closed(fsm: &mut QlFsm) { if matches!(fsm.state.link, LinkState::Connected(_)) { fsm.state.link = LinkState::Idle; diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 5a0e9eaa..b8ac956e 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -38,7 +38,7 @@ use ql_wire::{ PairingToken, PeerBundle, QlCrypto, QlIdentity, RouteId, SessionClose, SessionCloseCode, StreamClose, StreamId, }; -pub use session::{StreamOps, StreamReadIter, StreamWriter}; +pub use session::{SessionEvent, StreamReadIter, StreamWriter}; use crate::{ replay_cache::ReplayCache, @@ -81,7 +81,7 @@ pub enum Event { Readable(StreamId), /// a stream has room for more local writes Writable(StreamId), - /// the peer finished writing this stream + /// the peer finished writing this stream and no more bytes remain to read Finished(StreamId), /// our local FIN was acknowledged by the peer at the session layer OutboundFinished(StreamId), @@ -110,6 +110,42 @@ pub struct OutboundWrite { pub write_id: Option, } +pub struct StreamOps<'a> { + inner: session::StreamOps<'a, fsm::FsmEventEmitter<'a>>, +} + +impl<'a> StreamOps<'a> { + /// returns this stream's identifier + pub fn stream_id(&self) -> StreamId { + self.inner.stream_id() + } + + /// returns the readable stream bytes as owned `Bytes` views without consuming them + pub fn read(&self) -> StreamReadIter<'_> { + self.inner.read() + } + + /// returns how many bytes can be read from the stream + pub fn readable_bytes(&self) -> usize { + self.inner.readable_bytes() + } + + /// marks previously read bytes as consumed + pub fn commit_read(&mut self, len: usize) -> Result<(), CommitReadError> { + self.inner.commit_read(len) + } + + /// returns a writer if the local write side is still open + pub fn writer(&mut self) -> Option> { + self.inner.writer() + } + + /// closes the origin lane, return lane, or both lanes of the stream + pub fn close(&mut self, target: ql_wire::CloseTarget, code: ql_wire::StreamCloseCode) { + self.inner.close(target, code); + } +} + /// timing and buffering knobs for `QlFsm` #[derive(Debug, Clone, Copy)] pub struct QlFsmConfig { diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index ef9d43a9..3978bbb7 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -80,6 +80,19 @@ pub enum SessionEvent { SessionClosed(SessionClose), } +pub(crate) trait EventSink { + fn emit(&mut self, event: SessionEvent); +} + +impl EventSink for F +where + F: FnMut(SessionEvent), +{ + fn emit(&mut self, event: SessionEvent) { + self(event); + } +} + pub struct SessionFsm { config: SessionConfig, state: SessionState, @@ -116,7 +129,14 @@ impl SessionFsm { } } - pub fn open_stream(&mut self, route_id: RouteId) -> Result, NoSessionError> { + pub fn open_stream( + &mut self, + route_id: RouteId, + sink: E, + ) -> Result, NoSessionError> + where + E: EventSink, + { self.ensure_session_open()?; let stream_id = self .config @@ -133,16 +153,23 @@ impl SessionFsm { ), ); let stream_index = self.state.streams.len() - 1; - Ok(StreamOps::new(self, stream_id, stream_index)) + Ok(StreamOps::new(self, stream_id, stream_index, sink)) } - pub fn stream(&mut self, stream_id: StreamId) -> Result, StreamError> { + pub fn stream( + &mut self, + stream_id: StreamId, + sink: E, + ) -> Result, StreamError> + where + E: EventSink, + { self.ensure_session_open()?; let Some(stream_index) = self.state.streams.get_index_of(&stream_id) else { return Err(StreamError::MissingStream); }; - Ok(StreamOps::new(self, stream_id, stream_index)) + Ok(StreamOps::new(self, stream_id, stream_index, sink)) } pub fn queue_ping(&mut self) -> Result<(), NoSessionError> { @@ -151,7 +178,7 @@ impl SessionFsm { Ok(()) } - pub(crate) fn close(&mut self, code: SessionCloseCode, mut emit: impl FnMut(SessionEvent)) { + pub(crate) fn close(&mut self, code: SessionCloseCode, sink: &mut impl EventSink) { if self.state.phase != SessionPhase::Open { return; } @@ -161,7 +188,7 @@ impl SessionFsm { self.state.tracked_records.clear(); self.state.ack_tracker.clear_ack_state(); self.clear_streams(); - emit(SessionEvent::SessionClosed(close)); + sink.emit(SessionEvent::SessionClosed(close)); } pub(crate) fn is_closed(&self) -> bool { @@ -173,7 +200,7 @@ impl SessionFsm { now: Instant, seq: RecordSeq, frames: I, - mut emit: impl FnMut(SessionEvent), + sink: &mut impl EventSink, ) where I: IntoIterator, WireError>>, { @@ -200,28 +227,28 @@ impl SessionFsm { for frame in frames { let Ok(frame) = frame else { - self.close(SessionCloseCode::PROTOCOL, &mut emit); + self.close(SessionCloseCode::PROTOCOL, sink); return; }; ack_eliciting |= !matches!(frame, SessionFrame::Ack(_)); match frame { SessionFrame::Ping => {} - SessionFrame::Ack(ack) => self.process_record_ack(&ack, &mut emit), + SessionFrame::Ack(ack) => self.process_record_ack(&ack, sink), SessionFrame::StreamData(frame) => { - if self.handle_stream_data(frame, &mut emit).is_err() { - self.close(SessionCloseCode::PROTOCOL, &mut emit); + if self.handle_stream_data(frame, sink).is_err() { + self.close(SessionCloseCode::PROTOCOL, sink); return; } } - SessionFrame::StreamWindow(frame) => self.handle_stream_window(&frame, &mut emit), + SessionFrame::StreamWindow(frame) => self.handle_stream_window(&frame, sink), SessionFrame::StreamClose(frame) => { - if self.handle_stream_close(&frame, &mut emit).is_err() { - self.close(SessionCloseCode::PROTOCOL, &mut emit); + if self.handle_stream_close(&frame, sink).is_err() { + self.close(SessionCloseCode::PROTOCOL, sink); return; } } SessionFrame::Close(close) => { - self.close(close.code, &mut emit); + self.close(close.code, sink); handled_close = true; break; } @@ -272,7 +299,7 @@ impl SessionFsm { } } - pub fn on_timer(&mut self, now: Instant, mut emit: impl FnMut(SessionEvent)) { + pub fn on_timer(&mut self, now: Instant, sink: &mut impl EventSink) { if !self.state.phase.is_open() { return; } @@ -280,7 +307,7 @@ impl SessionFsm { if !self.config.peer_timeout.is_zero() && self.state.last_inbound_at + self.config.peer_timeout <= now { - self.close(SessionCloseCode::TIMEOUT, &mut emit); + self.close(SessionCloseCode::TIMEOUT, sink); return; } if self.state.phase == SessionPhase::Open @@ -525,7 +552,7 @@ impl SessionFsm { } } - fn process_record_ack(&mut self, ack: &RecordAck, emit: &mut impl FnMut(SessionEvent)) { + fn process_record_ack(&mut self, ack: &RecordAck, sink: &mut impl EventSink) { let stream_send_buffer_size = self.config.stream_send_buffer_size; let acked_records = self .state @@ -542,7 +569,7 @@ impl SessionFsm { &mut self.state.streams, stream_send_buffer_size, frame, - emit, + sink, ); } } @@ -582,7 +609,7 @@ impl SessionFsm { fn handle_stream_data( &mut self, frame: StreamData, - emit: &mut impl FnMut(SessionEvent), + sink: &mut impl EventSink, ) -> Result<(), ()> { let StreamData { stream_id, @@ -631,14 +658,15 @@ impl SessionFsm { // within the finalized byte range and any repeated FIN lands on that same offset. if (!frame.fin || frame_end == final_offset) && frame_end <= final_offset { if let Some(route_id) = opened_route { - emit(SessionEvent::Opened { + sink.emit(SessionEvent::Opened { stream_id, route_id, }); if readable_before > 0 { - emit(SessionEvent::Readable(stream_id)); + sink.emit(SessionEvent::Readable(stream_id)); + } else { + sink.emit(SessionEvent::Finished(stream_id)); } - emit(SessionEvent::Finished(stream_id)); } return Ok(()); } @@ -654,28 +682,29 @@ impl SessionFsm { } if let Some(route_id) = opened_route { - emit(SessionEvent::Opened { + sink.emit(SessionEvent::Opened { stream_id, route_id, }); } if stream.route_id.is_some() && readable_before == 0 && stream.readable_bytes() > 0 { - emit(SessionEvent::Readable(stream_id)); + sink.emit(SessionEvent::Readable(stream_id)); } if stream.route_id.is_some() && !was_finished && matches!(stream.inbound_state, InboundState::Finished) + && stream.readable_bytes() == 0 { - emit(SessionEvent::Finished(stream_id)); + sink.emit(SessionEvent::Finished(stream_id)); } self.try_reap_stream(stream_id); Ok(()) } - fn handle_stream_window(&mut self, frame: &StreamWindow, emit: &mut impl FnMut(SessionEvent)) { + fn handle_stream_window(&mut self, frame: &StreamWindow, sink: &mut impl EventSink) { let Some(stream) = self.state.streams.get_mut(&frame.stream_id) else { return; }; @@ -686,14 +715,14 @@ impl SessionFsm { stream.peer_max_offset = maximum_offset; } if was_full && stream.send_capacity(self.config.stream_send_buffer_size) > 0 { - emit(SessionEvent::Writable(frame.stream_id)); + sink.emit(SessionEvent::Writable(frame.stream_id)); } } fn handle_stream_close( &mut self, frame: &StreamClose, - emit: &mut impl FnMut(SessionEvent), + sink: &mut impl EventSink, ) -> Result<(), ()> { let stream_id = frame.stream_id; let stream = match self.state.streams.get_mut(&stream_id) { @@ -712,7 +741,7 @@ impl SessionFsm { { stream.inbound_state = InboundState::Closed(frame.clone()); stream.reset_recv(); - emit(SessionEvent::Closed(frame.clone())); + sink.emit(SessionEvent::Closed(frame.clone())); } if Self::target_affects_outbound(stream.role, frame.target) && !matches!(stream.outbound_state, OutboundState::Closed) @@ -720,7 +749,7 @@ impl SessionFsm { stream.outbound_state = OutboundState::Closed; stream.tx.clear(); stream.pending_close = None; - emit(SessionEvent::WritableClosed(frame.clone())); + sink.emit(SessionEvent::WritableClosed(frame.clone())); } self.try_reap_stream(frame.stream_id); Ok(()) @@ -955,7 +984,7 @@ fn acknowledge_tracked_frame( streams: &mut IndexMap, stream_send_buffer_size: usize, frame: &TrackedFrame, - emit: &mut impl FnMut(SessionEvent), + sink: &mut impl EventSink, ) { match frame { TrackedFrame::StreamClose(_) => {} @@ -970,10 +999,10 @@ fn acknowledge_tracked_frame( fin: frame.fin, }); if was_full && stream.send_capacity(stream_send_buffer_size) > 0 { - emit(SessionEvent::Writable(stream_id)); + sink.emit(SessionEvent::Writable(stream_id)); } if had_unacked_fin && !stream.tx.has_unacked_fin() { - emit(SessionEvent::OutboundFinished(stream_id)); + sink.emit(SessionEvent::OutboundFinished(stream_id)); } } } diff --git a/ql-fsm/src/session/stream_ops.rs b/ql-fsm/src/session/stream_ops.rs index 8e4f1ff7..30bd7a7c 100644 --- a/ql-fsm/src/session/stream_ops.rs +++ b/ql-fsm/src/session/stream_ops.rs @@ -1,23 +1,30 @@ use ql_wire::{CloseTarget, StreamClose, StreamCloseCode, StreamId}; -use super::{state::StreamState, stream_rx::StreamReadIter, SessionFsm}; +use super::{ + state::{InboundState, StreamState}, + stream_rx::StreamReadIter, + SessionEvent, EventSink, SessionFsm, +}; use crate::CommitReadError; -pub struct StreamOps<'a> { +pub struct StreamOps<'a, E> { session: &'a mut SessionFsm, + emit: E, stream_id: StreamId, stream_index: usize, reap_on_drop: bool, } -impl<'a> StreamOps<'a> { +impl<'a, E: EventSink> StreamOps<'a, E> { pub(super) fn new( session: &'a mut SessionFsm, stream_id: StreamId, stream_index: usize, + emit: E, ) -> Self { Self { session, + emit, stream_id, stream_index, reap_on_drop: false, @@ -41,13 +48,22 @@ impl<'a> StreamOps<'a> { /// marks previously read bytes as consumed pub fn commit_read(&mut self, len: usize) -> Result<(), CommitReadError> { - let stream = self.stream_mut(); - if len > stream.readable_bytes() { - return Err(CommitReadError); - } - stream.rx.consume(len); - if stream.recv_limit() > stream.advertised_max_offset { - stream.pending_window = true; + let stream_id = self.stream_id; + let emit_finished = { + let stream = self.stream_mut(); + if len > stream.readable_bytes() { + return Err(CommitReadError); + } + stream.rx.consume(len); + if stream.recv_limit() > stream.advertised_max_offset { + stream.pending_window = true; + } + stream.route_id.is_some() + && matches!(stream.inbound_state, InboundState::Finished) + && stream.readable_bytes() == 0 + }; + if emit_finished { + self.emit.emit(SessionEvent::Finished(stream_id)); } self.reap_on_drop = true; Ok(()) @@ -85,7 +101,7 @@ impl<'a> StreamOps<'a> { } } -impl Drop for StreamOps<'_> { +impl Drop for StreamOps<'_, E> { fn drop(&mut self) { if !self.reap_on_drop { return; diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 3168399e..681523b5 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -47,18 +47,31 @@ fn opened(stream_id: StreamId) -> SessionEvent { } fn open_stream_id(fsm: &mut SessionFsm) -> StreamId { - fsm.open_stream(route_id(1)).unwrap().stream_id() + fsm.open_stream(route_id(1), |_| {}).unwrap().stream_id() } fn write_stream_bytes(fsm: &mut SessionFsm, stream_id: StreamId, bytes: &[u8]) -> usize { let mut bytes = Bytes::copy_from_slice(bytes); - let mut stream = fsm.stream(stream_id).unwrap(); + let mut stream = fsm.stream(stream_id, |_| {}).unwrap(); let mut writer = stream.writer().unwrap(); writer.write(&mut bytes) } fn read_stream_all(fsm: &mut SessionFsm, stream_id: StreamId) -> Vec { - let mut stream = fsm.stream(stream_id).unwrap(); + let mut stream = fsm.stream(stream_id, |_| {}).unwrap(); + let out = stream.read().flatten().collect::>(); + stream.commit_read(out.len()).unwrap(); + out +} + +fn read_stream_all_with_events( + fsm: &mut SessionFsm, + stream_id: StreamId, + events: &mut Vec, +) -> Vec { + let mut stream = fsm + .stream(stream_id, |event| events.push(event)) + .unwrap(); let out = stream.read().flatten().collect::>(); stream.commit_read(out.len()).unwrap(); out @@ -107,7 +120,8 @@ fn receive_events( let bytes = Bytes::from(builder.bytes().to_vec()); let frames = parse_session_frames(bytes); let mut events = Vec::new(); - fsm.receive(now, seq, frames, |event| events.push(event)); + let mut emit = |event| events.push(event); + fsm.receive(now, seq, frames, &mut emit); events } @@ -136,7 +150,8 @@ fn retransmit_uses_new_record_seq() { assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"retry"), 5); let (first_seq, first) = next_outbound(&mut fsm, now).unwrap(); - fsm.on_timer(now + Duration::from_millis(200), |_| {}); + let mut emit = |_| {}; + fsm.on_timer(now + Duration::from_millis(200), &mut emit); let (retried_seq, retried) = next_outbound(&mut fsm, now + Duration::from_millis(200)).unwrap(); assert_ne!(first_seq, retried_seq); @@ -197,11 +212,12 @@ fn ack_reopens_write_capacity() { let (record_seq, _record) = next_outbound(&mut fsm, now).unwrap(); let mut events = Vec::new(); + let mut emit = |event| events.push(event); fsm.receive( now + Duration::from_millis(1), seq(9), std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), - |event| events.push(event), + &mut emit, ); assert!(events.contains(&SessionEvent::Writable(stream_id))); @@ -215,7 +231,7 @@ fn ack_of_fin_emits_outbound_finished_once() { let stream_id = open_stream_id(&mut fsm); assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"done"), 4); - fsm.stream(stream_id).unwrap().writer().unwrap().finish(); + fsm.stream(stream_id, |_| {}).unwrap().writer().unwrap().finish(); let (record_seq, record) = next_outbound(&mut fsm, now).unwrap(); assert!(matches!( @@ -228,20 +244,26 @@ fn ack_of_fin_emits_outbound_finished_once() { )); let mut events = Vec::new(); - fsm.receive( - now + Duration::from_millis(1), - seq(9), - std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), - |event| events.push(event), - ); + { + let mut emit = |event| events.push(event); + fsm.receive( + now + Duration::from_millis(1), + seq(9), + std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), + &mut emit, + ); + } assert_eq!(events, vec![SessionEvent::OutboundFinished(stream_id)]); - fsm.receive( - now + Duration::from_millis(2), - seq(10), - std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), - |event| events.push(event), - ); + { + let mut emit = |event| events.push(event); + fsm.receive( + now + Duration::from_millis(2), + seq(10), + std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), + &mut emit, + ); + } assert_eq!(events, vec![SessionEvent::OutboundFinished(stream_id)]); } @@ -276,7 +298,7 @@ fn commit_stream_read_is_what_advances_stream_window() { assert!(matches!(first.as_slice(), [SessionFrame::Ack(_)])); let read = fsm - .stream(stream_id) + .stream(stream_id, |_| {}) .unwrap() .read() .map(|chunk| chunk.len()) @@ -285,7 +307,7 @@ fn commit_stream_read_is_what_advances_stream_window() { assert!(next_outbound(&mut fsm, now + Duration::from_millis(2)).is_none()); - fsm.stream(stream_id).unwrap().commit_read(2).unwrap(); + fsm.stream(stream_id, |_| {}).unwrap().commit_read(2).unwrap(); let (_second_seq, second) = next_outbound(&mut fsm, now + Duration::from_millis(3)).unwrap(); assert!(matches!( second.as_slice(), @@ -318,7 +340,8 @@ fn pure_ack_only_records_are_fire_and_forget() { assert!(write_id.is_none()); assert!(matches!(ack.as_slice(), [SessionFrame::Ack(_)])); - fsm.on_timer(now + retransmit_timeout + Duration::from_millis(1), |_| {}); + let mut emit = |_| {}; + fsm.on_timer(now + retransmit_timeout + Duration::from_millis(1), &mut emit); assert!(fsm .take_next_write(now + retransmit_timeout + Duration::from_millis(1)) .is_none()); @@ -340,13 +363,31 @@ fn inbound_stream_data_emits_opened_and_readable() { let events = receive_events(&mut fsm, now, seq(0), &record); assert_eq!( events, - vec![ - opened(stream_id), - SessionEvent::Readable(stream_id), - SessionEvent::Finished(stream_id) - ] + vec![opened(stream_id), SessionEvent::Readable(stream_id)] + ); + let mut events = Vec::new(); + assert_eq!( + read_stream_all_with_events(&mut fsm, stream_id, &mut events), + b"hello".to_vec() ); - assert_eq!(read_stream_all(&mut fsm, stream_id), b"hello".to_vec()); + assert_eq!(events, vec![SessionEvent::Finished(stream_id)]); +} + +#[test] +fn inbound_empty_fin_emits_finished_immediately() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = stream_id(1); + let record = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + header: header(1), + fin: true, + bytes: Vec::new(), + })]; + + let events = receive_events(&mut fsm, now, seq(0), &record); + assert_eq!(events, vec![opened(stream_id), SessionEvent::Finished(stream_id)]); } #[test] @@ -355,7 +396,7 @@ fn remote_stream_close_is_reliable_and_retried() { let mut fsm = SessionFsm::new(SessionConfig::default(), now); let stream_id = open_stream_id(&mut fsm); - fsm.stream(stream_id) + fsm.stream(stream_id, |_| {}) .unwrap() .close(CloseTarget::Both, StreamCloseCode::CANCELLED); @@ -367,7 +408,8 @@ fn remote_stream_close_is_reliable_and_retried() { [SessionFrame::StreamClose(StreamClose { stream_id: id, .. })] if *id == stream_id )); - fsm.on_timer(now + Duration::from_millis(200), |_| {}); + let mut emit = |_| {}; + fsm.on_timer(now + Duration::from_millis(200), &mut emit); let (_retried_seq, retried) = next_outbound(&mut fsm, now + Duration::from_millis(200)).unwrap(); assert_eq!(first, retried); @@ -386,7 +428,7 @@ fn stream_ids_follow_even_odd_xid_ordering() { }, now, ) - .open_stream(route_id(1)) + .open_stream(route_id(1), |_| {}) .unwrap() .stream_id(); let odd_id = SessionFsm::new( @@ -396,7 +438,7 @@ fn stream_ids_follow_even_odd_xid_ordering() { }, now, ) - .open_stream(route_id(1)) + .open_stream(route_id(1), |_| {}) .unwrap() .stream_id(); @@ -501,13 +543,14 @@ fn duplicate_finished_remote_data_after_reap_is_ignored() { let first = receive_events(&mut fsm, now, seq(1), &record); assert_eq!( first, - vec![ - opened(stream_id), - SessionEvent::Readable(stream_id), - SessionEvent::Finished(stream_id), - ] + vec![opened(stream_id), SessionEvent::Readable(stream_id)] + ); + let mut events = Vec::new(); + assert_eq!( + read_stream_all_with_events(&mut fsm, stream_id, &mut events), + b"hello".to_vec() ); - assert_eq!(read_stream_all(&mut fsm, stream_id), b"hello".to_vec()); + assert_eq!(events, vec![SessionEvent::Finished(stream_id)]); let second = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &record); assert!(second.is_empty()); @@ -529,16 +572,17 @@ fn duplicate_finished_remote_data_before_read_is_ignored() { let first = receive_events(&mut fsm, now, seq(1), &record); assert_eq!( first, - vec![ - opened(stream_id), - SessionEvent::Readable(stream_id), - SessionEvent::Finished(stream_id), - ] + vec![opened(stream_id), SessionEvent::Readable(stream_id)] ); let second = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &record); assert!(second.is_empty()); - assert_eq!(read_stream_all(&mut fsm, stream_id), b"hello".to_vec()); + let mut events = Vec::new(); + assert_eq!( + read_stream_all_with_events(&mut fsm, stream_id, &mut events), + b"hello".to_vec() + ); + assert_eq!(events, vec![SessionEvent::Finished(stream_id)]); } #[test] @@ -748,7 +792,8 @@ fn sparse_out_of_order_ack_ranges_page_and_quiesce() { } let retransmit_time = now + sender_config.retransmit_timeout + Duration::from_millis(1); - sender.on_timer(retransmit_time, |_| {}); + let mut emit = |_| {}; + sender.on_timer(retransmit_time, &mut emit); let retransmits = drain_outbound(&mut sender, retransmit_time, originals.len()); assert!(!retransmits.is_empty()); @@ -768,8 +813,10 @@ fn sparse_out_of_order_ack_ranges_page_and_quiesce() { } let final_now = second_ack_time + sender_config.retransmit_timeout + Duration::from_millis(1); - sender.on_timer(final_now, |_| {}); - receiver.on_timer(final_now, |_| {}); + let mut sender_emit = |_| {}; + sender.on_timer(final_now, &mut sender_emit); + let mut receiver_emit = |_| {}; + receiver.on_timer(final_now, &mut receiver_emit); assert!(next_outbound(&mut sender, final_now).is_none()); assert!(next_outbound(&mut receiver, final_now).is_none()); } From 1187983756f50f046a80dd2e2a01d813809bcccf Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 08:56:10 -0400 Subject: [PATCH 239/304] ql-runtime: simplify stream finish logic --- ql-runtime/src/driver/mod.rs | 26 +++----------------------- ql-runtime/src/driver/state.rs | 14 -------------- ql-runtime/src/driver/test.rs | 4 ++-- 3 files changed, 5 insertions(+), 39 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 6067614f..2fddd029 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -324,7 +324,7 @@ impl DriverState { } Event::Finished(stream_id) => { log::info!("peer finished stream writes: stream_id={stream_id}"); - self.handle_inbound_finished(fsm, stream_id); + self.handle_inbound_finished(stream_id); } Event::OutboundFinished(stream_id) => { log::info!("outbound finish acknowledged: stream_id={stream_id}"); @@ -455,35 +455,15 @@ impl DriverState { } drop(stream_ops); - self.finish_inbound_if_ready(fsm, stream_id); } - fn handle_inbound_finished(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { + fn handle_inbound_finished(&mut self, stream_id: StreamId) { log::info!("inbound finished event: stream_id={stream_id}"); - let Some(stream) = self.streams.get_mut(&stream_id) else { - return; - }; - stream.inbound_queue_finish(); - self.finish_inbound_if_ready(fsm, stream_id); - } - - fn finish_inbound_if_ready(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { - if let Ok(stream_ops) = fsm.stream(stream_id) { - if stream_ops.readable_bytes() != 0 { - return; - } - } - let Entry::Occupied(mut entry) = self.streams.entry(stream_id) else { return; }; - let stream = entry.get_mut(); - if !stream.inbound_finish_pending() { - return; - } - log::info!("delivering clean inbound finish: stream_id={stream_id}"); - stream.inbound_finish(); + entry.get_mut().inbound_finish(); Self::try_reap_stream(entry); } diff --git a/ql-runtime/src/driver/state.rs b/ql-runtime/src/driver/state.rs index 9d4776da..0c93e781 100644 --- a/ql-runtime/src/driver/state.rs +++ b/ql-runtime/src/driver/state.rs @@ -135,18 +135,6 @@ impl DriverStreamIo { } } } - - pub fn inbound_queue_finish(&mut self) { - if let Some(inbound) = self.inbound.as_mut() { - inbound.finish_pending = true; - } - } - - pub fn inbound_finish_pending(&self) -> bool { - self.inbound - .as_ref() - .is_some_and(|inbound| inbound.finish_pending) - } } pub struct OutboundIo { @@ -168,7 +156,6 @@ impl OutboundIo { pub struct InboundIo { writer: ChunkSlotTx, terminal: Option>>, - finish_pending: bool, } pub enum InboundWriteResult { @@ -182,7 +169,6 @@ impl InboundIo { Self { writer, terminal: Some(terminal), - finish_pending: false, } } } diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index b272e56f..37783dc6 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -68,7 +68,7 @@ fn new_outbound_io() -> OutboundIo { #[test] fn handle_inbound_finished_reaps_closed_initiator_stream() { - let (mut state, mut fsm) = new_driver_state(); + let (mut state, _fsm) = new_driver_state(); let stream_id = StreamId(1u32.into()); state.streams.insert( @@ -76,7 +76,7 @@ fn handle_inbound_finished_reaps_closed_initiator_stream() { DriverStreamIo::new(true, None, Some(new_inbound_io(1))), ); - state.handle_inbound_finished(&mut fsm, stream_id); + state.handle_inbound_finished(stream_id); assert!(!state.streams.contains_key(&stream_id)); } From 479899dd671d6498232c6eca3a9b5bade50d4273 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 08:57:45 -0400 Subject: [PATCH 240/304] ql-runtime: feature gate kind --- ql-runtime/src/command.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index c7ea9489..918966ac 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -1,7 +1,7 @@ use ql_fsm::NoSessionError; use ql_wire::{CloseTarget, PairingToken, PeerBundle, RouteId, StreamCloseCode, StreamId}; -use crate::{chunk_slot::ChunkSlotRx, StreamReader, QlStreamError}; +use crate::{chunk_slot::ChunkSlotRx, QlStreamError, StreamReader}; pub enum RuntimeCommand { BindPeer { @@ -36,6 +36,7 @@ pub enum RuntimeCommand { } impl RuntimeCommand { + #[cfg(feature = "log")] pub fn kind(&self) -> &'static str { match self { Self::BindPeer { .. } => "BindPeer", From 1a5ce0044f52b8d34c7ac1fec39970c765bed71b Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 09:37:59 -0400 Subject: [PATCH 241/304] ql-wire: separate pairingid and pairingtoken --- ql-wire/src/handshake/mod.rs | 52 ++++++-------------- ql-wire/src/handshake/pairing.rs | 83 ++++++++++++++++++++++++++++++++ ql-wire/src/handshake/xx.rs | 44 ++++++++++------- ql-wire/src/tests.rs | 12 +++-- 4 files changed, 133 insertions(+), 58 deletions(-) create mode 100644 ql-wire/src/handshake/pairing.rs diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 9311617e..79e0f7ad 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -7,12 +7,14 @@ use crate::{ mod ik; mod kk; mod meta; +mod pairing; mod transport_params; mod xx; pub use ik::{Ik1, Ik2, IkHandshake}; pub use kk::{Kk1, Kk2, KkHandshake}; pub use meta::{HandshakeId, HandshakeMeta}; +pub use pairing::{PairingId, PairingToken}; pub use transport_params::TransportParams; pub use xx::{Xx1, Xx2, Xx3, Xx4, XxHandshake}; @@ -53,47 +55,13 @@ impl codec::WireDecode for HandshakeHeader { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -#[repr(transparent)] -pub struct PairingToken(pub [u8; Self::SIZE]); - -impl PairingToken { - pub const SIZE: usize = 16; -} - -impl std::fmt::Display for PairingToken { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - for byte in self.0 { - write!(f, "{byte:02x}")?; - } - Ok(()) - } -} - -impl WireEncode for PairingToken { - fn encoded_len(&self) -> usize { - Self::SIZE - } - - fn encode(&self, out: &mut W) { - self.0.encode(out); - } -} - -impl codec::WireDecode for PairingToken { - fn decode(reader: &mut codec::Reader) -> Result { - Ok(Self(reader.decode()?)) - } -} - -// TODO: this should not be exposed #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct XxHeader { - pub pairing_token: PairingToken, + pub pairing_id: PairingId, } impl XxHeader { - pub const WIRE_SIZE: usize = PairingToken::SIZE; + pub const WIRE_SIZE: usize = PairingId::SIZE; } impl WireEncode for XxHeader { @@ -102,14 +70,14 @@ impl WireEncode for XxHeader { } fn encode(&self, out: &mut W) { - self.pairing_token.encode(out); + self.pairing_id.encode(out); } } impl codec::WireDecode for XxHeader { fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { - pairing_token: reader.decode()?, + pairing_id: reader.decode()?, }) } } @@ -398,6 +366,14 @@ fn init_xx_symmetric(crypto: &impl QlCrypto) -> SymmetricState { SymmetricState::new(crypto, PROTOCOL_XX) } +fn mix_psk_pairing_token( + symmetric: &mut SymmetricState, + crypto: &impl QlCrypto, + pairing_token: PairingToken, +) { + symmetric.mix_key_and_hash(crypto, &pairing_token.psk(crypto)); +} + fn generate_ephemeral_keypair(crypto: &impl QlCrypto) -> EphemeralKeyPair { EphemeralKeyPair { mlkem: crypto.mlkem_generate_keypair(), diff --git a/ql-wire/src/handshake/pairing.rs b/ql-wire/src/handshake/pairing.rs new file mode 100644 index 00000000..237f066b --- /dev/null +++ b/ql-wire/src/handshake/pairing.rs @@ -0,0 +1,83 @@ +use std::fmt::{self, Display, Formatter}; + +use crate::{codec, ByteSlice, QlCrypto, WireEncode, WireError}; + +const PAIRING_ID_DOMAIN: &[u8] = b"ql-wire:pairing-id:v1"; +const PAIRING_PSK_DOMAIN: &[u8] = b"ql-wire:pairing-psk:v1"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct PairingToken(pub [u8; Self::SIZE]); + +impl PairingToken { + pub const SIZE: usize = 16; + + pub fn id(&self, crypto: &impl QlCrypto) -> PairingId { + let hash = crypto.sha256(&[PAIRING_ID_DOMAIN, &self.0]); + let mut id = [0u8; PairingId::SIZE]; + id.copy_from_slice(&hash[..PairingId::SIZE]); + PairingId(id) + } + + pub(super) fn psk(&self, crypto: &impl QlCrypto) -> [u8; 32] { + crypto.sha256(&[PAIRING_PSK_DOMAIN, &self.0]) + } +} + +impl Display for PairingToken { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + for byte in self.0 { + write!(f, "{byte:02x}")?; + } + Ok(()) + } +} + +impl WireEncode for PairingToken { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for PairingToken { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct PairingId(pub [u8; Self::SIZE]); + +impl PairingId { + pub const SIZE: usize = 16; +} + +impl Display for PairingId { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + for byte in self.0 { + write!(f, "{byte:02x}")?; + } + Ok(()) + } +} + +impl WireEncode for PairingId { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for PairingId { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} diff --git a/ql-wire/src/handshake/xx.rs b/ql-wire/src/handshake/xx.rs index 4e8d455b..7dd35943 100644 --- a/ql-wire/src/handshake/xx.rs +++ b/ql-wire/src/handshake/xx.rs @@ -2,13 +2,13 @@ use super::{ decrypt_mlkem_ciphertext, decrypt_peer_bundle, encrypt_mlkem_ciphertext, encrypt_peer_bundle, finalize_handshake, generate_ephemeral_keypair, init_xx_symmetric, initialize_handshake_meta, initialize_transport_params, mix_hash_ephemeral, mix_hash_pairing_handshake, - require_handshake_meta, require_transport_params, EncryptedMlKemCiphertext, - EncryptedPeerBundle, EphemeralKeyPair, EphemeralPublicKey, FinalizedHandshake, Role, - SymmetricState, TransportParams, XxHeader, + mix_psk_pairing_token, require_handshake_meta, require_transport_params, + EncryptedMlKemCiphertext, EncryptedPeerBundle, EphemeralKeyPair, EphemeralPublicKey, + FinalizedHandshake, Role, SymmetricState, TransportParams, XxHeader, }; use crate::{ - codec, ByteSlice, HandshakeKind, HandshakeMeta, MlKemCiphertext, PairingToken, PeerBundle, - QlCrypto, QlIdentity, WireEncode, WireError, + codec, ByteSlice, HandshakeKind, HandshakeMeta, MlKemCiphertext, PairingId, PairingToken, + PeerBundle, QlCrypto, QlIdentity, WireEncode, WireError, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -254,18 +254,26 @@ impl XxHandshake { self.pairing_token } + pub fn pairing_id(&self, crypto: &impl QlCrypto) -> PairingId { + self.pairing_token.id(crypto) + } + pub fn remote_bundle(&self) -> Option<&PeerBundle> { self.remote_bundle.as_ref() } - fn header(&self) -> XxHeader { + fn header(&self, crypto: &impl QlCrypto) -> XxHeader { XxHeader { - pairing_token: self.pairing_token, + pairing_id: self.pairing_token.id(crypto), } } - fn ensure_inbound_header(&self, header: XxHeader) -> Result<(), WireError> { - if header == self.header() { + fn ensure_inbound_header( + &self, + crypto: &impl QlCrypto, + header: XxHeader, + ) -> Result<(), WireError> { + if header == self.header(crypto) { Ok(()) } else { Err(WireError::InvalidPayload) @@ -281,7 +289,7 @@ impl XxHandshake { return Err(WireError::InvalidState); } initialize_handshake_meta(&mut self.handshake_meta, meta)?; - let header = self.header(); + let header = self.header(crypto); mix_hash_pairing_handshake( &mut self.symmetric, crypto, @@ -290,6 +298,7 @@ impl XxHandshake { &meta, self.local_transport_params, ); + mix_psk_pairing_token(&mut self.symmetric, crypto, self.pairing_token); let local_ephemeral = generate_ephemeral_keypair(crypto); let ephemeral = local_ephemeral.public(); @@ -316,7 +325,7 @@ impl XxHandshake { } message.meta.ensure_not_expired(now_seconds)?; initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; - self.ensure_inbound_header(message.header)?; + self.ensure_inbound_header(crypto, message.header)?; mix_hash_pairing_handshake( &mut self.symmetric, crypto, @@ -325,6 +334,7 @@ impl XxHandshake { &message.meta, message.transport_params, ); + mix_psk_pairing_token(&mut self.symmetric, crypto, self.pairing_token); mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); self.remote_ephemeral = Some(message.ephemeral.clone()); @@ -342,7 +352,7 @@ impl XxHandshake { return Err(WireError::InvalidState); } require_handshake_meta(self.handshake_meta.as_ref(), meta)?; - let header = self.header(); + let header = self.header(crypto); mix_hash_pairing_handshake( &mut self.symmetric, crypto, @@ -384,7 +394,7 @@ impl XxHandshake { } message.meta.ensure_not_expired(now_seconds)?; require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; - self.ensure_inbound_header(message.header)?; + self.ensure_inbound_header(crypto, message.header)?; mix_hash_pairing_handshake( &mut self.symmetric, crypto, @@ -421,7 +431,7 @@ impl XxHandshake { return Err(WireError::InvalidState); } require_handshake_meta(self.handshake_meta.as_ref(), meta)?; - let header = self.header(); + let header = self.header(crypto); mix_hash_pairing_handshake( &mut self.symmetric, crypto, @@ -462,7 +472,7 @@ impl XxHandshake { } message.meta.ensure_not_expired(now_seconds)?; require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; - self.ensure_inbound_header(message.header)?; + self.ensure_inbound_header(crypto, message.header)?; require_transport_params( self.remote_transport_params.as_ref(), message.transport_params, @@ -498,7 +508,7 @@ impl XxHandshake { return Err(WireError::InvalidState); } require_handshake_meta(self.handshake_meta.as_ref(), meta)?; - let header = self.header(); + let header = self.header(crypto); mix_hash_pairing_handshake( &mut self.symmetric, crypto, @@ -536,7 +546,7 @@ impl XxHandshake { } message.meta.ensure_not_expired(now_seconds)?; require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; - self.ensure_inbound_header(message.header)?; + self.ensure_inbound_header(crypto, message.header)?; require_transport_params( self.remote_transport_params.as_ref(), message.transport_params, diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 94379580..fe67a4d1 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -55,9 +55,13 @@ fn pairing_token(byte: u8) -> PairingToken { PairingToken([byte; PairingToken::SIZE]) } +fn pairing_id(byte: u8) -> PairingId { + PairingId([byte; PairingId::SIZE]) +} + fn xx_header(byte: u8) -> XxHeader { XxHeader { - pairing_token: pairing_token(byte), + pairing_id: pairing_id(byte), } } @@ -519,7 +523,7 @@ fn kk_handshake_rejects_tampered_transport_params() { } #[test] -fn xx_handshake_rejects_tampered_pairing_token() { +fn xx_handshake_rejects_tampered_pairing_id() { let crypto = SoftwareCrypto; let (initiator, responder) = test_identities(&crypto); let token = pairing_token(7); @@ -532,7 +536,7 @@ fn xx_handshake_rejects_tampered_pairing_token() { let mut m1 = initiator_state .write_1(&crypto, handshake_meta(31)) .unwrap(); - m1.header = xx_header(8); + m1.header.pairing_id = pairing_id(8); assert_eq!( responder_state.read_1(&crypto, 0, &m1), @@ -595,6 +599,8 @@ fn xx_handshake_round_trip_derives_matching_transport_and_learns_remote() { assert_eq!(initiator_state.pairing_token(), token); assert_eq!(responder_state.pairing_token(), token); + assert_eq!(initiator_state.pairing_id(&crypto), token.id(&crypto)); + assert_eq!(responder_state.pairing_id(&crypto), token.id(&crypto)); assert!(initiator_state.remote_bundle().is_none()); assert!(responder_state.remote_bundle().is_none()); From d41b82b35aef2e5231edf4b936b9ae50c1f4367d Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 09:42:53 -0400 Subject: [PATCH 242/304] ql-fsm: use pairing id --- ql-fsm/src/error.rs | 15 ++++------ ql-fsm/src/handshake/xx.rs | 54 +++++++++++++++++------------------ ql-fsm/src/tests/handshake.rs | 7 +++-- 3 files changed, 38 insertions(+), 38 deletions(-) diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs index 88563ea8..82a79d7b 100644 --- a/ql-fsm/src/error.rs +++ b/ql-fsm/src/error.rs @@ -3,7 +3,7 @@ use std::{ fmt::{Display, Formatter}, }; -use ql_wire::{PairingToken, WireError}; +use ql_wire::{PairingId, WireError}; #[derive(Debug, Clone, PartialEq, Eq)] pub enum ReceiveError { @@ -14,9 +14,9 @@ pub enum ReceiveError { InvalidXid, NoSession, NotPairingMode, - InvalidPairingToken { - expected: PairingToken, - actual: PairingToken, + InvalidPairingId { + expected: PairingId, + actual: PairingId, }, Replay, } @@ -31,11 +31,8 @@ impl Display for ReceiveError { Self::InvalidXid => f.write_str("invalid xid"), Self::NoSession => f.write_str("no active session"), Self::NotPairingMode => f.write_str("not in pairing mode"), - Self::InvalidPairingToken { expected, actual } => { - write!( - f, - "invalid pairing token: expected {expected}, actual {actual}" - ) + Self::InvalidPairingId { expected, actual } => { + write!(f, "invalid pairing id: expected {expected}, actual {actual}") } Self::Replay => f.write_str("replay"), } diff --git a/ql-fsm/src/handshake/xx.rs b/ql-fsm/src/handshake/xx.rs index b95879a3..d15b934d 100644 --- a/ql-fsm/src/handshake/xx.rs +++ b/ql-fsm/src/handshake/xx.rs @@ -34,41 +34,41 @@ pub fn handle_xx1( crypto: &impl QlCrypto, message: &Xx1, ) -> Result<(), ReceiveError> { - if should_ignore_inbound(fsm, message) { + if should_ignore_inbound(fsm, crypto, message) { return Ok(()); } if is_replayed_handshake_start(fsm, message.meta) { return Err(ReceiveError::Replay); } match fsm.state.armed_pairing_token { - Some(expected) if expected != message.header.pairing_token => { - return Err(ReceiveError::InvalidPairingToken { - expected, - actual: message.header.pairing_token, + Some(expected) if expected.id(crypto) != message.header.pairing_id => { + return Err(ReceiveError::InvalidPairingId { + expected: expected.id(crypto), + actual: message.header.pairing_id, }); } - Some(_) => {} + Some(token) => { + reset_connected_session_if_needed(fsm); + + let mut handshake = wire::XxHandshake::new_responder( + crypto, + fsm.identity.clone(), + token, + super::local_transport_params(fsm), + ); + handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; + let outbound = handshake.write_2(crypto, message.meta)?; + fsm.state.link = LinkState::XxResponder(XxResponderState { + handshake, + handshake_meta: message.meta, + deadline: fsm.state.now.instant + fsm.config.handshake_timeout, + }); + fsm.state.handshake = None; + enqueue_handshake(fsm, QlHandshakeRecord::Xx2(outbound)); + Ok(()) + } None => return Err(ReceiveError::NotPairingMode), } - - reset_connected_session_if_needed(fsm); - - let mut handshake = wire::XxHandshake::new_responder( - crypto, - fsm.identity.clone(), - message.header.pairing_token, - super::local_transport_params(fsm), - ); - handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; - let outbound = handshake.write_2(crypto, message.meta)?; - fsm.state.link = LinkState::XxResponder(XxResponderState { - handshake, - handshake_meta: message.meta, - deadline: fsm.state.now.instant + fsm.config.handshake_timeout, - }); - fsm.state.handshake = None; - enqueue_handshake(fsm, QlHandshakeRecord::Xx2(outbound)); - Ok(()) } pub fn handle_xx2( @@ -158,12 +158,12 @@ pub fn disarm_pairing(fsm: &mut QlFsm) { } } -pub fn should_ignore_inbound(fsm: &QlFsm, message: &Xx1) -> bool { +pub fn should_ignore_inbound(fsm: &QlFsm, crypto: &impl QlCrypto, message: &Xx1) -> bool { match &fsm.state.link { LinkState::Idle | LinkState::Connected(_) => false, LinkState::IkInitiator(_) | LinkState::KkInitiator(_) | LinkState::XxResponder(_) => true, LinkState::XxInitiator(state) => { - if state.handshake.pairing_token() != message.header.pairing_token { + if state.handshake.pairing_id(crypto) != message.header.pairing_id { return false; } super::local_start_wins(&state.initial_ephemeral, &message.ephemeral) diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index d492757c..9b1d749f 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -136,7 +136,7 @@ fn inbound_xx1_rejects_when_not_in_pairing_mode() { } #[test] -fn inbound_xx1_rejects_mismatched_pairing_token_with_expected_and_actual() { +fn inbound_xx1_rejects_mismatched_pairing_id_with_expected_and_actual() { let mut harness = Harness::paired(QlFsmConfig::default(), false, false); let expected = pairing_token(4); let actual = pairing_token(7); @@ -151,7 +151,10 @@ fn inbound_xx1_rejects_mismatched_pairing_token_with_expected_and_actual() { assert_eq!( err, - Err(ReceiveError::InvalidPairingToken { expected, actual }) + Err(ReceiveError::InvalidPairingId { + expected: expected.id(&SoftwareCrypto), + actual: actual.id(&SoftwareCrypto), + }) ); } From 70e109df7758908d0fd987776e9025ac3db8bd50 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 10:21:49 -0400 Subject: [PATCH 243/304] ql-runtime: get rid of concurrent-queue --- Cargo.lock | 1 - ql-runtime/Cargo.toml | 1 - .../src/{chunk_slot.rs => chunk_slot/mod.rs} | 21 +-- ql-runtime/src/chunk_slot/queue.rs | 142 ++++++++++++++++++ ql-runtime/src/chunk_slot/sync.rs | 66 ++++++++ 5 files changed, 216 insertions(+), 15 deletions(-) rename ql-runtime/src/{chunk_slot.rs => chunk_slot/mod.rs} (95%) create mode 100644 ql-runtime/src/chunk_slot/queue.rs create mode 100644 ql-runtime/src/chunk_slot/sync.rs diff --git a/Cargo.lock b/Cargo.lock index dea9f37f..fd232744 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2236,7 +2236,6 @@ version = "0.1.0" dependencies = [ "async-channel", "bytes", - "concurrent-queue", "event-listener", "futures-lite", "log", diff --git a/ql-runtime/Cargo.toml b/ql-runtime/Cargo.toml index 077faf10..50f7db19 100644 --- a/ql-runtime/Cargo.toml +++ b/ql-runtime/Cargo.toml @@ -13,7 +13,6 @@ rpc = ["dep:ql-rpc"] [dependencies] async-channel = { version = "2.5" } bytes = "1" -concurrent-queue = { version = "2.5" } event-listener = "5.4" futures-lite = { version = "2.5" } log = { version = "0.4", optional = true } diff --git a/ql-runtime/src/chunk_slot.rs b/ql-runtime/src/chunk_slot/mod.rs similarity index 95% rename from ql-runtime/src/chunk_slot.rs rename to ql-runtime/src/chunk_slot/mod.rs index d446859d..b15aee8a 100644 --- a/ql-runtime/src/chunk_slot.rs +++ b/ql-runtime/src/chunk_slot/mod.rs @@ -5,16 +5,12 @@ use std::{ }; use bytes::Bytes; -use concurrent_queue::{ConcurrentQueue, PopError, PushError}; use event_listener::{Event, EventListener}; -mod sync { - #[cfg(not(all(test, loom)))] - pub use std::sync::Arc; +use self::queue::{PopError, PushError, Single}; - #[cfg(all(test, loom))] - pub use loom::sync::Arc; -} +mod queue; +mod sync; use sync::*; @@ -22,7 +18,7 @@ use sync::*; /// receiver-side partial reads keep the remainder locally pub fn new() -> (ChunkSlotRx, ChunkSlotTx) { let shared = Arc::new(Shared { - queue: ConcurrentQueue::bounded(1), + queue: Single::new(), changed: Event::new(), }); @@ -45,7 +41,7 @@ pub struct ChunkSlotTx { } struct Shared { - queue: ConcurrentQueue, + queue: Single, changed: Event, } @@ -319,8 +315,8 @@ mod tests { rx.close(); - let error = tx.send(Bytes::from_static(b"abc")).await.unwrap_err(); - assert_eq!(error.0, Bytes::from_static(b"abc")); + let err = tx.send(Bytes::from_static(b"abc")).await.unwrap_err(); + assert_eq!(err.0, Bytes::from_static(b"abc")); } #[test] @@ -357,8 +353,7 @@ mod loom_tests { } fn check_model(f: impl Fn() + Sync + Send + 'static) { - let mut builder = model::Builder::new(); - builder.preemption_bound = Some(3); + let builder = model::Builder::new(); builder.check(f); } diff --git a/ql-runtime/src/chunk_slot/queue.rs b/ql-runtime/src/chunk_slot/queue.rs new file mode 100644 index 00000000..2b596fa6 --- /dev/null +++ b/ql-runtime/src/chunk_slot/queue.rs @@ -0,0 +1,142 @@ +//! local single-slot queue for `chunk_slot` to avoid `ConcurrentQueue` taking 512 bytes instead of 40 +//! copied from `concurrent_queue::single::Single` in `concurrent-queue` + +use core::mem::MaybeUninit; + +use super::sync::*; + +const LOCKED: usize = 1 << 0; +const PUSHED: usize = 1 << 1; +const CLOSED: usize = 1 << 2; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PopError { + Empty, + Closed, +} + +#[derive(Debug, PartialEq, Eq)] +pub enum PushError { + Full(T), + Closed(T), +} + +/// A single-element queue. +pub struct Single { + state: AtomicUsize, + slot: UnsafeCell>, +} + +unsafe impl Send for Single {} +unsafe impl Sync for Single {} + +impl Single { + /// Creates a new single-element queue. + pub fn new() -> Single { + Single { + state: AtomicUsize::new(0), + slot: UnsafeCell::new(MaybeUninit::uninit()), + } + } + + /// Attempts to push an item into the queue. + pub fn push(&self, value: T) -> Result<(), PushError> { + // Lock and fill the slot. + let state = self + .state + .compare_exchange(0, LOCKED | PUSHED, Ordering::SeqCst, Ordering::SeqCst) + .unwrap_or_else(|x| x); + + if state == 0 { + // Write the value and unlock. + self.slot.with_mut(|slot| unsafe { + slot.write(MaybeUninit::new(value)); + }); + self.state.fetch_and(!LOCKED, Ordering::Release); + Ok(()) + } else if state & CLOSED != 0 { + Err(PushError::Closed(value)) + } else { + Err(PushError::Full(value)) + } + } + + /// Attempts to pop an item from the queue. + pub fn pop(&self) -> Result { + let mut state = PUSHED; + loop { + // Lock and empty the slot. + let prev = self + .state + .compare_exchange( + state, + (state | LOCKED) & !PUSHED, + Ordering::SeqCst, + Ordering::SeqCst, + ) + .unwrap_or_else(|x| x); + + if prev == state { + // Read the value and unlock. + let value = self + .slot + .with_mut(|slot| unsafe { slot.read().assume_init() }); + self.state.fetch_and(!LOCKED, Ordering::Release); + return Ok(value); + } + + if prev & PUSHED == 0 { + if prev & CLOSED == 0 { + return Err(PopError::Empty); + } else { + return Err(PopError::Closed); + } + } + + if prev & LOCKED == 0 { + state = prev; + } else { + busy_wait(); + state = prev & !LOCKED; + } + } + } + + /// Returns the number of items in the queue. + pub fn len(&self) -> usize { + usize::from(self.state.load(Ordering::SeqCst) & PUSHED != 0) + } + + /// Returns `true` if the queue is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Closes the queue. + /// + /// Returns `true` if this call closed the queue. + pub fn close(&self) -> bool { + let state = self.state.fetch_or(CLOSED, Ordering::SeqCst); + state & CLOSED == 0 + } + + /// Returns `true` if the queue is closed. + pub fn is_closed(&self) -> bool { + self.state.load(Ordering::SeqCst) & CLOSED != 0 + } +} + +impl Drop for Single { + fn drop(&mut self) { + // Drop the value in the slot. + let Self { state, slot } = self; + state.with_mut(|state| { + if *state & PUSHED != 0 { + slot.with_mut(|slot| unsafe { + let value = &mut *slot; + value.as_mut_ptr().drop_in_place(); + }); + } + }); + } +} diff --git a/ql-runtime/src/chunk_slot/sync.rs b/ql-runtime/src/chunk_slot/sync.rs new file mode 100644 index 00000000..8e7423b5 --- /dev/null +++ b/ql-runtime/src/chunk_slot/sync.rs @@ -0,0 +1,66 @@ +#[cfg(not(all(test, loom)))] +mod inner { + pub use std::{ + cell::UnsafeCell, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + }; + + pub fn busy_wait() { + std::thread::yield_now(); + } + + pub trait UnsafeCellExt { + type Value; + + fn with_mut(&self, f: F) -> R + where + F: FnOnce(*mut Self::Value) -> R; + } + + impl UnsafeCellExt for UnsafeCell { + type Value = T; + + fn with_mut(&self, f: F) -> R + where + F: FnOnce(*mut Self::Value) -> R, + { + f(self.get()) + } + } + + pub trait AtomicExt { + type Value; + + fn with_mut(&mut self, f: F) -> R + where + F: FnOnce(&mut Self::Value) -> R; + } + + impl AtomicExt for AtomicUsize { + type Value = usize; + + fn with_mut(&mut self, f: F) -> R + where + F: FnOnce(&mut Self::Value) -> R, + { + f(self.get_mut()) + } + } +} + +#[cfg(all(test, loom))] +mod inner { + pub use loom::{ + cell::UnsafeCell, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + thread::yield_now as busy_wait, + }; +} + +pub use inner::*; From 13bef507035ae2ec8c2cf256fab63a62130cf6c2 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 10:55:10 -0400 Subject: [PATCH 244/304] ql-runtime: inbound poller + use Pin more --- ql-runtime/src/command.rs | 2 - ql-runtime/src/driver/mod.rs | 48 +++++++++-------- ql-runtime/src/driver/test.rs | 19 +++++-- ql-runtime/src/handle/mod.rs | 5 -- ql-runtime/src/platform.rs | 15 +++++- ql-runtime/src/tests/handshake.rs | 32 +++++------ ql-runtime/src/tests/heartbeat.rs | 9 ++-- ql-runtime/src/tests/mod.rs | 88 +++++++++++++++++++++---------- ql-runtime/src/tests/stream.rs | 22 ++++---- ql-runtime/src/tests/unpair.rs | 9 ++-- 10 files changed, 153 insertions(+), 96 deletions(-) diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index 918966ac..39c9ac84 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -32,7 +32,6 @@ pub enum RuntimeCommand { target: CloseTarget, code: StreamCloseCode, }, - Receive(Vec), } impl RuntimeCommand { @@ -48,7 +47,6 @@ impl RuntimeCommand { Self::PollInbound { .. } => "PollInbound", Self::PollStream { .. } => "PollStream", Self::CloseStream { .. } => "CloseStream", - Self::Receive(_) => "Receive", } } } diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 2fddd029..8be61c93 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -24,7 +24,7 @@ use crate::{ command::RuntimeCommand, handle::{QlStream, StreamReader, StreamWriter}, log, - platform::{QlPlatform, QlTimer}, + platform::{QlInbound, QlPlatform, QlTimer}, QlStreamError, Runtime, RuntimeHandle, }; @@ -33,7 +33,7 @@ impl Runtime

{ pub async fn run(self) { let Self { identity, - platform, + mut platform, config, rx, tx, @@ -48,7 +48,10 @@ impl Runtime

{ }; let mut in_flight = Vec::new(); - let mut timer = platform.timer(); + let timer = platform.timer(); + let mut timer = pin!(timer); + let inbound = platform.inbound(); + let mut inbound = pin!(inbound); let recv_future = rx.recv(); let mut recv_future = pin!(recv_future); let mut poll_cursor = 0usize; @@ -58,13 +61,14 @@ impl Runtime

{ if state.fill_write_slots(&mut fsm, &platform, &mut in_flight) { state.drain_fsm_events(&mut fsm, &platform); } - timer.set_deadline(fsm.next_deadline()); + timer.as_mut().set_deadline(fsm.next_deadline()); let step = poll_fn(|cx| { next_step( cx, recv_future.as_mut(), - &mut timer, + inbound.as_mut(), + timer.as_mut(), &mut in_flight, poll_cursor, ) @@ -77,6 +81,13 @@ impl Runtime

{ log::trace!("processing command: kind={}", command.kind()); state.drive_command(&mut fsm, command, &platform); } + DriverStep::Inbound(bytes) => { + log::trace!("received transport frame: len={}", bytes.len()); + if let Err(e) = fsm.receive(now(), bytes, &platform) { + log::info!("receive rejected frame: error={e:?}"); + platform.handle_recv_error(e); + } + } DriverStep::WriteCompleted { index, success } => { let write = in_flight.swap_remove(index); let write_id = write.write_id; @@ -112,23 +123,26 @@ struct InFlightWrite { enum DriverStep { Command(RuntimeCommand), + Inbound(Vec), WriteCompleted { index: usize, success: bool }, TimerExpired, Closed, } -const STEP_COUNT: usize = 3; +const STEP_COUNT: usize = 4; -fn next_step( +fn next_step( cx: &mut Context<'_>, mut recv_future: Pin<&mut Recv<'_, RuntimeCommand>>, - timer: &mut T, + mut inbound: Pin<&mut I>, + mut timer: Pin<&mut T>, in_flight: &mut [InFlightWrite], start: usize, ) -> Poll where T: QlTimer, F: Future + Unpin, + I: QlInbound, { for offset in 0..STEP_COUNT { let step = (start + offset) % STEP_COUNT; @@ -137,7 +151,8 @@ where .as_mut() .poll(cx) .map(|res| res.map_or(DriverStep::Closed, DriverStep::Command)), - 1 => { + 1 => inbound.as_mut().poll_recv(cx).map(DriverStep::Inbound), + 2 => { for (index, write) in in_flight.iter_mut().enumerate() { if let Poll::Ready(success) = Pin::new(&mut write.future).poll(cx) { return Poll::Ready(DriverStep::WriteCompleted { index, success }); @@ -145,13 +160,7 @@ where } Poll::Pending } - 2 => { - if timer.poll_wait(cx) == Poll::Ready(()) { - Poll::Ready(DriverStep::TimerExpired) - } else { - Poll::Pending - } - } + 3 => timer.as_mut().poll_wait(cx).map(|()| DriverStep::TimerExpired), _ => unreachable!(), }; if poll.is_ready() { @@ -192,13 +201,6 @@ impl DriverState { log::info!(" starting XX pairing"); fsm.connect_xx(now(), token, platform); } - RuntimeCommand::Receive(bytes) => { - log::trace!("received transport frame: len={}", bytes.len()); - if let Err(e) = fsm.receive(now(), bytes, platform) { - log::info!("receive rejected frame: error={e:?}"); - platform.handle_recv_error(e); - } - } RuntimeCommand::OpenStream { route_id, request_reader, diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index 37783dc6..4900e5f7 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -1,19 +1,19 @@ -use std::task::{Context, Poll}; - use ql_wire::{test_identity, NoopCrypto, PeerBundle, SoftwareCrypto, StreamClose, XID}; use super::*; use crate::{ chunk_slot, driver::state::{InboundIo, OutboundIo}, + platform::QlInbound, }; pub struct NoopTimer; +pub struct NoopInbound; impl crate::platform::QlTimer for NoopTimer { - fn set_deadline(&mut self, _deadline: Option) {} + fn set_deadline(self: Pin<&mut Self>, _deadline: Option) {} - fn poll_wait(&mut self, _cx: &mut Context<'_>) -> Poll<()> { + fn poll_wait(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> { Poll::Pending } } @@ -21,11 +21,16 @@ impl crate::platform::QlTimer for NoopTimer { impl QlPlatform for NoopCrypto { type Timer = NoopTimer; type WriteMessageFut<'a> = std::future::Ready; + type Inbound = NoopInbound; fn write_message(&self, _message: Vec) -> Self::WriteMessageFut<'_> { std::future::ready(true) } + fn inbound(&mut self) -> Self::Inbound { + NoopInbound + } + fn timer(&self) -> Self::Timer { NoopTimer } @@ -37,6 +42,12 @@ impl QlPlatform for NoopCrypto { fn handle_inbound(&self, _event: QlStream) {} } +impl QlInbound for NoopInbound { + fn poll_recv(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Pending + } +} + fn new_driver_state() -> (DriverState, QlFsm) { let (runtime_tx, _runtime_rx) = async_channel::unbounded(); ( diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index 113b5596..171cd403 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -46,11 +46,6 @@ impl RuntimeHandle { self.send(RuntimeCommand::StartPairing { token }); } - /// hands inbound transport bytes to the runtime - pub fn receive(&self, bytes: Vec) { - self.send(RuntimeCommand::Receive(bytes)); - } - /// opens a new stream on the active encrypted session pub async fn open_stream(&self, route_id: RouteId) -> Result { let (request_reader, request_writer) = chunk_slot::new(); diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs index e5946534..0dabfcd6 100644 --- a/ql-runtime/src/platform.rs +++ b/ql-runtime/src/platform.rs @@ -13,8 +13,12 @@ use crate::QlStream; pub type PlatformFuture<'a, T> = Pin + 'a>>; pub trait QlTimer { - fn set_deadline(&mut self, deadline: Option); - fn poll_wait(&mut self, cx: &mut Context<'_>) -> Poll<()>; + fn set_deadline(self: Pin<&mut Self>, deadline: Option); + fn poll_wait(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()>; +} + +pub trait QlInbound { + fn poll_recv(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; } pub trait QlPlatform: QlCrypto { @@ -22,8 +26,15 @@ pub trait QlPlatform: QlCrypto { type WriteMessageFut<'a>: Future + Unpin + 'a where Self: 'a; + type Inbound: QlInbound; fn write_message(&self, message: Vec) -> Self::WriteMessageFut<'_>; + /// Returns the platform's inbound transport poller. + /// + /// The runtime calls this once while starting the driver loop and retains the returned + /// poller for the lifetime of the runtime. Platform implementations may panic if this is + /// called more than once. + fn inbound(&mut self) -> Self::Inbound; fn timer(&self) -> Self::Timer; fn persist_peer(&self, peer: PeerBundle); diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index c785f048..2963c961 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -35,8 +35,8 @@ async fn handshake_timeout_disconnects() { }, ..default_runtime_config() }; - let (platform_a, _outbound_a, status_a) = TestPlatform::new(); - let (platform_b, _outbound_b, _status_b) = TestPlatform::new(); + let (platform_a, _outbound_a, _inbound_a, status_a) = TestPlatform::new(); + let (platform_b, _outbound_b, _inbound_b, _status_b) = TestPlatform::new(); let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); @@ -57,8 +57,10 @@ async fn handshake_timeout_disconnects() { async fn rejected_session_write_is_reissued() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new_with_session_write_failure(1); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (platform_a, outbound_a, inbound_a_tx, status_a) = + TestPlatform::new_with_session_write_failure(1); + let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = + TestPlatform::new_with_inbound(); let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); @@ -67,8 +69,8 @@ async fn rejected_session_write_is_reissued() { tokio::task::spawn_local(async move { runtime_a.run().await }); tokio::task::spawn_local(async move { runtime_b.run().await }); - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); + spawn_forwarder(outbound_a, inbound_b_tx); + spawn_forwarder(outbound_b, inbound_a_tx); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); @@ -115,8 +117,8 @@ async fn rejected_session_write_is_reissued() { async fn start_pairing_round_trip_connects_when_armed() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(); - let (platform_b, outbound_b, status_b) = TestPlatform::new(); + let (platform_a, outbound_a, inbound_a_tx, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, inbound_b_tx, status_b) = TestPlatform::new(); let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let token = pairing_token(7); @@ -126,8 +128,8 @@ async fn start_pairing_round_trip_connects_when_armed() { tokio::task::spawn_local(async move { runtime_a.run().await }); tokio::task::spawn_local(async move { runtime_b.run().await }); - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); + spawn_forwarder(outbound_a, inbound_b_tx); + spawn_forwarder(outbound_b, inbound_a_tx); handle_b.arm_pairing(token); handle_a.start_pairing(token); @@ -142,19 +144,19 @@ async fn start_pairing_round_trip_connects_when_armed() { async fn start_pairing_does_not_connect_when_unarmed() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(); - let (platform_b, outbound_b, _status_b) = TestPlatform::new(); + let (platform_a, outbound_a, inbound_a_tx, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, inbound_b_tx, _status_b) = TestPlatform::new(); let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let token = pairing_token(8); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); - let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + let (runtime_b, _handle_b) = new_runtime(identity_b.clone(), platform_b, config); tokio::task::spawn_local(async move { runtime_a.run().await }); tokio::task::spawn_local(async move { runtime_b.run().await }); - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); + spawn_forwarder(outbound_a, inbound_b_tx); + spawn_forwarder(outbound_b, inbound_a_tx); handle_a.start_pairing(token); diff --git a/ql-runtime/src/tests/heartbeat.rs b/ql-runtime/src/tests/heartbeat.rs index f9c7f2be..77af8b00 100644 --- a/ql-runtime/src/tests/heartbeat.rs +++ b/ql-runtime/src/tests/heartbeat.rs @@ -21,8 +21,9 @@ async fn session_timeout_disconnects_and_fails_pending_open() { ..default_runtime_config() }; let config_b = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (platform_a, outbound_a, inbound_a_tx, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = + TestPlatform::new_with_inbound(); let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config_a); @@ -32,8 +33,8 @@ async fn session_timeout_disconnects_and_fails_pending_open() { tokio::task::spawn_local(async move { runtime_b.run().await }); let drop_flag = Arc::new(AtomicBool::new(false)); - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_gated_forwarder(outbound_b, handle_a.clone(), drop_flag.clone()); + spawn_forwarder(outbound_a, inbound_b_tx); + spawn_gated_forwarder(outbound_b, inbound_a_tx, drop_flag.clone()); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 8cc7a449..9f48f6a5 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -10,6 +10,7 @@ use std::{ }; use async_channel::{Receiver, Sender}; +use futures_lite::Stream; use ql_fsm::PeerStatus; use ql_wire::{ test_identities, test_identity, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, @@ -76,6 +77,8 @@ impl WriteStats { struct TestPlatform { outbound: Sender>, + _inbound_messages_tx: Sender>, + inbound_messages: Option>>, status: Sender, inbound: Option>, crypto: SoftwareCrypto, @@ -85,33 +88,38 @@ struct TestPlatform { write_stats: Option, } +struct TestInbound { + receiver: Receiver>, +} + impl TestPlatform { - fn new() -> (Self, Receiver>, Receiver) { + fn new() -> (Self, Receiver>, Sender>, Receiver) { Self::new_inner(None, None, Duration::ZERO, None) } fn new_with_inbound() -> ( Self, Receiver>, + Sender>, Receiver, Receiver, ) { let (inbound_tx, inbound_rx) = async_channel::unbounded(); - let (platform, outbound_rx, status_rx) = + let (platform, outbound_rx, inbound_messages_tx, status_rx) = Self::new_inner(Some(inbound_tx), None, Duration::ZERO, None); - (platform, outbound_rx, status_rx, inbound_rx) + (platform, outbound_rx, inbound_messages_tx, status_rx, inbound_rx) } fn new_with_session_write_failure( fail_encrypted_write_at: usize, - ) -> (Self, Receiver>, Receiver) { + ) -> (Self, Receiver>, Sender>, Receiver) { Self::new_inner(None, Some(fail_encrypted_write_at), Duration::ZERO, None) } fn new_with_delayed_writes( delay: Duration, write_stats: WriteStats, - ) -> (Self, Receiver>, Receiver) { + ) -> (Self, Receiver>, Sender>, Receiver) { Self::new_inner(None, None, delay, Some(write_stats)) } @@ -120,12 +128,15 @@ impl TestPlatform { fail_encrypted_write_at: Option, write_delay: Duration, write_stats: Option, - ) -> (Self, Receiver>, Receiver) { + ) -> (Self, Receiver>, Sender>, Receiver) { let (outbound, outbound_rx) = async_channel::unbounded(); + let (inbound_messages_tx, inbound_messages_rx) = async_channel::unbounded(); let (status, status_rx) = async_channel::unbounded(); ( Self { outbound, + _inbound_messages_tx: inbound_messages_tx.clone(), + inbound_messages: Some(inbound_messages_rx), status, inbound, crypto: SoftwareCrypto, @@ -135,6 +146,7 @@ impl TestPlatform { write_stats, }, outbound_rx, + inbound_messages_tx, status_rx, ) } @@ -202,8 +214,10 @@ impl TestPair { a_to_b: LinkBehavior, b_to_a: LinkBehavior, ) -> (Self, ControlledLinks) { - let (platform_a, outbound_a, status_a, inbound_a) = TestPlatform::new_with_inbound(); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (platform_a, outbound_a, inbound_a_tx, status_a, inbound_a) = + TestPlatform::new_with_inbound(); + let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = + TestPlatform::new_with_inbound(); let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let links = ControlledLinks { a_to_b: LinkController::new(a_to_b), @@ -216,8 +230,8 @@ impl TestPair { tokio::task::spawn_local(async move { runtime_a.run().await }); tokio::task::spawn_local(async move { runtime_b.run().await }); - spawn_simulated_forwarder(outbound_a, handle_b.clone(), links.a_to_b.clone()); - spawn_simulated_forwarder(outbound_b, handle_a.clone(), links.b_to_a.clone()); + spawn_simulated_forwarder(outbound_a, inbound_b_tx, links.a_to_b.clone()); + spawn_simulated_forwarder(outbound_b, inbound_a_tx, links.b_to_a.clone()); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); ( @@ -288,13 +302,13 @@ impl TokioTimer { } impl QlTimer for TokioTimer { - fn set_deadline(&mut self, deadline: Option) { + fn set_deadline(mut self: Pin<&mut Self>, deadline: Option) { let deadline = deadline.map_or_else(parked_deadline, tokio::time::Instant::from_std); - self.sleep.as_mut().reset(deadline); + self.as_mut().get_mut().sleep.as_mut().reset(deadline); } - fn poll_wait(&mut self, cx: &mut Context<'_>) -> Poll<()> { - self.sleep.as_mut().poll(cx) + fn poll_wait(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + self.as_mut().get_mut().sleep.as_mut().poll(cx) } } @@ -351,6 +365,7 @@ impl QlKem for TestPlatform { impl crate::platform::QlPlatform for TestPlatform { type Timer = TokioTimer; type WriteMessageFut<'a> = PlatformFuture<'a, bool>; + type Inbound = TestInbound; fn write_message(&self, message: Vec) -> Self::WriteMessageFut<'_> { let outbound = self.outbound.clone(); @@ -389,6 +404,15 @@ impl crate::platform::QlPlatform for TestPlatform { }) } + fn inbound(&mut self) -> Self::Inbound { + TestInbound { + receiver: self + .inbound_messages + .take() + .expect("TestPlatform::inbound may only be called once"), + } + } + fn timer(&self) -> Self::Timer { TokioTimer::new() } @@ -406,6 +430,16 @@ impl crate::platform::QlPlatform for TestPlatform { } } +impl crate::platform::QlInbound for TestInbound { + fn poll_recv(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match unsafe { self.as_mut().map_unchecked_mut(|this| &mut this.receiver) }.poll_next(cx) { + Poll::Ready(Some(bytes)) => Poll::Ready(bytes), + Poll::Ready(None) => panic!("TestInbound channel closed"), + Poll::Pending => Poll::Pending, + } + } +} + fn parked_deadline() -> tokio::time::Instant { tokio::time::Instant::now() + Duration::from_secs(60 * 60 * 24 * 365 * 100) } @@ -430,17 +464,17 @@ fn register_peers( handle_b.bind_peer(id_a.bundle()); } -fn spawn_forwarder(outbound: Receiver>, handle: RuntimeHandle) { +fn spawn_forwarder(outbound: Receiver>, inbound: Sender>) { spawn_simulated_forwarder( outbound, - handle, + inbound, LinkController::new(LinkBehavior::default()), ); } fn spawn_simulated_forwarder( outbound: Receiver>, - handle: RuntimeHandle, + inbound: Sender>, controller: LinkController, ) { tokio::task::spawn_local(async move { @@ -473,12 +507,12 @@ fn spawn_simulated_forwarder( } let primary = bytes.clone(); - let primary_handle = handle.clone(); + let primary_inbound = inbound.clone(); tokio::task::spawn_local(async move { if !delay.is_zero() { tokio::time::sleep(delay).await; } - primary_handle.receive(primary); + let _ = primary_inbound.try_send(primary); }); if ordinal.is_some_and(|count| { @@ -486,13 +520,13 @@ fn spawn_simulated_forwarder( .duplicate_encrypted_every .is_some_and(|nth| nth != 0 && count % nth == 0) }) { - let duplicate_handle = handle.clone(); + let duplicate_inbound = inbound.clone(); tokio::task::spawn_local(async move { let duplicate_delay = delay + Duration::from_millis(1); if !duplicate_delay.is_zero() { tokio::time::sleep(duplicate_delay).await; } - duplicate_handle.receive(bytes); + let _ = duplicate_inbound.try_send(bytes); }); } } @@ -501,7 +535,7 @@ fn spawn_simulated_forwarder( fn spawn_drop_every_nth_encrypted_forwarder( outbound: Receiver>, - handle: RuntimeHandle, + inbound: Sender>, nth: usize, ) { tokio::task::spawn_local(async move { @@ -513,14 +547,14 @@ fn spawn_drop_every_nth_encrypted_forwarder( continue; } } - handle.receive(bytes); + let _ = inbound.try_send(bytes); } }); } fn spawn_gated_forwarder( outbound: Receiver>, - handle: RuntimeHandle, + inbound: Sender>, drop_flag: Arc, ) { tokio::task::spawn_local(async move { @@ -528,7 +562,7 @@ fn spawn_gated_forwarder( if drop_flag.load(Ordering::Relaxed) { continue; } - handle.receive(bytes); + let _ = inbound.try_send(bytes); } }); } @@ -625,7 +659,7 @@ fn default_runtime_config() -> RuntimeConfig { fn runtime_is_send() { let config = default_runtime_config(); let identity_a = test_identity(&SoftwareCrypto); - let (platform_a, _, _) = TestPlatform::new(); + let (platform_a, _, _, _) = TestPlatform::new(); let (runtime_a, _handle) = new_runtime(identity_a, platform_a, config); std::thread::spawn(move || { tokio::runtime::Builder::new_current_thread() @@ -640,7 +674,7 @@ fn runtime_is_send() { fn runtime_exits_when_last_handle_drops() { let config = default_runtime_config(); let identity = test_identity(&SoftwareCrypto); - let (platform, _, _) = TestPlatform::new(); + let (platform, _, _, _) = TestPlatform::new(); let (runtime, handle) = new_runtime(identity, platform, config); let (done_tx, done_rx) = oneshot::channel(); diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index f464a90c..f83783b8 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -288,9 +288,10 @@ async fn max_concurrent_message_writes_is_respected() { max_concurrent_message_writes: 2, ..default_runtime_config() }; - let (platform_a, outbound_a, status_a) = + let (platform_a, outbound_a, inbound_a_tx, status_a) = TestPlatform::new_with_delayed_writes(Duration::from_millis(40), stats.clone()); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = + TestPlatform::new_with_inbound(); let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); @@ -299,8 +300,8 @@ async fn max_concurrent_message_writes_is_respected() { tokio::task::spawn_local(async move { runtime_a.run().await }); tokio::task::spawn_local(async move { runtime_b.run().await }); - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); + spawn_forwarder(outbound_a, inbound_b_tx); + spawn_forwarder(outbound_b, inbound_a_tx); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); @@ -359,8 +360,9 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { }, ..default_runtime_config() }; - let (platform_a, outbound_a, status_a) = TestPlatform::new(); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (platform_a, outbound_a, inbound_a_tx, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = + TestPlatform::new_with_inbound(); let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let request_payload: Vec = (0..32).collect(); @@ -373,8 +375,8 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { tokio::task::spawn_local(async move { runtime_a.run().await }); tokio::task::spawn_local(async move { runtime_b.run().await }); - spawn_drop_every_nth_encrypted_forwarder(outbound_a, handle_b.clone(), 3); - spawn_drop_every_nth_encrypted_forwarder(outbound_b, handle_a.clone(), 3); + spawn_drop_every_nth_encrypted_forwarder(outbound_a, inbound_b_tx, 3); + spawn_drop_every_nth_encrypted_forwarder(outbound_b, inbound_a_tx, 3); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); @@ -420,7 +422,7 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { #[allow(clippy::too_many_lines)] #[tokio::test(flavor = "current_thread")] async fn multi_megabyte_stream_survives_asymmetric_loss_and_delay() { - run_local_test_timeout(Duration::from_secs(5), async { + run_local_test_timeout(Duration::from_secs(10), async { let payload_len = 2 * 1024 * 1024; let chunk_len = 16 * 1024; let payload: Vec = (0..payload_len) @@ -541,7 +543,7 @@ async fn multi_megabyte_stream_survives_asymmetric_loss_and_delay() { #[tokio::test(flavor = "current_thread")] async fn reproducer_writer_stalls_after_reverse_path_impairment() { - run_local_test_timeout(Duration::from_secs(5), async { + run_local_test_timeout(Duration::from_secs(10), async { let payload_len = 2 * 1024 * 1024; let chunk_len = 16 * 1024; let payload: Vec = (0..payload_len) diff --git a/ql-runtime/src/tests/unpair.rs b/ql-runtime/src/tests/unpair.rs index 93c78177..751c9a53 100644 --- a/ql-runtime/src/tests/unpair.rs +++ b/ql-runtime/src/tests/unpair.rs @@ -4,8 +4,9 @@ use super::*; async fn unpair_clears_remote_peer_and_aborts_active_stream() { run_local_test(async { let config = default_runtime_config(); - let (platform_a, outbound_a, status_a) = TestPlatform::new(); - let (platform_b, outbound_b, status_b, inbound_b) = TestPlatform::new_with_inbound(); + let (platform_a, outbound_a, inbound_a_tx, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = + TestPlatform::new_with_inbound(); let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); @@ -14,8 +15,8 @@ async fn unpair_clears_remote_peer_and_aborts_active_stream() { tokio::task::spawn_local(async move { runtime_a.run().await }); tokio::task::spawn_local(async move { runtime_b.run().await }); - spawn_forwarder(outbound_a, handle_b.clone()); - spawn_forwarder(outbound_b, handle_a.clone()); + spawn_forwarder(outbound_a, inbound_b_tx); + spawn_forwarder(outbound_b, inbound_a_tx); register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect().unwrap(); From e85d75fb78fa47b5947d2fec23b1960e1ced2708 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 11:25:16 -0400 Subject: [PATCH 245/304] ql-runtime: RuntimeCommand <-> Command --- ql-runtime/src/command.rs | 4 ++-- ql-runtime/src/driver/mod.rs | 36 ++++++++++++++++----------------- ql-runtime/src/driver/state.rs | 4 ++-- ql-runtime/src/driver/test.rs | 2 +- ql-runtime/src/handle/mod.rs | 22 ++++++++++---------- ql-runtime/src/handle/reader.rs | 8 ++++---- ql-runtime/src/handle/writer.rs | 6 +++--- ql-runtime/src/lib.rs | 4 ++-- 8 files changed, 42 insertions(+), 44 deletions(-) diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index 39c9ac84..c41ae8a8 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -3,7 +3,7 @@ use ql_wire::{CloseTarget, PairingToken, PeerBundle, RouteId, StreamCloseCode, S use crate::{chunk_slot::ChunkSlotRx, QlStreamError, StreamReader}; -pub enum RuntimeCommand { +pub enum Command { BindPeer { peer: PeerBundle, }, @@ -34,7 +34,7 @@ pub enum RuntimeCommand { }, } -impl RuntimeCommand { +impl Command { #[cfg(feature = "log")] pub fn kind(&self) -> &'static str { match self { diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 8be61c93..0216de8d 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -21,7 +21,7 @@ use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; use self::state::{DriverState, DriverStreamIo, InboundIo, InboundWriteResult, OutboundIo}; use crate::{ chunk_slot, - command::RuntimeCommand, + command::Command, handle::{QlStream, StreamReader, StreamWriter}, log, platform::{QlInbound, QlPlatform, QlTimer}, @@ -122,7 +122,7 @@ struct InFlightWrite { } enum DriverStep { - Command(RuntimeCommand), + Command(Command), Inbound(Vec), WriteCompleted { index: usize, success: bool }, TimerExpired, @@ -133,7 +133,7 @@ const STEP_COUNT: usize = 4; fn next_step( cx: &mut Context<'_>, - mut recv_future: Pin<&mut Recv<'_, RuntimeCommand>>, + mut recv_future: Pin<&mut Recv<'_, Command>>, mut inbound: Pin<&mut I>, mut timer: Pin<&mut T>, in_flight: &mut [InFlightWrite], @@ -160,7 +160,10 @@ where } Poll::Pending } - 3 => timer.as_mut().poll_wait(cx).map(|()| DriverStep::TimerExpired), + 3 => timer + .as_mut() + .poll_wait(cx) + .map(|()| DriverStep::TimerExpired), _ => unreachable!(), }; if poll.is_ready() { @@ -172,36 +175,31 @@ where } impl DriverState { - fn drive_command( - &mut self, - fsm: &mut QlFsm, - command: RuntimeCommand, - platform: &P, - ) { + fn drive_command(&mut self, fsm: &mut QlFsm, command: Command, platform: &P) { match command { - RuntimeCommand::BindPeer { peer } => { + Command::BindPeer { peer } => { log::info!("binding peer"); fsm.bind_peer(peer); } - RuntimeCommand::Connect => { + Command::Connect => { log::info!("starting IK connect"); if fsm.connect_ik(now(), platform).is_err() { log::warn!("IK connect ignored: no bound peer"); } } - RuntimeCommand::ArmPairing { token } => { + Command::ArmPairing { token } => { log::info!("arming inbound pairing"); fsm.arm_pairing(token); } - RuntimeCommand::DisarmPairing => { + Command::DisarmPairing => { log::info!("disarming inbound pairing"); fsm.disarm_pairing(); } - RuntimeCommand::StartPairing { token } => { + Command::StartPairing { token } => { log::info!(" starting XX pairing"); fsm.connect_xx(now(), token, platform); } - RuntimeCommand::OpenStream { + Command::OpenStream { route_id, request_reader, request_terminal, @@ -254,15 +252,15 @@ impl DriverState { drop(stream_ops); self.poll_stream(fsm, stream_id); } - RuntimeCommand::PollInbound { stream_id } => { + Command::PollInbound { stream_id } => { log::trace!("poll inbound requested: stream_id={stream_id}"); self.handle_inbound_readable(fsm, stream_id); } - RuntimeCommand::PollStream { stream_id } => { + Command::PollStream { stream_id } => { log::trace!("poll stream requested: stream_id={stream_id}"); self.poll_stream(fsm, stream_id); } - RuntimeCommand::CloseStream { + Command::CloseStream { stream_id, target, code, diff --git a/ql-runtime/src/driver/state.rs b/ql-runtime/src/driver/state.rs index 0c93e781..f5bfdd85 100644 --- a/ql-runtime/src/driver/state.rs +++ b/ql-runtime/src/driver/state.rs @@ -5,13 +5,13 @@ use ql_wire::{CloseTarget, StreamId}; use crate::{ chunk_slot::{ChunkSlotRx, ChunkSlotTx, TrySendError}, - command::RuntimeCommand, + command::Command, QlStreamError, }; pub struct DriverState { pub streams: HashMap, - pub runtime_tx: async_channel::WeakSender, + pub runtime_tx: async_channel::WeakSender, pub max_concurrent_message_writes: usize, } diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index 4900e5f7..1f3caa47 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -153,7 +153,7 @@ fn local_close_command_reaps_when_other_half_is_already_closed() { state.drive_command( &mut fsm, - RuntimeCommand::CloseStream { + Command::CloseStream { stream_id, target: CloseTarget::Origin, code: StreamCloseCode::CANCELLED, diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index 171cd403..27047342 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -5,7 +5,7 @@ use ql_fsm::NoSessionError; use ql_wire::{CloseTarget, PairingToken, PeerBundle, RouteId, StreamId}; pub use self::{reader::*, writer::*}; -use crate::{chunk_slot, command::RuntimeCommand}; +use crate::{chunk_slot, command::Command}; #[derive(Debug)] pub struct QlStream { @@ -17,33 +17,33 @@ pub struct QlStream { #[derive(Clone)] pub struct RuntimeHandle { - tx: async_channel::Sender, + tx: async_channel::Sender, } impl RuntimeHandle { /// binds the remote peer pub fn bind_peer(&self, peer: PeerBundle) { - self.send(RuntimeCommand::BindPeer { peer }); + self.send(Command::BindPeer { peer }); } /// starts an IK handshake with the bound peer pub fn connect(&self) { - self.send(RuntimeCommand::Connect); + self.send(Command::Connect); } /// arms acceptance of inbound xx pairings for a single token pub fn arm_pairing(&self, token: PairingToken) { - self.send(RuntimeCommand::ArmPairing { token }); + self.send(Command::ArmPairing { token }); } /// disarms inbound xx pairing pub fn disarm_pairing(&self) { - self.send(RuntimeCommand::DisarmPairing); + self.send(Command::DisarmPairing); } /// starts an outbound xx handshake using the supplied pairing token pub fn start_pairing(&self, token: PairingToken) { - self.send(RuntimeCommand::StartPairing { token }); + self.send(Command::StartPairing { token }); } /// opens a new stream on the active encrypted session @@ -52,7 +52,7 @@ impl RuntimeHandle { let (request_terminal_tx, request_terminal_rx) = oneshot::channel(); let (start_tx, start_rx) = oneshot::channel(); - self.send(RuntimeCommand::OpenStream { + self.send(Command::OpenStream { route_id, request_reader, request_terminal: request_terminal_tx, @@ -85,17 +85,17 @@ impl RuntimeHandle { } impl RuntimeHandle { - pub(crate) fn new(tx: async_channel::Sender) -> Self { + pub(crate) fn new(tx: async_channel::Sender) -> Self { Self { tx } } #[inline] #[track_caller] - pub(crate) fn send(&self, cmd: RuntimeCommand) { + pub(crate) fn send(&self, cmd: Command) { self.tx.try_send(cmd).expect("runtime is alive"); } - pub(crate) fn try_send(&self, cmd: RuntimeCommand) -> bool { + pub(crate) fn try_send(&self, cmd: Command) -> bool { self.tx.try_send(cmd).is_ok() } } diff --git a/ql-runtime/src/handle/reader.rs b/ql-runtime/src/handle/reader.rs index 467bc3d6..c3ccf353 100644 --- a/ql-runtime/src/handle/reader.rs +++ b/ql-runtime/src/handle/reader.rs @@ -8,7 +8,7 @@ use bytes::Bytes; use event_listener::EventListener; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; -use crate::{chunk_slot::ChunkSlotRx, command::RuntimeCommand, log, QlStreamError, RuntimeHandle}; +use crate::{chunk_slot::ChunkSlotRx, command::Command, log, QlStreamError, RuntimeHandle}; pub struct StreamReader { stream_id: StreamId, @@ -79,7 +79,7 @@ impl StreamReader { self.target, bytes.len() ); - self.handle.try_send(RuntimeCommand::PollInbound { + self.handle.try_send(Command::PollInbound { stream_id: self.stream_id, }); return Poll::Ready(Ok(Some(bytes))); @@ -165,7 +165,7 @@ impl StreamReader { self.reader.take(); self.wait = None; self.terminal = TerminalState::Delivered; - self.handle.try_send(RuntimeCommand::CloseStream { + self.handle.try_send(Command::CloseStream { stream_id: self.stream_id, target: self.target, code, @@ -184,7 +184,7 @@ impl Drop for StreamReader { self.target, StreamCloseCode::CANCELLED ); - self.handle.try_send(RuntimeCommand::CloseStream { + self.handle.try_send(Command::CloseStream { stream_id: self.stream_id, target: self.target, code: StreamCloseCode::CANCELLED, diff --git a/ql-runtime/src/handle/writer.rs b/ql-runtime/src/handle/writer.rs index 2d2700e6..572bd1f3 100644 --- a/ql-runtime/src/handle/writer.rs +++ b/ql-runtime/src/handle/writer.rs @@ -10,7 +10,7 @@ use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; use crate::{ chunk_slot::{ChunkSlotTx, SendClosed}, - command::RuntimeCommand, + command::Command, log, QlStreamError, RuntimeHandle, }; @@ -143,7 +143,7 @@ impl StreamWriter { } fn poll_runtime(&self) { - self.handle.try_send(RuntimeCommand::PollStream { + self.handle.try_send(Command::PollStream { stream_id: self.stream_id, }); } @@ -175,7 +175,7 @@ impl StreamWriter { code ); self.wait = None; - self.handle.try_send(RuntimeCommand::CloseStream { + self.handle.try_send(Command::CloseStream { stream_id: self.stream_id, target: self.target, code, diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index ea965d64..8f3083de 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -39,8 +39,8 @@ pub struct Runtime

{ identity: QlIdentity, platform: P, config: RuntimeConfig, - rx: async_channel::Receiver, - tx: async_channel::WeakSender, + rx: async_channel::Receiver, + tx: async_channel::WeakSender, } pub fn new_runtime

( From 2096dfe1c5f00f8b6f9f18b0106cbea08a24b93b Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 11:30:19 -0400 Subject: [PATCH 246/304] ql-runtime: remove stream_send_buffer_bytes config --- ql-runtime/src/lib.rs | 2 -- ql-runtime/src/tests/stream.rs | 2 -- 2 files changed, 4 deletions(-) diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index 8f3083de..a85388ad 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -21,7 +21,6 @@ use ql_wire::QlIdentity; #[derive(Debug, Clone, Copy)] pub struct RuntimeConfig { pub fsm: QlFsmConfig, - pub stream_send_buffer_bytes: usize, pub max_concurrent_message_writes: usize, } @@ -29,7 +28,6 @@ impl Default for RuntimeConfig { fn default() -> Self { Self { fsm: QlFsmConfig::default(), - stream_send_buffer_bytes: 16 * 1024, max_concurrent_message_writes: 4, } } diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index f83783b8..5991f064 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -440,7 +440,6 @@ async fn multi_megabyte_stream_survives_asymmetric_loss_and_delay() { session_pending_ack_range_limit: 4 * 1024, ..default_runtime_config().fsm }, - stream_send_buffer_bytes: 4 * 1024 * 1024, ..default_runtime_config() }; let (mut pair, links) = TestPair::new_with_controlled_links( @@ -560,7 +559,6 @@ async fn reproducer_writer_stalls_after_reverse_path_impairment() { session_pending_ack_range_limit: 4 * 1024, ..default_runtime_config().fsm }, - stream_send_buffer_bytes: 4 * 1024 * 1024, ..default_runtime_config() }; let (mut pair, links) = TestPair::new_with_controlled_links( From f605812e33a89110c326d17b3d2b8dfaf790b2c4 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 11:48:31 -0400 Subject: [PATCH 247/304] ql-runtime: logs for tests --- ql-runtime/Cargo.toml | 2 ++ ql-runtime/src/command.rs | 2 +- ql-runtime/src/log.rs | 4 ++-- ql-runtime/src/tests/mod.rs | 14 +++++++++++++- 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/ql-runtime/Cargo.toml b/ql-runtime/Cargo.toml index 50f7db19..1760928e 100644 --- a/ql-runtime/Cargo.toml +++ b/ql-runtime/Cargo.toml @@ -22,6 +22,8 @@ ql-rpc = { path = "../ql-rpc", optional = true } ql-wire = { path = "../ql-wire" } [dev-dependencies] +env_logger = "0.11" +log = "0.4" ql-wire = { path = "../ql-wire", features = ["test-utils"] } tokio = { version = "1.44", features = ["macros", "rt", "time", "sync"] } diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index c41ae8a8..af0a9d37 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -35,7 +35,7 @@ pub enum Command { } impl Command { - #[cfg(feature = "log")] + #[cfg(any(feature = "log", test))] pub fn kind(&self) -> &'static str { match self { Self::BindPeer { .. } => "BindPeer", diff --git a/ql-runtime/src/log.rs b/ql-runtime/src/log.rs index 943ff26a..ac2264c4 100644 --- a/ql-runtime/src/log.rs +++ b/ql-runtime/src/log.rs @@ -1,13 +1,13 @@ #![allow(unused_imports, unused_macros)] -#[cfg(feature = "log")] +#[cfg(any(feature = "log", test))] macro_rules! with_log { ($($tt:tt)*) => { $($tt)* }; } -#[cfg(not(feature = "log"))] +#[cfg(not(any(feature = "log", test)))] macro_rules! with_log { ($($tt:tt)*) => {}; } diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 9f48f6a5..e600e441 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -3,7 +3,7 @@ use std::{ pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, - Arc, Mutex, + Arc, Mutex, Once, }, task::{Context, Poll}, time::Duration, @@ -31,6 +31,17 @@ mod heartbeat; mod rpc; mod stream; +fn init_test_logger() { + static INIT: Once = Once::new(); + + INIT.call_once(|| { + let env = env_logger::Env::default().default_filter_or("ql_runtime=info"); + let mut builder = env_logger::Builder::from_env(env); + builder.is_test(true); + let _ = builder.try_init(); + }); +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] struct StatusEvent { peer: XID, @@ -580,6 +591,7 @@ async fn run_local_test_timeout(duration: Duration, future: F) where F: Future, { + init_test_logger(); let local = LocalSet::new(); let future = local.run_until(future); tokio::time::timeout(duration, future) From 81680a5b4c3e93fbbe382b20d949363e1e95ae62 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 11:55:22 -0400 Subject: [PATCH 248/304] ql-runtime: drain pending session work before shutdown --- ql-fsm/src/lib.rs | 7 +++++++ ql-fsm/src/session/mod.rs | 4 ++++ ql-runtime/src/driver/mod.rs | 19 +++++++++++-------- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index b8ac956e..2f22cdf8 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -285,6 +285,13 @@ impl QlFsm { fsm::next_deadline(self) } + pub fn has_shutdown_work(&self) -> bool { + self.state + .link + .connected() + .is_some_and(|state| state.session.has_shutdown_work()) + } + /// returns the next outbound record /// /// if `write_id` is `Some`, call `complete_write` exactly once diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 3978bbb7..a8675f4c 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -350,6 +350,10 @@ impl SessionFsm { .min() } + pub fn has_shutdown_work(&self) -> bool { + self.state.ack_tracker.ack_deadline().is_some() || !self.state.tracked_records.is_empty() + } + pub fn take_next_write(&mut self, now: Instant) -> Option<(Option, SessionRecordBuilder)> { match &self.state.phase { SessionPhase::Closing(close) => { diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 0216de8d..26504fab 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -53,7 +53,7 @@ impl Runtime

{ let inbound = platform.inbound(); let mut inbound = pin!(inbound); let recv_future = rx.recv(); - let mut recv_future = pin!(recv_future); + let mut recv_future = Some(pin!(recv_future)); let mut poll_cursor = 0usize; loop { @@ -66,7 +66,7 @@ impl Runtime

{ let step = poll_fn(|cx| { next_step( cx, - recv_future.as_mut(), + recv_future.as_mut().map(|future| future.as_mut()), inbound.as_mut(), timer.as_mut(), &mut in_flight, @@ -106,7 +106,8 @@ impl Runtime

{ "command channel closed: in_flight_writes={}", in_flight.len() ); - if in_flight.is_empty() { + recv_future = None; + if in_flight.is_empty() && !fsm.has_shutdown_work() { break; } } @@ -133,7 +134,7 @@ const STEP_COUNT: usize = 4; fn next_step( cx: &mut Context<'_>, - mut recv_future: Pin<&mut Recv<'_, Command>>, + mut recv_future: Option>>, mut inbound: Pin<&mut I>, mut timer: Pin<&mut T>, in_flight: &mut [InFlightWrite], @@ -147,10 +148,12 @@ where for offset in 0..STEP_COUNT { let step = (start + offset) % STEP_COUNT; let poll = match step { - 0 => recv_future - .as_mut() - .poll(cx) - .map(|res| res.map_or(DriverStep::Closed, DriverStep::Command)), + 0 => recv_future.as_mut().map_or(Poll::Pending, |recv_future| { + recv_future + .as_mut() + .poll(cx) + .map(|res| res.map_or(DriverStep::Closed, DriverStep::Command)) + }), 1 => inbound.as_mut().poll_recv(cx).map(DriverStep::Inbound), 2 => { for (index, write) in in_flight.iter_mut().enumerate() { From 1bf558da5c3a5bc2792b3b9e53c3238c1b389248 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 13:58:38 -0400 Subject: [PATCH 249/304] ql: fix clippy --- Cargo.lock | 200 +++++++++++++++++++++++++---- ql-fsm/src/fsm.rs | 2 +- ql-fsm/src/handshake/xx.rs | 6 +- ql-fsm/src/lib.rs | 2 +- ql-fsm/src/session/mod.rs | 2 +- ql-fsm/src/state.rs | 1 + ql-runtime/src/chunk_slot/mod.rs | 2 +- ql-runtime/src/chunk_slot/queue.rs | 6 +- ql-runtime/src/command.rs | 1 - ql-runtime/src/driver/mod.rs | 11 +- ql-runtime/src/log.rs | 15 +-- ql-runtime/src/tests/mod.rs | 38 +++++- ql-runtime/src/tests/stream.rs | 13 +- ql-wire/src/encrypted/mod.rs | 2 +- 14 files changed, 245 insertions(+), 56 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fd232744..452d0150 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -72,7 +72,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dbb4e440d04be07da1f1bf44fb4495ebd58669372fe0cffa6e48595ac5bd88a3" dependencies = [ "android_log-sys", - "env_filter", + "env_filter 0.1.3", "log", ] @@ -85,6 +85,56 @@ dependencies = [ "libc", ] +[[package]] +name = "anstream" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" + +[[package]] +name = "anstyle-parse" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys 0.61.2", +] + [[package]] name = "anyhow" version = "1.0.99" @@ -506,7 +556,7 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-link", + "windows-link 0.1.3", ] [[package]] @@ -520,6 +570,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "colorchoice" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -539,7 +595,7 @@ dependencies = [ "encode_unicode", "libc", "once_cell", - "windows-sys", + "windows-sys 0.59.0", ] [[package]] @@ -857,6 +913,29 @@ dependencies = [ "regex", ] +[[package]] +name = "env_filter" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e90c2accc4b07a8456ea0debdc2e7587bdd890680d71173a15d4ae604f6eef" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0621c04f2196ac3f488dd583365b9c09be011a4ab8b9f37248ffcc8f6198b56a" +dependencies = [ + "anstream", + "anstyle", + "env_filter 1.0.1", + "jiff", + "log", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -870,7 +949,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.59.0", ] [[package]] @@ -1104,7 +1183,7 @@ dependencies = [ "libc", "log", "rustversion", - "windows-link", + "windows-link 0.1.3", "windows-result", ] @@ -1458,6 +1537,12 @@ dependencies = [ "libc", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + [[package]] name = "itertools" version = "0.11.0" @@ -1473,6 +1558,30 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +[[package]] +name = "jiff" +version = "0.2.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a3546dc96b6d42c5f24902af9e2538e82e39ad350b0c766eb3fbf2d8f3d8359" +dependencies = [ + "jiff-static", + "log", + "portable-atomic", + "portable-atomic-util", + "serde_core", +] + +[[package]] +name = "jiff-static" +version = "0.2.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a8c8b344124222efd714b73bb41f8b5120b27a7cc1c75593a6ff768d9d05aa4" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "jobserver" version = "0.1.33" @@ -1627,9 +1736,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.27" +version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" [[package]] name = "loom" @@ -1715,7 +1824,7 @@ checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" dependencies = [ "libc", "wasi 0.11.1+wasi-snapshot-preview1", - "windows-sys", + "windows-sys 0.59.0", ] [[package]] @@ -1744,7 +1853,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys", + "windows-sys 0.59.0", ] [[package]] @@ -1828,6 +1937,12 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + [[package]] name = "oneshot" version = "0.1.11" @@ -2051,6 +2166,15 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +[[package]] +name = "portable-atomic-util" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a106d1259c23fac8e543272398ae0e3c0b8d33c88ed73d0cc71b0f1d902618" +dependencies = [ + "portable-atomic", +] + [[package]] name = "potential_utf" version = "0.1.2" @@ -2236,6 +2360,7 @@ version = "0.1.0" dependencies = [ "async-channel", "bytes", + "env_logger", "event-listener", "futures-lite", "log", @@ -2385,9 +2510,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.11.1" +version = "1.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" dependencies = [ "aho-corasick", "memchr", @@ -2397,9 +2522,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.9" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" dependencies = [ "aho-corasick", "memchr", @@ -2507,7 +2632,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.59.0", ] [[package]] @@ -2608,18 +2733,28 @@ checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", @@ -2842,7 +2977,7 @@ dependencies = [ "getrandom 0.3.3", "once_cell", "rustix", - "windows-sys", + "windows-sys 0.59.0", ] [[package]] @@ -3071,6 +3206,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "uuid" version = "1.18.1" @@ -3207,7 +3348,7 @@ checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" dependencies = [ "windows-implement", "windows-interface", - "windows-link", + "windows-link 0.1.3", "windows-result", "windows-strings", ] @@ -3240,13 +3381,19 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + [[package]] name = "windows-result" version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" dependencies = [ - "windows-link", + "windows-link 0.1.3", ] [[package]] @@ -3255,7 +3402,7 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" dependencies = [ - "windows-link", + "windows-link 0.1.3", ] [[package]] @@ -3267,6 +3414,15 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link 0.2.1", +] + [[package]] name = "windows-targets" version = "0.52.6" diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs index 9114a19b..3eae2459 100644 --- a/ql-fsm/src/fsm.rs +++ b/ql-fsm/src/fsm.rs @@ -242,7 +242,7 @@ fn apply_session_closed(fsm: &mut QlFsm) { } } -pub(super) fn deadline_after_secs(now_secs: u64, duration: Duration) -> u64 { +pub fn deadline_after_secs(now_secs: u64, duration: Duration) -> u64 { now_secs.saturating_add(duration_to_secs(duration)) } diff --git a/ql-fsm/src/handshake/xx.rs b/ql-fsm/src/handshake/xx.rs index d15b934d..a08682e0 100644 --- a/ql-fsm/src/handshake/xx.rs +++ b/ql-fsm/src/handshake/xx.rs @@ -42,10 +42,10 @@ pub fn handle_xx1( } match fsm.state.armed_pairing_token { Some(expected) if expected.id(crypto) != message.header.pairing_id => { - return Err(ReceiveError::InvalidPairingId { + Err(ReceiveError::InvalidPairingId { expected: expected.id(crypto), actual: message.header.pairing_id, - }); + }) } Some(token) => { reset_connected_session_if_needed(fsm); @@ -67,7 +67,7 @@ pub fn handle_xx1( enqueue_handshake(fsm, QlHandshakeRecord::Xx2(outbound)); Ok(()) } - None => return Err(ReceiveError::NotPairingMode), + None => Err(ReceiveError::NotPairingMode), } } diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 2f22cdf8..2b2fa1e6 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -114,7 +114,7 @@ pub struct StreamOps<'a> { inner: session::StreamOps<'a, fsm::FsmEventEmitter<'a>>, } -impl<'a> StreamOps<'a> { +impl StreamOps<'_> { /// returns this stream's identifier pub fn stream_id(&self) -> StreamId { self.inner.stream_id() diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index a8675f4c..9415f4e5 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -80,7 +80,7 @@ pub enum SessionEvent { SessionClosed(SessionClose), } -pub(crate) trait EventSink { +pub trait EventSink { fn emit(&mut self, event: SessionEvent); } diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index 79a8c5ee..4cb403b3 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -41,6 +41,7 @@ impl SessionTransport { } } +#[allow(clippy::large_enum_variant)] pub enum LinkState { Idle, IkInitiator(IkInitiatorState), diff --git a/ql-runtime/src/chunk_slot/mod.rs b/ql-runtime/src/chunk_slot/mod.rs index b15aee8a..d7c330f6 100644 --- a/ql-runtime/src/chunk_slot/mod.rs +++ b/ql-runtime/src/chunk_slot/mod.rs @@ -12,7 +12,7 @@ use self::queue::{PopError, PushError, Single}; mod queue; mod sync; -use sync::*; +use sync::Arc; /// creates a single-chunk handoff pair /// receiver-side partial reads keep the remainder locally diff --git a/ql-runtime/src/chunk_slot/queue.rs b/ql-runtime/src/chunk_slot/queue.rs index 2b596fa6..a0325efc 100644 --- a/ql-runtime/src/chunk_slot/queue.rs +++ b/ql-runtime/src/chunk_slot/queue.rs @@ -3,6 +3,7 @@ use core::mem::MaybeUninit; +#[allow(clippy::wildcard_imports)] use super::sync::*; const LOCKED: usize = 1 << 0; @@ -27,13 +28,14 @@ pub struct Single { slot: UnsafeCell>, } +#[allow(clippy::non_send_fields_in_send_ty)] unsafe impl Send for Single {} unsafe impl Sync for Single {} impl Single { /// Creates a new single-element queue. - pub fn new() -> Single { - Single { + pub fn new() -> Self { + Self { state: AtomicUsize::new(0), slot: UnsafeCell::new(MaybeUninit::uninit()), } diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index af0a9d37..1757d843 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -35,7 +35,6 @@ pub enum Command { } impl Command { - #[cfg(any(feature = "log", test))] pub fn kind(&self) -> &'static str { match self { Self::BindPeer { .. } => "BindPeer", diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 26504fab..725ef771 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -178,6 +178,7 @@ where } impl DriverState { + #[allow(clippy::too_many_lines)] fn drive_command(&mut self, fsm: &mut QlFsm, command: Command, platform: &P) { match command { Command::BindPeer { peer } => { @@ -339,8 +340,8 @@ impl DriverState { Event::WritableClosed(frame) => { self.handle_writable_closed(&frame); } - Event::SessionClosed(_close) => { - log::info!("session closed: frame={_close:?}"); + Event::SessionClosed(close) => { + log::info!("session closed: frame={close:?}"); for (_, mut stream) in self.streams.drain() { stream.fail_all(); } @@ -590,8 +591,10 @@ impl DriverState { break; } - let _len = bytes.len(); - log::trace!("writing stream bytes: stream_id={stream_id} len={_len}"); + log::trace!( + "writing stream bytes: stream_id={stream_id} len={}", + bytes.len() + ); let _ = writer.write(&mut bytes); } diff --git a/ql-runtime/src/log.rs b/ql-runtime/src/log.rs index ac2264c4..a0908f79 100644 --- a/ql-runtime/src/log.rs +++ b/ql-runtime/src/log.rs @@ -1,21 +1,17 @@ #![allow(unused_imports, unused_macros)] #[cfg(any(feature = "log", test))] -macro_rules! with_log { - ($($tt:tt)*) => { - $($tt)* +macro_rules! log { + ($level:ident, $($arg:tt)*) => { + ::log::log!(::log::Level::$level, $($arg)*) }; } #[cfg(not(any(feature = "log", test)))] -macro_rules! with_log { - ($($tt:tt)*) => {}; -} - macro_rules! log { ($level:ident, $($arg:tt)*) => { - $crate::log::with_log! { - ::log::log!(::log::Level::$level, $($arg)*) + if false { + let _ = format_args!($($arg)*); } }; } @@ -56,4 +52,3 @@ pub(crate) use info; pub(crate) use log; pub(crate) use trace; pub(crate) use warn_ as warn; -pub(crate) use with_log; diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index e600e441..5e27d5f7 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -104,7 +104,12 @@ struct TestInbound { } impl TestPlatform { - fn new() -> (Self, Receiver>, Sender>, Receiver) { + fn new() -> ( + Self, + Receiver>, + Sender>, + Receiver, + ) { Self::new_inner(None, None, Duration::ZERO, None) } @@ -118,19 +123,35 @@ impl TestPlatform { let (inbound_tx, inbound_rx) = async_channel::unbounded(); let (platform, outbound_rx, inbound_messages_tx, status_rx) = Self::new_inner(Some(inbound_tx), None, Duration::ZERO, None); - (platform, outbound_rx, inbound_messages_tx, status_rx, inbound_rx) + ( + platform, + outbound_rx, + inbound_messages_tx, + status_rx, + inbound_rx, + ) } fn new_with_session_write_failure( fail_encrypted_write_at: usize, - ) -> (Self, Receiver>, Sender>, Receiver) { + ) -> ( + Self, + Receiver>, + Sender>, + Receiver, + ) { Self::new_inner(None, Some(fail_encrypted_write_at), Duration::ZERO, None) } fn new_with_delayed_writes( delay: Duration, write_stats: WriteStats, - ) -> (Self, Receiver>, Sender>, Receiver) { + ) -> ( + Self, + Receiver>, + Sender>, + Receiver, + ) { Self::new_inner(None, None, delay, Some(write_stats)) } @@ -139,7 +160,12 @@ impl TestPlatform { fail_encrypted_write_at: Option, write_delay: Duration, write_stats: Option, - ) -> (Self, Receiver>, Sender>, Receiver) { + ) -> ( + Self, + Receiver>, + Sender>, + Receiver, + ) { let (outbound, outbound_rx) = async_channel::unbounded(); let (inbound_messages_tx, inbound_messages_rx) = async_channel::unbounded(); let (status, status_rx) = async_channel::unbounded(); @@ -583,7 +609,7 @@ async fn run_local_test(future: F) where F: Future, { - run_local_test_timeout(Duration::from_secs(5), future).await + run_local_test_timeout(Duration::from_secs(5), future).await; } #[allow(clippy::future_not_send)] diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index 5991f064..b85f0229 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -586,7 +586,7 @@ async fn reproducer_writer_stalls_after_reverse_path_impairment() { let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); let mut reader = stream.reader; - while let Some(_) = next_chunk(&mut reader).await.unwrap() {} + while next_chunk(&mut reader).await.unwrap().is_some() {} }); let recovery_links = links.clone(); @@ -640,7 +640,10 @@ async fn responder_drains_multiple_local_chunks_per_writable_wake() { let mut writer = inbound.writer; for _ in 0..chunk_count { - writer.write(Bytes::from(vec![0x5a; chunk_len])).await.unwrap(); + writer + .write(Bytes::from(vec![0x5a; chunk_len])) + .await + .unwrap(); } writer.finish().await.unwrap(); }); @@ -651,7 +654,11 @@ async fn responder_drains_multiple_local_chunks_per_writable_wake() { .open_stream(test_route_id()) .await .unwrap(); - stream.writer.write(Bytes::from_static(b"request")).await.unwrap(); + stream + .writer + .write(Bytes::from_static(b"request")) + .await + .unwrap(); stream.writer.finish().await.unwrap(); let received = read_all(stream.reader).await.unwrap(); diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 0c0e6338..1b838ddd 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -99,7 +99,7 @@ impl WireEncode for SessionFrame { #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] -pub(crate) enum SessionFrameKind { +pub enum SessionFrameKind { Ping = 1, Ack = 2, StreamData = 3, From d90c816bcc333b03ad949901de93de9ed639bd3b Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 14:33:42 -0400 Subject: [PATCH 250/304] ql-runtime: fix runtime future send test --- ql-runtime/src/platform.rs | 2 -- ql-runtime/src/tests/mod.rs | 23 ++++++++--------------- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs index 0dabfcd6..2dd974b1 100644 --- a/ql-runtime/src/platform.rs +++ b/ql-runtime/src/platform.rs @@ -10,8 +10,6 @@ use ql_wire::{PeerBundle, QlCrypto, XID}; use crate::QlStream; -pub type PlatformFuture<'a, T> = Pin + 'a>>; - pub trait QlTimer { fn set_deadline(self: Pin<&mut Self>, deadline: Option); fn poll_wait(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()>; diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 5e27d5f7..f2cdb757 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -20,9 +20,8 @@ use ql_wire::{ use tokio::{task::LocalSet, time::Sleep}; use crate::{ - new_runtime, - platform::{PlatformFuture, QlTimer}, - NoSessionError, QlFsmConfig, QlStream, QlStreamError, RuntimeConfig, RuntimeHandle, + new_runtime, platform::QlTimer, NoSessionError, QlFsmConfig, QlStream, QlStreamError, + RuntimeConfig, RuntimeHandle, }; mod handshake; @@ -401,7 +400,7 @@ impl QlKem for TestPlatform { impl crate::platform::QlPlatform for TestPlatform { type Timer = TokioTimer; - type WriteMessageFut<'a> = PlatformFuture<'a, bool>; + type WriteMessageFut<'a> = Pin + Send + 'a>>; type Inbound = TestInbound; fn write_message(&self, message: Vec) -> Self::WriteMessageFut<'_> { @@ -692,20 +691,14 @@ fn default_runtime_config() -> RuntimeConfig { } } -// runtime is send, though the Runtime::run future itself is not +// runtime is send, if platform is send #[test] fn runtime_is_send() { let config = default_runtime_config(); - let identity_a = test_identity(&SoftwareCrypto); - let (platform_a, _, _, _) = TestPlatform::new(); - let (runtime_a, _handle) = new_runtime(identity_a, platform_a, config); - std::thread::spawn(move || { - tokio::runtime::Builder::new_current_thread() - .enable_time() - .build() - .unwrap() - .block_on(runtime_a.run()); - }); + let identity = test_identity(&SoftwareCrypto); + let (platform, _, _, _) = TestPlatform::new(); + let (runtime, _handle) = new_runtime(identity, platform, config); + let _run: Box + Send> = Box::new(runtime.run()); } #[test] From dfef7291ce738d38d2e3c67bf3b9f91819ebabcc Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 18:06:29 -0400 Subject: [PATCH 251/304] ql-rpc: download support --- ql-rpc/src/codec.rs | 15 +++ ql-rpc/src/lib.rs | 2 +- ql-rpc/src/router/builder.rs | 31 ++++++ ql-rpc/src/router/download.rs | 136 ++++++++++++++++++++++++ ql-rpc/src/router/mod.rs | 2 + ql-rpc/src/rpc/download.rs | 75 +++++++++++++ ql-rpc/src/rpc/mod.rs | 4 + ql-rpc/src/rpc/request_with_progress.rs | 12 +-- ql-rpc/src/rpc/subscription.rs | 14 +-- ql-rpc/src/rpc/upload.rs | 14 +++ 10 files changed, 287 insertions(+), 18 deletions(-) create mode 100644 ql-rpc/src/router/download.rs create mode 100644 ql-rpc/src/rpc/download.rs create mode 100644 ql-rpc/src/rpc/upload.rs diff --git a/ql-rpc/src/codec.rs b/ql-rpc/src/codec.rs index 87375c4b..1cad921f 100644 --- a/ql-rpc/src/codec.rs +++ b/ql-rpc/src/codec.rs @@ -134,6 +134,21 @@ impl ChunkQueue { self.remaining } + pub fn pop_front(&mut self, max_len: usize) -> Option { + let front = self.chunks.front_mut()?; + let chunk = if max_len >= front.len() { + self.chunks.pop_front().expect("buffered chunk is present") + } else { + front.split_to(max_len) + }; + self.remaining -= chunk.len(); + Some(chunk) + } + + pub fn pop_front_chunk(&mut self) -> Option { + self.pop_front(usize::MAX) + } + pub fn try_take_part(&mut self) -> Result>, Error> { let Some(len) = self.peek_next_part_len()? else { return Ok(None); diff --git a/ql-rpc/src/lib.rs b/ql-rpc/src/lib.rs index ded00d2e..e81c3014 100644 --- a/ql-rpc/src/lib.rs +++ b/ql-rpc/src/lib.rs @@ -6,7 +6,7 @@ mod router; pub mod rpc; mod stream; -pub use codec::{ReadValueStep, RpcCodec, ValueReader}; +pub use codec::{ChunkQueue, ReadValueStep, RpcCodec, ValueReader}; pub use error::*; pub use router::*; pub use rpc::*; diff --git a/ql-rpc/src/router/builder.rs b/ql-rpc/src/router/builder.rs index 52e27926..cee40df0 100644 --- a/ql-rpc/src/router/builder.rs +++ b/ql-rpc/src/router/builder.rs @@ -1,12 +1,14 @@ use std::collections::HashMap; use super::{ + download::{handle_download_inner, DownloadHandler}, request::{handle_request_inner, RequestHandler}, subscription::{handle_subscription_inner, SubscriptionHandler}, LocalSpawn, LocalSpawner, RouteFn, Router, RouterConfig, RpcStream, SendSpawn, SendSpawner, Spawner, }; use crate::{ + download::Download as DownloadRpc, request::Request as RequestRpc, subscription::Subscription as SubscriptionRpc, RouteId, }; @@ -76,6 +78,19 @@ where }) } + pub fn download(self) -> Self + where + M: DownloadRpc + 'static, + S: DownloadHandler + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_download_inner::( + state, config, reader, writer, + )) + }) + } + pub fn subscription(self) -> Self where M: SubscriptionRpc + 'static, @@ -110,6 +125,22 @@ where }) } + pub fn download(self) -> Self + where + M: DownloadRpc + 'static, + M::Request: Send + 'static, + S: DownloadHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_download_inner::( + state, config, reader, writer, + )) + }) + } + pub fn subscription(self) -> Self where M: SubscriptionRpc + 'static, diff --git a/ql-rpc/src/router/download.rs b/ql-rpc/src/router/download.rs new file mode 100644 index 00000000..50922783 --- /dev/null +++ b/ql-rpc/src/router/download.rs @@ -0,0 +1,136 @@ +use std::marker::PhantomData; + +use bytes::Bytes; + +use super::{request::read_value_and_eof, RouterConfig}; +use crate::{ + codec, download::Download as DownloadRpc, finish_bytes, write_bytes, RpcCodec, RpcStream, + RpcRead, RpcWrite, StreamCloseCode, StreamError, +}; + +pub trait DownloadHandler +where + M: DownloadRpc, + St: RpcStream, +{ + fn handle( + self, + message: M::Request, + responder: DownloadResponder, + ); + + fn handle_transport_error(&self, _error: &St::Error) {} +} + +pub struct DownloadResponder +where + W: RpcWrite, +{ + writer: Option, + marker: PhantomData T>, +} + +pub struct DownloadWriter +where + W: RpcWrite, +{ + writer: Option, +} + +impl DownloadResponder +where + T: RpcCodec, + W: RpcWrite, +{ + fn new(writer: W) -> Self { + Self { + writer: Some(writer), + marker: PhantomData, + } + } + + pub async fn respond(mut self, response_header: T) -> Result, W::Error> { + let mut writer = self.writer.take().expect("download writer exists"); + let mut encoded = Vec::new(); + codec::encode_value_part(&response_header, &mut encoded); + write_bytes(&mut writer, Bytes::from(encoded)).await?; + Ok(DownloadWriter { + writer: Some(writer), + }) + } + + pub fn close(mut self, code: StreamCloseCode) { + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for DownloadResponder +where + W: RpcWrite, +{ + fn drop(&mut self) { + if let Some(writer) = self.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } +} + +impl DownloadWriter +where + W: RpcWrite, +{ + pub async fn send(&mut self, bytes: Bytes) -> Result<(), W::Error> { + let writer = self.writer.as_mut().expect("download body writer exists"); + write_bytes(writer, bytes).await + } + + pub async fn finish(mut self) -> Result<(), W::Error> { + let mut writer = self.writer.take().expect("download body writer exists"); + finish_bytes(&mut writer).await + } + + pub fn close(mut self, code: StreamCloseCode) { + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for DownloadWriter +where + W: RpcWrite, +{ + fn drop(&mut self) { + if let Some(writer) = self.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } +} + +pub(super) async fn handle_download_inner( + state: S, + config: RouterConfig, + mut reader: St::Reader, + writer: St::Writer, +) where + M: DownloadRpc + 'static, + S: DownloadHandler + 'static, + St: RpcStream + 'static, +{ + let request = match read_value_and_eof::(&mut reader, config).await { + Ok(request) => request, + Err(error) => { + let code = error.close_code(); + state.handle_transport_error(&error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } + return; + } + }; + + state.handle(request, DownloadResponder::new(writer)); +} diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs index 4e98ef0e..701a03b3 100644 --- a/ql-rpc/src/router/mod.rs +++ b/ql-rpc/src/router/mod.rs @@ -4,6 +4,7 @@ use crate::{RouteId, StreamCloseCode}; mod builder; mod config; +mod download; mod mode; mod request; mod subscription; @@ -11,6 +12,7 @@ mod subscription; pub use self::{ builder::RouterBuilder, config::RouterConfig, + download::{DownloadHandler, DownloadResponder, DownloadWriter}, mode::*, request::{RequestHandler, Response}, subscription::{SubscriptionHandler, SubscriptionResponder}, diff --git a/ql-rpc/src/rpc/download.rs b/ql-rpc/src/rpc/download.rs new file mode 100644 index 00000000..0e6e6556 --- /dev/null +++ b/ql-rpc/src/rpc/download.rs @@ -0,0 +1,75 @@ +use std::marker::PhantomData; + +use bytes::{BufMut, Bytes}; + +use crate::{codec, ChunkQueue, CodecError, RouteId, RpcCodec}; + +/// rpc where the responder streams a large byte body +/// the caller sends a request +/// the responder sends a typed header for the body +/// the responder streams the raw response bytes +pub trait Download { + const ROUTE: RouteId; + type Error; + /// input needed to start the download + type Request: RpcCodec; + /// details about the body before bytes arrive + type ResponseHeader: RpcCodec; +} + +pub enum ReadStep { + NeedMore(ResponseHeaderReader), + ResponseHeader { + value: M::ResponseHeader, + bytes: ChunkQueue, + }, +} + +pub struct ResponseHeaderReader { + bytes: codec::ChunkQueue, + marker: PhantomData M>, +} + +impl Default for ResponseHeaderReader { + fn default() -> Self { + Self { + bytes: codec::ChunkQueue::new(), + marker: PhantomData, + } + } +} + +impl ResponseHeaderReader { + pub fn push(mut self, chunk: Bytes) -> Self { + self.bytes.push(chunk); + self + } + + pub fn advance(mut self) -> Result, CodecError> { + let Some(mut body) = self.bytes.try_take_part().map_err(CodecError::Rpc)? else { + return Ok(ReadStep::NeedMore(self)); + }; + + let value = { + let value = M::ResponseHeader::decode_value(&mut body).map_err(CodecError::Codec)?; + drop(body); + value + }; + + Ok(ReadStep::ResponseHeader { + value, + bytes: self.bytes, + }) + } +} + +pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { + codec::encode_value_part(request, out) +} + +pub fn encode_response_header( + response_header: &M::ResponseHeader, + out: &mut (impl BufMut + AsMut<[u8]>), +) { + codec::encode_value_part(response_header, out) +} diff --git a/ql-rpc/src/rpc/mod.rs b/ql-rpc/src/rpc/mod.rs index d61a88a6..bd41ce49 100644 --- a/ql-rpc/src/rpc/mod.rs +++ b/ql-rpc/src/rpc/mod.rs @@ -1,9 +1,13 @@ +pub mod download; pub mod notification; pub mod request; pub mod request_with_progress; pub mod subscription; +pub mod upload; +pub use download::Download; pub use notification::Notification; pub use request::Request; pub use request_with_progress::RequestWithProgress; pub use subscription::Subscription; +pub use upload::Upload; diff --git a/ql-rpc/src/rpc/request_with_progress.rs b/ql-rpc/src/rpc/request_with_progress.rs index e24ddfd2..12711996 100644 --- a/ql-rpc/src/rpc/request_with_progress.rs +++ b/ql-rpc/src/rpc/request_with_progress.rs @@ -28,18 +28,14 @@ pub struct ResponseReader { impl Default for ResponseReader { fn default() -> Self { - Self::new() - } -} - -impl ResponseReader { - pub fn new() -> Self { Self { bytes: codec::ChunkQueue::new(), marker: PhantomData, } } +} +impl ResponseReader { pub fn push(mut self, chunk: Bytes) -> Self { self.bytes.push(chunk); self @@ -138,7 +134,7 @@ mod tests { encode_progress::(&b"10%".to_vec(), &mut encoded); encode_response::(&b"done".to_vec(), &mut encoded); - let reader = match ResponseReader::::new() + let reader = match ResponseReader::::default() .push(Bytes::from(encoded)) .advance() .unwrap() @@ -160,7 +156,7 @@ mod tests { let mut encoded = Vec::new(); encode_response::(&b"done".to_vec(), &mut encoded); - match ResponseReader::::new() + match ResponseReader::::default() .push(Bytes::from(encoded)) .advance() .unwrap() diff --git a/ql-rpc/src/rpc/subscription.rs b/ql-rpc/src/rpc/subscription.rs index 70e45280..6d8560d0 100644 --- a/ql-rpc/src/rpc/subscription.rs +++ b/ql-rpc/src/rpc/subscription.rs @@ -26,18 +26,14 @@ pub struct ResponseReader { impl Default for ResponseReader { fn default() -> Self { - Self::new() - } -} - -impl ResponseReader { - pub fn new() -> Self { Self { bytes: codec::ChunkQueue::new(), marker: PhantomData, } } +} +impl ResponseReader { pub fn push(mut self, chunk: Bytes) -> Self { self.bytes.push(chunk); self @@ -98,7 +94,7 @@ mod tests { encode_item::(&b"one".to_vec(), &mut encoded); encode_item::(&b"two".to_vec(), &mut encoded); - let reader = match ResponseReader::::new() + let reader = match ResponseReader::::default() .push(Bytes::from(encoded)) .advance() .unwrap() @@ -129,7 +125,7 @@ mod tests { let mut encoded = Vec::new(); encode_item::(&b"one".to_vec(), &mut encoded); - let reader = match ResponseReader::::new() + let reader = match ResponseReader::::default() .push(Bytes::from(encoded)) .advance() .unwrap() @@ -152,7 +148,7 @@ mod tests { let mut encoded = Vec::new(); encode_item::(&Vec::new(), &mut encoded); - match ResponseReader::::new() + match ResponseReader::::default() .push(Bytes::from(encoded)) .advance() .unwrap() diff --git a/ql-rpc/src/rpc/upload.rs b/ql-rpc/src/rpc/upload.rs new file mode 100644 index 00000000..b534cd6c --- /dev/null +++ b/ql-rpc/src/rpc/upload.rs @@ -0,0 +1,14 @@ +use crate::{RouteId, RpcCodec}; + +/// rpc where the caller streams a large byte body +/// the caller sends a request +/// the caller streams the raw request bytes +/// the responder sends a final typed response +pub trait Upload { + const ROUTE: RouteId; + type Error; + /// input needed to accept the upload + type Request: RpcCodec; + /// final status after all bytes are read + type Response: RpcCodec; +} From bc49b7d7f378e62ad9875f77a2a2c967568672fe Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 18:16:31 -0400 Subject: [PATCH 252/304] ql-runtime: download support --- ql-runtime/src/rpc/adapter.rs | 5 +- ql-runtime/src/rpc/download.rs | 93 ++++++++++++++++++++++++++++++++++ ql-runtime/src/rpc/mod.rs | 26 ++++++++-- ql-runtime/src/tests/rpc.rs | 75 ++++++++++++++++++++++++++- 4 files changed, 191 insertions(+), 8 deletions(-) create mode 100644 ql-runtime/src/rpc/download.rs diff --git a/ql-runtime/src/rpc/adapter.rs b/ql-runtime/src/rpc/adapter.rs index b7f3c0da..13e7eb57 100644 --- a/ql-runtime/src/rpc/adapter.rs +++ b/ql-runtime/src/rpc/adapter.rs @@ -2,8 +2,9 @@ use std::task::{Context, Poll}; use bytes::Bytes; pub use ql_rpc::{ - LocalSpawn, RequestHandler, Response, RouteId, RouterConfig, SendSpawn, StreamCloseCode, - SubscriptionHandler, SubscriptionResponder, + DownloadHandler, DownloadResponder, DownloadWriter, LocalSpawn, RequestHandler, Response, + RouteId, RouterConfig, SendSpawn, StreamCloseCode, SubscriptionHandler, + SubscriptionResponder, }; use ql_rpc::{RpcRead, RpcStream, RpcWrite, StreamError}; use ql_wire::{RouteId as WireRouteId, StreamCloseCode as WireStreamCloseCode}; diff --git a/ql-runtime/src/rpc/download.rs b/ql-runtime/src/rpc/download.rs new file mode 100644 index 00000000..55fd9aa1 --- /dev/null +++ b/ql-runtime/src/rpc/download.rs @@ -0,0 +1,93 @@ +use std::{ + future::poll_fn, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use ql_rpc::{ + download::{Download as DownloadRpc, ReadStep}, + Error, +}; + +use super::RpcError; +use crate::StreamReader; + +pub struct DownloadCall { + pub(super) stream: StreamReader, + pub(super) reader: Option>, +} + +impl Unpin for DownloadCall where M: DownloadRpc {} + +impl DownloadCall +where + M: DownloadRpc, +{ + pub async fn into_reader( + mut self, + ) -> Result<(M::ResponseHeader, DownloadReader), RpcError> { + loop { + let reader = self.reader.take().expect("download reader is present"); + match reader.advance()? { + ReadStep::ResponseHeader { value, bytes } => { + return Ok(( + value, + DownloadReader { + buffered: bytes, + stream: self.stream, + }, + )); + } + ReadStep::NeedMore(next) => { + self.reader = Some(next); + } + } + + match poll_fn(|cx| self.stream.poll_read_chunk(cx)).await? { + Some(chunk) => { + let reader = self.reader.take().expect("download reader is present"); + self.reader = Some(reader.push(chunk)); + } + None => return Err(Error::Truncated.into()), + } + } + } +} + +pub struct DownloadReader { + buffered: ql_rpc::ChunkQueue, + stream: StreamReader, +} + +impl DownloadReader { + pub fn poll_read( + &mut self, + max_len: usize, + cx: &mut Context<'_>, + ) -> Poll, crate::QlStreamError>> { + if let Some(chunk) = self.buffered.pop_front(max_len) { + return Poll::Ready(Ok(Some(chunk))); + } + + self.stream.poll_read(max_len, cx) + } + + pub fn poll_read_chunk( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, crate::QlStreamError>> { + self.poll_read(usize::MAX, cx) + } + + pub async fn read(&mut self, max_len: usize) -> Result, crate::QlStreamError> { + poll_fn(|cx| self.poll_read(max_len, cx)).await + } + + pub async fn read_chunk(&mut self) -> Result, crate::QlStreamError> { + self.read(usize::MAX).await + } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.stream.close(code); + } +} diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index 81ee69a0..9661256f 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -1,4 +1,5 @@ mod adapter; +mod download; mod error; mod request_with_progress; mod subscription; @@ -7,6 +8,7 @@ use std::future::poll_fn; use bytes::Bytes; use ql_rpc::{ + download::{self as rpc_download, Download as DownloadRpc}, notification::{self, Notification}, request::{self, Request as RequestRpc}, request_with_progress::{self as rpc_request_with_progress, RequestWithProgress}, @@ -14,7 +16,7 @@ use ql_rpc::{ Error, ReadValueStep, RpcCodec, ValueReader, }; -pub use self::{adapter::*, error::*, request_with_progress::*, subscription::*}; +pub use self::{adapter::*, download::*, error::*, request_with_progress::*, subscription::*}; use crate::{StreamReader, RuntimeHandle}; #[derive(Clone)] @@ -61,7 +63,23 @@ impl RpcHandle { let response = self.start_request(M::ROUTE, payload).await?; Ok(Subscription { stream: response, - reader: Some(rpc_subscription::ResponseReader::new()), + reader: Some(rpc_subscription::ResponseReader::default()), + }) + } + + pub async fn download( + &self, + request: &M::Request, + ) -> Result, RpcError> + where + M: DownloadRpc, + { + let mut payload = Vec::new(); + rpc_download::encode_request::(request, &mut payload); + let response = self.start_request(M::ROUTE, payload).await?; + Ok(DownloadCall { + stream: response, + reader: Some(rpc_download::ResponseHeaderReader::default()), }) } @@ -77,7 +95,7 @@ impl RpcHandle { let response = self.start_request(M::ROUTE, payload).await?; Ok(ProgressCall { stream: response, - reader: Some(rpc_request_with_progress::ResponseReader::new()), + reader: Some(rpc_request_with_progress::ResponseReader::default()), terminal: None, }) } @@ -101,7 +119,7 @@ async fn read_value(mut reader: StreamReader) -> Result where T: RpcCodec, { - let mut value_reader = ValueReader::::new(); + let mut value_reader = ValueReader::::default(); loop { match value_reader.advance().map_err(RpcError::from)? { diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index f518f7f2..a143981c 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -8,7 +8,9 @@ use std::{ use bytes::Bytes; use futures_lite::StreamExt; -use ql_rpc::{Response, RouteId, StreamCloseCode, SubscriptionResponder}; +use ql_rpc::{ + DownloadResponder, DownloadWriter, Response, RouteId, StreamCloseCode, SubscriptionResponder, +}; use ql_wire::RouteId as WireRouteId; use super::*; @@ -44,6 +46,15 @@ impl ql_rpc::request_with_progress::RequestWithProgress for Download { type Response = Vec; } +struct BlobDownload; + +impl ql_rpc::download::Download for BlobDownload { + const ROUTE: RouteId = RouteId::from_u32(54); + type Error = core::convert::Infallible; + type Request = Vec; + type ResponseHeader = Vec; +} + fn assert_send(value: T) -> T { value } @@ -379,12 +390,72 @@ async fn rpc_request_with_progress_supports_progress_then_await() { .await; } +#[tokio::test(flavor = "current_thread")] +async fn rpc_router_handles_download() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>>, + } + + impl crate::rpc::DownloadHandler for RouterState { + fn handle( + self, + request: Vec, + responder: DownloadResponder, StreamWriter>, + ) { + let seen = self.seen.clone(); + tokio::task::spawn_local(async move { + seen.borrow_mut().push(request); + let mut writer: DownloadWriter = + responder.respond(b"image/png".to_vec()).await.unwrap(); + writer.send(Bytes::from_static(b"abc")).await.unwrap(); + writer.send(Bytes::from_static(b"def")).await.unwrap(); + writer.finish().await.unwrap(); + }); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let seen = Rc::new(RefCell::new(Vec::new())); + + let router = + ql_rpc::Router::<_, QlStream, crate::rpc::LocalSpawn>::builder(crate::rpc::LocalSpawn) + .download::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await; + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let download = rpc.download::(&b"logo".to_vec()).await.unwrap(); + let (header, mut reader) = download.into_reader().await.unwrap(); + assert_eq!(header, b"image/png".to_vec()); + assert_eq!(reader.read_chunk().await.unwrap(), Some(Bytes::from_static(b"abc"))); + assert_eq!(reader.read_chunk().await.unwrap(), Some(Bytes::from_static(b"def"))); + assert_eq!(reader.read_chunk().await.unwrap(), None); + assert_eq!(seen.borrow().as_slice(), &[b"logo".to_vec()]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + async fn read_rpc_value(mut reader: crate::StreamReader) -> T where T: ql_rpc::RpcCodec, T::Error: std::fmt::Debug, { - let mut value_reader = ql_rpc::ValueReader::::new(); + let mut value_reader = ql_rpc::ValueReader::::default(); loop { match value_reader.advance().unwrap() { From e0343dcbcd2d087fd4785be03a1276d33b90ffa2 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 18:37:45 -0400 Subject: [PATCH 253/304] ql-rpc: centralize rpc impl --- ql-rpc/src/error.rs | 50 +++++++++++++++ ql-rpc/src/rpc/subscription.rs | 110 +++++++++++++++++++++++++++------ 2 files changed, 142 insertions(+), 18 deletions(-) diff --git a/ql-rpc/src/error.rs b/ql-rpc/src/error.rs index c4d7d6d9..7404a22e 100644 --- a/ql-rpc/src/error.rs +++ b/ql-rpc/src/error.rs @@ -60,3 +60,53 @@ impl From for CodecError { Self::Rpc(error) } } + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CallError { + Protocol(Error), + Codec(C), + Transport(T), +} + +impl std::fmt::Display for CallError +where + C: std::fmt::Display, + T: std::fmt::Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Protocol(error) => write!(f, "{error}"), + Self::Codec(error) => write!(f, "{error}"), + Self::Transport(error) => write!(f, "{error}"), + } + } +} + +impl std::error::Error for CallError +where + C: std::error::Error + 'static, + T: std::error::Error + 'static, +{ + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + CallError::Protocol(error) => Some(error), + CallError::Codec(error) => Some(error), + CallError::Transport(error) => Some(error), + } + } +} + +impl From for CallError { + fn from(error: Error) -> Self { + Self::Protocol(error) + } +} + +impl From> for CallError { + fn from(error: CodecError) -> Self { + match error { + CodecError::Rpc(error) => Self::Protocol(error), + CodecError::Codec(error) => Self::Codec(error), + } + } +} diff --git a/ql-rpc/src/rpc/subscription.rs b/ql-rpc/src/rpc/subscription.rs index 6d8560d0..1ad68570 100644 --- a/ql-rpc/src/rpc/subscription.rs +++ b/ql-rpc/src/rpc/subscription.rs @@ -1,8 +1,12 @@ -use std::marker::PhantomData; +use std::{ + future::poll_fn, + marker::PhantomData, + task::{Context, Poll}, +}; use bytes::{BufMut, Bytes}; -use crate::{codec, CodecError, RouteId, RpcCodec}; +use crate::{codec, CallError, CodecError, RouteId, RpcCodec, RpcRead}; pub trait Subscription { const ROUTE: RouteId; @@ -11,17 +15,85 @@ pub trait Subscription { type Event: RpcCodec; } -pub enum ReadStep { - NeedMore(ResponseReader), - Item { - value: M::Event, - next: ResponseReader, - }, +pub fn encode_request( + request: &M::Request, + out: &mut (impl BufMut + AsMut<[u8]>), +) { + codec::encode_value_part(request, out) } -pub struct ResponseReader { - bytes: codec::ChunkQueue, - marker: PhantomData M>, +pub fn encode_item(item: &M::Event, out: &mut (impl BufMut + AsMut<[u8]>)) { + codec::encode_value_part(item, out) +} + +pub struct SubscriptionCall +where + M: Subscription, + R: RpcRead, +{ + stream: R, + reader: Option>, +} + +impl SubscriptionCall +where + M: Subscription, + R: RpcRead, +{ + pub fn new(stream: R) -> Self { + Self { + stream, + reader: Some(ResponseReader::default()), + } + } + + pub async fn next_event(&mut self) -> Option>> { + poll_fn(|cx| self.poll_next_event(cx)).await + } + + pub fn poll_next_event( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + loop { + let Some(reader) = self.reader.take() else { + return Poll::Ready(None); + }; + + let reader = match reader.advance() { + Ok(ReadStep::Item { value, next }) => { + self.reader = Some(next); + return Poll::Ready(Some(Ok(value))); + } + Ok(ReadStep::NeedMore(next)) => next, + Err(error) => return Poll::Ready(Some(Err(error.into()))), + }; + + match self.stream.poll_read(usize::MAX, cx) { + Poll::Ready(Ok(Some(chunk))) => { + self.reader = Some(reader.push(chunk)); + } + Poll::Ready(Ok(None)) => { + if reader.is_empty() { + return Poll::Ready(None); + } + return Poll::Ready(Some(Err(crate::Error::Truncated.into()))); + } + Poll::Ready(Err(error)) => { + self.reader = None; + return Poll::Ready(Some(Err(CallError::Transport(error)))); + } + Poll::Pending => { + self.reader = Some(reader); + return Poll::Pending; + } + } + } + } + + pub fn into_inner(self) -> R { + self.stream + } } impl Default for ResponseReader { @@ -61,15 +133,17 @@ impl ResponseReader { } } -pub fn encode_request( - request: &M::Request, - out: &mut (impl BufMut + AsMut<[u8]>), -) { - codec::encode_value_part(request, out) +pub enum ReadStep { + NeedMore(ResponseReader), + Item { + value: M::Event, + next: ResponseReader, + }, } -pub fn encode_item(item: &M::Event, out: &mut (impl BufMut + AsMut<[u8]>)) { - codec::encode_value_part(item, out) +pub struct ResponseReader { + bytes: codec::ChunkQueue, + marker: PhantomData M>, } #[cfg(test)] From 33962a11fb7c7cfdf35638950ea91627cdd3c04f Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 18:38:48 -0400 Subject: [PATCH 254/304] ql-runtime: use rpc SubscriptionCall --- ql-runtime/src/rpc/error.rs | 10 ++++++ ql-runtime/src/rpc/mod.rs | 3 +- ql-runtime/src/rpc/subscription.rs | 49 ++++-------------------------- 3 files changed, 17 insertions(+), 45 deletions(-) diff --git a/ql-runtime/src/rpc/error.rs b/ql-runtime/src/rpc/error.rs index 1e3e03e9..4cc9e176 100644 --- a/ql-runtime/src/rpc/error.rs +++ b/ql-runtime/src/rpc/error.rs @@ -40,6 +40,16 @@ impl From> for RpcError { } } +impl From> for RpcError { + fn from(error: ql_rpc::CallError) -> Self { + match error { + ql_rpc::CallError::Protocol(error) => Self::Protocol(error), + ql_rpc::CallError::Codec(error) => Self::Codec(error), + ql_rpc::CallError::Transport(error) => error.into(), + } + } +} + impl std::fmt::Display for RpcError where E: std::fmt::Display, diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index 9661256f..021d2eec 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -62,8 +62,7 @@ impl RpcHandle { rpc_subscription::encode_request::(request, &mut payload); let response = self.start_request(M::ROUTE, payload).await?; Ok(Subscription { - stream: response, - reader: Some(rpc_subscription::ResponseReader::default()), + inner: rpc_subscription::SubscriptionCall::new(response), }) } diff --git a/ql-runtime/src/rpc/subscription.rs b/ql-runtime/src/rpc/subscription.rs index 00792b21..0dfd807e 100644 --- a/ql-runtime/src/rpc/subscription.rs +++ b/ql-runtime/src/rpc/subscription.rs @@ -4,17 +4,13 @@ use std::{ }; use futures_lite::{future::poll_fn, Stream}; -use ql_rpc::{ - subscription::{ReadStep, Subscription as SubscriptionRpc}, - Error, -}; +use ql_rpc::subscription::Subscription as SubscriptionRpc; use super::RpcError; use crate::StreamReader; pub struct Subscription { - pub(super) stream: StreamReader, - pub(super) reader: Option>, + pub(super) inner: ql_rpc::subscription::SubscriptionCall, } impl Unpin for Subscription where M: SubscriptionRpc {} @@ -35,42 +31,9 @@ where type Item = Result>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - - loop { - let Some(reader) = this.reader.take() else { - return Poll::Ready(None); - }; - - match reader.advance() { - Ok(ReadStep::Item { value, next }) => { - this.reader = Some(next); - return Poll::Ready(Some(Ok(value))); - } - Ok(ReadStep::NeedMore(next)) => { - this.reader = Some(next); - } - Err(error) => return Poll::Ready(Some(Err(error.into()))), - } - - match this.stream.poll_read_chunk(cx) { - Poll::Ready(Ok(Some(chunk))) => { - let reader = this.reader.take().expect("subscription reader is present"); - this.reader = Some(reader.push(chunk)); - } - Poll::Ready(Ok(None)) => { - let reader = this.reader.take().expect("subscription reader is present"); - if reader.is_empty() { - return Poll::Ready(None); - } - return Poll::Ready(Some(Err(Error::Truncated.into()))); - } - Poll::Ready(Err(error)) => { - this.reader = None; - return Poll::Ready(Some(Err(error.into()))); - } - Poll::Pending => return Poll::Pending, - } - } + self.get_mut() + .inner + .poll_next_event(cx) + .map(|item| item.map(|result| Ok(result?))) } } From 8233ba933ad971208affff6ca1fe4be5866be99a Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 18:41:58 -0400 Subject: [PATCH 255/304] ql-rpc: centralize downloadcall --- ql-rpc/src/rpc/download.rs | 129 +++++++++++++++++++++++++++++++++---- 1 file changed, 116 insertions(+), 13 deletions(-) diff --git a/ql-rpc/src/rpc/download.rs b/ql-rpc/src/rpc/download.rs index 0e6e6556..272ceaa7 100644 --- a/ql-rpc/src/rpc/download.rs +++ b/ql-rpc/src/rpc/download.rs @@ -1,8 +1,12 @@ -use std::marker::PhantomData; +use std::{ + future::poll_fn, + marker::PhantomData, + task::{Context, Poll}, +}; use bytes::{BufMut, Bytes}; -use crate::{codec, ChunkQueue, CodecError, RouteId, RpcCodec}; +use crate::{codec, CallError, ChunkQueue, CodecError, RouteId, RpcCodec, RpcRead}; /// rpc where the responder streams a large byte body /// the caller sends a request @@ -17,6 +21,116 @@ pub trait Download { type ResponseHeader: RpcCodec; } +pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { + codec::encode_value_part(request, out) +} + +pub fn encode_response_header( + response_header: &M::ResponseHeader, + out: &mut (impl BufMut + AsMut<[u8]>), +) { + codec::encode_value_part(response_header, out) +} + +pub struct DownloadCall +where + M: Download, + R: RpcRead, +{ + stream: R, + reader: Option>, +} + +pub struct DownloadReader +where + R: RpcRead, +{ + buffered: ChunkQueue, + stream: R, +} + +impl DownloadCall +where + M: Download, + R: RpcRead, +{ + pub fn new(stream: R) -> Self { + Self { + stream, + reader: Some(ResponseHeaderReader::default()), + } + } + + pub async fn into_reader( + mut self, + ) -> Result<(M::ResponseHeader, DownloadReader), CallError> { + loop { + let reader = self.reader.take().expect("download reader is present"); + let reader = match reader.advance() { + Ok(ReadStep::ResponseHeader { value, bytes }) => { + return Ok(( + value, + DownloadReader { + buffered: bytes, + stream: self.stream, + }, + )); + } + Ok(ReadStep::NeedMore(next)) => next, + Err(error) => return Err(error.into()), + }; + + match poll_fn(|cx| self.stream.poll_read(usize::MAX, cx)).await { + Ok(Some(chunk)) => { + self.reader = Some(reader.push(chunk)); + } + Ok(None) => return Err(crate::Error::Truncated.into()), + Err(error) => return Err(CallError::Transport(error)), + } + } + } + + pub fn into_inner(self) -> R { + self.stream + } +} + +impl DownloadReader +where + R: RpcRead, +{ + pub fn poll_read( + &mut self, + max_len: usize, + cx: &mut Context<'_>, + ) -> Poll, R::Error>> { + if let Some(chunk) = self.buffered.pop_front(max_len) { + return Poll::Ready(Ok(Some(chunk))); + } + + self.stream.poll_read(max_len, cx) + } + + pub fn poll_read_chunk( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, R::Error>> { + self.poll_read(usize::MAX, cx) + } + + pub async fn read(&mut self, max_len: usize) -> Result, R::Error> { + poll_fn(|cx| self.poll_read(max_len, cx)).await + } + + pub async fn read_chunk(&mut self) -> Result, R::Error> { + self.read(usize::MAX).await + } + + pub fn into_inner(self) -> R { + self.stream + } +} + pub enum ReadStep { NeedMore(ResponseHeaderReader), ResponseHeader { @@ -62,14 +176,3 @@ impl ResponseHeaderReader { }) } } - -pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { - codec::encode_value_part(request, out) -} - -pub fn encode_response_header( - response_header: &M::ResponseHeader, - out: &mut (impl BufMut + AsMut<[u8]>), -) { - codec::encode_value_part(response_header, out) -} From 1cc2c553ada49e2d1aa3c83105e688d8dc3a8aff Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 18:42:54 -0400 Subject: [PATCH 256/304] ql-runtime: use rpc DownloadCall --- ql-runtime/src/rpc/download.rs | 76 +++++----------------------------- ql-runtime/src/rpc/mod.rs | 3 +- 2 files changed, 12 insertions(+), 67 deletions(-) diff --git a/ql-runtime/src/rpc/download.rs b/ql-runtime/src/rpc/download.rs index 55fd9aa1..ecca8c63 100644 --- a/ql-runtime/src/rpc/download.rs +++ b/ql-runtime/src/rpc/download.rs @@ -1,93 +1,39 @@ -use std::{ - future::poll_fn, - task::{Context, Poll}, -}; - use bytes::Bytes; -use ql_rpc::{ - download::{Download as DownloadRpc, ReadStep}, - Error, -}; +use ql_rpc::download::Download as DownloadRpc; use super::RpcError; use crate::StreamReader; pub struct DownloadCall { - pub(super) stream: StreamReader, - pub(super) reader: Option>, + pub(super) inner: ql_rpc::download::DownloadCall, } -impl Unpin for DownloadCall where M: DownloadRpc {} +pub struct DownloadReader { + pub(super) inner: ql_rpc::download::DownloadReader, +} impl DownloadCall where M: DownloadRpc, { pub async fn into_reader( - mut self, + self, ) -> Result<(M::ResponseHeader, DownloadReader), RpcError> { - loop { - let reader = self.reader.take().expect("download reader is present"); - match reader.advance()? { - ReadStep::ResponseHeader { value, bytes } => { - return Ok(( - value, - DownloadReader { - buffered: bytes, - stream: self.stream, - }, - )); - } - ReadStep::NeedMore(next) => { - self.reader = Some(next); - } - } - - match poll_fn(|cx| self.stream.poll_read_chunk(cx)).await? { - Some(chunk) => { - let reader = self.reader.take().expect("download reader is present"); - self.reader = Some(reader.push(chunk)); - } - None => return Err(Error::Truncated.into()), - } - } + let (header, inner) = self.inner.into_reader().await?; + Ok((header, DownloadReader { inner })) } } -pub struct DownloadReader { - buffered: ql_rpc::ChunkQueue, - stream: StreamReader, -} - impl DownloadReader { - pub fn poll_read( - &mut self, - max_len: usize, - cx: &mut Context<'_>, - ) -> Poll, crate::QlStreamError>> { - if let Some(chunk) = self.buffered.pop_front(max_len) { - return Poll::Ready(Ok(Some(chunk))); - } - - self.stream.poll_read(max_len, cx) - } - - pub fn poll_read_chunk( - &mut self, - cx: &mut Context<'_>, - ) -> Poll, crate::QlStreamError>> { - self.poll_read(usize::MAX, cx) - } - pub async fn read(&mut self, max_len: usize) -> Result, crate::QlStreamError> { - poll_fn(|cx| self.poll_read(max_len, cx)).await + self.inner.read(max_len).await } pub async fn read_chunk(&mut self) -> Result, crate::QlStreamError> { - self.read(usize::MAX).await + self.inner.read_chunk().await } pub fn close(self, code: ql_wire::StreamCloseCode) { - self.stream.close(code); + self.inner.into_inner().close(code); } } diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index 021d2eec..484fcd18 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -77,8 +77,7 @@ impl RpcHandle { rpc_download::encode_request::(request, &mut payload); let response = self.start_request(M::ROUTE, payload).await?; Ok(DownloadCall { - stream: response, - reader: Some(rpc_download::ResponseHeaderReader::default()), + inner: rpc_download::DownloadCall::new(response), }) } From 38c29831dc3568b91501c3449580ebd8c2750f36 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 18:58:16 -0400 Subject: [PATCH 257/304] ql-rpc: refactor internals --- ql-rpc/src/chunk_queue.rs | 208 +++++++++++++++ ql-rpc/src/codec.rs | 237 ++--------------- ql-rpc/src/lib.rs | 28 +- ql-rpc/src/route_id.rs | 19 ++ ql-rpc/src/router/builder.rs | 6 +- ql-rpc/src/router/mod.rs | 9 +- .../rpc/{download.rs => download/client.rs} | 76 +----- ql-rpc/src/rpc/download/codec.rs | 62 +++++ ql-rpc/src/rpc/download/mod.rs | 22 ++ .../download.rs => rpc/download/server.rs} | 16 +- ql-rpc/src/rpc/mod.rs | 72 ++++++ .../codec.rs} | 8 +- ql-rpc/src/rpc/notification/mod.rs | 11 + ql-rpc/src/rpc/request.rs | 18 -- ql-rpc/src/rpc/request/client.rs | 24 ++ ql-rpc/src/rpc/request/mod.rs | 14 + .../request.rs => rpc/request/server.rs} | 55 ++-- .../codec.rs} | 16 +- ql-rpc/src/rpc/request_with_progress/mod.rs | 15 ++ ql-rpc/src/rpc/subscription.rs | 239 ------------------ ql-rpc/src/rpc/subscription/client.rs | 77 ++++++ ql-rpc/src/rpc/subscription/codec.rs | 66 +++++ ql-rpc/src/rpc/subscription/mod.rs | 16 ++ .../subscription/server.rs} | 8 +- ql-rpc/src/rpc/{upload.rs => upload/mod.rs} | 0 ql-runtime/src/rpc/mod.rs | 26 +- ql-runtime/src/tests/rpc.rs | 216 ++++------------ 27 files changed, 724 insertions(+), 840 deletions(-) create mode 100644 ql-rpc/src/chunk_queue.rs create mode 100644 ql-rpc/src/route_id.rs rename ql-rpc/src/rpc/{download.rs => download/client.rs} (55%) create mode 100644 ql-rpc/src/rpc/download/codec.rs create mode 100644 ql-rpc/src/rpc/download/mod.rs rename ql-rpc/src/{router/download.rs => rpc/download/server.rs} (85%) rename ql-rpc/src/rpc/{notification.rs => notification/codec.rs} (50%) create mode 100644 ql-rpc/src/rpc/notification/mod.rs delete mode 100644 ql-rpc/src/rpc/request.rs create mode 100644 ql-rpc/src/rpc/request/client.rs create mode 100644 ql-rpc/src/rpc/request/mod.rs rename ql-rpc/src/{router/request.rs => rpc/request/server.rs} (60%) rename ql-rpc/src/rpc/{request_with_progress.rs => request_with_progress/codec.rs} (90%) create mode 100644 ql-rpc/src/rpc/request_with_progress/mod.rs delete mode 100644 ql-rpc/src/rpc/subscription.rs create mode 100644 ql-rpc/src/rpc/subscription/client.rs create mode 100644 ql-rpc/src/rpc/subscription/codec.rs create mode 100644 ql-rpc/src/rpc/subscription/mod.rs rename ql-rpc/src/{router/subscription.rs => rpc/subscription/server.rs} (90%) rename ql-rpc/src/rpc/{upload.rs => upload/mod.rs} (100%) diff --git a/ql-rpc/src/chunk_queue.rs b/ql-rpc/src/chunk_queue.rs new file mode 100644 index 00000000..d26429a9 --- /dev/null +++ b/ql-rpc/src/chunk_queue.rs @@ -0,0 +1,208 @@ +use std::collections::VecDeque; + +use bytes::{Buf, Bytes}; + +use crate::Error; + +const LENGTH_SIZE: usize = 8; + +#[derive(Debug, Default)] +pub struct ChunkQueue { + chunks: VecDeque, + remaining: usize, +} + +impl ChunkQueue { + pub fn push(&mut self, chunk: Bytes) { + if chunk.is_empty() { + return; + } + self.remaining += chunk.len(); + self.chunks.push_back(chunk); + } + + pub fn remaining(&self) -> usize { + self.remaining + } + + pub fn pop_front(&mut self, max_len: usize) -> Option { + let front = self.chunks.front_mut()?; + let chunk = if max_len >= front.len() { + self.chunks.pop_front().expect("buffered chunk is present") + } else { + front.split_to(max_len) + }; + self.remaining -= chunk.len(); + Some(chunk) + } + + pub fn pop_front_chunk(&mut self) -> Option { + self.pop_front(usize::MAX) + } + + pub fn try_take_part(&mut self) -> Result>, Error> { + let Some(len) = self.peek_next_part_len()? else { + return Ok(None); + }; + self.advance(LENGTH_SIZE); + Ok(Some(DrainBuf::new(self, len))) + } + + pub fn try_take_tagged_part(&mut self) -> Result)>, Error> { + let mut bytes = self.peek(); + let Ok(kind) = bytes.try_get_u8() else { + return Ok(None); + }; + let Some(len) = read_next_part_len(&mut bytes)? else { + return Ok(None); + }; + + self.advance(1 + LENGTH_SIZE); + Ok(Some((kind, DrainBuf::new(self, len)))) + } + + fn peek_next_part_len(&self) -> Result, Error> { + let mut bytes = self.peek(); + read_next_part_len(&mut bytes) + } + + fn peek(&self) -> ChunkQueuePeek<'_> { + ChunkQueuePeek { + chunks: &self.chunks, + chunk_index: 0, + chunk_offset: 0, + remaining: self.remaining, + } + } + + fn front_chunk(&self, limit: usize) -> &[u8] { + let Some(chunk) = self.chunks.front() else { + return &[]; + }; + &chunk[..chunk.len().min(limit)] + } + + fn advance_inner(&mut self, mut cnt: usize) { + assert!(cnt <= self.remaining, "advanced past buffered data"); + self.remaining -= cnt; + while cnt > 0 { + let front = self.chunks.front_mut().expect("buffered data present"); + let consumed = cnt.min(front.len()); + front.advance(consumed); + cnt -= consumed; + if front.is_empty() { + self.chunks.pop_front(); + } + } + } +} + +struct ChunkQueuePeek<'a> { + chunks: &'a VecDeque, + chunk_index: usize, + chunk_offset: usize, + remaining: usize, +} + +impl Buf for ChunkQueuePeek<'_> { + fn remaining(&self) -> usize { + self.remaining + } + + fn chunk(&self) -> &[u8] { + if self.remaining == 0 { + return &[]; + } + + let Some(chunk) = self.chunks.get(self.chunk_index) else { + return &[]; + }; + &chunk[self.chunk_offset..] + } + + fn advance(&mut self, mut cnt: usize) { + assert!(cnt <= self.remaining, "advanced past buffered data"); + self.remaining -= cnt; + + while cnt > 0 { + let chunk = self + .chunks + .get(self.chunk_index) + .expect("buffered data present"); + let available = chunk.len() - self.chunk_offset; + let step = cnt.min(available); + self.chunk_offset += step; + cnt -= step; + if self.chunk_offset == chunk.len() { + self.chunk_index += 1; + self.chunk_offset = 0; + } + } + } +} + +impl Buf for ChunkQueue { + fn remaining(&self) -> usize { + self.remaining + } + + fn chunk(&self) -> &[u8] { + self.front_chunk(self.remaining) + } + + fn advance(&mut self, cnt: usize) { + assert!(cnt <= self.remaining, "advanced past buffered data"); + self.advance_inner(cnt); + } +} + +pub struct DrainBuf<'a> { + bytes: &'a mut ChunkQueue, + remaining: usize, +} + +impl<'a> DrainBuf<'a> { + pub fn new(bytes: &'a mut ChunkQueue, len: usize) -> Self { + debug_assert!(bytes.remaining() >= len); + Self { + bytes, + remaining: len, + } + } +} + +impl Buf for DrainBuf<'_> { + fn remaining(&self) -> usize { + self.remaining + } + + fn chunk(&self) -> &[u8] { + self.bytes.front_chunk(self.remaining) + } + + fn advance(&mut self, cnt: usize) { + assert!(cnt <= self.remaining(), "advanced past payload boundary"); + self.bytes.advance_inner(cnt); + self.remaining -= cnt; + } +} + +impl Drop for DrainBuf<'_> { + fn drop(&mut self) { + if self.remaining > 0 { + self.bytes.advance_inner(self.remaining); + self.remaining = 0; + } + } +} + +fn read_next_part_len(bytes: &mut B) -> Result, Error> { + let Ok(len) = bytes.try_get_u64_le() else { + return Ok(None); + }; + let len: usize = len.try_into().map_err(|_| Error::LengthOverflow)?; + if bytes.remaining() < len { + return Ok(None); + } + Ok(Some(len)) +} diff --git a/ql-rpc/src/codec.rs b/ql-rpc/src/codec.rs index 1cad921f..231dfee0 100644 --- a/ql-rpc/src/codec.rs +++ b/ql-rpc/src/codec.rs @@ -1,7 +1,8 @@ -use std::{collections::VecDeque, convert::Infallible, marker::PhantomData, str::Utf8Error}; +use std::{convert::Infallible, marker::PhantomData, str::Utf8Error}; use bytes::{Buf, BufMut, Bytes}; +pub use crate::chunk_queue::ChunkQueue; use crate::{CodecError, Error}; pub trait RpcCodec: Sized { @@ -67,30 +68,27 @@ pub fn encode_value_part>(value: &T, out: & backpatch_length(out, payload_start); } -pub enum ReadValueStep { - NeedMore(ValueReader), - Value(T), -} - -pub struct ValueReader { +/// reads one length-delimited rpc value from buffered byte chunks +pub struct FramedValueReader { bytes: ChunkQueue, marker: PhantomData T>, } -impl Default for ValueReader { - fn default() -> Self { - Self::new() - } +pub enum ReadValueStep { + NeedMore(FramedValueReader), + Value(T), } -impl ValueReader { - pub fn new() -> Self { +impl Default for FramedValueReader { + fn default() -> Self { Self { - bytes: ChunkQueue::new(), + bytes: ChunkQueue::default(), marker: PhantomData, } } +} +impl FramedValueReader { pub fn push(mut self, chunk: Bytes) -> Self { self.bytes.push(chunk); self @@ -111,211 +109,6 @@ impl ValueReader { } } -#[derive(Debug, Default)] -pub struct ChunkQueue { - chunks: VecDeque, - remaining: usize, -} - -impl ChunkQueue { - pub fn new() -> Self { - Self::default() - } - - pub fn push(&mut self, chunk: Bytes) { - if chunk.is_empty() { - return; - } - self.remaining += chunk.len(); - self.chunks.push_back(chunk); - } - - pub fn remaining(&self) -> usize { - self.remaining - } - - pub fn pop_front(&mut self, max_len: usize) -> Option { - let front = self.chunks.front_mut()?; - let chunk = if max_len >= front.len() { - self.chunks.pop_front().expect("buffered chunk is present") - } else { - front.split_to(max_len) - }; - self.remaining -= chunk.len(); - Some(chunk) - } - - pub fn pop_front_chunk(&mut self) -> Option { - self.pop_front(usize::MAX) - } - - pub fn try_take_part(&mut self) -> Result>, Error> { - let Some(len) = self.peek_next_part_len()? else { - return Ok(None); - }; - self.advance(LENGTH_SIZE); - Ok(Some(DrainBuf::new(self, len))) - } - - pub fn try_take_tagged_part(&mut self) -> Result)>, Error> { - let mut bytes = self.peek(); - let Ok(kind) = bytes.try_get_u8() else { - return Ok(None); - }; - let Some(len) = read_next_part_len(&mut bytes)? else { - return Ok(None); - }; - - self.advance(1 + LENGTH_SIZE); - Ok(Some((kind, DrainBuf::new(self, len)))) - } - - fn peek_next_part_len(&self) -> Result, Error> { - let mut bytes = self.peek(); - read_next_part_len(&mut bytes) - } - - fn peek(&self) -> ChunkQueuePeek<'_> { - ChunkQueuePeek { - chunks: &self.chunks, - chunk_index: 0, - chunk_offset: 0, - remaining: self.remaining, - } - } - - fn front_chunk(&self, limit: usize) -> &[u8] { - let Some(chunk) = self.chunks.front() else { - return &[]; - }; - &chunk[..chunk.len().min(limit)] - } - - fn advance_inner(&mut self, mut cnt: usize) { - assert!(cnt <= self.remaining, "advanced past buffered data"); - self.remaining -= cnt; - while cnt > 0 { - let front = self.chunks.front_mut().expect("buffered data present"); - let consumed = cnt.min(front.len()); - front.advance(consumed); - cnt -= consumed; - if front.is_empty() { - self.chunks.pop_front(); - } - } - } -} - -struct ChunkQueuePeek<'a> { - chunks: &'a VecDeque, - chunk_index: usize, - chunk_offset: usize, - remaining: usize, -} - -impl Buf for ChunkQueuePeek<'_> { - fn remaining(&self) -> usize { - self.remaining - } - - fn chunk(&self) -> &[u8] { - if self.remaining == 0 { - return &[]; - } - - let Some(chunk) = self.chunks.get(self.chunk_index) else { - return &[]; - }; - &chunk[self.chunk_offset..] - } - - fn advance(&mut self, mut cnt: usize) { - assert!(cnt <= self.remaining, "advanced past buffered data"); - self.remaining -= cnt; - - while cnt > 0 { - let chunk = self - .chunks - .get(self.chunk_index) - .expect("buffered data present"); - let available = chunk.len() - self.chunk_offset; - let step = cnt.min(available); - self.chunk_offset += step; - cnt -= step; - if self.chunk_offset == chunk.len() { - self.chunk_index += 1; - self.chunk_offset = 0; - } - } - } -} - -impl Buf for ChunkQueue { - fn remaining(&self) -> usize { - self.remaining - } - - fn chunk(&self) -> &[u8] { - self.front_chunk(self.remaining) - } - - fn advance(&mut self, cnt: usize) { - assert!(cnt <= self.remaining, "advanced past buffered data"); - self.advance_inner(cnt); - } -} - -pub struct DrainBuf<'a> { - bytes: &'a mut ChunkQueue, - remaining: usize, -} - -impl<'a> DrainBuf<'a> { - pub fn new(bytes: &'a mut ChunkQueue, len: usize) -> Self { - debug_assert!(bytes.remaining() >= len); - Self { - bytes, - remaining: len, - } - } -} - -impl Buf for DrainBuf<'_> { - fn remaining(&self) -> usize { - self.remaining - } - - fn chunk(&self) -> &[u8] { - self.bytes.front_chunk(self.remaining) - } - - fn advance(&mut self, cnt: usize) { - assert!(cnt <= self.remaining(), "advanced past payload boundary"); - self.bytes.advance_inner(cnt); - self.remaining -= cnt; - } -} - -impl Drop for DrainBuf<'_> { - fn drop(&mut self) { - if self.remaining > 0 { - self.bytes.advance_inner(self.remaining); - self.remaining = 0; - } - } -} - -fn read_next_part_len(bytes: &mut B) -> Result, Error> { - let Ok(len) = bytes.try_get_u64_le() else { - return Ok(None); - }; - let len: usize = len.try_into().map_err(|_| Error::LengthOverflow)?; - if bytes.remaining() < len { - return Ok(None); - } - Ok(Some(len)) -} - pub fn reserve_length>(out: &mut B) -> usize { let start = out.as_mut().len(); out.put_u64_le(0); @@ -334,14 +127,14 @@ pub fn backpatch_length + ?Sized>(out: &mut B, start: usize) { mod tests { use bytes::Bytes; - use super::{encode_value_part, ReadValueStep, ValueReader}; + use super::{encode_value_part, FramedValueReader, ReadValueStep}; #[test] fn value_reader_round_trips_framed_values() { let mut encoded = Vec::new(); encode_value_part(&b"hello".to_vec(), &mut encoded); - match ValueReader::>::new() + match FramedValueReader::>::default() .push(Bytes::from(encoded)) .advance() .unwrap() @@ -357,7 +150,7 @@ mod tests { encode_value_part(&b"hello".to_vec(), &mut encoded); let encoded = Bytes::from(encoded); - let reader = match ValueReader::>::new() + let reader = match FramedValueReader::>::default() .push(encoded.slice(..4)) .advance() .unwrap() diff --git a/ql-rpc/src/lib.rs b/ql-rpc/src/lib.rs index e81c3014..3b92d5ec 100644 --- a/ql-rpc/src/lib.rs +++ b/ql-rpc/src/lib.rs @@ -1,37 +1,21 @@ //! quantum link rpc protocol traits and framing helpers. +mod chunk_queue; pub(crate) mod codec; mod error; +mod route_id; mod router; -pub mod rpc; +mod rpc; mod stream; -pub use codec::{ChunkQueue, ReadValueStep, RpcCodec, ValueReader}; +pub use chunk_queue::ChunkQueue; +pub use codec::{FramedValueReader, ReadValueStep, RpcCodec}; pub use error::*; +pub use route_id::RouteId; pub use router::*; pub use rpc::*; pub use stream::*; -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -#[repr(transparent)] -pub struct RouteId(pub u32); - -impl RouteId { - pub const fn from_u32(value: u32) -> Self { - Self(value) - } - - pub const fn into_inner(self) -> u32 { - self.0 - } -} - -impl From for RouteId { - fn from(value: u32) -> Self { - Self::from_u32(value) - } -} - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[repr(transparent)] pub struct StreamCloseCode(pub u16); diff --git a/ql-rpc/src/route_id.rs b/ql-rpc/src/route_id.rs new file mode 100644 index 00000000..1b054e74 --- /dev/null +++ b/ql-rpc/src/route_id.rs @@ -0,0 +1,19 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct RouteId(pub u32); + +impl RouteId { + pub const fn from_u32(value: u32) -> Self { + Self(value) + } + + pub const fn into_inner(self) -> u32 { + self.0 + } +} + +impl From for RouteId { + fn from(value: u32) -> Self { + Self::from_u32(value) + } +} diff --git a/ql-rpc/src/router/builder.rs b/ql-rpc/src/router/builder.rs index cee40df0..0b0f6ab7 100644 --- a/ql-rpc/src/router/builder.rs +++ b/ql-rpc/src/router/builder.rs @@ -1,15 +1,15 @@ use std::collections::HashMap; use super::{ - download::{handle_download_inner, DownloadHandler}, - request::{handle_request_inner, RequestHandler}, - subscription::{handle_subscription_inner, SubscriptionHandler}, LocalSpawn, LocalSpawner, RouteFn, Router, RouterConfig, RpcStream, SendSpawn, SendSpawner, Spawner, }; use crate::{ download::Download as DownloadRpc, + download::server::{handle_download_inner, DownloadHandler}, request::Request as RequestRpc, subscription::Subscription as SubscriptionRpc, RouteId, + request::server::{handle_request_inner, RequestHandler}, + subscription::server::{handle_subscription_inner, SubscriptionHandler}, }; pub struct RouterBuilder diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs index 701a03b3..c0ee63d3 100644 --- a/ql-rpc/src/router/mod.rs +++ b/ql-rpc/src/router/mod.rs @@ -4,19 +4,16 @@ use crate::{RouteId, StreamCloseCode}; mod builder; mod config; -mod download; mod mode; -mod request; -mod subscription; pub use self::{ builder::RouterBuilder, config::RouterConfig, - download::{DownloadHandler, DownloadResponder, DownloadWriter}, mode::*, - request::{RequestHandler, Response}, - subscription::{SubscriptionHandler, SubscriptionResponder}, }; +pub use crate::download::{DownloadHandler, DownloadResponder, DownloadWriter}; +pub use crate::request::{RequestHandler, Response}; +pub use crate::subscription::{SubscriptionHandler, SubscriptionResponder}; use crate::{close_stream, RpcStream}; pub struct Router diff --git a/ql-rpc/src/rpc/download.rs b/ql-rpc/src/rpc/download/client.rs similarity index 55% rename from ql-rpc/src/rpc/download.rs rename to ql-rpc/src/rpc/download/client.rs index 272ceaa7..6d262a3e 100644 --- a/ql-rpc/src/rpc/download.rs +++ b/ql-rpc/src/rpc/download/client.rs @@ -1,36 +1,12 @@ use std::{ future::poll_fn, - marker::PhantomData, task::{Context, Poll}, }; -use bytes::{BufMut, Bytes}; +use bytes::Bytes; -use crate::{codec, CallError, ChunkQueue, CodecError, RouteId, RpcCodec, RpcRead}; - -/// rpc where the responder streams a large byte body -/// the caller sends a request -/// the responder sends a typed header for the body -/// the responder streams the raw response bytes -pub trait Download { - const ROUTE: RouteId; - type Error; - /// input needed to start the download - type Request: RpcCodec; - /// details about the body before bytes arrive - type ResponseHeader: RpcCodec; -} - -pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { - codec::encode_value_part(request, out) -} - -pub fn encode_response_header( - response_header: &M::ResponseHeader, - out: &mut (impl BufMut + AsMut<[u8]>), -) { - codec::encode_value_part(response_header, out) -} +use crate::{CallError, ChunkQueue, RpcRead}; +use crate::download::{Download, ReadStep, ResponseHeaderReader}; pub struct DownloadCall where @@ -130,49 +106,3 @@ where self.stream } } - -pub enum ReadStep { - NeedMore(ResponseHeaderReader), - ResponseHeader { - value: M::ResponseHeader, - bytes: ChunkQueue, - }, -} - -pub struct ResponseHeaderReader { - bytes: codec::ChunkQueue, - marker: PhantomData M>, -} - -impl Default for ResponseHeaderReader { - fn default() -> Self { - Self { - bytes: codec::ChunkQueue::new(), - marker: PhantomData, - } - } -} - -impl ResponseHeaderReader { - pub fn push(mut self, chunk: Bytes) -> Self { - self.bytes.push(chunk); - self - } - - pub fn advance(mut self) -> Result, CodecError> { - let Some(mut body) = self.bytes.try_take_part().map_err(CodecError::Rpc)? else { - return Ok(ReadStep::NeedMore(self)); - }; - - let value = { - let value = M::ResponseHeader::decode_value(&mut body).map_err(CodecError::Codec)?; - drop(body); - value - }; - - Ok(ReadStep::ResponseHeader { - value, - bytes: self.bytes, - }) - } -} diff --git a/ql-rpc/src/rpc/download/codec.rs b/ql-rpc/src/rpc/download/codec.rs new file mode 100644 index 00000000..48b2e53a --- /dev/null +++ b/ql-rpc/src/rpc/download/codec.rs @@ -0,0 +1,62 @@ +use std::marker::PhantomData; + +use bytes::{BufMut, Bytes}; + +use crate::{codec, download::Download, ChunkQueue, CodecError, RpcCodec}; + +pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { + codec::encode_value_part(request, out) +} + +pub fn encode_response_header( + response_header: &M::ResponseHeader, + out: &mut (impl BufMut + AsMut<[u8]>), +) { + codec::encode_value_part(response_header, out) +} + +pub enum ReadStep { + NeedMore(ResponseHeaderReader), + ResponseHeader { + value: M::ResponseHeader, + bytes: ChunkQueue, + }, +} + +pub struct ResponseHeaderReader { + bytes: codec::ChunkQueue, + marker: PhantomData M>, +} + +impl Default for ResponseHeaderReader { + fn default() -> Self { + Self { + bytes: codec::ChunkQueue::default(), + marker: PhantomData, + } + } +} + +impl ResponseHeaderReader { + pub fn push(mut self, chunk: Bytes) -> Self { + self.bytes.push(chunk); + self + } + + pub fn advance(mut self) -> Result, CodecError> { + let Some(mut body) = self.bytes.try_take_part()? else { + return Ok(ReadStep::NeedMore(self)); + }; + + let value = { + let value = M::ResponseHeader::decode_value(&mut body).map_err(CodecError::Codec)?; + drop(body); + value + }; + + Ok(ReadStep::ResponseHeader { + value, + bytes: self.bytes, + }) + } +} diff --git a/ql-rpc/src/rpc/download/mod.rs b/ql-rpc/src/rpc/download/mod.rs new file mode 100644 index 00000000..761bf6b6 --- /dev/null +++ b/ql-rpc/src/rpc/download/mod.rs @@ -0,0 +1,22 @@ +use crate::{RouteId, RpcCodec}; + +pub(crate) mod client; +pub(crate) mod codec; +pub(crate) mod server; + +pub use client::{DownloadCall, DownloadReader}; +pub use codec::{encode_request, encode_response_header, ReadStep, ResponseHeaderReader}; +pub use server::{DownloadHandler, DownloadResponder, DownloadWriter}; + +/// rpc where the responder streams a large byte body +/// the caller sends a request +/// the responder sends a typed header for the body +/// the responder streams the raw response bytes +pub trait Download { + const ROUTE: RouteId; + type Error; + /// input needed to start the download + type Request: RpcCodec; + /// details about the body before bytes arrive + type ResponseHeader: RpcCodec; +} diff --git a/ql-rpc/src/router/download.rs b/ql-rpc/src/rpc/download/server.rs similarity index 85% rename from ql-rpc/src/router/download.rs rename to ql-rpc/src/rpc/download/server.rs index 50922783..aab41dd0 100644 --- a/ql-rpc/src/router/download.rs +++ b/ql-rpc/src/rpc/download/server.rs @@ -2,11 +2,11 @@ use std::marker::PhantomData; use bytes::Bytes; -use super::{request::read_value_and_eof, RouterConfig}; use crate::{ - codec, download::Download as DownloadRpc, finish_bytes, write_bytes, RpcCodec, RpcStream, - RpcRead, RpcWrite, StreamCloseCode, StreamError, + codec, download::Download as DownloadRpc, finish_bytes, write_bytes, RpcCodec, RpcRead, + RpcStream, RpcWrite, StreamCloseCode, StreamError, }; +use crate::{rpc::read_framed_value, RouterConfig}; pub trait DownloadHandler where @@ -42,7 +42,7 @@ where T: RpcCodec, W: RpcWrite, { - fn new(writer: W) -> Self { + pub(crate) fn new(writer: W) -> Self { Self { writer: Some(writer), marker: PhantomData, @@ -82,12 +82,12 @@ where W: RpcWrite, { pub async fn send(&mut self, bytes: Bytes) -> Result<(), W::Error> { - let writer = self.writer.as_mut().expect("download body writer exists"); + let writer = self.writer.as_mut().expect("download writer exists"); write_bytes(writer, bytes).await } pub async fn finish(mut self) -> Result<(), W::Error> { - let mut writer = self.writer.take().expect("download body writer exists"); + let mut writer = self.writer.take().expect("download writer exists"); finish_bytes(&mut writer).await } @@ -109,7 +109,7 @@ where } } -pub(super) async fn handle_download_inner( +pub(crate) async fn handle_download_inner( state: S, config: RouterConfig, mut reader: St::Reader, @@ -119,7 +119,7 @@ pub(super) async fn handle_download_inner( S: DownloadHandler + 'static, St: RpcStream + 'static, { - let request = match read_value_and_eof::(&mut reader, config).await { + let request = match read_framed_value::(&mut reader, config).await { Ok(request) => request, Err(error) => { let code = error.close_code(); diff --git a/ql-rpc/src/rpc/mod.rs b/ql-rpc/src/rpc/mod.rs index bd41ce49..ba7823dd 100644 --- a/ql-rpc/src/rpc/mod.rs +++ b/ql-rpc/src/rpc/mod.rs @@ -1,3 +1,8 @@ +use crate::{ + read_bytes, CallError, ChunkQueue, FramedValueReader, ReadValueStep, RouterConfig, RpcCodec, RpcRead, + StreamCloseCode, +}; + pub mod download; pub mod notification; pub mod request; @@ -11,3 +16,70 @@ pub use request::Request; pub use request_with_progress::RequestWithProgress; pub use subscription::Subscription; pub use upload::Upload; + +/// reads one length-delimited value and rejects trailing bytes +async fn read_framed_value( + reader: &mut R, + config: RouterConfig, +) -> Result +where + T: RpcCodec, + R: RpcRead, +{ + let mut value_reader = FramedValueReader::::default(); + let mut total_read = 0usize; + + let value = loop { + match value_reader.advance() { + Ok(ReadValueStep::Value(value)) => break value, + Ok(ReadValueStep::NeedMore(next)) => value_reader = next, + Err(crate::CodecError::Rpc(_error)) => return Err(StreamCloseCode::REFUSED.into()), + Err(crate::CodecError::Codec(_error)) => return Err(StreamCloseCode::REFUSED.into()), + } + + let remaining = config.max_request_bytes.saturating_sub(total_read); + if remaining == 0 { + return Err(StreamCloseCode::LIMIT.into()); + } + + match read_bytes(reader, remaining).await { + Ok(Some(chunk)) => { + total_read += chunk.len(); + value_reader = value_reader.push(chunk); + } + Ok(None) => return Err(StreamCloseCode::REFUSED.into()), + Err(error) => return Err(error), + } + }; + + let remaining = config.max_request_bytes.saturating_sub(total_read); + let probe = remaining.max(1); + match read_bytes(reader, probe).await { + Ok(None) => Ok(value), + Ok(Some(_)) if remaining == 0 => Err(StreamCloseCode::LIMIT.into()), + Ok(Some(_)) => Err(StreamCloseCode::REFUSED.into()), + Err(error) => Err(error), + } +} + +/// reads one eof-delimited value and rejects trailing bytes +async fn read_whole_value(reader: &mut R) -> Result> +where + T: RpcCodec, + R: RpcRead, +{ + let mut bytes = ChunkQueue::default(); + + while let Some(chunk) = read_bytes(reader, usize::MAX) + .await + .map_err(CallError::Transport)? + { + bytes.push(chunk); + } + + let value = T::decode_value(&mut bytes).map_err(CallError::Codec)?; + if bytes.remaining() > 0 { + return Err(crate::Error::TrailingBytes.into()); + } + Ok(value) +} diff --git a/ql-rpc/src/rpc/notification.rs b/ql-rpc/src/rpc/notification/codec.rs similarity index 50% rename from ql-rpc/src/rpc/notification.rs rename to ql-rpc/src/rpc/notification/codec.rs index 7db5656d..ec33ed73 100644 --- a/ql-rpc/src/rpc/notification.rs +++ b/ql-rpc/src/rpc/notification/codec.rs @@ -1,12 +1,6 @@ use bytes::BufMut; -use crate::{codec, RouteId, RpcCodec}; - -pub trait Notification { - const ROUTE: RouteId; - type Error; - type Event: RpcCodec; -} +use crate::{codec, notification::Notification}; pub fn encode_event(event: &M::Event, out: &mut (impl BufMut + AsMut<[u8]>)) { codec::encode_value_part(event, out) diff --git a/ql-rpc/src/rpc/notification/mod.rs b/ql-rpc/src/rpc/notification/mod.rs new file mode 100644 index 00000000..81110173 --- /dev/null +++ b/ql-rpc/src/rpc/notification/mod.rs @@ -0,0 +1,11 @@ +use crate::{RouteId, RpcCodec}; + +pub(crate) mod codec; + +pub use codec::encode_event; + +pub trait Notification { + const ROUTE: RouteId; + type Error; + type Event: RpcCodec; +} diff --git a/ql-rpc/src/rpc/request.rs b/ql-rpc/src/rpc/request.rs deleted file mode 100644 index 8190aa3e..00000000 --- a/ql-rpc/src/rpc/request.rs +++ /dev/null @@ -1,18 +0,0 @@ -use bytes::BufMut; - -use crate::{codec, RouteId, RpcCodec}; - -pub trait Request { - const ROUTE: RouteId; - type Error; - type Request: RpcCodec; - type Response: RpcCodec; -} - -pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { - codec::encode_value_part(request, out) -} - -pub fn encode_response(response: &M::Response, out: &mut (impl BufMut + AsMut<[u8]>)) { - codec::encode_value_part(response, out) -} diff --git a/ql-rpc/src/rpc/request/client.rs b/ql-rpc/src/rpc/request/client.rs new file mode 100644 index 00000000..5d936af0 --- /dev/null +++ b/ql-rpc/src/rpc/request/client.rs @@ -0,0 +1,24 @@ +use bytes::BufMut; + +use crate::{request::Request, rpc::read_whole_value, CallError, RpcCodec, RpcRead}; + +pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { + request.encode_value(out) +} + +pub fn encode_response( + response: &M::Response, + out: &mut (impl BufMut + AsMut<[u8]>), +) { + response.encode_value(out) +} + +pub async fn read_response( + mut reader: R, +) -> Result> +where + M: Request, + R: RpcRead, +{ + read_whole_value::(&mut reader).await +} diff --git a/ql-rpc/src/rpc/request/mod.rs b/ql-rpc/src/rpc/request/mod.rs new file mode 100644 index 00000000..4f84a2ef --- /dev/null +++ b/ql-rpc/src/rpc/request/mod.rs @@ -0,0 +1,14 @@ +use crate::{RouteId, RpcCodec}; + +pub(crate) mod client; +pub(crate) mod server; + +pub use client::{encode_request, encode_response, read_response}; +pub use server::{RequestHandler, Response}; + +pub trait Request { + const ROUTE: RouteId; + type Error; + type Request: RpcCodec; + type Response: RpcCodec; +} diff --git a/ql-rpc/src/router/request.rs b/ql-rpc/src/rpc/request/server.rs similarity index 60% rename from ql-rpc/src/router/request.rs rename to ql-rpc/src/rpc/request/server.rs index 452d7880..b94e795c 100644 --- a/ql-rpc/src/router/request.rs +++ b/ql-rpc/src/rpc/request/server.rs @@ -2,12 +2,13 @@ use std::marker::PhantomData; use bytes::Bytes; -use super::RouterConfig; use crate::{ - codec, finish_bytes, read_bytes, request::Request as RequestRpc, write_bytes, ReadValueStep, - RpcCodec, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, ValueReader, + finish_bytes, read_bytes, request::Request as RequestRpc, write_bytes, ChunkQueue, + RpcCodec, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, }; +use crate::RouterConfig; + pub trait RequestHandler where M: RequestRpc, @@ -31,7 +32,7 @@ where T: RpcCodec, W: RpcWrite, { - fn new(writer: W) -> Self { + pub(crate) fn new(writer: W) -> Self { Self { writer: Some(writer), marker: PhantomData, @@ -41,7 +42,7 @@ where pub async fn respond(mut self, response: T) -> Result<(), W::Error> { let mut writer = self.writer.take().expect("response writer exists"); let mut encoded = Vec::new(); - codec::encode_value_part(&response, &mut encoded); + response.encode_value(&mut encoded); write_bytes(&mut writer, Bytes::from(encoded)).await?; finish_bytes(&mut writer).await?; Ok(()) @@ -65,7 +66,7 @@ where } } -pub(super) async fn handle_request_inner( +pub(crate) async fn handle_request_inner( state: S, config: RouterConfig, mut reader: St::Reader, @@ -75,7 +76,7 @@ pub(super) async fn handle_request_inner( S: RequestHandler + 'static, St: RpcStream + 'static, { - let request = match read_value_and_eof::(&mut reader, config).await { + let request = match read_whole_value::(&mut reader, config).await { Ok(request) => request, Err(error) => { let code = error.close_code(); @@ -91,7 +92,7 @@ pub(super) async fn handle_request_inner( state.handle(request, Response::new(writer)); } -pub(super) async fn read_value_and_eof( +pub(crate) async fn read_whole_value( reader: &mut R, config: RouterConfig, ) -> Result @@ -99,38 +100,28 @@ where T: RpcCodec, R: RpcRead, { - let mut value_reader = ValueReader::::new(); + let mut bytes = ChunkQueue::default(); let mut total_read = 0usize; - let value = loop { - match value_reader.advance() { - Ok(ReadValueStep::Value(value)) => break value, - Ok(ReadValueStep::NeedMore(next)) => value_reader = next, - Err(crate::CodecError::Rpc(_error)) => return Err(StreamCloseCode::REFUSED.into()), - Err(crate::CodecError::Codec(_error)) => return Err(StreamCloseCode::REFUSED.into()), - } - + loop { let remaining = config.max_request_bytes.saturating_sub(total_read); - if remaining == 0 { - return Err(StreamCloseCode::LIMIT.into()); - } - - match read_bytes(reader, remaining).await { + let probe = remaining.max(1); + match read_bytes(reader, probe).await { Ok(Some(chunk)) => { + if chunk.len() > remaining { + return Err(StreamCloseCode::LIMIT.into()); + } total_read += chunk.len(); - value_reader = value_reader.push(chunk); + bytes.push(chunk); } - Ok(None) => return Err(StreamCloseCode::REFUSED.into()), + Ok(None) => break, Err(error) => return Err(error), } - }; + } - let remaining = config.max_request_bytes.saturating_sub(total_read); - let probe = remaining.max(1); - match read_bytes(reader, probe).await { - Ok(None) => Ok(value), - Ok(Some(_)) if remaining == 0 => Err(StreamCloseCode::LIMIT.into()), - Ok(Some(_)) => Err(StreamCloseCode::REFUSED.into()), - Err(error) => Err(error), + let value = T::decode_value(&mut bytes).map_err(|_error| StreamCloseCode::REFUSED)?; + if bytes.remaining() > 0 { + return Err(StreamCloseCode::REFUSED.into()); } + Ok(value) } diff --git a/ql-rpc/src/rpc/request_with_progress.rs b/ql-rpc/src/rpc/request_with_progress/codec.rs similarity index 90% rename from ql-rpc/src/rpc/request_with_progress.rs rename to ql-rpc/src/rpc/request_with_progress/codec.rs index 12711996..4d8095ee 100644 --- a/ql-rpc/src/rpc/request_with_progress.rs +++ b/ql-rpc/src/rpc/request_with_progress/codec.rs @@ -2,15 +2,7 @@ use std::marker::PhantomData; use bytes::{BufMut, Bytes}; -use crate::{codec, CodecError, Error, RouteId, RpcCodec}; - -pub trait RequestWithProgress { - const ROUTE: RouteId; - type Error; - type Request: RpcCodec; - type Progress: RpcCodec; - type Response: RpcCodec; -} +use crate::{codec, request_with_progress::RequestWithProgress, CodecError, Error, RpcCodec}; pub enum ReadStep { NeedMore(ResponseReader), @@ -29,7 +21,7 @@ pub struct ResponseReader { impl Default for ResponseReader { fn default() -> Self { Self { - bytes: codec::ChunkQueue::new(), + bytes: codec::ChunkQueue::default(), marker: PhantomData, } } @@ -115,8 +107,8 @@ fn encode_tagged_value_part>( mod tests { use bytes::Bytes; - use super::{encode_progress, encode_response, ReadStep, RequestWithProgress, ResponseReader}; - use crate::RouteId; + use super::{encode_progress, encode_response, ReadStep, ResponseReader}; + use crate::{request_with_progress::RequestWithProgress, RouteId}; struct Watch; diff --git a/ql-rpc/src/rpc/request_with_progress/mod.rs b/ql-rpc/src/rpc/request_with_progress/mod.rs new file mode 100644 index 00000000..99ed9338 --- /dev/null +++ b/ql-rpc/src/rpc/request_with_progress/mod.rs @@ -0,0 +1,15 @@ +use crate::{RouteId, RpcCodec}; + +pub(crate) mod codec; + +pub use codec::{ + encode_progress, encode_request, encode_response, ReadStep, ResponseReader, +}; + +pub trait RequestWithProgress { + const ROUTE: RouteId; + type Error; + type Request: RpcCodec; + type Progress: RpcCodec; + type Response: RpcCodec; +} diff --git a/ql-rpc/src/rpc/subscription.rs b/ql-rpc/src/rpc/subscription.rs deleted file mode 100644 index 1ad68570..00000000 --- a/ql-rpc/src/rpc/subscription.rs +++ /dev/null @@ -1,239 +0,0 @@ -use std::{ - future::poll_fn, - marker::PhantomData, - task::{Context, Poll}, -}; - -use bytes::{BufMut, Bytes}; - -use crate::{codec, CallError, CodecError, RouteId, RpcCodec, RpcRead}; - -pub trait Subscription { - const ROUTE: RouteId; - type Error; - type Request: RpcCodec; - type Event: RpcCodec; -} - -pub fn encode_request( - request: &M::Request, - out: &mut (impl BufMut + AsMut<[u8]>), -) { - codec::encode_value_part(request, out) -} - -pub fn encode_item(item: &M::Event, out: &mut (impl BufMut + AsMut<[u8]>)) { - codec::encode_value_part(item, out) -} - -pub struct SubscriptionCall -where - M: Subscription, - R: RpcRead, -{ - stream: R, - reader: Option>, -} - -impl SubscriptionCall -where - M: Subscription, - R: RpcRead, -{ - pub fn new(stream: R) -> Self { - Self { - stream, - reader: Some(ResponseReader::default()), - } - } - - pub async fn next_event(&mut self) -> Option>> { - poll_fn(|cx| self.poll_next_event(cx)).await - } - - pub fn poll_next_event( - &mut self, - cx: &mut Context<'_>, - ) -> Poll>>> { - loop { - let Some(reader) = self.reader.take() else { - return Poll::Ready(None); - }; - - let reader = match reader.advance() { - Ok(ReadStep::Item { value, next }) => { - self.reader = Some(next); - return Poll::Ready(Some(Ok(value))); - } - Ok(ReadStep::NeedMore(next)) => next, - Err(error) => return Poll::Ready(Some(Err(error.into()))), - }; - - match self.stream.poll_read(usize::MAX, cx) { - Poll::Ready(Ok(Some(chunk))) => { - self.reader = Some(reader.push(chunk)); - } - Poll::Ready(Ok(None)) => { - if reader.is_empty() { - return Poll::Ready(None); - } - return Poll::Ready(Some(Err(crate::Error::Truncated.into()))); - } - Poll::Ready(Err(error)) => { - self.reader = None; - return Poll::Ready(Some(Err(CallError::Transport(error)))); - } - Poll::Pending => { - self.reader = Some(reader); - return Poll::Pending; - } - } - } - } - - pub fn into_inner(self) -> R { - self.stream - } -} - -impl Default for ResponseReader { - fn default() -> Self { - Self { - bytes: codec::ChunkQueue::new(), - marker: PhantomData, - } - } -} - -impl ResponseReader { - pub fn push(mut self, chunk: Bytes) -> Self { - self.bytes.push(chunk); - self - } - - pub fn is_empty(&self) -> bool { - self.bytes.remaining() == 0 - } - - pub fn advance(self) -> Result, CodecError> { - let mut this = self; - let Some(mut body) = this.bytes.try_take_part().map_err(CodecError::Rpc)? else { - return Ok(ReadStep::NeedMore(this)); - }; - - let item = { - let item = M::Event::decode_value(&mut body).map_err(CodecError::Codec)?; - drop(body); - item - }; - Ok(ReadStep::Item { - value: item, - next: this, - }) - } -} - -pub enum ReadStep { - NeedMore(ResponseReader), - Item { - value: M::Event, - next: ResponseReader, - }, -} - -pub struct ResponseReader { - bytes: codec::ChunkQueue, - marker: PhantomData M>, -} - -#[cfg(test)] -mod tests { - use bytes::Bytes; - - use super::{encode_item, ReadStep, ResponseReader, Subscription}; - use crate::RouteId; - - struct Feed; - - impl Subscription for Feed { - const ROUTE: RouteId = RouteId::from_u32(17); - type Error = core::convert::Infallible; - type Request = Vec; - type Event = Vec; - } - - #[test] - fn response_reader_streams_items_until_end() { - let mut encoded = Vec::new(); - encode_item::(&b"one".to_vec(), &mut encoded); - encode_item::(&b"two".to_vec(), &mut encoded); - - let reader = match ResponseReader::::default() - .push(Bytes::from(encoded)) - .advance() - .unwrap() - { - ReadStep::Item { value, next } => { - assert_eq!(value, b"one".to_vec()); - next - } - _ => unreachable!(), - }; - - let reader = match reader.advance().unwrap() { - ReadStep::Item { value, next } => { - assert_eq!(value, b"two".to_vec()); - next - } - _ => unreachable!(), - }; - - match reader.advance().unwrap() { - ReadStep::NeedMore(next) => assert!(next.is_empty()), - _ => unreachable!(), - } - } - - #[test] - fn response_reader_waits_for_transport_eof_when_no_end_frame_is_present() { - let mut encoded = Vec::new(); - encode_item::(&b"one".to_vec(), &mut encoded); - - let reader = match ResponseReader::::default() - .push(Bytes::from(encoded)) - .advance() - .unwrap() - { - ReadStep::Item { value, next } => { - assert_eq!(value, b"one".to_vec()); - next - } - _ => unreachable!(), - }; - - match reader.advance().unwrap() { - ReadStep::NeedMore(next) => assert!(next.is_empty()), - _ => unreachable!(), - } - } - - #[test] - fn response_reader_allows_empty_event_payloads() { - let mut encoded = Vec::new(); - encode_item::(&Vec::new(), &mut encoded); - - match ResponseReader::::default() - .push(Bytes::from(encoded)) - .advance() - .unwrap() - { - ReadStep::Item { value, next } => { - assert_eq!(value, Vec::::new()); - assert!( - matches!(next.advance().unwrap(), ReadStep::NeedMore(reader) if reader.is_empty()) - ); - } - _ => unreachable!(), - } - } -} diff --git a/ql-rpc/src/rpc/subscription/client.rs b/ql-rpc/src/rpc/subscription/client.rs new file mode 100644 index 00000000..39cb4f0d --- /dev/null +++ b/ql-rpc/src/rpc/subscription/client.rs @@ -0,0 +1,77 @@ +use std::{ + future::poll_fn, + task::{Context, Poll}, +}; + +use crate::{CallError, RpcRead}; +use crate::subscription::{ReadStep, ResponseReader, Subscription}; + +pub struct SubscriptionCall +where + M: Subscription, + R: RpcRead, +{ + stream: R, + reader: Option>, +} + +impl SubscriptionCall +where + M: Subscription, + R: RpcRead, +{ + pub fn new(stream: R) -> Self { + Self { + stream, + reader: Some(ResponseReader::default()), + } + } + + pub async fn next_event(&mut self) -> Option>> { + poll_fn(|cx| self.poll_next_event(cx)).await + } + + pub fn poll_next_event( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + loop { + let Some(reader) = self.reader.take() else { + return Poll::Ready(None); + }; + + let reader = match reader.advance() { + Ok(ReadStep::Item { value, next }) => { + self.reader = Some(next); + return Poll::Ready(Some(Ok(value))); + } + Ok(ReadStep::NeedMore(next)) => next, + Err(error) => return Poll::Ready(Some(Err(error.into()))), + }; + + match self.stream.poll_read(usize::MAX, cx) { + Poll::Ready(Ok(Some(chunk))) => { + self.reader = Some(reader.push(chunk)); + } + Poll::Ready(Ok(None)) => { + if reader.is_empty() { + return Poll::Ready(None); + } + return Poll::Ready(Some(Err(crate::Error::Truncated.into()))); + } + Poll::Ready(Err(error)) => { + self.reader = None; + return Poll::Ready(Some(Err(CallError::Transport(error)))); + } + Poll::Pending => { + self.reader = Some(reader); + return Poll::Pending; + } + } + } + } + + pub fn into_inner(self) -> R { + self.stream + } +} diff --git a/ql-rpc/src/rpc/subscription/codec.rs b/ql-rpc/src/rpc/subscription/codec.rs new file mode 100644 index 00000000..45fec4dc --- /dev/null +++ b/ql-rpc/src/rpc/subscription/codec.rs @@ -0,0 +1,66 @@ +use std::marker::PhantomData; + +use bytes::{BufMut, Bytes}; + +use crate::{codec, subscription::Subscription, CodecError, RpcCodec}; + +pub fn encode_request( + request: &M::Request, + out: &mut (impl BufMut + AsMut<[u8]>), +) { + codec::encode_value_part(request, out) +} + +pub fn encode_item(item: &M::Event, out: &mut (impl BufMut + AsMut<[u8]>)) { + codec::encode_value_part(item, out) +} + +pub enum ReadStep { + NeedMore(ResponseReader), + Item { + value: M::Event, + next: ResponseReader, + }, +} + +pub struct ResponseReader { + bytes: codec::ChunkQueue, + marker: PhantomData M>, +} + +impl Default for ResponseReader { + fn default() -> Self { + Self { + bytes: codec::ChunkQueue::default(), + marker: PhantomData, + } + } +} + +impl ResponseReader { + pub fn push(mut self, chunk: Bytes) -> Self { + self.bytes.push(chunk); + self + } + + pub fn is_empty(&self) -> bool { + self.bytes.remaining() == 0 + } + + pub fn advance(self) -> Result, CodecError> { + let mut this = self; + let Some(mut body) = this.bytes.try_take_part().map_err(CodecError::Rpc)? else { + return Ok(ReadStep::NeedMore(this)); + }; + + let item = { + let item = M::Event::decode_value(&mut body).map_err(CodecError::Codec)?; + drop(body); + item + }; + Ok(ReadStep::Item { + value: item, + next: this, + }) + } +} diff --git a/ql-rpc/src/rpc/subscription/mod.rs b/ql-rpc/src/rpc/subscription/mod.rs new file mode 100644 index 00000000..14edb791 --- /dev/null +++ b/ql-rpc/src/rpc/subscription/mod.rs @@ -0,0 +1,16 @@ +use crate::{RouteId, RpcCodec}; + +pub(crate) mod client; +pub(crate) mod codec; +pub(crate) mod server; + +pub use client::SubscriptionCall; +pub use codec::{encode_item, encode_request, ReadStep, ResponseReader}; +pub use server::{SubscriptionHandler, SubscriptionResponder}; + +pub trait Subscription { + const ROUTE: RouteId; + type Error; + type Request: RpcCodec; + type Event: RpcCodec; +} diff --git a/ql-rpc/src/router/subscription.rs b/ql-rpc/src/rpc/subscription/server.rs similarity index 90% rename from ql-rpc/src/router/subscription.rs rename to ql-rpc/src/rpc/subscription/server.rs index 9216680b..699ad284 100644 --- a/ql-rpc/src/router/subscription.rs +++ b/ql-rpc/src/rpc/subscription/server.rs @@ -2,11 +2,11 @@ use std::marker::PhantomData; use bytes::Bytes; -use super::{request::read_value_and_eof, RouterConfig}; use crate::{ codec, finish_bytes, subscription::Subscription as SubscriptionRpc, write_bytes, RpcCodec, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, }; +use crate::{rpc::read_framed_value, RouterConfig}; pub trait SubscriptionHandler where @@ -31,7 +31,7 @@ where T: RpcCodec, W: RpcWrite, { - fn new(writer: W) -> Self { + pub(crate) fn new(writer: W) -> Self { Self { writer: Some(writer), marker: PhantomData, @@ -69,7 +69,7 @@ where } } -pub(super) async fn handle_subscription_inner( +pub(crate) async fn handle_subscription_inner( state: S, config: RouterConfig, mut reader: St::Reader, @@ -79,7 +79,7 @@ pub(super) async fn handle_subscription_inner( S: SubscriptionHandler + 'static, St: RpcStream + 'static, { - let request = match read_value_and_eof::(&mut reader, config).await { + let request = match read_framed_value::(&mut reader, config).await { Ok(request) => request, Err(error) => { let code = error.close_code(); diff --git a/ql-rpc/src/rpc/upload.rs b/ql-rpc/src/rpc/upload/mod.rs similarity index 100% rename from ql-rpc/src/rpc/upload.rs rename to ql-rpc/src/rpc/upload/mod.rs diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index 484fcd18..f6f36c6b 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -4,8 +4,6 @@ mod error; mod request_with_progress; mod subscription; -use std::future::poll_fn; - use bytes::Bytes; use ql_rpc::{ download::{self as rpc_download, Download as DownloadRpc}, @@ -13,11 +11,10 @@ use ql_rpc::{ request::{self, Request as RequestRpc}, request_with_progress::{self as rpc_request_with_progress, RequestWithProgress}, subscription::{self as rpc_subscription, Subscription as SubscriptionRpc}, - Error, ReadValueStep, RpcCodec, ValueReader, }; pub use self::{adapter::*, download::*, error::*, request_with_progress::*, subscription::*}; -use crate::{StreamReader, RuntimeHandle}; +use crate::{RuntimeHandle, StreamReader}; #[derive(Clone)] pub struct RpcHandle { @@ -48,7 +45,7 @@ impl RpcHandle { let mut payload = Vec::new(); request::encode_request::(request, &mut payload); let response = self.start_request(M::ROUTE, payload).await?; - read_value::(response).await + Ok(request::read_response::(response).await?) } pub async fn subscribe( @@ -112,22 +109,3 @@ impl RpcHandle { Ok(stream.reader) } } - -async fn read_value(mut reader: StreamReader) -> Result> -where - T: RpcCodec, -{ - let mut value_reader = ValueReader::::default(); - - loop { - match value_reader.advance().map_err(RpcError::from)? { - ReadValueStep::Value(value) => return Ok(value), - ReadValueStep::NeedMore(next) => value_reader = next, - } - - match poll_fn(|cx| reader.poll_read_chunk(cx)).await? { - Some(chunk) => value_reader = value_reader.push(chunk), - None => return Err(Error::Truncated.into()), - } - } -} diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index a143981c..b055ab18 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -55,57 +55,18 @@ impl ql_rpc::download::Download for BlobDownload { type ResponseHeader = Vec; } -fn assert_send(value: T) -> T { - value -} - -#[tokio::test(flavor = "current_thread")] -async fn rpc_request_round_trips() { - run_local_test(async { - let mut pair = TestPair::new(default_runtime_config()); - pair.connect_and_wait(Side::A).await; - let inbound_b = pair.take_inbound(Side::B); - - let responder = tokio::task::spawn_local(async move { - let inbound = inbound_b.recv().await.unwrap(); - let request: Vec = read_rpc_value(inbound.reader).await; - assert_eq!( - inbound.route_id, - to_wire_route_id(::ROUTE) - ); - assert_eq!(request, b"hello".to_vec()); - - let mut encoded = Vec::new(); - ql_rpc::request::encode_response::(&"world".into(), &mut encoded); - let mut writer = inbound.writer; - writer.write(Bytes::from(encoded)).await.unwrap(); - writer.finish().await.unwrap(); - }); - - let rpc = pair.side_mut(Side::A).handle.rpc(); - let response = rpc.request::(&"hello".into()).await.unwrap(); - assert_eq!(response, "world"); - - tokio::time::timeout(Duration::from_secs(2), responder) - .await - .unwrap() - .unwrap(); - }) - .await; -} - #[tokio::test(flavor = "current_thread")] -async fn rpc_router_handles_request() { +async fn rpc_request() { #[derive(Clone)] struct RouterState { - seen: Rc>>, + seen: Arc>>, } impl crate::rpc::RequestHandler for RouterState { fn handle(self, request: String, response: Response) { let seen = self.seen.clone(); - tokio::task::spawn_local(async move { - seen.borrow_mut().push(request); + tokio::task::spawn(async move { + seen.lock().unwrap().push(request); let _ = response.respond("world".into()).await; }); } @@ -115,16 +76,17 @@ async fn rpc_router_handles_request() { let mut pair = TestPair::new(default_runtime_config()); pair.connect_and_wait(Side::A).await; let inbound_b = pair.take_inbound(Side::B); - let seen = Rc::new(RefCell::new(Vec::new())); + let seen = Arc::new(Mutex::new(Vec::new())); let router = - ql_rpc::Router::<_, QlStream, crate::rpc::LocalSpawn>::builder(crate::rpc::LocalSpawn) + ql_rpc::Router::<_, QlStream, crate::rpc::SendSpawn>::builder(crate::rpc::SendSpawn) .request::() .build(RouterState { seen: seen.clone() }); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); if let Some((_, fut)) = router.handle(inbound) { + let fut = assert_send(fut); fut.await } }); @@ -132,7 +94,7 @@ async fn rpc_router_handles_request() { let rpc = pair.side_mut(Side::A).handle.rpc(); let response = rpc.request::(&"hello".into()).await.unwrap(); assert_eq!(response, "world"); - assert_eq!(&*seen.borrow(), &["hello".to_string()]); + assert_eq!(&*seen.lock().unwrap(), &["hello".to_string()]); tokio::time::timeout(Duration::from_secs(2), responder) .await @@ -142,8 +104,12 @@ async fn rpc_router_handles_request() { .await; } +fn assert_send(value: T) -> T { + value +} + #[tokio::test(flavor = "current_thread")] -async fn rpc_router_handles_subscription() { +async fn rpc_subscrption() { #[derive(Clone)] struct RouterState { seen: Rc>>>, @@ -198,54 +164,6 @@ async fn rpc_router_handles_subscription() { .await; } -#[tokio::test(flavor = "current_thread")] -async fn rpc_send_router_handles_request() { - #[derive(Clone)] - struct RouterState { - seen: Arc>>, - } - - impl crate::rpc::RequestHandler for RouterState { - fn handle(self, request: String, response: crate::rpc::Response) { - let seen = self.seen.clone(); - tokio::task::spawn(async move { - seen.lock().unwrap().push(request); - let _ = response.respond("world".into()).await; - }); - } - } - - run_local_test(async { - let mut pair = TestPair::new(default_runtime_config()); - pair.connect_and_wait(Side::A).await; - let inbound_b = pair.take_inbound(Side::B); - let seen = Arc::new(Mutex::new(Vec::new())); - let router = - ql_rpc::Router::<_, QlStream, crate::rpc::SendSpawn>::builder(crate::rpc::SendSpawn) - .request::() - .build(RouterState { seen: seen.clone() }); - - let responder = tokio::task::spawn_local(async move { - let inbound = inbound_b.recv().await.unwrap(); - if let Some((_, fut)) = router.handle(inbound) { - let fut = assert_send(fut); - fut.await - } - }); - - let rpc = pair.side_mut(Side::A).handle.rpc(); - let response = rpc.request::(&"hello".into()).await.unwrap(); - assert_eq!(response, "world"); - assert_eq!(&*seen.lock().unwrap(), &["hello".to_string()]); - - tokio::time::timeout(Duration::from_secs(2), responder) - .await - .unwrap() - .unwrap(); - }) - .await; -} - #[tokio::test(flavor = "current_thread")] async fn rpc_router_enforces_max_request_bytes() { #[derive(Clone)] @@ -296,7 +214,7 @@ async fn rpc_router_enforces_max_request_bytes() { } #[tokio::test(flavor = "current_thread")] -async fn rpc_subscription_streams_events() { +async fn rpc_request_with_progress() { run_local_test(async { let mut pair = TestPair::new(default_runtime_config()); pair.connect_and_wait(Side::A).await; @@ -304,50 +222,27 @@ async fn rpc_subscription_streams_events() { let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); - let request: Vec = read_rpc_value(inbound.reader).await; + let request: Vec = { + let mut reader = inbound.reader; + let mut value_reader = ql_rpc::FramedValueReader::>::default(); + + loop { + match value_reader.advance().unwrap() { + ql_rpc::ReadValueStep::Value(value) => break value, + ql_rpc::ReadValueStep::NeedMore(next) => value_reader = next, + } + + match reader.read_chunk().await.unwrap() { + Some(chunk) => value_reader = value_reader.push(chunk), + None => panic!("truncated rpc value"), + } + } + }; assert_eq!( inbound.route_id, - to_wire_route_id(::ROUTE) - ); - assert_eq!(request, b"watch".to_vec()); - - let mut encoded = Vec::new(); - ql_rpc::subscription::encode_item::(&b"one".to_vec(), &mut encoded); - ql_rpc::subscription::encode_item::(&b"two".to_vec(), &mut encoded); - - let mut writer = inbound.writer; - writer.write(Bytes::from(encoded)).await.unwrap(); - writer.finish().await.unwrap(); - }); - - let rpc = pair.side_mut(Side::A).handle.rpc(); - let mut subscription = rpc.subscribe::(&b"watch".to_vec()).await.unwrap(); - assert_eq!(subscription.next().await.unwrap().unwrap(), b"one".to_vec()); - assert_eq!(subscription.next().await.unwrap().unwrap(), b"two".to_vec()); - assert!(subscription.next().await.is_none()); - - tokio::time::timeout(Duration::from_secs(2), responder) - .await - .unwrap() - .unwrap(); - }) - .await; -} - -#[tokio::test(flavor = "current_thread")] -async fn rpc_request_with_progress_supports_progress_then_await() { - run_local_test(async { - let mut pair = TestPair::new(default_runtime_config()); - pair.connect_and_wait(Side::A).await; - let inbound_b = pair.take_inbound(Side::B); - - let responder = tokio::task::spawn_local(async move { - let inbound = inbound_b.recv().await.unwrap(); - let request: Vec = read_rpc_value(inbound.reader).await; - assert_eq!( - inbound.route_id, - to_wire_route_id( + WireRouteId::from_u32( ::ROUTE + .into_inner() ) ); assert_eq!(request, b"logo".to_vec()); @@ -391,18 +286,14 @@ async fn rpc_request_with_progress_supports_progress_then_await() { } #[tokio::test(flavor = "current_thread")] -async fn rpc_router_handles_download() { +async fn rpc_download() { #[derive(Clone)] struct RouterState { seen: Rc>>>, } impl crate::rpc::DownloadHandler for RouterState { - fn handle( - self, - request: Vec, - responder: DownloadResponder, StreamWriter>, - ) { + fn handle(self, request: Vec, responder: DownloadResponder, StreamWriter>) { let seen = self.seen.clone(); tokio::task::spawn_local(async move { seen.borrow_mut().push(request); @@ -434,11 +325,20 @@ async fn rpc_router_handles_download() { }); let rpc = pair.side_mut(Side::A).handle.rpc(); - let download = rpc.download::(&b"logo".to_vec()).await.unwrap(); + let download = rpc + .download::(&b"logo".to_vec()) + .await + .unwrap(); let (header, mut reader) = download.into_reader().await.unwrap(); assert_eq!(header, b"image/png".to_vec()); - assert_eq!(reader.read_chunk().await.unwrap(), Some(Bytes::from_static(b"abc"))); - assert_eq!(reader.read_chunk().await.unwrap(), Some(Bytes::from_static(b"def"))); + assert_eq!( + reader.read_chunk().await.unwrap(), + Some(Bytes::from_static(b"abc")) + ); + assert_eq!( + reader.read_chunk().await.unwrap(), + Some(Bytes::from_static(b"def")) + ); assert_eq!(reader.read_chunk().await.unwrap(), None); assert_eq!(seen.borrow().as_slice(), &[b"logo".to_vec()]); @@ -449,27 +349,3 @@ async fn rpc_router_handles_download() { }) .await; } - -async fn read_rpc_value(mut reader: crate::StreamReader) -> T -where - T: ql_rpc::RpcCodec, - T::Error: std::fmt::Debug, -{ - let mut value_reader = ql_rpc::ValueReader::::default(); - - loop { - match value_reader.advance().unwrap() { - ql_rpc::ReadValueStep::Value(value) => return value, - ql_rpc::ReadValueStep::NeedMore(next) => value_reader = next, - } - - match reader.read_chunk().await.unwrap() { - Some(chunk) => value_reader = value_reader.push(chunk), - None => panic!("truncated rpc value"), - } - } -} - -fn to_wire_route_id(route_id: RouteId) -> WireRouteId { - WireRouteId::from_u32(route_id.into_inner()) -} From 1507a1dfd4da0940926d2c9e823811d0efa37096 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 21:15:40 -0400 Subject: [PATCH 258/304] ql-rpc: ProgressCall --- ql-rpc/src/router/builder.rs | 27 +++ ql-rpc/src/router/mod.rs | 1 + ql-rpc/src/rpc/mod.rs | 17 +- ql-rpc/src/rpc/progress/client.rs | 163 ++++++++++++++++++ .../codec.rs | 34 ++-- .../mod.rs | 6 +- ql-rpc/src/rpc/progress/server.rs | 101 +++++++++++ ql-runtime/src/rpc/mod.rs | 16 +- ql-runtime/src/rpc/progress.rs | 41 +++++ ql-runtime/src/rpc/request_with_progress.rs | 135 --------------- ql-runtime/src/tests/rpc.rs | 89 +++++----- 11 files changed, 408 insertions(+), 222 deletions(-) create mode 100644 ql-rpc/src/rpc/progress/client.rs rename ql-rpc/src/rpc/{request_with_progress => progress}/codec.rs (87%) rename ql-rpc/src/rpc/{request_with_progress => progress}/mod.rs (69%) create mode 100644 ql-rpc/src/rpc/progress/server.rs create mode 100644 ql-runtime/src/rpc/progress.rs delete mode 100644 ql-runtime/src/rpc/request_with_progress.rs diff --git a/ql-rpc/src/router/builder.rs b/ql-rpc/src/router/builder.rs index 0b0f6ab7..b3c33dc7 100644 --- a/ql-rpc/src/router/builder.rs +++ b/ql-rpc/src/router/builder.rs @@ -7,6 +7,8 @@ use super::{ use crate::{ download::Download as DownloadRpc, download::server::{handle_download_inner, DownloadHandler}, + progress::Progress as ProgressRpc, + progress::server::{ProgressHandler, handle_progress_inner}, request::Request as RequestRpc, subscription::Subscription as SubscriptionRpc, RouteId, request::server::{handle_request_inner, RequestHandler}, subscription::server::{handle_subscription_inner, SubscriptionHandler}, @@ -103,6 +105,17 @@ where )) }) } + + pub fn progress(self) -> Self + where + M: ProgressRpc + 'static, + S: ProgressHandler + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_progress_inner::(state, config, reader, writer)) + }) + } } impl RouterBuilder @@ -156,4 +169,18 @@ where )) }) } + + pub fn progress(self) -> Self + where + M: ProgressRpc + 'static, + M::Request: Send + 'static, + S: ProgressHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_progress_inner::(state, config, reader, writer)) + }) + } } diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs index c0ee63d3..b778f542 100644 --- a/ql-rpc/src/router/mod.rs +++ b/ql-rpc/src/router/mod.rs @@ -12,6 +12,7 @@ pub use self::{ mode::*, }; pub use crate::download::{DownloadHandler, DownloadResponder, DownloadWriter}; +pub use crate::progress::{ProgressHandler, ProgressResponder}; pub use crate::request::{RequestHandler, Response}; pub use crate::subscription::{SubscriptionHandler, SubscriptionResponder}; use crate::{close_stream, RpcStream}; diff --git a/ql-rpc/src/rpc/mod.rs b/ql-rpc/src/rpc/mod.rs index ba7823dd..77e7a494 100644 --- a/ql-rpc/src/rpc/mod.rs +++ b/ql-rpc/src/rpc/mod.rs @@ -1,27 +1,24 @@ use crate::{ - read_bytes, CallError, ChunkQueue, FramedValueReader, ReadValueStep, RouterConfig, RpcCodec, RpcRead, - StreamCloseCode, + read_bytes, CallError, ChunkQueue, CodecError, FramedValueReader, ReadValueStep, RouterConfig, + RpcCodec, RpcRead, StreamCloseCode, }; pub mod download; pub mod notification; +pub mod progress; pub mod request; -pub mod request_with_progress; pub mod subscription; pub mod upload; pub use download::Download; pub use notification::Notification; +pub use progress::Progress; pub use request::Request; -pub use request_with_progress::RequestWithProgress; pub use subscription::Subscription; pub use upload::Upload; /// reads one length-delimited value and rejects trailing bytes -async fn read_framed_value( - reader: &mut R, - config: RouterConfig, -) -> Result +async fn read_framed_value(reader: &mut R, config: RouterConfig) -> Result where T: RpcCodec, R: RpcRead, @@ -33,8 +30,8 @@ where match value_reader.advance() { Ok(ReadValueStep::Value(value)) => break value, Ok(ReadValueStep::NeedMore(next)) => value_reader = next, - Err(crate::CodecError::Rpc(_error)) => return Err(StreamCloseCode::REFUSED.into()), - Err(crate::CodecError::Codec(_error)) => return Err(StreamCloseCode::REFUSED.into()), + Err(CodecError::Rpc(_error)) => return Err(StreamCloseCode::REFUSED.into()), + Err(CodecError::Codec(_error)) => return Err(StreamCloseCode::REFUSED.into()), } let remaining = config.max_request_bytes.saturating_sub(total_read); diff --git a/ql-rpc/src/rpc/progress/client.rs b/ql-rpc/src/rpc/progress/client.rs new file mode 100644 index 00000000..54920706 --- /dev/null +++ b/ql-rpc/src/rpc/progress/client.rs @@ -0,0 +1,163 @@ +use std::{ + future::{poll_fn, Future}, + pin::Pin, + task::{Context, Poll}, +}; + +use crate::{ + progress::{Progress, ReadStep, ResponseReader}, + CallError, Error, RpcRead, +}; + +pub struct ProgressCall +where + M: Progress, + R: RpcRead, +{ + stream: R, + state: State, +} + +enum State +where + M: Progress, +{ + Invalid, + Reading(ResponseReader), + Terminal(Result>), + Done, +} + +impl Unpin for ProgressCall +where + M: Progress, + R: RpcRead, +{ +} + +impl ProgressCall +where + M: Progress, + R: RpcRead, +{ + pub fn new(stream: R) -> Self { + Self { + stream, + state: State::Reading(ResponseReader::default()), + } + } + + pub async fn next_progress(&mut self) -> Option { + poll_fn(|cx| self.poll_next_progress(cx)).await + } + + pub fn poll_next_progress(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + let reader = match std::mem::replace(&mut self.state, State::Invalid) { + State::Reading(reader) => reader, + state @ (State::Terminal(_) | State::Done) => { + self.state = state; + return Poll::Ready(None); + } + State::Invalid => panic!("invalid state"), + }; + + match reader.advance() { + Ok(ReadStep::Progress { value, next }) => { + self.state = State::Reading(next); + return Poll::Ready(Some(value)); + } + Ok(ReadStep::Response(response)) => { + self.state = State::Terminal(Ok(response)); + return Poll::Ready(None); + } + Ok(ReadStep::NeedMore(next)) => { + self.state = State::Reading(next); + } + Err(error) => { + self.state = State::Terminal(Err(error.into())); + return Poll::Ready(None); + } + } + + match self.stream.poll_read(usize::MAX, cx) { + Poll::Ready(Ok(Some(chunk))) => { + let State::Reading(reader) = std::mem::replace(&mut self.state, State::Invalid) + else { + panic!("invalid state"); + }; + self.state = State::Reading(reader.push(chunk)); + } + Poll::Ready(Ok(None)) => { + self.state = State::Terminal(Err(Error::MissingResponse.into())); + return Poll::Ready(None); + } + Poll::Ready(Err(error)) => { + self.state = State::Terminal(Err(CallError::Transport(error))); + return Poll::Ready(None); + } + Poll::Pending => return Poll::Pending, + } + } + } +} + +impl Future for ProgressCall +where + M: Progress, + R: RpcRead, +{ + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + loop { + let reader = match std::mem::replace(&mut this.state, State::Invalid) { + State::Reading(reader) => reader, + State::Terminal(result) => { + this.state = State::Done; + return Poll::Ready(result); + } + State::Done => panic!("polled after completion"), + State::Invalid => panic!("polled during state transition"), + }; + + match reader.advance() { + Ok(ReadStep::Progress { next, .. }) => { + this.state = State::Reading(next); + } + Ok(ReadStep::Response(response)) => { + this.state = State::Done; + return Poll::Ready(Ok(response)); + } + Ok(ReadStep::NeedMore(next)) => { + this.state = State::Reading(next); + } + Err(error) => { + this.state = State::Done; + return Poll::Ready(Err(error.into())); + } + } + + match this.stream.poll_read(usize::MAX, cx) { + Poll::Ready(Ok(Some(chunk))) => { + let State::Reading(reader) = std::mem::replace(&mut this.state, State::Invalid) + else { + panic!("progress reader is not present"); + }; + this.state = State::Reading(reader.push(chunk)); + } + Poll::Ready(Ok(None)) => { + this.state = State::Done; + return Poll::Ready(Err(Error::MissingResponse.into())); + } + Poll::Ready(Err(error)) => { + this.state = State::Done; + return Poll::Ready(Err(CallError::Transport(error))); + } + Poll::Pending => return Poll::Pending, + } + } + } +} diff --git a/ql-rpc/src/rpc/request_with_progress/codec.rs b/ql-rpc/src/rpc/progress/codec.rs similarity index 87% rename from ql-rpc/src/rpc/request_with_progress/codec.rs rename to ql-rpc/src/rpc/progress/codec.rs index 4d8095ee..e7dbd0d7 100644 --- a/ql-rpc/src/rpc/request_with_progress/codec.rs +++ b/ql-rpc/src/rpc/progress/codec.rs @@ -2,9 +2,9 @@ use std::marker::PhantomData; use bytes::{BufMut, Bytes}; -use crate::{codec, request_with_progress::RequestWithProgress, CodecError, Error, RpcCodec}; +use crate::{codec, progress::Progress, CodecError, Error, RpcCodec}; -pub enum ReadStep { +pub enum ReadStep { NeedMore(ResponseReader), Progress { value: M::Progress, @@ -13,12 +13,12 @@ pub enum ReadStep { Response(M::Response), } -pub struct ResponseReader { +pub struct ResponseReader { bytes: codec::ChunkQueue, marker: PhantomData M>, } -impl Default for ResponseReader { +impl Default for ResponseReader { fn default() -> Self { Self { bytes: codec::ChunkQueue::default(), @@ -27,7 +27,7 @@ impl Default for ResponseReader { } } -impl ResponseReader { +impl ResponseReader { pub fn push(mut self, chunk: Bytes) -> Self { self.bytes.push(chunk); self @@ -64,34 +64,34 @@ impl ResponseReader { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(u8)] -enum FrameKind { - Progress = 1, - Response = 2, -} - -pub fn encode_request( +pub fn encode_request( request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>), ) { codec::encode_value_part(request, out) } -pub fn encode_progress( +pub fn encode_progress( progress: &M::Progress, out: &mut (impl BufMut + AsMut<[u8]>), ) { encode_tagged_value_part(FrameKind::Progress, progress, out) } -pub fn encode_response( +pub fn encode_response( response: &M::Response, out: &mut (impl BufMut + AsMut<[u8]>), ) { encode_tagged_value_part(FrameKind::Response, response, out) } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +enum FrameKind { + Progress = 1, + Response = 2, +} + fn encode_tagged_value_part>( kind: FrameKind, value: &T, @@ -108,11 +108,11 @@ mod tests { use bytes::Bytes; use super::{encode_progress, encode_response, ReadStep, ResponseReader}; - use crate::{request_with_progress::RequestWithProgress, RouteId}; + use crate::{progress::Progress, RouteId}; struct Watch; - impl RequestWithProgress for Watch { + impl Progress for Watch { const ROUTE: RouteId = RouteId::from_u32(11); type Error = core::convert::Infallible; type Request = Vec; diff --git a/ql-rpc/src/rpc/request_with_progress/mod.rs b/ql-rpc/src/rpc/progress/mod.rs similarity index 69% rename from ql-rpc/src/rpc/request_with_progress/mod.rs rename to ql-rpc/src/rpc/progress/mod.rs index 99ed9338..c696e15f 100644 --- a/ql-rpc/src/rpc/request_with_progress/mod.rs +++ b/ql-rpc/src/rpc/progress/mod.rs @@ -1,12 +1,16 @@ use crate::{RouteId, RpcCodec}; +pub(crate) mod client; pub(crate) mod codec; +pub(crate) mod server; +pub use client::ProgressCall; pub use codec::{ encode_progress, encode_request, encode_response, ReadStep, ResponseReader, }; +pub use server::{ProgressHandler, ProgressResponder}; -pub trait RequestWithProgress { +pub trait Progress { const ROUTE: RouteId; type Error; type Request: RpcCodec; diff --git a/ql-rpc/src/rpc/progress/server.rs b/ql-rpc/src/rpc/progress/server.rs new file mode 100644 index 00000000..a2ff9784 --- /dev/null +++ b/ql-rpc/src/rpc/progress/server.rs @@ -0,0 +1,101 @@ +use std::marker::PhantomData; + +use bytes::Bytes; + +use crate::{ + RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, finish_bytes, + progress::{Progress, encode_progress, encode_response}, + write_bytes, +}; +use crate::{RouterConfig, rpc::read_framed_value}; + +pub trait ProgressHandler +where + M: Progress, + St: RpcStream, +{ + fn handle(self, request: M::Request, responder: ProgressResponder); + + fn handle_transport_error(&self, _error: &St::Error) {} +} + +pub struct ProgressResponder +where + M: Progress, + W: RpcWrite, +{ + writer: Option, + marker: PhantomData M>, +} + +impl ProgressResponder +where + M: Progress, + W: RpcWrite, +{ + pub(crate) fn new(writer: W) -> Self { + Self { + writer: Some(writer), + marker: PhantomData, + } + } + + pub async fn send(&mut self, progress: M::Progress) -> Result<(), W::Error> { + let writer = self.writer.as_mut().expect("progress writer exists"); + let mut encoded = Vec::new(); + encode_progress::(&progress, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await + } + + pub async fn finish(mut self, response: M::Response) -> Result<(), W::Error> { + let mut writer = self.writer.take().expect("progress writer exists"); + let mut encoded = Vec::new(); + encode_response::(&response, &mut encoded); + write_bytes(&mut writer, Bytes::from(encoded)).await?; + finish_bytes(&mut writer).await + } + + pub fn close(mut self, code: StreamCloseCode) { + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for ProgressResponder +where + M: Progress, + W: RpcWrite, +{ + fn drop(&mut self) { + if let Some(writer) = self.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } +} + +pub(crate) async fn handle_progress_inner( + state: S, + config: RouterConfig, + mut reader: St::Reader, + writer: St::Writer, +) where + M: Progress + 'static, + S: ProgressHandler + 'static, + St: RpcStream + 'static, +{ + let request = match read_framed_value::(&mut reader, config).await { + Ok(request) => request, + Err(error) => { + let code = error.close_code(); + state.handle_transport_error(&error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } + return; + } + }; + + state.handle(request, ProgressResponder::new(writer)); +} diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index f6f36c6b..2ede0b73 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -1,19 +1,19 @@ mod adapter; mod download; mod error; -mod request_with_progress; +mod progress; mod subscription; use bytes::Bytes; use ql_rpc::{ download::{self as rpc_download, Download as DownloadRpc}, notification::{self, Notification}, + progress::{self as rpc_progress, Progress}, request::{self, Request as RequestRpc}, - request_with_progress::{self as rpc_request_with_progress, RequestWithProgress}, subscription::{self as rpc_subscription, Subscription as SubscriptionRpc}, }; -pub use self::{adapter::*, download::*, error::*, request_with_progress::*, subscription::*}; +pub use self::{adapter::*, download::*, error::*, progress::*, subscription::*}; use crate::{RuntimeHandle, StreamReader}; #[derive(Clone)] @@ -78,20 +78,18 @@ impl RpcHandle { }) } - pub async fn request_with_progress( + pub async fn progress( &self, request: &M::Request, ) -> Result, RpcError> where - M: RequestWithProgress, + M: Progress, { let mut payload = Vec::new(); - rpc_request_with_progress::encode_request::(request, &mut payload); + rpc_progress::encode_request::(request, &mut payload); let response = self.start_request(M::ROUTE, payload).await?; Ok(ProgressCall { - stream: response, - reader: Some(rpc_request_with_progress::ResponseReader::default()), - terminal: None, + inner: rpc_progress::ProgressCall::new(response), }) } diff --git a/ql-runtime/src/rpc/progress.rs b/ql-runtime/src/rpc/progress.rs new file mode 100644 index 00000000..1c3984d4 --- /dev/null +++ b/ql-runtime/src/rpc/progress.rs @@ -0,0 +1,41 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_lite::Stream; +use ql_rpc::progress::Progress; + +use super::RpcError; +use crate::StreamReader; + +pub struct ProgressCall { + pub(super) inner: ql_rpc::progress::ProgressCall, +} + +impl Unpin for ProgressCall where M: Progress {} + +impl Stream for ProgressCall +where + M: Progress, +{ + type Item = M::Progress; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().inner.poll_next_progress(cx) + } +} + +impl Future for ProgressCall +where + M: Progress, +{ + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.get_mut().inner) + .poll(cx) + .map(|result| result.map_err(RpcError::from)) + } +} diff --git a/ql-runtime/src/rpc/request_with_progress.rs b/ql-runtime/src/rpc/request_with_progress.rs deleted file mode 100644 index 94944af2..00000000 --- a/ql-runtime/src/rpc/request_with_progress.rs +++ /dev/null @@ -1,135 +0,0 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -use futures_lite::{future::poll_fn, Stream}; -use ql_rpc::{ - request_with_progress::{ReadStep, RequestWithProgress}, - Error, -}; - -use super::RpcError; -use crate::StreamReader; - -pub struct ProgressCall { - pub(super) stream: StreamReader, - pub(super) reader: Option>, - pub(super) terminal: Option>>, -} - -impl Unpin for ProgressCall where M: RequestWithProgress {} - -impl ProgressCall -where - M: RequestWithProgress, -{ - pub async fn progress(&mut self) -> Option { - poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await - } -} - -impl Stream for ProgressCall -where - M: RequestWithProgress, -{ - type Item = M::Progress; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - - if this.terminal.is_some() || this.reader.is_none() { - return Poll::Ready(None); - } - - loop { - let reader = this.reader.take().expect("progress reader is present"); - match reader.advance() { - Ok(ReadStep::Progress { value, next }) => { - this.reader = Some(next); - return Poll::Ready(Some(value)); - } - Ok(ReadStep::Response(response)) => { - this.terminal = Some(Ok(response)); - return Poll::Ready(None); - } - Ok(ReadStep::NeedMore(next)) => { - this.reader = Some(next); - } - Err(error) => { - this.terminal = Some(Err(error.into())); - return Poll::Ready(None); - } - } - - match this.stream.poll_read_chunk(cx) { - Poll::Ready(Ok(Some(chunk))) => { - let reader = this.reader.take().expect("progress reader is present"); - this.reader = Some(reader.push(chunk)); - } - Poll::Ready(Ok(None)) => { - this.reader = None; - this.terminal = Some(Err(Error::MissingResponse.into())); - return Poll::Ready(None); - } - Poll::Ready(Err(error)) => { - this.reader = None; - this.terminal = Some(Err(error.into())); - return Poll::Ready(None); - } - Poll::Pending => return Poll::Pending, - } - } - } -} - -impl Future for ProgressCall -where - M: RequestWithProgress, -{ - type Output = Result>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - - if let Some(result) = this.terminal.take() { - return Poll::Ready(result); - } - - loop { - let Some(reader) = this.reader.take() else { - panic!("progress call polled after completion"); - }; - - match reader.advance() { - Ok(ReadStep::Progress { next, .. }) => { - this.reader = Some(next); - } - Ok(ReadStep::Response(response)) => { - return Poll::Ready(Ok(response)); - } - Ok(ReadStep::NeedMore(next)) => { - this.reader = Some(next); - } - Err(error) => return Poll::Ready(Err(error.into())), - } - - match this.stream.poll_read_chunk(cx) { - Poll::Ready(Ok(Some(chunk))) => { - let reader = this.reader.take().expect("progress reader is present"); - this.reader = Some(reader.push(chunk)); - } - Poll::Ready(Ok(None)) => { - this.reader = None; - return Poll::Ready(Err(Error::MissingResponse.into())); - } - Poll::Ready(Err(error)) => { - this.reader = None; - return Poll::Ready(Err(error.into())); - } - Poll::Pending => return Poll::Pending, - } - } - } -} diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index b055ab18..e13e8235 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -9,9 +9,9 @@ use std::{ use bytes::Bytes; use futures_lite::StreamExt; use ql_rpc::{ - DownloadResponder, DownloadWriter, Response, RouteId, StreamCloseCode, SubscriptionResponder, + DownloadResponder, DownloadWriter, ProgressResponder, Response, RouteId, StreamCloseCode, + SubscriptionResponder, }; -use ql_wire::RouteId as WireRouteId; use super::*; use crate::{QlStream, StreamWriter}; @@ -38,7 +38,7 @@ impl ql_rpc::subscription::Subscription for Feed { struct Download; -impl ql_rpc::request_with_progress::RequestWithProgress for Download { +impl ql_rpc::progress::Progress for Download { const ROUTE: RouteId = RouteId::from_u32(53); type Error = core::convert::Infallible; type Request = Vec; @@ -214,68 +214,57 @@ async fn rpc_router_enforces_max_request_bytes() { } #[tokio::test(flavor = "current_thread")] -async fn rpc_request_with_progress() { +async fn rpc_progress() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>>, + } + + impl ql_rpc::ProgressHandler for RouterState { + fn handle( + self, + request: Vec, + mut responder: ProgressResponder, + ) { + let seen = self.seen.clone(); + tokio::task::spawn_local(async move { + seen.borrow_mut().push(request); + responder.send(b"10".to_vec()).await.unwrap(); + responder.send(b"90".to_vec()).await.unwrap(); + responder.finish(b"done".to_vec()).await.unwrap(); + }); + } + } + run_local_test(async { let mut pair = TestPair::new(default_runtime_config()); pair.connect_and_wait(Side::A).await; let inbound_b = pair.take_inbound(Side::B); + let seen = Rc::new(RefCell::new(Vec::new())); + + let router = + ql_rpc::Router::<_, QlStream, crate::rpc::LocalSpawn>::builder(crate::rpc::LocalSpawn) + .progress::() + .build(RouterState { seen: seen.clone() }); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); - let request: Vec = { - let mut reader = inbound.reader; - let mut value_reader = ql_rpc::FramedValueReader::>::default(); - - loop { - match value_reader.advance().unwrap() { - ql_rpc::ReadValueStep::Value(value) => break value, - ql_rpc::ReadValueStep::NeedMore(next) => value_reader = next, - } - - match reader.read_chunk().await.unwrap() { - Some(chunk) => value_reader = value_reader.push(chunk), - None => panic!("truncated rpc value"), - } - } - }; - assert_eq!( - inbound.route_id, - WireRouteId::from_u32( - ::ROUTE - .into_inner() - ) - ); - assert_eq!(request, b"logo".to_vec()); - - let mut encoded = Vec::new(); - ql_rpc::request_with_progress::encode_progress::( - &b"10".to_vec(), - &mut encoded, - ); - ql_rpc::request_with_progress::encode_progress::( - &b"90".to_vec(), - &mut encoded, - ); - ql_rpc::request_with_progress::encode_response::( - &b"done".to_vec(), - &mut encoded, - ); - - let mut writer = inbound.writer; - writer.write(Bytes::from(encoded)).await.unwrap(); - writer.finish().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await; + } }); let rpc = pair.side_mut(Side::A).handle.rpc(); let mut download = rpc - .request_with_progress::(&b"logo".to_vec()) + .progress::(&b"logo".to_vec()) .await .unwrap(); - assert_eq!(download.progress().await, Some(b"10".to_vec())); - assert_eq!(download.progress().await, Some(b"90".to_vec())); - assert_eq!(download.progress().await, None); + assert_eq!(download.next().await, Some(b"10".to_vec())); + assert_eq!(download.next().await, Some(b"90".to_vec())); + assert_eq!(download.next().await, None); assert_eq!(download.await.unwrap(), b"done".to_vec()); + assert_eq!(seen.borrow().as_slice(), &[b"logo".to_vec()]); tokio::time::timeout(Duration::from_secs(2), responder) .await From 9f8c6f4acf849bf2323437114c48ec368260fe39 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 21:41:55 -0400 Subject: [PATCH 259/304] ql-rpc: clean up --- ql-rpc/src/rpc/download/codec.rs | 2 +- ql-rpc/src/rpc/download/server.rs | 7 ++--- ql-rpc/src/rpc/mod.rs | 34 +++++++++++++++-------- ql-rpc/src/rpc/notification/codec.rs | 5 +++- ql-rpc/src/rpc/notification/mod.rs | 4 +-- ql-rpc/src/rpc/progress/server.rs | 10 +++---- ql-rpc/src/rpc/request/client.rs | 17 ++++++++++-- ql-rpc/src/rpc/request/server.rs | 40 ++------------------------- ql-rpc/src/rpc/subscription/codec.rs | 2 +- ql-rpc/src/rpc/subscription/server.rs | 8 +++--- ql-runtime/src/rpc/mod.rs | 4 +-- 11 files changed, 61 insertions(+), 72 deletions(-) diff --git a/ql-rpc/src/rpc/download/codec.rs b/ql-rpc/src/rpc/download/codec.rs index 48b2e53a..53e54ecc 100644 --- a/ql-rpc/src/rpc/download/codec.rs +++ b/ql-rpc/src/rpc/download/codec.rs @@ -5,7 +5,7 @@ use bytes::{BufMut, Bytes}; use crate::{codec, download::Download, ChunkQueue, CodecError, RpcCodec}; pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { - codec::encode_value_part(request, out) + request.encode_value(out) } pub fn encode_response_header( diff --git a/ql-rpc/src/rpc/download/server.rs b/ql-rpc/src/rpc/download/server.rs index aab41dd0..43560c39 100644 --- a/ql-rpc/src/rpc/download/server.rs +++ b/ql-rpc/src/rpc/download/server.rs @@ -3,10 +3,9 @@ use std::marker::PhantomData; use bytes::Bytes; use crate::{ - codec, download::Download as DownloadRpc, finish_bytes, write_bytes, RpcCodec, RpcRead, - RpcStream, RpcWrite, StreamCloseCode, StreamError, + codec, download::Download as DownloadRpc, finish_bytes, rpc::read_eof_request, write_bytes, + RouterConfig, RpcCodec, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, }; -use crate::{rpc::read_framed_value, RouterConfig}; pub trait DownloadHandler where @@ -119,7 +118,7 @@ pub(crate) async fn handle_download_inner( S: DownloadHandler + 'static, St: RpcStream + 'static, { - let request = match read_framed_value::(&mut reader, config).await { + let request = match read_eof_request::(&mut reader, config).await { Ok(request) => request, Err(error) => { let code = error.close_code(); diff --git a/ql-rpc/src/rpc/mod.rs b/ql-rpc/src/rpc/mod.rs index 77e7a494..d04dc34d 100644 --- a/ql-rpc/src/rpc/mod.rs +++ b/ql-rpc/src/rpc/mod.rs @@ -1,6 +1,6 @@ use crate::{ - read_bytes, CallError, ChunkQueue, CodecError, FramedValueReader, ReadValueStep, RouterConfig, - RpcCodec, RpcRead, StreamCloseCode, + read_bytes, ChunkQueue, CodecError, FramedValueReader, ReadValueStep, RouterConfig, RpcCodec, + RpcRead, StreamCloseCode, }; pub mod download; @@ -18,7 +18,7 @@ pub use subscription::Subscription; pub use upload::Upload; /// reads one length-delimited value and rejects trailing bytes -async fn read_framed_value(reader: &mut R, config: RouterConfig) -> Result +async fn read_framed_request(reader: &mut R, config: RouterConfig) -> Result where T: RpcCodec, R: RpcRead, @@ -59,24 +59,34 @@ where } } -/// reads one eof-delimited value and rejects trailing bytes -async fn read_whole_value(reader: &mut R) -> Result> +/// reads one eof-delimited value up to the configured request limit +async fn read_eof_request(reader: &mut R, config: RouterConfig) -> Result where T: RpcCodec, R: RpcRead, { let mut bytes = ChunkQueue::default(); + let mut total_read = 0usize; - while let Some(chunk) = read_bytes(reader, usize::MAX) - .await - .map_err(CallError::Transport)? - { - bytes.push(chunk); + loop { + let remaining = config.max_request_bytes.saturating_sub(total_read); + let probe = remaining.max(1); + match read_bytes(reader, probe).await { + Ok(Some(chunk)) => { + if chunk.len() > remaining { + return Err(StreamCloseCode::LIMIT.into()); + } + total_read += chunk.len(); + bytes.push(chunk); + } + Ok(None) => break, + Err(error) => return Err(error), + } } - let value = T::decode_value(&mut bytes).map_err(CallError::Codec)?; + let value = T::decode_value(&mut bytes).map_err(|_error| StreamCloseCode::REFUSED)?; if bytes.remaining() > 0 { - return Err(crate::Error::TrailingBytes.into()); + return Err(StreamCloseCode::REFUSED.into()); } Ok(value) } diff --git a/ql-rpc/src/rpc/notification/codec.rs b/ql-rpc/src/rpc/notification/codec.rs index ec33ed73..838336f3 100644 --- a/ql-rpc/src/rpc/notification/codec.rs +++ b/ql-rpc/src/rpc/notification/codec.rs @@ -2,6 +2,9 @@ use bytes::BufMut; use crate::{codec, notification::Notification}; -pub fn encode_event(event: &M::Event, out: &mut (impl BufMut + AsMut<[u8]>)) { +pub fn encode_notification( + event: &M::Payload, + out: &mut (impl BufMut + AsMut<[u8]>), +) { codec::encode_value_part(event, out) } diff --git a/ql-rpc/src/rpc/notification/mod.rs b/ql-rpc/src/rpc/notification/mod.rs index 81110173..a895fca8 100644 --- a/ql-rpc/src/rpc/notification/mod.rs +++ b/ql-rpc/src/rpc/notification/mod.rs @@ -2,10 +2,10 @@ use crate::{RouteId, RpcCodec}; pub(crate) mod codec; -pub use codec::encode_event; +pub use codec::encode_notification; pub trait Notification { const ROUTE: RouteId; type Error; - type Event: RpcCodec; + type Payload: RpcCodec; } diff --git a/ql-rpc/src/rpc/progress/server.rs b/ql-rpc/src/rpc/progress/server.rs index a2ff9784..4acf1553 100644 --- a/ql-rpc/src/rpc/progress/server.rs +++ b/ql-rpc/src/rpc/progress/server.rs @@ -3,11 +3,11 @@ use std::marker::PhantomData; use bytes::Bytes; use crate::{ - RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, finish_bytes, - progress::{Progress, encode_progress, encode_response}, - write_bytes, + finish_bytes, + progress::{encode_progress, encode_response, Progress}, + rpc::read_framed_request, + write_bytes, RouterConfig, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, }; -use crate::{RouterConfig, rpc::read_framed_value}; pub trait ProgressHandler where @@ -84,7 +84,7 @@ pub(crate) async fn handle_progress_inner( S: ProgressHandler + 'static, St: RpcStream + 'static, { - let request = match read_framed_value::(&mut reader, config).await { + let request = match read_framed_request::(&mut reader, config).await { Ok(request) => request, Err(error) => { let code = error.close_code(); diff --git a/ql-rpc/src/rpc/request/client.rs b/ql-rpc/src/rpc/request/client.rs index 5d936af0..f4608a4d 100644 --- a/ql-rpc/src/rpc/request/client.rs +++ b/ql-rpc/src/rpc/request/client.rs @@ -1,6 +1,6 @@ use bytes::BufMut; -use crate::{request::Request, rpc::read_whole_value, CallError, RpcCodec, RpcRead}; +use crate::{CallError, ChunkQueue, RpcCodec, RpcRead, read_bytes, request::Request}; pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { request.encode_value(out) @@ -20,5 +20,18 @@ where M: Request, R: RpcRead, { - read_whole_value::(&mut reader).await + let mut bytes = ChunkQueue::default(); + + while let Some(chunk) = read_bytes(&mut reader, usize::MAX) + .await + .map_err(CallError::Transport)? + { + bytes.push(chunk); + } + + let value = M::Response::decode_value(&mut bytes).map_err(CallError::Codec)?; + if bytes.remaining() > 0 { + return Err(crate::Error::TrailingBytes.into()); + } + Ok(value) } diff --git a/ql-rpc/src/rpc/request/server.rs b/ql-rpc/src/rpc/request/server.rs index b94e795c..949dacef 100644 --- a/ql-rpc/src/rpc/request/server.rs +++ b/ql-rpc/src/rpc/request/server.rs @@ -3,12 +3,10 @@ use std::marker::PhantomData; use bytes::Bytes; use crate::{ - finish_bytes, read_bytes, request::Request as RequestRpc, write_bytes, ChunkQueue, + finish_bytes, request::Request as RequestRpc, rpc::read_eof_request, write_bytes, RouterConfig, RpcCodec, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, }; -use crate::RouterConfig; - pub trait RequestHandler where M: RequestRpc, @@ -76,7 +74,7 @@ pub(crate) async fn handle_request_inner( S: RequestHandler + 'static, St: RpcStream + 'static, { - let request = match read_whole_value::(&mut reader, config).await { + let request = match read_eof_request::(&mut reader, config).await { Ok(request) => request, Err(error) => { let code = error.close_code(); @@ -91,37 +89,3 @@ pub(crate) async fn handle_request_inner( state.handle(request, Response::new(writer)); } - -pub(crate) async fn read_whole_value( - reader: &mut R, - config: RouterConfig, -) -> Result -where - T: RpcCodec, - R: RpcRead, -{ - let mut bytes = ChunkQueue::default(); - let mut total_read = 0usize; - - loop { - let remaining = config.max_request_bytes.saturating_sub(total_read); - let probe = remaining.max(1); - match read_bytes(reader, probe).await { - Ok(Some(chunk)) => { - if chunk.len() > remaining { - return Err(StreamCloseCode::LIMIT.into()); - } - total_read += chunk.len(); - bytes.push(chunk); - } - Ok(None) => break, - Err(error) => return Err(error), - } - } - - let value = T::decode_value(&mut bytes).map_err(|_error| StreamCloseCode::REFUSED)?; - if bytes.remaining() > 0 { - return Err(StreamCloseCode::REFUSED.into()); - } - Ok(value) -} diff --git a/ql-rpc/src/rpc/subscription/codec.rs b/ql-rpc/src/rpc/subscription/codec.rs index 45fec4dc..9fdf4d7e 100644 --- a/ql-rpc/src/rpc/subscription/codec.rs +++ b/ql-rpc/src/rpc/subscription/codec.rs @@ -8,7 +8,7 @@ pub fn encode_request( request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>), ) { - codec::encode_value_part(request, out) + request.encode_value(out) } pub fn encode_item(item: &M::Event, out: &mut (impl BufMut + AsMut<[u8]>)) { diff --git a/ql-rpc/src/rpc/subscription/server.rs b/ql-rpc/src/rpc/subscription/server.rs index 699ad284..7a193462 100644 --- a/ql-rpc/src/rpc/subscription/server.rs +++ b/ql-rpc/src/rpc/subscription/server.rs @@ -3,10 +3,10 @@ use std::marker::PhantomData; use bytes::Bytes; use crate::{ - codec, finish_bytes, subscription::Subscription as SubscriptionRpc, write_bytes, RpcCodec, - RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, + codec, finish_bytes, rpc::read_eof_request, subscription::Subscription as SubscriptionRpc, + write_bytes, RouterConfig, RpcCodec, RpcRead, RpcStream, RpcWrite, StreamCloseCode, + StreamError, }; -use crate::{rpc::read_framed_value, RouterConfig}; pub trait SubscriptionHandler where @@ -79,7 +79,7 @@ pub(crate) async fn handle_subscription_inner( S: SubscriptionHandler + 'static, St: RpcStream + 'static, { - let request = match read_framed_value::(&mut reader, config).await { + let request = match read_eof_request::(&mut reader, config).await { Ok(request) => request, Err(error) => { let code = error.close_code(); diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index 2ede0b73..7383ef12 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -22,12 +22,12 @@ pub struct RpcHandle { } impl RpcHandle { - pub async fn event(&self, event: &M::Event) -> Result<(), RpcError> + pub async fn event(&self, event: &M::Payload) -> Result<(), RpcError> where M: Notification, { let mut payload = Vec::new(); - notification::encode_event::(event, &mut payload); + notification::encode_notification::(event, &mut payload); let mut stream = self .inner .open_stream(adapter::to_wire_route_id(M::ROUTE)) From 89502c24cc0b927dd55f829052141daa19910646 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 21:58:48 -0400 Subject: [PATCH 260/304] ql-rpc: notification --- ql-rpc/src/router/builder.rs | 31 +++++ ql-rpc/src/router/mod.rs | 1 + .../rpc/notification/{codec.rs => client.rs} | 6 +- ql-rpc/src/rpc/notification/mod.rs | 6 +- ql-rpc/src/rpc/notification/server.rs | 41 +++++++ ql-runtime/src/rpc/adapter.rs | 7 +- ql-runtime/src/rpc/mod.rs | 3 +- ql-runtime/src/tests/rpc.rs | 116 ++++++++++++------ 8 files changed, 161 insertions(+), 50 deletions(-) rename ql-rpc/src/rpc/notification/{codec.rs => client.rs} (50%) create mode 100644 ql-rpc/src/rpc/notification/server.rs diff --git a/ql-rpc/src/router/builder.rs b/ql-rpc/src/router/builder.rs index b3c33dc7..8d4ec302 100644 --- a/ql-rpc/src/router/builder.rs +++ b/ql-rpc/src/router/builder.rs @@ -7,6 +7,8 @@ use super::{ use crate::{ download::Download as DownloadRpc, download::server::{handle_download_inner, DownloadHandler}, + notification::Notification as NotificationRpc, + notification::server::{NotificationHandler, handle_notification_inner}, progress::Progress as ProgressRpc, progress::server::{ProgressHandler, handle_progress_inner}, request::Request as RequestRpc, subscription::Subscription as SubscriptionRpc, RouteId, @@ -80,6 +82,19 @@ where }) } + pub fn notification(self) -> Self + where + M: NotificationRpc + 'static, + S: NotificationHandler + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_notification_inner::( + state, config, reader, writer, + )) + }) + } + pub fn download(self) -> Self where M: DownloadRpc + 'static, @@ -138,6 +153,22 @@ where }) } + pub fn notification(self) -> Self + where + M: NotificationRpc + 'static, + M::Payload: Send + 'static, + S: NotificationHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_notification_inner::( + state, config, reader, writer, + )) + }) + } + pub fn download(self) -> Self where M: DownloadRpc + 'static, diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs index b778f542..1bc267c8 100644 --- a/ql-rpc/src/router/mod.rs +++ b/ql-rpc/src/router/mod.rs @@ -12,6 +12,7 @@ pub use self::{ mode::*, }; pub use crate::download::{DownloadHandler, DownloadResponder, DownloadWriter}; +pub use crate::notification::NotificationHandler; pub use crate::progress::{ProgressHandler, ProgressResponder}; pub use crate::request::{RequestHandler, Response}; pub use crate::subscription::{SubscriptionHandler, SubscriptionResponder}; diff --git a/ql-rpc/src/rpc/notification/codec.rs b/ql-rpc/src/rpc/notification/client.rs similarity index 50% rename from ql-rpc/src/rpc/notification/codec.rs rename to ql-rpc/src/rpc/notification/client.rs index 838336f3..72b6900a 100644 --- a/ql-rpc/src/rpc/notification/codec.rs +++ b/ql-rpc/src/rpc/notification/client.rs @@ -1,10 +1,10 @@ use bytes::BufMut; -use crate::{codec, notification::Notification}; +use crate::{notification::Notification, RpcCodec}; pub fn encode_notification( - event: &M::Payload, + payload: &M::Payload, out: &mut (impl BufMut + AsMut<[u8]>), ) { - codec::encode_value_part(event, out) + payload.encode_value(out) } diff --git a/ql-rpc/src/rpc/notification/mod.rs b/ql-rpc/src/rpc/notification/mod.rs index a895fca8..dae97c82 100644 --- a/ql-rpc/src/rpc/notification/mod.rs +++ b/ql-rpc/src/rpc/notification/mod.rs @@ -1,8 +1,10 @@ use crate::{RouteId, RpcCodec}; -pub(crate) mod codec; +pub(crate) mod client; +pub(crate) mod server; -pub use codec::encode_notification; +pub use client::encode_notification; +pub use server::NotificationHandler; pub trait Notification { const ROUTE: RouteId; diff --git a/ql-rpc/src/rpc/notification/server.rs b/ql-rpc/src/rpc/notification/server.rs new file mode 100644 index 00000000..fc98684c --- /dev/null +++ b/ql-rpc/src/rpc/notification/server.rs @@ -0,0 +1,41 @@ +use crate::{ + notification::Notification as NotificationRpc, rpc::read_eof_request, RouterConfig, RpcRead, + RpcStream, RpcWrite, StreamCloseCode, StreamError, +}; + +pub trait NotificationHandler +where + M: NotificationRpc, + St: RpcStream, +{ + fn handle(self, message: M::Payload); + + fn handle_transport_error(&self, _error: &St::Error) {} +} + +pub(crate) async fn handle_notification_inner( + state: S, + config: RouterConfig, + mut reader: St::Reader, + writer: St::Writer, +) where + M: NotificationRpc + 'static, + S: NotificationHandler + 'static, + St: RpcStream + 'static, +{ + let notification = match read_eof_request::(&mut reader, config).await { + Ok(notification) => notification, + Err(error) => { + let code = error.close_code(); + state.handle_transport_error(&error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } + return; + } + }; + + writer.close(StreamCloseCode::CANCELLED); + state.handle(notification); +} diff --git a/ql-runtime/src/rpc/adapter.rs b/ql-runtime/src/rpc/adapter.rs index 13e7eb57..a7347602 100644 --- a/ql-runtime/src/rpc/adapter.rs +++ b/ql-runtime/src/rpc/adapter.rs @@ -1,12 +1,7 @@ use std::task::{Context, Poll}; use bytes::Bytes; -pub use ql_rpc::{ - DownloadHandler, DownloadResponder, DownloadWriter, LocalSpawn, RequestHandler, Response, - RouteId, RouterConfig, SendSpawn, StreamCloseCode, SubscriptionHandler, - SubscriptionResponder, -}; -use ql_rpc::{RpcRead, RpcStream, RpcWrite, StreamError}; +use ql_rpc::{RouteId, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError}; use ql_wire::{RouteId as WireRouteId, StreamCloseCode as WireStreamCloseCode}; use crate::{QlStream, QlStreamError, StreamReader, StreamWriter}; diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index 7383ef12..4fc304bf 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -1,3 +1,5 @@ +pub use self::{download::*, error::*, progress::*, subscription::*}; + mod adapter; mod download; mod error; @@ -13,7 +15,6 @@ use ql_rpc::{ subscription::{self as rpc_subscription, Subscription as SubscriptionRpc}, }; -pub use self::{adapter::*, download::*, error::*, progress::*, subscription::*}; use crate::{RuntimeHandle, StreamReader}; #[derive(Clone)] diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index e13e8235..7c5293c2 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -9,11 +9,13 @@ use std::{ use bytes::Bytes; use futures_lite::StreamExt; use ql_rpc::{ - DownloadResponder, DownloadWriter, ProgressResponder, Response, RouteId, StreamCloseCode, - SubscriptionResponder, + DownloadHandler, DownloadResponder, DownloadWriter, LocalSpawn, NotificationHandler, + ProgressHandler, ProgressResponder, RequestHandler, Response, RouteId, SendSpawn, + StreamCloseCode, SubscriptionHandler, SubscriptionResponder, }; use super::*; +use crate::rpc::RpcError; use crate::{QlStream, StreamWriter}; struct Echo; @@ -36,6 +38,14 @@ impl ql_rpc::subscription::Subscription for Feed { type Event = Vec; } +struct Notice; + +impl ql_rpc::notification::Notification for Notice { + const ROUTE: RouteId = RouteId::from_u32(521); + type Error = core::convert::Infallible; + type Payload = Vec; +} + struct Download; impl ql_rpc::progress::Progress for Download { @@ -62,7 +72,7 @@ async fn rpc_request() { seen: Arc>>, } - impl crate::rpc::RequestHandler for RouterState { + impl RequestHandler for RouterState { fn handle(self, request: String, response: Response) { let seen = self.seen.clone(); tokio::task::spawn(async move { @@ -78,10 +88,9 @@ async fn rpc_request() { let inbound_b = pair.take_inbound(Side::B); let seen = Arc::new(Mutex::new(Vec::new())); - let router = - ql_rpc::Router::<_, QlStream, crate::rpc::SendSpawn>::builder(crate::rpc::SendSpawn) - .request::() - .build(RouterState { seen: seen.clone() }); + let router = ql_rpc::Router::<_, QlStream, SendSpawn>::builder(SendSpawn) + .request::() + .build(RouterState { seen: seen.clone() }); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); @@ -108,6 +117,48 @@ fn assert_send(value: T) -> T { value } +#[tokio::test(flavor = "current_thread")] +async fn rpc_notification() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>>, + } + + impl NotificationHandler for RouterState { + fn handle(self, payload: Vec) { + self.seen.borrow_mut().push(payload); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let seen = Rc::new(RefCell::new(Vec::new())); + + let router = ql_rpc::Router::<_, QlStream, LocalSpawn>::builder(LocalSpawn) + .notification::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await; + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + rpc.event::(&b"hello".to_vec()).await.unwrap(); + assert_eq!(seen.borrow().as_slice(), &[b"hello".to_vec()]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + #[tokio::test(flavor = "current_thread")] async fn rpc_subscrption() { #[derive(Clone)] @@ -115,7 +166,7 @@ async fn rpc_subscrption() { seen: Rc>>>, } - impl crate::rpc::SubscriptionHandler for RouterState { + impl SubscriptionHandler for RouterState { fn handle( self, request: Vec, @@ -137,10 +188,9 @@ async fn rpc_subscrption() { let inbound_b = pair.take_inbound(Side::B); let seen = Rc::new(RefCell::new(Vec::new())); - let router = - ql_rpc::Router::<_, QlStream, crate::rpc::LocalSpawn>::builder(crate::rpc::LocalSpawn) - .subscription::() - .build(RouterState { seen: seen.clone() }); + let router = ql_rpc::Router::<_, QlStream, LocalSpawn>::builder(LocalSpawn) + .subscription::() + .build(RouterState { seen: seen.clone() }); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); @@ -169,12 +219,8 @@ async fn rpc_router_enforces_max_request_bytes() { #[derive(Clone)] struct LimitedState; - impl crate::rpc::RequestHandler for LimitedState { - fn handle( - self, - request: String, - response: crate::rpc::Response, - ) { + impl RequestHandler for LimitedState { + fn handle(self, request: String, response: Response) { tokio::task::spawn_local(async move { let _ = response.respond(request).await; }); @@ -185,11 +231,10 @@ async fn rpc_router_enforces_max_request_bytes() { let mut pair = TestPair::new(default_runtime_config()); pair.connect_and_wait(Side::A).await; let inbound_b = pair.take_inbound(Side::B); - let router = - ql_rpc::Router::<_, QlStream, crate::rpc::LocalSpawn>::builder(crate::rpc::LocalSpawn) - .max_request_bytes(4) - .request::() - .build(LimitedState); + let router = ql_rpc::Router::<_, QlStream, LocalSpawn>::builder(LocalSpawn) + .max_request_bytes(4) + .request::() + .build(LimitedState); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); @@ -202,7 +247,7 @@ async fn rpc_router_enforces_max_request_bytes() { let response = rpc.request::(&"hello".to_string()).await; assert!(matches!( response, - Err(crate::rpc::RpcError::Closed(code)) if code == StreamCloseCode::LIMIT + Err(RpcError::Closed(code)) if code == StreamCloseCode::LIMIT )); tokio::time::timeout(Duration::from_secs(2), responder) @@ -220,7 +265,7 @@ async fn rpc_progress() { seen: Rc>>>, } - impl ql_rpc::ProgressHandler for RouterState { + impl ProgressHandler for RouterState { fn handle( self, request: Vec, @@ -242,10 +287,9 @@ async fn rpc_progress() { let inbound_b = pair.take_inbound(Side::B); let seen = Rc::new(RefCell::new(Vec::new())); - let router = - ql_rpc::Router::<_, QlStream, crate::rpc::LocalSpawn>::builder(crate::rpc::LocalSpawn) - .progress::() - .build(RouterState { seen: seen.clone() }); + let router = ql_rpc::Router::<_, QlStream, LocalSpawn>::builder(LocalSpawn) + .progress::() + .build(RouterState { seen: seen.clone() }); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); @@ -255,10 +299,7 @@ async fn rpc_progress() { }); let rpc = pair.side_mut(Side::A).handle.rpc(); - let mut download = rpc - .progress::(&b"logo".to_vec()) - .await - .unwrap(); + let mut download = rpc.progress::(&b"logo".to_vec()).await.unwrap(); assert_eq!(download.next().await, Some(b"10".to_vec())); assert_eq!(download.next().await, Some(b"90".to_vec())); @@ -281,7 +322,7 @@ async fn rpc_download() { seen: Rc>>>, } - impl crate::rpc::DownloadHandler for RouterState { + impl DownloadHandler for RouterState { fn handle(self, request: Vec, responder: DownloadResponder, StreamWriter>) { let seen = self.seen.clone(); tokio::task::spawn_local(async move { @@ -301,10 +342,9 @@ async fn rpc_download() { let inbound_b = pair.take_inbound(Side::B); let seen = Rc::new(RefCell::new(Vec::new())); - let router = - ql_rpc::Router::<_, QlStream, crate::rpc::LocalSpawn>::builder(crate::rpc::LocalSpawn) - .download::() - .build(RouterState { seen: seen.clone() }); + let router = ql_rpc::Router::<_, QlStream, LocalSpawn>::builder(LocalSpawn) + .download::() + .build(RouterState { seen: seen.clone() }); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); From 76438d2c55490a83f4707a44ccfd98093f1ef051 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 22:03:40 -0400 Subject: [PATCH 261/304] ql-runtime: clean up --- ql-runtime/src/handle/mod.rs | 4 +--- ql-runtime/src/rpc/mod.rs | 8 +++++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index 27047342..c6839f5d 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -78,9 +78,7 @@ impl RuntimeHandle { #[cfg(feature = "rpc")] pub fn rpc(&self) -> crate::rpc::RpcHandle { - crate::rpc::RpcHandle { - inner: self.clone(), - } + crate::rpc::RpcHandle::new(self.clone()) } } diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index 4fc304bf..21f29491 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -19,7 +19,7 @@ use crate::{RuntimeHandle, StreamReader}; #[derive(Clone)] pub struct RpcHandle { - pub(crate) inner: RuntimeHandle, + inner: RuntimeHandle, } impl RpcHandle { @@ -93,6 +93,12 @@ impl RpcHandle { inner: rpc_progress::ProgressCall::new(response), }) } +} + +impl RpcHandle { + pub(super) fn new(inner: RuntimeHandle) -> Self { + Self { inner } + } async fn start_request( &self, From 1aaf9a00ad785f4579bd59e7897a0feb383f702d Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 22:06:51 -0400 Subject: [PATCH 262/304] ql-rpc: move to utils --- ql-rpc/src/rpc/mod.rs | 81 +---------------------------------------- ql-rpc/src/rpc/utils.rs | 81 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 79 deletions(-) create mode 100644 ql-rpc/src/rpc/utils.rs diff --git a/ql-rpc/src/rpc/mod.rs b/ql-rpc/src/rpc/mod.rs index d04dc34d..36b96230 100644 --- a/ql-rpc/src/rpc/mod.rs +++ b/ql-rpc/src/rpc/mod.rs @@ -1,14 +1,10 @@ -use crate::{ - read_bytes, ChunkQueue, CodecError, FramedValueReader, ReadValueStep, RouterConfig, RpcCodec, - RpcRead, StreamCloseCode, -}; - pub mod download; pub mod notification; pub mod progress; pub mod request; pub mod subscription; pub mod upload; +mod utils; pub use download::Download; pub use notification::Notification; @@ -16,77 +12,4 @@ pub use progress::Progress; pub use request::Request; pub use subscription::Subscription; pub use upload::Upload; - -/// reads one length-delimited value and rejects trailing bytes -async fn read_framed_request(reader: &mut R, config: RouterConfig) -> Result -where - T: RpcCodec, - R: RpcRead, -{ - let mut value_reader = FramedValueReader::::default(); - let mut total_read = 0usize; - - let value = loop { - match value_reader.advance() { - Ok(ReadValueStep::Value(value)) => break value, - Ok(ReadValueStep::NeedMore(next)) => value_reader = next, - Err(CodecError::Rpc(_error)) => return Err(StreamCloseCode::REFUSED.into()), - Err(CodecError::Codec(_error)) => return Err(StreamCloseCode::REFUSED.into()), - } - - let remaining = config.max_request_bytes.saturating_sub(total_read); - if remaining == 0 { - return Err(StreamCloseCode::LIMIT.into()); - } - - match read_bytes(reader, remaining).await { - Ok(Some(chunk)) => { - total_read += chunk.len(); - value_reader = value_reader.push(chunk); - } - Ok(None) => return Err(StreamCloseCode::REFUSED.into()), - Err(error) => return Err(error), - } - }; - - let remaining = config.max_request_bytes.saturating_sub(total_read); - let probe = remaining.max(1); - match read_bytes(reader, probe).await { - Ok(None) => Ok(value), - Ok(Some(_)) if remaining == 0 => Err(StreamCloseCode::LIMIT.into()), - Ok(Some(_)) => Err(StreamCloseCode::REFUSED.into()), - Err(error) => Err(error), - } -} - -/// reads one eof-delimited value up to the configured request limit -async fn read_eof_request(reader: &mut R, config: RouterConfig) -> Result -where - T: RpcCodec, - R: RpcRead, -{ - let mut bytes = ChunkQueue::default(); - let mut total_read = 0usize; - - loop { - let remaining = config.max_request_bytes.saturating_sub(total_read); - let probe = remaining.max(1); - match read_bytes(reader, probe).await { - Ok(Some(chunk)) => { - if chunk.len() > remaining { - return Err(StreamCloseCode::LIMIT.into()); - } - total_read += chunk.len(); - bytes.push(chunk); - } - Ok(None) => break, - Err(error) => return Err(error), - } - } - - let value = T::decode_value(&mut bytes).map_err(|_error| StreamCloseCode::REFUSED)?; - if bytes.remaining() > 0 { - return Err(StreamCloseCode::REFUSED.into()); - } - Ok(value) -} +use utils::*; diff --git a/ql-rpc/src/rpc/utils.rs b/ql-rpc/src/rpc/utils.rs new file mode 100644 index 00000000..60a1ae5a --- /dev/null +++ b/ql-rpc/src/rpc/utils.rs @@ -0,0 +1,81 @@ +use crate::{ + read_bytes, ChunkQueue, CodecError, FramedValueReader, ReadValueStep, RouterConfig, RpcCodec, + RpcRead, StreamCloseCode, +}; + +/// reads one length-delimited value and rejects trailing bytes +pub(crate) async fn read_framed_request( + reader: &mut R, + config: RouterConfig, +) -> Result +where + T: RpcCodec, + R: RpcRead, +{ + let mut value_reader = FramedValueReader::::default(); + let mut total_read = 0usize; + + let value = loop { + match value_reader.advance() { + Ok(ReadValueStep::Value(value)) => break value, + Ok(ReadValueStep::NeedMore(next)) => value_reader = next, + Err(CodecError::Rpc(_error)) => return Err(StreamCloseCode::REFUSED.into()), + Err(CodecError::Codec(_error)) => return Err(StreamCloseCode::REFUSED.into()), + } + + let remaining = config.max_request_bytes.saturating_sub(total_read); + if remaining == 0 { + return Err(StreamCloseCode::LIMIT.into()); + } + + match read_bytes(reader, remaining).await { + Ok(Some(chunk)) => { + total_read += chunk.len(); + value_reader = value_reader.push(chunk); + } + Ok(None) => return Err(StreamCloseCode::REFUSED.into()), + Err(error) => return Err(error), + } + }; + + let remaining = config.max_request_bytes.saturating_sub(total_read); + let probe = remaining.max(1); + match read_bytes(reader, probe).await { + Ok(None) => Ok(value), + Ok(Some(_)) if remaining == 0 => Err(StreamCloseCode::LIMIT.into()), + Ok(Some(_)) => Err(StreamCloseCode::REFUSED.into()), + Err(error) => Err(error), + } +} + +/// reads one eof-delimited value up to the configured request limit +pub(crate) async fn read_eof_request(reader: &mut R, config: RouterConfig) -> Result +where + T: RpcCodec, + R: RpcRead, +{ + let mut bytes = ChunkQueue::default(); + let mut total_read = 0usize; + + loop { + let remaining = config.max_request_bytes.saturating_sub(total_read); + let probe = remaining.max(1); + match read_bytes(reader, probe).await { + Ok(Some(chunk)) => { + if chunk.len() > remaining { + return Err(StreamCloseCode::LIMIT.into()); + } + total_read += chunk.len(); + bytes.push(chunk); + } + Ok(None) => break, + Err(error) => return Err(error), + } + } + + let value = T::decode_value(&mut bytes).map_err(|_error| StreamCloseCode::REFUSED)?; + if bytes.remaining() > 0 { + return Err(StreamCloseCode::REFUSED.into()); + } + Ok(value) +} From 2f0e71cf9d43e2d18dcfa517e15c26ecf3fca986 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 22:14:55 -0400 Subject: [PATCH 263/304] ql-rpc: docs --- ql-rpc/src/router/builder.rs | 38 +++++++++++++++++++-------- ql-rpc/src/router/mod.rs | 18 ++++++------- ql-rpc/src/rpc/download/client.rs | 6 +++-- ql-rpc/src/rpc/download/mod.rs | 15 ++++++----- ql-rpc/src/rpc/mod.rs | 7 +++++ ql-rpc/src/rpc/notification/mod.rs | 7 +++++ ql-rpc/src/rpc/progress/codec.rs | 15 +++-------- ql-rpc/src/rpc/progress/mod.rs | 15 ++++++++--- ql-rpc/src/rpc/request/client.rs | 7 ++--- ql-rpc/src/rpc/request/mod.rs | 10 +++++++ ql-rpc/src/rpc/subscription/client.rs | 6 +++-- ql-rpc/src/rpc/subscription/codec.rs | 2 +- ql-rpc/src/rpc/subscription/mod.rs | 8 ++++++ ql-rpc/src/rpc/upload/mod.rs | 16 ++++++----- ql-rpc/src/rpc/utils.rs | 5 +++- 15 files changed, 116 insertions(+), 59 deletions(-) diff --git a/ql-rpc/src/router/builder.rs b/ql-rpc/src/router/builder.rs index 8d4ec302..4fb93e4c 100644 --- a/ql-rpc/src/router/builder.rs +++ b/ql-rpc/src/router/builder.rs @@ -5,15 +5,27 @@ use super::{ Spawner, }; use crate::{ - download::Download as DownloadRpc, - download::server::{handle_download_inner, DownloadHandler}, - notification::Notification as NotificationRpc, - notification::server::{NotificationHandler, handle_notification_inner}, - progress::Progress as ProgressRpc, - progress::server::{ProgressHandler, handle_progress_inner}, - request::Request as RequestRpc, subscription::Subscription as SubscriptionRpc, RouteId, - request::server::{handle_request_inner, RequestHandler}, - subscription::server::{handle_subscription_inner, SubscriptionHandler}, + download::{ + server::{handle_download_inner, DownloadHandler}, + Download as DownloadRpc, + }, + notification::{ + server::{handle_notification_inner, NotificationHandler}, + Notification as NotificationRpc, + }, + progress::{ + server::{handle_progress_inner, ProgressHandler}, + Progress as ProgressRpc, + }, + request::{ + server::{handle_request_inner, RequestHandler}, + Request as RequestRpc, + }, + subscription::{ + server::{handle_subscription_inner, SubscriptionHandler}, + Subscription as SubscriptionRpc, + }, + RouteId, }; pub struct RouterBuilder @@ -128,7 +140,9 @@ where { self.add_route(M::ROUTE, |spawner, state, config, stream| { let (reader, writer) = stream.split(); - spawner.spawn(handle_progress_inner::(state, config, reader, writer)) + spawner.spawn(handle_progress_inner::( + state, config, reader, writer, + )) }) } } @@ -211,7 +225,9 @@ where { self.add_route(M::ROUTE, |spawner, state, config, stream| { let (reader, writer) = stream.split(); - spawner.spawn(handle_progress_inner::(state, config, reader, writer)) + spawner.spawn(handle_progress_inner::( + state, config, reader, writer, + )) }) } } diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs index 1bc267c8..53a2ace2 100644 --- a/ql-rpc/src/router/mod.rs +++ b/ql-rpc/src/router/mod.rs @@ -6,17 +6,15 @@ mod builder; mod config; mod mode; -pub use self::{ - builder::RouterBuilder, - config::RouterConfig, - mode::*, -}; -pub use crate::download::{DownloadHandler, DownloadResponder, DownloadWriter}; -pub use crate::notification::NotificationHandler; -pub use crate::progress::{ProgressHandler, ProgressResponder}; -pub use crate::request::{RequestHandler, Response}; -pub use crate::subscription::{SubscriptionHandler, SubscriptionResponder}; +pub use self::{builder::RouterBuilder, config::RouterConfig, mode::*}; use crate::{close_stream, RpcStream}; +pub use crate::{ + download::{DownloadHandler, DownloadResponder, DownloadWriter}, + notification::NotificationHandler, + progress::{ProgressHandler, ProgressResponder}, + request::{RequestHandler, Response}, + subscription::{SubscriptionHandler, SubscriptionResponder}, +}; pub struct Router where diff --git a/ql-rpc/src/rpc/download/client.rs b/ql-rpc/src/rpc/download/client.rs index 6d262a3e..3156b8b8 100644 --- a/ql-rpc/src/rpc/download/client.rs +++ b/ql-rpc/src/rpc/download/client.rs @@ -5,8 +5,10 @@ use std::{ use bytes::Bytes; -use crate::{CallError, ChunkQueue, RpcRead}; -use crate::download::{Download, ReadStep, ResponseHeaderReader}; +use crate::{ + download::{Download, ReadStep, ResponseHeaderReader}, + CallError, ChunkQueue, RpcRead, +}; pub struct DownloadCall where diff --git a/ql-rpc/src/rpc/download/mod.rs b/ql-rpc/src/rpc/download/mod.rs index 761bf6b6..27da9898 100644 --- a/ql-rpc/src/rpc/download/mod.rs +++ b/ql-rpc/src/rpc/download/mod.rs @@ -8,15 +8,18 @@ pub use client::{DownloadCall, DownloadReader}; pub use codec::{encode_request, encode_response_header, ReadStep, ResponseHeaderReader}; pub use server::{DownloadHandler, DownloadResponder, DownloadWriter}; -/// rpc where the responder streams a large byte body -/// the caller sends a request -/// the responder sends a typed header for the body -/// the responder streams the raw response bytes +/// rpc where the responder returns metadata first and raw bytes after that +/// +/// the typed portion of the response ends at [`Self::ResponseHeader`] +/// after the header is decoded, the rest of the stream is exposed as raw byte +/// chunks through [`DownloadReader`] pub trait Download { + /// route used to dispatch this rpc family const ROUTE: RouteId; + /// codec error shared by request and response header values type Error; - /// input needed to start the download + /// typed input needed to start the download type Request: RpcCodec; - /// details about the body before bytes arrive + /// typed metadata available before body bytes arrive type ResponseHeader: RpcCodec; } diff --git a/ql-rpc/src/rpc/mod.rs b/ql-rpc/src/rpc/mod.rs index 36b96230..b3b85cf4 100644 --- a/ql-rpc/src/rpc/mod.rs +++ b/ql-rpc/src/rpc/mod.rs @@ -1,3 +1,10 @@ +//! rpc protocol families built on top of one stream per call +//! +//! each trait in this module names one rpc shape and the typed values that +//! travel on that stream +//! route dispatch uses [`crate::RouteId`] and the submodules provide the matching +//! client and server helpers for encoding, decoding, and handler glue + pub mod download; pub mod notification; pub mod progress; diff --git a/ql-rpc/src/rpc/notification/mod.rs b/ql-rpc/src/rpc/notification/mod.rs index dae97c82..e773378f 100644 --- a/ql-rpc/src/rpc/notification/mod.rs +++ b/ql-rpc/src/rpc/notification/mod.rs @@ -6,8 +6,15 @@ pub(crate) mod server; pub use client::encode_notification; pub use server::NotificationHandler; +/// one-way rpc that carries a single typed payload and no typed response +/// +/// the server reads [`Self::Payload`] to eof and then closes the response side +/// of the stream pub trait Notification { + /// route used to dispatch this notification const ROUTE: RouteId; + /// codec error for the notification payload type Error; + /// typed payload emitted by the caller type Payload: RpcCodec; } diff --git a/ql-rpc/src/rpc/progress/codec.rs b/ql-rpc/src/rpc/progress/codec.rs index e7dbd0d7..c56af0dd 100644 --- a/ql-rpc/src/rpc/progress/codec.rs +++ b/ql-rpc/src/rpc/progress/codec.rs @@ -64,24 +64,15 @@ impl ResponseReader { } } -pub fn encode_request( - request: &M::Request, - out: &mut (impl BufMut + AsMut<[u8]>), -) { +pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { codec::encode_value_part(request, out) } -pub fn encode_progress( - progress: &M::Progress, - out: &mut (impl BufMut + AsMut<[u8]>), -) { +pub fn encode_progress(progress: &M::Progress, out: &mut (impl BufMut + AsMut<[u8]>)) { encode_tagged_value_part(FrameKind::Progress, progress, out) } -pub fn encode_response( - response: &M::Response, - out: &mut (impl BufMut + AsMut<[u8]>), -) { +pub fn encode_response(response: &M::Response, out: &mut (impl BufMut + AsMut<[u8]>)) { encode_tagged_value_part(FrameKind::Response, response, out) } diff --git a/ql-rpc/src/rpc/progress/mod.rs b/ql-rpc/src/rpc/progress/mod.rs index c696e15f..5828def0 100644 --- a/ql-rpc/src/rpc/progress/mod.rs +++ b/ql-rpc/src/rpc/progress/mod.rs @@ -5,15 +5,24 @@ pub(crate) mod codec; pub(crate) mod server; pub use client::ProgressCall; -pub use codec::{ - encode_progress, encode_request, encode_response, ReadStep, ResponseReader, -}; +pub use codec::{encode_progress, encode_request, encode_response, ReadStep, ResponseReader}; pub use server::{ProgressHandler, ProgressResponder}; +/// rpc where the responder streams progress values before a final response +/// +/// the request is length-delimited +/// response frames are tagged so the client can distinguish +/// [`Self::Progress`] items from the final [`Self::Response`] +/// reaching eof before the final response is an error pub trait Progress { + /// route used to dispatch this rpc family const ROUTE: RouteId; + /// codec error shared by request, progress, and response values type Error; + /// typed input sent by the caller type Request: RpcCodec; + /// typed progress item emitted before completion type Progress: RpcCodec; + /// typed terminal response that completes the call type Response: RpcCodec; } diff --git a/ql-rpc/src/rpc/request/client.rs b/ql-rpc/src/rpc/request/client.rs index f4608a4d..e7ffb845 100644 --- a/ql-rpc/src/rpc/request/client.rs +++ b/ql-rpc/src/rpc/request/client.rs @@ -1,15 +1,12 @@ use bytes::BufMut; -use crate::{CallError, ChunkQueue, RpcCodec, RpcRead, read_bytes, request::Request}; +use crate::{read_bytes, request::Request, CallError, ChunkQueue, RpcCodec, RpcRead}; pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { request.encode_value(out) } -pub fn encode_response( - response: &M::Response, - out: &mut (impl BufMut + AsMut<[u8]>), -) { +pub fn encode_response(response: &M::Response, out: &mut (impl BufMut + AsMut<[u8]>)) { response.encode_value(out) } diff --git a/ql-rpc/src/rpc/request/mod.rs b/ql-rpc/src/rpc/request/mod.rs index 4f84a2ef..3c690542 100644 --- a/ql-rpc/src/rpc/request/mod.rs +++ b/ql-rpc/src/rpc/request/mod.rs @@ -6,9 +6,19 @@ pub(crate) mod server; pub use client::{encode_request, encode_response, read_response}; pub use server::{RequestHandler, Response}; +/// request-response rpc with exactly one typed value in each direction +/// +/// the request is read to eof on the server side, so callers must finish the +/// request stream after encoding [`Self::Request`] +/// the response is also read to eof and rejects trailing bytes after +/// [`Self::Response`] pub trait Request { + /// route used to dispatch this rpc family const ROUTE: RouteId; + /// codec error shared by request and response values type Error; + /// typed input sent by the caller type Request: RpcCodec; + /// typed output returned by the responder type Response: RpcCodec; } diff --git a/ql-rpc/src/rpc/subscription/client.rs b/ql-rpc/src/rpc/subscription/client.rs index 39cb4f0d..fe6b3838 100644 --- a/ql-rpc/src/rpc/subscription/client.rs +++ b/ql-rpc/src/rpc/subscription/client.rs @@ -3,8 +3,10 @@ use std::{ task::{Context, Poll}, }; -use crate::{CallError, RpcRead}; -use crate::subscription::{ReadStep, ResponseReader, Subscription}; +use crate::{ + subscription::{ReadStep, ResponseReader, Subscription}, + CallError, RpcRead, +}; pub struct SubscriptionCall where diff --git a/ql-rpc/src/rpc/subscription/codec.rs b/ql-rpc/src/rpc/subscription/codec.rs index 9fdf4d7e..525e817c 100644 --- a/ql-rpc/src/rpc/subscription/codec.rs +++ b/ql-rpc/src/rpc/subscription/codec.rs @@ -49,7 +49,7 @@ impl ResponseReader { pub fn advance(self) -> Result, CodecError> { let mut this = self; - let Some(mut body) = this.bytes.try_take_part().map_err(CodecError::Rpc)? else { + let Some(mut body) = this.bytes.try_take_part()? else { return Ok(ReadStep::NeedMore(this)); }; diff --git a/ql-rpc/src/rpc/subscription/mod.rs b/ql-rpc/src/rpc/subscription/mod.rs index 14edb791..0c4790bf 100644 --- a/ql-rpc/src/rpc/subscription/mod.rs +++ b/ql-rpc/src/rpc/subscription/mod.rs @@ -8,9 +8,17 @@ pub use client::SubscriptionCall; pub use codec::{encode_item, encode_request, ReadStep, ResponseReader}; pub use server::{SubscriptionHandler, SubscriptionResponder}; +/// rpc where one request opens a stream of typed events +/// +/// event frames are length-delimited and the stream ends cleanly at eof +/// any partial trailing frame is reported as truncation on the client side pub trait Subscription { + /// route used to dispatch this rpc family const ROUTE: RouteId; + /// codec error shared by request and event values type Error; + /// typed input that starts the subscription type Request: RpcCodec; + /// typed event yielded by the responder type Event: RpcCodec; } diff --git a/ql-rpc/src/rpc/upload/mod.rs b/ql-rpc/src/rpc/upload/mod.rs index b534cd6c..abe5dbf3 100644 --- a/ql-rpc/src/rpc/upload/mod.rs +++ b/ql-rpc/src/rpc/upload/mod.rs @@ -1,14 +1,18 @@ use crate::{RouteId, RpcCodec}; -/// rpc where the caller streams a large byte body -/// the caller sends a request -/// the caller streams the raw request bytes -/// the responder sends a final typed response +/// rpc where the caller uploads raw bytes after a typed request +/// +/// the typed request usually describes how the responder should interpret the +/// following byte stream +/// once the upload reaches eof, the responder returns one typed +/// [`Self::Response`] pub trait Upload { + /// route used to dispatch this rpc family const ROUTE: RouteId; + /// codec error shared by request and response values type Error; - /// input needed to accept the upload + /// typed input needed before request body bytes arrive type Request: RpcCodec; - /// final status after all bytes are read + /// typed terminal result after the upload body is fully read type Response: RpcCodec; } diff --git a/ql-rpc/src/rpc/utils.rs b/ql-rpc/src/rpc/utils.rs index 60a1ae5a..00c823e3 100644 --- a/ql-rpc/src/rpc/utils.rs +++ b/ql-rpc/src/rpc/utils.rs @@ -49,7 +49,10 @@ where } /// reads one eof-delimited value up to the configured request limit -pub(crate) async fn read_eof_request(reader: &mut R, config: RouterConfig) -> Result +pub(crate) async fn read_eof_request( + reader: &mut R, + config: RouterConfig, +) -> Result where T: RpcCodec, R: RpcRead, From 3d87b86b05f4d7d07df0458d28309c8b09c08570 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 22:23:45 -0400 Subject: [PATCH 264/304] ql-rpc: progress client shared code --- ql-rpc/src/rpc/progress/client.rs | 62 +++++++++---------------------- 1 file changed, 18 insertions(+), 44 deletions(-) diff --git a/ql-rpc/src/rpc/progress/client.rs b/ql-rpc/src/rpc/progress/client.rs index 54920706..bbd68928 100644 --- a/ql-rpc/src/rpc/progress/client.rs +++ b/ql-rpc/src/rpc/progress/client.rs @@ -51,7 +51,7 @@ where poll_fn(|cx| self.poll_next_progress(cx)).await } - pub fn poll_next_progress(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_step(&mut self, cx: &mut Context<'_>) -> Poll> { loop { let reader = match std::mem::replace(&mut self.state, State::Invalid) { State::Reading(reader) => reader, @@ -100,6 +100,10 @@ where } } } + + pub fn poll_next_progress(&mut self, cx: &mut Context<'_>) -> Poll> { + self.poll_step(cx) + } } impl Future for ProgressCall @@ -113,49 +117,19 @@ where let this = self.get_mut(); loop { - let reader = match std::mem::replace(&mut this.state, State::Invalid) { - State::Reading(reader) => reader, - State::Terminal(result) => { - this.state = State::Done; - return Poll::Ready(result); - } - State::Done => panic!("polled after completion"), - State::Invalid => panic!("polled during state transition"), - }; - - match reader.advance() { - Ok(ReadStep::Progress { next, .. }) => { - this.state = State::Reading(next); - } - Ok(ReadStep::Response(response)) => { - this.state = State::Done; - return Poll::Ready(Ok(response)); - } - Ok(ReadStep::NeedMore(next)) => { - this.state = State::Reading(next); - } - Err(error) => { - this.state = State::Done; - return Poll::Ready(Err(error.into())); - } - } - - match this.stream.poll_read(usize::MAX, cx) { - Poll::Ready(Ok(Some(chunk))) => { - let State::Reading(reader) = std::mem::replace(&mut this.state, State::Invalid) - else { - panic!("progress reader is not present"); - }; - this.state = State::Reading(reader.push(chunk)); - } - Poll::Ready(Ok(None)) => { - this.state = State::Done; - return Poll::Ready(Err(Error::MissingResponse.into())); - } - Poll::Ready(Err(error)) => { - this.state = State::Done; - return Poll::Ready(Err(CallError::Transport(error))); - } + match this.poll_step(cx) { + Poll::Ready(Some(_)) => {} + Poll::Ready(None) => match std::mem::replace(&mut this.state, State::Invalid) { + State::Terminal(result) => { + this.state = State::Done; + return Poll::Ready(result); + } + State::Done => panic!("polled after completion"), + State::Invalid => panic!("polled during state transition"), + State::Reading(_) => { + panic!("progress call reached terminal step without result") + } + }, Poll::Pending => return Poll::Pending, } } From 689cc580790bb133d75ca88f2b215eda14922b96 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 22:54:59 -0400 Subject: [PATCH 265/304] ql-rpc: use import --- ql-runtime/src/rpc/download.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ql-runtime/src/rpc/download.rs b/ql-runtime/src/rpc/download.rs index ecca8c63..a560dd69 100644 --- a/ql-runtime/src/rpc/download.rs +++ b/ql-runtime/src/rpc/download.rs @@ -2,7 +2,7 @@ use bytes::Bytes; use ql_rpc::download::Download as DownloadRpc; use super::RpcError; -use crate::StreamReader; +use crate::{QlStreamError, StreamReader}; pub struct DownloadCall { pub(super) inner: ql_rpc::download::DownloadCall, @@ -25,11 +25,11 @@ where } impl DownloadReader { - pub async fn read(&mut self, max_len: usize) -> Result, crate::QlStreamError> { + pub async fn read(&mut self, max_len: usize) -> Result, QlStreamError> { self.inner.read(max_len).await } - pub async fn read_chunk(&mut self) -> Result, crate::QlStreamError> { + pub async fn read_chunk(&mut self) -> Result, QlStreamError> { self.inner.read_chunk().await } From 1083b34b36d87aa62a748dc127a3fbe02cbecc3c Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 23:02:10 -0400 Subject: [PATCH 266/304] ql-rpc: upload --- ql-rpc/src/codec.rs | 4 + ql-rpc/src/router/builder.rs | 33 +++++++++ ql-rpc/src/router/mod.rs | 1 + ql-rpc/src/rpc/upload/client.rs | 78 ++++++++++++++++++++ ql-rpc/src/rpc/upload/mod.rs | 7 ++ ql-rpc/src/rpc/upload/server.rs | 125 ++++++++++++++++++++++++++++++++ ql-rpc/src/rpc/utils.rs | 45 ++++++++++++ ql-runtime/src/rpc/mod.rs | 20 ++++- ql-runtime/src/rpc/upload.rs | 22 ++++++ ql-runtime/src/tests/rpc.rs | 82 ++++++++++++++++++++- 10 files changed, 415 insertions(+), 2 deletions(-) create mode 100644 ql-rpc/src/rpc/upload/client.rs create mode 100644 ql-rpc/src/rpc/upload/server.rs create mode 100644 ql-runtime/src/rpc/upload.rs diff --git a/ql-rpc/src/codec.rs b/ql-rpc/src/codec.rs index 231dfee0..79859f97 100644 --- a/ql-rpc/src/codec.rs +++ b/ql-rpc/src/codec.rs @@ -94,6 +94,10 @@ impl FramedValueReader { self } + pub fn into_bytes(self) -> ChunkQueue { + self.bytes + } + pub fn advance(self) -> Result, CodecError> { let mut this = self; let Some(mut body) = this.bytes.try_take_part()? else { diff --git a/ql-rpc/src/router/builder.rs b/ql-rpc/src/router/builder.rs index 4fb93e4c..9910eb5c 100644 --- a/ql-rpc/src/router/builder.rs +++ b/ql-rpc/src/router/builder.rs @@ -25,6 +25,10 @@ use crate::{ server::{handle_subscription_inner, SubscriptionHandler}, Subscription as SubscriptionRpc, }, + upload::{ + server::{handle_upload_inner, UploadHandler}, + Upload as UploadRpc, + }, RouteId, }; @@ -145,6 +149,19 @@ where )) }) } + + pub fn upload(self) -> Self + where + M: UploadRpc + 'static, + S: UploadHandler + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_upload_inner::( + state, config, reader, writer, + )) + }) + } } impl RouterBuilder @@ -230,4 +247,20 @@ where )) }) } + + pub fn upload(self) -> Self + where + M: UploadRpc + 'static, + M::Request: Send + 'static, + S: UploadHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_upload_inner::( + state, config, reader, writer, + )) + }) + } } diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs index 53a2ace2..47945b3c 100644 --- a/ql-rpc/src/router/mod.rs +++ b/ql-rpc/src/router/mod.rs @@ -14,6 +14,7 @@ pub use crate::{ progress::{ProgressHandler, ProgressResponder}, request::{RequestHandler, Response}, subscription::{SubscriptionHandler, SubscriptionResponder}, + upload::{UploadHandler, UploadReader, UploadResponder}, }; pub struct Router diff --git a/ql-rpc/src/rpc/upload/client.rs b/ql-rpc/src/rpc/upload/client.rs new file mode 100644 index 00000000..b86f5477 --- /dev/null +++ b/ql-rpc/src/rpc/upload/client.rs @@ -0,0 +1,78 @@ +use bytes::{BufMut, Bytes}; + +use crate::{ + finish_bytes, read_bytes, upload::Upload, write_bytes, CallError, ChunkQueue, RpcCodec, + RpcRead, RpcWrite, +}; + +pub struct UploadCall +where + M: Upload, + W: RpcWrite, + R: RpcRead, +{ + writer: Option, + reader: Option, + marker: std::marker::PhantomData M>, +} + +impl UploadCall +where + M: Upload, + W: RpcWrite, + R: RpcRead, +{ + pub fn new(writer: W, reader: R) -> Self { + Self { + writer: Some(writer), + reader: Some(reader), + marker: std::marker::PhantomData, + } + } + + pub async fn send(&mut self, bytes: Bytes) -> Result<(), W::Error> { + let writer = self.writer.as_mut().expect("upload writer exists"); + write_bytes(writer, bytes).await + } + + pub async fn finish(mut self) -> Result> { + let mut writer = self.writer.take().expect("upload writer exists"); + finish_bytes(&mut writer).await.map_err(CallError::Transport)?; + + let mut reader = self.reader.take().expect("upload reader exists"); + let mut bytes = ChunkQueue::default(); + + while let Some(chunk) = read_bytes(&mut reader, usize::MAX) + .await + .map_err(CallError::Transport)? + { + bytes.push(chunk); + } + + let value = M::Response::decode_value(&mut bytes).map_err(CallError::Codec)?; + if bytes.remaining() > 0 { + return Err(crate::Error::TrailingBytes.into()); + } + Ok(value) + } +} + +impl Drop for UploadCall +where + M: Upload, + W: RpcWrite, + R: RpcRead, +{ + fn drop(&mut self) { + if let Some(reader) = self.reader.take() { + reader.close(crate::StreamCloseCode::CANCELLED); + } + if let Some(writer) = self.writer.take() { + writer.close(crate::StreamCloseCode::CANCELLED); + } + } +} + +pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { + crate::codec::encode_value_part(request, out) +} diff --git a/ql-rpc/src/rpc/upload/mod.rs b/ql-rpc/src/rpc/upload/mod.rs index abe5dbf3..be4a4b6c 100644 --- a/ql-rpc/src/rpc/upload/mod.rs +++ b/ql-rpc/src/rpc/upload/mod.rs @@ -1,9 +1,16 @@ use crate::{RouteId, RpcCodec}; +pub(crate) mod client; +pub(crate) mod server; + +pub use client::{encode_request, UploadCall}; +pub use server::{UploadHandler, UploadReader, UploadResponder}; + /// rpc where the caller uploads raw bytes after a typed request /// /// the typed request usually describes how the responder should interpret the /// following byte stream +/// the request is length-delimited so raw upload bytes can follow immediately /// once the upload reaches eof, the responder returns one typed /// [`Self::Response`] pub trait Upload { diff --git a/ql-rpc/src/rpc/upload/server.rs b/ql-rpc/src/rpc/upload/server.rs new file mode 100644 index 00000000..36399bcd --- /dev/null +++ b/ql-rpc/src/rpc/upload/server.rs @@ -0,0 +1,125 @@ +use std::{ + future::poll_fn, + task::{Context, Poll}, +}; + +use bytes::Bytes; + +use crate::{ + request::Response, rpc::read_framed_request_prefix, ChunkQueue, RouterConfig, RpcRead, + RpcStream, RpcWrite, StreamCloseCode, StreamError, Upload, +}; + +pub trait UploadHandler +where + M: Upload, + St: RpcStream, +{ + fn handle( + self, + request: M::Request, + upload: UploadReader, + responder: UploadResponder, + ); + + fn handle_transport_error(&self, _error: &St::Error) {} +} + +pub struct UploadReader +where + R: RpcRead, +{ + buffered: ChunkQueue, + stream: R, +} + +pub struct UploadResponder +where + W: RpcWrite, +{ + inner: Response, +} + +impl UploadReader +where + R: RpcRead, +{ + pub fn poll_read( + &mut self, + max_len: usize, + cx: &mut Context<'_>, + ) -> Poll, R::Error>> { + if let Some(chunk) = self.buffered.pop_front(max_len) { + return Poll::Ready(Ok(Some(chunk))); + } + + self.stream.poll_read(max_len, cx) + } + + pub async fn read(&mut self, max_len: usize) -> Result, R::Error> { + poll_fn(|cx| self.poll_read(max_len, cx)).await + } + + pub async fn read_chunk(&mut self) -> Result, R::Error> { + self.read(usize::MAX).await + } + + pub fn into_inner(self) -> R { + self.stream + } +} + +impl UploadResponder +where + T: crate::RpcCodec, + W: RpcWrite, +{ + pub(crate) fn new(writer: W) -> Self { + Self { + inner: Response::new(writer), + } + } + + pub async fn respond(self, response: T) -> Result<(), W::Error> { + self.inner.respond(response).await + } + + pub fn close(self, code: StreamCloseCode) { + self.inner.close(code); + } +} + +pub(crate) async fn handle_upload_inner( + state: S, + config: RouterConfig, + mut reader: St::Reader, + writer: St::Writer, +) where + M: Upload + 'static, + S: UploadHandler + 'static, + St: RpcStream + 'static, +{ + let (request, buffered) = match read_framed_request_prefix::(&mut reader, config) + .await + { + Ok(value) => value, + Err(error) => { + let code = error.close_code(); + state.handle_transport_error(&error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } + return; + } + }; + + state.handle( + request, + UploadReader { + buffered, + stream: reader, + }, + UploadResponder::new(writer), + ); +} diff --git a/ql-rpc/src/rpc/utils.rs b/ql-rpc/src/rpc/utils.rs index 00c823e3..d18aba68 100644 --- a/ql-rpc/src/rpc/utils.rs +++ b/ql-rpc/src/rpc/utils.rs @@ -48,6 +48,51 @@ where } } +/// reads one length-delimited value and returns any bytes already buffered +pub(crate) async fn read_framed_request_prefix( + reader: &mut R, + config: RouterConfig, +) -> Result<(T, ChunkQueue), R::Error> +where + T: RpcCodec, + R: RpcRead, +{ + let mut bytes = ChunkQueue::default(); + let mut total_read = 0usize; + + loop { + let maybe_value = { + match bytes.try_take_part() { + Ok(Some(mut body)) => { + let value = + T::decode_value(&mut body).map_err(|_error| StreamCloseCode::REFUSED)?; + drop(body); + Some(value) + } + Ok(None) => None, + Err(_error) => return Err(StreamCloseCode::REFUSED.into()), + } + }; + if let Some(value) = maybe_value { + return Ok((value, bytes)); + } + + let remaining = config.max_request_bytes.saturating_sub(total_read); + if remaining == 0 { + return Err(StreamCloseCode::LIMIT.into()); + } + + match read_bytes(reader, remaining).await { + Ok(Some(chunk)) => { + total_read += chunk.len(); + bytes.push(chunk); + } + Ok(None) => return Err(StreamCloseCode::REFUSED.into()), + Err(error) => return Err(error), + } + } +} + /// reads one eof-delimited value up to the configured request limit pub(crate) async fn read_eof_request( reader: &mut R, diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index 21f29491..4e4bd77f 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -1,10 +1,11 @@ -pub use self::{download::*, error::*, progress::*, subscription::*}; +pub use self::{download::*, error::*, progress::*, subscription::*, upload::*}; mod adapter; mod download; mod error; mod progress; mod subscription; +mod upload; use bytes::Bytes; use ql_rpc::{ @@ -13,6 +14,7 @@ use ql_rpc::{ progress::{self as rpc_progress, Progress}, request::{self, Request as RequestRpc}, subscription::{self as rpc_subscription, Subscription as SubscriptionRpc}, + upload::{self as rpc_upload, Upload as UploadRpc}, }; use crate::{RuntimeHandle, StreamReader}; @@ -93,6 +95,22 @@ impl RpcHandle { inner: rpc_progress::ProgressCall::new(response), }) } + + pub async fn upload(&self, request: &M::Request) -> Result, RpcError> + where + M: UploadRpc, + { + let mut payload = Vec::new(); + rpc_upload::encode_request::(request, &mut payload); + let mut stream = self + .inner + .open_stream(adapter::to_wire_route_id(M::ROUTE)) + .await?; + stream.writer.write(Bytes::from(payload)).await?; + Ok(UploadCall { + inner: rpc_upload::UploadCall::new(stream.writer, stream.reader), + }) + } } impl RpcHandle { diff --git a/ql-runtime/src/rpc/upload.rs b/ql-runtime/src/rpc/upload.rs new file mode 100644 index 00000000..c749e44b --- /dev/null +++ b/ql-runtime/src/rpc/upload.rs @@ -0,0 +1,22 @@ +use bytes::Bytes; +use ql_rpc::upload::Upload as UploadRpc; + +use super::RpcError; +use crate::QlStreamError; + +pub struct UploadCall { + pub(super) inner: ql_rpc::upload::UploadCall, +} + +impl UploadCall +where + M: UploadRpc, +{ + pub async fn send(&mut self, bytes: Bytes) -> Result<(), QlStreamError> { + self.inner.send(bytes).await + } + + pub async fn finish(self) -> Result> { + self.inner.finish().await.map_err(RpcError::from) + } +} diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 7c5293c2..c2213dcd 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -11,7 +11,8 @@ use futures_lite::StreamExt; use ql_rpc::{ DownloadHandler, DownloadResponder, DownloadWriter, LocalSpawn, NotificationHandler, ProgressHandler, ProgressResponder, RequestHandler, Response, RouteId, SendSpawn, - StreamCloseCode, SubscriptionHandler, SubscriptionResponder, + StreamCloseCode, SubscriptionHandler, SubscriptionResponder, UploadHandler, UploadReader, + UploadResponder, }; use super::*; @@ -65,6 +66,15 @@ impl ql_rpc::download::Download for BlobDownload { type ResponseHeader = Vec; } +struct BlobUpload; + +impl ql_rpc::upload::Upload for BlobUpload { + const ROUTE: RouteId = RouteId::from_u32(55); + type Error = core::convert::Infallible; + type Request = Vec; + type Response = Vec; +} + #[tokio::test(flavor = "current_thread")] async fn rpc_request() { #[derive(Clone)] @@ -378,3 +388,73 @@ async fn rpc_download() { }) .await; } + +#[tokio::test(flavor = "current_thread")] +async fn rpc_upload() { + #[derive(Clone)] + struct RouterState { + requests: Rc>>>, + uploads: Rc>>>, + } + + impl UploadHandler for RouterState { + fn handle( + self, + request: Vec, + mut upload: UploadReader, + responder: UploadResponder, StreamWriter>, + ) { + let requests = self.requests.clone(); + let uploads = self.uploads.clone(); + tokio::task::spawn_local(async move { + requests.borrow_mut().push(request); + + let mut body = Vec::new(); + while let Some(chunk) = upload.read_chunk().await.unwrap() { + body.extend_from_slice(&chunk); + } + uploads.borrow_mut().push(body.clone()); + + responder.respond(body).await.unwrap(); + }); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let requests = Rc::new(RefCell::new(Vec::new())); + let uploads = Rc::new(RefCell::new(Vec::new())); + + let router = ql_rpc::Router::<_, QlStream, LocalSpawn>::builder(LocalSpawn) + .upload::() + .build(RouterState { + requests: requests.clone(), + uploads: uploads.clone(), + }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await; + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let mut upload = rpc.upload::(&b"logo".to_vec()).await.unwrap(); + upload.send(Bytes::from_static(b"abc")).await.unwrap(); + upload.send(Bytes::from_static(b"def")).await.unwrap(); + let response = upload.finish().await.unwrap(); + + assert_eq!(response, b"abcdef".to_vec()); + assert_eq!(requests.borrow().as_slice(), &[b"logo".to_vec()]); + assert_eq!(uploads.borrow().as_slice(), &[b"abcdef".to_vec()]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} From bb5f8ead030e98a737e0d1828d1d7417bfa24f07 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 23:06:09 -0400 Subject: [PATCH 267/304] ql-rpc: separate framedread --- ql-rpc/src/codec.rs | 92 +------------------------------------- ql-rpc/src/framed_value.rs | 90 +++++++++++++++++++++++++++++++++++++ ql-rpc/src/lib.rs | 4 +- ql-rpc/src/rpc/utils.rs | 8 ++-- 4 files changed, 99 insertions(+), 95 deletions(-) create mode 100644 ql-rpc/src/framed_value.rs diff --git a/ql-rpc/src/codec.rs b/ql-rpc/src/codec.rs index 79859f97..51da527b 100644 --- a/ql-rpc/src/codec.rs +++ b/ql-rpc/src/codec.rs @@ -1,9 +1,8 @@ -use std::{convert::Infallible, marker::PhantomData, str::Utf8Error}; +use std::{convert::Infallible, str::Utf8Error}; use bytes::{Buf, BufMut, Bytes}; pub use crate::chunk_queue::ChunkQueue; -use crate::{CodecError, Error}; pub trait RpcCodec: Sized { type Error; @@ -69,53 +68,9 @@ pub fn encode_value_part>(value: &T, out: & } /// reads one length-delimited rpc value from buffered byte chunks -pub struct FramedValueReader { - bytes: ChunkQueue, - marker: PhantomData T>, -} - -pub enum ReadValueStep { - NeedMore(FramedValueReader), - Value(T), -} - -impl Default for FramedValueReader { - fn default() -> Self { - Self { - bytes: ChunkQueue::default(), - marker: PhantomData, - } - } -} - -impl FramedValueReader { - pub fn push(mut self, chunk: Bytes) -> Self { - self.bytes.push(chunk); - self - } - - pub fn into_bytes(self) -> ChunkQueue { - self.bytes - } - - pub fn advance(self) -> Result, CodecError> { - let mut this = self; - let Some(mut body) = this.bytes.try_take_part()? else { - return Ok(ReadValueStep::NeedMore(this)); - }; - - let value = T::decode_value(&mut body).map_err(CodecError::Codec)?; - drop(body); - if this.bytes.remaining() > 0 { - return Err(CodecError::Rpc(Error::TrailingBytes)); - } - Ok(ReadValueStep::Value(value)) - } -} - pub fn reserve_length>(out: &mut B) -> usize { let start = out.as_mut().len(); - out.put_u64_le(0); + out.put_bytes(0, LENGTH_SIZE); start } @@ -126,46 +81,3 @@ pub fn backpatch_length + ?Sized>(out: &mut B, start: usize) { let payload_len = u64::try_from(payload_len).expect("rpc payload exceeds u64 length framing"); out[start..payload_start].copy_from_slice(&payload_len.to_le_bytes()); } - -#[cfg(test)] -mod tests { - use bytes::Bytes; - - use super::{encode_value_part, FramedValueReader, ReadValueStep}; - - #[test] - fn value_reader_round_trips_framed_values() { - let mut encoded = Vec::new(); - encode_value_part(&b"hello".to_vec(), &mut encoded); - - match FramedValueReader::>::default() - .push(Bytes::from(encoded)) - .advance() - .unwrap() - { - ReadValueStep::Value(value) => assert_eq!(value, b"hello".to_vec()), - _ => unreachable!(), - } - } - - #[test] - fn value_reader_waits_for_complete_frame() { - let mut encoded = Vec::new(); - encode_value_part(&b"hello".to_vec(), &mut encoded); - let encoded = Bytes::from(encoded); - - let reader = match FramedValueReader::>::default() - .push(encoded.slice(..4)) - .advance() - .unwrap() - { - ReadValueStep::NeedMore(next) => next, - _ => unreachable!(), - }; - - match reader.push(encoded.slice(4..)).advance().unwrap() { - ReadValueStep::Value(value) => assert_eq!(value, b"hello".to_vec()), - _ => unreachable!(), - } - } -} diff --git a/ql-rpc/src/framed_value.rs b/ql-rpc/src/framed_value.rs new file mode 100644 index 00000000..b76007f5 --- /dev/null +++ b/ql-rpc/src/framed_value.rs @@ -0,0 +1,90 @@ +use std::marker::PhantomData; + +use bytes::Bytes; + +use crate::{chunk_queue::ChunkQueue, CodecError, Error, RpcCodec}; + +/// reads one length-delimited rpc value from buffered byte chunks +pub struct FramedReader { + bytes: ChunkQueue, + marker: PhantomData T>, +} + +pub enum FramedReadStep { + NeedMore(FramedReader), + Value(T), +} + +impl Default for FramedReader { + fn default() -> Self { + Self { + bytes: ChunkQueue::default(), + marker: PhantomData, + } + } +} + +impl FramedReader { + pub fn push(mut self, chunk: Bytes) -> Self { + self.bytes.push(chunk); + self + } + + pub fn advance(self) -> Result, CodecError> { + let mut this = self; + let Some(mut body) = this.bytes.try_take_part()? else { + return Ok(FramedReadStep::NeedMore(this)); + }; + + let value = T::decode_value(&mut body).map_err(CodecError::Codec)?; + drop(body); + if this.bytes.remaining() > 0 { + return Err(CodecError::Rpc(Error::TrailingBytes)); + } + Ok(FramedReadStep::Value(value)) + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::{FramedReadStep, FramedReader}; + use crate::codec::encode_value_part; + + #[test] + fn value_reader_round_trips_framed_values() { + let mut encoded = Vec::new(); + encode_value_part(&b"hello".to_vec(), &mut encoded); + + match FramedReader::>::default() + .push(Bytes::from(encoded)) + .advance() + .unwrap() + { + FramedReadStep::Value(value) => assert_eq!(value, b"hello".to_vec()), + _ => unreachable!(), + } + } + + #[test] + fn value_reader_waits_for_complete_frame() { + let mut encoded = Vec::new(); + encode_value_part(&b"hello".to_vec(), &mut encoded); + let encoded = Bytes::from(encoded); + + let reader = match FramedReader::>::default() + .push(encoded.slice(..4)) + .advance() + .unwrap() + { + FramedReadStep::NeedMore(next) => next, + _ => unreachable!(), + }; + + match reader.push(encoded.slice(4..)).advance().unwrap() { + FramedReadStep::Value(value) => assert_eq!(value, b"hello".to_vec()), + _ => unreachable!(), + } + } +} diff --git a/ql-rpc/src/lib.rs b/ql-rpc/src/lib.rs index 3b92d5ec..abd20eae 100644 --- a/ql-rpc/src/lib.rs +++ b/ql-rpc/src/lib.rs @@ -3,14 +3,16 @@ mod chunk_queue; pub(crate) mod codec; mod error; +mod framed_value; mod route_id; mod router; mod rpc; mod stream; pub use chunk_queue::ChunkQueue; -pub use codec::{FramedValueReader, ReadValueStep, RpcCodec}; +pub use codec::RpcCodec; pub use error::*; +use framed_value::*; pub use route_id::RouteId; pub use router::*; pub use rpc::*; diff --git a/ql-rpc/src/rpc/utils.rs b/ql-rpc/src/rpc/utils.rs index d18aba68..90c1df0f 100644 --- a/ql-rpc/src/rpc/utils.rs +++ b/ql-rpc/src/rpc/utils.rs @@ -1,5 +1,5 @@ use crate::{ - read_bytes, ChunkQueue, CodecError, FramedValueReader, ReadValueStep, RouterConfig, RpcCodec, + read_bytes, ChunkQueue, CodecError, FramedReadStep, FramedReader, RouterConfig, RpcCodec, RpcRead, StreamCloseCode, }; @@ -12,13 +12,13 @@ where T: RpcCodec, R: RpcRead, { - let mut value_reader = FramedValueReader::::default(); + let mut value_reader = FramedReader::::default(); let mut total_read = 0usize; let value = loop { match value_reader.advance() { - Ok(ReadValueStep::Value(value)) => break value, - Ok(ReadValueStep::NeedMore(next)) => value_reader = next, + Ok(FramedReadStep::Value(value)) => break value, + Ok(FramedReadStep::NeedMore(next)) => value_reader = next, Err(CodecError::Rpc(_error)) => return Err(StreamCloseCode::REFUSED.into()), Err(CodecError::Codec(_error)) => return Err(StreamCloseCode::REFUSED.into()), } From 2732fb6f945bfc5a723eb6830ba20de5151bf6f9 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 16 Apr 2026 23:11:37 -0400 Subject: [PATCH 268/304] ql-runtime: rename event -> notification --- ql-runtime/src/rpc/mod.rs | 2 +- ql-runtime/src/tests/rpc.rs | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index 4e4bd77f..c6960b17 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -25,7 +25,7 @@ pub struct RpcHandle { } impl RpcHandle { - pub async fn event(&self, event: &M::Payload) -> Result<(), RpcError> + pub async fn notification(&self, event: &M::Payload) -> Result<(), RpcError> where M: Notification, { diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index c2213dcd..5335998f 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -16,8 +16,7 @@ use ql_rpc::{ }; use super::*; -use crate::rpc::RpcError; -use crate::{QlStream, StreamWriter}; +use crate::{rpc::RpcError, QlStream, StreamWriter}; struct Echo; @@ -158,7 +157,9 @@ async fn rpc_notification() { }); let rpc = pair.side_mut(Side::A).handle.rpc(); - rpc.event::(&b"hello".to_vec()).await.unwrap(); + rpc.notification::(&b"hello".to_vec()) + .await + .unwrap(); assert_eq!(seen.borrow().as_slice(), &[b"hello".to_vec()]); tokio::time::timeout(Duration::from_secs(2), responder) From 4d4690ef18386f6714f019ac39760553b0584371 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 17 Apr 2026 07:13:33 -0400 Subject: [PATCH 269/304] ql-runtime: unified IO --- ql-runtime/src/chunk_slot/mod.rs | 434 --------------------- ql-runtime/src/command.rs | 6 +- ql-runtime/src/driver/mod.rs | 80 ++-- ql-runtime/src/driver/state.rs | 59 +-- ql-runtime/src/driver/test.rs | 57 +-- ql-runtime/src/handle/mod.rs | 23 +- ql-runtime/src/handle/reader.rs | 193 --------- ql-runtime/src/handle/writer.rs | 184 --------- ql-runtime/src/io/mod.rs | 37 ++ ql-runtime/src/{chunk_slot => io}/queue.rs | 81 +++- ql-runtime/src/io/reader.rs | 192 +++++++++ ql-runtime/src/io/shared.rs | 319 +++++++++++++++ ql-runtime/src/{chunk_slot => io}/sync.rs | 0 ql-runtime/src/io/writer.rs | 203 ++++++++++ ql-runtime/src/lib.rs | 2 +- 15 files changed, 908 insertions(+), 962 deletions(-) delete mode 100644 ql-runtime/src/chunk_slot/mod.rs delete mode 100644 ql-runtime/src/handle/reader.rs delete mode 100644 ql-runtime/src/handle/writer.rs create mode 100644 ql-runtime/src/io/mod.rs rename ql-runtime/src/{chunk_slot => io}/queue.rs (59%) create mode 100644 ql-runtime/src/io/reader.rs create mode 100644 ql-runtime/src/io/shared.rs rename ql-runtime/src/{chunk_slot => io}/sync.rs (100%) create mode 100644 ql-runtime/src/io/writer.rs diff --git a/ql-runtime/src/chunk_slot/mod.rs b/ql-runtime/src/chunk_slot/mod.rs deleted file mode 100644 index d7c330f6..00000000 --- a/ql-runtime/src/chunk_slot/mod.rs +++ /dev/null @@ -1,434 +0,0 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -use bytes::Bytes; -use event_listener::{Event, EventListener}; - -use self::queue::{PopError, PushError, Single}; - -mod queue; -mod sync; - -use sync::Arc; - -/// creates a single-chunk handoff pair -/// receiver-side partial reads keep the remainder locally -pub fn new() -> (ChunkSlotRx, ChunkSlotTx) { - let shared = Arc::new(Shared { - queue: Single::new(), - changed: Event::new(), - }); - - ( - ChunkSlotRx { - shared: Arc::clone(&shared), - pending: Bytes::new(), - }, - ChunkSlotTx { shared }, - ) -} - -pub struct ChunkSlotRx { - shared: Arc, - pending: Bytes, -} - -pub struct ChunkSlotTx { - shared: Arc, -} - -struct Shared { - queue: Single, - changed: Event, -} - -#[derive(Debug)] -pub struct SendClosed(pub Bytes); - -#[derive(Debug, PartialEq, Eq)] -pub enum TrySendError { - Closed(Bytes), - Full(Bytes), -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct RecvClosed; - -impl ChunkSlotRx { - pub fn try_recv(&mut self, max_len: usize) -> Result { - if !self.pending.is_empty() { - let pending = &mut self.pending; - let bytes = if pending.len() <= max_len { - std::mem::take(pending) - } else { - pending.split_to(max_len) - }; - return Ok(bytes); - } - - match self.shared.queue.pop() { - Ok(mut bytes) => { - self.shared.changed.notify(usize::MAX); - let pending = &mut self.pending; - - let bytes = if bytes.len() <= max_len { - bytes - } else { - let head = bytes.split_to(max_len); - *pending = bytes; - head - }; - Ok(bytes) - } - Err(PopError::Empty) => Ok(Bytes::new()), - Err(PopError::Closed) => Err(RecvClosed), - } - } - - pub fn poll_recv( - &mut self, - max_len: usize, - listener: &mut Option, - cx: &mut Context<'_>, - ) -> Poll> { - loop { - match self.try_recv(max_len) { - Ok(bytes) if !bytes.is_empty() => return Poll::Ready(Ok(bytes)), - Err(closed) => return Poll::Ready(Err(closed)), - Ok(_) => {} - } - - if let Some(active_listener) = listener.as_mut() { - match Pin::new(active_listener).poll(cx) { - Poll::Ready(()) => *listener = None, - Poll::Pending => return Poll::Pending, - } - } else { - *listener = Some(self.shared.changed.listen()); - } - } - } - - pub fn recv(&mut self, max_len: usize) -> Recv<'_> { - Recv { - rx: self, - max_len, - listener: None, - } - } - - pub fn is_finished(&self) -> bool { - self.pending.is_empty() && self.shared.queue.is_closed() && self.shared.queue.is_empty() - } - - pub fn is_empty(&self) -> bool { - self.pending.is_empty() && self.shared.queue.is_empty() - } - - pub fn close(self) { - if self.shared.queue.close() { - self.shared.changed.notify(usize::MAX); - } - } -} - -impl Drop for ChunkSlotRx { - fn drop(&mut self) { - if self.shared.queue.close() { - self.shared.changed.notify(usize::MAX); - } - } -} - -impl ChunkSlotTx { - pub fn try_send(&self, bytes: Bytes) -> Result<(), TrySendError> { - match self.shared.queue.push(bytes) { - Ok(()) => { - self.shared.changed.notify(usize::MAX); - Ok(()) - } - Err(PushError::Full(bytes)) => Err(TrySendError::Full(bytes)), - Err(PushError::Closed(bytes)) => Err(TrySendError::Closed(bytes)), - } - } - - pub fn poll_send( - &self, - bytes: &mut Bytes, - listener: &mut Option, - cx: &mut Context<'_>, - ) -> Poll> { - loop { - let chunk = std::mem::take(bytes); - match self.try_send(chunk) { - Ok(()) => return Poll::Ready(Ok(())), - Err(TrySendError::Closed(chunk)) => { - *bytes = chunk.clone(); - return Poll::Ready(Err(SendClosed(chunk))); - } - Err(TrySendError::Full(chunk)) => *bytes = chunk, - } - - if let Some(active_listener) = listener.as_mut() { - match Pin::new(active_listener).poll(cx) { - Poll::Ready(()) => *listener = None, - Poll::Pending => return Poll::Pending, - } - } else { - *listener = Some(self.shared.changed.listen()); - } - } - } - - pub fn send(&self, bytes: Bytes) -> Send<'_> { - Send { - tx: self, - bytes, - listener: None, - } - } - - pub fn is_closed(&self) -> bool { - self.shared.queue.is_closed() - } - - pub fn close(self) { - if self.shared.queue.close() { - self.shared.changed.notify(usize::MAX); - } - } -} - -impl Drop for ChunkSlotTx { - fn drop(&mut self) { - if self.shared.queue.close() { - self.shared.changed.notify(usize::MAX); - } - } -} - -pub struct Recv<'a> { - rx: &'a mut ChunkSlotRx, - max_len: usize, - listener: Option, -} - -impl Future for Recv<'_> { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().get_mut(); - this.rx.poll_recv(this.max_len, &mut this.listener, cx) - } -} - -pub struct Send<'a> { - tx: &'a ChunkSlotTx, - bytes: Bytes, - listener: Option, -} - -impl Future for Send<'_> { - type Output = Result<(), SendClosed>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().get_mut(); - this.tx.poll_send(&mut this.bytes, &mut this.listener, cx) - } -} - -#[cfg(test)] -mod tests { - use std::time::Duration; - - use bytes::Bytes; - - use super::{new, RecvClosed}; - - #[test] - fn try_send_and_take_round_trip() { - let (mut rx, tx) = new(); - - tx.try_send(Bytes::from_static(b"hello")).unwrap(); - assert_eq!(rx.try_recv(8), Ok(Bytes::from_static(b"hello"))); - assert_eq!(rx.try_recv(8), Ok(Bytes::new())); - } - - #[test] - fn read_splits_moves_remainder_to_receiver() { - let (mut rx, tx) = new(); - - tx.try_send(Bytes::from_static(b"hello")).unwrap(); - assert_eq!(rx.try_recv(2), Ok(Bytes::from_static(b"he"))); - tx.try_send(Bytes::from_static(b"!")).unwrap(); - assert_eq!(rx.try_recv(8), Ok(Bytes::from_static(b"llo"))); - assert_eq!(rx.try_recv(8), Ok(Bytes::from_static(b"!"))); - } - - #[test] - fn read_drains_slot_when_limit_covers_chunk() { - let (mut rx, tx) = new(); - - tx.try_send(Bytes::from_static(b"hello")).unwrap(); - assert_eq!(rx.try_recv(8), Ok(Bytes::from_static(b"hello"))); - tx.try_send(Bytes::from_static(b"!")).unwrap(); - assert_eq!(rx.try_recv(8), Ok(Bytes::from_static(b"!"))); - } - - #[tokio::test(flavor = "current_thread")] - async fn send_waits_until_slot_clears() { - let (mut rx, tx) = new(); - - tx.try_send(Bytes::from_static(b"a")).unwrap(); - - let sender = tokio::spawn(async move { - tx.send(Bytes::from_static(b"b")).await.unwrap(); - }); - - tokio::time::sleep(Duration::from_millis(10)).await; - assert_eq!(rx.try_recv(8), Ok(Bytes::from_static(b"a"))); - - tokio::time::timeout(Duration::from_secs(1), sender) - .await - .unwrap() - .unwrap(); - } - - #[tokio::test(flavor = "current_thread")] - async fn finish_yields_eof_after_buffered_chunk() { - let (mut rx, tx) = new(); - - tx.send(Bytes::from_static(b"abc")).await.unwrap(); - tx.close(); - - assert_eq!(rx.recv(8).await, Ok(Bytes::from_static(b"abc"))); - assert_eq!(rx.recv(8).await, Err(RecvClosed)); - assert!(rx.is_finished()); - } - - #[tokio::test(flavor = "current_thread")] - async fn closing_receiver_returns_unsent_bytes() { - let (rx, tx) = new(); - - rx.close(); - - let err = tx.send(Bytes::from_static(b"abc")).await.unwrap_err(); - assert_eq!(err.0, Bytes::from_static(b"abc")); - } - - #[test] - fn zero_length_recv_does_not_consume_buffered_chunk() { - let (mut rx, tx) = new(); - - tx.try_send(Bytes::from_static(b"hello")).unwrap(); - assert_eq!(rx.try_recv(0), Ok(Bytes::new())); - assert_eq!(rx.try_recv(8), Ok(Bytes::from_static(b"hello"))); - } -} - -#[cfg(all(test, loom))] -mod loom_tests { - use std::{ - future::Future, - pin::pin, - task::{Context, Poll, Waker}, - }; - - use bytes::Bytes; - use loom::{model, thread}; - - use super::{new, RecvClosed}; - - fn now_or_never(future: F) -> Option { - let waker = Waker::noop(); - let mut cx = Context::from_waker(waker); - let mut future = pin!(future); - match future.as_mut().poll(&mut cx) { - Poll::Ready(value) => Some(value), - Poll::Pending => None, - } - } - - fn check_model(f: impl Fn() + Sync + Send + 'static) { - let builder = model::Builder::new(); - builder.check(f); - } - - #[test] - fn try_recv_never_reports_closed_while_open() { - check_model(|| { - let (mut rx, tx) = new(); - - let sender = thread::spawn(move || { - let _ = tx.try_send(Bytes::from_static(b"abc")); - }); - - let receiver = thread::spawn(move || { - let result = rx.try_recv(1); - assert!( - !matches!(result, Err(RecvClosed)), - "open slot must not report RecvClosed" - ); - }); - - sender.join().unwrap(); - receiver.join().unwrap(); - }); - } - - #[test] - fn recv_observes_send_after_pending() { - check_model(|| { - let (mut rx, tx) = new(); - - assert!(now_or_never(rx.recv(8)).is_none()); - - let sender = thread::spawn(move || { - tx.try_send(Bytes::from_static(b"abc")).unwrap(); - }); - - sender.join().unwrap(); - - assert_eq!( - now_or_never(rx.recv(8)), - Some(Ok(Bytes::from_static(b"abc"))) - ); - }); - } - - #[test] - fn recv_observes_finish_as_closed() { - check_model(|| { - let (mut rx, tx) = new(); - - assert!(now_or_never(rx.recv(8)).is_none()); - - let finisher = thread::spawn(move || { - tx.close(); - }); - - finisher.join().unwrap(); - - assert_eq!(now_or_never(rx.recv(8)), Some(Err(RecvClosed))); - }); - } - - #[test] - fn partial_recv_preserves_remainder_and_finished_state() { - check_model(|| { - let (mut rx, tx) = new(); - - tx.try_send(Bytes::from_static(b"abcd")).unwrap(); - tx.close(); - - assert_eq!(rx.try_recv(2), Ok(Bytes::from_static(b"ab"))); - assert!(!rx.is_finished()); - assert_eq!(rx.try_recv(8), Ok(Bytes::from_static(b"cd"))); - assert_eq!(rx.try_recv(8), Err(RecvClosed)); - assert!(rx.is_finished()); - }); - } -} diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index 1757d843..7fc835d5 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -1,7 +1,7 @@ use ql_fsm::NoSessionError; use ql_wire::{CloseTarget, PairingToken, PeerBundle, RouteId, StreamCloseCode, StreamId}; -use crate::{chunk_slot::ChunkSlotRx, QlStreamError, StreamReader}; +use crate::{StreamReader, StreamWriter}; pub enum Command { BindPeer { @@ -17,9 +17,7 @@ pub enum Command { }, OpenStream { route_id: RouteId, - request_reader: ChunkSlotRx, - request_terminal: oneshot::Sender>, - start: oneshot::Sender>, + start: oneshot::Sender>, }, PollInbound { stream_id: StreamId, diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 725ef771..540e7b99 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -20,10 +20,9 @@ use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; use self::state::{DriverState, DriverStreamIo, InboundIo, InboundWriteResult, OutboundIo}; use crate::{ - chunk_slot, command::Command, - handle::{QlStream, StreamReader, StreamWriter}, - log, + handle::QlStream, + io, log, platform::{QlInbound, QlPlatform, QlTimer}, QlStreamError, Runtime, RuntimeHandle, }; @@ -203,12 +202,7 @@ impl DriverState { log::info!(" starting XX pairing"); fsm.connect_xx(now(), token, platform); } - Command::OpenStream { - route_id, - request_reader, - request_terminal, - start, - } => { + Command::OpenStream { route_id, start } => { log::info!("open stream requested: route_id={route_id}"); let Some(runtime_tx) = self.runtime_tx.upgrade() else { log::warn!("open stream aborted: runtime channel unavailable"); @@ -226,24 +220,24 @@ impl DriverState { }; let stream_id = stream_ops.stream_id(); log::info!("open stream allocated: route_id={route_id} stream_id={stream_id}"); - let (response_reader, response_writer) = chunk_slot::new(); - let (response_terminal_tx, response_terminal_rx) = oneshot::channel(); + let stream = io::new_stream( + stream_id, + CloseTarget::Return, + CloseTarget::Origin, + RuntimeHandle::new(runtime_tx), + ); self.streams.insert( stream_id, DriverStreamIo::new( true, - Some(OutboundIo::new(request_reader, request_terminal)), - Some(InboundIo::new(response_writer, response_terminal_tx)), + Some(OutboundIo::new(stream.writer_io)), + Some(InboundIo::new(stream.reader_io)), ), ); - let reader = StreamReader::new( - stream_id, - CloseTarget::Return, - response_reader, - response_terminal_rx, - RuntimeHandle::new(runtime_tx), - ); - if start.send(Ok((stream_id, reader))).is_err() { + if start + .send(Ok((stream_id, stream.reader, stream.writer))) + .is_err() + { log::warn!("open stream cancelled before delivery: stream_id={stream_id}"); if let Some(stream) = self.streams.get_mut(&stream_id) { stream.inbound_close(); @@ -367,17 +361,19 @@ impl DriverState { return; }; - let (request_reader, request_writer) = chunk_slot::new(); - let (request_terminal_tx, request_terminal_rx) = oneshot::channel(); - let (response_reader, response_writer) = chunk_slot::new(); - let (response_terminal_tx, response_terminal_rx) = oneshot::channel(); + let stream = io::new_stream( + stream_id, + CloseTarget::Origin, + CloseTarget::Return, + RuntimeHandle::new(runtime_tx), + ); self.streams.insert( stream_id, DriverStreamIo::new( false, - Some(OutboundIo::new(response_reader, response_terminal_tx)), - Some(InboundIo::new(request_writer, request_terminal_tx)), + Some(OutboundIo::new(stream.writer_io)), + Some(InboundIo::new(stream.reader_io)), ), ); @@ -387,20 +383,8 @@ impl DriverState { platform.handle_inbound(QlStream { stream_id, route_id, - reader: StreamReader::new( - stream_id, - CloseTarget::Origin, - request_reader, - request_terminal_rx, - RuntimeHandle::new(runtime_tx.clone()), - ), - writer: StreamWriter::new( - stream_id, - CloseTarget::Return, - response_writer, - response_terminal_rx, - RuntimeHandle::new(runtime_tx), - ), + reader: stream.reader, + writer: stream.writer, }); } @@ -550,13 +534,13 @@ impl DriverState { return; }; let stream = entry.get_mut(); - let Some(reader) = stream.outbound_reader_mut() else { - log::trace!("poll stream skipped without outbound reader: stream_id={stream_id}"); + let Some(writer_io) = stream.outbound_writer_mut() else { + log::trace!("poll stream skipped without outbound writer: stream_id={stream_id}"); return; }; - if reader.is_finished() { - log::info!("observed outbound reader finished before write: stream_id={stream_id}"); + if writer_io.is_finished() { + log::info!("observed outbound writer finished before write: stream_id={stream_id}"); if let Ok(mut stream_ops) = fsm.stream(stream_id) { if let Some(writer) = stream_ops.writer() { writer.finish(); @@ -584,7 +568,7 @@ impl DriverState { break; } - let Ok(mut bytes) = reader.try_recv(capacity) else { + let Ok(mut bytes) = writer_io.try_read(capacity) else { break; }; if bytes.is_empty() { @@ -598,8 +582,8 @@ impl DriverState { let _ = writer.write(&mut bytes); } - if reader.is_finished() { - log::info!("observed outbound reader finished after write: stream_id={stream_id}"); + if writer_io.is_finished() { + log::info!("observed outbound writer finished after write: stream_id={stream_id}"); writer.finish(); stream.outbound_queue_finish(); if stream.is_closed() { diff --git a/ql-runtime/src/driver/state.rs b/ql-runtime/src/driver/state.rs index f5bfdd85..37694c92 100644 --- a/ql-runtime/src/driver/state.rs +++ b/ql-runtime/src/driver/state.rs @@ -4,8 +4,8 @@ use bytes::Bytes; use ql_wire::{CloseTarget, StreamId}; use crate::{ - chunk_slot::{ChunkSlotRx, ChunkSlotTx, TrySendError}, command::Command, + io::{PushError, ReaderIo, WriterIo}, QlStreamError, }; @@ -64,30 +64,23 @@ impl DriverStreamIo { } pub fn outbound_finish(&mut self) { - if let Some(mut outbound) = self.outbound.take() { - if let Some(terminal) = outbound.terminal.take() { - let _ = terminal.send(Ok(())); - } + if let Some(outbound) = self.outbound.take() { + outbound.writer.finish(); } } pub fn outbound_fail(&mut self, error: QlStreamError) { - if let Some(mut outbound) = self.outbound.take() { - if let Some(terminal) = outbound.terminal.take() { - let _ = terminal.send(Err(error)); - } + if let Some(outbound) = self.outbound.take() { + let _ = outbound.writer.fail(error); } } - pub fn outbound_reader_mut(&mut self) -> Option<&mut ChunkSlotRx> { - self.outbound - .as_mut() - .and_then(|outbound| outbound.reader.as_mut()) + pub fn outbound_writer_mut(&mut self) -> Option<&mut WriterIo> { + self.outbound.as_mut().map(|outbound| &mut outbound.writer) } pub fn outbound_queue_finish(&mut self) { if let Some(outbound) = self.outbound.as_mut() { - outbound.reader = None; outbound.finish_pending = true; } } @@ -108,10 +101,10 @@ impl DriverStreamIo { }; let len = bytes.len(); - match inbound.writer.try_send(bytes) { + match inbound.reader.try_write(bytes) { Ok(()) => InboundWriteResult::Accepted(len), - Err(TrySendError::Full(_)) => InboundWriteResult::Full, - Err(TrySendError::Closed(_)) => { + Err(PushError::Full(_)) => InboundWriteResult::Full, + Err(PushError::Closed(_)) => { self.inbound = None; InboundWriteResult::Closed } @@ -119,43 +112,34 @@ impl DriverStreamIo { } pub fn inbound_finish(&mut self) { - if let Some(mut inbound) = self.inbound.take() { - inbound.writer.close(); - if let Some(terminal) = inbound.terminal.take() { - let _ = terminal.send(Ok(())); - } + if let Some(inbound) = self.inbound.take() { + inbound.reader.finish(); } } pub fn inbound_fail(&mut self, error: QlStreamError) { - if let Some(mut inbound) = self.inbound.take() { - inbound.writer.close(); - if let Some(terminal) = inbound.terminal.take() { - let _ = terminal.send(Err(error)); - } + if let Some(inbound) = self.inbound.take() { + let _ = inbound.reader.fail(error); } } } pub struct OutboundIo { - reader: Option, - terminal: Option>>, + writer: WriterIo, finish_pending: bool, } impl OutboundIo { - pub fn new(reader: ChunkSlotRx, terminal: oneshot::Sender>) -> Self { + pub fn new(writer: WriterIo) -> Self { Self { - reader: Some(reader), - terminal: Some(terminal), + writer, finish_pending: false, } } } pub struct InboundIo { - writer: ChunkSlotTx, - terminal: Option>>, + reader: ReaderIo, } pub enum InboundWriteResult { @@ -165,10 +149,7 @@ pub enum InboundWriteResult { } impl InboundIo { - pub fn new(writer: ChunkSlotTx, terminal: oneshot::Sender>) -> Self { - Self { - writer, - terminal: Some(terminal), - } + pub fn new(reader: ReaderIo) -> Self { + Self { reader } } } diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index 1f3caa47..9ca68645 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -2,8 +2,8 @@ use ql_wire::{test_identity, NoopCrypto, PeerBundle, SoftwareCrypto, StreamClose use super::*; use crate::{ - chunk_slot, driver::state::{InboundIo, OutboundIo}, + io, platform::QlInbound, }; @@ -66,15 +66,25 @@ fn new_driver_state() -> (DriverState, QlFsm) { fn new_inbound_io(capacity: usize) -> InboundIo { let _ = capacity; - let (_reader, writer) = chunk_slot::new(); - let (terminal_tx, _terminal_rx) = oneshot::channel(); - InboundIo::new(writer, terminal_tx) + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + let stream = io::new_stream( + StreamId(99u32.into()), + CloseTarget::Origin, + CloseTarget::Return, + RuntimeHandle::new(runtime_tx), + ); + InboundIo::new(stream.reader_io) } fn new_outbound_io() -> OutboundIo { - let (reader, _writer) = chunk_slot::new(); - let (terminal_tx, _terminal_rx) = oneshot::channel(); - OutboundIo::new(reader, terminal_tx) + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + let stream = io::new_stream( + StreamId(100u32.into()), + CloseTarget::Return, + CloseTarget::Origin, + RuntimeHandle::new(runtime_tx), + ); + OutboundIo::new(stream.writer_io) } #[test] @@ -115,17 +125,17 @@ fn handle_closed_stream_reaps_when_both_halves_close() { fn poll_stream_keeps_outbound_pending_after_local_finish_when_inbound_is_closed() { let (mut state, mut fsm) = new_driver_state(); let stream_id = StreamId(1u32.into()); - let (request_reader, request_writer) = chunk_slot::new(); - let (request_terminal_tx, _request_terminal_rx) = oneshot::channel(); - - drop(request_writer); + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + let mut stream = io::new_stream( + stream_id, + CloseTarget::Return, + CloseTarget::Origin, + RuntimeHandle::new(runtime_tx), + ); + stream.writer.queue_finish(); state.streams.insert( stream_id, - DriverStreamIo::new( - true, - Some(OutboundIo::new(request_reader, request_terminal_tx)), - None, - ), + DriverStreamIo::new(true, Some(OutboundIo::new(stream.writer_io)), None), ); state.poll_stream(&mut fsm, stream_id); @@ -139,16 +149,17 @@ fn poll_stream_keeps_outbound_pending_after_local_finish_when_inbound_is_closed( fn local_close_command_reaps_when_other_half_is_already_closed() { let (mut state, mut fsm) = new_driver_state(); let stream_id = StreamId(1u32.into()); - let (request_reader, _request_writer) = chunk_slot::new(); - let (request_terminal_tx, _request_terminal_rx) = oneshot::channel(); + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + let stream = io::new_stream( + stream_id, + CloseTarget::Return, + CloseTarget::Origin, + RuntimeHandle::new(runtime_tx), + ); state.streams.insert( stream_id, - DriverStreamIo::new( - true, - Some(OutboundIo::new(request_reader, request_terminal_tx)), - None, - ), + DriverStreamIo::new(true, Some(OutboundIo::new(stream.writer_io)), None), ); state.drive_command( diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index c6839f5d..e2c45314 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -1,11 +1,8 @@ -mod reader; -mod writer; - use ql_fsm::NoSessionError; -use ql_wire::{CloseTarget, PairingToken, PeerBundle, RouteId, StreamId}; +use ql_wire::{PairingToken, PeerBundle, RouteId, StreamId}; -pub use self::{reader::*, writer::*}; -use crate::{chunk_slot, command::Command}; +use crate::command::Command; +pub use crate::io::{StreamReader, StreamWriter}; #[derive(Debug)] pub struct QlStream { @@ -48,30 +45,20 @@ impl RuntimeHandle { /// opens a new stream on the active encrypted session pub async fn open_stream(&self, route_id: RouteId) -> Result { - let (request_reader, request_writer) = chunk_slot::new(); - let (request_terminal_tx, request_terminal_rx) = oneshot::channel(); let (start_tx, start_rx) = oneshot::channel(); self.send(Command::OpenStream { route_id, - request_reader, - request_terminal: request_terminal_tx, start: start_tx, }); // runtime cannot be shutdown while we have a handle - let (stream_id, reader) = start_rx.await.unwrap()?; + let (stream_id, reader, writer) = start_rx.await.unwrap()?; Ok(QlStream { stream_id, route_id, - writer: StreamWriter::new( - stream_id, - CloseTarget::Origin, - request_writer, - request_terminal_rx, - self.clone(), - ), + writer, reader, }) } diff --git a/ql-runtime/src/handle/reader.rs b/ql-runtime/src/handle/reader.rs deleted file mode 100644 index c3ccf353..00000000 --- a/ql-runtime/src/handle/reader.rs +++ /dev/null @@ -1,193 +0,0 @@ -use std::{ - future::{poll_fn, Future}, - pin::Pin, - task::{Context, Poll}, -}; - -use bytes::Bytes; -use event_listener::EventListener; -use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; - -use crate::{chunk_slot::ChunkSlotRx, command::Command, log, QlStreamError, RuntimeHandle}; - -pub struct StreamReader { - stream_id: StreamId, - target: CloseTarget, - reader: Option, - wait: Option, - terminal: TerminalState, - handle: RuntimeHandle, -} - -enum TerminalState { - Armed(oneshot::Receiver>), - Terminal(Result<(), QlStreamError>), - Delivered, -} - -// Safety: `ByteReader` contains a `oneshot::Receiver`, which is `!Sync`, but that receiver is -// fully encapsulated. No safe API accesses it through `&self`; all access requires `&mut self` -// or ownership, so shared references cannot race the receiver state across threads. -unsafe impl Sync for StreamReader {} - -impl std::fmt::Debug for StreamReader { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("InboundByteStream") - .field("stream_id", &self.stream_id) - .field("target", &self.target) - .field( - "terminal", - &matches!(self.terminal, TerminalState::Delivered), - ) - .finish_non_exhaustive() - } -} - -impl StreamReader { - pub(crate) fn new( - stream_id: StreamId, - target: CloseTarget, - reader: ChunkSlotRx, - terminal: oneshot::Receiver>, - handle: RuntimeHandle, - ) -> Self { - Self { - stream_id, - target, - reader: Some(reader), - wait: None, - terminal: TerminalState::Armed(terminal), - handle, - } - } - - pub fn poll_read( - &mut self, - max_len: usize, - cx: &mut Context<'_>, - ) -> Poll, QlStreamError>> { - if matches!(self.terminal, TerminalState::Delivered) { - return Poll::Ready(Ok(None)); - } - - if let Some(reader) = self.reader.as_mut() { - match reader.poll_recv(max_len, &mut self.wait, cx) { - Poll::Ready(Ok(bytes)) => { - log::trace!( - "byte reader received chunk: stream_id={:?} target={:?} len={}", - self.stream_id, - self.target, - bytes.len() - ); - self.handle.try_send(Command::PollInbound { - stream_id: self.stream_id, - }); - return Poll::Ready(Ok(Some(bytes))); - } - Poll::Ready(Err(_)) => { - log::debug!( - "byte reader channel closed: stream_id={:?} target={:?}", - self.stream_id, - self.target - ); - self.reader = None; - self.wait = None; - } - Poll::Pending => {} - } - } - - if let TerminalState::Armed(terminal) = &mut self.terminal { - let result = match Pin::new(terminal).poll(cx) { - Poll::Pending => None, - Poll::Ready(Ok(result)) => Some(result), - Poll::Ready(Err(_)) => { - panic!("byte reader terminal dropped before sending a terminal state") - } - }; - if let Some(result) = result { - self.terminal = TerminalState::Terminal(result); - } - } - - match &self.terminal { - TerminalState::Armed(_) => Poll::Pending, - TerminalState::Terminal(Ok(())) => { - log::debug!( - "byte reader delivered clean eof: stream_id={:?} target={:?}", - self.stream_id, - self.target - ); - self.terminal = TerminalState::Delivered; - Poll::Ready(Ok(None)) - } - TerminalState::Terminal(Err(error)) => { - let error = error.clone(); - log::debug!( - "byte reader delivered terminal error: stream_id={:?} target={:?} error={:?}", - self.stream_id, - self.target, - error - ); - self.terminal = TerminalState::Delivered; - Poll::Ready(Err(error)) - } - TerminalState::Delivered => Poll::Ready(Ok(None)), - } - } - - pub fn poll_read_chunk( - &mut self, - cx: &mut Context<'_>, - ) -> Poll, QlStreamError>> { - self.poll_read(usize::MAX, cx) - } - - /// Returns `Ok(None)` on clean EOF, `Ok(Some(_))` for data, and `Err(_)` for stream failure. - pub async fn read(&mut self, max_len: usize) -> Result, QlStreamError> { - poll_fn(|cx| self.poll_read(max_len, cx)).await - } - - pub async fn read_chunk(&mut self) -> Result, QlStreamError> { - self.read(usize::MAX).await - } - - pub fn close(mut self, code: StreamCloseCode) { - if matches!(self.terminal, TerminalState::Delivered) { - return; - } - log::debug!( - "byte reader explicit close: stream_id={:?} target={:?} code={:?}", - self.stream_id, - self.target, - code - ); - self.reader.take(); - self.wait = None; - self.terminal = TerminalState::Delivered; - self.handle.try_send(Command::CloseStream { - stream_id: self.stream_id, - target: self.target, - code, - }); - } -} - -impl Drop for StreamReader { - fn drop(&mut self) { - if matches!(self.terminal, TerminalState::Delivered) { - return; - } - log::debug!( - "byte reader drop close: stream_id={:?} target={:?} code={:?}", - self.stream_id, - self.target, - StreamCloseCode::CANCELLED - ); - self.handle.try_send(Command::CloseStream { - stream_id: self.stream_id, - target: self.target, - code: StreamCloseCode::CANCELLED, - }); - } -} diff --git a/ql-runtime/src/handle/writer.rs b/ql-runtime/src/handle/writer.rs deleted file mode 100644 index 572bd1f3..00000000 --- a/ql-runtime/src/handle/writer.rs +++ /dev/null @@ -1,184 +0,0 @@ -use std::{ - future::{poll_fn, Future}, - pin::Pin, - task::{Context, Poll}, -}; - -use bytes::Bytes; -use event_listener::EventListener; -use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; - -use crate::{ - chunk_slot::{ChunkSlotTx, SendClosed}, - command::Command, - log, QlStreamError, RuntimeHandle, -}; - -pub struct StreamWriter { - stream_id: StreamId, - target: CloseTarget, - writer: Option, - wait: Option, - terminal: WriteTerminalState, - handle: RuntimeHandle, -} - -enum WriteTerminalState { - Armed(oneshot::Receiver>), - Terminal(Result<(), QlStreamError>), -} - -// Safety: `ByteWriter` contains a `oneshot::Receiver`, which is `!Sync`, but that receiver is -// fully encapsulated. No safe API accesses it through `&self`; all access requires `&mut self` -// or ownership, so shared references cannot race the receiver state across threads. -unsafe impl Sync for StreamWriter {} - -impl std::fmt::Debug for StreamWriter { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("OutboundByteStream") - .field("stream_id", &self.stream_id) - .field("target", &self.target) - .field("closed", &self.writer.is_none()) - .finish_non_exhaustive() - } -} - -impl StreamWriter { - pub fn poll_write( - &mut self, - bytes: &mut Bytes, - cx: &mut Context<'_>, - ) -> Poll> { - if bytes.is_empty() { - return Poll::Ready(Ok(())); - } - - let Some(writer) = self.writer.as_ref() else { - return self.poll_terminal(cx); - }; - - match writer.poll_send(bytes, &mut self.wait, cx) { - Poll::Ready(Ok(())) => { - log::trace!( - "byte writer accepted chunk: stream_id={:?} target={:?}", - self.stream_id, - self.target - ); - self.wait = None; - self.poll_runtime(); - Poll::Ready(Ok(())) - } - Poll::Ready(Err(SendClosed(_bytes))) => { - log::debug!( - "byte writer send closed: stream_id={:?} target={:?}", - self.stream_id, - self.target - ); - self.writer.take(); - self.wait = None; - self.poll_terminal(cx) - } - Poll::Pending => Poll::Pending, - } - } - - pub async fn write(&mut self, bytes: Bytes) -> Result<(), QlStreamError> { - let mut bytes = bytes; - poll_fn(|cx| self.poll_write(&mut bytes, cx)).await - } - - pub fn queue_finish(&mut self) { - let Some(writer) = self.writer.take() else { - return; - }; - log::debug!( - "byte writer finish: stream_id={:?} target={:?}", - self.stream_id, - self.target - ); - writer.close(); - self.wait = None; - self.poll_runtime(); - } - - pub async fn finish(mut self) -> Result<(), QlStreamError> { - self.queue_finish(); - poll_fn(|cx| self.poll_terminal(cx)).await - } - - pub fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll> { - if self.writer.is_some() { - self.queue_finish(); - } - self.poll_terminal(cx) - } - - pub fn close(mut self, code: StreamCloseCode) { - self.close_inner(code); - } -} - -impl Drop for StreamWriter { - fn drop(&mut self) { - self.close_inner(StreamCloseCode::CANCELLED); - } -} - -impl StreamWriter { - pub(crate) fn new( - stream_id: StreamId, - target: CloseTarget, - writer: ChunkSlotTx, - terminal: oneshot::Receiver>, - handle: RuntimeHandle, - ) -> Self { - Self { - stream_id, - target, - writer: Some(writer), - wait: None, - terminal: WriteTerminalState::Armed(terminal), - handle, - } - } - - fn poll_runtime(&self) { - self.handle.try_send(Command::PollStream { - stream_id: self.stream_id, - }); - } - - fn poll_terminal(&mut self, cx: &mut Context<'_>) -> Poll> { - match &mut self.terminal { - WriteTerminalState::Terminal(result) => Poll::Ready(result.clone()), - WriteTerminalState::Armed(receiver) => match Pin::new(receiver).poll(cx) { - Poll::Ready(Ok(result)) => { - self.terminal = WriteTerminalState::Terminal(result.clone()); - Poll::Ready(result) - } - Poll::Ready(Err(_)) => { - panic!("byte writer terminal dropped before sending a terminal state") - } - Poll::Pending => Poll::Pending, - }, - } - } - - fn close_inner(&mut self, code: StreamCloseCode) { - if self.writer.take().is_none() { - return; - } - log::debug!( - "byte writer close: stream_id={:?} target={:?} code={:?}", - self.stream_id, - self.target, - code - ); - self.wait = None; - self.handle.try_send(Command::CloseStream { - stream_id: self.stream_id, - target: self.target, - code, - }); - } -} diff --git a/ql-runtime/src/io/mod.rs b/ql-runtime/src/io/mod.rs new file mode 100644 index 00000000..83c1079f --- /dev/null +++ b/ql-runtime/src/io/mod.rs @@ -0,0 +1,37 @@ +mod queue; +mod reader; +mod shared; +mod sync; +mod writer; + +use ql_wire::{CloseTarget, StreamId}; + +use self::shared::StreamShared; +pub(crate) use self::{ + queue::PushError, + shared::{ReaderIo, WriterIo}, +}; +pub use self::{reader::StreamReader, writer::StreamWriter}; +use crate::RuntimeHandle; + +pub(crate) struct StreamIo { + pub reader: StreamReader, + pub writer: StreamWriter, + pub reader_io: ReaderIo, + pub writer_io: WriterIo, +} + +pub(crate) fn new_stream( + stream_id: StreamId, + reader_target: CloseTarget, + writer_target: CloseTarget, + handle: RuntimeHandle, +) -> StreamIo { + let shared = StreamShared::new(stream_id); + StreamIo { + reader: StreamReader::new(shared.clone(), reader_target, handle.clone()), + writer: StreamWriter::new(shared.clone(), writer_target, handle), + reader_io: ReaderIo::new(shared.clone()), + writer_io: WriterIo::new(shared), + } +} diff --git a/ql-runtime/src/chunk_slot/queue.rs b/ql-runtime/src/io/queue.rs similarity index 59% rename from ql-runtime/src/chunk_slot/queue.rs rename to ql-runtime/src/io/queue.rs index a0325efc..e254b5ed 100644 --- a/ql-runtime/src/chunk_slot/queue.rs +++ b/ql-runtime/src/io/queue.rs @@ -1,4 +1,4 @@ -//! local single-slot queue for `chunk_slot` to avoid `ConcurrentQueue` taking 512 bytes instead of 40 +//! local single-slot queue for stream io //! copied from `concurrent_queue::single::Single` in `concurrent-queue` use core::mem::MaybeUninit; @@ -22,6 +22,9 @@ pub enum PushError { Closed(T), } +#[derive(Debug, PartialEq, Eq)] +pub struct ForcePushError(pub T); + /// A single-element queue. pub struct Single { state: AtomicUsize, @@ -63,6 +66,61 @@ impl Single { } } + /// Attempts to push an item into the queue, displacing another if necessary. + pub fn force_push(&self, value: T) -> Result, ForcePushError> { + // Attempt to lock the slot. + let mut state = 0; + + loop { + // Lock the slot. + let prev = self + .state + .compare_exchange(state, LOCKED | PUSHED, Ordering::SeqCst, Ordering::SeqCst) + .unwrap_or_else(|x| x); + + if prev & CLOSED != 0 { + return Err(ForcePushError(value)); + } + + if prev == state { + // If the value was pushed, swap out the value. + let prev_value = if prev & PUSHED == 0 { + // SAFETY: write is safe because we have locked the state. + self.slot.with_mut(|slot| unsafe { + slot.write(MaybeUninit::new(value)); + }); + None + } else { + // SAFETY: replace is safe because we have locked the state, and + // assume_init is safe because we have checked that the value was pushed. + self.slot.with_mut(move |slot| unsafe { + Some(std::ptr::replace(slot, MaybeUninit::new(value)).assume_init()) + }) + }; + + if let Some(prev_value) = prev_value { + // We can unlock the slot now. + self.state.fetch_and(!LOCKED, Ordering::Release); + // Return the old value. + return Ok(Some(prev_value)); + } + + // We can unlock the slot now. + self.state.fetch_and(!LOCKED, Ordering::Release); + return Ok(None); + } + + // Try to go for the current (pushed) state. + if prev & LOCKED == 0 { + state = prev; + } else { + // State is locked. + busy_wait(); + state = prev & !LOCKED; + } + } + } + /// Attempts to pop an item from the queue. pub fn pop(&self) -> Result { let mut state = PUSHED; @@ -88,11 +146,11 @@ impl Single { } if prev & PUSHED == 0 { - if prev & CLOSED == 0 { - return Err(PopError::Empty); + return if prev & CLOSED == 0 { + Err(PopError::Empty) } else { - return Err(PopError::Closed); - } + Err(PopError::Closed) + }; } if prev & LOCKED == 0 { @@ -113,19 +171,6 @@ impl Single { pub fn is_empty(&self) -> bool { self.len() == 0 } - - /// Closes the queue. - /// - /// Returns `true` if this call closed the queue. - pub fn close(&self) -> bool { - let state = self.state.fetch_or(CLOSED, Ordering::SeqCst); - state & CLOSED == 0 - } - - /// Returns `true` if the queue is closed. - pub fn is_closed(&self) -> bool { - self.state.load(Ordering::SeqCst) & CLOSED != 0 - } } impl Drop for Single { diff --git a/ql-runtime/src/io/reader.rs b/ql-runtime/src/io/reader.rs new file mode 100644 index 00000000..15b1e231 --- /dev/null +++ b/ql-runtime/src/io/reader.rs @@ -0,0 +1,192 @@ +use std::{ + future::{poll_fn, Future}, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use event_listener::EventListener; +use ql_wire::{CloseTarget, StreamCloseCode}; + +use super::{ + queue::PopError, + shared::{ReaderItem, StreamShared}, + sync::Arc, +}; +use crate::{command::Command, log, QlStreamError, RuntimeHandle}; + +pub struct StreamReader { + shared: Arc, + target: CloseTarget, + pending: Bytes, + wait: Option, + terminal: ReaderTerminalState, + handle: RuntimeHandle, +} + +enum ReaderTerminalState { + Open, + Delivered, +} + +unsafe impl Sync for StreamReader {} + +impl std::fmt::Debug for StreamReader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("InboundByteStream") + .field("stream_id", &self.shared.stream_id) + .field("target", &self.target) + .field( + "terminal", + &matches!(self.terminal, ReaderTerminalState::Delivered), + ) + .finish_non_exhaustive() + } +} + +impl StreamReader { + pub(crate) fn new( + shared: Arc, + target: CloseTarget, + handle: RuntimeHandle, + ) -> Self { + Self { + shared, + target, + pending: Bytes::new(), + wait: None, + terminal: ReaderTerminalState::Open, + handle, + } + } + + pub fn poll_read( + &mut self, + max_len: usize, + cx: &mut Context<'_>, + ) -> Poll, QlStreamError>> { + if matches!(self.terminal, ReaderTerminalState::Delivered) { + return Poll::Ready(Ok(None)); + } + + loop { + if !self.pending.is_empty() { + let pending = &mut self.pending; + let bytes = if pending.len() <= max_len { + std::mem::take(pending) + } else { + pending.split_to(max_len) + }; + self.handle.try_send(Command::PollInbound { + stream_id: self.shared.stream_id, + }); + return Poll::Ready(Ok(Some(bytes))); + } + + match self.shared.reader.pop() { + Ok(ReaderItem::Chunk(mut bytes)) => { + log::trace!( + "byte reader received chunk: stream_id={:?} target={:?} len={}", + self.shared.stream_id, + self.target, + bytes.len() + ); + self.handle.try_send(Command::PollInbound { + stream_id: self.shared.stream_id, + }); + if bytes.len() <= max_len { + return Poll::Ready(Ok(Some(bytes))); + } + let head = bytes.split_to(max_len); + self.pending = bytes; + return Poll::Ready(Ok(Some(head))); + } + Ok(ReaderItem::Error(error)) => { + log::debug!( + "byte reader delivered terminal error: stream_id={:?} target={:?} error={:?}", + self.shared.stream_id, + self.target, + error + ); + self.terminal = ReaderTerminalState::Delivered; + return Poll::Ready(Err(error)); + } + Err(PopError::Empty) => { + if self.shared.reader.is_finished() { + log::debug!( + "byte reader delivered clean eof: stream_id={:?} target={:?}", + self.shared.stream_id, + self.target + ); + self.terminal = ReaderTerminalState::Delivered; + return Poll::Ready(Ok(None)); + } + } + Err(PopError::Closed) => panic!("reader endpoint closed unexpectedly"), + } + + let active_listener = self.wait.get_or_insert_with(|| self.shared.reader.listen()); + match std::pin::Pin::new(active_listener).poll(cx) { + Poll::Ready(()) => self.wait = None, + Poll::Pending => return Poll::Pending, + } + } + } + + pub fn poll_read_chunk( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, QlStreamError>> { + self.poll_read(usize::MAX, cx) + } + + pub async fn read(&mut self, max_len: usize) -> Result, QlStreamError> { + poll_fn(|cx| self.poll_read(max_len, cx)).await + } + + pub async fn read_chunk(&mut self) -> Result, QlStreamError> { + self.read(usize::MAX).await + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if matches!(self.terminal, ReaderTerminalState::Delivered) { + return; + } + log::debug!( + "byte reader explicit close: stream_id={:?} target={:?} code={:?}", + self.shared.stream_id, + self.target, + code + ); + self.pending = Bytes::new(); + self.wait = None; + self.terminal = ReaderTerminalState::Delivered; + self.handle.try_send(Command::CloseStream { + stream_id: self.shared.stream_id, + target: self.target, + code, + }); + } +} + +impl Drop for StreamReader { + fn drop(&mut self) { + if matches!(self.terminal, ReaderTerminalState::Delivered) { + return; + } + log::debug!( + "byte reader drop close: stream_id={:?} target={:?} code={:?}", + self.shared.stream_id, + self.target, + StreamCloseCode::CANCELLED + ); + self.handle.try_send(Command::CloseStream { + stream_id: self.shared.stream_id, + target: self.target, + code: StreamCloseCode::CANCELLED, + }); + } +} diff --git a/ql-runtime/src/io/shared.rs b/ql-runtime/src/io/shared.rs new file mode 100644 index 00000000..6d2a0487 --- /dev/null +++ b/ql-runtime/src/io/shared.rs @@ -0,0 +1,319 @@ +use bytes::Bytes; +use event_listener::{Event, EventListener}; +use ql_wire::StreamId; + +use super::{ + queue::{ForcePushError, PopError, PushError, Single}, + sync::{Arc, AtomicUsize, Ordering}, +}; +use crate::QlStreamError; + +const READER_FINISHED: usize = 1 << 0; + +const WRITER_FINISH_REQUESTED: usize = 1 << 0; +const WRITER_TERMINAL_READY: usize = 1 << 1; +const WRITER_TERMINAL_OK: usize = 1 << 2; + +pub(crate) struct StreamShared { + pub stream_id: StreamId, + pub reader: ReaderShared, + pub writer: WriterShared, +} + +impl StreamShared { + pub fn new(stream_id: StreamId) -> Arc { + Arc::new(Self { + stream_id, + reader: ReaderShared::new(), + writer: WriterShared::new(), + }) + } +} + +enum SlotMsg { + Chunk(Bytes), + Error(QlStreamError), +} + +impl SlotMsg { + fn into_chunk(self) -> Option { + match self { + Self::Chunk(bytes) => Some(bytes), + Self::Error(_) => None, + } + } +} + +pub(crate) struct ReaderShared { + slot: Single, + changed: Event, + state: AtomicUsize, +} + +impl ReaderShared { + fn new() -> Self { + Self { + slot: Single::new(), + changed: Event::new(), + state: AtomicUsize::new(0), + } + } + + pub fn try_write(&self, bytes: Bytes) -> Result<(), PushError> { + match self.slot.push(SlotMsg::Chunk(bytes)) { + Ok(()) => { + self.changed.notify(usize::MAX); + Ok(()) + } + Err(PushError::Closed(SlotMsg::Chunk(bytes))) => Err(PushError::Closed(bytes)), + Err(PushError::Full(SlotMsg::Chunk(bytes))) => Err(PushError::Full(bytes)), + Err(PushError::Closed(SlotMsg::Error(_))) | Err(PushError::Full(SlotMsg::Error(_))) => { + unreachable!("reader chunk write cannot recover an error payload") + } + } + } + + pub fn finish(&self) { + if self.state.fetch_or(READER_FINISHED, Ordering::SeqCst) & READER_FINISHED == 0 { + self.changed.notify(usize::MAX); + } + } + + pub fn fail( + &self, + error: QlStreamError, + ) -> Result, ForcePushError> { + match self.slot.force_push(SlotMsg::Error(error)) { + Ok(displaced) => { + self.changed.notify(usize::MAX); + Ok(displaced.and_then(SlotMsg::into_chunk)) + } + Err(ForcePushError(SlotMsg::Error(error))) => Err(ForcePushError(error)), + Err(ForcePushError(SlotMsg::Chunk(_))) => { + unreachable!("reader fail cannot recover a chunk payload") + } + } + } + + pub fn is_finished(&self) -> bool { + self.state.load(Ordering::SeqCst) & READER_FINISHED != 0 + } + + pub fn pop(&self) -> Result { + match self.slot.pop() { + Ok(SlotMsg::Chunk(bytes)) => { + self.changed.notify(usize::MAX); + Ok(ReaderItem::Chunk(bytes)) + } + Ok(SlotMsg::Error(error)) => Ok(ReaderItem::Error(error)), + Err(error) => Err(error), + } + } + + pub fn listen(&self) -> EventListener { + self.changed.listen() + } +} + +pub(crate) enum ReaderItem { + Chunk(Bytes), + Error(QlStreamError), +} + +pub(crate) struct WriterShared { + slot: Single, + changed: Event, + state: AtomicUsize, +} + +impl WriterShared { + fn new() -> Self { + Self { + slot: Single::new(), + changed: Event::new(), + state: AtomicUsize::new(0), + } + } + + pub fn try_write(&self, bytes: Bytes) -> Result<(), PushError> { + if self.terminal_ready() || self.finish_requested() { + return Err(PushError::Closed(bytes)); + } + + match self.slot.push(SlotMsg::Chunk(bytes)) { + Ok(()) => { + self.changed.notify(usize::MAX); + Ok(()) + } + Err(PushError::Closed(SlotMsg::Chunk(bytes))) => Err(PushError::Closed(bytes)), + Err(PushError::Full(SlotMsg::Chunk(bytes))) => Err(PushError::Full(bytes)), + Err(PushError::Closed(SlotMsg::Error(_))) | Err(PushError::Full(SlotMsg::Error(_))) => { + unreachable!("writer chunk write cannot recover an error payload") + } + } + } + + pub fn request_finish(&self) { + if self + .state + .fetch_or(WRITER_FINISH_REQUESTED, Ordering::SeqCst) + & WRITER_FINISH_REQUESTED + == 0 + { + self.changed.notify(usize::MAX); + } + } + + pub fn finish_requested(&self) -> bool { + self.state.load(Ordering::SeqCst) & WRITER_FINISH_REQUESTED != 0 + } + + pub fn finish(&self) { + self.state + .fetch_or(WRITER_TERMINAL_READY | WRITER_TERMINAL_OK, Ordering::SeqCst); + self.changed.notify(usize::MAX); + } + + pub fn fail( + &self, + error: QlStreamError, + ) -> Result, ForcePushError> { + if self.terminal_ready() { + return Err(ForcePushError(error)); + } + + match self.slot.force_push(SlotMsg::Error(error)) { + Ok(displaced) => { + self.state.fetch_or(WRITER_TERMINAL_READY, Ordering::SeqCst); + self.changed.notify(usize::MAX); + Ok(displaced.and_then(SlotMsg::into_chunk)) + } + Err(ForcePushError(SlotMsg::Error(error))) => Err(ForcePushError(error)), + Err(ForcePushError(SlotMsg::Chunk(_))) => { + unreachable!("writer fail cannot recover a chunk payload") + } + } + } + + pub fn terminal_ready(&self) -> bool { + self.state.load(Ordering::SeqCst) & WRITER_TERMINAL_READY != 0 + } + + pub fn terminal_ok(&self) -> bool { + self.state.load(Ordering::SeqCst) & WRITER_TERMINAL_OK != 0 + } + + pub fn is_empty(&self) -> bool { + self.slot.is_empty() + } + + pub fn pop(&self) -> Result { + match self.slot.pop() { + Ok(SlotMsg::Chunk(bytes)) => { + self.changed.notify(usize::MAX); + Ok(WriterItem::Chunk(bytes)) + } + Ok(SlotMsg::Error(error)) => Ok(WriterItem::Error(error)), + Err(error) => Err(error), + } + } + + pub fn listen(&self) -> EventListener { + self.changed.listen() + } +} + +pub(crate) enum WriterItem { + Chunk(Bytes), + Error(QlStreamError), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct RecvClosed; + +pub(crate) struct ReaderIo { + shared: Arc, +} + +impl ReaderIo { + pub fn new(shared: Arc) -> Self { + Self { shared } + } + + pub fn try_write(&self, bytes: Bytes) -> Result<(), PushError> { + self.shared.reader.try_write(bytes) + } + + pub fn finish(&self) { + self.shared.reader.finish(); + } + + pub fn fail( + &self, + error: QlStreamError, + ) -> Result, ForcePushError> { + self.shared.reader.fail(error) + } +} + +pub(crate) struct WriterIo { + shared: Arc, + pending: Bytes, +} + +impl WriterIo { + pub fn new(shared: Arc) -> Self { + Self { + shared, + pending: Bytes::new(), + } + } + + pub fn is_finished(&self) -> bool { + self.pending.is_empty() + && self.shared.writer.finish_requested() + && self.shared.writer.is_empty() + } + + pub fn try_read(&mut self, max_len: usize) -> Result { + if !self.pending.is_empty() { + let pending = &mut self.pending; + let bytes = if pending.len() <= max_len { + std::mem::take(pending) + } else { + pending.split_to(max_len) + }; + return Ok(bytes); + } + + if self.shared.writer.terminal_ready() { + return Err(RecvClosed); + } + + match self.shared.writer.pop() { + Ok(WriterItem::Chunk(mut bytes)) => { + if bytes.len() <= max_len { + Ok(bytes) + } else { + let head = bytes.split_to(max_len); + self.pending = bytes; + Ok(head) + } + } + Ok(WriterItem::Error(_)) => Err(RecvClosed), + Err(PopError::Empty) => Ok(Bytes::new()), + Err(PopError::Closed) => Err(RecvClosed), + } + } + + pub fn finish(&self) { + self.shared.writer.finish(); + } + + pub fn fail( + &self, + error: QlStreamError, + ) -> Result, ForcePushError> { + self.shared.writer.fail(error) + } +} diff --git a/ql-runtime/src/chunk_slot/sync.rs b/ql-runtime/src/io/sync.rs similarity index 100% rename from ql-runtime/src/chunk_slot/sync.rs rename to ql-runtime/src/io/sync.rs diff --git a/ql-runtime/src/io/writer.rs b/ql-runtime/src/io/writer.rs new file mode 100644 index 00000000..369aa900 --- /dev/null +++ b/ql-runtime/src/io/writer.rs @@ -0,0 +1,203 @@ +use std::{ + future::{poll_fn, Future}, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use event_listener::EventListener; +use ql_wire::{CloseTarget, StreamCloseCode}; + +use super::{ + queue::{PopError, PushError}, + shared::{StreamShared, WriterItem}, + sync::Arc, +}; +use crate::{command::Command, log, QlStreamError, RuntimeHandle}; + +pub struct StreamWriter { + shared: Arc, + target: CloseTarget, + wait: Option, + open: bool, + terminal: WriterTerminalState, + handle: RuntimeHandle, +} + +enum WriterTerminalState { + Pending, + Terminal(Result<(), QlStreamError>), +} + +unsafe impl Sync for StreamWriter {} + +impl std::fmt::Debug for StreamWriter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OutboundByteStream") + .field("stream_id", &self.shared.stream_id) + .field("target", &self.target) + .field("closed", &!self.open) + .finish_non_exhaustive() + } +} + +impl StreamWriter { + pub(crate) fn new( + shared: Arc, + target: CloseTarget, + handle: RuntimeHandle, + ) -> Self { + Self { + shared, + target, + wait: None, + open: true, + terminal: WriterTerminalState::Pending, + handle, + } + } + + pub fn poll_write( + &mut self, + bytes: &mut Bytes, + cx: &mut Context<'_>, + ) -> Poll> { + if bytes.is_empty() { + return Poll::Ready(Ok(())); + } + + if !self.open { + return self.poll_terminal(cx); + } + + loop { + match self.shared.writer.try_write(std::mem::take(bytes)) { + Ok(()) => { + log::trace!( + "byte writer accepted chunk: stream_id={:?} target={:?}", + self.shared.stream_id, + self.target + ); + self.wait = None; + self.poll_runtime(); + return Poll::Ready(Ok(())); + } + Err(PushError::Closed(chunk)) => { + *bytes = chunk; + self.open = false; + self.wait = None; + return self.poll_terminal(cx); + } + Err(PushError::Full(chunk)) => { + *bytes = chunk; + } + } + + let active_listener = self.wait.get_or_insert_with(|| self.shared.writer.listen()); + match std::pin::Pin::new(active_listener).poll(cx) { + Poll::Ready(()) => self.wait = None, + Poll::Pending => return Poll::Pending, + } + } + } + + pub async fn write(&mut self, bytes: Bytes) -> Result<(), QlStreamError> { + let mut bytes = bytes; + poll_fn(|cx| self.poll_write(&mut bytes, cx)).await + } + + pub fn queue_finish(&mut self) { + if !self.open { + return; + } + log::debug!( + "byte writer finish: stream_id={:?} target={:?}", + self.shared.stream_id, + self.target + ); + self.open = false; + self.wait = None; + self.shared.writer.request_finish(); + self.poll_runtime(); + } + + pub async fn finish(mut self) -> Result<(), QlStreamError> { + self.queue_finish(); + poll_fn(|cx| self.poll_terminal(cx)).await + } + + pub fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.open { + self.queue_finish(); + } + self.poll_terminal(cx) + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn poll_runtime(&self) { + self.handle.try_send(Command::PollStream { + stream_id: self.shared.stream_id, + }); + } + + fn poll_terminal(&mut self, cx: &mut Context<'_>) -> Poll> { + match &self.terminal { + WriterTerminalState::Terminal(result) => return Poll::Ready(result.clone()), + WriterTerminalState::Pending => {} + } + + loop { + if self.shared.writer.terminal_ready() { + if self.shared.writer.terminal_ok() { + self.terminal = WriterTerminalState::Terminal(Ok(())); + return Poll::Ready(Ok(())); + } + + match self.shared.writer.pop() { + Ok(WriterItem::Error(error)) => { + self.terminal = WriterTerminalState::Terminal(Err(error.clone())); + return Poll::Ready(Err(error)); + } + Ok(WriterItem::Chunk(_)) => { + panic!("writer terminal phase contained chunk data") + } + Err(PopError::Empty) => {} + Err(PopError::Closed) => panic!("writer endpoint closed unexpectedly"), + } + } + + let active_listener = self.wait.get_or_insert_with(|| self.shared.writer.listen()); + match std::pin::Pin::new(active_listener).poll(cx) { + Poll::Ready(()) => self.wait = None, + Poll::Pending => return Poll::Pending, + } + } + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if !self.open { + return; + } + self.open = false; + log::debug!( + "byte writer close: stream_id={:?} target={:?} code={:?}", + self.shared.stream_id, + self.target, + code + ); + self.wait = None; + self.handle.try_send(Command::CloseStream { + stream_id: self.shared.stream_id, + target: self.target, + code, + }); + } +} + +impl Drop for StreamWriter { + fn drop(&mut self) { + self.close_inner(StreamCloseCode::CANCELLED); + } +} diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index a85388ad..b4305d14 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -2,11 +2,11 @@ pub use ql_fsm::NoSessionError; pub use self::{error::QlStreamError, handle::*, platform::*}; -pub mod chunk_slot; pub(crate) mod command; pub(crate) mod driver; mod error; pub mod handle; +pub(crate) mod io; pub mod log; pub mod platform; #[cfg(feature = "rpc")] From 98e7fb42a5683fbf6da86c7b57dec05a0bdac6fb Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 17 Apr 2026 07:41:02 -0400 Subject: [PATCH 270/304] ql-runtime: loom tests --- ql-runtime/src/io/reader.rs | 26 +- ql-runtime/src/io/shared.rs | 496 ++++++++++++++++++++++++++++++------ ql-runtime/src/io/sync.rs | 15 +- ql-runtime/src/io/writer.rs | 40 +-- 4 files changed, 460 insertions(+), 117 deletions(-) diff --git a/ql-runtime/src/io/reader.rs b/ql-runtime/src/io/reader.rs index 15b1e231..1d9d469e 100644 --- a/ql-runtime/src/io/reader.rs +++ b/ql-runtime/src/io/reader.rs @@ -9,7 +9,7 @@ use ql_wire::{CloseTarget, StreamCloseCode}; use super::{ queue::PopError, - shared::{ReaderItem, StreamShared}, + shared::{Item, ReaderShared, StreamShared}, sync::Arc, }; use crate::{command::Command, log, QlStreamError, RuntimeHandle}; @@ -18,7 +18,7 @@ pub struct StreamReader { shared: Arc, target: CloseTarget, pending: Bytes, - wait: Option, + listener: Option, terminal: ReaderTerminalState, handle: RuntimeHandle, } @@ -44,16 +44,12 @@ impl std::fmt::Debug for StreamReader { } impl StreamReader { - pub(crate) fn new( - shared: Arc, - target: CloseTarget, - handle: RuntimeHandle, - ) -> Self { + pub fn new(shared: Arc, target: CloseTarget, handle: RuntimeHandle) -> Self { Self { shared, target, pending: Bytes::new(), - wait: None, + listener: None, terminal: ReaderTerminalState::Open, handle, } @@ -83,7 +79,7 @@ impl StreamReader { } match self.shared.reader.pop() { - Ok(ReaderItem::Chunk(mut bytes)) => { + Ok(Item::Chunk(mut bytes)) => { log::trace!( "byte reader received chunk: stream_id={:?} target={:?} len={}", self.shared.stream_id, @@ -100,7 +96,7 @@ impl StreamReader { self.pending = bytes; return Poll::Ready(Ok(Some(head))); } - Ok(ReaderItem::Error(error)) => { + Ok(Item::Error(error)) => { log::debug!( "byte reader delivered terminal error: stream_id={:?} target={:?} error={:?}", self.shared.stream_id, @@ -111,7 +107,7 @@ impl StreamReader { return Poll::Ready(Err(error)); } Err(PopError::Empty) => { - if self.shared.reader.is_finished() { + if ReaderShared::is_finished(self.shared.reader.load_state()) { log::debug!( "byte reader delivered clean eof: stream_id={:?} target={:?}", self.shared.stream_id, @@ -124,9 +120,11 @@ impl StreamReader { Err(PopError::Closed) => panic!("reader endpoint closed unexpectedly"), } - let active_listener = self.wait.get_or_insert_with(|| self.shared.reader.listen()); + let active_listener = self + .listener + .get_or_insert_with(|| self.shared.reader.listen()); match std::pin::Pin::new(active_listener).poll(cx) { - Poll::Ready(()) => self.wait = None, + Poll::Ready(()) => self.listener = None, Poll::Pending => return Poll::Pending, } } @@ -161,8 +159,6 @@ impl StreamReader { self.target, code ); - self.pending = Bytes::new(); - self.wait = None; self.terminal = ReaderTerminalState::Delivered; self.handle.try_send(Command::CloseStream { stream_id: self.shared.stream_id, diff --git a/ql-runtime/src/io/shared.rs b/ql-runtime/src/io/shared.rs index 6d2a0487..1f33ee56 100644 --- a/ql-runtime/src/io/shared.rs +++ b/ql-runtime/src/io/shared.rs @@ -4,17 +4,17 @@ use ql_wire::StreamId; use super::{ queue::{ForcePushError, PopError, PushError, Single}, - sync::{Arc, AtomicUsize, Ordering}, + sync::{Arc, AtomicU8, Ordering}, }; use crate::QlStreamError; -const READER_FINISHED: usize = 1 << 0; +const READER_FINISHED: u8 = 1 << 0; -const WRITER_FINISH_REQUESTED: usize = 1 << 0; -const WRITER_TERMINAL_READY: usize = 1 << 1; -const WRITER_TERMINAL_OK: usize = 1 << 2; +const WRITER_FINISH_REQUESTED: u8 = 1 << 0; +const WRITER_TERMINAL_READY: u8 = 1 << 1; +const WRITER_TERMINAL_OK: u8 = 1 << 2; -pub(crate) struct StreamShared { +pub struct StreamShared { pub stream_id: StreamId, pub reader: ReaderShared, pub writer: WriterShared, @@ -30,12 +30,12 @@ impl StreamShared { } } -enum SlotMsg { +pub enum Item { Chunk(Bytes), Error(QlStreamError), } -impl SlotMsg { +impl Item { fn into_chunk(self) -> Option { match self { Self::Chunk(bytes) => Some(bytes), @@ -44,10 +44,10 @@ impl SlotMsg { } } -pub(crate) struct ReaderShared { - slot: Single, +pub struct ReaderShared { + slot: Single, changed: Event, - state: AtomicUsize, + state: AtomicU8, } impl ReaderShared { @@ -55,26 +55,26 @@ impl ReaderShared { Self { slot: Single::new(), changed: Event::new(), - state: AtomicUsize::new(0), + state: AtomicU8::new(0), } } pub fn try_write(&self, bytes: Bytes) -> Result<(), PushError> { - match self.slot.push(SlotMsg::Chunk(bytes)) { + match self.slot.push(Item::Chunk(bytes)) { Ok(()) => { self.changed.notify(usize::MAX); Ok(()) } - Err(PushError::Closed(SlotMsg::Chunk(bytes))) => Err(PushError::Closed(bytes)), - Err(PushError::Full(SlotMsg::Chunk(bytes))) => Err(PushError::Full(bytes)), - Err(PushError::Closed(SlotMsg::Error(_))) | Err(PushError::Full(SlotMsg::Error(_))) => { + Err(PushError::Closed(Item::Chunk(bytes))) => Err(PushError::Closed(bytes)), + Err(PushError::Full(Item::Chunk(bytes))) => Err(PushError::Full(bytes)), + Err(PushError::Closed(Item::Error(_))) | Err(PushError::Full(Item::Error(_))) => { unreachable!("reader chunk write cannot recover an error payload") } } } pub fn finish(&self) { - if self.state.fetch_or(READER_FINISHED, Ordering::SeqCst) & READER_FINISHED == 0 { + if self.state.fetch_or(READER_FINISHED, Ordering::Release) & READER_FINISHED == 0 { self.changed.notify(usize::MAX); } } @@ -83,29 +83,33 @@ impl ReaderShared { &self, error: QlStreamError, ) -> Result, ForcePushError> { - match self.slot.force_push(SlotMsg::Error(error)) { + match self.slot.force_push(Item::Error(error)) { Ok(displaced) => { self.changed.notify(usize::MAX); - Ok(displaced.and_then(SlotMsg::into_chunk)) + Ok(displaced.and_then(Item::into_chunk)) } - Err(ForcePushError(SlotMsg::Error(error))) => Err(ForcePushError(error)), - Err(ForcePushError(SlotMsg::Chunk(_))) => { + Err(ForcePushError(Item::Error(error))) => Err(ForcePushError(error)), + Err(ForcePushError(Item::Chunk(_))) => { unreachable!("reader fail cannot recover a chunk payload") } } } - pub fn is_finished(&self) -> bool { - self.state.load(Ordering::SeqCst) & READER_FINISHED != 0 + pub fn load_state(&self) -> u8 { + self.state.load(Ordering::Acquire) + } + + pub fn is_finished(state: u8) -> bool { + state & READER_FINISHED != 0 } - pub fn pop(&self) -> Result { + pub fn pop(&self) -> Result { match self.slot.pop() { - Ok(SlotMsg::Chunk(bytes)) => { + Ok(Item::Chunk(bytes)) => { self.changed.notify(usize::MAX); - Ok(ReaderItem::Chunk(bytes)) + Ok(Item::Chunk(bytes)) } - Ok(SlotMsg::Error(error)) => Ok(ReaderItem::Error(error)), + Ok(Item::Error(error)) => Ok(Item::Error(error)), Err(error) => Err(error), } } @@ -115,15 +119,10 @@ impl ReaderShared { } } -pub(crate) enum ReaderItem { - Chunk(Bytes), - Error(QlStreamError), -} - -pub(crate) struct WriterShared { - slot: Single, +pub struct WriterShared { + slot: Single, changed: Event, - state: AtomicUsize, + state: AtomicU8, } impl WriterShared { @@ -131,23 +130,40 @@ impl WriterShared { Self { slot: Single::new(), changed: Event::new(), - state: AtomicUsize::new(0), + state: AtomicU8::new(0), } } + pub fn load_state(&self) -> u8 { + self.state.load(Ordering::Acquire) + } + + pub fn finish_requested(state: u8) -> bool { + state & WRITER_FINISH_REQUESTED != 0 + } + + pub fn terminal_ready(state: u8) -> bool { + state & WRITER_TERMINAL_READY != 0 + } + + pub fn terminal_ok(state: u8) -> bool { + state & WRITER_TERMINAL_OK != 0 + } + pub fn try_write(&self, bytes: Bytes) -> Result<(), PushError> { - if self.terminal_ready() || self.finish_requested() { + let state = self.load_state(); + if Self::terminal_ready(state) || Self::finish_requested(state) { return Err(PushError::Closed(bytes)); } - match self.slot.push(SlotMsg::Chunk(bytes)) { + match self.slot.push(Item::Chunk(bytes)) { Ok(()) => { self.changed.notify(usize::MAX); Ok(()) } - Err(PushError::Closed(SlotMsg::Chunk(bytes))) => Err(PushError::Closed(bytes)), - Err(PushError::Full(SlotMsg::Chunk(bytes))) => Err(PushError::Full(bytes)), - Err(PushError::Closed(SlotMsg::Error(_))) | Err(PushError::Full(SlotMsg::Error(_))) => { + Err(PushError::Closed(Item::Chunk(bytes))) => Err(PushError::Closed(bytes)), + Err(PushError::Full(Item::Chunk(bytes))) => Err(PushError::Full(bytes)), + Err(PushError::Closed(Item::Error(_))) | Err(PushError::Full(Item::Error(_))) => { unreachable!("writer chunk write cannot recover an error payload") } } @@ -156,7 +172,7 @@ impl WriterShared { pub fn request_finish(&self) { if self .state - .fetch_or(WRITER_FINISH_REQUESTED, Ordering::SeqCst) + .fetch_or(WRITER_FINISH_REQUESTED, Ordering::Release) & WRITER_FINISH_REQUESTED == 0 { @@ -164,56 +180,70 @@ impl WriterShared { } } - pub fn finish_requested(&self) -> bool { - self.state.load(Ordering::SeqCst) & WRITER_FINISH_REQUESTED != 0 - } - pub fn finish(&self) { - self.state - .fetch_or(WRITER_TERMINAL_READY | WRITER_TERMINAL_OK, Ordering::SeqCst); - self.changed.notify(usize::MAX); + let mut state = self.state.load(Ordering::Acquire); + loop { + if Self::terminal_ready(state) { + return; + } + + let new_state = state | WRITER_TERMINAL_READY | WRITER_TERMINAL_OK; + match self + .state + .compare_exchange(state, new_state, Ordering::AcqRel, Ordering::Acquire) + { + Ok(_) => { + self.changed.notify(usize::MAX); + return; + } + Err(actual) => state = actual, + } + } } pub fn fail( &self, error: QlStreamError, ) -> Result, ForcePushError> { - if self.terminal_ready() { - return Err(ForcePushError(error)); + let mut state = self.state.load(Ordering::Acquire); + loop { + if Self::terminal_ready(state) { + return Err(ForcePushError(error)); + } + + let new_state = state | WRITER_TERMINAL_READY; + match self + .state + .compare_exchange(state, new_state, Ordering::AcqRel, Ordering::Acquire) + { + Ok(_) => break, + Err(actual) => state = actual, + } } - match self.slot.force_push(SlotMsg::Error(error)) { + match self.slot.force_push(Item::Error(error)) { Ok(displaced) => { - self.state.fetch_or(WRITER_TERMINAL_READY, Ordering::SeqCst); self.changed.notify(usize::MAX); - Ok(displaced.and_then(SlotMsg::into_chunk)) + Ok(displaced.and_then(Item::into_chunk)) } - Err(ForcePushError(SlotMsg::Error(error))) => Err(ForcePushError(error)), - Err(ForcePushError(SlotMsg::Chunk(_))) => { + Err(ForcePushError(Item::Error(error))) => Err(ForcePushError(error)), + Err(ForcePushError(Item::Chunk(_))) => { unreachable!("writer fail cannot recover a chunk payload") } } } - pub fn terminal_ready(&self) -> bool { - self.state.load(Ordering::SeqCst) & WRITER_TERMINAL_READY != 0 - } - - pub fn terminal_ok(&self) -> bool { - self.state.load(Ordering::SeqCst) & WRITER_TERMINAL_OK != 0 - } - pub fn is_empty(&self) -> bool { self.slot.is_empty() } - pub fn pop(&self) -> Result { + pub fn pop(&self) -> Result { match self.slot.pop() { - Ok(SlotMsg::Chunk(bytes)) => { + Ok(Item::Chunk(bytes)) => { self.changed.notify(usize::MAX); - Ok(WriterItem::Chunk(bytes)) + Ok(Item::Chunk(bytes)) } - Ok(SlotMsg::Error(error)) => Ok(WriterItem::Error(error)), + Ok(Item::Error(error)) => Ok(Item::Error(error)), Err(error) => Err(error), } } @@ -223,15 +253,10 @@ impl WriterShared { } } -pub(crate) enum WriterItem { - Chunk(Bytes), - Error(QlStreamError), -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(crate) struct RecvClosed; +pub struct RecvClosed; -pub(crate) struct ReaderIo { +pub struct ReaderIo { shared: Arc, } @@ -256,7 +281,7 @@ impl ReaderIo { } } -pub(crate) struct WriterIo { +pub struct WriterIo { shared: Arc, pending: Bytes, } @@ -270,8 +295,9 @@ impl WriterIo { } pub fn is_finished(&self) -> bool { + let state = self.shared.writer.load_state(); self.pending.is_empty() - && self.shared.writer.finish_requested() + && WriterShared::finish_requested(state) && self.shared.writer.is_empty() } @@ -286,12 +312,13 @@ impl WriterIo { return Ok(bytes); } - if self.shared.writer.terminal_ready() { + let state = self.shared.writer.load_state(); + if WriterShared::terminal_ready(state) { return Err(RecvClosed); } match self.shared.writer.pop() { - Ok(WriterItem::Chunk(mut bytes)) => { + Ok(Item::Chunk(mut bytes)) => { if bytes.len() <= max_len { Ok(bytes) } else { @@ -300,7 +327,7 @@ impl WriterIo { Ok(head) } } - Ok(WriterItem::Error(_)) => Err(RecvClosed), + Ok(Item::Error(_)) => Err(RecvClosed), Err(PopError::Empty) => Ok(Bytes::new()), Err(PopError::Closed) => Err(RecvClosed), } @@ -317,3 +344,312 @@ impl WriterIo { self.shared.writer.fail(error) } } + +#[cfg(all(test, loom))] +mod loom_tests { + use std::{ + future::Future, + pin::pin, + task::{Context, Poll, Waker}, + }; + + use bytes::Bytes; + use loom::{model, thread}; + use ql_wire::{StreamCloseCode, StreamId}; + + use super::{Item, PopError, PushError, ReaderShared, StreamShared, WriterIo, WriterShared}; + use crate::QlStreamError; + + fn check_model(f: impl Fn() + Sync + Send + 'static) { + let builder = model::Builder::new(); + builder.check(f); + } + + fn shared() -> super::super::sync::Arc { + StreamShared::new(StreamId(1u32.into())) + } + + #[test] + fn reader_listener_observes_finish_after_pending() { + check_model(|| { + let shared = shared(); + let waker = Waker::noop(); + let mut cx = Context::from_waker(waker); + let mut listener = pin!(shared.reader.listen()); + + assert!(matches!(listener.as_mut().poll(&mut cx), Poll::Pending)); + + let finisher = { + let shared = shared.clone(); + thread::spawn(move || { + shared.reader.finish(); + }) + }; + + finisher.join().unwrap(); + assert!(ReaderShared::is_finished(shared.reader.load_state())); + assert!(matches!(listener.as_mut().poll(&mut cx), Poll::Ready(()))); + }); + } + + #[test] + fn reader_chunk_remains_available_after_finish() { + check_model(|| { + let shared = shared(); + + let producer = { + let shared = shared.clone(); + thread::spawn(move || { + shared.reader.try_write(Bytes::from_static(b"abc")).unwrap(); + shared.reader.finish(); + }) + }; + + producer.join().unwrap(); + + match shared.reader.pop() { + Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), + _ => panic!("expected buffered reader chunk"), + } + assert!(ReaderShared::is_finished(shared.reader.load_state())); + assert!(matches!(shared.reader.pop(), Err(PopError::Empty))); + }); + } + + #[test] + fn reader_write_races_with_finish_preserves_chunk() { + check_model(|| { + let shared = shared(); + + let writer = { + let shared = shared.clone(); + thread::spawn(move || shared.reader.try_write(Bytes::from_static(b"abc"))) + }; + let finisher = { + let shared = shared.clone(); + thread::spawn(move || shared.reader.finish()) + }; + + assert_eq!(writer.join().unwrap(), Ok(())); + finisher.join().unwrap(); + + assert!(ReaderShared::is_finished(shared.reader.load_state())); + match shared.reader.pop() { + Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), + _ => panic!("expected buffered reader chunk"), + } + assert!(matches!(shared.reader.pop(), Err(PopError::Empty))); + }); + } + + #[test] + fn reader_fail_racing_with_pop_preserves_terminal_outcome() { + check_model(|| { + let shared = shared(); + shared.reader.try_write(Bytes::from_static(b"abc")).unwrap(); + + let popper = { + let shared = shared.clone(); + thread::spawn(move || shared.reader.pop()) + }; + let failer = { + let shared = shared.clone(); + thread::spawn(move || { + shared.reader.fail(QlStreamError::StreamClosed { + code: StreamCloseCode::CANCELLED, + }) + }) + }; + + let pop_result = popper.join().unwrap(); + let fail_result = failer.join().unwrap(); + + match (pop_result, fail_result) { + (Ok(Item::Chunk(bytes)), Ok(None)) => { + assert_eq!(bytes, Bytes::from_static(b"abc")); + match shared.reader.pop() { + Ok(Item::Error(QlStreamError::StreamClosed { code })) => { + assert_eq!(code, StreamCloseCode::CANCELLED); + } + _ => panic!("expected terminal reader error"), + } + } + (Ok(Item::Error(QlStreamError::StreamClosed { code })), Ok(Some(bytes))) => { + assert_eq!(code, StreamCloseCode::CANCELLED); + assert_eq!(bytes, Bytes::from_static(b"abc")); + assert!(matches!(shared.reader.pop(), Err(PopError::Empty))); + } + _ => panic!("unexpected reader fail/pop race outcome"), + } + }); + } + + #[test] + fn writer_is_finished_only_after_drain() { + check_model(|| { + let shared = shared(); + let mut writer_io = WriterIo::new(shared.clone()); + + shared.writer.try_write(Bytes::from_static(b"abc")).unwrap(); + shared.writer.request_finish(); + + assert!(!writer_io.is_finished()); + assert_eq!(writer_io.try_read(2), Ok(Bytes::from_static(b"ab"))); + assert!(!writer_io.is_finished()); + assert_eq!(writer_io.try_read(8), Ok(Bytes::from_static(b"c"))); + assert!(writer_io.is_finished()); + }); + } + + #[test] + fn writer_write_races_with_request_finish() { + check_model(|| { + let shared = shared(); + let mut writer_io = WriterIo::new(shared.clone()); + + let writer = { + let shared = shared.clone(); + thread::spawn(move || shared.writer.try_write(Bytes::from_static(b"abc"))) + }; + let finisher = { + let shared = shared.clone(); + thread::spawn(move || shared.writer.request_finish()) + }; + + let write_result = writer.join().unwrap(); + finisher.join().unwrap(); + + assert!(WriterShared::finish_requested(shared.writer.load_state())); + match write_result { + Ok(()) => { + assert_eq!(writer_io.try_read(8), Ok(Bytes::from_static(b"abc"))); + assert!(writer_io.is_finished()); + } + Err(PushError::Closed(bytes)) => { + assert_eq!(bytes, Bytes::from_static(b"abc")); + assert!(writer_io.is_finished()); + } + Err(PushError::Full(_)) => panic!("empty writer slot must not report full"), + } + }); + } + + #[test] + fn writer_fail_overwrites_buffered_chunk_and_wakes_listener() { + check_model(|| { + let shared = shared(); + shared.writer.try_write(Bytes::from_static(b"abc")).unwrap(); + + let waker = Waker::noop(); + let mut cx = Context::from_waker(waker); + let mut listener = pin!(shared.writer.listen()); + assert!(matches!(listener.as_mut().poll(&mut cx), Poll::Pending)); + + let failer = { + let shared = shared.clone(); + thread::spawn(move || { + let displaced = shared.writer.fail(QlStreamError::StreamClosed { + code: StreamCloseCode::CANCELLED, + }); + assert_eq!(displaced.unwrap(), Some(Bytes::from_static(b"abc"))); + }) + }; + + failer.join().unwrap(); + + assert!(WriterShared::terminal_ready(shared.writer.load_state())); + assert!(matches!(listener.as_mut().poll(&mut cx), Poll::Ready(()))); + match shared.writer.pop() { + Ok(Item::Error(QlStreamError::StreamClosed { code })) => { + assert_eq!(code, StreamCloseCode::CANCELLED); + } + _ => panic!("expected terminal writer error"), + } + }); + } + + #[test] + fn writer_write_races_with_fail() { + check_model(|| { + let shared = shared(); + + let writer = { + let shared = shared.clone(); + thread::spawn(move || shared.writer.try_write(Bytes::from_static(b"abc"))) + }; + let failer = { + let shared = shared.clone(); + thread::spawn(move || { + shared.writer.fail(QlStreamError::StreamClosed { + code: StreamCloseCode::CANCELLED, + }) + }) + }; + + let write_result = writer.join().unwrap(); + let fail_result = failer.join().unwrap(); + + assert!(WriterShared::terminal_ready(shared.writer.load_state())); + match (&write_result, &fail_result) { + (Ok(()), Ok(Some(bytes))) => { + assert_eq!(Bytes::from_static(b"abc"), bytes.clone()); + } + (Err(PushError::Closed(bytes)), Ok(None)) => { + assert_eq!(Bytes::from_static(b"abc"), bytes.clone()); + } + (Err(PushError::Full(bytes)), Ok(None)) => { + assert_eq!(Bytes::from_static(b"abc"), bytes.clone()); + } + _ => panic!( + "unexpected writer fail/write race outcome: write={write_result:?} fail={fail_result:?}" + ), + } + + match shared.writer.pop() { + Ok(Item::Error(QlStreamError::StreamClosed { code })) => { + assert_eq!(code, StreamCloseCode::CANCELLED); + } + _ => panic!("expected terminal writer error"), + } + }); + } + + #[test] + fn writer_finish_races_with_fail_without_masking_error() { + check_model(|| { + let shared = shared(); + + let finisher = { + let shared = shared.clone(); + thread::spawn(move || shared.writer.finish()) + }; + let failer = { + let shared = shared.clone(); + thread::spawn(move || { + shared.writer.fail(QlStreamError::StreamClosed { + code: StreamCloseCode::CANCELLED, + }) + }) + }; + + finisher.join().unwrap(); + let fail_result = failer.join().unwrap(); + + assert!(WriterShared::terminal_ready(shared.writer.load_state())); + match fail_result { + Err(_) => { + assert!(WriterShared::terminal_ok(shared.writer.load_state())); + } + Ok(_) => { + assert!(!WriterShared::terminal_ok(shared.writer.load_state())); + match shared.writer.pop() { + Ok(Item::Error(QlStreamError::StreamClosed { code })) => { + assert_eq!(code, StreamCloseCode::CANCELLED); + } + _ => panic!("expected terminal writer error"), + } + } + } + }); + } +} diff --git a/ql-runtime/src/io/sync.rs b/ql-runtime/src/io/sync.rs index 8e7423b5..c058710f 100644 --- a/ql-runtime/src/io/sync.rs +++ b/ql-runtime/src/io/sync.rs @@ -3,7 +3,7 @@ mod inner { pub use std::{ cell::UnsafeCell, sync::{ - atomic::{AtomicUsize, Ordering}, + atomic::{AtomicU8, AtomicUsize, Ordering}, Arc, }, }; @@ -49,6 +49,17 @@ mod inner { f(self.get_mut()) } } + + impl AtomicExt for AtomicU8 { + type Value = u8; + + fn with_mut(&mut self, f: F) -> R + where + F: FnOnce(&mut Self::Value) -> R, + { + f(self.get_mut()) + } + } } #[cfg(all(test, loom))] @@ -56,7 +67,7 @@ mod inner { pub use loom::{ cell::UnsafeCell, sync::{ - atomic::{AtomicUsize, Ordering}, + atomic::{AtomicU8, AtomicUsize, Ordering}, Arc, }, thread::yield_now as busy_wait, diff --git a/ql-runtime/src/io/writer.rs b/ql-runtime/src/io/writer.rs index 369aa900..4fb497ce 100644 --- a/ql-runtime/src/io/writer.rs +++ b/ql-runtime/src/io/writer.rs @@ -9,7 +9,7 @@ use ql_wire::{CloseTarget, StreamCloseCode}; use super::{ queue::{PopError, PushError}, - shared::{StreamShared, WriterItem}, + shared::{Item, StreamShared, WriterShared}, sync::Arc, }; use crate::{command::Command, log, QlStreamError, RuntimeHandle}; @@ -17,7 +17,7 @@ use crate::{command::Command, log, QlStreamError, RuntimeHandle}; pub struct StreamWriter { shared: Arc, target: CloseTarget, - wait: Option, + listener: Option, open: bool, terminal: WriterTerminalState, handle: RuntimeHandle, @@ -41,15 +41,11 @@ impl std::fmt::Debug for StreamWriter { } impl StreamWriter { - pub(crate) fn new( - shared: Arc, - target: CloseTarget, - handle: RuntimeHandle, - ) -> Self { + pub fn new(shared: Arc, target: CloseTarget, handle: RuntimeHandle) -> Self { Self { shared, target, - wait: None, + listener: None, open: true, terminal: WriterTerminalState::Pending, handle, @@ -77,14 +73,14 @@ impl StreamWriter { self.shared.stream_id, self.target ); - self.wait = None; + self.listener = None; self.poll_runtime(); return Poll::Ready(Ok(())); } Err(PushError::Closed(chunk)) => { *bytes = chunk; self.open = false; - self.wait = None; + self.listener = None; return self.poll_terminal(cx); } Err(PushError::Full(chunk)) => { @@ -92,9 +88,11 @@ impl StreamWriter { } } - let active_listener = self.wait.get_or_insert_with(|| self.shared.writer.listen()); + let active_listener = self + .listener + .get_or_insert_with(|| self.shared.writer.listen()); match std::pin::Pin::new(active_listener).poll(cx) { - Poll::Ready(()) => self.wait = None, + Poll::Ready(()) => self.listener = None, Poll::Pending => return Poll::Pending, } } @@ -115,7 +113,7 @@ impl StreamWriter { self.target ); self.open = false; - self.wait = None; + self.listener = None; self.shared.writer.request_finish(); self.poll_runtime(); } @@ -149,18 +147,19 @@ impl StreamWriter { } loop { - if self.shared.writer.terminal_ready() { - if self.shared.writer.terminal_ok() { + let state = self.shared.writer.load_state(); + if WriterShared::terminal_ready(state) { + if WriterShared::terminal_ok(state) { self.terminal = WriterTerminalState::Terminal(Ok(())); return Poll::Ready(Ok(())); } match self.shared.writer.pop() { - Ok(WriterItem::Error(error)) => { + Ok(Item::Error(error)) => { self.terminal = WriterTerminalState::Terminal(Err(error.clone())); return Poll::Ready(Err(error)); } - Ok(WriterItem::Chunk(_)) => { + Ok(Item::Chunk(_)) => { panic!("writer terminal phase contained chunk data") } Err(PopError::Empty) => {} @@ -168,9 +167,11 @@ impl StreamWriter { } } - let active_listener = self.wait.get_or_insert_with(|| self.shared.writer.listen()); + let active_listener = self + .listener + .get_or_insert_with(|| self.shared.writer.listen()); match std::pin::Pin::new(active_listener).poll(cx) { - Poll::Ready(()) => self.wait = None, + Poll::Ready(()) => self.listener = None, Poll::Pending => return Poll::Pending, } } @@ -187,7 +188,6 @@ impl StreamWriter { self.target, code ); - self.wait = None; self.handle.try_send(Command::CloseStream { stream_id: self.shared.stream_id, target: self.target, From 761f41833a8ad06451d3297699290581bbc41a55 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 17 Apr 2026 15:11:27 -0400 Subject: [PATCH 271/304] ql-runtime: more loom tests --- ql-runtime/src/io/shared.rs | 38 ++++++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/ql-runtime/src/io/shared.rs b/ql-runtime/src/io/shared.rs index 1f33ee56..1c3c732b 100644 --- a/ql-runtime/src/io/shared.rs +++ b/ql-runtime/src/io/shared.rs @@ -60,6 +60,10 @@ impl ReaderShared { } pub fn try_write(&self, bytes: Bytes) -> Result<(), PushError> { + if Self::is_finished(self.load_state()) { + return Err(PushError::Closed(bytes)); + } + match self.slot.push(Item::Chunk(bytes)) { Ok(()) => { self.changed.notify(usize::MAX); @@ -417,7 +421,23 @@ mod loom_tests { } #[test] - fn reader_write_races_with_finish_preserves_chunk() { + fn reader_rejects_write_after_finish() { + check_model(|| { + let shared = shared(); + + shared.reader.finish(); + + assert_eq!( + shared.reader.try_write(Bytes::from_static(b"abc")), + Err(PushError::Closed(Bytes::from_static(b"abc"))) + ); + assert!(ReaderShared::is_finished(shared.reader.load_state())); + assert!(matches!(shared.reader.pop(), Err(PopError::Empty))); + }); + } + + #[test] + fn reader_write_races_with_finish_has_coherent_outcome() { check_model(|| { let shared = shared(); @@ -430,13 +450,21 @@ mod loom_tests { thread::spawn(move || shared.reader.finish()) }; - assert_eq!(writer.join().unwrap(), Ok(())); + let write_result = writer.join().unwrap(); finisher.join().unwrap(); assert!(ReaderShared::is_finished(shared.reader.load_state())); - match shared.reader.pop() { - Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), - _ => panic!("expected buffered reader chunk"), + match write_result { + Ok(()) => match shared.reader.pop() { + Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), + _ => panic!("expected buffered reader chunk"), + }, + Err(PushError::Closed(bytes)) => { + assert_eq!(bytes, Bytes::from_static(b"abc")); + assert!(matches!(shared.reader.pop(), Err(PopError::Empty))); + return; + } + Err(PushError::Full(_)) => panic!("empty reader slot must not report full"), } assert!(matches!(shared.reader.pop(), Err(PopError::Empty))); }); From f94d1055d08ca62f6b4cb19119990ac449116e55 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 17 Apr 2026 15:27:05 -0400 Subject: [PATCH 272/304] diatomic waker insted of event listener --- Cargo.lock | 17 +++-- ql-runtime/Cargo.toml | 2 +- ql-runtime/src/io/reader.rs | 119 ++++++++++++++++++---------------- ql-runtime/src/io/shared.rs | 123 +++++++++++++++++++++++++----------- ql-runtime/src/io/writer.rs | 92 +++++++++++++++++---------- 5 files changed, 219 insertions(+), 134 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 452d0150..872dd247 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -788,6 +788,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "diatomic-waker" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab03c107fafeb3ee9f5925686dbb7a73bc76e3932abb0d2b365cb64b169cf04c" + [[package]] name = "digest" version = "0.10.7" @@ -949,7 +955,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -1183,7 +1189,7 @@ dependencies = [ "libc", "log", "rustversion", - "windows-link 0.1.3", + "windows-link 0.2.1", "windows-result", ] @@ -1853,7 +1859,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -2360,6 +2366,7 @@ version = "0.1.0" dependencies = [ "async-channel", "bytes", + "diatomic-waker", "env_logger", "event-listener", "futures-lite", @@ -2632,7 +2639,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -2977,7 +2984,7 @@ dependencies = [ "getrandom 0.3.3", "once_cell", "rustix", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/ql-runtime/Cargo.toml b/ql-runtime/Cargo.toml index 1760928e..56fd327f 100644 --- a/ql-runtime/Cargo.toml +++ b/ql-runtime/Cargo.toml @@ -13,7 +13,7 @@ rpc = ["dep:ql-rpc"] [dependencies] async-channel = { version = "2.5" } bytes = "1" -event-listener = "5.4" +diatomic-waker = { version = "0.2.3", default-features = false } futures-lite = { version = "2.5" } log = { version = "0.4", optional = true } oneshot = { version = "0.1.11" } diff --git a/ql-runtime/src/io/reader.rs b/ql-runtime/src/io/reader.rs index 1d9d469e..a7bacd3e 100644 --- a/ql-runtime/src/io/reader.rs +++ b/ql-runtime/src/io/reader.rs @@ -1,10 +1,9 @@ use std::{ - future::{poll_fn, Future}, + future::poll_fn, task::{Context, Poll}, }; use bytes::Bytes; -use event_listener::EventListener; use ql_wire::{CloseTarget, StreamCloseCode}; use super::{ @@ -18,7 +17,6 @@ pub struct StreamReader { shared: Arc, target: CloseTarget, pending: Bytes, - listener: Option, terminal: ReaderTerminalState, handle: RuntimeHandle, } @@ -49,7 +47,6 @@ impl StreamReader { shared, target, pending: Bytes::new(), - listener: None, terminal: ReaderTerminalState::Open, handle, } @@ -65,68 +62,78 @@ impl StreamReader { } loop { - if !self.pending.is_empty() { - let pending = &mut self.pending; - let bytes = if pending.len() <= max_len { - std::mem::take(pending) - } else { - pending.split_to(max_len) - }; + match self.try_read_ready(max_len) { + Poll::Ready(result) => return Poll::Ready(result), + Poll::Pending => {} + } + + self.shared.reader.register_waiter(cx.waker()); + + match self.try_read_ready(max_len) { + Poll::Ready(result) => { + self.shared.reader.unregister_waiter(); + return Poll::Ready(result); + } + Poll::Pending => return Poll::Pending, + } + } + } + + fn try_read_ready(&mut self, max_len: usize) -> Poll, QlStreamError>> { + if !self.pending.is_empty() { + let pending = &mut self.pending; + let bytes = if pending.len() <= max_len { + std::mem::take(pending) + } else { + pending.split_to(max_len) + }; + self.handle.try_send(Command::PollInbound { + stream_id: self.shared.stream_id, + }); + return Poll::Ready(Ok(Some(bytes))); + } + + match self.shared.reader.pop() { + Ok(Item::Chunk(mut bytes)) => { + log::trace!( + "byte reader received chunk: stream_id={:?} target={:?} len={}", + self.shared.stream_id, + self.target, + bytes.len() + ); self.handle.try_send(Command::PollInbound { stream_id: self.shared.stream_id, }); - return Poll::Ready(Ok(Some(bytes))); - } - - match self.shared.reader.pop() { - Ok(Item::Chunk(mut bytes)) => { - log::trace!( - "byte reader received chunk: stream_id={:?} target={:?} len={}", - self.shared.stream_id, - self.target, - bytes.len() - ); - self.handle.try_send(Command::PollInbound { - stream_id: self.shared.stream_id, - }); - if bytes.len() <= max_len { - return Poll::Ready(Ok(Some(bytes))); - } - let head = bytes.split_to(max_len); - self.pending = bytes; - return Poll::Ready(Ok(Some(head))); + if bytes.len() <= max_len { + return Poll::Ready(Ok(Some(bytes))); } - Ok(Item::Error(error)) => { + let head = bytes.split_to(max_len); + self.pending = bytes; + Poll::Ready(Ok(Some(head))) + } + Ok(Item::Error(error)) => { + log::debug!( + "byte reader delivered terminal error: stream_id={:?} target={:?} error={:?}", + self.shared.stream_id, + self.target, + error + ); + self.terminal = ReaderTerminalState::Delivered; + Poll::Ready(Err(error)) + } + Err(PopError::Empty) => { + if ReaderShared::is_finished(self.shared.reader.load_state()) { log::debug!( - "byte reader delivered terminal error: stream_id={:?} target={:?} error={:?}", + "byte reader delivered clean eof: stream_id={:?} target={:?}", self.shared.stream_id, - self.target, - error + self.target ); self.terminal = ReaderTerminalState::Delivered; - return Poll::Ready(Err(error)); - } - Err(PopError::Empty) => { - if ReaderShared::is_finished(self.shared.reader.load_state()) { - log::debug!( - "byte reader delivered clean eof: stream_id={:?} target={:?}", - self.shared.stream_id, - self.target - ); - self.terminal = ReaderTerminalState::Delivered; - return Poll::Ready(Ok(None)); - } + return Poll::Ready(Ok(None)); } - Err(PopError::Closed) => panic!("reader endpoint closed unexpectedly"), - } - - let active_listener = self - .listener - .get_or_insert_with(|| self.shared.reader.listen()); - match std::pin::Pin::new(active_listener).poll(cx) { - Poll::Ready(()) => self.listener = None, - Poll::Pending => return Poll::Pending, + Poll::Pending } + Err(PopError::Closed) => panic!("reader endpoint closed unexpectedly"), } } diff --git a/ql-runtime/src/io/shared.rs b/ql-runtime/src/io/shared.rs index 1c3c732b..dfa01136 100644 --- a/ql-runtime/src/io/shared.rs +++ b/ql-runtime/src/io/shared.rs @@ -1,5 +1,7 @@ +use std::task::Waker; + use bytes::Bytes; -use event_listener::{Event, EventListener}; +use diatomic_waker::DiatomicWaker; use ql_wire::StreamId; use super::{ @@ -46,7 +48,9 @@ impl Item { pub struct ReaderShared { slot: Single, - changed: Event, + // Sound because StreamShared creates exactly one StreamReader for this side, + // and that reader is the only task that registers or unregisters wakers. + changed: DiatomicWaker, state: AtomicU8, } @@ -54,7 +58,7 @@ impl ReaderShared { fn new() -> Self { Self { slot: Single::new(), - changed: Event::new(), + changed: DiatomicWaker::new(), state: AtomicU8::new(0), } } @@ -66,7 +70,7 @@ impl ReaderShared { match self.slot.push(Item::Chunk(bytes)) { Ok(()) => { - self.changed.notify(usize::MAX); + self.changed.notify(); Ok(()) } Err(PushError::Closed(Item::Chunk(bytes))) => Err(PushError::Closed(bytes)), @@ -79,7 +83,7 @@ impl ReaderShared { pub fn finish(&self) { if self.state.fetch_or(READER_FINISHED, Ordering::Release) & READER_FINISHED == 0 { - self.changed.notify(usize::MAX); + self.changed.notify(); } } @@ -89,7 +93,7 @@ impl ReaderShared { ) -> Result, ForcePushError> { match self.slot.force_push(Item::Error(error)) { Ok(displaced) => { - self.changed.notify(usize::MAX); + self.changed.notify(); Ok(displaced.and_then(Item::into_chunk)) } Err(ForcePushError(Item::Error(error))) => Err(ForcePushError(error)), @@ -110,7 +114,7 @@ impl ReaderShared { pub fn pop(&self) -> Result { match self.slot.pop() { Ok(Item::Chunk(bytes)) => { - self.changed.notify(usize::MAX); + self.changed.notify(); Ok(Item::Chunk(bytes)) } Ok(Item::Error(error)) => Ok(Item::Error(error)), @@ -118,14 +122,24 @@ impl ReaderShared { } } - pub fn listen(&self) -> EventListener { - self.changed.listen() + pub fn register_waiter(&self, waker: &Waker) { + // Safety: StreamReader is the only reader-side registrar for this + // shared state, so register/unregister never run concurrently. + unsafe { self.changed.register(waker) }; + } + + pub fn unregister_waiter(&self) { + // Safety: StreamReader is the only reader-side registrar for this + // shared state, so register/unregister never run concurrently. + unsafe { self.changed.unregister() }; } } pub struct WriterShared { slot: Single, - changed: Event, + // Sound because StreamShared creates exactly one StreamWriter for this side, + // and that writer is the only task that registers or unregisters wakers. + changed: DiatomicWaker, state: AtomicU8, } @@ -133,7 +147,7 @@ impl WriterShared { fn new() -> Self { Self { slot: Single::new(), - changed: Event::new(), + changed: DiatomicWaker::new(), state: AtomicU8::new(0), } } @@ -162,7 +176,7 @@ impl WriterShared { match self.slot.push(Item::Chunk(bytes)) { Ok(()) => { - self.changed.notify(usize::MAX); + self.changed.notify(); Ok(()) } Err(PushError::Closed(Item::Chunk(bytes))) => Err(PushError::Closed(bytes)), @@ -180,7 +194,7 @@ impl WriterShared { & WRITER_FINISH_REQUESTED == 0 { - self.changed.notify(usize::MAX); + self.changed.notify(); } } @@ -197,7 +211,7 @@ impl WriterShared { .compare_exchange(state, new_state, Ordering::AcqRel, Ordering::Acquire) { Ok(_) => { - self.changed.notify(usize::MAX); + self.changed.notify(); return; } Err(actual) => state = actual, @@ -227,7 +241,7 @@ impl WriterShared { match self.slot.force_push(Item::Error(error)) { Ok(displaced) => { - self.changed.notify(usize::MAX); + self.changed.notify(); Ok(displaced.and_then(Item::into_chunk)) } Err(ForcePushError(Item::Error(error))) => Err(ForcePushError(error)), @@ -244,7 +258,7 @@ impl WriterShared { pub fn pop(&self) -> Result { match self.slot.pop() { Ok(Item::Chunk(bytes)) => { - self.changed.notify(usize::MAX); + self.changed.notify(); Ok(Item::Chunk(bytes)) } Ok(Item::Error(error)) => Ok(Item::Error(error)), @@ -252,8 +266,16 @@ impl WriterShared { } } - pub fn listen(&self) -> EventListener { - self.changed.listen() + pub fn register_waiter(&self, waker: &Waker) { + // Safety: StreamWriter is the only writer-side registrar for this + // shared state, so register/unregister never run concurrently. + unsafe { self.changed.register(waker) }; + } + + pub fn unregister_waiter(&self) { + // Safety: StreamWriter is the only writer-side registrar for this + // shared state, so register/unregister never run concurrently. + unsafe { self.changed.unregister() }; } } @@ -351,11 +373,7 @@ impl WriterIo { #[cfg(all(test, loom))] mod loom_tests { - use std::{ - future::Future, - pin::pin, - task::{Context, Poll, Waker}, - }; + use std::task::Waker; use bytes::Bytes; use loom::{model, thread}; @@ -374,14 +392,10 @@ mod loom_tests { } #[test] - fn reader_listener_observes_finish_after_pending() { + fn reader_waiter_registration_survives_finish() { check_model(|| { let shared = shared(); - let waker = Waker::noop(); - let mut cx = Context::from_waker(waker); - let mut listener = pin!(shared.reader.listen()); - - assert!(matches!(listener.as_mut().poll(&mut cx), Poll::Pending)); + shared.reader.register_waiter(Waker::noop()); let finisher = { let shared = shared.clone(); @@ -392,7 +406,8 @@ mod loom_tests { finisher.join().unwrap(); assert!(ReaderShared::is_finished(shared.reader.load_state())); - assert!(matches!(listener.as_mut().poll(&mut cx), Poll::Ready(()))); + + shared.reader.unregister_waiter(); }); } @@ -563,15 +578,11 @@ mod loom_tests { } #[test] - fn writer_fail_overwrites_buffered_chunk_and_wakes_listener() { + fn writer_fail_overwrites_buffered_chunk_and_keeps_terminal_state_observable() { check_model(|| { let shared = shared(); shared.writer.try_write(Bytes::from_static(b"abc")).unwrap(); - - let waker = Waker::noop(); - let mut cx = Context::from_waker(waker); - let mut listener = pin!(shared.writer.listen()); - assert!(matches!(listener.as_mut().poll(&mut cx), Poll::Pending)); + shared.writer.register_waiter(Waker::noop()); let failer = { let shared = shared.clone(); @@ -586,7 +597,7 @@ mod loom_tests { failer.join().unwrap(); assert!(WriterShared::terminal_ready(shared.writer.load_state())); - assert!(matches!(listener.as_mut().poll(&mut cx), Poll::Ready(()))); + shared.writer.unregister_waiter(); match shared.writer.pop() { Ok(Item::Error(QlStreamError::StreamClosed { code })) => { assert_eq!(code, StreamCloseCode::CANCELLED); @@ -596,6 +607,44 @@ mod loom_tests { }); } + #[test] + fn reader_waiter_registration_can_be_reused_after_notification() { + check_model(|| { + let shared = shared(); + + shared.reader.register_waiter(Waker::noop()); + shared.reader.try_write(Bytes::from_static(b"abc")).unwrap(); + match shared.reader.pop() { + Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), + _ => panic!("expected buffered reader chunk"), + } + + shared.reader.register_waiter(Waker::noop()); + shared.reader.finish(); + assert!(ReaderShared::is_finished(shared.reader.load_state())); + shared.reader.unregister_waiter(); + }); + } + + #[test] + fn writer_waiter_registration_can_be_reused_after_notification() { + check_model(|| { + let shared = shared(); + + shared.writer.register_waiter(Waker::noop()); + shared.writer.try_write(Bytes::from_static(b"abc")).unwrap(); + match shared.writer.pop() { + Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), + _ => panic!("expected buffered writer chunk"), + } + + shared.writer.register_waiter(Waker::noop()); + shared.writer.finish(); + assert!(WriterShared::terminal_ready(shared.writer.load_state())); + shared.writer.unregister_waiter(); + }); + } + #[test] fn writer_write_races_with_fail() { check_model(|| { diff --git a/ql-runtime/src/io/writer.rs b/ql-runtime/src/io/writer.rs index 4fb497ce..fdb8402c 100644 --- a/ql-runtime/src/io/writer.rs +++ b/ql-runtime/src/io/writer.rs @@ -1,10 +1,9 @@ use std::{ - future::{poll_fn, Future}, + future::poll_fn, task::{Context, Poll}, }; use bytes::Bytes; -use event_listener::EventListener; use ql_wire::{CloseTarget, StreamCloseCode}; use super::{ @@ -17,7 +16,6 @@ use crate::{command::Command, log, QlStreamError, RuntimeHandle}; pub struct StreamWriter { shared: Arc, target: CloseTarget, - listener: Option, open: bool, terminal: WriterTerminalState, handle: RuntimeHandle, @@ -45,7 +43,6 @@ impl StreamWriter { Self { shared, target, - listener: None, open: true, terminal: WriterTerminalState::Pending, handle, @@ -73,14 +70,12 @@ impl StreamWriter { self.shared.stream_id, self.target ); - self.listener = None; self.poll_runtime(); return Poll::Ready(Ok(())); } Err(PushError::Closed(chunk)) => { *bytes = chunk; self.open = false; - self.listener = None; return self.poll_terminal(cx); } Err(PushError::Full(chunk)) => { @@ -88,12 +83,29 @@ impl StreamWriter { } } - let active_listener = self - .listener - .get_or_insert_with(|| self.shared.writer.listen()); - match std::pin::Pin::new(active_listener).poll(cx) { - Poll::Ready(()) => self.listener = None, - Poll::Pending => return Poll::Pending, + self.shared.writer.register_waiter(cx.waker()); + + match self.shared.writer.try_write(std::mem::take(bytes)) { + Ok(()) => { + self.shared.writer.unregister_waiter(); + log::trace!( + "byte writer accepted chunk: stream_id={:?} target={:?}", + self.shared.stream_id, + self.target + ); + self.poll_runtime(); + return Poll::Ready(Ok(())); + } + Err(PushError::Closed(chunk)) => { + self.shared.writer.unregister_waiter(); + *bytes = chunk; + self.open = false; + return self.poll_terminal(cx); + } + Err(PushError::Full(chunk)) => { + *bytes = chunk; + return Poll::Pending; + } } } } @@ -113,7 +125,6 @@ impl StreamWriter { self.target ); self.open = false; - self.listener = None; self.shared.writer.request_finish(); self.poll_runtime(); } @@ -147,34 +158,45 @@ impl StreamWriter { } loop { - let state = self.shared.writer.load_state(); - if WriterShared::terminal_ready(state) { - if WriterShared::terminal_ok(state) { - self.terminal = WriterTerminalState::Terminal(Ok(())); - return Poll::Ready(Ok(())); - } + match self.try_poll_terminal_ready() { + Poll::Ready(result) => return Poll::Ready(result), + Poll::Pending => {} + } - match self.shared.writer.pop() { - Ok(Item::Error(error)) => { - self.terminal = WriterTerminalState::Terminal(Err(error.clone())); - return Poll::Ready(Err(error)); - } - Ok(Item::Chunk(_)) => { - panic!("writer terminal phase contained chunk data") - } - Err(PopError::Empty) => {} - Err(PopError::Closed) => panic!("writer endpoint closed unexpectedly"), + self.shared.writer.register_waiter(cx.waker()); + + match self.try_poll_terminal_ready() { + Poll::Ready(result) => { + self.shared.writer.unregister_waiter(); + return Poll::Ready(result); } + Poll::Pending => return Poll::Pending, } + } + } - let active_listener = self - .listener - .get_or_insert_with(|| self.shared.writer.listen()); - match std::pin::Pin::new(active_listener).poll(cx) { - Poll::Ready(()) => self.listener = None, - Poll::Pending => return Poll::Pending, + fn try_poll_terminal_ready(&mut self) -> Poll> { + let state = self.shared.writer.load_state(); + if WriterShared::terminal_ready(state) { + if WriterShared::terminal_ok(state) { + self.terminal = WriterTerminalState::Terminal(Ok(())); + return Poll::Ready(Ok(())); + } + + match self.shared.writer.pop() { + Ok(Item::Error(error)) => { + self.terminal = WriterTerminalState::Terminal(Err(error.clone())); + return Poll::Ready(Err(error)); + } + Ok(Item::Chunk(_)) => { + panic!("writer terminal phase contained chunk data") + } + Err(PopError::Empty) => {} + Err(PopError::Closed) => panic!("writer endpoint closed unexpectedly"), } } + + Poll::Pending } fn close_inner(&mut self, code: StreamCloseCode) { From da87d1cfcf4cff9a2c885cf78e20399e90365057 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 17 Apr 2026 16:31:22 -0400 Subject: [PATCH 273/304] ql-runtime: io cleanup --- ql-runtime/src/driver/mod.rs | 18 +-- ql-runtime/src/driver/state.rs | 38 +++-- ql-runtime/src/driver/test.rs | 16 +- ql-runtime/src/io/{shared.rs => inner.rs} | 173 +++++++--------------- ql-runtime/src/io/mod.rs | 59 +++++--- ql-runtime/src/io/reader.rs | 38 ++--- ql-runtime/src/io/writer.rs | 48 +++--- 7 files changed, 182 insertions(+), 208 deletions(-) rename ql-runtime/src/io/{shared.rs => inner.rs} (81%) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 540e7b99..850ce7ff 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -220,7 +220,7 @@ impl DriverState { }; let stream_id = stream_ops.stream_id(); log::info!("open stream allocated: route_id={route_id} stream_id={stream_id}"); - let stream = io::new_stream( + let (reader, writer, reader_io, writer_io) = io::new_stream( stream_id, CloseTarget::Return, CloseTarget::Origin, @@ -230,12 +230,12 @@ impl DriverState { stream_id, DriverStreamIo::new( true, - Some(OutboundIo::new(stream.writer_io)), - Some(InboundIo::new(stream.reader_io)), + Some(OutboundIo::new(writer_io)), + Some(InboundIo::new(reader_io)), ), ); if start - .send(Ok((stream_id, stream.reader, stream.writer))) + .send(Ok((stream_id, reader, writer))) .is_err() { log::warn!("open stream cancelled before delivery: stream_id={stream_id}"); @@ -361,7 +361,7 @@ impl DriverState { return; }; - let stream = io::new_stream( + let (reader, writer, reader_io, writer_io) = io::new_stream( stream_id, CloseTarget::Origin, CloseTarget::Return, @@ -372,8 +372,8 @@ impl DriverState { stream_id, DriverStreamIo::new( false, - Some(OutboundIo::new(stream.writer_io)), - Some(InboundIo::new(stream.reader_io)), + Some(OutboundIo::new(writer_io)), + Some(InboundIo::new(reader_io)), ), ); @@ -383,8 +383,8 @@ impl DriverState { platform.handle_inbound(QlStream { stream_id, route_id, - reader: stream.reader, - writer: stream.writer, + reader, + writer, }); } diff --git a/ql-runtime/src/driver/state.rs b/ql-runtime/src/driver/state.rs index 37694c92..c3a1067f 100644 --- a/ql-runtime/src/driver/state.rs +++ b/ql-runtime/src/driver/state.rs @@ -5,7 +5,7 @@ use ql_wire::{CloseTarget, StreamId}; use crate::{ command::Command, - io::{PushError, ReaderIo, WriterIo}, + io::{PushError, Rx, Tx}, QlStreamError, }; @@ -65,18 +65,18 @@ impl DriverStreamIo { pub fn outbound_finish(&mut self) { if let Some(outbound) = self.outbound.take() { - outbound.writer.finish(); + outbound.tx.finish(); } } pub fn outbound_fail(&mut self, error: QlStreamError) { if let Some(outbound) = self.outbound.take() { - let _ = outbound.writer.fail(error); + let _ = outbound.tx.fail(error); } } - pub fn outbound_writer_mut(&mut self) -> Option<&mut WriterIo> { - self.outbound.as_mut().map(|outbound| &mut outbound.writer) + pub fn outbound_writer_mut(&mut self) -> Option<&mut OutboundIo> { + self.outbound.as_mut() } pub fn outbound_queue_finish(&mut self) { @@ -101,7 +101,7 @@ impl DriverStreamIo { }; let len = bytes.len(); - match inbound.reader.try_write(bytes) { + match inbound.rx.try_write(bytes) { Ok(()) => InboundWriteResult::Accepted(len), Err(PushError::Full(_)) => InboundWriteResult::Full, Err(PushError::Closed(_)) => { @@ -113,33 +113,43 @@ impl DriverStreamIo { pub fn inbound_finish(&mut self) { if let Some(inbound) = self.inbound.take() { - inbound.reader.finish(); + inbound.rx.finish(); } } pub fn inbound_fail(&mut self, error: QlStreamError) { if let Some(inbound) = self.inbound.take() { - let _ = inbound.reader.fail(error); + let _ = inbound.rx.fail(error); } } } pub struct OutboundIo { - writer: WriterIo, + tx: Tx, + pending: Bytes, finish_pending: bool, } impl OutboundIo { - pub fn new(writer: WriterIo) -> Self { + pub fn new(tx: Tx) -> Self { Self { - writer, + tx, + pending: Bytes::new(), finish_pending: false, } } + + pub fn is_finished(&self) -> bool { + self.pending.is_empty() && self.tx.is_finished() + } + + pub fn try_read(&mut self, max_len: usize) -> Result { + self.tx.try_read(&mut self.pending, max_len) + } } pub struct InboundIo { - reader: ReaderIo, + rx: Rx, } pub enum InboundWriteResult { @@ -149,7 +159,7 @@ pub enum InboundWriteResult { } impl InboundIo { - pub fn new(reader: ReaderIo) -> Self { - Self { reader } + pub fn new(rx: Rx) -> Self { + Self { rx } } } diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index 9ca68645..6f3840ac 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -73,7 +73,8 @@ fn new_inbound_io(capacity: usize) -> InboundIo { CloseTarget::Return, RuntimeHandle::new(runtime_tx), ); - InboundIo::new(stream.reader_io) + let (_, _, reader_io, _) = stream; + InboundIo::new(reader_io) } fn new_outbound_io() -> OutboundIo { @@ -84,7 +85,8 @@ fn new_outbound_io() -> OutboundIo { CloseTarget::Origin, RuntimeHandle::new(runtime_tx), ); - OutboundIo::new(stream.writer_io) + let (_, _, _, writer_io) = stream; + OutboundIo::new(writer_io) } #[test] @@ -126,16 +128,16 @@ fn poll_stream_keeps_outbound_pending_after_local_finish_when_inbound_is_closed( let (mut state, mut fsm) = new_driver_state(); let stream_id = StreamId(1u32.into()); let (runtime_tx, _runtime_rx) = async_channel::unbounded(); - let mut stream = io::new_stream( + let (_, mut writer, _, writer_io) = io::new_stream( stream_id, CloseTarget::Return, CloseTarget::Origin, RuntimeHandle::new(runtime_tx), ); - stream.writer.queue_finish(); + writer.queue_finish(); state.streams.insert( stream_id, - DriverStreamIo::new(true, Some(OutboundIo::new(stream.writer_io)), None), + DriverStreamIo::new(true, Some(OutboundIo::new(writer_io)), None), ); state.poll_stream(&mut fsm, stream_id); @@ -150,7 +152,7 @@ fn local_close_command_reaps_when_other_half_is_already_closed() { let (mut state, mut fsm) = new_driver_state(); let stream_id = StreamId(1u32.into()); let (runtime_tx, _runtime_rx) = async_channel::unbounded(); - let stream = io::new_stream( + let (_, _, _, writer_io) = io::new_stream( stream_id, CloseTarget::Return, CloseTarget::Origin, @@ -159,7 +161,7 @@ fn local_close_command_reaps_when_other_half_is_already_closed() { state.streams.insert( stream_id, - DriverStreamIo::new(true, Some(OutboundIo::new(stream.writer_io)), None), + DriverStreamIo::new(true, Some(OutboundIo::new(writer_io)), None), ); state.drive_command( diff --git a/ql-runtime/src/io/shared.rs b/ql-runtime/src/io/inner.rs similarity index 81% rename from ql-runtime/src/io/shared.rs rename to ql-runtime/src/io/inner.rs index dfa01136..879af7e5 100644 --- a/ql-runtime/src/io/shared.rs +++ b/ql-runtime/src/io/inner.rs @@ -6,7 +6,7 @@ use ql_wire::StreamId; use super::{ queue::{ForcePushError, PopError, PushError, Single}, - sync::{Arc, AtomicU8, Ordering}, + sync::{AtomicU8, Ordering}, }; use crate::QlStreamError; @@ -16,20 +16,18 @@ const WRITER_FINISH_REQUESTED: u8 = 1 << 0; const WRITER_TERMINAL_READY: u8 = 1 << 1; const WRITER_TERMINAL_OK: u8 = 1 << 2; -pub struct StreamShared { - pub stream_id: StreamId, - pub reader: ReaderShared, - pub writer: WriterShared, +pub(super) fn new(stream_id: StreamId) -> Inner { + Inner { + stream_id, + reader: RxInner::new(), + writer: TxInner::new(), + } } -impl StreamShared { - pub fn new(stream_id: StreamId) -> Arc { - Arc::new(Self { - stream_id, - reader: ReaderShared::new(), - writer: WriterShared::new(), - }) - } +pub(super) struct Inner { + pub(super) stream_id: StreamId, + pub(super) reader: RxInner, + pub(super) writer: TxInner, } pub enum Item { @@ -46,15 +44,13 @@ impl Item { } } -pub struct ReaderShared { +pub struct RxInner { slot: Single, - // Sound because StreamShared creates exactly one StreamReader for this side, - // and that reader is the only task that registers or unregisters wakers. changed: DiatomicWaker, state: AtomicU8, } -impl ReaderShared { +impl RxInner { fn new() -> Self { Self { slot: Single::new(), @@ -135,15 +131,13 @@ impl ReaderShared { } } -pub struct WriterShared { +pub struct TxInner { slot: Single, - // Sound because StreamShared creates exactly one StreamWriter for this side, - // and that writer is the only task that registers or unregisters wakers. changed: DiatomicWaker, state: AtomicU8, } -impl WriterShared { +impl TxInner { fn new() -> Self { Self { slot: Single::new(), @@ -277,98 +271,41 @@ impl WriterShared { // shared state, so register/unregister never run concurrently. unsafe { self.changed.unregister() }; } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct RecvClosed; - -pub struct ReaderIo { - shared: Arc, -} - -impl ReaderIo { - pub fn new(shared: Arc) -> Self { - Self { shared } - } - - pub fn try_write(&self, bytes: Bytes) -> Result<(), PushError> { - self.shared.reader.try_write(bytes) - } - - pub fn finish(&self) { - self.shared.reader.finish(); - } - - pub fn fail( - &self, - error: QlStreamError, - ) -> Result, ForcePushError> { - self.shared.reader.fail(error) - } -} - -pub struct WriterIo { - shared: Arc, - pending: Bytes, -} - -impl WriterIo { - pub fn new(shared: Arc) -> Self { - Self { - shared, - pending: Bytes::new(), - } - } pub fn is_finished(&self) -> bool { - let state = self.shared.writer.load_state(); - self.pending.is_empty() - && WriterShared::finish_requested(state) - && self.shared.writer.is_empty() + let state = self.load_state(); + TxInner::finish_requested(state) && self.is_empty() } - pub fn try_read(&mut self, max_len: usize) -> Result { - if !self.pending.is_empty() { - let pending = &mut self.pending; - let bytes = if pending.len() <= max_len { + pub fn try_read(&self, pending: &mut Bytes, max_len: usize) -> Result { + if !pending.is_empty() { + return Ok(if pending.len() <= max_len { std::mem::take(pending) } else { pending.split_to(max_len) - }; - return Ok(bytes); + }); } - let state = self.shared.writer.load_state(); - if WriterShared::terminal_ready(state) { - return Err(RecvClosed); + let state = self.load_state(); + if TxInner::terminal_ready(state) { + return Err(()); } - match self.shared.writer.pop() { + match self.pop() { Ok(Item::Chunk(mut bytes)) => { if bytes.len() <= max_len { Ok(bytes) } else { let head = bytes.split_to(max_len); - self.pending = bytes; + *pending = bytes; Ok(head) } } - Ok(Item::Error(_)) => Err(RecvClosed), + Ok(Item::Error(_)) => Err(()), Err(PopError::Empty) => Ok(Bytes::new()), - Err(PopError::Closed) => Err(RecvClosed), + Err(PopError::Closed) => Err(()), } } - - pub fn finish(&self) { - self.shared.writer.finish(); - } - - pub fn fail( - &self, - error: QlStreamError, - ) -> Result, ForcePushError> { - self.shared.writer.fail(error) - } } #[cfg(all(test, loom))] @@ -379,16 +316,16 @@ mod loom_tests { use loom::{model, thread}; use ql_wire::{StreamCloseCode, StreamId}; - use super::{Item, PopError, PushError, ReaderShared, StreamShared, WriterIo, WriterShared}; - use crate::QlStreamError; + use super::*; + use crate::{io::Tx, QlStreamError}; fn check_model(f: impl Fn() + Sync + Send + 'static) { let builder = model::Builder::new(); builder.check(f); } - fn shared() -> super::super::sync::Arc { - StreamShared::new(StreamId(1u32.into())) + fn shared() -> super::super::sync::Arc { + super::super::sync::Arc::new(new(StreamId(1u32.into()))) } #[test] @@ -405,7 +342,7 @@ mod loom_tests { }; finisher.join().unwrap(); - assert!(ReaderShared::is_finished(shared.reader.load_state())); + assert!(RxInner::is_finished(shared.reader.load_state())); shared.reader.unregister_waiter(); }); @@ -430,7 +367,7 @@ mod loom_tests { Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), _ => panic!("expected buffered reader chunk"), } - assert!(ReaderShared::is_finished(shared.reader.load_state())); + assert!(RxInner::is_finished(shared.reader.load_state())); assert!(matches!(shared.reader.pop(), Err(PopError::Empty))); }); } @@ -446,7 +383,7 @@ mod loom_tests { shared.reader.try_write(Bytes::from_static(b"abc")), Err(PushError::Closed(Bytes::from_static(b"abc"))) ); - assert!(ReaderShared::is_finished(shared.reader.load_state())); + assert!(RxInner::is_finished(shared.reader.load_state())); assert!(matches!(shared.reader.pop(), Err(PopError::Empty))); }); } @@ -468,7 +405,7 @@ mod loom_tests { let write_result = writer.join().unwrap(); finisher.join().unwrap(); - assert!(ReaderShared::is_finished(shared.reader.load_state())); + assert!(RxInner::is_finished(shared.reader.load_state())); match write_result { Ok(()) => match shared.reader.pop() { Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), @@ -531,16 +468,17 @@ mod loom_tests { fn writer_is_finished_only_after_drain() { check_model(|| { let shared = shared(); - let mut writer_io = WriterIo::new(shared.clone()); + let tx = Tx(shared.clone()); + let mut pending = Bytes::new(); shared.writer.try_write(Bytes::from_static(b"abc")).unwrap(); shared.writer.request_finish(); - assert!(!writer_io.is_finished()); - assert_eq!(writer_io.try_read(2), Ok(Bytes::from_static(b"ab"))); - assert!(!writer_io.is_finished()); - assert_eq!(writer_io.try_read(8), Ok(Bytes::from_static(b"c"))); - assert!(writer_io.is_finished()); + assert!(!(pending.is_empty() && tx.is_finished())); + assert_eq!(tx.try_read(&mut pending, 2), Ok(Bytes::from_static(b"ab"))); + assert!(!(pending.is_empty() && tx.is_finished())); + assert_eq!(tx.try_read(&mut pending, 8), Ok(Bytes::from_static(b"c"))); + assert!(pending.is_empty() && tx.is_finished()); }); } @@ -548,7 +486,8 @@ mod loom_tests { fn writer_write_races_with_request_finish() { check_model(|| { let shared = shared(); - let mut writer_io = WriterIo::new(shared.clone()); + let tx = Tx(shared.clone()); + let mut pending = Bytes::new(); let writer = { let shared = shared.clone(); @@ -562,15 +501,15 @@ mod loom_tests { let write_result = writer.join().unwrap(); finisher.join().unwrap(); - assert!(WriterShared::finish_requested(shared.writer.load_state())); + assert!(TxInner::finish_requested(shared.writer.load_state())); match write_result { Ok(()) => { - assert_eq!(writer_io.try_read(8), Ok(Bytes::from_static(b"abc"))); - assert!(writer_io.is_finished()); + assert_eq!(tx.try_read(&mut pending, 8), Ok(Bytes::from_static(b"abc"))); + assert!(pending.is_empty() && tx.is_finished()); } Err(PushError::Closed(bytes)) => { assert_eq!(bytes, Bytes::from_static(b"abc")); - assert!(writer_io.is_finished()); + assert!(pending.is_empty() && tx.is_finished()); } Err(PushError::Full(_)) => panic!("empty writer slot must not report full"), } @@ -596,7 +535,7 @@ mod loom_tests { failer.join().unwrap(); - assert!(WriterShared::terminal_ready(shared.writer.load_state())); + assert!(TxInner::terminal_ready(shared.writer.load_state())); shared.writer.unregister_waiter(); match shared.writer.pop() { Ok(Item::Error(QlStreamError::StreamClosed { code })) => { @@ -621,7 +560,7 @@ mod loom_tests { shared.reader.register_waiter(Waker::noop()); shared.reader.finish(); - assert!(ReaderShared::is_finished(shared.reader.load_state())); + assert!(RxInner::is_finished(shared.reader.load_state())); shared.reader.unregister_waiter(); }); } @@ -640,7 +579,7 @@ mod loom_tests { shared.writer.register_waiter(Waker::noop()); shared.writer.finish(); - assert!(WriterShared::terminal_ready(shared.writer.load_state())); + assert!(TxInner::terminal_ready(shared.writer.load_state())); shared.writer.unregister_waiter(); }); } @@ -666,7 +605,7 @@ mod loom_tests { let write_result = writer.join().unwrap(); let fail_result = failer.join().unwrap(); - assert!(WriterShared::terminal_ready(shared.writer.load_state())); + assert!(TxInner::terminal_ready(shared.writer.load_state())); match (&write_result, &fail_result) { (Ok(()), Ok(Some(bytes))) => { assert_eq!(Bytes::from_static(b"abc"), bytes.clone()); @@ -712,13 +651,13 @@ mod loom_tests { finisher.join().unwrap(); let fail_result = failer.join().unwrap(); - assert!(WriterShared::terminal_ready(shared.writer.load_state())); + assert!(TxInner::terminal_ready(shared.writer.load_state())); match fail_result { Err(_) => { - assert!(WriterShared::terminal_ok(shared.writer.load_state())); + assert!(TxInner::terminal_ok(shared.writer.load_state())); } Ok(_) => { - assert!(!WriterShared::terminal_ok(shared.writer.load_state())); + assert!(!TxInner::terminal_ok(shared.writer.load_state())); match shared.writer.pop() { Ok(Item::Error(QlStreamError::StreamClosed { code })) => { assert_eq!(code, StreamCloseCode::CANCELLED); diff --git a/ql-runtime/src/io/mod.rs b/ql-runtime/src/io/mod.rs index 83c1079f..148eef4f 100644 --- a/ql-runtime/src/io/mod.rs +++ b/ql-runtime/src/io/mod.rs @@ -4,21 +4,44 @@ mod shared; mod sync; mod writer; +use std::ops::Deref; + use ql_wire::{CloseTarget, StreamId}; -use self::shared::StreamShared; -pub(crate) use self::{ - queue::PushError, - shared::{ReaderIo, WriterIo}, -}; +pub(crate) use self::queue::PushError; pub use self::{reader::StreamReader, writer::StreamWriter}; use crate::RuntimeHandle; -pub(crate) struct StreamIo { - pub reader: StreamReader, - pub writer: StreamWriter, - pub reader_io: ReaderIo, - pub writer_io: WriterIo, +pub(crate) struct Rx(sync::Arc); + +impl Deref for Rx { + type Target = shared::RxInner; + + fn deref(&self) -> &Self::Target { + &self.0.reader + } +} + +impl Rx { + pub fn stream_id(&self) -> StreamId { + self.0.stream_id + } +} + +pub(crate) struct Tx(sync::Arc); + +impl Deref for Tx { + type Target = shared::TxInner; + + fn deref(&self) -> &Self::Target { + &self.0.writer + } +} + +impl Tx { + pub fn stream_id(&self) -> StreamId { + self.0.stream_id + } } pub(crate) fn new_stream( @@ -26,12 +49,12 @@ pub(crate) fn new_stream( reader_target: CloseTarget, writer_target: CloseTarget, handle: RuntimeHandle, -) -> StreamIo { - let shared = StreamShared::new(stream_id); - StreamIo { - reader: StreamReader::new(shared.clone(), reader_target, handle.clone()), - writer: StreamWriter::new(shared.clone(), writer_target, handle), - reader_io: ReaderIo::new(shared.clone()), - writer_io: WriterIo::new(shared), - } +) -> (StreamReader, StreamWriter, Rx, Tx) { + let shared = sync::Arc::new(shared::new(stream_id)); + ( + StreamReader::new(Rx(shared.clone()), reader_target, handle.clone()), + StreamWriter::new(Tx(shared.clone()), writer_target, handle), + Rx(shared.clone()), + Tx(shared), + ) } diff --git a/ql-runtime/src/io/reader.rs b/ql-runtime/src/io/reader.rs index a7bacd3e..377a222d 100644 --- a/ql-runtime/src/io/reader.rs +++ b/ql-runtime/src/io/reader.rs @@ -8,13 +8,13 @@ use ql_wire::{CloseTarget, StreamCloseCode}; use super::{ queue::PopError, - shared::{Item, ReaderShared, StreamShared}, - sync::Arc, + shared::{Item, RxInner}, + Rx, }; use crate::{command::Command, log, QlStreamError, RuntimeHandle}; pub struct StreamReader { - shared: Arc, + rx: Rx, target: CloseTarget, pending: Bytes, terminal: ReaderTerminalState, @@ -31,7 +31,7 @@ unsafe impl Sync for StreamReader {} impl std::fmt::Debug for StreamReader { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("InboundByteStream") - .field("stream_id", &self.shared.stream_id) + .field("stream_id", &self.rx.stream_id()) .field("target", &self.target) .field( "terminal", @@ -42,9 +42,9 @@ impl std::fmt::Debug for StreamReader { } impl StreamReader { - pub fn new(shared: Arc, target: CloseTarget, handle: RuntimeHandle) -> Self { + pub(crate) fn new(shared: Rx, target: CloseTarget, handle: RuntimeHandle) -> Self { Self { - shared, + rx: shared, target, pending: Bytes::new(), terminal: ReaderTerminalState::Open, @@ -67,11 +67,11 @@ impl StreamReader { Poll::Pending => {} } - self.shared.reader.register_waiter(cx.waker()); + self.rx.register_waiter(cx.waker()); match self.try_read_ready(max_len) { Poll::Ready(result) => { - self.shared.reader.unregister_waiter(); + self.rx.unregister_waiter(); return Poll::Ready(result); } Poll::Pending => return Poll::Pending, @@ -88,21 +88,21 @@ impl StreamReader { pending.split_to(max_len) }; self.handle.try_send(Command::PollInbound { - stream_id: self.shared.stream_id, + stream_id: self.rx.stream_id(), }); return Poll::Ready(Ok(Some(bytes))); } - match self.shared.reader.pop() { + match self.rx.pop() { Ok(Item::Chunk(mut bytes)) => { log::trace!( "byte reader received chunk: stream_id={:?} target={:?} len={}", - self.shared.stream_id, + self.rx.stream_id(), self.target, bytes.len() ); self.handle.try_send(Command::PollInbound { - stream_id: self.shared.stream_id, + stream_id: self.rx.stream_id(), }); if bytes.len() <= max_len { return Poll::Ready(Ok(Some(bytes))); @@ -114,7 +114,7 @@ impl StreamReader { Ok(Item::Error(error)) => { log::debug!( "byte reader delivered terminal error: stream_id={:?} target={:?} error={:?}", - self.shared.stream_id, + self.rx.stream_id(), self.target, error ); @@ -122,10 +122,10 @@ impl StreamReader { Poll::Ready(Err(error)) } Err(PopError::Empty) => { - if ReaderShared::is_finished(self.shared.reader.load_state()) { + if RxInner::is_finished(self.rx.load_state()) { log::debug!( "byte reader delivered clean eof: stream_id={:?} target={:?}", - self.shared.stream_id, + self.rx.stream_id(), self.target ); self.terminal = ReaderTerminalState::Delivered; @@ -162,13 +162,13 @@ impl StreamReader { } log::debug!( "byte reader explicit close: stream_id={:?} target={:?} code={:?}", - self.shared.stream_id, + self.rx.stream_id(), self.target, code ); self.terminal = ReaderTerminalState::Delivered; self.handle.try_send(Command::CloseStream { - stream_id: self.shared.stream_id, + stream_id: self.rx.stream_id(), target: self.target, code, }); @@ -182,12 +182,12 @@ impl Drop for StreamReader { } log::debug!( "byte reader drop close: stream_id={:?} target={:?} code={:?}", - self.shared.stream_id, + self.rx.stream_id(), self.target, StreamCloseCode::CANCELLED ); self.handle.try_send(Command::CloseStream { - stream_id: self.shared.stream_id, + stream_id: self.rx.stream_id(), target: self.target, code: StreamCloseCode::CANCELLED, }); diff --git a/ql-runtime/src/io/writer.rs b/ql-runtime/src/io/writer.rs index fdb8402c..ce0001d1 100644 --- a/ql-runtime/src/io/writer.rs +++ b/ql-runtime/src/io/writer.rs @@ -8,13 +8,13 @@ use ql_wire::{CloseTarget, StreamCloseCode}; use super::{ queue::{PopError, PushError}, - shared::{Item, StreamShared, WriterShared}, - sync::Arc, + shared::{Item, TxInner}, + Tx, }; use crate::{command::Command, log, QlStreamError, RuntimeHandle}; pub struct StreamWriter { - shared: Arc, + tx: Tx, target: CloseTarget, open: bool, terminal: WriterTerminalState, @@ -31,7 +31,7 @@ unsafe impl Sync for StreamWriter {} impl std::fmt::Debug for StreamWriter { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("OutboundByteStream") - .field("stream_id", &self.shared.stream_id) + .field("stream_id", &self.tx.stream_id()) .field("target", &self.target) .field("closed", &!self.open) .finish_non_exhaustive() @@ -39,9 +39,9 @@ impl std::fmt::Debug for StreamWriter { } impl StreamWriter { - pub fn new(shared: Arc, target: CloseTarget, handle: RuntimeHandle) -> Self { + pub(crate) fn new(shared: Tx, target: CloseTarget, handle: RuntimeHandle) -> Self { Self { - shared, + tx: shared, target, open: true, terminal: WriterTerminalState::Pending, @@ -63,11 +63,11 @@ impl StreamWriter { } loop { - match self.shared.writer.try_write(std::mem::take(bytes)) { + match self.tx.try_write(std::mem::take(bytes)) { Ok(()) => { log::trace!( "byte writer accepted chunk: stream_id={:?} target={:?}", - self.shared.stream_id, + self.tx.stream_id(), self.target ); self.poll_runtime(); @@ -83,21 +83,21 @@ impl StreamWriter { } } - self.shared.writer.register_waiter(cx.waker()); + self.tx.register_waiter(cx.waker()); - match self.shared.writer.try_write(std::mem::take(bytes)) { + match self.tx.try_write(std::mem::take(bytes)) { Ok(()) => { - self.shared.writer.unregister_waiter(); + self.tx.unregister_waiter(); log::trace!( "byte writer accepted chunk: stream_id={:?} target={:?}", - self.shared.stream_id, + self.tx.stream_id(), self.target ); self.poll_runtime(); return Poll::Ready(Ok(())); } Err(PushError::Closed(chunk)) => { - self.shared.writer.unregister_waiter(); + self.tx.unregister_waiter(); *bytes = chunk; self.open = false; return self.poll_terminal(cx); @@ -121,11 +121,11 @@ impl StreamWriter { } log::debug!( "byte writer finish: stream_id={:?} target={:?}", - self.shared.stream_id, + self.tx.stream_id(), self.target ); self.open = false; - self.shared.writer.request_finish(); + self.tx.request_finish(); self.poll_runtime(); } @@ -147,7 +147,7 @@ impl StreamWriter { fn poll_runtime(&self) { self.handle.try_send(Command::PollStream { - stream_id: self.shared.stream_id, + stream_id: self.tx.stream_id(), }); } @@ -163,11 +163,11 @@ impl StreamWriter { Poll::Pending => {} } - self.shared.writer.register_waiter(cx.waker()); + self.tx.register_waiter(cx.waker()); match self.try_poll_terminal_ready() { Poll::Ready(result) => { - self.shared.writer.unregister_waiter(); + self.tx.unregister_waiter(); return Poll::Ready(result); } Poll::Pending => return Poll::Pending, @@ -176,14 +176,14 @@ impl StreamWriter { } fn try_poll_terminal_ready(&mut self) -> Poll> { - let state = self.shared.writer.load_state(); - if WriterShared::terminal_ready(state) { - if WriterShared::terminal_ok(state) { + let state = self.tx.load_state(); + if TxInner::terminal_ready(state) { + if TxInner::terminal_ok(state) { self.terminal = WriterTerminalState::Terminal(Ok(())); return Poll::Ready(Ok(())); } - match self.shared.writer.pop() { + match self.tx.pop() { Ok(Item::Error(error)) => { self.terminal = WriterTerminalState::Terminal(Err(error.clone())); return Poll::Ready(Err(error)); @@ -206,12 +206,12 @@ impl StreamWriter { self.open = false; log::debug!( "byte writer close: stream_id={:?} target={:?} code={:?}", - self.shared.stream_id, + self.tx.stream_id(), self.target, code ); self.handle.try_send(Command::CloseStream { - stream_id: self.shared.stream_id, + stream_id: self.tx.stream_id(), target: self.target, code, }); From 26cb130d1778582cb7bb48b0b22baf1f31c3939e Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 17 Apr 2026 22:34:11 -0400 Subject: [PATCH 274/304] ql-runtime: reader/writer tests --- ql-runtime/src/io/inner.rs | 18 +++----- ql-runtime/src/io/mod.rs | 12 +++--- ql-runtime/src/io/reader.rs | 50 +++++++++++++++++++++- ql-runtime/src/io/sync.rs | 23 ++++++++++ ql-runtime/src/io/writer.rs | 84 ++++++++++++++++++++++++++++++++++++- 5 files changed, 167 insertions(+), 20 deletions(-) diff --git a/ql-runtime/src/io/inner.rs b/ql-runtime/src/io/inner.rs index 879af7e5..a52f463c 100644 --- a/ql-runtime/src/io/inner.rs +++ b/ql-runtime/src/io/inner.rs @@ -313,20 +313,14 @@ mod loom_tests { use std::task::Waker; use bytes::Bytes; - use loom::{model, thread}; - use ql_wire::{StreamCloseCode, StreamId}; + use loom::thread; + use ql_wire::StreamCloseCode; use super::*; - use crate::{io::Tx, QlStreamError}; - - fn check_model(f: impl Fn() + Sync + Send + 'static) { - let builder = model::Builder::new(); - builder.check(f); - } - - fn shared() -> super::super::sync::Arc { - super::super::sync::Arc::new(new(StreamId(1u32.into()))) - } + use crate::{ + io::{sync::loom::*, Tx}, + QlStreamError, + }; #[test] fn reader_waiter_registration_survives_finish() { diff --git a/ql-runtime/src/io/mod.rs b/ql-runtime/src/io/mod.rs index 148eef4f..6f3f5b9f 100644 --- a/ql-runtime/src/io/mod.rs +++ b/ql-runtime/src/io/mod.rs @@ -1,6 +1,6 @@ +mod inner; mod queue; mod reader; -mod shared; mod sync; mod writer; @@ -12,10 +12,10 @@ pub(crate) use self::queue::PushError; pub use self::{reader::StreamReader, writer::StreamWriter}; use crate::RuntimeHandle; -pub(crate) struct Rx(sync::Arc); +pub(crate) struct Rx(sync::Arc); impl Deref for Rx { - type Target = shared::RxInner; + type Target = inner::RxInner; fn deref(&self) -> &Self::Target { &self.0.reader @@ -28,10 +28,10 @@ impl Rx { } } -pub(crate) struct Tx(sync::Arc); +pub(crate) struct Tx(sync::Arc); impl Deref for Tx { - type Target = shared::TxInner; + type Target = inner::TxInner; fn deref(&self) -> &Self::Target { &self.0.writer @@ -50,7 +50,7 @@ pub(crate) fn new_stream( writer_target: CloseTarget, handle: RuntimeHandle, ) -> (StreamReader, StreamWriter, Rx, Tx) { - let shared = sync::Arc::new(shared::new(stream_id)); + let shared = sync::Arc::new(inner::new(stream_id)); ( StreamReader::new(Rx(shared.clone()), reader_target, handle.clone()), StreamWriter::new(Tx(shared.clone()), writer_target, handle), diff --git a/ql-runtime/src/io/reader.rs b/ql-runtime/src/io/reader.rs index 377a222d..d936a737 100644 --- a/ql-runtime/src/io/reader.rs +++ b/ql-runtime/src/io/reader.rs @@ -7,8 +7,8 @@ use bytes::Bytes; use ql_wire::{CloseTarget, StreamCloseCode}; use super::{ + inner::{Item, RxInner}, queue::PopError, - shared::{Item, RxInner}, Rx, }; use crate::{command::Command, log, QlStreamError, RuntimeHandle}; @@ -193,3 +193,51 @@ impl Drop for StreamReader { }); } } + +#[cfg(all(test, loom))] +mod loom_tests { + use std::task::{Context, Poll, Waker}; + + use bytes::Bytes; + use loom::thread; + use ql_wire::CloseTarget; + + use super::*; + use crate::io::sync::loom::*; + + #[test] + fn poll_read_observes_chunk_racing_with_registration() { + check_model(|| { + let inner = shared(); + let mut reader = StreamReader::new( + Rx(inner.clone()), + CloseTarget::Origin, + handle(), + ); + let mut cx = Context::from_waker(Waker::noop()); + + let producer = { + let inner = inner.clone(); + thread::spawn(move || { + inner.reader.try_write(Bytes::from_static(b"abc")).unwrap(); + }) + }; + + let first = reader.poll_read(usize::MAX, &mut cx); + producer.join().unwrap(); + + match first { + Poll::Ready(Ok(Some(bytes))) => { + assert_eq!(bytes, Bytes::from_static(b"abc")); + } + Poll::Pending => { + assert_eq!( + reader.poll_read(usize::MAX, &mut cx), + Poll::Ready(Ok(Some(Bytes::from_static(b"abc")))) + ); + } + other => panic!("unexpected first poll result: {other:?}"), + } + }); + } +} diff --git a/ql-runtime/src/io/sync.rs b/ql-runtime/src/io/sync.rs index c058710f..0e4ffb18 100644 --- a/ql-runtime/src/io/sync.rs +++ b/ql-runtime/src/io/sync.rs @@ -75,3 +75,26 @@ mod inner { } pub use inner::*; + +#[cfg(all(test, loom))] +pub(crate) mod loom { + use loom::model; + use ql_wire::StreamId; + + use super::Arc; + use crate::{io::inner::Inner, RuntimeHandle}; + + pub(crate) fn check_model(f: impl Fn() + Sync + Send + 'static) { + let builder = model::Builder::new(); + builder.check(f); + } + + pub(crate) fn shared() -> Arc { + Arc::new(crate::io::inner::new(StreamId(1u32.into()))) + } + + pub(crate) fn handle() -> RuntimeHandle { + let (tx, _rx) = async_channel::unbounded(); + RuntimeHandle::new(tx) + } +} diff --git a/ql-runtime/src/io/writer.rs b/ql-runtime/src/io/writer.rs index ce0001d1..c13bc7e1 100644 --- a/ql-runtime/src/io/writer.rs +++ b/ql-runtime/src/io/writer.rs @@ -7,8 +7,8 @@ use bytes::Bytes; use ql_wire::{CloseTarget, StreamCloseCode}; use super::{ + inner::{Item, TxInner}, queue::{PopError, PushError}, - shared::{Item, TxInner}, Tx, }; use crate::{command::Command, log, QlStreamError, RuntimeHandle}; @@ -223,3 +223,85 @@ impl Drop for StreamWriter { self.close_inner(StreamCloseCode::CANCELLED); } } + +#[cfg(all(test, loom))] +mod loom_tests { + use std::task::{Context, Poll, Waker}; + + use bytes::Bytes; + use loom::thread; + use ql_wire::CloseTarget; + + use super::*; + use crate::io::sync::loom::*; + + #[test] + fn poll_write_observes_capacity_racing_with_registration() { + check_model(|| { + let inner = shared(); + inner.writer.try_write(Bytes::from_static(b"abc")).unwrap(); + + let mut writer = StreamWriter::new( + Tx(inner.clone()), + CloseTarget::Origin, + handle(), + ); + let mut bytes = Bytes::from_static(b"xyz"); + let mut cx = Context::from_waker(Waker::noop()); + + let drainer = { + let inner = inner.clone(); + thread::spawn(move || { + assert!(matches!(inner.writer.pop(), Ok(Item::Chunk(_)))); + }) + }; + + let first = writer.poll_write(&mut bytes, &mut cx); + drainer.join().unwrap(); + + match first { + Poll::Ready(Ok(())) => { + assert!(bytes.is_empty()); + } + Poll::Pending => { + assert_eq!(writer.poll_write(&mut bytes, &mut cx), Poll::Ready(Ok(()))); + assert!(bytes.is_empty()); + } + other => panic!("unexpected first poll result: {other:?}"), + } + }); + } + + #[test] + fn poll_finish_observes_terminal_racing_with_registration() { + check_model(|| { + let inner = shared(); + let mut writer = StreamWriter::new( + Tx(inner.clone()), + CloseTarget::Origin, + handle(), + ); + let mut cx = Context::from_waker(Waker::noop()); + + writer.queue_finish(); + + let finisher = { + let inner = inner.clone(); + thread::spawn(move || { + inner.writer.finish(); + }) + }; + + let first = writer.poll_finish(&mut cx); + finisher.join().unwrap(); + + match first { + Poll::Ready(Ok(())) => {} + Poll::Pending => { + assert_eq!(writer.poll_finish(&mut cx), Poll::Ready(Ok(()))); + } + other => panic!("unexpected first poll result: {other:?}"), + } + }); + } +} From b94198738eaabb75cf72f0263f6a384d11189585 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 17 Apr 2026 23:02:25 -0400 Subject: [PATCH 275/304] ql-runtime: reader/writer cleanup --- ql-runtime/src/io/reader.rs | 14 +++++--------- ql-runtime/src/io/writer.rs | 20 ++++++-------------- 2 files changed, 11 insertions(+), 23 deletions(-) diff --git a/ql-runtime/src/io/reader.rs b/ql-runtime/src/io/reader.rs index d936a737..c7a8c77a 100644 --- a/ql-runtime/src/io/reader.rs +++ b/ql-runtime/src/io/reader.rs @@ -30,7 +30,7 @@ unsafe impl Sync for StreamReader {} impl std::fmt::Debug for StreamReader { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("InboundByteStream") + f.debug_struct("StreamReader") .field("stream_id", &self.rx.stream_id()) .field("target", &self.target) .field( @@ -96,7 +96,7 @@ impl StreamReader { match self.rx.pop() { Ok(Item::Chunk(mut bytes)) => { log::trace!( - "byte reader received chunk: stream_id={:?} target={:?} len={}", + "byte reader received chunk: stream_id={} target={:?} len={}", self.rx.stream_id(), self.target, bytes.len() @@ -113,7 +113,7 @@ impl StreamReader { } Ok(Item::Error(error)) => { log::debug!( - "byte reader delivered terminal error: stream_id={:?} target={:?} error={:?}", + "byte reader delivered terminal error: stream_id={} target={:?} error={:?}", self.rx.stream_id(), self.target, error @@ -124,7 +124,7 @@ impl StreamReader { Err(PopError::Empty) => { if RxInner::is_finished(self.rx.load_state()) { log::debug!( - "byte reader delivered clean eof: stream_id={:?} target={:?}", + "byte reader delivered clean eof: stream_id={} target={:?}", self.rx.stream_id(), self.target ); @@ -209,11 +209,7 @@ mod loom_tests { fn poll_read_observes_chunk_racing_with_registration() { check_model(|| { let inner = shared(); - let mut reader = StreamReader::new( - Rx(inner.clone()), - CloseTarget::Origin, - handle(), - ); + let mut reader = StreamReader::new(Rx(inner.clone()), CloseTarget::Origin, handle()); let mut cx = Context::from_waker(Waker::noop()); let producer = { diff --git a/ql-runtime/src/io/writer.rs b/ql-runtime/src/io/writer.rs index c13bc7e1..4b1f31ca 100644 --- a/ql-runtime/src/io/writer.rs +++ b/ql-runtime/src/io/writer.rs @@ -30,7 +30,7 @@ unsafe impl Sync for StreamWriter {} impl std::fmt::Debug for StreamWriter { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("OutboundByteStream") + f.debug_struct("StreamWriter") .field("stream_id", &self.tx.stream_id()) .field("target", &self.target) .field("closed", &!self.open) @@ -66,7 +66,7 @@ impl StreamWriter { match self.tx.try_write(std::mem::take(bytes)) { Ok(()) => { log::trace!( - "byte writer accepted chunk: stream_id={:?} target={:?}", + "byte writer accepted chunk: stream_id={} target={:?}", self.tx.stream_id(), self.target ); @@ -89,7 +89,7 @@ impl StreamWriter { Ok(()) => { self.tx.unregister_waiter(); log::trace!( - "byte writer accepted chunk: stream_id={:?} target={:?}", + "byte writer accepted chunk: stream_id={} target={:?}", self.tx.stream_id(), self.target ); @@ -120,7 +120,7 @@ impl StreamWriter { return; } log::debug!( - "byte writer finish: stream_id={:?} target={:?}", + "byte writer finish: stream_id={} target={:?}", self.tx.stream_id(), self.target ); @@ -241,11 +241,7 @@ mod loom_tests { let inner = shared(); inner.writer.try_write(Bytes::from_static(b"abc")).unwrap(); - let mut writer = StreamWriter::new( - Tx(inner.clone()), - CloseTarget::Origin, - handle(), - ); + let mut writer = StreamWriter::new(Tx(inner.clone()), CloseTarget::Origin, handle()); let mut bytes = Bytes::from_static(b"xyz"); let mut cx = Context::from_waker(Waker::noop()); @@ -276,11 +272,7 @@ mod loom_tests { fn poll_finish_observes_terminal_racing_with_registration() { check_model(|| { let inner = shared(); - let mut writer = StreamWriter::new( - Tx(inner.clone()), - CloseTarget::Origin, - handle(), - ); + let mut writer = StreamWriter::new(Tx(inner.clone()), CloseTarget::Origin, handle()); let mut cx = Context::from_waker(Waker::noop()); writer.queue_finish(); From acf55897414dfb91fa78abb5d835be5a5b9813e8 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 18 Apr 2026 05:34:53 -0400 Subject: [PATCH 276/304] ql-runtime: single atomic for rx/tx state --- ql-runtime/src/driver/mod.rs | 2 +- ql-runtime/src/driver/state.rs | 2 +- ql-runtime/src/io/inner.rs | 391 ++++++++++++++++----------------- ql-runtime/src/io/mod.rs | 16 +- ql-runtime/src/io/queue.rs | 189 ---------------- ql-runtime/src/io/reader.rs | 29 ++- ql-runtime/src/io/slot.rs | 176 +++++++++++++++ ql-runtime/src/io/sync.rs | 17 +- ql-runtime/src/io/writer.rs | 119 +++++----- 9 files changed, 445 insertions(+), 496 deletions(-) delete mode 100644 ql-runtime/src/io/queue.rs create mode 100644 ql-runtime/src/io/slot.rs diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 850ce7ff..ccf5f4bf 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -383,8 +383,8 @@ impl DriverState { platform.handle_inbound(QlStream { stream_id, route_id, - reader, writer, + reader, }); } diff --git a/ql-runtime/src/driver/state.rs b/ql-runtime/src/driver/state.rs index c3a1067f..0ff8eca8 100644 --- a/ql-runtime/src/driver/state.rs +++ b/ql-runtime/src/driver/state.rs @@ -119,7 +119,7 @@ impl DriverStreamIo { pub fn inbound_fail(&mut self, error: QlStreamError) { if let Some(inbound) = self.inbound.take() { - let _ = inbound.rx.fail(error); + inbound.rx.fail(error); } } } diff --git a/ql-runtime/src/io/inner.rs b/ql-runtime/src/io/inner.rs index a52f463c..bf4c45ae 100644 --- a/ql-runtime/src/io/inner.rs +++ b/ql-runtime/src/io/inner.rs @@ -1,3 +1,7 @@ +//! per-stream shared io state +//! each lane has one slot and one waker +//! the low slot bits belong to `slot.rs` and the higher bits here carry lane-specific flags + use std::task::Waker; use bytes::Bytes; @@ -5,29 +9,23 @@ use diatomic_waker::DiatomicWaker; use ql_wire::StreamId; use super::{ - queue::{ForcePushError, PopError, PushError, Single}, - sync::{AtomicU8, Ordering}, + slot::{PopError, PushError, Slot}, + sync::Arc, }; use crate::QlStreamError; -const READER_FINISHED: u8 = 1 << 0; - -const WRITER_FINISH_REQUESTED: u8 = 1 << 0; -const WRITER_TERMINAL_READY: u8 = 1 << 1; -const WRITER_TERMINAL_OK: u8 = 1 << 2; - -pub(super) fn new(stream_id: StreamId) -> Inner { - Inner { +pub(super) fn new(stream_id: StreamId) -> Arc { + Arc::new(Inner { stream_id, - reader: RxInner::new(), - writer: TxInner::new(), - } + rx: RxInner::new(), + tx: TxInner::new(), + }) } pub(super) struct Inner { pub(super) stream_id: StreamId, - pub(super) reader: RxInner, - pub(super) writer: TxInner, + pub(super) rx: RxInner, + pub(super) tx: TxInner, } pub enum Item { @@ -35,176 +33,137 @@ pub enum Item { Error(QlStreamError), } -impl Item { - fn into_chunk(self) -> Option { - match self { - Self::Chunk(bytes) => Some(bytes), - Self::Error(_) => None, - } - } -} +#[derive(Debug, PartialEq, Eq)] +pub struct ForcePushError(pub T); +/// reader-lane shared state pub struct RxInner { - slot: Single, + slot: Slot, changed: DiatomicWaker, - state: AtomicU8, } impl RxInner { + const FINISHED: usize = 1 << 2; + fn new() -> Self { Self { - slot: Single::new(), + slot: Slot::new(), changed: DiatomicWaker::new(), - state: AtomicU8::new(0), } } pub fn try_write(&self, bytes: Bytes) -> Result<(), PushError> { - if Self::is_finished(self.load_state()) { - return Err(PushError::Closed(bytes)); - } - - match self.slot.push(Item::Chunk(bytes)) { - Ok(()) => { - self.changed.notify(); - Ok(()) - } - Err(PushError::Closed(Item::Chunk(bytes))) => Err(PushError::Closed(bytes)), - Err(PushError::Full(Item::Chunk(bytes))) => Err(PushError::Full(bytes)), - Err(PushError::Closed(Item::Error(_))) | Err(PushError::Full(Item::Error(_))) => { - unreachable!("reader chunk write cannot recover an error payload") - } - } + try_write_chunk(&self.slot, &self.changed, bytes, Self::FINISHED) } + /// marks clean reader eof pub fn finish(&self) { - if self.state.fetch_or(READER_FINISHED, Ordering::Release) & READER_FINISHED == 0 { + if self.slot.fetch_or(Self::FINISHED) & Self::FINISHED == 0 { self.changed.notify(); } } + /// stores a terminal reader error pub fn fail( &self, error: QlStreamError, - ) -> Result, ForcePushError> { - match self.slot.force_push(Item::Error(error)) { - Ok(displaced) => { - self.changed.notify(); - Ok(displaced.and_then(Item::into_chunk)) - } - Err(ForcePushError(Item::Error(error))) => Err(ForcePushError(error)), - Err(ForcePushError(Item::Chunk(_))) => { - unreachable!("reader fail cannot recover a chunk payload") - } - } + ) -> Option { + let displaced = self.slot.force_push(Item::Error(error)); + self.changed.notify(); + displaced_bytes(displaced) } - pub fn load_state(&self) -> u8 { - self.state.load(Ordering::Acquire) + pub fn load_state(&self) -> usize { + self.slot.load_state() } - pub fn is_finished(state: u8) -> bool { - state & READER_FINISHED != 0 + pub fn is_finished(state: usize) -> bool { + state & Self::FINISHED != 0 } pub fn pop(&self) -> Result { - match self.slot.pop() { - Ok(Item::Chunk(bytes)) => { - self.changed.notify(); - Ok(Item::Chunk(bytes)) - } - Ok(Item::Error(error)) => Ok(Item::Error(error)), - Err(error) => Err(error), - } + pop_item(&self.slot, &self.changed) } + /// registers the sole reader-lane waiter pub fn register_waiter(&self, waker: &Waker) { - // Safety: StreamReader is the only reader-side registrar for this + // Safety: StreamReader is the only reader-lane registrar for this // shared state, so register/unregister never run concurrently. unsafe { self.changed.register(waker) }; } + /// unregisters the sole reader-lane waiter pub fn unregister_waiter(&self) { - // Safety: StreamReader is the only reader-side registrar for this + // Safety: StreamReader is the only reader-lane registrar for this // shared state, so register/unregister never run concurrently. unsafe { self.changed.unregister() }; } } +/// writer-lane shared state +/// +/// finish and fail race to establish the terminal result +/// terminal errors are stored in the slot pub struct TxInner { - slot: Single, + slot: Slot, changed: DiatomicWaker, - state: AtomicU8, } impl TxInner { + const FINISH_REQUESTED: usize = 1 << 2; + const TERMINAL_READY: usize = 1 << 3; + const TERMINAL_OK: usize = 1 << 4; + fn new() -> Self { Self { - slot: Single::new(), + slot: Slot::new(), changed: DiatomicWaker::new(), - state: AtomicU8::new(0), } } - pub fn load_state(&self) -> u8 { - self.state.load(Ordering::Acquire) + pub fn load_state(&self) -> usize { + self.slot.load_state() } - pub fn finish_requested(state: u8) -> bool { - state & WRITER_FINISH_REQUESTED != 0 + pub fn finish_requested(state: usize) -> bool { + state & Self::FINISH_REQUESTED != 0 } - pub fn terminal_ready(state: u8) -> bool { - state & WRITER_TERMINAL_READY != 0 + pub fn terminal_ready(state: usize) -> bool { + state & Self::TERMINAL_READY != 0 } - pub fn terminal_ok(state: u8) -> bool { - state & WRITER_TERMINAL_OK != 0 + pub fn terminal_ok(state: usize) -> bool { + state & Self::TERMINAL_OK != 0 } pub fn try_write(&self, bytes: Bytes) -> Result<(), PushError> { - let state = self.load_state(); - if Self::terminal_ready(state) || Self::finish_requested(state) { - return Err(PushError::Closed(bytes)); - } - - match self.slot.push(Item::Chunk(bytes)) { - Ok(()) => { - self.changed.notify(); - Ok(()) - } - Err(PushError::Closed(Item::Chunk(bytes))) => Err(PushError::Closed(bytes)), - Err(PushError::Full(Item::Chunk(bytes))) => Err(PushError::Full(bytes)), - Err(PushError::Closed(Item::Error(_))) | Err(PushError::Full(Item::Error(_))) => { - unreachable!("writer chunk write cannot recover an error payload") - } - } + try_write_chunk( + &self.slot, + &self.changed, + bytes, + Self::FINISH_REQUESTED | Self::TERMINAL_READY, + ) } + /// prevents future chunk writes once observed pub fn request_finish(&self) { - if self - .state - .fetch_or(WRITER_FINISH_REQUESTED, Ordering::Release) - & WRITER_FINISH_REQUESTED - == 0 - { + if self.slot.fetch_or(Self::FINISH_REQUESTED) & Self::FINISH_REQUESTED == 0 { self.changed.notify(); } } + /// commits a clean writer eof pub fn finish(&self) { - let mut state = self.state.load(Ordering::Acquire); + let mut state = self.slot.load_state(); loop { if Self::terminal_ready(state) { return; } - let new_state = state | WRITER_TERMINAL_READY | WRITER_TERMINAL_OK; - match self - .state - .compare_exchange(state, new_state, Ordering::AcqRel, Ordering::Acquire) - { - Ok(_) => { + let new_state = state | Self::TERMINAL_READY | Self::TERMINAL_OK; + match self.slot.compare_exchange(state, new_state) { + Ok(()) => { self.changed.notify(); return; } @@ -213,68 +172,52 @@ impl TxInner { } } + /// stores a terminal writer error + /// futures calls will have no effect pub fn fail( &self, error: QlStreamError, ) -> Result, ForcePushError> { - let mut state = self.state.load(Ordering::Acquire); + let mut state = self.slot.load_state(); loop { if Self::terminal_ready(state) { return Err(ForcePushError(error)); } - let new_state = state | WRITER_TERMINAL_READY; - match self - .state - .compare_exchange(state, new_state, Ordering::AcqRel, Ordering::Acquire) - { - Ok(_) => break, + let new_state = state | Self::TERMINAL_READY; + match self.slot.compare_exchange(state, new_state) { + Ok(()) => break, Err(actual) => state = actual, } } - match self.slot.force_push(Item::Error(error)) { - Ok(displaced) => { - self.changed.notify(); - Ok(displaced.and_then(Item::into_chunk)) - } - Err(ForcePushError(Item::Error(error))) => Err(ForcePushError(error)), - Err(ForcePushError(Item::Chunk(_))) => { - unreachable!("writer fail cannot recover a chunk payload") - } - } - } - - pub fn is_empty(&self) -> bool { - self.slot.is_empty() + let displaced = self.slot.force_push(Item::Error(error)); + self.changed.notify(); + Ok(displaced_bytes(displaced)) } pub fn pop(&self) -> Result { - match self.slot.pop() { - Ok(Item::Chunk(bytes)) => { - self.changed.notify(); - Ok(Item::Chunk(bytes)) - } - Ok(Item::Error(error)) => Ok(Item::Error(error)), - Err(error) => Err(error), - } + pop_item(&self.slot, &self.changed) } + /// registers the sole writer-lane waiter pub fn register_waiter(&self, waker: &Waker) { - // Safety: StreamWriter is the only writer-side registrar for this + // Safety: StreamWriter is the only writer-lane registrar for this // shared state, so register/unregister never run concurrently. unsafe { self.changed.register(waker) }; } + /// unregisters the sole writer-lane waiter pub fn unregister_waiter(&self) { - // Safety: StreamWriter is the only writer-side registrar for this + // Safety: StreamWriter is the only writer-lane registrar for this // shared state, so register/unregister never run concurrently. unsafe { self.changed.unregister() }; } + /// returns true once finish was requested and buffered data is drained pub fn is_finished(&self) -> bool { let state = self.load_state(); - TxInner::finish_requested(state) && self.is_empty() + Self::finish_requested(state) && Slot::::is_empty_state(state) } pub fn try_read(&self, pending: &mut Bytes, max_len: usize) -> Result { @@ -287,7 +230,7 @@ impl TxInner { } let state = self.load_state(); - if TxInner::terminal_ready(state) { + if Self::terminal_ready(state) { return Err(()); } @@ -302,9 +245,47 @@ impl TxInner { } } Ok(Item::Error(_)) => Err(()), - Err(PopError::Empty) => Ok(Bytes::new()), - Err(PopError::Closed) => Err(()), + Err(PopError) => Ok(Bytes::new()), + } + } +} + +#[inline] +fn try_write_chunk( + slot: &Slot, + changed: &DiatomicWaker, + bytes: Bytes, + closed_mask: usize, +) -> Result<(), PushError> { + match slot.try_push(Item::Chunk(bytes), closed_mask) { + Ok(()) => { + changed.notify(); + Ok(()) + } + Err(PushError::Closed(Item::Chunk(bytes))) => Err(PushError::Closed(bytes)), + Err(PushError::Full(Item::Chunk(bytes))) => Err(PushError::Full(bytes)), + Err(PushError::Closed(Item::Error(_)) | PushError::Full(Item::Error(_))) => { + unreachable!("chunk write cannot recover an error payload") + } + } +} + +#[inline] +fn displaced_bytes(displaced: Option) -> Option { + match displaced { + Some(Item::Chunk(bytes)) => Some(bytes), + Some(Item::Error(_)) | None => None, + } +} + +#[inline] +fn pop_item(slot: &Slot, changed: &DiatomicWaker) -> Result { + match slot.pop() { + item @ Ok(Item::Chunk(_)) => { + changed.notify(); + item } + item @ (Ok(Item::Error(_)) | Err(_)) => item, } } @@ -326,19 +307,19 @@ mod loom_tests { fn reader_waiter_registration_survives_finish() { check_model(|| { let shared = shared(); - shared.reader.register_waiter(Waker::noop()); + shared.rx.register_waiter(Waker::noop()); let finisher = { let shared = shared.clone(); thread::spawn(move || { - shared.reader.finish(); + shared.rx.finish(); }) }; finisher.join().unwrap(); - assert!(RxInner::is_finished(shared.reader.load_state())); + assert!(RxInner::is_finished(shared.rx.load_state())); - shared.reader.unregister_waiter(); + shared.rx.unregister_waiter(); }); } @@ -350,19 +331,19 @@ mod loom_tests { let producer = { let shared = shared.clone(); thread::spawn(move || { - shared.reader.try_write(Bytes::from_static(b"abc")).unwrap(); - shared.reader.finish(); + shared.rx.try_write(Bytes::from_static(b"abc")).unwrap(); + shared.rx.finish(); }) }; producer.join().unwrap(); - match shared.reader.pop() { + match shared.rx.pop() { Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), _ => panic!("expected buffered reader chunk"), } - assert!(RxInner::is_finished(shared.reader.load_state())); - assert!(matches!(shared.reader.pop(), Err(PopError::Empty))); + assert!(RxInner::is_finished(shared.rx.load_state())); + assert!(matches!(shared.rx.pop(), Err(PopError))); }); } @@ -371,14 +352,14 @@ mod loom_tests { check_model(|| { let shared = shared(); - shared.reader.finish(); + shared.rx.finish(); assert_eq!( - shared.reader.try_write(Bytes::from_static(b"abc")), + shared.rx.try_write(Bytes::from_static(b"abc")), Err(PushError::Closed(Bytes::from_static(b"abc"))) ); - assert!(RxInner::is_finished(shared.reader.load_state())); - assert!(matches!(shared.reader.pop(), Err(PopError::Empty))); + assert!(RxInner::is_finished(shared.rx.load_state())); + assert!(matches!(shared.rx.pop(), Err(PopError))); }); } @@ -389,30 +370,30 @@ mod loom_tests { let writer = { let shared = shared.clone(); - thread::spawn(move || shared.reader.try_write(Bytes::from_static(b"abc"))) + thread::spawn(move || shared.rx.try_write(Bytes::from_static(b"abc"))) }; let finisher = { let shared = shared.clone(); - thread::spawn(move || shared.reader.finish()) + thread::spawn(move || shared.rx.finish()) }; let write_result = writer.join().unwrap(); finisher.join().unwrap(); - assert!(RxInner::is_finished(shared.reader.load_state())); + assert!(RxInner::is_finished(shared.rx.load_state())); match write_result { - Ok(()) => match shared.reader.pop() { + Ok(()) => match shared.rx.pop() { Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), _ => panic!("expected buffered reader chunk"), }, Err(PushError::Closed(bytes)) => { assert_eq!(bytes, Bytes::from_static(b"abc")); - assert!(matches!(shared.reader.pop(), Err(PopError::Empty))); + assert!(matches!(shared.rx.pop(), Err(PopError))); return; } Err(PushError::Full(_)) => panic!("empty reader slot must not report full"), } - assert!(matches!(shared.reader.pop(), Err(PopError::Empty))); + assert!(matches!(shared.rx.pop(), Err(PopError))); }); } @@ -420,16 +401,16 @@ mod loom_tests { fn reader_fail_racing_with_pop_preserves_terminal_outcome() { check_model(|| { let shared = shared(); - shared.reader.try_write(Bytes::from_static(b"abc")).unwrap(); + shared.rx.try_write(Bytes::from_static(b"abc")).unwrap(); let popper = { let shared = shared.clone(); - thread::spawn(move || shared.reader.pop()) + thread::spawn(move || shared.rx.pop()) }; let failer = { let shared = shared.clone(); thread::spawn(move || { - shared.reader.fail(QlStreamError::StreamClosed { + shared.rx.fail(QlStreamError::StreamClosed { code: StreamCloseCode::CANCELLED, }) }) @@ -439,19 +420,19 @@ mod loom_tests { let fail_result = failer.join().unwrap(); match (pop_result, fail_result) { - (Ok(Item::Chunk(bytes)), Ok(None)) => { + (Ok(Item::Chunk(bytes)), None) => { assert_eq!(bytes, Bytes::from_static(b"abc")); - match shared.reader.pop() { + match shared.rx.pop() { Ok(Item::Error(QlStreamError::StreamClosed { code })) => { assert_eq!(code, StreamCloseCode::CANCELLED); } _ => panic!("expected terminal reader error"), } } - (Ok(Item::Error(QlStreamError::StreamClosed { code })), Ok(Some(bytes))) => { + (Ok(Item::Error(QlStreamError::StreamClosed { code })), Some(bytes)) => { assert_eq!(code, StreamCloseCode::CANCELLED); assert_eq!(bytes, Bytes::from_static(b"abc")); - assert!(matches!(shared.reader.pop(), Err(PopError::Empty))); + assert!(matches!(shared.rx.pop(), Err(PopError))); } _ => panic!("unexpected reader fail/pop race outcome"), } @@ -465,8 +446,8 @@ mod loom_tests { let tx = Tx(shared.clone()); let mut pending = Bytes::new(); - shared.writer.try_write(Bytes::from_static(b"abc")).unwrap(); - shared.writer.request_finish(); + shared.tx.try_write(Bytes::from_static(b"abc")).unwrap(); + shared.tx.request_finish(); assert!(!(pending.is_empty() && tx.is_finished())); assert_eq!(tx.try_read(&mut pending, 2), Ok(Bytes::from_static(b"ab"))); @@ -485,17 +466,17 @@ mod loom_tests { let writer = { let shared = shared.clone(); - thread::spawn(move || shared.writer.try_write(Bytes::from_static(b"abc"))) + thread::spawn(move || shared.tx.try_write(Bytes::from_static(b"abc"))) }; let finisher = { let shared = shared.clone(); - thread::spawn(move || shared.writer.request_finish()) + thread::spawn(move || shared.tx.request_finish()) }; let write_result = writer.join().unwrap(); finisher.join().unwrap(); - assert!(TxInner::finish_requested(shared.writer.load_state())); + assert!(TxInner::finish_requested(shared.tx.load_state())); match write_result { Ok(()) => { assert_eq!(tx.try_read(&mut pending, 8), Ok(Bytes::from_static(b"abc"))); @@ -514,13 +495,13 @@ mod loom_tests { fn writer_fail_overwrites_buffered_chunk_and_keeps_terminal_state_observable() { check_model(|| { let shared = shared(); - shared.writer.try_write(Bytes::from_static(b"abc")).unwrap(); - shared.writer.register_waiter(Waker::noop()); + shared.tx.try_write(Bytes::from_static(b"abc")).unwrap(); + shared.tx.register_waiter(Waker::noop()); let failer = { let shared = shared.clone(); thread::spawn(move || { - let displaced = shared.writer.fail(QlStreamError::StreamClosed { + let displaced = shared.tx.fail(QlStreamError::StreamClosed { code: StreamCloseCode::CANCELLED, }); assert_eq!(displaced.unwrap(), Some(Bytes::from_static(b"abc"))); @@ -529,9 +510,9 @@ mod loom_tests { failer.join().unwrap(); - assert!(TxInner::terminal_ready(shared.writer.load_state())); - shared.writer.unregister_waiter(); - match shared.writer.pop() { + assert!(TxInner::terminal_ready(shared.tx.load_state())); + shared.tx.unregister_waiter(); + match shared.tx.pop() { Ok(Item::Error(QlStreamError::StreamClosed { code })) => { assert_eq!(code, StreamCloseCode::CANCELLED); } @@ -545,17 +526,17 @@ mod loom_tests { check_model(|| { let shared = shared(); - shared.reader.register_waiter(Waker::noop()); - shared.reader.try_write(Bytes::from_static(b"abc")).unwrap(); - match shared.reader.pop() { + shared.rx.register_waiter(Waker::noop()); + shared.rx.try_write(Bytes::from_static(b"abc")).unwrap(); + match shared.rx.pop() { Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), _ => panic!("expected buffered reader chunk"), } - shared.reader.register_waiter(Waker::noop()); - shared.reader.finish(); - assert!(RxInner::is_finished(shared.reader.load_state())); - shared.reader.unregister_waiter(); + shared.rx.register_waiter(Waker::noop()); + shared.rx.finish(); + assert!(RxInner::is_finished(shared.rx.load_state())); + shared.rx.unregister_waiter(); }); } @@ -564,17 +545,17 @@ mod loom_tests { check_model(|| { let shared = shared(); - shared.writer.register_waiter(Waker::noop()); - shared.writer.try_write(Bytes::from_static(b"abc")).unwrap(); - match shared.writer.pop() { + shared.tx.register_waiter(Waker::noop()); + shared.tx.try_write(Bytes::from_static(b"abc")).unwrap(); + match shared.tx.pop() { Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), _ => panic!("expected buffered writer chunk"), } - shared.writer.register_waiter(Waker::noop()); - shared.writer.finish(); - assert!(TxInner::terminal_ready(shared.writer.load_state())); - shared.writer.unregister_waiter(); + shared.tx.register_waiter(Waker::noop()); + shared.tx.finish(); + assert!(TxInner::terminal_ready(shared.tx.load_state())); + shared.tx.unregister_waiter(); }); } @@ -585,12 +566,12 @@ mod loom_tests { let writer = { let shared = shared.clone(); - thread::spawn(move || shared.writer.try_write(Bytes::from_static(b"abc"))) + thread::spawn(move || shared.tx.try_write(Bytes::from_static(b"abc"))) }; let failer = { let shared = shared.clone(); thread::spawn(move || { - shared.writer.fail(QlStreamError::StreamClosed { + shared.tx.fail(QlStreamError::StreamClosed { code: StreamCloseCode::CANCELLED, }) }) @@ -599,7 +580,7 @@ mod loom_tests { let write_result = writer.join().unwrap(); let fail_result = failer.join().unwrap(); - assert!(TxInner::terminal_ready(shared.writer.load_state())); + assert!(TxInner::terminal_ready(shared.tx.load_state())); match (&write_result, &fail_result) { (Ok(()), Ok(Some(bytes))) => { assert_eq!(Bytes::from_static(b"abc"), bytes.clone()); @@ -615,7 +596,7 @@ mod loom_tests { ), } - match shared.writer.pop() { + match shared.tx.pop() { Ok(Item::Error(QlStreamError::StreamClosed { code })) => { assert_eq!(code, StreamCloseCode::CANCELLED); } @@ -631,12 +612,12 @@ mod loom_tests { let finisher = { let shared = shared.clone(); - thread::spawn(move || shared.writer.finish()) + thread::spawn(move || shared.tx.finish()) }; let failer = { let shared = shared.clone(); thread::spawn(move || { - shared.writer.fail(QlStreamError::StreamClosed { + shared.tx.fail(QlStreamError::StreamClosed { code: StreamCloseCode::CANCELLED, }) }) @@ -645,14 +626,14 @@ mod loom_tests { finisher.join().unwrap(); let fail_result = failer.join().unwrap(); - assert!(TxInner::terminal_ready(shared.writer.load_state())); + assert!(TxInner::terminal_ready(shared.tx.load_state())); match fail_result { Err(_) => { - assert!(TxInner::terminal_ok(shared.writer.load_state())); + assert!(TxInner::terminal_ok(shared.tx.load_state())); } Ok(_) => { - assert!(!TxInner::terminal_ok(shared.writer.load_state())); - match shared.writer.pop() { + assert!(!TxInner::terminal_ok(shared.tx.load_state())); + match shared.tx.pop() { Ok(Item::Error(QlStreamError::StreamClosed { code })) => { assert_eq!(code, StreamCloseCode::CANCELLED); } diff --git a/ql-runtime/src/io/mod.rs b/ql-runtime/src/io/mod.rs index 6f3f5b9f..8fa5cff4 100644 --- a/ql-runtime/src/io/mod.rs +++ b/ql-runtime/src/io/mod.rs @@ -1,6 +1,6 @@ mod inner; -mod queue; mod reader; +mod slot; mod sync; mod writer; @@ -8,17 +8,17 @@ use std::ops::Deref; use ql_wire::{CloseTarget, StreamId}; -pub(crate) use self::queue::PushError; +pub use self::slot::PushError; pub use self::{reader::StreamReader, writer::StreamWriter}; use crate::RuntimeHandle; -pub(crate) struct Rx(sync::Arc); +pub struct Rx(sync::Arc); impl Deref for Rx { type Target = inner::RxInner; fn deref(&self) -> &Self::Target { - &self.0.reader + &self.0.rx } } @@ -28,13 +28,13 @@ impl Rx { } } -pub(crate) struct Tx(sync::Arc); +pub struct Tx(sync::Arc); impl Deref for Tx { type Target = inner::TxInner; fn deref(&self) -> &Self::Target { - &self.0.writer + &self.0.tx } } @@ -44,13 +44,13 @@ impl Tx { } } -pub(crate) fn new_stream( +pub fn new_stream( stream_id: StreamId, reader_target: CloseTarget, writer_target: CloseTarget, handle: RuntimeHandle, ) -> (StreamReader, StreamWriter, Rx, Tx) { - let shared = sync::Arc::new(inner::new(stream_id)); + let shared = inner::new(stream_id); ( StreamReader::new(Rx(shared.clone()), reader_target, handle.clone()), StreamWriter::new(Tx(shared.clone()), writer_target, handle), diff --git a/ql-runtime/src/io/queue.rs b/ql-runtime/src/io/queue.rs deleted file mode 100644 index e254b5ed..00000000 --- a/ql-runtime/src/io/queue.rs +++ /dev/null @@ -1,189 +0,0 @@ -//! local single-slot queue for stream io -//! copied from `concurrent_queue::single::Single` in `concurrent-queue` - -use core::mem::MaybeUninit; - -#[allow(clippy::wildcard_imports)] -use super::sync::*; - -const LOCKED: usize = 1 << 0; -const PUSHED: usize = 1 << 1; -const CLOSED: usize = 1 << 2; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum PopError { - Empty, - Closed, -} - -#[derive(Debug, PartialEq, Eq)] -pub enum PushError { - Full(T), - Closed(T), -} - -#[derive(Debug, PartialEq, Eq)] -pub struct ForcePushError(pub T); - -/// A single-element queue. -pub struct Single { - state: AtomicUsize, - slot: UnsafeCell>, -} - -#[allow(clippy::non_send_fields_in_send_ty)] -unsafe impl Send for Single {} -unsafe impl Sync for Single {} - -impl Single { - /// Creates a new single-element queue. - pub fn new() -> Self { - Self { - state: AtomicUsize::new(0), - slot: UnsafeCell::new(MaybeUninit::uninit()), - } - } - - /// Attempts to push an item into the queue. - pub fn push(&self, value: T) -> Result<(), PushError> { - // Lock and fill the slot. - let state = self - .state - .compare_exchange(0, LOCKED | PUSHED, Ordering::SeqCst, Ordering::SeqCst) - .unwrap_or_else(|x| x); - - if state == 0 { - // Write the value and unlock. - self.slot.with_mut(|slot| unsafe { - slot.write(MaybeUninit::new(value)); - }); - self.state.fetch_and(!LOCKED, Ordering::Release); - Ok(()) - } else if state & CLOSED != 0 { - Err(PushError::Closed(value)) - } else { - Err(PushError::Full(value)) - } - } - - /// Attempts to push an item into the queue, displacing another if necessary. - pub fn force_push(&self, value: T) -> Result, ForcePushError> { - // Attempt to lock the slot. - let mut state = 0; - - loop { - // Lock the slot. - let prev = self - .state - .compare_exchange(state, LOCKED | PUSHED, Ordering::SeqCst, Ordering::SeqCst) - .unwrap_or_else(|x| x); - - if prev & CLOSED != 0 { - return Err(ForcePushError(value)); - } - - if prev == state { - // If the value was pushed, swap out the value. - let prev_value = if prev & PUSHED == 0 { - // SAFETY: write is safe because we have locked the state. - self.slot.with_mut(|slot| unsafe { - slot.write(MaybeUninit::new(value)); - }); - None - } else { - // SAFETY: replace is safe because we have locked the state, and - // assume_init is safe because we have checked that the value was pushed. - self.slot.with_mut(move |slot| unsafe { - Some(std::ptr::replace(slot, MaybeUninit::new(value)).assume_init()) - }) - }; - - if let Some(prev_value) = prev_value { - // We can unlock the slot now. - self.state.fetch_and(!LOCKED, Ordering::Release); - // Return the old value. - return Ok(Some(prev_value)); - } - - // We can unlock the slot now. - self.state.fetch_and(!LOCKED, Ordering::Release); - return Ok(None); - } - - // Try to go for the current (pushed) state. - if prev & LOCKED == 0 { - state = prev; - } else { - // State is locked. - busy_wait(); - state = prev & !LOCKED; - } - } - } - - /// Attempts to pop an item from the queue. - pub fn pop(&self) -> Result { - let mut state = PUSHED; - loop { - // Lock and empty the slot. - let prev = self - .state - .compare_exchange( - state, - (state | LOCKED) & !PUSHED, - Ordering::SeqCst, - Ordering::SeqCst, - ) - .unwrap_or_else(|x| x); - - if prev == state { - // Read the value and unlock. - let value = self - .slot - .with_mut(|slot| unsafe { slot.read().assume_init() }); - self.state.fetch_and(!LOCKED, Ordering::Release); - return Ok(value); - } - - if prev & PUSHED == 0 { - return if prev & CLOSED == 0 { - Err(PopError::Empty) - } else { - Err(PopError::Closed) - }; - } - - if prev & LOCKED == 0 { - state = prev; - } else { - busy_wait(); - state = prev & !LOCKED; - } - } - } - - /// Returns the number of items in the queue. - pub fn len(&self) -> usize { - usize::from(self.state.load(Ordering::SeqCst) & PUSHED != 0) - } - - /// Returns `true` if the queue is empty. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } -} - -impl Drop for Single { - fn drop(&mut self) { - // Drop the value in the slot. - let Self { state, slot } = self; - state.with_mut(|state| { - if *state & PUSHED != 0 { - slot.with_mut(|slot| unsafe { - let value = &mut *slot; - value.as_mut_ptr().drop_in_place(); - }); - } - }); - } -} diff --git a/ql-runtime/src/io/reader.rs b/ql-runtime/src/io/reader.rs index c7a8c77a..8c40ccd3 100644 --- a/ql-runtime/src/io/reader.rs +++ b/ql-runtime/src/io/reader.rs @@ -8,7 +8,7 @@ use ql_wire::{CloseTarget, StreamCloseCode}; use super::{ inner::{Item, RxInner}, - queue::PopError, + slot::PopError, Rx, }; use crate::{command::Command, log, QlStreamError, RuntimeHandle}; @@ -61,21 +61,19 @@ impl StreamReader { return Poll::Ready(Ok(None)); } - loop { - match self.try_read_ready(max_len) { - Poll::Ready(result) => return Poll::Ready(result), - Poll::Pending => {} - } + match self.try_read_ready(max_len) { + Poll::Ready(result) => return Poll::Ready(result), + Poll::Pending => {} + } - self.rx.register_waiter(cx.waker()); + self.rx.register_waiter(cx.waker()); - match self.try_read_ready(max_len) { - Poll::Ready(result) => { - self.rx.unregister_waiter(); - return Poll::Ready(result); - } - Poll::Pending => return Poll::Pending, + match self.try_read_ready(max_len) { + Poll::Ready(result) => { + self.rx.unregister_waiter(); + Poll::Ready(result) } + Poll::Pending => Poll::Pending, } } @@ -121,7 +119,7 @@ impl StreamReader { self.terminal = ReaderTerminalState::Delivered; Poll::Ready(Err(error)) } - Err(PopError::Empty) => { + Err(PopError) => { if RxInner::is_finished(self.rx.load_state()) { log::debug!( "byte reader delivered clean eof: stream_id={} target={:?}", @@ -133,7 +131,6 @@ impl StreamReader { } Poll::Pending } - Err(PopError::Closed) => panic!("reader endpoint closed unexpectedly"), } } @@ -215,7 +212,7 @@ mod loom_tests { let producer = { let inner = inner.clone(); thread::spawn(move || { - inner.reader.try_write(Bytes::from_static(b"abc")).unwrap(); + inner.rx.try_write(Bytes::from_static(b"abc")).unwrap(); }) }; diff --git a/ql-runtime/src/io/slot.rs b/ql-runtime/src/io/slot.rs new file mode 100644 index 00000000..e73180db --- /dev/null +++ b/ql-runtime/src/io/slot.rs @@ -0,0 +1,176 @@ +//! local single-slot queue for stream io +//! copied from `concurrent_queue::single::Single` in `concurrent-queue` + +use core::mem::MaybeUninit; + +#[allow(clippy::wildcard_imports)] +use super::sync::*; + +const LOCKED: usize = 1 << 0; +const PUSHED: usize = 1 << 1; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct PopError; + +#[derive(Debug, PartialEq, Eq)] +pub enum PushError { + Full(T), + Closed(T), +} + +/// A single-element queue. +pub struct Slot { + state: AtomicUsize, + value: UnsafeCell>, +} + +unsafe impl Send for Slot {} +unsafe impl Sync for Slot {} + +impl Slot { + /// Creates a new single-element queue. + pub fn new() -> Self { + Self { + state: AtomicUsize::new(0), + value: UnsafeCell::new(MaybeUninit::uninit()), + } + } + + #[inline] + pub fn load_state(&self) -> usize { + self.state.load(Ordering::Acquire) + } + + #[inline] + pub fn fetch_or(&self, bits: usize) -> usize { + self.state.fetch_or(bits, Ordering::Release) + } + + #[inline] + pub fn compare_exchange(&self, current: usize, new: usize) -> Result<(), usize> { + self.state + .compare_exchange(current, new, Ordering::AcqRel, Ordering::Acquire) + .map(|_| ()) + } + + /// Attempts to push an item into the queue. + pub fn try_push(&self, value: T, closed_mask: usize) -> Result<(), PushError> { + let mut state = self.load_state(); + loop { + if state & closed_mask != 0 { + return Err(PushError::Closed(value)); + } + if state & LOCKED != 0 { + busy_wait(); + state = self.load_state(); + continue; + } + if state & PUSHED != 0 { + return Err(PushError::Full(value)); + } + + // Lock and fill the slot. + let new_state = state | LOCKED | PUSHED; + match self.compare_exchange(state, new_state) { + Ok(()) => { + // Write the value and unlock. + self.value.with_mut(|slot| unsafe { + slot.write(MaybeUninit::new(value)); + }); + self.state.fetch_and(!LOCKED, Ordering::Release); + return Ok(()); + } + Err(actual) => state = actual, + } + } + } + + /// Attempts to push an item into the queue, displacing another if necessary. + pub fn force_push(&self, value: T) -> Option { + // Attempt to lock the slot. + let mut state = self.load_state(); + + loop { + if state & LOCKED != 0 { + busy_wait(); + state = self.load_state(); + continue; + } + + // Lock the slot. + let new_state = state | LOCKED | PUSHED; + match self.compare_exchange(state, new_state) { + Ok(()) => { + // If the value was pushed, swap out the value. + let displaced = if state & PUSHED == 0 { + // SAFETY: write is safe because we have locked the state. + self.value.with_mut(|slot| unsafe { + slot.write(MaybeUninit::new(value)); + }); + None + } else { + // SAFETY: replace is safe because we have locked the state, and + // assume_init is safe because we have checked that the value was pushed. + self.value.with_mut(move |slot| unsafe { + Some(std::ptr::replace(slot, MaybeUninit::new(value)).assume_init()) + }) + }; + + // We can unlock the slot now. + self.state.fetch_and(!LOCKED, Ordering::Release); + return displaced; + } + Err(actual) => state = actual, + } + } + } + + /// Attempts to pop an item from the queue. + pub fn pop(&self) -> Result { + let mut state = PUSHED; + loop { + if state & LOCKED != 0 { + busy_wait(); + state = self.load_state(); + continue; + } + if state & PUSHED == 0 { + return Err(PopError); + } + + // Lock and empty the slot. + let new_state = (state | LOCKED) & !PUSHED; + match self.compare_exchange(state, new_state) { + Ok(()) => { + // Read the value and unlock. + let value = self + .value + .with_mut(|slot| unsafe { slot.read().assume_init() }); + self.state.fetch_and(!LOCKED, Ordering::Release); + return Ok(value); + } + Err(actual) => state = actual, + } + } + } + + #[inline] + pub fn is_empty_state(state: usize) -> bool { + state & PUSHED == 0 + } + +} + +impl Drop for Slot { + fn drop(&mut self) { + // Drop the value in the slot. + self.state.with_mut(|state| { + if *state & PUSHED != 0 { + self.value.with_mut(|slot| unsafe { + let value = &mut *slot; + value.as_mut_ptr().drop_in_place(); + }); + } + }); + } +} diff --git a/ql-runtime/src/io/sync.rs b/ql-runtime/src/io/sync.rs index 0e4ffb18..c5034076 100644 --- a/ql-runtime/src/io/sync.rs +++ b/ql-runtime/src/io/sync.rs @@ -3,7 +3,7 @@ mod inner { pub use std::{ cell::UnsafeCell, sync::{ - atomic::{AtomicU8, AtomicUsize, Ordering}, + atomic::{AtomicUsize, Ordering}, Arc, }, }; @@ -49,17 +49,6 @@ mod inner { f(self.get_mut()) } } - - impl AtomicExt for AtomicU8 { - type Value = u8; - - fn with_mut(&mut self, f: F) -> R - where - F: FnOnce(&mut Self::Value) -> R, - { - f(self.get_mut()) - } - } } #[cfg(all(test, loom))] @@ -67,7 +56,7 @@ mod inner { pub use loom::{ cell::UnsafeCell, sync::{ - atomic::{AtomicU8, AtomicUsize, Ordering}, + atomic::{AtomicUsize, Ordering}, Arc, }, thread::yield_now as busy_wait, @@ -90,7 +79,7 @@ pub(crate) mod loom { } pub(crate) fn shared() -> Arc { - Arc::new(crate::io::inner::new(StreamId(1u32.into()))) + crate::io::inner::new(StreamId(1u32.into())) } pub(crate) fn handle() -> RuntimeHandle { diff --git a/ql-runtime/src/io/writer.rs b/ql-runtime/src/io/writer.rs index 4b1f31ca..cfad3196 100644 --- a/ql-runtime/src/io/writer.rs +++ b/ql-runtime/src/io/writer.rs @@ -8,8 +8,8 @@ use ql_wire::{CloseTarget, StreamCloseCode}; use super::{ inner::{Item, TxInner}, - queue::{PopError, PushError}, - Tx, + slot::PopError, + PushError, Tx, }; use crate::{command::Command, log, QlStreamError, RuntimeHandle}; @@ -62,50 +62,48 @@ impl StreamWriter { return self.poll_terminal(cx); } - loop { - match self.tx.try_write(std::mem::take(bytes)) { - Ok(()) => { - log::trace!( - "byte writer accepted chunk: stream_id={} target={:?}", - self.tx.stream_id(), - self.target - ); - self.poll_runtime(); - return Poll::Ready(Ok(())); - } - Err(PushError::Closed(chunk)) => { - *bytes = chunk; - self.open = false; - return self.poll_terminal(cx); - } - Err(PushError::Full(chunk)) => { - *bytes = chunk; - } + match self.tx.try_write(std::mem::take(bytes)) { + Ok(()) => { + log::trace!( + "byte writer accepted chunk: stream_id={} target={:?}", + self.tx.stream_id(), + self.target + ); + self.poll_runtime(); + return Poll::Ready(Ok(())); } + Err(PushError::Closed(chunk)) => { + *bytes = chunk; + self.open = false; + return self.poll_terminal(cx); + } + Err(PushError::Full(chunk)) => { + *bytes = chunk; + } + } - self.tx.register_waiter(cx.waker()); - - match self.tx.try_write(std::mem::take(bytes)) { - Ok(()) => { - self.tx.unregister_waiter(); - log::trace!( - "byte writer accepted chunk: stream_id={} target={:?}", - self.tx.stream_id(), - self.target - ); - self.poll_runtime(); - return Poll::Ready(Ok(())); - } - Err(PushError::Closed(chunk)) => { - self.tx.unregister_waiter(); - *bytes = chunk; - self.open = false; - return self.poll_terminal(cx); - } - Err(PushError::Full(chunk)) => { - *bytes = chunk; - return Poll::Pending; - } + self.tx.register_waiter(cx.waker()); + + match self.tx.try_write(std::mem::take(bytes)) { + Ok(()) => { + self.tx.unregister_waiter(); + log::trace!( + "byte writer accepted chunk: stream_id={} target={:?}", + self.tx.stream_id(), + self.target + ); + self.poll_runtime(); + Poll::Ready(Ok(())) + } + Err(PushError::Closed(chunk)) => { + self.tx.unregister_waiter(); + *bytes = chunk; + self.open = false; + self.poll_terminal(cx) + } + Err(PushError::Full(chunk)) => { + *bytes = chunk; + Poll::Pending } } } @@ -151,27 +149,25 @@ impl StreamWriter { }); } - fn poll_terminal(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_terminal(&mut self, cx: &Context<'_>) -> Poll> { match &self.terminal { WriterTerminalState::Terminal(result) => return Poll::Ready(result.clone()), WriterTerminalState::Pending => {} } - loop { - match self.try_poll_terminal_ready() { - Poll::Ready(result) => return Poll::Ready(result), - Poll::Pending => {} - } + match self.try_poll_terminal_ready() { + Poll::Ready(result) => return Poll::Ready(result), + Poll::Pending => {} + } - self.tx.register_waiter(cx.waker()); + self.tx.register_waiter(cx.waker()); - match self.try_poll_terminal_ready() { - Poll::Ready(result) => { - self.tx.unregister_waiter(); - return Poll::Ready(result); - } - Poll::Pending => return Poll::Pending, + match self.try_poll_terminal_ready() { + Poll::Ready(result) => { + self.tx.unregister_waiter(); + Poll::Ready(result) } + Poll::Pending => Poll::Pending, } } @@ -191,8 +187,7 @@ impl StreamWriter { Ok(Item::Chunk(_)) => { panic!("writer terminal phase contained chunk data") } - Err(PopError::Empty) => {} - Err(PopError::Closed) => panic!("writer endpoint closed unexpectedly"), + Err(PopError) => {} } } @@ -239,7 +234,7 @@ mod loom_tests { fn poll_write_observes_capacity_racing_with_registration() { check_model(|| { let inner = shared(); - inner.writer.try_write(Bytes::from_static(b"abc")).unwrap(); + inner.tx.try_write(Bytes::from_static(b"abc")).unwrap(); let mut writer = StreamWriter::new(Tx(inner.clone()), CloseTarget::Origin, handle()); let mut bytes = Bytes::from_static(b"xyz"); @@ -248,7 +243,7 @@ mod loom_tests { let drainer = { let inner = inner.clone(); thread::spawn(move || { - assert!(matches!(inner.writer.pop(), Ok(Item::Chunk(_)))); + assert!(matches!(inner.tx.pop(), Ok(Item::Chunk(_)))); }) }; @@ -280,7 +275,7 @@ mod loom_tests { let finisher = { let inner = inner.clone(); thread::spawn(move || { - inner.writer.finish(); + inner.tx.finish(); }) }; From 7a32454e253fc726d4dd41f1b2b1871e52d8f7af Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Sat, 18 Apr 2026 07:11:32 -0400 Subject: [PATCH 277/304] ql-runtime: fix clippy --- ql-runtime/src/tests/mod.rs | 51 +++++++++++++++---------------------- 1 file changed, 20 insertions(+), 31 deletions(-) diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index f2cdb757..48ab79fb 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -102,23 +102,27 @@ struct TestInbound { receiver: Receiver>, } +type TestPlatformParts = ( + TestPlatform, + Receiver>, + Sender>, + Receiver, +); + +type TestPlatformPartsWithInbound = ( + TestPlatform, + Receiver>, + Sender>, + Receiver, + Receiver, +); + impl TestPlatform { - fn new() -> ( - Self, - Receiver>, - Sender>, - Receiver, - ) { + fn new() -> TestPlatformParts { Self::new_inner(None, None, Duration::ZERO, None) } - fn new_with_inbound() -> ( - Self, - Receiver>, - Sender>, - Receiver, - Receiver, - ) { + fn new_with_inbound() -> TestPlatformPartsWithInbound { let (inbound_tx, inbound_rx) = async_channel::unbounded(); let (platform, outbound_rx, inbound_messages_tx, status_rx) = Self::new_inner(Some(inbound_tx), None, Duration::ZERO, None); @@ -133,24 +137,14 @@ impl TestPlatform { fn new_with_session_write_failure( fail_encrypted_write_at: usize, - ) -> ( - Self, - Receiver>, - Sender>, - Receiver, - ) { + ) -> TestPlatformParts { Self::new_inner(None, Some(fail_encrypted_write_at), Duration::ZERO, None) } fn new_with_delayed_writes( delay: Duration, write_stats: WriteStats, - ) -> ( - Self, - Receiver>, - Sender>, - Receiver, - ) { + ) -> TestPlatformParts { Self::new_inner(None, None, delay, Some(write_stats)) } @@ -159,12 +153,7 @@ impl TestPlatform { fail_encrypted_write_at: Option, write_delay: Duration, write_stats: Option, - ) -> ( - Self, - Receiver>, - Sender>, - Receiver, - ) { + ) -> TestPlatformParts { let (outbound, outbound_rx) = async_channel::unbounded(); let (inbound_messages_tx, inbound_messages_rx) = async_channel::unbounded(); let (status, status_rx) = async_channel::unbounded(); From 0b129ea7bd9b6b67a5ffd148a163f2278014c19d Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 20 Apr 2026 05:59:44 -0400 Subject: [PATCH 278/304] ql-wire: unpair --- ql-wire/src/encrypted/builder.rs | 5 +++++ ql-wire/src/encrypted/mod.rs | 11 +++++++++-- ql-wire/src/tests.rs | 15 +++++++++++++-- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index 711d923b..42933235 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -66,6 +66,10 @@ impl SessionRecordBuilder { self.push_empty_frame(super::SessionFrameKind::Ping) } + pub fn push_unpair(&mut self) -> bool { + self.push_empty_frame(super::SessionFrameKind::Unpair) + } + pub fn push_ack(&mut self, ack: &RecordAck) -> bool { self.push_frame_payload(super::SessionFrameKind::Ack, ack) } @@ -89,6 +93,7 @@ impl SessionRecordBuilder { pub fn push_frame(&mut self, frame: &SessionFrame) -> bool { match frame { SessionFrame::Ping => self.push_ping(), + SessionFrame::Unpair => self.push_unpair(), SessionFrame::Ack(frame) => self.push_ack(frame), SessionFrame::StreamData(frame) => self.push_stream_data(frame), SessionFrame::StreamWindow(frame) => self.push_stream_window(frame), diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 1b838ddd..563f9ded 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -23,7 +23,9 @@ pub use stream_window::*; #[derive(Debug, Clone, PartialEq, Eq)] pub enum SessionFrame { + // todo: do we need ping as explicit frame? Ping, + Unpair, Ack(RecordAck), StreamData(StreamData), StreamWindow(StreamWindow), @@ -36,6 +38,7 @@ impl WireDecode for SessionFrame { let kind = reader.decode::()?; let frame = match kind { SessionFrameKind::Ping => Self::Ping, + SessionFrameKind::Unpair => Self::Unpair, SessionFrameKind::Ack => Self::Ack(reader.decode::()?), SessionFrameKind::StreamData => Self::StreamData(reader.decode::>()?), SessionFrameKind::StreamWindow => Self::StreamWindow(reader.decode::()?), @@ -50,6 +53,7 @@ impl SessionFrame { fn kind(&self) -> SessionFrameKind { match self { Self::Ping => SessionFrameKind::Ping, + Self::Unpair => SessionFrameKind::Unpair, Self::Ack(_) => SessionFrameKind::Ack, Self::StreamData(_) => SessionFrameKind::StreamData, Self::StreamWindow(_) => SessionFrameKind::StreamWindow, @@ -63,6 +67,7 @@ impl SessionFrame { pub fn into_owned(self) -> SessionFrame> { match self { Self::Ping => SessionFrame::Ping, + Self::Unpair => SessionFrame::Unpair, Self::Ack(frame) => SessionFrame::Ack(frame), Self::StreamData(frame) => SessionFrame::StreamData(frame.into_owned()), Self::StreamWindow(frame) => SessionFrame::StreamWindow(frame), @@ -75,7 +80,7 @@ impl SessionFrame { impl WireEncode for SessionFrame { fn encoded_len(&self) -> usize { 1 + match self { - Self::Ping => 0, + Self::Ping | Self::Unpair => 0, Self::Ack(frame) => frame.encoded_len(), Self::StreamData(frame) => frame.encoded_len(), Self::StreamWindow(frame) => frame.encoded_len(), @@ -87,7 +92,7 @@ impl WireEncode for SessionFrame { fn encode(&self, out: &mut W) { out.put_u8(self.kind() as u8); match self { - Self::Ping => {} + Self::Ping | Self::Unpair => {} Self::Ack(frame) => frame.encode(out), Self::StreamData(frame) => frame.encode(out), Self::StreamWindow(frame) => frame.encode(out), @@ -106,6 +111,7 @@ pub enum SessionFrameKind { StreamWindow = 4, StreamClose = 5, Close = 6, + Unpair = 7, } impl TryFrom for SessionFrameKind { @@ -119,6 +125,7 @@ impl TryFrom for SessionFrameKind { 4 => Ok(Self::StreamWindow), 5 => Ok(Self::StreamClose), 6 => Ok(Self::Close), + 7 => Ok(Self::Unpair), _ => Err(WireError::InvalidPayload), } } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index fe67a4d1..9bbcb9cd 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -659,6 +659,7 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { }; let body = vec![ SessionFrame::Ping, + SessionFrame::Unpair, SessionFrame::Ack( RecordAck::from_ranges([record_ack_range(20, 23), record_ack_range(12, 13)]).unwrap(), ), @@ -852,13 +853,22 @@ fn protocol_record_size_breakdown() { RecordAck::from_ranges([record_ack_range(6, 6), record_ack_range(1, 2)]).unwrap(), )], ); - let session_stream_empty = encrypt_record( + let session_unpair = encrypt_record( &crypto, SessionHeader { connection_id: session.tx_connection_id, seq: record_seq(3), }, &session.tx_key, + &[SessionFrame::Unpair], + ); + let session_stream_empty = encrypt_record( + &crypto, + SessionHeader { + connection_id: session.tx_connection_id, + seq: record_seq(4), + }, + &session.tx_key, &[SessionFrame::StreamData(StreamData { stream_id: stream_id(1), offset: varint(0), @@ -871,7 +881,7 @@ fn protocol_record_size_breakdown() { &crypto, SessionHeader { connection_id: session.tx_connection_id, - seq: record_seq(4), + seq: record_seq(5), }, &session.tx_key, &[SessionFrame::Close(SessionClose { @@ -892,6 +902,7 @@ fn protocol_record_size_breakdown() { print_size("ql-wire pq xx4", xx4.encode_vec().len()); print_size("ql-wire session ping", session_ping.encode_vec().len()); print_size("ql-wire session ack", session_ack.encode_vec().len()); + print_size("ql-wire session unpair", session_unpair.encode_vec().len()); print_size( "ql-wire session stream empty", session_stream_empty.encode_vec().len(), From 1bbd6c4fd70daad1668a082377d6b4c1804c171e Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 20 Apr 2026 06:53:34 -0400 Subject: [PATCH 279/304] ql-fsm: unpair --- ql-fsm/src/error.rs | 5 +- ql-fsm/src/fsm.rs | 127 ++++++++++++++++++------------- ql-fsm/src/handshake/ik.rs | 2 +- ql-fsm/src/handshake/kk.rs | 2 +- ql-fsm/src/handshake/mod.rs | 4 +- ql-fsm/src/handshake/xx.rs | 2 +- ql-fsm/src/lib.rs | 18 +++-- ql-fsm/src/session/mod.rs | 75 +++++++++++------- ql-fsm/src/session/state.rs | 8 +- ql-fsm/src/session/stream_ops.rs | 2 +- ql-fsm/src/session/tests.rs | 61 +++++++++++++-- ql-fsm/src/tests/proptest.rs | 13 ++++ ql-fsm/src/tests/session.rs | 65 ++++++++++++++++ 13 files changed, 283 insertions(+), 101 deletions(-) diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs index 82a79d7b..3ded1ffc 100644 --- a/ql-fsm/src/error.rs +++ b/ql-fsm/src/error.rs @@ -32,7 +32,10 @@ impl Display for ReceiveError { Self::NoSession => f.write_str("no active session"), Self::NotPairingMode => f.write_str("not in pairing mode"), Self::InvalidPairingId { expected, actual } => { - write!(f, "invalid pairing id: expected {expected}, actual {actual}") + write!( + f, + "invalid pairing id: expected {expected}, actual {actual}" + ) } Self::Replay => f.write_str("replay"), } diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs index 3eae2459..6a32b821 100644 --- a/ql-fsm/src/fsm.rs +++ b/ql-fsm/src/fsm.rs @@ -8,18 +8,31 @@ use ql_wire::{self as wire, QlCrypto, RouteId, SessionCloseCode, StreamId, WireD use crate::{ handshake, - session::{EventSink, SessionEvent}, + session::{self, SessionEvent, TerminalFrame}, state::LinkState, Event, NoPeerError, NoSessionError, OutboundWrite, QlFsm, ReceiveError, StreamError, WriteId, }; -pub struct FsmEventEmitter<'a> { +pub struct EventSink<'a> { events: &'a mut VecDeque, + termination: Option, } -impl EventSink for FsmEventEmitter<'_> { +impl<'a> EventSink<'a> { + fn new(events: &'a mut VecDeque) -> Self { + Self { + events, + termination: None, + } + } +} + +impl session::EventSink for EventSink<'_> { fn emit(&mut self, event: SessionEvent) { match event { + SessionEvent::Unpaired => { + self.termination = Some(TerminalFrame::Unpair); + } SessionEvent::Opened { stream_id, route_id, @@ -48,6 +61,7 @@ impl EventSink for FsmEventEmitter<'_> { self.events.push_back(Event::WritableClosed(frame)); } SessionEvent::SessionClosed(close) => { + self.termination = Some(TerminalFrame::Close(close.clone())); self.events.push_back(Event::SessionClosed(close)); } } @@ -60,6 +74,24 @@ pub fn handle_bind_peer(fsm: &mut QlFsm, peer: ql_wire::PeerBundle) { fsm.state.peer = Some(peer); } +pub fn unpair(fsm: &mut QlFsm) { + let had_peer = fsm.state.peer.is_some(); + fsm.state.handshake = None; + fsm.state.armed_pairing_token = None; + + if let Some(conn) = fsm.state.link.connected_mut() { + let mut emit = EventSink::new(&mut fsm.events); + conn.session.unpair(&mut emit); + } else { + fsm.state.link = LinkState::Idle; + } + + if had_peer { + emit_peer_status(fsm, crate::PeerStatus::Unpaired); + } + fsm.state.peer = None; +} + pub fn handle_disarm_pairing(fsm: &mut QlFsm) { fsm.state.armed_pairing_token = None; handshake::handle_disarm_pairing(fsm); @@ -95,32 +127,40 @@ pub fn receive( handshake::handle_handshake_record(fsm, crypto, &record) } wire::RecordType::Session => { - let QlFsm { state, events, .. } = fsm; - let conn = state.link.connected_mut_or_err()?; - let (decrypt_len, seq) = { - let record = wire::QlSessionRecord::decode(&mut reader)?; - if record.header.connection_id != conn.transport.rx_connection_id { - return Err(ReceiveError::InvalidPayload); - } - let payload = wire::decrypt_record( - crypto, - &record.header, - record.payload, - &conn.transport.rx_key, - )?; - (payload.len(), record.header.seq) + let termination = { + let QlFsm { state, events, .. } = fsm; + let conn = state.link.connected_mut_or_err()?; + let (decrypt_len, seq) = { + let record = wire::QlSessionRecord::decode(&mut reader)?; + if record.header.connection_id != conn.transport.rx_connection_id { + return Err(ReceiveError::InvalidPayload); + } + let payload = wire::decrypt_record( + crypto, + &record.header, + record.payload, + &conn.transport.rx_key, + )?; + (payload.len(), record.header.seq) + }; + + let len = bytes.len(); + let plaintext = Bytes::from(bytes).slice(len - decrypt_len..); + let frames = wire::parse_session_frames(plaintext); + + let mut emit = EventSink::new(events); + conn.session + .receive(state.now.instant, seq, frames, &mut emit); + emit.termination }; - let len = bytes.len(); - let plaintext = Bytes::from(bytes).slice(len - decrypt_len..); - let frames = wire::parse_session_frames(plaintext); - - let mut emit = FsmEventEmitter { events }; - conn.session - .receive(state.now.instant, seq, frames, &mut emit); - - if conn.session.is_closed() { - apply_session_closed(fsm); + if matches!(termination, Some(TerminalFrame::Unpair)) { + if fsm.state.peer.is_some() { + emit_peer_status(fsm, crate::PeerStatus::Unpaired); + } + fsm.state.handshake = None; + fsm.state.armed_pairing_token = None; + fsm.state.peer = None; } Ok(()) } @@ -135,12 +175,8 @@ pub fn on_timer(fsm: &mut QlFsm) { return; }; - let mut emit = FsmEventEmitter { events }; + let mut emit = EventSink::new(events); conn.session.on_timer(state.now.instant, &mut emit); - - if conn.session.is_closed() { - apply_session_closed(fsm); - } } pub fn next_deadline(fsm: &QlFsm) -> Option { @@ -174,8 +210,9 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option Result, NoSessionError> { let QlFsm { state, events, .. } = fsm; let conn = state.link.connected_mut_or_err()?; - let inner = conn - .session - .open_stream(route_id, FsmEventEmitter { events })?; + let inner = conn.session.open_stream(route_id, EventSink::new(events))?; Ok(crate::StreamOps { inner }) } pub fn stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result, StreamError> { let QlFsm { state, events, .. } = fsm; let conn = state.link.connected_mut_or_err()?; - let inner = conn.session.stream(stream_id, FsmEventEmitter { events })?; + let inner = conn.session.stream(stream_id, EventSink::new(events))?; Ok(crate::StreamOps { inner }) } @@ -228,18 +263,8 @@ pub fn poll_event(fsm: &mut QlFsm) -> Option { fsm.events.pop_front() } -pub fn emit_peer_status(fsm: &mut QlFsm) { - if fsm.state.peer.is_some() { - fsm.events - .push_back(Event::PeerStatusChanged(fsm.state.link.status())); - } -} - -fn apply_session_closed(fsm: &mut QlFsm) { - if matches!(fsm.state.link, LinkState::Connected(_)) { - fsm.state.link = LinkState::Idle; - emit_peer_status(fsm); - } +pub fn emit_peer_status(fsm: &mut QlFsm, status: crate::PeerStatus) { + fsm.events.push_back(Event::PeerStatusChanged(status)); } pub fn deadline_after_secs(now_secs: u64, duration: Duration) -> u64 { diff --git a/ql-fsm/src/handshake/ik.rs b/ql-fsm/src/handshake/ik.rs index 06292816..0566335c 100644 --- a/ql-fsm/src/handshake/ik.rs +++ b/ql-fsm/src/handshake/ik.rs @@ -26,7 +26,7 @@ pub fn start_initiator(fsm: &mut QlFsm, crypto: &impl QlCrypto, peer: PeerBundle deadline: fsm.state.now.instant + fsm.config.handshake_timeout, }); enqueue_handshake(fsm, QlHandshakeRecord::Ik1(message)); - emit_peer_status(fsm); + emit_peer_status(fsm, fsm.state.link.status()); } pub fn handle_ik1( diff --git a/ql-fsm/src/handshake/kk.rs b/ql-fsm/src/handshake/kk.rs index b46f612a..b7192c5b 100644 --- a/ql-fsm/src/handshake/kk.rs +++ b/ql-fsm/src/handshake/kk.rs @@ -26,7 +26,7 @@ pub fn start_initiator(fsm: &mut QlFsm, crypto: &impl QlCrypto, peer: PeerBundle deadline: fsm.state.now.instant + fsm.config.handshake_timeout, }); enqueue_handshake(fsm, QlHandshakeRecord::Kk1(message)); - emit_peer_status(fsm); + emit_peer_status(fsm, fsm.state.link.status()); } pub fn handle_kk1( diff --git a/ql-fsm/src/handshake/mod.rs b/ql-fsm/src/handshake/mod.rs index 983c633f..8aee523f 100644 --- a/ql-fsm/src/handshake/mod.rs +++ b/ql-fsm/src/handshake/mod.rs @@ -94,7 +94,7 @@ pub fn handle_timer(fsm: &mut QlFsm) { fsm.state.link = LinkState::Idle; fsm.state.handshake = None; - emit_peer_status(fsm); + emit_peer_status(fsm, fsm.state.link.status()); } pub fn next_handshake_deadline(fsm: &QlFsm) -> Option { @@ -136,7 +136,7 @@ pub fn finish_handshake( fsm.state.now.instant, ); fsm.state.link = LinkState::Connected(ConnectedState { transport, session }); - emit_peer_status(fsm); + emit_peer_status(fsm, fsm.state.link.status()); Ok(()) } diff --git a/ql-fsm/src/handshake/xx.rs b/ql-fsm/src/handshake/xx.rs index a08682e0..a03594ed 100644 --- a/ql-fsm/src/handshake/xx.rs +++ b/ql-fsm/src/handshake/xx.rs @@ -26,7 +26,7 @@ pub fn start_initiator(fsm: &mut QlFsm, crypto: &impl QlCrypto, token: PairingTo deadline: fsm.state.now.instant + fsm.config.handshake_timeout, }); enqueue_handshake(fsm, QlHandshakeRecord::Xx1(message)); - emit_peer_status(fsm); + emit_peer_status(fsm, fsm.state.link.status()); } pub fn handle_xx1( diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 2b2fa1e6..c824b7ed 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -63,6 +63,11 @@ pub enum PeerStatus { Initiator, /// the encrypted session is up Connected, + /// the bound peer was forgotten immediately + /// + /// unpair is abortive and best-effort. the binding is removed immediately + /// and one final write may remain: a record containing only `SessionFrame::Unpair` + Unpaired, } /// events emitted by `QlFsm` @@ -70,7 +75,7 @@ pub enum PeerStatus { pub enum Event { /// a peer was learned during handshake completion NewPeer, - /// the peer changed connection state + /// the peer changed lifecycle state PeerStatusChanged(PeerStatus), /// a stream was opened Opened { @@ -111,7 +116,7 @@ pub struct OutboundWrite { } pub struct StreamOps<'a> { - inner: session::StreamOps<'a, fsm::FsmEventEmitter<'a>>, + inner: session::StreamOps<'a, fsm::EventSink<'a>>, } impl StreamOps<'_> { @@ -315,14 +320,15 @@ impl QlFsm { } /// closes the current encrypted session locally - /// - /// This transition is abortive and best-effort. It ends normal session use immediately and - /// may emit one final outbound close record, but it does not wait for the peer to acknowledge - /// that close. pub fn close_session(&mut self, code: SessionCloseCode) { fsm::close_session(self, code); } + /// forgets the bound peer locally and may emit one final outbound `SessionFrame::Unpair` + pub fn unpair(&mut self) { + fsm::unpair(self); + } + /// opens a new outgoing stream pub fn open_stream(&mut self, route_id: RouteId) -> Result, NoSessionError> { fsm::open_stream(self, route_id) diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 9415f4e5..912c7bd4 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -1,4 +1,4 @@ -pub use self::{stream_ops::*, stream_parity::*, stream_rx::*}; +pub use self::{state::TerminalFrame, stream_ops::*, stream_parity::*, stream_rx::*}; mod ack_tracker; mod range_set; @@ -78,6 +78,7 @@ pub enum SessionEvent { Closed(StreamClose), WritableClosed(StreamClose), SessionClosed(SessionClose), + Unpaired, } pub trait EventSink { @@ -178,39 +179,36 @@ impl SessionFsm { Ok(()) } - pub(crate) fn close(&mut self, code: SessionCloseCode, sink: &mut impl EventSink) { + pub fn close(&mut self, code: SessionCloseCode, sink: &mut impl EventSink) { if self.state.phase != SessionPhase::Open { return; } - let close = SessionClose { code }; - self.state.phase = SessionPhase::Closing(close.clone()); - self.state.tracked_records.clear(); - self.state.ack_tracker.clear_ack_state(); - self.clear_streams(); - sink.emit(SessionEvent::SessionClosed(close)); + self.begin_termination(TerminalFrame::Close(SessionClose { code }), sink); + } + + pub fn unpair(&mut self, sink: &mut impl EventSink) { + if self.state.phase != SessionPhase::Open { + return; + } + + self.begin_termination(TerminalFrame::Unpair, sink); } - pub(crate) fn is_closed(&self) -> bool { + pub fn is_closed(&self) -> bool { self.state.phase == SessionPhase::Closed } - pub(crate) fn receive( - &mut self, - now: Instant, - seq: RecordSeq, - frames: I, - sink: &mut impl EventSink, - ) where + pub fn receive(&mut self, now: Instant, seq: RecordSeq, frames: I, sink: &mut impl EventSink) + where I: IntoIterator, WireError>>, { - self.state.last_activity_at = now; - self.state.last_inbound_at = now; - if self.state.phase != SessionPhase::Open { return; } + self.state.last_activity_at = now; + self.state.last_inbound_at = now; self.collect_timeouts(now); match self.state.ack_tracker.insert(seq) { @@ -223,7 +221,6 @@ impl SessionFsm { } let mut ack_eliciting = false; - let mut handled_close = false; for frame in frames { let Ok(frame) = frame else { @@ -233,6 +230,10 @@ impl SessionFsm { ack_eliciting |= !matches!(frame, SessionFrame::Ack(_)); match frame { SessionFrame::Ping => {} + SessionFrame::Unpair => { + self.unpair(sink); + return; + } SessionFrame::Ack(ack) => self.process_record_ack(&ack, sink), SessionFrame::StreamData(frame) => { if self.handle_stream_data(frame, sink).is_err() { @@ -249,16 +250,11 @@ impl SessionFsm { } SessionFrame::Close(close) => { self.close(close.code, sink); - handled_close = true; - break; + return; } } } - if handled_close { - return; - } - if ack_eliciting { self.schedule_ack(now, false); } @@ -351,16 +347,25 @@ impl SessionFsm { } pub fn has_shutdown_work(&self) -> bool { - self.state.ack_tracker.ack_deadline().is_some() || !self.state.tracked_records.is_empty() + matches!(self.state.phase, SessionPhase::Terminating(_)) + || self.state.ack_tracker.ack_deadline().is_some() + || !self.state.tracked_records.is_empty() } pub fn take_next_write(&mut self, now: Instant) -> Option<(Option, SessionRecordBuilder)> { match &self.state.phase { - SessionPhase::Closing(close) => { + SessionPhase::Terminating(frame) => { let seq = self.state.next_record_seq; next_seq(&mut self.state.next_record_seq); let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); - assert!(builder.push_close(close), "builder has capacity"); + match frame { + TerminalFrame::Close(close) => { + assert!(builder.push_close(close), "builder has capacity"); + } + TerminalFrame::Unpair => { + assert!(builder.push_unpair(), "builder has capacity"); + } + } self.state.phase = SessionPhase::Closed; return Some((None, builder)); } @@ -426,6 +431,18 @@ impl SessionFsm { Some((builder, outbound)) } + fn begin_termination(&mut self, frame: TerminalFrame, sink: &mut impl EventSink) { + match &frame { + TerminalFrame::Close(close) => sink.emit(SessionEvent::SessionClosed(close.clone())), + TerminalFrame::Unpair => sink.emit(SessionEvent::Unpaired), + } + + self.state.phase = SessionPhase::Terminating(frame); + self.state.tracked_records.clear(); + self.state.ack_tracker.clear_ack_state(); + self.clear_streams(); + } + fn push_next_pending_stream_close( &mut self, builder: &mut SessionRecordBuilder, diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs index c9e7a8ca..b63140a1 100644 --- a/ql-fsm/src/session/state.rs +++ b/ql-fsm/src/session/state.rs @@ -26,7 +26,7 @@ pub struct SessionState { #[derive(Debug, Clone, PartialEq, Eq)] pub enum SessionPhase { Open, - Closing(SessionClose), + Terminating(TerminalFrame), Closed, } @@ -36,6 +36,12 @@ impl SessionPhase { } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TerminalFrame { + Close(SessionClose), + Unpair, +} + #[derive(Debug)] pub struct StreamState { pub role: StreamRole, diff --git a/ql-fsm/src/session/stream_ops.rs b/ql-fsm/src/session/stream_ops.rs index 30bd7a7c..548189b7 100644 --- a/ql-fsm/src/session/stream_ops.rs +++ b/ql-fsm/src/session/stream_ops.rs @@ -3,7 +3,7 @@ use ql_wire::{CloseTarget, StreamClose, StreamCloseCode, StreamId}; use super::{ state::{InboundState, StreamState}, stream_rx::StreamReadIter, - SessionEvent, EventSink, SessionFsm, + EventSink, SessionEvent, SessionFsm, }; use crate::CommitReadError; diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 681523b5..c84df3a5 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -69,9 +69,7 @@ fn read_stream_all_with_events( stream_id: StreamId, events: &mut Vec, ) -> Vec { - let mut stream = fsm - .stream(stream_id, |event| events.push(event)) - .unwrap(); + let mut stream = fsm.stream(stream_id, |event| events.push(event)).unwrap(); let out = stream.read().flatten().collect::>(); stream.commit_read(out.len()).unwrap(); out @@ -231,7 +229,11 @@ fn ack_of_fin_emits_outbound_finished_once() { let stream_id = open_stream_id(&mut fsm); assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"done"), 4); - fsm.stream(stream_id, |_| {}).unwrap().writer().unwrap().finish(); + fsm.stream(stream_id, |_| {}) + .unwrap() + .writer() + .unwrap() + .finish(); let (record_seq, record) = next_outbound(&mut fsm, now).unwrap(); assert!(matches!( @@ -307,7 +309,10 @@ fn commit_stream_read_is_what_advances_stream_window() { assert!(next_outbound(&mut fsm, now + Duration::from_millis(2)).is_none()); - fsm.stream(stream_id, |_| {}).unwrap().commit_read(2).unwrap(); + fsm.stream(stream_id, |_| {}) + .unwrap() + .commit_read(2) + .unwrap(); let (_second_seq, second) = next_outbound(&mut fsm, now + Duration::from_millis(3)).unwrap(); assert!(matches!( second.as_slice(), @@ -341,7 +346,10 @@ fn pure_ack_only_records_are_fire_and_forget() { assert!(matches!(ack.as_slice(), [SessionFrame::Ack(_)])); let mut emit = |_| {}; - fsm.on_timer(now + retransmit_timeout + Duration::from_millis(1), &mut emit); + fsm.on_timer( + now + retransmit_timeout + Duration::from_millis(1), + &mut emit, + ); assert!(fsm .take_next_write(now + retransmit_timeout + Duration::from_millis(1)) .is_none()); @@ -387,7 +395,10 @@ fn inbound_empty_fin_emits_finished_immediately() { })]; let events = receive_events(&mut fsm, now, seq(0), &record); - assert_eq!(events, vec![opened(stream_id), SessionEvent::Finished(stream_id)]); + assert_eq!( + events, + vec![opened(stream_id), SessionEvent::Finished(stream_id)] + ); } #[test] @@ -697,6 +708,42 @@ fn close_does_not_ack_rejected_record_seq() { assert!(matches!(outbound.as_slice(), [SessionFrame::Close(_)])); } +#[test] +fn inbound_unpair_emits_final_unpair_frame() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + + let events = receive_events(&mut fsm, now, seq(1), &[SessionFrame::Unpair]); + assert_eq!(events, vec![SessionEvent::Unpaired]); + assert!(!fsm.is_closed()); + + let (_seq, outbound) = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); + assert!(matches!(outbound.as_slice(), [SessionFrame::Unpair])); + assert!(fsm.is_closed()); +} + +#[test] +fn terminating_session_ignores_inbound_frames() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + + let mut events = Vec::new(); + fsm.unpair(&mut |event| events.push(event)); + assert_eq!(events, vec![SessionEvent::Unpaired]); + + let ignored = receive_events( + &mut fsm, + now + Duration::from_millis(1), + seq(1), + &[SessionFrame::Ping], + ); + assert!(ignored.is_empty()); + + let (_seq, outbound) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); + assert!(matches!(outbound.as_slice(), [SessionFrame::Unpair])); + assert!(fsm.is_closed()); +} + #[test] fn initial_peer_stream_receive_window_limits_first_send() { let now = Instant::now(); diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index 6bd196ef..c383349c 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -412,6 +412,19 @@ impl Runner { match event { Event::NewPeer => {} Event::PeerStatusChanged(status) => { + if status == PeerStatus::Unpaired { + let state = &mut self.events[side.idx()]; + prop_assert!( + state.session_epoch > 0, + "side {side:?} emitted Unpaired without a connected session" + ); + prop_assert!( + state.session_closed_epoch != Some(state.session_epoch), + "side {side:?} emitted duplicate terminal event in session epoch {}", + state.session_epoch + ); + state.session_closed_epoch = Some(state.session_epoch); + } self.events[side.idx()].note_peer_status(status); } Event::Opened { stream_id, .. } => { diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index cc4b79e1..86d0ce85 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -383,6 +383,71 @@ fn close_session_disconnects_locally() { ); } +#[test] +fn unpair_clears_bound_peer_and_emits_unpair_frame() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + harness.a.fsm.unpair(); + + assert_eq!( + harness.take_event(Side::A), + Some(Event::PeerStatusChanged(PeerStatus::Unpaired)) + ); + assert!(harness.a.fsm.peer().is_none()); + assert!(matches!( + harness.a.fsm.open_stream(route_id(1)), + Err(NoSessionError) + )); + assert_eq!(harness.a.fsm.queue_ping(), Err(NoSessionError)); + + let unpair = harness.next_decoded_outbound(Side::A).unwrap(); + assert!(matches!( + unpair.frames.as_slice(), + [ql_wire::SessionFrame::Unpair] + )); + assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); +} + +#[test] +fn inbound_unpair_clears_remote_peer_binding() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + harness.a.fsm.unpair(); + let unpair = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, unpair); + + assert_eq!( + harness.take_event(Side::B), + Some(Event::PeerStatusChanged(PeerStatus::Unpaired)) + ); + assert!(harness.b.fsm.peer().is_none()); + assert!(matches!( + harness.b.fsm.open_stream(route_id(1)), + Err(NoSessionError) + )); + assert!(matches!(harness.connect_ik(Side::B), Err(NoPeerError))); + + let reply_key = harness.b.fsm.state.link.transport().unwrap().tx_key.clone(); + let reply = harness.next_outbound(Side::B).unwrap(); + let (_header, frames) = decrypt_record(&harness.b.crypto, &reply, &reply_key); + assert!(matches!(frames.as_slice(), [ql_wire::SessionFrame::Unpair])); + assert!(matches!(harness.b.fsm.state.link, LinkState::Idle)); +} + +#[test] +fn local_unpair_without_session_emits_unpaired_immediately() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.a.fsm.unpair(); + + assert_eq!( + harness.take_event(Side::A), + Some(Event::PeerStatusChanged(PeerStatus::Unpaired)) + ); + assert!(harness.a.fsm.peer().is_none()); + assert_eq!(harness.take_event(Side::A), None); +} + #[test] fn session_records_contain_ack_frames_after_delivery() { let config = QlFsmConfig::default(); From 5bd44cdd1211b7c4577ec0ce9dc525249fe95494 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 20 Apr 2026 08:30:44 -0400 Subject: [PATCH 280/304] ql-runtime: expose session close and unpair --- ql-runtime/src/command.rs | 10 +- ql-runtime/src/driver/mod.rs | 23 +++- ql-runtime/src/driver/test.rs | 31 ++++- ql-runtime/src/handle/mod.rs | 12 +- ql-runtime/src/io/inner.rs | 5 +- ql-runtime/src/io/mod.rs | 3 +- ql-runtime/src/io/slot.rs | 1 - ql-runtime/src/platform.rs | 2 +- ql-runtime/src/tests/handshake.rs | 14 +- ql-runtime/src/tests/heartbeat.rs | 69 ---------- ql-runtime/src/tests/mod.rs | 23 ++-- ql-runtime/src/tests/session.rs | 213 ++++++++++++++++++++++++++++++ ql-runtime/src/tests/stream.rs | 8 +- ql-runtime/src/tests/unpair.rs | 61 --------- 14 files changed, 302 insertions(+), 173 deletions(-) delete mode 100644 ql-runtime/src/tests/heartbeat.rs create mode 100644 ql-runtime/src/tests/session.rs delete mode 100644 ql-runtime/src/tests/unpair.rs diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index 7fc835d5..07d3c7a5 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -1,5 +1,7 @@ use ql_fsm::NoSessionError; -use ql_wire::{CloseTarget, PairingToken, PeerBundle, RouteId, StreamCloseCode, StreamId}; +use ql_wire::{ + CloseTarget, PairingToken, PeerBundle, RouteId, SessionCloseCode, StreamCloseCode, StreamId, +}; use crate::{StreamReader, StreamWriter}; @@ -25,6 +27,10 @@ pub enum Command { PollStream { stream_id: StreamId, }, + CloseSession { + code: SessionCloseCode, + }, + Unpair, CloseStream { stream_id: StreamId, target: CloseTarget, @@ -43,6 +49,8 @@ impl Command { Self::OpenStream { .. } => "OpenStream", Self::PollInbound { .. } => "PollInbound", Self::PollStream { .. } => "PollStream", + Self::CloseSession { .. } => "CloseSession", + Self::Unpair => "Unpair", Self::CloseStream { .. } => "CloseStream", } } diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index ccf5f4bf..47233493 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -202,6 +202,14 @@ impl DriverState { log::info!(" starting XX pairing"); fsm.connect_xx(now(), token, platform); } + Command::CloseSession { code } => { + log::info!("closing session: code={code:?}"); + fsm.close_session(code); + } + Command::Unpair => { + log::info!("unpairing peer"); + fsm.unpair(); + } Command::OpenStream { route_id, start } => { log::info!("open stream requested: route_id={route_id}"); let Some(runtime_tx) = self.runtime_tx.upgrade() else { @@ -234,10 +242,7 @@ impl DriverState { Some(InboundIo::new(reader_io)), ), ); - if start - .send(Ok((stream_id, reader, writer))) - .is_err() - { + if start.send(Ok((stream_id, reader, writer))).is_err() { log::warn!("open stream cancelled before delivery: stream_id={stream_id}"); if let Some(stream) = self.streams.get_mut(&stream_id) { stream.inbound_close(); @@ -300,10 +305,14 @@ impl DriverState { } } Event::PeerStatusChanged(status) => { - log::info!("peer status changed: status={status:?}"); - if let Some(peer) = fsm.peer().map(|peer| peer.xid) { - platform.handle_peer_status(peer, status); + let peer = fsm.peer().map(|peer| peer.xid); + log::info!("peer status changed: peer={peer:?} status={status:?}"); + if status == ql_fsm::PeerStatus::Unpaired { + for (_, mut stream) in self.streams.drain() { + stream.fail_all(); + } } + platform.handle_peer_status(peer, status); } Event::Opened { stream_id, diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index 6f3840ac..bb93203d 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -37,7 +37,7 @@ impl QlPlatform for NoopCrypto { fn persist_peer(&self, _peer: PeerBundle) {} - fn handle_peer_status(&self, _peer: XID, _status: ql_fsm::PeerStatus) {} + fn handle_peer_status(&self, _peer: Option, _status: ql_fsm::PeerStatus) {} fn handle_inbound(&self, _event: QlStream) {} } @@ -176,3 +176,32 @@ fn local_close_command_reaps_when_other_half_is_already_closed() { assert!(!state.streams.contains_key(&stream_id)); } + +#[test] +fn unpaired_status_fails_and_reaps_all_streams() { + let (mut state, mut fsm) = new_driver_state(); + let peer = test_identity(&SoftwareCrypto).bundle(); + let stream_id = StreamId(1u32.into()); + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + let (_, _, reader_io, writer_io) = io::new_stream( + stream_id, + CloseTarget::Origin, + CloseTarget::Return, + RuntimeHandle::new(runtime_tx), + ); + + state.streams.insert( + stream_id, + DriverStreamIo::new( + false, + Some(OutboundIo::new(writer_io)), + Some(InboundIo::new(reader_io)), + ), + ); + fsm.bind_peer(peer); + fsm.unpair(); + + state.drain_fsm_events(&mut fsm, &NoopCrypto); + + assert!(state.streams.is_empty()); +} diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index e2c45314..e98ce584 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -1,5 +1,5 @@ use ql_fsm::NoSessionError; -use ql_wire::{PairingToken, PeerBundle, RouteId, StreamId}; +use ql_wire::{PairingToken, PeerBundle, RouteId, SessionCloseCode, StreamId}; use crate::command::Command; pub use crate::io::{StreamReader, StreamWriter}; @@ -43,6 +43,16 @@ impl RuntimeHandle { self.send(Command::StartPairing { token }); } + /// closes the current encrypted session + pub fn close_session(&self, code: SessionCloseCode) { + self.send(Command::CloseSession { code }); + } + + /// forgets the currently bound peer and initiates session unpairing if connected + pub fn unpair(&self) { + self.send(Command::Unpair); + } + /// opens a new stream on the active encrypted session pub async fn open_stream(&self, route_id: RouteId) -> Result { let (start_tx, start_rx) = oneshot::channel(); diff --git a/ql-runtime/src/io/inner.rs b/ql-runtime/src/io/inner.rs index bf4c45ae..64df6ced 100644 --- a/ql-runtime/src/io/inner.rs +++ b/ql-runtime/src/io/inner.rs @@ -64,10 +64,7 @@ impl RxInner { } /// stores a terminal reader error - pub fn fail( - &self, - error: QlStreamError, - ) -> Option { + pub fn fail(&self, error: QlStreamError) -> Option { let displaced = self.slot.force_push(Item::Error(error)); self.changed.notify(); displaced_bytes(displaced) diff --git a/ql-runtime/src/io/mod.rs b/ql-runtime/src/io/mod.rs index 8fa5cff4..2eb7f0f0 100644 --- a/ql-runtime/src/io/mod.rs +++ b/ql-runtime/src/io/mod.rs @@ -8,8 +8,7 @@ use std::ops::Deref; use ql_wire::{CloseTarget, StreamId}; -pub use self::slot::PushError; -pub use self::{reader::StreamReader, writer::StreamWriter}; +pub use self::{reader::StreamReader, slot::PushError, writer::StreamWriter}; use crate::RuntimeHandle; pub struct Rx(sync::Arc); diff --git a/ql-runtime/src/io/slot.rs b/ql-runtime/src/io/slot.rs index e73180db..f71f1b0c 100644 --- a/ql-runtime/src/io/slot.rs +++ b/ql-runtime/src/io/slot.rs @@ -158,7 +158,6 @@ impl Slot { pub fn is_empty_state(state: usize) -> bool { state & PUSHED == 0 } - } impl Drop for Slot { diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs index 2dd974b1..a0020789 100644 --- a/ql-runtime/src/platform.rs +++ b/ql-runtime/src/platform.rs @@ -37,7 +37,7 @@ pub trait QlPlatform: QlCrypto { fn persist_peer(&self, peer: PeerBundle); - fn handle_peer_status(&self, peer: XID, status: PeerStatus); + fn handle_peer_status(&self, peer: Option, status: PeerStatus); fn handle_inbound(&self, event: QlStream); fn handle_recv_error(&self, _error: ReceiveError) {} } diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index 2963c961..d641beee 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -48,7 +48,7 @@ async fn handshake_timeout_disconnects() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, identity_b.xid, PeerStatus::Disconnected).await; + await_status(&status_a, Some(identity_b.xid), PeerStatus::Disconnected).await; }) .await; } @@ -75,8 +75,8 @@ async fn rejected_session_write_is_reissued() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; - await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; + await_status(&status_a, Some(identity_b.xid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.xid), PeerStatus::Connected).await; let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); @@ -104,7 +104,7 @@ async fn rejected_session_write_is_reissued() { assert_no_status_for( &status_a, - identity_b.xid, + Some(identity_b.xid), PeerStatus::Disconnected, Duration::from_millis(150), ) @@ -134,8 +134,8 @@ async fn start_pairing_round_trip_connects_when_armed() { handle_b.arm_pairing(token); handle_a.start_pairing(token); - await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; - await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; + await_status(&status_a, Some(identity_b.xid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.xid), PeerStatus::Connected).await; }) .await; } @@ -162,7 +162,7 @@ async fn start_pairing_does_not_connect_when_unarmed() { assert_no_status_for( &status_a, - identity_b.xid, + Some(identity_b.xid), PeerStatus::Connected, Duration::from_millis(150), ) diff --git a/ql-runtime/src/tests/heartbeat.rs b/ql-runtime/src/tests/heartbeat.rs deleted file mode 100644 index 77af8b00..00000000 --- a/ql-runtime/src/tests/heartbeat.rs +++ /dev/null @@ -1,69 +0,0 @@ -use std::{ - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - time::Duration, -}; - -use super::*; -use crate::QlStreamError; - -#[tokio::test(flavor = "current_thread")] -async fn session_timeout_disconnects_and_fails_pending_open() { - run_local_test(async { - let config_a = RuntimeConfig { - fsm: QlFsmConfig { - session_keepalive_interval: Duration::from_millis(40), - session_peer_timeout: Duration::from_millis(60), - ..default_runtime_config().fsm - }, - ..default_runtime_config() - }; - let config_b = default_runtime_config(); - let (platform_a, outbound_a, inbound_a_tx, status_a) = TestPlatform::new(); - let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = - TestPlatform::new_with_inbound(); - let (identity_a, identity_b) = test_identities(&SoftwareCrypto); - - let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config_a); - let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config_b); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - let drop_flag = Arc::new(AtomicBool::new(false)); - spawn_forwarder(outbound_a, inbound_b_tx); - spawn_gated_forwarder(outbound_b, inbound_a_tx, drop_flag.clone()); - - register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect(); - - await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; - await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; - - let responder_task = tokio::task::spawn_local(async move { - let stream = inbound_b.recv().await.unwrap(); - let _ = read_all(stream.reader).await; - let err = stream.writer.finish().await.unwrap_err(); - assert!(matches!(err, QlStreamError::NoSession)); - }); - - drop_flag.store(true, Ordering::Relaxed); - - let mut pending = handle_a.open_stream(test_route_id()).await.unwrap(); - let err = pending.writer.finish().await.unwrap_err(); - assert!(matches!(err, QlStreamError::NoSession)); - - await_status(&status_a, identity_b.xid, PeerStatus::Disconnected).await; - - let result = - tokio::time::timeout(Duration::from_millis(300), next_chunk(&mut pending.reader)) - .await - .unwrap(); - assert!(matches!(result, Err(QlStreamError::NoSession))); - - responder_task.abort(); - }) - .await; -} diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 48ab79fb..066903a5 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -25,9 +25,9 @@ use crate::{ }; mod handshake; -mod heartbeat; #[cfg(feature = "rpc")] mod rpc; +mod session; mod stream; fn init_test_logger() { @@ -43,7 +43,7 @@ fn init_test_logger() { #[derive(Debug, Clone, Copy, PartialEq, Eq)] struct StatusEvent { - peer: XID, + peer: Option, status: PeerStatus, } @@ -135,16 +135,11 @@ impl TestPlatform { ) } - fn new_with_session_write_failure( - fail_encrypted_write_at: usize, - ) -> TestPlatformParts { + fn new_with_session_write_failure(fail_encrypted_write_at: usize) -> TestPlatformParts { Self::new_inner(None, Some(fail_encrypted_write_at), Duration::ZERO, None) } - fn new_with_delayed_writes( - delay: Duration, - write_stats: WriteStats, - ) -> TestPlatformParts { + fn new_with_delayed_writes(delay: Duration, write_stats: WriteStats) -> TestPlatformParts { Self::new_inner(None, None, delay, Some(write_stats)) } @@ -296,13 +291,13 @@ impl TestPair { self.side(initiator).handle.connect(); await_status( &self.side(initiator).status, - self.side(initiator.opposite()).peer, + Some(self.side(initiator.opposite()).peer), PeerStatus::Connected, ) .await; await_status( &self.side(initiator.opposite()).status, - self.side(initiator).peer, + Some(self.side(initiator).peer), PeerStatus::Connected, ) .await; @@ -444,7 +439,7 @@ impl crate::platform::QlPlatform for TestPlatform { fn persist_peer(&self, _peer: PeerBundle) {} - fn handle_peer_status(&self, peer: XID, status: PeerStatus) { + fn handle_peer_status(&self, peer: Option, status: PeerStatus) { let _ = self.status.try_send(StatusEvent { peer, status }); } @@ -613,7 +608,7 @@ where .unwrap_or_else(|_| panic!("local runtime test exceeded {duration:?}")); } -async fn await_status(receiver: &Receiver, peer: XID, stage: PeerStatus) { +async fn await_status(receiver: &Receiver, peer: Option, stage: PeerStatus) { tokio::time::timeout(Duration::from_secs(2), async { loop { if let Ok(event) = receiver.recv().await { @@ -629,7 +624,7 @@ async fn await_status(receiver: &Receiver, peer: XID, stage: PeerSt async fn assert_no_status_for( receiver: &Receiver, - peer: XID, + peer: Option, status: PeerStatus, window: Duration, ) { diff --git a/ql-runtime/src/tests/session.rs b/ql-runtime/src/tests/session.rs new file mode 100644 index 00000000..6b185d24 --- /dev/null +++ b/ql-runtime/src/tests/session.rs @@ -0,0 +1,213 @@ +use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::Duration, +}; + +use bytes::Bytes; +use ql_wire::SessionCloseCode; + +use super::*; +use crate::QlStreamError; + +#[tokio::test(flavor = "current_thread")] +async fn close_session_aborts_active_streams_and_allows_reconnect() { + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + let inbound_b = pair.take_inbound(Side::B); + let (received_tx, received_rx) = async_channel::bounded(1); + pair.connect_and_wait(Side::A).await; + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let mut reader = stream.reader; + + assert_eq!( + next_chunk(&mut reader).await.unwrap(), + Some(vec![1, 2, 3, 4]) + ); + received_tx.send(()).await.unwrap(); + + let err = next_chunk(&mut reader).await.unwrap_err(); + assert_eq!(err, QlStreamError::NoSession); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + stream + .writer + .write(Bytes::from_static(&[1, 2, 3, 4])) + .await + .unwrap(); + received_rx.recv().await.unwrap(); + + pair.side(Side::A) + .handle + .close_session(SessionCloseCode::CANCELLED); + + let err = stream.writer.finish().await.unwrap_err(); + assert_eq!(err, QlStreamError::NoSession); + + await_status( + &pair.side(Side::A).status, + Some(pair.side(Side::B).peer), + PeerStatus::Disconnected, + ) + .await; + await_status( + &pair.side(Side::B).status, + Some(pair.side(Side::A).peer), + PeerStatus::Disconnected, + ) + .await; + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + + pair.connect_and_wait(Side::A).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn unpair_aborts_active_streams_and_prevents_reconnect() { + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + let inbound_b = pair.take_inbound(Side::B); + let (received_tx, received_rx) = async_channel::bounded(1); + pair.connect_and_wait(Side::A).await; + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let mut reader = stream.reader; + + assert_eq!( + next_chunk(&mut reader).await.unwrap(), + Some(vec![5, 6, 7, 8]) + ); + received_tx.send(()).await.unwrap(); + + let err = next_chunk(&mut reader).await.unwrap_err(); + assert_eq!(err, QlStreamError::NoSession); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + stream + .writer + .write(Bytes::from_static(&[5, 6, 7, 8])) + .await + .unwrap(); + received_rx.recv().await.unwrap(); + + pair.side(Side::A).handle.unpair(); + + let err = stream.writer.finish().await.unwrap_err(); + assert_eq!(err, QlStreamError::NoSession); + + await_status(&pair.side(Side::A).status, None, PeerStatus::Unpaired).await; + await_status(&pair.side(Side::B).status, None, PeerStatus::Unpaired).await; + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + + assert!(matches!( + pair.side(Side::A).handle.open_stream(test_route_id()).await, + Err(NoSessionError) + )); + assert!(matches!( + pair.side(Side::B).handle.open_stream(test_route_id()).await, + Err(NoSessionError) + )); + + pair.side(Side::B).handle.connect(); + assert_no_status_for( + &pair.side(Side::B).status, + None, + PeerStatus::Initiator, + Duration::from_millis(150), + ) + .await; + assert_no_status_for( + &pair.side(Side::B).status, + None, + PeerStatus::Connected, + Duration::from_millis(150), + ) + .await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn session_timeout_disconnects_and_fails_pending_open() { + run_local_test(async { + let config_a = RuntimeConfig { + fsm: QlFsmConfig { + session_keepalive_interval: Duration::from_millis(40), + session_peer_timeout: Duration::from_millis(60), + ..default_runtime_config().fsm + }, + ..default_runtime_config() + }; + let config_b = default_runtime_config(); + let (platform_a, outbound_a, inbound_a_tx, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = + TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let drop_flag = Arc::new(AtomicBool::new(false)); + spawn_forwarder(outbound_a, inbound_b_tx); + spawn_gated_forwarder(outbound_b, inbound_a_tx, drop_flag.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect(); + + await_status(&status_a, Some(identity_b.xid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.xid), PeerStatus::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let _ = read_all(stream.reader).await; + let err = stream.writer.finish().await.unwrap_err(); + assert!(matches!(err, QlStreamError::NoSession)); + }); + + drop_flag.store(true, Ordering::Relaxed); + + let mut pending = handle_a.open_stream(test_route_id()).await.unwrap(); + let err = pending.writer.finish().await.unwrap_err(); + assert!(matches!(err, QlStreamError::NoSession)); + + await_status(&status_a, Some(identity_b.xid), PeerStatus::Disconnected).await; + + let result = + tokio::time::timeout(Duration::from_millis(300), next_chunk(&mut pending.reader)) + .await + .unwrap(); + assert!(matches!(result, Err(QlStreamError::NoSession))); + + responder_task.abort(); + }) + .await; +} diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index b85f0229..008a3c6c 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -306,8 +306,8 @@ async fn max_concurrent_message_writes_is_respected() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; - await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; + await_status(&status_a, Some(identity_b.xid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.xid), PeerStatus::Connected).await; let responder = tokio::task::spawn_local(async move { for _ in 0..4 { @@ -381,8 +381,8 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, identity_b.xid, PeerStatus::Connected).await; - await_status(&status_b, identity_a.xid, PeerStatus::Connected).await; + await_status(&status_a, Some(identity_b.xid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.xid), PeerStatus::Connected).await; let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); diff --git a/ql-runtime/src/tests/unpair.rs b/ql-runtime/src/tests/unpair.rs deleted file mode 100644 index 751c9a53..00000000 --- a/ql-runtime/src/tests/unpair.rs +++ /dev/null @@ -1,61 +0,0 @@ -use super::*; - -#[tokio::test(flavor = "current_thread")] -async fn unpair_clears_remote_peer_and_aborts_active_stream() { - run_local_test(async { - let config = default_runtime_config(); - let (platform_a, outbound_a, inbound_a_tx, status_a) = TestPlatform::new(); - let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = - TestPlatform::new_with_inbound(); - let (identity_a, identity_b) = test_identities(&SoftwareCrypto); - - let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); - let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); - - tokio::task::spawn_local(async move { runtime_a.run().await }); - tokio::task::spawn_local(async move { runtime_b.run().await }); - - spawn_forwarder(outbound_a, inbound_b_tx); - spawn_forwarder(outbound_b, inbound_a_tx); - - register_peers(&handle_a, &handle_b, &identity_a, &identity_b); - handle_a.connect().unwrap(); - - await_status(&status_a, identity_b.xid, PeerStage::Connected).await; - await_status(&status_b, identity_a.xid, PeerStage::Connected).await; - - let responder = tokio::task::spawn_local(async move { - let stream = inbound_b.recv().await.unwrap(); - let mut request = stream.request; - let _ = request.next_chunk().await; - let second = request.next_chunk().await; - assert!(matches!(second, Ok(None) | Err(QlError::Cancelled))); - }); - - let mut stream = handle_a.open_stream(test_route_id()).await.unwrap(); - stream.request.write_all(&[1, 2, 3, 4]).await.unwrap(); - - handle_a.unpair().unwrap(); - assert!(matches!( - handle_a.open_stream(test_route_id()).await, - Err(QlError::NoPeerBound) - )); - - tokio::time::timeout(std::time::Duration::from_secs(2), responder) - .await - .unwrap() - .unwrap(); - - let open_err_b = tokio::time::timeout(std::time::Duration::from_secs(2), async { - loop { - match handle_b.open_stream(test_route_id()).await { - Err(QlError::NoPeerBound) => return, - _ => tokio::time::sleep(std::time::Duration::from_millis(10)).await, - } - } - }) - .await; - assert!(open_err_b.is_ok(), "remote peer was not cleared"); - }) - .await; -} From 344b87651f2038269328441d06e768590053508e Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 20 Apr 2026 09:42:44 -0400 Subject: [PATCH 281/304] ql-rpc: fmt --- ql-rpc/src/rpc/upload/client.rs | 4 +++- ql-rpc/src/rpc/upload/server.rs | 25 ++++++++++++------------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/ql-rpc/src/rpc/upload/client.rs b/ql-rpc/src/rpc/upload/client.rs index b86f5477..a31a1fa6 100644 --- a/ql-rpc/src/rpc/upload/client.rs +++ b/ql-rpc/src/rpc/upload/client.rs @@ -37,7 +37,9 @@ where pub async fn finish(mut self) -> Result> { let mut writer = self.writer.take().expect("upload writer exists"); - finish_bytes(&mut writer).await.map_err(CallError::Transport)?; + finish_bytes(&mut writer) + .await + .map_err(CallError::Transport)?; let mut reader = self.reader.take().expect("upload reader exists"); let mut bytes = ChunkQueue::default(); diff --git a/ql-rpc/src/rpc/upload/server.rs b/ql-rpc/src/rpc/upload/server.rs index 36399bcd..1a702ad4 100644 --- a/ql-rpc/src/rpc/upload/server.rs +++ b/ql-rpc/src/rpc/upload/server.rs @@ -99,20 +99,19 @@ pub(crate) async fn handle_upload_inner( S: UploadHandler + 'static, St: RpcStream + 'static, { - let (request, buffered) = match read_framed_request_prefix::(&mut reader, config) - .await - { - Ok(value) => value, - Err(error) => { - let code = error.close_code(); - state.handle_transport_error(&error); - if let Some(code) = code { - reader.close(code); - writer.close(code); + let (request, buffered) = + match read_framed_request_prefix::(&mut reader, config).await { + Ok(value) => value, + Err(error) => { + let code = error.close_code(); + state.handle_transport_error(&error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } + return; } - return; - } - }; + }; state.handle( request, From fcf1c858c80d1162bd0090cd4fa7b583550865ad Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 20 Apr 2026 11:38:36 -0400 Subject: [PATCH 282/304] ql-wire: bytebuf --- ql-wire/src/bytes.rs | 39 ++++++++++++- ql-wire/src/crypto.rs | 14 +++-- ql-wire/src/encrypted/builder.rs | 84 +++++++++++++++------------ ql-wire/src/encrypted/mod.rs | 20 +++++-- ql-wire/src/encrypted_message.rs | 97 -------------------------------- ql-wire/src/handshake/mod.rs | 45 +++++++++------ ql-wire/src/lib.rs | 2 - ql-wire/src/record.rs | 73 ++++++++---------------- ql-wire/src/testing.rs | 47 ++++++++++------ ql-wire/src/tests.rs | 84 ++++++++++++++++----------- 10 files changed, 243 insertions(+), 262 deletions(-) delete mode 100644 ql-wire/src/encrypted_message.rs diff --git a/ql-wire/src/bytes.rs b/ql-wire/src/bytes.rs index 9fecf5ea..cda7ec79 100644 --- a/ql-wire/src/bytes.rs +++ b/ql-wire/src/bytes.rs @@ -1,6 +1,6 @@ use core::ops::{Deref, DerefMut}; -use bytes::{Buf, Bytes}; +use bytes::{Buf, BufMut, Bytes}; /// A mutable or immutable byte slice owner used by the wire parser. pub trait ByteSlice: Deref + Sized { @@ -15,6 +15,43 @@ pub trait ByteSliceMut: ByteSlice + DerefMut {} impl ByteSliceMut for B where B: ByteSlice + DerefMut {} +/// An owned growable byte buffer used by outbound encoding and crypto paths. +pub trait ByteBuf: + AsRef<[u8]> + + AsMut<[u8]> + + Deref + + DerefMut + + BufMut + + Send + + Sized + + 'static +{ + fn with_capacity(capacity: usize) -> Self; + fn len(&self) -> usize; + fn capacity(&self) -> usize; + + fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl ByteBuf for Vec { + #[inline] + fn with_capacity(capacity: usize) -> Self { + Self::with_capacity(capacity) + } + + #[inline] + fn len(&self) -> usize { + Self::len(self) + } + + #[inline] + fn capacity(&self) -> usize { + Self::capacity(self) + } +} + impl ByteSlice for &[u8] { #[inline] fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { diff --git a/ql-wire/src/crypto.rs b/ql-wire/src/crypto.rs index 96ace383..6617ccec 100644 --- a/ql-wire/src/crypto.rs +++ b/ql-wire/src/crypto.rs @@ -1,5 +1,5 @@ use crate::{ - MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, SessionKey, + ByteBuf, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, SessionKey, ENCRYPTED_MESSAGE_AUTH_SIZE, }; @@ -12,22 +12,26 @@ pub trait QlHash { } pub trait QlAead { + type B: ByteBuf; + fn aes256_gcm_encrypt( &self, key: &SessionKey, nonce: &Nonce, aad: &[u8], - buffer: &mut [u8], - ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; + buffer: Self::B, + range: core::ops::Range, + ) -> (Self::B, [u8; ENCRYPTED_MESSAGE_AUTH_SIZE]); fn aes256_gcm_decrypt( &self, key: &SessionKey, nonce: &Nonce, aad: &[u8], - buffer: &mut [u8], + buffer: Self::B, + range: core::ops::Range, auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], - ) -> bool; + ) -> Option; } pub trait QlKem { diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index 42933235..d2e8df8b 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -1,20 +1,18 @@ -use bytes::BufMut; - use super::{RecordAck, SessionClose, SessionFrame, StreamClose, StreamData, StreamWindow}; use crate::{ - BufView, ConnectionId, Nonce, QlCrypto, RecordSeq, RecordType, SessionHeader, SessionKey, - WireEncode, QL_WIRE_VERSION, + BufView, ByteBuf, ConnectionId, Nonce, QlCrypto, RecordSeq, RecordType, SessionHeader, + SessionKey, WireEncode, QL_WIRE_VERSION, }; #[derive(Debug, Clone, PartialEq, Eq)] -pub struct SessionRecordBuilder { +pub struct SessionRecordBuilder { seq: RecordSeq, prefix_len: usize, max_capacity: usize, - bytes: Vec, + bytes: Option, } -impl SessionRecordBuilder { +impl SessionRecordBuilder { pub const MIN_CAPACITY: usize = 1 + 1 + ConnectionId::SIZE @@ -29,7 +27,7 @@ impl SessionRecordBuilder { seq, prefix_len, max_capacity, - bytes: Vec::new(), + bytes: None, } } @@ -46,7 +44,9 @@ impl SessionRecordBuilder { } pub fn len(&self) -> usize { - self.bytes.len().saturating_sub(self.prefix_len) + self.bytes + .as_ref() + .map_or(0, |bytes| bytes.len().saturating_sub(self.prefix_len)) } pub fn is_empty(&self) -> bool { @@ -55,11 +55,14 @@ impl SessionRecordBuilder { pub fn remaining_capacity(&self) -> usize { self.max_capacity - .saturating_sub(self.bytes.len().max(self.prefix_len)) + .saturating_sub(self.prefix_len.saturating_add(self.len())) } pub fn bytes(&self) -> &[u8] { - self.bytes.get(self.prefix_len..).unwrap_or_default() + self.bytes + .as_ref() + .and_then(|bytes| bytes.get(self.prefix_len..)) + .unwrap_or_default() } pub fn push_ping(&mut self) -> bool { @@ -74,7 +77,7 @@ impl SessionRecordBuilder { self.push_frame_payload(super::SessionFrameKind::Ack, ack) } - pub fn push_stream_data(&mut self, frame: &StreamData) -> bool { + pub fn push_stream_data(&mut self, frame: &StreamData) -> bool { self.push_frame_payload(super::SessionFrameKind::StreamData, frame) } @@ -90,7 +93,7 @@ impl SessionRecordBuilder { self.push_frame_payload(super::SessionFrameKind::Close, close) } - pub fn push_frame(&mut self, frame: &SessionFrame) -> bool { + pub fn push_frame(&mut self, frame: &SessionFrame) -> bool { match frame { SessionFrame::Ping => self.push_ping(), SessionFrame::Unpair => self.push_unpair(), @@ -102,44 +105,42 @@ impl SessionRecordBuilder { } } - pub fn encrypt( - mut self, - crypto: &impl QlCrypto, + pub fn encrypt>( + self, + crypto: &C, connection_id: ConnectionId, session_key: &SessionKey, - ) -> Vec { - self.ensure_prefix_capacity(0); + ) -> B { let header = SessionHeader { connection_id, seq: self.seq, }; let aad = header.aad(); let nonce = Nonce::from_counter(self.seq.into_inner()); - let auth = crypto.aes256_gcm_encrypt( - session_key, - &nonce, - &aad, - &mut self.bytes[self.prefix_len..], - ); - - let mut prefix = &mut self.bytes[..self.prefix_len]; + let prefix_len = self.prefix_len; + let bytes = self.into_bytes(0); + let body_range = prefix_len..bytes.len(); + let (mut bytes, auth) = + crypto.aes256_gcm_encrypt(session_key, &nonce, &aad, bytes, body_range); + + let mut prefix = &mut bytes[..prefix_len]; prefix[0] = QL_WIRE_VERSION; prefix[1] = RecordType::Session as u8; prefix = &mut prefix[2..]; header.encode(&mut prefix); auth.encode(&mut prefix); debug_assert!(prefix.is_empty()); - self.bytes + bytes } - fn push_wire_size(&mut self, wire_size: usize, encode: impl FnOnce(&mut Vec)) -> bool { + fn push_wire_size(&mut self, wire_size: usize, encode: impl FnOnce(&mut B)) -> bool { if !self.can_push_len(wire_size) { return false; } - self.ensure_prefix_capacity(wire_size); - let start = self.bytes.len(); - encode(&mut self.bytes); - debug_assert_eq!(self.bytes.len(), start + wire_size); + let bytes = self.bytes_mut(wire_size); + let start = bytes.len(); + encode(bytes); + debug_assert_eq!(bytes.len(), start + wire_size); true } @@ -163,10 +164,21 @@ impl SessionRecordBuilder { len <= self.remaining_capacity() } - fn ensure_prefix_capacity(&mut self, additional_body_len: usize) { - if self.bytes.is_empty() { - self.bytes.reserve(self.prefix_len + additional_body_len); - self.bytes.resize(self.prefix_len, 0); + fn bytes_mut(&mut self, additional_body_len: usize) -> &mut B { + self.ensure_bytes(additional_body_len); + self.bytes.as_mut().unwrap() + } + + fn into_bytes(mut self, additional_body_len: usize) -> B { + self.ensure_bytes(additional_body_len); + self.bytes.take().unwrap() + } + + fn ensure_bytes(&mut self, additional_body_len: usize) { + if self.bytes.is_none() { + let mut bytes = B::with_capacity(self.prefix_len + additional_body_len); + bytes.put_bytes(0, self.prefix_len); + self.bytes = Some(bytes); } } } diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 563f9ded..0143e831 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -1,6 +1,6 @@ use crate::{ - codec, encrypted_message::EncryptedMessage, BufView, ByteSlice, Nonce, QlCrypto, Reader, - SessionHeader, SessionKey, WireDecode, WireEncode, WireError, + codec, BufView, ByteBuf, ByteSlice, Nonce, QlCrypto, Reader, SessionHeader, SessionKey, + WireDecode, WireEncode, WireError, ENCRYPTED_MESSAGE_AUTH_SIZE, }; mod ack; @@ -166,13 +166,21 @@ impl Iterator for SessionFrameIter { } } -pub fn decrypt_record>( - crypto: &impl QlCrypto, +pub fn decrypt_record( + crypto: &impl QlCrypto, header: &SessionHeader, - encrypted: EncryptedMessage, + buffer: B, + ciphertext_range: core::ops::Range, + auth: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], session_key: &SessionKey, ) -> Result { + assert!( + ciphertext_range.start <= ciphertext_range.end && ciphertext_range.end <= buffer.len(), + "ciphertext valid range", + ); let aad = header.aad(); let nonce = Nonce::from_counter(header.seq.into_inner()); - encrypted.decrypt_in_place(crypto, session_key, &nonce, &aad) + crypto + .aes256_gcm_decrypt(session_key, &nonce, &aad, buffer, ciphertext_range, auth) + .ok_or(WireError::DecryptFailed) } diff --git a/ql-wire/src/encrypted_message.rs b/ql-wire/src/encrypted_message.rs deleted file mode 100644 index 9e11d3d0..00000000 --- a/ql-wire/src/encrypted_message.rs +++ /dev/null @@ -1,97 +0,0 @@ -use crate::{ - codec, ByteSlice, Nonce, QlCrypto, SessionKey, WireDecode, WireEncode, WireError, - ENCRYPTED_MESSAGE_AUTH_SIZE, -}; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct EncryptedMessage { - pub auth: [u8; ENCRYPTED_MESSAGE_AUTH_SIZE], - pub ciphertext: B, -} - -impl EncryptedMessage { - pub const AUTH_SIZE: usize = ENCRYPTED_MESSAGE_AUTH_SIZE; - pub const HEADER_LEN: usize = Self::AUTH_SIZE; - - pub fn into_owned(self) -> EncryptedMessage> - where - B: ByteSlice, - { - EncryptedMessage { - auth: self.auth, - ciphertext: self.ciphertext.to_vec(), - } - } -} - -impl WireDecode for EncryptedMessage { - fn decode(reader: &mut codec::Reader) -> Result { - Ok(Self { - auth: reader.decode()?, - ciphertext: reader.take_rest(), - }) - } -} - -impl> EncryptedMessage { - pub fn decrypt( - &self, - crypto: &impl QlCrypto, - key: &SessionKey, - nonce: &Nonce, - aad: &[u8], - ) -> Result, WireError> { - let mut plaintext = self.ciphertext.as_ref().to_vec(); - if !crypto.aes256_gcm_decrypt(key, nonce, aad, &mut plaintext, &self.auth) { - return Err(WireError::DecryptFailed); - } - Ok(plaintext) - } -} - -impl> WireEncode for EncryptedMessage { - fn encoded_len(&self) -> usize { - Self::HEADER_LEN + self.ciphertext.as_ref().len() - } - - fn encode(&self, out: &mut W) { - self.auth.encode(out); - self.ciphertext.as_ref().encode(out); - } -} - -impl> EncryptedMessage { - pub fn decrypt_in_place( - mut self, - crypto: &impl QlCrypto, - key: &SessionKey, - nonce: &Nonce, - aad: &[u8], - ) -> Result { - let ciphertext = self.ciphertext.as_mut(); - if !crypto.aes256_gcm_decrypt(key, nonce, aad, ciphertext, &self.auth) { - return Err(WireError::DecryptFailed); - } - Ok(self.ciphertext) - } -} - -impl EncryptedMessage> { - pub fn encrypt( - crypto: &impl QlCrypto, - key: &SessionKey, - mut plaintext: Vec, - nonce: &Nonce, - aad: &[u8], - ) -> Self { - let auth = crypto.aes256_gcm_encrypt(key, nonce, aad, &mut plaintext); - Self { - auth, - ciphertext: plaintext, - } - } - - pub fn decode(bytes: &[u8]) -> Result { - Ok(EncryptedMessage::decode_exact(bytes)?.into_owned()) - } -} diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 79e0f7ad..87eed252 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -1,6 +1,8 @@ +use bytes::BufMut; + use crate::{ - codec, ByteSlice, ConnectionId, HandshakeKind, MlKemCiphertext, MlKemKeyPair, MlKemPublicKey, - Nonce, PeerBundle, QlCrypto, SessionKey, WireDecode, WireEncode, WireError, + codec, ByteBuf, ByteSlice, ConnectionId, HandshakeKind, MlKemCiphertext, MlKemKeyPair, + MlKemPublicKey, Nonce, PeerBundle, QlCrypto, SessionKey, WireDecode, WireEncode, WireError, ENCRYPTED_MESSAGE_AUTH_SIZE, XID, }; @@ -225,43 +227,50 @@ impl CipherState { self.key.is_some() } - fn encrypt( + fn encrypt( &mut self, - crypto: &impl QlCrypto, + crypto: &C, aad: &[u8], plaintext: &[u8], ) -> Result, WireError> { let key = self.key.as_ref().ok_or(WireError::InvalidState)?; let nonce = Nonce::from_counter(self.nonce); - let mut ciphertext = Vec::with_capacity(plaintext.len() + ENCRYPTED_MESSAGE_AUTH_SIZE); - ciphertext.extend_from_slice(plaintext); - let auth = crypto.aes256_gcm_encrypt(key, &nonce, aad, &mut ciphertext); + let mut buffer = C::B::with_capacity(plaintext.len()); + buffer.put_slice(plaintext); + let payload_len = buffer.len(); + let (ciphertext, auth) = + crypto.aes256_gcm_encrypt(key, &nonce, aad, buffer, 0..payload_len); + let mut out = Vec::with_capacity(ciphertext.len() + ENCRYPTED_MESSAGE_AUTH_SIZE); + out.extend_from_slice(&ciphertext); + out.extend_from_slice(&auth); self.nonce = self.nonce.wrapping_add(1); - ciphertext.extend_from_slice(&auth); - Ok(ciphertext) + Ok(out) } - fn decrypt( + fn decrypt( &mut self, - crypto: &impl QlCrypto, + crypto: &C, aad: &[u8], ciphertext: &[u8], ) -> Result, WireError> { if ciphertext.len() < ENCRYPTED_MESSAGE_AUTH_SIZE { return Err(WireError::InvalidPayload); } + let split = ciphertext.len() - ENCRYPTED_MESSAGE_AUTH_SIZE; let (ciphertext, auth) = ciphertext.split_at(split); - let mut plaintext = ciphertext.to_vec(); - let key = self.key.as_ref().ok_or(WireError::InvalidState)?; - let nonce = Nonce::from_counter(self.nonce); let mut auth_tag = [0u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; auth_tag.copy_from_slice(auth); - if !crypto.aes256_gcm_decrypt(key, &nonce, aad, &mut plaintext, &auth_tag) { - return Err(WireError::DecryptFailed); - } + let key = self.key.as_ref().ok_or(WireError::InvalidState)?; + let nonce = Nonce::from_counter(self.nonce); + let mut buffer = C::B::with_capacity(ciphertext.len()); + buffer.put_slice(ciphertext); + let payload_len = buffer.len(); + let plaintext = crypto + .aes256_gcm_decrypt(key, &nonce, aad, buffer, 0..payload_len, &auth_tag) + .ok_or(WireError::DecryptFailed)?; self.nonce = self.nonce.wrapping_add(1); - Ok(plaintext) + Ok(plaintext.to_vec()) } } diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs index 63b5e633..821aa32b 100644 --- a/ql-wire/src/lib.rs +++ b/ql-wire/src/lib.rs @@ -8,7 +8,6 @@ mod bytes; mod codec; mod crypto; mod encrypted; -mod encrypted_message; mod error; mod handshake; mod header; @@ -25,7 +24,6 @@ pub use bytes::*; pub use codec::*; pub use crypto::*; pub use encrypted::*; -pub use encrypted_message::*; pub use error::*; pub use handshake::*; pub use header::*; diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index 163a1bff..b1d1fe00 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -1,26 +1,21 @@ use crate::{ codec, - encrypted_message::EncryptedMessage, handshake::{Ik1, Ik2, Kk1, Kk2, Xx1, Xx2, Xx3, Xx4}, - ByteSlice, SessionHeader, WireDecode, WireEncode, WireError, QL_WIRE_VERSION, + ByteBuf, ByteSlice, SessionHeader, WireDecode, WireEncode, WireError, + ENCRYPTED_MESSAGE_AUTH_SIZE, QL_WIRE_VERSION, }; -pub fn encode_record(out: &mut W, record_type: RecordType, body: &T) -where - W: bytes::BufMut + ?Sized, - T: WireEncode + ?Sized, -{ +pub fn encode_record( + record_type: RecordType, + body: &T, +) -> B { + let mut out = B::with_capacity(RecordHeader::WIRE_SIZE + body.encoded_len()); RecordHeader { version: QL_WIRE_VERSION, record_type, } - .encode(out); - body.encode(out); -} - -pub fn encode_record_vec(record_type: RecordType, body: &T) -> Vec { - let mut out = Vec::with_capacity(RecordHeader::WIRE_SIZE + body.encoded_len()); - encode_record(&mut out, record_type, body); + .encode(&mut out); + body.encode(&mut out); out } @@ -33,6 +28,21 @@ where Ok((reader.decode()?, reader.decode()?)) } +pub fn decode_session_record_prefix( + bytes: &[u8], +) -> Result<(SessionHeader, [u8; ENCRYPTED_MESSAGE_AUTH_SIZE], usize), WireError> { + let mut reader = codec::Reader::new(bytes); + let record = reader.decode::()?; + if record.version != QL_WIRE_VERSION || record.record_type != RecordType::Session { + return Err(WireError::InvalidPayload); + } + + let header = reader.decode::()?; + let auth = reader.decode()?; + let ciphertext_start = bytes.len().saturating_sub(reader.remaining_len()); + Ok((header, auth, ciphertext_start)) +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct RecordHeader { pub version: u8, @@ -217,38 +227,3 @@ impl WireDecode for QlHandshakeRecord { } } } - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct QlSessionRecord { - pub header: SessionHeader, - pub payload: EncryptedMessage, -} - -impl> WireEncode for QlSessionRecord { - fn encoded_len(&self) -> usize { - self.header.encoded_len() + self.payload.encoded_len() - } - - fn encode(&self, out: &mut W) { - self.header.encode(out); - self.payload.encode(out); - } -} - -impl QlSessionRecord { - pub fn into_owned(self) -> QlSessionRecord> { - QlSessionRecord { - header: self.header, - payload: self.payload.into_owned(), - } - } -} - -impl WireDecode for QlSessionRecord { - fn decode(reader: &mut codec::Reader) -> Result { - Ok(Self { - header: reader.decode()?, - payload: reader.decode()?, - }) - } -} diff --git a/ql-wire/src/testing.rs b/ql-wire/src/testing.rs index 83b4fbde..11d7c4fc 100644 --- a/ql-wire/src/testing.rs +++ b/ql-wire/src/testing.rs @@ -38,25 +38,28 @@ impl QlHash for SoftwareCrypto { } impl QlAead for SoftwareCrypto { + type B = Vec; + fn aes256_gcm_encrypt( &self, key: &SessionKey, nonce: &Nonce, aad: &[u8], - buffer: &mut [u8], - ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE] { + mut buffer: Self::B, + range: core::ops::Range, + ) -> (Self::B, [u8; ENCRYPTED_MESSAGE_AUTH_SIZE]) { let key: AesGcm256Key = (*key.data()).into(); - let plaintext = buffer.to_vec(); + let plaintext = buffer[range.clone()].to_vec(); let mut auth = [0u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; key.encrypt( - buffer, + &mut buffer[range], (&mut auth).into(), (&nonce.0).into(), aad, &plaintext, ) .unwrap(); - auth + (buffer, auth) } fn aes256_gcm_decrypt( @@ -64,13 +67,21 @@ impl QlAead for SoftwareCrypto { key: &SessionKey, nonce: &Nonce, aad: &[u8], - buffer: &mut [u8], + mut buffer: Self::B, + range: core::ops::Range, auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], - ) -> bool { + ) -> Option { let key: AesGcm256Key = (*key.data()).into(); - let ciphertext = buffer.to_vec(); - key.decrypt(buffer, (&nonce.0).into(), aad, &ciphertext, auth_tag.into()) - .is_ok() + let ciphertext = buffer[range.clone()].to_vec(); + key.decrypt( + &mut buffer[range], + (&nonce.0).into(), + aad, + &ciphertext, + auth_tag.into(), + ) + .ok()?; + Some(buffer) } } @@ -129,14 +140,17 @@ impl QlHash for NoopCrypto { } impl QlAead for NoopCrypto { + type B = Vec; + fn aes256_gcm_encrypt( &self, _key: &SessionKey, _nonce: &Nonce, _aad: &[u8], - _buffer: &mut [u8], - ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE] { - [0; ENCRYPTED_MESSAGE_AUTH_SIZE] + buffer: Self::B, + _range: core::ops::Range, + ) -> (Self::B, [u8; ENCRYPTED_MESSAGE_AUTH_SIZE]) { + (buffer, [0; ENCRYPTED_MESSAGE_AUTH_SIZE]) } fn aes256_gcm_decrypt( @@ -144,10 +158,11 @@ impl QlAead for NoopCrypto { _key: &SessionKey, _nonce: &Nonce, _aad: &[u8], - _buffer: &mut [u8], + _buffer: Self::B, + _range: core::ops::Range, _auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], - ) -> bool { - false + ) -> Option { + None } } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 9bbcb9cd..2a9ab68a 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -6,9 +6,10 @@ fn decode_handshake_record(bytes: &[u8]) -> QlHandshakeRecord { decode_record(bytes).unwrap().1 } -fn decode_session_record(bytes: &[u8]) -> QlSessionRecord> { - let (_, record) = decode_record::, _>(bytes).unwrap(); - record.into_owned() +fn decode_session_record_prefix_for_test( + bytes: &[u8], +) -> (SessionHeader, [u8; ENCRYPTED_MESSAGE_AUTH_SIZE], usize) { + decode_session_record_prefix(bytes).unwrap() } fn xid(byte: u8) -> XID { @@ -66,21 +67,20 @@ fn xx_header(byte: u8) -> XxHeader { } fn encrypt_record( - crypto: &impl QlCrypto, + crypto: &impl QlCrypto>, header: SessionHeader, session_key: &SessionKey, body: &[SessionFrame>], -) -> QlSessionRecord> { - let mut builder = SessionRecordBuilder::new(header.seq, usize::MAX); +) -> Vec { + let body_len = body.iter().map(WireEncode::encoded_len).sum::(); + let max_capacity = + RecordHeader::WIRE_SIZE + header.encoded_len() + ENCRYPTED_MESSAGE_AUTH_SIZE + body_len; + let mut builder = SessionRecordBuilder::>::new(header.seq, max_capacity); for frame in body { let pushed = builder.push_frame(frame); debug_assert!(pushed); } - decode_session_record( - builder - .encrypt(crypto, header.connection_id, session_key) - .as_slice(), - ) + builder.encrypt(crypto, header.connection_id, session_key) } #[test] @@ -107,7 +107,7 @@ fn handshake_record_round_trip_supports_ik_kk_and_xx() { }, static_bundle: EncryptedPeerBundle::new(Box::new([13; EncryptedPeerBundle::WIRE_SIZE])), }); - let ik_encoded = encode_record_vec(RecordType::Handshake, &ik); + let ik_encoded: Vec = encode_record(RecordType::Handshake, &ik); assert_eq!( RecordHeader::decode_bytes(ik_encoded.as_slice()).unwrap(), RecordHeader { @@ -126,7 +126,7 @@ fn handshake_record_round_trip_supports_ik_kk_and_xx() { mlkem_public_key: MlKemPublicKey::new(Box::new([15; MlKemPublicKey::SIZE])), }, }); - let kk_encoded = encode_record_vec(RecordType::Handshake, &kk); + let kk_encoded: Vec = encode_record(RecordType::Handshake, &kk); assert_eq!( RecordHeader::decode_bytes(kk_encoded.as_slice()).unwrap(), RecordHeader { @@ -144,7 +144,7 @@ fn handshake_record_round_trip_supports_ik_kk_and_xx() { mlkem_public_key: MlKemPublicKey::new(Box::new([17; MlKemPublicKey::SIZE])), }, }); - let xx_encoded = encode_record_vec(RecordType::Handshake, &xx); + let xx_encoded: Vec = encode_record(RecordType::Handshake, &xx); assert_eq!( RecordHeader::decode_bytes(xx_encoded.as_slice()).unwrap(), RecordHeader { @@ -686,28 +686,44 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { let session_key = SessionKey::from_data([7; SessionKey::SIZE]); let record = encrypt_record(&crypto, header, &session_key, &body); - let bytes = encode_record_vec(RecordType::Session, &record); assert_eq!( - RecordHeader::decode_bytes(bytes.as_slice()).unwrap(), + RecordHeader::decode_bytes(record.as_slice()).unwrap(), RecordHeader { version: QL_WIRE_VERSION, record_type: RecordType::Session, } ); - let decoded = decode_session_record(bytes.as_slice()); - assert_eq!(decoded.header, header); - let encrypted = decoded.payload; + let (decoded_header, auth, ciphertext_start) = decode_session_record_prefix_for_test(&record); + assert_eq!(decoded_header, header); + let ciphertext_range = ciphertext_start..record.len(); - let decrypted = - encrypted::decrypt_record(&crypto, &header, encrypted.clone(), &session_key).unwrap(); - assert_eq!(decode_session_frames(&decrypted).unwrap(), body); + let decrypted = encrypted::decrypt_record( + &crypto, + &header, + record.clone(), + ciphertext_range.clone(), + &auth, + &session_key, + ) + .unwrap(); + assert_eq!( + decode_session_frames(&decrypted[ciphertext_range.clone()]).unwrap(), + body + ); let wrong_header = SessionHeader { connection_id: ConnectionId::from_data([0x99; ConnectionId::SIZE]), seq: header.seq, }; assert_eq!( - encrypted::decrypt_record(&crypto, &wrong_header, encrypted.clone(), &session_key), + encrypted::decrypt_record( + &crypto, + &wrong_header, + record.clone(), + ciphertext_range.clone(), + &auth, + &session_key, + ), Err(WireError::DecryptFailed) ); @@ -716,7 +732,14 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { seq: record_seq(header.seq.into_inner() + 1), }; assert_eq!( - encrypted::decrypt_record(&crypto, &wrong_seq_header, encrypted, &session_key), + encrypted::decrypt_record( + &crypto, + &wrong_seq_header, + record, + ciphertext_range, + &auth, + &session_key, + ), Err(WireError::DecryptFailed) ); } @@ -900,12 +923,9 @@ fn protocol_record_size_breakdown() { print_size("ql-wire pq xx2", xx2.encode_vec().len()); print_size("ql-wire pq xx3", xx3.encode_vec().len()); print_size("ql-wire pq xx4", xx4.encode_vec().len()); - print_size("ql-wire session ping", session_ping.encode_vec().len()); - print_size("ql-wire session ack", session_ack.encode_vec().len()); - print_size("ql-wire session unpair", session_unpair.encode_vec().len()); - print_size( - "ql-wire session stream empty", - session_stream_empty.encode_vec().len(), - ); - print_size("ql-wire session close", session_close.encode_vec().len()); + print_size("ql-wire session ping", session_ping.len()); + print_size("ql-wire session ack", session_ack.len()); + print_size("ql-wire session unpair", session_unpair.len()); + print_size("ql-wire session stream empty", session_stream_empty.len()); + print_size("ql-wire session close", session_close.len()); } From 2bb5e3e34abe744901e3e35b71bad8591c2a1c9a Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 20 Apr 2026 12:14:21 -0400 Subject: [PATCH 283/304] ql-fsm: bytebuf --- ql-fsm/src/fsm.rs | 39 ++++++++++-------- ql-fsm/src/lib.rs | 16 ++++---- ql-fsm/src/session/ack_tracker.rs | 12 +++--- ql-fsm/src/session/mod.rs | 46 +++++++++++++-------- ql-fsm/src/session/tests.rs | 68 ++++++++++++++++++------------- ql-fsm/src/tests/mod.rs | 24 ++++++----- ql-fsm/src/tests/proptest.rs | 65 ++++++++++++++--------------- ql-fsm/src/tests/session.rs | 2 +- 8 files changed, 147 insertions(+), 125 deletions(-) diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs index 6a32b821..85d8129e 100644 --- a/ql-fsm/src/fsm.rs +++ b/ql-fsm/src/fsm.rs @@ -109,12 +109,12 @@ pub fn handle_connect_kk(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), handshake::handle_connect_kk(fsm, crypto) } -pub fn receive( +pub fn receive( fsm: &mut QlFsm, - mut bytes: Vec, - crypto: &impl QlCrypto, + mut bytes: C::B, + crypto: &C, ) -> Result<(), ReceiveError> { - let mut reader = wire::Reader::new(bytes.as_mut_slice()); + let mut reader = wire::Reader::new(&mut bytes[..]); let header = wire::RecordHeader::decode(&mut reader)?; if header.version != wire::QL_WIRE_VERSION { @@ -130,27 +130,30 @@ pub fn receive( let termination = { let QlFsm { state, events, .. } = fsm; let conn = state.link.connected_mut_or_err()?; - let (decrypt_len, seq) = { - let record = wire::QlSessionRecord::decode(&mut reader)?; - if record.header.connection_id != conn.transport.rx_connection_id { + let (plaintext_range, seq) = { + let (header, auth, ciphertext_start) = + wire::decode_session_record_prefix(&bytes)?; + if header.connection_id != conn.transport.rx_connection_id { return Err(ReceiveError::InvalidPayload); } - let payload = wire::decrypt_record( + let ciphertext_range = ciphertext_start..bytes.len(); + bytes = wire::decrypt_record( crypto, - &record.header, - record.payload, + &header, + bytes, + ciphertext_range.clone(), + &auth, &conn.transport.rx_key, )?; - (payload.len(), record.header.seq) + (ciphertext_range, header.seq) }; - let len = bytes.len(); - let plaintext = Bytes::from(bytes).slice(len - decrypt_len..); - let frames = wire::parse_session_frames(plaintext); + let bytes = Bytes::from_owner(bytes); + let plaintext = bytes.slice(plaintext_range); let mut emit = EventSink::new(events); conn.session - .receive(state.now.instant, seq, frames, &mut emit); + .receive(state.now.instant, seq, plaintext, &mut emit); emit.termination }; @@ -192,9 +195,9 @@ pub fn next_deadline(fsm: &QlFsm) -> Option { .min() } -pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option { +pub fn take_next_write(fsm: &mut QlFsm, crypto: &C) -> Option> { if let Some(record) = fsm.state.handshake.take() { - let record = wire::encode_record_vec(ql_wire::RecordType::Handshake, &record); + let record: C::B = wire::encode_record(ql_wire::RecordType::Handshake, &record); return Some(OutboundWrite { record, write_id: None, @@ -204,7 +207,7 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option(state.now.instant)?; let record = builder.encrypt( crypto, conn.transport.tx_connection_id, diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index c824b7ed..534290f3 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -108,9 +108,9 @@ pub struct WriteId(pub(crate) u64); /// outbound record produced by `QlFsm` #[derive(Debug, Clone, PartialEq, Eq)] -pub struct OutboundWrite { +pub struct OutboundWrite { /// wire bytes to hand to the transport - pub record: Vec, + pub record: B, /// write handle that must be completed exactly once pub write_id: Option, } @@ -264,11 +264,11 @@ impl QlFsm { } /// handles one inbound wire message - pub fn receive( + pub fn receive( &mut self, now: FsmTime, - bytes: Vec, - crypto: &impl QlCrypto, + bytes: C::B, + crypto: &C, ) -> Result<(), ReceiveError> { self.state.now = now; fsm::receive(self, bytes, crypto) @@ -302,11 +302,11 @@ impl QlFsm { /// if `write_id` is `Some`, call `complete_write` exactly once /// /// if it is `None`, the record is fire-and-forget - pub fn take_next_write( + pub fn take_next_write( &mut self, now: FsmTime, - crypto: &impl QlCrypto, - ) -> Option { + crypto: &C, + ) -> Option> { self.state.now = now; fsm::take_next_write(self, crypto) } diff --git a/ql-fsm/src/session/ack_tracker.rs b/ql-fsm/src/session/ack_tracker.rs index 240095fc..a75b5c63 100644 --- a/ql-fsm/src/session/ack_tracker.rs +++ b/ql-fsm/src/session/ack_tracker.rs @@ -184,7 +184,7 @@ mod tests { RecordSeq::from_u64(value).unwrap() } - fn ack_ranges(pending_ack: PendingAck) -> Vec<(u64, u64)> { + fn ack_ranges(pending_ack: &PendingAck) -> Vec<(u64, u64)> { pending_ack .ack .ranges() @@ -203,7 +203,7 @@ mod tests { ack_tracker.schedule_ack(now); let pending_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); - assert_eq!(ack_ranges(pending_ack), vec![(10, 12)]); + assert_eq!(ack_ranges(&pending_ack), vec![(10, 12)]); } #[test] @@ -218,7 +218,7 @@ mod tests { ack_tracker.schedule_ack(now + Duration::from_millis(5)); let pending_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); - assert_eq!(ack_ranges(pending_ack), vec![(15, 16), (12, 12), (10, 10)]); + assert_eq!(ack_ranges(&pending_ack), vec![(15, 16), (12, 12), (10, 10)]); } #[test] @@ -242,7 +242,7 @@ mod tests { ack_tracker.schedule_ack(now); let pending_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); - assert_eq!(ack_ranges(pending_ack), vec![(5, 5), (3, 3)]); + assert_eq!(ack_ranges(&pending_ack), vec![(5, 5), (3, 3)]); } #[test] @@ -256,11 +256,11 @@ mod tests { ack_tracker.schedule_ack(now); let first_ack = ack_tracker.pending_ack(4).unwrap(); - assert_eq!(ack_ranges(first_ack.clone()), vec![(5, 5)]); + assert_eq!(ack_ranges(&first_ack), vec![(5, 5)]); ack_tracker.on_ack_emitted(&first_ack); ack_tracker.retire_acked_ranges(&first_ack.ack); let second_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); - assert_eq!(ack_ranges(second_ack), vec![(3, 3), (1, 1)]); + assert_eq!(ack_ranges(&second_ack), vec![(3, 3), (1, 1)]); } } diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 912c7bd4..05e788fc 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -18,9 +18,9 @@ use std::time::{Duration, Instant}; use bytes::Bytes; use indexmap::IndexMap; use ql_wire::{ - CloseTarget, RecordAck, RecordSeq, RouteId, SessionClose, SessionCloseCode, SessionFrame, - SessionRecordBuilder, StreamClose, StreamData, StreamHeader, StreamId, StreamWindow, VarInt, - WireError, + ByteBuf, CloseTarget, RecordAck, RecordSeq, RouteId, SessionClose, SessionCloseCode, + SessionFrame, SessionRecordBuilder, StreamClose, StreamData, StreamHeader, StreamId, + StreamWindow, VarInt, }; use self::{ @@ -103,7 +103,7 @@ impl SessionFsm { pub fn new(mut config: SessionConfig, now: Instant) -> Self { config.record_max_size = config .record_max_size - .max(SessionRecordBuilder::MIN_CAPACITY); + .max(SessionRecordBuilder::>::MIN_CAPACITY); config.stream_send_buffer_size = config.stream_send_buffer_size.max(1); config.stream_receive_buffer_size = config.stream_receive_buffer_size.max(1); config.accepted_record_window = config.accepted_record_window.max(1); @@ -117,13 +117,13 @@ impl SessionFsm { next_stream_ordinal: 0, next_record_seq: RecordSeq::from_u32(0), next_write_id: 0, - tracked_records: Default::default(), + tracked_records: IndexMap::default(), ack_tracker: AckTracker::new( config.accepted_record_window, config.pending_ack_range_limit, ), pending_ping: false, - streams: Default::default(), + streams: IndexMap::default(), next_stream_index: 0, remote_stream_history: RemoteStreamHistory::new(config.local_parity.remote()), }, @@ -199,10 +199,13 @@ impl SessionFsm { self.state.phase == SessionPhase::Closed } - pub fn receive(&mut self, now: Instant, seq: RecordSeq, frames: I, sink: &mut impl EventSink) - where - I: IntoIterator, WireError>>, - { + pub fn receive( + &mut self, + now: Instant, + seq: RecordSeq, + bytes: Bytes, + sink: &mut impl EventSink, + ) { if self.state.phase != SessionPhase::Open { return; } @@ -221,6 +224,7 @@ impl SessionFsm { } let mut ack_eliciting = false; + let frames = ql_wire::parse_session_frames(bytes); for frame in frames { let Ok(frame) = frame else { @@ -352,12 +356,15 @@ impl SessionFsm { || !self.state.tracked_records.is_empty() } - pub fn take_next_write(&mut self, now: Instant) -> Option<(Option, SessionRecordBuilder)> { + pub fn take_next_write( + &mut self, + now: Instant, + ) -> Option<(Option, SessionRecordBuilder)> { match &self.state.phase { SessionPhase::Terminating(frame) => { let seq = self.state.next_record_seq; next_seq(&mut self.state.next_record_seq); - let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); + let mut builder = SessionRecordBuilder::::new(seq, self.config.record_max_size); match frame { TerminalFrame::Close(close) => { assert!(builder.push_close(close), "builder has capacity"); @@ -376,7 +383,7 @@ impl SessionFsm { } self.collect_timeouts(now); - let (builder, outbound) = self.build_next_record(now)?; + let (builder, outbound) = self.build_next_record::(now)?; let should_track = outbound.ping_included || !outbound.window_updates.is_empty() @@ -391,9 +398,12 @@ impl SessionFsm { Some((write_id, builder)) } - fn build_next_record(&mut self, now: Instant) -> Option<(SessionRecordBuilder, TrackedRecord)> { + fn build_next_record( + &mut self, + now: Instant, + ) -> Option<(SessionRecordBuilder, TrackedRecord)> { let seq = self.state.next_record_seq; - let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); + let mut builder = SessionRecordBuilder::::new(seq, self.config.record_max_size); let mut outbound = TrackedRecord { seq, frames: Vec::new(), @@ -445,7 +455,7 @@ impl SessionFsm { fn push_next_pending_stream_close( &mut self, - builder: &mut SessionRecordBuilder, + builder: &mut SessionRecordBuilder, outbound: &mut TrackedRecord, ) { let len = self.state.streams.len(); @@ -472,7 +482,7 @@ impl SessionFsm { fn push_next_pending_stream_window( &mut self, - builder: &mut SessionRecordBuilder, + builder: &mut SessionRecordBuilder, outbound: &mut TrackedRecord, ) { let len = self.state.streams.len(); @@ -505,7 +515,7 @@ impl SessionFsm { fn push_next_stream_data( &mut self, - builder: &mut SessionRecordBuilder, + builder: &mut SessionRecordBuilder, outbound: &mut TrackedRecord, ) { const OVERHEAD: usize = 1 + StreamData::>::MIN_WIRE_SIZE; diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index c84df3a5..ff7735e7 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -2,9 +2,9 @@ use std::time::{Duration, Instant}; use bytes::Bytes; use ql_wire::{ - decode_session_frames, parse_session_frames, CloseTarget, RecordAck, RecordSeq, RouteId, - SessionFrame, SessionRecordBuilder, StreamClose, StreamCloseCode, StreamData, StreamHeader, - StreamId, VarInt, XID, + decode_session_frames, CloseTarget, RecordAck, RecordSeq, RouteId, SessionFrame, + SessionRecordBuilder, StreamClose, StreamCloseCode, StreamData, StreamHeader, StreamId, + VarInt, WireEncode, XID, }; use super::{SessionConfig, SessionEvent, SessionFsm}; @@ -33,10 +33,10 @@ fn record_ack(seq: RecordSeq) -> RecordAck { const REFUSED: StreamCloseCode = StreamCloseCode(1); const TIMEOUT: StreamCloseCode = StreamCloseCode(2); -fn header(value: u64) -> Option { - Some(StreamHeader { +fn header(value: u64) -> StreamHeader { + StreamHeader { route_id: route_id(value), - }) + } } fn opened(stream_id: StreamId) -> SessionEvent { @@ -79,7 +79,7 @@ fn next_outbound( fsm: &mut SessionFsm, now: Instant, ) -> Option<(RecordSeq, Vec>>)> { - let (write_id, builder) = fsm.take_next_write(now)?; + let (write_id, builder) = fsm.take_next_write::>(now)?; if let Some(write_id) = write_id { fsm.complete_write(now, write_id, true); } @@ -111,18 +111,25 @@ fn receive_events( seq: RecordSeq, record: &[SessionFrame>], ) -> Vec { - let mut builder = SessionRecordBuilder::new(seq, usize::MAX); + let mut builder = SessionRecordBuilder::>::new(seq, usize::MAX); for frame in record { assert!(builder.push_frame(frame)); } let bytes = Bytes::from(builder.bytes().to_vec()); - let frames = parse_session_frames(bytes); let mut events = Vec::new(); let mut emit = |event| events.push(event); - fsm.receive(now, seq, frames, &mut emit); + fsm.receive(now, seq, bytes, &mut emit); events } +fn encode_frames(frames: &[SessionFrame>]) -> Bytes { + let mut out = Vec::with_capacity(frames.iter().map(WireEncode::encoded_len).sum()); + for frame in frames { + frame.encode(&mut out); + } + Bytes::from(out) +} + #[test] fn outbound_record_seq_increments_monotonically() { let now = Instant::now(); @@ -161,7 +168,7 @@ fn lost_record_on_one_stream_does_not_block_another_stream() { let now = Instant::now(); let mut fsm = SessionFsm::new( SessionConfig { - record_max_size: 80 + SessionRecordBuilder::MIN_CAPACITY, + record_max_size: 80 + SessionRecordBuilder::>::MIN_CAPACITY, ..SessionConfig::default() }, now, @@ -211,10 +218,11 @@ fn ack_reopens_write_capacity() { let mut events = Vec::new(); let mut emit = |event| events.push(event); + let ack = encode_frames(&[SessionFrame::Ack(record_ack(record_seq))]); fsm.receive( now + Duration::from_millis(1), seq(9), - std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), + ack, &mut emit, ); @@ -248,10 +256,11 @@ fn ack_of_fin_emits_outbound_finished_once() { let mut events = Vec::new(); { let mut emit = |event| events.push(event); + let ack = encode_frames(&[SessionFrame::Ack(record_ack(record_seq))]); fsm.receive( now + Duration::from_millis(1), seq(9), - std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), + ack, &mut emit, ); } @@ -259,10 +268,11 @@ fn ack_of_fin_emits_outbound_finished_once() { { let mut emit = |event| events.push(event); + let ack = encode_frames(&[SessionFrame::Ack(record_ack(record_seq))]); fsm.receive( now + Duration::from_millis(2), seq(10), - std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), + ack, &mut emit, ); } @@ -284,7 +294,7 @@ fn commit_stream_read_is_what_advances_stream_window() { let data = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: header(1), + header: Some(header(1)), fin: false, bytes: b"hi".to_vec(), })]; @@ -294,7 +304,7 @@ fn commit_stream_read_is_what_advances_stream_window() { vec![opened(stream_id), SessionEvent::Readable(stream_id)] ); - let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); + let (write_id, builder) = fsm.take_next_write::>(now + Duration::from_millis(1)).unwrap(); let first = decode_session_frames(builder.bytes()).unwrap(); assert!(write_id.is_none()); assert!(matches!(first.as_slice(), [SessionFrame::Ack(_)])); @@ -333,14 +343,14 @@ fn pure_ack_only_records_are_fire_and_forget() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: header(1), + header: Some(header(1)), fin: false, bytes: b"hi".to_vec(), })]; let _ = receive_events(&mut fsm, now, seq(7), &record); - let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); + let (write_id, builder) = fsm.take_next_write::>(now + Duration::from_millis(1)).unwrap(); let ack = decode_session_frames(builder.bytes()).unwrap(); assert!(write_id.is_none()); assert!(matches!(ack.as_slice(), [SessionFrame::Ack(_)])); @@ -351,7 +361,7 @@ fn pure_ack_only_records_are_fire_and_forget() { &mut emit, ); assert!(fsm - .take_next_write(now + retransmit_timeout + Duration::from_millis(1)) + .take_next_write::>(now + retransmit_timeout + Duration::from_millis(1)) .is_none()); } @@ -363,7 +373,7 @@ fn inbound_stream_data_emits_opened_and_readable() { let record = vec![SessionFrame::StreamData(ql_wire::StreamData { stream_id, offset: offset(0), - header: header(1), + header: Some(header(1)), fin: true, bytes: b"hello".to_vec(), })]; @@ -389,7 +399,7 @@ fn inbound_empty_fin_emits_finished_immediately() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: header(1), + header: Some(header(1)), fin: true, bytes: Vec::new(), })]; @@ -411,7 +421,7 @@ fn remote_stream_close_is_reliable_and_retried() { .unwrap() .close(CloseTarget::Both, StreamCloseCode::CANCELLED); - let (write_id, builder) = fsm.take_next_write(now).unwrap(); + let (write_id, builder) = fsm.take_next_write::>(now).unwrap(); fsm.complete_write(now, write_id.expect("stream close should be tracked"), true); let first = decode_session_frames(builder.bytes()).unwrap(); assert!(matches!( @@ -465,7 +475,7 @@ fn duplicate_stream_data_is_not_redelivered() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: header(1), + header: Some(header(1)), fin: false, bytes: b"hi".to_vec(), })]; @@ -512,7 +522,7 @@ fn late_remote_stream_data_after_close_is_ignored() { let data = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: header(1), + header: Some(header(1)), fin: false, bytes: b"hello".to_vec(), })]; @@ -546,7 +556,7 @@ fn duplicate_finished_remote_data_after_reap_is_ignored() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: header(1), + header: Some(header(1)), fin: true, bytes: b"hello".to_vec(), })]; @@ -575,7 +585,7 @@ fn duplicate_finished_remote_data_before_read_is_ignored() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: header(1), + header: Some(header(1)), fin: true, bytes: b"hello".to_vec(), })]; @@ -683,7 +693,7 @@ fn close_does_not_ack_rejected_record_seq() { let invalid = vec![SessionFrame::StreamData(StreamData { stream_id: stream_id(0), offset: offset(0), - header: header(1), + header: Some(header(1)), fin: false, bytes: b"bad".to_vec(), })]; @@ -791,7 +801,7 @@ fn sparse_out_of_order_ack_ranges_page_and_quiesce() { let now = Instant::now(); let sender_config = SessionConfig { local_parity: StreamParity::Even, - record_max_size: SessionRecordBuilder::MIN_CAPACITY + 40, + record_max_size: SessionRecordBuilder::>::MIN_CAPACITY + 40, ack_delay: Duration::from_millis(5), retransmit_timeout: Duration::from_millis(25), stream_send_buffer_size: 8 * 1024, @@ -800,7 +810,7 @@ fn sparse_out_of_order_ack_ranges_page_and_quiesce() { }; let receiver_config = SessionConfig { local_parity: StreamParity::Odd, - record_max_size: SessionRecordBuilder::MIN_CAPACITY + 10, + record_max_size: SessionRecordBuilder::>::MIN_CAPACITY + 10, ack_delay: Duration::from_millis(1), retransmit_timeout: Duration::from_millis(25), pending_ack_range_limit: 512, diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index bba4b599..8762004e 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -26,8 +26,8 @@ enum Side { impl Side { fn idx(self) -> usize { match self { - Side::A => 0, - Side::B => 1, + Self::A => 0, + Self::B => 1, } } } @@ -176,7 +176,7 @@ impl Harness { Some(write.record) } - fn next_write(&mut self, side: Side) -> Option { + fn next_write(&mut self, side: Side) -> Option>> { let time = self.time(); let Node { fsm, crypto } = self.node_mut(side); fsm.take_next_write(time, crypto) @@ -231,7 +231,7 @@ impl Harness { .complete_write(time, write_id, false); } - fn decode_session_write(&self, write: OutboundWrite, side: Side) -> DecodedSessionWrite { + fn decode_session_write(&self, write: OutboundWrite>, side: Side) -> DecodedSessionWrite { let peer = self.node(match side { Side::A => Side::B, Side::B => Side::A, @@ -326,21 +326,23 @@ fn session_config(harness: &Harness, a: bool) -> SessionConfig { } fn decrypt_record( - crypto: &impl QlCrypto, + crypto: &impl QlCrypto>, record: &[u8], session_key: &SessionKey, ) -> (ql_wire::SessionHeader, Vec>>) { - let (_header, record) = - ql_wire::decode_record::, _>(record).unwrap(); + let (header, auth, ciphertext_start) = ql_wire::decode_session_record_prefix(record).unwrap(); + let ciphertext_range = ciphertext_start..record.len(); let plaintext = ql_wire::decrypt_record( crypto, - &record.header, - record.payload.into_owned(), + &header, + record.to_vec(), + ciphertext_range.clone(), + &auth, session_key, ) .unwrap(); ( - record.header, - ql_wire::decode_session_frames(&plaintext).unwrap(), + header, + ql_wire::decode_session_frames(&plaintext[ciphertext_range]).unwrap(), ) } diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index c383349c..1cb99527 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -296,17 +296,14 @@ impl Runner { Action::Write { side, slot, bytes } => { if let Some(stream_id) = self.slots[side.idx()][*slot] { let mut chunk = Bytes::copy_from_slice(bytes); - let accepted = if let Ok(mut stream) = - self.harness.node_mut(*side).fsm.stream(stream_id) - { - if let Some(mut writer) = stream.writer() { - writer.write(&mut chunk) - } else { - 0 - } - } else { - 0 - }; + let accepted = self + .harness + .node_mut(*side) + .fsm + .stream(stream_id) + .map_or(0, |mut stream| { + stream.writer().map_or(0, |mut writer| writer.write(&mut chunk)) + }); if accepted != 0 { self.expected[opposite(*side).idx()] .entry(stream_id) @@ -317,18 +314,17 @@ impl Runner { } Action::Finish { side, slot } => { if let Some(stream_id) = self.slots[side.idx()][*slot] { - let finished = if let Ok(mut stream) = - self.harness.node_mut(*side).fsm.stream(stream_id) - { - if let Some(writer) = stream.writer() { - writer.finish(); - true - } else { - false - } - } else { - false - }; + let finished = self + .harness + .node_mut(*side) + .fsm + .stream(stream_id) + .is_ok_and(|mut stream| { + stream.writer().is_some_and(|writer| { + writer.finish(); + true + }) + }); if finished { self.finished_by[side.idx()].insert(stream_id); } @@ -336,14 +332,15 @@ impl Runner { } Action::Close { side, slot } => { if let Some(stream_id) = self.slots[side.idx()][*slot] { - let closed = if let Ok(mut stream) = - self.harness.node_mut(*side).fsm.stream(stream_id) - { - stream.close(CloseTarget::Both, StreamCloseCode::CANCELLED); - true - } else { - false - }; + let closed = self + .harness + .node_mut(*side) + .fsm + .stream(stream_id) + .is_ok_and(|mut stream| { + stream.close(CloseTarget::Both, StreamCloseCode::CANCELLED); + true + }); if closed { self.closed_by[side.idx()].insert(stream_id); self.slots[side.idx()][*slot] = None; @@ -863,9 +860,9 @@ fn connected_action_strategy() -> impl Strategy { side_action(Action::DropNext), side_usize_action(queue_index.clone(), Action::deliver_queued), side_usize_action(queue_index.clone(), Action::duplicate_queued), - side_usize_action(queue_index.clone(), Action::drop_queued), + side_usize_action(queue_index, Action::drop_queued), side_usize_action(slot.clone(), Action::open_stream), - side_usize_vec_action(slot.clone(), bytes.clone(), Action::write), + side_usize_vec_action(slot.clone(), bytes, Action::write), side_usize_action(slot.clone(), Action::finish), side_usize_action(slot, Action::close), ] @@ -910,7 +907,7 @@ fn terminal_action_strategy() -> impl Strategy { let queue_index = 0usize..6; prop_oneof![ side_usize_action(slot.clone(), Action::open_stream), - side_usize_vec_action(slot.clone(), bytes.clone(), Action::write), + side_usize_vec_action(slot.clone(), bytes, Action::write), side_usize_action(slot.clone(), Action::finish), side_usize_action(slot, Action::close), side_action(Action::TakeNext), diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 86d0ce85..edd2cf44 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -215,7 +215,7 @@ fn disconnected_stream_operations_fail_with_no_session() { stream.close( ql_wire::CloseTarget::Both, ql_wire::StreamCloseCode::CANCELLED, - ) + ); }), Err(StreamError::NoSession) ); From 9e1a1802f2948b072549c9155581bdc97bfb0e6e Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 20 Apr 2026 15:04:23 -0400 Subject: [PATCH 284/304] Revert "ql-fsm: bytebuf" This reverts commit d63058bac7ba1e7a411671964a6be5b3d12affad. --- ql-fsm/src/fsm.rs | 39 ++++++++---------- ql-fsm/src/lib.rs | 16 ++++---- ql-fsm/src/session/ack_tracker.rs | 12 +++--- ql-fsm/src/session/mod.rs | 46 ++++++++------------- ql-fsm/src/session/tests.rs | 68 +++++++++++++------------------ ql-fsm/src/tests/mod.rs | 24 +++++------ ql-fsm/src/tests/proptest.rs | 65 +++++++++++++++-------------- ql-fsm/src/tests/session.rs | 2 +- 8 files changed, 125 insertions(+), 147 deletions(-) diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs index 85d8129e..6a32b821 100644 --- a/ql-fsm/src/fsm.rs +++ b/ql-fsm/src/fsm.rs @@ -109,12 +109,12 @@ pub fn handle_connect_kk(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), handshake::handle_connect_kk(fsm, crypto) } -pub fn receive( +pub fn receive( fsm: &mut QlFsm, - mut bytes: C::B, - crypto: &C, + mut bytes: Vec, + crypto: &impl QlCrypto, ) -> Result<(), ReceiveError> { - let mut reader = wire::Reader::new(&mut bytes[..]); + let mut reader = wire::Reader::new(bytes.as_mut_slice()); let header = wire::RecordHeader::decode(&mut reader)?; if header.version != wire::QL_WIRE_VERSION { @@ -130,30 +130,27 @@ pub fn receive( let termination = { let QlFsm { state, events, .. } = fsm; let conn = state.link.connected_mut_or_err()?; - let (plaintext_range, seq) = { - let (header, auth, ciphertext_start) = - wire::decode_session_record_prefix(&bytes)?; - if header.connection_id != conn.transport.rx_connection_id { + let (decrypt_len, seq) = { + let record = wire::QlSessionRecord::decode(&mut reader)?; + if record.header.connection_id != conn.transport.rx_connection_id { return Err(ReceiveError::InvalidPayload); } - let ciphertext_range = ciphertext_start..bytes.len(); - bytes = wire::decrypt_record( + let payload = wire::decrypt_record( crypto, - &header, - bytes, - ciphertext_range.clone(), - &auth, + &record.header, + record.payload, &conn.transport.rx_key, )?; - (ciphertext_range, header.seq) + (payload.len(), record.header.seq) }; - let bytes = Bytes::from_owner(bytes); - let plaintext = bytes.slice(plaintext_range); + let len = bytes.len(); + let plaintext = Bytes::from(bytes).slice(len - decrypt_len..); + let frames = wire::parse_session_frames(plaintext); let mut emit = EventSink::new(events); conn.session - .receive(state.now.instant, seq, plaintext, &mut emit); + .receive(state.now.instant, seq, frames, &mut emit); emit.termination }; @@ -195,9 +192,9 @@ pub fn next_deadline(fsm: &QlFsm) -> Option { .min() } -pub fn take_next_write(fsm: &mut QlFsm, crypto: &C) -> Option> { +pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option { if let Some(record) = fsm.state.handshake.take() { - let record: C::B = wire::encode_record(ql_wire::RecordType::Handshake, &record); + let record = wire::encode_record_vec(ql_wire::RecordType::Handshake, &record); return Some(OutboundWrite { record, write_id: None, @@ -207,7 +204,7 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &C) -> Option(state.now.instant)?; + let (write_id, builder) = conn.session.take_next_write(state.now.instant)?; let record = builder.encrypt( crypto, conn.transport.tx_connection_id, diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 534290f3..c824b7ed 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -108,9 +108,9 @@ pub struct WriteId(pub(crate) u64); /// outbound record produced by `QlFsm` #[derive(Debug, Clone, PartialEq, Eq)] -pub struct OutboundWrite { +pub struct OutboundWrite { /// wire bytes to hand to the transport - pub record: B, + pub record: Vec, /// write handle that must be completed exactly once pub write_id: Option, } @@ -264,11 +264,11 @@ impl QlFsm { } /// handles one inbound wire message - pub fn receive( + pub fn receive( &mut self, now: FsmTime, - bytes: C::B, - crypto: &C, + bytes: Vec, + crypto: &impl QlCrypto, ) -> Result<(), ReceiveError> { self.state.now = now; fsm::receive(self, bytes, crypto) @@ -302,11 +302,11 @@ impl QlFsm { /// if `write_id` is `Some`, call `complete_write` exactly once /// /// if it is `None`, the record is fire-and-forget - pub fn take_next_write( + pub fn take_next_write( &mut self, now: FsmTime, - crypto: &C, - ) -> Option> { + crypto: &impl QlCrypto, + ) -> Option { self.state.now = now; fsm::take_next_write(self, crypto) } diff --git a/ql-fsm/src/session/ack_tracker.rs b/ql-fsm/src/session/ack_tracker.rs index a75b5c63..240095fc 100644 --- a/ql-fsm/src/session/ack_tracker.rs +++ b/ql-fsm/src/session/ack_tracker.rs @@ -184,7 +184,7 @@ mod tests { RecordSeq::from_u64(value).unwrap() } - fn ack_ranges(pending_ack: &PendingAck) -> Vec<(u64, u64)> { + fn ack_ranges(pending_ack: PendingAck) -> Vec<(u64, u64)> { pending_ack .ack .ranges() @@ -203,7 +203,7 @@ mod tests { ack_tracker.schedule_ack(now); let pending_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); - assert_eq!(ack_ranges(&pending_ack), vec![(10, 12)]); + assert_eq!(ack_ranges(pending_ack), vec![(10, 12)]); } #[test] @@ -218,7 +218,7 @@ mod tests { ack_tracker.schedule_ack(now + Duration::from_millis(5)); let pending_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); - assert_eq!(ack_ranges(&pending_ack), vec![(15, 16), (12, 12), (10, 10)]); + assert_eq!(ack_ranges(pending_ack), vec![(15, 16), (12, 12), (10, 10)]); } #[test] @@ -242,7 +242,7 @@ mod tests { ack_tracker.schedule_ack(now); let pending_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); - assert_eq!(ack_ranges(&pending_ack), vec![(5, 5), (3, 3)]); + assert_eq!(ack_ranges(pending_ack), vec![(5, 5), (3, 3)]); } #[test] @@ -256,11 +256,11 @@ mod tests { ack_tracker.schedule_ack(now); let first_ack = ack_tracker.pending_ack(4).unwrap(); - assert_eq!(ack_ranges(&first_ack), vec![(5, 5)]); + assert_eq!(ack_ranges(first_ack.clone()), vec![(5, 5)]); ack_tracker.on_ack_emitted(&first_ack); ack_tracker.retire_acked_ranges(&first_ack.ack); let second_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); - assert_eq!(ack_ranges(&second_ack), vec![(3, 3), (1, 1)]); + assert_eq!(ack_ranges(second_ack), vec![(3, 3), (1, 1)]); } } diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 05e788fc..912c7bd4 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -18,9 +18,9 @@ use std::time::{Duration, Instant}; use bytes::Bytes; use indexmap::IndexMap; use ql_wire::{ - ByteBuf, CloseTarget, RecordAck, RecordSeq, RouteId, SessionClose, SessionCloseCode, - SessionFrame, SessionRecordBuilder, StreamClose, StreamData, StreamHeader, StreamId, - StreamWindow, VarInt, + CloseTarget, RecordAck, RecordSeq, RouteId, SessionClose, SessionCloseCode, SessionFrame, + SessionRecordBuilder, StreamClose, StreamData, StreamHeader, StreamId, StreamWindow, VarInt, + WireError, }; use self::{ @@ -103,7 +103,7 @@ impl SessionFsm { pub fn new(mut config: SessionConfig, now: Instant) -> Self { config.record_max_size = config .record_max_size - .max(SessionRecordBuilder::>::MIN_CAPACITY); + .max(SessionRecordBuilder::MIN_CAPACITY); config.stream_send_buffer_size = config.stream_send_buffer_size.max(1); config.stream_receive_buffer_size = config.stream_receive_buffer_size.max(1); config.accepted_record_window = config.accepted_record_window.max(1); @@ -117,13 +117,13 @@ impl SessionFsm { next_stream_ordinal: 0, next_record_seq: RecordSeq::from_u32(0), next_write_id: 0, - tracked_records: IndexMap::default(), + tracked_records: Default::default(), ack_tracker: AckTracker::new( config.accepted_record_window, config.pending_ack_range_limit, ), pending_ping: false, - streams: IndexMap::default(), + streams: Default::default(), next_stream_index: 0, remote_stream_history: RemoteStreamHistory::new(config.local_parity.remote()), }, @@ -199,13 +199,10 @@ impl SessionFsm { self.state.phase == SessionPhase::Closed } - pub fn receive( - &mut self, - now: Instant, - seq: RecordSeq, - bytes: Bytes, - sink: &mut impl EventSink, - ) { + pub fn receive(&mut self, now: Instant, seq: RecordSeq, frames: I, sink: &mut impl EventSink) + where + I: IntoIterator, WireError>>, + { if self.state.phase != SessionPhase::Open { return; } @@ -224,7 +221,6 @@ impl SessionFsm { } let mut ack_eliciting = false; - let frames = ql_wire::parse_session_frames(bytes); for frame in frames { let Ok(frame) = frame else { @@ -356,15 +352,12 @@ impl SessionFsm { || !self.state.tracked_records.is_empty() } - pub fn take_next_write( - &mut self, - now: Instant, - ) -> Option<(Option, SessionRecordBuilder)> { + pub fn take_next_write(&mut self, now: Instant) -> Option<(Option, SessionRecordBuilder)> { match &self.state.phase { SessionPhase::Terminating(frame) => { let seq = self.state.next_record_seq; next_seq(&mut self.state.next_record_seq); - let mut builder = SessionRecordBuilder::::new(seq, self.config.record_max_size); + let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); match frame { TerminalFrame::Close(close) => { assert!(builder.push_close(close), "builder has capacity"); @@ -383,7 +376,7 @@ impl SessionFsm { } self.collect_timeouts(now); - let (builder, outbound) = self.build_next_record::(now)?; + let (builder, outbound) = self.build_next_record(now)?; let should_track = outbound.ping_included || !outbound.window_updates.is_empty() @@ -398,12 +391,9 @@ impl SessionFsm { Some((write_id, builder)) } - fn build_next_record( - &mut self, - now: Instant, - ) -> Option<(SessionRecordBuilder, TrackedRecord)> { + fn build_next_record(&mut self, now: Instant) -> Option<(SessionRecordBuilder, TrackedRecord)> { let seq = self.state.next_record_seq; - let mut builder = SessionRecordBuilder::::new(seq, self.config.record_max_size); + let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); let mut outbound = TrackedRecord { seq, frames: Vec::new(), @@ -455,7 +445,7 @@ impl SessionFsm { fn push_next_pending_stream_close( &mut self, - builder: &mut SessionRecordBuilder, + builder: &mut SessionRecordBuilder, outbound: &mut TrackedRecord, ) { let len = self.state.streams.len(); @@ -482,7 +472,7 @@ impl SessionFsm { fn push_next_pending_stream_window( &mut self, - builder: &mut SessionRecordBuilder, + builder: &mut SessionRecordBuilder, outbound: &mut TrackedRecord, ) { let len = self.state.streams.len(); @@ -515,7 +505,7 @@ impl SessionFsm { fn push_next_stream_data( &mut self, - builder: &mut SessionRecordBuilder, + builder: &mut SessionRecordBuilder, outbound: &mut TrackedRecord, ) { const OVERHEAD: usize = 1 + StreamData::>::MIN_WIRE_SIZE; diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index ff7735e7..c84df3a5 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -2,9 +2,9 @@ use std::time::{Duration, Instant}; use bytes::Bytes; use ql_wire::{ - decode_session_frames, CloseTarget, RecordAck, RecordSeq, RouteId, SessionFrame, - SessionRecordBuilder, StreamClose, StreamCloseCode, StreamData, StreamHeader, StreamId, - VarInt, WireEncode, XID, + decode_session_frames, parse_session_frames, CloseTarget, RecordAck, RecordSeq, RouteId, + SessionFrame, SessionRecordBuilder, StreamClose, StreamCloseCode, StreamData, StreamHeader, + StreamId, VarInt, XID, }; use super::{SessionConfig, SessionEvent, SessionFsm}; @@ -33,10 +33,10 @@ fn record_ack(seq: RecordSeq) -> RecordAck { const REFUSED: StreamCloseCode = StreamCloseCode(1); const TIMEOUT: StreamCloseCode = StreamCloseCode(2); -fn header(value: u64) -> StreamHeader { - StreamHeader { +fn header(value: u64) -> Option { + Some(StreamHeader { route_id: route_id(value), - } + }) } fn opened(stream_id: StreamId) -> SessionEvent { @@ -79,7 +79,7 @@ fn next_outbound( fsm: &mut SessionFsm, now: Instant, ) -> Option<(RecordSeq, Vec>>)> { - let (write_id, builder) = fsm.take_next_write::>(now)?; + let (write_id, builder) = fsm.take_next_write(now)?; if let Some(write_id) = write_id { fsm.complete_write(now, write_id, true); } @@ -111,25 +111,18 @@ fn receive_events( seq: RecordSeq, record: &[SessionFrame>], ) -> Vec { - let mut builder = SessionRecordBuilder::>::new(seq, usize::MAX); + let mut builder = SessionRecordBuilder::new(seq, usize::MAX); for frame in record { assert!(builder.push_frame(frame)); } let bytes = Bytes::from(builder.bytes().to_vec()); + let frames = parse_session_frames(bytes); let mut events = Vec::new(); let mut emit = |event| events.push(event); - fsm.receive(now, seq, bytes, &mut emit); + fsm.receive(now, seq, frames, &mut emit); events } -fn encode_frames(frames: &[SessionFrame>]) -> Bytes { - let mut out = Vec::with_capacity(frames.iter().map(WireEncode::encoded_len).sum()); - for frame in frames { - frame.encode(&mut out); - } - Bytes::from(out) -} - #[test] fn outbound_record_seq_increments_monotonically() { let now = Instant::now(); @@ -168,7 +161,7 @@ fn lost_record_on_one_stream_does_not_block_another_stream() { let now = Instant::now(); let mut fsm = SessionFsm::new( SessionConfig { - record_max_size: 80 + SessionRecordBuilder::>::MIN_CAPACITY, + record_max_size: 80 + SessionRecordBuilder::MIN_CAPACITY, ..SessionConfig::default() }, now, @@ -218,11 +211,10 @@ fn ack_reopens_write_capacity() { let mut events = Vec::new(); let mut emit = |event| events.push(event); - let ack = encode_frames(&[SessionFrame::Ack(record_ack(record_seq))]); fsm.receive( now + Duration::from_millis(1), seq(9), - ack, + std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), &mut emit, ); @@ -256,11 +248,10 @@ fn ack_of_fin_emits_outbound_finished_once() { let mut events = Vec::new(); { let mut emit = |event| events.push(event); - let ack = encode_frames(&[SessionFrame::Ack(record_ack(record_seq))]); fsm.receive( now + Duration::from_millis(1), seq(9), - ack, + std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), &mut emit, ); } @@ -268,11 +259,10 @@ fn ack_of_fin_emits_outbound_finished_once() { { let mut emit = |event| events.push(event); - let ack = encode_frames(&[SessionFrame::Ack(record_ack(record_seq))]); fsm.receive( now + Duration::from_millis(2), seq(10), - ack, + std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), &mut emit, ); } @@ -294,7 +284,7 @@ fn commit_stream_read_is_what_advances_stream_window() { let data = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: Some(header(1)), + header: header(1), fin: false, bytes: b"hi".to_vec(), })]; @@ -304,7 +294,7 @@ fn commit_stream_read_is_what_advances_stream_window() { vec![opened(stream_id), SessionEvent::Readable(stream_id)] ); - let (write_id, builder) = fsm.take_next_write::>(now + Duration::from_millis(1)).unwrap(); + let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); let first = decode_session_frames(builder.bytes()).unwrap(); assert!(write_id.is_none()); assert!(matches!(first.as_slice(), [SessionFrame::Ack(_)])); @@ -343,14 +333,14 @@ fn pure_ack_only_records_are_fire_and_forget() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: Some(header(1)), + header: header(1), fin: false, bytes: b"hi".to_vec(), })]; let _ = receive_events(&mut fsm, now, seq(7), &record); - let (write_id, builder) = fsm.take_next_write::>(now + Duration::from_millis(1)).unwrap(); + let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); let ack = decode_session_frames(builder.bytes()).unwrap(); assert!(write_id.is_none()); assert!(matches!(ack.as_slice(), [SessionFrame::Ack(_)])); @@ -361,7 +351,7 @@ fn pure_ack_only_records_are_fire_and_forget() { &mut emit, ); assert!(fsm - .take_next_write::>(now + retransmit_timeout + Duration::from_millis(1)) + .take_next_write(now + retransmit_timeout + Duration::from_millis(1)) .is_none()); } @@ -373,7 +363,7 @@ fn inbound_stream_data_emits_opened_and_readable() { let record = vec![SessionFrame::StreamData(ql_wire::StreamData { stream_id, offset: offset(0), - header: Some(header(1)), + header: header(1), fin: true, bytes: b"hello".to_vec(), })]; @@ -399,7 +389,7 @@ fn inbound_empty_fin_emits_finished_immediately() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: Some(header(1)), + header: header(1), fin: true, bytes: Vec::new(), })]; @@ -421,7 +411,7 @@ fn remote_stream_close_is_reliable_and_retried() { .unwrap() .close(CloseTarget::Both, StreamCloseCode::CANCELLED); - let (write_id, builder) = fsm.take_next_write::>(now).unwrap(); + let (write_id, builder) = fsm.take_next_write(now).unwrap(); fsm.complete_write(now, write_id.expect("stream close should be tracked"), true); let first = decode_session_frames(builder.bytes()).unwrap(); assert!(matches!( @@ -475,7 +465,7 @@ fn duplicate_stream_data_is_not_redelivered() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: Some(header(1)), + header: header(1), fin: false, bytes: b"hi".to_vec(), })]; @@ -522,7 +512,7 @@ fn late_remote_stream_data_after_close_is_ignored() { let data = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: Some(header(1)), + header: header(1), fin: false, bytes: b"hello".to_vec(), })]; @@ -556,7 +546,7 @@ fn duplicate_finished_remote_data_after_reap_is_ignored() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: Some(header(1)), + header: header(1), fin: true, bytes: b"hello".to_vec(), })]; @@ -585,7 +575,7 @@ fn duplicate_finished_remote_data_before_read_is_ignored() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: Some(header(1)), + header: header(1), fin: true, bytes: b"hello".to_vec(), })]; @@ -693,7 +683,7 @@ fn close_does_not_ack_rejected_record_seq() { let invalid = vec![SessionFrame::StreamData(StreamData { stream_id: stream_id(0), offset: offset(0), - header: Some(header(1)), + header: header(1), fin: false, bytes: b"bad".to_vec(), })]; @@ -801,7 +791,7 @@ fn sparse_out_of_order_ack_ranges_page_and_quiesce() { let now = Instant::now(); let sender_config = SessionConfig { local_parity: StreamParity::Even, - record_max_size: SessionRecordBuilder::>::MIN_CAPACITY + 40, + record_max_size: SessionRecordBuilder::MIN_CAPACITY + 40, ack_delay: Duration::from_millis(5), retransmit_timeout: Duration::from_millis(25), stream_send_buffer_size: 8 * 1024, @@ -810,7 +800,7 @@ fn sparse_out_of_order_ack_ranges_page_and_quiesce() { }; let receiver_config = SessionConfig { local_parity: StreamParity::Odd, - record_max_size: SessionRecordBuilder::>::MIN_CAPACITY + 10, + record_max_size: SessionRecordBuilder::MIN_CAPACITY + 10, ack_delay: Duration::from_millis(1), retransmit_timeout: Duration::from_millis(25), pending_ack_range_limit: 512, diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 8762004e..bba4b599 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -26,8 +26,8 @@ enum Side { impl Side { fn idx(self) -> usize { match self { - Self::A => 0, - Self::B => 1, + Side::A => 0, + Side::B => 1, } } } @@ -176,7 +176,7 @@ impl Harness { Some(write.record) } - fn next_write(&mut self, side: Side) -> Option>> { + fn next_write(&mut self, side: Side) -> Option { let time = self.time(); let Node { fsm, crypto } = self.node_mut(side); fsm.take_next_write(time, crypto) @@ -231,7 +231,7 @@ impl Harness { .complete_write(time, write_id, false); } - fn decode_session_write(&self, write: OutboundWrite>, side: Side) -> DecodedSessionWrite { + fn decode_session_write(&self, write: OutboundWrite, side: Side) -> DecodedSessionWrite { let peer = self.node(match side { Side::A => Side::B, Side::B => Side::A, @@ -326,23 +326,21 @@ fn session_config(harness: &Harness, a: bool) -> SessionConfig { } fn decrypt_record( - crypto: &impl QlCrypto>, + crypto: &impl QlCrypto, record: &[u8], session_key: &SessionKey, ) -> (ql_wire::SessionHeader, Vec>>) { - let (header, auth, ciphertext_start) = ql_wire::decode_session_record_prefix(record).unwrap(); - let ciphertext_range = ciphertext_start..record.len(); + let (_header, record) = + ql_wire::decode_record::, _>(record).unwrap(); let plaintext = ql_wire::decrypt_record( crypto, - &header, - record.to_vec(), - ciphertext_range.clone(), - &auth, + &record.header, + record.payload.into_owned(), session_key, ) .unwrap(); ( - header, - ql_wire::decode_session_frames(&plaintext[ciphertext_range]).unwrap(), + record.header, + ql_wire::decode_session_frames(&plaintext).unwrap(), ) } diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index 1cb99527..c383349c 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -296,14 +296,17 @@ impl Runner { Action::Write { side, slot, bytes } => { if let Some(stream_id) = self.slots[side.idx()][*slot] { let mut chunk = Bytes::copy_from_slice(bytes); - let accepted = self - .harness - .node_mut(*side) - .fsm - .stream(stream_id) - .map_or(0, |mut stream| { - stream.writer().map_or(0, |mut writer| writer.write(&mut chunk)) - }); + let accepted = if let Ok(mut stream) = + self.harness.node_mut(*side).fsm.stream(stream_id) + { + if let Some(mut writer) = stream.writer() { + writer.write(&mut chunk) + } else { + 0 + } + } else { + 0 + }; if accepted != 0 { self.expected[opposite(*side).idx()] .entry(stream_id) @@ -314,17 +317,18 @@ impl Runner { } Action::Finish { side, slot } => { if let Some(stream_id) = self.slots[side.idx()][*slot] { - let finished = self - .harness - .node_mut(*side) - .fsm - .stream(stream_id) - .is_ok_and(|mut stream| { - stream.writer().is_some_and(|writer| { - writer.finish(); - true - }) - }); + let finished = if let Ok(mut stream) = + self.harness.node_mut(*side).fsm.stream(stream_id) + { + if let Some(writer) = stream.writer() { + writer.finish(); + true + } else { + false + } + } else { + false + }; if finished { self.finished_by[side.idx()].insert(stream_id); } @@ -332,15 +336,14 @@ impl Runner { } Action::Close { side, slot } => { if let Some(stream_id) = self.slots[side.idx()][*slot] { - let closed = self - .harness - .node_mut(*side) - .fsm - .stream(stream_id) - .is_ok_and(|mut stream| { - stream.close(CloseTarget::Both, StreamCloseCode::CANCELLED); - true - }); + let closed = if let Ok(mut stream) = + self.harness.node_mut(*side).fsm.stream(stream_id) + { + stream.close(CloseTarget::Both, StreamCloseCode::CANCELLED); + true + } else { + false + }; if closed { self.closed_by[side.idx()].insert(stream_id); self.slots[side.idx()][*slot] = None; @@ -860,9 +863,9 @@ fn connected_action_strategy() -> impl Strategy { side_action(Action::DropNext), side_usize_action(queue_index.clone(), Action::deliver_queued), side_usize_action(queue_index.clone(), Action::duplicate_queued), - side_usize_action(queue_index, Action::drop_queued), + side_usize_action(queue_index.clone(), Action::drop_queued), side_usize_action(slot.clone(), Action::open_stream), - side_usize_vec_action(slot.clone(), bytes, Action::write), + side_usize_vec_action(slot.clone(), bytes.clone(), Action::write), side_usize_action(slot.clone(), Action::finish), side_usize_action(slot, Action::close), ] @@ -907,7 +910,7 @@ fn terminal_action_strategy() -> impl Strategy { let queue_index = 0usize..6; prop_oneof![ side_usize_action(slot.clone(), Action::open_stream), - side_usize_vec_action(slot.clone(), bytes, Action::write), + side_usize_vec_action(slot.clone(), bytes.clone(), Action::write), side_usize_action(slot.clone(), Action::finish), side_usize_action(slot, Action::close), side_action(Action::TakeNext), diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index edd2cf44..86d0ce85 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -215,7 +215,7 @@ fn disconnected_stream_operations_fail_with_no_session() { stream.close( ql_wire::CloseTarget::Both, ql_wire::StreamCloseCode::CANCELLED, - ); + ) }), Err(StreamError::NoSession) ); From 86b8c19bfbb1fb0910176e399c57194b7ab7a88f Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 20 Apr 2026 15:04:25 -0400 Subject: [PATCH 285/304] Revert "ql-wire: bytebuf" This reverts commit 7567dbb7dbca932a605f5e135421210c86387caf. --- ql-wire/src/bytes.rs | 39 +------------ ql-wire/src/crypto.rs | 14 ++--- ql-wire/src/encrypted/builder.rs | 84 ++++++++++++--------------- ql-wire/src/encrypted/mod.rs | 20 ++----- ql-wire/src/encrypted_message.rs | 97 ++++++++++++++++++++++++++++++++ ql-wire/src/handshake/mod.rs | 45 ++++++--------- ql-wire/src/lib.rs | 2 + ql-wire/src/record.rs | 73 ++++++++++++++++-------- ql-wire/src/testing.rs | 47 ++++++---------- ql-wire/src/tests.rs | 84 +++++++++++---------------- 10 files changed, 262 insertions(+), 243 deletions(-) create mode 100644 ql-wire/src/encrypted_message.rs diff --git a/ql-wire/src/bytes.rs b/ql-wire/src/bytes.rs index cda7ec79..9fecf5ea 100644 --- a/ql-wire/src/bytes.rs +++ b/ql-wire/src/bytes.rs @@ -1,6 +1,6 @@ use core::ops::{Deref, DerefMut}; -use bytes::{Buf, BufMut, Bytes}; +use bytes::{Buf, Bytes}; /// A mutable or immutable byte slice owner used by the wire parser. pub trait ByteSlice: Deref + Sized { @@ -15,43 +15,6 @@ pub trait ByteSliceMut: ByteSlice + DerefMut {} impl ByteSliceMut for B where B: ByteSlice + DerefMut {} -/// An owned growable byte buffer used by outbound encoding and crypto paths. -pub trait ByteBuf: - AsRef<[u8]> - + AsMut<[u8]> - + Deref - + DerefMut - + BufMut - + Send - + Sized - + 'static -{ - fn with_capacity(capacity: usize) -> Self; - fn len(&self) -> usize; - fn capacity(&self) -> usize; - - fn is_empty(&self) -> bool { - self.len() == 0 - } -} - -impl ByteBuf for Vec { - #[inline] - fn with_capacity(capacity: usize) -> Self { - Self::with_capacity(capacity) - } - - #[inline] - fn len(&self) -> usize { - Self::len(self) - } - - #[inline] - fn capacity(&self) -> usize { - Self::capacity(self) - } -} - impl ByteSlice for &[u8] { #[inline] fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { diff --git a/ql-wire/src/crypto.rs b/ql-wire/src/crypto.rs index 6617ccec..96ace383 100644 --- a/ql-wire/src/crypto.rs +++ b/ql-wire/src/crypto.rs @@ -1,5 +1,5 @@ use crate::{ - ByteBuf, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, SessionKey, + MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, SessionKey, ENCRYPTED_MESSAGE_AUTH_SIZE, }; @@ -12,26 +12,22 @@ pub trait QlHash { } pub trait QlAead { - type B: ByteBuf; - fn aes256_gcm_encrypt( &self, key: &SessionKey, nonce: &Nonce, aad: &[u8], - buffer: Self::B, - range: core::ops::Range, - ) -> (Self::B, [u8; ENCRYPTED_MESSAGE_AUTH_SIZE]); + buffer: &mut [u8], + ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; fn aes256_gcm_decrypt( &self, key: &SessionKey, nonce: &Nonce, aad: &[u8], - buffer: Self::B, - range: core::ops::Range, + buffer: &mut [u8], auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], - ) -> Option; + ) -> bool; } pub trait QlKem { diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs index d2e8df8b..42933235 100644 --- a/ql-wire/src/encrypted/builder.rs +++ b/ql-wire/src/encrypted/builder.rs @@ -1,18 +1,20 @@ +use bytes::BufMut; + use super::{RecordAck, SessionClose, SessionFrame, StreamClose, StreamData, StreamWindow}; use crate::{ - BufView, ByteBuf, ConnectionId, Nonce, QlCrypto, RecordSeq, RecordType, SessionHeader, - SessionKey, WireEncode, QL_WIRE_VERSION, + BufView, ConnectionId, Nonce, QlCrypto, RecordSeq, RecordType, SessionHeader, SessionKey, + WireEncode, QL_WIRE_VERSION, }; #[derive(Debug, Clone, PartialEq, Eq)] -pub struct SessionRecordBuilder { +pub struct SessionRecordBuilder { seq: RecordSeq, prefix_len: usize, max_capacity: usize, - bytes: Option, + bytes: Vec, } -impl SessionRecordBuilder { +impl SessionRecordBuilder { pub const MIN_CAPACITY: usize = 1 + 1 + ConnectionId::SIZE @@ -27,7 +29,7 @@ impl SessionRecordBuilder { seq, prefix_len, max_capacity, - bytes: None, + bytes: Vec::new(), } } @@ -44,9 +46,7 @@ impl SessionRecordBuilder { } pub fn len(&self) -> usize { - self.bytes - .as_ref() - .map_or(0, |bytes| bytes.len().saturating_sub(self.prefix_len)) + self.bytes.len().saturating_sub(self.prefix_len) } pub fn is_empty(&self) -> bool { @@ -55,14 +55,11 @@ impl SessionRecordBuilder { pub fn remaining_capacity(&self) -> usize { self.max_capacity - .saturating_sub(self.prefix_len.saturating_add(self.len())) + .saturating_sub(self.bytes.len().max(self.prefix_len)) } pub fn bytes(&self) -> &[u8] { - self.bytes - .as_ref() - .and_then(|bytes| bytes.get(self.prefix_len..)) - .unwrap_or_default() + self.bytes.get(self.prefix_len..).unwrap_or_default() } pub fn push_ping(&mut self) -> bool { @@ -77,7 +74,7 @@ impl SessionRecordBuilder { self.push_frame_payload(super::SessionFrameKind::Ack, ack) } - pub fn push_stream_data(&mut self, frame: &StreamData) -> bool { + pub fn push_stream_data(&mut self, frame: &StreamData) -> bool { self.push_frame_payload(super::SessionFrameKind::StreamData, frame) } @@ -93,7 +90,7 @@ impl SessionRecordBuilder { self.push_frame_payload(super::SessionFrameKind::Close, close) } - pub fn push_frame(&mut self, frame: &SessionFrame) -> bool { + pub fn push_frame(&mut self, frame: &SessionFrame) -> bool { match frame { SessionFrame::Ping => self.push_ping(), SessionFrame::Unpair => self.push_unpair(), @@ -105,42 +102,44 @@ impl SessionRecordBuilder { } } - pub fn encrypt>( - self, - crypto: &C, + pub fn encrypt( + mut self, + crypto: &impl QlCrypto, connection_id: ConnectionId, session_key: &SessionKey, - ) -> B { + ) -> Vec { + self.ensure_prefix_capacity(0); let header = SessionHeader { connection_id, seq: self.seq, }; let aad = header.aad(); let nonce = Nonce::from_counter(self.seq.into_inner()); - let prefix_len = self.prefix_len; - let bytes = self.into_bytes(0); - let body_range = prefix_len..bytes.len(); - let (mut bytes, auth) = - crypto.aes256_gcm_encrypt(session_key, &nonce, &aad, bytes, body_range); - - let mut prefix = &mut bytes[..prefix_len]; + let auth = crypto.aes256_gcm_encrypt( + session_key, + &nonce, + &aad, + &mut self.bytes[self.prefix_len..], + ); + + let mut prefix = &mut self.bytes[..self.prefix_len]; prefix[0] = QL_WIRE_VERSION; prefix[1] = RecordType::Session as u8; prefix = &mut prefix[2..]; header.encode(&mut prefix); auth.encode(&mut prefix); debug_assert!(prefix.is_empty()); - bytes + self.bytes } - fn push_wire_size(&mut self, wire_size: usize, encode: impl FnOnce(&mut B)) -> bool { + fn push_wire_size(&mut self, wire_size: usize, encode: impl FnOnce(&mut Vec)) -> bool { if !self.can_push_len(wire_size) { return false; } - let bytes = self.bytes_mut(wire_size); - let start = bytes.len(); - encode(bytes); - debug_assert_eq!(bytes.len(), start + wire_size); + self.ensure_prefix_capacity(wire_size); + let start = self.bytes.len(); + encode(&mut self.bytes); + debug_assert_eq!(self.bytes.len(), start + wire_size); true } @@ -164,21 +163,10 @@ impl SessionRecordBuilder { len <= self.remaining_capacity() } - fn bytes_mut(&mut self, additional_body_len: usize) -> &mut B { - self.ensure_bytes(additional_body_len); - self.bytes.as_mut().unwrap() - } - - fn into_bytes(mut self, additional_body_len: usize) -> B { - self.ensure_bytes(additional_body_len); - self.bytes.take().unwrap() - } - - fn ensure_bytes(&mut self, additional_body_len: usize) { - if self.bytes.is_none() { - let mut bytes = B::with_capacity(self.prefix_len + additional_body_len); - bytes.put_bytes(0, self.prefix_len); - self.bytes = Some(bytes); + fn ensure_prefix_capacity(&mut self, additional_body_len: usize) { + if self.bytes.is_empty() { + self.bytes.reserve(self.prefix_len + additional_body_len); + self.bytes.resize(self.prefix_len, 0); } } } diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs index 0143e831..563f9ded 100644 --- a/ql-wire/src/encrypted/mod.rs +++ b/ql-wire/src/encrypted/mod.rs @@ -1,6 +1,6 @@ use crate::{ - codec, BufView, ByteBuf, ByteSlice, Nonce, QlCrypto, Reader, SessionHeader, SessionKey, - WireDecode, WireEncode, WireError, ENCRYPTED_MESSAGE_AUTH_SIZE, + codec, encrypted_message::EncryptedMessage, BufView, ByteSlice, Nonce, QlCrypto, Reader, + SessionHeader, SessionKey, WireDecode, WireEncode, WireError, }; mod ack; @@ -166,21 +166,13 @@ impl Iterator for SessionFrameIter { } } -pub fn decrypt_record( - crypto: &impl QlCrypto, +pub fn decrypt_record>( + crypto: &impl QlCrypto, header: &SessionHeader, - buffer: B, - ciphertext_range: core::ops::Range, - auth: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], + encrypted: EncryptedMessage, session_key: &SessionKey, ) -> Result { - assert!( - ciphertext_range.start <= ciphertext_range.end && ciphertext_range.end <= buffer.len(), - "ciphertext valid range", - ); let aad = header.aad(); let nonce = Nonce::from_counter(header.seq.into_inner()); - crypto - .aes256_gcm_decrypt(session_key, &nonce, &aad, buffer, ciphertext_range, auth) - .ok_or(WireError::DecryptFailed) + encrypted.decrypt_in_place(crypto, session_key, &nonce, &aad) } diff --git a/ql-wire/src/encrypted_message.rs b/ql-wire/src/encrypted_message.rs new file mode 100644 index 00000000..9e11d3d0 --- /dev/null +++ b/ql-wire/src/encrypted_message.rs @@ -0,0 +1,97 @@ +use crate::{ + codec, ByteSlice, Nonce, QlCrypto, SessionKey, WireDecode, WireEncode, WireError, + ENCRYPTED_MESSAGE_AUTH_SIZE, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EncryptedMessage { + pub auth: [u8; ENCRYPTED_MESSAGE_AUTH_SIZE], + pub ciphertext: B, +} + +impl EncryptedMessage { + pub const AUTH_SIZE: usize = ENCRYPTED_MESSAGE_AUTH_SIZE; + pub const HEADER_LEN: usize = Self::AUTH_SIZE; + + pub fn into_owned(self) -> EncryptedMessage> + where + B: ByteSlice, + { + EncryptedMessage { + auth: self.auth, + ciphertext: self.ciphertext.to_vec(), + } + } +} + +impl WireDecode for EncryptedMessage { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + auth: reader.decode()?, + ciphertext: reader.take_rest(), + }) + } +} + +impl> EncryptedMessage { + pub fn decrypt( + &self, + crypto: &impl QlCrypto, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + ) -> Result, WireError> { + let mut plaintext = self.ciphertext.as_ref().to_vec(); + if !crypto.aes256_gcm_decrypt(key, nonce, aad, &mut plaintext, &self.auth) { + return Err(WireError::DecryptFailed); + } + Ok(plaintext) + } +} + +impl> WireEncode for EncryptedMessage { + fn encoded_len(&self) -> usize { + Self::HEADER_LEN + self.ciphertext.as_ref().len() + } + + fn encode(&self, out: &mut W) { + self.auth.encode(out); + self.ciphertext.as_ref().encode(out); + } +} + +impl> EncryptedMessage { + pub fn decrypt_in_place( + mut self, + crypto: &impl QlCrypto, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + ) -> Result { + let ciphertext = self.ciphertext.as_mut(); + if !crypto.aes256_gcm_decrypt(key, nonce, aad, ciphertext, &self.auth) { + return Err(WireError::DecryptFailed); + } + Ok(self.ciphertext) + } +} + +impl EncryptedMessage> { + pub fn encrypt( + crypto: &impl QlCrypto, + key: &SessionKey, + mut plaintext: Vec, + nonce: &Nonce, + aad: &[u8], + ) -> Self { + let auth = crypto.aes256_gcm_encrypt(key, nonce, aad, &mut plaintext); + Self { + auth, + ciphertext: plaintext, + } + } + + pub fn decode(bytes: &[u8]) -> Result { + Ok(EncryptedMessage::decode_exact(bytes)?.into_owned()) + } +} diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 87eed252..79e0f7ad 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -1,8 +1,6 @@ -use bytes::BufMut; - use crate::{ - codec, ByteBuf, ByteSlice, ConnectionId, HandshakeKind, MlKemCiphertext, MlKemKeyPair, - MlKemPublicKey, Nonce, PeerBundle, QlCrypto, SessionKey, WireDecode, WireEncode, WireError, + codec, ByteSlice, ConnectionId, HandshakeKind, MlKemCiphertext, MlKemKeyPair, MlKemPublicKey, + Nonce, PeerBundle, QlCrypto, SessionKey, WireDecode, WireEncode, WireError, ENCRYPTED_MESSAGE_AUTH_SIZE, XID, }; @@ -227,50 +225,43 @@ impl CipherState { self.key.is_some() } - fn encrypt( + fn encrypt( &mut self, - crypto: &C, + crypto: &impl QlCrypto, aad: &[u8], plaintext: &[u8], ) -> Result, WireError> { let key = self.key.as_ref().ok_or(WireError::InvalidState)?; let nonce = Nonce::from_counter(self.nonce); - let mut buffer = C::B::with_capacity(plaintext.len()); - buffer.put_slice(plaintext); - let payload_len = buffer.len(); - let (ciphertext, auth) = - crypto.aes256_gcm_encrypt(key, &nonce, aad, buffer, 0..payload_len); - let mut out = Vec::with_capacity(ciphertext.len() + ENCRYPTED_MESSAGE_AUTH_SIZE); - out.extend_from_slice(&ciphertext); - out.extend_from_slice(&auth); + let mut ciphertext = Vec::with_capacity(plaintext.len() + ENCRYPTED_MESSAGE_AUTH_SIZE); + ciphertext.extend_from_slice(plaintext); + let auth = crypto.aes256_gcm_encrypt(key, &nonce, aad, &mut ciphertext); self.nonce = self.nonce.wrapping_add(1); - Ok(out) + ciphertext.extend_from_slice(&auth); + Ok(ciphertext) } - fn decrypt( + fn decrypt( &mut self, - crypto: &C, + crypto: &impl QlCrypto, aad: &[u8], ciphertext: &[u8], ) -> Result, WireError> { if ciphertext.len() < ENCRYPTED_MESSAGE_AUTH_SIZE { return Err(WireError::InvalidPayload); } - let split = ciphertext.len() - ENCRYPTED_MESSAGE_AUTH_SIZE; let (ciphertext, auth) = ciphertext.split_at(split); - let mut auth_tag = [0u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; - auth_tag.copy_from_slice(auth); + let mut plaintext = ciphertext.to_vec(); let key = self.key.as_ref().ok_or(WireError::InvalidState)?; let nonce = Nonce::from_counter(self.nonce); - let mut buffer = C::B::with_capacity(ciphertext.len()); - buffer.put_slice(ciphertext); - let payload_len = buffer.len(); - let plaintext = crypto - .aes256_gcm_decrypt(key, &nonce, aad, buffer, 0..payload_len, &auth_tag) - .ok_or(WireError::DecryptFailed)?; + let mut auth_tag = [0u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; + auth_tag.copy_from_slice(auth); + if !crypto.aes256_gcm_decrypt(key, &nonce, aad, &mut plaintext, &auth_tag) { + return Err(WireError::DecryptFailed); + } self.nonce = self.nonce.wrapping_add(1); - Ok(plaintext.to_vec()) + Ok(plaintext) } } diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs index 821aa32b..63b5e633 100644 --- a/ql-wire/src/lib.rs +++ b/ql-wire/src/lib.rs @@ -8,6 +8,7 @@ mod bytes; mod codec; mod crypto; mod encrypted; +mod encrypted_message; mod error; mod handshake; mod header; @@ -24,6 +25,7 @@ pub use bytes::*; pub use codec::*; pub use crypto::*; pub use encrypted::*; +pub use encrypted_message::*; pub use error::*; pub use handshake::*; pub use header::*; diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs index b1d1fe00..163a1bff 100644 --- a/ql-wire/src/record.rs +++ b/ql-wire/src/record.rs @@ -1,21 +1,26 @@ use crate::{ codec, + encrypted_message::EncryptedMessage, handshake::{Ik1, Ik2, Kk1, Kk2, Xx1, Xx2, Xx3, Xx4}, - ByteBuf, ByteSlice, SessionHeader, WireDecode, WireEncode, WireError, - ENCRYPTED_MESSAGE_AUTH_SIZE, QL_WIRE_VERSION, + ByteSlice, SessionHeader, WireDecode, WireEncode, WireError, QL_WIRE_VERSION, }; -pub fn encode_record( - record_type: RecordType, - body: &T, -) -> B { - let mut out = B::with_capacity(RecordHeader::WIRE_SIZE + body.encoded_len()); +pub fn encode_record(out: &mut W, record_type: RecordType, body: &T) +where + W: bytes::BufMut + ?Sized, + T: WireEncode + ?Sized, +{ RecordHeader { version: QL_WIRE_VERSION, record_type, } - .encode(&mut out); - body.encode(&mut out); + .encode(out); + body.encode(out); +} + +pub fn encode_record_vec(record_type: RecordType, body: &T) -> Vec { + let mut out = Vec::with_capacity(RecordHeader::WIRE_SIZE + body.encoded_len()); + encode_record(&mut out, record_type, body); out } @@ -28,21 +33,6 @@ where Ok((reader.decode()?, reader.decode()?)) } -pub fn decode_session_record_prefix( - bytes: &[u8], -) -> Result<(SessionHeader, [u8; ENCRYPTED_MESSAGE_AUTH_SIZE], usize), WireError> { - let mut reader = codec::Reader::new(bytes); - let record = reader.decode::()?; - if record.version != QL_WIRE_VERSION || record.record_type != RecordType::Session { - return Err(WireError::InvalidPayload); - } - - let header = reader.decode::()?; - let auth = reader.decode()?; - let ciphertext_start = bytes.len().saturating_sub(reader.remaining_len()); - Ok((header, auth, ciphertext_start)) -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct RecordHeader { pub version: u8, @@ -227,3 +217,38 @@ impl WireDecode for QlHandshakeRecord { } } } + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QlSessionRecord { + pub header: SessionHeader, + pub payload: EncryptedMessage, +} + +impl> WireEncode for QlSessionRecord { + fn encoded_len(&self) -> usize { + self.header.encoded_len() + self.payload.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.payload.encode(out); + } +} + +impl QlSessionRecord { + pub fn into_owned(self) -> QlSessionRecord> { + QlSessionRecord { + header: self.header, + payload: self.payload.into_owned(), + } + } +} + +impl WireDecode for QlSessionRecord { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + payload: reader.decode()?, + }) + } +} diff --git a/ql-wire/src/testing.rs b/ql-wire/src/testing.rs index 11d7c4fc..83b4fbde 100644 --- a/ql-wire/src/testing.rs +++ b/ql-wire/src/testing.rs @@ -38,28 +38,25 @@ impl QlHash for SoftwareCrypto { } impl QlAead for SoftwareCrypto { - type B = Vec; - fn aes256_gcm_encrypt( &self, key: &SessionKey, nonce: &Nonce, aad: &[u8], - mut buffer: Self::B, - range: core::ops::Range, - ) -> (Self::B, [u8; ENCRYPTED_MESSAGE_AUTH_SIZE]) { + buffer: &mut [u8], + ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE] { let key: AesGcm256Key = (*key.data()).into(); - let plaintext = buffer[range.clone()].to_vec(); + let plaintext = buffer.to_vec(); let mut auth = [0u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; key.encrypt( - &mut buffer[range], + buffer, (&mut auth).into(), (&nonce.0).into(), aad, &plaintext, ) .unwrap(); - (buffer, auth) + auth } fn aes256_gcm_decrypt( @@ -67,21 +64,13 @@ impl QlAead for SoftwareCrypto { key: &SessionKey, nonce: &Nonce, aad: &[u8], - mut buffer: Self::B, - range: core::ops::Range, + buffer: &mut [u8], auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], - ) -> Option { + ) -> bool { let key: AesGcm256Key = (*key.data()).into(); - let ciphertext = buffer[range.clone()].to_vec(); - key.decrypt( - &mut buffer[range], - (&nonce.0).into(), - aad, - &ciphertext, - auth_tag.into(), - ) - .ok()?; - Some(buffer) + let ciphertext = buffer.to_vec(); + key.decrypt(buffer, (&nonce.0).into(), aad, &ciphertext, auth_tag.into()) + .is_ok() } } @@ -140,17 +129,14 @@ impl QlHash for NoopCrypto { } impl QlAead for NoopCrypto { - type B = Vec; - fn aes256_gcm_encrypt( &self, _key: &SessionKey, _nonce: &Nonce, _aad: &[u8], - buffer: Self::B, - _range: core::ops::Range, - ) -> (Self::B, [u8; ENCRYPTED_MESSAGE_AUTH_SIZE]) { - (buffer, [0; ENCRYPTED_MESSAGE_AUTH_SIZE]) + _buffer: &mut [u8], + ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE] { + [0; ENCRYPTED_MESSAGE_AUTH_SIZE] } fn aes256_gcm_decrypt( @@ -158,11 +144,10 @@ impl QlAead for NoopCrypto { _key: &SessionKey, _nonce: &Nonce, _aad: &[u8], - _buffer: Self::B, - _range: core::ops::Range, + _buffer: &mut [u8], _auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], - ) -> Option { - None + ) -> bool { + false } } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 2a9ab68a..9bbcb9cd 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -6,10 +6,9 @@ fn decode_handshake_record(bytes: &[u8]) -> QlHandshakeRecord { decode_record(bytes).unwrap().1 } -fn decode_session_record_prefix_for_test( - bytes: &[u8], -) -> (SessionHeader, [u8; ENCRYPTED_MESSAGE_AUTH_SIZE], usize) { - decode_session_record_prefix(bytes).unwrap() +fn decode_session_record(bytes: &[u8]) -> QlSessionRecord> { + let (_, record) = decode_record::, _>(bytes).unwrap(); + record.into_owned() } fn xid(byte: u8) -> XID { @@ -67,20 +66,21 @@ fn xx_header(byte: u8) -> XxHeader { } fn encrypt_record( - crypto: &impl QlCrypto>, + crypto: &impl QlCrypto, header: SessionHeader, session_key: &SessionKey, body: &[SessionFrame>], -) -> Vec { - let body_len = body.iter().map(WireEncode::encoded_len).sum::(); - let max_capacity = - RecordHeader::WIRE_SIZE + header.encoded_len() + ENCRYPTED_MESSAGE_AUTH_SIZE + body_len; - let mut builder = SessionRecordBuilder::>::new(header.seq, max_capacity); +) -> QlSessionRecord> { + let mut builder = SessionRecordBuilder::new(header.seq, usize::MAX); for frame in body { let pushed = builder.push_frame(frame); debug_assert!(pushed); } - builder.encrypt(crypto, header.connection_id, session_key) + decode_session_record( + builder + .encrypt(crypto, header.connection_id, session_key) + .as_slice(), + ) } #[test] @@ -107,7 +107,7 @@ fn handshake_record_round_trip_supports_ik_kk_and_xx() { }, static_bundle: EncryptedPeerBundle::new(Box::new([13; EncryptedPeerBundle::WIRE_SIZE])), }); - let ik_encoded: Vec = encode_record(RecordType::Handshake, &ik); + let ik_encoded = encode_record_vec(RecordType::Handshake, &ik); assert_eq!( RecordHeader::decode_bytes(ik_encoded.as_slice()).unwrap(), RecordHeader { @@ -126,7 +126,7 @@ fn handshake_record_round_trip_supports_ik_kk_and_xx() { mlkem_public_key: MlKemPublicKey::new(Box::new([15; MlKemPublicKey::SIZE])), }, }); - let kk_encoded: Vec = encode_record(RecordType::Handshake, &kk); + let kk_encoded = encode_record_vec(RecordType::Handshake, &kk); assert_eq!( RecordHeader::decode_bytes(kk_encoded.as_slice()).unwrap(), RecordHeader { @@ -144,7 +144,7 @@ fn handshake_record_round_trip_supports_ik_kk_and_xx() { mlkem_public_key: MlKemPublicKey::new(Box::new([17; MlKemPublicKey::SIZE])), }, }); - let xx_encoded: Vec = encode_record(RecordType::Handshake, &xx); + let xx_encoded = encode_record_vec(RecordType::Handshake, &xx); assert_eq!( RecordHeader::decode_bytes(xx_encoded.as_slice()).unwrap(), RecordHeader { @@ -686,44 +686,28 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { let session_key = SessionKey::from_data([7; SessionKey::SIZE]); let record = encrypt_record(&crypto, header, &session_key, &body); + let bytes = encode_record_vec(RecordType::Session, &record); assert_eq!( - RecordHeader::decode_bytes(record.as_slice()).unwrap(), + RecordHeader::decode_bytes(bytes.as_slice()).unwrap(), RecordHeader { version: QL_WIRE_VERSION, record_type: RecordType::Session, } ); - let (decoded_header, auth, ciphertext_start) = decode_session_record_prefix_for_test(&record); - assert_eq!(decoded_header, header); - let ciphertext_range = ciphertext_start..record.len(); + let decoded = decode_session_record(bytes.as_slice()); + assert_eq!(decoded.header, header); + let encrypted = decoded.payload; - let decrypted = encrypted::decrypt_record( - &crypto, - &header, - record.clone(), - ciphertext_range.clone(), - &auth, - &session_key, - ) - .unwrap(); - assert_eq!( - decode_session_frames(&decrypted[ciphertext_range.clone()]).unwrap(), - body - ); + let decrypted = + encrypted::decrypt_record(&crypto, &header, encrypted.clone(), &session_key).unwrap(); + assert_eq!(decode_session_frames(&decrypted).unwrap(), body); let wrong_header = SessionHeader { connection_id: ConnectionId::from_data([0x99; ConnectionId::SIZE]), seq: header.seq, }; assert_eq!( - encrypted::decrypt_record( - &crypto, - &wrong_header, - record.clone(), - ciphertext_range.clone(), - &auth, - &session_key, - ), + encrypted::decrypt_record(&crypto, &wrong_header, encrypted.clone(), &session_key), Err(WireError::DecryptFailed) ); @@ -732,14 +716,7 @@ fn encrypted_session_record_round_trip_uses_connection_id_header() { seq: record_seq(header.seq.into_inner() + 1), }; assert_eq!( - encrypted::decrypt_record( - &crypto, - &wrong_seq_header, - record, - ciphertext_range, - &auth, - &session_key, - ), + encrypted::decrypt_record(&crypto, &wrong_seq_header, encrypted, &session_key), Err(WireError::DecryptFailed) ); } @@ -923,9 +900,12 @@ fn protocol_record_size_breakdown() { print_size("ql-wire pq xx2", xx2.encode_vec().len()); print_size("ql-wire pq xx3", xx3.encode_vec().len()); print_size("ql-wire pq xx4", xx4.encode_vec().len()); - print_size("ql-wire session ping", session_ping.len()); - print_size("ql-wire session ack", session_ack.len()); - print_size("ql-wire session unpair", session_unpair.len()); - print_size("ql-wire session stream empty", session_stream_empty.len()); - print_size("ql-wire session close", session_close.len()); + print_size("ql-wire session ping", session_ping.encode_vec().len()); + print_size("ql-wire session ack", session_ack.encode_vec().len()); + print_size("ql-wire session unpair", session_unpair.encode_vec().len()); + print_size( + "ql-wire session stream empty", + session_stream_empty.encode_vec().len(), + ); + print_size("ql-wire session close", session_close.encode_vec().len()); } From 307054c27b6db6378b07590422496852714279e3 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 21 Apr 2026 08:39:30 -0400 Subject: [PATCH 286/304] remove expiration --- ql-wire/src/handshake/ik.rs | 4 -- ql-wire/src/handshake/kk.rs | 4 -- ql-wire/src/handshake/meta.rs | 10 +--- ql-wire/src/handshake/xx.rs | 8 --- ql-wire/src/tests.rs | 95 ++++++++++++----------------------- 5 files changed, 33 insertions(+), 88 deletions(-) diff --git a/ql-wire/src/handshake/ik.rs b/ql-wire/src/handshake/ik.rs index 10ba0843..460b7ab7 100644 --- a/ql-wire/src/handshake/ik.rs +++ b/ql-wire/src/handshake/ik.rs @@ -286,13 +286,11 @@ impl IkHandshake { pub fn read_1( &mut self, crypto: &impl QlCrypto, - now_seconds: u64, message: &Ik1, ) -> Result<(), WireError> { if self.step != IkStep::Recv1 { return Err(WireError::InvalidState); } - message.meta.ensure_not_expired(now_seconds)?; initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; self.ensure_inbound_recipient(message.header)?; self.ensure_known_remote_sender(message.header)?; @@ -334,13 +332,11 @@ impl IkHandshake { pub fn read_2( &mut self, crypto: &impl QlCrypto, - now_seconds: u64, message: &Ik2, ) -> Result<(), WireError> { if self.step != IkStep::Recv2 { return Err(WireError::InvalidState); } - message.meta.ensure_not_expired(now_seconds)?; require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; self.ensure_inbound_recipient(message.header)?; self.ensure_known_remote_sender(message.header)?; diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index 9cc17ba4..a08e8056 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -274,13 +274,11 @@ impl KkHandshake { pub fn read_1( &mut self, crypto: &impl QlCrypto, - now_seconds: u64, message: &Kk1, ) -> Result<(), WireError> { if self.step != KkStep::Recv1 { return Err(WireError::InvalidState); } - message.meta.ensure_not_expired(now_seconds)?; initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; self.ensure_inbound_header(message.header)?; mix_hash_routed_handshake( @@ -308,13 +306,11 @@ impl KkHandshake { pub fn read_2( &mut self, crypto: &impl QlCrypto, - now_seconds: u64, message: &Kk2, ) -> Result<(), WireError> { if self.step != KkStep::Recv2 { return Err(WireError::InvalidState); } - message.meta.ensure_not_expired(now_seconds)?; require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; self.ensure_inbound_header(message.header)?; mix_hash_routed_handshake( diff --git a/ql-wire/src/handshake/meta.rs b/ql-wire/src/handshake/meta.rs index 4987b03c..8cb0cf97 100644 --- a/ql-wire/src/handshake/meta.rs +++ b/ql-wire/src/handshake/meta.rs @@ -7,7 +7,6 @@ pub struct HandshakeId(pub u32); #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct HandshakeMeta { pub handshake_id: HandshakeId, - pub valid_until: u64, } impl codec::WireDecode for HandshakeId { @@ -27,12 +26,7 @@ impl WireEncode for HandshakeId { } impl HandshakeMeta { - pub const WIRE_SIZE: usize = size_of::() + size_of::(); - - // TODO: re-think expiration - pub fn ensure_not_expired(&self, _now_seconds: u64) -> Result<(), WireError> { - Ok(()) - } + pub const WIRE_SIZE: usize = size_of::(); } impl WireEncode for HandshakeMeta { @@ -42,7 +36,6 @@ impl WireEncode for HandshakeMeta { fn encode(&self, out: &mut W) { self.handshake_id.encode(out); - self.valid_until.encode(out); } } @@ -50,7 +43,6 @@ impl codec::WireDecode for HandshakeMeta { fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { handshake_id: reader.decode()?, - valid_until: reader.decode()?, }) } } diff --git a/ql-wire/src/handshake/xx.rs b/ql-wire/src/handshake/xx.rs index 7dd35943..c4e7294f 100644 --- a/ql-wire/src/handshake/xx.rs +++ b/ql-wire/src/handshake/xx.rs @@ -317,13 +317,11 @@ impl XxHandshake { pub fn read_1( &mut self, crypto: &impl QlCrypto, - now_seconds: u64, message: &Xx1, ) -> Result<(), WireError> { if self.step != XxStep::Recv1 { return Err(WireError::InvalidState); } - message.meta.ensure_not_expired(now_seconds)?; initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; self.ensure_inbound_header(crypto, message.header)?; mix_hash_pairing_handshake( @@ -386,13 +384,11 @@ impl XxHandshake { pub fn read_2( &mut self, crypto: &impl QlCrypto, - now_seconds: u64, message: &Xx2, ) -> Result<(), WireError> { if self.step != XxStep::Recv2 { return Err(WireError::InvalidState); } - message.meta.ensure_not_expired(now_seconds)?; require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; self.ensure_inbound_header(crypto, message.header)?; mix_hash_pairing_handshake( @@ -464,13 +460,11 @@ impl XxHandshake { pub fn read_3( &mut self, crypto: &impl QlCrypto, - now_seconds: u64, message: &Xx3, ) -> Result<(), WireError> { if self.step != XxStep::Recv3 { return Err(WireError::InvalidState); } - message.meta.ensure_not_expired(now_seconds)?; require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; self.ensure_inbound_header(crypto, message.header)?; require_transport_params( @@ -538,13 +532,11 @@ impl XxHandshake { pub fn read_4( &mut self, crypto: &impl QlCrypto, - now_seconds: u64, message: &Xx4, ) -> Result<(), WireError> { if self.step != XxStep::Recv4 { return Err(WireError::InvalidState); } - message.meta.ensure_not_expired(now_seconds)?; require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; self.ensure_inbound_header(crypto, message.header)?; require_transport_params( diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 9bbcb9cd..36d80d59 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -34,7 +34,6 @@ fn stream_id(value: u64) -> StreamId { fn handshake_meta(id: u32) -> HandshakeMeta { HandshakeMeta { handshake_id: HandshakeId(id), - valid_until: 10_000 + u64::from(id), } } @@ -172,7 +171,7 @@ fn ik_handshake_rejects_tampered_handshake_meta() { let m1 = initiator_state .write_1(&crypto, handshake_meta(77)) .unwrap(); - responder_state.read_1(&crypto, 0, &m1).unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); let mut m2 = responder_state .write_2(&crypto, handshake_meta(77)) @@ -180,7 +179,7 @@ fn ik_handshake_rejects_tampered_handshake_meta() { m2.meta.handshake_id = HandshakeId(78); assert_eq!( - initiator_state.read_2(&crypto, 0, &m2), + initiator_state.read_2(&crypto, &m2), Err(WireError::InvalidPayload) ); } @@ -206,7 +205,7 @@ fn kk_handshake_rejects_tampered_handshake_header() { let m1 = initiator_state .write_1(&crypto, handshake_meta(88)) .unwrap(); - responder_state.read_1(&crypto, 0, &m1).unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); let mut m2 = responder_state .write_2(&crypto, handshake_meta(88)) @@ -214,7 +213,7 @@ fn kk_handshake_rejects_tampered_handshake_header() { m2.header = handshake_header(9, 1); assert_eq!( - initiator_state.read_2(&crypto, 0, &m2), + initiator_state.read_2(&crypto, &m2), Err(WireError::InvalidPayload) ); } @@ -236,7 +235,7 @@ fn ik_handshake_rejects_tampered_transport_params() { let m1 = initiator_state .write_1(&crypto, handshake_meta(89)) .unwrap(); - responder_state.read_1(&crypto, 0, &m1).unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); let mut m2 = responder_state .write_2(&crypto, handshake_meta(89)) @@ -244,7 +243,7 @@ fn ik_handshake_rejects_tampered_transport_params() { m2.transport_params.initial_stream_receive_window += 1; assert_eq!( - initiator_state.read_2(&crypto, 0, &m2), + initiator_state.read_2(&crypto, &m2), Err(WireError::DecryptFailed) ); } @@ -269,7 +268,7 @@ fn ik_handshake_rejects_tampered_handshake_header() { m1.header.sender = xid(9); assert_eq!( - responder_state.read_1(&crypto, 0, &m1), + responder_state.read_1(&crypto, &m1), Err(WireError::DecryptFailed) ); } @@ -298,41 +297,11 @@ fn ik_handshake_rejects_bound_remote_bundle_mismatch() { .unwrap(); assert_eq!( - responder_state.read_1(&crypto, 0, &m1), + responder_state.read_1(&crypto, &m1), Err(WireError::InvalidPayload) ); } -#[test] -fn ik_handshake_rejects_expired_message() { - let crypto = SoftwareCrypto; - let (initiator, responder) = test_identities(&crypto); - - let mut initiator_state = IkHandshake::new_initiator( - &crypto, - initiator, - responder.bundle(), - TransportParams::default(), - ); - let mut responder_state = - IkHandshake::new_responder(&crypto, responder, None, TransportParams::default()); - - let m1 = initiator_state - .write_1( - &crypto, - HandshakeMeta { - handshake_id: HandshakeId(92), - valid_until: 5, - }, - ) - .unwrap(); - - assert_eq!( - responder_state.read_1(&crypto, 6, &m1), - Err(WireError::Expired) - ); -} - #[test] fn ik_handshake_round_trip_derives_matching_transport_and_learns_remote() { let crypto = SoftwareCrypto; @@ -352,12 +321,12 @@ fn ik_handshake_round_trip_derives_matching_transport_and_learns_remote() { let m1 = initiator_state .write_1(&crypto, handshake_meta(11)) .unwrap(); - responder_state.read_1(&crypto, 0, &m1).unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); let m2 = responder_state .write_2(&crypto, handshake_meta(11)) .unwrap(); - initiator_state.read_2(&crypto, 0, &m2).unwrap(); + initiator_state.read_2(&crypto, &m2).unwrap(); let initiator_final = initiator_state.finalize(&crypto).unwrap(); let responder_final = responder_state.finalize(&crypto).unwrap(); @@ -405,12 +374,12 @@ fn ik_handshake_round_trip_derives_matching_transport_with_bound_responder() { let m1 = initiator_state .write_1(&crypto, handshake_meta(12)) .unwrap(); - responder_state.read_1(&crypto, 0, &m1).unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); let m2 = responder_state .write_2(&crypto, handshake_meta(12)) .unwrap(); - initiator_state.read_2(&crypto, 0, &m2).unwrap(); + initiator_state.read_2(&crypto, &m2).unwrap(); let initiator_final = initiator_state.finalize(&crypto).unwrap(); let responder_final = responder_state.finalize(&crypto).unwrap(); @@ -458,12 +427,12 @@ fn kk_handshake_round_trip_derives_matching_transport() { let m1 = initiator_state .write_1(&crypto, handshake_meta(21)) .unwrap(); - responder_state.read_1(&crypto, 0, &m1).unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); let m2 = responder_state .write_2(&crypto, handshake_meta(21)) .unwrap(); - initiator_state.read_2(&crypto, 0, &m2).unwrap(); + initiator_state.read_2(&crypto, &m2).unwrap(); let initiator_final = initiator_state.finalize(&crypto).unwrap(); let responder_final = responder_state.finalize(&crypto).unwrap(); @@ -509,7 +478,7 @@ fn kk_handshake_rejects_tampered_transport_params() { let m1 = initiator_state .write_1(&crypto, handshake_meta(22)) .unwrap(); - responder_state.read_1(&crypto, 0, &m1).unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); let mut m2 = responder_state .write_2(&crypto, handshake_meta(22)) @@ -517,7 +486,7 @@ fn kk_handshake_rejects_tampered_transport_params() { m2.transport_params.initial_stream_receive_window += 1; assert_eq!( - initiator_state.read_2(&crypto, 0, &m2), + initiator_state.read_2(&crypto, &m2), Err(WireError::DecryptFailed) ); } @@ -539,7 +508,7 @@ fn xx_handshake_rejects_tampered_pairing_id() { m1.header.pairing_id = pairing_id(8); assert_eq!( - responder_state.read_1(&crypto, 0, &m1), + responder_state.read_1(&crypto, &m1), Err(WireError::InvalidPayload) ); } @@ -566,12 +535,12 @@ fn xx_handshake_rejects_repeated_transport_param_change() { let m1 = initiator_state .write_1(&crypto, handshake_meta(32)) .unwrap(); - responder_state.read_1(&crypto, 0, &m1).unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); let m2 = responder_state .write_2(&crypto, handshake_meta(32)) .unwrap(); - initiator_state.read_2(&crypto, 0, &m2).unwrap(); + initiator_state.read_2(&crypto, &m2).unwrap(); let mut m3 = initiator_state .write_3(&crypto, handshake_meta(32)) @@ -579,7 +548,7 @@ fn xx_handshake_rejects_repeated_transport_param_change() { m3.transport_params.initial_stream_receive_window += 1; assert_eq!( - responder_state.read_3(&crypto, 0, &m3), + responder_state.read_3(&crypto, &m3), Err(WireError::InvalidPayload) ); } @@ -607,25 +576,25 @@ fn xx_handshake_round_trip_derives_matching_transport_and_learns_remote() { let m1 = initiator_state .write_1(&crypto, handshake_meta(33)) .unwrap(); - responder_state.read_1(&crypto, 0, &m1).unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); let m2 = responder_state .write_2(&crypto, handshake_meta(33)) .unwrap(); - initiator_state.read_2(&crypto, 0, &m2).unwrap(); + initiator_state.read_2(&crypto, &m2).unwrap(); assert_eq!(initiator_state.remote_bundle(), Some(&responder.bundle())); assert!(responder_state.remote_bundle().is_none()); let m3 = initiator_state .write_3(&crypto, handshake_meta(33)) .unwrap(); - responder_state.read_3(&crypto, 0, &m3).unwrap(); + responder_state.read_3(&crypto, &m3).unwrap(); assert_eq!(responder_state.remote_bundle(), Some(&initiator.bundle())); let m4 = responder_state .write_4(&crypto, handshake_meta(33)) .unwrap(); - initiator_state.read_4(&crypto, 0, &m4).unwrap(); + initiator_state.read_4(&crypto, &m4).unwrap(); let initiator_final = initiator_state.finalize(&crypto).unwrap(); let responder_final = responder_state.finalize(&crypto).unwrap(); @@ -771,10 +740,10 @@ fn protocol_record_size_breakdown() { IkHandshake::new_responder(&crypto, responder.clone(), None, TransportParams::default()); let ik1 = ik_initiator.write_1(&crypto, handshake_meta(101)).unwrap(); - ik_responder.read_1(&crypto, 0, &ik1).unwrap(); + ik_responder.read_1(&crypto, &ik1).unwrap(); let ik2 = ik_responder.write_2(&crypto, handshake_meta(101)).unwrap(); - ik_initiator.read_2(&crypto, 0, &ik2).unwrap(); + ik_initiator.read_2(&crypto, &ik2).unwrap(); let ik1 = QlHandshakeRecord::Ik1(ik1); let ik2 = QlHandshakeRecord::Ik2(ik2); @@ -793,10 +762,10 @@ fn protocol_record_size_breakdown() { ); let kk1 = kk_initiator.write_1(&crypto, handshake_meta(201)).unwrap(); - kk_responder.read_1(&crypto, 0, &kk1).unwrap(); + kk_responder.read_1(&crypto, &kk1).unwrap(); let kk2 = kk_responder.write_2(&crypto, handshake_meta(201)).unwrap(); - kk_initiator.read_2(&crypto, 0, &kk2).unwrap(); + kk_initiator.read_2(&crypto, &kk2).unwrap(); let kk1 = QlHandshakeRecord::Kk1(kk1); let kk2 = QlHandshakeRecord::Kk2(kk2); @@ -816,16 +785,16 @@ fn protocol_record_size_breakdown() { ); let xx1 = xx_initiator.write_1(&crypto, handshake_meta(301)).unwrap(); - xx_responder.read_1(&crypto, 0, &xx1).unwrap(); + xx_responder.read_1(&crypto, &xx1).unwrap(); let xx2 = xx_responder.write_2(&crypto, handshake_meta(301)).unwrap(); - xx_initiator.read_2(&crypto, 0, &xx2).unwrap(); + xx_initiator.read_2(&crypto, &xx2).unwrap(); let xx3 = xx_initiator.write_3(&crypto, handshake_meta(301)).unwrap(); - xx_responder.read_3(&crypto, 0, &xx3).unwrap(); + xx_responder.read_3(&crypto, &xx3).unwrap(); let xx4 = xx_responder.write_4(&crypto, handshake_meta(301)).unwrap(); - xx_initiator.read_4(&crypto, 0, &xx4).unwrap(); + xx_initiator.read_4(&crypto, &xx4).unwrap(); let xx1 = QlHandshakeRecord::Xx1(xx1); let xx2 = QlHandshakeRecord::Xx2(xx2); From 0e0197adecf3a65c75da3e871de9422722a4d3de Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 21 Apr 2026 08:50:58 -0400 Subject: [PATCH 287/304] ql-fsm: remove expiration and fsmtime --- ql-fsm/src/error.rs | 2 -- ql-fsm/src/fsm.rs | 20 +++++--------------- ql-fsm/src/handshake/ik.rs | 12 ++++-------- ql-fsm/src/handshake/kk.rs | 12 ++++-------- ql-fsm/src/handshake/mod.rs | 17 ++++------------- ql-fsm/src/handshake/xx.rs | 18 +++++++----------- ql-fsm/src/lib.rs | 28 ++++++++-------------------- ql-fsm/src/replay_cache.rs | 23 ----------------------- ql-fsm/src/state.rs | 5 ++--- ql-fsm/src/tests/mod.rs | 20 +++++--------------- 10 files changed, 39 insertions(+), 118 deletions(-) delete mode 100644 ql-fsm/src/replay_cache.rs diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs index 3ded1ffc..99b829d6 100644 --- a/ql-fsm/src/error.rs +++ b/ql-fsm/src/error.rs @@ -18,7 +18,6 @@ pub enum ReceiveError { expected: PairingId, actual: PairingId, }, - Replay, } impl Display for ReceiveError { @@ -37,7 +36,6 @@ impl Display for ReceiveError { "invalid pairing id: expected {expected}, actual {actual}" ) } - Self::Replay => f.write_str("replay"), } } } diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs index 6a32b821..cc342b53 100644 --- a/ql-fsm/src/fsm.rs +++ b/ql-fsm/src/fsm.rs @@ -1,6 +1,6 @@ use std::{ collections::VecDeque, - time::{Duration, Instant}, + time::Instant, }; use bytes::Bytes; @@ -150,7 +150,7 @@ pub fn receive( let mut emit = EventSink::new(events); conn.session - .receive(state.now.instant, seq, frames, &mut emit); + .receive(state.now, seq, frames, &mut emit); emit.termination }; @@ -176,7 +176,7 @@ pub fn on_timer(fsm: &mut QlFsm) { }; let mut emit = EventSink::new(events); - conn.session.on_timer(state.now.instant, &mut emit); + conn.session.on_timer(state.now, &mut emit); } pub fn next_deadline(fsm: &QlFsm) -> Option { @@ -204,7 +204,7 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option Option { pub fn emit_peer_status(fsm: &mut QlFsm, status: crate::PeerStatus) { fsm.events.push_back(Event::PeerStatusChanged(status)); } - -pub fn deadline_after_secs(now_secs: u64, duration: Duration) -> u64 { - now_secs.saturating_add(duration_to_secs(duration)) -} - -fn duration_to_secs(duration: Duration) -> u64 { - duration - .as_secs() - .saturating_add(u64::from(duration.subsec_nanos() > 0)) -} diff --git a/ql-fsm/src/handshake/ik.rs b/ql-fsm/src/handshake/ik.rs index 0566335c..45456ad2 100644 --- a/ql-fsm/src/handshake/ik.rs +++ b/ql-fsm/src/handshake/ik.rs @@ -1,8 +1,7 @@ use ql_wire::{self as wire, Ik1, Ik2, PeerBundle, QlCrypto, QlHandshakeRecord}; use super::{ - emit_peer_status, enqueue_handshake, finish_handshake, is_replayed_handshake_start, - reset_connected_session_if_needed, + emit_peer_status, enqueue_handshake, finish_handshake, reset_connected_session_if_needed, }; use crate::{ state::{IkInitiatorState, LinkState, SessionTransport}, @@ -23,7 +22,7 @@ pub fn start_initiator(fsm: &mut QlFsm, crypto: &impl QlCrypto, peer: PeerBundle handshake_id: meta.handshake_id, initial_ephemeral: message.ephemeral.clone(), handshake, - deadline: fsm.state.now.instant + fsm.config.handshake_timeout, + deadline: fsm.state.now + fsm.config.handshake_timeout, }); enqueue_handshake(fsm, QlHandshakeRecord::Ik1(message)); emit_peer_status(fsm, fsm.state.link.status()); @@ -37,9 +36,6 @@ pub fn handle_ik1( if should_ignore_inbound(fsm, message) { return Ok(()); } - if is_replayed_handshake_start(fsm, message.meta) { - return Ok(()); - } if message.header.recipient != fsm.identity.xid { return Err(ReceiveError::InvalidXid); } @@ -57,7 +53,7 @@ pub fn handle_ik1( fsm.state.peer.clone(), super::local_transport_params(fsm), ); - handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; + handshake.read_1(crypto, message)?; let outbound = handshake.write_2(crypto, message.meta)?; let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); finish_handshake(fsm, transport, remote_bundle)?; @@ -82,7 +78,7 @@ pub fn handle_ik2( state .handshake - .read_2(crypto, fsm.state.now.unix_secs, message)?; + .read_2(crypto, message)?; } let LinkState::IkInitiator(state) = fsm.state.link.take() else { diff --git a/ql-fsm/src/handshake/kk.rs b/ql-fsm/src/handshake/kk.rs index b7192c5b..7ad035a5 100644 --- a/ql-fsm/src/handshake/kk.rs +++ b/ql-fsm/src/handshake/kk.rs @@ -1,8 +1,7 @@ use ql_wire::{self as wire, Kk1, Kk2, PeerBundle, QlCrypto, QlHandshakeRecord}; use super::{ - emit_peer_status, enqueue_handshake, finish_handshake, is_replayed_handshake_start, - reset_connected_session_if_needed, + emit_peer_status, enqueue_handshake, finish_handshake, reset_connected_session_if_needed, }; use crate::{ state::{KkInitiatorState, LinkState, SessionTransport}, @@ -23,7 +22,7 @@ pub fn start_initiator(fsm: &mut QlFsm, crypto: &impl QlCrypto, peer: PeerBundle handshake_id: meta.handshake_id, initial_ephemeral: message.ephemeral.clone(), handshake, - deadline: fsm.state.now.instant + fsm.config.handshake_timeout, + deadline: fsm.state.now + fsm.config.handshake_timeout, }); enqueue_handshake(fsm, QlHandshakeRecord::Kk1(message)); emit_peer_status(fsm, fsm.state.link.status()); @@ -37,9 +36,6 @@ pub fn handle_kk1( if should_ignore_inbound(fsm, message) { return Ok(()); } - if is_replayed_handshake_start(fsm, message.meta) { - return Ok(()); - } let Some(peer) = fsm.state.peer.clone() else { return Err(ReceiveError::InvalidPayload); @@ -56,7 +52,7 @@ pub fn handle_kk1( peer, super::local_transport_params(fsm), ); - handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; + handshake.read_1(crypto, message)?; let outbound = handshake.write_2(crypto, message.meta)?; let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); finish_handshake(fsm, transport, remote_bundle)?; @@ -81,7 +77,7 @@ pub fn handle_kk2( state .handshake - .read_2(crypto, fsm.state.now.unix_secs, message)?; + .read_2(crypto, message)?; } let LinkState::KkInitiator(state) = fsm.state.link.take() else { diff --git a/ql-fsm/src/handshake/mod.rs b/ql-fsm/src/handshake/mod.rs index 8aee523f..34bc0648 100644 --- a/ql-fsm/src/handshake/mod.rs +++ b/ql-fsm/src/handshake/mod.rs @@ -7,7 +7,7 @@ use ql_wire::{ }; use crate::{ - fsm::{deadline_after_secs, emit_peer_status}, + fsm::emit_peer_status, session::{SessionConfig, SessionFsm, StreamParity}, state::{ConnectedState, LinkState, SessionTransport}, Event, NoPeerError, QlFsm, ReceiveError, @@ -35,10 +35,7 @@ pub fn handle_connect_xx(fsm: &mut QlFsm, token: PairingToken, crypto: &impl QlC pub fn next_handshake_meta(fsm: &mut QlFsm) -> HandshakeMeta { let handshake_id = wire::HandshakeId(fsm.state.next_control_id); fsm.state.next_control_id = fsm.state.next_control_id.wrapping_add(1); - HandshakeMeta { - handshake_id, - valid_until: deadline_after_secs(fsm.state.now.unix_secs, fsm.config.handshake_timeout), - } + HandshakeMeta { handshake_id } } pub fn enqueue_handshake(fsm: &mut QlFsm, record: QlHandshakeRecord) { @@ -61,12 +58,6 @@ pub fn prepare_for_outbound_connect(fsm: &mut QlFsm) { reset_connected_session_if_needed(fsm); } -pub fn is_replayed_handshake_start(fsm: &mut QlFsm, meta: HandshakeMeta) -> bool { - fsm.state - .replay_cache - .check_and_store_valid_until(meta, fsm.state.now.unix_secs) -} - pub fn handle_handshake_record( fsm: &mut QlFsm, crypto: &impl QlCrypto, @@ -88,7 +79,7 @@ pub fn handle_timer(fsm: &mut QlFsm) { let Some(deadline) = fsm.state.link.handshake_deadline() else { return; }; - if deadline > fsm.state.now.instant { + if deadline > fsm.state.now { return; } @@ -133,7 +124,7 @@ pub fn finish_handshake( .remote_transport_params .initial_stream_receive_window, }, - fsm.state.now.instant, + fsm.state.now, ); fsm.state.link = LinkState::Connected(ConnectedState { transport, session }); emit_peer_status(fsm, fsm.state.link.status()); diff --git a/ql-fsm/src/handshake/xx.rs b/ql-fsm/src/handshake/xx.rs index a03594ed..1219507b 100644 --- a/ql-fsm/src/handshake/xx.rs +++ b/ql-fsm/src/handshake/xx.rs @@ -1,8 +1,7 @@ use ql_wire::{self as wire, PairingToken, QlCrypto, QlHandshakeRecord, Xx1, Xx2, Xx3, Xx4}; use super::{ - emit_peer_status, enqueue_handshake, finish_handshake, is_replayed_handshake_start, - reset_connected_session_if_needed, + emit_peer_status, enqueue_handshake, finish_handshake, reset_connected_session_if_needed, }; use crate::{ state::{LinkState, SessionTransport, XxInitiatorState, XxResponderState}, @@ -23,7 +22,7 @@ pub fn start_initiator(fsm: &mut QlFsm, crypto: &impl QlCrypto, token: PairingTo handshake_id: meta.handshake_id, initial_ephemeral: message.ephemeral.clone(), handshake, - deadline: fsm.state.now.instant + fsm.config.handshake_timeout, + deadline: fsm.state.now + fsm.config.handshake_timeout, }); enqueue_handshake(fsm, QlHandshakeRecord::Xx1(message)); emit_peer_status(fsm, fsm.state.link.status()); @@ -37,9 +36,6 @@ pub fn handle_xx1( if should_ignore_inbound(fsm, crypto, message) { return Ok(()); } - if is_replayed_handshake_start(fsm, message.meta) { - return Err(ReceiveError::Replay); - } match fsm.state.armed_pairing_token { Some(expected) if expected.id(crypto) != message.header.pairing_id => { Err(ReceiveError::InvalidPairingId { @@ -56,12 +52,12 @@ pub fn handle_xx1( token, super::local_transport_params(fsm), ); - handshake.read_1(crypto, fsm.state.now.unix_secs, message)?; + handshake.read_1(crypto, message)?; let outbound = handshake.write_2(crypto, message.meta)?; fsm.state.link = LinkState::XxResponder(XxResponderState { handshake, handshake_meta: message.meta, - deadline: fsm.state.now.instant + fsm.config.handshake_timeout, + deadline: fsm.state.now + fsm.config.handshake_timeout, }); fsm.state.handshake = None; enqueue_handshake(fsm, QlHandshakeRecord::Xx2(outbound)); @@ -87,7 +83,7 @@ pub fn handle_xx2( state .handshake - .read_2(crypto, fsm.state.now.unix_secs, message)?; + .read_2(crypto, message)?; let outbound = state.handshake.write_3(crypto, message.meta)?; fsm.state.handshake = None; enqueue_handshake(fsm, QlHandshakeRecord::Xx3(outbound)); @@ -111,7 +107,7 @@ pub fn handle_xx3( state .handshake - .read_3(crypto, fsm.state.now.unix_secs, message)?; + .read_3(crypto, message)?; let handshake_meta = state.handshake_meta; let LinkState::XxResponder(mut state) = fsm.state.link.take() else { unreachable!("active XX responder was checked above"); @@ -140,7 +136,7 @@ pub fn handle_xx4( state .handshake - .read_4(crypto, fsm.state.now.unix_secs, message)?; + .read_4(crypto, message)?; } let LinkState::XxInitiator(state) = fsm.state.link.take() else { diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index c824b7ed..502fea0f 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -21,7 +21,6 @@ mod error; mod fsm; mod handshake; -pub(crate) mod replay_cache; mod session; pub(crate) mod state; #[cfg(test)] @@ -41,19 +40,9 @@ use ql_wire::{ pub use session::{SessionEvent, StreamReadIter, StreamWriter}; use crate::{ - replay_cache::ReplayCache, state::{LinkState, QlFsmState}, }; -/// time input for `QlFsm` -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct FsmTime { - /// monotonic time used for local deadlines - pub instant: Instant, - /// wall-clock unix time used for expiration checks - pub unix_secs: u64, -} - /// connection state for the bound peer #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PeerStatus { @@ -204,12 +193,11 @@ pub struct QlFsm { impl QlFsm { /// creates a new `QlFsm` - pub fn new(config: QlFsmConfig, identity: QlIdentity, now: FsmTime) -> Self { + pub fn new(config: QlFsmConfig, identity: QlIdentity, now: Instant) -> Self { Self { config, identity, state: QlFsmState { - replay_cache: ReplayCache::default(), next_control_id: 1, peer: None, armed_pairing_token: None, @@ -246,19 +234,19 @@ impl QlFsm { } /// starts an outbound xx handshake using the supplied pairing token - pub fn connect_xx(&mut self, now: FsmTime, token: PairingToken, crypto: &impl QlCrypto) { + pub fn connect_xx(&mut self, now: Instant, token: PairingToken, crypto: &impl QlCrypto) { self.state.now = now; fsm::handle_connect_xx(self, token, crypto); } /// starts an IK handshake with the currently bound peer - pub fn connect_ik(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { + pub fn connect_ik(&mut self, now: Instant, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { self.state.now = now; fsm::handle_connect_ik(self, crypto) } /// starts a KK handshake with the currently bound peer - pub fn connect_kk(&mut self, now: FsmTime, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { + pub fn connect_kk(&mut self, now: Instant, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { self.state.now = now; fsm::handle_connect_kk(self, crypto) } @@ -266,7 +254,7 @@ impl QlFsm { /// handles one inbound wire message pub fn receive( &mut self, - now: FsmTime, + now: Instant, bytes: Vec, crypto: &impl QlCrypto, ) -> Result<(), ReceiveError> { @@ -280,7 +268,7 @@ impl QlFsm { } /// advances time-based state - pub fn on_timer(&mut self, now: FsmTime) { + pub fn on_timer(&mut self, now: Instant) { self.state.now = now; fsm::on_timer(self); } @@ -304,7 +292,7 @@ impl QlFsm { /// if it is `None`, the record is fire-and-forget pub fn take_next_write( &mut self, - now: FsmTime, + now: Instant, crypto: &impl QlCrypto, ) -> Option { self.state.now = now; @@ -314,7 +302,7 @@ impl QlFsm { /// completes a `SessionWriteId` from `take_next_write` with the transport outcome /// /// call this at most once for each returned `SessionWriteId` - pub fn complete_write(&mut self, now: FsmTime, write_id: WriteId, success: bool) { + pub fn complete_write(&mut self, now: Instant, write_id: WriteId, success: bool) { self.state.now = now; fsm::complete_write(self, write_id, success); } diff --git a/ql-fsm/src/replay_cache.rs b/ql-fsm/src/replay_cache.rs deleted file mode 100644 index 547c0507..00000000 --- a/ql-fsm/src/replay_cache.rs +++ /dev/null @@ -1,23 +0,0 @@ -use std::collections::{hash_map::Entry, HashMap}; - -use ql_wire::{HandshakeId, HandshakeMeta}; - -#[derive(Debug, Default)] -pub struct ReplayCache { - valid_until_by_id: HashMap, -} - -impl ReplayCache { - pub fn check_and_store_valid_until(&mut self, meta: HandshakeMeta, now_secs: u64) -> bool { - self.valid_until_by_id - .retain(|_, stored_valid_until| *stored_valid_until > now_secs); - - match self.valid_until_by_id.entry(meta.handshake_id) { - Entry::Occupied(_) => true, - Entry::Vacant(entry) => { - entry.insert(meta.valid_until); - false - } - } - } -} diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs index 4cb403b3..8268bc16 100644 --- a/ql-fsm/src/state.rs +++ b/ql-fsm/src/state.rs @@ -5,16 +5,15 @@ use ql_wire::{ PairingToken, PeerBundle, QlHandshakeRecord, SessionKey, TransportParams, XxHandshake, }; -use crate::{replay_cache::ReplayCache, session::SessionFsm, FsmTime, NoSessionError, PeerStatus}; +use crate::{session::SessionFsm, NoSessionError, PeerStatus}; pub struct QlFsmState { - pub replay_cache: ReplayCache, pub next_control_id: u32, pub peer: Option, pub armed_pairing_token: Option, pub handshake: Option, pub link: LinkState, - pub now: FsmTime, + pub now: Instant, } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index bba4b599..acb0005f 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -12,7 +12,7 @@ use ql_wire::{ use crate::{ session::{SessionConfig, SessionFsm, StreamParity}, state::{ConnectedState, LinkState, SessionTransport}, - Event, FsmTime, NoPeerError, OutboundWrite, QlFsm, QlFsmConfig, WriteId, + Event, NoPeerError, OutboundWrite, QlFsm, QlFsmConfig, WriteId, }; type TestCrypto = SoftwareCrypto; @@ -39,7 +39,6 @@ struct Node { struct Harness { now: Instant, - unix_secs: u64, a: Node, b: Node, } @@ -72,20 +71,15 @@ impl Harness { ) -> Self { let (identity_a, identity_b) = test_identities(&SoftwareCrypto); let now = Instant::now(); - let time = FsmTime { - instant: now, - unix_secs: 1_700_000_000, - }; let mut harness = Self { now, - unix_secs: time.unix_secs, a: Node { - fsm: QlFsm::new(config_a, identity_a.clone(), time), + fsm: QlFsm::new(config_a, identity_a.clone(), now), crypto: SoftwareCrypto, }, b: Node { - fsm: QlFsm::new(config_b, identity_b.clone(), time), + fsm: QlFsm::new(config_b, identity_b.clone(), now), crypto: SoftwareCrypto, }, }; @@ -142,16 +136,12 @@ impl Harness { harness } - fn time(&self) -> FsmTime { - FsmTime { - instant: self.now, - unix_secs: self.unix_secs, - } + fn time(&self) -> Instant { + self.now } fn advance(&mut self, duration: Duration) { self.now += duration; - self.unix_secs = self.unix_secs.saturating_add(duration.as_secs()); } fn node(&self, side: Side) -> &Node { From cf78c8a5558cab0ad95d1329733dd3d6ea7a6ee2 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 21 Apr 2026 08:53:09 -0400 Subject: [PATCH 288/304] ql-runtime: remove usage of FsmTime --- ql-runtime/src/driver/mod.rs | 32 +++++++++----------------------- ql-runtime/src/driver/test.rs | 2 +- 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 47233493..0ae54699 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -10,12 +10,12 @@ use std::{ future::Future, pin::{pin, Pin}, task::{Context, Poll}, - time::{Duration, Instant, SystemTime, UNIX_EPOCH}, + time::Instant, }; use async_channel::Recv; use futures_lite::future::{poll_fn, yield_now}; -use ql_fsm::{Event, FsmTime, QlFsm, WriteId}; +use ql_fsm::{Event, QlFsm, WriteId}; use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; use self::state::{DriverState, DriverStreamIo, InboundIo, InboundWriteResult, OutboundIo}; @@ -38,7 +38,7 @@ impl Runtime

{ tx, } = self; - let mut fsm = QlFsm::new(config.fsm, identity, now()); + let mut fsm = QlFsm::new(config.fsm, identity, Instant::now()); let mut state = DriverState { streams: HashMap::new(), @@ -82,7 +82,7 @@ impl Runtime

{ } DriverStep::Inbound(bytes) => { log::trace!("received transport frame: len={}", bytes.len()); - if let Err(e) = fsm.receive(now(), bytes, &platform) { + if let Err(e) = fsm.receive(Instant::now(), bytes, &platform) { log::info!("receive rejected frame: error={e:?}"); platform.handle_recv_error(e); } @@ -98,7 +98,7 @@ impl Runtime

{ } DriverStep::TimerExpired => { log::trace!("timer expired"); - fsm.on_timer(now()); + fsm.on_timer(Instant::now()); } DriverStep::Closed => { log::debug!( @@ -186,7 +186,7 @@ impl DriverState { } Command::Connect => { log::info!("starting IK connect"); - if fsm.connect_ik(now(), platform).is_err() { + if fsm.connect_ik(Instant::now(), platform).is_err() { log::warn!("IK connect ignored: no bound peer"); } } @@ -200,7 +200,7 @@ impl DriverState { } Command::StartPairing { token } => { log::info!(" starting XX pairing"); - fsm.connect_xx(now(), token, platform); + fsm.connect_xx(Instant::now(), token, platform); } Command::CloseSession { code } => { log::info!("closing session: code={code:?}"); @@ -290,7 +290,7 @@ impl DriverState { fn drive_write_completed(fsm: &mut QlFsm, session_write_id: Option, success: bool) { if let Some(write_id) = session_write_id { - fsm.complete_write(now(), write_id, success); + fsm.complete_write(Instant::now(), write_id, success); } } @@ -521,7 +521,7 @@ impl DriverState { ) -> bool { let mut filled = false; while in_flight.len() < self.max_concurrent_message_writes { - let Some(write) = fsm.take_next_write(now(), platform) else { + let Some(write) = fsm.take_next_write(Instant::now(), platform) else { break; }; filled = true; @@ -607,17 +607,3 @@ impl DriverState { } } } - -fn now() -> FsmTime { - FsmTime { - instant: Instant::now(), - unix_secs: unix_now_secs(), - } -} - -fn unix_now_secs() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or(Duration::ZERO) - .as_secs() -} diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index bb93203d..9b325c08 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -59,7 +59,7 @@ fn new_driver_state() -> (DriverState, QlFsm) { QlFsm::new( ql_fsm::QlFsmConfig::default(), test_identity(&SoftwareCrypto), - now(), + Instant::now(), ), ) } From f1ef902695c22fdb4bce99c1eced2ecdca620efc Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 24 Apr 2026 08:10:14 -0400 Subject: [PATCH 289/304] update design document --- QL_V2.md | 174 ++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 134 insertions(+), 40 deletions(-) diff --git a/QL_V2.md b/QL_V2.md index 8d91361c..32adf645 100644 --- a/QL_V2.md +++ b/QL_V2.md @@ -27,10 +27,13 @@ QLv2 is not: - `peer`: one QLv2 endpoint - `XID`: a stable 16-byte peer identifier - `peer bundle`: public peer information: `version`, `xid`, `capabilities`, and ML-KEM public key +- `pairing token`: an out-of-band secret that authorizes an `XX` pairing attempt +- `pairing_id`: the visible identifier derived from a pairing token and carried on `XX` records - `session`: one live encrypted channel with directional keys and directional connection IDs - `record`: one complete QLv2 wire unit - `frame`: one logical item inside a session record - `stream`: one duplex byte stream inside a session +- `route_id`: the application route carried once on the first initiator `StreamData` frame for a stream - `stream origin`: the peer that opened the stream - `origin lane`: bytes sent by the stream origin - `return lane`: bytes sent back toward the stream origin @@ -44,29 +47,61 @@ QLv2 has two record types: Handshake records are large because they carry ML-KEM material. Session records are small and can carry multiple frames, including frames for different streams. -Handshake records are routed by peer identity via visible `sender` and `recipient` XIDs. Session records are routed by `connection_id`. +All whole-record sizes below include the outer 2-byte record header: `version` plus `record type`. QLv2 uses QUIC-style variable-length integers for several steady-state fields. A varint is 1, 2, 4, or 8 bytes and can represent values in the range `0..2^62-1`. This keeps small values compact while allowing very large record and stream number spaces. Today, varints are used for: - session record `seq` -- `Ack.base_seq` -- `StreamData` frame length +- `Ack.largest_acked` +- `Ack.block_count` +- `Ack.first_range_len` +- `Ack.gap` +- `Ack.range_len` - `StreamData.stream_id` - `StreamData.offset` +- `StreamData.route_id` when present +- `StreamData.bytes_len` - `StreamWindow.stream_id` - `StreamWindow.maximum_offset` - `StreamClose.stream_id` ### Handshake records -| Record | Size | Used when | Purpose | -| --- | ---: | --- | --- | -| `IK1` | 4793 bytes | initiator already knows the responder bundle | start a handshake toward a known responder | -| `IK2` | 3203 bytes | second message of `IK` | finish the responder side of the handshake and establish the session | -| `KK1` | 3187 bytes | both peers already know each other | start a handshake between already-known peers | -| `KK2` | 3203 bytes | second message of `KK` | finish the responder side of the handshake and establish the session | +QLv2 has two routed known-peer handshakes and one pairing handshake: + +- `IK` and `KK` carry a visible `sender` and `recipient` XID +- `XX` carries a visible `pairing_id` + +#### IK + +Used when the initiator already knows the responder bundle. + +| Record | Size | Purpose | +| --- | ---: | --- | +| `IK1` | 4785 bytes | start a handshake toward a known responder | +| `IK2` | 3195 bytes | complete `IK` and establish the session | + +#### KK + +Used when both peers already know each other. + +| Record | Size | Purpose | +| --- | ---: | --- | +| `KK1` | 3179 bytes | start a handshake between already-known peers | +| `KK2` | 3195 bytes | complete `KK` and establish the session | + +#### XX + +Used when the initiator has received an out of band pairing token, and neither peer knows each other. + +| Record | Size | Purpose | +| --- | ---: | --- | +| `XX1` | 1595 bytes | start pairing | +| `XX2` | 3201 bytes | send responder static identity and ciphertext | +| `XX3` | 3217 bytes | send initiator static identity and ciphertext | +| `XX4` | 1611 bytes | complete `XX` and establish the session | ### Session records @@ -90,34 +125,44 @@ The visible session header is authenticated as AEAD AAD but is not encrypted. | Frame | Size | Purpose | | --- | ---: | --- | | `Ping` | 1 byte | keep the session alive when idle | -| `Ack` | `10..17` bytes | acknowledge received session records | +| `Unpair` | 1 byte | forget the currently bound peer and abort the session | +| `Ack` | `4+` bytes | acknowledge received session records with ACK ranges | | `StreamWindow` | `3..17` bytes | extend per-stream send credit | | `StreamClose` | `5..12` bytes | abort one stream lane or both lanes | | `Close` | 3 bytes | close the whole session | -| `StreamData` | `4..26 + payload_len` bytes | carry stream bytes and optional `fin` | +| `StreamData` | `5..34 + payload_len` bytes | carry stream bytes, optional opener route, and optional `fin` | `StreamData` is the main steady-state frame: -`1 kind + varint(frame_len) + varint(stream_id) + varint(offset) + 1 fin + payload_len` +`1 kind + varint(stream_id) + varint(offset) + 1 flags + optional varint(route_id) + varint(bytes_len) + payload_len` + +The flags byte carries: + +- `fin` +- `header present` -Some useful minimum sizes for single-frame records: +Some useful minimum whole-record sizes for single-frame records: | Record | Size | Meaning | | --- | ---: | --- | | `Ping` only | 36 bytes | idle keepalive | +| `Unpair` only | 36 bytes | peer unpair | +| `Ack` only | 39 bytes | smallest selective ack | | `Close` only | 38 bytes | session shutdown | -| empty `StreamData` | 40 bytes | open or finish a stream lane without payload bytes | +| empty `StreamData` without route header | 40 bytes | empty data or empty `fin` on an existing stream | +| empty opener `StreamData` with a 1-byte `route_id` | 41 bytes | open a new stream without payload bytes | ## Handshake -QLv2 currently supports two 2-message Noise-style handshake patterns: +QLv2 currently supports three Noise-style handshake patterns: -- `IK`: the initiator already knows the responder bundle -- `KK`: both peers already know each other +- `IK`: 2 messages, initiator already knows the responder bundle +- `KK`: 2 messages, both peers already know each other +- `XX`: 4 messages, peers authenticate through an out-of-band pairing token and exchange static identity during the handshake -The handshake covers peer authentication and session establishment. There is no separate peer-level pairing record. +The handshake covers peer authentication and session establishment. -The handshake does five things: +Each successful handshake does five things: 1. authenticate which peer we are talking to 2. derive a fresh transmit key and receive key @@ -125,25 +170,38 @@ The handshake does five things: 4. bind transport parameters into the transcript 5. produce a `handshake_hash` for the completed exchange -First-contact identity exchange is still partly out of band. `IK` removes the need for the responder to know the initiator in advance, but the initiator still needs the responder bundle before it can start. +Today the only transport parameter is: + +- initial per-stream receive window + +Future transport parameters could include session-wide byte credit or record-size limits. + +Each handshake attempt carries: + +- `handshake_id`: identifies one attempt and lets stale replies be ignored +- transport parameters -Each handshake carries: +`valid_until` is not currently part of the wire format. Handshake attempts instead expire by local timer. -- `handshake_id`: identifies one handshake attempt -- `valid_until`: expiration time for that attempt -- transport parameters: today this is initial per-stream receive credit +### Pattern summary -Handshake rules: +- `IK` lets the responder learn the initiator during handshake completion. The initiator still needs the responder bundle before it can start. +- `KK` requires both peers to already know each other. +- `XX` requires the responder to be armed for pairing and to recognize the visible `pairing_id` derived from the expected pairing token. -- handshake start messages are replay-checked by `handshake_id` -- expired handshake messages are rejected -- simultaneous starts are resolved deterministically: `IK` beats `KK`; otherwise the initial ephemeral key breaks ties -- handshake attempts time out and are dropped rather than being retransmitted in place +### Handshake rules -Session establishment is slightly asymmetric: +- attempts are identified by `handshake_id` +- handshake messages are not retransmitted in place +- simultaneous starts must converge deterministically +- if `IK` and `KK` race, `IK` wins +- same-pattern races break ties by ordering the initial ephemeral public keys +- `XX` requires out-of-band authorization and uses visible `pairing_id` for lookup -- the responder enters the connected state when it processes message 1 and constructs message 2 -- the initiator enters the connected state when it receives message 2 +### Session establishment points + +- `IK` and `KK` complete after message 2 (1 RT) +- `XX` completes after 4 messages (2 RTT) ## Session Model @@ -173,6 +231,23 @@ An `Ack` means the peer: - decrypted it with the current session key - accepted its `seq` +The ACK wire format is range-based, not bitmap-based. It carries: + +- `largest_acked` +- `block_count` +- `first_range_len` +- zero or more `(gap, range_len)` blocks + +Ranges are encoded from highest sequence numbers down to lowest sequence numbers. + +Receivers track a recent accepted record window so they can: + +- reject duplicates +- ignore records that are too old +- emit selective ACK ranges + +Pending ACK state is also range-based. If there are too many disjoint ranges, older low ranges may be dropped. An emitted ACK may also be truncated by the remaining record budget. + Retransmission works at the frame level: - every emitted session record gets a fresh `seq` @@ -182,7 +257,9 @@ Retransmission works at the frame level: QLv2 does not resend the same logical record identity. -There is no explicit `Nack` frame. Loss is inferred either from timeout or from later selective `Ack` state that makes it clear a record was not accepted. +There is no explicit `Nack` frame. Loss is inferred from timeout or from later ACK state that no longer includes a record. + +Pure ACK-only records are fire-and-forget: they are not themselves retransmitted. Example: @@ -207,10 +284,6 @@ If `seq = 10` is considered lost, its frame is restored and packed again with a | `StreamData` | `stream_id=4 offset=0 bytes="hello"` | | `StreamData` | `stream_id=4 offset=5 bytes=" world"` | -Receivers track a recent record window so they can: -- reject duplicates -- send selective acks with `base_seq + bitmap` - ## Streams Streams are the application primitive. @@ -229,15 +302,25 @@ Important properties: - different streams can make progress independently - record loss on one stream does not block unrelated streams -A stream opens implicitly on the first valid `StreamData` or `StreamClose` for that remote stream ID. There is no separate open frame. +There is no separate open frame. + +Locally, opening a stream allocates: + +- a new `stream_id` +- an application `route_id` + +On the wire, the stream opener carries that `route_id` once, in the first initiator `StreamData` frame at `offset = 0`, using the optional `StreamHeader`. `StreamData` carries: - `stream_id` - `offset` +- optional `StreamHeader { route_id }` - `fin` - bytes +`StreamHeader` is only valid on the first initiator `StreamData` frame for a stream, at `offset = 0`. + `fin` is graceful completion of one lane. It says "no more bytes on this lane" without aborting the other lane. ## Flow Control @@ -246,13 +329,15 @@ Flow control is per stream. During the handshake, each peer advertises an initial per-stream receive window. That becomes the initial send credit the remote peer can use on each stream. -`StreamWindow` extends that credit by advertising a larger maximum offset. +`StreamWindow` extends that credit by advertising a larger absolute `maximum_offset`. In practice, a stream is writable only when both are true: - local send buffering has room - peer-advertised stream credit allows more bytes +Receive credit advances when the local application commits read bytes, not merely when bytes become readable. That is when the FSM emits a `StreamWindow` update. + ## Close And Liveness `StreamClose` aborts a stream early. Semantically it can target: @@ -263,6 +348,13 @@ In practice, a stream is writable only when both are true: `Close` aborts the whole session. +`Unpair` is stronger than `Close`: + +- it forgets the currently bound peer locally +- it aborts the active session immediately +- it may emit one final outbound `Unpair` frame +- reconnect does not resume until a peer is paired again + Idle sessions may send `Ping`. The peer does not answer with another ping; normal record acknowledgment is enough. Sessions also have local timers for: @@ -273,9 +365,11 @@ Sessions also have local timers for: - keepalive ping interval - peer silence timeout +If peer silence exceeds the configured timeout, the session closes with timeout. + ## Security Properties -The current handshake is ML-KEM-based and post-quantum focused. +The current handshake family is ML-KEM-based and post-quantum focused. Session payloads are encrypted and authenticated. The session header stays visible so the receiver can route the record, but it is still authenticated as AEAD AAD. From 21f846b65ccb6cef422b8936f03eefe30783beae Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 29 Apr 2026 08:52:04 -0400 Subject: [PATCH 290/304] ql: fix clippy --- ql-fsm/src/fsm.rs | 11 ++---- ql-fsm/src/handshake/ik.rs | 4 +- ql-fsm/src/handshake/kk.rs | 4 +- ql-fsm/src/handshake/xx.rs | 12 ++---- ql-fsm/src/lib.rs | 4 +- ql-fsm/src/session/ack_tracker.rs | 12 +++--- ql-fsm/src/session/mod.rs | 4 +- ql-fsm/src/session/tests.rs | 24 ++++++------ ql-fsm/src/tests/mod.rs | 4 +- ql-fsm/src/tests/proptest.rs | 65 +++++++++++++++---------------- ql-fsm/src/tests/session.rs | 2 +- ql-wire/src/handshake/ik.rs | 20 +++------- ql-wire/src/handshake/kk.rs | 20 +++------- ql-wire/src/handshake/mod.rs | 6 +-- ql-wire/src/handshake/xx.rs | 40 ++++++------------- 15 files changed, 90 insertions(+), 142 deletions(-) diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs index cc342b53..aacdbd67 100644 --- a/ql-fsm/src/fsm.rs +++ b/ql-fsm/src/fsm.rs @@ -1,7 +1,4 @@ -use std::{ - collections::VecDeque, - time::Instant, -}; +use std::{collections::VecDeque, time::Instant}; use bytes::Bytes; use ql_wire::{self as wire, QlCrypto, RouteId, SessionCloseCode, StreamId, WireDecode}; @@ -149,8 +146,7 @@ pub fn receive( let frames = wire::parse_session_frames(plaintext); let mut emit = EventSink::new(events); - conn.session - .receive(state.now, seq, frames, &mut emit); + conn.session.receive(state.now, seq, frames, &mut emit); emit.termination }; @@ -223,8 +219,7 @@ pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option Vec<(u64, u64)> { + fn ack_ranges(pending_ack: &PendingAck) -> Vec<(u64, u64)> { pending_ack .ack .ranges() @@ -203,7 +203,7 @@ mod tests { ack_tracker.schedule_ack(now); let pending_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); - assert_eq!(ack_ranges(pending_ack), vec![(10, 12)]); + assert_eq!(ack_ranges(&pending_ack), vec![(10, 12)]); } #[test] @@ -218,7 +218,7 @@ mod tests { ack_tracker.schedule_ack(now + Duration::from_millis(5)); let pending_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); - assert_eq!(ack_ranges(pending_ack), vec![(15, 16), (12, 12), (10, 10)]); + assert_eq!(ack_ranges(&pending_ack), vec![(15, 16), (12, 12), (10, 10)]); } #[test] @@ -242,7 +242,7 @@ mod tests { ack_tracker.schedule_ack(now); let pending_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); - assert_eq!(ack_ranges(pending_ack), vec![(5, 5), (3, 3)]); + assert_eq!(ack_ranges(&pending_ack), vec![(5, 5), (3, 3)]); } #[test] @@ -256,11 +256,11 @@ mod tests { ack_tracker.schedule_ack(now); let first_ack = ack_tracker.pending_ack(4).unwrap(); - assert_eq!(ack_ranges(first_ack.clone()), vec![(5, 5)]); + assert_eq!(ack_ranges(&first_ack), vec![(5, 5)]); ack_tracker.on_ack_emitted(&first_ack); ack_tracker.retire_acked_ranges(&first_ack.ack); let second_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); - assert_eq!(ack_ranges(second_ack), vec![(3, 3), (1, 1)]); + assert_eq!(ack_ranges(&second_ack), vec![(3, 3), (1, 1)]); } } diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs index 912c7bd4..55187757 100644 --- a/ql-fsm/src/session/mod.rs +++ b/ql-fsm/src/session/mod.rs @@ -117,13 +117,13 @@ impl SessionFsm { next_stream_ordinal: 0, next_record_seq: RecordSeq::from_u32(0), next_write_id: 0, - tracked_records: Default::default(), + tracked_records: IndexMap::default(), ack_tracker: AckTracker::new( config.accepted_record_window, config.pending_ack_range_limit, ), pending_ping: false, - streams: Default::default(), + streams: IndexMap::default(), next_stream_index: 0, remote_stream_history: RemoteStreamHistory::new(config.local_parity.remote()), }, diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index c84df3a5..2226753f 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -33,10 +33,10 @@ fn record_ack(seq: RecordSeq) -> RecordAck { const REFUSED: StreamCloseCode = StreamCloseCode(1); const TIMEOUT: StreamCloseCode = StreamCloseCode(2); -fn header(value: u64) -> Option { - Some(StreamHeader { +fn header(value: u64) -> StreamHeader { + StreamHeader { route_id: route_id(value), - }) + } } fn opened(stream_id: StreamId) -> SessionEvent { @@ -284,7 +284,7 @@ fn commit_stream_read_is_what_advances_stream_window() { let data = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: header(1), + header: Some(header(1)), fin: false, bytes: b"hi".to_vec(), })]; @@ -333,7 +333,7 @@ fn pure_ack_only_records_are_fire_and_forget() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: header(1), + header: Some(header(1)), fin: false, bytes: b"hi".to_vec(), })]; @@ -363,7 +363,7 @@ fn inbound_stream_data_emits_opened_and_readable() { let record = vec![SessionFrame::StreamData(ql_wire::StreamData { stream_id, offset: offset(0), - header: header(1), + header: Some(header(1)), fin: true, bytes: b"hello".to_vec(), })]; @@ -389,7 +389,7 @@ fn inbound_empty_fin_emits_finished_immediately() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: header(1), + header: Some(header(1)), fin: true, bytes: Vec::new(), })]; @@ -465,7 +465,7 @@ fn duplicate_stream_data_is_not_redelivered() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: header(1), + header: Some(header(1)), fin: false, bytes: b"hi".to_vec(), })]; @@ -512,7 +512,7 @@ fn late_remote_stream_data_after_close_is_ignored() { let data = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: header(1), + header: Some(header(1)), fin: false, bytes: b"hello".to_vec(), })]; @@ -546,7 +546,7 @@ fn duplicate_finished_remote_data_after_reap_is_ignored() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: header(1), + header: Some(header(1)), fin: true, bytes: b"hello".to_vec(), })]; @@ -575,7 +575,7 @@ fn duplicate_finished_remote_data_before_read_is_ignored() { let record = vec![SessionFrame::StreamData(StreamData { stream_id, offset: offset(0), - header: header(1), + header: Some(header(1)), fin: true, bytes: b"hello".to_vec(), })]; @@ -683,7 +683,7 @@ fn close_does_not_ack_rejected_record_seq() { let invalid = vec![SessionFrame::StreamData(StreamData { stream_id: stream_id(0), offset: offset(0), - header: header(1), + header: Some(header(1)), fin: false, bytes: b"bad".to_vec(), })]; diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index acb0005f..662c577b 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -26,8 +26,8 @@ enum Side { impl Side { fn idx(self) -> usize { match self { - Side::A => 0, - Side::B => 1, + Self::A => 0, + Self::B => 1, } } } diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index c383349c..d95513b3 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -296,17 +296,14 @@ impl Runner { Action::Write { side, slot, bytes } => { if let Some(stream_id) = self.slots[side.idx()][*slot] { let mut chunk = Bytes::copy_from_slice(bytes); - let accepted = if let Ok(mut stream) = - self.harness.node_mut(*side).fsm.stream(stream_id) - { - if let Some(mut writer) = stream.writer() { - writer.write(&mut chunk) - } else { - 0 - } - } else { - 0 - }; + let accepted = self.harness.node_mut(*side).fsm.stream(stream_id).map_or( + 0, + |mut stream| { + stream + .writer() + .map_or(0, |mut writer| writer.write(&mut chunk)) + }, + ); if accepted != 0 { self.expected[opposite(*side).idx()] .entry(stream_id) @@ -317,18 +314,17 @@ impl Runner { } Action::Finish { side, slot } => { if let Some(stream_id) = self.slots[side.idx()][*slot] { - let finished = if let Ok(mut stream) = - self.harness.node_mut(*side).fsm.stream(stream_id) - { - if let Some(writer) = stream.writer() { - writer.finish(); - true - } else { - false - } - } else { - false - }; + let finished = self + .harness + .node_mut(*side) + .fsm + .stream(stream_id) + .is_ok_and(|mut stream| { + stream.writer().is_some_and(|writer| { + writer.finish(); + true + }) + }); if finished { self.finished_by[side.idx()].insert(stream_id); } @@ -336,14 +332,15 @@ impl Runner { } Action::Close { side, slot } => { if let Some(stream_id) = self.slots[side.idx()][*slot] { - let closed = if let Ok(mut stream) = - self.harness.node_mut(*side).fsm.stream(stream_id) - { - stream.close(CloseTarget::Both, StreamCloseCode::CANCELLED); - true - } else { - false - }; + let closed = self + .harness + .node_mut(*side) + .fsm + .stream(stream_id) + .is_ok_and(|mut stream| { + stream.close(CloseTarget::Both, StreamCloseCode::CANCELLED); + true + }); if closed { self.closed_by[side.idx()].insert(stream_id); self.slots[side.idx()][*slot] = None; @@ -863,9 +860,9 @@ fn connected_action_strategy() -> impl Strategy { side_action(Action::DropNext), side_usize_action(queue_index.clone(), Action::deliver_queued), side_usize_action(queue_index.clone(), Action::duplicate_queued), - side_usize_action(queue_index.clone(), Action::drop_queued), + side_usize_action(queue_index, Action::drop_queued), side_usize_action(slot.clone(), Action::open_stream), - side_usize_vec_action(slot.clone(), bytes.clone(), Action::write), + side_usize_vec_action(slot.clone(), bytes, Action::write), side_usize_action(slot.clone(), Action::finish), side_usize_action(slot, Action::close), ] @@ -910,7 +907,7 @@ fn terminal_action_strategy() -> impl Strategy { let queue_index = 0usize..6; prop_oneof![ side_usize_action(slot.clone(), Action::open_stream), - side_usize_vec_action(slot.clone(), bytes.clone(), Action::write), + side_usize_vec_action(slot.clone(), bytes, Action::write), side_usize_action(slot.clone(), Action::finish), side_usize_action(slot, Action::close), side_action(Action::TakeNext), diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index 86d0ce85..edd2cf44 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -215,7 +215,7 @@ fn disconnected_stream_operations_fail_with_no_session() { stream.close( ql_wire::CloseTarget::Both, ql_wire::StreamCloseCode::CANCELLED, - ) + ); }), Err(StreamError::NoSession) ); diff --git a/ql-wire/src/handshake/ik.rs b/ql-wire/src/handshake/ik.rs index 460b7ab7..03bdc032 100644 --- a/ql-wire/src/handshake/ik.rs +++ b/ql-wire/src/handshake/ik.rs @@ -211,7 +211,7 @@ impl IkHandshake { crypto, header, HandshakeKind::Ik1, - &meta, + meta, self.local_transport_params, ); let (skem_ciphertext, skem_secret) = @@ -253,7 +253,7 @@ impl IkHandshake { crypto, header, HandshakeKind::Ik2, - &meta, + meta, self.local_transport_params, ); let remote_ephemeral = self @@ -283,11 +283,7 @@ impl IkHandshake { }) } - pub fn read_1( - &mut self, - crypto: &impl QlCrypto, - message: &Ik1, - ) -> Result<(), WireError> { + pub fn read_1(&mut self, crypto: &impl QlCrypto, message: &Ik1) -> Result<(), WireError> { if self.step != IkStep::Recv1 { return Err(WireError::InvalidState); } @@ -299,7 +295,7 @@ impl IkHandshake { crypto, message.header, HandshakeKind::Ik1, - &message.meta, + message.meta, message.transport_params, ); self.symmetric @@ -329,11 +325,7 @@ impl IkHandshake { Ok(()) } - pub fn read_2( - &mut self, - crypto: &impl QlCrypto, - message: &Ik2, - ) -> Result<(), WireError> { + pub fn read_2(&mut self, crypto: &impl QlCrypto, message: &Ik2) -> Result<(), WireError> { if self.step != IkStep::Recv2 { return Err(WireError::InvalidState); } @@ -345,7 +337,7 @@ impl IkHandshake { crypto, message.header, HandshakeKind::Ik2, - &message.meta, + message.meta, message.transport_params, ); let local_ephemeral = self diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index a08e8056..a56fca83 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -202,7 +202,7 @@ impl KkHandshake { crypto, header, HandshakeKind::Kk1, - &meta, + meta, self.local_transport_params, ); let (skem_ciphertext, skem_secret) = @@ -242,7 +242,7 @@ impl KkHandshake { crypto, header, HandshakeKind::Kk2, - &meta, + meta, self.local_transport_params, ); let remote_ephemeral = self @@ -271,11 +271,7 @@ impl KkHandshake { }) } - pub fn read_1( - &mut self, - crypto: &impl QlCrypto, - message: &Kk1, - ) -> Result<(), WireError> { + pub fn read_1(&mut self, crypto: &impl QlCrypto, message: &Kk1) -> Result<(), WireError> { if self.step != KkStep::Recv1 { return Err(WireError::InvalidState); } @@ -286,7 +282,7 @@ impl KkHandshake { crypto, message.header, HandshakeKind::Kk1, - &message.meta, + message.meta, message.transport_params, ); self.symmetric @@ -303,11 +299,7 @@ impl KkHandshake { Ok(()) } - pub fn read_2( - &mut self, - crypto: &impl QlCrypto, - message: &Kk2, - ) -> Result<(), WireError> { + pub fn read_2(&mut self, crypto: &impl QlCrypto, message: &Kk2) -> Result<(), WireError> { if self.step != KkStep::Recv2 { return Err(WireError::InvalidState); } @@ -318,7 +310,7 @@ impl KkHandshake { crypto, message.header, HandshakeKind::Kk2, - &message.meta, + message.meta, message.transport_params, ); let local_ephemeral = self diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 79e0f7ad..64a1a6d2 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -393,7 +393,7 @@ fn mix_hash_routed_handshake( crypto: &impl QlCrypto, header: HandshakeHeader, kind: HandshakeKind, - meta: &HandshakeMeta, + meta: HandshakeMeta, transport_params: TransportParams, ) { mix_hash_handshake_preamble( @@ -411,7 +411,7 @@ fn mix_hash_pairing_handshake( crypto: &impl QlCrypto, header: XxHeader, kind: HandshakeKind, - meta: &HandshakeMeta, + meta: HandshakeMeta, transport_params: TransportParams, ) { mix_hash_handshake_preamble( @@ -429,7 +429,7 @@ fn mix_hash_handshake_preamble( crypto: &impl QlCrypto, header: &[u8], kind: HandshakeKind, - meta: &HandshakeMeta, + meta: HandshakeMeta, transport_params: TransportParams, ) { symmetric.mix_hash(crypto, HANDSHAKE_PREAMBLE_DOMAIN); diff --git a/ql-wire/src/handshake/xx.rs b/ql-wire/src/handshake/xx.rs index c4e7294f..8f252263 100644 --- a/ql-wire/src/handshake/xx.rs +++ b/ql-wire/src/handshake/xx.rs @@ -295,7 +295,7 @@ impl XxHandshake { crypto, header, HandshakeKind::Xx1, - &meta, + meta, self.local_transport_params, ); mix_psk_pairing_token(&mut self.symmetric, crypto, self.pairing_token); @@ -314,11 +314,7 @@ impl XxHandshake { }) } - pub fn read_1( - &mut self, - crypto: &impl QlCrypto, - message: &Xx1, - ) -> Result<(), WireError> { + pub fn read_1(&mut self, crypto: &impl QlCrypto, message: &Xx1) -> Result<(), WireError> { if self.step != XxStep::Recv1 { return Err(WireError::InvalidState); } @@ -329,7 +325,7 @@ impl XxHandshake { crypto, message.header, HandshakeKind::Xx1, - &message.meta, + message.meta, message.transport_params, ); mix_psk_pairing_token(&mut self.symmetric, crypto, self.pairing_token); @@ -356,7 +352,7 @@ impl XxHandshake { crypto, header, HandshakeKind::Xx2, - &meta, + meta, self.local_transport_params, ); @@ -381,11 +377,7 @@ impl XxHandshake { }) } - pub fn read_2( - &mut self, - crypto: &impl QlCrypto, - message: &Xx2, - ) -> Result<(), WireError> { + pub fn read_2(&mut self, crypto: &impl QlCrypto, message: &Xx2) -> Result<(), WireError> { if self.step != XxStep::Recv2 { return Err(WireError::InvalidState); } @@ -396,7 +388,7 @@ impl XxHandshake { crypto, message.header, HandshakeKind::Xx2, - &message.meta, + message.meta, message.transport_params, ); @@ -433,7 +425,7 @@ impl XxHandshake { crypto, header, HandshakeKind::Xx3, - &meta, + meta, self.local_transport_params, ); @@ -457,11 +449,7 @@ impl XxHandshake { }) } - pub fn read_3( - &mut self, - crypto: &impl QlCrypto, - message: &Xx3, - ) -> Result<(), WireError> { + pub fn read_3(&mut self, crypto: &impl QlCrypto, message: &Xx3) -> Result<(), WireError> { if self.step != XxStep::Recv3 { return Err(WireError::InvalidState); } @@ -476,7 +464,7 @@ impl XxHandshake { crypto, message.header, HandshakeKind::Xx3, - &message.meta, + message.meta, message.transport_params, ); @@ -508,7 +496,7 @@ impl XxHandshake { crypto, header, HandshakeKind::Xx4, - &meta, + meta, self.local_transport_params, ); @@ -529,11 +517,7 @@ impl XxHandshake { }) } - pub fn read_4( - &mut self, - crypto: &impl QlCrypto, - message: &Xx4, - ) -> Result<(), WireError> { + pub fn read_4(&mut self, crypto: &impl QlCrypto, message: &Xx4) -> Result<(), WireError> { if self.step != XxStep::Recv4 { return Err(WireError::InvalidState); } @@ -548,7 +532,7 @@ impl XxHandshake { crypto, message.header, HandshakeKind::Xx4, - &message.meta, + message.meta, message.transport_params, ); From 0a0c4e6f6998a84e2dfb5a6610f829a9a4f36d39 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 29 Apr 2026 09:48:18 -0400 Subject: [PATCH 291/304] ql: better coalescing across restransmits and inner byte queue --- ql-fsm/src/session/stream_tx.rs | 123 +++++++++++++++++++++++++++----- 1 file changed, 105 insertions(+), 18 deletions(-) diff --git a/ql-fsm/src/session/stream_tx.rs b/ql-fsm/src/session/stream_tx.rs index e4d2d3b6..15533922 100644 --- a/ql-fsm/src/session/stream_tx.rs +++ b/ql-fsm/src/session/stream_tx.rs @@ -182,19 +182,29 @@ impl StreamTx { max_payload: usize, peer_max_offset: u64, ) -> Option { - // TODO: coalesce a lost range with contiguous unsent tail bytes when they fit in the same - // transmit budget. That would let a repacked record send one larger StreamData frame - // instead of retransmitting the lost prefix first and the new tail later. + let budget_end = |start: u64| { + start + .saturating_add(max_payload as u64) + .min(peer_max_offset) + }; + + // prefer the lowest lost bytes before sending new bytes if let Some(range) = self.retransmits.peek_min() { - let end = range - .end - .min(range.start.saturating_add(max_payload as u64)) - .min(peer_max_offset); + let mut end = range.end.min(budget_end(range.start)); + + // extend only when lost bytes end where unsent bytes begin + if end == range.end && range.end == self.unsent { + end = self.end_offset().min(budget_end(range.start)); + } + if end > range.start { let range = self.retransmits.pop_min().unwrap(); if end < range.end { self.retransmits.insert(end..range.end); } + + // mark any new bytes in this frame as sent + self.unsent = self.unsent.max(end); return Some(StreamTxRange { offset: range.start, len: usize::try_from(end - range.start).unwrap(), @@ -203,11 +213,9 @@ impl StreamTx { } } + // send bytes that have not been sent yet if self.unsent < self.end_offset() { - let end = self - .end_offset() - .min(self.unsent.saturating_add(max_payload as u64)) - .min(peer_max_offset); + let end = self.end_offset().min(budget_end(self.unsent)); if end > self.unsent { let start = self.unsent; self.unsent = end; @@ -219,11 +227,15 @@ impl StreamTx { } } - let final_offset = self.final_offset.filter(|final_offset| { - matches!(final_offset.state, SendState::Lost | SendState::Unsent) - && final_offset.offset <= peer_max_offset - })?; - self.final_offset.as_mut().unwrap().state = SendState::Sent; + // send a fin after all data has been sent + let final_offset = + self.final_offset + .as_mut() + .filter(|TrackedFinalOffset { offset, state }| { + (*state == SendState::Lost || *state == SendState::Unsent) + && *offset <= peer_max_offset + })?; + final_offset.state = SendState::Sent; Some(StreamTxRange { offset: final_offset.offset, len: 0, @@ -375,7 +387,7 @@ mod tests { use super::{StreamTx, StreamTxRange}; #[test] - fn append_tracks_unsent_tail() { + fn append_tracks_unsent_bytes() { let mut tx = StreamTx::new(); tx.append(Bytes::from_static(b"abc")); tx.append(Bytes::from_static(b"de")); @@ -391,7 +403,7 @@ mod tests { } #[test] - fn lost_range_is_selected_before_unsent_tail() { + fn lost_range_is_selected_before_unsent_bytes() { let mut tx = StreamTx::new(); tx.append(Bytes::from_static(b"abcdef")); @@ -408,6 +420,81 @@ mod tests { ); } + #[test] + fn lost_range_coalesces_contiguous_unsent_bytes() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abc")); + + let first = tx.poll_transmit(3, u64::MAX).unwrap(); + tx.retransmit(first); + tx.append(Bytes::from_static(b"def")); + + assert_eq!( + tx.poll_transmit(6, u64::MAX), + Some(StreamTxRange { + offset: 0, + len: 6, + fin: false, + }) + ); + assert_eq!(tx.poll_transmit(6, u64::MAX), None); + } + + #[test] + fn lost_range_coalesces_only_new_bytes_that_fit() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abc")); + + let first = tx.poll_transmit(3, u64::MAX).unwrap(); + tx.retransmit(first); + tx.append(Bytes::from_static(b"def")); + + assert_eq!( + tx.poll_transmit(5, u64::MAX), + Some(StreamTxRange { + offset: 0, + len: 5, + fin: false, + }) + ); + assert_eq!( + tx.poll_transmit(6, u64::MAX), + Some(StreamTxRange { + offset: 5, + len: 1, + fin: false, + }) + ); + } + + #[test] + fn non_contiguous_lost_range_does_not_coalesce_unsent_bytes() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abcdef")); + + let first = tx.poll_transmit(3, u64::MAX).unwrap(); + let _second = tx.poll_transmit(3, u64::MAX).unwrap(); + tx.retransmit(first); + tx.append(Bytes::from_static(b"ghi")); + + assert_eq!( + tx.poll_transmit(6, u64::MAX), + Some(StreamTxRange { + offset: 0, + len: 3, + fin: false, + }) + ); + assert_eq!( + tx.poll_transmit(6, u64::MAX), + Some(StreamTxRange { + offset: 6, + len: 3, + fin: false, + }) + ); + } + #[test] fn acked_prefix_is_trimmed() { let mut tx = StreamTx::new(); From 416309898773b958b08334914f2228b2ea4a93fb Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 12 May 2026 09:12:45 -0400 Subject: [PATCH 292/304] ql-rpc: remove copied() --- ql-rpc/src/router/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs index 47945b3c..dfdbc960 100644 --- a/ql-rpc/src/router/mod.rs +++ b/ql-rpc/src/router/mod.rs @@ -39,7 +39,7 @@ where pub fn handle(&self, stream: St) -> Option<(RouteId, Sp::Handle)> { let route_id = stream.route_id()?; - let Some(route) = self.routes.get(&route_id).copied() else { + let Some(route) = self.routes.get(&route_id) else { close_stream(stream, StreamCloseCode::UNKNOWN_ROUTE); return None; }; From e596d5e03f5a327c1892402c174a6932c29fc57b Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 12 May 2026 11:17:32 -0400 Subject: [PATCH 293/304] ql: include peer xids in xx handshake --- ql-fsm/src/fsm.rs | 4 +- ql-fsm/src/handshake/mod.rs | 8 +-- ql-fsm/src/handshake/xx.rs | 28 ++++++-- ql-fsm/src/lib.rs | 8 ++- ql-fsm/src/pairing.rs | 38 +++++++++++ ql-fsm/src/tests/handshake.rs | 9 ++- ql-fsm/src/tests/mod.rs | 21 +++++- ql-runtime/src/command.rs | 4 +- ql-runtime/src/driver/mod.rs | 4 +- ql-runtime/src/handle/mod.rs | 8 +-- ql-runtime/src/lib.rs | 2 +- ql-runtime/src/tests/handshake.rs | 10 ++- ql-runtime/src/tests/mod.rs | 4 +- ql-wire/src/handshake/mod.rs | 41 ++--------- ql-wire/src/handshake/xx.rs | 102 +++++++++++++++++++++------- ql-wire/src/tests.rs | 109 ++++++++++++++++++++++++++---- 16 files changed, 296 insertions(+), 104 deletions(-) create mode 100644 ql-fsm/src/pairing.rs diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs index aacdbd67..c73fb6fc 100644 --- a/ql-fsm/src/fsm.rs +++ b/ql-fsm/src/fsm.rs @@ -94,8 +94,8 @@ pub fn handle_disarm_pairing(fsm: &mut QlFsm) { handshake::handle_disarm_pairing(fsm); } -pub fn handle_connect_xx(fsm: &mut QlFsm, token: ql_wire::PairingToken, crypto: &impl QlCrypto) { - handshake::handle_connect_xx(fsm, token, crypto); +pub fn handle_connect_xx(fsm: &mut QlFsm, invite: crate::PairingInvite, crypto: &impl QlCrypto) { + handshake::handle_connect_xx(fsm, invite, crypto); } pub fn handle_connect_ik(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { diff --git a/ql-fsm/src/handshake/mod.rs b/ql-fsm/src/handshake/mod.rs index 34bc0648..2431bbc8 100644 --- a/ql-fsm/src/handshake/mod.rs +++ b/ql-fsm/src/handshake/mod.rs @@ -2,9 +2,7 @@ mod ik; mod kk; mod xx; -use ql_wire::{ - self as wire, EphemeralPublicKey, HandshakeMeta, PairingToken, QlCrypto, QlHandshakeRecord, -}; +use ql_wire::{self as wire, EphemeralPublicKey, HandshakeMeta, QlCrypto, QlHandshakeRecord}; use crate::{ fsm::emit_peer_status, @@ -27,9 +25,9 @@ pub fn handle_connect_kk(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), Ok(()) } -pub fn handle_connect_xx(fsm: &mut QlFsm, token: PairingToken, crypto: &impl QlCrypto) { +pub fn handle_connect_xx(fsm: &mut QlFsm, invite: crate::PairingInvite, crypto: &impl QlCrypto) { prepare_for_outbound_connect(fsm); - xx::start_initiator(fsm, crypto, token); + xx::start_initiator(fsm, crypto, invite.token, invite.xid); } pub fn next_handshake_meta(fsm: &mut QlFsm) -> HandshakeMeta { diff --git a/ql-fsm/src/handshake/xx.rs b/ql-fsm/src/handshake/xx.rs index 63cef457..dbc72006 100644 --- a/ql-fsm/src/handshake/xx.rs +++ b/ql-fsm/src/handshake/xx.rs @@ -1,4 +1,4 @@ -use ql_wire::{self as wire, PairingToken, QlCrypto, QlHandshakeRecord, Xx1, Xx2, Xx3, Xx4}; +use ql_wire::{self as wire, PairingToken, QlCrypto, QlHandshakeRecord, Xx1, Xx2, Xx3, Xx4, XID}; use super::{ emit_peer_status, enqueue_handshake, finish_handshake, reset_connected_session_if_needed, @@ -8,11 +8,17 @@ use crate::{ QlFsm, ReceiveError, }; -pub fn start_initiator(fsm: &mut QlFsm, crypto: &impl QlCrypto, token: PairingToken) { +pub fn start_initiator( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + token: PairingToken, + remote_xid: XID, +) { let meta = super::next_handshake_meta(fsm); let mut handshake = wire::XxHandshake::new_initiator( crypto, fsm.identity.clone(), + remote_xid, token, super::local_transport_params(fsm), ); @@ -37,18 +43,25 @@ pub fn handle_xx1( return Ok(()); } match fsm.state.armed_pairing_token { - Some(expected) if expected.id(crypto) != message.header.pairing_id => { + Some(expected) if expected.id(crypto) != message.pairing_id => { Err(ReceiveError::InvalidPairingId { expected: expected.id(crypto), - actual: message.header.pairing_id, + actual: message.pairing_id, }) } + Some(_) + if message.header.recipient != fsm.identity.xid + || message.header.sender == fsm.identity.xid => + { + Err(ReceiveError::InvalidXid) + } Some(token) => { reset_connected_session_if_needed(fsm); let mut handshake = wire::XxHandshake::new_responder( crypto, fsm.identity.clone(), + message.header.sender, token, super::local_transport_params(fsm), ); @@ -153,7 +166,12 @@ pub fn should_ignore_inbound(fsm: &QlFsm, crypto: &impl QlCrypto, message: &Xx1) LinkState::Idle | LinkState::Connected(_) => false, LinkState::IkInitiator(_) | LinkState::KkInitiator(_) | LinkState::XxResponder(_) => true, LinkState::XxInitiator(state) => { - if state.handshake.pairing_id(crypto) != message.header.pairing_id { + if state.handshake.pairing_id(crypto) != message.pairing_id { + return false; + } + if message.header.recipient != fsm.identity.xid + || message.header.sender != state.handshake.remote_xid() + { return false; } super::local_start_wins(&state.initial_ephemeral, &message.ephemeral) diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs index 860cedc8..2010e308 100644 --- a/ql-fsm/src/lib.rs +++ b/ql-fsm/src/lib.rs @@ -21,6 +21,7 @@ mod error; mod fsm; mod handshake; +mod pairing; mod session; pub(crate) mod state; #[cfg(test)] @@ -33,6 +34,7 @@ use std::{ pub use bytes::Bytes; pub use error::*; +pub use pairing::PairingInvite; use ql_wire::{ PairingToken, PeerBundle, QlCrypto, QlIdentity, RouteId, SessionClose, SessionCloseCode, StreamClose, StreamId, @@ -231,10 +233,10 @@ impl QlFsm { fsm::handle_disarm_pairing(self); } - /// starts an outbound xx handshake using the supplied pairing token - pub fn connect_xx(&mut self, now: Instant, token: PairingToken, crypto: &impl QlCrypto) { + /// starts an outbound xx handshake using a pairing invite + pub fn connect_xx(&mut self, now: Instant, invite: PairingInvite, crypto: &impl QlCrypto) { self.state.now = now; - fsm::handle_connect_xx(self, token, crypto); + fsm::handle_connect_xx(self, invite, crypto); } /// starts an IK handshake with the currently bound peer diff --git a/ql-fsm/src/pairing.rs b/ql-fsm/src/pairing.rs new file mode 100644 index 00000000..fef19954 --- /dev/null +++ b/ql-fsm/src/pairing.rs @@ -0,0 +1,38 @@ +use ql_wire::{ByteSlice, PairingToken, Reader, WireDecode, WireEncode, WireError, XID}; + +/// Out-of-band invite consumed by the initiator of an XX pairing +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct PairingInvite { + pub xid: XID, + pub token: PairingToken, +} + +impl PairingInvite { + pub const VERSION: u8 = 1; + pub const WIRE_SIZE: usize = size_of::() + XID::SIZE + PairingToken::SIZE; +} + +impl WireEncode for PairingInvite { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + Self::VERSION.encode(out); + self.xid.encode(out); + self.token.encode(out); + } +} + +impl WireDecode for PairingInvite { + fn decode(reader: &mut Reader) -> Result { + if reader.decode::()? != Self::VERSION { + return Err(WireError::InvalidPayload); + } + + Ok(Self { + xid: reader.decode()?, + token: reader.decode()?, + }) + } +} diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 9b1d749f..301b9680 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -103,7 +103,14 @@ fn connect_methods_require_bound_peer() { assert_eq!(fsm.connect_ik(time, &crypto), Err(NoPeerError)); assert_eq!(fsm.connect_kk(time, &crypto), Err(NoPeerError)); - fsm.connect_xx(time, pairing_token(2), &crypto); + fsm.connect_xx( + time, + PairingInvite { + xid: ql_wire::XID([2; ql_wire::XID::SIZE]), + token: pairing_token(2), + }, + &crypto, + ); } #[test] diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 662c577b..703070cf 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -6,13 +6,13 @@ use std::time::{Duration, Instant}; use ql_wire::{ self, test_identities, test_identity, ConnectionId, PairingToken, QlCrypto, SessionKey, - SoftwareCrypto, TransportParams, + SoftwareCrypto, TransportParams, XID, }; use crate::{ session::{SessionConfig, SessionFsm, StreamParity}, state::{ConnectedState, LinkState, SessionTransport}, - Event, NoPeerError, OutboundWrite, QlFsm, QlFsmConfig, WriteId, + Event, NoPeerError, OutboundWrite, PairingInvite, QlFsm, QlFsmConfig, WriteId, }; type TestCrypto = SoftwareCrypto; @@ -199,8 +199,23 @@ impl Harness { fn connect_xx(&mut self, side: Side, token: PairingToken) { let time = self.time(); + let remote_xid = self.remote_xid(side); let Node { fsm, crypto } = self.node_mut(side); - fsm.connect_xx(time, token, crypto); + fsm.connect_xx( + time, + PairingInvite { + xid: remote_xid, + token, + }, + crypto, + ); + } + + fn remote_xid(&self, side: Side) -> XID { + match side { + Side::A => self.b.fsm.identity.xid, + Side::B => self.a.fsm.identity.xid, + } } fn deliver(&mut self, side: Side, record: Vec) { diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs index 07d3c7a5..4a47a45e 100644 --- a/ql-runtime/src/command.rs +++ b/ql-runtime/src/command.rs @@ -1,4 +1,4 @@ -use ql_fsm::NoSessionError; +use ql_fsm::{NoSessionError, PairingInvite}; use ql_wire::{ CloseTarget, PairingToken, PeerBundle, RouteId, SessionCloseCode, StreamCloseCode, StreamId, }; @@ -15,7 +15,7 @@ pub enum Command { }, DisarmPairing, StartPairing { - token: PairingToken, + invite: PairingInvite, }, OpenStream { route_id: RouteId, diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index 0ae54699..d68bfce2 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -198,9 +198,9 @@ impl DriverState { log::info!("disarming inbound pairing"); fsm.disarm_pairing(); } - Command::StartPairing { token } => { + Command::StartPairing { invite } => { log::info!(" starting XX pairing"); - fsm.connect_xx(Instant::now(), token, platform); + fsm.connect_xx(Instant::now(), invite, platform); } Command::CloseSession { code } => { log::info!("closing session: code={code:?}"); diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs index e98ce584..1782c17a 100644 --- a/ql-runtime/src/handle/mod.rs +++ b/ql-runtime/src/handle/mod.rs @@ -1,4 +1,4 @@ -use ql_fsm::NoSessionError; +use ql_fsm::{NoSessionError, PairingInvite}; use ql_wire::{PairingToken, PeerBundle, RouteId, SessionCloseCode, StreamId}; use crate::command::Command; @@ -38,9 +38,9 @@ impl RuntimeHandle { self.send(Command::DisarmPairing); } - /// starts an outbound xx handshake using the supplied pairing token - pub fn start_pairing(&self, token: PairingToken) { - self.send(Command::StartPairing { token }); + /// starts an outbound xx handshake using an out-of-band pairing invite + pub fn start_pairing(&self, invite: PairingInvite) { + self.send(Command::StartPairing { invite }); } /// closes the current encrypted session diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs index b4305d14..33783456 100644 --- a/ql-runtime/src/lib.rs +++ b/ql-runtime/src/lib.rs @@ -1,4 +1,4 @@ -pub use ql_fsm::NoSessionError; +pub use ql_fsm::{NoSessionError, PairingInvite}; pub use self::{error::QlStreamError, handle::*, platform::*}; diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index d641beee..416ec842 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -132,7 +132,10 @@ async fn start_pairing_round_trip_connects_when_armed() { spawn_forwarder(outbound_b, inbound_a_tx); handle_b.arm_pairing(token); - handle_a.start_pairing(token); + handle_a.start_pairing(PairingInvite { + xid: identity_b.xid, + token, + }); await_status(&status_a, Some(identity_b.xid), PeerStatus::Connected).await; await_status(&status_b, Some(identity_a.xid), PeerStatus::Connected).await; @@ -158,7 +161,10 @@ async fn start_pairing_does_not_connect_when_unarmed() { spawn_forwarder(outbound_a, inbound_b_tx); spawn_forwarder(outbound_b, inbound_a_tx); - handle_a.start_pairing(token); + handle_a.start_pairing(PairingInvite { + xid: identity_b.xid, + token, + }); assert_no_status_for( &status_a, diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 066903a5..71083724 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -20,8 +20,8 @@ use ql_wire::{ use tokio::{task::LocalSet, time::Sleep}; use crate::{ - new_runtime, platform::QlTimer, NoSessionError, QlFsmConfig, QlStream, QlStreamError, - RuntimeConfig, RuntimeHandle, + new_runtime, platform::QlTimer, NoSessionError, PairingInvite, QlFsmConfig, QlStream, + QlStreamError, RuntimeConfig, RuntimeHandle, }; mod handshake; diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 64a1a6d2..6ba9209c 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -55,33 +55,6 @@ impl codec::WireDecode for HandshakeHeader { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct XxHeader { - pub pairing_id: PairingId, -} - -impl XxHeader { - pub const WIRE_SIZE: usize = PairingId::SIZE; -} - -impl WireEncode for XxHeader { - fn encoded_len(&self) -> usize { - Self::WIRE_SIZE - } - - fn encode(&self, out: &mut W) { - self.pairing_id.encode(out); - } -} - -impl codec::WireDecode for XxHeader { - fn decode(reader: &mut codec::Reader) -> Result { - Ok(Self { - pairing_id: reader.decode()?, - }) - } -} - #[derive(Debug, Clone, PartialEq, Eq)] pub struct EphemeralPublicKey { pub mlkem_public_key: MlKemPublicKey, @@ -409,19 +382,15 @@ fn mix_hash_routed_handshake( fn mix_hash_pairing_handshake( symmetric: &mut SymmetricState, crypto: &impl QlCrypto, - header: XxHeader, + header: HandshakeHeader, kind: HandshakeKind, meta: HandshakeMeta, + pairing_id: PairingId, transport_params: TransportParams, ) { - mix_hash_handshake_preamble( - symmetric, - crypto, - &header.encode_vec(), - kind, - meta, - transport_params, - ); + let mut preamble = header.encode_vec(); + pairing_id.encode(&mut preamble); + mix_hash_handshake_preamble(symmetric, crypto, &preamble, kind, meta, transport_params); } fn mix_hash_handshake_preamble( diff --git a/ql-wire/src/handshake/xx.rs b/ql-wire/src/handshake/xx.rs index 8f252263..2f026ed9 100644 --- a/ql-wire/src/handshake/xx.rs +++ b/ql-wire/src/handshake/xx.rs @@ -4,24 +4,26 @@ use super::{ initialize_transport_params, mix_hash_ephemeral, mix_hash_pairing_handshake, mix_psk_pairing_token, require_handshake_meta, require_transport_params, EncryptedMlKemCiphertext, EncryptedPeerBundle, EphemeralKeyPair, EphemeralPublicKey, - FinalizedHandshake, Role, SymmetricState, TransportParams, XxHeader, + FinalizedHandshake, HandshakeHeader, Role, SymmetricState, TransportParams, }; use crate::{ codec, ByteSlice, HandshakeKind, HandshakeMeta, MlKemCiphertext, PairingId, PairingToken, - PeerBundle, QlCrypto, QlIdentity, WireEncode, WireError, + PeerBundle, QlCrypto, QlIdentity, WireEncode, WireError, XID, }; #[derive(Debug, Clone, PartialEq, Eq)] pub struct Xx1 { - pub header: XxHeader, + pub header: HandshakeHeader, pub meta: HandshakeMeta, + pub pairing_id: PairingId, pub transport_params: TransportParams, pub ephemeral: EphemeralPublicKey, } impl Xx1 { - pub const WIRE_SIZE: usize = XxHeader::WIRE_SIZE + pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + HandshakeMeta::WIRE_SIZE + + PairingId::SIZE + TransportParams::WIRE_SIZE + EphemeralPublicKey::WIRE_SIZE; } @@ -31,6 +33,7 @@ impl codec::WireDecode for Xx1 { Ok(Self { header: reader.decode()?, meta: reader.decode()?, + pairing_id: reader.decode()?, transport_params: reader.decode()?, ephemeral: reader.decode()?, }) @@ -45,6 +48,7 @@ impl WireEncode for Xx1 { fn encode(&self, out: &mut W) { self.header.encode(out); self.meta.encode(out); + self.pairing_id.encode(out); self.transport_params.encode(out); self.ephemeral.encode(out); } @@ -52,16 +56,18 @@ impl WireEncode for Xx1 { #[derive(Debug, Clone, PartialEq, Eq)] pub struct Xx2 { - pub header: XxHeader, + pub header: HandshakeHeader, pub meta: HandshakeMeta, + pub pairing_id: PairingId, pub transport_params: TransportParams, pub ekem_ciphertext: MlKemCiphertext, pub static_bundle: EncryptedPeerBundle, } impl Xx2 { - pub const WIRE_SIZE: usize = XxHeader::WIRE_SIZE + pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + HandshakeMeta::WIRE_SIZE + + PairingId::SIZE + TransportParams::WIRE_SIZE + MlKemCiphertext::SIZE + EncryptedPeerBundle::WIRE_SIZE; @@ -72,6 +78,7 @@ impl codec::WireDecode for Xx2 { Ok(Self { header: reader.decode()?, meta: reader.decode()?, + pairing_id: reader.decode()?, transport_params: reader.decode()?, ekem_ciphertext: reader.decode()?, static_bundle: reader.decode()?, @@ -87,6 +94,7 @@ impl WireEncode for Xx2 { fn encode(&self, out: &mut W) { self.header.encode(out); self.meta.encode(out); + self.pairing_id.encode(out); self.transport_params.encode(out); self.ekem_ciphertext.encode(out); self.static_bundle.encode(out); @@ -95,16 +103,18 @@ impl WireEncode for Xx2 { #[derive(Debug, Clone, PartialEq, Eq)] pub struct Xx3 { - pub header: XxHeader, + pub header: HandshakeHeader, pub meta: HandshakeMeta, + pub pairing_id: PairingId, pub transport_params: TransportParams, pub skem_ciphertext: EncryptedMlKemCiphertext, pub static_bundle: EncryptedPeerBundle, } impl Xx3 { - pub const WIRE_SIZE: usize = XxHeader::WIRE_SIZE + pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + HandshakeMeta::WIRE_SIZE + + PairingId::SIZE + TransportParams::WIRE_SIZE + EncryptedMlKemCiphertext::WIRE_SIZE + EncryptedPeerBundle::WIRE_SIZE; @@ -115,6 +125,7 @@ impl codec::WireDecode for Xx3 { Ok(Self { header: reader.decode()?, meta: reader.decode()?, + pairing_id: reader.decode()?, transport_params: reader.decode()?, skem_ciphertext: reader.decode()?, static_bundle: reader.decode()?, @@ -130,6 +141,7 @@ impl WireEncode for Xx3 { fn encode(&self, out: &mut W) { self.header.encode(out); self.meta.encode(out); + self.pairing_id.encode(out); self.transport_params.encode(out); self.skem_ciphertext.encode(out); self.static_bundle.encode(out); @@ -138,15 +150,17 @@ impl WireEncode for Xx3 { #[derive(Debug, Clone, PartialEq, Eq)] pub struct Xx4 { - pub header: XxHeader, + pub header: HandshakeHeader, pub meta: HandshakeMeta, + pub pairing_id: PairingId, pub transport_params: TransportParams, pub skem_ciphertext: EncryptedMlKemCiphertext, } impl Xx4 { - pub const WIRE_SIZE: usize = XxHeader::WIRE_SIZE + pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + HandshakeMeta::WIRE_SIZE + + PairingId::SIZE + TransportParams::WIRE_SIZE + EncryptedMlKemCiphertext::WIRE_SIZE; } @@ -156,6 +170,7 @@ impl codec::WireDecode for Xx4 { Ok(Self { header: reader.decode()?, meta: reader.decode()?, + pairing_id: reader.decode()?, transport_params: reader.decode()?, skem_ciphertext: reader.decode()?, }) @@ -170,6 +185,7 @@ impl WireEncode for Xx4 { fn encode(&self, out: &mut W) { self.header.encode(out); self.meta.encode(out); + self.pairing_id.encode(out); self.transport_params.encode(out); self.skem_ciphertext.encode(out); } @@ -194,6 +210,7 @@ pub struct XxHandshake { step: XxStep, symmetric: SymmetricState, local: QlIdentity, + remote_xid: XID, pairing_token: PairingToken, remote_bundle: Option, local_ephemeral: Option, @@ -207,6 +224,7 @@ impl XxHandshake { pub fn new_initiator( crypto: &impl QlCrypto, local: QlIdentity, + remote_xid: XID, pairing_token: PairingToken, local_transport_params: TransportParams, ) -> Self { @@ -215,6 +233,7 @@ impl XxHandshake { step: XxStep::Send1, symmetric: init_xx_symmetric(crypto), local, + remote_xid, pairing_token, remote_bundle: None, local_ephemeral: None, @@ -228,6 +247,7 @@ impl XxHandshake { pub fn new_responder( crypto: &impl QlCrypto, local: QlIdentity, + remote_xid: XID, pairing_token: PairingToken, local_transport_params: TransportParams, ) -> Self { @@ -236,6 +256,7 @@ impl XxHandshake { step: XxStep::Recv1, symmetric: init_xx_symmetric(crypto), local, + remote_xid, pairing_token, remote_bundle: None, local_ephemeral: None, @@ -258,22 +279,39 @@ impl XxHandshake { self.pairing_token.id(crypto) } + pub fn remote_xid(&self) -> XID { + self.remote_xid + } + pub fn remote_bundle(&self) -> Option<&PeerBundle> { self.remote_bundle.as_ref() } - fn header(&self, crypto: &impl QlCrypto) -> XxHeader { - XxHeader { - pairing_id: self.pairing_token.id(crypto), + fn header(&self) -> HandshakeHeader { + HandshakeHeader { + sender: self.local.xid, + recipient: self.remote_xid, } } fn ensure_inbound_header( &self, crypto: &impl QlCrypto, - header: XxHeader, + header: HandshakeHeader, + pairing_id: PairingId, ) -> Result<(), WireError> { - if header == self.header(crypto) { + if header.sender == self.remote_xid + && header.recipient == self.local.xid + && pairing_id == self.pairing_token.id(crypto) + { + Ok(()) + } else { + Err(WireError::InvalidPayload) + } + } + + fn ensure_remote_bundle(&self, bundle: &PeerBundle) -> Result<(), WireError> { + if bundle.xid == self.remote_xid { Ok(()) } else { Err(WireError::InvalidPayload) @@ -289,13 +327,15 @@ impl XxHandshake { return Err(WireError::InvalidState); } initialize_handshake_meta(&mut self.handshake_meta, meta)?; - let header = self.header(crypto); + let header = self.header(); + let pairing_id = self.pairing_token.id(crypto); mix_hash_pairing_handshake( &mut self.symmetric, crypto, header, HandshakeKind::Xx1, meta, + pairing_id, self.local_transport_params, ); mix_psk_pairing_token(&mut self.symmetric, crypto, self.pairing_token); @@ -309,6 +349,7 @@ impl XxHandshake { Ok(Xx1 { header, meta, + pairing_id, transport_params: self.local_transport_params, ephemeral, }) @@ -319,13 +360,14 @@ impl XxHandshake { return Err(WireError::InvalidState); } initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; - self.ensure_inbound_header(crypto, message.header)?; + self.ensure_inbound_header(crypto, message.header, message.pairing_id)?; mix_hash_pairing_handshake( &mut self.symmetric, crypto, message.header, HandshakeKind::Xx1, message.meta, + message.pairing_id, message.transport_params, ); mix_psk_pairing_token(&mut self.symmetric, crypto, self.pairing_token); @@ -346,13 +388,15 @@ impl XxHandshake { return Err(WireError::InvalidState); } require_handshake_meta(self.handshake_meta.as_ref(), meta)?; - let header = self.header(crypto); + let header = self.header(); + let pairing_id = self.pairing_token.id(crypto); mix_hash_pairing_handshake( &mut self.symmetric, crypto, header, HandshakeKind::Xx2, meta, + pairing_id, self.local_transport_params, ); @@ -371,6 +415,7 @@ impl XxHandshake { Ok(Xx2 { header, meta, + pairing_id, transport_params: self.local_transport_params, ekem_ciphertext, static_bundle, @@ -382,13 +427,14 @@ impl XxHandshake { return Err(WireError::InvalidState); } require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; - self.ensure_inbound_header(crypto, message.header)?; + self.ensure_inbound_header(crypto, message.header, message.pairing_id)?; mix_hash_pairing_handshake( &mut self.symmetric, crypto, message.header, HandshakeKind::Xx2, message.meta, + message.pairing_id, message.transport_params, ); @@ -404,6 +450,7 @@ impl XxHandshake { let remote_bundle = decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; + self.ensure_remote_bundle(&remote_bundle)?; self.remote_bundle = Some(remote_bundle); initialize_transport_params(&mut self.remote_transport_params, message.transport_params)?; self.step = XxStep::Send3; @@ -419,13 +466,15 @@ impl XxHandshake { return Err(WireError::InvalidState); } require_handshake_meta(self.handshake_meta.as_ref(), meta)?; - let header = self.header(crypto); + let header = self.header(); + let pairing_id = self.pairing_token.id(crypto); mix_hash_pairing_handshake( &mut self.symmetric, crypto, header, HandshakeKind::Xx3, meta, + pairing_id, self.local_transport_params, ); @@ -443,6 +492,7 @@ impl XxHandshake { Ok(Xx3 { header, meta, + pairing_id, transport_params: self.local_transport_params, skem_ciphertext, static_bundle, @@ -454,7 +504,7 @@ impl XxHandshake { return Err(WireError::InvalidState); } require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; - self.ensure_inbound_header(crypto, message.header)?; + self.ensure_inbound_header(crypto, message.header, message.pairing_id)?; require_transport_params( self.remote_transport_params.as_ref(), message.transport_params, @@ -465,6 +515,7 @@ impl XxHandshake { message.header, HandshakeKind::Xx3, message.meta, + message.pairing_id, message.transport_params, ); @@ -476,6 +527,7 @@ impl XxHandshake { let remote_bundle = decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; + self.ensure_remote_bundle(&remote_bundle)?; self.remote_bundle = Some(remote_bundle); self.step = XxStep::Send4; Ok(()) @@ -490,13 +542,15 @@ impl XxHandshake { return Err(WireError::InvalidState); } require_handshake_meta(self.handshake_meta.as_ref(), meta)?; - let header = self.header(crypto); + let header = self.header(); + let pairing_id = self.pairing_token.id(crypto); mix_hash_pairing_handshake( &mut self.symmetric, crypto, header, HandshakeKind::Xx4, meta, + pairing_id, self.local_transport_params, ); @@ -512,6 +566,7 @@ impl XxHandshake { Ok(Xx4 { header, meta, + pairing_id, transport_params: self.local_transport_params, skem_ciphertext, }) @@ -522,7 +577,7 @@ impl XxHandshake { return Err(WireError::InvalidState); } require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; - self.ensure_inbound_header(crypto, message.header)?; + self.ensure_inbound_header(crypto, message.header, message.pairing_id)?; require_transport_params( self.remote_transport_params.as_ref(), message.transport_params, @@ -533,6 +588,7 @@ impl XxHandshake { message.header, HandshakeKind::Xx4, message.meta, + message.pairing_id, message.transport_params, ); diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 36d80d59..d0fcfa9f 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -58,9 +58,10 @@ fn pairing_id(byte: u8) -> PairingId { PairingId([byte; PairingId::SIZE]) } -fn xx_header(byte: u8) -> XxHeader { - XxHeader { - pairing_id: pairing_id(byte), +fn xx_header(sender: u8, recipient: u8) -> HandshakeHeader { + HandshakeHeader { + sender: xid(sender), + recipient: xid(recipient), } } @@ -136,8 +137,9 @@ fn handshake_record_round_trip_supports_ik_kk_and_xx() { assert_eq!(decode_handshake_record(kk_encoded.as_slice()), kk); let xx = QlHandshakeRecord::Xx1(Xx1 { - header: xx_header(3), + header: xx_header(1, 2), meta: handshake_meta(3), + pairing_id: pairing_id(3), transport_params: handshake_transport_params(196_608), ephemeral: EphemeralPublicKey { mlkem_public_key: MlKemPublicKey::new(Box::new([17; MlKemPublicKey::SIZE])), @@ -497,15 +499,82 @@ fn xx_handshake_rejects_tampered_pairing_id() { let (initiator, responder) = test_identities(&crypto); let token = pairing_token(7); - let mut initiator_state = - XxHandshake::new_initiator(&crypto, initiator, token, TransportParams::default()); - let mut responder_state = - XxHandshake::new_responder(&crypto, responder, token, TransportParams::default()); + let mut initiator_state = XxHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.xid, + token, + TransportParams::default(), + ); + let mut responder_state = XxHandshake::new_responder( + &crypto, + responder, + initiator.xid, + token, + TransportParams::default(), + ); let mut m1 = initiator_state .write_1(&crypto, handshake_meta(31)) .unwrap(); - m1.header.pairing_id = pairing_id(8); + m1.pairing_id = pairing_id(8); + + assert_eq!( + responder_state.read_1(&crypto, &m1), + Err(WireError::InvalidPayload) + ); +} + +#[test] +fn xx_handshake_rejects_tampered_sender_or_recipient() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + let token = pairing_token(7); + + let mut initiator_state = XxHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.xid, + token, + TransportParams::default(), + ); + let mut responder_state = XxHandshake::new_responder( + &crypto, + responder.clone(), + initiator.xid, + token, + TransportParams::default(), + ); + + let mut m1 = initiator_state + .write_1(&crypto, handshake_meta(31)) + .unwrap(); + m1.header.sender = responder.xid; + + assert_eq!( + responder_state.read_1(&crypto, &m1), + Err(WireError::InvalidPayload) + ); + + let mut initiator_state = XxHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.xid, + token, + TransportParams::default(), + ); + let mut responder_state = XxHandshake::new_responder( + &crypto, + responder.clone(), + initiator.xid, + token, + TransportParams::default(), + ); + + let mut m1 = initiator_state + .write_1(&crypto, handshake_meta(31)) + .unwrap(); + m1.header.recipient = initiator.xid; assert_eq!( responder_state.read_1(&crypto, &m1), @@ -522,12 +591,14 @@ fn xx_handshake_rejects_repeated_transport_param_change() { let mut initiator_state = XxHandshake::new_initiator( &crypto, initiator.clone(), + responder.xid, token, handshake_transport_params(12_288), ); let mut responder_state = XxHandshake::new_responder( &crypto, responder, + initiator.xid, token, handshake_transport_params(24_576), ); @@ -561,10 +632,20 @@ fn xx_handshake_round_trip_derives_matching_transport_and_learns_remote() { let initiator_params = handshake_transport_params(28_672); let responder_params = handshake_transport_params(57_344); - let mut initiator_state = - XxHandshake::new_initiator(&crypto, initiator.clone(), token, initiator_params); - let mut responder_state = - XxHandshake::new_responder(&crypto, responder.clone(), token, responder_params); + let mut initiator_state = XxHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.xid, + token, + initiator_params, + ); + let mut responder_state = XxHandshake::new_responder( + &crypto, + responder.clone(), + initiator.xid, + token, + responder_params, + ); assert_eq!(initiator_state.pairing_token(), token); assert_eq!(responder_state.pairing_token(), token); @@ -774,12 +855,14 @@ fn protocol_record_size_breakdown() { let mut xx_initiator = XxHandshake::new_initiator( &crypto, initiator.clone(), + responder.xid, token, TransportParams::default(), ); let mut xx_responder = XxHandshake::new_responder( &crypto, responder.clone(), + initiator.xid, token, TransportParams::default(), ); From 4a87474316a13e590d1a4d3f33ee3f89565280a4 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 13 May 2026 15:17:46 -0400 Subject: [PATCH 294/304] ql: detailed recv errors --- ql-fsm/src/error.rs | 43 +++++++++++++++++-------------- ql-fsm/src/fsm.rs | 16 +++++++----- ql-fsm/src/handshake/ik.rs | 27 +++++++++++++++----- ql-fsm/src/handshake/kk.rs | 29 +++++++++++++++------ ql-fsm/src/handshake/mod.rs | 2 +- ql-fsm/src/handshake/xx.rs | 49 ++++++++++++++++++++++++++++-------- ql-fsm/src/tests/proptest.rs | 17 +++++++++---- ql-wire/src/error.rs | 10 ++++++++ ql-wire/src/handshake/mod.rs | 8 +++--- ql-wire/src/handshake/xx.rs | 15 ++++++----- ql-wire/src/tests.rs | 10 ++++---- 11 files changed, 154 insertions(+), 72 deletions(-) diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs index 99b829d6..e38108b6 100644 --- a/ql-fsm/src/error.rs +++ b/ql-fsm/src/error.rs @@ -7,11 +7,18 @@ use ql_wire::{PairingId, WireError}; #[derive(Debug, Clone, PartialEq, Eq)] pub enum ReceiveError { - InvalidPayload, - InvalidState, - Expired, - DecryptFailed, + InvalidRecordHeader(WireError), + InvalidRecordVersion, + InvalidHandshakeRecord(WireError), + InvalidSessionRecord(WireError), + InvalidSessionConnectionId, + InvalidSessionPayload(WireError), + InvalidIkHandshake(WireError), + InvalidKkHandshake(WireError), + InvalidXxHandshake(WireError), + InvalidRemoteBundle, InvalidXid, + NoPeer, NoSession, NotPairingMode, InvalidPairingId { @@ -23,11 +30,20 @@ pub enum ReceiveError { impl Display for ReceiveError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Self::InvalidPayload => f.write_str("invalid payload"), - Self::InvalidState => f.write_str("invalid state"), - Self::Expired => f.write_str("expired"), - Self::DecryptFailed => f.write_str("decryption failed"), + Self::InvalidRecordHeader(error) => write!(f, "invalid record header: {error}"), + Self::InvalidRecordVersion => f.write_str("invalid record version"), + Self::InvalidHandshakeRecord(error) => { + write!(f, "invalid handshake record: {error}") + } + Self::InvalidSessionRecord(error) => write!(f, "invalid session record: {error}"), + Self::InvalidSessionConnectionId => f.write_str("invalid session connection id"), + Self::InvalidSessionPayload(error) => write!(f, "invalid session payload: {error}"), + Self::InvalidIkHandshake(error) => write!(f, "invalid ik handshake: {error}"), + Self::InvalidKkHandshake(error) => write!(f, "invalid kk handshake: {error}"), + Self::InvalidXxHandshake(error) => write!(f, "invalid xx handshake: {error}"), + Self::InvalidRemoteBundle => f.write_str("invalid remote bundle"), Self::InvalidXid => f.write_str("invalid xid"), + Self::NoPeer => f.write_str("no bound peer"), Self::NoSession => f.write_str("no active session"), Self::NotPairingMode => f.write_str("not in pairing mode"), Self::InvalidPairingId { expected, actual } => { @@ -42,17 +58,6 @@ impl Display for ReceiveError { impl std::error::Error for ReceiveError {} -impl From for ReceiveError { - fn from(value: WireError) -> Self { - match value { - WireError::InvalidPayload => Self::InvalidPayload, - WireError::InvalidState => Self::InvalidState, - WireError::Expired => Self::Expired, - WireError::DecryptFailed => Self::DecryptFailed, - } - } -} - impl From for ReceiveError { fn from(_: NoSessionError) -> Self { Self::NoSession diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs index c73fb6fc..036a336e 100644 --- a/ql-fsm/src/fsm.rs +++ b/ql-fsm/src/fsm.rs @@ -112,15 +112,17 @@ pub fn receive( crypto: &impl QlCrypto, ) -> Result<(), ReceiveError> { let mut reader = wire::Reader::new(bytes.as_mut_slice()); - let header = wire::RecordHeader::decode(&mut reader)?; + let header = + wire::RecordHeader::decode(&mut reader).map_err(ReceiveError::InvalidRecordHeader)?; if header.version != wire::QL_WIRE_VERSION { - return Err(ReceiveError::InvalidPayload); + return Err(ReceiveError::InvalidRecordVersion); } match header.record_type { wire::RecordType::Handshake => { - let record = wire::QlHandshakeRecord::decode(&mut reader)?; + let record = wire::QlHandshakeRecord::decode(&mut reader) + .map_err(ReceiveError::InvalidHandshakeRecord)?; handshake::handle_handshake_record(fsm, crypto, &record) } wire::RecordType::Session => { @@ -128,16 +130,18 @@ pub fn receive( let QlFsm { state, events, .. } = fsm; let conn = state.link.connected_mut_or_err()?; let (decrypt_len, seq) = { - let record = wire::QlSessionRecord::decode(&mut reader)?; + let record = wire::QlSessionRecord::decode(&mut reader) + .map_err(ReceiveError::InvalidSessionRecord)?; if record.header.connection_id != conn.transport.rx_connection_id { - return Err(ReceiveError::InvalidPayload); + return Err(ReceiveError::InvalidSessionConnectionId); } let payload = wire::decrypt_record( crypto, &record.header, record.payload, &conn.transport.rx_key, - )?; + ) + .map_err(ReceiveError::InvalidSessionPayload)?; (payload.len(), record.header.seq) }; diff --git a/ql-fsm/src/handshake/ik.rs b/ql-fsm/src/handshake/ik.rs index 9e4cb2cd..10d210de 100644 --- a/ql-fsm/src/handshake/ik.rs +++ b/ql-fsm/src/handshake/ik.rs @@ -53,9 +53,17 @@ pub fn handle_ik1( fsm.state.peer.clone(), super::local_transport_params(fsm), ); - handshake.read_1(crypto, message)?; - let outbound = handshake.write_2(crypto, message.meta)?; - let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); + handshake + .read_1(crypto, message) + .map_err(ReceiveError::InvalidIkHandshake)?; + let outbound = handshake + .write_2(crypto, message.meta) + .map_err(ReceiveError::InvalidIkHandshake)?; + let (transport, remote_bundle) = SessionTransport::from_finalized( + handshake + .finalize(crypto) + .map_err(ReceiveError::InvalidIkHandshake)?, + ); finish_handshake(fsm, transport, remote_bundle)?; fsm.state.handshake = None; enqueue_handshake(fsm, QlHandshakeRecord::Ik2(outbound)); @@ -76,14 +84,21 @@ pub fn handle_ik2( return Ok(()); } - state.handshake.read_2(crypto, message)?; + state + .handshake + .read_2(crypto, message) + .map_err(ReceiveError::InvalidIkHandshake)?; } let LinkState::IkInitiator(state) = fsm.state.link.take() else { unreachable!("active IK initiator was checked above"); }; - let (transport, remote_bundle) = - SessionTransport::from_finalized(state.handshake.finalize(crypto)?); + let (transport, remote_bundle) = SessionTransport::from_finalized( + state + .handshake + .finalize(crypto) + .map_err(ReceiveError::InvalidIkHandshake)?, + ); finish_handshake(fsm, transport, remote_bundle) } diff --git a/ql-fsm/src/handshake/kk.rs b/ql-fsm/src/handshake/kk.rs index 99615561..140a9c5e 100644 --- a/ql-fsm/src/handshake/kk.rs +++ b/ql-fsm/src/handshake/kk.rs @@ -38,7 +38,7 @@ pub fn handle_kk1( } let Some(peer) = fsm.state.peer.clone() else { - return Err(ReceiveError::InvalidPayload); + return Err(ReceiveError::NoPeer); }; if message.header.recipient != fsm.identity.xid || message.header.sender != peer.xid { return Err(ReceiveError::InvalidXid); @@ -52,9 +52,17 @@ pub fn handle_kk1( peer, super::local_transport_params(fsm), ); - handshake.read_1(crypto, message)?; - let outbound = handshake.write_2(crypto, message.meta)?; - let (transport, remote_bundle) = SessionTransport::from_finalized(handshake.finalize(crypto)?); + handshake + .read_1(crypto, message) + .map_err(ReceiveError::InvalidKkHandshake)?; + let outbound = handshake + .write_2(crypto, message.meta) + .map_err(ReceiveError::InvalidKkHandshake)?; + let (transport, remote_bundle) = SessionTransport::from_finalized( + handshake + .finalize(crypto) + .map_err(ReceiveError::InvalidKkHandshake)?, + ); finish_handshake(fsm, transport, remote_bundle)?; fsm.state.handshake = None; enqueue_handshake(fsm, QlHandshakeRecord::Kk2(outbound)); @@ -75,14 +83,21 @@ pub fn handle_kk2( return Ok(()); } - state.handshake.read_2(crypto, message)?; + state + .handshake + .read_2(crypto, message) + .map_err(ReceiveError::InvalidKkHandshake)?; } let LinkState::KkInitiator(state) = fsm.state.link.take() else { unreachable!("active KK initiator was checked above"); }; - let (transport, remote_bundle) = - SessionTransport::from_finalized(state.handshake.finalize(crypto)?); + let (transport, remote_bundle) = SessionTransport::from_finalized( + state + .handshake + .finalize(crypto) + .map_err(ReceiveError::InvalidKkHandshake)?, + ); finish_handshake(fsm, transport, remote_bundle) } diff --git a/ql-fsm/src/handshake/mod.rs b/ql-fsm/src/handshake/mod.rs index 2431bbc8..e2d2a5a2 100644 --- a/ql-fsm/src/handshake/mod.rs +++ b/ql-fsm/src/handshake/mod.rs @@ -98,7 +98,7 @@ pub fn finish_handshake( let xid = remote_bundle.xid; if let Some(peer) = fsm.state.peer.as_ref() { if peer != &remote_bundle { - return Err(ReceiveError::InvalidPayload); + return Err(ReceiveError::InvalidRemoteBundle); } } else { fsm.state.peer = Some(remote_bundle); diff --git a/ql-fsm/src/handshake/xx.rs b/ql-fsm/src/handshake/xx.rs index dbc72006..79653b46 100644 --- a/ql-fsm/src/handshake/xx.rs +++ b/ql-fsm/src/handshake/xx.rs @@ -65,8 +65,12 @@ pub fn handle_xx1( token, super::local_transport_params(fsm), ); - handshake.read_1(crypto, message)?; - let outbound = handshake.write_2(crypto, message.meta)?; + handshake + .read_1(crypto, message) + .map_err(ReceiveError::InvalidXxHandshake)?; + let outbound = handshake + .write_2(crypto, message.meta) + .map_err(ReceiveError::InvalidXxHandshake)?; fsm.state.link = LinkState::XxResponder(XxResponderState { handshake, handshake_meta: message.meta, @@ -94,8 +98,14 @@ pub fn handle_xx2( return Ok(()); } - state.handshake.read_2(crypto, message)?; - let outbound = state.handshake.write_3(crypto, message.meta)?; + state + .handshake + .read_2(crypto, message) + .map_err(ReceiveError::InvalidXxHandshake)?; + let outbound = state + .handshake + .write_3(crypto, message.meta) + .map_err(ReceiveError::InvalidXxHandshake)?; fsm.state.handshake = None; enqueue_handshake(fsm, QlHandshakeRecord::Xx3(outbound)); } @@ -116,16 +126,26 @@ pub fn handle_xx3( return Ok(()); } - state.handshake.read_3(crypto, message)?; + state + .handshake + .read_3(crypto, message) + .map_err(ReceiveError::InvalidXxHandshake)?; let handshake_meta = state.handshake_meta; let LinkState::XxResponder(mut state) = fsm.state.link.take() else { unreachable!("active XX responder was checked above"); }; - let outbound = state.handshake.write_4(crypto, handshake_meta)?; + let outbound = state + .handshake + .write_4(crypto, handshake_meta) + .map_err(ReceiveError::InvalidXxHandshake)?; fsm.state.handshake = None; enqueue_handshake(fsm, QlHandshakeRecord::Xx4(outbound)); - let (transport, remote_bundle) = - SessionTransport::from_finalized(state.handshake.finalize(crypto)?); + let (transport, remote_bundle) = SessionTransport::from_finalized( + state + .handshake + .finalize(crypto) + .map_err(ReceiveError::InvalidXxHandshake)?, + ); finish_handshake(fsm, transport, remote_bundle) } @@ -143,14 +163,21 @@ pub fn handle_xx4( return Ok(()); } - state.handshake.read_4(crypto, message)?; + state + .handshake + .read_4(crypto, message) + .map_err(ReceiveError::InvalidXxHandshake)?; } let LinkState::XxInitiator(state) = fsm.state.link.take() else { unreachable!("active XX initiator was checked above"); }; - let (transport, remote_bundle) = - SessionTransport::from_finalized(state.handshake.finalize(crypto)?); + let (transport, remote_bundle) = SessionTransport::from_finalized( + state + .handshake + .finalize(crypto) + .map_err(ReceiveError::InvalidXxHandshake)?, + ); finish_handshake(fsm, transport, remote_bundle) } diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs index d95513b3..bc97ca77 100644 --- a/ql-fsm/src/tests/proptest.rs +++ b/ql-fsm/src/tests/proptest.rs @@ -7,7 +7,7 @@ extern crate proptest as proptest_crate; use bytes::Bytes; use proptest_crate::{collection::vec, prelude::*, test_runner::TestCaseResult}; -use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; +use ql_wire::{CloseTarget, StreamCloseCode, StreamId, WireError}; use super::*; @@ -544,10 +544,17 @@ impl Runner { matches!( error, ReceiveError::NoSession - | ReceiveError::InvalidState - | ReceiveError::Expired - | ReceiveError::InvalidPayload - | ReceiveError::DecryptFailed + | ReceiveError::NoPeer + | ReceiveError::InvalidRemoteBundle + | ReceiveError::InvalidSessionPayload(WireError::InvalidPayload) + | ReceiveError::InvalidSessionPayload(WireError::DecryptFailed) + | ReceiveError::InvalidIkHandshake(WireError::InvalidPayload) + | ReceiveError::InvalidIkHandshake(WireError::InvalidState) + | ReceiveError::InvalidKkHandshake(WireError::InvalidPayload) + | ReceiveError::InvalidKkHandshake(WireError::InvalidState) + | ReceiveError::InvalidXxHandshake(WireError::InvalidPayload) + | ReceiveError::InvalidXxHandshake(WireError::InvalidState) + | ReceiveError::InvalidXxHandshake(WireError::DecryptFailed) ), "unexpected receive error on side {side:?}: {error:?}" ); diff --git a/ql-wire/src/error.rs b/ql-wire/src/error.rs index 6f17d648..8da1eec0 100644 --- a/ql-wire/src/error.rs +++ b/ql-wire/src/error.rs @@ -3,6 +3,11 @@ use core::fmt; #[derive(Debug, Clone, PartialEq, Eq)] pub enum WireError { InvalidPayload, + InvalidHandshakeHeader, + InvalidHandshakeMeta, + InvalidPairingId, + InvalidRemoteBundle, + InvalidTransportParams, Expired, DecryptFailed, InvalidState, @@ -12,6 +17,11 @@ impl fmt::Display for WireError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let message = match self { Self::InvalidPayload => "invalid payload", + Self::InvalidHandshakeHeader => "invalid handshake header", + Self::InvalidHandshakeMeta => "invalid handshake meta", + Self::InvalidPairingId => "invalid pairing id", + Self::InvalidRemoteBundle => "invalid remote bundle", + Self::InvalidTransportParams => "invalid transport params", Self::Expired => "expired", Self::DecryptFailed => "decryption failed", Self::InvalidState => "invalid state", diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index 6ba9209c..fcbe02c4 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -413,7 +413,7 @@ fn initialize_handshake_meta( meta: HandshakeMeta, ) -> Result<(), WireError> { match expected { - Some(stored) if *stored != meta => Err(WireError::InvalidPayload), + Some(stored) if *stored != meta => Err(WireError::InvalidHandshakeMeta), Some(_) => Ok(()), None => { *expected = Some(meta); @@ -428,7 +428,7 @@ fn require_handshake_meta( ) -> Result<(), WireError> { match expected { Some(stored) if *stored == meta => Ok(()), - _ => Err(WireError::InvalidPayload), + _ => Err(WireError::InvalidHandshakeMeta), } } @@ -437,7 +437,7 @@ fn initialize_transport_params( transport_params: TransportParams, ) -> Result<(), WireError> { match expected { - Some(stored) if *stored != transport_params => Err(WireError::InvalidPayload), + Some(stored) if *stored != transport_params => Err(WireError::InvalidTransportParams), Some(_) => Ok(()), None => { *expected = Some(transport_params); @@ -452,7 +452,7 @@ fn require_transport_params( ) -> Result<(), WireError> { match expected { Some(stored) if *stored == transport_params => Ok(()), - _ => Err(WireError::InvalidPayload), + _ => Err(WireError::InvalidTransportParams), } } diff --git a/ql-wire/src/handshake/xx.rs b/ql-wire/src/handshake/xx.rs index 2f026ed9..fb35e861 100644 --- a/ql-wire/src/handshake/xx.rs +++ b/ql-wire/src/handshake/xx.rs @@ -300,21 +300,20 @@ impl XxHandshake { header: HandshakeHeader, pairing_id: PairingId, ) -> Result<(), WireError> { - if header.sender == self.remote_xid - && header.recipient == self.local.xid - && pairing_id == self.pairing_token.id(crypto) - { - Ok(()) - } else { - Err(WireError::InvalidPayload) + if header.sender != self.remote_xid || header.recipient != self.local.xid { + return Err(WireError::InvalidHandshakeHeader); } + if pairing_id != self.pairing_token.id(crypto) { + return Err(WireError::InvalidPairingId); + } + Ok(()) } fn ensure_remote_bundle(&self, bundle: &PeerBundle) -> Result<(), WireError> { if bundle.xid == self.remote_xid { Ok(()) } else { - Err(WireError::InvalidPayload) + Err(WireError::InvalidRemoteBundle) } } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index d0fcfa9f..ae4660ff 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -182,7 +182,7 @@ fn ik_handshake_rejects_tampered_handshake_meta() { assert_eq!( initiator_state.read_2(&crypto, &m2), - Err(WireError::InvalidPayload) + Err(WireError::InvalidHandshakeMeta) ); } @@ -521,7 +521,7 @@ fn xx_handshake_rejects_tampered_pairing_id() { assert_eq!( responder_state.read_1(&crypto, &m1), - Err(WireError::InvalidPayload) + Err(WireError::InvalidPairingId) ); } @@ -553,7 +553,7 @@ fn xx_handshake_rejects_tampered_sender_or_recipient() { assert_eq!( responder_state.read_1(&crypto, &m1), - Err(WireError::InvalidPayload) + Err(WireError::InvalidHandshakeHeader) ); let mut initiator_state = XxHandshake::new_initiator( @@ -578,7 +578,7 @@ fn xx_handshake_rejects_tampered_sender_or_recipient() { assert_eq!( responder_state.read_1(&crypto, &m1), - Err(WireError::InvalidPayload) + Err(WireError::InvalidHandshakeHeader) ); } @@ -620,7 +620,7 @@ fn xx_handshake_rejects_repeated_transport_param_change() { assert_eq!( responder_state.read_3(&crypto, &m3), - Err(WireError::InvalidPayload) + Err(WireError::InvalidTransportParams) ); } From f0ae09fe9cc8921324a7a355a415fda61d741253 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 14 May 2026 09:02:44 -0400 Subject: [PATCH 295/304] ql-rpc: add duplex modality --- ql-rpc/src/router/builder.rs | 34 ++++++++ ql-rpc/src/router/mod.rs | 1 + ql-rpc/src/rpc/duplex/client.rs | 149 ++++++++++++++++++++++++++++++++ ql-rpc/src/rpc/duplex/codec.rs | 90 +++++++++++++++++++ ql-rpc/src/rpc/duplex/mod.rs | 25 ++++++ ql-rpc/src/rpc/duplex/server.rs | 38 ++++++++ ql-rpc/src/rpc/mod.rs | 2 + ql-runtime/src/rpc/duplex.rs | 59 +++++++++++++ ql-runtime/src/rpc/mod.rs | 22 ++++- ql-runtime/src/tests/rpc.rs | 85 +++++++++++++++++- 10 files changed, 500 insertions(+), 5 deletions(-) create mode 100644 ql-rpc/src/rpc/duplex/client.rs create mode 100644 ql-rpc/src/rpc/duplex/codec.rs create mode 100644 ql-rpc/src/rpc/duplex/mod.rs create mode 100644 ql-rpc/src/rpc/duplex/server.rs create mode 100644 ql-runtime/src/rpc/duplex.rs diff --git a/ql-rpc/src/router/builder.rs b/ql-rpc/src/router/builder.rs index 9910eb5c..a31887be 100644 --- a/ql-rpc/src/router/builder.rs +++ b/ql-rpc/src/router/builder.rs @@ -9,6 +9,10 @@ use crate::{ server::{handle_download_inner, DownloadHandler}, Download as DownloadRpc, }, + duplex::{ + server::{handle_duplex_inner, DuplexHandler}, + Duplex as DuplexRpc, + }, notification::{ server::{handle_notification_inner, NotificationHandler}, Notification as NotificationRpc, @@ -111,6 +115,19 @@ where }) } + pub fn duplex(self) -> Self + where + M: DuplexRpc + 'static, + S: DuplexHandler + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_duplex_inner::( + state, config, reader, writer, + )) + }) + } + pub fn download(self) -> Self where M: DownloadRpc + 'static, @@ -200,6 +217,23 @@ where }) } + pub fn duplex(self) -> Self + where + M: DuplexRpc + 'static, + M::InitiatorEvent: Send + 'static, + M::ResponderEvent: Send + 'static, + S: DuplexHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_duplex_inner::( + state, config, reader, writer, + )) + }) + } + pub fn download(self) -> Self where M: DownloadRpc + 'static, diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs index dfdbc960..3d1ab6cd 100644 --- a/ql-rpc/src/router/mod.rs +++ b/ql-rpc/src/router/mod.rs @@ -10,6 +10,7 @@ pub use self::{builder::RouterBuilder, config::RouterConfig, mode::*}; use crate::{close_stream, RpcStream}; pub use crate::{ download::{DownloadHandler, DownloadResponder, DownloadWriter}, + duplex::{DuplexHandler, DuplexPeer}, notification::NotificationHandler, progress::{ProgressHandler, ProgressResponder}, request::{RequestHandler, Response}, diff --git a/ql-rpc/src/rpc/duplex/client.rs b/ql-rpc/src/rpc/duplex/client.rs new file mode 100644 index 00000000..579a9f5e --- /dev/null +++ b/ql-rpc/src/rpc/duplex/client.rs @@ -0,0 +1,149 @@ +use std::{ + future::poll_fn, + marker::PhantomData, + task::{Context, Poll}, +}; + +use bytes::Bytes; + +use crate::{ + duplex::{codec, Duplex, EventReader, ReadStep}, + finish_bytes, write_bytes, CallError, RpcCodec, RpcRead, RpcWrite, StreamCloseCode, +}; + +pub struct DuplexCall +where + M: Duplex, + W: RpcWrite, + R: RpcRead, +{ + pub sender: DuplexSender, + pub receiver: DuplexReceiver, +} + +pub struct DuplexSender +where + T: RpcCodec, + W: RpcWrite, +{ + writer: Option, + marker: PhantomData T>, +} + +pub struct DuplexReceiver +where + T: RpcCodec, + R: RpcRead, +{ + stream: R, + reader: Option>, +} + +impl DuplexSender +where + T: RpcCodec, + W: RpcWrite, +{ + pub fn new(writer: W) -> Self { + Self { + writer: Some(writer), + marker: PhantomData, + } + } + + pub async fn send(&mut self, event: &T) -> Result<(), W::Error> { + let writer = self.writer.as_mut().expect("duplex writer exists"); + let mut encoded = Vec::new(); + codec::encode_event(event, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await + } + + pub async fn finish(mut self) -> Result<(), W::Error> { + let mut writer = self.writer.take().expect("duplex writer exists"); + finish_bytes(&mut writer).await + } + + pub fn close(mut self, code: StreamCloseCode) { + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for DuplexSender +where + T: RpcCodec, + W: RpcWrite, +{ + fn drop(&mut self) { + if let Some(writer) = self.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } +} + +impl DuplexReceiver +where + T: RpcCodec, + R: RpcRead, +{ + pub fn new(stream: R) -> Self { + Self { + stream, + reader: Some(EventReader::default()), + } + } + + pub async fn next_event(&mut self) -> Option>> { + poll_fn(|cx| self.poll_next_event(cx)).await + } + + pub fn poll_next_event( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + loop { + let Some(reader) = self.reader.take() else { + return Poll::Ready(None); + }; + + let reader = match reader.advance() { + Ok(ReadStep::Event { value, next }) => { + self.reader = Some(next); + return Poll::Ready(Some(Ok(value))); + } + Ok(ReadStep::NeedMore(next)) => next, + Err(error) => return Poll::Ready(Some(Err(error.into()))), + }; + + match self.stream.poll_read(usize::MAX, cx) { + Poll::Ready(Ok(Some(chunk))) => { + self.reader = Some(reader.push(chunk)); + } + Poll::Ready(Ok(None)) => { + self.reader = None; + if reader.is_empty() { + return Poll::Ready(None); + } + return Poll::Ready(Some(Err(crate::Error::Truncated.into()))); + } + Poll::Ready(Err(error)) => { + self.reader = None; + return Poll::Ready(Some(Err(CallError::Transport(error)))); + } + Poll::Pending => { + self.reader = Some(reader); + return Poll::Pending; + } + } + } + } + + pub fn close(self, code: StreamCloseCode) { + self.stream.close(code); + } + + pub fn into_inner(self) -> R { + self.stream + } +} diff --git a/ql-rpc/src/rpc/duplex/codec.rs b/ql-rpc/src/rpc/duplex/codec.rs new file mode 100644 index 00000000..f7950e89 --- /dev/null +++ b/ql-rpc/src/rpc/duplex/codec.rs @@ -0,0 +1,90 @@ +use std::marker::PhantomData; + +use bytes::{BufMut, Bytes}; + +use crate::{codec, CodecError, RpcCodec}; + +pub fn encode_event(event: &T, out: &mut (impl BufMut + AsMut<[u8]>)) +where + T: RpcCodec, +{ + codec::encode_value_part(event, out) +} + +pub enum ReadStep { + NeedMore(EventReader), + Event { value: T, next: EventReader }, +} + +pub struct EventReader { + bytes: codec::ChunkQueue, + marker: PhantomData T>, +} + +impl Default for EventReader { + fn default() -> Self { + Self { + bytes: codec::ChunkQueue::default(), + marker: PhantomData, + } + } +} + +impl EventReader { + pub fn push(mut self, chunk: Bytes) -> Self { + self.bytes.push(chunk); + self + } + + pub fn is_empty(&self) -> bool { + self.bytes.remaining() == 0 + } + + pub fn advance(self) -> Result, CodecError> { + let mut this = self; + let Some(mut body) = this.bytes.try_take_part()? else { + return Ok(ReadStep::NeedMore(this)); + }; + + let value = { + let value = T::decode_value(&mut body).map_err(CodecError::Codec)?; + drop(body); + value + }; + Ok(ReadStep::Event { value, next: this }) + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::{encode_event, EventReader, ReadStep}; + + #[test] + fn event_reader_emits_multiple_events() { + let mut encoded = Vec::new(); + encode_event(&b"one".to_vec(), &mut encoded); + encode_event(&b"two".to_vec(), &mut encoded); + + let reader = match EventReader::>::default() + .push(Bytes::from(encoded)) + .advance() + .unwrap() + { + ReadStep::Event { value, next } => { + assert_eq!(value, b"one".to_vec()); + next + } + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + ReadStep::Event { value, next } => { + assert_eq!(value, b"two".to_vec()); + assert!(next.is_empty()); + } + _ => unreachable!(), + } + } +} diff --git a/ql-rpc/src/rpc/duplex/mod.rs b/ql-rpc/src/rpc/duplex/mod.rs new file mode 100644 index 00000000..8ff772e0 --- /dev/null +++ b/ql-rpc/src/rpc/duplex/mod.rs @@ -0,0 +1,25 @@ +use crate::{RouteId, RpcCodec}; + +pub(crate) mod client; +pub(crate) mod codec; +pub(crate) mod server; + +pub use client::{DuplexCall, DuplexReceiver, DuplexSender}; +pub use codec::{encode_event, EventReader, ReadStep}; +pub use server::{DuplexHandler, DuplexPeer}; + +/// rpc where both sides exchange typed events on the same stream +/// +/// The initiator opens the routed stream. After that, either side may send any +/// number of events of its directional event type until it finishes or closes +/// its write side. +pub trait Duplex { + /// route used to dispatch this rpc family + const ROUTE: RouteId; + /// codec error shared by both directional event values + type Error; + /// typed event sent by the side that opened the stream + type InitiatorEvent: RpcCodec; + /// typed event sent by the side handling the route + type ResponderEvent: RpcCodec; +} diff --git a/ql-rpc/src/rpc/duplex/server.rs b/ql-rpc/src/rpc/duplex/server.rs new file mode 100644 index 00000000..8dc6093b --- /dev/null +++ b/ql-rpc/src/rpc/duplex/server.rs @@ -0,0 +1,38 @@ +use crate::{ + duplex::{Duplex, DuplexReceiver, DuplexSender}, + RpcRead, RpcStream, RpcWrite, +}; + +pub trait DuplexHandler +where + M: Duplex, + St: RpcStream, +{ + fn handle(self, peer: DuplexPeer); +} + +pub struct DuplexPeer +where + M: Duplex, + W: RpcWrite, + R: RpcRead, +{ + pub sender: DuplexSender, + pub receiver: DuplexReceiver, +} + +pub(crate) async fn handle_duplex_inner( + state: S, + _config: crate::RouterConfig, + reader: St::Reader, + writer: St::Writer, +) where + M: Duplex + 'static, + S: DuplexHandler + 'static, + St: RpcStream + 'static, +{ + state.handle(DuplexPeer { + sender: DuplexSender::new(writer), + receiver: DuplexReceiver::new(reader), + }); +} diff --git a/ql-rpc/src/rpc/mod.rs b/ql-rpc/src/rpc/mod.rs index b3b85cf4..1f8f32fb 100644 --- a/ql-rpc/src/rpc/mod.rs +++ b/ql-rpc/src/rpc/mod.rs @@ -6,6 +6,7 @@ //! client and server helpers for encoding, decoding, and handler glue pub mod download; +pub mod duplex; pub mod notification; pub mod progress; pub mod request; @@ -14,6 +15,7 @@ pub mod upload; mod utils; pub use download::Download; +pub use duplex::Duplex; pub use notification::Notification; pub use progress::Progress; pub use request::Request; diff --git a/ql-runtime/src/rpc/duplex.rs b/ql-runtime/src/rpc/duplex.rs new file mode 100644 index 00000000..cdad6670 --- /dev/null +++ b/ql-runtime/src/rpc/duplex.rs @@ -0,0 +1,59 @@ +use futures_lite::future::poll_fn; +use ql_rpc::duplex::Duplex as DuplexRpc; + +use super::RpcError; +use crate::{QlStreamError, StreamReader, StreamWriter}; + +pub struct DuplexCall { + pub sender: DuplexSender, + pub receiver: DuplexReceiver, +} + +pub struct DuplexSender +where + T: ql_rpc::RpcCodec, +{ + pub(super) inner: ql_rpc::duplex::DuplexSender, +} + +pub struct DuplexReceiver +where + T: ql_rpc::RpcCodec, +{ + pub(super) inner: ql_rpc::duplex::DuplexReceiver, +} + +impl DuplexSender +where + T: ql_rpc::RpcCodec, +{ + pub async fn send(&mut self, event: &T) -> Result<(), QlStreamError> { + self.inner.send(event).await + } + + pub async fn finish(self) -> Result<(), QlStreamError> { + self.inner.finish().await + } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } +} + +impl DuplexReceiver +where + T: ql_rpc::RpcCodec, +{ + pub async fn next_event(&mut self) -> Option>> { + poll_fn(|cx| { + self.inner + .poll_next_event(cx) + .map(|item| item.map(|result| Ok(result?))) + }) + .await + } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } +} diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs index c6960b17..d8be02c5 100644 --- a/ql-runtime/src/rpc/mod.rs +++ b/ql-runtime/src/rpc/mod.rs @@ -1,7 +1,8 @@ -pub use self::{download::*, error::*, progress::*, subscription::*, upload::*}; +pub use self::{download::*, duplex::*, error::*, progress::*, subscription::*, upload::*}; mod adapter; mod download; +mod duplex; mod error; mod progress; mod subscription; @@ -10,6 +11,7 @@ mod upload; use bytes::Bytes; use ql_rpc::{ download::{self as rpc_download, Download as DownloadRpc}, + duplex::{self as rpc_duplex, Duplex as DuplexRpc}, notification::{self, Notification}, progress::{self as rpc_progress, Progress}, request::{self, Request as RequestRpc}, @@ -111,6 +113,24 @@ impl RpcHandle { inner: rpc_upload::UploadCall::new(stream.writer, stream.reader), }) } + + pub async fn duplex(&self) -> Result, RpcError> + where + M: DuplexRpc, + { + let stream = self + .inner + .open_stream(adapter::to_wire_route_id(M::ROUTE)) + .await?; + Ok(DuplexCall { + sender: DuplexSender { + inner: rpc_duplex::DuplexSender::new(stream.writer), + }, + receiver: DuplexReceiver { + inner: rpc_duplex::DuplexReceiver::new(stream.reader), + }, + }) + } } impl RpcHandle { diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index 5335998f..cc08934b 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -9,10 +9,10 @@ use std::{ use bytes::Bytes; use futures_lite::StreamExt; use ql_rpc::{ - DownloadHandler, DownloadResponder, DownloadWriter, LocalSpawn, NotificationHandler, - ProgressHandler, ProgressResponder, RequestHandler, Response, RouteId, SendSpawn, - StreamCloseCode, SubscriptionHandler, SubscriptionResponder, UploadHandler, UploadReader, - UploadResponder, + DownloadHandler, DownloadResponder, DownloadWriter, DuplexHandler, DuplexPeer, LocalSpawn, + NotificationHandler, ProgressHandler, ProgressResponder, RequestHandler, Response, RouteId, + SendSpawn, StreamCloseCode, SubscriptionHandler, SubscriptionResponder, UploadHandler, + UploadReader, UploadResponder, }; use super::*; @@ -74,6 +74,15 @@ impl ql_rpc::upload::Upload for BlobUpload { type Response = Vec; } +struct Chat; + +impl ql_rpc::duplex::Duplex for Chat { + const ROUTE: RouteId = RouteId::from_u32(56); + type Error = core::convert::Infallible; + type InitiatorEvent = Vec; + type ResponderEvent = Vec; +} + #[tokio::test(flavor = "current_thread")] async fn rpc_request() { #[derive(Clone)] @@ -459,3 +468,71 @@ async fn rpc_upload() { }) .await; } + +#[tokio::test(flavor = "current_thread")] +async fn rpc_duplex() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>>, + } + + impl DuplexHandler for RouterState { + fn handle(self, mut peer: DuplexPeer) { + let seen = self.seen.clone(); + tokio::task::spawn_local(async move { + let first = peer.receiver.next_event().await.unwrap().unwrap(); + seen.borrow_mut().push(first); + + peer.sender + .send(&b"challenge-response".to_vec()) + .await + .unwrap(); + + let second = peer.receiver.next_event().await.unwrap().unwrap(); + seen.borrow_mut().push(second); + + peer.sender.finish().await.unwrap(); + }); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let seen = Rc::new(RefCell::new(Vec::new())); + + let router = ql_rpc::Router::<_, QlStream, LocalSpawn>::builder(LocalSpawn) + .duplex::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await; + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let mut chat = rpc.duplex::().await.unwrap(); + chat.sender.send(&b"challenge".to_vec()).await.unwrap(); + assert_eq!( + chat.receiver.next_event().await.unwrap().unwrap(), + b"challenge-response".to_vec() + ); + chat.sender.send(&b"verification".to_vec()).await.unwrap(); + chat.sender.finish().await.unwrap(); + assert!(chat.receiver.next_event().await.is_none()); + + assert_eq!( + seen.borrow().as_slice(), + &[b"challenge".to_vec(), b"verification".to_vec()] + ); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} From 9322a76581de971e37730a8f58c966705e319e99 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 14 May 2026 10:39:35 -0400 Subject: [PATCH 296/304] ql-rpc: expose router ids --- ql-rpc/src/router/builder.rs | 15 +++++++-------- ql-rpc/src/router/mod.rs | 31 +++++++++++++++++++++++++++---- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/ql-rpc/src/router/builder.rs b/ql-rpc/src/router/builder.rs index a31887be..bb5d2e18 100644 --- a/ql-rpc/src/router/builder.rs +++ b/ql-rpc/src/router/builder.rs @@ -1,8 +1,6 @@ -use std::collections::HashMap; - use super::{ - LocalSpawn, LocalSpawner, RouteFn, Router, RouterConfig, RpcStream, SendSpawn, SendSpawner, - Spawner, + LocalSpawn, LocalSpawner, RouteEntry, RouteFn, Router, RouterConfig, RpcStream, SendSpawn, + SendSpawner, Spawner, }; use crate::{ download::{ @@ -33,7 +31,6 @@ use crate::{ server::{handle_upload_inner, UploadHandler}, Upload as UploadRpc, }, - RouteId, }; pub struct RouterBuilder @@ -42,7 +39,7 @@ where { config: RouterConfig, spawner: Sp, - routes: HashMap>, + routes: Vec>, } impl RouterBuilder @@ -53,7 +50,7 @@ where Self { config: RouterConfig::default(), spawner, - routes: std::collections::HashMap::new(), + routes: Vec::new(), } } @@ -68,6 +65,7 @@ where } pub fn build(mut self, state: S) -> Router { + self.routes.sort_by_key(|entry| entry.route_id); self.routes.shrink_to_fit(); Router { config: self.config, @@ -78,9 +76,10 @@ where } fn add_route(mut self, route_id: crate::RouteId, route: RouteFn) -> Self { - if self.routes.insert(route_id, route).is_some() { + if self.routes.iter().any(|entry| entry.route_id == route_id) { panic!("duplicate rpc route {}", route_id.into_inner()); } + self.routes.push(RouteEntry::new(route_id, route)); self } } diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs index 3d1ab6cd..334bd161 100644 --- a/ql-rpc/src/router/mod.rs +++ b/ql-rpc/src/router/mod.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; - use crate::{RouteId, StreamCloseCode}; mod builder; @@ -25,7 +23,24 @@ where config: RouterConfig, state: S, spawner: Sp, - routes: HashMap>, + routes: Vec>, +} + +struct RouteEntry +where + Sp: Spawner, +{ + route_id: RouteId, + route: RouteFn, +} + +impl RouteEntry +where + Sp: Spawner, +{ + fn new(route_id: RouteId, route: RouteFn) -> Self { + Self { route_id, route } + } } impl Router @@ -40,13 +55,21 @@ where pub fn handle(&self, stream: St) -> Option<(RouteId, Sp::Handle)> { let route_id = stream.route_id()?; - let Some(route) = self.routes.get(&route_id) else { + let Ok(index) = self + .routes + .binary_search_by_key(&route_id, |entry| entry.route_id) + else { close_stream(stream, StreamCloseCode::UNKNOWN_ROUTE); return None; }; + let route = self.routes[index].route; Some(( route_id, route(&self.spawner, self.state.clone(), self.config, stream), )) } + + pub fn route_ids(&self) -> impl ExactSizeIterator + '_ { + self.routes.iter().map(|entry| entry.route_id) + } } From cb48acdba29a834083e266a3297d777715627da4 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 14 May 2026 12:59:47 -0400 Subject: [PATCH 297/304] ql-rpc: better router --- Cargo.lock | 12 ++ ql-rpc/Cargo.toml | 1 + ql-rpc/src/router/builder.rs | 195 +++++++++++++-------- ql-rpc/src/router/mod.rs | 34 ++-- ql-rpc/src/router/mode.rs | 38 +---- ql-rpc/src/rpc/download/mod.rs | 2 +- ql-rpc/src/rpc/download/server.rs | 19 ++- ql-rpc/src/rpc/duplex/mod.rs | 2 +- ql-rpc/src/rpc/duplex/server.rs | 25 ++- ql-rpc/src/rpc/notification/mod.rs | 2 +- ql-rpc/src/rpc/notification/server.rs | 19 ++- ql-rpc/src/rpc/progress/mod.rs | 2 +- ql-rpc/src/rpc/progress/server.rs | 19 ++- ql-rpc/src/rpc/request/mod.rs | 2 +- ql-rpc/src/rpc/request/server.rs | 19 ++- ql-rpc/src/rpc/subscription/mod.rs | 2 +- ql-rpc/src/rpc/subscription/server.rs | 23 ++- ql-rpc/src/rpc/upload/mod.rs | 2 +- ql-rpc/src/rpc/upload/server.rs | 28 ++- ql-runtime/src/tests/rpc.rs | 236 +++++++++++++++----------- 20 files changed, 408 insertions(+), 274 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 872dd247..123d0e59 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2358,6 +2358,7 @@ name = "ql-rpc" version = "0.1.0" dependencies = [ "bytes", + "trait-variant", ] [[package]] @@ -3146,6 +3147,17 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "trait-variant" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70977707304198400eb4835a78f6a9f928bf41bba420deb8fdb175cd965d77a7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "typenum" version = "1.18.0" diff --git a/ql-rpc/Cargo.toml b/ql-rpc/Cargo.toml index 897ee9ef..51a764dc 100644 --- a/ql-rpc/Cargo.toml +++ b/ql-rpc/Cargo.toml @@ -7,3 +7,4 @@ license = "Proprietary" [dependencies] bytes = { version = "1" } +trait-variant = { version = "0.1" } diff --git a/ql-rpc/src/router/builder.rs b/ql-rpc/src/router/builder.rs index bb5d2e18..b59a84e6 100644 --- a/ql-rpc/src/router/builder.rs +++ b/ql-rpc/src/router/builder.rs @@ -1,56 +1,41 @@ +use std::marker::PhantomData; + use super::{ - LocalSpawn, LocalSpawner, RouteEntry, RouteFn, Router, RouterConfig, RpcStream, SendSpawn, - SendSpawner, Spawner, + LocalSpawner, RouteEntry, RouteFn, Router, RouterConfig, RpcStream, SendSpawner, Spawner, }; use crate::{ - download::{ - server::{handle_download_inner, DownloadHandler}, - Download as DownloadRpc, - }, - duplex::{ - server::{handle_duplex_inner, DuplexHandler}, - Duplex as DuplexRpc, - }, - notification::{ - server::{handle_notification_inner, NotificationHandler}, - Notification as NotificationRpc, - }, - progress::{ - server::{handle_progress_inner, ProgressHandler}, - Progress as ProgressRpc, - }, - request::{ - server::{handle_request_inner, RequestHandler}, - Request as RequestRpc, - }, - subscription::{ - server::{handle_subscription_inner, SubscriptionHandler}, - Subscription as SubscriptionRpc, - }, - upload::{ - server::{handle_upload_inner, UploadHandler}, - Upload as UploadRpc, - }, + download::{server::*, Download as DownloadRpc}, + duplex::{server::*, Duplex as DuplexRpc}, + notification::{server::*, Notification as NotificationRpc}, + progress::{server::*, Progress as ProgressRpc}, + request::{server::*, Request as RequestRpc}, + subscription::{server::*, Subscription as SubscriptionRpc}, + upload::{server::*, Upload as UploadRpc}, }; -pub struct RouterBuilder +pub struct LocalRoutes; +pub struct SendRoutes; + +pub struct RouterBuilder where Sp: Spawner, { config: RouterConfig, spawner: Sp, routes: Vec>, + marker: PhantomData Mode>, } -impl RouterBuilder +impl RouterBuilder where Sp: Spawner, { - pub fn new(spawner: Sp) -> Self { + pub(crate) fn new(spawner: Sp) -> Self { Self { config: RouterConfig::default(), spawner, routes: Vec::new(), + marker: PhantomData, } } @@ -84,19 +69,25 @@ where } } -impl RouterBuilder +impl RouterBuilder where + Sp: LocalSpawner, St: RpcStream + 'static, { pub fn request(self) -> Self where M: RequestRpc + 'static, - S: RequestHandler + 'static, + S: RequestHandlerLocal + 'static, { self.add_route(M::ROUTE, |spawner, state, config, stream| { let (reader, writer) = stream.split(); - spawner.spawn(handle_request_inner::( - state, config, reader, writer, + spawner.spawn(handle_request_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, )) }) } @@ -104,12 +95,17 @@ where pub fn notification(self) -> Self where M: NotificationRpc + 'static, - S: NotificationHandler + 'static, + S: NotificationHandlerLocal + 'static, { self.add_route(M::ROUTE, |spawner, state, config, stream| { let (reader, writer) = stream.split(); - spawner.spawn(handle_notification_inner::( - state, config, reader, writer, + spawner.spawn(handle_notification_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, )) }) } @@ -117,12 +113,16 @@ where pub fn duplex(self) -> Self where M: DuplexRpc + 'static, - S: DuplexHandler + 'static, + S: DuplexHandlerLocal + 'static, { self.add_route(M::ROUTE, |spawner, state, config, stream| { let (reader, writer) = stream.split(); - spawner.spawn(handle_duplex_inner::( - state, config, reader, writer, + spawner.spawn(handle_duplex_inner::( + state, + config, + reader, + writer, + S::handle, )) }) } @@ -130,12 +130,17 @@ where pub fn download(self) -> Self where M: DownloadRpc + 'static, - S: DownloadHandler + 'static, + S: DownloadHandlerLocal + 'static, { self.add_route(M::ROUTE, |spawner, state, config, stream| { let (reader, writer) = stream.split(); - spawner.spawn(handle_download_inner::( - state, config, reader, writer, + spawner.spawn(handle_download_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, )) }) } @@ -143,12 +148,17 @@ where pub fn subscription(self) -> Self where M: SubscriptionRpc + 'static, - S: SubscriptionHandler + 'static, + S: SubscriptionHandlerLocal + 'static, { self.add_route(M::ROUTE, |spawner, state, config, stream| { let (reader, writer) = stream.split(); - spawner.spawn(handle_subscription_inner::( - state, config, reader, writer, + spawner.spawn(handle_subscription_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, )) }) } @@ -156,12 +166,17 @@ where pub fn progress(self) -> Self where M: ProgressRpc + 'static, - S: ProgressHandler + 'static, + S: ProgressHandlerLocal + 'static, { self.add_route(M::ROUTE, |spawner, state, config, stream| { let (reader, writer) = stream.split(); - spawner.spawn(handle_progress_inner::( - state, config, reader, writer, + spawner.spawn(handle_progress_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, )) }) } @@ -169,19 +184,25 @@ where pub fn upload(self) -> Self where M: UploadRpc + 'static, - S: UploadHandler + 'static, + S: UploadHandlerLocal + 'static, { self.add_route(M::ROUTE, |spawner, state, config, stream| { let (reader, writer) = stream.split(); - spawner.spawn(handle_upload_inner::( - state, config, reader, writer, + spawner.spawn(handle_upload_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, )) }) } } -impl RouterBuilder +impl RouterBuilder where + Sp: SendSpawner + Send, St: RpcStream + 'static, { pub fn request(self) -> Self @@ -194,8 +215,13 @@ where { self.add_route(M::ROUTE, |spawner, state, config, stream| { let (reader, writer) = stream.split(); - spawner.spawn(handle_request_inner::( - state, config, reader, writer, + spawner.spawn(handle_request_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, )) }) } @@ -210,8 +236,13 @@ where { self.add_route(M::ROUTE, |spawner, state, config, stream| { let (reader, writer) = stream.split(); - spawner.spawn(handle_notification_inner::( - state, config, reader, writer, + spawner.spawn(handle_notification_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, )) }) } @@ -227,8 +258,12 @@ where { self.add_route(M::ROUTE, |spawner, state, config, stream| { let (reader, writer) = stream.split(); - spawner.spawn(handle_duplex_inner::( - state, config, reader, writer, + spawner.spawn(handle_duplex_inner::( + state, + config, + reader, + writer, + S::handle, )) }) } @@ -243,8 +278,13 @@ where { self.add_route(M::ROUTE, |spawner, state, config, stream| { let (reader, writer) = stream.split(); - spawner.spawn(handle_download_inner::( - state, config, reader, writer, + spawner.spawn(handle_download_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, )) }) } @@ -259,8 +299,13 @@ where { self.add_route(M::ROUTE, |spawner, state, config, stream| { let (reader, writer) = stream.split(); - spawner.spawn(handle_subscription_inner::( - state, config, reader, writer, + spawner.spawn(handle_subscription_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, )) }) } @@ -275,8 +320,13 @@ where { self.add_route(M::ROUTE, |spawner, state, config, stream| { let (reader, writer) = stream.split(); - spawner.spawn(handle_progress_inner::( - state, config, reader, writer, + spawner.spawn(handle_progress_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, )) }) } @@ -291,8 +341,13 @@ where { self.add_route(M::ROUTE, |spawner, state, config, stream| { let (reader, writer) = stream.split(); - spawner.spawn(handle_upload_inner::( - state, config, reader, writer, + spawner.spawn(handle_upload_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, )) }) } diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs index 334bd161..522f0e52 100644 --- a/ql-rpc/src/router/mod.rs +++ b/ql-rpc/src/router/mod.rs @@ -4,16 +4,20 @@ mod builder; mod config; mod mode; -pub use self::{builder::RouterBuilder, config::RouterConfig, mode::*}; +pub use self::{ + builder::{LocalRoutes, RouterBuilder, SendRoutes}, + config::RouterConfig, + mode::*, +}; use crate::{close_stream, RpcStream}; pub use crate::{ - download::{DownloadHandler, DownloadResponder, DownloadWriter}, - duplex::{DuplexHandler, DuplexPeer}, - notification::NotificationHandler, - progress::{ProgressHandler, ProgressResponder}, - request::{RequestHandler, Response}, - subscription::{SubscriptionHandler, SubscriptionResponder}, - upload::{UploadHandler, UploadReader, UploadResponder}, + download::{DownloadHandler, DownloadHandlerLocal, DownloadResponder, DownloadWriter}, + duplex::{DuplexHandler, DuplexHandlerLocal, DuplexPeer}, + notification::{NotificationHandler, NotificationHandlerLocal}, + progress::{ProgressHandler, ProgressHandlerLocal, ProgressResponder}, + request::{RequestHandler, RequestHandlerLocal, Response}, + subscription::{SubscriptionHandler, SubscriptionHandlerLocal, SubscriptionResponder}, + upload::{UploadHandler, UploadHandlerLocal, UploadReader, UploadResponder}, }; pub struct Router @@ -49,8 +53,18 @@ where St: RpcStream, Sp: Spawner, { - pub fn builder(spawner: Sp) -> RouterBuilder { - RouterBuilder::::new(spawner) + pub fn builder_local(spawner: Sp) -> RouterBuilder + where + Sp: LocalSpawner, + { + RouterBuilder::::new(spawner) + } + + pub fn builder_send(spawner: Sp) -> RouterBuilder + where + Sp: SendSpawner, + { + RouterBuilder::::new(spawner) } pub fn handle(&self, stream: St) -> Option<(RouteId, Sp::Handle)> { diff --git a/ql-rpc/src/router/mode.rs b/ql-rpc/src/router/mode.rs index 5d22d706..33b6c06a 100644 --- a/ql-rpc/src/router/mode.rs +++ b/ql-rpc/src/router/mode.rs @@ -1,11 +1,11 @@ -use std::{future::Future, pin::Pin}; +use std::future::Future; use crate::RouterConfig; pub type RouteFn = fn(&Sp, S, RouterConfig, St) -> ::Handle; -pub trait Spawner { - type Handle: Future + 'static; +pub trait Spawner: Clone + 'static { + type Handle; } pub trait LocalSpawner: Spawner { @@ -19,35 +19,3 @@ pub trait SendSpawner: Spawner { where F: Future + Send + 'static; } - -#[derive(Debug, Clone, Copy, Default)] -pub struct LocalSpawn; - -impl Spawner for LocalSpawn { - type Handle = Pin + 'static>>; -} - -impl LocalSpawner for LocalSpawn { - fn spawn(&self, fut: F) -> Self::Handle - where - F: Future + 'static, - { - Box::pin(fut) - } -} - -#[derive(Debug, Clone, Copy, Default)] -pub struct SendSpawn; - -impl Spawner for SendSpawn { - type Handle = Pin + Send + 'static>>; -} - -impl SendSpawner for SendSpawn { - fn spawn(&self, fut: F) -> Self::Handle - where - F: Future + Send + 'static, - { - Box::pin(fut) - } -} diff --git a/ql-rpc/src/rpc/download/mod.rs b/ql-rpc/src/rpc/download/mod.rs index 27da9898..967158cc 100644 --- a/ql-rpc/src/rpc/download/mod.rs +++ b/ql-rpc/src/rpc/download/mod.rs @@ -6,7 +6,7 @@ pub(crate) mod server; pub use client::{DownloadCall, DownloadReader}; pub use codec::{encode_request, encode_response_header, ReadStep, ResponseHeaderReader}; -pub use server::{DownloadHandler, DownloadResponder, DownloadWriter}; +pub use server::{DownloadHandler, DownloadHandlerLocal, DownloadResponder, DownloadWriter}; /// rpc where the responder returns metadata first and raw bytes after that /// diff --git a/ql-rpc/src/rpc/download/server.rs b/ql-rpc/src/rpc/download/server.rs index 43560c39..0e26b7b6 100644 --- a/ql-rpc/src/rpc/download/server.rs +++ b/ql-rpc/src/rpc/download/server.rs @@ -1,4 +1,4 @@ -use std::marker::PhantomData; +use std::{future::Future, marker::PhantomData}; use bytes::Bytes; @@ -7,12 +7,13 @@ use crate::{ RouterConfig, RpcCodec, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, }; -pub trait DownloadHandler +#[trait_variant::make(DownloadHandler: Send)] +pub trait DownloadHandlerLocal where M: DownloadRpc, St: RpcStream, { - fn handle( + async fn handle( self, message: M::Request, responder: DownloadResponder, @@ -108,21 +109,25 @@ where } } -pub(crate) async fn handle_download_inner( +pub(crate) async fn handle_download_inner( state: S, config: RouterConfig, mut reader: St::Reader, writer: St::Writer, + handle: H, + handle_transport_error: E, ) where M: DownloadRpc + 'static, - S: DownloadHandler + 'static, St: RpcStream + 'static, + H: FnOnce(S, M::Request, DownloadResponder) -> HF, + HF: Future, + E: FnOnce(&S, &St::Error), { let request = match read_eof_request::(&mut reader, config).await { Ok(request) => request, Err(error) => { let code = error.close_code(); - state.handle_transport_error(&error); + handle_transport_error(&state, &error); if let Some(code) = code { reader.close(code); writer.close(code); @@ -131,5 +136,5 @@ pub(crate) async fn handle_download_inner( } }; - state.handle(request, DownloadResponder::new(writer)); + handle(state, request, DownloadResponder::new(writer)).await; } diff --git a/ql-rpc/src/rpc/duplex/mod.rs b/ql-rpc/src/rpc/duplex/mod.rs index 8ff772e0..c8f3f603 100644 --- a/ql-rpc/src/rpc/duplex/mod.rs +++ b/ql-rpc/src/rpc/duplex/mod.rs @@ -6,7 +6,7 @@ pub(crate) mod server; pub use client::{DuplexCall, DuplexReceiver, DuplexSender}; pub use codec::{encode_event, EventReader, ReadStep}; -pub use server::{DuplexHandler, DuplexPeer}; +pub use server::{DuplexHandler, DuplexHandlerLocal, DuplexPeer}; /// rpc where both sides exchange typed events on the same stream /// diff --git a/ql-rpc/src/rpc/duplex/server.rs b/ql-rpc/src/rpc/duplex/server.rs index 8dc6093b..bf024335 100644 --- a/ql-rpc/src/rpc/duplex/server.rs +++ b/ql-rpc/src/rpc/duplex/server.rs @@ -1,14 +1,17 @@ +use std::future::Future; + use crate::{ duplex::{Duplex, DuplexReceiver, DuplexSender}, RpcRead, RpcStream, RpcWrite, }; -pub trait DuplexHandler +#[trait_variant::make(DuplexHandler: Send)] +pub trait DuplexHandlerLocal where M: Duplex, St: RpcStream, { - fn handle(self, peer: DuplexPeer); + async fn handle(self, peer: DuplexPeer); } pub struct DuplexPeer @@ -21,18 +24,24 @@ where pub receiver: DuplexReceiver, } -pub(crate) async fn handle_duplex_inner( +pub(crate) async fn handle_duplex_inner( state: S, _config: crate::RouterConfig, reader: St::Reader, writer: St::Writer, + handle: H, ) where M: Duplex + 'static, - S: DuplexHandler + 'static, St: RpcStream + 'static, + H: FnOnce(S, DuplexPeer) -> HF, + HF: Future, { - state.handle(DuplexPeer { - sender: DuplexSender::new(writer), - receiver: DuplexReceiver::new(reader), - }); + handle( + state, + DuplexPeer { + sender: DuplexSender::new(writer), + receiver: DuplexReceiver::new(reader), + }, + ) + .await; } diff --git a/ql-rpc/src/rpc/notification/mod.rs b/ql-rpc/src/rpc/notification/mod.rs index e773378f..d57bb3f5 100644 --- a/ql-rpc/src/rpc/notification/mod.rs +++ b/ql-rpc/src/rpc/notification/mod.rs @@ -4,7 +4,7 @@ pub(crate) mod client; pub(crate) mod server; pub use client::encode_notification; -pub use server::NotificationHandler; +pub use server::{NotificationHandler, NotificationHandlerLocal}; /// one-way rpc that carries a single typed payload and no typed response /// diff --git a/ql-rpc/src/rpc/notification/server.rs b/ql-rpc/src/rpc/notification/server.rs index fc98684c..c9a4fdba 100644 --- a/ql-rpc/src/rpc/notification/server.rs +++ b/ql-rpc/src/rpc/notification/server.rs @@ -1,33 +1,40 @@ +use std::future::Future; + use crate::{ notification::Notification as NotificationRpc, rpc::read_eof_request, RouterConfig, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, }; -pub trait NotificationHandler +#[trait_variant::make(NotificationHandler: Send)] +pub trait NotificationHandlerLocal where M: NotificationRpc, St: RpcStream, { - fn handle(self, message: M::Payload); + async fn handle(self, message: M::Payload); fn handle_transport_error(&self, _error: &St::Error) {} } -pub(crate) async fn handle_notification_inner( +pub(crate) async fn handle_notification_inner( state: S, config: RouterConfig, mut reader: St::Reader, writer: St::Writer, + handle: H, + handle_transport_error: E, ) where M: NotificationRpc + 'static, - S: NotificationHandler + 'static, St: RpcStream + 'static, + H: FnOnce(S, M::Payload) -> HF, + HF: Future, + E: FnOnce(&S, &St::Error), { let notification = match read_eof_request::(&mut reader, config).await { Ok(notification) => notification, Err(error) => { let code = error.close_code(); - state.handle_transport_error(&error); + handle_transport_error(&state, &error); if let Some(code) = code { reader.close(code); writer.close(code); @@ -37,5 +44,5 @@ pub(crate) async fn handle_notification_inner( }; writer.close(StreamCloseCode::CANCELLED); - state.handle(notification); + handle(state, notification).await; } diff --git a/ql-rpc/src/rpc/progress/mod.rs b/ql-rpc/src/rpc/progress/mod.rs index 5828def0..1ee935e0 100644 --- a/ql-rpc/src/rpc/progress/mod.rs +++ b/ql-rpc/src/rpc/progress/mod.rs @@ -6,7 +6,7 @@ pub(crate) mod server; pub use client::ProgressCall; pub use codec::{encode_progress, encode_request, encode_response, ReadStep, ResponseReader}; -pub use server::{ProgressHandler, ProgressResponder}; +pub use server::{ProgressHandler, ProgressHandlerLocal, ProgressResponder}; /// rpc where the responder streams progress values before a final response /// diff --git a/ql-rpc/src/rpc/progress/server.rs b/ql-rpc/src/rpc/progress/server.rs index 4acf1553..8599d93e 100644 --- a/ql-rpc/src/rpc/progress/server.rs +++ b/ql-rpc/src/rpc/progress/server.rs @@ -1,4 +1,4 @@ -use std::marker::PhantomData; +use std::{future::Future, marker::PhantomData}; use bytes::Bytes; @@ -9,12 +9,13 @@ use crate::{ write_bytes, RouterConfig, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, }; -pub trait ProgressHandler +#[trait_variant::make(ProgressHandler: Send)] +pub trait ProgressHandlerLocal where M: Progress, St: RpcStream, { - fn handle(self, request: M::Request, responder: ProgressResponder); + async fn handle(self, request: M::Request, responder: ProgressResponder); fn handle_transport_error(&self, _error: &St::Error) {} } @@ -74,21 +75,25 @@ where } } -pub(crate) async fn handle_progress_inner( +pub(crate) async fn handle_progress_inner( state: S, config: RouterConfig, mut reader: St::Reader, writer: St::Writer, + handle: H, + handle_transport_error: E, ) where M: Progress + 'static, - S: ProgressHandler + 'static, St: RpcStream + 'static, + H: FnOnce(S, M::Request, ProgressResponder) -> HF, + HF: Future, + E: FnOnce(&S, &St::Error), { let request = match read_framed_request::(&mut reader, config).await { Ok(request) => request, Err(error) => { let code = error.close_code(); - state.handle_transport_error(&error); + handle_transport_error(&state, &error); if let Some(code) = code { reader.close(code); writer.close(code); @@ -97,5 +102,5 @@ pub(crate) async fn handle_progress_inner( } }; - state.handle(request, ProgressResponder::new(writer)); + handle(state, request, ProgressResponder::new(writer)).await; } diff --git a/ql-rpc/src/rpc/request/mod.rs b/ql-rpc/src/rpc/request/mod.rs index 3c690542..a81ba523 100644 --- a/ql-rpc/src/rpc/request/mod.rs +++ b/ql-rpc/src/rpc/request/mod.rs @@ -4,7 +4,7 @@ pub(crate) mod client; pub(crate) mod server; pub use client::{encode_request, encode_response, read_response}; -pub use server::{RequestHandler, Response}; +pub use server::{RequestHandler, RequestHandlerLocal, Response}; /// request-response rpc with exactly one typed value in each direction /// diff --git a/ql-rpc/src/rpc/request/server.rs b/ql-rpc/src/rpc/request/server.rs index 949dacef..e3347061 100644 --- a/ql-rpc/src/rpc/request/server.rs +++ b/ql-rpc/src/rpc/request/server.rs @@ -1,4 +1,4 @@ -use std::marker::PhantomData; +use std::{future::Future, marker::PhantomData}; use bytes::Bytes; @@ -7,12 +7,13 @@ use crate::{ RpcCodec, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, }; -pub trait RequestHandler +#[trait_variant::make(RequestHandler: Send)] +pub trait RequestHandlerLocal where M: RequestRpc, St: RpcStream, { - fn handle(self, message: M::Request, responder: Response); + async fn handle(self, message: M::Request, responder: Response); fn handle_transport_error(&self, _error: &St::Error) {} } @@ -64,21 +65,25 @@ where } } -pub(crate) async fn handle_request_inner( +pub(crate) async fn handle_request_inner( state: S, config: RouterConfig, mut reader: St::Reader, writer: St::Writer, + handle: H, + handle_transport_error: E, ) where M: RequestRpc + 'static, - S: RequestHandler + 'static, St: RpcStream + 'static, + H: FnOnce(S, M::Request, Response) -> HF, + HF: Future, + E: FnOnce(&S, &St::Error), { let request = match read_eof_request::(&mut reader, config).await { Ok(request) => request, Err(error) => { let code = error.close_code(); - state.handle_transport_error(&error); + handle_transport_error(&state, &error); if let Some(code) = code { reader.close(code); writer.close(code); @@ -87,5 +92,5 @@ pub(crate) async fn handle_request_inner( } }; - state.handle(request, Response::new(writer)); + handle(state, request, Response::new(writer)).await; } diff --git a/ql-rpc/src/rpc/subscription/mod.rs b/ql-rpc/src/rpc/subscription/mod.rs index 0c4790bf..f66ffa7a 100644 --- a/ql-rpc/src/rpc/subscription/mod.rs +++ b/ql-rpc/src/rpc/subscription/mod.rs @@ -6,7 +6,7 @@ pub(crate) mod server; pub use client::SubscriptionCall; pub use codec::{encode_item, encode_request, ReadStep, ResponseReader}; -pub use server::{SubscriptionHandler, SubscriptionResponder}; +pub use server::{SubscriptionHandler, SubscriptionHandlerLocal, SubscriptionResponder}; /// rpc where one request opens a stream of typed events /// diff --git a/ql-rpc/src/rpc/subscription/server.rs b/ql-rpc/src/rpc/subscription/server.rs index 7a193462..32fac4f6 100644 --- a/ql-rpc/src/rpc/subscription/server.rs +++ b/ql-rpc/src/rpc/subscription/server.rs @@ -1,4 +1,4 @@ -use std::marker::PhantomData; +use std::{future::Future, marker::PhantomData}; use bytes::Bytes; @@ -8,12 +8,17 @@ use crate::{ StreamError, }; -pub trait SubscriptionHandler +#[trait_variant::make(SubscriptionHandler: Send)] +pub trait SubscriptionHandlerLocal where M: SubscriptionRpc, St: RpcStream, { - fn handle(self, message: M::Request, responder: SubscriptionResponder); + async fn handle( + self, + message: M::Request, + responder: SubscriptionResponder, + ); fn handle_transport_error(&self, _error: &St::Error) {} } @@ -69,21 +74,25 @@ where } } -pub(crate) async fn handle_subscription_inner( +pub(crate) async fn handle_subscription_inner( state: S, config: RouterConfig, mut reader: St::Reader, writer: St::Writer, + handle: H, + handle_transport_error: E, ) where M: SubscriptionRpc + 'static, - S: SubscriptionHandler + 'static, St: RpcStream + 'static, + H: FnOnce(S, M::Request, SubscriptionResponder) -> HF, + HF: Future, + E: FnOnce(&S, &St::Error), { let request = match read_eof_request::(&mut reader, config).await { Ok(request) => request, Err(error) => { let code = error.close_code(); - state.handle_transport_error(&error); + handle_transport_error(&state, &error); if let Some(code) = code { reader.close(code); writer.close(code); @@ -92,5 +101,5 @@ pub(crate) async fn handle_subscription_inner( } }; - state.handle(request, SubscriptionResponder::new(writer)); + handle(state, request, SubscriptionResponder::new(writer)).await; } diff --git a/ql-rpc/src/rpc/upload/mod.rs b/ql-rpc/src/rpc/upload/mod.rs index be4a4b6c..985bb5e7 100644 --- a/ql-rpc/src/rpc/upload/mod.rs +++ b/ql-rpc/src/rpc/upload/mod.rs @@ -4,7 +4,7 @@ pub(crate) mod client; pub(crate) mod server; pub use client::{encode_request, UploadCall}; -pub use server::{UploadHandler, UploadReader, UploadResponder}; +pub use server::{UploadHandler, UploadHandlerLocal, UploadReader, UploadResponder}; /// rpc where the caller uploads raw bytes after a typed request /// diff --git a/ql-rpc/src/rpc/upload/server.rs b/ql-rpc/src/rpc/upload/server.rs index 1a702ad4..b12783c8 100644 --- a/ql-rpc/src/rpc/upload/server.rs +++ b/ql-rpc/src/rpc/upload/server.rs @@ -1,5 +1,5 @@ use std::{ - future::poll_fn, + future::{poll_fn, Future}, task::{Context, Poll}, }; @@ -10,12 +10,13 @@ use crate::{ RpcStream, RpcWrite, StreamCloseCode, StreamError, Upload, }; -pub trait UploadHandler +#[trait_variant::make(UploadHandler: Send)] +pub trait UploadHandlerLocal where M: Upload, St: RpcStream, { - fn handle( + async fn handle( self, request: M::Request, upload: UploadReader, @@ -89,22 +90,31 @@ where } } -pub(crate) async fn handle_upload_inner( +pub(crate) async fn handle_upload_inner( state: S, config: RouterConfig, mut reader: St::Reader, writer: St::Writer, + handle: H, + handle_transport_error: E, ) where M: Upload + 'static, - S: UploadHandler + 'static, St: RpcStream + 'static, + H: FnOnce( + S, + M::Request, + UploadReader, + UploadResponder, + ) -> HF, + HF: Future, + E: FnOnce(&S, &St::Error), { let (request, buffered) = match read_framed_request_prefix::(&mut reader, config).await { Ok(value) => value, Err(error) => { let code = error.close_code(); - state.handle_transport_error(&error); + handle_transport_error(&state, &error); if let Some(code) = code { reader.close(code); writer.close(code); @@ -113,12 +123,14 @@ pub(crate) async fn handle_upload_inner( } }; - state.handle( + handle( + state, request, UploadReader { buffered, stream: reader, }, UploadResponder::new(writer), - ); + ) + .await; } diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index cc08934b..ed6a6eba 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -1,5 +1,6 @@ use std::{ cell::RefCell, + future::Future, rc::Rc, str::Utf8Error, sync::{Arc, Mutex}, @@ -9,15 +10,48 @@ use std::{ use bytes::Bytes; use futures_lite::StreamExt; use ql_rpc::{ - DownloadHandler, DownloadResponder, DownloadWriter, DuplexHandler, DuplexPeer, LocalSpawn, - NotificationHandler, ProgressHandler, ProgressResponder, RequestHandler, Response, RouteId, - SendSpawn, StreamCloseCode, SubscriptionHandler, SubscriptionResponder, UploadHandler, - UploadReader, UploadResponder, + DownloadHandlerLocal, DownloadResponder, DownloadWriter, DuplexHandlerLocal, DuplexPeer, + LocalSpawner, NotificationHandlerLocal, ProgressHandlerLocal, ProgressResponder, + RequestHandler, RequestHandlerLocal, Response, RouteId, SendSpawner, Spawner, StreamCloseCode, + SubscriptionHandlerLocal, SubscriptionResponder, UploadHandlerLocal, UploadReader, + UploadResponder, }; use super::*; use crate::{rpc::RpcError, QlStream, StreamWriter}; +#[derive(Debug, Clone, Copy)] +struct TokioLocalSpawner; + +impl Spawner for TokioLocalSpawner { + type Handle = tokio::task::JoinHandle<()>; +} + +impl LocalSpawner for TokioLocalSpawner { + fn spawn(&self, fut: F) -> Self::Handle + where + F: Future + 'static, + { + tokio::task::spawn_local(fut) + } +} + +#[derive(Debug, Clone, Copy)] +struct TokioSendSpawner; + +impl Spawner for TokioSendSpawner { + type Handle = tokio::task::JoinHandle<()>; +} + +impl SendSpawner for TokioSendSpawner { + fn spawn(&self, fut: F) -> Self::Handle + where + F: Future + Send + 'static, + { + tokio::task::spawn(fut) + } +} + struct Echo; impl ql_rpc::request::Request for Echo { @@ -91,12 +125,10 @@ async fn rpc_request() { } impl RequestHandler for RouterState { - fn handle(self, request: String, response: Response) { + async fn handle(self, request: String, response: Response) { let seen = self.seen.clone(); - tokio::task::spawn(async move { - seen.lock().unwrap().push(request); - let _ = response.respond("world".into()).await; - }); + seen.lock().unwrap().push(request); + let _ = response.respond("world".into()).await; } } @@ -106,15 +138,16 @@ async fn rpc_request() { let inbound_b = pair.take_inbound(Side::B); let seen = Arc::new(Mutex::new(Vec::new())); - let router = ql_rpc::Router::<_, QlStream, SendSpawn>::builder(SendSpawn) - .request::() - .build(RouterState { seen: seen.clone() }); + let router = + ql_rpc::Router::<_, QlStream, TokioSendSpawner>::builder_send(TokioSendSpawner) + .request::() + .build(RouterState { seen: seen.clone() }); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); if let Some((_, fut)) = router.handle(inbound) { let fut = assert_send(fut); - fut.await + fut.await.unwrap(); } }); @@ -142,8 +175,8 @@ async fn rpc_notification() { seen: Rc>>>, } - impl NotificationHandler for RouterState { - fn handle(self, payload: Vec) { + impl NotificationHandlerLocal for RouterState { + async fn handle(self, payload: Vec) { self.seen.borrow_mut().push(payload); } } @@ -154,14 +187,15 @@ async fn rpc_notification() { let inbound_b = pair.take_inbound(Side::B); let seen = Rc::new(RefCell::new(Vec::new())); - let router = ql_rpc::Router::<_, QlStream, LocalSpawn>::builder(LocalSpawn) - .notification::() - .build(RouterState { seen: seen.clone() }); + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .notification::() + .build(RouterState { seen: seen.clone() }); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); if let Some((_, fut)) = router.handle(inbound) { - fut.await; + fut.await.unwrap(); } }); @@ -186,19 +220,17 @@ async fn rpc_subscrption() { seen: Rc>>>, } - impl SubscriptionHandler for RouterState { - fn handle( + impl SubscriptionHandlerLocal for RouterState { + async fn handle( self, request: Vec, mut response: SubscriptionResponder, StreamWriter>, ) { let seen = self.seen.clone(); - tokio::task::spawn_local(async move { - seen.borrow_mut().push(request); - let _ = response.send(b"one".to_vec()).await; - let _ = response.send(b"two".to_vec()).await; - let _ = response.finish().await; - }); + seen.borrow_mut().push(request); + let _ = response.send(b"one".to_vec()).await; + let _ = response.send(b"two".to_vec()).await; + let _ = response.finish().await; } } @@ -208,14 +240,15 @@ async fn rpc_subscrption() { let inbound_b = pair.take_inbound(Side::B); let seen = Rc::new(RefCell::new(Vec::new())); - let router = ql_rpc::Router::<_, QlStream, LocalSpawn>::builder(LocalSpawn) - .subscription::() - .build(RouterState { seen: seen.clone() }); + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .subscription::() + .build(RouterState { seen: seen.clone() }); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); if let Some((_, fut)) = router.handle(inbound) { - fut.await; + fut.await.unwrap(); } }); @@ -239,11 +272,9 @@ async fn rpc_router_enforces_max_request_bytes() { #[derive(Clone)] struct LimitedState; - impl RequestHandler for LimitedState { - fn handle(self, request: String, response: Response) { - tokio::task::spawn_local(async move { - let _ = response.respond(request).await; - }); + impl RequestHandlerLocal for LimitedState { + async fn handle(self, request: String, response: Response) { + let _ = response.respond(request).await; } } @@ -251,15 +282,16 @@ async fn rpc_router_enforces_max_request_bytes() { let mut pair = TestPair::new(default_runtime_config()); pair.connect_and_wait(Side::A).await; let inbound_b = pair.take_inbound(Side::B); - let router = ql_rpc::Router::<_, QlStream, LocalSpawn>::builder(LocalSpawn) - .max_request_bytes(4) - .request::() - .build(LimitedState); + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .max_request_bytes(4) + .request::() + .build(LimitedState); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); if let Some((_, fut)) = router.handle(inbound) { - fut.await + fut.await.unwrap(); } }); @@ -285,19 +317,17 @@ async fn rpc_progress() { seen: Rc>>>, } - impl ProgressHandler for RouterState { - fn handle( + impl ProgressHandlerLocal for RouterState { + async fn handle( self, request: Vec, mut responder: ProgressResponder, ) { let seen = self.seen.clone(); - tokio::task::spawn_local(async move { - seen.borrow_mut().push(request); - responder.send(b"10".to_vec()).await.unwrap(); - responder.send(b"90".to_vec()).await.unwrap(); - responder.finish(b"done".to_vec()).await.unwrap(); - }); + seen.borrow_mut().push(request); + responder.send(b"10".to_vec()).await.unwrap(); + responder.send(b"90".to_vec()).await.unwrap(); + responder.finish(b"done".to_vec()).await.unwrap(); } } @@ -307,14 +337,15 @@ async fn rpc_progress() { let inbound_b = pair.take_inbound(Side::B); let seen = Rc::new(RefCell::new(Vec::new())); - let router = ql_rpc::Router::<_, QlStream, LocalSpawn>::builder(LocalSpawn) - .progress::() - .build(RouterState { seen: seen.clone() }); + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .progress::() + .build(RouterState { seen: seen.clone() }); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); if let Some((_, fut)) = router.handle(inbound) { - fut.await; + fut.await.unwrap(); } }); @@ -342,17 +373,19 @@ async fn rpc_download() { seen: Rc>>>, } - impl DownloadHandler for RouterState { - fn handle(self, request: Vec, responder: DownloadResponder, StreamWriter>) { + impl DownloadHandlerLocal for RouterState { + async fn handle( + self, + request: Vec, + responder: DownloadResponder, StreamWriter>, + ) { let seen = self.seen.clone(); - tokio::task::spawn_local(async move { - seen.borrow_mut().push(request); - let mut writer: DownloadWriter = - responder.respond(b"image/png".to_vec()).await.unwrap(); - writer.send(Bytes::from_static(b"abc")).await.unwrap(); - writer.send(Bytes::from_static(b"def")).await.unwrap(); - writer.finish().await.unwrap(); - }); + seen.borrow_mut().push(request); + let mut writer: DownloadWriter = + responder.respond(b"image/png".to_vec()).await.unwrap(); + writer.send(Bytes::from_static(b"abc")).await.unwrap(); + writer.send(Bytes::from_static(b"def")).await.unwrap(); + writer.finish().await.unwrap(); } } @@ -362,14 +395,15 @@ async fn rpc_download() { let inbound_b = pair.take_inbound(Side::B); let seen = Rc::new(RefCell::new(Vec::new())); - let router = ql_rpc::Router::<_, QlStream, LocalSpawn>::builder(LocalSpawn) - .download::() - .build(RouterState { seen: seen.clone() }); + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .download::() + .build(RouterState { seen: seen.clone() }); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); if let Some((_, fut)) = router.handle(inbound) { - fut.await; + fut.await.unwrap(); } }); @@ -407,8 +441,8 @@ async fn rpc_upload() { uploads: Rc>>>, } - impl UploadHandler for RouterState { - fn handle( + impl UploadHandlerLocal for RouterState { + async fn handle( self, request: Vec, mut upload: UploadReader, @@ -416,17 +450,15 @@ async fn rpc_upload() { ) { let requests = self.requests.clone(); let uploads = self.uploads.clone(); - tokio::task::spawn_local(async move { - requests.borrow_mut().push(request); + requests.borrow_mut().push(request); - let mut body = Vec::new(); - while let Some(chunk) = upload.read_chunk().await.unwrap() { - body.extend_from_slice(&chunk); - } - uploads.borrow_mut().push(body.clone()); + let mut body = Vec::new(); + while let Some(chunk) = upload.read_chunk().await.unwrap() { + body.extend_from_slice(&chunk); + } + uploads.borrow_mut().push(body.clone()); - responder.respond(body).await.unwrap(); - }); + responder.respond(body).await.unwrap(); } } @@ -437,17 +469,18 @@ async fn rpc_upload() { let requests = Rc::new(RefCell::new(Vec::new())); let uploads = Rc::new(RefCell::new(Vec::new())); - let router = ql_rpc::Router::<_, QlStream, LocalSpawn>::builder(LocalSpawn) - .upload::() - .build(RouterState { - requests: requests.clone(), - uploads: uploads.clone(), - }); + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .upload::() + .build(RouterState { + requests: requests.clone(), + uploads: uploads.clone(), + }); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); if let Some((_, fut)) = router.handle(inbound) { - fut.await; + fut.await.unwrap(); } }); @@ -476,23 +509,21 @@ async fn rpc_duplex() { seen: Rc>>>, } - impl DuplexHandler for RouterState { - fn handle(self, mut peer: DuplexPeer) { + impl DuplexHandlerLocal for RouterState { + async fn handle(self, mut peer: DuplexPeer) { let seen = self.seen.clone(); - tokio::task::spawn_local(async move { - let first = peer.receiver.next_event().await.unwrap().unwrap(); - seen.borrow_mut().push(first); + let first = peer.receiver.next_event().await.unwrap().unwrap(); + seen.borrow_mut().push(first); - peer.sender - .send(&b"challenge-response".to_vec()) - .await - .unwrap(); + peer.sender + .send(&b"challenge-response".to_vec()) + .await + .unwrap(); - let second = peer.receiver.next_event().await.unwrap().unwrap(); - seen.borrow_mut().push(second); + let second = peer.receiver.next_event().await.unwrap().unwrap(); + seen.borrow_mut().push(second); - peer.sender.finish().await.unwrap(); - }); + peer.sender.finish().await.unwrap(); } } @@ -502,14 +533,15 @@ async fn rpc_duplex() { let inbound_b = pair.take_inbound(Side::B); let seen = Rc::new(RefCell::new(Vec::new())); - let router = ql_rpc::Router::<_, QlStream, LocalSpawn>::builder(LocalSpawn) - .duplex::() - .build(RouterState { seen: seen.clone() }); + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .duplex::() + .build(RouterState { seen: seen.clone() }); let responder = tokio::task::spawn_local(async move { let inbound = inbound_b.recv().await.unwrap(); if let Some((_, fut)) = router.handle(inbound) { - fut.await; + fut.await.unwrap(); } }); From ea890d4265705c0df373a856ddf6baace3568847 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 14 May 2026 19:30:34 -0400 Subject: [PATCH 298/304] ql-rpc: download with multiple parts --- ql-rpc/src/chunk_queue.rs | 50 ++++- ql-rpc/src/router/mod.rs | 2 +- ql-rpc/src/rpc/download/client.rs | 119 +++++++++--- ql-rpc/src/rpc/download/codec.rs | 301 +++++++++++++++++++++++++++++- ql-rpc/src/rpc/download/mod.rs | 21 ++- ql-rpc/src/rpc/download/server.rs | 109 +++++++++-- ql-runtime/src/rpc/download.rs | 40 ++-- ql-runtime/src/tests/rpc.rs | 37 ++-- 8 files changed, 591 insertions(+), 88 deletions(-) diff --git a/ql-rpc/src/chunk_queue.rs b/ql-rpc/src/chunk_queue.rs index d26429a9..33f62998 100644 --- a/ql-rpc/src/chunk_queue.rs +++ b/ql-rpc/src/chunk_queue.rs @@ -2,7 +2,7 @@ use std::collections::VecDeque; use bytes::{Buf, Bytes}; -use crate::Error; +use crate::{CodecError, Error}; const LENGTH_SIZE: usize = 8; @@ -25,6 +25,14 @@ impl ChunkQueue { self.remaining } + pub fn expect_empty(&self) -> Result<(), CodecError> { + if self.remaining > 0 { + Err(CodecError::Rpc(Error::TrailingBytes)) + } else { + Ok(()) + } + } + pub fn pop_front(&mut self, max_len: usize) -> Option { let front = self.chunks.front_mut()?; let chunk = if max_len >= front.len() { @@ -61,6 +69,27 @@ impl ChunkQueue { Ok(Some((kind, DrainBuf::new(self, len)))) } + pub fn try_take_tagged_part_header(&mut self) -> Result, Error> { + let mut bytes = self.peek(); + let Ok(kind) = bytes.try_get_u8() else { + return Ok(None); + }; + let Some(len) = read_part_len_header(&mut bytes)? else { + return Ok(None); + }; + + self.advance(1 + LENGTH_SIZE); + Ok(Some((kind, len))) + } + + pub fn try_take_body(&mut self, len: usize) -> Option> { + if self.remaining < len { + return None; + } + + Some(DrainBuf::new(self, len)) + } + fn peek_next_part_len(&self) -> Result, Error> { let mut bytes = self.peek(); read_next_part_len(&mut bytes) @@ -169,6 +198,14 @@ impl<'a> DrainBuf<'a> { remaining: len, } } + + pub fn expect_empty(&self) -> Result<(), CodecError> { + if self.remaining > 0 { + Err(CodecError::Rpc(Error::TrailingBytes)) + } else { + Ok(()) + } + } } impl Buf for DrainBuf<'_> { @@ -197,12 +234,19 @@ impl Drop for DrainBuf<'_> { } fn read_next_part_len(bytes: &mut B) -> Result, Error> { - let Ok(len) = bytes.try_get_u64_le() else { + let Some(len) = read_part_len_header(bytes)? else { return Ok(None); }; - let len: usize = len.try_into().map_err(|_| Error::LengthOverflow)?; if bytes.remaining() < len { return Ok(None); } Ok(Some(len)) } + +fn read_part_len_header(bytes: &mut B) -> Result, Error> { + let Ok(len) = bytes.try_get_u64_le() else { + return Ok(None); + }; + let len: usize = len.try_into().map_err(|_| Error::LengthOverflow)?; + Ok(Some(len)) +} diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs index 522f0e52..31e973ac 100644 --- a/ql-rpc/src/router/mod.rs +++ b/ql-rpc/src/router/mod.rs @@ -11,7 +11,7 @@ pub use self::{ }; use crate::{close_stream, RpcStream}; pub use crate::{ - download::{DownloadHandler, DownloadHandlerLocal, DownloadResponder, DownloadWriter}, + download::{DownloadHandler, DownloadHandlerLocal, DownloadStart, DownloadWriter}, duplex::{DuplexHandler, DuplexHandlerLocal, DuplexPeer}, notification::{NotificationHandler, NotificationHandlerLocal}, progress::{ProgressHandler, ProgressHandlerLocal, ProgressResponder}, diff --git a/ql-rpc/src/rpc/download/client.rs b/ql-rpc/src/rpc/download/client.rs index 3156b8b8..9f8280e6 100644 --- a/ql-rpc/src/rpc/download/client.rs +++ b/ql-rpc/src/rpc/download/client.rs @@ -1,13 +1,11 @@ -use std::{ - future::poll_fn, - task::{Context, Poll}, -}; +use std::future::poll_fn; use bytes::Bytes; +use super::codec::FrameKind; use crate::{ - download::{Download, ReadStep, ResponseHeaderReader}, - CallError, ChunkQueue, RpcRead, + download::{Download, PartReadStep, ReadStep, ResponseHeaderReader}, + CallError, RpcRead, }; pub struct DownloadCall @@ -19,12 +17,23 @@ where reader: Option>, } -pub struct DownloadReader +pub struct DownloadPart<'a, M, R> +where + M: Download, + R: RpcRead, +{ + parent: &'a mut DownloadReader, + finished: bool, +} + +pub struct DownloadReader where + M: Download, R: RpcRead, { - buffered: ChunkQueue, stream: R, + reader: crate::download::PartFrameReader, + finished: bool, } impl DownloadCall @@ -41,7 +50,7 @@ where pub async fn into_reader( mut self, - ) -> Result<(M::ResponseHeader, DownloadReader), CallError> { + ) -> Result<(M::ResponseHeader, DownloadReader), CallError> { loop { let reader = self.reader.take().expect("download reader is present"); let reader = match reader.advance() { @@ -49,8 +58,9 @@ where return Ok(( value, DownloadReader { - buffered: bytes, stream: self.stream, + reader: crate::download::PartFrameReader::new(bytes), + finished: false, }, )); } @@ -73,38 +83,87 @@ where } } -impl DownloadReader +impl DownloadReader where + M: Download, R: RpcRead, { - pub fn poll_read( + pub async fn next_part( &mut self, - max_len: usize, - cx: &mut Context<'_>, - ) -> Poll, R::Error>> { - if let Some(chunk) = self.buffered.pop_front(max_len) { - return Poll::Ready(Ok(Some(chunk))); + ) -> Result)>, CallError> + { + if self.finished { + return Ok(None); } - self.stream.poll_read(max_len, cx) - } - - pub fn poll_read_chunk( - &mut self, - cx: &mut Context<'_>, - ) -> Poll, R::Error>> { - self.poll_read(usize::MAX, cx) + match self.read_frame().await? { + PartReadStep::PartHeader(value) => Ok(Some(( + value, + DownloadPart { + parent: self, + finished: false, + }, + ))), + PartReadStep::Finish => { + self.finished = true; + Ok(None) + } + PartReadStep::BodyBytes(_) => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::BodyChunk.tag()).into()) + } + PartReadStep::EndPart => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::EndPart.tag()).into()) + } + PartReadStep::NeedMore => unreachable!("read_frame waits for a complete frame"), + } } - pub async fn read(&mut self, max_len: usize) -> Result, R::Error> { - poll_fn(|cx| self.poll_read(max_len, cx)).await - } + async fn read_frame(&mut self) -> Result, CallError> { + loop { + match self.reader.advance() { + Ok(PartReadStep::NeedMore) => {} + Ok(step) => return Ok(step), + Err(error) => return Err(error.into()), + } - pub async fn read_chunk(&mut self) -> Result, R::Error> { - self.read(usize::MAX).await + match poll_fn(|cx| self.stream.poll_read(usize::MAX, cx)).await { + Ok(Some(chunk)) => { + self.reader.push(chunk); + } + Ok(None) => return Err(crate::Error::Truncated.into()), + Err(error) => return Err(CallError::Transport(error)), + } + } } pub fn into_inner(self) -> R { self.stream } } + +impl DownloadPart<'_, M, R> +where + M: Download, + R: RpcRead, +{ + pub async fn read_chunk(&mut self) -> Result, CallError> { + if self.finished { + return Ok(None); + } + + match self.parent.read_frame().await? { + PartReadStep::BodyBytes(bytes) => Ok(Some(bytes)), + PartReadStep::EndPart => { + self.finished = true; + Ok(None) + } + PartReadStep::PartHeader(_) => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::PartHeader.tag()).into()) + } + PartReadStep::Finish => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::Finish.tag()).into()) + } + PartReadStep::NeedMore => unreachable!("read_frame waits for a complete frame"), + } + } +} diff --git a/ql-rpc/src/rpc/download/codec.rs b/ql-rpc/src/rpc/download/codec.rs index 53e54ecc..53332773 100644 --- a/ql-rpc/src/rpc/download/codec.rs +++ b/ql-rpc/src/rpc/download/codec.rs @@ -23,11 +23,38 @@ pub enum ReadStep { }, } +pub enum PartReadStep { + NeedMore, + PartHeader(M::PartHeader), + BodyBytes(Bytes), + EndPart, + Finish, +} + pub struct ResponseHeaderReader { bytes: codec::ChunkQueue, marker: PhantomData M>, } +pub struct PartFrameReader { + bytes: codec::ChunkQueue, + pending_frame: PendingFrame, + marker: PhantomData M>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PendingFrame { + None, + Control { kind: FrameKind, len: usize }, + Body { remaining: usize }, +} + +impl PendingFrame { + fn take(&mut self) -> Self { + std::mem::replace(self, Self::None) + } +} + impl Default for ResponseHeaderReader { fn default() -> Self { Self { @@ -44,13 +71,11 @@ impl ResponseHeaderReader { } pub fn advance(mut self) -> Result, CodecError> { - let Some(mut body) = self.bytes.try_take_part()? else { - return Ok(ReadStep::NeedMore(self)); - }; - let value = { + let Some(mut body) = self.bytes.try_take_part()? else { + return Ok(ReadStep::NeedMore(self)); + }; let value = M::ResponseHeader::decode_value(&mut body).map_err(CodecError::Codec)?; - drop(body); value }; @@ -60,3 +85,269 @@ impl ResponseHeaderReader { }) } } + +impl PartFrameReader { + pub fn new(bytes: ChunkQueue) -> Self { + Self { + bytes, + pending_frame: PendingFrame::None, + marker: PhantomData, + } + } + + pub fn push(&mut self, chunk: Bytes) { + self.bytes.push(chunk); + } + + pub fn advance(&mut self) -> Result, CodecError> { + loop { + match self.pending_frame.take() { + PendingFrame::Body { remaining } => { + if remaining == 0 { + continue; + } + + let Some(bytes) = self.bytes.pop_front(remaining) else { + self.pending_frame = PendingFrame::Body { remaining }; + return Ok(PartReadStep::NeedMore); + }; + + let remaining = remaining - bytes.len(); + self.pending_frame = if remaining == 0 { + PendingFrame::None + } else { + PendingFrame::Body { remaining } + }; + return Ok(PartReadStep::BodyBytes(bytes)); + } + PendingFrame::Control { kind, len } => { + let Some(mut body) = self.bytes.try_take_body(len) else { + self.pending_frame = PendingFrame::Control { kind, len }; + return Ok(PartReadStep::NeedMore); + }; + + match kind { + FrameKind::PartHeader => { + let value = M::PartHeader::decode_value(&mut body) + .map_err(CodecError::Codec)?; + return Ok(PartReadStep::PartHeader(value)); + } + FrameKind::BodyChunk => unreachable!("body chunk is not a control frame"), + FrameKind::EndPart => { + body.expect_empty()?; + return Ok(PartReadStep::EndPart); + } + FrameKind::Finish => { + body.expect_empty()?; + drop(body); + self.bytes.expect_empty()?; + return Ok(PartReadStep::Finish); + } + } + } + PendingFrame::None => { + let Some((kind, len)) = self + .bytes + .try_take_tagged_part_header() + .map_err(CodecError::Rpc)? + else { + return Ok(PartReadStep::NeedMore); + }; + + let kind = FrameKind::try_from(kind).map_err(CodecError::Rpc)?; + self.pending_frame = if kind == FrameKind::BodyChunk { + PendingFrame::Body { remaining: len } + } else { + PendingFrame::Control { kind, len } + }; + } + } + } + } +} + +pub fn encode_part_header( + part_header: &M::PartHeader, + out: &mut (impl BufMut + AsMut<[u8]>), +) { + encode_tagged_value_part(FrameKind::PartHeader, part_header, out) +} + +pub fn encode_body_chunk(bytes: &Bytes, out: &mut (impl BufMut + AsMut<[u8]>)) { + encode_tagged_value_part(FrameKind::BodyChunk, bytes, out) +} + +pub fn encode_end_part(out: &mut (impl BufMut + AsMut<[u8]>)) { + encode_tagged_empty_part(FrameKind::EndPart, out) +} + +pub fn encode_finish(out: &mut (impl BufMut + AsMut<[u8]>)) { + encode_tagged_empty_part(FrameKind::Finish, out) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub(super) enum FrameKind { + PartHeader = 1, + BodyChunk = 2, + EndPart = 3, + Finish = 4, +} + +impl FrameKind { + pub fn tag(self) -> u8 { + self as u8 + } +} + +impl TryFrom for FrameKind { + type Error = crate::Error; + + fn try_from(value: u8) -> Result { + match value { + x if x == Self::PartHeader.tag() => Ok(Self::PartHeader), + x if x == Self::BodyChunk.tag() => Ok(Self::BodyChunk), + x if x == Self::EndPart.tag() => Ok(Self::EndPart), + x if x == Self::Finish.tag() => Ok(Self::Finish), + other => Err(crate::Error::UnexpectedFrameKind(other)), + } + } +} + +fn encode_tagged_value_part>( + kind: FrameKind, + value: &T, + out: &mut B, +) { + out.put_u8(kind.tag()); + let payload_start = codec::reserve_length(out); + value.encode_value(out); + codec::backpatch_length(out, payload_start); +} + +fn encode_tagged_empty_part>(kind: FrameKind, out: &mut B) { + out.put_u8(kind.tag()); + let payload_start = codec::reserve_length(out); + codec::backpatch_length(out, payload_start); +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::{ + encode_body_chunk, encode_end_part, encode_finish, encode_part_header, PartFrameReader, + PartReadStep, + }; + use crate::{download::Download, RouteId}; + + struct Files; + + impl Download for Files { + const ROUTE: RouteId = RouteId::from_u32(12); + type Error = core::convert::Infallible; + type Request = Vec; + type ResponseHeader = Vec; + type PartHeader = Vec; + } + + #[test] + fn part_reader_emits_multipart_sequence() { + let mut encoded = Vec::new(); + encode_part_header::(&b"a.txt".to_vec(), &mut encoded); + encode_body_chunk(&Bytes::from_static(b"hel"), &mut encoded); + encode_body_chunk(&Bytes::from_static(b"lo"), &mut encoded); + encode_end_part(&mut encoded); + encode_part_header::(&b"b.txt".to_vec(), &mut encoded); + encode_end_part(&mut encoded); + encode_finish(&mut encoded); + + let mut reader = PartFrameReader::::new(Default::default()); + reader.push(Bytes::from(encoded)); + + match reader.advance().unwrap() { + PartReadStep::PartHeader(value) => { + assert_eq!(value, b"a.txt".to_vec()); + } + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + PartReadStep::BodyBytes(bytes) => assert_eq!(bytes, Bytes::from_static(b"hel")), + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + PartReadStep::BodyBytes(bytes) => assert_eq!(bytes, Bytes::from_static(b"lo")), + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + PartReadStep::EndPart => {} + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + PartReadStep::PartHeader(value) => { + assert_eq!(value, b"b.txt".to_vec()); + } + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + PartReadStep::EndPart => {} + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + PartReadStep::Finish => {} + _ => unreachable!(), + } + } + + #[test] + fn part_reader_waits_for_complete_frame() { + let mut encoded = Vec::new(); + encode_part_header::(&b"a.txt".to_vec(), &mut encoded); + let encoded = Bytes::from(encoded); + + let mut reader = PartFrameReader::::new(Default::default()); + reader.push(encoded.slice(..4)); + match reader.advance().unwrap() { + PartReadStep::NeedMore => {} + _ => unreachable!(), + }; + + reader.push(encoded.slice(4..)); + match reader.advance().unwrap() { + PartReadStep::PartHeader(value) => assert_eq!(value, b"a.txt".to_vec()), + _ => unreachable!(), + } + } + + #[test] + fn body_chunk_frame_streams_after_header() { + let mut encoded = Vec::new(); + encode_body_chunk(&Bytes::from_static(b"hello"), &mut encoded); + let encoded = Bytes::from(encoded); + + let mut reader = PartFrameReader::::new(Default::default()); + reader.push(encoded.slice(..9)); + match reader.advance().unwrap() { + PartReadStep::NeedMore => {} + _ => unreachable!(), + }; + + reader.push(encoded.slice(9..11)); + match reader.advance().unwrap() { + PartReadStep::BodyBytes(bytes) => assert_eq!(bytes, Bytes::from_static(b"he")), + _ => unreachable!(), + }; + + reader.push(encoded.slice(11..)); + match reader.advance().unwrap() { + PartReadStep::BodyBytes(bytes) => assert_eq!(bytes, Bytes::from_static(b"llo")), + _ => unreachable!(), + }; + } +} diff --git a/ql-rpc/src/rpc/download/mod.rs b/ql-rpc/src/rpc/download/mod.rs index 967158cc..5cef967c 100644 --- a/ql-rpc/src/rpc/download/mod.rs +++ b/ql-rpc/src/rpc/download/mod.rs @@ -4,15 +4,20 @@ pub(crate) mod client; pub(crate) mod codec; pub(crate) mod server; -pub use client::{DownloadCall, DownloadReader}; -pub use codec::{encode_request, encode_response_header, ReadStep, ResponseHeaderReader}; -pub use server::{DownloadHandler, DownloadHandlerLocal, DownloadResponder, DownloadWriter}; +pub use client::{DownloadCall, DownloadPart, DownloadReader}; +pub use codec::{ + encode_body_chunk, encode_end_part, encode_finish, encode_part_header, encode_request, + encode_response_header, PartFrameReader, PartReadStep, ReadStep, ResponseHeaderReader, +}; +pub use server::{ + DownloadHandler, DownloadHandlerLocal, DownloadPartWriter, DownloadStart, DownloadWriter, +}; -/// rpc where the responder returns metadata first and raw bytes after that +/// rpc where the responder returns metadata first and then zero or more byte parts /// /// the typed portion of the response ends at [`Self::ResponseHeader`] -/// after the header is decoded, the rest of the stream is exposed as raw byte -/// chunks through [`DownloadReader`] +/// after the header is decoded, the rest of the stream is exposed as typed +/// part headers followed by raw byte chunks through [`DownloadReader`] pub trait Download { /// route used to dispatch this rpc family const ROUTE: RouteId; @@ -20,6 +25,8 @@ pub trait Download { type Error; /// typed input needed to start the download type Request: RpcCodec; - /// typed metadata available before body bytes arrive + /// typed metadata available before parts arrive type ResponseHeader: RpcCodec; + /// typed metadata available before each byte part arrives + type PartHeader: RpcCodec; } diff --git a/ql-rpc/src/rpc/download/server.rs b/ql-rpc/src/rpc/download/server.rs index 0e26b7b6..8308d7f7 100644 --- a/ql-rpc/src/rpc/download/server.rs +++ b/ql-rpc/src/rpc/download/server.rs @@ -3,8 +3,14 @@ use std::{future::Future, marker::PhantomData}; use bytes::Bytes; use crate::{ - codec, download::Download as DownloadRpc, finish_bytes, rpc::read_eof_request, write_bytes, - RouterConfig, RpcCodec, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, + codec, + download::{ + encode_body_chunk, encode_end_part, encode_finish, encode_part_header, + Download as DownloadRpc, + }, + finish_bytes, + rpc::read_eof_request, + write_bytes, RouterConfig, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, }; #[trait_variant::make(DownloadHandler: Send)] @@ -13,33 +19,41 @@ where M: DownloadRpc, St: RpcStream, { - async fn handle( - self, - message: M::Request, - responder: DownloadResponder, - ); + async fn handle(self, message: M::Request, download: DownloadStart); fn handle_transport_error(&self, _error: &St::Error) {} } -pub struct DownloadResponder +pub struct DownloadStart where + M: DownloadRpc, W: RpcWrite, { writer: Option, - marker: PhantomData T>, + marker: PhantomData M>, } -pub struct DownloadWriter +pub struct DownloadWriter where + M: DownloadRpc, W: RpcWrite, { writer: Option, + marker: PhantomData M>, } -impl DownloadResponder +pub struct DownloadPartWriter<'a, M, W> where - T: RpcCodec, + M: DownloadRpc, + W: RpcWrite, +{ + parent: &'a mut DownloadWriter, + finished: bool, +} + +impl DownloadStart +where + M: DownloadRpc, W: RpcWrite, { pub(crate) fn new(writer: W) -> Self { @@ -49,13 +63,17 @@ where } } - pub async fn respond(mut self, response_header: T) -> Result, W::Error> { + pub async fn start( + mut self, + response_header: M::ResponseHeader, + ) -> Result, W::Error> { let mut writer = self.writer.take().expect("download writer exists"); let mut encoded = Vec::new(); codec::encode_value_part(&response_header, &mut encoded); write_bytes(&mut writer, Bytes::from(encoded)).await?; Ok(DownloadWriter { writer: Some(writer), + marker: PhantomData, }) } @@ -66,8 +84,9 @@ where } } -impl Drop for DownloadResponder +impl Drop for DownloadStart where + M: DownloadRpc, W: RpcWrite, { fn drop(&mut self) { @@ -77,17 +96,30 @@ where } } -impl DownloadWriter +impl DownloadWriter where + M: DownloadRpc, W: RpcWrite, { - pub async fn send(&mut self, bytes: Bytes) -> Result<(), W::Error> { + pub async fn start_part( + &mut self, + part_header: M::PartHeader, + ) -> Result, W::Error> { let writer = self.writer.as_mut().expect("download writer exists"); - write_bytes(writer, bytes).await + let mut encoded = Vec::new(); + encode_part_header::(&part_header, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await?; + Ok(DownloadPartWriter { + parent: self, + finished: false, + }) } pub async fn finish(mut self) -> Result<(), W::Error> { let mut writer = self.writer.take().expect("download writer exists"); + let mut encoded = Vec::new(); + encode_finish(&mut encoded); + write_bytes(&mut writer, Bytes::from(encoded)).await?; finish_bytes(&mut writer).await } @@ -98,8 +130,9 @@ where } } -impl Drop for DownloadWriter +impl Drop for DownloadWriter where + M: DownloadRpc, W: RpcWrite, { fn drop(&mut self) { @@ -109,6 +142,42 @@ where } } +impl DownloadPartWriter<'_, M, W> +where + M: DownloadRpc, + W: RpcWrite, +{ + pub async fn send(&mut self, bytes: Bytes) -> Result<(), W::Error> { + let writer = self.parent.writer.as_mut().expect("download writer exists"); + let mut encoded = Vec::new(); + encode_body_chunk(&bytes, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await + } + + pub async fn finish(mut self) -> Result<(), W::Error> { + let writer = self.parent.writer.as_mut().expect("download writer exists"); + let mut encoded = Vec::new(); + encode_end_part(&mut encoded); + write_bytes(writer, Bytes::from(encoded)).await?; + self.finished = true; + Ok(()) + } +} + +impl Drop for DownloadPartWriter<'_, M, W> +where + M: DownloadRpc, + W: RpcWrite, +{ + fn drop(&mut self) { + if !self.finished { + if let Some(writer) = self.parent.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } + } +} + pub(crate) async fn handle_download_inner( state: S, config: RouterConfig, @@ -119,7 +188,7 @@ pub(crate) async fn handle_download_inner( ) where M: DownloadRpc + 'static, St: RpcStream + 'static, - H: FnOnce(S, M::Request, DownloadResponder) -> HF, + H: FnOnce(S, M::Request, DownloadStart) -> HF, HF: Future, E: FnOnce(&S, &St::Error), { @@ -136,5 +205,5 @@ pub(crate) async fn handle_download_inner( } }; - handle(state, request, DownloadResponder::new(writer)).await; + handle(state, request, DownloadStart::new(writer)).await; } diff --git a/ql-runtime/src/rpc/download.rs b/ql-runtime/src/rpc/download.rs index a560dd69..6068e357 100644 --- a/ql-runtime/src/rpc/download.rs +++ b/ql-runtime/src/rpc/download.rs @@ -2,14 +2,18 @@ use bytes::Bytes; use ql_rpc::download::Download as DownloadRpc; use super::RpcError; -use crate::{QlStreamError, StreamReader}; +use crate::StreamReader; pub struct DownloadCall { pub(super) inner: ql_rpc::download::DownloadCall, } -pub struct DownloadReader { - pub(super) inner: ql_rpc::download::DownloadReader, +pub struct DownloadReader { + pub(super) inner: ql_rpc::download::DownloadReader, +} + +pub struct DownloadPart<'a, M: DownloadRpc> { + inner: ql_rpc::download::DownloadPart<'a, M, StreamReader>, } impl DownloadCall @@ -18,22 +22,36 @@ where { pub async fn into_reader( self, - ) -> Result<(M::ResponseHeader, DownloadReader), RpcError> { + ) -> Result<(M::ResponseHeader, DownloadReader), RpcError> { let (header, inner) = self.inner.into_reader().await?; Ok((header, DownloadReader { inner })) } } -impl DownloadReader { - pub async fn read(&mut self, max_len: usize) -> Result, QlStreamError> { - self.inner.read(max_len).await - } - - pub async fn read_chunk(&mut self) -> Result, QlStreamError> { - self.inner.read_chunk().await +impl DownloadReader +where + M: DownloadRpc, +{ + pub async fn next_part( + &mut self, + ) -> Result)>, RpcError> { + Ok(self + .inner + .next_part() + .await? + .map(|(header, inner)| (header, DownloadPart { inner }))) } pub fn close(self, code: ql_wire::StreamCloseCode) { self.inner.into_inner().close(code); } } + +impl DownloadPart<'_, M> +where + M: DownloadRpc, +{ + pub async fn read_chunk(&mut self) -> Result, RpcError> { + Ok(self.inner.read_chunk().await?) + } +} diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index ed6a6eba..e8c03f7a 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -10,9 +10,9 @@ use std::{ use bytes::Bytes; use futures_lite::StreamExt; use ql_rpc::{ - DownloadHandlerLocal, DownloadResponder, DownloadWriter, DuplexHandlerLocal, DuplexPeer, - LocalSpawner, NotificationHandlerLocal, ProgressHandlerLocal, ProgressResponder, - RequestHandler, RequestHandlerLocal, Response, RouteId, SendSpawner, Spawner, StreamCloseCode, + DownloadHandlerLocal, DownloadStart, DuplexHandlerLocal, DuplexPeer, LocalSpawner, + NotificationHandlerLocal, ProgressHandlerLocal, ProgressResponder, RequestHandler, + RequestHandlerLocal, Response, RouteId, SendSpawner, Spawner, StreamCloseCode, SubscriptionHandlerLocal, SubscriptionResponder, UploadHandlerLocal, UploadReader, UploadResponder, }; @@ -97,6 +97,7 @@ impl ql_rpc::download::Download for BlobDownload { type Error = core::convert::Infallible; type Request = Vec; type ResponseHeader = Vec; + type PartHeader = Vec; } struct BlobUpload; @@ -377,14 +378,18 @@ async fn rpc_download() { async fn handle( self, request: Vec, - responder: DownloadResponder, StreamWriter>, + download: DownloadStart, ) { let seen = self.seen.clone(); seen.borrow_mut().push(request); - let mut writer: DownloadWriter = - responder.respond(b"image/png".to_vec()).await.unwrap(); - writer.send(Bytes::from_static(b"abc")).await.unwrap(); - writer.send(Bytes::from_static(b"def")).await.unwrap(); + let mut writer = download.start(b"image/png".to_vec()).await.unwrap(); + let mut part = writer.start_part(b"icon".to_vec()).await.unwrap(); + part.send(Bytes::from_static(b"abc")).await.unwrap(); + part.send(Bytes::from_static(b"def")).await.unwrap(); + part.finish().await.unwrap(); + let mut part = writer.start_part(b"manifest".to_vec()).await.unwrap(); + part.send(Bytes::from_static(b"{}")).await.unwrap(); + part.finish().await.unwrap(); writer.finish().await.unwrap(); } } @@ -414,15 +419,25 @@ async fn rpc_download() { .unwrap(); let (header, mut reader) = download.into_reader().await.unwrap(); assert_eq!(header, b"image/png".to_vec()); + let (part_header, mut part) = reader.next_part().await.unwrap().unwrap(); + assert_eq!(part_header, b"icon".to_vec()); assert_eq!( - reader.read_chunk().await.unwrap(), + part.read_chunk().await.unwrap(), Some(Bytes::from_static(b"abc")) ); assert_eq!( - reader.read_chunk().await.unwrap(), + part.read_chunk().await.unwrap(), Some(Bytes::from_static(b"def")) ); - assert_eq!(reader.read_chunk().await.unwrap(), None); + assert_eq!(part.read_chunk().await.unwrap(), None); + let (part_header, mut part) = reader.next_part().await.unwrap().unwrap(); + assert_eq!(part_header, b"manifest".to_vec()); + assert_eq!( + part.read_chunk().await.unwrap(), + Some(Bytes::from_static(b"{}")) + ); + assert_eq!(part.read_chunk().await.unwrap(), None); + assert!(reader.next_part().await.unwrap().is_none()); assert_eq!(seen.borrow().as_slice(), &[b"logo".to_vec()]); tokio::time::timeout(Duration::from_secs(2), responder) From d3c9169e5f34935542688038ac69385b350b08ba Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 15 May 2026 09:50:04 -0400 Subject: [PATCH 299/304] ql-rpc: upload parts --- ql-rpc/src/framed_value.rs | 51 +++++++- ql-rpc/src/rpc/download/client.rs | 91 +++++++++---- ql-rpc/src/rpc/download/mod.rs | 12 +- ql-rpc/src/rpc/download/server.rs | 12 +- ql-rpc/src/rpc/mod.rs | 1 + .../src/rpc/{download/codec.rs => parts.rs} | 102 +++------------ ql-rpc/src/rpc/upload/client.rs | 84 ++++++++++-- ql-rpc/src/rpc/upload/mod.rs | 10 +- ql-rpc/src/rpc/upload/server.rs | 121 ++++++++++++++---- ql-rpc/src/rpc/utils.rs | 27 ++-- ql-runtime/src/rpc/download.rs | 4 - ql-runtime/src/rpc/upload.rs | 26 +++- ql-runtime/src/tests/rpc.rs | 68 ++++++---- 13 files changed, 396 insertions(+), 213 deletions(-) rename ql-rpc/src/rpc/{download/codec.rs => parts.rs} (74%) diff --git a/ql-rpc/src/framed_value.rs b/ql-rpc/src/framed_value.rs index b76007f5..600357da 100644 --- a/ql-rpc/src/framed_value.rs +++ b/ql-rpc/src/framed_value.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use bytes::Bytes; -use crate::{chunk_queue::ChunkQueue, CodecError, Error, RpcCodec}; +use crate::{chunk_queue::ChunkQueue, CodecError, RpcCodec}; /// reads one length-delimited rpc value from buffered byte chunks pub struct FramedReader { @@ -15,6 +15,11 @@ pub enum FramedReadStep { Value(T), } +pub enum FramedPrefixStep { + NeedMore(FramedReader), + Value { value: T, bytes: ChunkQueue }, +} + impl Default for FramedReader { fn default() -> Self { Self { @@ -31,17 +36,27 @@ impl FramedReader { } pub fn advance(self) -> Result, CodecError> { + match self.advance_prefix()? { + FramedPrefixStep::NeedMore(next) => Ok(FramedReadStep::NeedMore(next)), + FramedPrefixStep::Value { value, bytes } => { + bytes.expect_empty()?; + Ok(FramedReadStep::Value(value)) + } + } + } + + pub fn advance_prefix(self) -> Result, CodecError> { let mut this = self; let Some(mut body) = this.bytes.try_take_part()? else { - return Ok(FramedReadStep::NeedMore(this)); + return Ok(FramedPrefixStep::NeedMore(this)); }; let value = T::decode_value(&mut body).map_err(CodecError::Codec)?; drop(body); - if this.bytes.remaining() > 0 { - return Err(CodecError::Rpc(Error::TrailingBytes)); - } - Ok(FramedReadStep::Value(value)) + Ok(FramedPrefixStep::Value { + value, + bytes: this.bytes, + }) } } @@ -49,7 +64,7 @@ impl FramedReader { mod tests { use bytes::Bytes; - use super::{FramedReadStep, FramedReader}; + use super::{FramedPrefixStep, FramedReadStep, FramedReader}; use crate::codec::encode_value_part; #[test] @@ -87,4 +102,26 @@ mod tests { _ => unreachable!(), } } + + #[test] + fn value_reader_returns_prefix_remainder() { + let mut encoded = Vec::new(); + encode_value_part(&b"hello".to_vec(), &mut encoded); + encoded.extend_from_slice(b"tail"); + + match FramedReader::>::default() + .push(Bytes::from(encoded)) + .advance_prefix() + .unwrap() + { + FramedPrefixStep::Value { value, mut bytes } => { + assert_eq!(value, b"hello".to_vec()); + assert_eq!( + bytes.pop_front(usize::MAX), + Some(Bytes::from_static(b"tail")) + ); + } + _ => unreachable!(), + } + } } diff --git a/ql-rpc/src/rpc/download/client.rs b/ql-rpc/src/rpc/download/client.rs index 9f8280e6..9175dbc5 100644 --- a/ql-rpc/src/rpc/download/client.rs +++ b/ql-rpc/src/rpc/download/client.rs @@ -1,11 +1,11 @@ use std::future::poll_fn; -use bytes::Bytes; +use bytes::{BufMut, Bytes}; -use super::codec::FrameKind; use crate::{ - download::{Download, PartReadStep, ReadStep, ResponseHeaderReader}, - CallError, RpcRead, + download::{Download, PartReadStep}, + rpc::parts::FrameKind, + CallError, FramedPrefixStep, FramedReader, RpcCodec, RpcRead, StreamCloseCode, }; pub struct DownloadCall @@ -13,8 +13,8 @@ where M: Download, R: RpcRead, { - stream: R, - reader: Option>, + stream: Option, + reader: Option>, } pub struct DownloadPart<'a, M, R> @@ -31,8 +31,8 @@ where M: Download, R: RpcRead, { - stream: R, - reader: crate::download::PartFrameReader, + stream: Option, + reader: crate::download::PartFrameReader, finished: bool, } @@ -43,8 +43,8 @@ where { pub fn new(stream: R) -> Self { Self { - stream, - reader: Some(ResponseHeaderReader::default()), + stream: Some(stream), + reader: Some(FramedReader::default()), } } @@ -53,22 +53,23 @@ where ) -> Result<(M::ResponseHeader, DownloadReader), CallError> { loop { let reader = self.reader.take().expect("download reader is present"); - let reader = match reader.advance() { - Ok(ReadStep::ResponseHeader { value, bytes }) => { + let reader = match reader.advance_prefix() { + Ok(FramedPrefixStep::Value { value, bytes }) => { return Ok(( value, DownloadReader { - stream: self.stream, - reader: crate::download::PartFrameReader::new(bytes), + stream: self.stream.take(), + reader: crate::download::PartFrameReader::::new(bytes), finished: false, }, )); } - Ok(ReadStep::NeedMore(next)) => next, + Ok(FramedPrefixStep::NeedMore(next)) => next, Err(error) => return Err(error.into()), }; - match poll_fn(|cx| self.stream.poll_read(usize::MAX, cx)).await { + let stream = self.stream.as_mut().expect("download stream exists"); + match poll_fn(|cx| stream.poll_read(usize::MAX, cx)).await { Ok(Some(chunk)) => { self.reader = Some(reader.push(chunk)); } @@ -78,8 +79,20 @@ where } } - pub fn into_inner(self) -> R { - self.stream + fn close(&mut self, code: StreamCloseCode) { + if let Some(stream) = self.stream.take() { + stream.close(code); + } + } +} + +impl Drop for DownloadCall +where + M: Download, + R: RpcRead, +{ + fn drop(&mut self) { + self.close(StreamCloseCode::CANCELLED); } } @@ -118,7 +131,9 @@ where } } - async fn read_frame(&mut self) -> Result, CallError> { + async fn read_frame( + &mut self, + ) -> Result, CallError> { loop { match self.reader.advance() { Ok(PartReadStep::NeedMore) => {} @@ -126,7 +141,8 @@ where Err(error) => return Err(error.into()), } - match poll_fn(|cx| self.stream.poll_read(usize::MAX, cx)).await { + let stream = self.stream.as_mut().expect("download stream exists"); + match poll_fn(|cx| stream.poll_read(usize::MAX, cx)).await { Ok(Some(chunk)) => { self.reader.push(chunk); } @@ -136,8 +152,23 @@ where } } - pub fn into_inner(self) -> R { - self.stream + fn close(&mut self, code: StreamCloseCode) { + self.finished = true; + if let Some(stream) = self.stream.take() { + stream.close(code); + } + } +} + +impl Drop for DownloadReader +where + M: Download, + R: RpcRead, +{ + fn drop(&mut self) { + if !self.finished { + self.close(StreamCloseCode::CANCELLED); + } } } @@ -167,3 +198,19 @@ where } } } + +impl Drop for DownloadPart<'_, M, R> +where + M: Download, + R: RpcRead, +{ + fn drop(&mut self) { + if !self.finished { + self.parent.close(StreamCloseCode::CANCELLED); + } + } +} + +pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { + request.encode_value(out) +} diff --git a/ql-rpc/src/rpc/download/mod.rs b/ql-rpc/src/rpc/download/mod.rs index 5cef967c..c37fc90a 100644 --- a/ql-rpc/src/rpc/download/mod.rs +++ b/ql-rpc/src/rpc/download/mod.rs @@ -1,18 +1,18 @@ use crate::{RouteId, RpcCodec}; pub(crate) mod client; -pub(crate) mod codec; pub(crate) mod server; -pub use client::{DownloadCall, DownloadPart, DownloadReader}; -pub use codec::{ - encode_body_chunk, encode_end_part, encode_finish, encode_part_header, encode_request, - encode_response_header, PartFrameReader, PartReadStep, ReadStep, ResponseHeaderReader, -}; +pub use client::{encode_request, DownloadCall, DownloadPart, DownloadReader}; pub use server::{ DownloadHandler, DownloadHandlerLocal, DownloadPartWriter, DownloadStart, DownloadWriter, }; +pub use crate::rpc::parts::{ + encode_body_chunk, encode_end_part, encode_finish, encode_part_header, PartFrameReader, + PartReadStep, +}; + /// rpc where the responder returns metadata first and then zero or more byte parts /// /// the typed portion of the response ends at [`Self::ResponseHeader`] diff --git a/ql-rpc/src/rpc/download/server.rs b/ql-rpc/src/rpc/download/server.rs index 8308d7f7..7b421e9b 100644 --- a/ql-rpc/src/rpc/download/server.rs +++ b/ql-rpc/src/rpc/download/server.rs @@ -4,12 +4,12 @@ use bytes::Bytes; use crate::{ codec, - download::{ - encode_body_chunk, encode_end_part, encode_finish, encode_part_header, - Download as DownloadRpc, - }, + download::Download as DownloadRpc, finish_bytes, - rpc::read_eof_request, + rpc::{ + parts::{encode_body_chunk, encode_end_part, encode_finish, encode_part_header}, + read_eof_request, + }, write_bytes, RouterConfig, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, }; @@ -107,7 +107,7 @@ where ) -> Result, W::Error> { let writer = self.writer.as_mut().expect("download writer exists"); let mut encoded = Vec::new(); - encode_part_header::(&part_header, &mut encoded); + encode_part_header(&part_header, &mut encoded); write_bytes(writer, Bytes::from(encoded)).await?; Ok(DownloadPartWriter { parent: self, diff --git a/ql-rpc/src/rpc/mod.rs b/ql-rpc/src/rpc/mod.rs index 1f8f32fb..7b573afb 100644 --- a/ql-rpc/src/rpc/mod.rs +++ b/ql-rpc/src/rpc/mod.rs @@ -8,6 +8,7 @@ pub mod download; pub mod duplex; pub mod notification; +pub(crate) mod parts; pub mod progress; pub mod request; pub mod subscription; diff --git a/ql-rpc/src/rpc/download/codec.rs b/ql-rpc/src/rpc/parts.rs similarity index 74% rename from ql-rpc/src/rpc/download/codec.rs rename to ql-rpc/src/rpc/parts.rs index 53332773..47ff1e87 100644 --- a/ql-rpc/src/rpc/download/codec.rs +++ b/ql-rpc/src/rpc/parts.rs @@ -2,44 +2,20 @@ use std::marker::PhantomData; use bytes::{BufMut, Bytes}; -use crate::{codec, download::Download, ChunkQueue, CodecError, RpcCodec}; +use crate::{codec, ChunkQueue, CodecError, RpcCodec}; -pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { - request.encode_value(out) -} - -pub fn encode_response_header( - response_header: &M::ResponseHeader, - out: &mut (impl BufMut + AsMut<[u8]>), -) { - codec::encode_value_part(response_header, out) -} - -pub enum ReadStep { - NeedMore(ResponseHeaderReader), - ResponseHeader { - value: M::ResponseHeader, - bytes: ChunkQueue, - }, -} - -pub enum PartReadStep { +pub enum PartReadStep { NeedMore, - PartHeader(M::PartHeader), + PartHeader(H), BodyBytes(Bytes), EndPart, Finish, } -pub struct ResponseHeaderReader { - bytes: codec::ChunkQueue, - marker: PhantomData M>, -} - -pub struct PartFrameReader { +pub struct PartFrameReader { bytes: codec::ChunkQueue, pending_frame: PendingFrame, - marker: PhantomData M>, + marker: PhantomData H>, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -55,38 +31,7 @@ impl PendingFrame { } } -impl Default for ResponseHeaderReader { - fn default() -> Self { - Self { - bytes: codec::ChunkQueue::default(), - marker: PhantomData, - } - } -} - -impl ResponseHeaderReader { - pub fn push(mut self, chunk: Bytes) -> Self { - self.bytes.push(chunk); - self - } - - pub fn advance(mut self) -> Result, CodecError> { - let value = { - let Some(mut body) = self.bytes.try_take_part()? else { - return Ok(ReadStep::NeedMore(self)); - }; - let value = M::ResponseHeader::decode_value(&mut body).map_err(CodecError::Codec)?; - value - }; - - Ok(ReadStep::ResponseHeader { - value, - bytes: self.bytes, - }) - } -} - -impl PartFrameReader { +impl PartFrameReader { pub fn new(bytes: ChunkQueue) -> Self { Self { bytes, @@ -99,7 +44,7 @@ impl PartFrameReader { self.bytes.push(chunk); } - pub fn advance(&mut self) -> Result, CodecError> { + pub fn advance(&mut self) -> Result, CodecError> { loop { match self.pending_frame.take() { PendingFrame::Body { remaining } => { @@ -128,8 +73,7 @@ impl PartFrameReader { match kind { FrameKind::PartHeader => { - let value = M::PartHeader::decode_value(&mut body) - .map_err(CodecError::Codec)?; + let value = H::decode_value(&mut body).map_err(CodecError::Codec)?; return Ok(PartReadStep::PartHeader(value)); } FrameKind::BodyChunk => unreachable!("body chunk is not a control frame"), @@ -166,10 +110,7 @@ impl PartFrameReader { } } -pub fn encode_part_header( - part_header: &M::PartHeader, - out: &mut (impl BufMut + AsMut<[u8]>), -) { +pub fn encode_part_header(part_header: &H, out: &mut (impl BufMut + AsMut<[u8]>)) { encode_tagged_value_part(FrameKind::PartHeader, part_header, out) } @@ -239,30 +180,19 @@ mod tests { encode_body_chunk, encode_end_part, encode_finish, encode_part_header, PartFrameReader, PartReadStep, }; - use crate::{download::Download, RouteId}; - - struct Files; - - impl Download for Files { - const ROUTE: RouteId = RouteId::from_u32(12); - type Error = core::convert::Infallible; - type Request = Vec; - type ResponseHeader = Vec; - type PartHeader = Vec; - } #[test] fn part_reader_emits_multipart_sequence() { let mut encoded = Vec::new(); - encode_part_header::(&b"a.txt".to_vec(), &mut encoded); + encode_part_header(&b"a.txt".to_vec(), &mut encoded); encode_body_chunk(&Bytes::from_static(b"hel"), &mut encoded); encode_body_chunk(&Bytes::from_static(b"lo"), &mut encoded); encode_end_part(&mut encoded); - encode_part_header::(&b"b.txt".to_vec(), &mut encoded); + encode_part_header(&b"b.txt".to_vec(), &mut encoded); encode_end_part(&mut encoded); encode_finish(&mut encoded); - let mut reader = PartFrameReader::::new(Default::default()); + let mut reader = PartFrameReader::>::new(Default::default()); reader.push(Bytes::from(encoded)); match reader.advance().unwrap() { @@ -306,12 +236,12 @@ mod tests { } #[test] - fn part_reader_waits_for_complete_frame() { + fn part_reader_waits_for_complete_header_frame() { let mut encoded = Vec::new(); - encode_part_header::(&b"a.txt".to_vec(), &mut encoded); + encode_part_header(&b"a.txt".to_vec(), &mut encoded); let encoded = Bytes::from(encoded); - let mut reader = PartFrameReader::::new(Default::default()); + let mut reader = PartFrameReader::>::new(Default::default()); reader.push(encoded.slice(..4)); match reader.advance().unwrap() { PartReadStep::NeedMore => {} @@ -331,7 +261,7 @@ mod tests { encode_body_chunk(&Bytes::from_static(b"hello"), &mut encoded); let encoded = Bytes::from(encoded); - let mut reader = PartFrameReader::::new(Default::default()); + let mut reader = PartFrameReader::>::new(Default::default()); reader.push(encoded.slice(..9)); match reader.advance().unwrap() { PartReadStep::NeedMore => {} diff --git a/ql-rpc/src/rpc/upload/client.rs b/ql-rpc/src/rpc/upload/client.rs index a31a1fa6..47cb6544 100644 --- a/ql-rpc/src/rpc/upload/client.rs +++ b/ql-rpc/src/rpc/upload/client.rs @@ -1,8 +1,10 @@ use bytes::{BufMut, Bytes}; use crate::{ - finish_bytes, read_bytes, upload::Upload, write_bytes, CallError, ChunkQueue, RpcCodec, - RpcRead, RpcWrite, + finish_bytes, read_bytes, + rpc::parts::{encode_body_chunk, encode_end_part, encode_finish, encode_part_header}, + upload::Upload, + write_bytes, CallError, ChunkQueue, RpcCodec, RpcRead, RpcWrite, StreamCloseCode, }; pub struct UploadCall @@ -16,6 +18,16 @@ where marker: std::marker::PhantomData M>, } +pub struct UploadPartWriter<'a, M, W, R> +where + M: Upload, + W: RpcWrite, + R: RpcRead, +{ + parent: &'a mut UploadCall, + finished: bool, +} + impl UploadCall where M: Upload, @@ -30,13 +42,27 @@ where } } - pub async fn send(&mut self, bytes: Bytes) -> Result<(), W::Error> { + pub async fn start_part( + &mut self, + part_header: M::PartHeader, + ) -> Result, W::Error> { let writer = self.writer.as_mut().expect("upload writer exists"); - write_bytes(writer, bytes).await + let mut encoded = Vec::new(); + encode_part_header(&part_header, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await?; + Ok(UploadPartWriter { + parent: self, + finished: false, + }) } pub async fn finish(mut self) -> Result> { let mut writer = self.writer.take().expect("upload writer exists"); + let mut encoded = Vec::new(); + encode_finish(&mut encoded); + write_bytes(&mut writer, Bytes::from(encoded)) + .await + .map_err(CallError::Transport)?; finish_bytes(&mut writer) .await .map_err(CallError::Transport)?; @@ -57,6 +83,15 @@ where } Ok(value) } + + fn close(&mut self, code: StreamCloseCode) { + if let Some(reader) = self.reader.take() { + reader.close(code); + } + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } } impl Drop for UploadCall @@ -66,11 +101,42 @@ where R: RpcRead, { fn drop(&mut self) { - if let Some(reader) = self.reader.take() { - reader.close(crate::StreamCloseCode::CANCELLED); - } - if let Some(writer) = self.writer.take() { - writer.close(crate::StreamCloseCode::CANCELLED); + self.close(StreamCloseCode::CANCELLED); + } +} + +impl UploadPartWriter<'_, M, W, R> +where + M: Upload, + W: RpcWrite, + R: RpcRead, +{ + pub async fn send(&mut self, bytes: Bytes) -> Result<(), W::Error> { + let writer = self.parent.writer.as_mut().expect("upload writer exists"); + let mut encoded = Vec::new(); + encode_body_chunk(&bytes, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await + } + + pub async fn finish(mut self) -> Result<(), W::Error> { + let writer = self.parent.writer.as_mut().expect("upload writer exists"); + let mut encoded = Vec::new(); + encode_end_part(&mut encoded); + write_bytes(writer, Bytes::from(encoded)).await?; + self.finished = true; + Ok(()) + } +} + +impl Drop for UploadPartWriter<'_, M, W, R> +where + M: Upload, + W: RpcWrite, + R: RpcRead, +{ + fn drop(&mut self) { + if !self.finished { + self.parent.close(StreamCloseCode::CANCELLED); } } } diff --git a/ql-rpc/src/rpc/upload/mod.rs b/ql-rpc/src/rpc/upload/mod.rs index 985bb5e7..b6d2a7e1 100644 --- a/ql-rpc/src/rpc/upload/mod.rs +++ b/ql-rpc/src/rpc/upload/mod.rs @@ -3,13 +3,13 @@ use crate::{RouteId, RpcCodec}; pub(crate) mod client; pub(crate) mod server; -pub use client::{encode_request, UploadCall}; -pub use server::{UploadHandler, UploadHandlerLocal, UploadReader, UploadResponder}; +pub use client::{encode_request, UploadCall, UploadPartWriter}; +pub use server::{UploadHandler, UploadHandlerLocal, UploadPart, UploadReader, UploadResponder}; -/// rpc where the caller uploads raw bytes after a typed request +/// rpc where the caller uploads zero or more byte parts after a typed request /// /// the typed request usually describes how the responder should interpret the -/// following byte stream +/// following parts /// the request is length-delimited so raw upload bytes can follow immediately /// once the upload reaches eof, the responder returns one typed /// [`Self::Response`] @@ -20,6 +20,8 @@ pub trait Upload { type Error; /// typed input needed before request body bytes arrive type Request: RpcCodec; + /// typed metadata available before each byte part arrives + type PartHeader: RpcCodec; /// typed terminal result after the upload body is fully read type Response: RpcCodec; } diff --git a/ql-rpc/src/rpc/upload/server.rs b/ql-rpc/src/rpc/upload/server.rs index b12783c8..9354e3d9 100644 --- a/ql-rpc/src/rpc/upload/server.rs +++ b/ql-rpc/src/rpc/upload/server.rs @@ -1,13 +1,14 @@ -use std::{ - future::{poll_fn, Future}, - task::{Context, Poll}, -}; +use std::future::{poll_fn, Future}; use bytes::Bytes; use crate::{ - request::Response, rpc::read_framed_request_prefix, ChunkQueue, RouterConfig, RpcRead, - RpcStream, RpcWrite, StreamCloseCode, StreamError, Upload, + request::Response, + rpc::{ + parts::{FrameKind, PartFrameReader, PartReadStep}, + read_framed_request_prefix, + }, + RouterConfig, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, Upload, }; #[trait_variant::make(UploadHandler: Send)] @@ -19,19 +20,30 @@ where async fn handle( self, request: M::Request, - upload: UploadReader, + upload: UploadReader, responder: UploadResponder, ); fn handle_transport_error(&self, _error: &St::Error) {} } -pub struct UploadReader +pub struct UploadReader where + M: Upload, R: RpcRead, { - buffered: ChunkQueue, stream: R, + reader: PartFrameReader, + finished: bool, +} + +pub struct UploadPart<'a, M, R> +where + M: Upload, + R: RpcRead, +{ + parent: &'a mut UploadReader, + finished: bool, } pub struct UploadResponder @@ -41,28 +53,59 @@ where inner: Response, } -impl UploadReader +impl UploadReader where + M: Upload, R: RpcRead, { - pub fn poll_read( + pub async fn next_part( &mut self, - max_len: usize, - cx: &mut Context<'_>, - ) -> Poll, R::Error>> { - if let Some(chunk) = self.buffered.pop_front(max_len) { - return Poll::Ready(Ok(Some(chunk))); + ) -> Result)>, crate::CallError> + { + if self.finished { + return Ok(None); } - self.stream.poll_read(max_len, cx) + match self.read_frame().await? { + PartReadStep::PartHeader(value) => Ok(Some(( + value, + UploadPart { + parent: self, + finished: false, + }, + ))), + PartReadStep::Finish => { + self.finished = true; + Ok(None) + } + PartReadStep::BodyBytes(_) => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::BodyChunk.tag()).into()) + } + PartReadStep::EndPart => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::EndPart.tag()).into()) + } + PartReadStep::NeedMore => unreachable!("read_frame waits for a complete frame"), + } } - pub async fn read(&mut self, max_len: usize) -> Result, R::Error> { - poll_fn(|cx| self.poll_read(max_len, cx)).await - } + async fn read_frame( + &mut self, + ) -> Result, crate::CallError> { + loop { + match self.reader.advance() { + Ok(PartReadStep::NeedMore) => {} + Ok(step) => return Ok(step), + Err(error) => return Err(error.into()), + } - pub async fn read_chunk(&mut self) -> Result, R::Error> { - self.read(usize::MAX).await + match poll_fn(|cx| self.stream.poll_read(usize::MAX, cx)).await { + Ok(Some(chunk)) => { + self.reader.push(chunk); + } + Ok(None) => return Err(crate::Error::Truncated.into()), + Err(error) => return Err(crate::CallError::Transport(error)), + } + } } pub fn into_inner(self) -> R { @@ -70,6 +113,35 @@ where } } +impl UploadPart<'_, M, R> +where + M: Upload, + R: RpcRead, +{ + pub async fn read_chunk( + &mut self, + ) -> Result, crate::CallError> { + if self.finished { + return Ok(None); + } + + match self.parent.read_frame().await? { + PartReadStep::BodyBytes(bytes) => Ok(Some(bytes)), + PartReadStep::EndPart => { + self.finished = true; + Ok(None) + } + PartReadStep::PartHeader(_) => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::PartHeader.tag()).into()) + } + PartReadStep::Finish => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::Finish.tag()).into()) + } + PartReadStep::NeedMore => unreachable!("read_frame waits for a complete frame"), + } + } +} + impl UploadResponder where T: crate::RpcCodec, @@ -103,7 +175,7 @@ pub(crate) async fn handle_upload_inner( H: FnOnce( S, M::Request, - UploadReader, + UploadReader, UploadResponder, ) -> HF, HF: Future, @@ -127,8 +199,9 @@ pub(crate) async fn handle_upload_inner( state, request, UploadReader { - buffered, stream: reader, + reader: PartFrameReader::new(buffered), + finished: false, }, UploadResponder::new(writer), ) diff --git a/ql-rpc/src/rpc/utils.rs b/ql-rpc/src/rpc/utils.rs index 90c1df0f..bf5f49ea 100644 --- a/ql-rpc/src/rpc/utils.rs +++ b/ql-rpc/src/rpc/utils.rs @@ -1,6 +1,6 @@ use crate::{ - read_bytes, ChunkQueue, CodecError, FramedReadStep, FramedReader, RouterConfig, RpcCodec, - RpcRead, StreamCloseCode, + read_bytes, ChunkQueue, CodecError, FramedPrefixStep, FramedReadStep, FramedReader, + RouterConfig, RpcCodec, RpcRead, StreamCloseCode, }; /// reads one length-delimited value and rejects trailing bytes @@ -57,24 +57,15 @@ where T: RpcCodec, R: RpcRead, { - let mut bytes = ChunkQueue::default(); + let mut value_reader = FramedReader::::default(); let mut total_read = 0usize; loop { - let maybe_value = { - match bytes.try_take_part() { - Ok(Some(mut body)) => { - let value = - T::decode_value(&mut body).map_err(|_error| StreamCloseCode::REFUSED)?; - drop(body); - Some(value) - } - Ok(None) => None, - Err(_error) => return Err(StreamCloseCode::REFUSED.into()), - } - }; - if let Some(value) = maybe_value { - return Ok((value, bytes)); + match value_reader.advance_prefix() { + Ok(FramedPrefixStep::Value { value, bytes }) => return Ok((value, bytes)), + Ok(FramedPrefixStep::NeedMore(next)) => value_reader = next, + Err(CodecError::Rpc(_error)) => return Err(StreamCloseCode::REFUSED.into()), + Err(CodecError::Codec(_error)) => return Err(StreamCloseCode::REFUSED.into()), } let remaining = config.max_request_bytes.saturating_sub(total_read); @@ -85,7 +76,7 @@ where match read_bytes(reader, remaining).await { Ok(Some(chunk)) => { total_read += chunk.len(); - bytes.push(chunk); + value_reader = value_reader.push(chunk); } Ok(None) => return Err(StreamCloseCode::REFUSED.into()), Err(error) => return Err(error), diff --git a/ql-runtime/src/rpc/download.rs b/ql-runtime/src/rpc/download.rs index 6068e357..a85a20e3 100644 --- a/ql-runtime/src/rpc/download.rs +++ b/ql-runtime/src/rpc/download.rs @@ -41,10 +41,6 @@ where .await? .map(|(header, inner)| (header, DownloadPart { inner }))) } - - pub fn close(self, code: ql_wire::StreamCloseCode) { - self.inner.into_inner().close(code); - } } impl DownloadPart<'_, M> diff --git a/ql-runtime/src/rpc/upload.rs b/ql-runtime/src/rpc/upload.rs index c749e44b..33ee3665 100644 --- a/ql-runtime/src/rpc/upload.rs +++ b/ql-runtime/src/rpc/upload.rs @@ -8,15 +8,37 @@ pub struct UploadCall { pub(super) inner: ql_rpc::upload::UploadCall, } +pub struct UploadPartWriter<'a, M: UploadRpc> { + inner: ql_rpc::upload::UploadPartWriter<'a, M, crate::StreamWriter, crate::StreamReader>, +} + impl UploadCall where M: UploadRpc, { - pub async fn send(&mut self, bytes: Bytes) -> Result<(), QlStreamError> { - self.inner.send(bytes).await + pub async fn start_part( + &mut self, + part_header: M::PartHeader, + ) -> Result, QlStreamError> { + Ok(UploadPartWriter { + inner: self.inner.start_part(part_header).await?, + }) } pub async fn finish(self) -> Result> { self.inner.finish().await.map_err(RpcError::from) } } + +impl UploadPartWriter<'_, M> +where + M: UploadRpc, +{ + pub async fn send(&mut self, bytes: Bytes) -> Result<(), QlStreamError> { + self.inner.send(bytes).await + } + + pub async fn finish(self) -> Result<(), QlStreamError> { + self.inner.finish().await + } +} diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index e8c03f7a..b8863f31 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -106,6 +106,7 @@ impl ql_rpc::upload::Upload for BlobUpload { const ROUTE: RouteId = RouteId::from_u32(55); type Error = core::convert::Infallible; type Request = Vec; + type PartHeader = Vec; type Response = Vec; } @@ -419,24 +420,28 @@ async fn rpc_download() { .unwrap(); let (header, mut reader) = download.into_reader().await.unwrap(); assert_eq!(header, b"image/png".to_vec()); - let (part_header, mut part) = reader.next_part().await.unwrap().unwrap(); - assert_eq!(part_header, b"icon".to_vec()); - assert_eq!( - part.read_chunk().await.unwrap(), - Some(Bytes::from_static(b"abc")) - ); - assert_eq!( - part.read_chunk().await.unwrap(), - Some(Bytes::from_static(b"def")) - ); - assert_eq!(part.read_chunk().await.unwrap(), None); - let (part_header, mut part) = reader.next_part().await.unwrap().unwrap(); - assert_eq!(part_header, b"manifest".to_vec()); - assert_eq!( - part.read_chunk().await.unwrap(), - Some(Bytes::from_static(b"{}")) - ); - assert_eq!(part.read_chunk().await.unwrap(), None); + { + let (part_header, mut part) = reader.next_part().await.unwrap().unwrap(); + assert_eq!(part_header, b"icon".to_vec()); + assert_eq!( + part.read_chunk().await.unwrap(), + Some(Bytes::from_static(b"abc")) + ); + assert_eq!( + part.read_chunk().await.unwrap(), + Some(Bytes::from_static(b"def")) + ); + assert_eq!(part.read_chunk().await.unwrap(), None); + } + { + let (part_header, mut part) = reader.next_part().await.unwrap().unwrap(); + assert_eq!(part_header, b"manifest".to_vec()); + assert_eq!( + part.read_chunk().await.unwrap(), + Some(Bytes::from_static(b"{}")) + ); + assert_eq!(part.read_chunk().await.unwrap(), None); + } assert!(reader.next_part().await.unwrap().is_none()); assert_eq!(seen.borrow().as_slice(), &[b"logo".to_vec()]); @@ -460,7 +465,7 @@ async fn rpc_upload() { async fn handle( self, request: Vec, - mut upload: UploadReader, + mut upload: UploadReader, responder: UploadResponder, StreamWriter>, ) { let requests = self.requests.clone(); @@ -468,8 +473,13 @@ async fn rpc_upload() { requests.borrow_mut().push(request); let mut body = Vec::new(); - while let Some(chunk) = upload.read_chunk().await.unwrap() { - body.extend_from_slice(&chunk); + while let Some((part_header, mut part)) = upload.next_part().await.unwrap() { + body.extend_from_slice(&part_header); + body.push(b':'); + while let Some(chunk) = part.read_chunk().await.unwrap() { + body.extend_from_slice(&chunk); + } + body.push(b';'); } uploads.borrow_mut().push(body.clone()); @@ -501,13 +511,21 @@ async fn rpc_upload() { let rpc = pair.side_mut(Side::A).handle.rpc(); let mut upload = rpc.upload::(&b"logo".to_vec()).await.unwrap(); - upload.send(Bytes::from_static(b"abc")).await.unwrap(); - upload.send(Bytes::from_static(b"def")).await.unwrap(); + let mut part = upload.start_part(b"icon".to_vec()).await.unwrap(); + part.send(Bytes::from_static(b"abc")).await.unwrap(); + part.send(Bytes::from_static(b"def")).await.unwrap(); + part.finish().await.unwrap(); + let mut part = upload.start_part(b"manifest".to_vec()).await.unwrap(); + part.send(Bytes::from_static(b"{}")).await.unwrap(); + part.finish().await.unwrap(); let response = upload.finish().await.unwrap(); - assert_eq!(response, b"abcdef".to_vec()); + assert_eq!(response, b"icon:abcdef;manifest:{};".to_vec()); assert_eq!(requests.borrow().as_slice(), &[b"logo".to_vec()]); - assert_eq!(uploads.borrow().as_slice(), &[b"abcdef".to_vec()]); + assert_eq!( + uploads.borrow().as_slice(), + &[b"icon:abcdef;manifest:{};".to_vec()] + ); tokio::time::timeout(Duration::from_secs(2), responder) .await From a4839e4ca667c344b20ede8bb9152522f611927e Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 15 May 2026 10:18:12 -0400 Subject: [PATCH 300/304] ql-rpc: clean up apis --- ql-rpc/src/rpc/download/client.rs | 41 ++++++++++------ ql-rpc/src/rpc/download/server.rs | 10 ++-- ql-rpc/src/rpc/duplex/client.rs | 69 ++++++++++++++++----------- ql-rpc/src/rpc/duplex/codec.rs | 32 ++++++------- ql-rpc/src/rpc/progress/client.rs | 53 ++++++++++++-------- ql-rpc/src/rpc/progress/codec.rs | 43 +++++++---------- ql-rpc/src/rpc/progress/server.rs | 4 +- ql-rpc/src/rpc/request/server.rs | 2 +- ql-rpc/src/rpc/subscription/client.rs | 66 ++++++++++++++++--------- ql-rpc/src/rpc/subscription/codec.rs | 22 +++------ ql-rpc/src/rpc/subscription/server.rs | 4 +- ql-rpc/src/rpc/upload/client.rs | 10 ++-- ql-rpc/src/rpc/upload/server.rs | 52 ++++++++++++++++---- ql-runtime/src/rpc/download.rs | 12 +++++ ql-runtime/src/rpc/progress.rs | 9 ++++ ql-runtime/src/rpc/subscription.rs | 4 ++ 16 files changed, 266 insertions(+), 167 deletions(-) diff --git a/ql-rpc/src/rpc/download/client.rs b/ql-rpc/src/rpc/download/client.rs index 9175dbc5..42793489 100644 --- a/ql-rpc/src/rpc/download/client.rs +++ b/ql-rpc/src/rpc/download/client.rs @@ -33,7 +33,6 @@ where { stream: Option, reader: crate::download::PartFrameReader, - finished: bool, } impl DownloadCall @@ -52,15 +51,15 @@ where mut self, ) -> Result<(M::ResponseHeader, DownloadReader), CallError> { loop { - let reader = self.reader.take().expect("download reader is present"); + let reader = self.reader.take().unwrap(); let reader = match reader.advance_prefix() { Ok(FramedPrefixStep::Value { value, bytes }) => { + let stream = self.stream.take().unwrap(); return Ok(( value, DownloadReader { - stream: self.stream.take(), + stream: Some(stream), reader: crate::download::PartFrameReader::::new(bytes), - finished: false, }, )); } @@ -68,7 +67,7 @@ where Err(error) => return Err(error.into()), }; - let stream = self.stream.as_mut().expect("download stream exists"); + let stream = self.stream.as_mut().unwrap(); match poll_fn(|cx| stream.poll_read(usize::MAX, cx)).await { Ok(Some(chunk)) => { self.reader = Some(reader.push(chunk)); @@ -79,7 +78,11 @@ where } } - fn close(&mut self, code: StreamCloseCode) { + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { if let Some(stream) = self.stream.take() { stream.close(code); } @@ -92,7 +95,7 @@ where R: RpcRead, { fn drop(&mut self) { - self.close(StreamCloseCode::CANCELLED); + self.close_inner(StreamCloseCode::CANCELLED); } } @@ -105,7 +108,7 @@ where &mut self, ) -> Result)>, CallError> { - if self.finished { + if self.stream.is_none() { return Ok(None); } @@ -118,7 +121,7 @@ where }, ))), PartReadStep::Finish => { - self.finished = true; + self.stream.take(); Ok(None) } PartReadStep::BodyBytes(_) => { @@ -141,7 +144,7 @@ where Err(error) => return Err(error.into()), } - let stream = self.stream.as_mut().expect("download stream exists"); + let stream = self.stream.as_mut().unwrap(); match poll_fn(|cx| stream.poll_read(usize::MAX, cx)).await { Ok(Some(chunk)) => { self.reader.push(chunk); @@ -152,8 +155,11 @@ where } } - fn close(&mut self, code: StreamCloseCode) { - self.finished = true; + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { if let Some(stream) = self.stream.take() { stream.close(code); } @@ -166,8 +172,8 @@ where R: RpcRead, { fn drop(&mut self) { - if !self.finished { - self.close(StreamCloseCode::CANCELLED); + if self.stream.is_some() { + self.close_inner(StreamCloseCode::CANCELLED); } } } @@ -197,6 +203,11 @@ where PartReadStep::NeedMore => unreachable!("read_frame waits for a complete frame"), } } + + pub fn close(mut self, code: StreamCloseCode) { + self.parent.close_inner(code); + self.finished = true; + } } impl Drop for DownloadPart<'_, M, R> @@ -206,7 +217,7 @@ where { fn drop(&mut self) { if !self.finished { - self.parent.close(StreamCloseCode::CANCELLED); + self.parent.close_inner(StreamCloseCode::CANCELLED); } } } diff --git a/ql-rpc/src/rpc/download/server.rs b/ql-rpc/src/rpc/download/server.rs index 7b421e9b..f9b3d7d0 100644 --- a/ql-rpc/src/rpc/download/server.rs +++ b/ql-rpc/src/rpc/download/server.rs @@ -67,7 +67,7 @@ where mut self, response_header: M::ResponseHeader, ) -> Result, W::Error> { - let mut writer = self.writer.take().expect("download writer exists"); + let mut writer = self.writer.take().unwrap(); let mut encoded = Vec::new(); codec::encode_value_part(&response_header, &mut encoded); write_bytes(&mut writer, Bytes::from(encoded)).await?; @@ -105,7 +105,7 @@ where &mut self, part_header: M::PartHeader, ) -> Result, W::Error> { - let writer = self.writer.as_mut().expect("download writer exists"); + let writer = self.writer.as_mut().unwrap(); let mut encoded = Vec::new(); encode_part_header(&part_header, &mut encoded); write_bytes(writer, Bytes::from(encoded)).await?; @@ -116,7 +116,7 @@ where } pub async fn finish(mut self) -> Result<(), W::Error> { - let mut writer = self.writer.take().expect("download writer exists"); + let mut writer = self.writer.take().unwrap(); let mut encoded = Vec::new(); encode_finish(&mut encoded); write_bytes(&mut writer, Bytes::from(encoded)).await?; @@ -148,14 +148,14 @@ where W: RpcWrite, { pub async fn send(&mut self, bytes: Bytes) -> Result<(), W::Error> { - let writer = self.parent.writer.as_mut().expect("download writer exists"); + let writer = self.parent.writer.as_mut().unwrap(); let mut encoded = Vec::new(); encode_body_chunk(&bytes, &mut encoded); write_bytes(writer, Bytes::from(encoded)).await } pub async fn finish(mut self) -> Result<(), W::Error> { - let writer = self.parent.writer.as_mut().expect("download writer exists"); + let writer = self.parent.writer.as_mut().unwrap(); let mut encoded = Vec::new(); encode_end_part(&mut encoded); write_bytes(writer, Bytes::from(encoded)).await?; diff --git a/ql-rpc/src/rpc/duplex/client.rs b/ql-rpc/src/rpc/duplex/client.rs index 579a9f5e..e76050a6 100644 --- a/ql-rpc/src/rpc/duplex/client.rs +++ b/ql-rpc/src/rpc/duplex/client.rs @@ -35,8 +35,8 @@ where T: RpcCodec, R: RpcRead, { - stream: R, - reader: Option>, + stream: Option, + reader: EventReader, } impl DuplexSender @@ -52,14 +52,14 @@ where } pub async fn send(&mut self, event: &T) -> Result<(), W::Error> { - let writer = self.writer.as_mut().expect("duplex writer exists"); + let writer = self.writer.as_mut().unwrap(); let mut encoded = Vec::new(); codec::encode_event(event, &mut encoded); write_bytes(writer, Bytes::from(encoded)).await } pub async fn finish(mut self) -> Result<(), W::Error> { - let mut writer = self.writer.take().expect("duplex writer exists"); + let mut writer = self.writer.take().unwrap(); finish_bytes(&mut writer).await } @@ -89,8 +89,8 @@ where { pub fn new(stream: R) -> Self { Self { - stream, - reader: Some(EventReader::default()), + stream: Some(stream), + reader: EventReader::default(), } } @@ -102,48 +102,63 @@ where &mut self, cx: &mut Context<'_>, ) -> Poll>>> { + if self.stream.is_none() { + return Poll::Ready(None); + } + loop { - let Some(reader) = self.reader.take() else { - return Poll::Ready(None); - }; - - let reader = match reader.advance() { - Ok(ReadStep::Event { value, next }) => { - self.reader = Some(next); - return Poll::Ready(Some(Ok(value))); + match self.reader.advance() { + Ok(ReadStep::Event(value)) => return Poll::Ready(Some(Ok(value))), + Ok(ReadStep::NeedMore) => {} + Err(error) => { + self.stream.take(); + return Poll::Ready(Some(Err(error.into()))); } - Ok(ReadStep::NeedMore(next)) => next, - Err(error) => return Poll::Ready(Some(Err(error.into()))), - }; + } - match self.stream.poll_read(usize::MAX, cx) { + let stream = self.stream.as_mut().unwrap(); + match stream.poll_read(usize::MAX, cx) { Poll::Ready(Ok(Some(chunk))) => { - self.reader = Some(reader.push(chunk)); + self.reader.push(chunk); } Poll::Ready(Ok(None)) => { - self.reader = None; - if reader.is_empty() { + if self.reader.is_empty() { + self.stream.take(); return Poll::Ready(None); } + self.stream.take(); return Poll::Ready(Some(Err(crate::Error::Truncated.into()))); } Poll::Ready(Err(error)) => { - self.reader = None; + self.stream.take(); return Poll::Ready(Some(Err(CallError::Transport(error)))); } Poll::Pending => { - self.reader = Some(reader); return Poll::Pending; } } } } - pub fn close(self, code: StreamCloseCode) { - self.stream.close(code); + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if let Some(stream) = self.stream.take() { + stream.close(code); + } } +} - pub fn into_inner(self) -> R { - self.stream +impl Drop for DuplexReceiver +where + T: RpcCodec, + R: RpcRead, +{ + fn drop(&mut self) { + if self.stream.is_some() { + self.close_inner(StreamCloseCode::CANCELLED); + } } } diff --git a/ql-rpc/src/rpc/duplex/codec.rs b/ql-rpc/src/rpc/duplex/codec.rs index f7950e89..68bc87c7 100644 --- a/ql-rpc/src/rpc/duplex/codec.rs +++ b/ql-rpc/src/rpc/duplex/codec.rs @@ -12,8 +12,8 @@ where } pub enum ReadStep { - NeedMore(EventReader), - Event { value: T, next: EventReader }, + NeedMore, + Event(T), } pub struct EventReader { @@ -31,19 +31,17 @@ impl Default for EventReader { } impl EventReader { - pub fn push(mut self, chunk: Bytes) -> Self { + pub fn push(&mut self, chunk: Bytes) { self.bytes.push(chunk); - self } pub fn is_empty(&self) -> bool { self.bytes.remaining() == 0 } - pub fn advance(self) -> Result, CodecError> { - let mut this = self; - let Some(mut body) = this.bytes.try_take_part()? else { - return Ok(ReadStep::NeedMore(this)); + pub fn advance(&mut self) -> Result, CodecError> { + let Some(mut body) = self.bytes.try_take_part()? else { + return Ok(ReadStep::NeedMore); }; let value = { @@ -51,7 +49,7 @@ impl EventReader { drop(body); value }; - Ok(ReadStep::Event { value, next: this }) + Ok(ReadStep::Event(value)) } } @@ -67,22 +65,20 @@ mod tests { encode_event(&b"one".to_vec(), &mut encoded); encode_event(&b"two".to_vec(), &mut encoded); - let reader = match EventReader::>::default() - .push(Bytes::from(encoded)) - .advance() - .unwrap() - { - ReadStep::Event { value, next } => { + let mut reader = EventReader::>::default(); + reader.push(Bytes::from(encoded)); + + match reader.advance().unwrap() { + ReadStep::Event(value) => { assert_eq!(value, b"one".to_vec()); - next } _ => unreachable!(), }; match reader.advance().unwrap() { - ReadStep::Event { value, next } => { + ReadStep::Event(value) => { assert_eq!(value, b"two".to_vec()); - assert!(next.is_empty()); + assert!(reader.is_empty()); } _ => unreachable!(), } diff --git a/ql-rpc/src/rpc/progress/client.rs b/ql-rpc/src/rpc/progress/client.rs index bbd68928..c2218c97 100644 --- a/ql-rpc/src/rpc/progress/client.rs +++ b/ql-rpc/src/rpc/progress/client.rs @@ -6,7 +6,7 @@ use std::{ use crate::{ progress::{Progress, ReadStep, ResponseReader}, - CallError, Error, RpcRead, + CallError, Error, RpcRead, StreamCloseCode, }; pub struct ProgressCall @@ -14,7 +14,7 @@ where M: Progress, R: RpcRead, { - stream: R, + stream: Option, state: State, } @@ -42,7 +42,7 @@ where { pub fn new(stream: R) -> Self { Self { - stream, + stream: Some(stream), state: State::Reading(ResponseReader::default()), } } @@ -53,40 +53,32 @@ where fn poll_step(&mut self, cx: &mut Context<'_>) -> Poll> { loop { - let reader = match std::mem::replace(&mut self.state, State::Invalid) { + let reader = match &mut self.state { State::Reading(reader) => reader, - state @ (State::Terminal(_) | State::Done) => { - self.state = state; - return Poll::Ready(None); - } + State::Terminal(_) | State::Done => return Poll::Ready(None), State::Invalid => panic!("invalid state"), }; match reader.advance() { - Ok(ReadStep::Progress { value, next }) => { - self.state = State::Reading(next); - return Poll::Ready(Some(value)); - } + Ok(ReadStep::Progress(value)) => return Poll::Ready(Some(value)), Ok(ReadStep::Response(response)) => { self.state = State::Terminal(Ok(response)); return Poll::Ready(None); } - Ok(ReadStep::NeedMore(next)) => { - self.state = State::Reading(next); - } + Ok(ReadStep::NeedMore) => {} Err(error) => { self.state = State::Terminal(Err(error.into())); return Poll::Ready(None); } } - match self.stream.poll_read(usize::MAX, cx) { + let stream = self.stream.as_mut().unwrap(); + match stream.poll_read(usize::MAX, cx) { Poll::Ready(Ok(Some(chunk))) => { - let State::Reading(reader) = std::mem::replace(&mut self.state, State::Invalid) - else { + let State::Reading(reader) = &mut self.state else { panic!("invalid state"); }; - self.state = State::Reading(reader.push(chunk)); + reader.push(chunk); } Poll::Ready(Ok(None)) => { self.state = State::Terminal(Err(Error::MissingResponse.into())); @@ -104,6 +96,29 @@ where pub fn poll_next_progress(&mut self, cx: &mut Context<'_>) -> Poll> { self.poll_step(cx) } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { + self.state = State::Done; + if let Some(stream) = self.stream.take() { + stream.close(code); + } + } +} + +impl Drop for ProgressCall +where + M: Progress, + R: RpcRead, +{ + fn drop(&mut self) { + if matches!(self.state, State::Reading(_)) { + self.close_inner(StreamCloseCode::CANCELLED); + } + } } impl Future for ProgressCall diff --git a/ql-rpc/src/rpc/progress/codec.rs b/ql-rpc/src/rpc/progress/codec.rs index c56af0dd..11e0d235 100644 --- a/ql-rpc/src/rpc/progress/codec.rs +++ b/ql-rpc/src/rpc/progress/codec.rs @@ -5,11 +5,8 @@ use bytes::{BufMut, Bytes}; use crate::{codec, progress::Progress, CodecError, Error, RpcCodec}; pub enum ReadStep { - NeedMore(ResponseReader), - Progress { - value: M::Progress, - next: ResponseReader, - }, + NeedMore, + Progress(M::Progress), Response(M::Response), } @@ -28,17 +25,14 @@ impl Default for ResponseReader { } impl ResponseReader { - pub fn push(mut self, chunk: Bytes) -> Self { + pub fn push(&mut self, chunk: Bytes) { self.bytes.push(chunk); - self } - pub fn advance(self) -> Result, CodecError> { - let mut this = self; - - let Some((kind, mut body)) = this.bytes.try_take_tagged_part().map_err(CodecError::Rpc)? + pub fn advance(&mut self) -> Result, CodecError> { + let Some((kind, mut body)) = self.bytes.try_take_tagged_part().map_err(CodecError::Rpc)? else { - return Ok(ReadStep::NeedMore(this)); + return Ok(ReadStep::NeedMore); }; match kind { @@ -48,12 +42,12 @@ impl ResponseReader { drop(body); value }; - Ok(ReadStep::Progress { value, next: this }) + Ok(ReadStep::Progress(value)) } x if x == FrameKind::Response as u8 => { let response = M::Response::decode_value(&mut body).map_err(CodecError::Codec)?; drop(body); - if this.bytes.remaining() > 0 { + if self.bytes.remaining() > 0 { Err(CodecError::Rpc(Error::TrailingBytes)) } else { Ok(ReadStep::Response(response)) @@ -117,14 +111,12 @@ mod tests { encode_progress::(&b"10%".to_vec(), &mut encoded); encode_response::(&b"done".to_vec(), &mut encoded); - let reader = match ResponseReader::::default() - .push(Bytes::from(encoded)) - .advance() - .unwrap() - { - ReadStep::Progress { value, next } => { + let mut reader = ResponseReader::::default(); + reader.push(Bytes::from(encoded)); + + match reader.advance().unwrap() { + ReadStep::Progress(value) => { assert_eq!(value, b"10%".to_vec()); - next } _ => unreachable!(), }; @@ -139,11 +131,10 @@ mod tests { let mut encoded = Vec::new(); encode_response::(&b"done".to_vec(), &mut encoded); - match ResponseReader::::default() - .push(Bytes::from(encoded)) - .advance() - .unwrap() - { + let mut reader = ResponseReader::::default(); + reader.push(Bytes::from(encoded)); + + match reader.advance().unwrap() { ReadStep::Response(value) => assert_eq!(value, b"done".to_vec()), _ => unreachable!(), } diff --git a/ql-rpc/src/rpc/progress/server.rs b/ql-rpc/src/rpc/progress/server.rs index 8599d93e..b94421cf 100644 --- a/ql-rpc/src/rpc/progress/server.rs +++ b/ql-rpc/src/rpc/progress/server.rs @@ -42,14 +42,14 @@ where } pub async fn send(&mut self, progress: M::Progress) -> Result<(), W::Error> { - let writer = self.writer.as_mut().expect("progress writer exists"); + let writer = self.writer.as_mut().unwrap(); let mut encoded = Vec::new(); encode_progress::(&progress, &mut encoded); write_bytes(writer, Bytes::from(encoded)).await } pub async fn finish(mut self, response: M::Response) -> Result<(), W::Error> { - let mut writer = self.writer.take().expect("progress writer exists"); + let mut writer = self.writer.take().unwrap(); let mut encoded = Vec::new(); encode_response::(&response, &mut encoded); write_bytes(&mut writer, Bytes::from(encoded)).await?; diff --git a/ql-rpc/src/rpc/request/server.rs b/ql-rpc/src/rpc/request/server.rs index e3347061..5211cce2 100644 --- a/ql-rpc/src/rpc/request/server.rs +++ b/ql-rpc/src/rpc/request/server.rs @@ -39,7 +39,7 @@ where } pub async fn respond(mut self, response: T) -> Result<(), W::Error> { - let mut writer = self.writer.take().expect("response writer exists"); + let mut writer = self.writer.take().unwrap(); let mut encoded = Vec::new(); response.encode_value(&mut encoded); write_bytes(&mut writer, Bytes::from(encoded)).await?; diff --git a/ql-rpc/src/rpc/subscription/client.rs b/ql-rpc/src/rpc/subscription/client.rs index fe6b3838..fe6aa5b1 100644 --- a/ql-rpc/src/rpc/subscription/client.rs +++ b/ql-rpc/src/rpc/subscription/client.rs @@ -5,7 +5,7 @@ use std::{ use crate::{ subscription::{ReadStep, ResponseReader, Subscription}, - CallError, RpcRead, + CallError, RpcRead, StreamCloseCode, }; pub struct SubscriptionCall @@ -13,8 +13,8 @@ where M: Subscription, R: RpcRead, { - stream: R, - reader: Option>, + stream: Option, + reader: ResponseReader, } impl SubscriptionCall @@ -24,8 +24,8 @@ where { pub fn new(stream: R) -> Self { Self { - stream, - reader: Some(ResponseReader::default()), + stream: Some(stream), + reader: ResponseReader::default(), } } @@ -37,43 +37,63 @@ where &mut self, cx: &mut Context<'_>, ) -> Poll>>> { - loop { - let Some(reader) = self.reader.take() else { - return Poll::Ready(None); - }; + if self.stream.is_none() { + return Poll::Ready(None); + } - let reader = match reader.advance() { - Ok(ReadStep::Item { value, next }) => { - self.reader = Some(next); - return Poll::Ready(Some(Ok(value))); + loop { + match self.reader.advance() { + Ok(ReadStep::Item(value)) => return Poll::Ready(Some(Ok(value))), + Ok(ReadStep::NeedMore) => {} + Err(error) => { + self.stream.take(); + return Poll::Ready(Some(Err(error.into()))); } - Ok(ReadStep::NeedMore(next)) => next, - Err(error) => return Poll::Ready(Some(Err(error.into()))), - }; + } - match self.stream.poll_read(usize::MAX, cx) { + let stream = self.stream.as_mut().unwrap(); + match stream.poll_read(usize::MAX, cx) { Poll::Ready(Ok(Some(chunk))) => { - self.reader = Some(reader.push(chunk)); + self.reader.push(chunk); } Poll::Ready(Ok(None)) => { - if reader.is_empty() { + if self.reader.is_empty() { + self.stream.take(); return Poll::Ready(None); } + self.stream.take(); return Poll::Ready(Some(Err(crate::Error::Truncated.into()))); } Poll::Ready(Err(error)) => { - self.reader = None; + self.stream.take(); return Poll::Ready(Some(Err(CallError::Transport(error)))); } Poll::Pending => { - self.reader = Some(reader); return Poll::Pending; } } } } - pub fn into_inner(self) -> R { - self.stream + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if let Some(stream) = self.stream.take() { + stream.close(code); + } + } +} + +impl Drop for SubscriptionCall +where + M: Subscription, + R: RpcRead, +{ + fn drop(&mut self) { + if self.stream.is_some() { + self.close_inner(StreamCloseCode::CANCELLED); + } } } diff --git a/ql-rpc/src/rpc/subscription/codec.rs b/ql-rpc/src/rpc/subscription/codec.rs index 525e817c..bdd16209 100644 --- a/ql-rpc/src/rpc/subscription/codec.rs +++ b/ql-rpc/src/rpc/subscription/codec.rs @@ -16,11 +16,8 @@ pub fn encode_item(item: &M::Event, out: &mut (impl BufMut + As } pub enum ReadStep { - NeedMore(ResponseReader), - Item { - value: M::Event, - next: ResponseReader, - }, + NeedMore, + Item(M::Event), } pub struct ResponseReader { @@ -38,19 +35,17 @@ impl Default for ResponseReader { } impl ResponseReader { - pub fn push(mut self, chunk: Bytes) -> Self { + pub fn push(&mut self, chunk: Bytes) { self.bytes.push(chunk); - self } pub fn is_empty(&self) -> bool { self.bytes.remaining() == 0 } - pub fn advance(self) -> Result, CodecError> { - let mut this = self; - let Some(mut body) = this.bytes.try_take_part()? else { - return Ok(ReadStep::NeedMore(this)); + pub fn advance(&mut self) -> Result, CodecError> { + let Some(mut body) = self.bytes.try_take_part()? else { + return Ok(ReadStep::NeedMore); }; let item = { @@ -58,9 +53,6 @@ impl ResponseReader { drop(body); item }; - Ok(ReadStep::Item { - value: item, - next: this, - }) + Ok(ReadStep::Item(item)) } } diff --git a/ql-rpc/src/rpc/subscription/server.rs b/ql-rpc/src/rpc/subscription/server.rs index 32fac4f6..6dfdd4b0 100644 --- a/ql-rpc/src/rpc/subscription/server.rs +++ b/ql-rpc/src/rpc/subscription/server.rs @@ -44,7 +44,7 @@ where } pub async fn send(&mut self, event: T) -> Result<(), W::Error> { - let writer = self.writer.as_mut().expect("subscription writer exists"); + let writer = self.writer.as_mut().unwrap(); let mut encoded = Vec::new(); codec::encode_value_part(&event, &mut encoded); write_bytes(writer, Bytes::from(encoded)).await?; @@ -52,7 +52,7 @@ where } pub async fn finish(mut self) -> Result<(), W::Error> { - let mut writer = self.writer.take().expect("subscription writer exists"); + let mut writer = self.writer.take().unwrap(); finish_bytes(&mut writer).await } diff --git a/ql-rpc/src/rpc/upload/client.rs b/ql-rpc/src/rpc/upload/client.rs index 47cb6544..b41dedcd 100644 --- a/ql-rpc/src/rpc/upload/client.rs +++ b/ql-rpc/src/rpc/upload/client.rs @@ -46,7 +46,7 @@ where &mut self, part_header: M::PartHeader, ) -> Result, W::Error> { - let writer = self.writer.as_mut().expect("upload writer exists"); + let writer = self.writer.as_mut().unwrap(); let mut encoded = Vec::new(); encode_part_header(&part_header, &mut encoded); write_bytes(writer, Bytes::from(encoded)).await?; @@ -57,7 +57,7 @@ where } pub async fn finish(mut self) -> Result> { - let mut writer = self.writer.take().expect("upload writer exists"); + let mut writer = self.writer.take().unwrap(); let mut encoded = Vec::new(); encode_finish(&mut encoded); write_bytes(&mut writer, Bytes::from(encoded)) @@ -67,7 +67,7 @@ where .await .map_err(CallError::Transport)?; - let mut reader = self.reader.take().expect("upload reader exists"); + let mut reader = self.reader.take().unwrap(); let mut bytes = ChunkQueue::default(); while let Some(chunk) = read_bytes(&mut reader, usize::MAX) @@ -112,14 +112,14 @@ where R: RpcRead, { pub async fn send(&mut self, bytes: Bytes) -> Result<(), W::Error> { - let writer = self.parent.writer.as_mut().expect("upload writer exists"); + let writer = self.parent.writer.as_mut().unwrap(); let mut encoded = Vec::new(); encode_body_chunk(&bytes, &mut encoded); write_bytes(writer, Bytes::from(encoded)).await } pub async fn finish(mut self) -> Result<(), W::Error> { - let writer = self.parent.writer.as_mut().expect("upload writer exists"); + let writer = self.parent.writer.as_mut().unwrap(); let mut encoded = Vec::new(); encode_end_part(&mut encoded); write_bytes(writer, Bytes::from(encoded)).await?; diff --git a/ql-rpc/src/rpc/upload/server.rs b/ql-rpc/src/rpc/upload/server.rs index 9354e3d9..d2e6765b 100644 --- a/ql-rpc/src/rpc/upload/server.rs +++ b/ql-rpc/src/rpc/upload/server.rs @@ -32,9 +32,8 @@ where M: Upload, R: RpcRead, { - stream: R, + stream: Option, reader: PartFrameReader, - finished: bool, } pub struct UploadPart<'a, M, R> @@ -62,7 +61,7 @@ where &mut self, ) -> Result)>, crate::CallError> { - if self.finished { + if self.stream.is_none() { return Ok(None); } @@ -75,7 +74,7 @@ where }, ))), PartReadStep::Finish => { - self.finished = true; + self.stream.take(); Ok(None) } PartReadStep::BodyBytes(_) => { @@ -98,7 +97,8 @@ where Err(error) => return Err(error.into()), } - match poll_fn(|cx| self.stream.poll_read(usize::MAX, cx)).await { + let stream = self.stream.as_mut().unwrap(); + match poll_fn(|cx| stream.poll_read(usize::MAX, cx)).await { Ok(Some(chunk)) => { self.reader.push(chunk); } @@ -108,8 +108,26 @@ where } } - pub fn into_inner(self) -> R { - self.stream + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if let Some(stream) = self.stream.take() { + stream.close(code); + } + } +} + +impl Drop for UploadReader +where + M: Upload, + R: RpcRead, +{ + fn drop(&mut self) { + if self.stream.is_some() { + self.close_inner(StreamCloseCode::CANCELLED); + } } } @@ -140,6 +158,23 @@ where PartReadStep::NeedMore => unreachable!("read_frame waits for a complete frame"), } } + + pub fn close(mut self, code: StreamCloseCode) { + self.parent.close_inner(code); + self.finished = true; + } +} + +impl Drop for UploadPart<'_, M, R> +where + M: Upload, + R: RpcRead, +{ + fn drop(&mut self) { + if !self.finished { + self.parent.close_inner(StreamCloseCode::CANCELLED); + } + } } impl UploadResponder @@ -199,9 +234,8 @@ pub(crate) async fn handle_upload_inner( state, request, UploadReader { - stream: reader, + stream: Some(reader), reader: PartFrameReader::new(buffered), - finished: false, }, UploadResponder::new(writer), ) diff --git a/ql-runtime/src/rpc/download.rs b/ql-runtime/src/rpc/download.rs index a85a20e3..0b723db2 100644 --- a/ql-runtime/src/rpc/download.rs +++ b/ql-runtime/src/rpc/download.rs @@ -26,6 +26,10 @@ where let (header, inner) = self.inner.into_reader().await?; Ok((header, DownloadReader { inner })) } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } } impl DownloadReader @@ -41,6 +45,10 @@ where .await? .map(|(header, inner)| (header, DownloadPart { inner }))) } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } } impl DownloadPart<'_, M> @@ -50,4 +58,8 @@ where pub async fn read_chunk(&mut self) -> Result, RpcError> { Ok(self.inner.read_chunk().await?) } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } } diff --git a/ql-runtime/src/rpc/progress.rs b/ql-runtime/src/rpc/progress.rs index 1c3984d4..a22da20f 100644 --- a/ql-runtime/src/rpc/progress.rs +++ b/ql-runtime/src/rpc/progress.rs @@ -16,6 +16,15 @@ pub struct ProgressCall { impl Unpin for ProgressCall where M: Progress {} +impl ProgressCall +where + M: Progress, +{ + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } +} + impl Stream for ProgressCall where M: Progress, diff --git a/ql-runtime/src/rpc/subscription.rs b/ql-runtime/src/rpc/subscription.rs index 0dfd807e..45a08a6b 100644 --- a/ql-runtime/src/rpc/subscription.rs +++ b/ql-runtime/src/rpc/subscription.rs @@ -22,6 +22,10 @@ where pub async fn next_event(&mut self) -> Option>> { poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } } impl Stream for Subscription From 179d61cd18703fa989b857ba805f838ffa9e07ec Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Fri, 15 May 2026 16:42:35 -0400 Subject: [PATCH 301/304] ql: route trait --- ql-rpc/src/rpc/download/mod.rs | 7 +++---- ql-rpc/src/rpc/duplex/mod.rs | 7 +++---- ql-rpc/src/rpc/mod.rs | 7 +++++++ ql-rpc/src/rpc/notification/mod.rs | 7 +++---- ql-rpc/src/rpc/progress/codec.rs | 7 +++++-- ql-rpc/src/rpc/progress/mod.rs | 7 +++---- ql-rpc/src/rpc/request/mod.rs | 7 +++---- ql-rpc/src/rpc/subscription/mod.rs | 7 +++---- ql-rpc/src/rpc/upload/mod.rs | 7 +++---- 9 files changed, 33 insertions(+), 30 deletions(-) diff --git a/ql-rpc/src/rpc/download/mod.rs b/ql-rpc/src/rpc/download/mod.rs index c37fc90a..5ed34aed 100644 --- a/ql-rpc/src/rpc/download/mod.rs +++ b/ql-rpc/src/rpc/download/mod.rs @@ -1,4 +1,5 @@ -use crate::{RouteId, RpcCodec}; +use super::Route; +use crate::RpcCodec; pub(crate) mod client; pub(crate) mod server; @@ -18,9 +19,7 @@ pub use crate::rpc::parts::{ /// the typed portion of the response ends at [`Self::ResponseHeader`] /// after the header is decoded, the rest of the stream is exposed as typed /// part headers followed by raw byte chunks through [`DownloadReader`] -pub trait Download { - /// route used to dispatch this rpc family - const ROUTE: RouteId; +pub trait Download: Route { /// codec error shared by request and response header values type Error; /// typed input needed to start the download diff --git a/ql-rpc/src/rpc/duplex/mod.rs b/ql-rpc/src/rpc/duplex/mod.rs index c8f3f603..a9622029 100644 --- a/ql-rpc/src/rpc/duplex/mod.rs +++ b/ql-rpc/src/rpc/duplex/mod.rs @@ -1,4 +1,5 @@ -use crate::{RouteId, RpcCodec}; +use super::Route; +use crate::RpcCodec; pub(crate) mod client; pub(crate) mod codec; @@ -13,9 +14,7 @@ pub use server::{DuplexHandler, DuplexHandlerLocal, DuplexPeer}; /// The initiator opens the routed stream. After that, either side may send any /// number of events of its directional event type until it finishes or closes /// its write side. -pub trait Duplex { - /// route used to dispatch this rpc family - const ROUTE: RouteId; +pub trait Duplex: Route { /// codec error shared by both directional event values type Error; /// typed event sent by the side that opened the stream diff --git a/ql-rpc/src/rpc/mod.rs b/ql-rpc/src/rpc/mod.rs index 7b573afb..2d84f050 100644 --- a/ql-rpc/src/rpc/mod.rs +++ b/ql-rpc/src/rpc/mod.rs @@ -5,6 +5,8 @@ //! route dispatch uses [`crate::RouteId`] and the submodules provide the matching //! client and server helpers for encoding, decoding, and handler glue +use crate::RouteId; + pub mod download; pub mod duplex; pub mod notification; @@ -15,6 +17,11 @@ pub mod subscription; pub mod upload; mod utils; +pub trait Route { + /// route used to dispatch this rpc family + const ROUTE: RouteId; +} + pub use download::Download; pub use duplex::Duplex; pub use notification::Notification; diff --git a/ql-rpc/src/rpc/notification/mod.rs b/ql-rpc/src/rpc/notification/mod.rs index d57bb3f5..4740a64f 100644 --- a/ql-rpc/src/rpc/notification/mod.rs +++ b/ql-rpc/src/rpc/notification/mod.rs @@ -1,4 +1,5 @@ -use crate::{RouteId, RpcCodec}; +use super::Route; +use crate::RpcCodec; pub(crate) mod client; pub(crate) mod server; @@ -10,9 +11,7 @@ pub use server::{NotificationHandler, NotificationHandlerLocal}; /// /// the server reads [`Self::Payload`] to eof and then closes the response side /// of the stream -pub trait Notification { - /// route used to dispatch this notification - const ROUTE: RouteId; +pub trait Notification: Route { /// codec error for the notification payload type Error; /// typed payload emitted by the caller diff --git a/ql-rpc/src/rpc/progress/codec.rs b/ql-rpc/src/rpc/progress/codec.rs index 11e0d235..a0dc1b8c 100644 --- a/ql-rpc/src/rpc/progress/codec.rs +++ b/ql-rpc/src/rpc/progress/codec.rs @@ -93,12 +93,15 @@ mod tests { use bytes::Bytes; use super::{encode_progress, encode_response, ReadStep, ResponseReader}; - use crate::{progress::Progress, RouteId}; + use crate::{progress::Progress, Route, RouteId}; struct Watch; - impl Progress for Watch { + impl Route for Watch { const ROUTE: RouteId = RouteId::from_u32(11); + } + + impl Progress for Watch { type Error = core::convert::Infallible; type Request = Vec; type Progress = Vec; diff --git a/ql-rpc/src/rpc/progress/mod.rs b/ql-rpc/src/rpc/progress/mod.rs index 1ee935e0..b21c826d 100644 --- a/ql-rpc/src/rpc/progress/mod.rs +++ b/ql-rpc/src/rpc/progress/mod.rs @@ -1,4 +1,5 @@ -use crate::{RouteId, RpcCodec}; +use super::Route; +use crate::RpcCodec; pub(crate) mod client; pub(crate) mod codec; @@ -14,9 +15,7 @@ pub use server::{ProgressHandler, ProgressHandlerLocal, ProgressResponder}; /// response frames are tagged so the client can distinguish /// [`Self::Progress`] items from the final [`Self::Response`] /// reaching eof before the final response is an error -pub trait Progress { - /// route used to dispatch this rpc family - const ROUTE: RouteId; +pub trait Progress: Route { /// codec error shared by request, progress, and response values type Error; /// typed input sent by the caller diff --git a/ql-rpc/src/rpc/request/mod.rs b/ql-rpc/src/rpc/request/mod.rs index a81ba523..adf32597 100644 --- a/ql-rpc/src/rpc/request/mod.rs +++ b/ql-rpc/src/rpc/request/mod.rs @@ -1,4 +1,5 @@ -use crate::{RouteId, RpcCodec}; +use super::Route; +use crate::RpcCodec; pub(crate) mod client; pub(crate) mod server; @@ -12,9 +13,7 @@ pub use server::{RequestHandler, RequestHandlerLocal, Response}; /// request stream after encoding [`Self::Request`] /// the response is also read to eof and rejects trailing bytes after /// [`Self::Response`] -pub trait Request { - /// route used to dispatch this rpc family - const ROUTE: RouteId; +pub trait Request: Route { /// codec error shared by request and response values type Error; /// typed input sent by the caller diff --git a/ql-rpc/src/rpc/subscription/mod.rs b/ql-rpc/src/rpc/subscription/mod.rs index f66ffa7a..672eb9bc 100644 --- a/ql-rpc/src/rpc/subscription/mod.rs +++ b/ql-rpc/src/rpc/subscription/mod.rs @@ -1,4 +1,5 @@ -use crate::{RouteId, RpcCodec}; +use super::Route; +use crate::RpcCodec; pub(crate) mod client; pub(crate) mod codec; @@ -12,9 +13,7 @@ pub use server::{SubscriptionHandler, SubscriptionHandlerLocal, SubscriptionResp /// /// event frames are length-delimited and the stream ends cleanly at eof /// any partial trailing frame is reported as truncation on the client side -pub trait Subscription { - /// route used to dispatch this rpc family - const ROUTE: RouteId; +pub trait Subscription: Route { /// codec error shared by request and event values type Error; /// typed input that starts the subscription diff --git a/ql-rpc/src/rpc/upload/mod.rs b/ql-rpc/src/rpc/upload/mod.rs index b6d2a7e1..9f96a824 100644 --- a/ql-rpc/src/rpc/upload/mod.rs +++ b/ql-rpc/src/rpc/upload/mod.rs @@ -1,4 +1,5 @@ -use crate::{RouteId, RpcCodec}; +use super::Route; +use crate::RpcCodec; pub(crate) mod client; pub(crate) mod server; @@ -13,9 +14,7 @@ pub use server::{UploadHandler, UploadHandlerLocal, UploadPart, UploadReader, Up /// the request is length-delimited so raw upload bytes can follow immediately /// once the upload reaches eof, the responder returns one typed /// [`Self::Response`] -pub trait Upload { - /// route used to dispatch this rpc family - const ROUTE: RouteId; +pub trait Upload: Route { /// codec error shared by request and response values type Error; /// typed input needed before request body bytes arrive From 415beb73d3a2dd3c791a960f7a67709eee206978 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Mon, 18 May 2026 09:39:06 -0400 Subject: [PATCH 302/304] ql-rpc: download server/client complete method --- ql-rpc/src/rpc/download/client.rs | 29 +++++++++++++--- ql-rpc/src/rpc/download/server.rs | 12 +++++++ ql-runtime/src/rpc/download.rs | 10 +++--- ql-runtime/src/tests/rpc.rs | 56 ++++++++++++++++++++++++++++++- 4 files changed, 97 insertions(+), 10 deletions(-) diff --git a/ql-rpc/src/rpc/download/client.rs b/ql-rpc/src/rpc/download/client.rs index 42793489..9a648181 100644 --- a/ql-rpc/src/rpc/download/client.rs +++ b/ql-rpc/src/rpc/download/client.rs @@ -47,7 +47,7 @@ where } } - pub async fn into_reader( + pub async fn start( mut self, ) -> Result<(M::ResponseHeader, DownloadReader), CallError> { loop { @@ -134,6 +134,29 @@ where } } + pub async fn complete(mut self) -> Result<(), CallError> { + match self.read_frame().await? { + PartReadStep::Finish => { + self.stream.take(); + Ok(()) + } + PartReadStep::PartHeader(_) => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::PartHeader.tag()).into()) + } + PartReadStep::BodyBytes(_) => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::BodyChunk.tag()).into()) + } + PartReadStep::EndPart => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::EndPart.tag()).into()) + } + PartReadStep::NeedMore => unreachable!("read_frame waits for a complete frame"), + } + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + async fn read_frame( &mut self, ) -> Result, CallError> { @@ -155,10 +178,6 @@ where } } - pub fn close(mut self, code: StreamCloseCode) { - self.close_inner(code); - } - fn close_inner(&mut self, code: StreamCloseCode) { if let Some(stream) = self.stream.take() { stream.close(code); diff --git a/ql-rpc/src/rpc/download/server.rs b/ql-rpc/src/rpc/download/server.rs index f9b3d7d0..fcdcb047 100644 --- a/ql-rpc/src/rpc/download/server.rs +++ b/ql-rpc/src/rpc/download/server.rs @@ -63,6 +63,7 @@ where } } + /// send the response header and begin streaming parts pub async fn start( mut self, response_header: M::ResponseHeader, @@ -77,6 +78,17 @@ where }) } + /// send a header-only response and finish the stream + pub async fn complete(mut self, response_header: M::ResponseHeader) -> Result<(), W::Error> { + let mut writer = self.writer.take().unwrap(); + let mut encoded = Vec::new(); + codec::encode_value_part(&response_header, &mut encoded); + encode_finish(&mut encoded); + write_bytes(&mut writer, Bytes::from(encoded)).await?; + finish_bytes(&mut writer).await + } + + /// close the stream with a transport code pub fn close(mut self, code: StreamCloseCode) { if let Some(writer) = self.writer.take() { writer.close(code); diff --git a/ql-runtime/src/rpc/download.rs b/ql-runtime/src/rpc/download.rs index 0b723db2..d3b63585 100644 --- a/ql-runtime/src/rpc/download.rs +++ b/ql-runtime/src/rpc/download.rs @@ -20,10 +20,8 @@ impl DownloadCall where M: DownloadRpc, { - pub async fn into_reader( - self, - ) -> Result<(M::ResponseHeader, DownloadReader), RpcError> { - let (header, inner) = self.inner.into_reader().await?; + pub async fn start(self) -> Result<(M::ResponseHeader, DownloadReader), RpcError> { + let (header, inner) = self.inner.start().await?; Ok((header, DownloadReader { inner })) } @@ -46,6 +44,10 @@ where .map(|(header, inner)| (header, DownloadPart { inner }))) } + pub async fn complete(self) -> Result<(), RpcError> { + self.inner.complete().await.map_err(RpcError::from) + } + pub fn close(self, code: ql_wire::StreamCloseCode) { self.inner.close(ql_rpc::StreamCloseCode(code.0)); } diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs index b8863f31..d6147c30 100644 --- a/ql-runtime/src/tests/rpc.rs +++ b/ql-runtime/src/tests/rpc.rs @@ -418,7 +418,7 @@ async fn rpc_download() { .download::(&b"logo".to_vec()) .await .unwrap(); - let (header, mut reader) = download.into_reader().await.unwrap(); + let (header, mut reader) = download.start().await.unwrap(); assert_eq!(header, b"image/png".to_vec()); { let (part_header, mut part) = reader.next_part().await.unwrap().unwrap(); @@ -453,6 +453,60 @@ async fn rpc_download() { .await; } +#[tokio::test(flavor = "current_thread")] +async fn rpc_download_complete() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>>, + } + + impl DownloadHandlerLocal for RouterState { + async fn handle( + self, + request: Vec, + download: DownloadStart, + ) { + self.seen.borrow_mut().push(request); + download.complete(b"not found".to_vec()).await.unwrap(); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let seen = Rc::new(RefCell::new(Vec::new())); + + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .download::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let download = rpc + .download::(&b"logo".to_vec()) + .await + .unwrap(); + let (header, reader) = download.start().await.unwrap(); + assert_eq!(header, b"not found".to_vec()); + reader.complete().await.unwrap(); + assert_eq!(seen.borrow().as_slice(), &[b"logo".to_vec()]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + #[tokio::test(flavor = "current_thread")] async fn rpc_upload() { #[derive(Clone)] From 468ecf110ca2f6a1340517acda4c30117747aff9 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Tue, 26 May 2026 06:16:07 -0400 Subject: [PATCH 303/304] ql: xid to qid --- QL_V2.md | 8 +-- ql-fsm/src/error.rs | 4 +- ql-fsm/src/handshake/ik.rs | 10 ++-- ql-fsm/src/handshake/kk.rs | 6 +-- ql-fsm/src/handshake/mod.rs | 6 +-- ql-fsm/src/handshake/xx.rs | 16 +++--- ql-fsm/src/pairing.rs | 10 ++-- ql-fsm/src/session/stream_parity.rs | 4 +- ql-fsm/src/session/tests.rs | 6 +-- ql-fsm/src/tests/handshake.rs | 6 +-- ql-fsm/src/tests/mod.rs | 22 ++++---- ql-fsm/src/tests/session.rs | 4 +- ql-runtime/src/driver/mod.rs | 2 +- ql-runtime/src/driver/test.rs | 8 +-- ql-runtime/src/platform.rs | 4 +- ql-runtime/src/tests/handshake.rs | 18 +++---- ql-runtime/src/tests/mod.rs | 24 ++++----- ql-runtime/src/tests/session.rs | 6 +-- ql-runtime/src/tests/stream.rs | 8 +-- ql-wire/src/handshake/ik.rs | 10 ++-- ql-wire/src/handshake/kk.rs | 8 +-- ql-wire/src/handshake/mod.rs | 14 +++-- ql-wire/src/handshake/xx.rs | 24 ++++----- ql-wire/src/identity.rs | 37 +++++++------ ql-wire/src/lib.rs | 4 +- ql-wire/src/qid.rs | 44 +++++++++++++++ ql-wire/src/testing.rs | 11 ++-- ql-wire/src/tests.rs | 84 +++++++++++++++++++++-------- ql-wire/src/xid.rs | 25 --------- 29 files changed, 250 insertions(+), 183 deletions(-) create mode 100644 ql-wire/src/qid.rs delete mode 100644 ql-wire/src/xid.rs diff --git a/QL_V2.md b/QL_V2.md index 32adf645..0062c7e6 100644 --- a/QL_V2.md +++ b/QL_V2.md @@ -25,8 +25,8 @@ QLv2 is not: ## Core terms - `peer`: one QLv2 endpoint -- `XID`: a stable 16-byte peer identifier -- `peer bundle`: public peer information: `version`, `xid`, `capabilities`, and ML-KEM public key +- `QID`: a stable 16-byte peer identifier +- `peer bundle`: public peer information: `version`, `qid`, `capabilities`, and ML-KEM public key - `pairing token`: an out-of-band secret that authorizes an `XX` pairing attempt - `pairing_id`: the visible identifier derived from a pairing token and carried on `XX` records - `session`: one live encrypted channel with directional keys and directional connection IDs @@ -71,7 +71,7 @@ Today, varints are used for: QLv2 has two routed known-peer handshakes and one pairing handshake: -- `IK` and `KK` carry a visible `sender` and `recipient` XID +- `IK` and `KK` carry a visible `sender` and `recipient` QID - `XX` carries a visible `pairing_id` #### IK @@ -296,7 +296,7 @@ A stream has two independent lanes: Important properties: - either peer can open a stream -- stream IDs are split by parity derived from XID ordering, so both peers can open streams without collision +- stream IDs are split by parity derived from QID ordering, so both peers can open streams without collision - stream IDs increase monotonically within each parity namespace and must not repeat within a session - ordering is preserved within a stream lane - different streams can make progress independently diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs index e38108b6..9bf2a915 100644 --- a/ql-fsm/src/error.rs +++ b/ql-fsm/src/error.rs @@ -17,7 +17,7 @@ pub enum ReceiveError { InvalidKkHandshake(WireError), InvalidXxHandshake(WireError), InvalidRemoteBundle, - InvalidXid, + InvalidQid, NoPeer, NoSession, NotPairingMode, @@ -42,7 +42,7 @@ impl Display for ReceiveError { Self::InvalidKkHandshake(error) => write!(f, "invalid kk handshake: {error}"), Self::InvalidXxHandshake(error) => write!(f, "invalid xx handshake: {error}"), Self::InvalidRemoteBundle => f.write_str("invalid remote bundle"), - Self::InvalidXid => f.write_str("invalid xid"), + Self::InvalidQid => f.write_str("invalid qid"), Self::NoPeer => f.write_str("no bound peer"), Self::NoSession => f.write_str("no active session"), Self::NotPairingMode => f.write_str("not in pairing mode"), diff --git a/ql-fsm/src/handshake/ik.rs b/ql-fsm/src/handshake/ik.rs index 10d210de..7e6ebd1e 100644 --- a/ql-fsm/src/handshake/ik.rs +++ b/ql-fsm/src/handshake/ik.rs @@ -36,12 +36,12 @@ pub fn handle_ik1( if should_ignore_inbound(fsm, message) { return Ok(()); } - if message.header.recipient != fsm.identity.xid { - return Err(ReceiveError::InvalidXid); + if message.header.recipient != fsm.identity.qid { + return Err(ReceiveError::InvalidQid); } if let Some(peer) = fsm.state.peer.as_ref() { - if message.header.sender != peer.xid { - return Err(ReceiveError::InvalidXid); + if message.header.sender != peer.qid { + return Err(ReceiveError::InvalidQid); } } @@ -110,7 +110,7 @@ pub fn should_ignore_inbound(fsm: &QlFsm, message: &Ik1) -> bool { | LinkState::XxInitiator(_) | LinkState::XxResponder(_) => false, LinkState::IkInitiator(state) => { - if fsm.state.peer.as_ref().map(|peer| peer.xid) != Some(message.header.sender) { + if fsm.state.peer.as_ref().map(|peer| peer.qid) != Some(message.header.sender) { return false; } super::local_start_wins(&state.initial_ephemeral, &message.ephemeral) diff --git a/ql-fsm/src/handshake/kk.rs b/ql-fsm/src/handshake/kk.rs index 140a9c5e..e78c8a6d 100644 --- a/ql-fsm/src/handshake/kk.rs +++ b/ql-fsm/src/handshake/kk.rs @@ -40,8 +40,8 @@ pub fn handle_kk1( let Some(peer) = fsm.state.peer.clone() else { return Err(ReceiveError::NoPeer); }; - if message.header.recipient != fsm.identity.xid || message.header.sender != peer.xid { - return Err(ReceiveError::InvalidXid); + if message.header.recipient != fsm.identity.qid || message.header.sender != peer.qid { + return Err(ReceiveError::InvalidQid); } reset_connected_session_if_needed(fsm); @@ -109,7 +109,7 @@ pub fn should_ignore_inbound(fsm: &QlFsm, message: &Kk1) -> bool { | LinkState::XxResponder(_) => false, LinkState::IkInitiator(_) => true, LinkState::KkInitiator(state) => { - if fsm.state.peer.as_ref().map(|peer| peer.xid) != Some(message.header.sender) { + if fsm.state.peer.as_ref().map(|peer| peer.qid) != Some(message.header.sender) { return false; } super::local_start_wins(&state.initial_ephemeral, &message.ephemeral) diff --git a/ql-fsm/src/handshake/mod.rs b/ql-fsm/src/handshake/mod.rs index e2d2a5a2..1881f66e 100644 --- a/ql-fsm/src/handshake/mod.rs +++ b/ql-fsm/src/handshake/mod.rs @@ -27,7 +27,7 @@ pub fn handle_connect_kk(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), pub fn handle_connect_xx(fsm: &mut QlFsm, invite: crate::PairingInvite, crypto: &impl QlCrypto) { prepare_for_outbound_connect(fsm); - xx::start_initiator(fsm, crypto, invite.token, invite.xid); + xx::start_initiator(fsm, crypto, invite.token, invite.qid); } pub fn next_handshake_meta(fsm: &mut QlFsm) -> HandshakeMeta { @@ -95,7 +95,7 @@ pub fn finish_handshake( transport: SessionTransport, remote_bundle: wire::PeerBundle, ) -> Result<(), ReceiveError> { - let xid = remote_bundle.xid; + let qid = remote_bundle.qid; if let Some(peer) = fsm.state.peer.as_ref() { if peer != &remote_bundle { return Err(ReceiveError::InvalidRemoteBundle); @@ -108,7 +108,7 @@ pub fn finish_handshake( let config = &fsm.config; let session = SessionFsm::new( SessionConfig { - local_parity: StreamParity::for_local(fsm.identity.xid, xid), + local_parity: StreamParity::for_local(fsm.identity.qid, qid), record_max_size: config.session_record_max_size, ack_delay: config.session_record_ack_delay, retransmit_timeout: config.session_record_retransmit_timeout, diff --git a/ql-fsm/src/handshake/xx.rs b/ql-fsm/src/handshake/xx.rs index 79653b46..c9a289e0 100644 --- a/ql-fsm/src/handshake/xx.rs +++ b/ql-fsm/src/handshake/xx.rs @@ -1,4 +1,4 @@ -use ql_wire::{self as wire, PairingToken, QlCrypto, QlHandshakeRecord, Xx1, Xx2, Xx3, Xx4, XID}; +use ql_wire::{self as wire, PairingToken, QlCrypto, QlHandshakeRecord, Xx1, Xx2, Xx3, Xx4, QID}; use super::{ emit_peer_status, enqueue_handshake, finish_handshake, reset_connected_session_if_needed, @@ -12,13 +12,13 @@ pub fn start_initiator( fsm: &mut QlFsm, crypto: &impl QlCrypto, token: PairingToken, - remote_xid: XID, + remote_qid: QID, ) { let meta = super::next_handshake_meta(fsm); let mut handshake = wire::XxHandshake::new_initiator( crypto, fsm.identity.clone(), - remote_xid, + remote_qid, token, super::local_transport_params(fsm), ); @@ -50,10 +50,10 @@ pub fn handle_xx1( }) } Some(_) - if message.header.recipient != fsm.identity.xid - || message.header.sender == fsm.identity.xid => + if message.header.recipient != fsm.identity.qid + || message.header.sender == fsm.identity.qid => { - Err(ReceiveError::InvalidXid) + Err(ReceiveError::InvalidQid) } Some(token) => { reset_connected_session_if_needed(fsm); @@ -196,8 +196,8 @@ pub fn should_ignore_inbound(fsm: &QlFsm, crypto: &impl QlCrypto, message: &Xx1) if state.handshake.pairing_id(crypto) != message.pairing_id { return false; } - if message.header.recipient != fsm.identity.xid - || message.header.sender != state.handshake.remote_xid() + if message.header.recipient != fsm.identity.qid + || message.header.sender != state.handshake.remote_qid() { return false; } diff --git a/ql-fsm/src/pairing.rs b/ql-fsm/src/pairing.rs index fef19954..4b8361b8 100644 --- a/ql-fsm/src/pairing.rs +++ b/ql-fsm/src/pairing.rs @@ -1,15 +1,15 @@ -use ql_wire::{ByteSlice, PairingToken, Reader, WireDecode, WireEncode, WireError, XID}; +use ql_wire::{ByteSlice, PairingToken, Reader, WireDecode, WireEncode, WireError, QID}; /// Out-of-band invite consumed by the initiator of an XX pairing #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct PairingInvite { - pub xid: XID, + pub qid: QID, pub token: PairingToken, } impl PairingInvite { pub const VERSION: u8 = 1; - pub const WIRE_SIZE: usize = size_of::() + XID::SIZE + PairingToken::SIZE; + pub const WIRE_SIZE: usize = size_of::() + QID::SIZE + PairingToken::SIZE; } impl WireEncode for PairingInvite { @@ -19,7 +19,7 @@ impl WireEncode for PairingInvite { fn encode(&self, out: &mut W) { Self::VERSION.encode(out); - self.xid.encode(out); + self.qid.encode(out); self.token.encode(out); } } @@ -31,7 +31,7 @@ impl WireDecode for PairingInvite { } Ok(Self { - xid: reader.decode()?, + qid: reader.decode()?, token: reader.decode()?, }) } diff --git a/ql-fsm/src/session/stream_parity.rs b/ql-fsm/src/session/stream_parity.rs index 87c9ef33..70f60776 100644 --- a/ql-fsm/src/session/stream_parity.rs +++ b/ql-fsm/src/session/stream_parity.rs @@ -1,4 +1,4 @@ -use ql_wire::{StreamId, XID}; +use ql_wire::{StreamId, QID}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum StreamParity { @@ -7,7 +7,7 @@ pub enum StreamParity { } impl StreamParity { - pub fn for_local(local: XID, peer: XID) -> Self { + pub fn for_local(local: QID, peer: QID) -> Self { match local.0.cmp(&peer.0) { std::cmp::Ordering::Less | std::cmp::Ordering::Equal => Self::Even, std::cmp::Ordering::Greater => Self::Odd, diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs index 2226753f..f1f29879 100644 --- a/ql-fsm/src/session/tests.rs +++ b/ql-fsm/src/session/tests.rs @@ -4,7 +4,7 @@ use bytes::Bytes; use ql_wire::{ decode_session_frames, parse_session_frames, CloseTarget, RecordAck, RecordSeq, RouteId, SessionFrame, SessionRecordBuilder, StreamClose, StreamCloseCode, StreamData, StreamHeader, - StreamId, VarInt, XID, + StreamId, VarInt, QID, }; use super::{SessionConfig, SessionEvent, SessionFsm}; @@ -429,8 +429,8 @@ fn remote_stream_close_is_reliable_and_retried() { #[test] fn stream_ids_follow_even_odd_xid_ordering() { let now = Instant::now(); - let even = StreamParity::for_local(XID([1; XID::SIZE]), XID([2; XID::SIZE])); - let odd = StreamParity::for_local(XID([2; XID::SIZE]), XID([1; XID::SIZE])); + let even = StreamParity::for_local(QID([1; QID::SIZE]), QID([2; QID::SIZE])); + let odd = StreamParity::for_local(QID([2; QID::SIZE]), QID([1; QID::SIZE])); let even_id = SessionFsm::new( SessionConfig { diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index 301b9680..eae4f4d3 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -96,7 +96,7 @@ fn ik_connect_learns_remote_initial_stream_receive_window() { #[test] fn connect_methods_require_bound_peer() { let time = Harness::paired_known(QlFsmConfig::default()).time(); - let identity = test_identity(&SoftwareCrypto); + let identity = generate_identity(&SoftwareCrypto); let mut fsm = QlFsm::new(QlFsmConfig::default(), identity, time); let crypto = SoftwareCrypto; @@ -106,7 +106,7 @@ fn connect_methods_require_bound_peer() { fsm.connect_xx( time, PairingInvite { - xid: ql_wire::XID([2; ql_wire::XID::SIZE]), + qid: ql_wire::QID([2; ql_wire::QID::SIZE]), token: pairing_token(2), }, &crypto, @@ -343,7 +343,7 @@ fn bind_peer_clears_queued_handshake_output() { harness .a .fsm - .bind_peer(test_identity(&SoftwareCrypto).bundle()); + .bind_peer(generate_identity(&SoftwareCrypto).bundle()); assert!(harness.drain_events(Side::A).is_empty()); assert!(harness.next_outbound(Side::A).is_none()); diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs index 703070cf..c4d005e4 100644 --- a/ql-fsm/src/tests/mod.rs +++ b/ql-fsm/src/tests/mod.rs @@ -5,8 +5,8 @@ mod session; use std::time::{Duration, Instant}; use ql_wire::{ - self, test_identities, test_identity, ConnectionId, PairingToken, QlCrypto, SessionKey, - SoftwareCrypto, TransportParams, XID, + self, generate_identity, test_identities, ConnectionId, PairingToken, QlCrypto, SessionKey, + SoftwareCrypto, TransportParams, QID, }; use crate::{ @@ -199,22 +199,22 @@ impl Harness { fn connect_xx(&mut self, side: Side, token: PairingToken) { let time = self.time(); - let remote_xid = self.remote_xid(side); + let remote_qid = self.remote_qid(side); let Node { fsm, crypto } = self.node_mut(side); fsm.connect_xx( time, PairingInvite { - xid: remote_xid, + qid: remote_qid, token, }, crypto, ); } - fn remote_xid(&self, side: Side) -> XID { + fn remote_qid(&self, side: Side) -> QID { match side { - Side::A => self.b.fsm.identity.xid, - Side::B => self.a.fsm.identity.xid, + Side::A => self.b.fsm.identity.qid, + Side::B => self.a.fsm.identity.qid, } } @@ -299,14 +299,14 @@ fn pairing_token(byte: u8) -> PairingToken { fn session_config(harness: &Harness, a: bool) -> SessionConfig { let (local, peer, config) = if a { ( - harness.a.fsm.identity.xid, - harness.a.fsm.state.peer.as_ref().unwrap().xid, + harness.a.fsm.identity.qid, + harness.a.fsm.state.peer.as_ref().unwrap().qid, harness.a.fsm.config, ) } else { ( - harness.b.fsm.identity.xid, - harness.b.fsm.state.peer.as_ref().unwrap().xid, + harness.b.fsm.identity.qid, + harness.b.fsm.state.peer.as_ref().unwrap().qid, harness.b.fsm.config, ) }; diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs index edd2cf44..c55e51c1 100644 --- a/ql-fsm/src/tests/session.rs +++ b/ql-fsm/src/tests/session.rs @@ -150,11 +150,11 @@ fn simultaneous_opens_use_even_and_odd_stream_ids() { assert_ne!(stream_id_a, stream_id_b); assert!( - StreamParity::for_local(harness.a.fsm.identity.xid, harness.b.fsm.identity.xid) + StreamParity::for_local(harness.a.fsm.identity.qid, harness.b.fsm.identity.qid) .matches(stream_id_a) ); assert!( - StreamParity::for_local(harness.b.fsm.identity.xid, harness.a.fsm.identity.xid) + StreamParity::for_local(harness.b.fsm.identity.qid, harness.a.fsm.identity.qid) .matches(stream_id_b) ); diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs index d68bfce2..35de1bf0 100644 --- a/ql-runtime/src/driver/mod.rs +++ b/ql-runtime/src/driver/mod.rs @@ -305,7 +305,7 @@ impl DriverState { } } Event::PeerStatusChanged(status) => { - let peer = fsm.peer().map(|peer| peer.xid); + let peer = fsm.peer().map(|peer| peer.qid); log::info!("peer status changed: peer={peer:?} status={status:?}"); if status == ql_fsm::PeerStatus::Unpaired { for (_, mut stream) in self.streams.drain() { diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index 9b325c08..db0116ea 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -1,4 +1,4 @@ -use ql_wire::{test_identity, NoopCrypto, PeerBundle, SoftwareCrypto, StreamClose, XID}; +use ql_wire::{generate_identity, NoopCrypto, PeerBundle, SoftwareCrypto, StreamClose, QID}; use super::*; use crate::{ @@ -37,7 +37,7 @@ impl QlPlatform for NoopCrypto { fn persist_peer(&self, _peer: PeerBundle) {} - fn handle_peer_status(&self, _peer: Option, _status: ql_fsm::PeerStatus) {} + fn handle_peer_status(&self, _peer: Option, _status: ql_fsm::PeerStatus) {} fn handle_inbound(&self, _event: QlStream) {} } @@ -58,7 +58,7 @@ fn new_driver_state() -> (DriverState, QlFsm) { }, QlFsm::new( ql_fsm::QlFsmConfig::default(), - test_identity(&SoftwareCrypto), + generate_identity(&SoftwareCrypto), Instant::now(), ), ) @@ -180,7 +180,7 @@ fn local_close_command_reaps_when_other_half_is_already_closed() { #[test] fn unpaired_status_fails_and_reaps_all_streams() { let (mut state, mut fsm) = new_driver_state(); - let peer = test_identity(&SoftwareCrypto).bundle(); + let peer = generate_identity(&SoftwareCrypto).bundle(); let stream_id = StreamId(1u32.into()); let (runtime_tx, _runtime_rx) = async_channel::unbounded(); let (_, _, reader_io, writer_io) = io::new_stream( diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs index a0020789..331bfe7a 100644 --- a/ql-runtime/src/platform.rs +++ b/ql-runtime/src/platform.rs @@ -6,7 +6,7 @@ use std::{ }; use ql_fsm::{PeerStatus, ReceiveError}; -use ql_wire::{PeerBundle, QlCrypto, XID}; +use ql_wire::{PeerBundle, QlCrypto, QID}; use crate::QlStream; @@ -37,7 +37,7 @@ pub trait QlPlatform: QlCrypto { fn persist_peer(&self, peer: PeerBundle); - fn handle_peer_status(&self, peer: Option, status: PeerStatus); + fn handle_peer_status(&self, peer: Option, status: PeerStatus); fn handle_inbound(&self, event: QlStream); fn handle_recv_error(&self, _error: ReceiveError) {} } diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs index 416ec842..65731bbc 100644 --- a/ql-runtime/src/tests/handshake.rs +++ b/ql-runtime/src/tests/handshake.rs @@ -48,7 +48,7 @@ async fn handshake_timeout_disconnects() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, Some(identity_b.xid), PeerStatus::Disconnected).await; + await_status(&status_a, Some(identity_b.qid), PeerStatus::Disconnected).await; }) .await; } @@ -75,8 +75,8 @@ async fn rejected_session_write_is_reissued() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, Some(identity_b.xid), PeerStatus::Connected).await; - await_status(&status_b, Some(identity_a.xid), PeerStatus::Connected).await; + await_status(&status_a, Some(identity_b.qid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.qid), PeerStatus::Connected).await; let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); @@ -104,7 +104,7 @@ async fn rejected_session_write_is_reissued() { assert_no_status_for( &status_a, - Some(identity_b.xid), + Some(identity_b.qid), PeerStatus::Disconnected, Duration::from_millis(150), ) @@ -133,12 +133,12 @@ async fn start_pairing_round_trip_connects_when_armed() { handle_b.arm_pairing(token); handle_a.start_pairing(PairingInvite { - xid: identity_b.xid, + qid: identity_b.qid, token, }); - await_status(&status_a, Some(identity_b.xid), PeerStatus::Connected).await; - await_status(&status_b, Some(identity_a.xid), PeerStatus::Connected).await; + await_status(&status_a, Some(identity_b.qid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.qid), PeerStatus::Connected).await; }) .await; } @@ -162,13 +162,13 @@ async fn start_pairing_does_not_connect_when_unarmed() { spawn_forwarder(outbound_b, inbound_a_tx); handle_a.start_pairing(PairingInvite { - xid: identity_b.xid, + qid: identity_b.qid, token, }); assert_no_status_for( &status_a, - Some(identity_b.xid), + Some(identity_b.qid), PeerStatus::Connected, Duration::from_millis(150), ) diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index 71083724..e22e5041 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -13,9 +13,9 @@ use async_channel::{Receiver, Sender}; use futures_lite::Stream; use ql_fsm::PeerStatus; use ql_wire::{ - test_identities, test_identity, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, - Nonce, PairingToken, PeerBundle, QlAead, QlHash, QlIdentity, QlKem, QlRandom, RecordHeader, - RecordType, RouteId, SessionKey, SoftwareCrypto, WireDecode, XID, + generate_identity, test_identities, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, + MlKemPublicKey, Nonce, PairingToken, PeerBundle, QlAead, QlHash, QlIdentity, QlKem, QlRandom, + RecordHeader, RecordType, RouteId, SessionKey, SoftwareCrypto, WireDecode, QID, }; use tokio::{task::LocalSet, time::Sleep}; @@ -43,7 +43,7 @@ fn init_test_logger() { #[derive(Debug, Clone, Copy, PartialEq, Eq)] struct StatusEvent { - peer: Option, + peer: Option, status: PeerStatus, } @@ -175,7 +175,7 @@ impl TestPlatform { struct TestSide { handle: RuntimeHandle, status: Receiver, - peer: XID, + peer: QID, inbound: Receiver, } @@ -259,13 +259,13 @@ impl TestPair { a: TestSide { handle: handle_a, status: status_a, - peer: identity_a.xid, + peer: identity_a.qid, inbound: inbound_a, }, b: TestSide { handle: handle_b, status: status_b, - peer: identity_b.xid, + peer: identity_b.qid, inbound: inbound_b, }, }, @@ -439,7 +439,7 @@ impl crate::platform::QlPlatform for TestPlatform { fn persist_peer(&self, _peer: PeerBundle) {} - fn handle_peer_status(&self, peer: Option, status: PeerStatus) { + fn handle_peer_status(&self, peer: Option, status: PeerStatus) { let _ = self.status.try_send(StatusEvent { peer, status }); } @@ -608,7 +608,7 @@ where .unwrap_or_else(|_| panic!("local runtime test exceeded {duration:?}")); } -async fn await_status(receiver: &Receiver, peer: Option, stage: PeerStatus) { +async fn await_status(receiver: &Receiver, peer: Option, stage: PeerStatus) { tokio::time::timeout(Duration::from_secs(2), async { loop { if let Ok(event) = receiver.recv().await { @@ -624,7 +624,7 @@ async fn await_status(receiver: &Receiver, peer: Option, stage async fn assert_no_status_for( receiver: &Receiver, - peer: Option, + peer: Option, status: PeerStatus, window: Duration, ) { @@ -679,7 +679,7 @@ fn default_runtime_config() -> RuntimeConfig { #[test] fn runtime_is_send() { let config = default_runtime_config(); - let identity = test_identity(&SoftwareCrypto); + let identity = generate_identity(&SoftwareCrypto); let (platform, _, _, _) = TestPlatform::new(); let (runtime, _handle) = new_runtime(identity, platform, config); let _run: Box + Send> = Box::new(runtime.run()); @@ -688,7 +688,7 @@ fn runtime_is_send() { #[test] fn runtime_exits_when_last_handle_drops() { let config = default_runtime_config(); - let identity = test_identity(&SoftwareCrypto); + let identity = generate_identity(&SoftwareCrypto); let (platform, _, _, _) = TestPlatform::new(); let (runtime, handle) = new_runtime(identity, platform, config); let (done_tx, done_rx) = oneshot::channel(); diff --git a/ql-runtime/src/tests/session.rs b/ql-runtime/src/tests/session.rs index 6b185d24..ec351e35 100644 --- a/ql-runtime/src/tests/session.rs +++ b/ql-runtime/src/tests/session.rs @@ -183,8 +183,8 @@ async fn session_timeout_disconnects_and_fails_pending_open() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, Some(identity_b.xid), PeerStatus::Connected).await; - await_status(&status_b, Some(identity_a.xid), PeerStatus::Connected).await; + await_status(&status_a, Some(identity_b.qid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.qid), PeerStatus::Connected).await; let responder_task = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); @@ -199,7 +199,7 @@ async fn session_timeout_disconnects_and_fails_pending_open() { let err = pending.writer.finish().await.unwrap_err(); assert!(matches!(err, QlStreamError::NoSession)); - await_status(&status_a, Some(identity_b.xid), PeerStatus::Disconnected).await; + await_status(&status_a, Some(identity_b.qid), PeerStatus::Disconnected).await; let result = tokio::time::timeout(Duration::from_millis(300), next_chunk(&mut pending.reader)) diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs index 008a3c6c..176711c8 100644 --- a/ql-runtime/src/tests/stream.rs +++ b/ql-runtime/src/tests/stream.rs @@ -306,8 +306,8 @@ async fn max_concurrent_message_writes_is_respected() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, Some(identity_b.xid), PeerStatus::Connected).await; - await_status(&status_b, Some(identity_a.xid), PeerStatus::Connected).await; + await_status(&status_a, Some(identity_b.qid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.qid), PeerStatus::Connected).await; let responder = tokio::task::spawn_local(async move { for _ in 0..4 { @@ -381,8 +381,8 @@ async fn stream_round_trip_survives_encrypted_packet_drops() { register_peers(&handle_a, &handle_b, &identity_a, &identity_b); handle_a.connect(); - await_status(&status_a, Some(identity_b.xid), PeerStatus::Connected).await; - await_status(&status_b, Some(identity_a.xid), PeerStatus::Connected).await; + await_status(&status_a, Some(identity_b.qid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.qid), PeerStatus::Connected).await; let responder = tokio::task::spawn_local(async move { let stream = inbound_b.recv().await.unwrap(); diff --git a/ql-wire/src/handshake/ik.rs b/ql-wire/src/handshake/ik.rs index 03bdc032..26fa5621 100644 --- a/ql-wire/src/handshake/ik.rs +++ b/ql-wire/src/handshake/ik.rs @@ -173,13 +173,13 @@ impl IkHandshake { fn outbound_header(&self) -> Result { let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; Ok(HandshakeHeader { - sender: self.local.xid, - recipient: remote_bundle.xid, + sender: self.local.qid, + recipient: remote_bundle.qid, }) } fn ensure_inbound_recipient(&self, header: HandshakeHeader) -> Result<(), WireError> { - if header.recipient == self.local.xid { + if header.recipient == self.local.qid { Ok(()) } else { Err(WireError::InvalidPayload) @@ -188,7 +188,7 @@ impl IkHandshake { fn ensure_known_remote_sender(&self, header: HandshakeHeader) -> Result<(), WireError> { if let Some(remote_bundle) = self.remote_bundle.as_ref() { - if header.sender != remote_bundle.xid { + if header.sender != remote_bundle.qid { return Err(WireError::InvalidPayload); } } @@ -310,7 +310,7 @@ impl IkHandshake { let remote_bundle = decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; - if remote_bundle.xid != message.header.sender { + if remote_bundle.qid != message.header.sender { return Err(WireError::InvalidPayload); } match self.remote_bundle.as_ref() { diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs index a56fca83..2ad5ee2a 100644 --- a/ql-wire/src/handshake/kk.rs +++ b/ql-wire/src/handshake/kk.rs @@ -167,15 +167,15 @@ impl KkHandshake { fn outbound_header(&self) -> HandshakeHeader { HandshakeHeader { - sender: self.local.xid, - recipient: self.remote_bundle.xid, + sender: self.local.qid, + recipient: self.remote_bundle.qid, } } fn inbound_header(&self) -> HandshakeHeader { HandshakeHeader { - sender: self.remote_bundle.xid, - recipient: self.local.xid, + sender: self.remote_bundle.qid, + recipient: self.local.qid, } } diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index fcbe02c4..fcd5ccaf 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -1,7 +1,7 @@ use crate::{ codec, ByteSlice, ConnectionId, HandshakeKind, MlKemCiphertext, MlKemKeyPair, MlKemPublicKey, Nonce, PeerBundle, QlCrypto, SessionKey, WireDecode, WireEncode, WireError, - ENCRYPTED_MESSAGE_AUTH_SIZE, XID, + ENCRYPTED_MESSAGE_AUTH_SIZE, QID, }; mod ik; @@ -27,12 +27,12 @@ const HANDSHAKE_PREAMBLE_DOMAIN: &[u8] = b"ql-wire:handshake-preamble:v1"; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct HandshakeHeader { - pub sender: XID, - pub recipient: XID, + pub sender: QID, + pub recipient: QID, } impl HandshakeHeader { - pub const WIRE_SIZE: usize = XID::SIZE * 2; + pub const WIRE_SIZE: usize = QID::SIZE * 2; } impl WireEncode for HandshakeHeader { @@ -473,7 +473,11 @@ fn decrypt_peer_bundle( bundle: &EncryptedPeerBundle, ) -> Result { let plaintext = symmetric.decrypt_and_hash(crypto, bundle.as_bytes())?; - PeerBundle::decode_exact(plaintext.as_slice()) + let bundle = PeerBundle::decode_exact(plaintext.as_slice())?; + if !bundle.qid_matches_public_key(crypto) { + return Err(WireError::InvalidRemoteBundle); + } + Ok(bundle) } fn encrypt_mlkem_ciphertext( diff --git a/ql-wire/src/handshake/xx.rs b/ql-wire/src/handshake/xx.rs index fb35e861..b688d757 100644 --- a/ql-wire/src/handshake/xx.rs +++ b/ql-wire/src/handshake/xx.rs @@ -8,7 +8,7 @@ use super::{ }; use crate::{ codec, ByteSlice, HandshakeKind, HandshakeMeta, MlKemCiphertext, PairingId, PairingToken, - PeerBundle, QlCrypto, QlIdentity, WireEncode, WireError, XID, + PeerBundle, QlCrypto, QlIdentity, WireEncode, WireError, QID, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -210,7 +210,7 @@ pub struct XxHandshake { step: XxStep, symmetric: SymmetricState, local: QlIdentity, - remote_xid: XID, + remote_qid: QID, pairing_token: PairingToken, remote_bundle: Option, local_ephemeral: Option, @@ -224,7 +224,7 @@ impl XxHandshake { pub fn new_initiator( crypto: &impl QlCrypto, local: QlIdentity, - remote_xid: XID, + remote_qid: QID, pairing_token: PairingToken, local_transport_params: TransportParams, ) -> Self { @@ -233,7 +233,7 @@ impl XxHandshake { step: XxStep::Send1, symmetric: init_xx_symmetric(crypto), local, - remote_xid, + remote_qid, pairing_token, remote_bundle: None, local_ephemeral: None, @@ -247,7 +247,7 @@ impl XxHandshake { pub fn new_responder( crypto: &impl QlCrypto, local: QlIdentity, - remote_xid: XID, + remote_qid: QID, pairing_token: PairingToken, local_transport_params: TransportParams, ) -> Self { @@ -256,7 +256,7 @@ impl XxHandshake { step: XxStep::Recv1, symmetric: init_xx_symmetric(crypto), local, - remote_xid, + remote_qid, pairing_token, remote_bundle: None, local_ephemeral: None, @@ -279,8 +279,8 @@ impl XxHandshake { self.pairing_token.id(crypto) } - pub fn remote_xid(&self) -> XID { - self.remote_xid + pub fn remote_qid(&self) -> QID { + self.remote_qid } pub fn remote_bundle(&self) -> Option<&PeerBundle> { @@ -289,8 +289,8 @@ impl XxHandshake { fn header(&self) -> HandshakeHeader { HandshakeHeader { - sender: self.local.xid, - recipient: self.remote_xid, + sender: self.local.qid, + recipient: self.remote_qid, } } @@ -300,7 +300,7 @@ impl XxHandshake { header: HandshakeHeader, pairing_id: PairingId, ) -> Result<(), WireError> { - if header.sender != self.remote_xid || header.recipient != self.local.xid { + if header.sender != self.remote_qid || header.recipient != self.local.qid { return Err(WireError::InvalidHandshakeHeader); } if pairing_id != self.pairing_token.id(crypto) { @@ -310,7 +310,7 @@ impl XxHandshake { } fn ensure_remote_bundle(&self, bundle: &PeerBundle) -> Result<(), WireError> { - if bundle.xid == self.remote_xid { + if bundle.qid == self.remote_qid { Ok(()) } else { Err(WireError::InvalidRemoteBundle) diff --git a/ql-wire/src/identity.rs b/ql-wire/src/identity.rs index 1de12a01..3a358401 100644 --- a/ql-wire/src/identity.rs +++ b/ql-wire/src/identity.rs @@ -1,20 +1,26 @@ use crate::{ - codec, ByteSlice, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, QlCrypto, WireEncode, - WireError, XID, + codec, ByteSlice, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, QlCrypto, QlHash, WireEncode, + WireError, QID, }; #[derive(Debug, Clone, PartialEq, Eq)] pub struct PeerBundle { pub version: u16, - pub xid: XID, + pub qid: QID, pub capabilities: u32, pub mlkem_public_key: MlKemPublicKey, + // todo: add + // pub name: String } impl PeerBundle { pub const VERSION: u16 = 1; pub const WIRE_SIZE: usize = - size_of::() + XID::SIZE + size_of::() + MlKemPublicKey::SIZE; + size_of::() + QID::SIZE + size_of::() + MlKemPublicKey::SIZE; + + pub fn qid_matches_public_key(&self, crypto: &impl QlHash) -> bool { + self.qid.matches_public_key(crypto, &self.mlkem_public_key) + } } impl WireEncode for PeerBundle { @@ -24,7 +30,7 @@ impl WireEncode for PeerBundle { fn encode(&self, out: &mut W) { self.version.encode(out); - self.xid.encode(out); + self.qid.encode(out); self.capabilities.encode(out); self.mlkem_public_key.encode(out); } @@ -34,7 +40,7 @@ impl codec::WireDecode for PeerBundle { fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { version: reader.decode()?, - xid: reader.decode()?, + qid: reader.decode()?, capabilities: reader.decode()?, mlkem_public_key: reader.decode()?, }) @@ -43,7 +49,7 @@ impl codec::WireDecode for PeerBundle { #[derive(Debug, Clone)] pub struct QlIdentity { - pub xid: XID, + pub qid: QID, pub mlkem_private_key: MlKemPrivateKey, pub mlkem_public_key: MlKemPublicKey, pub capabilities: u32, @@ -51,15 +57,16 @@ pub struct QlIdentity { impl QlIdentity { pub const WIRE_SIZE: usize = - XID::SIZE + MlKemPrivateKey::SIZE + MlKemPublicKey::SIZE + size_of::(); + QID::SIZE + MlKemPrivateKey::SIZE + MlKemPublicKey::SIZE + size_of::(); pub fn new( - xid: XID, + crypto: &impl QlHash, mlkem_private_key: MlKemPrivateKey, mlkem_public_key: MlKemPublicKey, ) -> Self { + let qid = QID::derive(crypto, &mlkem_public_key); Self { - xid, + qid, mlkem_private_key, mlkem_public_key, capabilities: 0, @@ -75,7 +82,7 @@ impl QlIdentity { pub fn bundle(&self) -> PeerBundle { PeerBundle { version: PeerBundle::VERSION, - xid: self.xid, + qid: self.qid, capabilities: self.capabilities, mlkem_public_key: self.mlkem_public_key.clone(), } @@ -88,7 +95,7 @@ impl WireEncode for QlIdentity { } fn encode(&self, out: &mut W) { - self.xid.encode(out); + self.qid.encode(out); self.mlkem_private_key.as_bytes().encode(out); self.mlkem_public_key.encode(out); self.capabilities.encode(out); @@ -98,7 +105,7 @@ impl WireEncode for QlIdentity { impl codec::WireDecode for QlIdentity { fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { - xid: reader.decode()?, + qid: reader.decode()?, mlkem_private_key: MlKemPrivateKey::new(reader.decode()?), mlkem_public_key: reader.decode()?, capabilities: reader.decode()?, @@ -106,10 +113,10 @@ impl codec::WireDecode for QlIdentity { } } -pub fn generate_identity(crypto: &impl QlCrypto, xid: XID) -> QlIdentity { +pub fn generate_identity(crypto: &impl QlCrypto) -> QlIdentity { let MlKemKeyPair { private: mlkem_private_key, public: mlkem_public_key, } = crypto.mlkem_generate_keypair(); - QlIdentity::new(xid, mlkem_private_key, mlkem_public_key) + QlIdentity::new(crypto, mlkem_private_key, mlkem_public_key) } diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs index 63b5e633..2ad8a171 100644 --- a/ql-wire/src/lib.rs +++ b/ql-wire/src/lib.rs @@ -15,11 +15,11 @@ mod header; mod identity; mod nonce; mod pq; +mod qid; mod record; #[cfg(any(feature = "test-utils", test))] mod testing; mod varint; -mod xid; pub use bytes::*; pub use codec::*; @@ -32,11 +32,11 @@ pub use header::*; pub use identity::*; pub use nonce::*; pub use pq::*; +pub use qid::*; pub use record::*; #[cfg(any(feature = "test-utils", test))] pub use testing::*; pub use varint::*; -pub use xid::*; pub const QL_WIRE_VERSION: u8 = 1; pub const ENCRYPTED_MESSAGE_AUTH_SIZE: usize = 16; diff --git a/ql-wire/src/qid.rs b/ql-wire/src/qid.rs new file mode 100644 index 00000000..55c6684f --- /dev/null +++ b/ql-wire/src/qid.rs @@ -0,0 +1,44 @@ +use crate::{codec, ByteSlice, MlKemPublicKey, QlHash, WireEncode, WireError, ML_KEM_SUITE_TAG}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct QID(pub [u8; Self::SIZE]); + +impl QID { + pub const SIZE: usize = 16; + + pub fn derive(crypto: &impl QlHash, mlkem_public_key: &MlKemPublicKey) -> Self { + let digest = crypto.sha256(&[ + b"quantum-link qid v1", + ML_KEM_SUITE_TAG, + mlkem_public_key.as_bytes(), + ]); + let mut qid = [0u8; Self::SIZE]; + qid.copy_from_slice(&digest[..Self::SIZE]); + Self(qid) + } + + pub fn matches_public_key( + &self, + crypto: &impl QlHash, + mlkem_public_key: &MlKemPublicKey, + ) -> bool { + *self == Self::derive(crypto, mlkem_public_key) + } +} + +impl WireEncode for QID { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for QID { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} diff --git a/ql-wire/src/testing.rs b/ql-wire/src/testing.rs index 83b4fbde..11b251c4 100644 --- a/ql-wire/src/testing.rs +++ b/ql-wire/src/testing.rs @@ -4,7 +4,7 @@ use sha2::{Digest, Sha256}; use crate::{ MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, QlAead, QlCrypto, - QlHash, QlIdentity, QlKem, QlRandom, SessionKey, ENCRYPTED_MESSAGE_AUTH_SIZE, XID, + QlHash, QlIdentity, QlKem, QlRandom, SessionKey, ENCRYPTED_MESSAGE_AUTH_SIZE, }; #[derive(Debug, Default, Clone, Copy)] @@ -13,12 +13,11 @@ pub struct SoftwareCrypto; #[derive(Debug, Default, Clone, Copy)] pub struct NoopCrypto; -pub fn test_identity(crypto: &impl QlCrypto) -> QlIdentity { - crate::generate_identity(crypto, XID(random_array(crypto))) -} - pub fn test_identities(crypto: &impl QlCrypto) -> (QlIdentity, QlIdentity) { - (test_identity(crypto), test_identity(crypto)) + ( + crate::generate_identity(crypto), + crate::generate_identity(crypto), + ) } impl QlRandom for SoftwareCrypto { diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index ae4660ff..4298ce7b 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -11,8 +11,8 @@ fn decode_session_record(bytes: &[u8]) -> QlSessionRecord> { record.into_owned() } -fn xid(byte: u8) -> XID { - XID([byte; XID::SIZE]) +fn qid(byte: u8) -> QID { + QID([byte; QID::SIZE]) } fn varint(value: u64) -> VarInt { @@ -45,8 +45,8 @@ fn handshake_transport_params(window: u32) -> TransportParams { fn handshake_header(sender: u8, recipient: u8) -> HandshakeHeader { HandshakeHeader { - sender: xid(sender), - recipient: xid(recipient), + sender: qid(sender), + recipient: qid(recipient), } } @@ -60,8 +60,8 @@ fn pairing_id(byte: u8) -> PairingId { fn xx_header(sender: u8, recipient: u8) -> HandshakeHeader { HandshakeHeader { - sender: xid(sender), - recipient: xid(recipient), + sender: qid(sender), + recipient: qid(recipient), } } @@ -86,7 +86,7 @@ fn encrypt_record( #[test] fn peer_bundle_round_trip() { let crypto = SoftwareCrypto; - let identity = test_identity(&crypto).with_capabilities(0x55aa_33cc); + let identity = generate_identity(&crypto).with_capabilities(0x55aa_33cc); let bundle = identity.bundle(); let encoded = bundle.encode_vec(); @@ -95,6 +95,44 @@ fn peer_bundle_round_trip() { assert_eq!(decoded, bundle); } +#[test] +fn qid_derives_from_mlkem_public_key() { + let crypto = SoftwareCrypto; + let public_key = MlKemPublicKey::new(Box::new([42; MlKemPublicKey::SIZE])); + let qid = QID::derive(&crypto, &public_key); + + let digest = crypto.sha256(&[ + b"quantum-link qid v1", + ML_KEM_SUITE_TAG, + public_key.as_bytes(), + ]); + let mut expected = [0u8; QID::SIZE]; + expected.copy_from_slice(&digest[..QID::SIZE]); + + assert_eq!(qid, QID(expected)); + assert!(qid.matches_public_key(&crypto, &public_key)); +} + +#[test] +fn qid_changes_when_mlkem_public_key_changes() { + let crypto = SoftwareCrypto; + let first = MlKemPublicKey::new(Box::new([1; MlKemPublicKey::SIZE])); + let second = MlKemPublicKey::new(Box::new([2; MlKemPublicKey::SIZE])); + + assert_ne!(QID::derive(&crypto, &first), QID::derive(&crypto, &second)); +} + +#[test] +fn peer_bundle_detects_tampered_qid() { + let crypto = SoftwareCrypto; + let identity = generate_identity(&crypto); + let mut bundle = identity.bundle(); + + bundle.qid = qid(9); + + assert!(!bundle.qid_matches_public_key(&crypto)); +} + #[test] fn handshake_record_round_trip_supports_ik_kk_and_xx() { let ik = QlHandshakeRecord::Ik1(Ik1 { @@ -267,7 +305,7 @@ fn ik_handshake_rejects_tampered_handshake_header() { let mut m1 = initiator_state .write_1(&crypto, handshake_meta(90)) .unwrap(); - m1.header.sender = xid(9); + m1.header.sender = qid(9); assert_eq!( responder_state.read_1(&crypto, &m1), @@ -279,7 +317,7 @@ fn ik_handshake_rejects_tampered_handshake_header() { fn ik_handshake_rejects_bound_remote_bundle_mismatch() { let crypto = SoftwareCrypto; let (initiator, responder) = test_identities(&crypto); - let bogus = test_identity(&crypto); + let bogus = generate_identity(&crypto); let mut initiator_state = IkHandshake::new_initiator( &crypto, @@ -502,14 +540,14 @@ fn xx_handshake_rejects_tampered_pairing_id() { let mut initiator_state = XxHandshake::new_initiator( &crypto, initiator.clone(), - responder.xid, + responder.qid, token, TransportParams::default(), ); let mut responder_state = XxHandshake::new_responder( &crypto, responder, - initiator.xid, + initiator.qid, token, TransportParams::default(), ); @@ -534,14 +572,14 @@ fn xx_handshake_rejects_tampered_sender_or_recipient() { let mut initiator_state = XxHandshake::new_initiator( &crypto, initiator.clone(), - responder.xid, + responder.qid, token, TransportParams::default(), ); let mut responder_state = XxHandshake::new_responder( &crypto, responder.clone(), - initiator.xid, + initiator.qid, token, TransportParams::default(), ); @@ -549,7 +587,7 @@ fn xx_handshake_rejects_tampered_sender_or_recipient() { let mut m1 = initiator_state .write_1(&crypto, handshake_meta(31)) .unwrap(); - m1.header.sender = responder.xid; + m1.header.sender = responder.qid; assert_eq!( responder_state.read_1(&crypto, &m1), @@ -559,14 +597,14 @@ fn xx_handshake_rejects_tampered_sender_or_recipient() { let mut initiator_state = XxHandshake::new_initiator( &crypto, initiator.clone(), - responder.xid, + responder.qid, token, TransportParams::default(), ); let mut responder_state = XxHandshake::new_responder( &crypto, responder.clone(), - initiator.xid, + initiator.qid, token, TransportParams::default(), ); @@ -574,7 +612,7 @@ fn xx_handshake_rejects_tampered_sender_or_recipient() { let mut m1 = initiator_state .write_1(&crypto, handshake_meta(31)) .unwrap(); - m1.header.recipient = initiator.xid; + m1.header.recipient = initiator.qid; assert_eq!( responder_state.read_1(&crypto, &m1), @@ -591,14 +629,14 @@ fn xx_handshake_rejects_repeated_transport_param_change() { let mut initiator_state = XxHandshake::new_initiator( &crypto, initiator.clone(), - responder.xid, + responder.qid, token, handshake_transport_params(12_288), ); let mut responder_state = XxHandshake::new_responder( &crypto, responder, - initiator.xid, + initiator.qid, token, handshake_transport_params(24_576), ); @@ -635,14 +673,14 @@ fn xx_handshake_round_trip_derives_matching_transport_and_learns_remote() { let mut initiator_state = XxHandshake::new_initiator( &crypto, initiator.clone(), - responder.xid, + responder.qid, token, initiator_params, ); let mut responder_state = XxHandshake::new_responder( &crypto, responder.clone(), - initiator.xid, + initiator.qid, token, responder_params, ); @@ -855,14 +893,14 @@ fn protocol_record_size_breakdown() { let mut xx_initiator = XxHandshake::new_initiator( &crypto, initiator.clone(), - responder.xid, + responder.qid, token, TransportParams::default(), ); let mut xx_responder = XxHandshake::new_responder( &crypto, responder.clone(), - initiator.xid, + initiator.qid, token, TransportParams::default(), ); diff --git a/ql-wire/src/xid.rs b/ql-wire/src/xid.rs deleted file mode 100644 index f7500af6..00000000 --- a/ql-wire/src/xid.rs +++ /dev/null @@ -1,25 +0,0 @@ -use crate::{codec, ByteSlice, WireEncode, WireError}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -#[repr(transparent)] -pub struct XID(pub [u8; Self::SIZE]); - -impl XID { - pub const SIZE: usize = 16; -} - -impl WireEncode for XID { - fn encoded_len(&self) -> usize { - Self::SIZE - } - - fn encode(&self, out: &mut W) { - self.0.encode(out); - } -} - -impl codec::WireDecode for XID { - fn decode(reader: &mut codec::Reader) -> Result { - Ok(Self(reader.decode()?)) - } -} From a4d42e6b420ad94de6455f8b0dd0eaa61dfb726a Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Wed, 27 May 2026 08:46:14 -0400 Subject: [PATCH 304/304] ql: identity name --- ql-fsm/src/tests/handshake.rs | 4 +- ql-runtime/src/driver/test.rs | 4 +- ql-runtime/src/tests/mod.rs | 4 +- ql-wire/src/handshake/ik.rs | 16 +++--- ql-wire/src/handshake/mod.rs | 26 +++++----- ql-wire/src/handshake/xx.rs | 32 +++++------- ql-wire/src/identity.rs | 96 ++++++++++++++++++++++++++++++----- ql-wire/src/testing.rs | 4 +- ql-wire/src/tests.rs | 27 ++++++++-- 9 files changed, 144 insertions(+), 69 deletions(-) diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs index eae4f4d3..4de4f06e 100644 --- a/ql-fsm/src/tests/handshake.rs +++ b/ql-fsm/src/tests/handshake.rs @@ -96,7 +96,7 @@ fn ik_connect_learns_remote_initial_stream_receive_window() { #[test] fn connect_methods_require_bound_peer() { let time = Harness::paired_known(QlFsmConfig::default()).time(); - let identity = generate_identity(&SoftwareCrypto); + let identity = generate_identity(&SoftwareCrypto, "identity").unwrap(); let mut fsm = QlFsm::new(QlFsmConfig::default(), identity, time); let crypto = SoftwareCrypto; @@ -343,7 +343,7 @@ fn bind_peer_clears_queued_handshake_output() { harness .a .fsm - .bind_peer(generate_identity(&SoftwareCrypto).bundle()); + .bind_peer(generate_identity(&SoftwareCrypto, "peer").unwrap().bundle()); assert!(harness.drain_events(Side::A).is_empty()); assert!(harness.next_outbound(Side::A).is_none()); diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs index db0116ea..af4ab63a 100644 --- a/ql-runtime/src/driver/test.rs +++ b/ql-runtime/src/driver/test.rs @@ -58,7 +58,7 @@ fn new_driver_state() -> (DriverState, QlFsm) { }, QlFsm::new( ql_fsm::QlFsmConfig::default(), - generate_identity(&SoftwareCrypto), + generate_identity(&SoftwareCrypto, "driver").unwrap(), Instant::now(), ), ) @@ -180,7 +180,7 @@ fn local_close_command_reaps_when_other_half_is_already_closed() { #[test] fn unpaired_status_fails_and_reaps_all_streams() { let (mut state, mut fsm) = new_driver_state(); - let peer = generate_identity(&SoftwareCrypto).bundle(); + let peer = generate_identity(&SoftwareCrypto, "peer").unwrap().bundle(); let stream_id = StreamId(1u32.into()); let (runtime_tx, _runtime_rx) = async_channel::unbounded(); let (_, _, reader_io, writer_io) = io::new_stream( diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs index e22e5041..af368738 100644 --- a/ql-runtime/src/tests/mod.rs +++ b/ql-runtime/src/tests/mod.rs @@ -679,7 +679,7 @@ fn default_runtime_config() -> RuntimeConfig { #[test] fn runtime_is_send() { let config = default_runtime_config(); - let identity = generate_identity(&SoftwareCrypto); + let identity = generate_identity(&SoftwareCrypto, "runtime").unwrap(); let (platform, _, _, _) = TestPlatform::new(); let (runtime, _handle) = new_runtime(identity, platform, config); let _run: Box + Send> = Box::new(runtime.run()); @@ -688,7 +688,7 @@ fn runtime_is_send() { #[test] fn runtime_exits_when_last_handle_drops() { let config = default_runtime_config(); - let identity = generate_identity(&SoftwareCrypto); + let identity = generate_identity(&SoftwareCrypto, "runtime").unwrap(); let (platform, _, _, _) = TestPlatform::new(); let (runtime, handle) = new_runtime(identity, platform, config); let (done_tx, done_rx) = oneshot::channel(); diff --git a/ql-wire/src/handshake/ik.rs b/ql-wire/src/handshake/ik.rs index 26fa5621..628e30e7 100644 --- a/ql-wire/src/handshake/ik.rs +++ b/ql-wire/src/handshake/ik.rs @@ -20,15 +20,6 @@ pub struct Ik1 { pub static_bundle: EncryptedPeerBundle, } -impl Ik1 { - pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE - + HandshakeMeta::WIRE_SIZE - + TransportParams::WIRE_SIZE - + MlKemCiphertext::SIZE - + EphemeralPublicKey::WIRE_SIZE - + EncryptedPeerBundle::WIRE_SIZE; -} - impl codec::WireDecode for Ik1 { fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { @@ -44,7 +35,12 @@ impl codec::WireDecode for Ik1 { impl WireEncode for Ik1 { fn encoded_len(&self) -> usize { - Self::WIRE_SIZE + HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + TransportParams::WIRE_SIZE + + MlKemCiphertext::SIZE + + EphemeralPublicKey::WIRE_SIZE + + self.static_bundle.encoded_len() } fn encode(&self, out: &mut W) { diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs index fcd5ccaf..a9b7cf87 100644 --- a/ql-wire/src/handshake/mod.rs +++ b/ql-wire/src/handshake/mod.rs @@ -83,7 +83,7 @@ impl codec::WireDecode for EphemeralPublicKey { } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct EncryptedMlKemCiphertext(Box<[u8; Self::WIRE_SIZE]>); +pub struct EncryptedMlKemCiphertext(pub Box<[u8; Self::WIRE_SIZE]>); impl EncryptedMlKemCiphertext { pub const WIRE_SIZE: usize = MlKemCiphertext::SIZE + ENCRYPTED_MESSAGE_AUTH_SIZE; @@ -114,33 +114,33 @@ impl codec::WireDecode for EncryptedMlKemCiphertext { } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct EncryptedPeerBundle(Box<[u8; Self::WIRE_SIZE]>); +pub struct EncryptedPeerBundle(pub Box<[u8]>); impl EncryptedPeerBundle { - pub const WIRE_SIZE: usize = PeerBundle::WIRE_SIZE + ENCRYPTED_MESSAGE_AUTH_SIZE; + pub const MAX_WIRE_SIZE: usize = PeerBundle::MAX_WIRE_SIZE + ENCRYPTED_MESSAGE_AUTH_SIZE; - pub fn new(data: Box<[u8; Self::WIRE_SIZE]>) -> Self { - Self(data) - } - - pub fn as_bytes(&self) -> &[u8; Self::WIRE_SIZE] { + pub fn as_bytes(&self) -> &[u8] { self.0.as_ref() } } impl WireEncode for EncryptedPeerBundle { fn encoded_len(&self) -> usize { - Self::WIRE_SIZE + self.0.len() } fn encode(&self, out: &mut W) { - self.0.as_ref().encode(out); + self.as_bytes().encode(out); } } impl codec::WireDecode for EncryptedPeerBundle { fn decode(reader: &mut codec::Reader) -> Result { - Ok(Self::new(reader.decode()?)) + let data = reader.take_rest(); + if data.len() > Self::MAX_WIRE_SIZE { + return Err(WireError::InvalidPayload); + } + Ok(Self(data.to_vec().into_boxed_slice())) } } @@ -462,9 +462,7 @@ fn encrypt_peer_bundle( bundle: &PeerBundle, ) -> Result { let ciphertext = symmetric.encrypt_and_hash(crypto, &bundle.encode_vec())?; - let out: Box<[u8; EncryptedPeerBundle::WIRE_SIZE]> = - ciphertext.try_into().map_err(|_| WireError::InvalidState)?; - Ok(EncryptedPeerBundle::new(out)) + Ok(EncryptedPeerBundle(ciphertext.into_boxed_slice())) } fn decrypt_peer_bundle( diff --git a/ql-wire/src/handshake/xx.rs b/ql-wire/src/handshake/xx.rs index b688d757..0b6452d4 100644 --- a/ql-wire/src/handshake/xx.rs +++ b/ql-wire/src/handshake/xx.rs @@ -64,15 +64,6 @@ pub struct Xx2 { pub static_bundle: EncryptedPeerBundle, } -impl Xx2 { - pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE - + HandshakeMeta::WIRE_SIZE - + PairingId::SIZE - + TransportParams::WIRE_SIZE - + MlKemCiphertext::SIZE - + EncryptedPeerBundle::WIRE_SIZE; -} - impl codec::WireDecode for Xx2 { fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { @@ -88,7 +79,12 @@ impl codec::WireDecode for Xx2 { impl WireEncode for Xx2 { fn encoded_len(&self) -> usize { - Self::WIRE_SIZE + HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + PairingId::SIZE + + TransportParams::WIRE_SIZE + + MlKemCiphertext::SIZE + + self.static_bundle.encoded_len() } fn encode(&self, out: &mut W) { @@ -111,15 +107,6 @@ pub struct Xx3 { pub static_bundle: EncryptedPeerBundle, } -impl Xx3 { - pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE - + HandshakeMeta::WIRE_SIZE - + PairingId::SIZE - + TransportParams::WIRE_SIZE - + EncryptedMlKemCiphertext::WIRE_SIZE - + EncryptedPeerBundle::WIRE_SIZE; -} - impl codec::WireDecode for Xx3 { fn decode(reader: &mut codec::Reader) -> Result { Ok(Self { @@ -135,7 +122,12 @@ impl codec::WireDecode for Xx3 { impl WireEncode for Xx3 { fn encoded_len(&self) -> usize { - Self::WIRE_SIZE + HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + PairingId::SIZE + + TransportParams::WIRE_SIZE + + EncryptedMlKemCiphertext::WIRE_SIZE + + self.static_bundle.encoded_len() } fn encode(&self, out: &mut W) { diff --git a/ql-wire/src/identity.rs b/ql-wire/src/identity.rs index 3a358401..1f8dbee7 100644 --- a/ql-wire/src/identity.rs +++ b/ql-wire/src/identity.rs @@ -1,6 +1,8 @@ +use std::ops::Deref; + use crate::{ - codec, ByteSlice, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, QlCrypto, QlHash, WireEncode, - WireError, QID, + codec, ByteSlice, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, QlCrypto, QlHash, VarInt, + WireEncode, WireError, QID, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -9,14 +11,14 @@ pub struct PeerBundle { pub qid: QID, pub capabilities: u32, pub mlkem_public_key: MlKemPublicKey, - // todo: add - // pub name: String + pub name: QlName, } impl PeerBundle { pub const VERSION: u16 = 1; - pub const WIRE_SIZE: usize = + pub const FIXED_WIRE_SIZE: usize = size_of::() + QID::SIZE + size_of::() + MlKemPublicKey::SIZE; + pub const MAX_WIRE_SIZE: usize = Self::FIXED_WIRE_SIZE + VarInt::MAX_SIZE + QlName::MAX_LEN; pub fn qid_matches_public_key(&self, crypto: &impl QlHash) -> bool { self.qid.matches_public_key(crypto, &self.mlkem_public_key) @@ -25,7 +27,7 @@ impl PeerBundle { impl WireEncode for PeerBundle { fn encoded_len(&self) -> usize { - Self::WIRE_SIZE + Self::FIXED_WIRE_SIZE + self.name.encoded_len() } fn encode(&self, out: &mut W) { @@ -33,6 +35,7 @@ impl WireEncode for PeerBundle { self.qid.encode(out); self.capabilities.encode(out); self.mlkem_public_key.encode(out); + self.name.encode(out); } } @@ -43,6 +46,7 @@ impl codec::WireDecode for PeerBundle { qid: reader.decode()?, capabilities: reader.decode()?, mlkem_public_key: reader.decode()?, + name: reader.decode()?, }) } } @@ -53,24 +57,29 @@ pub struct QlIdentity { pub mlkem_private_key: MlKemPrivateKey, pub mlkem_public_key: MlKemPublicKey, pub capabilities: u32, + pub name: QlName, } impl QlIdentity { - pub const WIRE_SIZE: usize = + pub const FIXED_WIRE_SIZE: usize = QID::SIZE + MlKemPrivateKey::SIZE + MlKemPublicKey::SIZE + size_of::(); + pub const MAX_WIRE_SIZE: usize = Self::FIXED_WIRE_SIZE + VarInt::MAX_SIZE + QlName::MAX_LEN; pub fn new( crypto: &impl QlHash, mlkem_private_key: MlKemPrivateKey, mlkem_public_key: MlKemPublicKey, - ) -> Self { + name: impl Into, + ) -> Result { + let name = QlName::new(name)?; let qid = QID::derive(crypto, &mlkem_public_key); - Self { + Ok(Self { qid, mlkem_private_key, mlkem_public_key, capabilities: 0, - } + name, + }) } #[must_use] @@ -79,19 +88,25 @@ impl QlIdentity { self } + pub fn with_name(mut self, name: impl Into) -> Result { + self.name = QlName::new(name)?; + Ok(self) + } + pub fn bundle(&self) -> PeerBundle { PeerBundle { version: PeerBundle::VERSION, qid: self.qid, capabilities: self.capabilities, mlkem_public_key: self.mlkem_public_key.clone(), + name: self.name.clone(), } } } impl WireEncode for QlIdentity { fn encoded_len(&self) -> usize { - Self::WIRE_SIZE + Self::FIXED_WIRE_SIZE + self.name.encoded_len() } fn encode(&self, out: &mut W) { @@ -99,6 +114,7 @@ impl WireEncode for QlIdentity { self.mlkem_private_key.as_bytes().encode(out); self.mlkem_public_key.encode(out); self.capabilities.encode(out); + self.name.encode(out); } } @@ -109,14 +125,68 @@ impl codec::WireDecode for QlIdentity { mlkem_private_key: MlKemPrivateKey::new(reader.decode()?), mlkem_public_key: reader.decode()?, capabilities: reader.decode()?, + name: reader.decode()?, }) } } -pub fn generate_identity(crypto: &impl QlCrypto) -> QlIdentity { +pub fn generate_identity( + crypto: &impl QlCrypto, + name: impl Into, +) -> Result { let MlKemKeyPair { private: mlkem_private_key, public: mlkem_public_key, } = crypto.mlkem_generate_keypair(); - QlIdentity::new(crypto, mlkem_private_key, mlkem_public_key) + QlIdentity::new(crypto, mlkem_private_key, mlkem_public_key, name) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QlName(String); + +impl QlName { + pub const MAX_LEN: usize = 256; + + pub fn new(name: impl Into) -> Result { + let name = name.into(); + if name.is_empty() || name.len() > Self::MAX_LEN { + return Err(WireError::InvalidPayload); + } + Ok(Self(name)) + } +} + +impl Deref for QlName { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl WireEncode for QlName { + fn encoded_len(&self) -> usize { + let len = VarInt::try_from(self.0.len()).unwrap(); + len.encoded_len() + self.0.len() + } + + fn encode(&self, out: &mut W) { + VarInt::try_from(self.0.len()) + .expect("identity name length fits in varint") + .encode(out); + self.0.as_bytes().encode(out); + } +} + +impl codec::WireDecode for QlName { + fn decode(reader: &mut codec::Reader) -> Result { + let len = usize::try_from(reader.decode::()?.into_inner()) + .map_err(|_| WireError::InvalidPayload)?; + if len == 0 || len > Self::MAX_LEN { + return Err(WireError::InvalidPayload); + } + let bytes = reader.take_bytes(len)?; + let name = std::str::from_utf8(&bytes).map_err(|_| WireError::InvalidPayload)?; + Ok(QlName::new(name)?) + } } diff --git a/ql-wire/src/testing.rs b/ql-wire/src/testing.rs index 11b251c4..a1223c12 100644 --- a/ql-wire/src/testing.rs +++ b/ql-wire/src/testing.rs @@ -15,8 +15,8 @@ pub struct NoopCrypto; pub fn test_identities(crypto: &impl QlCrypto) -> (QlIdentity, QlIdentity) { ( - crate::generate_identity(crypto), - crate::generate_identity(crypto), + crate::generate_identity(crypto, "alice").unwrap(), + crate::generate_identity(crypto, "bob").unwrap(), ) } diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs index 4298ce7b..61826ddd 100644 --- a/ql-wire/src/tests.rs +++ b/ql-wire/src/tests.rs @@ -86,13 +86,32 @@ fn encrypt_record( #[test] fn peer_bundle_round_trip() { let crypto = SoftwareCrypto; - let identity = generate_identity(&crypto).with_capabilities(0x55aa_33cc); + let identity = generate_identity(&crypto, "alice") + .unwrap() + .with_capabilities(0x55aa_33cc); let bundle = identity.bundle(); let encoded = bundle.encode_vec(); let decoded = PeerBundle::decode_exact(encoded.as_slice()).unwrap(); assert_eq!(decoded, bundle); + assert_eq!(&*decoded.name, "alice"); +} + +#[test] +fn identity_name_validation() { + assert_eq!( + QlName::new("a".repeat(QlName::MAX_LEN)).unwrap().len(), + QlName::MAX_LEN + ); + assert!(matches!( + QlName::new(""), + Err(WireError::InvalidPayload) + )); + assert!(matches!( + QlName::new("a".repeat(QlName::MAX_LEN + 1)), + Err(WireError::InvalidPayload) + )); } #[test] @@ -125,7 +144,7 @@ fn qid_changes_when_mlkem_public_key_changes() { #[test] fn peer_bundle_detects_tampered_qid() { let crypto = SoftwareCrypto; - let identity = generate_identity(&crypto); + let identity = generate_identity(&crypto, "alice").unwrap(); let mut bundle = identity.bundle(); bundle.qid = qid(9); @@ -143,7 +162,7 @@ fn handshake_record_round_trip_supports_ik_kk_and_xx() { ephemeral: EphemeralPublicKey { mlkem_public_key: MlKemPublicKey::new(Box::new([9; MlKemPublicKey::SIZE])), }, - static_bundle: EncryptedPeerBundle::new(Box::new([13; EncryptedPeerBundle::WIRE_SIZE])), + static_bundle: EncryptedPeerBundle(vec![13; 64].into_boxed_slice()), }); let ik_encoded = encode_record_vec(RecordType::Handshake, &ik); assert_eq!( @@ -317,7 +336,7 @@ fn ik_handshake_rejects_tampered_handshake_header() { fn ik_handshake_rejects_bound_remote_bundle_mismatch() { let crypto = SoftwareCrypto; let (initiator, responder) = test_identities(&crypto); - let bogus = generate_identity(&crypto); + let bogus = generate_identity(&crypto, "bogus").unwrap(); let mut initiator_state = IkHandshake::new_initiator( &crypto,