Skip to content

Commit

Permalink
Add ShutdownGuard to replace oneshot for shutdown (TraceMachina#1491)
Browse files Browse the repository at this point in the history
This makes makes it easier to manage more complex shutdown cases.
For example, some services might need to shutdown last, so we
introduce a priority system that services can opt into.
  • Loading branch information
allada authored Nov 22, 2024
1 parent 22707d7 commit a8c3217
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 23 deletions.
1 change: 1 addition & 0 deletions nativelink-util/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ rust_library(
"src/proto_stream_utils.rs",
"src/resource_info.rs",
"src/retry.rs",
"src/shutdown_guard.rs",
"src/store_trait.rs",
"src/task.rs",
"src/tls_utils.rs",
Expand Down
1 change: 1 addition & 0 deletions nativelink-util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub mod platform_properties;
pub mod proto_stream_utils;
pub mod resource_info;
pub mod retry;
pub mod shutdown_guard;
pub mod store_trait;
pub mod task;
pub mod tls_utils;
Expand Down
128 changes: 128 additions & 0 deletions nativelink-util/src/shutdown_guard.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright 2024 The NativeLink Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;

use tokio::sync::watch;

#[derive(Clone, Copy, Eq, PartialEq, Hash)]
pub enum Priority {
// The least important priority.
LeastImportant = 2,
// All priorities greater than 1 must be complted.
P1 = 1,
// All other priorities must be completed.
P0 = 0,
}

impl From<usize> for Priority {
fn from(value: usize) -> Self {
match value {
0 => Priority::P0,
1 => Priority::P1,
_ => Priority::LeastImportant,
}
}
}

impl From<Priority> for usize {
fn from(priority: Priority) -> usize {
match priority {
Priority::P0 => 0,
Priority::P1 => 1,
Priority::LeastImportant => 2,
}
}
}

/// Tracks other services that have registered to be notified when
/// the process is being shut down.
pub struct ShutdownGuard {
priority: Priority,
tx: watch::Sender<HashMap<Priority, usize>>,
rx: watch::Receiver<HashMap<Priority, usize>>,
}

impl ShutdownGuard {
/// Waits for all priorities less important than the given
/// priority to be completed.
pub async fn wait_for(&mut self, priority: Priority) {
if priority != self.priority {
// Promote our priority to the new priority.
self.tx.send_modify(|map| {
let old_count = map.remove(&self.priority).unwrap_or(0).saturating_sub(1);
map.insert(self.priority, old_count);

self.priority = priority;

let new_count = map.get(&priority).unwrap_or(&0).saturating_add(1);
map.insert(priority, new_count);
});
}
// Ignore error because the receiver will never be closed
// if the sender is still alive here.
let _ = self
.rx
.wait_for(|map| {
let start = usize::from(priority) + 1;
let end = usize::from(Priority::LeastImportant);
for p in start..=end {
if *map.get(&p.into()).unwrap_or(&0) > 0 {
return false;
}
}
true
})
.await;
}
}

impl Default for ShutdownGuard {
fn default() -> Self {
let priority = Priority::LeastImportant;
let mut map = HashMap::new();
map.insert(priority, 0);
let (tx, rx) = watch::channel(map);
Self { priority, tx, rx }
}
}

impl Clone for ShutdownGuard {
fn clone(&self) -> Self {
self.tx.send_modify(|map| {
map.insert(
self.priority,
map.get(&Priority::LeastImportant)
.unwrap_or(&0)
.saturating_add(1),
);
});
Self {
priority: Priority::LeastImportant,
tx: self.tx.clone(),
rx: self.rx.clone(),
}
}
}

impl Drop for ShutdownGuard {
fn drop(&mut self) {
self.tx.send_modify(|map| {
map.insert(
self.priority,
map.get(&self.priority).unwrap_or(&0).saturating_sub(1),
);
});
}
}
7 changes: 4 additions & 3 deletions nativelink-worker/src/local_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ use nativelink_util::common::fs;
use nativelink_util::digest_hasher::{DigestHasherFunc, ACTIVE_HASHER_FUNC};
use nativelink_util::metrics_utils::{AsyncCounterWrapper, CounterWithTime};
use nativelink_util::origin_context::ActiveOriginContext;
use nativelink_util::shutdown_guard::ShutdownGuard;
use nativelink_util::store_trait::Store;
use nativelink_util::{spawn, tls_utils};
use tokio::process;
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::sync::{broadcast, mpsc};
use tokio::time::sleep;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::Streaming;
Expand Down Expand Up @@ -168,7 +169,7 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
async fn run(
&mut self,
update_for_worker_stream: Streaming<UpdateForWorker>,
shutdown_rx: &mut broadcast::Receiver<Arc<oneshot::Sender<()>>>,
shutdown_rx: &mut broadcast::Receiver<ShutdownGuard>,
) -> Result<(), Error> {
// This big block of logic is designed to help simplify upstream components. Upstream
// components can write standard futures that return a `Result<(), Error>` and this block
Expand Down Expand Up @@ -528,7 +529,7 @@ impl<T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorker<T, U> {
#[instrument(skip(self), level = Level::INFO)]
pub async fn run(
mut self,
mut shutdown_rx: broadcast::Receiver<Arc<oneshot::Sender<()>>>,
mut shutdown_rx: broadcast::Receiver<ShutdownGuard>,
) -> Result<(), Error> {
let sleep_fn = self
.sleep_fn
Expand Down
10 changes: 4 additions & 6 deletions nativelink-worker/src/running_actions_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ use nativelink_util::action_messages::{
use nativelink_util::common::{fs, DigestInfo};
use nativelink_util::digest_hasher::{DigestHasher, DigestHasherFunc};
use nativelink_util::metrics_utils::{AsyncCounterWrapper, CounterWithTime};
use nativelink_util::shutdown_guard::ShutdownGuard;
use nativelink_util::store_trait::{Store, StoreLike, UploadSizeInfo};
use nativelink_util::{background_spawn, spawn, spawn_blocking};
use parking_lot::Mutex;
Expand Down Expand Up @@ -1350,10 +1351,7 @@ pub trait RunningActionsManager: Sync + Send + Sized + Unpin + 'static {
hasher: DigestHasherFunc,
) -> impl Future<Output = Result<(), Error>> + Send;

fn complete_actions(
&self,
complete_msg: Arc<oneshot::Sender<()>>,
) -> impl Future<Output = ()> + Send;
fn complete_actions(&self, complete_msg: ShutdownGuard) -> impl Future<Output = ()> + Send;

fn kill_all(&self) -> impl Future<Output = ()> + Send;

Expand Down Expand Up @@ -1886,9 +1884,9 @@ impl RunningActionsManager for RunningActionsManagerImpl {
}

// Waits for all running actions to complete and signals completion.
// Use the Arc<oneshot::Sender<()>> to signal the completion of the actions
// Use the ShutdownGuard to signal the completion of the actions
// Dropping the sender automatically notifies the process to terminate.
async fn complete_actions(&self, _complete_msg: Arc<oneshot::Sender<()>>) {
async fn complete_actions(&self, _complete_msg: ShutdownGuard) {
let _ = self
.action_done_tx
.subscribe()
Expand Down
5 changes: 3 additions & 2 deletions nativelink-worker/tests/utils/local_worker_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::
ExecuteResult, GoingAwayRequest, KeepAliveRequest, SupportedProperties, UpdateForWorker,
};
use nativelink_util::channel_body_for_tests::ChannelBody;
use nativelink_util::shutdown_guard::ShutdownGuard;
use nativelink_util::spawn;
use nativelink_util::task::JoinHandleDropGuard;
use nativelink_worker::local_worker::LocalWorker;
use nativelink_worker::worker_api_client_wrapper::WorkerApiClientTrait;
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::sync::{broadcast, mpsc};
use tonic::Status;
use tonic::{
codec::Codec, // Needed for .decoder().
Expand Down Expand Up @@ -198,7 +199,7 @@ pub async fn setup_local_worker_with_config(local_worker_config: LocalWorkerConf
}),
Box::new(move |_| Box::pin(async move { /* No sleep */ })),
);
let (shutdown_tx_test, _) = broadcast::channel::<Arc<oneshot::Sender<()>>>(BROADCAST_CAPACITY);
let (shutdown_tx_test, _) = broadcast::channel::<ShutdownGuard>(BROADCAST_CAPACITY);

let drop_guard = spawn!("local_worker_spawn", async move {
worker.run(shutdown_tx_test.subscribe()).await
Expand Down
8 changes: 3 additions & 5 deletions nativelink-worker/tests/utils/mock_running_actions_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::
use nativelink_util::action_messages::{ActionResult, OperationId};
use nativelink_util::common::DigestInfo;
use nativelink_util::digest_hasher::DigestHasherFunc;
use nativelink_util::shutdown_guard::ShutdownGuard;
use nativelink_worker::running_actions_manager::{Metrics, RunningAction, RunningActionsManager};
use tokio::sync::{mpsc, oneshot};
use tokio::sync::mpsc;

#[derive(Debug)]
enum RunningActionManagerCalls {
Expand Down Expand Up @@ -167,10 +168,7 @@ impl RunningActionsManager for MockRunningActionsManager {
Ok(())
}

fn complete_actions(
&self,
_complete_msg: Arc<oneshot::Sender<()>>,
) -> impl Future<Output = ()> + Send {
fn complete_actions(&self, _complete_msg: ShutdownGuard) -> impl Future<Output = ()> + Send {
future::ready(())
}

Expand Down
14 changes: 7 additions & 7 deletions src/bin/nativelink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ use nativelink_util::health_utils::HealthRegistryBuilder;
use nativelink_util::metrics_utils::{set_metrics_enabled_for_this_thread, Counter};
use nativelink_util::operation_state_manager::ClientStateManager;
use nativelink_util::origin_context::OriginContext;
use nativelink_util::shutdown_guard::{Priority, ShutdownGuard};
use nativelink_util::store_trait::{
set_default_digest_size_health_check, DEFAULT_DIGEST_SIZE_HEALTH_CHECK_CFG,
};
Expand All @@ -69,7 +70,7 @@ use tokio::net::TcpListener;
use tokio::select;
#[cfg(target_family = "unix")]
use tokio::signal::unix::{signal, SignalKind};
use tokio::sync::{broadcast, oneshot};
use tokio::sync::broadcast;
use tokio_rustls::rustls::pki_types::{CertificateDer, CertificateRevocationListDer};
use tokio_rustls::rustls::server::WebPkiClientVerifier;
use tokio_rustls::rustls::{RootCertStore, ServerConfig as TlsServerConfig};
Expand Down Expand Up @@ -166,7 +167,7 @@ impl RootMetricsComponent for ConnectedClientsMetrics {}
async fn inner_main(
cfg: CasConfig,
server_start_timestamp: u64,
shutdown_tx: broadcast::Sender<Arc<oneshot::Sender<()>>>,
shutdown_tx: broadcast::Sender<ShutdownGuard>,
) -> Result<(), Box<dyn std::error::Error>> {
fn into_encoding(from: HttpCompressionAlgorithm) -> Option<CompressionEncoding> {
match from {
Expand Down Expand Up @@ -1040,9 +1041,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Initiates the shutdown process by broadcasting the shutdown signal via the `oneshot::Sender` to all listeners.
// Each listener will perform its cleanup and then drop its `oneshot::Sender`, signaling completion.
// Once all `oneshot::Sender` instances are dropped, the worker knows it can safely terminate.
let (shutdown_tx, _) = broadcast::channel::<Arc<oneshot::Sender<()>>>(BROADCAST_CAPACITY);
let (shutdown_tx, _) = broadcast::channel::<ShutdownGuard>(BROADCAST_CAPACITY);
let shutdown_tx_clone = shutdown_tx.clone();
let (complete_tx, complete_rx) = oneshot::channel::<()>();
let mut shutdown_guard = ShutdownGuard::default();

runtime.spawn(async move {
tokio::signal::ctrl_c()
Expand All @@ -1054,15 +1055,14 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {

#[cfg(target_family = "unix")]
{
let complete_tx = Arc::new(complete_tx);
runtime.spawn(async move {
signal(SignalKind::terminate())
.expect("Failed to listen to SIGTERM")
.recv()
.await;
event!(Level::WARN, "Process terminated via SIGTERM",);
let _ = shutdown_tx_clone.send(complete_tx);
let _ = complete_rx.await;
let _ = shutdown_tx_clone.send(shutdown_guard.clone());
let _ = shutdown_guard.wait_for(Priority::P0).await;
event!(Level::WARN, "Successfully shut down nativelink.",);
std::process::exit(143);
});
Expand Down

0 comments on commit a8c3217

Please sign in to comment.