From 47ece851d86e441371ac0cddb69824d951506c67 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Thu, 4 Jul 2024 15:01:52 +0800 Subject: [PATCH] refactor: demonstrate graceful shutdown on compute node (#17533) Signed-off-by: Bugen Zhao --- Cargo.lock | 1 + src/cmd/src/lib.rs | 24 ++-- src/cmd_all/src/bin/risingwave.rs | 12 +- src/cmd_all/src/standalone.rs | 12 +- src/common/Cargo.toml | 1 + src/common/src/util/mod.rs | 1 + src/compute/src/lib.rs | 15 +-- src/compute/src/server.rs | 109 ++++++++---------- src/connector/Cargo.toml | 2 +- .../compaction_test/src/bin/compaction.rs | 2 +- .../compaction_test/src/bin/delete_range.rs | 2 +- src/tests/simulation/src/cluster.rs | 3 +- .../tests/integration_tests/batch/mod.rs | 3 +- src/utils/runtime/src/lib.rs | 73 +++++++++++- 14 files changed, 158 insertions(+), 102 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4b24429682e00..c99a5c162b549 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10884,6 +10884,7 @@ dependencies = [ "thiserror-ext", "tinyvec", "tokio-retry", + "tokio-util", "toml 0.8.12", "tower-layer", "tower-service", diff --git a/src/cmd/src/lib.rs b/src/cmd/src/lib.rs index 59ffda9e76557..2225b1d0cd530 100644 --- a/src/cmd/src/lib.rs +++ b/src/cmd/src/lib.rs @@ -37,27 +37,31 @@ risingwave_expr_impl::enable!(); // Entry point functions. -pub fn compute(opts: ComputeNodeOpts) { +pub fn compute(opts: ComputeNodeOpts) -> ! { init_risingwave_logger(LoggerSettings::from_opts(&opts)); - main_okk(risingwave_compute::start(opts)); + main_okk(|shutdown| risingwave_compute::start(opts, shutdown)); } -pub fn meta(opts: MetaNodeOpts) { +pub fn meta(opts: MetaNodeOpts) -> ! { init_risingwave_logger(LoggerSettings::from_opts(&opts)); - main_okk(risingwave_meta_node::start(opts)); + // TODO(shutdown): pass the shutdown token + main_okk(|_| risingwave_meta_node::start(opts)); } -pub fn frontend(opts: FrontendOpts) { +pub fn frontend(opts: FrontendOpts) -> ! { init_risingwave_logger(LoggerSettings::from_opts(&opts)); - main_okk(risingwave_frontend::start(opts)); + // TODO(shutdown): pass the shutdown token + main_okk(|_| risingwave_frontend::start(opts)); } -pub fn compactor(opts: CompactorOpts) { +pub fn compactor(opts: CompactorOpts) -> ! { init_risingwave_logger(LoggerSettings::from_opts(&opts)); - main_okk(risingwave_compactor::start(opts)); + // TODO(shutdown): pass the shutdown token + main_okk(|_| risingwave_compactor::start(opts)); } -pub fn ctl(opts: CtlOpts) { +pub fn ctl(opts: CtlOpts) -> ! { init_risingwave_logger(LoggerSettings::new("ctl").stderr(true)); - main_okk(risingwave_ctl::start(opts)); + // TODO(shutdown): pass the shutdown token + main_okk(|_| risingwave_ctl::start(opts)); } diff --git a/src/cmd_all/src/bin/risingwave.rs b/src/cmd_all/src/bin/risingwave.rs index 5a4e55c981043..13c73217b77cc 100644 --- a/src/cmd_all/src/bin/risingwave.rs +++ b/src/cmd_all/src/bin/risingwave.rs @@ -110,7 +110,7 @@ enum Component { impl Component { /// Start the component from the given `args` without `argv[0]`. - fn start(self, matches: &ArgMatches) { + fn start(self, matches: &ArgMatches) -> ! { eprintln!("launching `{}`", self); fn parse_opts(matches: &ArgMatches) -> T { @@ -224,19 +224,20 @@ fn main() { component.start(&matches); } -fn standalone(opts: StandaloneOpts) { +fn standalone(opts: StandaloneOpts) -> ! { let opts = risingwave_cmd_all::parse_standalone_opt_args(&opts); let settings = risingwave_rt::LoggerSettings::from_opts(&opts) .with_target("risingwave_storage", Level::WARN) .with_thread_name(true); risingwave_rt::init_risingwave_logger(settings); - risingwave_rt::main_okk(risingwave_cmd_all::standalone(opts)).unwrap(); + // TODO(shutdown): pass the shutdown token + risingwave_rt::main_okk(|_| risingwave_cmd_all::standalone(opts)); } /// For single node, the internals are just a config mapping from its /// high level options to standalone mode node-level options. /// We will start a standalone instance, with all nodes in the same process. -fn single_node(opts: SingleNodeOpts) { +fn single_node(opts: SingleNodeOpts) -> ! { if env::var(TELEMETRY_CLUSTER_TYPE).is_err() { env::set_var(TELEMETRY_CLUSTER_TYPE, TELEMETRY_CLUSTER_TYPE_SINGLE_NODE); } @@ -245,7 +246,8 @@ fn single_node(opts: SingleNodeOpts) { .with_target("risingwave_storage", Level::WARN) .with_thread_name(true); risingwave_rt::init_risingwave_logger(settings); - risingwave_rt::main_okk(risingwave_cmd_all::standalone(opts)).unwrap(); + // TODO(shutdown): pass the shutdown token + risingwave_rt::main_okk(|_| risingwave_cmd_all::standalone(opts)); } #[cfg(test)] diff --git a/src/cmd_all/src/standalone.rs b/src/cmd_all/src/standalone.rs index cb009f124758f..ca7b551d7bb45 100644 --- a/src/cmd_all/src/standalone.rs +++ b/src/cmd_all/src/standalone.rs @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anyhow::Result; use clap::Parser; use risingwave_common::config::MetaBackend; use risingwave_common::util::meta_addr::MetaAddressStrategy; +use risingwave_common::util::tokio_util::sync::CancellationToken; use risingwave_compactor::CompactorOpts; use risingwave_compute::ComputeNodeOpts; use risingwave_frontend::FrontendOpts; @@ -183,7 +183,7 @@ pub async fn standalone( frontend_opts, compactor_opts, }: ParsedStandaloneOpts, -) -> Result<()> { +) { tracing::info!("launching Risingwave in standalone mode"); let mut is_in_memory = false; @@ -215,7 +215,11 @@ pub async fn standalone( } if let Some(opts) = compute_opts { tracing::info!("starting compute-node thread with cli args: {:?}", opts); - let _compute_handle = tokio::spawn(async move { risingwave_compute::start(opts).await }); + // TODO(shutdown): pass the shutdown token + let _compute_handle = + tokio::spawn( + async move { risingwave_compute::start(opts, CancellationToken::new()).await }, + ); } if let Some(opts) = frontend_opts.clone() { tracing::info!("starting frontend-node thread with cli args: {:?}", opts); @@ -265,8 +269,6 @@ It SHOULD NEVER be used in benchmarks and production environment!!!" // support it? signal::ctrl_c().await.unwrap(); tracing::info!("Ctrl+C received, now exiting"); - - Ok(()) } #[cfg(test)] diff --git a/src/common/Cargo.toml b/src/common/Cargo.toml index ae6f67faf3aac..3ae8fb38fcd5d 100644 --- a/src/common/Cargo.toml +++ b/src/common/Cargo.toml @@ -117,6 +117,7 @@ tokio = { version = "0.2", package = "madsim-tokio", features = [ "signal", ] } tokio-retry = "0.3" +tokio-util = { workspace = true } toml = "0.8" tracing = "0.1" tracing-futures = { version = "0.2", features = ["futures-03"] } diff --git a/src/common/src/util/mod.rs b/src/common/src/util/mod.rs index c8027ad46e381..20dac5906c91d 100644 --- a/src/common/src/util/mod.rs +++ b/src/common/src/util/mod.rs @@ -41,3 +41,4 @@ pub mod stream_graph_visitor; pub mod tracing; pub mod value_encoding; pub mod worker_util; +pub use tokio_util; diff --git a/src/compute/src/lib.rs b/src/compute/src/lib.rs index 87f052470ea7d..7a3d2be65d1df 100644 --- a/src/compute/src/lib.rs +++ b/src/compute/src/lib.rs @@ -38,6 +38,7 @@ use risingwave_common::config::{AsyncStackTraceOption, MetricLevel, OverrideConf use risingwave_common::util::meta_addr::MetaAddressStrategy; use risingwave_common::util::resource_util::cpu::total_cpu_available; use risingwave_common::util::resource_util::memory::system_memory_available_bytes; +use risingwave_common::util::tokio_util::sync::CancellationToken; use serde::{Deserialize, Serialize}; /// If `total_memory_bytes` is not specified, the default memory limit will be set to @@ -198,7 +199,10 @@ fn validate_opts(opts: &ComputeNodeOpts) { use crate::server::compute_node_serve; /// Start compute node -pub fn start(opts: ComputeNodeOpts) -> Pin + Send>> { +pub fn start( + opts: ComputeNodeOpts, + shutdown: CancellationToken, +) -> Pin + Send>> { // WARNING: don't change the function signature. Making it `async fn` will cause // slow compile in release mode. Box::pin(async move { @@ -218,14 +222,7 @@ pub fn start(opts: ComputeNodeOpts) -> Pin + Send>> .unwrap(); tracing::info!("advertise addr is {}", advertise_addr); - let (join_handle_vec, _shutdown_send) = - compute_node_serve(listen_addr, advertise_addr, opts).await; - - tracing::info!("Server listening at {}", listen_addr); - - for join_handle in join_handle_vec { - join_handle.await.unwrap(); - } + compute_node_serve(listen_addr, advertise_addr, opts, shutdown).await; }) } diff --git a/src/compute/src/server.rs b/src/compute/src/server.rs index a28631a29bea6..7527b27b1b9f6 100644 --- a/src/compute/src/server.rs +++ b/src/compute/src/server.rs @@ -35,6 +35,7 @@ use risingwave_common::telemetry::manager::TelemetryManager; use risingwave_common::telemetry::telemetry_env_enabled; use risingwave_common::util::addr::HostAddr; use risingwave_common::util::pretty_bytes::convert; +use risingwave_common::util::tokio_util::sync::CancellationToken; use risingwave_common::{GIT_SHA, RW_VERSION}; use risingwave_common_heap_profiling::HeapProfiler; use risingwave_common_service::metrics_manager::MetricsManager; @@ -64,7 +65,6 @@ use risingwave_storage::opts::StorageOpts; use risingwave_storage::StateStoreImpl; use risingwave_stream::executor::monitor::global_streaming_metrics; use risingwave_stream::task::{LocalStreamManager, StreamEnvironment}; -use thiserror_ext::AsReport; use tokio::sync::oneshot::Sender; use tokio::task::JoinHandle; use tower::Layer; @@ -84,11 +84,14 @@ use crate::telemetry::ComputeTelemetryCreator; use crate::ComputeNodeOpts; /// Bootstraps the compute-node. +/// +/// Returns when the `shutdown` token is triggered. pub async fn compute_node_serve( listen_addr: SocketAddr, advertise_addr: HostAddr, opts: ComputeNodeOpts, -) -> (Vec>, Sender<()>) { + shutdown: CancellationToken, +) { // Load the configuration. let config = load_config(&opts.config_path, &opts); info!("Starting compute node",); @@ -168,6 +171,7 @@ pub async fn compute_node_serve( let worker_id = meta_client.worker_id(); info!("Assigned worker node id {}", worker_id); + // TODO(shutdown): remove this as there's no need to gracefully shutdown the sub-tasks. let mut sub_tasks: Vec<(JoinHandle<()>, Sender<()>)> = vec![]; // Initialize the metrics subsystem. let source_metrics = Arc::new(GLOBAL_SOURCE_METRICS.clone()); @@ -191,8 +195,6 @@ pub async fn compute_node_serve( hummock_metrics.clone(), )); - let mut join_handle_vec = vec![]; - let await_tree_config = match &config.streaming.async_stack_trace { AsyncStackTraceOption::Off => None, c => await_tree::ConfigBuilder::default() @@ -399,63 +401,39 @@ pub async fn compute_node_serve( #[cfg(not(madsim))] SpillOp::clean_spill_directory().await.unwrap(); - let (shutdown_send, mut shutdown_recv) = tokio::sync::oneshot::channel::<()>(); - let join_handle = tokio::spawn(async move { - tonic::transport::Server::builder() - .initial_connection_window_size(MAX_CONNECTION_WINDOW_SIZE) - .initial_stream_window_size(STREAM_WINDOW_SIZE) - .http2_max_pending_accept_reset_streams(Some( - config.server.grpc_max_reset_stream as usize, - )) - .layer(TracingExtractLayer::new()) - // XXX: unlimit the max message size to allow arbitrary large SQL input. - .add_service(TaskServiceServer::new(batch_srv).max_decoding_message_size(usize::MAX)) - .add_service( - ExchangeServiceServer::new(exchange_srv).max_decoding_message_size(usize::MAX), - ) - .add_service({ - let await_tree_reg = stream_srv.mgr.await_tree_reg().cloned(); - let srv = - StreamServiceServer::new(stream_srv).max_decoding_message_size(usize::MAX); - #[cfg(madsim)] - { - srv - } - #[cfg(not(madsim))] - { - AwaitTreeMiddlewareLayer::new_optional(await_tree_reg).layer(srv) - } - }) - .add_service(MonitorServiceServer::new(monitor_srv)) - .add_service(ConfigServiceServer::new(config_srv)) - .add_service(HealthServer::new(health_srv)) - .monitored_serve_with_shutdown( - listen_addr, - "grpc-compute-node-service", - TcpConfig { - tcp_nodelay: true, - keepalive_duration: None, - }, - async move { - tokio::select! { - _ = tokio::signal::ctrl_c() => {}, - _ = &mut shutdown_recv => { - for (join_handle, shutdown_sender) in sub_tasks { - if let Err(_err) = shutdown_sender.send(()) { - tracing::warn!("Failed to send shutdown"); - continue; - } - if let Err(err) = join_handle.await { - tracing::warn!(error = %err.as_report(), "Failed to join shutdown"); - } - } - }, - } - }, - ) - .await; - }); - join_handle_vec.push(join_handle); + let server = tonic::transport::Server::builder() + .initial_connection_window_size(MAX_CONNECTION_WINDOW_SIZE) + .initial_stream_window_size(STREAM_WINDOW_SIZE) + .http2_max_pending_accept_reset_streams(Some(config.server.grpc_max_reset_stream as usize)) + .layer(TracingExtractLayer::new()) + // XXX: unlimit the max message size to allow arbitrary large SQL input. + .add_service(TaskServiceServer::new(batch_srv).max_decoding_message_size(usize::MAX)) + .add_service(ExchangeServiceServer::new(exchange_srv).max_decoding_message_size(usize::MAX)) + .add_service({ + let await_tree_reg = stream_srv.mgr.await_tree_reg().cloned(); + let srv = StreamServiceServer::new(stream_srv).max_decoding_message_size(usize::MAX); + #[cfg(madsim)] + { + srv + } + #[cfg(not(madsim))] + { + AwaitTreeMiddlewareLayer::new_optional(await_tree_reg).layer(srv) + } + }) + .add_service(MonitorServiceServer::new(monitor_srv)) + .add_service(ConfigServiceServer::new(config_srv)) + .add_service(HealthServer::new(health_srv)) + .monitored_serve_with_shutdown( + listen_addr, + "grpc-compute-node-service", + TcpConfig { + tcp_nodelay: true, + keepalive_duration: None, + }, + shutdown.clone().cancelled_owned(), + ); + let _server_handle = tokio::spawn(server); // Boot metrics service. if config.server.metrics_level > MetricLevel::Disabled { @@ -464,8 +442,15 @@ pub async fn compute_node_serve( // All set, let the meta service know we're ready. meta_client.activate(&advertise_addr).await.unwrap(); + // Wait for the shutdown signal. + let _ = shutdown.cancelled().await; + + // TODO(shutdown): gracefully unregister from the meta service. - (join_handle_vec, shutdown_send) + // NOTE(shutdown): We can't simply join the tonic server here because it only returns when all + // existing connections are closed, while we have long-running streaming calls that never + // close. From the other side, there's also no need to gracefully shutdown them if we have + // unregistered from the meta service. } /// Check whether the compute node has enough memory to perform computing tasks. Apart from storage, diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index 3d8bf618eca58..b00c8013c70e4 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -160,7 +160,7 @@ tokio = { version = "0.2", package = "madsim-tokio", features = [ tokio-postgres = { version = "0.7", features = ["with-uuid-1"] } tokio-retry = "0.3" tokio-stream = "0.1" -tokio-util = { version = "0.7", features = ["codec", "io"] } +tokio-util = { workspace = true, features = ["codec", "io"] } tonic = { workspace = true } tracing = "0.1" typed-builder = "^0.18" diff --git a/src/tests/compaction_test/src/bin/compaction.rs b/src/tests/compaction_test/src/bin/compaction.rs index 62f47b568f1e9..604a6c607d30a 100644 --- a/src/tests/compaction_test/src/bin/compaction.rs +++ b/src/tests/compaction_test/src/bin/compaction.rs @@ -22,5 +22,5 @@ fn main() { risingwave_rt::init_risingwave_logger(risingwave_rt::LoggerSettings::default()); - risingwave_rt::main_okk(risingwave_compaction_test::start(opts)) + risingwave_rt::main_okk(|_| risingwave_compaction_test::start(opts)) } diff --git a/src/tests/compaction_test/src/bin/delete_range.rs b/src/tests/compaction_test/src/bin/delete_range.rs index f154314fb5c8a..1861ca1b9b03f 100644 --- a/src/tests/compaction_test/src/bin/delete_range.rs +++ b/src/tests/compaction_test/src/bin/delete_range.rs @@ -23,5 +23,5 @@ fn main() { risingwave_rt::init_risingwave_logger(risingwave_rt::LoggerSettings::default()); - risingwave_rt::main_okk(risingwave_compaction_test::start_delete_range(opts)) + risingwave_rt::main_okk(|_| risingwave_compaction_test::start_delete_range(opts)) } diff --git a/src/tests/simulation/src/cluster.rs b/src/tests/simulation/src/cluster.rs index 8caf8b52931ca..351e45279fd5e 100644 --- a/src/tests/simulation/src/cluster.rs +++ b/src/tests/simulation/src/cluster.rs @@ -32,6 +32,7 @@ use itertools::Itertools; use madsim::runtime::{Handle, NodeHandle}; use rand::seq::IteratorRandom; use rand::Rng; +use risingwave_common::util::tokio_util::sync::CancellationToken; #[cfg(madsim)] use risingwave_object_store::object::sim::SimServer as ObjectStoreSimServer; use risingwave_pb::common::WorkerNode; @@ -500,7 +501,7 @@ impl Cluster { .name(format!("compute-{i}")) .ip([192, 168, 3, i as u8].into()) .cores(conf.compute_node_cores) - .init(move || risingwave_compute::start(opts.clone())) + .init(move || risingwave_compute::start(opts.clone(), CancellationToken::new())) .build(); } diff --git a/src/tests/simulation/tests/integration_tests/batch/mod.rs b/src/tests/simulation/tests/integration_tests/batch/mod.rs index 25c690cfdcbf8..1ee5132884a2a 100644 --- a/src/tests/simulation/tests/integration_tests/batch/mod.rs +++ b/src/tests/simulation/tests/integration_tests/batch/mod.rs @@ -18,6 +18,7 @@ use std::io::Write; use clap::Parser; use itertools::Itertools; +use risingwave_common::util::tokio_util::sync::CancellationToken; use risingwave_simulation::cluster::{Cluster, ConfigPath, Configuration, Session}; use tokio::time::Duration; @@ -44,7 +45,7 @@ fn create_compute_node(cluster: &Cluster, idx: usize, role: &str) { .name(format!("compute-{idx}")) .ip([192, 168, 3, idx as u8].into()) .cores(config.compute_node_cores) - .init(move || risingwave_compute::start(opts.clone())) + .init(move || risingwave_compute::start(opts.clone(), CancellationToken::new())) .build(); } diff --git a/src/utils/runtime/src/lib.rs b/src/utils/runtime/src/lib.rs index b4501d991f05b..6418c18ea103a 100644 --- a/src/utils/runtime/src/lib.rs +++ b/src/utils/runtime/src/lib.rs @@ -19,8 +19,13 @@ #![feature(panic_update_hook)] #![feature(let_chains)] +#![feature(exitcode_exit_method)] + +use std::pin::pin; +use std::process::ExitCode; use futures::Future; +use risingwave_common::util::tokio_util::sync::CancellationToken; mod logger; pub use logger::*; @@ -33,7 +38,24 @@ use prof::*; /// Start RisingWave components with configs from environment variable. /// -/// Currently, the following env variables will be read: +/// # Shutdown on Ctrl-C +/// +/// The given closure `f` will take a [`CancellationToken`] as an argument. When a `SIGINT` signal +/// is received (typically by pressing `Ctrl-C`), [`CancellationToken::cancel`] will be called to +/// notify all subscribers to shutdown. You can use [`.cancelled()`](CancellationToken::cancelled) +/// to get notified on this. +/// +/// Users may also send a second `SIGINT` signal to force shutdown. In this case, this function +/// will exit the process with a non-zero exit code. +/// +/// When `f` returns, this function will assume that the component has finished its work and it's +/// safe to exit. Therefore, this function will exit the process with exit code 0 **without** +/// waiting for background tasks to finish. In other words, it's the responsibility of `f` to +/// ensure that all essential background tasks are finished before returning. +/// +/// # Environment variables +/// +/// Currently, the following environment variables will be read and used to configure the runtime. /// /// * `RW_WORKER_THREADS` (alias of `TOKIO_WORKER_THREADS`): number of tokio worker threads. If /// not set, it will be decided by tokio. Note that this can still be overridden by per-module @@ -42,9 +64,10 @@ use prof::*; /// debug mode, and disable in release mode. /// * `RW_PROFILE_PATH`: the path to generate flamegraph. If set, then profiling is automatically /// enabled. -pub fn main_okk(f: F) -> F::Output +pub fn main_okk(f: F) -> ! where - F: Future + Send + 'static, + F: FnOnce(CancellationToken) -> Fut, + Fut: Future + Send + 'static, { set_panic_hook(); @@ -73,10 +96,48 @@ where spawn_prof_thread(profile_path); } - tokio::runtime::Builder::new_multi_thread() + let future_with_shutdown = async move { + let shutdown = CancellationToken::new(); + let mut fut = pin!(f(shutdown.clone())); + + tokio::select! { + biased; + result = tokio::signal::ctrl_c() => { + result.expect("failed to receive ctrl-c signal"); + tracing::info!("received ctrl-c, shutting down... (press ctrl-c again to force shutdown)"); + + // Send shutdown signal. + shutdown.cancel(); + + // While waiting for the future to finish, listen for the second ctrl-c signal. + tokio::select! { + biased; + result = tokio::signal::ctrl_c() => { + result.expect("failed to receive ctrl-c signal"); + tracing::warn!("forced shutdown"); + + // Directly exit the process **here** instead of returning from the future, since + // we don't even want to run destructors but only exit as soon as possible. + ExitCode::FAILURE.exit_process(); + } + _ = &mut fut => {}, + } + } + _ = &mut fut => {}, + } + }; + + let runtime = tokio::runtime::Builder::new_multi_thread() .thread_name("rw-main") .enable_all() .build() - .unwrap() - .block_on(f) + .unwrap(); + + runtime.block_on(future_with_shutdown); + + // Shutdown the runtime and exit the process, without waiting for background tasks to finish. + // See the doc on this function for more details. + // TODO(shutdown): is it necessary to shutdown here as we're going to exit? + runtime.shutdown_background(); + ExitCode::SUCCESS.exit_process(); }