Skip to content

Commit

Permalink
Support graceful shutdown for network connections
Browse files Browse the repository at this point in the history
Waits for network connections to close during graceful shutdown.
  • Loading branch information
allada committed Oct 29, 2024
1 parent 51a2fd4 commit 28d42d4
Show file tree
Hide file tree
Showing 9 changed files with 219 additions and 81 deletions.
1 change: 1 addition & 0 deletions nativelink-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub fn nativelink_test(attr: TokenStream, item: TokenStream) -> TokenStream {
#[allow(clippy::disallowed_methods)]
#[tokio::test(#attr)]
async fn #fn_name(#fn_inputs) #fn_output {
nativelink_util::shutdown_manager::ShutdownManager::init(&tokio::runtime::Handle::current());
// Error means already initialized, which is ok.
let _ = nativelink_util::init_tracing();
// If already set it's ok.
Expand Down
1 change: 1 addition & 0 deletions nativelink-util/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ rust_library(
"src/proto_stream_utils.rs",
"src/resource_info.rs",
"src/retry.rs",
"src/shutdown_manager.rs",
"src/store_trait.rs",
"src/task.rs",
"src/tls_utils.rs",
Expand Down
1 change: 1 addition & 0 deletions nativelink-util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub mod platform_properties;
pub mod proto_stream_utils;
pub mod resource_info;
pub mod retry;
pub mod shutdown_manager;
pub mod store_trait;
pub mod task;
pub mod tls_utils;
Expand Down
146 changes: 146 additions & 0 deletions nativelink-util/src/shutdown_manager.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// Copyright 2024 The NativeLink Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use parking_lot::Mutex;
use tokio::runtime::Handle;
#[cfg(target_family = "unix")]
use tokio::signal::unix::{signal, SignalKind};
use tokio::sync::{broadcast, oneshot};
use tracing::{event, Level};

static SHUTDOWN_MANAGER: ShutdownManager = ShutdownManager {
is_shutting_down: AtomicBool::new(false),
shutdown_tx: Mutex::new(None), // Will be initialized in `init`.
};

/// Broadcast Channel Capacity
/// Note: The actual capacity may be greater than the provided capacity.
const BROADCAST_CAPACITY: usize = 1;

/// ShutdownManager is a singleton that manages the shutdown of the
/// application. Services can register to be notified when a graceful
/// shutdown is initiated using [`ShutdownManager::wait_for_shutdown`].
/// When the future returned by [`ShutdownManager::wait_for_shutdown`] is
/// completed, the caller will then be handed a [`ShutdownGuard`] which
/// must be held until the caller has completed its shutdown procedure.
/// Once the caller has completed its shutdown procedure, the caller
/// must drop the [`ShutdownGuard`]. When all [`ShutdownGuard`]s have
/// been dropped, the application will then exit.
pub struct ShutdownManager {
is_shutting_down: AtomicBool,
shutdown_tx: Mutex<Option<broadcast::Sender<Arc<oneshot::Sender<()>>>>>,
}

impl ShutdownManager {
#[allow(clippy::disallowed_methods)]
pub fn init(runtime: &Handle) {
let (shutdown_tx, _) = broadcast::channel::<Arc<oneshot::Sender<()>>>(BROADCAST_CAPACITY);
*SHUTDOWN_MANAGER.shutdown_tx.lock() = Some(shutdown_tx);

runtime.spawn(async move {
tokio::signal::ctrl_c()
.await
.expect("Failed to listen to SIGINT");
event!(Level::WARN, "User terminated process via SIGINT");
std::process::exit(130);
});

#[cfg(target_family = "unix")]
{
runtime.spawn(async move {
signal(SignalKind::terminate())
.expect("Failed to listen to SIGTERM")
.recv()
.await;
event!(Level::WARN, "Received SIGTERM, begginning shutdown.");
Self::graceful_shutdown();
});
}
}

pub fn is_shutting_down() -> bool {
SHUTDOWN_MANAGER.is_shutting_down.load(Ordering::Acquire)
}

#[allow(clippy::disallowed_methods)]
fn graceful_shutdown() {
if SHUTDOWN_MANAGER
.is_shutting_down
.swap(true, Ordering::Release)
{
event!(Level::WARN, "Shutdown already in progress.");
return;
}
tokio::spawn(async move {
let (complete_tx, complete_rx) = oneshot::channel::<()>();
let shutdown_guard = Arc::new(complete_tx);
{
let shutdown_tx_lock = SHUTDOWN_MANAGER.shutdown_tx.lock();
// No need to check result of send, since it will only fail if
// all receivers have been dropped, in which case it means we
// can safely shutdown.
let _ = shutdown_tx_lock
.as_ref()
.expect("ShutdownManager was never initialized")
.send(shutdown_guard);
}
// It is impossible for the result to be anything but Err(RecvError),
// which means all receivers have been dropped and we can safely shutdown.
let _ = complete_rx.await;
event!(Level::WARN, "All services gracefully shutdown.",);
std::process::exit(143);
});
}

pub async fn wait_for_shutdown(service_name: impl Into<String>) -> ShutdownGuard {
let service_name = service_name.into();
let mut shutdown_receiver = SHUTDOWN_MANAGER
.shutdown_tx
.lock()
.as_ref()
.expect("ShutdownManager was never initialized")
.subscribe();
let sender = shutdown_receiver
.recv()
.await
.expect("Shutdown sender dropped. This should never happen.");
event!(
Level::INFO,
"Service {service_name} has been notified of shutdown request"
);
ShutdownGuard {
service_name,
_guard: sender,
}
}
}

#[derive(Clone)]
pub struct ShutdownGuard {
service_name: String,
_guard: Arc<oneshot::Sender<()>>,
}

impl Drop for ShutdownGuard {
fn drop(&mut self) {
event!(
Level::INFO,
"Service {} has completed shutdown.",
self.service_name
);
}
}
21 changes: 10 additions & 11 deletions nativelink-worker/src/local_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ use nativelink_util::common::fs;
use nativelink_util::digest_hasher::{DigestHasherFunc, ACTIVE_HASHER_FUNC};
use nativelink_util::metrics_utils::{AsyncCounterWrapper, CounterWithTime};
use nativelink_util::origin_context::ActiveOriginContext;
use nativelink_util::shutdown_manager::ShutdownManager;
use nativelink_util::store_trait::Store;
use nativelink_util::{spawn, tls_utils};
use tokio::process;
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::sync::mpsc;
use tokio::time::sleep;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::Streaming;
Expand Down Expand Up @@ -168,7 +169,6 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
async fn run(
&mut self,
update_for_worker_stream: Streaming<UpdateForWorker>,
shutdown_rx: &mut broadcast::Receiver<Arc<oneshot::Sender<()>>>,
) -> Result<(), Error> {
// This big block of logic is designed to help simplify upstream components. Upstream
// components can write standard futures that return a `Result<(), Error>` and this block
Expand Down Expand Up @@ -350,18 +350,17 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
futures.push(fut);
},
res = futures.next() => res.err_tip(|| "Keep-alive should always pending. Likely unable to send data to scheduler")??,
complete_msg = shutdown_rx.recv().fuse() => {
event!(Level::WARN, "Worker loop reveiced shutdown signal. Shutting down worker...",);
complete_msg = ShutdownManager::wait_for_shutdown("LocalWorker").fuse() => {
event!(Level::INFO, "Worker loop reveiced shutdown signal. Shutting down worker...",);
let mut grpc_client = self.grpc_client.clone();
let worker_id = self.worker_id.clone();
let running_actions_manager = self.running_actions_manager.clone();
let complete_msg_clone = complete_msg.map_err(|e| make_err!(Code::Internal, "Failed to receive shutdown message: {e:?}"))?.clone();
let shutdown_future = async move {
if let Err(e) = grpc_client.going_away(GoingAwayRequest { worker_id }).await {
event!(Level::ERROR, "Failed to send GoingAwayRequest: {e}",);
return Err(e.into());
}
running_actions_manager.complete_actions(complete_msg_clone).await;
running_actions_manager.complete_actions(complete_msg).await;
Ok::<(), Error>(())
};
futures.push(shutdown_future.boxed());
Expand Down Expand Up @@ -526,10 +525,7 @@ impl<T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorker<T, U> {
}

#[instrument(skip(self), level = Level::INFO)]
pub async fn run(
mut self,
mut shutdown_rx: broadcast::Receiver<Arc<oneshot::Sender<()>>>,
) -> Result<(), Error> {
pub async fn run(mut self) -> Result<(), Error> {
let sleep_fn = self
.sleep_fn
.take()
Expand Down Expand Up @@ -575,7 +571,10 @@ impl<T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorker<T, U> {
);

// Now listen for connections and run all other services.
if let Err(err) = inner.run(update_for_worker_stream, &mut shutdown_rx).await {
if let Err(err) = inner.run(update_for_worker_stream).await {
if ShutdownManager::is_shutting_down() {
return Ok(()); // Do not reconnect if we are shutting down.
}
'no_more_actions: {
// Ensure there are no actions in transit before we try to kill
// all our actions.
Expand Down
10 changes: 4 additions & 6 deletions nativelink-worker/src/running_actions_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ use nativelink_util::action_messages::{
use nativelink_util::common::{fs, DigestInfo};
use nativelink_util::digest_hasher::{DigestHasher, DigestHasherFunc};
use nativelink_util::metrics_utils::{AsyncCounterWrapper, CounterWithTime};
use nativelink_util::shutdown_manager::ShutdownGuard;
use nativelink_util::store_trait::{Store, StoreLike, UploadSizeInfo};
use nativelink_util::{background_spawn, spawn, spawn_blocking};
use parking_lot::Mutex;
Expand Down Expand Up @@ -1349,10 +1350,7 @@ pub trait RunningActionsManager: Sync + Send + Sized + Unpin + 'static {
hasher: DigestHasherFunc,
) -> impl Future<Output = Result<(), Error>> + Send;

fn complete_actions(
&self,
complete_msg: Arc<oneshot::Sender<()>>,
) -> impl Future<Output = ()> + Send;
fn complete_actions(&self, shutdown_guard: ShutdownGuard) -> impl Future<Output = ()> + Send;

fn kill_all(&self) -> impl Future<Output = ()> + Send;

Expand Down Expand Up @@ -1885,9 +1883,9 @@ impl RunningActionsManager for RunningActionsManagerImpl {
}

// Waits for all running actions to complete and signals completion.
// Use the Arc<oneshot::Sender<()>> to signal the completion of the actions
// Use the [`ShutdownGuard`] to signal the completion of the actions
// Dropping the sender automatically notifies the process to terminate.
async fn complete_actions(&self, _complete_msg: Arc<oneshot::Sender<()>>) {
async fn complete_actions(&self, _shutdown_guard: ShutdownGuard) {
let _ = self
.action_done_tx
.subscribe()
Expand Down
11 changes: 2 additions & 9 deletions nativelink-worker/tests/utils/local_worker_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use nativelink_util::spawn;
use nativelink_util::task::JoinHandleDropGuard;
use nativelink_worker::local_worker::LocalWorker;
use nativelink_worker::worker_api_client_wrapper::WorkerApiClientTrait;
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::sync::mpsc;
use tonic::Status;
use tonic::{
codec::Codec, // Needed for .decoder().
Expand All @@ -40,10 +40,6 @@ use tonic::{

use super::mock_running_actions_manager::MockRunningActionsManager;

/// Broadcast Channel Capacity
/// Note: The actual capacity may be greater than the provided capacity.
const BROADCAST_CAPACITY: usize = 1;

#[derive(Debug)]
enum WorkerClientApiCalls {
ConnectWorker(SupportedProperties),
Expand Down Expand Up @@ -198,11 +194,8 @@ pub async fn setup_local_worker_with_config(local_worker_config: LocalWorkerConf
}),
Box::new(move |_| Box::pin(async move { /* No sleep */ })),
);
let (shutdown_tx_test, _) = broadcast::channel::<Arc<oneshot::Sender<()>>>(BROADCAST_CAPACITY);

let drop_guard = spawn!("local_worker_spawn", async move {
worker.run(shutdown_tx_test.subscribe()).await
});
let drop_guard = spawn!("local_worker_spawn", worker.run());

let (tx_stream, streaming_response) = setup_grpc_stream();
TestContext {
Expand Down
8 changes: 3 additions & 5 deletions nativelink-worker/tests/utils/mock_running_actions_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::
use nativelink_util::action_messages::{ActionResult, OperationId};
use nativelink_util::common::DigestInfo;
use nativelink_util::digest_hasher::DigestHasherFunc;
use nativelink_util::shutdown_manager::ShutdownGuard;
use nativelink_worker::running_actions_manager::{Metrics, RunningAction, RunningActionsManager};
use tokio::sync::{mpsc, oneshot};
use tokio::sync::mpsc;

#[derive(Debug)]
enum RunningActionManagerCalls {
Expand Down Expand Up @@ -167,10 +168,7 @@ impl RunningActionsManager for MockRunningActionsManager {
Ok(())
}

fn complete_actions(
&self,
_complete_msg: Arc<oneshot::Sender<()>>,
) -> impl Future<Output = ()> + Send {
fn complete_actions(&self, _shutdown_guard: ShutdownGuard) -> impl Future<Output = ()> + Send {
future::ready(())
}

Expand Down
Loading

0 comments on commit 28d42d4

Please sign in to comment.