From 03b2cbc4d808ac1015b89824f34c8c1aaf1fe27e Mon Sep 17 00:00:00 2001 From: Nicolas del Valle Date: Mon, 20 May 2024 13:39:06 +0700 Subject: [PATCH 1/2] Add Frame size limit --- src/codec.rs | 21 +++++++++++++++++++-- tests/connection.rs | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/src/codec.rs b/src/codec.rs index f92f605..d94bb9b 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -1,5 +1,6 @@ use bytes::{Buf, BytesMut}; use std::convert::TryInto; +use std::env; use std::io::Cursor; use tokio_util::codec::Decoder; @@ -8,16 +9,32 @@ use crate::Error; pub struct FrameCodec; +impl FrameCodec { + fn max_frame_size() -> usize { + env::var("MAX_FRAME_SIZE") + .map(|s| s.parse().expect("MAX_FRAME_SIZE must be a number")) + .unwrap_or(512 * 1024 * 1024) + } +} + impl Decoder for FrameCodec { type Item = Frame; type Error = Error; // TODO: // * Use src.reserve. This is a more efficient way to allocate space in the buffer. - // * Return an error if the frame is too large. This is a simple way to prevent a malicious - // client from sending a large frame and causing the server to run out of memory. // * Read more here: https://docs.rs/tokio-util/latest/tokio_util/codec/index.html fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + // Check if the frame size exceeds a certain limit to prevent DoS attacks + + println!("src.len(): {}", src.len()); + + if src.len() > FrameCodec::max_frame_size() { + return Err("frame size exceeds limit".into()); + } + + print!("processing frame: "); + let mut cursor = Cursor::new(&src[..]); let frame = match Frame::parse(&mut cursor) { Ok(frame) => frame, diff --git a/tests/connection.rs b/tests/connection.rs index a93fe61..e15f037 100644 --- a/tests/connection.rs +++ b/tests/connection.rs @@ -218,3 +218,39 @@ async fn test_parse_incomplete_frame() { ])); assert_eq!(actual, expected); } + +#[tokio::test] +async fn test_max_frame_size_limit() { + let one_mb = 1024 * 1024; + std::env::set_var("MAX_FRAME_SIZE", one_mb.to_string()); + + let (tcp_stream_tx, tcp_stream) = create_tcp_connection().await.unwrap(); + let peer_addr = tcp_stream.peer_addr().unwrap(); + let mut connection = Connection::new(tcp_stream, peer_addr); + + // Frame below limit size calculation: + // The frame format includes a length indicator and data terminated with \r\n. + // For a frame just below the 1 MB limit (one_mb - 1 bytes): + // - Length Indicator: $1048575\r\n + // - $: 1 byte + // - 1048575: 7 bytes (for the length) + // - \r\n: 2 bytes (CRLF) + // Total length indicator size: 1 + 7 + 2 = 10 bytes + // - Data size: To fit within the limit, the data itself should be one_mb - 1 - 10 bytes. + // Since the data terminates with \r\n, the actual data size should be one_mb - 12 bytes. + let frame_below_limit = format!("${}\r\n{}\r\n", one_mb - 1, "A".repeat(one_mb - 12)); + + let frame_above_limit = format!("${}\r\n{}\r\n", one_mb + 1, "A".repeat(one_mb + 1)); + + tcp_stream_tx.send(frame_below_limit.into_bytes()).unwrap(); + tcp_stream_tx.send(frame_above_limit.into_bytes()).unwrap(); + + let _frame_below_limit = connection.read_frame().await.unwrap(); + let frame_above_limit_result = connection.read_frame().await; + let frame_above_limit_error = frame_above_limit_result.unwrap_err(); + + assert_eq!( + frame_above_limit_error.to_string(), + "frame size exceeds limit" + ); +} From 9f3ddc5ac8e44373df3cb262c6a389cea1c3047c Mon Sep 17 00:00:00 2001 From: Nicolas del Valle Date: Mon, 20 May 2024 13:40:05 +0700 Subject: [PATCH 2/2] Add comment --- src/codec.rs | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/codec.rs b/src/codec.rs index d94bb9b..77b5989 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -25,16 +25,11 @@ impl Decoder for FrameCodec { // * Use src.reserve. This is a more efficient way to allocate space in the buffer. // * Read more here: https://docs.rs/tokio-util/latest/tokio_util/codec/index.html fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - // Check if the frame size exceeds a certain limit to prevent DoS attacks - - println!("src.len(): {}", src.len()); - + // Check if the frame size exceeds the limit to prevent DoS attacks. if src.len() > FrameCodec::max_frame_size() { return Err("frame size exceeds limit".into()); } - print!("processing frame: "); - let mut cursor = Cursor::new(&src[..]); let frame = match Frame::parse(&mut cursor) { Ok(frame) => frame,