diff --git a/crates/engineioxide/Cargo.toml b/crates/engineioxide/Cargo.toml index b8ac8ae8..93425318 100644 --- a/crates/engineioxide/Cargo.toml +++ b/crates/engineioxide/Cargo.toml @@ -54,7 +54,8 @@ tracing-subscriber.workspace = true hyper = { workspace = true, features = ["server", "http1"] } criterion.workspace = true axum.workspace = true -hyper-util = { workspace = true, features = ["tokio", "client-legacy"] } +tokio-stream = "0.1" +tokio-util = { version = "0.7", features = ["io"], default-features = false } [features] v3 = ["memchr", "unicode-segmentation", "itoa"] diff --git a/crates/engineioxide/src/service/mod.rs b/crates/engineioxide/src/service/mod.rs index af9b6a27..f34745fe 100644 --- a/crates/engineioxide/src/service/mod.rs +++ b/crates/engineioxide/src/service/mod.rs @@ -165,7 +165,7 @@ impl EngineIoService where H: EngineIoHandler, { - /// Create a new engine.io over websocket through a raw stream. + /// Create a new engine.io conn over websocket through a raw stream. /// Mostly used for testing. pub fn ws_init( &self, diff --git a/crates/engineioxide/tests/fixture.rs b/crates/engineioxide/tests/fixture.rs index 08bd6fd1..669a5662 100644 --- a/crates/engineioxide/tests/fixture.rs +++ b/crates/engineioxide/tests/fixture.rs @@ -8,7 +8,7 @@ use std::{ time::Duration, }; -use bytes::{BufMut, Bytes}; +use bytes::Bytes; use engineioxide::{ config::EngineIoConfig, handler::EngineIoHandler, service::EngineIoService, sid::Sid, ProtocolVersion, @@ -20,10 +20,12 @@ use tokio::{ io::{AsyncRead, AsyncWrite, ReadBuf}, sync::mpsc, }; +use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_tungstenite::{ tungstenite::{handshake::client::generate_key, protocol::Role}, WebSocketStream, }; +use tokio_util::io::StreamReader; use tower_service::Service; /// An OpenPacket is used to initiate a connection @@ -89,20 +91,32 @@ pub async fn create_ws_connection( new_ws_mock_conn(svc, ProtocolVersion::V4, None).await } -pub struct StreamImpl(mpsc::UnboundedSender, mpsc::UnboundedReceiver); +pin_project_lite::pin_project! { + pub struct StreamImpl { + tx: mpsc::UnboundedSender>, + #[pin] + rx: StreamReader>, Bytes>, + } +} +impl StreamImpl { + pub fn new( + tx: mpsc::UnboundedSender>, + rx: mpsc::UnboundedReceiver>, + ) -> Self { + Self { + tx, + rx: StreamReader::new(UnboundedReceiverStream::new(rx)), + } + } +} impl AsyncRead for StreamImpl { fn poll_read( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - self.1.poll_recv(cx).map(|e| { - if let Some(e) = e { - buf.put(e); - } - Ok(()) - }) + self.project().rx.poll_read(cx, buf) } } impl AsyncWrite for StreamImpl { @@ -112,7 +126,10 @@ impl AsyncWrite for StreamImpl { buf: &[u8], ) -> Poll> { let len = buf.len(); - self.0.send(Bytes::copy_from_slice(buf)).unwrap(); + self.project() + .tx + .send(Ok(Bytes::copy_from_slice(buf))) + .unwrap(); Poll::Ready(Ok(len)) } @@ -120,11 +137,7 @@ impl AsyncWrite for StreamImpl { Poll::Ready(Ok(())) } - fn poll_shutdown( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - self.1.close(); + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } } @@ -148,11 +161,10 @@ async fn new_ws_mock_conn( .unwrap() .into_parts() .0; - - tokio::spawn(svc.ws_init(StreamImpl(tx, rx1), protocol, sid, parts)); + tokio::spawn(svc.ws_init(StreamImpl::new(tx, rx1), protocol, sid, parts)); tokio_tungstenite::WebSocketStream::from_raw_socket( - StreamImpl(tx1, rx), + StreamImpl::new(tx1, rx), Role::Client, Default::default(), )