diff --git a/src/backend/libc/net/syscalls.rs b/src/backend/libc/net/syscalls.rs index 3013f9922..220ce8bd1 100644 --- a/src/backend/libc/net/syscalls.rs +++ b/src/backend/libc/net/syscalls.rs @@ -7,11 +7,15 @@ use super::msghdr::with_xdp_msghdr; #[cfg(target_os = "linux")] use super::write_sockaddr::encode_sockaddr_xdp; use crate::backend::c; +#[cfg(target_os = "linux")] +use crate::backend::conv::ret_u32; use crate::backend::conv::{borrowed_fd, ret, ret_owned_fd, ret_send_recv, send_recv_len}; use crate::fd::{BorrowedFd, OwnedFd}; use crate::io; #[cfg(target_os = "linux")] use crate::net::xdp::SocketAddrXdp; +#[cfg(target_os = "linux")] +use crate::net::MMsgHdr; use crate::net::{SocketAddrAny, SocketAddrV4, SocketAddrV6}; use crate::utils::as_ptr; use core::mem::{size_of, MaybeUninit}; @@ -455,6 +459,23 @@ pub(crate) fn sendmsg_xdp( }) } +#[cfg(target_os = "linux")] +pub(crate) fn sendmmsg( + sockfd: BorrowedFd<'_>, + msgs: &mut [MMsgHdr<'_>], + flags: SendFlags, +) -> io::Result { + unsafe { + ret_u32(c::sendmmsg( + borrowed_fd(sockfd), + msgs.as_mut_ptr() as _, + msgs.len().try_into().unwrap_or(c::c_uint::MAX), + bitflags_bits!(flags), + )) + .map(|ret| ret as usize) + } +} + #[cfg(not(any( apple, windows, diff --git a/src/backend/linux_raw/c.rs b/src/backend/linux_raw/c.rs index 4035bf945..f075b0abe 100644 --- a/src/backend/linux_raw/c.rs +++ b/src/backend/linux_raw/c.rs @@ -56,12 +56,12 @@ pub(crate) use linux_raw_sys::{ general::{O_CLOEXEC as SOCK_CLOEXEC, O_NONBLOCK as SOCK_NONBLOCK}, if_ether::*, net::{ - linger, msghdr, sockaddr, sockaddr_in, sockaddr_in6, sockaddr_un, socklen_t, AF_DECnet, __kernel_sa_family_t as sa_family_t, __kernel_sockaddr_storage as sockaddr_storage, - cmsghdr, in6_addr, in_addr, ip_mreq, ip_mreq_source, ip_mreqn, ipv6_mreq, AF_APPLETALK, - AF_ASH, AF_ATMPVC, AF_ATMSVC, AF_AX25, AF_BLUETOOTH, AF_BRIDGE, AF_CAN, AF_ECONET, - AF_IEEE802154, AF_INET, AF_INET6, AF_IPX, AF_IRDA, AF_ISDN, AF_IUCV, AF_KEY, AF_LLC, - AF_NETBEUI, AF_NETLINK, AF_NETROM, AF_PACKET, AF_PHONET, AF_PPPOX, AF_RDS, AF_ROSE, + cmsghdr, in6_addr, in_addr, ip_mreq, ip_mreq_source, ip_mreqn, ipv6_mreq, linger, mmsghdr, + msghdr, sockaddr, sockaddr_in, sockaddr_in6, sockaddr_un, socklen_t, AF_DECnet, + AF_APPLETALK, AF_ASH, AF_ATMPVC, AF_ATMSVC, AF_AX25, AF_BLUETOOTH, AF_BRIDGE, AF_CAN, + AF_ECONET, AF_IEEE802154, AF_INET, AF_INET6, AF_IPX, AF_IRDA, AF_ISDN, AF_IUCV, AF_KEY, + AF_LLC, AF_NETBEUI, AF_NETLINK, AF_NETROM, AF_PACKET, AF_PHONET, AF_PPPOX, AF_RDS, AF_ROSE, AF_RXRPC, AF_SECURITY, AF_SNA, AF_TIPC, AF_UNIX, AF_UNSPEC, AF_WANPIPE, AF_X25, AF_XDP, IP6T_SO_ORIGINAL_DST, IPPROTO_FRAGMENT, IPPROTO_ICMPV6, IPPROTO_MH, IPPROTO_ROUTING, IPV6_ADD_MEMBERSHIP, IPV6_DROP_MEMBERSHIP, IPV6_FREEBIND, IPV6_MULTICAST_HOPS, diff --git a/src/backend/linux_raw/net/syscalls.rs b/src/backend/linux_raw/net/syscalls.rs index 4d4427a40..880c45ecc 100644 --- a/src/backend/linux_raw/net/syscalls.rs +++ b/src/backend/linux_raw/net/syscalls.rs @@ -16,6 +16,8 @@ use super::send_recv::{RecvFlags, SendFlags}; use super::write_sockaddr::encode_sockaddr_xdp; use super::write_sockaddr::{encode_sockaddr_v4, encode_sockaddr_v6}; use crate::backend::c; +#[cfg(target_os = "linux")] +use crate::backend::conv::slice_mut; use crate::backend::conv::{ by_mut, by_ref, c_int, c_uint, pass_usize, ret, ret_owned_fd, ret_usize, size_of, slice, socklen_t, zero, @@ -24,6 +26,8 @@ use crate::fd::{BorrowedFd, OwnedFd}; use crate::io::{self, IoSlice, IoSliceMut}; #[cfg(target_os = "linux")] use crate::net::xdp::SocketAddrXdp; +#[cfg(target_os = "linux")] +use crate::net::MMsgHdr; use crate::net::{ AddressFamily, Protocol, RecvAncillaryBuffer, RecvMsgReturn, SendAncillaryBuffer, Shutdown, SocketAddrAny, SocketAddrUnix, SocketAddrV4, SocketAddrV6, SocketFlags, SocketType, @@ -36,8 +40,8 @@ use { crate::backend::reg::{ArgReg, SocketArg}, linux_raw_sys::net::{ SYS_ACCEPT, SYS_ACCEPT4, SYS_BIND, SYS_CONNECT, SYS_GETPEERNAME, SYS_GETSOCKNAME, - SYS_LISTEN, SYS_RECV, SYS_RECVFROM, SYS_RECVMSG, SYS_SEND, SYS_SENDMSG, SYS_SENDTO, - SYS_SHUTDOWN, SYS_SOCKET, SYS_SOCKETPAIR, + SYS_LISTEN, SYS_RECV, SYS_RECVFROM, SYS_RECVMSG, SYS_SEND, SYS_SENDMMSG, SYS_SENDMSG, + SYS_SENDTO, SYS_SHUTDOWN, SYS_SOCKET, SYS_SOCKETPAIR, }, }; @@ -439,6 +443,30 @@ pub(crate) fn sendmsg_xdp( }) } +#[cfg(target_os = "linux")] +#[inline] +pub(crate) fn sendmmsg( + sockfd: BorrowedFd<'_>, + msgs: &mut [MMsgHdr<'_>], + flags: SendFlags, +) -> io::Result { + let (msgs, len) = slice_mut(msgs); + + #[cfg(not(target_arch = "x86"))] + let result = unsafe { ret_usize(syscall!(__NR_sendmmsg, sockfd, msgs, len, flags)) }; + + #[cfg(target_arch = "x86")] + let result = unsafe { + ret_usize(syscall!( + __NR_socketcall, + x86_sys(SYS_SENDMMSG), + slice_just_addr::, _>(&[sockfd.into(), msgs, len, flags.into()]) + )) + }; + + result +} + #[inline] pub(crate) fn shutdown(fd: BorrowedFd<'_>, how: Shutdown) -> io::Result<()> { #[cfg(not(target_arch = "x86"))] diff --git a/src/net/send_recv/msg.rs b/src/net/send_recv/msg.rs index 794485d9f..14fa0e717 100644 --- a/src/net/send_recv/msg.rs +++ b/src/net/send_recv/msg.rs @@ -2,11 +2,17 @@ #![allow(unsafe_code)] +#[cfg(target_os = "linux")] +use crate::backend::net::msghdr::{ + with_noaddr_msghdr, with_unix_msghdr, with_v4_msghdr, with_v6_msghdr, with_xdp_msghdr, +}; use crate::backend::{self, c}; use crate::fd::{AsFd, BorrowedFd, OwnedFd}; use crate::io::{self, IoSlice, IoSliceMut}; #[cfg(linux_kernel)] use crate::net::UCred; +#[cfg(target_os = "linux")] +use crate::net::{xdp::SocketAddrXdp, SocketAddrUnix}; use core::iter::FusedIterator; use core::marker::PhantomData; @@ -591,6 +597,75 @@ impl<'buf> Iterator for AncillaryDrain<'buf> { impl FusedIterator for AncillaryDrain<'_> {} +/// An ABI-compatible wrapper for `mmsghdr`, for sending multiple messages with +/// [sendmmsg]. +#[cfg(target_os = "linux")] +#[repr(transparent)] +pub struct MMsgHdr<'a> { + raw: c::mmsghdr, + _phantom: PhantomData<&'a mut ()>, +} + +#[cfg(target_os = "linux")] +impl<'a> MMsgHdr<'a> { + /// Constructs a new message with no destination address. + pub fn new(iov: &[IoSlice<'a>], control: &mut SendAncillaryBuffer<'_, '_, '_>) -> Self { + with_noaddr_msghdr(iov, control, Self::wrap) + } + + /// Constructs a new message to a specific IPv4 address. + pub fn new_v4( + addr: &SocketAddrV4, + iov: &[IoSlice<'a>], + control: &mut SendAncillaryBuffer<'_, '_, '_>, + ) -> Self { + with_v4_msghdr(addr, iov, control, Self::wrap) + } + + /// Constructs a new message to a specific IPv6 address. + pub fn new_v6( + addr: &SocketAddrV6, + iov: &[IoSlice<'a>], + control: &mut SendAncillaryBuffer<'_, '_, '_>, + ) -> Self { + with_v6_msghdr(addr, iov, control, Self::wrap) + } + + /// Constructs a new message to a specific Unix-domain address. + pub fn new_unix( + addr: &SocketAddrUnix, + iov: &[IoSlice<'a>], + control: &mut SendAncillaryBuffer<'_, '_, '_>, + ) -> Self { + with_unix_msghdr(addr, iov, control, Self::wrap) + } + + /// Constructs a new message to a specific XDP address. + pub fn new_xdp( + addr: &SocketAddrXdp, + iov: &[IoSlice<'a>], + control: &mut SendAncillaryBuffer<'_, '_, '_>, + ) -> Self { + with_xdp_msghdr(addr, iov, control, Self::wrap) + } + + fn wrap(msg_hdr: c::msghdr) -> Self { + Self { + raw: c::mmsghdr { + msg_hdr, + msg_len: 0, + }, + _phantom: PhantomData, + } + } + + /// Returns the number of bytes sent. This will return 0 until after a + /// successful call to [sendmmsg]. + pub fn bytes_sent(&self) -> usize { + self.raw.msg_len as _ + } +} + /// `sendmsg(msghdr)`—Sends a message on a socket. /// /// # References @@ -781,6 +856,22 @@ pub fn sendmsg_any( } } +/// `sendmmsg(msghdr)`—Sends multiple messages on a socket. +/// +/// # References +/// - [Linux] +/// +/// [Linux]: https://man7.org/linux/man-pages/man2/sendmmsg.2.html +#[inline] +#[cfg(target_os = "linux")] +pub fn sendmmsg( + socket: impl AsFd, + msgs: &mut [MMsgHdr<'_>], + flags: SendFlags, +) -> io::Result { + backend::net::syscalls::sendmmsg(socket.as_fd(), msgs, flags) +} + /// `recvmsg(msghdr)`—Receives a message from a socket. /// /// # References diff --git a/tests/net/v4.rs b/tests/net/v4.rs index d770b657b..49cadef6c 100644 --- a/tests/net/v4.rs +++ b/tests/net/v4.rs @@ -194,3 +194,91 @@ fn test_v4_msg() { client.join().unwrap(); server.join().unwrap(); } + +#[test] +#[cfg(target_os = "linux")] +fn test_v4_sendmmsg() { + crate::init(); + + use std::net::TcpStream; + + use rustix::io::IoSlice; + use rustix::net::{sendmmsg, MMsgHdr}; + + fn server(ready: Arc<(Mutex, Condvar)>) { + let connection_socket = socket(AddressFamily::INET, SocketType::STREAM, None).unwrap(); + + let name = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 0); + bind_v4(&connection_socket, &name).unwrap(); + + let who = match getsockname(&connection_socket).unwrap() { + SocketAddrAny::V4(addr) => addr, + _ => panic!(), + }; + + listen(&connection_socket, 1).unwrap(); + + { + let (lock, cvar) = &*ready; + let mut port = lock.lock().unwrap(); + *port = who.port(); + cvar.notify_all(); + } + + let mut buffer = vec![0; 13]; + let mut data_socket: TcpStream = accept(&connection_socket).unwrap().into(); + + std::io::Read::read_exact(&mut data_socket, &mut buffer).unwrap(); + assert_eq!(String::from_utf8_lossy(&buffer), "hello...world"); + } + + fn client(ready: Arc<(Mutex, Condvar)>) { + let port = { + let (lock, cvar) = &*ready; + let mut port = lock.lock().unwrap(); + while *port == 0 { + port = cvar.wait(port).unwrap(); + } + *port + }; + + let addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port); + let data_socket = socket(AddressFamily::INET, SocketType::STREAM, None).unwrap(); + connect_v4(&data_socket, &addr).unwrap(); + + let mut off = 0; + while off < 2 { + let sent = sendmmsg( + &data_socket, + &mut [ + MMsgHdr::new(&[IoSlice::new(b"hello")], &mut Default::default()), + MMsgHdr::new(&[IoSlice::new(b"...world")], &mut Default::default()), + ][off..], + SendFlags::empty(), + ) + .unwrap(); + + off += sent; + } + } + + let ready = Arc::new((Mutex::new(0_u16), Condvar::new())); + let ready_clone = Arc::clone(&ready); + + let server = thread::Builder::new() + .name("server".to_string()) + .spawn(move || { + server(ready); + }) + .unwrap(); + + let client = thread::Builder::new() + .name("client".to_string()) + .spawn(move || { + client(ready_clone); + }) + .unwrap(); + + client.join().unwrap(); + server.join().unwrap(); +} diff --git a/tests/net/v6.rs b/tests/net/v6.rs index 0d0a596c9..435ccf968 100644 --- a/tests/net/v6.rs +++ b/tests/net/v6.rs @@ -193,3 +193,94 @@ fn test_v6_msg() { client.join().unwrap(); server.join().unwrap(); } + +#[test] +#[cfg(target_os = "linux")] +fn test_v6_sendmmsg() { + crate::init(); + + use std::net::TcpStream; + + use rustix::io::IoSlice; + use rustix::net::{sendmmsg, MMsgHdr}; + + fn server(ready: Arc<(Mutex, Condvar)>) { + let connection_socket = socket(AddressFamily::INET6, SocketType::STREAM, None).unwrap(); + + let name = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 0, 0, 0); + bind_v6(&connection_socket, &name).unwrap(); + + let who = match getsockname(&connection_socket).unwrap() { + SocketAddrAny::V6(addr) => addr, + _ => panic!(), + }; + + listen(&connection_socket, 1).unwrap(); + + { + let (lock, cvar) = &*ready; + let mut port = lock.lock().unwrap(); + *port = who.port(); + cvar.notify_all(); + } + + let mut buffer = vec![0; 13]; + let mut data_socket: TcpStream = accept(&connection_socket).unwrap().into(); + + std::io::Read::read_exact(&mut data_socket, &mut buffer).unwrap(); + assert_eq!(String::from_utf8_lossy(&buffer), "hello...world"); + } + + fn client(ready: Arc<(Mutex, Condvar)>) { + let port = { + let (lock, cvar) = &*ready; + let mut port = lock.lock().unwrap(); + while *port == 0 { + port = cvar.wait(port).unwrap(); + } + *port + }; + + let addr = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), port, 0, 0); + let data_socket = socket(AddressFamily::INET6, SocketType::STREAM, None).unwrap(); + connect_v6(&data_socket, &addr).unwrap(); + + let mut off = 0; + loop { + let sent = sendmmsg( + &data_socket, + &mut [ + MMsgHdr::new(&[IoSlice::new(b"hello")], &mut Default::default()), + MMsgHdr::new(&[IoSlice::new(b"...world")], &mut Default::default()), + ][off..], + SendFlags::empty(), + ) + .unwrap(); + + off += sent; + if off >= 2 { + break; + } + } + } + + let ready = Arc::new((Mutex::new(0_u16), Condvar::new())); + let ready_clone = Arc::clone(&ready); + + let server = thread::Builder::new() + .name("server".to_string()) + .spawn(move || { + server(ready); + }) + .unwrap(); + + let client = thread::Builder::new() + .name("client".to_string()) + .spawn(move || { + client(ready_clone); + }) + .unwrap(); + + client.join().unwrap(); + server.join().unwrap(); +}