diff --git a/nativelink-util/BUILD.bazel b/nativelink-util/BUILD.bazel index ac17063f1..89c75e72d 100644 --- a/nativelink-util/BUILD.bazel +++ b/nativelink-util/BUILD.bazel @@ -30,6 +30,7 @@ rust_library( "src/proto_stream_utils.rs", "src/resource_info.rs", "src/retry.rs", + "src/shutdown_guard.rs", "src/store_trait.rs", "src/task.rs", "src/tls_utils.rs", diff --git a/nativelink-util/src/lib.rs b/nativelink-util/src/lib.rs index 17edbf700..2df2bec21 100644 --- a/nativelink-util/src/lib.rs +++ b/nativelink-util/src/lib.rs @@ -32,6 +32,7 @@ pub mod platform_properties; pub mod proto_stream_utils; pub mod resource_info; pub mod retry; +pub mod shutdown_guard; pub mod store_trait; pub mod task; pub mod tls_utils; diff --git a/nativelink-util/src/shutdown_guard.rs b/nativelink-util/src/shutdown_guard.rs new file mode 100644 index 000000000..8c52c15ad --- /dev/null +++ b/nativelink-util/src/shutdown_guard.rs @@ -0,0 +1,128 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; + +use tokio::sync::watch; + +#[derive(Clone, Copy, Eq, PartialEq, Hash)] +pub enum Priority { + // The least important priority. + LeastImportant = 2, + // All priorities greater than 1 must be complted. + P1 = 1, + // All other priorities must be completed. + P0 = 0, +} + +impl From for Priority { + fn from(value: usize) -> Self { + match value { + 0 => Priority::P0, + 1 => Priority::P1, + _ => Priority::LeastImportant, + } + } +} + +impl From for usize { + fn from(priority: Priority) -> usize { + match priority { + Priority::P0 => 0, + Priority::P1 => 1, + Priority::LeastImportant => 2, + } + } +} + +/// Tracks other services that have registered to be notified when +/// the process is being shut down. +pub struct ShutdownGuard { + priority: Priority, + tx: watch::Sender>, + rx: watch::Receiver>, +} + +impl ShutdownGuard { + /// Waits for all priorities less important than the given + /// priority to be completed. + pub async fn wait_for(&mut self, priority: Priority) { + if priority != self.priority { + // Promote our priority to the new priority. + self.tx.send_modify(|map| { + let old_count = map.remove(&self.priority).unwrap_or(0).saturating_sub(1); + map.insert(self.priority, old_count); + + self.priority = priority; + + let new_count = map.get(&priority).unwrap_or(&0).saturating_add(1); + map.insert(priority, new_count); + }); + } + // Ignore error because the receiver will never be closed + // if the sender is still alive here. + let _ = self + .rx + .wait_for(|map| { + let start = usize::from(priority) + 1; + let end = usize::from(Priority::LeastImportant); + for p in start..=end { + if *map.get(&p.into()).unwrap_or(&0) > 0 { + return false; + } + } + true + }) + .await; + } +} + +impl Default for ShutdownGuard { + fn default() -> Self { + let priority = Priority::LeastImportant; + let mut map = HashMap::new(); + map.insert(priority, 0); + let (tx, rx) = watch::channel(map); + Self { priority, tx, rx } + } +} + +impl Clone for ShutdownGuard { + fn clone(&self) -> Self { + self.tx.send_modify(|map| { + map.insert( + self.priority, + map.get(&Priority::LeastImportant) + .unwrap_or(&0) + .saturating_add(1), + ); + }); + Self { + priority: Priority::LeastImportant, + tx: self.tx.clone(), + rx: self.rx.clone(), + } + } +} + +impl Drop for ShutdownGuard { + fn drop(&mut self) { + self.tx.send_modify(|map| { + map.insert( + self.priority, + map.get(&self.priority).unwrap_or(&0).saturating_sub(1), + ); + }); + } +} diff --git a/nativelink-worker/src/local_worker.rs b/nativelink-worker/src/local_worker.rs index c7fd5670d..482f39963 100644 --- a/nativelink-worker/src/local_worker.rs +++ b/nativelink-worker/src/local_worker.rs @@ -36,10 +36,11 @@ use nativelink_util::common::fs; use nativelink_util::digest_hasher::{DigestHasherFunc, ACTIVE_HASHER_FUNC}; use nativelink_util::metrics_utils::{AsyncCounterWrapper, CounterWithTime}; use nativelink_util::origin_context::ActiveOriginContext; +use nativelink_util::shutdown_guard::ShutdownGuard; use nativelink_util::store_trait::Store; use nativelink_util::{spawn, tls_utils}; use tokio::process; -use tokio::sync::{broadcast, mpsc, oneshot}; +use tokio::sync::{broadcast, mpsc}; use tokio::time::sleep; use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::Streaming; @@ -168,7 +169,7 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, async fn run( &mut self, update_for_worker_stream: Streaming, - shutdown_rx: &mut broadcast::Receiver>>, + shutdown_rx: &mut broadcast::Receiver, ) -> Result<(), Error> { // This big block of logic is designed to help simplify upstream components. Upstream // components can write standard futures that return a `Result<(), Error>` and this block @@ -528,7 +529,7 @@ impl LocalWorker { #[instrument(skip(self), level = Level::INFO)] pub async fn run( mut self, - mut shutdown_rx: broadcast::Receiver>>, + mut shutdown_rx: broadcast::Receiver, ) -> Result<(), Error> { let sleep_fn = self .sleep_fn diff --git a/nativelink-worker/src/running_actions_manager.rs b/nativelink-worker/src/running_actions_manager.rs index 83d62dcd8..6e207f4c0 100644 --- a/nativelink-worker/src/running_actions_manager.rs +++ b/nativelink-worker/src/running_actions_manager.rs @@ -63,6 +63,7 @@ use nativelink_util::action_messages::{ use nativelink_util::common::{fs, DigestInfo}; use nativelink_util::digest_hasher::{DigestHasher, DigestHasherFunc}; use nativelink_util::metrics_utils::{AsyncCounterWrapper, CounterWithTime}; +use nativelink_util::shutdown_guard::ShutdownGuard; use nativelink_util::store_trait::{Store, StoreLike, UploadSizeInfo}; use nativelink_util::{background_spawn, spawn, spawn_blocking}; use parking_lot::Mutex; @@ -1350,10 +1351,7 @@ pub trait RunningActionsManager: Sync + Send + Sized + Unpin + 'static { hasher: DigestHasherFunc, ) -> impl Future> + Send; - fn complete_actions( - &self, - complete_msg: Arc>, - ) -> impl Future + Send; + fn complete_actions(&self, complete_msg: ShutdownGuard) -> impl Future + Send; fn kill_all(&self) -> impl Future + Send; @@ -1886,9 +1884,9 @@ impl RunningActionsManager for RunningActionsManagerImpl { } // Waits for all running actions to complete and signals completion. - // Use the Arc> to signal the completion of the actions + // Use the ShutdownGuard to signal the completion of the actions // Dropping the sender automatically notifies the process to terminate. - async fn complete_actions(&self, _complete_msg: Arc>) { + async fn complete_actions(&self, _complete_msg: ShutdownGuard) { let _ = self .action_done_tx .subscribe() diff --git a/nativelink-worker/tests/utils/local_worker_test_utils.rs b/nativelink-worker/tests/utils/local_worker_test_utils.rs index 6eb349ef4..336cca9ad 100644 --- a/nativelink-worker/tests/utils/local_worker_test_utils.rs +++ b/nativelink-worker/tests/utils/local_worker_test_utils.rs @@ -24,11 +24,12 @@ use nativelink_proto::com::github::trace_machina::nativelink::remote_execution:: ExecuteResult, GoingAwayRequest, KeepAliveRequest, SupportedProperties, UpdateForWorker, }; use nativelink_util::channel_body_for_tests::ChannelBody; +use nativelink_util::shutdown_guard::ShutdownGuard; use nativelink_util::spawn; use nativelink_util::task::JoinHandleDropGuard; use nativelink_worker::local_worker::LocalWorker; use nativelink_worker::worker_api_client_wrapper::WorkerApiClientTrait; -use tokio::sync::{broadcast, mpsc, oneshot}; +use tokio::sync::{broadcast, mpsc}; use tonic::Status; use tonic::{ codec::Codec, // Needed for .decoder(). @@ -198,7 +199,7 @@ pub async fn setup_local_worker_with_config(local_worker_config: LocalWorkerConf }), Box::new(move |_| Box::pin(async move { /* No sleep */ })), ); - let (shutdown_tx_test, _) = broadcast::channel::>>(BROADCAST_CAPACITY); + let (shutdown_tx_test, _) = broadcast::channel::(BROADCAST_CAPACITY); let drop_guard = spawn!("local_worker_spawn", async move { worker.run(shutdown_tx_test.subscribe()).await diff --git a/nativelink-worker/tests/utils/mock_running_actions_manager.rs b/nativelink-worker/tests/utils/mock_running_actions_manager.rs index 542ebd93b..705600df2 100644 --- a/nativelink-worker/tests/utils/mock_running_actions_manager.rs +++ b/nativelink-worker/tests/utils/mock_running_actions_manager.rs @@ -22,8 +22,9 @@ use nativelink_proto::com::github::trace_machina::nativelink::remote_execution:: use nativelink_util::action_messages::{ActionResult, OperationId}; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; +use nativelink_util::shutdown_guard::ShutdownGuard; use nativelink_worker::running_actions_manager::{Metrics, RunningAction, RunningActionsManager}; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::mpsc; #[derive(Debug)] enum RunningActionManagerCalls { @@ -167,10 +168,7 @@ impl RunningActionsManager for MockRunningActionsManager { Ok(()) } - fn complete_actions( - &self, - _complete_msg: Arc>, - ) -> impl Future + Send { + fn complete_actions(&self, _complete_msg: ShutdownGuard) -> impl Future + Send { future::ready(()) } diff --git a/src/bin/nativelink.rs b/src/bin/nativelink.rs index 7e2276a63..42ae7096c 100644 --- a/src/bin/nativelink.rs +++ b/src/bin/nativelink.rs @@ -53,6 +53,7 @@ use nativelink_util::health_utils::HealthRegistryBuilder; use nativelink_util::metrics_utils::{set_metrics_enabled_for_this_thread, Counter}; use nativelink_util::operation_state_manager::ClientStateManager; use nativelink_util::origin_context::OriginContext; +use nativelink_util::shutdown_guard::{Priority, ShutdownGuard}; use nativelink_util::store_trait::{ set_default_digest_size_health_check, DEFAULT_DIGEST_SIZE_HEALTH_CHECK_CFG, }; @@ -69,7 +70,7 @@ use tokio::net::TcpListener; use tokio::select; #[cfg(target_family = "unix")] use tokio::signal::unix::{signal, SignalKind}; -use tokio::sync::{broadcast, oneshot}; +use tokio::sync::broadcast; use tokio_rustls::rustls::pki_types::{CertificateDer, CertificateRevocationListDer}; use tokio_rustls::rustls::server::WebPkiClientVerifier; use tokio_rustls::rustls::{RootCertStore, ServerConfig as TlsServerConfig}; @@ -166,7 +167,7 @@ impl RootMetricsComponent for ConnectedClientsMetrics {} async fn inner_main( cfg: CasConfig, server_start_timestamp: u64, - shutdown_tx: broadcast::Sender>>, + shutdown_tx: broadcast::Sender, ) -> Result<(), Box> { fn into_encoding(from: HttpCompressionAlgorithm) -> Option { match from { @@ -1040,9 +1041,9 @@ fn main() -> Result<(), Box> { // Initiates the shutdown process by broadcasting the shutdown signal via the `oneshot::Sender` to all listeners. // Each listener will perform its cleanup and then drop its `oneshot::Sender`, signaling completion. // Once all `oneshot::Sender` instances are dropped, the worker knows it can safely terminate. - let (shutdown_tx, _) = broadcast::channel::>>(BROADCAST_CAPACITY); + let (shutdown_tx, _) = broadcast::channel::(BROADCAST_CAPACITY); let shutdown_tx_clone = shutdown_tx.clone(); - let (complete_tx, complete_rx) = oneshot::channel::<()>(); + let mut shutdown_guard = ShutdownGuard::default(); runtime.spawn(async move { tokio::signal::ctrl_c() @@ -1054,15 +1055,14 @@ fn main() -> Result<(), Box> { #[cfg(target_family = "unix")] { - let complete_tx = Arc::new(complete_tx); runtime.spawn(async move { signal(SignalKind::terminate()) .expect("Failed to listen to SIGTERM") .recv() .await; event!(Level::WARN, "Process terminated via SIGTERM",); - let _ = shutdown_tx_clone.send(complete_tx); - let _ = complete_rx.await; + let _ = shutdown_tx_clone.send(shutdown_guard.clone()); + let _ = shutdown_guard.wait_for(Priority::P0).await; event!(Level::WARN, "Successfully shut down nativelink.",); std::process::exit(143); });