diff --git a/Cargo.lock b/Cargo.lock index 5a79fda4..c23ed48c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -584,6 +584,7 @@ dependencies = [ "axum", "base64", "bytes", + "criterion", "futures", "http", "http-body 0.4.5", diff --git a/Cargo.toml b/Cargo.toml index a2b7d502..938f06f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ resolver = "2" [workspace.dependencies] futures = "0.3.27" tokio = "1.34.0" +tokio-tungstenite = "0.20.1" serde = { version = "1.0.193", features = ["derive"] } serde_json = "1.0.108" tower = { version = "0.4.13", default-features = false } @@ -14,12 +15,6 @@ thiserror = "1.0.40" tracing = "0.1.37" itoa = "1.0.9" -tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } -hyper = "0.14.25" -axum = "0.6.20" -warp = "0.3.6" -salvo = { version = "0.58.5", features = ["tower-compat"] } - # Hyper v0.1 http-body-v1 = { package = "http-body", version = "1.0.0-rc.2" } hyper-v1 = { package = "hyper", version = "1.0.0-rc.4", features = [ @@ -28,6 +23,21 @@ hyper-v1 = { package = "hyper", version = "1.0.0-rc.4", features = [ "http2", ] } +# Dev deps +tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } +criterion = { version = "0.5.1", features = ["html_reports"] } +hyper = { version = "0.14.25", features = [ + "http1", + "http2", + "server", + "stream", + "runtime", + "client", +] } +axum = "0.6.20" +warp = "0.3.6" +salvo = { version = "0.58.5", features = ["tower-compat"] } + [workspace.package] version = "0.7.2" edition = "2021" diff --git a/engineioxide/Cargo.toml b/engineioxide/Cargo.toml index 06533348..ac968d43 100644 --- a/engineioxide/Cargo.toml +++ b/engineioxide/Cargo.toml @@ -27,11 +27,11 @@ thiserror.workspace = true tokio = { workspace = true, features = ["rt", "time"] } tower.workspace = true hyper.workspace = true +tokio-tungstenite.workspace = true base64 = "0.21.0" bytes = "1.4.0" pin-project = "1.0.12" -tokio-tungstenite = "0.20.1" rand = "0.8.5" # Tracing @@ -48,14 +48,7 @@ http-body-v1 = { workspace = true, optional = true } [dev-dependencies] tokio = { workspace = true, features = ["macros", "parking_lot"] } tracing-subscriber.workspace = true -hyper = { workspace = true, features = [ - "http1", - "http2", - "server", - "stream", - "runtime", - "client", -] } +criterion.workspace = true warp.workspace = true axum.workspace = true salvo.workspace = true @@ -65,3 +58,13 @@ v3 = ["memchr", "unicode-segmentation"] test-utils = [] tracing = ["dep:tracing"] hyper-v1 = ["dep:hyper-v1", "dep:http-body-v1"] + +[[bench]] +name = "packet_encode" +path = "benches/packet_encode.rs" +harness = false + +[[bench]] +name = "packet_decode" +path = "benches/packet_decode.rs" +harness = false diff --git a/engineioxide/benches/packet_decode.rs b/engineioxide/benches/packet_decode.rs new file mode 100644 index 00000000..ea223048 --- /dev/null +++ b/engineioxide/benches/packet_decode.rs @@ -0,0 +1,32 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use engineioxide::Packet; + +fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("Decode packet ping/pong", |b| { + let packet: String = Packet::Ping.try_into().unwrap(); + b.iter(|| Packet::try_from(packet.as_str()).unwrap()) + }); + c.bench_function("Decode packet ping/pong upgrade", |b| { + let packet: String = Packet::PingUpgrade.try_into().unwrap(); + b.iter(|| Packet::try_from(packet.as_str()).unwrap()) + }); + c.bench_function("Decode packet message", |b| { + let packet: String = Packet::Message(black_box("Hello").to_string()) + .try_into() + .unwrap(); + b.iter(|| Packet::try_from(packet.as_str()).unwrap()) + }); + c.bench_function("Decode packet noop", |b| { + let packet: String = Packet::Noop.try_into().unwrap(); + b.iter(|| Packet::try_from(packet.as_str()).unwrap()) + }); + c.bench_function("Decode packet binary b64", |b| { + let packet: String = Packet::Binary(black_box(vec![0x00, 0x01, 0x02, 0x03, 0x04, 0x05])) + .try_into() + .unwrap(); + b.iter(|| Packet::try_from(packet.clone()).unwrap()) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/engineioxide/benches/packet_encode.rs b/engineioxide/benches/packet_encode.rs new file mode 100644 index 00000000..499e9e2b --- /dev/null +++ b/engineioxide/benches/packet_encode.rs @@ -0,0 +1,36 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use engineioxide::{config::EngineIoConfig, sid::Sid, OpenPacket, Packet, TransportType}; + +fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("Encode packet open", |b| { + let packet = Packet::Open(OpenPacket::new( + black_box(TransportType::Polling), + black_box(Sid::ZERO), + &EngineIoConfig::default(), + )); + b.iter(|| TryInto::::try_into(packet.clone())) + }); + c.bench_function("Encode packet ping/pong", |b| { + let packet = Packet::Ping; + b.iter(|| TryInto::::try_into(packet.clone())) + }); + c.bench_function("Encode packet ping/pong upgrade", |b| { + let packet = Packet::PingUpgrade; + b.iter(|| TryInto::::try_into(packet.clone())) + }); + c.bench_function("Encode packet message", |b| { + let packet = Packet::Message(black_box("Hello").to_string()); + b.iter(|| TryInto::::try_into(packet.clone())) + }); + c.bench_function("Encode packet noop", |b| { + let packet = Packet::Noop; + b.iter(|| TryInto::::try_into(packet.clone())) + }); + c.bench_function("Encode packet binary b64", |b| { + let packet = Packet::Binary(black_box(vec![0x00, 0x01, 0x02, 0x03, 0x04, 0x05])); + b.iter(|| TryInto::::try_into(packet.clone())) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/engineioxide/src/errors.rs b/engineioxide/src/errors.rs index becdd2a3..85d3e7ec 100644 --- a/engineioxide/src/errors.rs +++ b/engineioxide/src/errors.rs @@ -45,6 +45,8 @@ pub enum Error { #[error("Invalid packet length")] InvalidPacketLength, + #[error("Invalid packet type")] + InvalidPacketType(Option), } /// Convert an error into an http response @@ -64,10 +66,12 @@ impl From for Response> { .status(code) .body(ResponseBody::empty_response()) .unwrap(), - Error::BadPacket(_) => Response::builder() - .status(400) - .body(ResponseBody::empty_response()) - .unwrap(), + Error::BadPacket(_) | Error::InvalidPacketLength | Error::InvalidPacketType(_) => { + Response::builder() + .status(400) + .body(ResponseBody::empty_response()) + .unwrap() + } Error::PayloadTooLarge => Response::builder() .status(413) .body(ResponseBody::empty_response()) diff --git a/engineioxide/src/lib.rs b/engineioxide/src/lib.rs index fdbc46eb..098f06f4 100644 --- a/engineioxide/src/lib.rs +++ b/engineioxide/src/lib.rs @@ -35,6 +35,9 @@ pub use service::{ProtocolVersion, TransportType}; pub use socket::{DisconnectReason, Socket}; +#[cfg(feature = "test-utils")] +pub use packet::*; + pub mod config; pub mod handler; pub mod layer; diff --git a/engineioxide/src/packet.rs b/engineioxide/src/packet.rs index ac23cbd7..4dda495d 100644 --- a/engineioxide/src/packet.rs +++ b/engineioxide/src/packet.rs @@ -1,12 +1,13 @@ use base64::{engine::general_purpose, Engine}; -use serde::{de::Error, Deserialize, Serialize}; +use serde::Serialize; use crate::config::EngineIoConfig; +use crate::errors::Error; use crate::sid::Sid; use crate::TransportType; /// A Packet type to use when receiving and sending data from the client -#[derive(Debug, PartialEq, PartialOrd)] +#[derive(Debug, Clone, PartialEq, PartialOrd)] pub enum Packet { /// Open packet used to initiate a connection Open(OpenPacket), @@ -90,14 +91,14 @@ impl Packet { Packet::Noop => 1, Packet::Binary(data) => { if b64 { - 1 + ((data.len() as f64) / 3.).ceil() as usize * 4 + 1 + base64::encoded_len(data.len(), true).unwrap_or(usize::MAX - 1) } else { 1 + data.len() } } Packet::BinaryV3(data) => { if b64 { - 2 + ((data.len() as f64) / 3.).ceil() as usize * 4 + 2 + base64::encoded_len(data.len(), true).unwrap_or(usize::MAX - 2) } else { 1 + data.len() } @@ -108,77 +109,75 @@ impl Packet { /// Serialize a [Packet] to a [String] according to the Engine.IO protocol impl TryInto for Packet { - type Error = crate::errors::Error; + type Error = Error; fn try_into(self) -> Result { - let res = match self { + let len = self.get_size_hint(true); + let mut buffer = String::with_capacity(len); + match self { Packet::Open(open) => { - "0".to_string() + &serde_json::to_string(&open).map_err(Self::Error::from)? + buffer.push('0'); + buffer.push_str(&serde_json::to_string(&open)?); + } + Packet::Close => buffer.push('1'), + Packet::Ping => buffer.push('2'), + Packet::Pong => buffer.push('3'), + Packet::PingUpgrade => buffer.push_str("2probe"), + Packet::PongUpgrade => buffer.push_str("3probe"), + Packet::Message(msg) => { + buffer.push('4'); + buffer.push_str(&msg); + } + Packet::Upgrade => buffer.push('5'), + Packet::Noop => buffer.push('6'), + Packet::Binary(data) => { + buffer.push('b'); + general_purpose::STANDARD.encode_string(data, &mut buffer); + } + Packet::BinaryV3(data) => { + buffer.push_str("b4"); + general_purpose::STANDARD.encode_string(data, &mut buffer); } - Packet::Close => "1".to_string(), - Packet::Ping => "2".to_string(), - Packet::Pong => "3".to_string(), - Packet::PingUpgrade => "2probe".to_string(), - Packet::PongUpgrade => "3probe".to_string(), - Packet::Message(msg) => "4".to_string() + &msg, - Packet::Upgrade => "5".to_string(), - Packet::Noop => "6".to_string(), - Packet::Binary(data) => "b".to_string() + &general_purpose::STANDARD.encode(data), - Packet::BinaryV3(data) => "b4".to_string() + &general_purpose::STANDARD.encode(data), }; - Ok(res) + Ok(buffer) } } /// Deserialize a [Packet] from a [String] according to the Engine.IO protocol impl TryFrom<&str> for Packet { - type Error = crate::errors::Error; + type Error = Error; fn try_from(value: &str) -> Result { - let mut chars = value.chars(); - let packet_type = chars - .next() - .ok_or_else(|| serde_json::Error::custom("Packet type not found in packet string"))?; - let packet_data = chars.as_str(); - let is_upgrade = packet_data.starts_with("probe"); + let packet_type = value + .as_bytes() + .first() + .ok_or(Error::InvalidPacketType(None))?; + let is_upgrade = value.len() == 6 && &value[1..6] == "probe"; let res = match packet_type { - '0' => Packet::Open(serde_json::from_str(packet_data)?), - '1' => Packet::Close, - '2' => { - if is_upgrade { - Packet::PingUpgrade - } else { - Packet::Ping - } + b'1' => Packet::Close, + b'2' if is_upgrade => Packet::PingUpgrade, + b'2' => Packet::Ping, + b'3' if is_upgrade => Packet::PongUpgrade, + b'3' => Packet::Pong, + b'4' => Packet::Message(value[1..].to_string()), + b'5' => Packet::Upgrade, + b'6' => Packet::Noop, + b'b' if value.as_bytes().get(1) == Some(&b'4') => { + Packet::BinaryV3(general_purpose::STANDARD.decode(value[2..].as_bytes())?) } - '3' => { - if is_upgrade { - Packet::PongUpgrade - } else { - Packet::Pong - } - } - '4' => Packet::Message(packet_data.to_string()), - '5' => Packet::Upgrade, - '6' => Packet::Noop, - 'b' if value.starts_with("b4") => { - Packet::BinaryV3(general_purpose::STANDARD.decode(packet_data[1..].as_bytes())?) - } - 'b' => Packet::Binary(general_purpose::STANDARD.decode(packet_data.as_bytes())?), - c => Err(serde_json::Error::custom( - "Invalid packet type ".to_string() + &c.to_string(), - ))?, + b'b' => Packet::Binary(general_purpose::STANDARD.decode(value[1..].as_bytes())?), + c => Err(Error::InvalidPacketType(Some(*c as char)))?, }; Ok(res) } } impl TryFrom for Packet { - type Error = crate::errors::Error; + type Error = Error; fn try_from(value: String) -> Result { Packet::try_from(value.as_str()) } } /// An OpenPacket is used to initiate a connection -#[derive(Debug, Serialize, Deserialize, PartialEq, PartialOrd)] +#[derive(Debug, Clone, Serialize, PartialEq, PartialOrd)] #[serde(rename_all = "camelCase")] pub struct OpenPacket { sid: Sid, @@ -226,26 +225,9 @@ mod tests { assert_eq!(packet_str, format!("0{{\"sid\":\"{sid}\",\"upgrades\":[\"websocket\"],\"pingInterval\":25000,\"pingTimeout\":20000,\"maxPayload\":100000}}")); } - #[test] - fn test_open_packet_deserialize() { - let sid = Sid::new(); - let packet_str = format!("0{{\"sid\":\"{sid}\",\"upgrades\":[\"websocket\"],\"pingInterval\":25000,\"pingTimeout\":20000,\"maxPayload\":100000}}"); - let packet = Packet::try_from(packet_str.to_string()).unwrap(); - assert_eq!( - packet, - Packet::Open(OpenPacket { - sid, - upgrades: vec!["websocket".to_string()], - ping_interval: 25000, - ping_timeout: 20000, - max_payload: 1e5 as u64, - }) - ); - } - #[test] fn test_message_packet() { - let packet = Packet::Message("hello".to_string()); + let packet = Packet::Message("hello".into()); let packet_str: String = packet.try_into().unwrap(); assert_eq!(packet_str, "4hello"); } @@ -254,7 +236,7 @@ mod tests { fn test_message_packet_deserialize() { let packet_str = "4hello".to_string(); let packet: Packet = packet_str.try_into().unwrap(); - assert_eq!(packet, Packet::Message("hello".to_string())); + assert_eq!(packet, Packet::Message("hello".into())); } #[test] @@ -319,7 +301,7 @@ mod tests { let packet = Packet::PongUpgrade; assert_eq!(packet.get_size_hint(false), 6); - let packet = Packet::Message("hello".to_string()); + let packet = Packet::Message("hello".into()); assert_eq!(packet.get_size_hint(false), 6); let packet = Packet::Upgrade; diff --git a/engineioxide/src/socket.rs b/engineioxide/src/socket.rs index c64ea47a..6480aa91 100644 --- a/engineioxide/src/socket.rs +++ b/engineioxide/src/socket.rs @@ -106,7 +106,9 @@ impl From<&Error> for Option { WsTransport(tungstenite::Error::ConnectionClosed) => None, WsTransport(_) | Io(_) => Some(DisconnectReason::TransportError), BadPacket(_) | Serialize(_) | Base64(_) | StrUtf8(_) | PayloadTooLarge - | InvalidPacketLength => Some(DisconnectReason::PacketParsingError), + | InvalidPacketLength | InvalidPacketType(_) => { + Some(DisconnectReason::PacketParsingError) + } HeartbeatTimeout => Some(DisconnectReason::HeartbeatTimeout), _ => None, } diff --git a/socketioxide/Cargo.toml b/socketioxide/Cargo.toml index f7ecff05..3734a48e 100644 --- a/socketioxide/Cargo.toml +++ b/socketioxide/Cargo.toml @@ -48,6 +48,7 @@ engineioxide = { path = "../engineioxide", features = [ "tracing", "test-utils", ] } +tokio-tungstenite.workspace = true axum.workspace = true salvo.workspace = true warp.workspace = true @@ -57,16 +58,8 @@ tokio = { workspace = true, features = [ "rt-multi-thread", ] } tracing-subscriber.workspace = true -tokio-tungstenite = "0.20.0" -hyper = { workspace = true, features = [ - "http1", - "http2", - "server", - "stream", - "runtime", - "client", -] } -criterion = { version = "0.5.1", features = ["html_reports"] } +hyper.workspace = true +criterion.workspace = true # docs.rs-specific configuration [package.metadata.docs.rs] diff --git a/socketioxide/src/socket.rs b/socketioxide/src/socket.rs index 7454e074..1d954ea3 100644 --- a/socketioxide/src/socket.rs +++ b/socketioxide/src/socket.rs @@ -116,7 +116,7 @@ pub struct AckResponse { pub struct Socket { config: Arc, ns: Arc>, - message_handlers: RwLock>>, + message_handlers: RwLock, BoxedMessageHandler>>, disconnect_handler: Mutex>>, ack_message: Mutex>>>, ack_counter: AtomicI64, @@ -207,7 +207,7 @@ impl Socket { /// }); /// }); /// ``` - pub fn on(&self, event: impl Into, handler: H) + pub fn on(&self, event: impl Into>, handler: H) where H: MessageHandler, T: Send + Sync + 'static, @@ -267,7 +267,11 @@ impl Socket { /// }); /// }); /// ``` - pub fn emit(&self, event: impl Into, data: impl Serialize) -> Result<(), SendError> { + pub fn emit( + &self, + event: impl Into>, + data: impl Serialize, + ) -> Result<(), SendError> { let ns = self.ns(); let data = serde_json::to_value(data)?; if let Err(e) = self.send(Packet::event(ns, event.into(), data)) { @@ -305,7 +309,7 @@ impl Socket { /// ``` pub async fn emit_with_ack( &self, - event: impl Into, + event: impl Into>, data: impl Serialize, ) -> Result, AckError> where