From 406eba49c4b5e39030d04b8cae7847450230828d Mon Sep 17 00:00:00 2001 From: xdustinface Date: Wed, 29 Apr 2026 23:53:57 +1000 Subject: [PATCH] refactor(dash-spv): split peer `TcpStream` and add per-peer writer task `Peer` now splits its `TcpStream` into independent read and write halves via `tokio::io::split`. The read half lives behind its own `Mutex` so the reader can frame inbound bytes without contending with senders. The write half is owned by a dedicated per-peer writer task that drains an `mpsc::Sender` onto the socket. `send_message` now just queues into that channel and returns immediately. This removes the single `Mutex` that previously serialised the reader, the maintenance ping, every distributed send, and every broadcast through the same critical section, so a long inbound `Headers2` decompression no longer holds outbound pings off the wire. The reader loop also no longer has anything to gain from holding a mutating lock to call `receive_message`. The `Peer` API is adjusted accordingly: `send_message`, `receive_message`, and `handle_ping` are now `&self` since the only state they touch lives behind `Arc`s. `bytes_sent` becomes an `AtomicU64` shared with the writer task. `Peer::connect_instance` is retained for compatibility but routes through the same path. --- dash-spv/src/network/constants.rs | 1 - dash-spv/src/network/manager.rs | 27 +- dash-spv/src/network/peer.rs | 751 +++++++++++++++----------- masternode-seeds-fetcher/src/probe.rs | 2 +- 4 files changed, 443 insertions(+), 338 deletions(-) diff --git a/dash-spv/src/network/constants.rs b/dash-spv/src/network/constants.rs index e1278f0f6..d337516a7 100644 --- a/dash-spv/src/network/constants.rs +++ b/dash-spv/src/network/constants.rs @@ -25,5 +25,4 @@ pub const PEER_DISCOVERY_INTERVAL: Duration = Duration::from_secs(60); // Discov // DNS and polling intervals pub const DNS_DISCOVERY_DELAY: Duration = Duration::from_secs(10); -pub const MESSAGE_POLL_INTERVAL: Duration = Duration::from_millis(10); pub const MESSAGE_RECEIVE_TIMEOUT: Duration = Duration::from_millis(100); diff --git a/dash-spv/src/network/manager.rs b/dash-spv/src/network/manager.rs index c68cfcce8..1e6172059 100644 --- a/dash-spv/src/network/manager.rs +++ b/dash-spv/src/network/manager.rs @@ -459,26 +459,19 @@ impl PeerNetworkManager { } }; - // Read message with minimal lock time + // Read with only a shared peer lock so senders can run concurrently. The + // inner read state has its own mutex, and `receive_message` is fully + // waker-driven on the read half. No polling sleep needed. let msg_result = { - // Try to get a read lock first to check if peer is available let peer_guard = peer.read().await; if !peer_guard.is_connected() { tracing::warn!("Breaking peer reader loop for {} - peer no longer connected (iteration {})", addr, loop_iteration); - drop(peer_guard); break; } - drop(peer_guard); - - // Now get write lock only for the duration of the read - let mut peer_guard = peer.write().await; tokio::select! { message = peer_guard.receive_message() => { message }, - _ = tokio::time::sleep(MESSAGE_POLL_INTERVAL) => { - Ok(None) - }, _ = shutdown_token.cancelled() => { tracing::info!("Breaking peer reader loop for {} - shutdown signal received while reading (iteration {})", addr, loop_iteration); break; @@ -519,7 +512,7 @@ impl PeerNetworkManager { ); // Send our known addresses let response = addrv2_handler.build_addr_response().await; - let mut peer_guard = peer.write().await; + let peer_guard = peer.read().await; if let Err(e) = peer_guard.send_message(response).await { tracing::error!( "Failed to send addr response to {}: {}", @@ -531,12 +524,16 @@ impl PeerNetworkManager { } NetworkMessage::Ping(nonce) => { // Handle ping directly - let mut peer_guard = peer.write().await; + let peer_guard = peer.read().await; if let Err(e) = peer_guard.handle_ping(*nonce).await { tracing::error!("Failed to handle ping from {}: {}", addr, e); // If we can't send pong, connection is likely broken if matches!(e, NetworkError::ConnectionFailed(_)) { tracing::warn!("Breaking peer reader loop for {} - failed to send pong response (iteration {})", addr, loop_iteration); + // Drop the read guard before acquiring the write + // guard on the same RwLock to avoid a self-deadlock. + drop(peer_guard); + peer.write().await.mark_disconnected(); break; } } @@ -688,6 +685,7 @@ impl PeerNetworkManager { match e { NetworkError::PeerDisconnected => { tracing::info!("Peer {} disconnected", addr); + peer.write().await.mark_disconnected(); break; } NetworkError::Timeout => { @@ -757,6 +755,7 @@ impl PeerNetworkManager { } } + peer.write().await.mark_disconnected(); break; } } @@ -1237,7 +1236,7 @@ impl PeerNetworkManager { other => other, }; - let mut peer_guard = peer.write().await; + let peer_guard = peer.read().await; peer_guard .send_message(message) .await @@ -1263,7 +1262,7 @@ impl PeerNetworkManager { let msg = message.clone(); let handle = tokio::spawn(async move { - let mut peer_guard = peer.write().await; + let peer_guard = peer.read().await; peer_guard.send_message(msg).await.map_err(Error::Network) }); handles.push(handle); diff --git a/dash-spv/src/network/peer.rs b/dash-spv/src/network/peer.rs index beec9cfff..57bfd180a 100644 --- a/dash-spv/src/network/peer.rs +++ b/dash-spv/src/network/peer.rs @@ -3,11 +3,12 @@ use dashcore::network::constants::ServiceFlags; use std::collections::HashMap; use std::net::SocketAddr; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::time::{Duration, SystemTime}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}; use tokio::net::TcpStream; -use tokio::sync::Mutex; +use tokio::sync::{mpsc, Mutex}; use dashcore::consensus::{encode, Decodable}; use dashcore::network::message::{NetworkMessage, RawNetworkMessage}; @@ -17,22 +18,29 @@ use crate::error::{NetworkError, NetworkResult}; use crate::network::constants::PING_INTERVAL; use crate::network::Message; -/// Internal state for the TCP connection -struct ConnectionState { - stream: TcpStream, - // Stateful message framing buffer to ensure full frames before decoding +/// Capacity of the per-peer outbound message channel. Generous enough to absorb bursts +/// while still bounding memory if the writer task or socket falls behind. +const OUTBOUND_CHANNEL_CAPACITY: usize = 256; + +const READ_CHUNK_SIZE: usize = 8192; + +struct ReadState { + read_half: ReadHalf, framing_buffer: Vec, } -/// Dash P2P peer +/// Dash P2P peer. +/// +/// The peer's outbound writes are funnelled through a per-peer writer task that owns the +/// `WriteHalf` of the underlying TCP stream. Inbound reads happen on the `ReadHalf` under +/// a separate mutex so reading never blocks writing and vice versa. pub struct Peer { address: SocketAddr, - // Use a single mutex to protect both the write stream and read buffer - // This ensures no concurrent access to the underlying socket - state: Option>>, + read_state: Option>>, + out_tx: Option>, timeout: Duration, connected_at: Option, - bytes_sent: u64, + bytes_sent: Arc, network: Network, // Ping/pong state last_ping_sent: Option, @@ -57,10 +65,11 @@ impl Peer { pub fn new(address: SocketAddr, timeout: Duration, network: Network) -> Self { Self { address, - state: None, + read_state: None, + out_tx: None, timeout, connected_at: None, - bytes_sent: 0, + bytes_sent: Arc::new(AtomicU64::new(0)), network, last_ping_sent: None, last_pong_received: None, @@ -96,29 +105,9 @@ impl Peer { NetworkError::ConnectionFailed(format!("Failed to set TCP_NODELAY: {}", e)) })?; - let state = ConnectionState { - stream, - framing_buffer: Vec::new(), - }; - - Ok(Self { - address, - state: Some(Arc::new(Mutex::new(state))), - timeout, - connected_at: Some(SystemTime::now()), - bytes_sent: 0, - network, - last_ping_sent: None, - last_pong_received: None, - pending_pings: HashMap::new(), - version: None, - services: None, - user_agent: None, - best_height: None, - relay: None, - prefers_headers2: false, - sent_sendheaders2: false, - }) + let mut peer = Self::new(address, timeout, network); + peer.install_stream(stream); + Ok(peer) } pub fn version(&self) -> Option { @@ -166,29 +155,36 @@ impl Peer { NetworkError::ConnectionFailed(format!("Failed to set TCP_NODELAY: {}", e)) })?; - let state = ConnectionState { - stream, - framing_buffer: Vec::new(), - }; - - self.state = Some(Arc::new(Mutex::new(state))); - self.connected_at = Some(SystemTime::now()); + self.install_stream(stream); tracing::info!("Connected to peer {}", self.address); Ok(()) } + fn install_stream(&mut self, stream: TcpStream) { + let (read_half, write_half) = tokio::io::split(stream); + let (out_tx, out_rx) = mpsc::channel(OUTBOUND_CHANNEL_CAPACITY); + + self.read_state = Some(Arc::new(Mutex::new(ReadState { + read_half, + framing_buffer: Vec::new(), + }))); + self.out_tx = Some(out_tx); + self.connected_at = Some(SystemTime::now()); + + spawn_writer_task( + self.address, + self.network.magic(), + write_half, + out_rx, + self.bytes_sent.clone(), + ); + } + /// Disconnect from the peer. pub async fn disconnect(&mut self) -> NetworkResult<()> { - if let Some(state_arc) = self.state.take() { - if let Ok(state_mutex) = Arc::try_unwrap(state_arc) { - let mut state = state_mutex.into_inner(); - let _ = state.stream.shutdown().await; - } - } - self.connected_at = None; - + self.tear_down_connection(); tracing::info!("Disconnected from peer {}", self.address); Ok(()) @@ -296,9 +292,9 @@ impl Peer { } /// Helper function to read some bytes into the framing buffer. - async fn read_some(state: &mut ConnectionState) -> std::io::Result { - let mut tmp = [0u8; 8192]; - match state.stream.read(&mut tmp).await { + async fn read_some(state: &mut ReadState) -> std::io::Result { + let mut tmp = [0u8; READ_CHUNK_SIZE]; + match state.read_half.read(&mut tmp).await { Ok(0) => Ok(0), Ok(n) => { state.framing_buffer.extend_from_slice(&tmp[..n]); @@ -308,172 +304,95 @@ impl Peer { } } - /// Send a message to the peer. - pub async fn send_message(&mut self, message: NetworkMessage) -> NetworkResult<()> { - let state_arc = self - .state + /// Send a message to the peer by handing it off to the per-peer writer task. + /// + /// Awaits channel capacity if the writer task is behind, and only fails when the peer + /// is disconnected (the writer task has exited and closed the channel). + pub async fn send_message(&self, message: NetworkMessage) -> NetworkResult<()> { + let tx = self + .out_tx .as_ref() .ok_or_else(|| NetworkError::ConnectionFailed("Not connected".to_string()))?; - let raw_message = RawNetworkMessage { - magic: self.network.magic(), - payload: message, - }; - - let serialized = encode::serialize(&raw_message); - - // Log details for debugging headers2 issues - if matches!( - raw_message.payload, - NetworkMessage::GetHeaders2(_) | NetworkMessage::GetHeaders(_) - ) { - let msg_type = match raw_message.payload { - NetworkMessage::GetHeaders2(_) => "GetHeaders2", - NetworkMessage::GetHeaders(_) => "GetHeaders", - _ => "Unknown", - }; - tracing::debug!( - "Sending {} raw bytes (len={}): {:02x?}", - msg_type, - serialized.len(), - &serialized[..std::cmp::min(100, serialized.len())] - ); + if matches!(&message, NetworkMessage::GetHeaders2(_) | NetworkMessage::GetHeaders(_)) { + tracing::debug!("Queueing {} for {}", message.cmd(), self.address); } - // Lock the state for the entire write operation - let mut state = state_arc.lock().await; - - // Write with error handling - match state.stream.write_all(&serialized).await { - Ok(_) => { - // Flush to ensure data is sent immediately - if let Err(e) = state.stream.flush().await { - tracing::warn!("Failed to flush socket {}: {}", self.address, e); - } - self.bytes_sent += serialized.len() as u64; - tracing::debug!("Sent message to {}: {:?}", self.address, raw_message.payload); - Ok(()) - } - Err(e) => { - tracing::warn!("Disconnecting {} due to write error: {}", self.address, e); - // Drop the lock before clearing connection state - drop(state); - // Clear connection state on write error - self.state = None; - self.connected_at = None; - Err(NetworkError::ConnectionFailed(format!("Write failed: {}", e))) - } - } + tx.send(message).await.map_err(|_| NetworkError::PeerDisconnected) } /// Receive a message from the peer. - pub async fn receive_message(&mut self) -> NetworkResult> { - // If the state was cleared e.g. by a write-path broken pipe, treat as disconnected - // so the reader loop handles it identically to a read-path EOF. - let state_arc = self.state.as_ref().ok_or(NetworkError::PeerDisconnected)?; - - // Lock the state for the entire read operation - // This ensures no concurrent access to the socket + pub async fn receive_message(&self) -> NetworkResult> { + let state_arc = self.read_state.as_ref().ok_or(NetworkError::PeerDisconnected)?; let mut state = state_arc.lock().await; // Buffered, stateful framing const HEADER_LEN: usize = 24; // magic[4] + cmd[12] + length[4] + checksum[4] const MAX_RESYNC_STEPS_PER_CALL: usize = 64; - let result = async { - let magic_bytes = self.network.magic().to_le_bytes(); - let mut resync_steps = 0usize; + let magic_bytes = self.network.magic().to_le_bytes(); + let mut resync_steps = 0usize; - loop { - // Ensure header availability - if state.framing_buffer.len() < HEADER_LEN { - match Self::read_some(&mut state).await { - Ok(0) => { - tracing::info!("Peer {} closed connection (EOF)", self.address); - return Err(NetworkError::PeerDisconnected); - } - Ok(_) => {} - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - return Ok(None); - } - Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { - return Ok(None); - } - Err(ref e) - if e.kind() == std::io::ErrorKind::ConnectionAborted - || e.kind() == std::io::ErrorKind::ConnectionReset => - { - tracing::info!("Peer {} connection reset/aborted", self.address); - return Err(NetworkError::PeerDisconnected); - } - Err(e) => { - return Err(NetworkError::ConnectionFailed(format!( - "Read failed: {}", - e - ))); - } + loop { + // Ensure header availability + if state.framing_buffer.len() < HEADER_LEN { + match Self::read_some(&mut state).await { + Ok(0) => { + tracing::info!("Peer {} closed connection (EOF)", self.address); + return Err(NetworkError::PeerDisconnected); + } + Ok(_) => {} + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + return Ok(None); + } + Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { + return Ok(None); + } + Err(ref e) + if e.kind() == std::io::ErrorKind::ConnectionAborted + || e.kind() == std::io::ErrorKind::ConnectionReset => + { + tracing::info!("Peer {} connection reset/aborted", self.address); + return Err(NetworkError::PeerDisconnected); + } + Err(e) => { + return Err(NetworkError::ConnectionFailed(format!("Read failed: {}", e))); } } + } - // Align to magic - if state.framing_buffer.len() >= 4 && state.framing_buffer[..4] != magic_bytes { - if let Some(pos) = - state.framing_buffer.windows(4).position(|w| w == magic_bytes) - { - if pos > 0 { - tracing::warn!( - "{}: stream desync: skipping {} stray bytes before magic", - self.address, - pos - ); - state.framing_buffer.drain(0..pos); - resync_steps += 1; - if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { - return Ok(None); - } - continue; - } - } else { - // Keep last 3 bytes of potential magic prefix - if state.framing_buffer.len() > 3 { - let dropped = state.framing_buffer.len() - 3; - tracing::warn!( - "{}: stream desync: dropping {} bytes (no magic found)", - self.address, - dropped - ); - state.framing_buffer.drain(0..dropped); - resync_steps += 1; - if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { - return Ok(None); - } - } - // Need more data - match Self::read_some(&mut state).await { - Ok(0) => { - tracing::info!("Peer {} closed connection (EOF)", self.address); - return Err(NetworkError::PeerDisconnected); - } - Ok(_) => {} - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - return Ok(None); - } - Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { - return Ok(None); - } - Err(e) => { - return Err(NetworkError::ConnectionFailed(format!( - "Read failed: {}", - e - ))); - } + // Align to magic + if state.framing_buffer.len() >= 4 && state.framing_buffer[..4] != magic_bytes { + if let Some(pos) = state.framing_buffer.windows(4).position(|w| w == magic_bytes) { + if pos > 0 { + tracing::warn!( + "{}: stream desync: skipping {} stray bytes before magic", + self.address, + pos + ); + state.framing_buffer.drain(0..pos); + resync_steps += 1; + if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { + return Ok(None); } continue; } - } - - // Ensure full header - if state.framing_buffer.len() < HEADER_LEN { + } else { + // Keep last 3 bytes of potential magic prefix + if state.framing_buffer.len() > 3 { + let dropped = state.framing_buffer.len() - 3; + tracing::warn!( + "{}: stream desync: dropping {} bytes (no magic found)", + self.address, + dropped + ); + state.framing_buffer.drain(0..dropped); + resync_steps += 1; + if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { + return Ok(None); + } + } + // Need more data match Self::read_some(&mut state).await { Ok(0) => { tracing::info!("Peer {} closed connection (EOF)", self.address); @@ -495,83 +414,142 @@ impl Peer { } continue; } + } - // Parse header fields - let length_le = u32::from_le_bytes([ - state.framing_buffer[16], - state.framing_buffer[17], - state.framing_buffer[18], - state.framing_buffer[19], - ]) as usize; - let header_checksum = [ - state.framing_buffer[20], - state.framing_buffer[21], - state.framing_buffer[22], - state.framing_buffer[23], - ]; - // Validate announced length to prevent unbounded accumulation or overflow - if length_le > dashcore::network::message::MAX_MSG_SIZE { - return Err(NetworkError::ProtocolError(format!( - "Declared payload length {} exceeds MAX_MSG_SIZE {}", - length_le, - dashcore::network::message::MAX_MSG_SIZE - ))); - } - let total_len = match HEADER_LEN.checked_add(length_le) { - Some(v) => v, - None => { - return Err(NetworkError::ProtocolError( - "Message length overflow".to_string(), - )); + // Ensure full header + if state.framing_buffer.len() < HEADER_LEN { + match Self::read_some(&mut state).await { + Ok(0) => { + tracing::info!("Peer {} closed connection (EOF)", self.address); + return Err(NetworkError::PeerDisconnected); + } + Ok(_) => {} + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + return Ok(None); + } + Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { + return Ok(None); + } + Err(e) => { + return Err(NetworkError::ConnectionFailed(format!("Read failed: {}", e))); } - }; + } + continue; + } - // Ensure full frame available - if state.framing_buffer.len() < total_len { - match Self::read_some(&mut state).await { - Ok(0) => { - tracing::info!("Peer {} closed connection (EOF)", self.address); - return Err(NetworkError::PeerDisconnected); - } - Ok(_) => {} - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - return Ok(None); - } - Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { - return Ok(None); - } - Err(e) => { - return Err(NetworkError::ConnectionFailed(format!( - "Read failed: {}", - e - ))); - } + // Parse header fields + let length_le = u32::from_le_bytes([ + state.framing_buffer[16], + state.framing_buffer[17], + state.framing_buffer[18], + state.framing_buffer[19], + ]) as usize; + let header_checksum = [ + state.framing_buffer[20], + state.framing_buffer[21], + state.framing_buffer[22], + state.framing_buffer[23], + ]; + // Validate announced length to prevent unbounded accumulation or overflow + if length_le > dashcore::network::message::MAX_MSG_SIZE { + return Err(NetworkError::ProtocolError(format!( + "Declared payload length {} exceeds MAX_MSG_SIZE {}", + length_le, + dashcore::network::message::MAX_MSG_SIZE + ))); + } + let total_len = match HEADER_LEN.checked_add(length_le) { + Some(v) => v, + None => { + return Err(NetworkError::ProtocolError("Message length overflow".to_string())); + } + }; + + // Ensure full frame available + if state.framing_buffer.len() < total_len { + match Self::read_some(&mut state).await { + Ok(0) => { + tracing::info!("Peer {} closed connection (EOF)", self.address); + return Err(NetworkError::PeerDisconnected); + } + Ok(_) => {} + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + return Ok(None); + } + Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { + return Ok(None); + } + Err(e) => { + return Err(NetworkError::ConnectionFailed(format!("Read failed: {}", e))); } - continue; } + continue; + } - // Verify checksum - let payload_slice = &state.framing_buffer[HEADER_LEN..total_len]; - let expected = { - let checksum = ::hash( - payload_slice, - ); - [checksum[0], checksum[1], checksum[2], checksum[3]] - }; - if expected != header_checksum { + // Verify checksum + let payload_slice = &state.framing_buffer[HEADER_LEN..total_len]; + let expected = { + let checksum = + ::hash(payload_slice); + [checksum[0], checksum[1], checksum[2], checksum[3]] + }; + if expected != header_checksum { + tracing::warn!( + "Skipping message with invalid checksum from {}: expected {:02x?}, actual {:02x?}", + self.address, + expected, + header_checksum + ); + if header_checksum == [0, 0, 0, 0] { tracing::warn!( - "Skipping message with invalid checksum from {}: expected {:02x?}, actual {:02x?}", - self.address, - expected, - header_checksum + "All-zeros checksum detected from {}, likely corrupted stream - resyncing", + self.address ); - if header_checksum == [0, 0, 0, 0] { + } + // Resync by dropping a byte and retrying + state.framing_buffer.drain(0..1); + resync_steps += 1; + if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { + return Ok(None); + } + continue; + } + + // Decode full RawNetworkMessage from the frame using existing decoder + let mut cursor = std::io::Cursor::new(&state.framing_buffer[..total_len]); + match RawNetworkMessage::consensus_decode(&mut cursor) { + Ok(raw_message) => { + // Consume bytes + state.framing_buffer.drain(0..total_len); + + // Validate magic matches our network + if raw_message.magic != self.network.magic() { tracing::warn!( - "All-zeros checksum detected from {}, likely corrupted stream - resyncing", - self.address + "Received message with wrong magic bytes: expected {:#x}, got {:#x}", + self.network.magic(), + raw_message.magic ); + return Err(NetworkError::ProtocolError(format!( + "Wrong magic bytes: expected {:#x}, got {:#x}", + self.network.magic(), + raw_message.magic + ))); } - // Resync by dropping a byte and retrying + + tracing::trace!( + "Successfully decoded message from {}: {:?}", + self.address, + raw_message.payload.cmd() + ); + + return Ok(Some(Message::new(self.address, raw_message.payload))); + } + Err(e) => { + tracing::warn!( + "{}: decode error after framing ({}), attempting resync", + self.address, + e + ); state.framing_buffer.drain(0..1); resync_steps += 1; if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { @@ -579,69 +557,26 @@ impl Peer { } continue; } - - // Decode full RawNetworkMessage from the frame using existing decoder - let mut cursor = std::io::Cursor::new(&state.framing_buffer[..total_len]); - match RawNetworkMessage::consensus_decode(&mut cursor) { - Ok(raw_message) => { - // Consume bytes - state.framing_buffer.drain(0..total_len); - - // Validate magic matches our network - if raw_message.magic != self.network.magic() { - tracing::warn!( - "Received message with wrong magic bytes: expected {:#x}, got {:#x}", - self.network.magic(), - raw_message.magic - ); - return Err(NetworkError::ProtocolError(format!( - "Wrong magic bytes: expected {:#x}, got {:#x}", - self.network.magic(), - raw_message.magic - ))); - } - - tracing::trace!( - "Successfully decoded message from {}: {:?}", - self.address, - raw_message.payload.cmd() - ); - - return Ok(Some(Message::new(self.address, raw_message.payload))); - } - Err(e) => { - tracing::warn!( - "{}: decode error after framing ({}), attempting resync", - self.address, - e - ); - state.framing_buffer.drain(0..1); - resync_steps += 1; - if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { - return Ok(None); - } - continue; - } - } } } - .await; - - // Drop the lock before disconnecting - drop(state); + } - // Handle disconnection if needed - if let Err(NetworkError::PeerDisconnected) = &result { - self.state = None; - self.connected_at = None; - } + /// Mark the peer disconnected, closing the outbound channel so the writer task winds down. + pub fn mark_disconnected(&mut self) { + self.tear_down_connection(); + } - result + fn tear_down_connection(&mut self) { + self.read_state = None; + // Dropping `out_tx` closes the channel. The writer task drains any queued + // messages, flushes the socket, and exits cleanly. + self.out_tx = None; + self.connected_at = None; } /// Check if the connection is active. pub fn is_connected(&self) -> bool { - self.state.is_some() + self.read_state.is_some() && self.out_tx.is_some() } /// Check if connection appears healthy (not just connected). @@ -681,7 +616,7 @@ impl Peer { /// Get connection statistics. pub fn stats(&self) -> (u64, u64) { - (self.bytes_sent, 0) // TODO: Track bytes received + (self.bytes_sent.load(Ordering::Relaxed), 0) // TODO: Track bytes received } /// Send a ping message with a random nonce. @@ -701,7 +636,7 @@ impl Peer { } /// Handle a received ping message by sending a pong response. - pub async fn handle_ping(&mut self, nonce: u64) -> NetworkResult<()> { + pub async fn handle_ping(&self, nonce: u64) -> NetworkResult<()> { let pong_message = NetworkMessage::Pong(nonce); self.send_message(pong_message).await?; @@ -828,6 +763,37 @@ impl Peer { } } +fn spawn_writer_task( + address: SocketAddr, + magic: u32, + mut write_half: WriteHalf, + mut out_rx: mpsc::Receiver, + bytes_sent: Arc, +) { + tokio::spawn(async move { + while let Some(message) = out_rx.recv().await { + let raw = RawNetworkMessage { + magic, + payload: message, + }; + let serialized = encode::serialize(&raw); + if let Err(e) = write_half.write_all(&serialized).await { + tracing::warn!("Writer task for {} stopping after write error: {}", address, e); + break; + } + if let Err(e) = write_half.flush().await { + tracing::warn!("Writer task for {} stopping after flush error: {}", address, e); + break; + } + bytes_sent.fetch_add(serialized.len() as u64, Ordering::Relaxed); + tracing::debug!("Sent message to {}: {:?}", address, raw.payload); + } + + // Best-effort close to release the kernel-side socket promptly. + let _ = write_half.shutdown().await; + }); +} + #[cfg(test)] impl Peer { pub(crate) fn set_services(&mut self, flags: ServiceFlags) { @@ -837,11 +803,50 @@ impl Peer { #[cfg(test)] mod tests { + use std::future; use std::net::SocketAddr; use std::time::{Duration, SystemTime}; + use dashcore::consensus::encode; + use dashcore::network::message::{NetworkMessage, RawNetworkMessage}; + use dashcore::Network; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::{TcpListener, TcpStream}; + use super::Peer; + /// Build a connected `Peer` paired with the server-side `TcpStream` for the test to drive. + async fn paired_peer() -> (Peer, TcpStream) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let listen_addr = listener.local_addr().unwrap(); + let accept = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + stream + }); + let peer = Peer::connect(listen_addr, 5, Network::Regtest).await.unwrap(); + let server_stream = accept.await.unwrap(); + (peer, server_stream) + } + + /// Block until `n` bytes have been read from `stream` or the timeout expires. + async fn read_exact_with_timeout( + stream: &mut TcpStream, + n: usize, + timeout: Duration, + ) -> Vec { + let mut buf = vec![0u8; n]; + tokio::time::timeout(timeout, stream.read_exact(&mut buf)).await.unwrap().unwrap(); + buf + } + + /// Serialize a `NetworkMessage` into the framed wire bytes a peer would send. + fn frame(magic: u32, payload: NetworkMessage) -> Vec { + encode::serialize(&RawNetworkMessage { + magic, + payload, + }) + } + #[test] fn remove_expired_pings() { let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap(); @@ -871,4 +876,106 @@ mod tests { assert!(peer.remove_expired_pings()); assert!(peer.pending_pings.is_empty()); } + + /// `send_message` must hit the wire even while a concurrent reader is parked inside + /// `receive_message` holding the read state mutex. + #[tokio::test] + async fn send_completes_while_reader_holds_read_state() { + let (peer, mut server) = paired_peer().await; + + // Park a fake "reader" that owns the inner read state mutex but never reads bytes. + let read_state = peer.read_state.as_ref().cloned().unwrap(); + let reader_handle = tokio::spawn(async move { + let _guard = read_state.lock().await; + future::pending::<()>().await + }); + + // Give the reader task a tick to actually acquire the mutex. + tokio::task::yield_now().await; + + let payload = NetworkMessage::Ping(0xDEADBEEF); + let expected = frame(Network::Regtest.magic(), payload.clone()); + peer.send_message(payload).await.unwrap(); + let observed = + read_exact_with_timeout(&mut server, expected.len(), Duration::from_millis(500)).await; + assert_eq!(observed, expected); + + reader_handle.abort(); + drop(peer); + } + + /// `receive_message` and `send_message` must be runnable concurrently on the same peer + /// without either blocking the other. + #[tokio::test] + async fn concurrent_send_and_receive() { + let (peer, mut server) = paired_peer().await; + let magic = Network::Regtest.magic(); + + let pong = frame(magic, NetworkMessage::Pong(424242)); + server.write_all(&pong).await.unwrap(); + server.flush().await.unwrap(); + + let send_fut = peer.send_message(NetworkMessage::Ping(7)); + let recv_fut = peer.receive_message(); + let (send_res, recv_res) = tokio::join!(send_fut, recv_fut); + send_res.unwrap(); + let received = recv_res.unwrap().unwrap(); + assert!(matches!(received.inner(), NetworkMessage::Pong(424242))); + + let expected_ping = frame(magic, NetworkMessage::Ping(7)); + let observed = + read_exact_with_timeout(&mut server, expected_ping.len(), Duration::from_secs(2)).await; + assert_eq!(observed, expected_ping); + } + + /// All queued messages must reach the wire before the writer task winds down on + /// disconnect. + #[tokio::test] + async fn writer_task_drains_on_disconnect() { + let (mut peer, mut server) = paired_peer().await; + let magic = Network::Regtest.magic(); + + let nonces: Vec = (0..16).collect(); + for nonce in &nonces { + peer.send_message(NetworkMessage::Ping(*nonce)).await.unwrap(); + } + + let _ = peer.disconnect().await; + + let mut frame_size = 0; + for nonce in &nonces { + let expected = frame(magic, NetworkMessage::Ping(*nonce)); + if frame_size == 0 { + frame_size = expected.len(); + } + let observed = + read_exact_with_timeout(&mut server, frame_size, Duration::from_secs(2)).await; + assert_eq!(observed, expected); + } + } + + /// When the peer side of the socket is closed, the writer task must surface a clean + /// disconnect signal back through `send_message` rather than panicking or stalling. + #[tokio::test] + async fn writer_task_signals_socket_close() { + let (peer, server) = paired_peer().await; + drop(server); + + // The writer task may take a moment to observe the closed socket. Loop until + // `send_message` reports `PeerDisconnected` or we hit a generous timeout. + let deadline = tokio::time::Instant::now() + Duration::from_secs(2); + loop { + let result = peer.send_message(NetworkMessage::Ping(1)).await; + if let Err(super::NetworkError::PeerDisconnected) = result { + return; + } + if tokio::time::Instant::now() >= deadline { + panic!( + "writer task did not signal disconnect within 2s, last result: {:?}", + result + ); + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + } } diff --git a/masternode-seeds-fetcher/src/probe.rs b/masternode-seeds-fetcher/src/probe.rs index 49a1de9cc..38edd8542 100644 --- a/masternode-seeds-fetcher/src/probe.rs +++ b/masternode-seeds-fetcher/src/probe.rs @@ -54,7 +54,7 @@ pub async fn probe_core( } async fn probe_core_inner(peer_addr: SocketAddr, network: DashNetwork) -> Result> { - let mut peer = Peer::connect(peer_addr, CONNECT_TIMEOUT.as_secs(), network).await?; + let peer = Peer::connect(peer_addr, CONNECT_TIMEOUT.as_secs(), network).await?; let version = VersionMessage::new( ServiceFlags::NONE,