Skip to content

Commit

Permalink
feat: make tls certificates/keys reloadable (part 1) (GreptimeTeam#3335)
Browse files Browse the repository at this point in the history
* feat: make tls certificates/keys reloadable (part 1)

* feat: add notify watcher for cert/key files

* test: add unit test for watcher

* fix: correct usage of watcher

* fix: skip watch when tls disabled
  • Loading branch information
sunng87 authored Feb 26, 2024
1 parent e859f0e commit 3887d20
Show file tree
Hide file tree
Showing 10 changed files with 310 additions and 43 deletions.
70 changes: 70 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 16 additions & 7 deletions src/frontend/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use std::sync::Arc;
use auth::UserProviderRef;
use common_base::Plugins;
use common_runtime::Builder as RuntimeBuilder;
use servers::error::InternalIoSnafu;
use servers::grpc::builder::GrpcServerBuilder;
use servers::grpc::greptime_handler::GreptimeRequestHandler;
use servers::grpc::{GrpcServer, GrpcServerConfig};
Expand All @@ -30,6 +29,7 @@ use servers::postgres::PostgresServer;
use servers::query_handler::grpc::ServerGrpcQueryHandlerAdapter;
use servers::query_handler::sql::ServerSqlQueryHandlerAdapter;
use servers::server::{Server, ServerHandlers};
use servers::tls::{watch_tls_config, ReloadableTlsServerConfig};
use snafu::ResultExt;

use crate::error::{self, Result, StartServerSnafu};
Expand Down Expand Up @@ -195,6 +195,12 @@ where
let opts = &opts.mysql;
let mysql_addr = parse_addr(&opts.addr)?;

let tls_server_config = Arc::new(
ReloadableTlsServerConfig::try_new(opts.tls.clone()).context(StartServerSnafu)?,
);

watch_tls_config(tls_server_config.clone()).context(StartServerSnafu)?;

let mysql_io_runtime = Arc::new(
RuntimeBuilder::default()
.worker_threads(opts.runtime_size)
Expand All @@ -210,11 +216,7 @@ where
)),
Arc::new(MysqlSpawnConfig::new(
opts.tls.should_force_tls(),
opts.tls
.setup()
.context(InternalIoSnafu)
.context(StartServerSnafu)?
.map(Arc::new),
tls_server_config,
opts.reject_no_database.unwrap_or(false),
)),
);
Expand All @@ -226,6 +228,12 @@ where
let opts = &opts.postgres;
let pg_addr = parse_addr(&opts.addr)?;

let tls_server_config = Arc::new(
ReloadableTlsServerConfig::try_new(opts.tls.clone()).context(StartServerSnafu)?,
);

watch_tls_config(tls_server_config.clone()).context(StartServerSnafu)?;

let pg_io_runtime = Arc::new(
RuntimeBuilder::default()
.worker_threads(opts.runtime_size)
Expand All @@ -236,7 +244,8 @@ where

let pg_server = Box::new(PostgresServer::new(
ServerSqlQueryHandlerAdapter::arc(instance.clone()),
opts.tls.clone(),
opts.tls.should_force_tls(),
tls_server_config,
pg_io_runtime,
user_provider.clone(),
)) as Box<dyn Server>;
Expand Down
2 changes: 2 additions & 0 deletions src/servers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ influxdb_line_protocol = { git = "https://github.com/evenyag/influxdb_iox", bran
itertools.workspace = true
lazy_static.workspace = true
mime_guess = "2.0"
notify = "6.1"
once_cell.workspace = true
openmetrics-parser = "0.4"
opensrv-mysql = "0.7.0"
Expand Down Expand Up @@ -121,6 +122,7 @@ script = { workspace = true, features = ["python"] }
serde_json.workspace = true
session = { workspace = true, features = ["testing"] }
table.workspace = true
tempfile = "3.0.0"
tokio-postgres = "0.7"
tokio-postgres-rustls = "0.11"
tokio-test = "0.4"
Expand Down
9 changes: 8 additions & 1 deletion src/servers/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,12 @@ pub enum Error {
"Invalid parameter, physical_table is not expected when metric engine is disabled"
))]
UnexpectedPhysicalTable { location: Location },

#[snafu(display("Failed to initialize a watcher for file"))]
FileWatch {
#[snafu(source)]
error: notify::Error,
},
}

pub type Result<T> = std::result::Result<T, Error>;
Expand All @@ -462,7 +468,8 @@ impl ErrorExt for Error {
| CatalogError { .. }
| GrpcReflectionService { .. }
| BuildHttpResponse { .. }
| Arrow { .. } => StatusCode::Internal,
| Arrow { .. }
| FileWatch { .. } => StatusCode::Internal,

UnsupportedDataType { .. } => StatusCode::Unsupported,

Expand Down
7 changes: 4 additions & 3 deletions src/servers/src/mysql/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use crate::error::{Error, Result};
use crate::mysql::handler::MysqlInstanceShim;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
use crate::server::{AbortableStream, BaseTcpServer, Server};
use crate::tls::ReloadableTlsServerConfig;

// Default size of ResultSet write buffer: 100KB
const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
Expand Down Expand Up @@ -68,15 +69,15 @@ impl MysqlSpawnRef {
pub struct MysqlSpawnConfig {
// tls config
force_tls: bool,
tls: Option<Arc<ServerConfig>>,
tls: Arc<ReloadableTlsServerConfig>,
// other shim config
reject_no_database: bool,
}

impl MysqlSpawnConfig {
pub fn new(
force_tls: bool,
tls: Option<Arc<ServerConfig>>,
tls: Arc<ReloadableTlsServerConfig>,
reject_no_database: bool,
) -> MysqlSpawnConfig {
MysqlSpawnConfig {
Expand All @@ -87,7 +88,7 @@ impl MysqlSpawnConfig {
}

fn tls(&self) -> Option<Arc<ServerConfig>> {
self.tls.clone()
self.tls.get_server_config()
}
}

Expand Down
27 changes: 13 additions & 14 deletions src/servers/src/postgres/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,47 +29,52 @@ use super::{MakePostgresServerHandler, MakePostgresServerHandlerBuilder};
use crate::error::Result;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
use crate::server::{AbortableStream, BaseTcpServer, Server};
use crate::tls::TlsOption;
use crate::tls::ReloadableTlsServerConfig;

pub struct PostgresServer {
base_server: BaseTcpServer,
make_handler: Arc<MakePostgresServerHandler>,
tls: TlsOption,
tls_server_config: Arc<ReloadableTlsServerConfig>,
}

impl PostgresServer {
/// Creates a new Postgres server with provided query_handler and async runtime
pub fn new(
query_handler: ServerSqlQueryHandlerRef,
tls: TlsOption,
force_tls: bool,
tls_server_config: Arc<ReloadableTlsServerConfig>,
io_runtime: Arc<Runtime>,
user_provider: Option<UserProviderRef>,
) -> PostgresServer {
let make_handler = Arc::new(
MakePostgresServerHandlerBuilder::default()
.query_handler(query_handler.clone())
.user_provider(user_provider.clone())
.force_tls(tls.should_force_tls())
.force_tls(force_tls)
.build()
.unwrap(),
);
PostgresServer {
base_server: BaseTcpServer::create_server("Postgres", io_runtime),
make_handler,
tls,
tls_server_config,
}
}

fn accept(
&self,
io_runtime: Arc<Runtime>,
accepting_stream: AbortableStream,
tls_acceptor: Option<Arc<TlsAcceptor>>,
) -> impl Future<Output = ()> {
let handler_maker = self.make_handler.clone();
let tls_server_config = self.tls_server_config.clone();
accepting_stream.for_each(move |tcp_stream| {
let io_runtime = io_runtime.clone();
let tls_acceptor = tls_acceptor.clone();

let tls_acceptor = tls_server_config
.get_server_config()
.map(|server_config| Arc::new(TlsAcceptor::from(server_config)));

let handler_maker = handler_maker.clone();

async move {
Expand Down Expand Up @@ -119,14 +124,8 @@ impl Server for PostgresServer {
async fn start(&self, listening: SocketAddr) -> Result<SocketAddr> {
let (stream, addr) = self.base_server.bind(listening).await?;

debug!("Starting PostgreSQL with TLS option: {:?}", self.tls);
let tls_acceptor = self
.tls
.setup()?
.map(|server_conf| Arc::new(TlsAcceptor::from(Arc::new(server_conf))));

let io_runtime = self.base_server.io_runtime();
let join_handle = common_runtime::spawn_read(self.accept(io_runtime, stream, tls_acceptor));
let join_handle = common_runtime::spawn_read(self.accept(io_runtime, stream));

self.base_server.start_with(join_handle).await?;
Ok(addr)
Expand Down
Loading

0 comments on commit 3887d20

Please sign in to comment.