Skip to content

Commit

Permalink
config: allows bind configuration with existing socket
Browse files Browse the repository at this point in the history
  • Loading branch information
BiagioFesta committed Nov 8, 2024
1 parent 4d05c33 commit ba97acc
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 71 deletions.
129 changes: 95 additions & 34 deletions wtransport/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,20 @@ use crate::tls::build_native_cert_store;
use crate::tls::Identity;
use quinn::EndpointConfig;
use quinn::TransportConfig;
use socket2::Domain as SocketDomain;
use socket2::Protocol as SocketProtocol;
use socket2::Socket;
use socket2::Type as SocketType;
use std::fmt::Debug;
use std::fmt::Display;
use std::future::Future;
use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::net::Ipv6Addr;
use std::net::SocketAddr;
use std::net::SocketAddrV4;
use std::net::SocketAddrV6;
use std::net::UdpSocket;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
Expand Down Expand Up @@ -227,10 +233,9 @@ pub struct InvalidIdleTimeout;
/// .build();
/// # Ok(())
/// # }
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct ServerConfig {
pub(crate) bind_address: SocketAddr,
pub(crate) dual_stack_config: Ipv6DualStackConfig,
pub(crate) bind_address_config: BindAddressConfig,
pub(crate) endpoint_config: quinn::EndpointConfig,
pub(crate) quic_config: quinn::ServerConfig,
}
Expand Down Expand Up @@ -330,8 +335,7 @@ impl ServerConfigBuilder<states::WantsBindAddress> {
address: SocketAddr,
) -> ServerConfigBuilder<states::WantsIdentity> {
ServerConfigBuilder(states::WantsIdentity {
bind_address: address,
dual_stack_config: Ipv6DualStackConfig::OsDefault,
bind_address_config: BindAddressConfig::from(address),
})
}

Expand All @@ -344,8 +348,17 @@ impl ServerConfigBuilder<states::WantsBindAddress> {
dual_stack_config: Ipv6DualStackConfig,
) -> ServerConfigBuilder<states::WantsIdentity> {
ServerConfigBuilder(states::WantsIdentity {
bind_address: address.into(),
dual_stack_config,
bind_address_config: BindAddressConfig::AddressV6(address, dual_stack_config),
})
}

/// Configures the server to bind to a pre-existing [`UdpSocket`].
///
/// This allows the server to use an already created socket, which may be beneficial
/// for scenarios where socket reuse or specific socket configuration is needed.
pub fn with_bind_socket(self, socket: UdpSocket) -> ServerConfigBuilder<states::WantsIdentity> {
ServerConfigBuilder(states::WantsIdentity {
bind_address_config: BindAddressConfig::Socket(socket),
})
}
}
Expand Down Expand Up @@ -508,8 +521,7 @@ impl ServerConfigBuilder<states::WantsIdentity> {
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn build_with_quic_config(self, quic_config: QuicServerConfig) -> ServerConfig {
ServerConfig {
bind_address: self.0.bind_address,
dual_stack_config: self.0.dual_stack_config,
bind_address_config: self.0.bind_address_config,
endpoint_config: EndpointConfig::default(),
quic_config,
}
Expand All @@ -522,8 +534,7 @@ impl ServerConfigBuilder<states::WantsIdentity> {
transport_config: TransportConfig,
) -> ServerConfigBuilder<states::WantsTransportConfigServer> {
ServerConfigBuilder(states::WantsTransportConfigServer {
bind_address: self.0.bind_address,
dual_stack_config: self.0.dual_stack_config,
bind_address_config: self.0.bind_address_config,
tls_config,
endpoint_config,
transport_config,
Expand Down Expand Up @@ -551,8 +562,7 @@ impl ServerConfigBuilder<states::WantsTransportConfigServer> {
quic_config.migration(self.0.migration);

ServerConfig {
bind_address: self.0.bind_address,
dual_stack_config: self.0.dual_stack_config,
bind_address_config: self.0.bind_address_config,
endpoint_config: self.0.endpoint_config,
quic_config,
}
Expand Down Expand Up @@ -700,10 +710,9 @@ impl ServerConfigBuilder<states::WantsTransportConfigServer> {
/// .keep_alive_interval(Some(Duration::from_secs(3)))
/// .build();
/// ```
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct ClientConfig {
pub(crate) bind_address: SocketAddr,
pub(crate) dual_stack_config: Ipv6DualStackConfig,
pub(crate) bind_address_config: BindAddressConfig,
pub(crate) endpoint_config: quinn::EndpointConfig,
pub(crate) quic_config: quinn::ClientConfig,
pub(crate) dns_resolver: Arc<dyn DnsResolver + Send + Sync>,
Expand Down Expand Up @@ -813,8 +822,7 @@ impl ClientConfigBuilder<states::WantsBindAddress> {
address: SocketAddr,
) -> ClientConfigBuilder<states::WantsRootStore> {
ClientConfigBuilder(states::WantsRootStore {
bind_address: address,
dual_stack_config: Ipv6DualStackConfig::OsDefault,
bind_address_config: BindAddressConfig::from(address),
})
}

Expand All @@ -827,8 +835,17 @@ impl ClientConfigBuilder<states::WantsBindAddress> {
dual_stack_config: Ipv6DualStackConfig,
) -> ClientConfigBuilder<states::WantsRootStore> {
ClientConfigBuilder(states::WantsRootStore {
bind_address: address.into(),
dual_stack_config,
bind_address_config: BindAddressConfig::AddressV6(address, dual_stack_config),
})
}

/// Configures the client to bind to a pre-existing [`UdpSocket`].
///
/// This allows the client to use an already created socket, which can be useful in cases
/// where socket reuse or specific socket configurations are necessary.
pub fn with_bind_socket(self, socket: UdpSocket) -> ServerConfigBuilder<states::WantsIdentity> {
ServerConfigBuilder(states::WantsIdentity {
bind_address_config: BindAddressConfig::Socket(socket),
})
}
}
Expand Down Expand Up @@ -1022,8 +1039,7 @@ impl ClientConfigBuilder<states::WantsRootStore> {
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn build_with_quic_config(self, quic_config: QuicClientConfig) -> ClientConfig {
ClientConfig {
bind_address: self.0.bind_address,
dual_stack_config: self.0.dual_stack_config,
bind_address_config: self.0.bind_address_config,
endpoint_config: EndpointConfig::default(),
quic_config,
dns_resolver: Arc::<TokioDnsResolver>::default(),
Expand All @@ -1037,8 +1053,7 @@ impl ClientConfigBuilder<states::WantsRootStore> {
transport_config: TransportConfig,
) -> ClientConfigBuilder<states::WantsTransportConfigClient> {
ClientConfigBuilder(states::WantsTransportConfigClient {
bind_address: self.0.bind_address,
dual_stack_config: self.0.dual_stack_config,
bind_address_config: self.0.bind_address_config,
tls_config,
endpoint_config,
transport_config,
Expand All @@ -1062,8 +1077,7 @@ impl ClientConfigBuilder<states::WantsTransportConfigClient> {
quic_config.transport_config(Arc::new(self.0.transport_config));

ClientConfig {
bind_address: self.0.bind_address,
dual_stack_config: self.0.dual_stack_config,
bind_address_config: self.0.bind_address_config,
endpoint_config: self.0.endpoint_config,
quic_config,
dns_resolver: self.0.dns_resolver,
Expand Down Expand Up @@ -1127,6 +1141,57 @@ impl Default for ClientConfigBuilder<states::WantsBindAddress> {
}
}

#[derive(Debug)]
pub(crate) enum BindAddressConfig {
AddressV4(SocketAddrV4),
AddressV6(SocketAddrV6, Ipv6DualStackConfig),
Socket(UdpSocket),
}

impl BindAddressConfig {
pub(crate) fn bind_socket(self) -> std::io::Result<UdpSocket> {
let (bind_address, dual_stack_config) = match self {
BindAddressConfig::AddressV4(address) => {
(SocketAddr::from(address), Ipv6DualStackConfig::OsDefault)
}
BindAddressConfig::AddressV6(address, ipv6_dual_stack_config) => {
(SocketAddr::from(address), ipv6_dual_stack_config)
}
BindAddressConfig::Socket(socket) => {
return Ok(socket);
}
};

let domain = match bind_address {
SocketAddr::V4(_) => SocketDomain::IPV4,
SocketAddr::V6(_) => SocketDomain::IPV6,
};

let socket = Socket::new(domain, SocketType::DGRAM, Some(SocketProtocol::UDP))?;

match dual_stack_config {
Ipv6DualStackConfig::OsDefault => {}
Ipv6DualStackConfig::Deny => socket.set_only_v6(true)?,
Ipv6DualStackConfig::Allow => socket.set_only_v6(false)?,
}

socket.bind(&bind_address.into())?;

Ok(UdpSocket::from(socket))
}
}

impl From<SocketAddr> for BindAddressConfig {
fn from(value: SocketAddr) -> Self {
match value {
SocketAddr::V4(address) => BindAddressConfig::AddressV4(address),
SocketAddr::V6(address) => {
BindAddressConfig::AddressV6(address, Ipv6DualStackConfig::OsDefault)
}
}
}
}

/// State-types for client/server builder.
pub mod states {
use super::*;
Expand All @@ -1136,20 +1201,17 @@ pub mod states {

/// Config builder state where the caller must supply TLS certificate.
pub struct WantsIdentity {
pub(super) bind_address: SocketAddr,
pub(super) dual_stack_config: Ipv6DualStackConfig,
pub(super) bind_address_config: BindAddressConfig,
}

/// Config builder state where the caller must supply TLS root store.
pub struct WantsRootStore {
pub(super) bind_address: SocketAddr,
pub(super) dual_stack_config: Ipv6DualStackConfig,
pub(super) bind_address_config: BindAddressConfig,
}

/// Config builder state where transport properties can be set.
pub struct WantsTransportConfigServer {
pub(super) bind_address: SocketAddr,
pub(super) dual_stack_config: Ipv6DualStackConfig,
pub(super) bind_address_config: BindAddressConfig,
pub(super) tls_config: TlsServerConfig,
pub(super) endpoint_config: quinn::EndpointConfig,
pub(super) transport_config: quinn::TransportConfig,
Expand All @@ -1158,8 +1220,7 @@ pub mod states {

/// Config builder state where transport properties can be set.
pub struct WantsTransportConfigClient {
pub(super) bind_address: SocketAddr,
pub(super) dual_stack_config: Ipv6DualStackConfig,
pub(super) bind_address_config: BindAddressConfig,
pub(super) tls_config: TlsClientConfig,
pub(super) endpoint_config: quinn::EndpointConfig,
pub(super) transport_config: quinn::TransportConfig,
Expand Down
43 changes: 6 additions & 37 deletions wtransport/src/endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::config::ClientConfig;
use crate::config::DnsResolver;
use crate::config::Ipv6DualStackConfig;
use crate::config::ServerConfig;
use crate::connection::Connection;
use crate::driver::streams::session::StreamSession;
Expand All @@ -12,10 +11,6 @@ use crate::error::ConnectingError;
use crate::error::ConnectionError;
use crate::VarInt;
use quinn::TokioRuntime;
use socket2::Domain as SocketDomain;
use socket2::Protocol as SocketProtocol;
use socket2::Socket;
use socket2::Type as SocketType;
use std::collections::HashMap;
use std::future::Future;
use std::future::IntoFuture;
Expand Down Expand Up @@ -107,28 +102,6 @@ pub struct Endpoint<Side> {
}

impl<Side> Endpoint<Side> {
fn bind_socket(
bind_address: SocketAddr,
dual_stack_config: Ipv6DualStackConfig,
) -> std::io::Result<Socket> {
let domain = match bind_address {
SocketAddr::V4(_) => SocketDomain::IPV4,
SocketAddr::V6(_) => SocketDomain::IPV6,
};

let socket = Socket::new(domain, SocketType::DGRAM, Some(SocketProtocol::UDP))?;

match dual_stack_config {
Ipv6DualStackConfig::OsDefault => {}
Ipv6DualStackConfig::Deny => socket.set_only_v6(true)?,
Ipv6DualStackConfig::Allow => socket.set_only_v6(false)?,
}

socket.bind(&bind_address.into())?;

Ok(socket)
}

/// Closes all of this endpoint's connections immediately and cease accepting new connections.
pub fn close(&self, error_code: VarInt, reason: &[u8]) {
self.endpoint.close(varint_w2q(error_code), reason);
Expand All @@ -155,12 +128,10 @@ impl Endpoint<endpoint_side::Server> {
pub fn server(server_config: ServerConfig) -> std::io::Result<Self> {
let endpoint_config = server_config.endpoint_config;
let quic_config = server_config.quic_config;
let socket =
Self::bind_socket(server_config.bind_address, server_config.dual_stack_config)?;
let socket = server_config.bind_address_config.bind_socket()?;
let runtime = Arc::new(TokioRuntime);

let endpoint =
quinn::Endpoint::new(endpoint_config, Some(quic_config), socket.into(), runtime)?;
let endpoint = quinn::Endpoint::new(endpoint_config, Some(quic_config), socket, runtime)?;

Ok(Self {
endpoint,
Expand Down Expand Up @@ -195,9 +166,8 @@ impl Endpoint<endpoint_side::Server> {
/// If `false`, the bind address configuration will be ignored.
pub fn reload_config(&self, server_config: ServerConfig, rebind: bool) -> std::io::Result<()> {
if rebind {
let socket =
Self::bind_socket(server_config.bind_address, server_config.dual_stack_config)?;
self.endpoint.rebind(socket.into())?;
let socket = server_config.bind_address_config.bind_socket()?;
self.endpoint.rebind(socket)?;
}

let quic_config = server_config.quic_config;
Expand All @@ -212,11 +182,10 @@ impl Endpoint<endpoint_side::Client> {
pub fn client(client_config: ClientConfig) -> std::io::Result<Self> {
let endpoint_config = client_config.endpoint_config;
let quic_config = client_config.quic_config;
let socket =
Self::bind_socket(client_config.bind_address, client_config.dual_stack_config)?;
let socket = client_config.bind_address_config.bind_socket()?;
let runtime = Arc::new(TokioRuntime);

let mut endpoint = quinn::Endpoint::new(endpoint_config, None, socket.into(), runtime)?;
let mut endpoint = quinn::Endpoint::new(endpoint_config, None, socket, runtime)?;

endpoint.set_default_client_config(quic_config);

Expand Down

0 comments on commit ba97acc

Please sign in to comment.