diff --git a/core/Cargo.lock b/core/Cargo.lock index 28417ce9e28b..d2917fe4721d 100644 --- a/core/Cargo.lock +++ b/core/Cargo.lock @@ -3814,9 +3814,9 @@ dependencies = [ "rustls 0.23.15", "rustls-pki-types", "tokio", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.1", "tower-service", - "webpki-roots 0.26.6", + "webpki-roots 0.26.7", ] [[package]] @@ -5052,6 +5052,7 @@ dependencies = [ "reqwest", "rocksdb", "rust-nebula", + "rustls 0.23.15", "serde", "serde_json", "sha1", @@ -5064,10 +5065,12 @@ dependencies = [ "surrealdb", "tikv-client", "tokio", + "tokio-rustls 0.26.1", "tracing", "tracing-opentelemetry", "tracing-subscriber", "uuid", + "webpki-roots 0.26.7", ] [[package]] @@ -6425,7 +6428,7 @@ dependencies = [ "tokio", "tokio-native-tls", "tokio-retry2", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.1", "tokio-util", "url", ] @@ -6613,7 +6616,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper 1.0.1", "tokio", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.1", "tokio-util", "tower-service", "url", @@ -6621,7 +6624,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots 0.26.6", + "webpki-roots 0.26.7", "windows-registry", ] @@ -7615,7 +7618,7 @@ dependencies = [ "tokio-stream", "tracing", "url", - "webpki-roots 0.26.6", + "webpki-roots 0.26.7", ] [[package]] @@ -8413,12 +8416,11 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.0" +version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" dependencies = [ "rustls 0.23.15", - "rustls-pki-types", "tokio", ] @@ -8444,9 +8446,9 @@ dependencies = [ "rustls 0.23.15", "rustls-pki-types", "tokio", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.1", "tungstenite", - "webpki-roots 0.26.6", + "webpki-roots 0.26.7", ] [[package]] @@ -8534,7 +8536,7 @@ dependencies = [ "rustls-pemfile 2.2.0", "socket2", "tokio", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.1", "tokio-stream", "tower 0.4.13", "tower-layer", @@ -8889,7 +8891,7 @@ dependencies = [ "rustls 0.23.15", "rustls-pki-types", "url", - "webpki-roots 0.26.6", + "webpki-roots 0.26.7", ] [[package]] @@ -9190,9 +9192,9 @@ checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" [[package]] name = "webpki-roots" -version = "0.26.6" +version = "0.26.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958" +checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e" dependencies = [ "rustls-pki-types", ] diff --git a/core/Cargo.toml b/core/Cargo.toml index d404845dd3e2..03442f89eaad 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -159,7 +159,12 @@ services-ipmfs = [] services-koofr = [] services-lakefs = [] services-libsql = ["dep:hrana-client-proto"] -services-memcached = ["dep:bb8"] +services-memcached = [ + "dep:bb8", + "dep:tokio-rustls", + "dep:webpki-roots", + "dep:rustls", +] services-memory = [] services-mini-moka = ["dep:mini-moka"] services-moka = ["dep:moka"] @@ -359,6 +364,12 @@ monoio = { version = "0.2.4", optional = true, features = [ "unlinkat", "renameat", ] } +# for services-memcached +rustls = { version = "0.23.15", default-features = false, features = [ + "std", +], optional = true } +tokio-rustls = { version = "0.26.1", optional = true } +webpki-roots = { version = "0.26.7", optional = true } # Layers # for layers-async-backtrace diff --git a/core/src/services/memcached/backend.rs b/core/src/services/memcached/backend.rs index 5ddfa4b9c114..d014ee4db109 100644 --- a/core/src/services/memcached/backend.rs +++ b/core/src/services/memcached/backend.rs @@ -15,18 +15,23 @@ // specific language governing permissions and limitations // under the License. +use std::path::PathBuf; +use std::sync::Arc; use std::time::Duration; -use bb8::RunError; -use tokio::net::TcpStream; -use tokio::sync::OnceCell; - use super::binary; use crate::raw::adapters::kv; use crate::raw::*; use crate::services::MemcachedConfig; use crate::*; +use bb8::RunError; +use rustls::pki_types::pem::PemObject; +use rustls::pki_types::{CertificateDer, ServerName}; +use tokio::net::TcpStream; +use tokio::sync::OnceCell; +use tokio_rustls::TlsConnector; + impl Configurator for MemcachedConfig { type Builder = MemcachedBuilder; fn into_builder(self) -> Self::Builder { @@ -82,6 +87,18 @@ impl MemcachedBuilder { self.config.default_ttl = Some(ttl); self } + + /// Set the tls connect on. + pub fn tls(mut self, tls: bool) -> Self { + self.config.tls = Some(tls); + self + } + + /// Set the tls connect on. + pub fn cafile(mut self, cafile: PathBuf) -> Self { + self.config.cafile = Some(cafile); + self + } } impl Builder for MemcachedBuilder { @@ -126,6 +143,14 @@ impl Builder for MemcachedBuilder { .with_context("endpoint", &endpoint), ); }; + if self.config.tls.unwrap_or(false) { + ServerName::try_from(host.clone()).map_err(|err| { + Error::new(ErrorKind::ConfigInvalid, "Invalid dns name error") + .with_context("service", Scheme::Memcached) + .with_context("host", &host) + .set_source(err) + })?; + } let port = if let Some(port) = uri.port_u16() { port } else { @@ -150,6 +175,9 @@ impl Builder for MemcachedBuilder { endpoint, username: self.config.username.clone(), password: self.config.password.clone(), + tls: self.config.tls.clone(), + cafile: self.config.cafile.clone(), + host, conn, default_ttl: self.config.default_ttl, }) @@ -166,6 +194,9 @@ pub struct Adapter { username: Option, password: Option, default_ttl: Option, + tls: Option, + cafile: Option, + host: String, conn: OnceCell>, } @@ -178,6 +209,9 @@ impl Adapter { &self.endpoint, self.username.clone(), self.password.clone(), + self.tls.clone(), + self.cafile.clone(), + &self.host, ); bb8::Pool::builder().build(mgr).await.map_err(|err| { @@ -246,14 +280,27 @@ struct MemcacheConnectionManager { address: String, username: Option, password: Option, + tls: Option, + cafile: Option, + host: String, } impl MemcacheConnectionManager { - fn new(address: &str, username: Option, password: Option) -> Self { + fn new( + address: &str, + username: Option, + password: Option, + tls: Option, + cafile: Option, + host: &str, + ) -> Self { Self { address: address.to_string(), username, password, + tls, + cafile, + host: host.to_string(), } } } @@ -265,14 +312,71 @@ impl bb8::ManageConnection for MemcacheConnectionManager { /// TODO: Implement unix stream support. async fn connect(&self) -> Result { - let conn = TcpStream::connect(&self.address) - .await - .map_err(new_std_io_error)?; - let mut conn = binary::Connection::new(conn); + let conn = if self.tls.unwrap_or(false) { + let mut root_cert_store = rustls::RootCertStore::empty(); + + if let Some(cafile) = &self.cafile { + for cert in CertificateDer::pem_file_iter(cafile).map_err(|err| match err { + rustls::pki_types::pem::Error::Io(err) => new_std_io_error(err), + _ => unreachable!(), + })? { + root_cert_store + .add(cert.map_err(|err| match err { + rustls::pki_types::pem::Error::Io(err) => new_std_io_error(err), + _ => unreachable!(), + })?) + .map_err(|err| { + Error::new(ErrorKind::Unexpected, "tls connect failed").set_source(err) + })?; + } + } else { + root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + } - if let (Some(username), Some(password)) = (self.username.as_ref(), self.password.as_ref()) { - conn.auth(username, password).await?; - } + let config = rustls::ClientConfig::builder() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); + + let connector = TlsConnector::from(Arc::new(config)); + let conn = TcpStream::connect(&self.address) + .await + .map_err(new_std_io_error)?; + let domain = ServerName::try_from(self.host.as_str()) + .map_err(|err| { + Error::new(ErrorKind::ConfigInvalid, "Invalid dns name error") + .with_context("service", Scheme::Memcached) + .with_context("address", &self.address) + .set_source(err) + })? + .to_owned(); + + let conn = connector.connect(domain, conn).await.map_err(|err| { + Error::new(ErrorKind::Unexpected, "tls connect failed").set_source(err) + })?; + + let mut conn = binary::TlsConnection::new(conn); + + if let (Some(username), Some(password)) = + (self.username.as_ref(), self.password.as_ref()) + { + conn.auth(username, password).await?; + } + binary::Connection::Tls(conn) + } else { + let conn = TcpStream::connect(&self.address) + .await + .map_err(new_std_io_error)?; + + let mut conn = binary::TcpConnection::new(conn); + + if let (Some(username), Some(password)) = + (self.username.as_ref(), self.password.as_ref()) + { + conn.auth(username, password).await?; + } + + binary::Connection::Tcp(conn) + }; Ok(conn) } diff --git a/core/src/services/memcached/binary.rs b/core/src/services/memcached/binary.rs index f24db3a4dbe2..3c9c8620d316 100644 --- a/core/src/services/memcached/binary.rs +++ b/core/src/services/memcached/binary.rs @@ -15,14 +15,15 @@ // specific language governing permissions and limitations // under the License. +use crate::raw::*; +use crate::*; + use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::io::BufReader; use tokio::io::{self}; use tokio::net::TcpStream; - -use crate::raw::*; -use crate::*; +use tokio_rustls::client::TlsStream; pub(super) mod constants { pub const OK_STATUS: u16 = 0x0; @@ -61,7 +62,10 @@ pub struct PacketHeader { } impl PacketHeader { - pub async fn write(self, writer: &mut TcpStream) -> io::Result<()> { + pub async fn write( + self, + writer: &mut T, + ) -> io::Result<()> { writer.write_u8(self.magic).await?; writer.write_u8(self.opcode).await?; writer.write_u16(self.key_length).await?; @@ -74,7 +78,9 @@ impl PacketHeader { Ok(()) } - pub async fn read(reader: &mut TcpStream) -> Result { + pub async fn read( + reader: &mut T, + ) -> Result { let header = PacketHeader { magic: reader.read_u8().await?, opcode: reader.read_u8().await?, @@ -90,6 +96,39 @@ impl PacketHeader { } } +pub enum Connection { + Tls(TlsConnection), + Tcp(TcpConnection), +} + +impl Connection { + pub async fn version(&mut self) -> Result { + match self { + Self::Tls(conn) => conn.version().await, + Self::Tcp(conn) => conn.version().await, + } + } + pub async fn get(&mut self, key: &str) -> Result>> { + match self { + Self::Tls(conn) => conn.get(key).await, + Self::Tcp(conn) => conn.get(key).await, + } + } + + pub async fn set(&mut self, key: &str, val: &[u8], expiration: u32) -> Result<()> { + match self { + Self::Tls(conn) => conn.set(key, val, expiration).await, + Self::Tcp(conn) => conn.set(key, val, expiration).await, + } + } + pub async fn delete(&mut self, key: &str) -> Result<()> { + match self { + Self::Tls(conn) => conn.delete(key).await, + Self::Tcp(conn) => conn.delete(key).await, + } + } +} + pub struct Response { header: PacketHeader, _key: Vec, @@ -97,11 +136,160 @@ pub struct Response { value: Vec, } -pub struct Connection { +pub struct TlsConnection { + io: BufReader>, +} + +impl TlsConnection { + pub fn new(io: TlsStream) -> Self { + Self { + io: BufReader::new(io), + } + } + + pub async fn auth(&mut self, username: &str, password: &str) -> Result<()> { + let writer = self.io.get_mut(); + let key = "PLAIN"; + let request_header = PacketHeader { + magic: Magic::Request as u8, + opcode: Opcode::StartAuth as u8, + key_length: key.len() as u16, + total_body_length: (key.len() + username.len() + password.len() + 2) as u32, + ..Default::default() + }; + request_header + .write(writer) + .await + .map_err(new_std_io_error)?; + writer + .write_all(key.as_bytes()) + .await + .map_err(new_std_io_error)?; + writer + .write_all(format!("\x00{}\x00{}", username, password).as_bytes()) + .await + .map_err(new_std_io_error)?; + writer.flush().await.map_err(new_std_io_error)?; + parse_response(writer).await?; + Ok(()) + } + + pub async fn version(&mut self) -> Result { + let writer = self.io.get_mut(); + let request_header = PacketHeader { + magic: Magic::Request as u8, + opcode: Opcode::Version as u8, + ..Default::default() + }; + request_header + .write(writer) + .await + .map_err(new_std_io_error)?; + writer.flush().await.map_err(new_std_io_error)?; + let response = parse_response(writer).await?; + let version = String::from_utf8(response.value); + match version { + Ok(version) => Ok(version), + Err(e) => { + Err(Error::new(ErrorKind::Unexpected, "unexpected data received").set_source(e)) + } + } + } + + pub async fn get(&mut self, key: &str) -> Result>> { + let writer = self.io.get_mut(); + let request_header = PacketHeader { + magic: Magic::Request as u8, + opcode: Opcode::Get as u8, + key_length: key.len() as u16, + total_body_length: key.len() as u32, + ..Default::default() + }; + request_header + .write(writer) + .await + .map_err(new_std_io_error)?; + writer + .write_all(key.as_bytes()) + .await + .map_err(new_std_io_error)?; + writer.flush().await.map_err(new_std_io_error)?; + match parse_response(writer).await { + Ok(response) => { + if response.header.vbucket_id_or_status == 0x1 { + return Ok(None); + } + Ok(Some(response.value)) + } + Err(e) => Err(e), + } + } + + pub async fn set(&mut self, key: &str, val: &[u8], expiration: u32) -> Result<()> { + let writer = self.io.get_mut(); + let request_header = PacketHeader { + magic: Magic::Request as u8, + opcode: Opcode::Set as u8, + key_length: key.len() as u16, + extras_length: 8, + total_body_length: (8 + key.len() + val.len()) as u32, + ..Default::default() + }; + let extras = StoreExtras { + flags: 0, + expiration, + }; + request_header + .write(writer) + .await + .map_err(new_std_io_error)?; + writer + .write_u32(extras.flags) + .await + .map_err(new_std_io_error)?; + writer + .write_u32(extras.expiration) + .await + .map_err(new_std_io_error)?; + writer + .write_all(key.as_bytes()) + .await + .map_err(new_std_io_error)?; + writer.write_all(val).await.map_err(new_std_io_error)?; + writer.flush().await.map_err(new_std_io_error)?; + + parse_response(writer).await?; + Ok(()) + } + + pub async fn delete(&mut self, key: &str) -> Result<()> { + let writer = self.io.get_mut(); + let request_header = PacketHeader { + magic: Magic::Request as u8, + opcode: Opcode::Delete as u8, + key_length: key.len() as u16, + total_body_length: key.len() as u32, + ..Default::default() + }; + request_header + .write(writer) + .await + .map_err(new_std_io_error)?; + writer + .write_all(key.as_bytes()) + .await + .map_err(new_std_io_error)?; + writer.flush().await.map_err(new_std_io_error)?; + parse_response(writer).await?; + Ok(()) + } +} + +pub struct TcpConnection { io: BufReader, } -impl Connection { +impl TcpConnection { pub fn new(io: TcpStream) -> Self { Self { io: BufReader::new(io), @@ -246,8 +434,12 @@ impl Connection { } } -pub async fn parse_response(reader: &mut TcpStream) -> Result { - let header = PacketHeader::read(reader).await.map_err(new_std_io_error)?; +pub async fn parse_response( + reader: &mut T, +) -> Result { + let header = PacketHeader::read::(reader) + .await + .map_err(new_std_io_error)?; if header.vbucket_id_or_status != constants::OK_STATUS && header.vbucket_id_or_status != constants::KEY_NOT_FOUND diff --git a/core/src/services/memcached/config.rs b/core/src/services/memcached/config.rs index f0b5815ff7e6..edcecd038d70 100644 --- a/core/src/services/memcached/config.rs +++ b/core/src/services/memcached/config.rs @@ -16,6 +16,7 @@ // under the License. use std::fmt::Debug; +use std::path::PathBuf; use std::time::Duration; use serde::Deserialize; @@ -40,4 +41,8 @@ pub struct MemcachedConfig { pub password: Option, /// The default ttl for put operations. pub default_ttl: Option, + /// default is false + pub tls: Option, + /// Path to the CA certificate for TLS verification. + pub cafile: Option, }