Skip to content

Commit

Permalink
Support graceful shutdown for network connections
Browse files Browse the repository at this point in the history
Waits for network connections to close during graceful shutdown.
  • Loading branch information
allada committed Oct 29, 2024
1 parent 51a2fd4 commit 39db838
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 105 deletions.
4 changes: 4 additions & 0 deletions nativelink-config/src/cas_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,10 @@ pub struct ServerConfig {

/// Services to attach to server.
pub services: Option<ServicesConfig>,

/// Do not wait for connections to close during a graceful shutdown.
#[serde(default)]
pub experimental_connections_dont_block_graceful_shutdown: bool,
}

#[allow(non_camel_case_types)]
Expand Down
1 change: 1 addition & 0 deletions nativelink-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub fn nativelink_test(attr: TokenStream, item: TokenStream) -> TokenStream {
#[allow(clippy::disallowed_methods)]
#[tokio::test(#attr)]
async fn #fn_name(#fn_inputs) #fn_output {
nativelink_util::shutdown_manager::ShutdownManager::init(&tokio::runtime::Handle::current());
// Error means already initialized, which is ok.
let _ = nativelink_util::init_tracing();
// If already set it's ok.
Expand Down
2 changes: 1 addition & 1 deletion nativelink-service/src/worker_api_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ impl WorkerApiServer {
) -> Result<Response<()>, Error> {
let worker_id: WorkerId = going_away_request.worker_id.try_into()?;
self.scheduler
.remove_worker(&worker_id)
.set_drain_worker(&worker_id, true)
.await
.err_tip(|| "While calling WorkerApiServer::inner_going_away")?;
Ok(Response::new(()))
Expand Down
1 change: 1 addition & 0 deletions nativelink-util/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ rust_library(
"src/proto_stream_utils.rs",
"src/resource_info.rs",
"src/retry.rs",
"src/shutdown_manager.rs",
"src/store_trait.rs",
"src/task.rs",
"src/tls_utils.rs",
Expand Down
1 change: 1 addition & 0 deletions nativelink-util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub mod platform_properties;
pub mod proto_stream_utils;
pub mod resource_info;
pub mod retry;
pub mod shutdown_manager;
pub mod store_trait;
pub mod task;
pub mod tls_utils;
Expand Down
146 changes: 146 additions & 0 deletions nativelink-util/src/shutdown_manager.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// Copyright 2024 The NativeLink Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use parking_lot::Mutex;
use tokio::runtime::Handle;
#[cfg(target_family = "unix")]
use tokio::signal::unix::{signal, SignalKind};
use tokio::sync::{broadcast, oneshot};
use tracing::{event, Level};

static SHUTDOWN_MANAGER: ShutdownManager = ShutdownManager {
is_shutting_down: AtomicBool::new(false),
shutdown_tx: Mutex::new(None), // Will be initialized in `init`.
};

/// Broadcast Channel Capacity
/// Note: The actual capacity may be greater than the provided capacity.
const BROADCAST_CAPACITY: usize = 1;

/// ShutdownManager is a singleton that manages the shutdown of the
/// application. Services can register to be notified when a graceful
/// shutdown is initiated using [`ShutdownManager::wait_for_shutdown`].
/// When the future returned by [`ShutdownManager::wait_for_shutdown`] is
/// completed, the caller will then be handed a [`ShutdownGuard`] which
/// must be held until the caller has completed its shutdown procedure.
/// Once the caller has completed its shutdown procedure, the caller
/// must drop the [`ShutdownGuard`]. When all [`ShutdownGuard`]s have
/// been dropped, the application will then exit.
pub struct ShutdownManager {
is_shutting_down: AtomicBool,
shutdown_tx: Mutex<Option<broadcast::Sender<Arc<oneshot::Sender<()>>>>>,
}

impl ShutdownManager {
#[allow(clippy::disallowed_methods)]
pub fn init(runtime: &Handle) {
let (shutdown_tx, _) = broadcast::channel::<Arc<oneshot::Sender<()>>>(BROADCAST_CAPACITY);
*SHUTDOWN_MANAGER.shutdown_tx.lock() = Some(shutdown_tx);

runtime.spawn(async move {
tokio::signal::ctrl_c()
.await
.expect("Failed to listen to SIGINT");
event!(Level::WARN, "User terminated process via SIGINT");
std::process::exit(130);
});

#[cfg(target_family = "unix")]
{
runtime.spawn(async move {
signal(SignalKind::terminate())
.expect("Failed to listen to SIGTERM")
.recv()
.await;
event!(Level::WARN, "Received SIGTERM, begginning shutdown.");
Self::graceful_shutdown();
});
}
}

pub fn is_shutting_down() -> bool {
SHUTDOWN_MANAGER.is_shutting_down.load(Ordering::Acquire)
}

#[allow(clippy::disallowed_methods)]
fn graceful_shutdown() {
if SHUTDOWN_MANAGER
.is_shutting_down
.swap(true, Ordering::Release)
{
event!(Level::WARN, "Shutdown already in progress.");
return;
}
tokio::spawn(async move {
let (complete_tx, complete_rx) = oneshot::channel::<()>();
let shutdown_guard = Arc::new(complete_tx);
{
let shutdown_tx_lock = SHUTDOWN_MANAGER.shutdown_tx.lock();
// No need to check result of send, since it will only fail if
// all receivers have been dropped, in which case it means we
// can safely shutdown.
let _ = shutdown_tx_lock
.as_ref()
.expect("ShutdownManager was never initialized")
.send(shutdown_guard);
}
// It is impossible for the result to be anything but Err(RecvError),
// which means all receivers have been dropped and we can safely shutdown.
let _ = complete_rx.await;
event!(Level::WARN, "All services gracefully shutdown.",);
std::process::exit(143);
});
}

pub async fn wait_for_shutdown(service_name: impl Into<String>) -> ShutdownGuard {
let service_name = service_name.into();
let mut shutdown_receiver = SHUTDOWN_MANAGER
.shutdown_tx
.lock()
.as_ref()
.expect("ShutdownManager was never initialized")
.subscribe();
let sender = shutdown_receiver
.recv()
.await
.expect("Shutdown sender dropped. This should never happen.");
event!(
Level::INFO,
"Service {service_name} has been notified of shutdown request"
);
ShutdownGuard {
service_name,
_guard: sender,
}
}
}

#[derive(Clone)]
pub struct ShutdownGuard {
service_name: String,
_guard: Arc<oneshot::Sender<()>>,
}

impl Drop for ShutdownGuard {
fn drop(&mut self) {
event!(
Level::INFO,
"Service {} has completed shutdown.",
self.service_name
);
}
}
61 changes: 45 additions & 16 deletions nativelink-worker/src/local_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ use nativelink_util::common::fs;
use nativelink_util::digest_hasher::{DigestHasherFunc, ACTIVE_HASHER_FUNC};
use nativelink_util::metrics_utils::{AsyncCounterWrapper, CounterWithTime};
use nativelink_util::origin_context::ActiveOriginContext;
use nativelink_util::shutdown_manager::ShutdownManager;
use nativelink_util::store_trait::Store;
use nativelink_util::{spawn, tls_utils};
use tokio::process;
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::sync::mpsc;
use tokio::time::sleep;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::Streaming;
Expand Down Expand Up @@ -168,7 +169,6 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
async fn run(
&mut self,
update_for_worker_stream: Streaming<UpdateForWorker>,
shutdown_rx: &mut broadcast::Receiver<Arc<oneshot::Sender<()>>>,
) -> Result<(), Error> {
// This big block of logic is designed to help simplify upstream components. Upstream
// components can write standard futures that return a `Result<(), Error>` and this block
Expand All @@ -190,9 +190,20 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,

let mut update_for_worker_stream = update_for_worker_stream.fuse();

// If we are shutting down we need to hold onto the shutdown guard
// until we are done processing all the futures.
let mut _maybe_shutdown_guard = None;
loop {
select! {
maybe_update = update_for_worker_stream.next() => {
if maybe_update.is_none() && ShutdownManager::is_shutting_down() {
event!(
Level::ERROR,
"Closed stream",
);
// Happy shutdown path, no need to log anything.
continue;
}
match maybe_update
.err_tip(|| "UpdateForWorker stream closed early")?
.err_tip(|| "Got error in UpdateForWorker stream")?
Expand Down Expand Up @@ -349,22 +360,39 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
let fut = res.err_tip(|| "New future stream receives should never be closed")?;
futures.push(fut);
},
res = futures.next() => res.err_tip(|| "Keep-alive should always pending. Likely unable to send data to scheduler")??,
complete_msg = shutdown_rx.recv().fuse() => {
event!(Level::WARN, "Worker loop reveiced shutdown signal. Shutting down worker...",);
res = futures.next() => {
let res = res.err_tip(|| "Keep-alive should always pending. Likely unable to send data to scheduler")?;
// If we are shutting down and we get an error, we want to
// keep draining, but not reconnect.
if ShutdownManager::is_shutting_down() {
// If we are shutting down and we only have keep alive left,
// we can exit.
if futures.len() == 1 {
return Ok(());
}
if res.is_err() {
event!(
Level::ERROR,
"During shutdown future failed with error: {:?}", res.unwrap_err(),
);
continue;
}
}
// If we are not shutting down and get an error, return the error.
res?;
},
shutdown_guard = ShutdownManager::wait_for_shutdown("LocalWorker").fuse() => {
_maybe_shutdown_guard = Some(shutdown_guard);
event!(Level::INFO, "Worker loop reveiced shutdown signal. Shutting down worker...",);
let mut grpc_client = self.grpc_client.clone();
let worker_id = self.worker_id.clone();
let running_actions_manager = self.running_actions_manager.clone();
let complete_msg_clone = complete_msg.map_err(|e| make_err!(Code::Internal, "Failed to receive shutdown message: {e:?}"))?.clone();
let shutdown_future = async move {
futures.push(async move {
if let Err(e) = grpc_client.going_away(GoingAwayRequest { worker_id }).await {
event!(Level::ERROR, "Failed to send GoingAwayRequest: {e}",);
return Err(e.into());
}
running_actions_manager.complete_actions(complete_msg_clone).await;
Ok::<(), Error>(())
};
futures.push(shutdown_future.boxed());
}.boxed());
},
};
}
Expand Down Expand Up @@ -526,10 +554,7 @@ impl<T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorker<T, U> {
}

#[instrument(skip(self), level = Level::INFO)]
pub async fn run(
mut self,
mut shutdown_rx: broadcast::Receiver<Arc<oneshot::Sender<()>>>,
) -> Result<(), Error> {
pub async fn run(mut self) -> Result<(), Error> {
let sleep_fn = self
.sleep_fn
.take()
Expand Down Expand Up @@ -575,7 +600,11 @@ impl<T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorker<T, U> {
);

// Now listen for connections and run all other services.
if let Err(err) = inner.run(update_for_worker_stream, &mut shutdown_rx).await {
let res = inner.run(update_for_worker_stream).await;
if ShutdownManager::is_shutting_down() {
return Ok(()); // Do not reconnect if we are shutting down.
}
if let Err(err) = res {
'no_more_actions: {
// Ensure there are no actions in transit before we try to kill
// all our actions.
Expand Down
19 changes: 1 addition & 18 deletions nativelink-worker/src/running_actions_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1349,11 +1349,6 @@ pub trait RunningActionsManager: Sync + Send + Sized + Unpin + 'static {
hasher: DigestHasherFunc,
) -> impl Future<Output = Result<(), Error>> + Send;

fn complete_actions(
&self,
complete_msg: Arc<oneshot::Sender<()>>,
) -> impl Future<Output = ()> + Send;

fn kill_all(&self) -> impl Future<Output = ()> + Send;

fn kill_operation(
Expand Down Expand Up @@ -1884,17 +1879,6 @@ impl RunningActionsManager for RunningActionsManagerImpl {
Ok(())
}

// Waits for all running actions to complete and signals completion.
// Use the Arc<oneshot::Sender<()>> to signal the completion of the actions
// Dropping the sender automatically notifies the process to terminate.
async fn complete_actions(&self, _complete_msg: Arc<oneshot::Sender<()>>) {
let _ = self
.action_done_tx
.subscribe()
.wait_for(|_| self.running_actions.lock().is_empty())
.await;
}

// Note: When the future returns the process should be fully killed and cleaned up.
async fn kill_all(&self) {
self.metrics
Expand All @@ -1918,8 +1902,7 @@ impl RunningActionsManager for RunningActionsManagerImpl {
let _ = self
.action_done_tx
.subscribe()
.wait_for(|_| self.running_actions.lock().is_empty())
.await;
.wait_for(|_| self.running_actions.lock().is_empty());
}

#[inline]
Expand Down
11 changes: 2 additions & 9 deletions nativelink-worker/tests/utils/local_worker_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use nativelink_util::spawn;
use nativelink_util::task::JoinHandleDropGuard;
use nativelink_worker::local_worker::LocalWorker;
use nativelink_worker::worker_api_client_wrapper::WorkerApiClientTrait;
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::sync::mpsc;
use tonic::Status;
use tonic::{
codec::Codec, // Needed for .decoder().
Expand All @@ -40,10 +40,6 @@ use tonic::{

use super::mock_running_actions_manager::MockRunningActionsManager;

/// Broadcast Channel Capacity
/// Note: The actual capacity may be greater than the provided capacity.
const BROADCAST_CAPACITY: usize = 1;

#[derive(Debug)]
enum WorkerClientApiCalls {
ConnectWorker(SupportedProperties),
Expand Down Expand Up @@ -198,11 +194,8 @@ pub async fn setup_local_worker_with_config(local_worker_config: LocalWorkerConf
}),
Box::new(move |_| Box::pin(async move { /* No sleep */ })),
);
let (shutdown_tx_test, _) = broadcast::channel::<Arc<oneshot::Sender<()>>>(BROADCAST_CAPACITY);

let drop_guard = spawn!("local_worker_spawn", async move {
worker.run(shutdown_tx_test.subscribe()).await
});
let drop_guard = spawn!("local_worker_spawn", worker.run());

let (tx_stream, streaming_response) = setup_grpc_stream();
TestContext {
Expand Down
Loading

0 comments on commit 39db838

Please sign in to comment.