diff --git a/Cargo.lock b/Cargo.lock index 1c00c62918..febcba6291 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5257,6 +5257,7 @@ checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" name = "infra_utils" version = "0.0.0" dependencies = [ + "nix 0.20.2", "num-traits 0.2.19", "pretty_assertions", "regex", diff --git a/crates/infra_utils/Cargo.toml b/crates/infra_utils/Cargo.toml index 4f086f55a0..08741a3f4f 100644 --- a/crates/infra_utils/Cargo.toml +++ b/crates/infra_utils/Cargo.toml @@ -12,11 +12,12 @@ workspace = true [dependencies] num-traits.workspace = true regex.workspace = true -tokio = { workspace = true, features = ["process", "time"] } +tokio = { workspace = true, features = ["process", "rt", "time"] } tracing.workspace = true [dev-dependencies] +nix.workspace = true pretty_assertions.workspace = true rstest.workspace = true -tokio = { workspace = true, features = ["macros", "rt", "sync"] } +tokio = { workspace = true, features = ["macros", "rt", "signal", "sync"] } tracing-subscriber = { workspace = true, features = ["env-filter"] } diff --git a/crates/infra_utils/src/lib.rs b/crates/infra_utils/src/lib.rs index 2a6f8ed8e6..548702976e 100644 --- a/crates/infra_utils/src/lib.rs +++ b/crates/infra_utils/src/lib.rs @@ -2,5 +2,6 @@ pub mod command; pub mod metrics; pub mod path; pub mod run_until; +pub mod tasks; pub mod tracing; pub mod type_name; diff --git a/crates/infra_utils/src/tasks.rs b/crates/infra_utils/src/tasks.rs new file mode 100644 index 0000000000..0f40bf7e29 --- /dev/null +++ b/crates/infra_utils/src/tasks.rs @@ -0,0 +1,63 @@ +use std::future::Future; + +use tokio::task::JoinHandle; +use tracing::error; + +#[cfg(test)] +#[path = "tasks_test.rs"] +mod tasks_test; + +/// Spawns a monitored asynchronous task in Tokio. +/// +/// This function spawns two tasks: +/// 1. The first task executes the provided future. +/// 2. The second task awaits the completion of the first task. +/// - If the first task completes successfully, then it returns its result. +/// - If the first task panics, it logs the error and terminates the process with exit code 1. +/// +/// # Type Parameters +/// +/// - `F`: The type of the future to be executed. Must implement `Future` and be `Send + 'static`. +/// - `T`: The output type of the future. Must be `Send + 'static`. +/// +/// # Arguments +/// +/// - `future`: The future to be executed by the spawned task. +/// +/// # Returns +/// +/// A `JoinHandle` of the second monitoring task. +pub fn spawn_with_exit_on_panic(future: F) -> JoinHandle +where + F: Future + Send + 'static, + T: Send + 'static, +{ + inner_spawn_with_exit_on_panic(future, exit_process) +} + +// Use an inner function to enable injecting the exit function for testing. +pub(crate) fn inner_spawn_with_exit_on_panic(future: F, on_exit_f: E) -> JoinHandle +where + F: Future + Send + 'static, + E: FnOnce() + Send + 'static, + T: Send + 'static, +{ + // Spawn the first task to execute the future + let monitored_task = tokio::spawn(future); + + // Spawn the second task to await the first task and assert its completion + tokio::spawn(async move { + match monitored_task.await { + Ok(res) => res, + Err(err) => { + error!("Monitored task failed: {:?}", err); + on_exit_f(); + unreachable!() + } + } + }) +} + +pub(crate) fn exit_process() { + std::process::exit(1); +} diff --git a/crates/infra_utils/src/tasks_test.rs b/crates/infra_utils/src/tasks_test.rs new file mode 100644 index 0000000000..25d93437bd --- /dev/null +++ b/crates/infra_utils/src/tasks_test.rs @@ -0,0 +1,48 @@ +use rstest::rstest; +use tokio::signal::unix::{signal, SignalKind}; +use tokio::time::{sleep, timeout, Duration}; + +use crate::tasks::{inner_spawn_with_exit_on_panic, spawn_with_exit_on_panic}; + +#[rstest] +#[tokio::test] +async fn test_spawn_with_exit_on_panic_success() { + let handle = spawn_with_exit_on_panic(async { + sleep(Duration::from_millis(10)).await; + }); + + // Await the monitoring task + handle.await.unwrap(); +} + +#[rstest] +#[tokio::test] +async fn test_spawn_with_exit_on_panic_failure() { + // Mock exit process function: instead of calling `std::process::exit(1)`, send 'SIGTERM' to + // self. + let mock_exit_process = || { + // Use fully-qualified nix modules to avoid ambiguity with the tokio ones. + let pid = nix::unistd::getpid(); + nix::sys::signal::kill(pid, nix::sys::signal::Signal::SIGTERM) + .expect("Failed to send signal"); + }; + + // Set up a SIGTERM handler. + let mut sigterm = signal(SignalKind::terminate()).expect("Failed to set up SIGTERM handler"); + + // Spawn a task that panics, and uses the SIGTERM mocked exit process function. + inner_spawn_with_exit_on_panic( + async { + panic!("This task will fail!"); + }, + mock_exit_process, + ); + + // Assert the task failure is detected and that the mocked exit process function is called by + // awaiting for the SIGTERM signal. Bound the timeout to ensure the test does not hang + // indefinitely. + assert!( + timeout(Duration::from_millis(10), sigterm.recv()).await.is_ok(), + "Did not receive SIGTERM signal." + ); +}