diff --git a/examples/chat/src/main.rs b/examples/chat/src/main.rs index 2f73fab4..8790bfca 100644 --- a/examples/chat/src/main.rs +++ b/examples/chat/src/main.rs @@ -88,7 +88,7 @@ async fn main() -> Result<(), Box> { .ok(); }); - s.on_disconnect(move |s, _| async move { + s.on_disconnect(|s: SocketRef| { if let Some(username) = s.extensions.get::() { let i = NUM_USERS.fetch_sub(1, std::sync::atomic::Ordering::SeqCst) - 1; let res = Res::UserEvent { diff --git a/examples/private-messaging/src/handlers.rs b/examples/private-messaging/src/handlers.rs index b72354e9..1962a4e1 100644 --- a/examples/private-messaging/src/handlers.rs +++ b/examples/private-messaging/src/handlers.rs @@ -61,7 +61,7 @@ pub fn on_connection(s: SocketRef, TryData(auth): TryData) { }, ); - s.on_disconnect(|s, _| async move { + s.on_disconnect(|s: SocketRef| { let mut session = s.extensions.get::().unwrap().clone(); session.connected = false; diff --git a/socketioxide/src/handler/disconnect.rs b/socketioxide/src/handler/disconnect.rs new file mode 100644 index 00000000..90db32e7 --- /dev/null +++ b/socketioxide/src/handler/disconnect.rs @@ -0,0 +1,219 @@ +//! [`DisconnectHandler`] trait and implementations, used to handle the disconnect event. +//! It has a flexible axum-like API, you can put any arguments as long as it implements the [`FromDisconnectParts`] trait. +//! +//! You can also implement the [`FromDisconnectParts`] trait for your own types. +//! See the [`extract`](super::extract) module doc for more details on available extractors. +//! +//! Handlers can be _optionally_ async. +//! +//! ## Example with sync closures +//! ```rust +//! # use socketioxide::SocketIo; +//! # use serde_json::Error; +//! # use socketioxide::extract::*; +//! # use socketioxide::socket::DisconnectReason; +//! let (svc, io) = SocketIo::new_svc(); +//! io.ns("/", |s: SocketRef| { +//! s.on_disconnect(|s: SocketRef, reason: DisconnectReason| { +//! println!("Socket {} was disconnected because {} ", s.id, reason); +//! }); +//! }); +//! ``` +//! +//! ## Example with async closures +//! ```rust +//! # use socketioxide::SocketIo; +//! # use serde_json::Error; +//! # use socketioxide::extract::*; +//! let (svc, io) = SocketIo::new_svc(); +//! io.ns("/", |s: SocketRef| { +//! s.on_disconnect(move |s: SocketRef| async move { +//! println!("Socket {} was disconnected", s.id); +//! }); +//! }); +//! ``` +//! +//! ## Example with async non anonymous functions +//! ```rust +//! # use socketioxide::SocketIo; +//! # use serde_json::Error; +//! # use socketioxide::extract::*; +//! # use socketioxide::socket::DisconnectReason; +//! async fn handler(s: SocketRef, reason: DisconnectReason) { +//! tokio::time::sleep(std::time::Duration::from_secs(1)).await; +//! println!("Socket disconnected on {} namespace with id and reason: {} {}", s.ns(), s.id, reason); +//! } +//! +//! let (svc, io) = SocketIo::new_svc(); +//! +//! // You can reuse the same handler for multiple sockets +//! io.ns("/", |s: SocketRef| { +//! s.on_disconnect(handler); +//! }); +//! io.ns("/admin", |s: SocketRef| { +//! s.on_disconnect(handler); +//! }); +//! ``` +use std::sync::Arc; + +use futures::Future; + +use crate::{ + adapter::Adapter, + socket::{DisconnectReason, Socket}, +}; + +use super::MakeErasedHandler; + +/// A Type Erased [`DisconnectHandler`] so it can be stored in a HashMap +pub(crate) type BoxedDisconnectHandler = Box>; +pub(crate) trait ErasedDisconnectHandler: Send + Sync + 'static { + fn call(&self, s: Arc>, reason: DisconnectReason); +} + +impl MakeErasedHandler +where + T: Send + Sync + 'static, + H: DisconnectHandler + Send + Sync + 'static, +{ + pub fn new_disconnect_boxed(inner: H) -> Box> { + Box::new(MakeErasedHandler::new(inner)) + } +} + +impl ErasedDisconnectHandler for MakeErasedHandler +where + H: DisconnectHandler + Send + Sync + 'static, + T: Send + Sync + 'static, +{ + #[inline(always)] + fn call(&self, s: Arc>, reason: DisconnectReason) { + self.handler.call(s, reason); + } +} + +/// A trait used to extract the arguments from the disconnect event. +/// The `Result` associated type is used to return an error if the extraction fails, +/// in this case the [`DisconnectHandler`] is not called. +/// +/// * See the [`disconnect`](super::disconnect) module doc for more details on disconnect handler. +/// * See the [`extract`](super::extract) module doc for more details on available extractors. +pub trait FromDisconnectParts: Sized { + /// The error type returned by the extractor + type Error: std::error::Error + 'static; + + /// Extract the arguments from the disconnect event. + /// If it fails, the handler is not called + fn from_disconnect_parts( + s: &Arc>, + reason: DisconnectReason, + ) -> Result; +} + +/// Define a handler for the disconnect event. +/// It is implemented for closures with up to 16 arguments. They must implement the [`FromDisconnectParts`] trait. +/// +/// * See the [`disconnect`](super::disconnect) module doc for more details on disconnect handler. +/// * See the [`extract`](super::extract) module doc for more details on available extractors. +pub trait DisconnectHandler: Send + Sync + 'static { + /// Call the handler with the given arguments. + fn call(&self, s: Arc>, reason: DisconnectReason); + + #[doc(hidden)] + fn phantom(&self) -> std::marker::PhantomData { + std::marker::PhantomData + } +} + +mod private { + #[derive(Debug, Copy, Clone)] + pub enum Sync {} + #[derive(Debug, Copy, Clone)] + pub enum Async {} +} + +macro_rules! impl_handler_async { + ( + [$($ty:ident),*] + ) => { + #[allow(non_snake_case, unused)] + impl DisconnectHandler for F + where + F: FnOnce($($ty,)*) -> Fut + Send + Sync + Clone + 'static, + Fut: Future + Send + 'static, + A: Adapter, + $( $ty: FromDisconnectParts + Send, )* + { + fn call(&self, s: Arc>, reason: DisconnectReason) { + $( + let $ty = match $ty::from_disconnect_parts(&s, reason) { + Ok(v) => v, + Err(_e) => { + #[cfg(feature = "tracing")] + tracing::error!("Error while extracting data: {}", _e); + return; + }, + }; + )* + + let fut = (self.clone())($($ty,)*); + tokio::spawn(fut); + + } + } + }; +} + +macro_rules! impl_handler { + ( + [$($ty:ident),*] + ) => { + #[allow(non_snake_case, unused)] + impl DisconnectHandler for F + where + F: FnOnce($($ty,)*) + Send + Sync + Clone + 'static, + A: Adapter, + $( $ty: FromDisconnectParts + Send, )* + { + fn call(&self, s: Arc>, reason: DisconnectReason) { + $( + let $ty = match $ty::from_disconnect_parts(&s, reason) { + Ok(v) => v, + Err(_e) => { + #[cfg(feature = "tracing")] + tracing::error!("Error while extracting data: {}", _e); + return; + }, + }; + )* + + (self.clone())($($ty,)*); + } + } + }; +} +#[rustfmt::skip] +macro_rules! all_the_tuples { + ($name:ident) => { + $name!([]); + $name!([T1]); + $name!([T1, T2]); + $name!([T1, T2, T3]); + $name!([T1, T2, T3, T4]); + $name!([T1, T2, T3, T4, T5]); + $name!([T1, T2, T3, T4, T5, T6]); + $name!([T1, T2, T3, T4, T5, T6, T7]); + $name!([T1, T2, T3, T4, T5, T6, T7, T8]); + $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9]); + $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10]); + $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11]); + $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12]); + $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13]); + $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14]); + $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15]); + $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16]); + }; +} + +all_the_tuples!(impl_handler_async); +all_the_tuples!(impl_handler); diff --git a/socketioxide/src/handler/extract.rs b/socketioxide/src/handler/extract.rs index d2bf04b7..f6aac605 100644 --- a/socketioxide/src/handler/extract.rs +++ b/socketioxide/src/handler/extract.rs @@ -1,4 +1,5 @@ -//! ### Extractors for [`ConnectHandler`](super::ConnectHandler) and [`MessageHandler`](super::MessageHandler). +//! ### Extractors for [`ConnectHandler`](super::ConnectHandler), [`MessageHandler`](super::MessageHandler) +//! and [`DisconnectHandler`](super::DisconnectHandler). //! //! They can be used to extract data from the context of the handler and get specific params. Here are some examples of extractors: //! * [`Data`]: extracts and deserialize to json any data, if a deserialization error occurs the handler won't be called: @@ -10,8 +11,10 @@ //! * [`SocketRef`]: extracts a reference to the [`Socket`] //! * [`Bin`]: extract a binary payload for a given message. Because it consumes the event it should be the last argument //! * [`AckSender`]: Can be used to send an ack response to the current message event +//! * [`ProtocolVersion`](crate::ProtocolVersion): extracts the protocol version +//! * [`TransportType`](crate::TransportType): extracts the transport type //! -//! ### You can also implement your own Extractor with the [`FromConnectParts`]and [`FromMessageParts`] traits +//! ### You can also implement your own Extractor with the [`FromConnectParts`], [`FromMessageParts`] and [`FromDisconnectParts`] traits //! When implementing these traits, if you clone the [`Arc`] make sure that it is dropped at least when the socket is disconnected. //! Otherwise it will create a memory leak. It is why the [`SocketRef`] extractor is used instead of cloning the socket for common usage. //! @@ -81,7 +84,9 @@ use std::convert::Infallible; use std::sync::Arc; use super::message::FromMessageParts; +use super::FromDisconnectParts; use super::{connect::FromConnectParts, message::FromMessage}; +use crate::socket::DisconnectReason; use crate::{ adapter::{Adapter, LocalAdapter}, packet::Packet, @@ -203,6 +208,12 @@ impl FromMessageParts for SocketRef { Ok(SocketRef(s.clone())) } } +impl FromDisconnectParts for SocketRef { + type Error = Infallible; + fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + Ok(SocketRef(s.clone())) + } +} impl std::ops::Deref for SocketRef { type Target = Socket; @@ -317,6 +328,12 @@ impl FromMessageParts for crate::ProtocolVersion { Ok(s.protocol()) } } +impl FromDisconnectParts for crate::ProtocolVersion { + type Error = Infallible; + fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + Ok(s.protocol()) + } +} impl FromConnectParts for crate::TransportType { type Error = Infallible; @@ -341,6 +358,23 @@ impl FromMessageParts for crate::TransportType { } } +impl FromDisconnectParts for crate::TransportType { + type Error = Infallible; + fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + Ok(s.transport_type()) + } +} + +impl FromDisconnectParts for DisconnectReason { + type Error = Infallible; + fn from_disconnect_parts( + _: &Arc>, + reason: DisconnectReason, + ) -> Result { + Ok(reason) + } +} + /// An Extractor that contains a reference to a state previously set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder). /// It implements [`std::ops::Deref`] to access the inner type so you can use it as a normal reference. /// diff --git a/socketioxide/src/handler/mod.rs b/socketioxide/src/handler/mod.rs index 51ebe39f..2868fb38 100644 --- a/socketioxide/src/handler/mod.rs +++ b/socketioxide/src/handler/mod.rs @@ -1,12 +1,15 @@ //! Functions and types used to handle incoming connections and messages. -//! There is two main types of handlers: [`ConnectHandler`] and [`MessageHandler`]. -//! Both handlers can be async or not. +//! There is three main types of handlers: [`ConnectHandler`], [`MessageHandler`] and [`DisconnectHandler`]. +//! All handlers can be async or not. pub mod connect; +pub mod disconnect; pub mod extract; pub mod message; pub(crate) use connect::BoxedConnectHandler; pub use connect::{ConnectHandler, FromConnectParts}; +pub(crate) use disconnect::BoxedDisconnectHandler; +pub use disconnect::{DisconnectHandler, FromDisconnectParts}; pub(crate) use message::BoxedMessageHandler; pub use message::{FromMessage, FromMessageParts, MessageHandler}; /// A struct used to erase the type of a [`ConnectHandler`] or [`MessageHandler`] so it can be stored in a map diff --git a/socketioxide/src/layer.rs b/socketioxide/src/layer.rs index c6a046ef..6cc05c99 100644 --- a/socketioxide/src/layer.rs +++ b/socketioxide/src/layer.rs @@ -1,4 +1,4 @@ -//! ## A tower [`Layer`](tower::Layer) for socket.io so it can be used as a middleware with frameworks supporting layers. +//! ## A tower [`Layer`] for socket.io so it can be used as a middleware with frameworks supporting layers. //! //! #### Example with axum : //! ```rust diff --git a/socketioxide/src/lib.rs b/socketioxide/src/lib.rs index 56df5c39..cdf2239e 100644 --- a/socketioxide/src/lib.rs +++ b/socketioxide/src/lib.rs @@ -139,17 +139,25 @@ //! ``` //! //! ## Handlers -//! Handlers are functions or clonable closures that are given to the `io.ns` and the `socket.on` methods. They can be async or sync and can take from 0 to 16 arguments that implements the [`FromConnectParts`](handler::FromConnectParts) trait for the [`ConnectHandler`](handler::ConnectHandler) and the [`FromMessageParts`](handler::FromMessageParts) for the [`MessageHandler`](handler::MessageHandler). They are greatly inspired by the axum handlers. +//! Handlers are functions or clonable closures that are given to the `io.ns`, the `socket.on` and the `socket.on_disconnect` fns. +//! They can be async or sync and can take from 0 to 16 arguments that implements the [`FromConnectParts`](handler::FromConnectParts) +//! trait for the [`ConnectHandler`](handler::ConnectHandler), the [`FromMessageParts`](handler::FromMessageParts) for +//! the [`MessageHandler`](handler::MessageHandler) and the [`FromDisconnectParts`](handler::FromDisconnectParts) for +//! the [`DisconnectHandler`](handler::DisconnectHandler). +//! They are greatly inspired by the axum handlers. //! //! If they are async, a new task will be spawned for each incoming connection/message so it doesn't block the event management task. //! //! * Check the [`handler::connect`] module doc for more details on the connect handler //! * Check the [`handler::message`] module doc for more details on the message handler. +//! * Check the [`handler::disconnect`] module doc for more details on the disconnect handler. //! * Check the [`handler::extract`] module doc for more details on the extractors. //! //! ## Extractors //! Handlers params are called extractors and are used to extract data from the incoming connection/message. They are inspired by the axum extractors. -//! An extractor is a struct that implements the [`FromConnectParts`](handler::FromConnectParts) trait for the [`ConnectHandler`](handler::ConnectHandler) and the [`FromMessageParts`](handler::FromMessageParts) for the [`MessageHandler`](handler::MessageHandler). +//! An extractor is a struct that implements the [`FromConnectParts`](handler::FromConnectParts) trait for the [`ConnectHandler`](handler::ConnectHandler) +//! the [`FromMessageParts`](handler::FromMessageParts) for the [`MessageHandler`](handler::MessageHandler) and the +//! [`FromDisconnectParts`](handler::FromDisconnectParts) for the [`DisconnectHandler`](handler::DisconnectHandler). //! //! Here are some examples of extractors: //! * [`Data`](extract::Data): extracts and deserialize to json any data, if a deserialize error occurs the handler won't be called @@ -161,6 +169,8 @@ //! * [`SocketRef`](extract::Data): extracts a reference to the [`Socket`](socket::Socket) //! * [`Bin`](extract::Data): extract a binary payload for a given message. Because it consumes the event it should be the last argument //! * [`AckSender`](extract::Data): Can be used to send an ack response to the current message event +//! * [`ProtocolVersion`]: extracts the protocol version of the socket +//! * [`TransportType`]: extracts the transport type of the socket //! //! ### Extractor order //! Extractors are run in the order of their declaration in the handler signature. If an extractor returns an error, the handler won't be called and a `tracing::error!` call will be emitted if the `tracing` feature is enabled. @@ -173,7 +183,7 @@ //! There are three types of events: //! * The connect event is emitted when a new connection is established. It can be handled with the [`ConnectHandler`](handler::ConnectHandler) and the `io.ns` method. //! * The message event is emitted when a new message is received. It can be handled with the [`MessageHandler`](handler::MessageHandler) and the `socket.on` method. -//! * The disconnect event is emitted when a socket is closed. Contrary to the two previous events, the callback is not flexible, it *must* be async and have the following signature `async fn(SocketRef, DisconnectReason)`. It can be handled with the `socket.on_disconnect` method. +//! * The disconnect event is emitted when a socket is closed. It can be handled with the [`DisconnectHandler`](handler::DisconnectHandler) and the `socket.on_disconnect` method. //! //! Only one handler can exist for an event so registering a new handler for an event will replace the previous one. //! diff --git a/socketioxide/src/service.rs b/socketioxide/src/service.rs index d69181e4..9e376b5c 100644 --- a/socketioxide/src/service.rs +++ b/socketioxide/src/service.rs @@ -1,4 +1,4 @@ -//! ## A tower [`Service`](tower::Service) for socket.io so it can be used with frameworks supporting tower services. +//! ## A tower [`Service`] for socket.io so it can be used with frameworks supporting tower services. //! //! #### Example with a `Warp` inner service : //! ```rust diff --git a/socketioxide/src/socket.rs b/socketioxide/src/socket.rs index 26dc3be7..06887292 100644 --- a/socketioxide/src/socket.rs +++ b/socketioxide/src/socket.rs @@ -1,5 +1,5 @@ //! A [`Socket`] represents a client connected to a namespace. -//! The socket struct itself should not be used directly, but through a [`SocketRef`]. +//! The socket struct itself should not be used directly, but through a [`SocketRef`](crate::extract::SocketRef). use std::{ borrow::Cow, collections::HashMap, @@ -13,7 +13,6 @@ use std::{ }; use engineioxide::{sid::Sid, socket::DisconnectReason as EIoDisconnectReason}; -use futures::{future::BoxFuture, Future}; use serde::{de::DeserializeOwned, Serialize}; use serde_json::Value; use tokio::sync::oneshot; @@ -24,7 +23,10 @@ use crate::extensions::Extensions; use crate::{ adapter::{Adapter, LocalAdapter, Room}, errors::{AckError, Error}, - handler::{BoxedMessageHandler, MakeErasedHandler, MessageHandler}, + handler::{ + BoxedDisconnectHandler, BoxedMessageHandler, DisconnectHandler, MakeErasedHandler, + MessageHandler, + }, ns::Namespace, operators::{Operators, RoomParam}, packet::{BinaryPacket, Packet, PacketData}, @@ -33,15 +35,12 @@ use crate::{ use crate::{ client::SocketData, errors::{AdapterError, SendError}, - extract::SocketRef, }; -type DisconnectCallback = Box< - dyn FnOnce(SocketRef, DisconnectReason) -> BoxFuture<'static, ()> + Send + Sync + 'static, ->; - /// All the possible reasons for a [`Socket`] to be disconnected from a namespace. -#[derive(Debug, Clone, Eq, PartialEq)] +/// +/// It can be used as an extractor in the [`on_disconnect`](crate::handler::disconnect) handler. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum DisconnectReason { /// The client gracefully closed the connection TransportClose, @@ -112,12 +111,12 @@ pub struct AckResponse { /// A Socket represents a client connected to a namespace. /// It is used to send and receive messages from the client, join and leave rooms, etc. -/// The socket struct itself should not be used directly, but through a [`SocketRef`]. +/// The socket struct itself should not be used directly, but through a [`SocketRef`](crate::extract::SocketRef). pub struct Socket { config: Arc, ns: Arc>, message_handlers: RwLock, BoxedMessageHandler>>, - disconnect_handler: Mutex>>, + disconnect_handler: Mutex>>, ack_message: Mutex>>>, ack_counter: AtomicI64, /// The socket id @@ -219,15 +218,16 @@ impl Socket { } /// ## Registers a disconnect handler. + /// You can register only one disconnect handler per socket. If you register multiple handlers, only the last one will be used. /// - /// Contrary to [`ConnectHandler`](crate::handler::ConnectHandler) and [`MessageHandler`]. - /// Arguments are not dynamic and the handler should always be async. + /// * See the [`disconnect`](crate::handler::disconnect) module doc for more details on disconnect handler. + /// * See the [`extract`](crate::extract) module doc for more details on available extractors. /// /// The callback will be called when the socket is disconnected from the server or the client or when the underlying connection crashes. /// A [`DisconnectReason`] is passed to the callback to indicate the reason for the disconnection. /// ### Example /// ``` - /// # use socketioxide::{SocketIo, extract::*}; + /// # use socketioxide::{SocketIo, socket::DisconnectReason, extract::*}; /// # use serde_json::Value; /// # use std::sync::Arc; /// let (_, io) = SocketIo::new_svc(); @@ -236,17 +236,17 @@ impl Socket { /// // Close the current socket /// socket.disconnect().ok(); /// }); - /// socket.on_disconnect(|socket, reason| async move { + /// socket.on_disconnect(|socket: SocketRef, reason: DisconnectReason| async move { /// println!("Socket {} on ns {} disconnected, reason: {:?}", socket.id, socket.ns(), reason); /// }); /// }); - pub fn on_disconnect(&self, callback: C) + pub fn on_disconnect(&self, callback: C) where - C: Fn(SocketRef, DisconnectReason) -> F + Send + Sync + 'static, - F: Future + Send + 'static, + C: DisconnectHandler + Send + Sync + 'static, + T: Send + Sync + 'static, { - let handler = Box::new(move |s, r| Box::pin(callback(s, r)) as _); - *self.disconnect_handler.lock().unwrap() = Some(handler); + let handler = MakeErasedHandler::new_disconnect_boxed(callback); + self.disconnect_handler.lock().unwrap().replace(handler); } /// Emits a message to the client @@ -585,7 +585,7 @@ impl Socket { /// It maybe also close when the underlying transport is closed or failed. pub(crate) fn close(self: Arc, reason: DisconnectReason) -> Result<(), AdapterError> { if let Some(handler) = self.disconnect_handler.lock().unwrap().take() { - tokio::spawn(handler(SocketRef::new(self.clone()), reason)); + handler.call(self.clone(), reason); } self.ns.remove_socket(self.id)?; diff --git a/socketioxide/tests/disconnect_reason.rs b/socketioxide/tests/disconnect_reason.rs index d4d1cf86..bc1cc354 100644 --- a/socketioxide/tests/disconnect_reason.rs +++ b/socketioxide/tests/disconnect_reason.rs @@ -26,10 +26,9 @@ fn attach_handler(io: &SocketIo, chan_size: usize) -> mpsc::Receiver