Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(services/memcached): Add TLS support for AWS ElastiCache #5499

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 102 additions & 80 deletions core/src/services/memcached/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ use std::time::Duration;
use bb8::RunError;
use tokio::net::TcpStream;
use tokio::sync::OnceCell;
use tokio_native_tls::native_tls;
use tokio_native_tls::TlsConnector;
use tokio_native_tls::TlsStream;

use super::binary;
use crate::raw::adapters::kv;
Expand Down Expand Up @@ -82,6 +85,22 @@ impl MemcachedBuilder {
self.config.default_ttl = Some(ttl);
self
}

/// Enable TLS for the connection.
///
/// Required for AWS ElastiCache Memcached serverless instances.
pub fn enable_tls(mut self, enable: bool) -> Self {
self.config.enable_tls = Some(enable);
self
}

/// Set the CA certificate file path for TLS verification.
pub fn ca_cert(mut self, ca_cert: &str) -> Self {
if !ca_cert.is_empty() {
self.config.ca_cert = Some(ca_cert.to_string());
}
self
}
}

impl Builder for MemcachedBuilder {
Expand All @@ -100,60 +119,32 @@ impl Builder for MemcachedBuilder {
.set_source(err)
})?;

match uri.scheme_str() {
// If scheme is none, we will use tcp by default.
None => (),
Some(scheme) => {
// We only support tcp by now.
if scheme != "tcp" {
return Err(Error::new(
ErrorKind::ConfigInvalid,
"endpoint is using invalid scheme",
)
.with_context("service", Scheme::Memcached)
.with_context("endpoint", &endpoint)
.with_context("scheme", scheme.to_string()));
}
}
};
let authority = uri.authority().ok_or_else(|| {
Error::new(ErrorKind::ConfigInvalid, "endpoint must contain authority")
.with_context("service", Scheme::Memcached)
.with_context("endpoint", &endpoint)
})?;

let host = if let Some(host) = uri.host() {
host.to_string()
} else {
return Err(
Error::new(ErrorKind::ConfigInvalid, "endpoint doesn't have host")
.with_context("service", Scheme::Memcached)
.with_context("endpoint", &endpoint),
);
};
let port = if let Some(port) = uri.port_u16() {
port
} else {
return Err(
Error::new(ErrorKind::ConfigInvalid, "endpoint doesn't have port")
.with_context("service", Scheme::Memcached)
.with_context("endpoint", &endpoint),
);
};
let endpoint = format!("{host}:{port}",);

let root = normalize_root(
self.config
.root
.clone()
.unwrap_or_else(|| "/".to_string())
.as_str(),
let root = normalize_root(&self.config.root.unwrap_or_default())?;

let manager = MemcacheConnectionManager::new(
authority.as_str(),
self.config.username,
self.config.password,
self.config.enable_tls.unwrap_or(false),
self.config.ca_cert,
);

let conn = OnceCell::new();
let pool = bb8::Pool::builder()
.max_size(1)
.build(manager)
.map_err(new_connection_error)?;

Ok(MemcachedBackend::new(Adapter {
endpoint,
username: self.config.username.clone(),
password: self.config.password.clone(),
conn,
pool: pool.into(),
root,
default_ttl: self.config.default_ttl,
})
.with_normalized_root(root))
}))
}
}

Expand All @@ -162,32 +153,14 @@ pub type MemcachedBackend = kv::Backend<Adapter>;

#[derive(Clone, Debug)]
pub struct Adapter {
endpoint: String,
username: Option<String>,
password: Option<String>,
pool: bb8::Pool<MemcacheConnectionManager>,
root: String,
default_ttl: Option<Duration>,
conn: OnceCell<bb8::Pool<MemcacheConnectionManager>>,
}

impl Adapter {
async fn conn(&self) -> Result<bb8::PooledConnection<'_, MemcacheConnectionManager>> {
let pool = self
.conn
.get_or_try_init(|| async {
let mgr = MemcacheConnectionManager::new(
&self.endpoint,
self.username.clone(),
self.password.clone(),
);

bb8::Pool::builder().build(mgr).await.map_err(|err| {
Error::new(ErrorKind::ConfigInvalid, "connect to memecached failed")
.set_source(err)
})
})
.await?;

pool.get().await.map_err(|err| match err {
self.pool.get().await.map_err(|err| match err {
RunError::TimedOut => {
Error::new(ErrorKind::Unexpected, "get connection from pool failed").set_temporary()
}
Expand Down Expand Up @@ -242,18 +215,28 @@ impl kv::Adapter for Adapter {

/// A `bb8::ManageConnection` for `memcache_async::ascii::Protocol`.
#[derive(Clone, Debug)]
struct MemcacheConnectionManager {
pub struct MemcacheConnectionManager {
address: String,
username: Option<String>,
password: Option<String>,
enable_tls: bool,
ca_cert: Option<String>,
}

impl MemcacheConnectionManager {
fn new(address: &str, username: Option<String>, password: Option<String>) -> Self {
pub fn new(
address: &str,
username: Option<String>,
password: Option<String>,
enable_tls: bool,
ca_cert: Option<String>,
) -> Self {
Self {
address: address.to_string(),
username,
password,
enable_tls,
ca_cert,
}
}
}
Expand All @@ -263,17 +246,56 @@ impl bb8::ManageConnection for MemcacheConnectionManager {
type Connection = binary::Connection;
type Error = Error;

/// TODO: Implement unix stream support.
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
let conn = TcpStream::connect(&self.address)
.await
.map_err(new_std_io_error)?;
let mut conn = binary::Connection::new(conn);
let stream = if self.enable_tls {
let mut builder = native_tls::TlsConnector::builder();

// If CA cert is provided, add it to the builder
if let Some(ca_cert) = &self.ca_cert {
builder.add_root_certificate(
native_tls::Certificate::from_pem(
&std::fs::read(ca_cert).map_err(|err| {
Error::new(ErrorKind::ConfigInvalid, "failed to read CA certificate")
.set_source(err)
})?
).map_err(|err| {
Error::new(ErrorKind::ConfigInvalid, "invalid CA certificate")
.set_source(err)
})?
);
}

let connector = builder.build().map_err(|err| {
Error::new(ErrorKind::ConfigInvalid, "failed to build TLS connector")
.set_source(err)
})?;
let connector = TlsConnector::from(connector);

let tcp = TcpStream::connect(&self.address).await.map_err(|err| {
Error::new(ErrorKind::ConfigInvalid, "failed to connect")
.set_source(err)
})?;

let domain = self.address.split(':').next().unwrap_or(&self.address);
let tls_stream = connector.connect(domain, tcp).await.map_err(|err| {
Error::new(ErrorKind::ConfigInvalid, "TLS handshake failed")
.set_source(err)
})?;

if let (Some(username), Some(password)) = (self.username.as_ref(), self.password.as_ref()) {
conn.auth(username, password).await?;
binary::Connection::new(Box::new(tls_stream)).await
} else {
let tcp = TcpStream::connect(&self.address).await.map_err(|err| {
Error::new(ErrorKind::ConfigInvalid, "failed to connect")
.set_source(err)
})?;
binary::Connection::new(Box::new(tcp)).await
};

if let (Some(username), Some(password)) = (&self.username, &self.password) {
stream.authenticate(username, password).await?;
}
Ok(conn)

Ok(stream)
}

async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
Expand Down
25 changes: 12 additions & 13 deletions core/src/services/memcached/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::io::{self};
use tokio::net::TcpStream;
use tokio_native_tls::TlsStream;

use crate::raw::*;
use crate::*;
Expand Down Expand Up @@ -61,7 +62,7 @@ pub struct PacketHeader {
}

impl PacketHeader {
pub async fn write(self, writer: &mut TcpStream) -> io::Result<()> {
pub async fn write(self, writer: &mut dyn AsyncWrite) -> io::Result<()> {
writer.write_u8(self.magic).await?;
writer.write_u8(self.opcode).await?;
writer.write_u16(self.key_length).await?;
Expand All @@ -74,7 +75,7 @@ impl PacketHeader {
Ok(())
}

pub async fn read(reader: &mut TcpStream) -> Result<PacketHeader, io::Error> {
pub async fn read(reader: &mut dyn AsyncRead) -> Result<PacketHeader, io::Error> {
let header = PacketHeader {
magic: reader.read_u8().await?,
opcode: reader.read_u8().await?,
Expand All @@ -98,18 +99,16 @@ pub struct Response {
}

pub struct Connection {
io: BufReader<TcpStream>,
stream: Box<dyn AsyncRead + AsyncWrite + Send + Unpin>,
}

impl Connection {
pub fn new(io: TcpStream) -> Self {
Self {
io: BufReader::new(io),
}
pub async fn new(stream: Box<dyn AsyncRead + AsyncWrite + Send + Unpin>) -> Result<Self, Error> {
Ok(Self { stream })
}

pub async fn auth(&mut self, username: &str, password: &str) -> Result<()> {
let writer = self.io.get_mut();
let writer = &mut *self.stream;
let key = "PLAIN";
let request_header = PacketHeader {
magic: Magic::Request as u8,
Expand All @@ -136,7 +135,7 @@ impl Connection {
}

pub async fn version(&mut self) -> Result<String> {
let writer = self.io.get_mut();
let writer = &mut *self.stream;
let request_header = PacketHeader {
magic: Magic::Request as u8,
opcode: Opcode::Version as u8,
Expand All @@ -158,7 +157,7 @@ impl Connection {
}

pub async fn get(&mut self, key: &str) -> Result<Option<Vec<u8>>> {
let writer = self.io.get_mut();
let writer = &mut *self.stream;
let request_header = PacketHeader {
magic: Magic::Request as u8,
opcode: Opcode::Get as u8,
Expand Down Expand Up @@ -187,7 +186,7 @@ impl Connection {
}

pub async fn set(&mut self, key: &str, val: &[u8], expiration: u32) -> Result<()> {
let writer = self.io.get_mut();
let writer = &mut *self.stream;
let request_header = PacketHeader {
magic: Magic::Request as u8,
opcode: Opcode::Set as u8,
Expand Down Expand Up @@ -224,7 +223,7 @@ impl Connection {
}

pub async fn delete(&mut self, key: &str) -> Result<()> {
let writer = self.io.get_mut();
let writer = &mut *self.stream;
let request_header = PacketHeader {
magic: Magic::Request as u8,
opcode: Opcode::Delete as u8,
Expand All @@ -246,7 +245,7 @@ impl Connection {
}
}

pub async fn parse_response(reader: &mut TcpStream) -> Result<Response> {
pub async fn parse_response(reader: &mut dyn AsyncRead) -> Result<Response> {
let header = PacketHeader::read(reader).await.map_err(new_std_io_error)?;

if header.vbucket_id_or_status != constants::OK_STATUS
Expand Down
6 changes: 6 additions & 0 deletions core/src/services/memcached/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,10 @@ pub struct MemcachedConfig {
pub password: Option<String>,
/// The default ttl for put operations.
pub default_ttl: Option<Duration>,
/// Enable TLS for the connection.
///
/// Required for AWS ElastiCache Memcached serverless instances.
pub enable_tls: Option<bool>,
/// Path to CA certificate file for TLS verification.
pub ca_cert: Option<String>,
}
Loading