diff --git a/src/control_message.rs b/src/control_message.rs index 66c4a28..c1fb291 100644 --- a/src/control_message.rs +++ b/src/control_message.rs @@ -1,8 +1,48 @@ -use std::marker::PhantomData; +use std::{marker::PhantomData, net::IpAddr}; use crate::socket::Timestamp; -pub(crate) const fn control_message_space() -> usize { +#[cfg(target_os = "linux")] +const SCM_TIMESTAMPING_CMSG_SIZE: usize = control_message_space::<[libc::timespec; 3]>(); +#[cfg(any(target_os = "linux", target_os = "freebsd"))] +const SCM_TIMESTAMP_NS_CMSG_SIZE: usize = control_message_space::(); +const SCM_TIMESTAMP_CMSG_SIZE: usize = control_message_space::(); +#[cfg(target_os = "linux")] +const RECEIVERR_CMSG_SIZE: usize = + control_message_space::<(libc::sock_extended_err, libc::sockaddr_storage)>(); +#[cfg(target_os = "linux")] +const IP_PKTINFO_CMSG_SIZE: usize = control_message_space::(); +#[cfg(any(target_os = "freebsd", target_os = "macos"))] +const IP_RECVDSTADDR_CMSG_SIZE: usize = control_message_space::(); +const IP6_PKTINFO_CMSG_SIZE: usize = control_message_space::(); + +// Utility needed since the ord trait max function is not usable in const environments +const fn max(a: usize, b: usize) -> usize { + if a > b { + a + } else { + b + } +} + +#[cfg(target_os = "linux")] +pub(crate) const EXPECTED_MAX_CMSG_SIZE: usize = + max( + max(SCM_TIMESTAMPING_CMSG_SIZE, SCM_TIMESTAMP_NS_CMSG_SIZE), + SCM_TIMESTAMP_CMSG_SIZE, + ) + max(IP_PKTINFO_CMSG_SIZE, IP6_PKTINFO_CMSG_SIZE) + + RECEIVERR_CMSG_SIZE; +#[cfg(target_os = "freebsd")] +pub(crate) const EXPECTED_MAX_CMSG_SIZE: usize = + max(SCM_TIMESTAMP_NS_CMSG_SIZE, SCM_TIMESTAMP_CMSG_SIZE) + + max(IP_RECVDSTADDR_CMSG_SIZE, IP6_PKTINFO_CMSG_SIZE); +#[cfg(target_os = "macos")] +pub(crate) const EXPECTED_MAX_CMSG_SIZE: usize = + SCM_TIMESTAMP_CMSG_SIZE + max(IP_RECVDSTADDR_CMSG_SIZE, IP6_PKTINFO_CMSG_SIZE); +#[cfg(not(any(target_os = "linux", target_os = "freebsd", target_os = "macos")))] +pub(crate) const EXPECTED_MAX_CMSG_SIZE: usize = SCM_TIMESTAMP_CMSG_SIZE + IP6_PKTINFO_CMSG_SIZE; + +const fn control_message_space() -> usize { // Safety: CMSG_SPACE is safe to call (unsafe { libc::CMSG_SPACE((std::mem::size_of::()) as _) }) as usize } @@ -66,6 +106,7 @@ pub(crate) enum ControlMessage { }, #[cfg(target_os = "linux")] ReceiveError(libc::sock_extended_err), + DestinationIp(IpAddr), Other(libc::cmsghdr), } @@ -168,6 +209,55 @@ impl Iterator for ControlMessageIterator<'_> { ControlMessage::ReceiveError(error) } + + #[cfg(target_os = "linux")] + (libc::SOL_IP, libc::IP_PKTINFO) => { + // Safety: + // current_msg was constructed from a pointer that pointed to a valid + // control message. + // IP_PKTINFO always has a in_pktinfo in the data + let pktinfo = unsafe { + let ptr = libc::CMSG_DATA(current_msg) as *const libc::in_pktinfo; + std::ptr::read_unaligned(ptr) + }; + + ControlMessage::DestinationIp( + std::net::Ipv4Addr::from_bits(u32::from_be(pktinfo.ipi_addr.s_addr)).into(), + ) + } + + #[cfg(any(target_os = "freebsd", target_os = "macos"))] + (libc::IPPROTO_IP, libc::IP_RECVDSTADDR) => { + // Safety: + // current_msg was constructed from a pointer that pointed to a valid + // control message. + // IP_RECVDSTADDR always has a in_addr in the data + let in_addr = unsafe { + let ptr = libc::CMSG_DATA(current_msg) as *const libc::in_addr; + std::ptr::read_unaligned(ptr) + }; + + ControlMessage::DestinationIp( + std::net::Ipv4Addr::from_bits(u32::from_be(in_addr.s_addr)).into(), + ) + } + + (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => { + // Safety: + // current_msg was constructed from a pointer that pointed to a valid + // control message. + // IPV6_PKTINFO always has a in6_pktinfo in the data + let pktinfo = unsafe { + let ptr = libc::CMSG_DATA(current_msg) as *const libc::in6_pktinfo; + std::ptr::read_unaligned(ptr) + }; + + ControlMessage::DestinationIp( + std::net::Ipv6Addr::from_bits(u128::from_be_bytes(pktinfo.ipi6_addr.s6_addr)) + .into(), + ) + } + _ => ControlMessage::Other(*current_msg), }) } diff --git a/src/networkaddress.rs b/src/networkaddress.rs index 741130b..d920bb0 100644 --- a/src/networkaddress.rs +++ b/src/networkaddress.rs @@ -1,5 +1,5 @@ use std::{ - net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, os::fd::RawFd, }; @@ -24,11 +24,15 @@ pub(crate) mod sealed { pub struct PrivateToken; } -pub trait NetworkAddress: Sized + SealedNA { +pub trait NetworkAddress: Copy + Sized + SealedNA { #[doc(hidden)] fn to_sockaddr(&self, _token: PrivateToken) -> libc::sockaddr_storage; #[doc(hidden)] fn from_sockaddr(addr: libc::sockaddr_storage, _token: PrivateToken) -> Option; + #[doc(hidden)] + fn from_ip_and_port(addr: IpAddr, port: u16) -> Option; + #[doc(hidden)] + fn port(&self) -> u16; } pub trait MulticastJoinable: NetworkAddress + SealedMC { @@ -98,6 +102,17 @@ impl NetworkAddress for SocketAddrV4 { u16::from_be_bytes(input.sin_port.to_ne_bytes()), )) } + + fn from_ip_and_port(addr: IpAddr, port: u16) -> Option { + match addr { + IpAddr::V4(addr) => Some(SocketAddrV4::new(addr, port)), + IpAddr::V6(_) => None, + } + } + + fn port(&self) -> u16 { + self.port() + } } impl SealedNA for SocketAddrV6 {} @@ -154,6 +169,17 @@ impl NetworkAddress for SocketAddrV6 { input.sin6_scope_id, )) } + + fn from_ip_and_port(addr: IpAddr, port: u16) -> Option { + match addr { + IpAddr::V4(_) => None, + IpAddr::V6(addr) => Some(SocketAddrV6::new(addr, port, 0, 0)), + } + } + + fn port(&self) -> u16 { + self.port() + } } impl SealedNA for SocketAddr {} @@ -179,4 +205,12 @@ impl NetworkAddress for SocketAddr { _ => None, } } + + fn from_ip_and_port(addr: IpAddr, port: u16) -> Option { + Some(SocketAddr::new(addr, port)) + } + + fn port(&self) -> u16 { + self.port() + } } diff --git a/src/networkaddress/linux.rs b/src/networkaddress/linux.rs index 18752af..8c45d6e 100644 --- a/src/networkaddress/linux.rs +++ b/src/networkaddress/linux.rs @@ -120,6 +120,15 @@ impl NetworkAddress for EthernetAddress { input.sll_ifindex, )) } + + fn from_ip_and_port(_addr: std::net::IpAddr, _port: u16) -> Option { + None + } + + fn port(&self) -> u16 { + // Ethernet doesn't have a port, zero is a decent sentinal value to cover that. + 0 + } } impl SealedMC for EthernetAddress {} diff --git a/src/raw_socket.rs b/src/raw_socket.rs index 847dd0f..ab5722d 100644 --- a/src/raw_socket.rs +++ b/src/raw_socket.rs @@ -1,9 +1,11 @@ use std::{ io::IoSliceMut, + mem::transmute, os::fd::{AsRawFd, RawFd}, + ptr::write_unaligned, }; -use libc::{c_void, sockaddr, sockaddr_storage}; +use libc::{c_void, in6_addr, sockaddr, sockaddr_in, sockaddr_in6, sockaddr_storage}; use crate::{ cerr, @@ -12,6 +14,10 @@ use crate::{ }, }; +#[cfg(any(target_os = "macos", target_os = "freebsd"))] +mod bsdlike; +#[cfg(not(any(target_os = "macos", target_os = "freebsd", target_os = "linux")))] +mod fallback; #[cfg(target_os = "freebsd")] mod freebsd; #[cfg(target_os = "linux")] @@ -43,6 +49,25 @@ impl RawSocket { }) } + pub(crate) fn enable_destination_ipv6(&self) -> std::io::Result<()> { + // SAFETY: + // + // - the socket is provided by (safe) rust, and will outlive the call + // - method is guaranteed to be a valid "name" argument + // - the options pointer outlives the call + // - the `option_len` corresponds with the options pointer + unsafe { + cerr(libc::setsockopt( + self.fd, + libc::IPPROTO_IPV6, + libc::IPV6_RECVPKTINFO, + &(1 as libc::c_int) as *const _ as *const libc::c_void, + std::mem::size_of::() as libc::socklen_t, + ))?; + } + Ok(()) + } + pub(crate) fn bind(&self, addr: sockaddr_storage) -> std::io::Result<()> { // Per posix, it may be invalid to specify a length larger than that of the family. let len = sockaddr_len(addr); @@ -195,6 +220,108 @@ impl RawSocket { Ok(()) } + pub(crate) fn send_from_to( + &self, + msg: &[u8], + from: sockaddr_storage, + to: sockaddr_storage, + ) -> std::io::Result<()> { + match from.ss_family as libc::c_int { + libc::AF_INET => { + // Safety: + // Transmuting &sockaddr_storage into another sockaddr reference type is safe, and in this case the lifetimes work out. + let from = unsafe { transmute::<&sockaddr_storage, &sockaddr_in>(&from) }; + self.send_from_to_v4(msg, from.sin_addr, to) + } + libc::AF_INET6 => { + // Safety: + // Transmuting &sockaddr_storage into another sockaddr reference type is safe, and in this case the lifetimes work out. + let from = unsafe { transmute::<&sockaddr_storage, &sockaddr_in6>(&from) }; + self.send_from_to_v6(msg, from.sin6_addr, to) + } + _ => Err(std::io::ErrorKind::InvalidInput.into()), + } + } + + pub(crate) fn send_from(&self, msg: &[u8], addr: sockaddr_storage) -> std::io::Result<()> { + match addr.ss_family as libc::c_int { + libc::AF_INET => { + // Safety: + // Transmuting &sockaddr_storage into another sockaddr reference type is safe, and in this case the lifetimes work out. + let from = unsafe { transmute::<&sockaddr_storage, &sockaddr_in>(&addr) }; + self.send_from_v4(msg, from.sin_addr) + } + libc::AF_INET6 => { + // Safety: + // Transmuting &sockaddr_storage into another sockaddr reference type is safe, and in this case the lifetimes work out. + let from = unsafe { transmute::<&sockaddr_storage, &sockaddr_in6>(&addr) }; + self.send_from_v6(msg, from.sin6_addr) + } + _ => Err(std::io::ErrorKind::InvalidInput.into()), + } + } + + pub(crate) fn send_from_v6(&self, msg: &[u8], addr: in6_addr) -> std::io::Result<()> { + let control_message = control_message( + libc::IPPROTO_IPV6, + libc::IPV6_PKTINFO, + libc::in6_pktinfo { + ipi6_addr: addr, + ipi6_ifindex: 0, + }, + ); + + let mut iov = libc::iovec { + iov_base: msg.as_ptr() as *mut libc::c_void, + iov_len: msg.len(), + }; + + let mut msghdr = empty_msghdr(); + msghdr.msg_iov = &raw mut iov; + msghdr.msg_iovlen = 1; + msghdr.msg_control = control_message.as_ptr() as *mut _; + msghdr.msg_controllen = control_message.len() as _; + + // Safety: + // msghdr is valid. + cerr(unsafe { libc::sendmsg(self.fd, &raw const msghdr, 0) } as _).map(|_| {}) + } + + pub(crate) fn send_from_to_v6( + &self, + msg: &[u8], + from: in6_addr, + to: sockaddr_storage, + ) -> std::io::Result<()> { + let to_len = sockaddr_len(to); + + let control_message = control_message( + libc::IPPROTO_IPV6, + libc::IPV6_PKTINFO, + libc::in6_pktinfo { + ipi6_addr: from, + ipi6_ifindex: 0, + }, + ); + + let mut iov = libc::iovec { + iov_base: msg.as_ptr() as *mut libc::c_void, + iov_len: msg.len(), + }; + + let mut msghdr = empty_msghdr(); + msghdr.msg_name = &raw const to as *mut _; + msghdr.msg_namelen = to_len; + msghdr.msg_iov = &raw mut iov; + msghdr.msg_iovlen = 1; + msghdr.msg_control = control_message.as_ptr() as *mut _; + msghdr.msg_controllen = control_message.len() as _; + + // Safety: + // msghdr is valid. + cerr(unsafe { libc::sendmsg(self.fd, &raw const msghdr, 0) } as _).map(|_| {}) + } + pub(crate) fn getsockname(&self) -> std::io::Result { let mut addr = zeroed_sockaddr_storage(); let mut addr_len: libc::socklen_t = std::mem::size_of_val(&addr) as _; @@ -240,6 +367,39 @@ fn sockaddr_len(addr: sockaddr_storage) -> u32 { }) } +// Generate a control message with T as its contents +// Guarantees that the resulting vec contains valid control messages. +fn control_message(level: libc::c_int, type_: libc::c_int, content: T) -> Vec { + // Safety: + // libc::CMSG_SPACE is always safe to call. + let mut control_message = vec![0u8; unsafe { libc::CMSG_SPACE(size_of::() as _) } as _]; + + // Safety: + // libc::CMSG_LEN is always safe to call. + let header = libc::cmsghdr { + cmsg_len: unsafe { libc::CMSG_LEN(size_of::() as _) } as _, + cmsg_level: level, + cmsg_type: type_, + }; + // Safety: + // libc::CMSG_SPACE ensures we have sufficient space for the control message header. + unsafe { write_unaligned(control_message.as_mut_ptr() as *mut libc::cmsghdr, header) }; + + // Safety: + // libc::CMSG_SPACE ensures we have sufficient space for the control message contents. + // libc::CMSG_DATA ensures we write that content at a valid offset. + // libc::CMSG_DATA provides a valid pointer to the contents of a control message when provided + // with a valid pointer to a control message header, which we have in the buffer. + unsafe { + write_unaligned( + libc::CMSG_DATA(control_message.as_mut_ptr() as *mut libc::cmsghdr) as *mut T, + content, + ) + }; + + control_message +} + impl Drop for RawSocket { fn drop(&mut self) { // Safety: close is always safe to call on a file descriptor diff --git a/src/raw_socket/bsdlike.rs b/src/raw_socket/bsdlike.rs new file mode 100644 index 0000000..e171c5b --- /dev/null +++ b/src/raw_socket/bsdlike.rs @@ -0,0 +1,60 @@ +use libc::{in_addr, sockaddr_storage}; + +use crate::{cerr, control_message::empty_msghdr, raw_socket::sockaddr_len}; + +use super::{control_message, RawSocket}; + +impl RawSocket { + pub(crate) fn enable_destination_ipv4(&self) -> std::io::Result<()> { + // SAFETY: + // + // - the socket is provided by (safe) rust, and will outlive the call + // - method is guaranteed to be a valid "name" argument + // - the options pointer outlives the call + // - the `option_len` corresponds with the options pointer + unsafe { + cerr(libc::setsockopt( + self.fd, + libc::IPPROTO_IP, + libc::IP_RECVDSTADDR, + &(1 as libc::c_int) as *const _ as *const libc::c_void, + std::mem::size_of::() as libc::socklen_t, + ))?; + } + Ok(()) + } + + pub(crate) fn send_from_v4(&self, msg: &[u8], _addr: in_addr) -> std::io::Result<()> { + // FreeBSD and similar don't support setting an IPv4 source address + // on connected sockets. + self.send(msg) + } + + pub(crate) fn send_from_to_v4( + &self, + msg: &[u8], + from: in_addr, + to: sockaddr_storage, + ) -> std::io::Result<()> { + let to_len = sockaddr_len(to); + + let control_message = control_message(libc::IPPROTO_IP, libc::IP_SENDSRCADDR, from); + + let mut iov = libc::iovec { + iov_base: msg.as_ptr() as *mut libc::c_void, + iov_len: msg.len(), + }; + + let mut msghdr = empty_msghdr(); + msghdr.msg_name = &raw const to as *mut _; + msghdr.msg_namelen = to_len; + msghdr.msg_iov = &raw mut iov; + msghdr.msg_iovlen = 1; + msghdr.msg_control = control_message.as_ptr() as *mut _; + msghdr.msg_controllen = control_message.len() as _; + + // Safety: + // msghdr is valid. + cerr(unsafe { libc::sendmsg(self.fd, &raw const msghdr, 0) } as _).map(|_| {}) + } +} diff --git a/src/raw_socket/fallback.rs b/src/raw_socket/fallback.rs new file mode 100644 index 0000000..960985b --- /dev/null +++ b/src/raw_socket/fallback.rs @@ -0,0 +1,25 @@ +use libc::{in_addr, sockaddr_storage}; + +use super::RawSocket; + +impl RawSocket { + pub(crate) fn enable_destination_ipv4(&self) -> std::io::Result<()> { + // Noop, fallback to local address. + Ok(()) + } + + pub(crate) fn send_from_v4(&self, msg: &[u8], addr: in_addr) -> std::io::Result<()> { + // Fallback, ignore the from + self.send(msg) + } + + pub(crate) fn send_from_to_v4( + &self, + msg: &[u8], + from: in_addr, + to: sockaddr_storage, + ) -> std::io::Result<()> { + // Fallback, ignore the from + self.send_to(msg, to) + } +} diff --git a/src/raw_socket/linux.rs b/src/raw_socket/linux.rs index 2fe83ac..bbed6ab 100644 --- a/src/raw_socket/linux.rs +++ b/src/raw_socket/linux.rs @@ -1,8 +1,12 @@ use std::net::Ipv4Addr; -use crate::{cerr, interface::InterfaceName}; +use libc::{in_addr, sockaddr_storage}; -use super::RawSocket; +use crate::{ + cerr, control_message::empty_msghdr, interface::InterfaceName, raw_socket::sockaddr_len, +}; + +use super::{control_message, RawSocket}; #[repr(C)] struct SoTimestamping { @@ -11,6 +15,25 @@ struct SoTimestamping { } impl RawSocket { + pub(crate) fn enable_destination_ipv4(&self) -> std::io::Result<()> { + // SAFETY: + // + // - the socket is provided by (safe) rust, and will outlive the call + // - method is guaranteed to be a valid "name" argument + // - the options pointer outlives the call + // - the `option_len` corresponds with the options pointer + unsafe { + cerr(libc::setsockopt( + self.fd, + libc::IPPROTO_IP, + libc::IP_PKTINFO, + &(1 as libc::c_int) as *const _ as *const libc::c_void, + std::mem::size_of::() as libc::socklen_t, + ))?; + } + Ok(()) + } + pub(crate) fn so_timestamping(&self, options: u32, bind_phc: u32) -> std::io::Result<()> { // Documentation on the timestamping calls: // @@ -190,4 +213,67 @@ impl RawSocket { })?; Ok(()) } + + pub(crate) fn send_from_v4(&self, msg: &[u8], addr: in_addr) -> std::io::Result<()> { + let control_message = control_message( + libc::IPPROTO_IP, + libc::IP_PKTINFO, + libc::in_pktinfo { + ipi_ifindex: 0, + ipi_spec_dst: addr, + ipi_addr: libc::in_addr { s_addr: 0 }, + }, + ); + + let mut iov = libc::iovec { + iov_base: msg.as_ptr() as *mut libc::c_void, + iov_len: msg.len(), + }; + + let mut msghdr = empty_msghdr(); + msghdr.msg_iov = &raw mut iov; + msghdr.msg_iovlen = 1; + msghdr.msg_control = control_message.as_ptr() as *mut _; + msghdr.msg_controllen = control_message.len() as _; + + // Safety: + // msghdr is valid. + cerr(unsafe { libc::sendmsg(self.fd, &raw const msghdr, 0) } as _).map(|_| {}) + } + + pub(crate) fn send_from_to_v4( + &self, + msg: &[u8], + from: in_addr, + to: sockaddr_storage, + ) -> std::io::Result<()> { + let to_len = sockaddr_len(to); + + let control_message = control_message( + libc::IPPROTO_IP, + libc::IP_PKTINFO, + libc::in_pktinfo { + ipi_ifindex: 0, + ipi_spec_dst: from, + ipi_addr: libc::in_addr { s_addr: 0 }, + }, + ); + + let mut iov = libc::iovec { + iov_base: msg.as_ptr() as *mut libc::c_void, + iov_len: msg.len(), + }; + + let mut msghdr = empty_msghdr(); + msghdr.msg_name = &raw const to as *mut _; + msghdr.msg_namelen = to_len; + msghdr.msg_iov = &raw mut iov; + msghdr.msg_iovlen = 1; + msghdr.msg_control = control_message.as_ptr() as *mut _; + msghdr.msg_controllen = control_message.len() as _; + + // Safety: + // msghdr is valid. + cerr(unsafe { libc::sendmsg(self.fd, &raw const msghdr, 0) } as _).map(|_| {}) + } } diff --git a/src/socket.rs b/src/socket.rs index 3ffb5a0..a0f4868 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -3,7 +3,7 @@ use std::{marker::PhantomData, net::SocketAddr, os::fd::AsRawFd}; use tokio::io::{unix::AsyncFd, Interest}; use crate::{ - control_message::{control_message_space, ControlMessage, MessageQueue}, + control_message::{ControlMessage, MessageQueue, EXPECTED_MAX_CMSG_SIZE}, interface::InterfaceName, networkaddress::{sealed::PrivateToken, MulticastJoinable, NetworkAddress}, raw_socket::RawSocket, @@ -98,6 +98,7 @@ fn select_timestamp( pub struct RecvResult { pub bytes_read: usize, pub remote_addr: A, + pub local_addr: A, pub timestamp: Option, } @@ -107,34 +108,31 @@ pub struct Socket { socket: AsyncFd, #[cfg(target_os = "linux")] send_counter: u32, - _addr: PhantomData, + local_addr: A, _state: PhantomData, } +#[non_exhaustive] pub struct Open; +#[non_exhaustive] pub struct Connected; impl Socket { - pub fn local_addr(&self) -> std::io::Result { - let addr = self.socket.get_ref().getsockname()?; - A::from_sockaddr(addr, PrivateToken).ok_or_else(|| std::io::ErrorKind::Other.into()) - } - - pub fn peer_addr(&self) -> std::io::Result { - let addr = self.socket.get_ref().getpeername()?; - A::from_sockaddr(addr, PrivateToken).ok_or_else(|| std::io::ErrorKind::Other.into()) + pub fn local_addr(&self) -> A { + self.local_addr } pub async fn recv(&self, buf: &mut [u8]) -> std::io::Result> { self.socket .async_io(Interest::READABLE, |socket| { - let mut control_buf = [0; control_message_space::<[libc::timespec; 3]>()]; + let mut control_buf = [0; EXPECTED_MAX_CMSG_SIZE]; // loops for when we receive an interrupt during the recv let (bytes_read, control_messages, remote_address) = socket.receive_message(buf, &mut control_buf, MessageQueue::Normal)?; let mut timestamp = None; + let mut local_addr = self.local_addr; // Loops through the control messages, but we should only get a single message // in practice @@ -153,6 +151,12 @@ impl Socket { ); } + ControlMessage::DestinationIp(addr) => { + if let Some(addr) = A::from_ip_and_port(addr, self.local_addr.port()) { + local_addr = addr; + } + } + ControlMessage::Other(msg) => { tracing::debug!( "unexpected control message on receive: {} {}", @@ -169,6 +173,7 @@ impl Socket { Ok(RecvResult { bytes_read, remote_addr, + local_addr, timestamp, }) }) @@ -204,6 +209,41 @@ impl Socket { } } + pub async fn send_from_to( + &mut self, + buf: &[u8], + from: A, + to: A, + ) -> std::io::Result> { + let from = from.to_sockaddr(PrivateToken); + let to = to.to_sockaddr(PrivateToken); + + self.socket + .async_io(Interest::WRITABLE, |socket| { + socket.send_from_to(buf, from, to) + }) + .await?; + + if matches!( + self.timestamp_mode, + InterfaceTimestampMode::HardwarePTPAll | InterfaceTimestampMode::SoftwareAll + ) { + #[cfg(target_os = "linux")] + { + let expected_counter = self.send_counter; + self.send_counter = self.send_counter.wrapping_add(1); + self.fetch_send_timestamp(expected_counter).await + } + + #[cfg(not(target_os = "linux"))] + { + unreachable!("Should not be able to create send timestamping sockets on platforms other than linux") + } + } else { + Ok(None) + } + } + pub fn connect(self, addr: A) -> std::io::Result> { let addr = addr.to_sockaddr(PrivateToken); self.socket.get_ref().connect(addr)?; @@ -212,13 +252,18 @@ impl Socket { socket: self.socket, #[cfg(target_os = "linux")] send_counter: self.send_counter, - _addr: PhantomData, + local_addr: self.local_addr, _state: PhantomData, }) } } impl Socket { + pub fn peer_addr(&self) -> std::io::Result { + let addr = self.socket.get_ref().getpeername()?; + A::from_sockaddr(addr, PrivateToken).ok_or_else(|| std::io::ErrorKind::Other.into()) + } + pub async fn send(&mut self, buf: &[u8]) -> std::io::Result> { self.socket .async_io(Interest::WRITABLE, |socket| socket.send(buf)) @@ -243,6 +288,32 @@ impl Socket { Ok(None) } } + + pub async fn send_from(&mut self, buf: &[u8], from: A) -> std::io::Result> { + let from = from.to_sockaddr(PrivateToken); + self.socket + .async_io(Interest::WRITABLE, |socket| socket.send_from(buf, from)) + .await?; + + if matches!( + self.timestamp_mode, + InterfaceTimestampMode::HardwarePTPAll | InterfaceTimestampMode::SoftwareAll + ) { + #[cfg(target_os = "linux")] + { + let expected_counter = self.send_counter; + self.send_counter = self.send_counter.wrapping_add(1); + self.fetch_send_timestamp(expected_counter).await + } + + #[cfg(not(target_os = "linux"))] + { + unreachable!("Should not be able to create send timestamping sockets on platforms other than linux") + } + } else { + Ok(None) + } + } } impl Socket { @@ -264,16 +335,23 @@ pub fn open_ip( SocketAddr::V4(_) => RawSocket::open(libc::PF_INET, libc::SOCK_DGRAM, libc::IPPROTO_UDP), SocketAddr::V6(_) => RawSocket::open(libc::PF_INET6, libc::SOCK_DGRAM, libc::IPPROTO_UDP), }?; + match addr { + SocketAddr::V4(_) => socket.enable_destination_ipv4()?, + SocketAddr::V6(_) => socket.enable_destination_ipv6()?, + } socket.bind(addr.to_sockaddr(PrivateToken))?; socket.set_nonblocking(true)?; configure_timestamping(&socket, None, timestamping.into(), None)?; + let local_addr = SocketAddr::from_sockaddr(socket.getsockname()?, PrivateToken) + .ok_or::(std::io::ErrorKind::Other.into())?; + Ok(Socket { timestamp_mode: timestamping.into(), socket: AsyncFd::new(socket)?, #[cfg(target_os = "linux")] send_counter: 0, - _addr: PhantomData, + local_addr, _state: PhantomData, }) } @@ -287,16 +365,23 @@ pub fn connect_address( SocketAddr::V4(_) => RawSocket::open(libc::PF_INET, libc::SOCK_DGRAM, libc::IPPROTO_UDP), SocketAddr::V6(_) => RawSocket::open(libc::PF_INET6, libc::SOCK_DGRAM, libc::IPPROTO_UDP), }?; + match addr { + SocketAddr::V4(_) => socket.enable_destination_ipv4()?, + SocketAddr::V6(_) => socket.enable_destination_ipv6()?, + } socket.connect(addr.to_sockaddr(PrivateToken))?; socket.set_nonblocking(true)?; configure_timestamping(&socket, None, timestamping.into(), None)?; + let local_addr = SocketAddr::from_sockaddr(socket.getsockname()?, PrivateToken) + .ok_or::(std::io::ErrorKind::Other.into())?; + Ok(Socket { timestamp_mode: timestamping.into(), socket: AsyncFd::new(socket)?, #[cfg(target_os = "linux")] send_counter: 0, - _addr: PhantomData, + local_addr, _state: PhantomData, }) } @@ -304,7 +389,7 @@ pub fn connect_address( #[cfg(test)] mod tests { use super::*; - use std::net::{IpAddr, Ipv4Addr}; + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; #[tokio::test] async fn test_open_ip() { @@ -328,4 +413,137 @@ mod tests { assert_eq!(recv_result.bytes_read, 3); assert_eq!(&buf[0..3], &[4, 5, 6]); } + + #[tokio::test] + async fn test_open_ip_dest_addr() { + let a = open_ip( + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 5127), + GeneralTimestampMode::None, + ) + .unwrap(); + let mut b = connect_address( + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5127), + GeneralTimestampMode::None, + ) + .unwrap(); + assert!(b.send(&[1, 2, 3]).await.is_ok()); + let mut buf = [0; 4]; + let recv_result = a.recv(&mut buf).await.unwrap(); + assert_eq!(recv_result.bytes_read, 3); + assert_eq!(&buf[0..3], &[1, 2, 3]); + assert_eq!( + recv_result.local_addr, + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5127) + ); + assert_ne!(a.local_addr().ip(), IpAddr::V4(Ipv4Addr::LOCALHOST)); + + let a = open_ip( + SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 5129), + GeneralTimestampMode::None, + ) + .unwrap(); + let mut b = connect_address( + SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 5129), + GeneralTimestampMode::None, + ) + .unwrap(); + assert!(b.send(&[1, 2, 3]).await.is_ok()); + let mut buf = [0; 4]; + let recv_result = a.recv(&mut buf).await.unwrap(); + assert_eq!(recv_result.bytes_read, 3); + assert_eq!(&buf[0..3], &[1, 2, 3]); + assert_eq!( + recv_result.local_addr, + SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 5129) + ); + assert_ne!(a.local_addr().ip(), IpAddr::V6(Ipv6Addr::LOCALHOST)); + } + + #[tokio::test] + async fn test_send_from() { + let mut a = open_ip( + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 5130), + GeneralTimestampMode::None, + ) + .unwrap(); + let mut b = connect_address( + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5130), + GeneralTimestampMode::None, + ) + .unwrap(); + b.send_from( + &[1, 2, 3], + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + ) + .await + .unwrap(); + let mut buf = [0; 4]; + let recv_result = a.recv(&mut buf).await.unwrap(); + assert_eq!(recv_result.bytes_read, 3); + assert_eq!(&buf[0..3], &[1, 2, 3]); + assert_eq!( + recv_result.remote_addr.ip(), + IpAddr::V4(Ipv4Addr::LOCALHOST) + ); + + a.send_from_to( + &[1, 2, 3], + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + dbg!(b.local_addr()), + ) + .await + .unwrap(); + let mut buf = [0; 4]; + let recv_result = b.recv(&mut buf).await.unwrap(); + assert_eq!(recv_result.bytes_read, 3); + assert_eq!(&buf[0..3], &[1, 2, 3]); + assert_eq!( + recv_result.remote_addr.ip(), + IpAddr::V4(Ipv4Addr::LOCALHOST) + ); + } + + #[tokio::test] + async fn test_send_from_v6() { + let mut a = open_ip( + SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 5131), + GeneralTimestampMode::None, + ) + .unwrap(); + let mut b = connect_address( + SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 5131), + GeneralTimestampMode::None, + ) + .unwrap(); + b.send_from( + &[1, 2, 3], + SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 0), + ) + .await + .unwrap(); + let mut buf = [0; 4]; + let recv_result = a.recv(&mut buf).await.unwrap(); + assert_eq!(recv_result.bytes_read, 3); + assert_eq!(&buf[0..3], &[1, 2, 3]); + assert_eq!( + recv_result.remote_addr.ip(), + IpAddr::V6(Ipv6Addr::LOCALHOST) + ); + + a.send_from_to( + &[1, 2, 3], + SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 0), + dbg!(b.local_addr()), + ) + .await + .unwrap(); + let mut buf = [0; 4]; + let recv_result = b.recv(&mut buf).await.unwrap(); + assert_eq!(recv_result.bytes_read, 3); + assert_eq!(&buf[0..3], &[1, 2, 3]); + assert_eq!( + recv_result.remote_addr.ip(), + IpAddr::V6(Ipv6Addr::LOCALHOST) + ); + } } diff --git a/src/socket/linux.rs b/src/socket/linux.rs index 9d3ea61..a4c1c73 100644 --- a/src/socket/linux.rs +++ b/src/socket/linux.rs @@ -6,7 +6,7 @@ use std::{ use tokio::io::{unix::AsyncFd, Interest}; use crate::{ - control_message::{control_message_space, ControlMessage, MessageQueue}, + control_message::{ControlMessage, MessageQueue, EXPECTED_MAX_CMSG_SIZE}, interface::{lookup_phc, InterfaceName}, networkaddress::{sealed::PrivateToken, EthernetAddress, MacAddress, NetworkAddress}, raw_socket::RawSocket, @@ -52,10 +52,7 @@ impl Socket { &self, expected_counter: u32, ) -> std::io::Result> { - const CONTROL_SIZE: usize = control_message_space::<[libc::timespec; 3]>() - + control_message_space::<(libc::sock_extended_err, libc::sockaddr_storage)>(); - - let mut control_buf = [0; CONTROL_SIZE]; + let mut control_buf = [0; EXPECTED_MAX_CMSG_SIZE]; // NOTE: this read could block! let (_, control_messages, _) = self.socket.get_ref().receive_message( @@ -93,6 +90,10 @@ impl Socket { } } + ControlMessage::DestinationIp(_) => { + tracing::debug!("unexpected destination ip control message"); + } + ControlMessage::Other(msg) => { tracing::debug!( msg.cmsg_level, @@ -163,6 +164,8 @@ pub fn open_interface_udp( ) -> std::io::Result> { // Setup the socket let socket = RawSocket::open(libc::PF_INET6, libc::SOCK_DGRAM, libc::IPPROTO_UDP)?; + socket.enable_destination_ipv4()?; + socket.enable_destination_ipv6()?; socket.reuse_addr()?; socket.ipv6_v6only(false)?; socket.bind(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0).to_sockaddr(PrivateToken))?; @@ -185,11 +188,14 @@ pub fn open_interface_udp( } socket.set_nonblocking(true)?; + let local_addr = SocketAddr::from_sockaddr(socket.getsockname()?, PrivateToken) + .ok_or::(std::io::ErrorKind::Other.into())?; + Ok(Socket { timestamp_mode: timestamping, socket: AsyncFd::new(socket)?, send_counter: 0, - _addr: PhantomData, + local_addr, _state: PhantomData, }) } @@ -202,6 +208,7 @@ pub fn open_interface_udp4( ) -> std::io::Result> { // Setup the socket let socket = RawSocket::open(libc::PF_INET, libc::SOCK_DGRAM, libc::IPPROTO_UDP)?; + socket.enable_destination_ipv4()?; socket.reuse_addr()?; socket.bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port).to_sockaddr(PrivateToken))?; socket.bind_to_device(interface)?; @@ -223,11 +230,14 @@ pub fn open_interface_udp4( } socket.set_nonblocking(true)?; + let local_addr = SocketAddrV4::from_sockaddr(socket.getsockname()?, PrivateToken) + .ok_or::(std::io::ErrorKind::Other.into())?; + Ok(Socket { timestamp_mode: timestamping, socket: AsyncFd::new(socket)?, send_counter: 0, - _addr: PhantomData, + local_addr, _state: PhantomData, }) } @@ -240,6 +250,7 @@ pub fn open_interface_udp6( ) -> std::io::Result> { // Setup the socket let socket = RawSocket::open(libc::PF_INET6, libc::SOCK_DGRAM, libc::IPPROTO_UDP)?; + socket.enable_destination_ipv6()?; socket.reuse_addr()?; socket.ipv6_v6only(true)?; socket.bind(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0).to_sockaddr(PrivateToken))?; @@ -262,11 +273,14 @@ pub fn open_interface_udp6( } socket.set_nonblocking(true)?; + let local_addr = SocketAddrV6::from_sockaddr(socket.getsockname()?, PrivateToken) + .ok_or::(std::io::ErrorKind::Other.into())?; + Ok(Socket { timestamp_mode: timestamping, socket: AsyncFd::new(socket)?, send_counter: 0, - _addr: PhantomData, + local_addr, _state: PhantomData, }) } @@ -308,11 +322,14 @@ pub fn open_interface_ethernet( } socket.set_nonblocking(true)?; + let local_addr = EthernetAddress::from_sockaddr(socket.getsockname()?, PrivateToken) + .ok_or::(std::io::ErrorKind::Other.into())?; + Ok(Socket { timestamp_mode: timestamping, socket: AsyncFd::new(socket)?, send_counter: 0, - _addr: PhantomData, + local_addr, _state: PhantomData, }) } @@ -325,6 +342,48 @@ mod tests { use super::*; + #[tokio::test] + async fn test_open_udp() { + use std::str::FromStr; + let a = open_interface_udp( + InterfaceName::from_str("lo").unwrap(), + 5128, + super::InterfaceTimestampMode::None, + None, + ) + .unwrap(); + + let mut b = connect_address( + SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 5128), + GeneralTimestampMode::None, + ) + .unwrap(); + assert!(b.send(&[1, 2, 3]).await.is_ok()); + let mut buf = [0; 4]; + let recv_result = a.recv(&mut buf).await.unwrap(); + assert_eq!(recv_result.bytes_read, 3); + assert_eq!(&buf[0..3], &[1, 2, 3]); + assert_eq!( + recv_result.local_addr, + SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 5128) + ); + + let mut b = connect_address( + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 1)), 5128), + GeneralTimestampMode::None, + ) + .unwrap(); + assert!(b.send(&[1, 2, 3]).await.is_ok()); + let mut buf = [0; 4]; + let recv_result = a.recv(&mut buf).await.unwrap(); + assert_eq!(recv_result.bytes_read, 3); + assert_eq!(&buf[0..3], &[1, 2, 3]); + assert_eq!( + recv_result.local_addr, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 1)), 5128) + ); + } + #[tokio::test] async fn test_open_udp6() { use std::str::FromStr; @@ -345,6 +404,10 @@ mod tests { let recv_result = a.recv(&mut buf).await.unwrap(); assert_eq!(recv_result.bytes_read, 3); assert_eq!(&buf[0..3], &[1, 2, 3]); + assert_eq!( + recv_result.local_addr, + SocketAddrV6::new(Ipv6Addr::LOCALHOST, 5123, 0, 0) + ); assert!(a.send_to(&[4, 5, 6], recv_result.remote_addr).await.is_ok()); let recv_result = b.recv(&mut buf).await.unwrap(); assert_eq!(recv_result.bytes_read, 3); @@ -371,6 +434,10 @@ mod tests { let recv_result = a.recv(&mut buf).await.unwrap(); assert_eq!(recv_result.bytes_read, 3); assert_eq!(&buf[0..3], &[1, 2, 3]); + assert_eq!( + recv_result.local_addr, + SocketAddrV4::new(Ipv4Addr::LOCALHOST, 5124) + ); assert!(a.send_to(&[4, 5, 6], recv_result.remote_addr).await.is_ok()); let recv_result = b.recv(&mut buf).await.unwrap(); assert_eq!(recv_result.bytes_read, 3);