Skip to content

Commit

Permalink
fix: optimize code
Browse files Browse the repository at this point in the history
  • Loading branch information
poltao committed May 22, 2024
1 parent bf3e8ab commit 1039baf
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 30 deletions.
3 changes: 1 addition & 2 deletions src/frontend/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ where
max_send_message_size: opts.max_send_message_size.as_bytes() as usize,
tls: opts.tls.clone(),
};
let mut builder = GrpcServerBuilder::new(grpc_config, grpc_runtime);
builder = builder
let builder = GrpcServerBuilder::new(grpc_config, grpc_runtime)
.with_tls_config(opts.tls.clone())
.context(error::InvalidTlsConfigSnafu)?;
Ok(builder)
Expand Down
31 changes: 18 additions & 13 deletions src/servers/src/grpc/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use api::v1::prometheus_gateway_server::PrometheusGatewayServer;
use api::v1::region::region_server::RegionServer;
use arrow_flight::flight_service_server::FlightServiceServer;
use auth::UserProviderRef;
use common_grpc::error::{InvalidConfigFilePathSnafu, Result};
use common_grpc::error::{Error, InvalidConfigFilePathSnafu, Result};
use common_runtime::Runtime;
use opentelemetry_proto::tonic::collector::metrics::v1::metrics_service_server::MetricsServiceServer;
use opentelemetry_proto::tonic::collector::trace::v1::trace_service_server::TraceServiceServer;
Expand All @@ -40,7 +40,7 @@ use crate::grpc::otlp::OtlpService;
use crate::grpc::prom_query_gateway::PrometheusGatewayService;
use crate::prometheus_handler::PrometheusHandlerRef;
use crate::query_handler::OpenTelemetryProtocolHandlerRef;
use crate::tls::{TlsMode, TlsOption};
use crate::tls::TlsOption;

/// Add a gRPC service (`service`) to a `builder`([RoutesBuilder]).
/// This macro will automatically add some gRPC properties to the service.
Expand Down Expand Up @@ -164,18 +164,23 @@ impl GrpcServerBuilder {
}

pub fn with_tls_config(mut self, tls_option: TlsOption) -> Result<Self> {
let tls_config = match tls_option.mode {
TlsMode::Require => {
let cert = std::fs::read_to_string(tls_option.cert_path)
.context(InvalidConfigFilePathSnafu)?;
let key = std::fs::read_to_string(tls_option.key_path)
.context(InvalidConfigFilePathSnafu)?;
let identity = Identity::from_pem(cert, key);
Some(ServerTlsConfig::new().identity(identity))
}
_ => None,
// tonic does not support watching for tls config changes
// so we don't support it either for now
if tls_option.watch {
return Err(Error::NotSupported {
feat: "grpc tls watch".to_string(),
});
}
self.tls_config = if tls_option.should_force_tls() {
let cert = std::fs::read_to_string(tls_option.cert_path)
.context(InvalidConfigFilePathSnafu)?;
let key =
std::fs::read_to_string(tls_option.key_path).context(InvalidConfigFilePathSnafu)?;
let identity = Identity::from_pem(cert, key);
Some(ServerTlsConfig::new().identity(identity))
} else {
None
};
self.tls_config = tls_config;
Ok(self)
}

Expand Down
21 changes: 8 additions & 13 deletions tests-integration/src/test_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ use servers::postgres::PostgresServer;
use servers::query_handler::grpc::ServerGrpcQueryHandlerAdapter;
use servers::query_handler::sql::{ServerSqlQueryHandlerAdapter, SqlQueryHandler};
use servers::server::Server;
use servers::tls::{ReloadableTlsServerConfig, TlsMode};
use servers::tls::ReloadableTlsServerConfig;
use servers::Mode;
use session::context::QueryContext;

Expand Down Expand Up @@ -512,19 +512,14 @@ pub async fn setup_grpc_server_with(
let flight_handler = Arc::new(greptime_request_handler.clone());

let grpc_config = grpc_config.unwrap_or_default();
let grpc_builder = GrpcServerBuilder::new(grpc_config.clone(), runtime);
let grpc_builder = match grpc_config.tls.mode {
TlsMode::Require => grpc_builder.with_tls_config(grpc_config.tls).unwrap(),
_ => grpc_builder,
};
let grpc_builder = GrpcServerBuilder::new(grpc_config.clone(), runtime)
.database_handler(greptime_request_handler)
.flight_handler(flight_handler)
.prometheus_handler(fe_instance_ref.clone(), user_provider)
.with_tls_config(grpc_config.tls)
.unwrap();

let fe_grpc_server = Arc::new(
grpc_builder
.database_handler(greptime_request_handler)
.flight_handler(flight_handler)
.prometheus_handler(fe_instance_ref.clone(), user_provider)
.build(),
);
let fe_grpc_server = Arc::new(grpc_builder.build());

let fe_grpc_addr = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let fe_grpc_addr = fe_grpc_server
Expand Down
30 changes: 28 additions & 2 deletions tests-integration/tests/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;

use api::v1::alter_expr::Kind;
use api::v1::promql_request::Promql;
use api::v1::{
Expand All @@ -25,6 +27,8 @@ use common_catalog::consts::MITO_ENGINE;
use common_grpc::channel_manager::ClientTlsOption;
use common_query::Output;
use common_recordbatch::RecordBatches;
use common_runtime::Runtime;
use servers::grpc::builder::GrpcServerBuilder;
use servers::grpc::GrpcServerConfig;
use servers::http::prometheus::{
PromData, PromQueryResult, PromSeriesMatrix, PromSeriesVector, PrometheusJsonResponse,
Expand Down Expand Up @@ -152,6 +156,7 @@ pub async fn test_grpc_zstd_compression(store_type: StorageType) {
let config = GrpcServerConfig {
max_recv_message_size: 1024,
max_send_message_size: 1024,
tls: TlsOption::default(),
};
let (addr, mut guard, fe_grpc_server) =
setup_grpc_server_with(store_type, "auto_create_table", None, Some(config)).await;
Expand Down Expand Up @@ -766,9 +771,8 @@ pub async fn test_grpc_tls_config(store_type: StorageType) {
);
db.sql("show tables;").await.unwrap();
}

// test corrupted client key
{
// test corrupted client key
client_tls.client_key_path = client_corrupted;
let grpc_client = Client::with_tls_and_urls(vec![addr], client_tls.clone()).unwrap();
let db = Database::new_with_dbname(
Expand All @@ -778,6 +782,28 @@ pub async fn test_grpc_tls_config(store_type: StorageType) {
let re = db.sql("show tables;").await;
assert!(re.is_err());
}
// test grpc unsupported tls watch
{
let tls = TlsOption {
watch: true,
..Default::default()
};
let config = GrpcServerConfig {
max_recv_message_size: 1024,
max_send_message_size: 1024,
tls,
};
let runtime = Arc::new(Runtime::builder().build().unwrap());
let grpc_builder =
GrpcServerBuilder::new(config.clone(), runtime).with_tls_config(config.tls);
assert!(grpc_builder.is_err());
}

let _ = fe_grpc_server.shutdown().await;
guard.remove_all().await;
}

#[tokio::test(flavor = "multi_thread")]
async fn test_grpc() {
test_grpc_tls_config(StorageType::File).await;
}

0 comments on commit 1039baf

Please sign in to comment.