From 28d42d442b5c2301dffb32d165986be68c3484ac Mon Sep 17 00:00:00 2001 From: Blaise Bruer Date: Tue, 29 Oct 2024 13:25:33 -0500 Subject: [PATCH] Support graceful shutdown for network connections Waits for network connections to close during graceful shutdown. --- nativelink-macro/src/lib.rs | 1 + nativelink-util/BUILD.bazel | 1 + nativelink-util/src/lib.rs | 1 + nativelink-util/src/shutdown_manager.rs | 146 ++++++++++++++++++ nativelink-worker/src/local_worker.rs | 21 ++- .../src/running_actions_manager.rs | 10 +- .../tests/utils/local_worker_test_utils.rs | 11 +- .../utils/mock_running_actions_manager.rs | 8 +- src/bin/nativelink.rs | 101 ++++++------ 9 files changed, 219 insertions(+), 81 deletions(-) create mode 100644 nativelink-util/src/shutdown_manager.rs diff --git a/nativelink-macro/src/lib.rs b/nativelink-macro/src/lib.rs index f37175569..9d5d3a77e 100644 --- a/nativelink-macro/src/lib.rs +++ b/nativelink-macro/src/lib.rs @@ -34,6 +34,7 @@ pub fn nativelink_test(attr: TokenStream, item: TokenStream) -> TokenStream { #[allow(clippy::disallowed_methods)] #[tokio::test(#attr)] async fn #fn_name(#fn_inputs) #fn_output { + nativelink_util::shutdown_manager::ShutdownManager::init(&tokio::runtime::Handle::current()); // Error means already initialized, which is ok. let _ = nativelink_util::init_tracing(); // If already set it's ok. diff --git a/nativelink-util/BUILD.bazel b/nativelink-util/BUILD.bazel index ac17063f1..83e41f008 100644 --- a/nativelink-util/BUILD.bazel +++ b/nativelink-util/BUILD.bazel @@ -30,6 +30,7 @@ rust_library( "src/proto_stream_utils.rs", "src/resource_info.rs", "src/retry.rs", + "src/shutdown_manager.rs", "src/store_trait.rs", "src/task.rs", "src/tls_utils.rs", diff --git a/nativelink-util/src/lib.rs b/nativelink-util/src/lib.rs index 17edbf700..9bf74f7bf 100644 --- a/nativelink-util/src/lib.rs +++ b/nativelink-util/src/lib.rs @@ -32,6 +32,7 @@ pub mod platform_properties; pub mod proto_stream_utils; pub mod resource_info; pub mod retry; +pub mod shutdown_manager; pub mod store_trait; pub mod task; pub mod tls_utils; diff --git a/nativelink-util/src/shutdown_manager.rs b/nativelink-util/src/shutdown_manager.rs new file mode 100644 index 000000000..9170e410f --- /dev/null +++ b/nativelink-util/src/shutdown_manager.rs @@ -0,0 +1,146 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +use parking_lot::Mutex; +use tokio::runtime::Handle; +#[cfg(target_family = "unix")] +use tokio::signal::unix::{signal, SignalKind}; +use tokio::sync::{broadcast, oneshot}; +use tracing::{event, Level}; + +static SHUTDOWN_MANAGER: ShutdownManager = ShutdownManager { + is_shutting_down: AtomicBool::new(false), + shutdown_tx: Mutex::new(None), // Will be initialized in `init`. +}; + +/// Broadcast Channel Capacity +/// Note: The actual capacity may be greater than the provided capacity. +const BROADCAST_CAPACITY: usize = 1; + +/// ShutdownManager is a singleton that manages the shutdown of the +/// application. Services can register to be notified when a graceful +/// shutdown is initiated using [`ShutdownManager::wait_for_shutdown`]. +/// When the future returned by [`ShutdownManager::wait_for_shutdown`] is +/// completed, the caller will then be handed a [`ShutdownGuard`] which +/// must be held until the caller has completed its shutdown procedure. +/// Once the caller has completed its shutdown procedure, the caller +/// must drop the [`ShutdownGuard`]. When all [`ShutdownGuard`]s have +/// been dropped, the application will then exit. +pub struct ShutdownManager { + is_shutting_down: AtomicBool, + shutdown_tx: Mutex>>>>, +} + +impl ShutdownManager { + #[allow(clippy::disallowed_methods)] + pub fn init(runtime: &Handle) { + let (shutdown_tx, _) = broadcast::channel::>>(BROADCAST_CAPACITY); + *SHUTDOWN_MANAGER.shutdown_tx.lock() = Some(shutdown_tx); + + runtime.spawn(async move { + tokio::signal::ctrl_c() + .await + .expect("Failed to listen to SIGINT"); + event!(Level::WARN, "User terminated process via SIGINT"); + std::process::exit(130); + }); + + #[cfg(target_family = "unix")] + { + runtime.spawn(async move { + signal(SignalKind::terminate()) + .expect("Failed to listen to SIGTERM") + .recv() + .await; + event!(Level::WARN, "Received SIGTERM, begginning shutdown."); + Self::graceful_shutdown(); + }); + } + } + + pub fn is_shutting_down() -> bool { + SHUTDOWN_MANAGER.is_shutting_down.load(Ordering::Acquire) + } + + #[allow(clippy::disallowed_methods)] + fn graceful_shutdown() { + if SHUTDOWN_MANAGER + .is_shutting_down + .swap(true, Ordering::Release) + { + event!(Level::WARN, "Shutdown already in progress."); + return; + } + tokio::spawn(async move { + let (complete_tx, complete_rx) = oneshot::channel::<()>(); + let shutdown_guard = Arc::new(complete_tx); + { + let shutdown_tx_lock = SHUTDOWN_MANAGER.shutdown_tx.lock(); + // No need to check result of send, since it will only fail if + // all receivers have been dropped, in which case it means we + // can safely shutdown. + let _ = shutdown_tx_lock + .as_ref() + .expect("ShutdownManager was never initialized") + .send(shutdown_guard); + } + // It is impossible for the result to be anything but Err(RecvError), + // which means all receivers have been dropped and we can safely shutdown. + let _ = complete_rx.await; + event!(Level::WARN, "All services gracefully shutdown.",); + std::process::exit(143); + }); + } + + pub async fn wait_for_shutdown(service_name: impl Into) -> ShutdownGuard { + let service_name = service_name.into(); + let mut shutdown_receiver = SHUTDOWN_MANAGER + .shutdown_tx + .lock() + .as_ref() + .expect("ShutdownManager was never initialized") + .subscribe(); + let sender = shutdown_receiver + .recv() + .await + .expect("Shutdown sender dropped. This should never happen."); + event!( + Level::INFO, + "Service {service_name} has been notified of shutdown request" + ); + ShutdownGuard { + service_name, + _guard: sender, + } + } +} + +#[derive(Clone)] +pub struct ShutdownGuard { + service_name: String, + _guard: Arc>, +} + +impl Drop for ShutdownGuard { + fn drop(&mut self) { + event!( + Level::INFO, + "Service {} has completed shutdown.", + self.service_name + ); + } +} diff --git a/nativelink-worker/src/local_worker.rs b/nativelink-worker/src/local_worker.rs index 8a7f8b895..6fda27e42 100644 --- a/nativelink-worker/src/local_worker.rs +++ b/nativelink-worker/src/local_worker.rs @@ -36,10 +36,11 @@ use nativelink_util::common::fs; use nativelink_util::digest_hasher::{DigestHasherFunc, ACTIVE_HASHER_FUNC}; use nativelink_util::metrics_utils::{AsyncCounterWrapper, CounterWithTime}; use nativelink_util::origin_context::ActiveOriginContext; +use nativelink_util::shutdown_manager::ShutdownManager; use nativelink_util::store_trait::Store; use nativelink_util::{spawn, tls_utils}; use tokio::process; -use tokio::sync::{broadcast, mpsc, oneshot}; +use tokio::sync::mpsc; use tokio::time::sleep; use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::Streaming; @@ -168,7 +169,6 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, async fn run( &mut self, update_for_worker_stream: Streaming, - shutdown_rx: &mut broadcast::Receiver>>, ) -> Result<(), Error> { // This big block of logic is designed to help simplify upstream components. Upstream // components can write standard futures that return a `Result<(), Error>` and this block @@ -350,18 +350,17 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, futures.push(fut); }, res = futures.next() => res.err_tip(|| "Keep-alive should always pending. Likely unable to send data to scheduler")??, - complete_msg = shutdown_rx.recv().fuse() => { - event!(Level::WARN, "Worker loop reveiced shutdown signal. Shutting down worker...",); + complete_msg = ShutdownManager::wait_for_shutdown("LocalWorker").fuse() => { + event!(Level::INFO, "Worker loop reveiced shutdown signal. Shutting down worker...",); let mut grpc_client = self.grpc_client.clone(); let worker_id = self.worker_id.clone(); let running_actions_manager = self.running_actions_manager.clone(); - let complete_msg_clone = complete_msg.map_err(|e| make_err!(Code::Internal, "Failed to receive shutdown message: {e:?}"))?.clone(); let shutdown_future = async move { if let Err(e) = grpc_client.going_away(GoingAwayRequest { worker_id }).await { event!(Level::ERROR, "Failed to send GoingAwayRequest: {e}",); return Err(e.into()); } - running_actions_manager.complete_actions(complete_msg_clone).await; + running_actions_manager.complete_actions(complete_msg).await; Ok::<(), Error>(()) }; futures.push(shutdown_future.boxed()); @@ -526,10 +525,7 @@ impl LocalWorker { } #[instrument(skip(self), level = Level::INFO)] - pub async fn run( - mut self, - mut shutdown_rx: broadcast::Receiver>>, - ) -> Result<(), Error> { + pub async fn run(mut self) -> Result<(), Error> { let sleep_fn = self .sleep_fn .take() @@ -575,7 +571,10 @@ impl LocalWorker { ); // Now listen for connections and run all other services. - if let Err(err) = inner.run(update_for_worker_stream, &mut shutdown_rx).await { + if let Err(err) = inner.run(update_for_worker_stream).await { + if ShutdownManager::is_shutting_down() { + return Ok(()); // Do not reconnect if we are shutting down. + } 'no_more_actions: { // Ensure there are no actions in transit before we try to kill // all our actions. diff --git a/nativelink-worker/src/running_actions_manager.rs b/nativelink-worker/src/running_actions_manager.rs index b9c9c13aa..b834e439b 100644 --- a/nativelink-worker/src/running_actions_manager.rs +++ b/nativelink-worker/src/running_actions_manager.rs @@ -63,6 +63,7 @@ use nativelink_util::action_messages::{ use nativelink_util::common::{fs, DigestInfo}; use nativelink_util::digest_hasher::{DigestHasher, DigestHasherFunc}; use nativelink_util::metrics_utils::{AsyncCounterWrapper, CounterWithTime}; +use nativelink_util::shutdown_manager::ShutdownGuard; use nativelink_util::store_trait::{Store, StoreLike, UploadSizeInfo}; use nativelink_util::{background_spawn, spawn, spawn_blocking}; use parking_lot::Mutex; @@ -1349,10 +1350,7 @@ pub trait RunningActionsManager: Sync + Send + Sized + Unpin + 'static { hasher: DigestHasherFunc, ) -> impl Future> + Send; - fn complete_actions( - &self, - complete_msg: Arc>, - ) -> impl Future + Send; + fn complete_actions(&self, shutdown_guard: ShutdownGuard) -> impl Future + Send; fn kill_all(&self) -> impl Future + Send; @@ -1885,9 +1883,9 @@ impl RunningActionsManager for RunningActionsManagerImpl { } // Waits for all running actions to complete and signals completion. - // Use the Arc> to signal the completion of the actions + // Use the [`ShutdownGuard`] to signal the completion of the actions // Dropping the sender automatically notifies the process to terminate. - async fn complete_actions(&self, _complete_msg: Arc>) { + async fn complete_actions(&self, _shutdown_guard: ShutdownGuard) { let _ = self .action_done_tx .subscribe() diff --git a/nativelink-worker/tests/utils/local_worker_test_utils.rs b/nativelink-worker/tests/utils/local_worker_test_utils.rs index 6eb349ef4..63d2cb9cc 100644 --- a/nativelink-worker/tests/utils/local_worker_test_utils.rs +++ b/nativelink-worker/tests/utils/local_worker_test_utils.rs @@ -28,7 +28,7 @@ use nativelink_util::spawn; use nativelink_util::task::JoinHandleDropGuard; use nativelink_worker::local_worker::LocalWorker; use nativelink_worker::worker_api_client_wrapper::WorkerApiClientTrait; -use tokio::sync::{broadcast, mpsc, oneshot}; +use tokio::sync::mpsc; use tonic::Status; use tonic::{ codec::Codec, // Needed for .decoder(). @@ -40,10 +40,6 @@ use tonic::{ use super::mock_running_actions_manager::MockRunningActionsManager; -/// Broadcast Channel Capacity -/// Note: The actual capacity may be greater than the provided capacity. -const BROADCAST_CAPACITY: usize = 1; - #[derive(Debug)] enum WorkerClientApiCalls { ConnectWorker(SupportedProperties), @@ -198,11 +194,8 @@ pub async fn setup_local_worker_with_config(local_worker_config: LocalWorkerConf }), Box::new(move |_| Box::pin(async move { /* No sleep */ })), ); - let (shutdown_tx_test, _) = broadcast::channel::>>(BROADCAST_CAPACITY); - let drop_guard = spawn!("local_worker_spawn", async move { - worker.run(shutdown_tx_test.subscribe()).await - }); + let drop_guard = spawn!("local_worker_spawn", worker.run()); let (tx_stream, streaming_response) = setup_grpc_stream(); TestContext { diff --git a/nativelink-worker/tests/utils/mock_running_actions_manager.rs b/nativelink-worker/tests/utils/mock_running_actions_manager.rs index 542ebd93b..add99403a 100644 --- a/nativelink-worker/tests/utils/mock_running_actions_manager.rs +++ b/nativelink-worker/tests/utils/mock_running_actions_manager.rs @@ -22,8 +22,9 @@ use nativelink_proto::com::github::trace_machina::nativelink::remote_execution:: use nativelink_util::action_messages::{ActionResult, OperationId}; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; +use nativelink_util::shutdown_manager::ShutdownGuard; use nativelink_worker::running_actions_manager::{Metrics, RunningAction, RunningActionsManager}; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::mpsc; #[derive(Debug)] enum RunningActionManagerCalls { @@ -167,10 +168,7 @@ impl RunningActionsManager for MockRunningActionsManager { Ok(()) } - fn complete_actions( - &self, - _complete_msg: Arc>, - ) -> impl Future + Send { + fn complete_actions(&self, _shutdown_guard: ShutdownGuard) -> impl Future + Send { future::ready(()) } diff --git a/src/bin/nativelink.rs b/src/bin/nativelink.rs index 1c61b9f75..08c1ca60b 100644 --- a/src/bin/nativelink.rs +++ b/src/bin/nativelink.rs @@ -53,6 +53,7 @@ use nativelink_util::health_utils::HealthRegistryBuilder; use nativelink_util::metrics_utils::{set_metrics_enabled_for_this_thread, Counter}; use nativelink_util::operation_state_manager::ClientStateManager; use nativelink_util::origin_context::OriginContext; +use nativelink_util::shutdown_manager::ShutdownManager; use nativelink_util::store_trait::{ set_default_digest_size_health_check, DEFAULT_DIGEST_SIZE_HEALTH_CHECK_CFG, }; @@ -67,9 +68,6 @@ use rustls_pemfile::{certs as extract_certs, crls as extract_crls}; use scopeguard::guard; use tokio::net::TcpListener; use tokio::select; -#[cfg(target_family = "unix")] -use tokio::signal::unix::{signal, SignalKind}; -use tokio::sync::{broadcast, oneshot}; use tokio_rustls::rustls::pki_types::{CertificateDer, CertificateRevocationListDer}; use tokio_rustls::rustls::server::WebPkiClientVerifier; use tokio_rustls::rustls::{RootCertStore, ServerConfig as TlsServerConfig}; @@ -94,10 +92,6 @@ const DEFAULT_HEALTH_STATUS_CHECK_PATH: &str = "/status"; /// Name of environment variable to disable metrics. const METRICS_DISABLE_ENV: &str = "NATIVELINK_DISABLE_METRICS"; -/// Broadcast Channel Capacity -/// Note: The actual capacity may be greater than the provided capacity. -const BROADCAST_CAPACITY: usize = 1; - /// Backend for bazel remote execution / cache API. #[derive(Parser, Debug)] #[clap( @@ -166,7 +160,6 @@ impl RootMetricsComponent for ConnectedClientsMetrics {} async fn inner_main( cfg: CasConfig, server_start_timestamp: u64, - shutdown_tx: broadcast::Sender>>, ) -> Result<(), Box> { let health_registry_builder = Arc::new(AsyncMutex::new(HealthRegistryBuilder::new("nativelink"))); @@ -243,9 +236,15 @@ async fn inner_main( schedulers: action_schedulers.clone(), })); - for (server_cfg, connected_clients_mux) in servers_and_clients { + for (i, (server_cfg, connected_clients_mux)) in servers_and_clients.into_iter().enumerate() { let services = server_cfg.services.ok_or("'services' must be configured")?; + let name = if server_cfg.name.is_empty() { + format!("{i}") + } else { + server_cfg.name.clone() + }; + // Currently we only support http as our socket type. let ListenerConfig::http(http_config) = server_cfg.listener; @@ -776,10 +775,21 @@ async fn inner_main( if let Some(value) = http_config.experimental_http2_max_header_list_size { http.http2().max_header_list_size(value); } - event!(Level::WARN, "Ready, listening on {socket_addr}",); + + event!(Level::WARN, "Ready, listening on {socket_addr}"); root_futures.push(Box::pin(async move { + let shutdown_guard = Arc::new(Mutex::new(None)); + let name = format!("TcpSocketListener_{name}"); loop { select! { + inner_shutdown_guard = ShutdownManager::wait_for_shutdown(name.clone()) => { + let connected_clients = connected_clients_mux.inner.lock(); + if connected_clients.is_empty() { + drop(shutdown_guard.lock().take()); + } else { + *shutdown_guard.lock() = Some(inner_shutdown_guard); + } + } accept_result = tcp_listener.accept() => { match accept_result { Ok((tcp_stream, remote_addr)) => { @@ -796,6 +806,8 @@ async fn inner_main( .insert(SocketAddrWrapper(remote_addr)); connected_clients_mux.counter.inc(); + let shutdown_guard = shutdown_guard.clone(); + let name = name.clone(); // This is the safest way to guarantee that if our future // is ever dropped we will cleanup our data. let scope_guard = guard( @@ -806,13 +818,32 @@ async fn inner_main( Level::INFO, ?remote_addr, ?socket_addr, + name, "Client disconnected" ); if let Some(connected_clients_mux) = weak_connected_clients_mux.upgrade() { - connected_clients_mux - .inner - .lock() - .remove(&SocketAddrWrapper(remote_addr)); + let mut connected_clients = connected_clients_mux.inner.lock(); + connected_clients.remove(&SocketAddrWrapper(remote_addr)); + + if connected_clients.is_empty() { + event!( + target: "nativelink::services", + Level::INFO, + ?remote_addr, + ?socket_addr, + name, + "No more clients connected & received shutdown signal." + ); + drop(shutdown_guard.lock().take()); + } else if ShutdownManager::is_shutting_down() { + event!( + target: "nativelink::services", + Level::INFO, + name, + "Waiting on {} more clients to disconnect before shutting down.", + connected_clients.len() + ); + } } }, ); @@ -942,9 +973,8 @@ async fn inner_main( } worker_names.insert(name.clone()); worker_metrics.insert(name.clone(), metrics); - let shutdown_rx = shutdown_tx.subscribe(); let fut = Arc::new(OriginContext::new()) - .wrap_async(trace_span!("worker_ctx"), local_worker.run(shutdown_rx)); + .wrap_async(trace_span!("worker_ctx"), local_worker.run()); spawn!("worker", fut, ?name) } }; @@ -1037,41 +1067,12 @@ fn main() -> Result<(), Box> { .on_thread_start(move || set_metrics_enabled_for_this_thread(metrics_enabled)) .build()?; - // Initiates the shutdown process by broadcasting the shutdown signal via the `oneshot::Sender` to all listeners. - // Each listener will perform its cleanup and then drop its `oneshot::Sender`, signaling completion. - // Once all `oneshot::Sender` instances are dropped, the worker knows it can safely terminate. - let (shutdown_tx, _) = broadcast::channel::>>(BROADCAST_CAPACITY); - let shutdown_tx_clone = shutdown_tx.clone(); - let (complete_tx, complete_rx) = oneshot::channel::<()>(); - - runtime.spawn(async move { - tokio::signal::ctrl_c() - .await - .expect("Failed to listen to SIGINT"); - eprintln!("User terminated process via SIGINT"); - std::process::exit(130); - }); - - #[cfg(target_family = "unix")] - { - let complete_tx = Arc::new(complete_tx); - runtime.spawn(async move { - signal(SignalKind::terminate()) - .expect("Failed to listen to SIGTERM") - .recv() - .await; - event!(Level::WARN, "Process terminated via SIGTERM",); - let _ = shutdown_tx_clone.send(complete_tx); - let _ = complete_rx.await; - event!(Level::WARN, "Successfully shut down nativelink.",); - std::process::exit(143); - }); - } + ShutdownManager::init(runtime.handle()); - let _ = runtime.block_on(Arc::new(OriginContext::new()).wrap_async( - trace_span!("main"), - inner_main(cfg, server_start_time, shutdown_tx), - )); + let _ = runtime.block_on( + Arc::new(OriginContext::new()) + .wrap_async(trace_span!("main"), inner_main(cfg, server_start_time)), + ); } Ok(()) }