diff --git a/Cargo.lock b/Cargo.lock index 84d28103f312..29ea367b9825 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3585,6 +3585,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "fsevent-sys" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76ee7a02da4d231650c7cea31349b889be2f45ddb3ef3032d2ec8185f6313fd2" +dependencies = [ + "libc", +] + [[package]] name = "fst" version = "0.4.7" @@ -4369,6 +4378,26 @@ dependencies = [ "snafu", ] +[[package]] +name = "inotify" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8069d3ec154eb856955c1c0fbffefbf5f3c40a104ec912d4797314c1801abff" +dependencies = [ + "bitflags 1.3.2", + "inotify-sys", + "libc", +] + +[[package]] +name = "inotify-sys" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb" +dependencies = [ + "libc", +] + [[package]] name = "instant" version = "0.1.12" @@ -4566,6 +4595,26 @@ dependencies = [ "indexmap 2.1.0", ] +[[package]] +name = "kqueue" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7447f1ca1b7b563588a205fe93dea8df60fd981423a768bc1c0ded35ed147d0c" +dependencies = [ + "kqueue-sys", + "libc", +] + +[[package]] +name = "kqueue-sys" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed9625ffda8729b85e45cf04090035ac368927b8cebc34898e7c120f52e4838b" +dependencies = [ + "bitflags 1.3.2", + "libc", +] + [[package]] name = "lalrpop" version = "0.19.12" @@ -5655,6 +5704,25 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" +[[package]] +name = "notify" +version = "6.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6205bd8bb1e454ad2e27422015fb5e4f2bcc7e08fa8f27058670d208324a4d2d" +dependencies = [ + "bitflags 2.4.1", + "crossbeam-channel", + "filetime", + "fsevent-sys", + "inotify", + "kqueue", + "libc", + "log", + "mio", + "walkdir", + "windows-sys 0.48.0", +] + [[package]] name = "ntapi" version = "0.4.1" @@ -8995,6 +9063,7 @@ dependencies = [ "lazy_static", "mime_guess", "mysql_async", + "notify", "once_cell", "openmetrics-parser", "opensrv-mysql", @@ -9027,6 +9096,7 @@ dependencies = [ "sql", "strum 0.25.0", "table", + "tempfile", "tikv-jemalloc-ctl", "tokio", "tokio-postgres", diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index 9fcc7372c858..4866dd7fab78 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -18,7 +18,6 @@ use std::sync::Arc; use auth::UserProviderRef; use common_base::Plugins; use common_runtime::Builder as RuntimeBuilder; -use servers::error::InternalIoSnafu; use servers::grpc::builder::GrpcServerBuilder; use servers::grpc::greptime_handler::GreptimeRequestHandler; use servers::grpc::{GrpcServer, GrpcServerConfig}; @@ -30,6 +29,7 @@ use servers::postgres::PostgresServer; use servers::query_handler::grpc::ServerGrpcQueryHandlerAdapter; use servers::query_handler::sql::ServerSqlQueryHandlerAdapter; use servers::server::{Server, ServerHandlers}; +use servers::tls::{watch_tls_config, ReloadableTlsServerConfig}; use snafu::ResultExt; use crate::error::{self, Result, StartServerSnafu}; @@ -195,6 +195,12 @@ where let opts = &opts.mysql; let mysql_addr = parse_addr(&opts.addr)?; + let tls_server_config = Arc::new( + ReloadableTlsServerConfig::try_new(opts.tls.clone()).context(StartServerSnafu)?, + ); + + watch_tls_config(tls_server_config.clone()).context(StartServerSnafu)?; + let mysql_io_runtime = Arc::new( RuntimeBuilder::default() .worker_threads(opts.runtime_size) @@ -210,11 +216,7 @@ where )), Arc::new(MysqlSpawnConfig::new( opts.tls.should_force_tls(), - opts.tls - .setup() - .context(InternalIoSnafu) - .context(StartServerSnafu)? - .map(Arc::new), + tls_server_config, opts.reject_no_database.unwrap_or(false), )), ); @@ -226,6 +228,12 @@ where let opts = &opts.postgres; let pg_addr = parse_addr(&opts.addr)?; + let tls_server_config = Arc::new( + ReloadableTlsServerConfig::try_new(opts.tls.clone()).context(StartServerSnafu)?, + ); + + watch_tls_config(tls_server_config.clone()).context(StartServerSnafu)?; + let pg_io_runtime = Arc::new( RuntimeBuilder::default() .worker_threads(opts.runtime_size) @@ -236,7 +244,8 @@ where let pg_server = Box::new(PostgresServer::new( ServerSqlQueryHandlerAdapter::arc(instance.clone()), - opts.tls.clone(), + opts.tls.should_force_tls(), + tls_server_config, pg_io_runtime, user_provider.clone(), )) as Box; diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index d0aba8af463d..69f318815493 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -59,6 +59,7 @@ influxdb_line_protocol = { git = "https://github.com/evenyag/influxdb_iox", bran itertools.workspace = true lazy_static.workspace = true mime_guess = "2.0" +notify = "6.1" once_cell.workspace = true openmetrics-parser = "0.4" opensrv-mysql = "0.7.0" @@ -121,6 +122,7 @@ script = { workspace = true, features = ["python"] } serde_json.workspace = true session = { workspace = true, features = ["testing"] } table.workspace = true +tempfile = "3.0.0" tokio-postgres = "0.7" tokio-postgres-rustls = "0.11" tokio-test = "0.4" diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index 4ebbdc55445c..2640454a94af 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -441,6 +441,12 @@ pub enum Error { "Invalid parameter, physical_table is not expected when metric engine is disabled" ))] UnexpectedPhysicalTable { location: Location }, + + #[snafu(display("Failed to initialize a watcher for file"))] + FileWatch { + #[snafu(source)] + error: notify::Error, + }, } pub type Result = std::result::Result; @@ -462,7 +468,8 @@ impl ErrorExt for Error { | CatalogError { .. } | GrpcReflectionService { .. } | BuildHttpResponse { .. } - | Arrow { .. } => StatusCode::Internal, + | Arrow { .. } + | FileWatch { .. } => StatusCode::Internal, UnsupportedDataType { .. } => StatusCode::Unsupported, diff --git a/src/servers/src/mysql/server.rs b/src/servers/src/mysql/server.rs index 785a47a120bf..a71a1dc62313 100644 --- a/src/servers/src/mysql/server.rs +++ b/src/servers/src/mysql/server.rs @@ -33,6 +33,7 @@ use crate::error::{Error, Result}; use crate::mysql::handler::MysqlInstanceShim; use crate::query_handler::sql::ServerSqlQueryHandlerRef; use crate::server::{AbortableStream, BaseTcpServer, Server}; +use crate::tls::ReloadableTlsServerConfig; // Default size of ResultSet write buffer: 100KB const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024; @@ -68,7 +69,7 @@ impl MysqlSpawnRef { pub struct MysqlSpawnConfig { // tls config force_tls: bool, - tls: Option>, + tls: Arc, // other shim config reject_no_database: bool, } @@ -76,7 +77,7 @@ pub struct MysqlSpawnConfig { impl MysqlSpawnConfig { pub fn new( force_tls: bool, - tls: Option>, + tls: Arc, reject_no_database: bool, ) -> MysqlSpawnConfig { MysqlSpawnConfig { @@ -87,7 +88,7 @@ impl MysqlSpawnConfig { } fn tls(&self) -> Option> { - self.tls.clone() + self.tls.get_server_config() } } diff --git a/src/servers/src/postgres/server.rs b/src/servers/src/postgres/server.rs index 6a4d7a112dda..3ed9f5f40ebf 100644 --- a/src/servers/src/postgres/server.rs +++ b/src/servers/src/postgres/server.rs @@ -29,19 +29,20 @@ use super::{MakePostgresServerHandler, MakePostgresServerHandlerBuilder}; use crate::error::Result; use crate::query_handler::sql::ServerSqlQueryHandlerRef; use crate::server::{AbortableStream, BaseTcpServer, Server}; -use crate::tls::TlsOption; +use crate::tls::ReloadableTlsServerConfig; pub struct PostgresServer { base_server: BaseTcpServer, make_handler: Arc, - tls: TlsOption, + tls_server_config: Arc, } impl PostgresServer { /// Creates a new Postgres server with provided query_handler and async runtime pub fn new( query_handler: ServerSqlQueryHandlerRef, - tls: TlsOption, + force_tls: bool, + tls_server_config: Arc, io_runtime: Arc, user_provider: Option, ) -> PostgresServer { @@ -49,14 +50,14 @@ impl PostgresServer { MakePostgresServerHandlerBuilder::default() .query_handler(query_handler.clone()) .user_provider(user_provider.clone()) - .force_tls(tls.should_force_tls()) + .force_tls(force_tls) .build() .unwrap(), ); PostgresServer { base_server: BaseTcpServer::create_server("Postgres", io_runtime), make_handler, - tls, + tls_server_config, } } @@ -64,12 +65,16 @@ impl PostgresServer { &self, io_runtime: Arc, accepting_stream: AbortableStream, - tls_acceptor: Option>, ) -> impl Future { let handler_maker = self.make_handler.clone(); + let tls_server_config = self.tls_server_config.clone(); accepting_stream.for_each(move |tcp_stream| { let io_runtime = io_runtime.clone(); - let tls_acceptor = tls_acceptor.clone(); + + let tls_acceptor = tls_server_config + .get_server_config() + .map(|server_config| Arc::new(TlsAcceptor::from(server_config))); + let handler_maker = handler_maker.clone(); async move { @@ -119,14 +124,8 @@ impl Server for PostgresServer { async fn start(&self, listening: SocketAddr) -> Result { let (stream, addr) = self.base_server.bind(listening).await?; - debug!("Starting PostgreSQL with TLS option: {:?}", self.tls); - let tls_acceptor = self - .tls - .setup()? - .map(|server_conf| Arc::new(TlsAcceptor::from(Arc::new(server_conf)))); - let io_runtime = self.base_server.io_runtime(); - let join_handle = common_runtime::spawn_read(self.accept(io_runtime, stream, tls_acceptor)); + let join_handle = common_runtime::spawn_read(self.accept(io_runtime, stream)); self.base_server.start_with(join_handle).await?; Ok(addr) diff --git a/src/servers/src/tls.rs b/src/servers/src/tls.rs index deb695bb4558..1c0be507e723 100644 --- a/src/servers/src/tls.rs +++ b/src/servers/src/tls.rs @@ -13,14 +13,23 @@ // limitations under the License. use std::fs::File; -use std::io::{BufReader, Error, ErrorKind}; - +use std::io::{BufReader, Error as IoError, ErrorKind}; +use std::path::Path; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::mpsc::channel; +use std::sync::{Arc, RwLock}; + +use common_telemetry::{error, info}; +use notify::{EventKind, RecursiveMode, Watcher}; use rustls::ServerConfig; use rustls_pemfile::{certs, pkcs8_private_keys, rsa_private_keys}; use rustls_pki_types::{CertificateDer, PrivateKeyDer}; use serde::{Deserialize, Serialize}; +use snafu::ResultExt; use strum::EnumString; +use crate::error::{FileWatchSnafu, InternalIoSnafu, Result}; + /// TlsMode is used for Mysql and Postgres server start up. #[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq, EnumString)] #[serde(rename_all = "snake_case")] @@ -73,27 +82,38 @@ impl TlsOption { tls_option } - pub fn setup(&self) -> Result, Error> { + pub fn setup(&self) -> Result> { if let TlsMode::Disable = self.mode { return Ok(None); } - let cert = certs(&mut BufReader::new(File::open(&self.cert_path)?)) - .collect::, Error>>()?; + let cert = certs(&mut BufReader::new( + File::open(&self.cert_path).context(InternalIoSnafu)?, + )) + .collect::, IoError>>() + .context(InternalIoSnafu)?; let key = { - let mut pkcs8 = pkcs8_private_keys(&mut BufReader::new(File::open(&self.key_path)?)) - .map(|key| key.map(PrivateKeyDer::from)) - .collect::, Error>>()?; + let mut pkcs8 = pkcs8_private_keys(&mut BufReader::new( + File::open(&self.key_path).context(InternalIoSnafu)?, + )) + .map(|key| key.map(PrivateKeyDer::from)) + .collect::, IoError>>() + .context(InternalIoSnafu)?; + if !pkcs8.is_empty() { pkcs8.remove(0) } else { - let mut rsa = rsa_private_keys(&mut BufReader::new(File::open(&self.key_path)?)) - .map(|key| key.map(PrivateKeyDer::from)) - .collect::, Error>>()?; + let mut rsa = rsa_private_keys(&mut BufReader::new( + File::open(&self.key_path).context(InternalIoSnafu)?, + )) + .map(|key| key.map(PrivateKeyDer::from)) + .collect::, IoError>>() + .context(InternalIoSnafu)?; if !rsa.is_empty() { rsa.remove(0) } else { - return Err(Error::new(ErrorKind::InvalidInput, "invalid key")); + return Err(IoError::new(ErrorKind::InvalidInput, "invalid key")) + .context(InternalIoSnafu); } } }; @@ -110,6 +130,104 @@ impl TlsOption { pub fn should_force_tls(&self) -> bool { !matches!(self.mode, TlsMode::Disable | TlsMode::Prefer) } + + pub fn cert_path(&self) -> &Path { + Path::new(&self.cert_path) + } + + pub fn key_path(&self) -> &Path { + Path::new(&self.key_path) + } +} + +/// A mutable container for TLS server config +/// +/// This struct allows dynamic reloading of server certificates and keys +pub struct ReloadableTlsServerConfig { + tls_option: TlsOption, + config: RwLock>>, + version: AtomicUsize, +} + +impl ReloadableTlsServerConfig { + /// Create server config by loading configuration from `TlsOption` + pub fn try_new(tls_option: TlsOption) -> Result { + let server_config = tls_option.setup()?; + Ok(Self { + tls_option, + config: RwLock::new(server_config.map(Arc::new)), + version: AtomicUsize::new(0), + }) + } + + /// Reread server certificates and keys from file system. + pub fn reload(&self) -> Result<()> { + let server_config = self.tls_option.setup()?; + *self.config.write().unwrap() = server_config.map(Arc::new); + self.version.fetch_add(1, Ordering::Relaxed); + Ok(()) + } + + /// Get the server config hold by this container + pub fn get_server_config(&self) -> Option> { + self.config.read().unwrap().clone() + } + + /// Get associated `TlsOption` + pub fn get_tls_option(&self) -> &TlsOption { + &self.tls_option + } + + /// Get version of current config + /// + /// this version will auto increase when server config get reloaded. + pub fn get_version(&self) -> usize { + self.version.load(Ordering::Relaxed) + } +} + +pub fn watch_tls_config(tls_server_config: Arc) -> Result<()> { + if tls_server_config.get_tls_option().mode == TlsMode::Disable { + return Ok(()); + } + + let tls_server_config_for_watcher = tls_server_config.clone(); + + let (tx, rx) = channel::>(); + let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu)?; + + watcher + .watch( + tls_server_config.get_tls_option().cert_path(), + RecursiveMode::NonRecursive, + ) + .context(FileWatchSnafu)?; + + watcher + .watch( + tls_server_config.get_tls_option().key_path(), + RecursiveMode::NonRecursive, + ) + .context(FileWatchSnafu)?; + + std::thread::spawn(move || { + let _watcher = watcher; + while let Ok(res) = rx.recv() { + if let Ok(event) = res { + match event.kind { + EventKind::Modify(_) | EventKind::Create(_) => { + info!("Detected TLS cert/key file change: {:?}", event); + if let Err(err) = tls_server_config_for_watcher.reload() { + error!(err; "Failed to reload TLS server config"); + } + } + _ => {} + } + } + } + }); + + Ok(()) } #[cfg(test)] @@ -237,4 +355,44 @@ mod tests { assert!(!t.key_path.is_empty()); assert!(!t.cert_path.is_empty()); } + + #[test] + fn test_tls_file_change_watch() { + let dir = tempfile::tempdir().unwrap(); + let cert_path = dir.path().join("serevr.crt"); + let key_path = dir.path().join("server.key"); + + std::fs::copy("tests/ssl/server.crt", &cert_path).expect("failed to copy cert to tmpdir"); + std::fs::copy("tests/ssl/server-rsa.key", &key_path).expect("failed to copy key to tmpdir"); + + let server_tls = TlsOption { + mode: TlsMode::Require, + cert_path: cert_path + .clone() + .into_os_string() + .into_string() + .expect("failed to convert path to string"), + key_path: key_path + .clone() + .into_os_string() + .into_string() + .expect("failed to convert path to string"), + }; + + let server_config = Arc::new( + ReloadableTlsServerConfig::try_new(server_tls).expect("failed to create server config"), + ); + watch_tls_config(server_config.clone()).expect("failed to watch server config"); + + assert_eq!(0, server_config.get_version()); + assert!(server_config.get_server_config().is_some()); + + std::fs::copy("tests/ssl/server-pkcs8.key", &key_path) + .expect("failed to copy key to tmpdir"); + + // waiting for async load + std::thread::sleep(std::time::Duration::from_millis(100)); + assert!(server_config.get_version() > 1); + assert!(server_config.get_server_config().is_some()); + } } diff --git a/src/servers/tests/mysql/mysql_server_test.rs b/src/servers/tests/mysql/mysql_server_test.rs index 18f5865a05e9..3cbac4ee9369 100644 --- a/src/servers/tests/mysql/mysql_server_test.rs +++ b/src/servers/tests/mysql/mysql_server_test.rs @@ -30,7 +30,7 @@ use rand::Rng; use servers::error::Result; use servers::mysql::server::{MysqlServer, MysqlSpawnConfig, MysqlSpawnRef}; use servers::server::Server; -use servers::tls::TlsOption; +use servers::tls::{ReloadableTlsServerConfig, TlsOption}; use table::test_util::MemTable; use table::TableRef; @@ -59,12 +59,17 @@ fn create_mysql_server(table: TableRef, opts: MysqlOpts<'_>) -> Result);