Skip to content

Commit

Permalink
Merge branch 'main' into feat-state-management
Browse files Browse the repository at this point in the history
  • Loading branch information
Totodore authored Dec 5, 2023
2 parents e15ef1d + b7674cb commit ea0f210
Show file tree
Hide file tree
Showing 10 changed files with 301 additions and 37 deletions.
2 changes: 1 addition & 1 deletion examples/chat/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.ok();
});

s.on_disconnect(move |s, _| async move {
s.on_disconnect(|s: SocketRef| {
if let Some(username) = s.extensions.get::<Username>() {
let i = NUM_USERS.fetch_sub(1, std::sync::atomic::Ordering::SeqCst) - 1;
let res = Res::UserEvent {
Expand Down
2 changes: 1 addition & 1 deletion examples/private-messaging/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub fn on_connection(s: SocketRef, TryData(auth): TryData<Auth>) {
},
);

s.on_disconnect(|s, _| async move {
s.on_disconnect(|s: SocketRef| {
let mut session = s.extensions.get::<Session>().unwrap().clone();
session.connected = false;

Expand Down
219 changes: 219 additions & 0 deletions socketioxide/src/handler/disconnect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
//! [`DisconnectHandler`] trait and implementations, used to handle the disconnect event.
//! It has a flexible axum-like API, you can put any arguments as long as it implements the [`FromDisconnectParts`] trait.
//!
//! You can also implement the [`FromDisconnectParts`] trait for your own types.
//! See the [`extract`](super::extract) module doc for more details on available extractors.
//!
//! Handlers can be _optionally_ async.
//!
//! ## Example with sync closures
//! ```rust
//! # use socketioxide::SocketIo;
//! # use serde_json::Error;
//! # use socketioxide::extract::*;
//! # use socketioxide::socket::DisconnectReason;
//! let (svc, io) = SocketIo::new_svc();
//! io.ns("/", |s: SocketRef| {
//! s.on_disconnect(|s: SocketRef, reason: DisconnectReason| {
//! println!("Socket {} was disconnected because {} ", s.id, reason);
//! });
//! });
//! ```
//!
//! ## Example with async closures
//! ```rust
//! # use socketioxide::SocketIo;
//! # use serde_json::Error;
//! # use socketioxide::extract::*;
//! let (svc, io) = SocketIo::new_svc();
//! io.ns("/", |s: SocketRef| {
//! s.on_disconnect(move |s: SocketRef| async move {
//! println!("Socket {} was disconnected", s.id);
//! });
//! });
//! ```
//!
//! ## Example with async non anonymous functions
//! ```rust
//! # use socketioxide::SocketIo;
//! # use serde_json::Error;
//! # use socketioxide::extract::*;
//! # use socketioxide::socket::DisconnectReason;
//! async fn handler(s: SocketRef, reason: DisconnectReason) {
//! tokio::time::sleep(std::time::Duration::from_secs(1)).await;
//! println!("Socket disconnected on {} namespace with id and reason: {} {}", s.ns(), s.id, reason);
//! }
//!
//! let (svc, io) = SocketIo::new_svc();
//!
//! // You can reuse the same handler for multiple sockets
//! io.ns("/", |s: SocketRef| {
//! s.on_disconnect(handler);
//! });
//! io.ns("/admin", |s: SocketRef| {
//! s.on_disconnect(handler);
//! });
//! ```
use std::sync::Arc;

use futures::Future;

use crate::{
adapter::Adapter,
socket::{DisconnectReason, Socket},
};

use super::MakeErasedHandler;

/// A Type Erased [`DisconnectHandler`] so it can be stored in a HashMap
pub(crate) type BoxedDisconnectHandler<A> = Box<dyn ErasedDisconnectHandler<A>>;
pub(crate) trait ErasedDisconnectHandler<A: Adapter>: Send + Sync + 'static {
fn call(&self, s: Arc<Socket<A>>, reason: DisconnectReason);
}

impl<A: Adapter, T, H> MakeErasedHandler<H, A, T>
where
T: Send + Sync + 'static,
H: DisconnectHandler<A, T> + Send + Sync + 'static,
{
pub fn new_disconnect_boxed(inner: H) -> Box<dyn ErasedDisconnectHandler<A>> {
Box::new(MakeErasedHandler::new(inner))
}
}

impl<A: Adapter, T, H> ErasedDisconnectHandler<A> for MakeErasedHandler<H, A, T>
where
H: DisconnectHandler<A, T> + Send + Sync + 'static,
T: Send + Sync + 'static,
{
#[inline(always)]
fn call(&self, s: Arc<Socket<A>>, reason: DisconnectReason) {
self.handler.call(s, reason);
}
}

/// A trait used to extract the arguments from the disconnect event.
/// The `Result` associated type is used to return an error if the extraction fails,
/// in this case the [`DisconnectHandler`] is not called.
///
/// * See the [`disconnect`](super::disconnect) module doc for more details on disconnect handler.
/// * See the [`extract`](super::extract) module doc for more details on available extractors.
pub trait FromDisconnectParts<A: Adapter>: Sized {
/// The error type returned by the extractor
type Error: std::error::Error + 'static;

/// Extract the arguments from the disconnect event.
/// If it fails, the handler is not called
fn from_disconnect_parts(
s: &Arc<Socket<A>>,
reason: DisconnectReason,
) -> Result<Self, Self::Error>;
}

/// Define a handler for the disconnect event.
/// It is implemented for closures with up to 16 arguments. They must implement the [`FromDisconnectParts`] trait.
///
/// * See the [`disconnect`](super::disconnect) module doc for more details on disconnect handler.
/// * See the [`extract`](super::extract) module doc for more details on available extractors.
pub trait DisconnectHandler<A: Adapter, T>: Send + Sync + 'static {
/// Call the handler with the given arguments.
fn call(&self, s: Arc<Socket<A>>, reason: DisconnectReason);

#[doc(hidden)]
fn phantom(&self) -> std::marker::PhantomData<T> {
std::marker::PhantomData
}
}

mod private {
#[derive(Debug, Copy, Clone)]
pub enum Sync {}
#[derive(Debug, Copy, Clone)]
pub enum Async {}
}

macro_rules! impl_handler_async {
(
[$($ty:ident),*]
) => {
#[allow(non_snake_case, unused)]
impl<A, F, Fut, $($ty,)*> DisconnectHandler<A, (private::Async, $($ty,)*)> for F
where
F: FnOnce($($ty,)*) -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = ()> + Send + 'static,
A: Adapter,
$( $ty: FromDisconnectParts<A> + Send, )*
{
fn call(&self, s: Arc<Socket<A>>, reason: DisconnectReason) {
$(
let $ty = match $ty::from_disconnect_parts(&s, reason) {
Ok(v) => v,
Err(_e) => {
#[cfg(feature = "tracing")]
tracing::error!("Error while extracting data: {}", _e);
return;
},
};
)*

let fut = (self.clone())($($ty,)*);
tokio::spawn(fut);

}
}
};
}

macro_rules! impl_handler {
(
[$($ty:ident),*]
) => {
#[allow(non_snake_case, unused)]
impl<A, F, $($ty,)*> DisconnectHandler<A, (private::Sync, $($ty,)*)> for F
where
F: FnOnce($($ty,)*) + Send + Sync + Clone + 'static,
A: Adapter,
$( $ty: FromDisconnectParts<A> + Send, )*
{
fn call(&self, s: Arc<Socket<A>>, reason: DisconnectReason) {
$(
let $ty = match $ty::from_disconnect_parts(&s, reason) {
Ok(v) => v,
Err(_e) => {
#[cfg(feature = "tracing")]
tracing::error!("Error while extracting data: {}", _e);
return;
},
};
)*

(self.clone())($($ty,)*);
}
}
};
}
#[rustfmt::skip]
macro_rules! all_the_tuples {
($name:ident) => {
$name!([]);
$name!([T1]);
$name!([T1, T2]);
$name!([T1, T2, T3]);
$name!([T1, T2, T3, T4]);
$name!([T1, T2, T3, T4, T5]);
$name!([T1, T2, T3, T4, T5, T6]);
$name!([T1, T2, T3, T4, T5, T6, T7]);
$name!([T1, T2, T3, T4, T5, T6, T7, T8]);
$name!([T1, T2, T3, T4, T5, T6, T7, T8, T9]);
$name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10]);
$name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11]);
$name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12]);
$name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13]);
$name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14]);
$name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15]);
$name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16]);
};
}

all_the_tuples!(impl_handler_async);
all_the_tuples!(impl_handler);
38 changes: 36 additions & 2 deletions socketioxide/src/handler/extract.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! ### Extractors for [`ConnectHandler`](super::ConnectHandler) and [`MessageHandler`](super::MessageHandler).
//! ### Extractors for [`ConnectHandler`](super::ConnectHandler), [`MessageHandler`](super::MessageHandler)
//! and [`DisconnectHandler`](super::DisconnectHandler).
//!
//! They can be used to extract data from the context of the handler and get specific params. Here are some examples of extractors:
//! * [`Data`]: extracts and deserialize to json any data, if a deserialization error occurs the handler won't be called:
Expand All @@ -10,8 +11,10 @@
//! * [`SocketRef`]: extracts a reference to the [`Socket`]
//! * [`Bin`]: extract a binary payload for a given message. Because it consumes the event it should be the last argument
//! * [`AckSender`]: Can be used to send an ack response to the current message event
//! * [`ProtocolVersion`](crate::ProtocolVersion): extracts the protocol version
//! * [`TransportType`](crate::TransportType): extracts the transport type
//!
//! ### You can also implement your own Extractor with the [`FromConnectParts`]and [`FromMessageParts`] traits
//! ### You can also implement your own Extractor with the [`FromConnectParts`], [`FromMessageParts`] and [`FromDisconnectParts`] traits
//! When implementing these traits, if you clone the [`Arc<Socket>`] make sure that it is dropped at least when the socket is disconnected.
//! Otherwise it will create a memory leak. It is why the [`SocketRef`] extractor is used instead of cloning the socket for common usage.
//!
Expand Down Expand Up @@ -81,7 +84,9 @@ use std::convert::Infallible;
use std::sync::Arc;

use super::message::FromMessageParts;
use super::FromDisconnectParts;
use super::{connect::FromConnectParts, message::FromMessage};
use crate::socket::DisconnectReason;
use crate::{
adapter::{Adapter, LocalAdapter},
packet::Packet,
Expand Down Expand Up @@ -203,6 +208,12 @@ impl<A: Adapter> FromMessageParts<A> for SocketRef<A> {
Ok(SocketRef(s.clone()))
}
}
impl<A: Adapter> FromDisconnectParts<A> for SocketRef<A> {
type Error = Infallible;
fn from_disconnect_parts(s: &Arc<Socket<A>>, _: DisconnectReason) -> Result<Self, Infallible> {
Ok(SocketRef(s.clone()))
}
}

impl<A: Adapter> std::ops::Deref for SocketRef<A> {
type Target = Socket<A>;
Expand Down Expand Up @@ -317,6 +328,12 @@ impl<A: Adapter> FromMessageParts<A> for crate::ProtocolVersion {
Ok(s.protocol())
}
}
impl<A: Adapter> FromDisconnectParts<A> for crate::ProtocolVersion {
type Error = Infallible;
fn from_disconnect_parts(s: &Arc<Socket<A>>, _: DisconnectReason) -> Result<Self, Infallible> {
Ok(s.protocol())
}
}

impl<A: Adapter> FromConnectParts<A> for crate::TransportType {
type Error = Infallible;
Expand All @@ -341,6 +358,23 @@ impl<A: Adapter> FromMessageParts<A> for crate::TransportType {
}
}

impl<A: Adapter> FromDisconnectParts<A> for crate::TransportType {
type Error = Infallible;
fn from_disconnect_parts(s: &Arc<Socket<A>>, _: DisconnectReason) -> Result<Self, Infallible> {
Ok(s.transport_type())
}
}

impl<A: Adapter> FromDisconnectParts<A> for DisconnectReason {
type Error = Infallible;
fn from_disconnect_parts(
_: &Arc<Socket<A>>,
reason: DisconnectReason,
) -> Result<Self, Infallible> {
Ok(reason)
}
}

/// An Extractor that contains a reference to a state previously set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder).
/// It implements [`std::ops::Deref`] to access the inner type so you can use it as a normal reference.
///
Expand Down
7 changes: 5 additions & 2 deletions socketioxide/src/handler/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
//! Functions and types used to handle incoming connections and messages.
//! There is two main types of handlers: [`ConnectHandler`] and [`MessageHandler`].
//! Both handlers can be async or not.
//! There is three main types of handlers: [`ConnectHandler`], [`MessageHandler`] and [`DisconnectHandler`].
//! All handlers can be async or not.
pub mod connect;
pub mod disconnect;
pub mod extract;
pub mod message;

pub(crate) use connect::BoxedConnectHandler;
pub use connect::{ConnectHandler, FromConnectParts};
pub(crate) use disconnect::BoxedDisconnectHandler;
pub use disconnect::{DisconnectHandler, FromDisconnectParts};
pub(crate) use message::BoxedMessageHandler;
pub use message::{FromMessage, FromMessageParts, MessageHandler};
/// A struct used to erase the type of a [`ConnectHandler`] or [`MessageHandler`] so it can be stored in a map
Expand Down
2 changes: 1 addition & 1 deletion socketioxide/src/layer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//! ## A tower [`Layer`](tower::Layer) for socket.io so it can be used as a middleware with frameworks supporting layers.
//! ## A tower [`Layer`] for socket.io so it can be used as a middleware with frameworks supporting layers.
//!
//! #### Example with axum :
//! ```rust
Expand Down
16 changes: 13 additions & 3 deletions socketioxide/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,25 @@
//! ```
//!
//! ## Handlers
//! Handlers are functions or clonable closures that are given to the `io.ns` and the `socket.on` methods. They can be async or sync and can take from 0 to 16 arguments that implements the [`FromConnectParts`](handler::FromConnectParts) trait for the [`ConnectHandler`](handler::ConnectHandler) and the [`FromMessageParts`](handler::FromMessageParts) for the [`MessageHandler`](handler::MessageHandler). They are greatly inspired by the axum handlers.
//! Handlers are functions or clonable closures that are given to the `io.ns`, the `socket.on` and the `socket.on_disconnect` fns.
//! They can be async or sync and can take from 0 to 16 arguments that implements the [`FromConnectParts`](handler::FromConnectParts)
//! trait for the [`ConnectHandler`](handler::ConnectHandler), the [`FromMessageParts`](handler::FromMessageParts) for
//! the [`MessageHandler`](handler::MessageHandler) and the [`FromDisconnectParts`](handler::FromDisconnectParts) for
//! the [`DisconnectHandler`](handler::DisconnectHandler).
//! They are greatly inspired by the axum handlers.
//!
//! If they are async, a new task will be spawned for each incoming connection/message so it doesn't block the event management task.
//!
//! * Check the [`handler::connect`] module doc for more details on the connect handler
//! * Check the [`handler::message`] module doc for more details on the message handler.
//! * Check the [`handler::disconnect`] module doc for more details on the disconnect handler.
//! * Check the [`handler::extract`] module doc for more details on the extractors.
//!
//! ## Extractors
//! Handlers params are called extractors and are used to extract data from the incoming connection/message. They are inspired by the axum extractors.
//! An extractor is a struct that implements the [`FromConnectParts`](handler::FromConnectParts) trait for the [`ConnectHandler`](handler::ConnectHandler) and the [`FromMessageParts`](handler::FromMessageParts) for the [`MessageHandler`](handler::MessageHandler).
//! An extractor is a struct that implements the [`FromConnectParts`](handler::FromConnectParts) trait for the [`ConnectHandler`](handler::ConnectHandler)
//! the [`FromMessageParts`](handler::FromMessageParts) for the [`MessageHandler`](handler::MessageHandler) and the
//! [`FromDisconnectParts`](handler::FromDisconnectParts) for the [`DisconnectHandler`](handler::DisconnectHandler).
//!
//! Here are some examples of extractors:
//! * [`Data`](extract::Data): extracts and deserialize to json any data, if a deserialize error occurs the handler won't be called
Expand All @@ -161,6 +169,8 @@
//! * [`SocketRef`](extract::Data): extracts a reference to the [`Socket`](socket::Socket)
//! * [`Bin`](extract::Data): extract a binary payload for a given message. Because it consumes the event it should be the last argument
//! * [`AckSender`](extract::Data): Can be used to send an ack response to the current message event
//! * [`ProtocolVersion`]: extracts the protocol version of the socket
//! * [`TransportType`]: extracts the transport type of the socket
//!
//! ### Extractor order
//! Extractors are run in the order of their declaration in the handler signature. If an extractor returns an error, the handler won't be called and a `tracing::error!` call will be emitted if the `tracing` feature is enabled.
Expand All @@ -173,7 +183,7 @@
//! There are three types of events:
//! * The connect event is emitted when a new connection is established. It can be handled with the [`ConnectHandler`](handler::ConnectHandler) and the `io.ns` method.
//! * The message event is emitted when a new message is received. It can be handled with the [`MessageHandler`](handler::MessageHandler) and the `socket.on` method.
//! * The disconnect event is emitted when a socket is closed. Contrary to the two previous events, the callback is not flexible, it *must* be async and have the following signature `async fn(SocketRef, DisconnectReason)`. It can be handled with the `socket.on_disconnect` method.
//! * The disconnect event is emitted when a socket is closed. It can be handled with the [`DisconnectHandler`](handler::DisconnectHandler) and the `socket.on_disconnect` method.
//!
//! Only one handler can exist for an event so registering a new handler for an event will replace the previous one.
//!
Expand Down
2 changes: 1 addition & 1 deletion socketioxide/src/service.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//! ## A tower [`Service`](tower::Service) for socket.io so it can be used with frameworks supporting tower services.
//! ## A tower [`Service`] for socket.io so it can be used with frameworks supporting tower services.
//!
//! #### Example with a `Warp` inner service :
//! ```rust
Expand Down
Loading

0 comments on commit ea0f210

Please sign in to comment.