Skip to content

Commit

Permalink
feat: add integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
poltao committed May 17, 2024
1 parent a36b285 commit 7ed1f65
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 7 deletions.
13 changes: 12 additions & 1 deletion src/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use api::v1::prometheus_gateway_client::PrometheusGatewayClient;
use api::v1::region::region_client::RegionClient as PbRegionClient;
use api::v1::HealthCheckRequest;
use arrow_flight::flight_service_client::FlightServiceClient;
use common_grpc::channel_manager::ChannelManager;
use common_grpc::channel_manager::{ChannelConfig, ChannelManager, ClientTlsOption};
use parking_lot::RwLock;
use snafu::{OptionExt, ResultExt};
use tonic::transport::Channel;
Expand Down Expand Up @@ -86,6 +86,17 @@ impl Client {
Self::with_manager_and_urls(ChannelManager::new(), urls)
}

pub fn with_tls_and_urls<U, A>(urls: A, client_tls: ClientTlsOption) -> Result<Self>
where
U: AsRef<str>,
A: AsRef<[U]>,
{
let channel_config = ChannelConfig::default().client_tls_config(client_tls);
let channel_manager = ChannelManager::with_tls_config(channel_config)
.context(error::CreateTlsChannelSnafu)?;
Ok(Self::with_manager_and_urls(channel_manager, urls))
}

pub fn with_manager_and_urls<U, A>(channel_manager: ChannelManager, urls: A) -> Self
where
U: AsRef<str>,
Expand Down
13 changes: 10 additions & 3 deletions src/client/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ pub enum Error {
source: common_grpc::error::Error,
},

#[snafu(display("Failed to create Tls channel manager"))]
CreateTlsChannel {
#[snafu(implicit)]
location: Location,
source: common_grpc::error::Error,
},

#[snafu(display("Failed to request RegionServer, code: {}", code))]
RegionServer {
code: Code,
Expand Down Expand Up @@ -129,9 +136,9 @@ impl ErrorExt for Error {
Error::FlightGet { source, .. }
| Error::HandleRequest { source, .. }
| Error::RegionServer { source, .. } => source.status_code(),
Error::CreateChannel { source, .. } | Error::ConvertFlightData { source, .. } => {
source.status_code()
}
Error::CreateChannel { source, .. }
| Error::ConvertFlightData { source, .. }
| Error::CreateTlsChannel { source, .. } => source.status_code(),
Error::IllegalGrpcClientState { .. } => StatusCode::Unexpected,
}
}
Expand Down
11 changes: 9 additions & 2 deletions tests-integration/src/test_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ use servers::postgres::PostgresServer;
use servers::query_handler::grpc::ServerGrpcQueryHandlerAdapter;
use servers::query_handler::sql::{ServerSqlQueryHandlerAdapter, SqlQueryHandler};
use servers::server::Server;
use servers::tls::ReloadableTlsServerConfig;
use servers::tls::{ReloadableTlsServerConfig, TlsMode};
use servers::Mode;
use session::context::QueryContext;

Expand Down Expand Up @@ -511,8 +511,15 @@ pub async fn setup_grpc_server_with(

let flight_handler = Arc::new(greptime_request_handler.clone());

let grpc_config = grpc_config.unwrap_or_default();
let grpc_builder = GrpcServerBuilder::new(grpc_config.clone(), runtime);
let grpc_builder = match grpc_config.tls.mode {
TlsMode::Require => grpc_builder.with_tls_config(grpc_config.tls).unwrap(),
_ => grpc_builder,
};

let fe_grpc_server = Arc::new(
GrpcServerBuilder::new(grpc_config.unwrap_or_default(), runtime)
grpc_builder
.database_handler(greptime_request_handler)
.flight_handler(flight_handler)
.prometheus_handler(fe_instance_ref.clone(), user_provider)
Expand Down
49 changes: 48 additions & 1 deletion tests-integration/tests/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use api::v1::{
use auth::user_provider_from_option;
use client::{Client, OutputData, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_catalog::consts::MITO_ENGINE;
use common_grpc::channel_manager::ClientTlsOption;
use common_query::Output;
use common_recordbatch::RecordBatches;
use servers::grpc::GrpcServerConfig;
Expand All @@ -30,7 +31,7 @@ use servers::http::prometheus::{
PrometheusResponse,
};
use servers::server::Server;
use servers::tls::TlsOption;
use servers::tls::{TlsMode, TlsOption};
use tests_integration::database::Database;
use tests_integration::test_util::{
setup_grpc_server, setup_grpc_server_with, setup_grpc_server_with_user_provider, StorageType,
Expand Down Expand Up @@ -77,6 +78,7 @@ macro_rules! grpc_tests {
test_health_check,
test_prom_gateway_query,
test_grpc_timezone,
test_grpc_tls_config,
);
)*
};
Expand Down Expand Up @@ -704,3 +706,48 @@ async fn to_batch(output: Output) -> String {
.pretty_print()
.unwrap()
}

pub async fn test_grpc_tls_config(store_type: StorageType) {
let comm_dir = std::path::PathBuf::from_iter([
std::env!("CARGO_RUSTC_CURRENT_DIR"),
"src/common/grpc/tests",
]);
let ca_path = comm_dir.join("tls/server.cert.pem");
let cert_path = comm_dir.join("tls/client.cert.pem");
let key_path = comm_dir.join("tls/client.key.pem");
let ca_path_string = ca_path.to_str().unwrap().to_string();
let cert_path_str = cert_path.to_str().unwrap();
let key_path_str = key_path.to_str().unwrap();

let tls = TlsOption::new(
Some(TlsMode::Require),
Some(cert_path_str.to_string()),
Some(key_path_str.to_string()),
);
let config = GrpcServerConfig {
max_recv_message_size: 1024,
max_send_message_size: 1024,
tls,
};
let (addr, mut guard, fe_grpc_server) =
setup_grpc_server_with(store_type, "tls_create_table", None, Some(config)).await;

let client_tls = ClientTlsOption {
server_ca_cert_path: ca_path_string,
client_cert_path: cert_path_str.to_string(),
client_key_path: key_path_str.to_string(),
};
let grpc_client = Client::with_tls_and_urls(vec![addr], client_tls).unwrap();
let db = Database::new_with_dbname(
format!("{}-{}", DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME),
grpc_client,
);
db.sql("show tables;").await.unwrap();
let _ = fe_grpc_server.shutdown().await;
guard.remove_all().await;
}

#[tokio::test(flavor = "multi_thread")]
async fn test_grpc_server_builder() {
test_grpc_tls_config(StorageType::File).await;
}

0 comments on commit 7ed1f65

Please sign in to comment.