diff --git a/Cargo.toml b/Cargo.toml index f0be666..65d78cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ storb_gateway = { version = "*", path = "crates/storb_gateway" } storb_node = { version = "*", path = "crates/storb_node" } storb_protocol = { version = "*", path = "crates/storb_protocol" } storb_storage = { version = "*", path = "crates/storb_storage" } +storb_rpc = { version = "*", path = "crates/storb_rpc" } # storb_client = { version = "*", path = "crates/storb_client" } # diff --git a/crates/storb_protocol/src/rpc.rs b/crates/storb_protocol/src/rpc.rs index e0e644a..af83cfa 100644 --- a/crates/storb_protocol/src/rpc.rs +++ b/crates/storb_protocol/src/rpc.rs @@ -27,12 +27,7 @@ enum RpcError { pub struct RPCClientOptions { /// Disable TLS pub insecure: bool, - pub server_addr: SocketAddr -} - -/// Configuration options for the RPC server. -pub struct RPCServerOptions { - pub addr: SocketAddr + pub server_addr: SocketAddr, } // TODO: remove unwraps @@ -62,10 +57,3 @@ pub async fn init_client_rpc(rpc_options: RpcClientOptions) -> Result { -// let port = rpc_options.addr.port(); -// let ip = rpc_options.addr.ip(); -// let socket = UdpSocket::bind(("0.0.0.0", quic_port))?; -// let server_config = configure_server(ip).unwrap(); -// } diff --git a/crates/storb_rpc/Cargo.toml b/crates/storb_rpc/Cargo.toml new file mode 100644 index 0000000..f9e9ed1 --- /dev/null +++ b/crates/storb_rpc/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "storb_rpc" +description = "QUIC and Cap'n'Proto RPC framework " +version.workspace = true +edition.workspace = true + +[dependencies] +capnp.workspace = true +capnp-rpc.workspace = true +futures.workspace = true +multiaddr.workspace = true +quinn.workspace = true +rcgen.workspace = true +rustls.workspace = true +serde.workspace = true +thiserror.workspace = true +tokio.workspace = true +tokio-util.workspace = true +tracing.workspace = true + +[dev-dependencies] +tracing-subscriber = "0.3" +tokio-util = { version = "0.7", features = ["compat"] } +rustls = { version = "0.23", features = ["aws-lc-rs"] } diff --git a/crates/storb_rpc/README.md b/crates/storb_rpc/README.md new file mode 100644 index 0000000..5ec08cd --- /dev/null +++ b/crates/storb_rpc/README.md @@ -0,0 +1,3 @@ +# `storb_rpc` + +RPC over Quic Library diff --git a/crates/storb_rpc/src/lib.rs b/crates/storb_rpc/src/lib.rs new file mode 100644 index 0000000..fb127d6 --- /dev/null +++ b/crates/storb_rpc/src/lib.rs @@ -0,0 +1,6 @@ +pub mod macros; +pub mod server; +pub mod service; + +pub use server::{AsyncRuntime, ServerBuilder, ServerOptions, TokioRuntime}; +pub use service::{CapnpServiceWrapper, GeneratedCapnpService, Service, ServiceBuilder}; diff --git a/crates/storb_rpc/src/macros.rs b/crates/storb_rpc/src/macros.rs new file mode 100644 index 0000000..c90512e --- /dev/null +++ b/crates/storb_rpc/src/macros.rs @@ -0,0 +1,8 @@ +pub mod include_proto { + #[macro_export] + macro_rules! include_proto { + ($schema:expr) => { + include!(concat!(env!("OUT_DIR"), "/", $schema, "_capnp.rs")) + }; + } +} diff --git a/crates/storb_rpc/src/server.rs b/crates/storb_rpc/src/server.rs new file mode 100644 index 0000000..8845878 --- /dev/null +++ b/crates/storb_rpc/src/server.rs @@ -0,0 +1,133 @@ +use std::collections::HashMap; +use std::future::Future; +use std::net::SocketAddr; +use std::sync::Arc; + +use capnp_rpc::{twoparty, RpcSystem}; +use quinn::{Endpoint, ServerConfig as QuinnServerConfig}; + +use crate::service::Service; + +pub trait AsyncRuntime: Send + Sync + 'static + Clone { + fn spawn(&self, future: F) + where + F: Future + Send + 'static; +} + +#[derive(Clone)] +pub struct TokioRuntime; + +impl AsyncRuntime for TokioRuntime { + fn spawn(&self, future: F) + where + F: Future + Send + 'static, + { + tokio::spawn(future); + } +} + +#[derive(Clone)] +pub struct ServerOptions { + pub addr: SocketAddr, + pub tls: Option, +} + +impl Default for ServerOptions { + fn default() -> Self { + Self { + addr: "127.0.0.1:5000".parse().unwrap(), + tls: None, + } + } +} + +pub struct ServerBuilder { + opts: ServerOptions, + services: Vec>, + runtime: R, +} + +impl ServerBuilder { + pub fn new() -> Self { + Self { + opts: ServerOptions::default(), + services: Vec::new(), + runtime: TokioRuntime, + } + } +} + +impl ServerBuilder +where + R: AsyncRuntime + Clone, +{ + pub fn bind_addr(mut self, addr: SocketAddr) -> Self { + self.opts.addr = addr; + self + } + + pub fn tls_config(mut self, cfg: Option) -> Self { + self.opts.tls = cfg; + self + } + + pub fn add_service(mut self, svc: S) -> Self { + self.services.push(Arc::new(svc)); + self + } + + pub fn with_runtime(self, runtime: NewR) -> ServerBuilder { + ServerBuilder { + opts: self.opts, + services: self.services, + runtime, + } + } + + pub async fn serve(self) -> anyhow::Result<()> { + let tls_cfg = if let Some(cfg) = self.opts.tls { + cfg + } else { + let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])?; + let key = cert.serialize_der()?; + let cert = cert.serialize_der()?; + let mut server_cfg = quinn::ServerConfig::with_single_cert(vec![cert], key)?; + Arc::make_mut(&mut server_cfg.transport).max_concurrent_uni_streams(0_u8.into()); + server_cfg + }; + + let (endpoint, mut incoming) = Endpoint::server(tls_cfg, self.opts.addr)?; + println!("Listening on {}", self.opts.addr); + + while let Some(conn) = incoming.next().await { + let services = self.services.clone(); + let runtime = self.runtime.clone(); + + runtime.spawn(async move { + if let Ok(new_conn) = conn.await { + println!("Connection from {}", new_conn.remote_address()); + + while let Ok((send, recv)) = new_conn.accept_bi().await { + let mut rpc_sys = RpcSystem::new( + Box::new(twoparty::VatNetwork::new( + recv, + send, + twoparty::Side::Server, + Default::default(), + )), + None, + ); + + for svc in &services { + svc.register_methods(&mut rpc_sys); + } + + tokio::task::spawn_local(rpc_sys.map(|_| ())); + } + } + }); + } + + Ok(()) + } +} diff --git a/crates/storb_rpc/src/service.rs b/crates/storb_rpc/src/service.rs new file mode 100644 index 0000000..e3b7bdd --- /dev/null +++ b/crates/storb_rpc/src/service.rs @@ -0,0 +1,149 @@ +use std::collections::HashMap; +use std::fmt; +use std::future::Future; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; + +use capnp_rpc::twoparty::{Side, VatNetwork}; +use capnp_rpc::RpcSystem; +use quinn::{RecvStream, SendStream}; + +pub struct MetadataMap(pub HashMap); + +impl MetadataMap { + pub fn keys(&self) -> impl Iterator { + self.0.keys() + } +} + +pub struct RequestContext { + pub peer_addr: SocketAddr, + pub metadata: MetadataMap, +} + +impl fmt::Debug for RequestContext { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RequestContext") + .field("peer_addr", &self.peer_addr) + .field("metadata_keys", &self.metadata.keys().collect::>()) + .finish() + } +} + +pub trait Middleware: Send + Sync + 'static { + fn handle( + &self, + ctx: RequestContext, + next: Next, + ) -> Pin + Send + '_>>; +} + +impl fmt::Debug for dyn Middleware { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("Middleware") + } +} + +pub struct Next<'a> { + idx: usize, + middlewares: &'a [Arc], + service_impl: &'a dyn GeneratedCapnpService, + rpc_system: &'a mut RpcSystem>, +} + +impl<'a> Next<'a> { + pub async fn run(mut self, ctx: RequestContext) { + if self.idx < self.middlewares.len() { + let mw = &self.middlewares[self.idx]; + self.idx += 1; + mw.handle(ctx, self).await; + } else { + self.service_impl.register_methods_impl(self.rpc_system); + } + } +} + +pub trait Service: Send + Sync + 'static { + /// Register all RPC endpoints on the given system. + fn register_methods(&self, rpc_system: &mut RpcSystem>); + + /// A stable name for logging + fn name(&self) -> &'static str; + + /// Health check for graceful shutdown and monitoring. + fn is_healthy(&self) -> bool { + true + } +} + +pub struct ServiceBuilder { + interceptors: Vec>, +} + +pub trait Interceptor: Send + Sync + 'static { + fn intercept(&self, ctx: &mut RequestContext); +} + +impl ServiceBuilder { + pub fn new() -> Self { + ServiceBuilder { + interceptors: Vec::new(), + } + } + + pub fn add_interceptor(mut self, interceptor: I) -> Self + where + I: Interceptor, + { + self.interceptors.push(Arc::new(interceptor)); + self + } + + pub fn wrap(self, impl_obj: T) -> CapnpServiceWrapper + where + T: GeneratedCapnpService, + { + CapnpServiceWrapper { + inner: Arc::new(impl_obj), + interceptors: self.interceptors, + } + } +} + +pub trait GeneratedCapnpService: Send + Sync + 'static { + fn register_methods_impl(&self, rpc_system: &mut RpcSystem>); + + fn service_name() -> &'static str + where + Self: Sized; + + fn is_healthy(&self) -> bool { + true + } +} + +pub struct CapnpServiceWrapper +where + T: GeneratedCapnpService, +{ + inner: Arc, + interceptors: Vec>, +} + +impl Service for CapnpServiceWrapper +where + T: GeneratedCapnpService, +{ + fn register_methods(&self, rpc_system: &mut RpcSystem>) { + self.inner.register_methods_impl(rpc_system); + } + + fn name(&self) -> &'static str { + T::service_name() + } + + fn is_healthy(&self) -> bool { + self.inner.is_healthy() + } +}