diff --git a/Cargo.lock b/Cargo.lock index 8f306f785..7845ffde6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -387,9 +387,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.37" +version = "4.5.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eccb054f56cbd38340b380d4a8e69ef1f02f1af43db2f0cc817a4774d80ae071" +checksum = "ed93b9805f8ba930df42c2590f05453d5ec36cbb85d018868a5b24d31f6ac000" dependencies = [ "clap_builder", "clap_derive", @@ -397,9 +397,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.37" +version = "4.5.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efd9466fac8543255d3b1fcad4762c5e116ffe808c8a3043d4263cd4fd4862a2" +checksum = "379026ff283facf611b0ea629334361c4211d1b12ee01024eec1591133b04120" dependencies = [ "anstyle", "clap_lex", @@ -818,6 +818,7 @@ dependencies = [ "futures", "libc", "monero-serai", + "pin-project-lite", "rayon", "serde", "tokio", @@ -3062,9 +3063,9 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tempfile" -version = "3.19.1" +version = "3.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf" +checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" dependencies = [ "fastrand", "getrandom 0.3.3", diff --git a/Cargo.toml b/Cargo.toml index b00836c55..850c43d28 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -132,6 +132,7 @@ monero-serai = { git = "https://github.com/Cuprate/serai.git", rev = "e nu-ansi-term = { version = "0.46", default-features = false } paste = { version = "1", default-features = false } pin-project = { version = "1", default-features = false } +pin-project-lite = { version = "0.2.16", default-features = false } randomx-rs = { git = "https://github.com/Cuprate/randomx-rs.git", rev = "e09955c", default-features = false } rand = { version = "0.8", default-features = false } rand_distr = { version = "0.4", default-features = false } diff --git a/helper/Cargo.toml b/helper/Cargo.toml index ad70e18bf..daef57aac 100644 --- a/helper/Cargo.toml +++ b/helper/Cargo.toml @@ -24,6 +24,7 @@ time = ["dep:chrono", "std"] thread = ["std", "dep:target_os_lib"] tx = ["dep:monero-serai"] fmt = ["map", "std"] +timeout = ["std", "dep:pin-project-lite", "dep:tokio"] [dependencies] cuprate-constants = { workspace = true, optional = true, features = ["block"] } @@ -35,6 +36,8 @@ dirs = { workspace = true, optional = true } futures = { workspace = true, optional = true, features = ["std"] } monero-serai = { workspace = true, optional = true } rayon = { workspace = true, optional = true } +tokio = { workspace = true, optional = true, features = ["time"] } +pin-project-lite = { workspace = true, optional = true } serde = { workspace = true, optional = true, features = ["derive"] } diff --git a/helper/src/lib.rs b/helper/src/lib.rs index bf464042e..f48417128 100644 --- a/helper/src/lib.rs +++ b/helper/src/lib.rs @@ -14,6 +14,9 @@ pub mod cast; #[cfg(feature = "fs")] pub mod fs; +#[cfg(feature = "timeout")] +pub mod timeout; + pub mod network; #[cfg(feature = "num")] diff --git a/helper/src/timeout.rs b/helper/src/timeout.rs new file mode 100644 index 000000000..4b5da1ddd --- /dev/null +++ b/helper/src/timeout.rs @@ -0,0 +1,529 @@ +//! IO Timeout Wrapper +//! +//! This module implements wrapper around [`AsyncRead`]/[`AsyncWrite`] types to return `TimedOut` error if +//! they haven't been able to complete operation after a period of time. +//! +//! This is used as a denial of service mitigation mechanism against keep-alive or one-way spamming connections. +//! +//! Internally these wrappers abstract the `Duration` field to welcome shared data structures that can be used to +//! adapt the timeout period on the fly. +//! + +use std::{ + future::Future, + io::{Error, ErrorKind}, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use pin_project_lite::pin_project; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + time::{sleep_until, Instant, Sleep}, +}; + +/// Helper trait that add [`ExtractDuration::extract_duration`] function that can compute a timeout Duration +/// from a reference. This is sensibly the same as `D where Duration: From<&D>` but with +/// no specific lifetime requirement. +/// +/// This trait is implemented for [`Duration`]. +pub trait ExtractDuration: Clone + Unpin { + fn extract_duration(&self) -> Duration; +} + +impl ExtractDuration for Duration { + fn extract_duration(&self) -> Duration { + *self + } +} + +/// A timeout state with a specified duration. +/// +/// `D` implements [`ExtractDuration`] trait which +/// permit custom types to compute a Timeout duration. +/// +/// This can be useful for shared data structures that +/// modify the timeout on the fly. +pub struct TimeoutState { + timeout: D, + refresh: bool, + sleep: Pin>, +} + +impl TimeoutState +where + Self: Unpin, +{ + /// Create a new [`TimeoutState`] with the given timeout type. + pub fn new(timeout: D) -> Self { + Self { + timeout, + refresh: true, + sleep: Box::pin(sleep_until(Instant::now())), + } + } + + /// Poll inner [`Sleep`] for completion. Update its deadline on first use and return + /// `Poll::Ready(Error::from(ErrorKind::TimedOut))` on completion, `Poll::Pending` otherwise + pub fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut proj = self; + + // On first poll after refresh activate couldown. + if proj.refresh { + proj.refresh = false; + let timeout = proj.timeout.extract_duration(); + proj.sleep.as_mut().reset(Instant::now() + timeout); + } + + proj.sleep + .as_mut() + .poll(cx) + .map(|()| Error::from(ErrorKind::TimedOut)) + } +} + +// Helper macros for reducing redundancy. This main logic is present in every poll. +macro_rules! poll_or_timeout { + ($self:ident::$io:ident..$timeout:ident => $poll:ident, $cx:ident, $($arg:expr),*) => {{ + let proj = $self .project(); + + match proj.$io.$poll($cx, $($arg),*) { + Poll::Pending => proj.$timeout.as_mut().poll($cx).map(Err), + Poll::Ready(r) => { + proj.$timeout.refresh = true; + Poll::Ready(r) + } + } + }}; + ($self:ident::$io:ident..$timeout:ident => $poll:ident, $cx:ident) => {{ + let proj = $self .project(); + + match proj.$io.$poll($cx) { + Poll::Pending => proj.$timeout.as_mut().poll($cx).map(Err), + Poll::Ready(r) => { + proj.$timeout.refresh = true; + Poll::Ready(r) + } + } + }}; +} + +pin_project! { + /// A timeout wrapper around an [`AsyncWrite`] implemented type. + /// + /// Returns a `TimedOut` error if any poll operation have been returning + /// `Poll::Pending` for the timeout duration. + pub struct WriteTimeout { + #[pin] + writer: W, + timeout: Pin>>, + } +} + +impl WriteTimeout { + /// Create a new [`WriteTimeout`] from a writer and an [`ExtractDuration`] enabled type. + pub fn new(writer: W, timeout: D) -> Self { + Self { + writer, + timeout: Box::pin(TimeoutState::new(timeout)), + } + } +} + +impl AsyncWrite for WriteTimeout { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + poll_or_timeout!(self::writer..timeout => poll_write, cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[std::io::IoSlice<'_>], + ) -> Poll> { + poll_or_timeout!(self::writer..timeout => poll_write_vectored, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + poll_or_timeout!(self::writer..timeout => poll_flush, cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + poll_or_timeout!(self::writer..timeout => poll_shutdown, cx) + } + + fn is_write_vectored(&self) -> bool { + self.writer.is_write_vectored() + } +} + +pin_project! { + /// A timeout wrapper around an [`AsyncRead`] implemented type. + /// + /// Returns a `TimedOut` error if `poll_read` have been returning + /// `Poll::Pending` for the timeout duration. + pub struct ReadTimeout { + #[pin] + reader: R, + timeout: Pin>>, + } +} + +impl ReadTimeout { + /// Create a new [`ReadTimeout`] from a reader and an [`ExtractDuration`] enabled type. + pub fn new(reader: R, timeout: D) -> Self { + Self { + reader, + timeout: Box::pin(TimeoutState::new(timeout)), + } + } +} + +impl AsyncRead for ReadTimeout { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + poll_or_timeout!(self::reader..timeout => poll_read, cx, buf) + } +} + +pin_project! { + /// A timeout wrapper around an [`AsyncRead`] + [`AsyncWrite`] implemented type. + /// + /// Returns a `TimedOut` error if `poll_read` have been returning + /// `Poll::Pending` for the timeout duration. + pub struct StreamTimeout { + #[pin] + stream: S, + write_timeout: Pin>>, + read_timeout: Pin>> + } +} + +impl StreamTimeout { + /// Create a new [`StreamTimeout`] from a stream and two [`ExtractDuration`] enabled type. + pub fn new(stream: S, write_timeout: DW, read_timeout: DR) -> Self { + Self { + stream, + write_timeout: Box::pin(TimeoutState::new(write_timeout)), + read_timeout: Box::pin(TimeoutState::new(read_timeout)), + } + } +} + +impl AsyncWrite + for StreamTimeout +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + poll_or_timeout!(self::stream..write_timeout => poll_write, cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[std::io::IoSlice<'_>], + ) -> Poll> { + poll_or_timeout!(self::stream..write_timeout => poll_write_vectored, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + poll_or_timeout!(self::stream..write_timeout => poll_flush, cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + poll_or_timeout!(self::stream..write_timeout => poll_shutdown, cx) + } + + fn is_write_vectored(&self) -> bool { + self.stream.is_write_vectored() + } +} + +impl AsyncRead + for StreamTimeout +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + poll_or_timeout!(self::stream..read_timeout => poll_read, cx, buf) + } +} + +#[cfg(test)] +mod test { + use std::{ + future::Future, + io::ErrorKind, + net::{IpAddr, Ipv4Addr, SocketAddr}, + time::Duration, + }; + + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpListener, TcpStream}, + select, + task::JoinSet, + time::{sleep, timeout}, + }; + + use crate::timeout::{ReadTimeout, StreamTimeout, WriteTimeout}; + + #[cfg(target_os = "macos")] + const TEST_TIMEOUT: Duration = Duration::from_secs(2); + #[cfg(not(target_os = "macos"))] + const TEST_TIMEOUT: Duration = Duration::from_secs(10); + + fn within_current_thread_runtime(future: impl Future) { + // Start tokio runtime + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_all() + .build() + .unwrap(); + + runtime.block_on(future); + } + + // Common setup used between TCP tests. + async fn spawn_tcp_setup(port: u16, client_test: C, listener_test: L) + where + R1: Future + Send + 'static, + R2: Future + Send + 'static, + C: Fn(TcpStream) -> R1 + Send + 'static, + L: Fn(TcpStream) -> R2 + Send + 'static, + { + let socketaddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port); + + let listener = TcpListener::bind(socketaddr) + .await + .expect("Unable to bind TCP Listener"); + + let mut set = JoinSet::new(); + + // Spawn Listener + set.spawn(async move { + let connection = listener + .accept() + .await + .expect("Unable to accept incoming connection"); + + listener_test(connection.0).await; + }); + + // Spawn client + set.spawn(async move { + let Ok(stream) = timeout(TEST_TIMEOUT, TcpStream::connect(socketaddr)) + .await + .expect("Unable to connect listener") + else { + panic!("No connection has been made to the listener!"); + }; + + client_test(stream).await; + }); + + set.join_all().await; + } + + #[test] + fn tcp_write_timeout_err() { + within_current_thread_runtime(spawn_tcp_setup( + 60031, + async |_stream: TcpStream| { + sleep(TEST_TIMEOUT + Duration::from_secs(1)).await; + }, + async |stream: TcpStream| { + let (_reader, writer) = stream.into_split(); + + // Wrap writer half into a WriteTimeout. + let mut writer = WriteTimeout::new(writer, TEST_TIMEOUT); + + // Write. + let buf = vec![1_u8; 64 * 1024_usize.pow(2)]; // 64MiB + select! { + r = writer.write_all(&buf) => { + if let Err(err) = r { + assert_eq!(err.kind(), ErrorKind::TimedOut); + } else { + panic!("Buffer have been successfully flushed. This test needs to updated.") + } + } + () = sleep(TEST_TIMEOUT + Duration::from_secs(1)) => { + panic!("No error has been returned after +1 seconds.") + } + } + }, + )); + } + + #[test] + fn tcp_write_timeout_ok() { + within_current_thread_runtime(spawn_tcp_setup( + 60032, + async |stream: TcpStream| { + let (mut reader, _writer) = stream.into_split(); + + sleep(TEST_TIMEOUT / 2).await; + + loop { + let mut buf = vec![0_u8; 64 * 1024]; // 64KiB + if reader.read_exact(&mut buf).await.is_err() { + break; + } + } + }, + async |stream: TcpStream| { + let (_reader, writer) = stream.into_split(); + + // Wrap writer half into a WriteTimeout. + let mut writer = WriteTimeout::new(writer, TEST_TIMEOUT); + + // Write. + let buf = vec![1_u8; 1024_usize.pow(2)]; // 1MiB + select! { + r = writer.write_all(&buf) => { + assert!(r.is_ok()); + } + () = sleep(TEST_TIMEOUT + Duration::from_secs(1)) => { + panic!("No error has been returned after +1 seconds.") + } + } + }, + )); + } + + #[test] + fn tcp_read_timeout_err() { + within_current_thread_runtime(spawn_tcp_setup( + 60033, + async |_stream: TcpStream| { + sleep(TEST_TIMEOUT + Duration::from_secs(1)).await; + }, + async |stream: TcpStream| { + let (reader, _writer) = stream.into_split(); + + // Wrap reader half into a ReadTimeout. + let mut reader = ReadTimeout::new(reader, TEST_TIMEOUT); + + // Try to read. + let mut buf = vec![0_u8; 1024]; // 1KiB + select! { + r = reader.read_buf(&mut buf) => { + if let Err(err) = r { + assert_eq!(err.kind(), ErrorKind::TimedOut); + } else { + panic!("The buffer has been successfully filled. This test needs to updated.") + } + } + () = sleep(TEST_TIMEOUT + Duration::from_secs(1)) => { + panic!("No error has been returned after +1 seconds.") + } + } + }, + )); + } + + #[test] + fn tcp_read_timeout_ok() { + within_current_thread_runtime(spawn_tcp_setup( + 60034, + async |stream: TcpStream| { + let (_reader, mut writer) = stream.into_split(); + + sleep(TEST_TIMEOUT / 2).await; + + let _ = writer + .write(&[1]) + .await + .expect("Unable to write into TCP stream"); + }, + async |stream: TcpStream| { + let (reader, _writer) = stream.into_split(); + + // Wrap reader half into a ReadTimeout. + let mut reader = ReadTimeout::new(reader, TEST_TIMEOUT); + + // Try to read + let mut buf = vec![0_u8; 1024]; // 1KiB + select! { + r = reader.read_buf(&mut buf) => { + assert!(r.is_ok()); + } + () = sleep(TEST_TIMEOUT + Duration::from_secs(1)) => { + panic!("No error has been returned after +1 seconds.") + } + } + }, + )); + } + + #[test] + fn tcp_stream_read_timeout_err() { + within_current_thread_runtime(spawn_tcp_setup( + 60035, + async |_stream: TcpStream| { + sleep(TEST_TIMEOUT + Duration::from_secs(1)).await; + }, + async |stream: TcpStream| { + // Wrap stream into StreamTimeout + let mut stream = StreamTimeout::new(stream, TEST_TIMEOUT, TEST_TIMEOUT); + + // Try to read + let mut buf = vec![0_u8; 1024]; // 1KiB + select! { + r = stream.read_buf(&mut buf) => { + if let Err(err) = r { + assert_eq!(err.kind(), ErrorKind::TimedOut); + } else { + panic!("The buffer has been successfully filled. This test needs to updated.") + } + } + () = sleep(TEST_TIMEOUT + Duration::from_secs(1)) => { + panic!("No error has been returned after +1 seconds.") + } + } + }, + )); + } + + #[test] + fn tcp_stream_write_timeout_err() { + within_current_thread_runtime(spawn_tcp_setup( + 60036, + async |_stream: TcpStream| { + sleep(TEST_TIMEOUT + Duration::from_secs(1)).await; + }, + async |stream: TcpStream| { + // Wrap stream into StreamTimeout + let mut stream = StreamTimeout::new(stream, TEST_TIMEOUT, TEST_TIMEOUT); + + // Try to write + let buf = vec![1_u8; 64 * 1024_usize.pow(2)]; // 16MiB + select! { + r = stream.write_all(&buf) => { + if let Err(err) = r { + assert_eq!(err.kind(), ErrorKind::TimedOut); + } else { + panic!("Buffer have been successfully flushed. This test needs to updated.") + } + } + () = sleep(TEST_TIMEOUT * 2) => { + panic!("No error has been returned after +1 seconds.") + } + } + }, + )); + } +} diff --git a/p2p/p2p-core/Cargo.toml b/p2p/p2p-core/Cargo.toml index 4515e5bbe..6ef3878f2 100644 --- a/p2p/p2p-core/Cargo.toml +++ b/p2p/p2p-core/Cargo.toml @@ -10,7 +10,7 @@ default = ["borsh"] borsh = ["dep:borsh", "cuprate-pruning/borsh"] [dependencies] -cuprate-helper = { workspace = true, features = ["asynch"], default-features = false } +cuprate-helper = { workspace = true, features = ["asynch", "timeout"], default-features = false } cuprate-wire = { workspace = true, features = ["tracing"] } cuprate-pruning = { workspace = true } cuprate-types = { workspace = true } diff --git a/p2p/p2p-core/src/transports/tcp.rs b/p2p/p2p-core/src/transports/tcp.rs index 31f58ea75..383880d1b 100644 --- a/p2p/p2p-core/src/transports/tcp.rs +++ b/p2p/p2p-core/src/transports/tcp.rs @@ -40,7 +40,7 @@ pub struct TcpServerConfig { pub port: u16, /// Number of milliseconds before timeout at TCP writing - _send_timeout: Duration, + send_timeout: Duration, } impl Default for TcpServerConfig { @@ -49,7 +49,7 @@ impl Default for TcpServerConfig { ipv4: Some(Ipv4Addr::UNSPECIFIED), ipv6: None, port: 18081, - _send_timeout: Duration::from_secs(20), + send_timeout: Duration::from_secs(20), } } } @@ -60,6 +60,8 @@ pub struct TcpInBoundStream { listener_v4: Option, /// IPv6 TCP listener listener_v6: Option, + /// Send Timeout + _send_timeout: Duration, } impl Stream for TcpInBoundStream { @@ -141,6 +143,7 @@ impl> Transport for Tcp { return Ok(TcpInBoundStream { listener_v4: ipv4_listener, listener_v6: None, + _send_timeout: config.send_timeout, }); } @@ -153,6 +156,7 @@ impl> Transport for Tcp { Ok(TcpInBoundStream { listener_v4: ipv4_listener, listener_v6: ipv6_listener, + _send_timeout: config.send_timeout, }) } }