Skip to content

Commit

Permalink
Add Frame size limit
Browse files Browse the repository at this point in the history
  • Loading branch information
ndelvalle committed May 20, 2024
1 parent dc34513 commit 03b2cbc
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
21 changes: 19 additions & 2 deletions src/codec.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use bytes::{Buf, BytesMut};
use std::convert::TryInto;
use std::env;
use std::io::Cursor;
use tokio_util::codec::Decoder;

Expand All @@ -8,16 +9,32 @@ use crate::Error;

pub struct FrameCodec;

impl FrameCodec {
fn max_frame_size() -> usize {
env::var("MAX_FRAME_SIZE")
.map(|s| s.parse().expect("MAX_FRAME_SIZE must be a number"))
.unwrap_or(512 * 1024 * 1024)
}
}

impl Decoder for FrameCodec {
type Item = Frame;
type Error = Error;

// TODO:
// * Use src.reserve. This is a more efficient way to allocate space in the buffer.
// * Return an error if the frame is too large. This is a simple way to prevent a malicious
// client from sending a large frame and causing the server to run out of memory.
// * Read more here: https://docs.rs/tokio-util/latest/tokio_util/codec/index.html
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
// Check if the frame size exceeds a certain limit to prevent DoS attacks

println!("src.len(): {}", src.len());

if src.len() > FrameCodec::max_frame_size() {
return Err("frame size exceeds limit".into());
}

print!("processing frame: ");

let mut cursor = Cursor::new(&src[..]);
let frame = match Frame::parse(&mut cursor) {
Ok(frame) => frame,
Expand Down
36 changes: 36 additions & 0 deletions tests/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,39 @@ async fn test_parse_incomplete_frame() {
]));
assert_eq!(actual, expected);
}

#[tokio::test]
async fn test_max_frame_size_limit() {
let one_mb = 1024 * 1024;
std::env::set_var("MAX_FRAME_SIZE", one_mb.to_string());

let (tcp_stream_tx, tcp_stream) = create_tcp_connection().await.unwrap();
let peer_addr = tcp_stream.peer_addr().unwrap();
let mut connection = Connection::new(tcp_stream, peer_addr);

// Frame below limit size calculation:
// The frame format includes a length indicator and data terminated with \r\n.
// For a frame just below the 1 MB limit (one_mb - 1 bytes):
// - Length Indicator: $1048575\r\n
// - $: 1 byte
// - 1048575: 7 bytes (for the length)
// - \r\n: 2 bytes (CRLF)
// Total length indicator size: 1 + 7 + 2 = 10 bytes
// - Data size: To fit within the limit, the data itself should be one_mb - 1 - 10 bytes.
// Since the data terminates with \r\n, the actual data size should be one_mb - 12 bytes.
let frame_below_limit = format!("${}\r\n{}\r\n", one_mb - 1, "A".repeat(one_mb - 12));

let frame_above_limit = format!("${}\r\n{}\r\n", one_mb + 1, "A".repeat(one_mb + 1));

tcp_stream_tx.send(frame_below_limit.into_bytes()).unwrap();
tcp_stream_tx.send(frame_above_limit.into_bytes()).unwrap();

let _frame_below_limit = connection.read_frame().await.unwrap();
let frame_above_limit_result = connection.read_frame().await;
let frame_above_limit_error = frame_above_limit_result.unwrap_err();

assert_eq!(
frame_above_limit_error.to_string(),
"frame size exceeds limit"
);
}

0 comments on commit 03b2cbc

Please sign in to comment.