diff --git a/.cargo/config.toml b/.cargo/config.toml index 733f815..1f360d6 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -68,7 +68,9 @@ rustflags = [ "-Wclippy::useless_transmute", "-Wclippy::verbose_file_reads", "-Wclippy::zero_sized_map_values", + "-Wclippy::undocumented_unsafe_blocks", "-Wfuture_incompatible", "-Wnonstandard_style", "-Wrust_2018_idioms", + "-Wlet-underscore", ] diff --git a/CHANGELOG.md b/CHANGELOG.md index c5f1691..a099af1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] - ReleaseDate +### Changed +- [PR#16](https://github.com/Jake-Shadle/xdp/pull/16) changed `RxRing` and `TxRing` to use the new `slab::Slab` trait. +- [PR#16](https://github.com/Jake-Shadle/xdp/pull/16) moved `HeapSlab` to the new `slab` module, and made it implement `slab::Slab`, changing it so that items are always pushed to the front and popped from the back, unlike the previous implementation which allowed both. + +### Added +- [PR#16](https://github.com/Jake-Shadle/xdp/pull/16) added a new `slab::StackSlab` fixed size ring buffer that implements `slab::Slab`. + +### Fixed +- [PR#16](https://github.com/Jake-Shadle/xdp/pull/16) fixed some undefined behavior in the netlink code used to query NIC capabilities. +- [PR#16](https://github.com/Jake-Shadle/xdp/pull/16) fixed a bug where TX metadata would not be added and would return an error if the packet headroom was not large enough for the metadata, this is irrelevant. + ## [0.5.0] - 2025-02-27 ### Changed - [PR#15](https://github.com/Jake-Shadle/xdp/pull/15) renamed `UdpPacket` -> `UdpHeaders`, and changed the contents to be the actual headers that can be de/serialized from/to the packet buffer. diff --git a/crates/integ/tests/tx_checksum.rs b/crates/integ/tests/tx_checksum.rs index dba7a81..5a0d7be 100644 --- a/crates/integ/tests/tx_checksum.rs +++ b/crates/integ/tests/tx_checksum.rs @@ -1,6 +1,7 @@ use test_utils::netlink::VethPair; use xdp::{ packet::{net_types as nt, *}, + slab::Slab, socket::*, umem::*, *, @@ -39,8 +40,8 @@ fn do_checksum_test(software: bool, vpair: &VethPair) { frame_size: FrameSize::TwoK, head_room: 20, frame_count: 64, - tx_metadata: software, software_checksum: software, + ..Default::default() } .build() .expect("invalid umem cfg"), @@ -110,168 +111,148 @@ fn do_checksum_test(software: bool, vpair: &VethPair) { }}; } - let res = std::thread::scope(|s| { - let client = s.spawn(|| -> Result<(), (&'static str, std::io::Error)> { - let dest: std::net::SocketAddr = (vpair.inside.ipv4, 7777).into(); - let local = client_socket.local_addr().unwrap(); - println!("sending {}b {local} -> {dest}", clientp.len()); - client_socket - .send_to(clientp, dest) - .map_err(|err| ("failed to send first request", err))?; - - let mut response = [0u8; 20]; - - println!("receiving {}b {local} <- {dest}", serverp.len()); - let (read, addr) = client_socket - .recv_from(&mut response) - .map_err(|err| ("failed to receive first response", err))?; - assert_eq!(&response[..read], serverp); - assert_eq!(addr, (vpair.inside.ipv4, sport).into()); - - println!("sending {}b {local} -> {dest}", clientp.len()); - client_socket - .send_to(clientp, dest) - .map_err(|err| ("failed to send second request", err))?; - - println!("receiving {}b {local} <- {dest}", serverp.len()); - let (read, addr) = client_socket - .recv_from(&mut response) - .map_err(|err| ("failed to receive second response", err))?; - - assert_eq!(&response[..read], serverp); - assert_eq!(addr, (vpair.inside.ipv4, sport).into()); - - Ok(()) - }); - - let server = s.spawn(|| { - let timeout = PollTimeout::new(Some(std::time::Duration::from_millis(100))); - - let mut slab = xdp::HeapSlab::with_capacity(BATCH_SIZE); - - unsafe { - poll_loop!({ - xdp_socket.poll_read(timeout).unwrap(); - if rx.recv(&umem, &mut slab) == 1 { - break; - } - }); - - let mut packet = slab.pop_back().unwrap(); - let udp = nt::UdpHeaders::parse_packet(&packet) - .expect("failed to parse packet") - .expect("not a UDP packet"); - - // For this packet, we calculate the full checksum - packet.adjust_tail(-(udp.data_length as i32)).unwrap(); - packet.insert(udp.data_offset, serverp).unwrap(); - - let nt::IpHdr::V4(mut copy) = udp.ip else { - unreachable!() - }; - std::mem::swap(&mut copy.destination, &mut copy.source); - copy.time_to_live -= 1; - - let mut new = nt::UdpHeaders { - eth: nt::EthHdr { - source: udp.eth.destination, - destination: udp.eth.source, - ether_type: udp.eth.ether_type, - }, - ip: nt::IpHdr::V4(copy), - udp: nt::UdpHdr { - destination: udp.udp.source, - source: sport.into(), - length: 0.into(), - check: 0, - }, - data_offset: udp.data_offset, - data_length: serverp.len(), - }; - - // For this packet, we calculate the full checksum - let data_checksum = csum::partial(serverp, 0); - let full_checksum = new.calc_checksum(serverp.len(), data_checksum); - new.set_packet_headers(&mut packet, true).unwrap(); - println!("Full checksum: {full_checksum:04x}"); - - slab.push_back(packet); - assert_eq!(tx.send(&mut slab), 1); - - poll_loop!({ - xdp_socket.poll(timeout).unwrap(); - if cr.dequeue(&mut umem, 1) == 1 { - break; - } - }); - - poll_loop!({ - xdp_socket.poll_read(timeout).unwrap(); - if rx.recv(&umem, &mut slab) == 1 { - break; - } - }); - - let mut packet = slab.pop_back().unwrap(); - let udp = nt::UdpHeaders::parse_packet(&packet) - .expect("failed to parse packet") - .expect("not a UDP packet"); - - packet.adjust_tail(-(udp.data_length as i32)).unwrap(); - packet.insert(udp.data_offset, serverp).unwrap(); - - let nt::IpHdr::V4(mut copy) = udp.ip else { - unreachable!() - }; - std::mem::swap(&mut copy.destination, &mut copy.source); - copy.time_to_live -= 1; - - let mut new = nt::UdpHeaders { - eth: nt::EthHdr { - source: udp.eth.destination, - destination: udp.eth.source, - ether_type: udp.eth.ether_type, - }, - ip: nt::IpHdr::V4(copy), - udp: nt::UdpHdr { - destination: udp.udp.source, - source: sport.into(), - length: 0.into(), - check: 0, - }, - data_offset: udp.data_offset, - data_length: serverp.len(), - }; - new.set_packet_headers(&mut packet, true).unwrap(); - println!( - "partial checksum: {:04x}", - packet.calc_udp_checksum().unwrap() - ); - - slab.push_back(packet); - assert_eq!(tx.send(&mut slab), 1); - - poll_loop!({ - xdp_socket.poll(timeout).unwrap(); - if cr.dequeue(&mut umem, 1) == 1 { - break; - } - }); - } - }); + std::thread::spawn(move || { + let timeout = PollTimeout::new(Some(std::time::Duration::from_millis(100))); + + let mut slab = xdp::slab::StackSlab::::new(); - let err = if let Err(err) = client.join().unwrap() { - run.store(false, std::sync::atomic::Ordering::Relaxed); - Some(err) - } else { - None - }; + unsafe { + poll_loop!({ + xdp_socket.poll_read(timeout).unwrap(); + if rx.recv(&umem, &mut slab) == 1 { + break; + } + }); + + let mut packet = slab.pop_back().unwrap(); + let udp = nt::UdpHeaders::parse_packet(&packet) + .expect("failed to parse packet") + .expect("not a UDP packet"); + + // For this packet, we calculate the full checksum + packet.adjust_tail(-(udp.data_length as i32)).unwrap(); + packet.insert(udp.data_offset, serverp).unwrap(); + + let nt::IpHdr::V4(mut copy) = udp.ip else { + unreachable!() + }; + std::mem::swap(&mut copy.destination, &mut copy.source); + copy.time_to_live -= 1; + + let mut new = nt::UdpHeaders { + eth: nt::EthHdr { + source: udp.eth.destination, + destination: udp.eth.source, + ether_type: udp.eth.ether_type, + }, + ip: nt::IpHdr::V4(copy), + udp: nt::UdpHdr { + destination: udp.udp.source, + source: sport.into(), + length: 0.into(), + check: 0, + }, + data_offset: udp.data_offset, + data_length: serverp.len(), + }; + + // For this packet, we calculate the full checksum + let data_checksum = csum::partial(serverp, 0); + let full_checksum = new.calc_checksum(serverp.len(), data_checksum); + new.set_packet_headers(&mut packet, true).unwrap(); + println!("Full checksum: {full_checksum:04x}"); + + slab.push_front(packet); + assert_eq!(tx.send(&mut slab), 1); + + poll_loop!({ + xdp_socket.poll(timeout).unwrap(); + if cr.dequeue(&mut umem, 1) == 1 { + break; + } + }); - server.join().unwrap(); - err + poll_loop!({ + xdp_socket.poll_read(timeout).unwrap(); + if rx.recv(&umem, &mut slab) == 1 { + break; + } + }); + + let mut packet = slab.pop_back().unwrap(); + let udp = nt::UdpHeaders::parse_packet(&packet) + .expect("failed to parse packet") + .expect("not a UDP packet"); + + packet.adjust_tail(-(udp.data_length as i32)).unwrap(); + packet.insert(udp.data_offset, serverp).unwrap(); + + let nt::IpHdr::V4(mut copy) = udp.ip else { + unreachable!() + }; + std::mem::swap(&mut copy.destination, &mut copy.source); + copy.time_to_live -= 1; + + let mut new = nt::UdpHeaders { + eth: nt::EthHdr { + source: udp.eth.destination, + destination: udp.eth.source, + ether_type: udp.eth.ether_type, + }, + ip: nt::IpHdr::V4(copy), + udp: nt::UdpHdr { + destination: udp.udp.source, + source: sport.into(), + length: 0.into(), + check: 0, + }, + data_offset: udp.data_offset, + data_length: serverp.len(), + }; + new.set_packet_headers(&mut packet, true).unwrap(); + println!( + "partial checksum: {:04x}", + packet.calc_udp_checksum().unwrap() + ); + + slab.push_front(packet); + assert_eq!(tx.send(&mut slab), 1); + + poll_loop!({ + xdp_socket.poll(timeout).unwrap(); + if cr.dequeue(&mut umem, 1) == 1 { + break; + } + }); + } }); - if let Some((msg, err)) = res { - panic!("{msg}: {err}"); - } + let dest: std::net::SocketAddr = (vpair.inside.ipv4, 7777).into(); + let local = client_socket.local_addr().unwrap(); + println!("sending {}b {local} -> {dest}", clientp.len()); + client_socket + .send_to(clientp, dest) + .expect("failed to send first request"); + + let mut response = [0u8; 20]; + + println!("receiving {}b {local} <- {dest}", serverp.len()); + let (read, addr) = client_socket + .recv_from(&mut response) + .expect("failed to receive first response"); + assert_eq!(&response[..read], serverp); + assert_eq!(addr, (vpair.inside.ipv4, sport).into()); + + println!("sending {}b {local} -> {dest}", clientp.len()); + client_socket + .send_to(clientp, dest) + .expect("failed to send second request"); + + println!("receiving {}b {local} <- {dest}", serverp.len()); + let (read, addr) = client_socket + .recv_from(&mut response) + .expect("failed to receive second response"); + + assert_eq!(&response[..read], serverp); + assert_eq!(addr, (vpair.inside.ipv4, sport).into()); } diff --git a/crates/tests/tests/slab.rs b/crates/tests/tests/slab.rs new file mode 100644 index 0000000..7d2ee3f --- /dev/null +++ b/crates/tests/tests/slab.rs @@ -0,0 +1,90 @@ +use xdp::slab::Slab as _; + +xdp::slab!(TestSlab, u8); + +fn umem(frame_count: u32, head_room: u32) -> xdp::Umem { + use xdp::umem; + + umem::Umem::map( + umem::UmemCfgBuilder { + frame_size: umem::FrameSize::TwoK, + head_room, + frame_count, + ..Default::default() + } + .build() + .unwrap(), + ) + .unwrap() +} + +#[test] +fn edge_conditions() { + const CAP: usize = 64; + + let mut ss = TestSlab::::new(); + let mut umem = umem(80, 0); + + assert!(ss.is_empty()); + assert_eq!(ss.available(), CAP); + assert_eq!(ss.len(), 0); + + for _ in 0..CAP { + let packet = unsafe { umem.alloc() }.unwrap(); + ss.push_front(packet); + } + + assert!(!ss.is_empty()); + assert_eq!(ss.available(), 0); + assert_eq!(ss.len(), CAP); + + let over = unsafe { umem.alloc() }.unwrap(); + let over = ss.push_front(over).unwrap(); + + let back = ss.pop_back().unwrap(); + assert_eq!(ss.available(), 1); + assert_eq!(ss.len(), CAP - 1); + assert!(ss.push_front(over).is_none()); + + umem.free_packet(back); + + while let Some(p) = ss.pop_back() { + umem.free_packet(p); + } + + assert!(ss.is_empty()); + assert_eq!(ss.available(), CAP); + assert_eq!(ss.len(), 0); + + for i in 0..CAP { + let mut packet = unsafe { umem.alloc() }.unwrap(); + packet.insert(0, &[i as u8]).unwrap(); + ss.push_front(packet); + } + + assert!(!ss.is_empty()); + assert_eq!(ss.available(), 0); + assert_eq!(ss.len(), CAP); + + for _ in 0..9 { + for _ in 0..CAP { + let p = ss.pop_back().unwrap(); + assert_eq!(ss.len(), CAP - 1); + assert!(ss.push_front(p).is_none()); + } + } + + assert_eq!(ss.len(), CAP); + + for i in 0..CAP { + let p = ss.pop_back().unwrap(); + assert_eq!(&p[0..1], &[i as u8]); + if i % 2 == 1 { + assert!(ss.push_front(p).is_none()); + } else { + umem.free_packet(p); + } + } + + assert_eq!(ss.len(), CAP >> 1); +} diff --git a/src/affinity.rs b/src/affinity.rs index 6f8311d..c21bf9e 100644 --- a/src/affinity.rs +++ b/src/affinity.rs @@ -13,6 +13,7 @@ impl CoreId { /// Sets the core affinity for the current thread #[inline] pub fn set_affinity(self) -> Result<()> { + // SAFETY: syscall unsafe { let mut set = mem::zeroed(); @@ -63,6 +64,7 @@ impl CoreIds { /// Creates an iterator over the available CPUs #[inline] pub fn new() -> Result { + // SAFETY: syscall let set = unsafe { let mut set = mem::zeroed(); diff --git a/src/lib.rs b/src/lib.rs index 9b98fd6..5b6c9b6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,6 +29,7 @@ pub mod libc; mod mmap; pub mod nic; mod rings; +pub mod slab; pub mod socket; pub mod umem; @@ -38,76 +39,3 @@ pub use rings::{ CompletionRing, FillRing, RingConfig, RingConfigBuilder, Rings, RxRing, TxRing, WakableFillRing, WakableRings, WakableTxRing, }; - -// TODO: This is using VecDequeue (heap) internally, but in most situations this -// could just be fixed sizes with a const N: usize and stored on the stack, so -// might be worth doing that implementation inline, just not super important -/// A ring buffer used to do bulk pops from a [`RxRing`] or pushes to a [`TxRing`] -/// -/// This is allocated on the heap, but will _not_ grow, and is intended to be -/// allocated once before entering an I/O loop -pub struct HeapSlab { - vd: std::collections::VecDeque, -} - -impl HeapSlab { - /// Allocates a new [`Self`] with the maximum specified capacity - #[inline] - pub fn with_capacity(capacity: usize) -> Self { - Self { - vd: std::collections::VecDeque::with_capacity(capacity), - } - } - - /// The number of packets in the slab - #[inline] - pub fn len(&self) -> usize { - self.vd.len() - } - - /// True if the slab is empty - #[inline] - pub fn is_empty(&self) -> bool { - self.vd.is_empty() - } - - /// The number of packets that can be pushed to the slab - #[inline] - pub fn available(&self) -> usize { - self.vd.capacity() - self.vd.len() - } - - /// Pops the front packet if any - #[inline] - pub fn pop_front(&mut self) -> Option { - self.vd.pop_front() - } - - /// Pops the back packet if any - #[inline] - pub fn pop_back(&mut self) -> Option { - self.vd.pop_back() - } - - /// Pushes a packet to the front, returning `Some` if the slab is at capacity - #[inline] - pub fn push_front(&mut self, item: Packet) -> Option { - if self.available() > 0 { - self.vd.push_front(item); - None - } else { - Some(item) - } - } - - /// Pushes a packet to the back, returning `Some` if the slab is at capacity - #[inline] - pub fn push_back(&mut self, item: Packet) -> Option { - if self.available() > 0 { - self.vd.push_back(item); - None - } else { - Some(item) - } - } -} diff --git a/src/libc.rs b/src/libc.rs index cd15dd5..a68fbee 100644 --- a/src/libc.rs +++ b/src/libc.rs @@ -17,17 +17,17 @@ use std::{ffi::c_void, os::fd::RawFd}; /// Internal flags to enable TX offload features if they were enabled at /// socket creation -#[derive(Copy, Clone)] -#[repr(u32)] -pub enum InternalXdpFlags { - /// TX checksum offload is enabled - SupportsChecksumOffload = 1 << 31, - /// TX checksum offload is enabled in software - SoftwareOffload = (1 << 30) | (1 << 31), +pub(crate) mod InternalXdpFlags { + pub type Enum = u32; + + /// TX checksum offload is supported + pub const SUPPORTS_CHECKSUM_OFFLOAD: Enum = 1 << 31; + /// TX checksum offload is done in software rather than hardware + pub const USE_SOFTWARE_OFFLOAD: Enum = SUPPORTS_CHECKSUM_OFFLOAD | (1 << 30); /// TX completion timestamp is supported - CompletionTimestamp = 1 << 29, + pub const SUPPORTS_TIMESTAMP: Enum = 1 << 29; /// Mask of valid flags - Mask = 0xf0000000, + pub const MASK: Enum = 0xf0000000; } /// The bindings specific to the various rings used by `AF_XDP` sockets. @@ -229,6 +229,7 @@ pub mod xdp { pub offload: xsk_tx_offload, } + // SAFETY: POD unsafe impl crate::packet::Pod for xsk_tx_metadata {} /// Flags available when registering a [`crate::Umem`] with a socket @@ -445,7 +446,7 @@ pub(crate) mod iface { #[repr(C)] pub struct ifaddrs { pub ifa_next: *mut ifaddrs, - pub ifa_name: *mut c_char, + pub ifa_name: *mut u8, pub ifa_flags: u32, pub ifa_addr: *mut sockaddr, pub ifa_netmask: *mut sockaddr, @@ -459,7 +460,7 @@ pub(crate) mod iface { pub d_off: i64, pub d_reclen: u16, pub d_type: u8, - pub d_name: [c_char; 256], + pub d_name: [u8; 256], } #[repr(C)] @@ -469,7 +470,7 @@ pub(crate) mod iface { #[repr(C)] pub struct ifreq { - pub ifr_name: [c_char; 16], + pub ifr_name: [u8; 16], pub ifr_ifru: ifr_ifru, } @@ -483,21 +484,21 @@ pub(crate) mod iface { pub fn freeifaddrs(ifap: *mut ifaddrs) -> i32; /// - pub fn if_indextoname(ifindex: u32, ifname: *mut c_char) -> *mut c_char; + pub fn if_indextoname(ifindex: u32, ifname: *mut u8) -> *mut u8; /// - pub fn if_nametoindex(ifname: *const c_char) -> u32; + pub fn if_nametoindex(ifname: *const u8) -> u32; /// pub fn ioctl(fd: RawFd, request: u64, ...) -> i32; /// - pub fn opendir(dirname: *const c_char) -> *mut DIR; + pub fn opendir(dirname: *const u8) -> *mut DIR; /// pub fn closedir(dirp: *mut DIR) -> i32; /// pub fn readdir(dirp: *mut DIR) -> *mut dirent; - pub fn strncmp(cs: *const c_char, ct: *const c_char, n: usize) -> i32; + pub fn strncmp(cs: *const u8, ct: *const u8, n: usize) -> i32; } } diff --git a/src/mmap.rs b/src/mmap.rs index 7cadddc..3d22be2 100644 --- a/src/mmap.rs +++ b/src/mmap.rs @@ -18,7 +18,7 @@ fn page_size() -> usize { } pub struct Mmap { - addr: *mut std::ffi::c_void, + pub(crate) ptr: *mut u8, len: usize, } @@ -54,6 +54,7 @@ impl Mmap { flags: mmap::Flags::Enum, file: i32, ) -> std::io::Result { + // SAFETY: syscalls unsafe { let alignment = offset % page_size() as u64; let aligned_offset = offset - alignment; @@ -70,33 +71,25 @@ impl Mmap { Err(std::io::Error::last_os_error()) } else { Ok(Self { - addr: base.add(alignment as _), + ptr: base.add(alignment as _).cast(), len: length, }) } } } -} - -unsafe impl Sync for Mmap {} -unsafe impl Send for Mmap {} - -impl std::ops::Deref for Mmap { - type Target = [u8]; - fn deref(&self) -> &Self::Target { - unsafe { std::slice::from_raw_parts(self.addr.cast(), self.len) } + #[inline] + pub fn len(&self) -> usize { + self.len } } -impl std::ops::DerefMut for Mmap { - fn deref_mut(&mut self) -> &mut Self::Target { - unsafe { std::slice::from_raw_parts_mut(self.addr.cast(), self.len) } - } -} +// SAFETY: Safe to send across threads +unsafe impl Send for Mmap {} impl Drop for Mmap { fn drop(&mut self) { - unsafe { mmap::munmap(self.addr, self.len) }; + // SAFETY: syscall, the pointer is validated before we create an Mmap + unsafe { mmap::munmap(self.ptr.cast(), self.len) }; } } diff --git a/src/nic.rs b/src/nic.rs index c3bffc3..5a10112 100644 --- a/src/nic.rs +++ b/src/nic.rs @@ -301,19 +301,19 @@ impl NicIndex { /// `None` if the interface cannot be found #[inline] pub fn lookup_by_name(ifname: &std::ffi::CStr) -> std::io::Result> { - unsafe { - let res = iface::if_nametoindex(ifname.as_ptr()); - if res == 0 { - let err = std::io::Error::last_os_error(); - - if err.raw_os_error() == Some(iface::ENODEV) { - Ok(None) - } else { - Err(err) - } + // SAFETY: syscall, we give it a valid pointer + let res = unsafe { iface::if_nametoindex(ifname.as_ptr().cast()) }; + + if res == 0 { + let err = std::io::Error::last_os_error(); + + if err.raw_os_error() == Some(iface::ENODEV) { + Ok(None) } else { - Ok(Some(Self(res))) + Err(err) } + } else { + Ok(Some(Self(res))) } } @@ -321,6 +321,7 @@ impl NicIndex { #[inline] pub fn name(&self) -> std::io::Result { let mut name = [0; iface::IF_NAMESIZE]; + // SAFETY: syscall, we give it a valid pointer if unsafe { !iface::if_indextoname(self.0, name.as_mut_ptr()).is_null() } { let len = name .iter() @@ -337,6 +338,7 @@ impl NicIndex { pub fn addresses( &self, ) -> std::io::Result<(Option, Option)> { + // SAFETY: syscalls unsafe { let mut ifaddrs = std::mem::MaybeUninit::<*mut iface::ifaddrs>::uninit(); if iface::getifaddrs(ifaddrs.as_mut_ptr()) != 0 { @@ -349,6 +351,7 @@ impl NicIndex { struct Ifaddrs(*mut iface::ifaddrs); impl Drop for Ifaddrs { fn drop(&mut self) { + // SAFETY: syscall, we validate the pointer before allowing it to be freed unsafe { iface::freeifaddrs(self.0) }; } } @@ -506,10 +509,10 @@ impl NicIndex { continue; } - if entry.d_name[..2] == [b'r' as i8, b'x' as i8] { + if entry.d_name[..2] == [b'r', b'x'] { channels.max_rx += 1; channels.rx_count += 1; - } else if entry.d_name[..2] == [b't' as i8, b'x' as i8] { + } else if entry.d_name[..2] == [b't', b'x'] { channels.max_tx += 1; channels.tx_count += 1; } @@ -584,7 +587,7 @@ impl PartialEq for NicIndex { /// The human-readable name assigned to a network device #[derive(Copy, Clone)] pub struct NicName { - arr: [i8; iface::IF_NAMESIZE], + arr: [u8; iface::IF_NAMESIZE], len: usize, } @@ -593,10 +596,7 @@ impl NicName { /// unlikely case the interface name is not utf-8 #[inline] pub fn as_str(&self) -> Option<&str> { - std::str::from_utf8(unsafe { - std::slice::from_raw_parts(self.arr.as_ptr().cast(), self.len) - }) - .ok() + std::str::from_utf8(&self.arr[..self.len]).ok() } } @@ -676,9 +676,10 @@ impl Iterator for InterfaceIter { ifname[..name.len()].copy_from_slice(name.as_bytes()); ifname[name.len()] = 0; - let Ok(Some(iface)) = NicIndex::lookup_by_name(unsafe { - std::ffi::CStr::from_bytes_with_nul_unchecked(&ifname) - }) else { + let Ok(Some(iface)) = NicIndex::lookup_by_name( + // SAFETY: we ensure there is a null byte at the end + unsafe { std::ffi::CStr::from_bytes_with_nul_unchecked(&ifname) }, + ) else { continue; }; diff --git a/src/nic/netlink.rs b/src/nic/netlink.rs index ae5e651..53adf59 100644 --- a/src/nic/netlink.rs +++ b/src/nic/netlink.rs @@ -70,6 +70,7 @@ const NLMSGERR_ATTR_MSG: u16 = 1; macro_rules! len { ($record:ty) => { + // SAFETY: internal only unsafe impl Pod for $record {} impl $record { @@ -153,7 +154,7 @@ impl Buf { } #[inline] - fn read(&self, off: &mut usize) -> Result<&P> { + fn read(&self, off: &mut usize) -> Result

{ if *off + P::size() > self.len { return Err(Error::new( ErrorKind::UnexpectedEof, @@ -161,13 +162,15 @@ impl Buf { )); } - let p = unsafe { &*(self.buf.as_ptr().byte_offset(*off as _).cast()) }; + let p = + // SAFETY: we've validated we'll only read within bounds + unsafe { std::ptr::read_unaligned(self.buf.as_ptr().byte_offset(*off as _).cast()) }; *off += P::size(); Ok(p) } #[inline] - fn write(&mut self, off: &mut usize) -> Result<&mut P> { + fn write(&mut self, off: &mut usize, item: P) -> Result<()> { if *off + P::size() > self.len { return Err(Error::new( ErrorKind::UnexpectedEof, @@ -175,9 +178,12 @@ impl Buf { )); } - let p = unsafe { &mut *(self.buf.as_mut_ptr().byte_offset(*off as _).cast()) }; + // SAFETY: we've validated we'll only write within bounds + unsafe { + std::ptr::write_unaligned(self.buf.as_mut_ptr().byte_offset(*off as _).cast(), item); + }; *off += P::size(); - Ok(p) + Ok(()) } #[inline] @@ -300,12 +306,16 @@ impl NetlinkSocket { let seq = self.seq; self.seq += 1; + // SAFETY: various syscalls and buffer manipulation unsafe { let mut off = 0; let len = msg.len; - let hdr = msg.write::(&mut off)?; + + let mut hdr = msg.read::(&mut off)?; + off = 0; hdr.nlmsg_seq = seq; hdr.nlmsg_len = len as _; + msg.write(&mut off, hdr)?; let sent: usize = io_err!(socket::send( self.sock.as_raw_fd(), @@ -354,7 +364,7 @@ impl NetlinkSocket { let message = if msg_hdr.nlmsg_flags & msg_flags::ACK_TLVS != 0 { // We could also recover the offset of the failing attribute, but considering // we only do 2 requests and both have a single attribute.. - AttrIter::error(msg, msg_hdr, err_hdr, &mut offset).find_map(|(kind, data)| { + AttrIter::error(msg, &msg_hdr, &err_hdr, &mut offset).find_map(|(kind, data)| { (kind == NLMSGERR_ATTR_MSG).then_some(String::from_utf8_lossy(&data[..data.len() - 2]).into_owned()) }).unwrap_or_else(|| format!("received netlink error code {}, and we failed to retrieve the additional information provided by the kernel", err_hdr.error)) } else { @@ -373,7 +383,7 @@ impl NetlinkSocket { return Ok(None); } _other => { - let res = func(AttrIter::generic(msg, msg_hdr, &mut offset)?)?; + let res = func(AttrIter::generic(msg, &msg_hdr, &mut offset)?)?; if res.is_some() { return Ok(res); } diff --git a/src/packet.rs b/src/packet.rs index 0984b4d..9ba23f3 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -36,6 +36,10 @@ pub enum PacketError { /// The length of the actual valid contents length: usize, }, + /// TX checksum offload is not supported + ChecksumUnsupported, + /// TX timestamp is not supported + TimestampUnsupported, } impl PacketError { @@ -47,6 +51,8 @@ impl PacketError { Self::InvalidPacketLength {} => "invalid packet length", Self::InvalidOffset { .. } => "invalid offset", Self::InsufficientData { .. } => "insufficient data", + Self::ChecksumUnsupported => "TX checksum unsupported", + Self::TimestampUnsupported => "TX timestamp unsupported", } } } @@ -75,12 +81,16 @@ pub unsafe trait Pod: Sized { /// Gets a zeroed [`Self`] #[inline] fn zeroed() -> Self { + // SAFETY: by implementing Pod the user is saying that an all zero block + // is a valid representation of this type unsafe { std::mem::zeroed() } } /// Gets [`Self`] as a byte slice #[inline] fn as_bytes(&self) -> &[u8] { + // SAFETY: by implementing Pod the user is saying that the struct can be + // represented safely by a byte slice unsafe { std::slice::from_raw_parts((self as *const Self).cast(), std::mem::size_of::()) } @@ -130,12 +140,11 @@ pub enum CsumOffload { /// │ │ │ │ │ /// head +14 +34 +42 tail /// ``` -/// -/// pub struct Packet { /// The entire packet buffer, including headroom, initialized packet contents, /// and uninitialized/empty remainder - pub(crate) data: &'static mut [u8], + pub(crate) data: *mut u8, + pub(crate) capacity: usize, /// The offset in data where the packet starts pub(crate) head: usize, /// The offset in data where the packet ends @@ -148,16 +157,14 @@ impl Packet { /// Only used for testing #[doc(hidden)] pub fn testing_new(buf: &mut [u8; 2 * 1024]) -> Self { - unsafe { - Self { - data: std::mem::transmute::<&mut [u8], &'static mut [u8]>( - &mut buf[libc::xdp::XDP_PACKET_HEADROOM as usize..], - ), - head: 0, - tail: 0, - base: std::ptr::null(), - options: 0, - } + let data = &mut buf[libc::xdp::XDP_PACKET_HEADROOM as usize..]; + Self { + data: data.as_mut_ptr(), + capacity: data.len(), + head: 0, + tail: 0, + base: std::ptr::null(), + options: 0, } } @@ -179,7 +186,7 @@ impl Packet { /// part of every packet #[inline] pub fn capacity(&self) -> usize { - self.data.len() + self.capacity } /// Resets the tail of this packet, causing it to become empty @@ -202,7 +209,7 @@ impl Packet { /// offload or not #[inline] pub fn can_offload_checksum(&self) -> bool { - (self.options & libc::InternalXdpFlags::SupportsChecksumOffload as u32) != 0 + (self.options & libc::InternalXdpFlags::SUPPORTS_CHECKSUM_OFFLOAD) != 0 } /// Adjust the head of the packet up or down by `diff` bytes @@ -254,7 +261,7 @@ impl Packet { self.tail -= diff; } else { let diff = diff as usize; - if self.tail + diff > self.data.len() { + if self.tail + diff > self.capacity { return Err(PacketError::InvalidPacketLength {}); } @@ -289,7 +296,8 @@ impl Packet { }); } - Ok(unsafe { std::ptr::read_unaligned(self.data.as_ptr().byte_offset(start as _).cast()) }) + // SAFETY: we've validated the pointer read is within bounds + Ok(unsafe { std::ptr::read_unaligned(self.data.byte_offset(start as _).cast()) }) } /// Writes the contents of `item` at the specified `offset` @@ -320,12 +328,10 @@ impl Packet { }); } + // SAFETY: we've validated the pointer write is within bounds unsafe { std::ptr::write_unaligned( - self.data - .as_mut_ptr() - .byte_offset((self.head + offset) as _) - .cast(), + self.data.byte_offset((self.head + offset) as _).cast(), item, ); } @@ -353,7 +359,14 @@ impl Packet { }); } - array.copy_from_slice(&self.data[start..start + N]); + // SAFETY: we've validated the range of data we are reading is valid + unsafe { + std::ptr::copy_nonoverlapping( + self.data.byte_offset(offset as _), + array.as_mut_ptr(), + N, + ); + } Ok(()) } @@ -366,7 +379,7 @@ impl Packet { /// - The offset + `slice.len()` would exceed the capacity #[inline] pub fn insert(&mut self, offset: usize, slice: &[u8]) -> Result<(), PacketError> { - if self.tail + slice.len() > self.data.len() { + if self.tail + slice.len() > self.capacity { return Err(PacketError::InvalidPacketLength {}); } else if offset > self.tail { return Err(PacketError::InvalidOffset { @@ -377,27 +390,34 @@ impl Packet { let adjusted_offset = self.head + offset; let shift = self.tail + self.head - adjusted_offset; - if shift > 0 { - unsafe { - // Note that dst is declared before src, otherwise miri complains about UB - let dst = self - .data - .as_mut_ptr() - .byte_offset((adjusted_offset + slice.len()) as isize); - let src = self.data.as_ptr().byte_offset(adjusted_offset as isize); - std::ptr::copy(src, dst, shift); + + // SAFETY: we validate we're within bounds before doing any writes to the + // pointer, which is alive as long as the owning mmap + unsafe { + if shift > 0 { + std::ptr::copy( + self.data.byte_offset(adjusted_offset as isize), + self.data + .byte_offset((adjusted_offset + slice.len()) as isize), + shift, + ); } + + std::ptr::copy_nonoverlapping( + slice.as_ptr(), + self.data.byte_offset(adjusted_offset as _), + slice.len(), + ); } - self.data[adjusted_offset..adjusted_offset + slice.len()].copy_from_slice(slice); self.tail += slice.len(); Ok(()) } /// Sets the specified [TX metadata](https://github.com/torvalds/linux/blob/ae90f6a6170d7a7a1aa4fddf664fbd093e3023bc/Documentation/networking/xsk-tx-metadata.rst) /// - /// Calling this function requires that the [`crate::umem::UmemCfgBuilder::tx_metadata`] - /// was true. + /// Calling this function requires that the [`crate::umem::UmemCfgBuilder::tx_checksum`] + /// and/or [`crate::umem::UmemCfgBuilder::tx_timestamp`] were true /// /// - If `csum` is `CsumOffload::Request`, this will request that the Layer 4 /// checksum computation be offload to the NIC before transmission. Note that @@ -416,13 +436,21 @@ impl Packet { // This would mean the user is making a request that won't actually do anything debug_assert!(request_timestamp || matches!(csum, CsumOffload::Request { .. })); - if self.head < std::mem::size_of::() { - return Err(PacketError::InsufficientHeadroom { - diff: std::mem::size_of::(), - head: self.head, - }); + if matches!(csum, CsumOffload::Request { .. }) + && (self.options & libc::InternalXdpFlags::SUPPORTS_CHECKSUM_OFFLOAD) == 0 + { + return Err(PacketError::ChecksumUnsupported); + } else if request_timestamp + && (self.options & libc::InternalXdpFlags::SUPPORTS_TIMESTAMP) == 0 + { + return Err(PacketError::TimestampUnsupported); } + // SAFETY: While this looks pretty dangerous because we are getting a pointer + // before the base packet, it's actually safe as the presence of either the + // checksum offload or timestamp flags means the umem was registered with + // space for an xsk_tx_metadata that the kernel will also know the location + // of unsafe { let mut tx_meta = std::mem::zeroed::(); @@ -437,8 +465,9 @@ impl Packet { std::ptr::write_unaligned( self.data - .as_mut_ptr() - .byte_offset((self.head - std::mem::size_of::()) as _) + .byte_offset( + self.head as isize - std::mem::size_of::() as isize, + ) .cast(), tx_meta, ); @@ -448,33 +477,51 @@ impl Packet { Ok(()) } + + #[doc(hidden)] + #[inline] + pub fn inner_copy(&mut self) -> Self { + Self { + data: self.data, + capacity: self.capacity, + head: self.head, + tail: self.tail, + base: self.base, + options: self.options, + } + } } impl std::ops::Deref for Packet { type Target = [u8]; fn deref(&self) -> &Self::Target { - &self.data[self.head..self.tail] + // SAFETY: the pointer is valid as long as the mmap is alive + unsafe { &std::slice::from_raw_parts(self.data, self.capacity)[self.head..self.tail] } } } impl std::ops::DerefMut for Packet { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.data[self.head..self.tail] + // SAFETY: the pointer is valid as long as the mmap is alive + unsafe { + &mut std::slice::from_raw_parts_mut(self.data, self.capacity)[self.head..self.tail] + } } } impl From for libc::xdp::xdp_desc { fn from(packet: Packet) -> Self { libc::xdp::xdp_desc { + // SAFETY: the pointer is valid as long as the mmap it is allocated + // from is alive addr: unsafe { packet .data - .as_ptr() .byte_offset(packet.head as _) .offset_from(packet.base) as _ }, len: (packet.tail - packet.head) as _, - options: packet.options & !(libc::InternalXdpFlags::Mask as u32), + options: packet.options & !libc::InternalXdpFlags::MASK, } } } diff --git a/src/packet/csum.rs b/src/packet/csum.rs index 01fd891..bc505a5 100644 --- a/src/packet/csum.rs +++ b/src/packet/csum.rs @@ -22,6 +22,7 @@ pub fn to_u16(mut csum: u32) -> u16 { /// Add with carry #[inline] pub fn add(mut a: u32, b: u32) -> u32 { + // SAFETY: asm unsafe { std::arch::asm!( "addl {b:e}, {a:e}", @@ -84,6 +85,7 @@ pub fn partial(mut buf: &[u8], sum: u32) -> u32 { fn update_40(mut sum: u64, bytes: &[u8]) -> u64 { debug_assert_eq!(bytes.len(), 40); + // SAFETY: asm unsafe { std::arch::asm!( "addq 0*8({buf}), {sum}", @@ -111,6 +113,7 @@ pub fn partial(mut buf: &[u8], sum: u32) -> u32 { buf = &buf[80..]; } + // SAFETY: asm unsafe { std::arch::asm!( "addq {0}, {sum}", @@ -133,6 +136,7 @@ pub fn partial(mut buf: &[u8], sum: u32) -> u32 { let len = buf.len(); if len & 32 != 0 { + // SAFETY: asm unsafe { std::arch::asm!( "addq 0*8({buf}), {sum}", @@ -150,6 +154,7 @@ pub fn partial(mut buf: &[u8], sum: u32) -> u32 { } if len & 16 != 0 { + // SAFETY: asm unsafe { std::arch::asm!( "addq 0*8({buf}), {sum}", @@ -165,6 +170,7 @@ pub fn partial(mut buf: &[u8], sum: u32) -> u32 { } if len & 8 != 0 { + // SAFETY: asm unsafe { std::arch::asm!( "addq 0*8({buf}), {sum}", @@ -183,6 +189,7 @@ pub fn partial(mut buf: &[u8], sum: u32) -> u32 { // of the whole u64 let shift = ((-(len as i64) << 3) & 63) as u32; + // SAFETY: asm unsafe { // The kernel's load_unaligned_zeropad needs to take into account // this load potentially crossing page boundaries, but we don't have @@ -295,6 +302,7 @@ impl super::Packet { let udp_hdr = self.read::(offset)?; // https://en.wikipedia.org/wiki/User_Datagram_Protocol#IPv4_pseudo_header + // SAFETY: asm unsafe { let mut sum = 0; @@ -324,6 +332,7 @@ impl super::Packet { let udp_hdr = self.read::(offset)?; // https://en.wikipedia.org/wiki/User_Datagram_Protocol#IPv6_pseudo_header + // SAFETY: asm unsafe { let mut sum = ((udp_hdr.length.host() as u32).to_be() as u64) .wrapping_add((IpProto::Udp as u64).to_be()); @@ -392,6 +401,7 @@ impl nt::UdpHeaders { match &self.ip { nt::IpHdr::V4(v4) => { // https://en.wikipedia.org/wiki/User_Datagram_Protocol#IPv4_pseudo_header + // SAFETY: asm unsafe { std::arch::asm!( "addq {pseudo_udp}, {sum}", @@ -415,6 +425,7 @@ impl nt::UdpHeaders { } nt::IpHdr::V6(v6) => { // https://en.wikipedia.org/wiki/User_Datagram_Protocol#IPv6_pseudo_header + // SAFETY: asm unsafe { let source = v6.source; let destination = v6.destination; diff --git a/src/packet/net_types.rs b/src/packet/net_types.rs index 8deec19..22abd28 100644 --- a/src/packet/net_types.rs +++ b/src/packet/net_types.rs @@ -10,6 +10,7 @@ use std::{ macro_rules! len { ($record:ty) => { + // SAFETY: We only use this macro on types it is safe for unsafe impl Pod for $record {} impl $record { diff --git a/src/rings.rs b/src/rings.rs index a390329..c85b571 100644 --- a/src/rings.rs +++ b/src/rings.rs @@ -177,7 +177,7 @@ fn map_ring( offset: libc::RingPageOffsets, offsets: &libc::xdp_ring_offset, ) -> std::io::Result<(crate::mmap::Mmap, XskRing)> { - let mut mmap = crate::mmap::Mmap::map_ring( + let mmap = crate::mmap::Mmap::map_ring( offsets.desc as usize + (count as usize * std::mem::size_of::()), offset as u64, socket, @@ -185,7 +185,7 @@ fn map_ring( // SAFETY: The lifetime of the pointers are the same as the mmap let ring = unsafe { - let map = mmap.as_mut_ptr(); + let map = mmap.ptr; let producer = AtomicU32::from_ptr(map.byte_offset(offsets.producer as _).cast()); let consumer = AtomicU32::from_ptr(map.byte_offset(offsets.consumer as _).cast()); @@ -260,6 +260,7 @@ impl std::ops::Index for XskProducer { #[inline] fn index(&self, index: usize) -> &Self::Output { + // SAFETY: each ring impl ensures the index is valid unsafe { self.0.ring.get_unchecked(index) } } } @@ -267,6 +268,7 @@ impl std::ops::Index for XskProducer { impl std::ops::IndexMut for XskProducer { #[inline] fn index_mut(&mut self, index: usize) -> &mut Self::Output { + // SAFETY: each ring impl ensures the index is valid unsafe { self.0.ring.get_unchecked_mut(index) } } } diff --git a/src/rings/fill.rs b/src/rings/fill.rs index 0ab69cb..f754d67 100644 --- a/src/rings/fill.rs +++ b/src/rings/fill.rs @@ -51,8 +51,8 @@ impl FillRing { /// lower than the requested `num_packets` if the [`Umem`] didn't have enough /// open slots, or the rx ring had insufficient capacity pub unsafe fn enqueue(&mut self, umem: &mut Umem, num_packets: usize) -> usize { - let mut popper = umem.popper(); - let requested = std::cmp::min(popper.len(), num_packets); + let available = umem.available(); + let requested = std::cmp::min(available.len(), num_packets); if requested == 0 { return 0; } @@ -62,7 +62,7 @@ impl FillRing { if actual > 0 { let mask = self.ring.mask(); for i in idx..idx + actual { - self.ring[i & mask] = popper.pop(); + self.ring[i & mask] = available.pop_front().unwrap(); } self.ring.submit(actual as _); @@ -104,6 +104,7 @@ impl WakableFillRing { num_packets: usize, wakeup: bool, ) -> std::io::Result { + // SAFETY: FillRing::enqueue is unsafe let queued = unsafe { self.inner.enqueue(umem, num_packets) }; if queued > 0 && wakeup { diff --git a/src/rings/rx.rs b/src/rings/rx.rs index 575bad4..5d5896d 100644 --- a/src/rings/rx.rs +++ b/src/rings/rx.rs @@ -1,7 +1,7 @@ //! The [`RxRing`] is a consumer ring that userspace can dequeue packets that have //! been received on the NIC queue the ring is bound to -use crate::{HeapSlab, Umem, libc}; +use crate::{Umem, libc, slab::Slab}; /// Ring from which we can dequeue packets that have been filled by the kernel pub struct RxRing { @@ -48,7 +48,7 @@ impl RxRing { /// /// The packets returned in the slab must not outlive the [`Umem`] #[inline] - pub unsafe fn recv(&mut self, umem: &Umem, packets: &mut HeapSlab) -> usize { + pub unsafe fn recv(&mut self, umem: &Umem, packets: &mut S) -> usize { let nb = packets.available(); if nb == 0 { return 0; @@ -57,20 +57,19 @@ impl RxRing { let (actual, idx) = self.ring.peek(nb as _); if actual > 0 { - unsafe { self.do_recv(actual, idx, umem, packets) }; - } - - actual - } + let mask = self.ring.mask(); + for i in idx..idx + actual { + let desc = self.ring[i & mask]; + packets.push_front( + // SAFETY: The user is responsible for the lifetime of the + // packets we are returning + unsafe { umem.packet(desc) }, + ); + } - #[inline] - unsafe fn do_recv(&mut self, actual: usize, idx: usize, umem: &Umem, packets: &mut HeapSlab) { - let mask = self.ring.mask(); - for i in idx..idx + actual { - let desc = self.ring[i & mask]; - packets.push_back(unsafe { umem.packet(desc) }); + self.ring.release(actual as _); } - self.ring.release(actual as _); + actual } } diff --git a/src/rings/tx.rs b/src/rings/tx.rs index 62993dd..6c6a971 100644 --- a/src/rings/tx.rs +++ b/src/rings/tx.rs @@ -2,8 +2,8 @@ //! sent by the NIC the ring is bound to use crate::{ - HeapSlab, libc::{self, rings}, + slab::Slab, }; /// The ring used to enqueue packets for the kernel to send @@ -56,7 +56,7 @@ impl TxRing { /// The number of packets that were actually enqueued. This number can be /// lower than the requested `num_packets` if the ring doesn't have sufficient /// capacity - pub unsafe fn send(&mut self, packets: &mut HeapSlab) -> usize { + pub unsafe fn send(&mut self, packets: &mut S) -> usize { let requested = packets.len(); if requested == 0 { return 0; @@ -67,7 +67,7 @@ impl TxRing { if actual > 0 { let mask = self.ring.mask(); for i in idx..idx + actual { - let Some(packet) = packets.pop_front() else { + let Some(packet) = packets.pop_back() else { unreachable!() }; @@ -109,7 +109,12 @@ impl WakableTxRing { /// The number of packets that were actually enqueued. This number can be /// lower than the requested `num_packets` if the ring doesn't have sufficient /// capacity - pub unsafe fn send(&mut self, packets: &mut HeapSlab, wakeup: bool) -> std::io::Result { + pub unsafe fn send( + &mut self, + packets: &mut S, + wakeup: bool, + ) -> std::io::Result { + // SAFETY: TxRing::send is unsafe let queued = unsafe { self.inner.send(packets) }; if queued > 0 && wakeup { diff --git a/src/slab.rs b/src/slab.rs new file mode 100644 index 0000000..42674df --- /dev/null +++ b/src/slab.rs @@ -0,0 +1,162 @@ +//! Contains simple slab data structures for use with the TX and RX rings + +use crate::Packet; + +/// A fixed size buffer of packets +pub trait Slab { + /// The number of free slots available + fn available(&self) -> usize; + /// The number of occupied slots + fn len(&self) -> usize; + /// True if the slab is empty + fn is_empty(&self) -> bool; + /// Pushes a packet to the slab, returning the packet if the slab is at capacity + fn push_front(&mut self, packet: Packet) -> Option; + /// Pops the back packet if any + fn pop_back(&mut self) -> Option; +} + +// A heap allocated slab, using [`std::collections::VecDequeue`] +/// +/// This is allocated on the heap, but will _not_ grow, and is intended to be +/// allocated once before entering an I/O loop +pub struct HeapSlab { + vd: std::collections::VecDeque, +} + +impl HeapSlab { + /// Allocates a new [`Self`] with the maximum specified capacity + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + vd: std::collections::VecDeque::with_capacity(capacity), + } + } +} + +impl Slab for HeapSlab { + /// The number of packets in the slab + #[inline] + fn len(&self) -> usize { + self.vd.len() + } + + /// True if the slab is empty + #[inline] + fn is_empty(&self) -> bool { + self.vd.is_empty() + } + + /// The number of packets that can be pushed to the slab + #[inline] + fn available(&self) -> usize { + self.vd.capacity() - self.vd.len() + } + + /// Pops the front packet if any + #[inline] + fn pop_back(&mut self) -> Option { + self.vd.pop_back() + } + + /// Pushes a packet to the front, returning `Some` if the slab is at capacity + #[inline] + fn push_front(&mut self, item: Packet) -> Option { + if self.available() > 0 { + self.vd.push_front(item); + None + } else { + Some(item) + } + } +} + +struct AssertPowerOf2; + +impl AssertPowerOf2 { + const OK: () = assert!(usize::is_power_of_two(N), "must be a power of 2"); +} + +#[doc(hidden)] +pub const fn assert_power_of_2() { + let () = AssertPowerOf2::::OK; +} + +/// Slab impl macro, only public for creating testing slabs with integer types < usize +#[cfg_attr(debug_assertions, macro_export)] +macro_rules! slab { + ($name:ident, $int:ty) => { + /// A stack allocated, fixed size, ring buffer + pub struct $name { + ring: [$crate::Packet; N], + read: $int, + write: $int, + } + + impl $name { + /// Creates a new slab, `N` must be a power of 2 + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + $crate::slab::assert_power_of_2::(); + + Self { + // SAFETY: Packet is just a POD + ring: unsafe { std::mem::zeroed() }, + read: 0, + write: 0, + } + } + } + + impl $crate::slab::Slab for $name { + /// The current number of packets in the slab + #[inline] + fn len(&self) -> usize { + if self.write >= self.read { + (self.write - self.read) as _ + } else { + <$int>::MAX as usize - self.read as usize + self.write as usize + 1 + } + } + + /// True if the slab is empty + #[inline] + fn is_empty(&self) -> bool { + self.write == self.read + } + + /// The number of packets that can be pushed to the slab + #[inline] + fn available(&self) -> usize { + N - self.len() + } + + /// Pops the back packet if any + #[inline] + fn pop_back(&mut self) -> Option<$crate::Packet> { + if self.is_empty() { + return None; + } + + let index = self.read as usize % N; + self.read = self.read.wrapping_add(1); + Some(self.ring[index].inner_copy()) + } + + /// Pushes a packet to the front, returning `Some` if the slab is at capacity + #[inline] + fn push_front(&mut self, item: $crate::Packet) -> Option<$crate::Packet> { + if self.available() > 0 { + let index = self.write as usize % N; + self.write = self.write.wrapping_add(1); + self.ring[index] = item; + None + } else { + Some(item) + } + } + } + }; +} + +slab!(StackSlab, usize); diff --git a/src/socket.rs b/src/socket.rs index 44bd4dc..dc849d6 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -223,24 +223,28 @@ impl XdpSocketBuilder { cfg: &rings::RingConfig, ) -> Result { let mut flags = 0; - if !umem.frame_size.is_power_of_two() { + // Internally umem uses frame_size - head room for the capacity of + // each packet, but we need to readjust it here so the kernel knows + // the actual size + let chunk_size = umem.frame_size as u32 + xdp::XDP_PACKET_HEADROOM as u32; + if !chunk_size.is_power_of_two() { flags |= xdp::UmemFlags::XDP_UMEM_UNALIGNED_CHUNK_FLAG; } - if umem.options & InternalXdpFlags::SupportsChecksumOffload as u32 != 0 { + if umem.options != 0 { // This value is only available in very recent ~6.11 kernels and was introduced // for those who didn't zero initialize xdp_umem_reg flags |= xdp::UmemFlags::XDP_UMEM_TX_METADATA_LEN; - if umem.options & InternalXdpFlags::SoftwareOffload as u32 != 0 { + if umem.options & InternalXdpFlags::USE_SOFTWARE_OFFLOAD != 0 { flags |= xdp::UmemFlags::XDP_UMEM_TX_SW_CSUM; } } let umem_reg = xdp::XdpUmemReg { - addr: umem.mmap.as_ptr() as _, + addr: umem.mmap.ptr as _, len: umem.mmap.len() as _, - chunk_size: umem.frame_size as _, + chunk_size, headroom: umem.head_room as _, flags, tx_metadata_len: if umem.options != 0 { @@ -320,6 +324,7 @@ impl XdpSocketBuilder { sxdp_shared_umem_fd: 0, }; + // SAFETY: syscall, all inputs are valid if unsafe { socket::bind( self.sock.as_raw_fd(), @@ -336,15 +341,7 @@ impl XdpSocketBuilder { #[inline] fn set_sockopt(&mut self, name: OptName, val: &T) -> Result<(), SocketError> { - // let level = if matches!( - // name, - // OptName::PreferBusyPoll | OptName::BusyPoll | OptName::BusyPollBudget - // ) { - // libc::SOL_SOCKET - // } else { - // libc::SOL_XDP - // }; - + // SAFETY: syscall, all inputs are valid if unsafe { libc::socket::setsockopt( self.sock.as_raw_fd(), @@ -424,6 +421,7 @@ impl XdpSocket { #[inline] fn poll_inner(&self, events: i16, timeout: PollTimeout) -> std::io::Result { + // SAFETY: syscall, all inputs are valid let ret = unsafe { socket::poll( &mut socket::pollfd { diff --git a/src/umem.rs b/src/umem.rs index 7919fa8..2b154af 100644 --- a/src/umem.rs +++ b/src/umem.rs @@ -73,7 +73,9 @@ pub struct Umem { /// data when receiving, which allows the packet to grow downward when eg. /// changing from IPv4 -> IPv6 without needing to copying data upwards pub(crate) head_room: usize, - pub(crate) options: u32, + /// Flags that control how the umem is registered and thus what capabilities + /// each packet has + pub(crate) options: InternalXdpFlags::Enum, } impl Umem { @@ -88,18 +90,10 @@ impl Umem { Ok(Self { mmap, available, - frame_size: cfg.frame_size as _, - frame_mask: !(cfg.frame_size as u64 - 1), + frame_size: cfg.frame_size as usize - libc::xdp::XDP_PACKET_HEADROOM as usize, + frame_mask: cfg.frame_mask, head_room: cfg.head_room as _, - options: if cfg.tx_metadata { - InternalXdpFlags::SupportsChecksumOffload as u32 - } else { - 0 - } | if cfg.software_checksum { - InternalXdpFlags::SoftwareOffload as u32 - } else { - 0 - }, + options: cfg.options, }) } @@ -115,18 +109,17 @@ impl Umem { // SAFETY: Barring kernel bugs, we should only ever get valid addresses // within the range of our map unsafe { - let addr = self + let data = self .mmap - .as_ptr() - .byte_offset((desc.addr - self.head_room as u64) as _) - as *mut u8; - let data = std::slice::from_raw_parts_mut(addr, self.frame_size); + .ptr + .byte_offset((desc.addr - self.head_room as u64) as _); Packet { data, + capacity: self.frame_size, head: self.head_room, tail: self.head_room + desc.len as usize, - base: self.mmap.as_ptr(), + base: self.mmap.ptr, options: desc.options | self.options, } } @@ -143,19 +136,20 @@ impl Umem { pub unsafe fn alloc(&mut self) -> Option { let addr = self.available.pop_front()?; + // SAFETY: The free list of addresses will always be within the range + // of the mmap unsafe { - let addr = self + let data = self .mmap - .as_ptr() - .byte_offset((addr + libc::xdp::XDP_PACKET_HEADROOM) as _) - as *mut u8; - let data = std::slice::from_raw_parts_mut(addr, self.frame_size); + .ptr + .byte_offset((addr + libc::xdp::XDP_PACKET_HEADROOM) as _); Some(Packet { data, + capacity: self.frame_size, head: self.head_room, tail: self.head_room, - base: self.mmap.as_ptr(), + base: self.mmap.ptr, options: self.options, }) } @@ -175,18 +169,19 @@ impl Umem { #[inline] pub fn free_packet(&mut self, packet: Packet) { debug_assert_eq!( - packet.base, - self.mmap.as_ptr(), + packet.base, self.mmap.ptr, "the packet was not allocated from this Umem" ); - self.free_addr(unsafe { - packet - .data - .as_ptr() - .byte_offset(packet.head as _) - .offset_from(packet.base) as _ - }); + self.free_addr( + // SAFETY: We've checked that the packet is owned by this Umem + unsafe { + packet + .data + .byte_offset(packet.head as _) + .offset_from(packet.base) as _ + }, + ); } /// The equivalent of [`Self::free_addr`], but returns the timestamp the @@ -197,12 +192,15 @@ impl Umem { let align_offset = address % self.frame_size as u64; let timestamp = if align_offset >= std::mem::size_of::() as u64 { + // SAFETY: This is a pod, so even if this wasn't actually enabled when + // the packet was enqueued, it shouldn't result in UB unsafe { - let tx_meta = &*(self - .mmap - .as_ptr() - .byte_offset((address - std::mem::size_of::() as u64) as _) - .cast::()); + let tx_meta = std::ptr::read_unaligned( + self.mmap + .ptr + .byte_offset((address - std::mem::size_of::() as u64) as _) + .cast::(), + ); tx_meta.offload.completion } } else { @@ -214,29 +212,8 @@ impl Umem { } #[inline] - pub(crate) fn popper(&mut self) -> UmemPopper<'_> { - UmemPopper { - available: &mut self.available, - } - } -} - -pub(crate) struct UmemPopper<'umem> { - available: &'umem mut VecDeque, -} - -impl UmemPopper<'_> { - #[inline] - pub(crate) fn len(&self) -> usize { - self.available.len() - } - - #[inline] - pub(crate) fn pop(&mut self) -> u64 { - let Some(addr) = self.available.pop_front() else { - unreachable!() - }; - addr + pub(crate) fn available(&mut self) -> &mut VecDeque { + &mut self.available } } @@ -254,15 +231,14 @@ pub struct UmemCfgBuilder { pub frame_count: u32, /// If true, the [`Umem`] will be registered with the socket with an /// additional section before the packet that may be filled with TX metadata - /// that either request a checksum be calculated by the NIC, and/or that the + /// that either request a checksum be calculated by the NIC + pub tx_checksum: bool, + /// If true, the [`Umem`] will be , and/or that the /// transmission timestamp is set before being added to the completion queue - pub tx_metadata: bool, + pub tx_timestamp: bool, /// For testing purposes only, enables the /// [`libc::xdp::UmemFlags::XDP_UMEM_TX_SW_CSUM`] flag so the checksum is /// calculated by the driver in software - /// - /// Note that [`Self::tx_metadata`] must also be set to true when using - /// this option #[cfg(debug_assertions)] pub software_checksum: bool, } @@ -273,7 +249,8 @@ impl Default for UmemCfgBuilder { frame_size: FrameSize::FourK, // XSK_UMEM_DEFAULT_FRAME_SIZE head_room: 0, frame_count: 8 * 1024, - tx_metadata: false, + tx_checksum: false, + tx_timestamp: false, #[cfg(debug_assertions)] software_checksum: false, } @@ -281,9 +258,22 @@ impl Default for UmemCfgBuilder { } impl UmemCfgBuilder { + /// Creates a builder with TX checksum offload and/or timestamping if supported + /// by the NIC + pub fn new(tx_flags: crate::nic::XdpTxMetadata) -> Self { + Self { + tx_checksum: tx_flags.checksum(), + tx_timestamp: tx_flags.timestamp(), + ..Default::default() + } + } + /// Attempts build a [`UmemCfg`] that can be used with [`Umem::map`] pub fn build(self) -> Result { let frame_size = self.frame_size.try_into()?; + // For now we only allow 2k and 4k sizes, but if we supported unaligned + // frames in the future we'd need to change this + let frame_mask = !(frame_size as u64 - 1); let head_room = within_range!( self, @@ -292,15 +282,35 @@ impl UmemCfgBuilder { ); let frame_count = within_range!(self, frame_count, 1..u32::MAX as _); + let total_size = frame_count as usize * frame_size as usize; + if total_size > isize::MAX as usize { + return Err(Error::Cfg(crate::error::ConfigError { + name: "frame_count * frame_size", + kind: crate::error::ConfigErrorKind::OutOfRange { + size: total_size, + range: frame_size as usize..isize::MAX as usize, + }, + })); + } + + let mut options = 0; + if self.tx_checksum { + options |= InternalXdpFlags::SUPPORTS_CHECKSUM_OFFLOAD; + } + if self.tx_timestamp { + options |= InternalXdpFlags::SUPPORTS_TIMESTAMP; + } + #[cfg(debug_assertions)] + if self.software_checksum { + options |= InternalXdpFlags::USE_SOFTWARE_OFFLOAD; + } + Ok(UmemCfg { frame_size, + frame_mask, frame_count, head_room, - tx_metadata: self.tx_metadata, - #[cfg(debug_assertions)] - software_checksum: self.software_checksum, - #[cfg(not(debug_assertions))] - software_checksum: false, + options, }) } } @@ -309,8 +319,8 @@ impl UmemCfgBuilder { #[derive(Copy, Clone)] pub struct UmemCfg { frame_size: u32, + frame_mask: u64, frame_count: u32, head_room: u32, - tx_metadata: bool, - software_checksum: bool, + options: InternalXdpFlags::Enum, }