Skip to content

Commit bff021a

Browse files
authored
Merge pull request #114 from Totodore/ft-max-payload-encoding
fix: apply max payload option when encoding (fix #113)
2 parents 4b80479 + a8fca85 commit bff021a

File tree

7 files changed

+374
-55
lines changed

7 files changed

+374
-55
lines changed

engineioxide/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ mod body;
1919
mod engine;
2020
mod futures;
2121
mod packet;
22+
mod peekable;
2223
mod transport;

engineioxide/src/packet.rs

+86-1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,39 @@ impl Packet {
7171
_ => panic!("Packet is not a binary"),
7272
}
7373
}
74+
75+
/// Get the max size the packet could have when serialized
76+
///
77+
/// If b64 is true, it returns the max size when serialized to base64
78+
///
79+
/// The base64 max size factor is `ceil(n / 3) * 4`
80+
pub(crate) fn get_size_hint(&self, b64: bool) -> usize {
81+
match self {
82+
Packet::Open(_) => 151, // max possible size for the open packet serialized
83+
Packet::Close => 1,
84+
Packet::Ping => 1,
85+
Packet::Pong => 1,
86+
Packet::PingUpgrade => 6,
87+
Packet::PongUpgrade => 6,
88+
Packet::Message(msg) => 1 + msg.len(),
89+
Packet::Upgrade => 1,
90+
Packet::Noop => 1,
91+
Packet::Binary(data) => {
92+
if b64 {
93+
1 + ((data.len() as f64) / 3.).ceil() as usize * 4
94+
} else {
95+
1 + data.len()
96+
}
97+
}
98+
Packet::BinaryV3(data) => {
99+
if b64 {
100+
2 + ((data.len() as f64) / 3.).ceil() as usize * 4
101+
} else {
102+
1 + data.len()
103+
}
104+
}
105+
}
106+
}
74107
}
75108

76109
/// Serialize a [Packet] to a [String] according to the Engine.IO protocol
@@ -179,7 +212,7 @@ mod tests {
179212
use crate::config::EngineIoConfig;
180213

181214
use super::*;
182-
use std::convert::TryInto;
215+
use std::{convert::TryInto, time::Duration};
183216

184217
#[test]
185218
fn test_open_packet() {
@@ -249,4 +282,56 @@ mod tests {
249282
let packet: Packet = packet_str.try_into().unwrap();
250283
assert_eq!(packet, Packet::BinaryV3(vec![1, 2, 3]));
251284
}
285+
286+
#[test]
287+
fn test_packet_get_size_hint() {
288+
// Max serialized packet
289+
let open = OpenPacket::new(
290+
TransportType::Polling,
291+
Sid::MAX,
292+
&EngineIoConfig {
293+
max_buffer_size: usize::MAX,
294+
max_payload: u64::MAX,
295+
ping_interval: Duration::MAX,
296+
ping_timeout: Duration::MAX,
297+
transports: TransportType::Polling as u8 | TransportType::Websocket as u8,
298+
..Default::default()
299+
},
300+
);
301+
let size = serde_json::to_string(&open).unwrap().len();
302+
let packet = Packet::Open(open);
303+
assert_eq!(packet.get_size_hint(false), size);
304+
305+
let packet = Packet::Close;
306+
assert_eq!(packet.get_size_hint(false), 1);
307+
308+
let packet = Packet::Ping;
309+
assert_eq!(packet.get_size_hint(false), 1);
310+
311+
let packet = Packet::Pong;
312+
assert_eq!(packet.get_size_hint(false), 1);
313+
314+
let packet = Packet::PingUpgrade;
315+
assert_eq!(packet.get_size_hint(false), 6);
316+
317+
let packet = Packet::PongUpgrade;
318+
assert_eq!(packet.get_size_hint(false), 6);
319+
320+
let packet = Packet::Message("hello".to_string());
321+
assert_eq!(packet.get_size_hint(false), 6);
322+
323+
let packet = Packet::Upgrade;
324+
assert_eq!(packet.get_size_hint(false), 1);
325+
326+
let packet = Packet::Noop;
327+
assert_eq!(packet.get_size_hint(false), 1);
328+
329+
let packet = Packet::Binary(vec![1, 2, 3]);
330+
assert_eq!(packet.get_size_hint(false), 4);
331+
assert_eq!(packet.get_size_hint(true), 5);
332+
333+
let packet = Packet::BinaryV3(vec![1, 2, 3]);
334+
assert_eq!(packet.get_size_hint(false), 4);
335+
assert_eq!(packet.get_size_hint(true), 6);
336+
}
252337
}

engineioxide/src/peekable.rs

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
use tokio::sync::mpsc::{error::TryRecvError, Receiver};
2+
3+
/// Peekable receiver for polling transport
4+
/// It is a thin wrapper around a [`Receiver`](tokio::sync::mpsc::Receiver) that allows to peek the next packet without consuming it
5+
///
6+
/// Its main goal is to be able to peek the next packet without consuming it to calculate the
7+
/// packet length when using polling transport to check if it fits according to the max_payload setting
8+
#[derive(Debug)]
9+
pub struct PeekableReceiver<T> {
10+
rx: Receiver<T>,
11+
next: Option<T>,
12+
}
13+
impl<T> PeekableReceiver<T> {
14+
pub fn new(rx: Receiver<T>) -> Self {
15+
Self { rx, next: None }
16+
}
17+
pub fn peek(&mut self) -> Option<&T> {
18+
if self.next.is_none() {
19+
self.next = self.rx.try_recv().ok();
20+
}
21+
self.next.as_ref()
22+
}
23+
pub async fn recv(&mut self) -> Option<T> {
24+
if self.next.is_none() {
25+
self.rx.recv().await
26+
} else {
27+
self.next.take()
28+
}
29+
}
30+
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
31+
if self.next.is_none() {
32+
self.rx.try_recv()
33+
} else {
34+
Ok(self.next.take().unwrap())
35+
}
36+
}
37+
38+
pub fn close(&mut self) {
39+
self.rx.close()
40+
}
41+
}
42+
43+
#[cfg(test)]
44+
mod tests {
45+
use tokio::sync::Mutex;
46+
47+
#[tokio::test]
48+
async fn peek() {
49+
use super::PeekableReceiver;
50+
use crate::packet::Packet;
51+
use tokio::sync::mpsc::channel;
52+
53+
let (tx, rx) = channel(1);
54+
let rx = Mutex::new(PeekableReceiver::new(rx));
55+
let mut rx = rx.lock().await;
56+
57+
assert!(rx.peek().is_none());
58+
59+
tx.send(Packet::Ping).await.unwrap();
60+
assert_eq!(rx.peek(), Some(&Packet::Ping));
61+
assert_eq!(rx.recv().await, Some(Packet::Ping));
62+
assert!(rx.peek().is_none());
63+
64+
tx.send(Packet::Pong).await.unwrap();
65+
assert_eq!(rx.peek(), Some(&Packet::Pong));
66+
assert_eq!(rx.recv().await, Some(Packet::Pong));
67+
assert!(rx.peek().is_none());
68+
69+
tx.send(Packet::Close).await.unwrap();
70+
assert_eq!(rx.peek(), Some(&Packet::Close));
71+
assert_eq!(rx.recv().await, Some(Packet::Close));
72+
assert!(rx.peek().is_none());
73+
74+
tx.send(Packet::Close).await.unwrap();
75+
assert_eq!(rx.peek(), Some(&Packet::Close));
76+
assert_eq!(rx.recv().await, Some(Packet::Close));
77+
assert!(rx.peek().is_none());
78+
}
79+
}

engineioxide/src/socket.rs

+7-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ use tokio::{
1818
use tokio_tungstenite::tungstenite;
1919
use tracing::debug;
2020

21-
use crate::{config::EngineIoConfig, errors::Error, packet::Packet, service::ProtocolVersion};
21+
use crate::{
22+
config::EngineIoConfig, errors::Error, packet::Packet, peekable::PeekableReceiver,
23+
service::ProtocolVersion,
24+
};
2225
use crate::{sid_generator::Sid, transport::TransportType};
2326

2427
/// Http Request data used to create a socket
@@ -118,7 +121,7 @@ where
118121
/// * From the [encoder](crate::service::encoder) if the transport is polling
119122
/// * From the fn [`on_ws_req_init`](crate::engine::EngineIo) if the transport is websocket
120123
/// * Automatically via the [`close_session fn`](crate::engine::EngineIo::close_session) as a fallback. Because with polling transport, if the client is not currently polling then the encoder will never be able to close the channel
121-
pub(crate) internal_rx: Mutex<Receiver<Packet>>,
124+
pub(crate) internal_rx: Mutex<PeekableReceiver<Packet>>,
122125

123126
/// Channel to send [Packet] to the internal connection
124127
internal_tx: mpsc::Sender<Packet>,
@@ -166,7 +169,7 @@ where
166169
protocol,
167170
transport: AtomicU8::new(transport as u8),
168171

169-
internal_rx: Mutex::new(internal_rx),
172+
internal_rx: Mutex::new(PeekableReceiver::new(internal_rx)),
170173
internal_tx,
171174

172175
heartbeat_rx: Mutex::new(heartbeat_rx),
@@ -409,7 +412,7 @@ where
409412
protocol: ProtocolVersion::V4,
410413
transport: AtomicU8::new(TransportType::Websocket as u8),
411414

412-
internal_rx: Mutex::new(internal_rx),
415+
internal_rx: Mutex::new(PeekableReceiver::new(internal_rx)),
413416
internal_tx,
414417

415418
heartbeat_rx: Mutex::new(heartbeat_rx),

engineioxide/src/transport/polling/mod.rs

+9-5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use crate::{
1515
packet::{OpenPacket, Packet},
1616
service::ProtocolVersion,
1717
sid_generator::Sid,
18+
transport::polling::payload::Payload,
1819
DisconnectReason, SocketReq,
1920
};
2021

@@ -49,7 +50,7 @@ where
4950

5051
engine.handler.on_connect(socket);
5152

52-
let packet: String = Packet::Open(packet).try_into()?;
53+
let packet: String = Packet::Open(packet).try_into().unwrap();
5354
let packet = {
5455
#[cfg(feature = "v3")]
5556
{
@@ -97,13 +98,16 @@ where
9798

9899
debug!("[sid={sid}] polling request");
99100

101+
let max_payload = engine.config.max_payload;
102+
100103
#[cfg(feature = "v3")]
101-
let (payload, is_binary) = payload::encoder(rx, protocol, socket.supports_binary).await?;
104+
let Payload { data, has_binary } =
105+
payload::encoder(rx, protocol, socket.supports_binary, max_payload).await?;
102106
#[cfg(not(feature = "v3"))]
103-
let (payload, is_binary) = payload::encoder(rx, protocol).await?;
107+
let Payload { data, has_binary } = payload::encoder(rx, protocol, max_payload).await?;
104108

105-
debug!("[sid={sid}] sending data: {:?}", payload);
106-
Ok(http_response(StatusCode::OK, payload, is_binary)?)
109+
debug!("[sid={sid}] sending data: {:?}", data);
110+
Ok(http_response(StatusCode::OK, data, has_binary)?)
107111
}
108112

109113
/// Handle http polling post request

0 commit comments

Comments
 (0)