Skip to content

Commit

Permalink
feat(services/memcached): Add TLS support for AWS ElastiCache
Browse files Browse the repository at this point in the history
Implement TLS support for Memcached connections, particularly for AWS
ElastiCache serverless instances which require TLS. This change includes:

- Add TLS configuration options to MemcachedConfig
- Add TLS support in connection handling using tokio-native-tls
- Support custom CA certificates for AWS ElastiCache verification
- Update binary protocol handling for TLS streams
  • Loading branch information
AryanVBW committed Jan 2, 2025
1 parent 0146a12 commit c11cd14
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 93 deletions.
182 changes: 102 additions & 80 deletions core/src/services/memcached/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ use std::time::Duration;
use bb8::RunError;
use tokio::net::TcpStream;
use tokio::sync::OnceCell;
use tokio_native_tls::native_tls;
use tokio_native_tls::TlsConnector;
use tokio_native_tls::TlsStream;

use super::binary;
use crate::raw::adapters::kv;
Expand Down Expand Up @@ -82,6 +85,22 @@ impl MemcachedBuilder {
self.config.default_ttl = Some(ttl);
self
}

/// Enable TLS for the connection.
///
/// Required for AWS ElastiCache Memcached serverless instances.
pub fn enable_tls(mut self, enable: bool) -> Self {
self.config.enable_tls = Some(enable);
self
}

/// Set the CA certificate file path for TLS verification.
pub fn ca_cert(mut self, ca_cert: &str) -> Self {
if !ca_cert.is_empty() {
self.config.ca_cert = Some(ca_cert.to_string());
}
self
}
}

impl Builder for MemcachedBuilder {
Expand All @@ -100,60 +119,32 @@ impl Builder for MemcachedBuilder {
.set_source(err)
})?;

match uri.scheme_str() {
// If scheme is none, we will use tcp by default.
None => (),
Some(scheme) => {
// We only support tcp by now.
if scheme != "tcp" {
return Err(Error::new(
ErrorKind::ConfigInvalid,
"endpoint is using invalid scheme",
)
.with_context("service", Scheme::Memcached)
.with_context("endpoint", &endpoint)
.with_context("scheme", scheme.to_string()));
}
}
};
let authority = uri.authority().ok_or_else(|| {
Error::new(ErrorKind::ConfigInvalid, "endpoint must contain authority")
.with_context("service", Scheme::Memcached)
.with_context("endpoint", &endpoint)
})?;

let host = if let Some(host) = uri.host() {
host.to_string()
} else {
return Err(
Error::new(ErrorKind::ConfigInvalid, "endpoint doesn't have host")
.with_context("service", Scheme::Memcached)
.with_context("endpoint", &endpoint),
);
};
let port = if let Some(port) = uri.port_u16() {
port
} else {
return Err(
Error::new(ErrorKind::ConfigInvalid, "endpoint doesn't have port")
.with_context("service", Scheme::Memcached)
.with_context("endpoint", &endpoint),
);
};
let endpoint = format!("{host}:{port}",);

let root = normalize_root(
self.config
.root
.clone()
.unwrap_or_else(|| "/".to_string())
.as_str(),
let root = normalize_root(&self.config.root.unwrap_or_default())?;

let manager = MemcacheConnectionManager::new(
authority.as_str(),
self.config.username,
self.config.password,
self.config.enable_tls.unwrap_or(false),
self.config.ca_cert,
);

let conn = OnceCell::new();
let pool = bb8::Pool::builder()
.max_size(1)
.build(manager)
.map_err(new_connection_error)?;

Ok(MemcachedBackend::new(Adapter {
endpoint,
username: self.config.username.clone(),
password: self.config.password.clone(),
conn,
pool: pool.into(),
root,
default_ttl: self.config.default_ttl,
})
.with_normalized_root(root))
}))
}
}

Expand All @@ -162,32 +153,14 @@ pub type MemcachedBackend = kv::Backend<Adapter>;

#[derive(Clone, Debug)]
pub struct Adapter {
endpoint: String,
username: Option<String>,
password: Option<String>,
pool: bb8::Pool<MemcacheConnectionManager>,
root: String,
default_ttl: Option<Duration>,
conn: OnceCell<bb8::Pool<MemcacheConnectionManager>>,
}

impl Adapter {
async fn conn(&self) -> Result<bb8::PooledConnection<'_, MemcacheConnectionManager>> {
let pool = self
.conn
.get_or_try_init(|| async {
let mgr = MemcacheConnectionManager::new(
&self.endpoint,
self.username.clone(),
self.password.clone(),
);

bb8::Pool::builder().build(mgr).await.map_err(|err| {
Error::new(ErrorKind::ConfigInvalid, "connect to memecached failed")
.set_source(err)
})
})
.await?;

pool.get().await.map_err(|err| match err {
self.pool.get().await.map_err(|err| match err {
RunError::TimedOut => {
Error::new(ErrorKind::Unexpected, "get connection from pool failed").set_temporary()
}
Expand Down Expand Up @@ -242,18 +215,28 @@ impl kv::Adapter for Adapter {

/// A `bb8::ManageConnection` for `memcache_async::ascii::Protocol`.
#[derive(Clone, Debug)]
struct MemcacheConnectionManager {
pub struct MemcacheConnectionManager {
address: String,
username: Option<String>,
password: Option<String>,
enable_tls: bool,
ca_cert: Option<String>,
}

impl MemcacheConnectionManager {
fn new(address: &str, username: Option<String>, password: Option<String>) -> Self {
pub fn new(
address: &str,
username: Option<String>,
password: Option<String>,
enable_tls: bool,
ca_cert: Option<String>,
) -> Self {
Self {
address: address.to_string(),
username,
password,
enable_tls,
ca_cert,
}
}
}
Expand All @@ -263,17 +246,56 @@ impl bb8::ManageConnection for MemcacheConnectionManager {
type Connection = binary::Connection;
type Error = Error;

/// TODO: Implement unix stream support.
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
let conn = TcpStream::connect(&self.address)
.await
.map_err(new_std_io_error)?;
let mut conn = binary::Connection::new(conn);
let stream = if self.enable_tls {
let mut builder = native_tls::TlsConnector::builder();

// If CA cert is provided, add it to the builder
if let Some(ca_cert) = &self.ca_cert {
builder.add_root_certificate(
native_tls::Certificate::from_pem(
&std::fs::read(ca_cert).map_err(|err| {
Error::new(ErrorKind::ConfigInvalid, "failed to read CA certificate")
.set_source(err)
})?
).map_err(|err| {
Error::new(ErrorKind::ConfigInvalid, "invalid CA certificate")
.set_source(err)
})?
);
}

let connector = builder.build().map_err(|err| {
Error::new(ErrorKind::ConfigInvalid, "failed to build TLS connector")
.set_source(err)
})?;
let connector = TlsConnector::from(connector);

let tcp = TcpStream::connect(&self.address).await.map_err(|err| {
Error::new(ErrorKind::ConfigInvalid, "failed to connect")
.set_source(err)
})?;

let domain = self.address.split(':').next().unwrap_or(&self.address);
let tls_stream = connector.connect(domain, tcp).await.map_err(|err| {
Error::new(ErrorKind::ConfigInvalid, "TLS handshake failed")
.set_source(err)
})?;

if let (Some(username), Some(password)) = (self.username.as_ref(), self.password.as_ref()) {
conn.auth(username, password).await?;
binary::Connection::new(Box::new(tls_stream)).await
} else {
let tcp = TcpStream::connect(&self.address).await.map_err(|err| {
Error::new(ErrorKind::ConfigInvalid, "failed to connect")
.set_source(err)
})?;
binary::Connection::new(Box::new(tcp)).await
};

if let (Some(username), Some(password)) = (&self.username, &self.password) {
stream.authenticate(username, password).await?;
}
Ok(conn)

Ok(stream)
}

async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
Expand Down
25 changes: 12 additions & 13 deletions core/src/services/memcached/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::io::{self};
use tokio::net::TcpStream;
use tokio_native_tls::TlsStream;

use crate::raw::*;
use crate::*;
Expand Down Expand Up @@ -61,7 +62,7 @@ pub struct PacketHeader {
}

impl PacketHeader {
pub async fn write(self, writer: &mut TcpStream) -> io::Result<()> {
pub async fn write(self, writer: &mut dyn AsyncWrite) -> io::Result<()> {
writer.write_u8(self.magic).await?;
writer.write_u8(self.opcode).await?;
writer.write_u16(self.key_length).await?;
Expand All @@ -74,7 +75,7 @@ impl PacketHeader {
Ok(())
}

pub async fn read(reader: &mut TcpStream) -> Result<PacketHeader, io::Error> {
pub async fn read(reader: &mut dyn AsyncRead) -> Result<PacketHeader, io::Error> {
let header = PacketHeader {
magic: reader.read_u8().await?,
opcode: reader.read_u8().await?,
Expand All @@ -98,18 +99,16 @@ pub struct Response {
}

pub struct Connection {
io: BufReader<TcpStream>,
stream: Box<dyn AsyncRead + AsyncWrite + Send + Unpin>,
}

impl Connection {
pub fn new(io: TcpStream) -> Self {
Self {
io: BufReader::new(io),
}
pub async fn new(stream: Box<dyn AsyncRead + AsyncWrite + Send + Unpin>) -> Result<Self, Error> {
Ok(Self { stream })
}

pub async fn auth(&mut self, username: &str, password: &str) -> Result<()> {
let writer = self.io.get_mut();
let writer = &mut *self.stream;
let key = "PLAIN";
let request_header = PacketHeader {
magic: Magic::Request as u8,
Expand All @@ -136,7 +135,7 @@ impl Connection {
}

pub async fn version(&mut self) -> Result<String> {
let writer = self.io.get_mut();
let writer = &mut *self.stream;
let request_header = PacketHeader {
magic: Magic::Request as u8,
opcode: Opcode::Version as u8,
Expand All @@ -158,7 +157,7 @@ impl Connection {
}

pub async fn get(&mut self, key: &str) -> Result<Option<Vec<u8>>> {
let writer = self.io.get_mut();
let writer = &mut *self.stream;
let request_header = PacketHeader {
magic: Magic::Request as u8,
opcode: Opcode::Get as u8,
Expand Down Expand Up @@ -187,7 +186,7 @@ impl Connection {
}

pub async fn set(&mut self, key: &str, val: &[u8], expiration: u32) -> Result<()> {
let writer = self.io.get_mut();
let writer = &mut *self.stream;
let request_header = PacketHeader {
magic: Magic::Request as u8,
opcode: Opcode::Set as u8,
Expand Down Expand Up @@ -224,7 +223,7 @@ impl Connection {
}

pub async fn delete(&mut self, key: &str) -> Result<()> {
let writer = self.io.get_mut();
let writer = &mut *self.stream;
let request_header = PacketHeader {
magic: Magic::Request as u8,
opcode: Opcode::Delete as u8,
Expand All @@ -246,7 +245,7 @@ impl Connection {
}
}

pub async fn parse_response(reader: &mut TcpStream) -> Result<Response> {
pub async fn parse_response(reader: &mut dyn AsyncRead) -> Result<Response> {
let header = PacketHeader::read(reader).await.map_err(new_std_io_error)?;

if header.vbucket_id_or_status != constants::OK_STATUS
Expand Down
6 changes: 6 additions & 0 deletions core/src/services/memcached/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,10 @@ pub struct MemcachedConfig {
pub password: Option<String>,
/// The default ttl for put operations.
pub default_ttl: Option<Duration>,
/// Enable TLS for the connection.
///
/// Required for AWS ElastiCache Memcached serverless instances.
pub enable_tls: Option<bool>,
/// Path to CA certificate file for TLS verification.
pub ca_cert: Option<String>,
}

0 comments on commit c11cd14

Please sign in to comment.