Skip to content

Commit

Permalink
refactor: demonstrate graceful shutdown on compute node (#17533)
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao authored Jul 4, 2024
1 parent 3e7418c commit 47ece85
Show file tree
Hide file tree
Showing 14 changed files with 158 additions and 102 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 14 additions & 10 deletions src/cmd/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,31 @@ risingwave_expr_impl::enable!();

// Entry point functions.

pub fn compute(opts: ComputeNodeOpts) {
pub fn compute(opts: ComputeNodeOpts) -> ! {
init_risingwave_logger(LoggerSettings::from_opts(&opts));
main_okk(risingwave_compute::start(opts));
main_okk(|shutdown| risingwave_compute::start(opts, shutdown));
}

pub fn meta(opts: MetaNodeOpts) {
pub fn meta(opts: MetaNodeOpts) -> ! {
init_risingwave_logger(LoggerSettings::from_opts(&opts));
main_okk(risingwave_meta_node::start(opts));
// TODO(shutdown): pass the shutdown token
main_okk(|_| risingwave_meta_node::start(opts));
}

pub fn frontend(opts: FrontendOpts) {
pub fn frontend(opts: FrontendOpts) -> ! {
init_risingwave_logger(LoggerSettings::from_opts(&opts));
main_okk(risingwave_frontend::start(opts));
// TODO(shutdown): pass the shutdown token
main_okk(|_| risingwave_frontend::start(opts));
}

pub fn compactor(opts: CompactorOpts) {
pub fn compactor(opts: CompactorOpts) -> ! {
init_risingwave_logger(LoggerSettings::from_opts(&opts));
main_okk(risingwave_compactor::start(opts));
// TODO(shutdown): pass the shutdown token
main_okk(|_| risingwave_compactor::start(opts));
}

pub fn ctl(opts: CtlOpts) {
pub fn ctl(opts: CtlOpts) -> ! {
init_risingwave_logger(LoggerSettings::new("ctl").stderr(true));
main_okk(risingwave_ctl::start(opts));
// TODO(shutdown): pass the shutdown token
main_okk(|_| risingwave_ctl::start(opts));
}
12 changes: 7 additions & 5 deletions src/cmd_all/src/bin/risingwave.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ enum Component {

impl Component {
/// Start the component from the given `args` without `argv[0]`.
fn start(self, matches: &ArgMatches) {
fn start(self, matches: &ArgMatches) -> ! {
eprintln!("launching `{}`", self);

fn parse_opts<T: FromArgMatches>(matches: &ArgMatches) -> T {
Expand Down Expand Up @@ -224,19 +224,20 @@ fn main() {
component.start(&matches);
}

fn standalone(opts: StandaloneOpts) {
fn standalone(opts: StandaloneOpts) -> ! {
let opts = risingwave_cmd_all::parse_standalone_opt_args(&opts);
let settings = risingwave_rt::LoggerSettings::from_opts(&opts)
.with_target("risingwave_storage", Level::WARN)
.with_thread_name(true);
risingwave_rt::init_risingwave_logger(settings);
risingwave_rt::main_okk(risingwave_cmd_all::standalone(opts)).unwrap();
// TODO(shutdown): pass the shutdown token
risingwave_rt::main_okk(|_| risingwave_cmd_all::standalone(opts));
}

/// For single node, the internals are just a config mapping from its
/// high level options to standalone mode node-level options.
/// We will start a standalone instance, with all nodes in the same process.
fn single_node(opts: SingleNodeOpts) {
fn single_node(opts: SingleNodeOpts) -> ! {
if env::var(TELEMETRY_CLUSTER_TYPE).is_err() {
env::set_var(TELEMETRY_CLUSTER_TYPE, TELEMETRY_CLUSTER_TYPE_SINGLE_NODE);
}
Expand All @@ -245,7 +246,8 @@ fn single_node(opts: SingleNodeOpts) {
.with_target("risingwave_storage", Level::WARN)
.with_thread_name(true);
risingwave_rt::init_risingwave_logger(settings);
risingwave_rt::main_okk(risingwave_cmd_all::standalone(opts)).unwrap();
// TODO(shutdown): pass the shutdown token
risingwave_rt::main_okk(|_| risingwave_cmd_all::standalone(opts));
}

#[cfg(test)]
Expand Down
12 changes: 7 additions & 5 deletions src/cmd_all/src/standalone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use anyhow::Result;
use clap::Parser;
use risingwave_common::config::MetaBackend;
use risingwave_common::util::meta_addr::MetaAddressStrategy;
use risingwave_common::util::tokio_util::sync::CancellationToken;
use risingwave_compactor::CompactorOpts;
use risingwave_compute::ComputeNodeOpts;
use risingwave_frontend::FrontendOpts;
Expand Down Expand Up @@ -183,7 +183,7 @@ pub async fn standalone(
frontend_opts,
compactor_opts,
}: ParsedStandaloneOpts,
) -> Result<()> {
) {
tracing::info!("launching Risingwave in standalone mode");

let mut is_in_memory = false;
Expand Down Expand Up @@ -215,7 +215,11 @@ pub async fn standalone(
}
if let Some(opts) = compute_opts {
tracing::info!("starting compute-node thread with cli args: {:?}", opts);
let _compute_handle = tokio::spawn(async move { risingwave_compute::start(opts).await });
// TODO(shutdown): pass the shutdown token
let _compute_handle =
tokio::spawn(
async move { risingwave_compute::start(opts, CancellationToken::new()).await },
);
}
if let Some(opts) = frontend_opts.clone() {
tracing::info!("starting frontend-node thread with cli args: {:?}", opts);
Expand Down Expand Up @@ -265,8 +269,6 @@ It SHOULD NEVER be used in benchmarks and production environment!!!"
// support it?
signal::ctrl_c().await.unwrap();
tracing::info!("Ctrl+C received, now exiting");

Ok(())
}

#[cfg(test)]
Expand Down
1 change: 1 addition & 0 deletions src/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ tokio = { version = "0.2", package = "madsim-tokio", features = [
"signal",
] }
tokio-retry = "0.3"
tokio-util = { workspace = true }
toml = "0.8"
tracing = "0.1"
tracing-futures = { version = "0.2", features = ["futures-03"] }
Expand Down
1 change: 1 addition & 0 deletions src/common/src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ pub mod stream_graph_visitor;
pub mod tracing;
pub mod value_encoding;
pub mod worker_util;
pub use tokio_util;
15 changes: 6 additions & 9 deletions src/compute/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use risingwave_common::config::{AsyncStackTraceOption, MetricLevel, OverrideConf
use risingwave_common::util::meta_addr::MetaAddressStrategy;
use risingwave_common::util::resource_util::cpu::total_cpu_available;
use risingwave_common::util::resource_util::memory::system_memory_available_bytes;
use risingwave_common::util::tokio_util::sync::CancellationToken;
use serde::{Deserialize, Serialize};

/// If `total_memory_bytes` is not specified, the default memory limit will be set to
Expand Down Expand Up @@ -198,7 +199,10 @@ fn validate_opts(opts: &ComputeNodeOpts) {
use crate::server::compute_node_serve;

/// Start compute node
pub fn start(opts: ComputeNodeOpts) -> Pin<Box<dyn Future<Output = ()> + Send>> {
pub fn start(
opts: ComputeNodeOpts,
shutdown: CancellationToken,
) -> Pin<Box<dyn Future<Output = ()> + Send>> {
// WARNING: don't change the function signature. Making it `async fn` will cause
// slow compile in release mode.
Box::pin(async move {
Expand All @@ -218,14 +222,7 @@ pub fn start(opts: ComputeNodeOpts) -> Pin<Box<dyn Future<Output = ()> + Send>>
.unwrap();
tracing::info!("advertise addr is {}", advertise_addr);

let (join_handle_vec, _shutdown_send) =
compute_node_serve(listen_addr, advertise_addr, opts).await;

tracing::info!("Server listening at {}", listen_addr);

for join_handle in join_handle_vec {
join_handle.await.unwrap();
}
compute_node_serve(listen_addr, advertise_addr, opts, shutdown).await;
})
}

Expand Down
109 changes: 47 additions & 62 deletions src/compute/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use risingwave_common::telemetry::manager::TelemetryManager;
use risingwave_common::telemetry::telemetry_env_enabled;
use risingwave_common::util::addr::HostAddr;
use risingwave_common::util::pretty_bytes::convert;
use risingwave_common::util::tokio_util::sync::CancellationToken;
use risingwave_common::{GIT_SHA, RW_VERSION};
use risingwave_common_heap_profiling::HeapProfiler;
use risingwave_common_service::metrics_manager::MetricsManager;
Expand Down Expand Up @@ -64,7 +65,6 @@ use risingwave_storage::opts::StorageOpts;
use risingwave_storage::StateStoreImpl;
use risingwave_stream::executor::monitor::global_streaming_metrics;
use risingwave_stream::task::{LocalStreamManager, StreamEnvironment};
use thiserror_ext::AsReport;
use tokio::sync::oneshot::Sender;
use tokio::task::JoinHandle;
use tower::Layer;
Expand All @@ -84,11 +84,14 @@ use crate::telemetry::ComputeTelemetryCreator;
use crate::ComputeNodeOpts;

/// Bootstraps the compute-node.
///
/// Returns when the `shutdown` token is triggered.
pub async fn compute_node_serve(
listen_addr: SocketAddr,
advertise_addr: HostAddr,
opts: ComputeNodeOpts,
) -> (Vec<JoinHandle<()>>, Sender<()>) {
shutdown: CancellationToken,
) {
// Load the configuration.
let config = load_config(&opts.config_path, &opts);
info!("Starting compute node",);
Expand Down Expand Up @@ -168,6 +171,7 @@ pub async fn compute_node_serve(
let worker_id = meta_client.worker_id();
info!("Assigned worker node id {}", worker_id);

// TODO(shutdown): remove this as there's no need to gracefully shutdown the sub-tasks.
let mut sub_tasks: Vec<(JoinHandle<()>, Sender<()>)> = vec![];
// Initialize the metrics subsystem.
let source_metrics = Arc::new(GLOBAL_SOURCE_METRICS.clone());
Expand All @@ -191,8 +195,6 @@ pub async fn compute_node_serve(
hummock_metrics.clone(),
));

let mut join_handle_vec = vec![];

let await_tree_config = match &config.streaming.async_stack_trace {
AsyncStackTraceOption::Off => None,
c => await_tree::ConfigBuilder::default()
Expand Down Expand Up @@ -399,63 +401,39 @@ pub async fn compute_node_serve(
#[cfg(not(madsim))]
SpillOp::clean_spill_directory().await.unwrap();

let (shutdown_send, mut shutdown_recv) = tokio::sync::oneshot::channel::<()>();
let join_handle = tokio::spawn(async move {
tonic::transport::Server::builder()
.initial_connection_window_size(MAX_CONNECTION_WINDOW_SIZE)
.initial_stream_window_size(STREAM_WINDOW_SIZE)
.http2_max_pending_accept_reset_streams(Some(
config.server.grpc_max_reset_stream as usize,
))
.layer(TracingExtractLayer::new())
// XXX: unlimit the max message size to allow arbitrary large SQL input.
.add_service(TaskServiceServer::new(batch_srv).max_decoding_message_size(usize::MAX))
.add_service(
ExchangeServiceServer::new(exchange_srv).max_decoding_message_size(usize::MAX),
)
.add_service({
let await_tree_reg = stream_srv.mgr.await_tree_reg().cloned();
let srv =
StreamServiceServer::new(stream_srv).max_decoding_message_size(usize::MAX);
#[cfg(madsim)]
{
srv
}
#[cfg(not(madsim))]
{
AwaitTreeMiddlewareLayer::new_optional(await_tree_reg).layer(srv)
}
})
.add_service(MonitorServiceServer::new(monitor_srv))
.add_service(ConfigServiceServer::new(config_srv))
.add_service(HealthServer::new(health_srv))
.monitored_serve_with_shutdown(
listen_addr,
"grpc-compute-node-service",
TcpConfig {
tcp_nodelay: true,
keepalive_duration: None,
},
async move {
tokio::select! {
_ = tokio::signal::ctrl_c() => {},
_ = &mut shutdown_recv => {
for (join_handle, shutdown_sender) in sub_tasks {
if let Err(_err) = shutdown_sender.send(()) {
tracing::warn!("Failed to send shutdown");
continue;
}
if let Err(err) = join_handle.await {
tracing::warn!(error = %err.as_report(), "Failed to join shutdown");
}
}
},
}
},
)
.await;
});
join_handle_vec.push(join_handle);
let server = tonic::transport::Server::builder()
.initial_connection_window_size(MAX_CONNECTION_WINDOW_SIZE)
.initial_stream_window_size(STREAM_WINDOW_SIZE)
.http2_max_pending_accept_reset_streams(Some(config.server.grpc_max_reset_stream as usize))
.layer(TracingExtractLayer::new())
// XXX: unlimit the max message size to allow arbitrary large SQL input.
.add_service(TaskServiceServer::new(batch_srv).max_decoding_message_size(usize::MAX))
.add_service(ExchangeServiceServer::new(exchange_srv).max_decoding_message_size(usize::MAX))
.add_service({
let await_tree_reg = stream_srv.mgr.await_tree_reg().cloned();
let srv = StreamServiceServer::new(stream_srv).max_decoding_message_size(usize::MAX);
#[cfg(madsim)]
{
srv
}
#[cfg(not(madsim))]
{
AwaitTreeMiddlewareLayer::new_optional(await_tree_reg).layer(srv)
}
})
.add_service(MonitorServiceServer::new(monitor_srv))
.add_service(ConfigServiceServer::new(config_srv))
.add_service(HealthServer::new(health_srv))
.monitored_serve_with_shutdown(
listen_addr,
"grpc-compute-node-service",
TcpConfig {
tcp_nodelay: true,
keepalive_duration: None,
},
shutdown.clone().cancelled_owned(),
);
let _server_handle = tokio::spawn(server);

// Boot metrics service.
if config.server.metrics_level > MetricLevel::Disabled {
Expand All @@ -464,8 +442,15 @@ pub async fn compute_node_serve(

// All set, let the meta service know we're ready.
meta_client.activate(&advertise_addr).await.unwrap();
// Wait for the shutdown signal.
let _ = shutdown.cancelled().await;

// TODO(shutdown): gracefully unregister from the meta service.

(join_handle_vec, shutdown_send)
// NOTE(shutdown): We can't simply join the tonic server here because it only returns when all
// existing connections are closed, while we have long-running streaming calls that never
// close. From the other side, there's also no need to gracefully shutdown them if we have
// unregistered from the meta service.
}

/// Check whether the compute node has enough memory to perform computing tasks. Apart from storage,
Expand Down
2 changes: 1 addition & 1 deletion src/connector/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ tokio = { version = "0.2", package = "madsim-tokio", features = [
tokio-postgres = { version = "0.7", features = ["with-uuid-1"] }
tokio-retry = "0.3"
tokio-stream = "0.1"
tokio-util = { version = "0.7", features = ["codec", "io"] }
tokio-util = { workspace = true, features = ["codec", "io"] }
tonic = { workspace = true }
tracing = "0.1"
typed-builder = "^0.18"
Expand Down
2 changes: 1 addition & 1 deletion src/tests/compaction_test/src/bin/compaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ fn main() {

risingwave_rt::init_risingwave_logger(risingwave_rt::LoggerSettings::default());

risingwave_rt::main_okk(risingwave_compaction_test::start(opts))
risingwave_rt::main_okk(|_| risingwave_compaction_test::start(opts))
}
2 changes: 1 addition & 1 deletion src/tests/compaction_test/src/bin/delete_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ fn main() {

risingwave_rt::init_risingwave_logger(risingwave_rt::LoggerSettings::default());

risingwave_rt::main_okk(risingwave_compaction_test::start_delete_range(opts))
risingwave_rt::main_okk(|_| risingwave_compaction_test::start_delete_range(opts))
}
3 changes: 2 additions & 1 deletion src/tests/simulation/src/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use itertools::Itertools;
use madsim::runtime::{Handle, NodeHandle};
use rand::seq::IteratorRandom;
use rand::Rng;
use risingwave_common::util::tokio_util::sync::CancellationToken;
#[cfg(madsim)]
use risingwave_object_store::object::sim::SimServer as ObjectStoreSimServer;
use risingwave_pb::common::WorkerNode;
Expand Down Expand Up @@ -500,7 +501,7 @@ impl Cluster {
.name(format!("compute-{i}"))
.ip([192, 168, 3, i as u8].into())
.cores(conf.compute_node_cores)
.init(move || risingwave_compute::start(opts.clone()))
.init(move || risingwave_compute::start(opts.clone(), CancellationToken::new()))
.build();
}

Expand Down
Loading

0 comments on commit 47ece85

Please sign in to comment.