diff --git a/core/src/services/memcached/backend.rs b/core/src/services/memcached/backend.rs index 5ddfa4b9c114..25e4527f9e87 100644 --- a/core/src/services/memcached/backend.rs +++ b/core/src/services/memcached/backend.rs @@ -20,6 +20,9 @@ use std::time::Duration; use bb8::RunError; use tokio::net::TcpStream; use tokio::sync::OnceCell; +use tokio_native_tls::native_tls; +use tokio_native_tls::TlsConnector; +use tokio_native_tls::TlsStream; use super::binary; use crate::raw::adapters::kv; @@ -82,6 +85,22 @@ impl MemcachedBuilder { self.config.default_ttl = Some(ttl); self } + + /// Enable TLS for the connection. + /// + /// Required for AWS ElastiCache Memcached serverless instances. + pub fn enable_tls(mut self, enable: bool) -> Self { + self.config.enable_tls = Some(enable); + self + } + + /// Set the CA certificate file path for TLS verification. + pub fn ca_cert(mut self, ca_cert: &str) -> Self { + if !ca_cert.is_empty() { + self.config.ca_cert = Some(ca_cert.to_string()); + } + self + } } impl Builder for MemcachedBuilder { @@ -100,60 +119,32 @@ impl Builder for MemcachedBuilder { .set_source(err) })?; - match uri.scheme_str() { - // If scheme is none, we will use tcp by default. - None => (), - Some(scheme) => { - // We only support tcp by now. - if scheme != "tcp" { - return Err(Error::new( - ErrorKind::ConfigInvalid, - "endpoint is using invalid scheme", - ) - .with_context("service", Scheme::Memcached) - .with_context("endpoint", &endpoint) - .with_context("scheme", scheme.to_string())); - } - } - }; + let authority = uri.authority().ok_or_else(|| { + Error::new(ErrorKind::ConfigInvalid, "endpoint must contain authority") + .with_context("service", Scheme::Memcached) + .with_context("endpoint", &endpoint) + })?; - let host = if let Some(host) = uri.host() { - host.to_string() - } else { - return Err( - Error::new(ErrorKind::ConfigInvalid, "endpoint doesn't have host") - .with_context("service", Scheme::Memcached) - .with_context("endpoint", &endpoint), - ); - }; - let port = if let Some(port) = uri.port_u16() { - port - } else { - return Err( - Error::new(ErrorKind::ConfigInvalid, "endpoint doesn't have port") - .with_context("service", Scheme::Memcached) - .with_context("endpoint", &endpoint), - ); - }; - let endpoint = format!("{host}:{port}",); - - let root = normalize_root( - self.config - .root - .clone() - .unwrap_or_else(|| "/".to_string()) - .as_str(), + let root = normalize_root(&self.config.root.unwrap_or_default())?; + + let manager = MemcacheConnectionManager::new( + authority.as_str(), + self.config.username, + self.config.password, + self.config.enable_tls.unwrap_or(false), + self.config.ca_cert, ); - let conn = OnceCell::new(); + let pool = bb8::Pool::builder() + .max_size(1) + .build(manager) + .map_err(new_connection_error)?; + Ok(MemcachedBackend::new(Adapter { - endpoint, - username: self.config.username.clone(), - password: self.config.password.clone(), - conn, + pool: pool.into(), + root, default_ttl: self.config.default_ttl, - }) - .with_normalized_root(root)) + })) } } @@ -162,32 +153,14 @@ pub type MemcachedBackend = kv::Backend; #[derive(Clone, Debug)] pub struct Adapter { - endpoint: String, - username: Option, - password: Option, + pool: bb8::Pool, + root: String, default_ttl: Option, - conn: OnceCell>, } impl Adapter { async fn conn(&self) -> Result> { - let pool = self - .conn - .get_or_try_init(|| async { - let mgr = MemcacheConnectionManager::new( - &self.endpoint, - self.username.clone(), - self.password.clone(), - ); - - bb8::Pool::builder().build(mgr).await.map_err(|err| { - Error::new(ErrorKind::ConfigInvalid, "connect to memecached failed") - .set_source(err) - }) - }) - .await?; - - pool.get().await.map_err(|err| match err { + self.pool.get().await.map_err(|err| match err { RunError::TimedOut => { Error::new(ErrorKind::Unexpected, "get connection from pool failed").set_temporary() } @@ -242,18 +215,28 @@ impl kv::Adapter for Adapter { /// A `bb8::ManageConnection` for `memcache_async::ascii::Protocol`. #[derive(Clone, Debug)] -struct MemcacheConnectionManager { +pub struct MemcacheConnectionManager { address: String, username: Option, password: Option, + enable_tls: bool, + ca_cert: Option, } impl MemcacheConnectionManager { - fn new(address: &str, username: Option, password: Option) -> Self { + pub fn new( + address: &str, + username: Option, + password: Option, + enable_tls: bool, + ca_cert: Option, + ) -> Self { Self { address: address.to_string(), username, password, + enable_tls, + ca_cert, } } } @@ -263,17 +246,56 @@ impl bb8::ManageConnection for MemcacheConnectionManager { type Connection = binary::Connection; type Error = Error; - /// TODO: Implement unix stream support. async fn connect(&self) -> Result { - let conn = TcpStream::connect(&self.address) - .await - .map_err(new_std_io_error)?; - let mut conn = binary::Connection::new(conn); + let stream = if self.enable_tls { + let mut builder = native_tls::TlsConnector::builder(); + + // If CA cert is provided, add it to the builder + if let Some(ca_cert) = &self.ca_cert { + builder.add_root_certificate( + native_tls::Certificate::from_pem( + &std::fs::read(ca_cert).map_err(|err| { + Error::new(ErrorKind::ConfigInvalid, "failed to read CA certificate") + .set_source(err) + })? + ).map_err(|err| { + Error::new(ErrorKind::ConfigInvalid, "invalid CA certificate") + .set_source(err) + })? + ); + } + + let connector = builder.build().map_err(|err| { + Error::new(ErrorKind::ConfigInvalid, "failed to build TLS connector") + .set_source(err) + })?; + let connector = TlsConnector::from(connector); + + let tcp = TcpStream::connect(&self.address).await.map_err(|err| { + Error::new(ErrorKind::ConfigInvalid, "failed to connect") + .set_source(err) + })?; + + let domain = self.address.split(':').next().unwrap_or(&self.address); + let tls_stream = connector.connect(domain, tcp).await.map_err(|err| { + Error::new(ErrorKind::ConfigInvalid, "TLS handshake failed") + .set_source(err) + })?; - if let (Some(username), Some(password)) = (self.username.as_ref(), self.password.as_ref()) { - conn.auth(username, password).await?; + binary::Connection::new(Box::new(tls_stream)).await + } else { + let tcp = TcpStream::connect(&self.address).await.map_err(|err| { + Error::new(ErrorKind::ConfigInvalid, "failed to connect") + .set_source(err) + })?; + binary::Connection::new(Box::new(tcp)).await + }; + + if let (Some(username), Some(password)) = (&self.username, &self.password) { + stream.authenticate(username, password).await?; } - Ok(conn) + + Ok(stream) } async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { diff --git a/core/src/services/memcached/binary.rs b/core/src/services/memcached/binary.rs index f24db3a4dbe2..a1ffa6e1c4ba 100644 --- a/core/src/services/memcached/binary.rs +++ b/core/src/services/memcached/binary.rs @@ -20,6 +20,7 @@ use tokio::io::AsyncWriteExt; use tokio::io::BufReader; use tokio::io::{self}; use tokio::net::TcpStream; +use tokio_native_tls::TlsStream; use crate::raw::*; use crate::*; @@ -61,7 +62,7 @@ pub struct PacketHeader { } impl PacketHeader { - pub async fn write(self, writer: &mut TcpStream) -> io::Result<()> { + pub async fn write(self, writer: &mut dyn AsyncWrite) -> io::Result<()> { writer.write_u8(self.magic).await?; writer.write_u8(self.opcode).await?; writer.write_u16(self.key_length).await?; @@ -74,7 +75,7 @@ impl PacketHeader { Ok(()) } - pub async fn read(reader: &mut TcpStream) -> Result { + pub async fn read(reader: &mut dyn AsyncRead) -> Result { let header = PacketHeader { magic: reader.read_u8().await?, opcode: reader.read_u8().await?, @@ -98,18 +99,16 @@ pub struct Response { } pub struct Connection { - io: BufReader, + stream: Box, } impl Connection { - pub fn new(io: TcpStream) -> Self { - Self { - io: BufReader::new(io), - } + pub async fn new(stream: Box) -> Result { + Ok(Self { stream }) } pub async fn auth(&mut self, username: &str, password: &str) -> Result<()> { - let writer = self.io.get_mut(); + let writer = &mut *self.stream; let key = "PLAIN"; let request_header = PacketHeader { magic: Magic::Request as u8, @@ -136,7 +135,7 @@ impl Connection { } pub async fn version(&mut self) -> Result { - let writer = self.io.get_mut(); + let writer = &mut *self.stream; let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Version as u8, @@ -158,7 +157,7 @@ impl Connection { } pub async fn get(&mut self, key: &str) -> Result>> { - let writer = self.io.get_mut(); + let writer = &mut *self.stream; let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Get as u8, @@ -187,7 +186,7 @@ impl Connection { } pub async fn set(&mut self, key: &str, val: &[u8], expiration: u32) -> Result<()> { - let writer = self.io.get_mut(); + let writer = &mut *self.stream; let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Set as u8, @@ -224,7 +223,7 @@ impl Connection { } pub async fn delete(&mut self, key: &str) -> Result<()> { - let writer = self.io.get_mut(); + let writer = &mut *self.stream; let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Delete as u8, @@ -246,7 +245,7 @@ impl Connection { } } -pub async fn parse_response(reader: &mut TcpStream) -> Result { +pub async fn parse_response(reader: &mut dyn AsyncRead) -> Result { let header = PacketHeader::read(reader).await.map_err(new_std_io_error)?; if header.vbucket_id_or_status != constants::OK_STATUS diff --git a/core/src/services/memcached/config.rs b/core/src/services/memcached/config.rs index f0b5815ff7e6..cde5a920cabe 100644 --- a/core/src/services/memcached/config.rs +++ b/core/src/services/memcached/config.rs @@ -40,4 +40,10 @@ pub struct MemcachedConfig { pub password: Option, /// The default ttl for put operations. pub default_ttl: Option, + /// Enable TLS for the connection. + /// + /// Required for AWS ElastiCache Memcached serverless instances. + pub enable_tls: Option, + /// Path to CA certificate file for TLS verification. + pub ca_cert: Option, }