diff --git a/src/client/src/client.rs b/src/client/src/client.rs index 5e82295c16f6..6fa3866d9a9d 100644 --- a/src/client/src/client.rs +++ b/src/client/src/client.rs @@ -19,7 +19,7 @@ use api::v1::prometheus_gateway_client::PrometheusGatewayClient; use api::v1::region::region_client::RegionClient as PbRegionClient; use api::v1::HealthCheckRequest; use arrow_flight::flight_service_client::FlightServiceClient; -use common_grpc::channel_manager::ChannelManager; +use common_grpc::channel_manager::{ChannelConfig, ChannelManager, ClientTlsOption}; use parking_lot::RwLock; use snafu::{OptionExt, ResultExt}; use tonic::transport::Channel; @@ -86,6 +86,17 @@ impl Client { Self::with_manager_and_urls(ChannelManager::new(), urls) } + pub fn with_tls_and_urls(urls: A, client_tls: ClientTlsOption) -> Result + where + U: AsRef, + A: AsRef<[U]>, + { + let channel_config = ChannelConfig::default().client_tls_config(client_tls); + let channel_manager = ChannelManager::with_tls_config(channel_config) + .context(error::CreateTlsChannelSnafu)?; + Ok(Self::with_manager_and_urls(channel_manager, urls)) + } + pub fn with_manager_and_urls(channel_manager: ChannelManager, urls: A) -> Self where U: AsRef, diff --git a/src/client/src/error.rs b/src/client/src/error.rs index 29197450b62d..e265662e9f2a 100644 --- a/src/client/src/error.rs +++ b/src/client/src/error.rs @@ -82,6 +82,13 @@ pub enum Error { source: common_grpc::error::Error, }, + #[snafu(display("Failed to create Tls channel manager"))] + CreateTlsChannel { + #[snafu(implicit)] + location: Location, + source: common_grpc::error::Error, + }, + #[snafu(display("Failed to request RegionServer, code: {}", code))] RegionServer { code: Code, @@ -129,9 +136,9 @@ impl ErrorExt for Error { Error::FlightGet { source, .. } | Error::HandleRequest { source, .. } | Error::RegionServer { source, .. } => source.status_code(), - Error::CreateChannel { source, .. } | Error::ConvertFlightData { source, .. } => { - source.status_code() - } + Error::CreateChannel { source, .. } + | Error::ConvertFlightData { source, .. } + | Error::CreateTlsChannel { source, .. } => source.status_code(), Error::IllegalGrpcClientState { .. } => StatusCode::Unexpected, } } diff --git a/tests-integration/src/test_util.rs b/tests-integration/src/test_util.rs index 548404ece08d..9894606e79bc 100644 --- a/tests-integration/src/test_util.rs +++ b/tests-integration/src/test_util.rs @@ -50,7 +50,7 @@ use servers::postgres::PostgresServer; use servers::query_handler::grpc::ServerGrpcQueryHandlerAdapter; use servers::query_handler::sql::{ServerSqlQueryHandlerAdapter, SqlQueryHandler}; use servers::server::Server; -use servers::tls::ReloadableTlsServerConfig; +use servers::tls::{ReloadableTlsServerConfig, TlsMode}; use servers::Mode; use session::context::QueryContext; @@ -511,8 +511,15 @@ pub async fn setup_grpc_server_with( let flight_handler = Arc::new(greptime_request_handler.clone()); + let grpc_config = grpc_config.unwrap_or_default(); + let grpc_builder = GrpcServerBuilder::new(grpc_config.clone(), runtime); + let grpc_builder = match grpc_config.tls.mode { + TlsMode::Require => grpc_builder.with_tls_config(grpc_config.tls).unwrap(), + _ => grpc_builder, + }; + let fe_grpc_server = Arc::new( - GrpcServerBuilder::new(grpc_config.unwrap_or_default(), runtime) + grpc_builder .database_handler(greptime_request_handler) .flight_handler(flight_handler) .prometheus_handler(fe_instance_ref.clone(), user_provider) diff --git a/tests-integration/tests/grpc.rs b/tests-integration/tests/grpc.rs index df5198a99969..756ed63bcc7d 100644 --- a/tests-integration/tests/grpc.rs +++ b/tests-integration/tests/grpc.rs @@ -22,6 +22,7 @@ use api::v1::{ use auth::user_provider_from_option; use client::{Client, OutputData, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_catalog::consts::MITO_ENGINE; +use common_grpc::channel_manager::ClientTlsOption; use common_query::Output; use common_recordbatch::RecordBatches; use servers::grpc::GrpcServerConfig; @@ -30,7 +31,7 @@ use servers::http::prometheus::{ PrometheusResponse, }; use servers::server::Server; -use servers::tls::TlsOption; +use servers::tls::{TlsMode, TlsOption}; use tests_integration::database::Database; use tests_integration::test_util::{ setup_grpc_server, setup_grpc_server_with, setup_grpc_server_with_user_provider, StorageType, @@ -77,6 +78,7 @@ macro_rules! grpc_tests { test_health_check, test_prom_gateway_query, test_grpc_timezone, + test_grpc_tls_config, ); )* }; @@ -704,3 +706,48 @@ async fn to_batch(output: Output) -> String { .pretty_print() .unwrap() } + +pub async fn test_grpc_tls_config(store_type: StorageType) { + let comm_dir = std::path::PathBuf::from_iter([ + std::env!("CARGO_RUSTC_CURRENT_DIR"), + "src/common/grpc/tests", + ]); + let ca_path = comm_dir.join("tls/server.cert.pem"); + let cert_path = comm_dir.join("tls/client.cert.pem"); + let key_path = comm_dir.join("tls/client.key.pem"); + let ca_path_string = ca_path.to_str().unwrap().to_string(); + let cert_path_str = cert_path.to_str().unwrap(); + let key_path_str = key_path.to_str().unwrap(); + + let tls = TlsOption::new( + Some(TlsMode::Require), + Some(cert_path_str.to_string()), + Some(key_path_str.to_string()), + ); + let config = GrpcServerConfig { + max_recv_message_size: 1024, + max_send_message_size: 1024, + tls, + }; + let (addr, mut guard, fe_grpc_server) = + setup_grpc_server_with(store_type, "tls_create_table", None, Some(config)).await; + + let client_tls = ClientTlsOption { + server_ca_cert_path: ca_path_string, + client_cert_path: cert_path_str.to_string(), + client_key_path: key_path_str.to_string(), + }; + let grpc_client = Client::with_tls_and_urls(vec![addr], client_tls).unwrap(); + let db = Database::new_with_dbname( + format!("{}-{}", DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME), + grpc_client, + ); + db.sql("show tables;").await.unwrap(); + let _ = fe_grpc_server.shutdown().await; + guard.remove_all().await; +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_grpc_server_builder() { + test_grpc_tls_config(StorageType::File).await; +}