diff --git a/benches/overhead.rs b/benches/overhead.rs index 7b62830..710ad6d 100644 --- a/benches/overhead.rs +++ b/benches/overhead.rs @@ -1,4 +1,4 @@ -use chili::{Scope, ThreadPool}; +use chili::Scope; use divan::Bencher; struct Node { @@ -46,8 +46,7 @@ fn no_overhead(bencher: Bencher, nodes: (usize, usize)) { } let tree = Node::tree(nodes.0); - let thread_pool = ThreadPool::new().unwrap(); - let mut scope = thread_pool.scope(); + let mut scope = Scope::global(); bencher.bench_local(move || { assert_eq!(sum(&tree, &mut scope), nodes.1 as u64); @@ -66,8 +65,7 @@ fn chili_overhead(bencher: Bencher, nodes: (usize, usize)) { } let tree = Node::tree(nodes.0); - let thread_pool = ThreadPool::new().unwrap(); - let mut scope = thread_pool.scope(); + let mut scope = Scope::global(); bencher.bench_local(move || { assert_eq!(sum(&tree, &mut scope), nodes.1 as u64); diff --git a/src/lib.rs b/src/lib.rs index 3ecc720..c7155f1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,7 +16,7 @@ //! # Examples //! //! ``` -//! # use chili::{Scope, ThreadPool}; +//! # use chili::Scope; //! struct Node { //! val: u64, //! left: Option>, @@ -44,10 +44,7 @@ //! //! let tree = Node::tree(10); //! -//! let mut thread_pool = ThreadPool::new().unwrap(); -//! let mut scope = thread_pool.scope(); -//! -//! assert_eq!(sum(&tree, &mut scope), 1023); +//! assert_eq!(sum(&tree, &mut Scope::global()), 1023); //! ``` use std::{ @@ -58,7 +55,7 @@ use std::{ panic, sync::{ atomic::{AtomicBool, Ordering}, - Arc, Barrier, Condvar, Mutex, Weak, + Arc, Barrier, Condvar, Mutex, OnceLock, Weak, }, thread::{self, JoinHandle}, time::{Duration, Instant}, @@ -226,13 +223,13 @@ impl DerefMut for ThreadJobQueue<'_> { /// /// ``` /// # use chili::ThreadPool; -/// let mut tp = ThreadPool::new().unwrap(); +/// let mut tp = ThreadPool::new(); /// let mut s = tp.scope(); /// /// let mut vals = [0; 2]; /// let (left, right) = vals.split_at_mut(1); /// -/// s.join(|_|left[0] = 1, |_| right[0] = 1); +/// s.join(|_| left[0] = 1, |_| right[0] = 1); /// /// assert_eq!(vals, [1; 2]); /// ``` @@ -245,6 +242,18 @@ pub struct Scope<'s> { } impl<'s> Scope<'s> { + /// Returns the global scope. + /// + /// # Examples + /// + /// ``` + /// # use chili::Scope; + /// let _s = Scope::global(); + /// ``` + pub fn global() -> Scope<'static> { + ThreadPool::global().scope() + } + fn new_from_thread_pool(thread_pool: &'s ThreadPool) -> Self { let heartbeat = thread_pool.context.lock.lock().unwrap().new_heartbeat(); thread_pool @@ -421,14 +430,11 @@ impl<'s> Scope<'s> { /// # Examples /// /// ``` - /// # use chili::ThreadPool; - /// let mut tp = ThreadPool::new().unwrap(); - /// let mut s = tp.scope(); - /// + /// # use chili::Scope; /// let mut vals = [0; 2]; /// let (left, right) = vals.split_at_mut(1); /// - /// s.join(|_|left[0] = 1, |_| right[0] = 1); + /// Scope::global().join(|_| left[0] = 1, |_| right[0] = 1); /// /// assert_eq!(vals, [1; 2]); /// ``` @@ -451,15 +457,13 @@ impl<'s> Scope<'s> { /// # Examples /// /// ``` - /// # use chili::ThreadPool; - /// let mut tp = ThreadPool::new().unwrap(); - /// let mut s = tp.scope(); + /// # use chili::Scope; /// /// let mut vals = [0; 2]; /// let (left, right) = vals.split_at_mut(1); /// /// // Skip checking 7/8 calls to join_with_heartbeat_every. - /// s.join_with_heartbeat_every::<8, _, _, _, _>(|_|left[0] = 1, |_| right[0] = 1); + /// Scope::global().join_with_heartbeat_every::<8, _, _, _, _>(|_| left[0] = 1, |_| right[0] = 1); /// /// assert_eq!(vals, [1; 2]); /// ``` @@ -489,7 +493,7 @@ impl<'s> Scope<'s> { pub struct Config { /// The number of threads or `None` to use /// `std::thread::available_parallelism`. - pub thread_count: Option, + pub thread_count: Option>, /// The interval between heartbeats on any particular thread. pub heartbeat_interval: Duration, } @@ -503,6 +507,8 @@ impl Default for Config { } } +static GLOBAL_THREAD_POOL: OnceLock = OnceLock::new(); + /// A thread pool for running fork-join workloads. #[derive(Debug)] pub struct ThreadPool { @@ -518,9 +524,9 @@ impl ThreadPool { /// /// ``` /// # use chili::ThreadPool; - /// let _tp = ThreadPool::new().unwrap(); + /// let _tp = ThreadPool::new(); /// ``` - pub fn new() -> Option { + pub fn new() -> Self { Self::with_config(Config::default()) } @@ -529,19 +535,19 @@ impl ThreadPool { /// # Examples /// /// ``` - /// # use std::time::Duration; + /// # use std::{num::NonZero, time::Duration}; /// # use chili::{Config, ThreadPool}; /// let _tp = ThreadPool::with_config(Config { - /// thread_count: Some(1), + /// thread_count: Some(NonZero::new(1).unwrap()), /// heartbeat_interval: Duration::from_micros(50), - /// }).unwrap(); + /// }); /// ``` - pub fn with_config(config: Config) -> Option { - let thread_count = config.thread_count.or_else(|| { - thread::available_parallelism() - .ok() - .map(NonZero::::get) - })? - 1; + pub fn with_config(config: Config) -> Self { + let thread_count = config + .thread_count + .or_else(|| thread::available_parallelism().ok()) + .map(|thread_count| thread_count.get() - 1) + .unwrap_or_default(); let worker_barrier = Arc::new(Barrier::new(thread_count + 1)); let context = Arc::new(Context { @@ -562,13 +568,53 @@ impl ThreadPool { worker_barrier.wait(); - Some(Self { + Self { context: context.clone(), worker_handles, heartbeat_handle: Some(thread::spawn(move || { execute_heartbeat(context, config.heartbeat_interval, thread_count); })), - }) + } + } + + /// Sets the global thread pool to this one. + /// + /// The global thread pool can only be set once. Any subsequent call will + /// return the thread pool back. + /// + /// # Examples + /// + /// ``` + /// # use std::{num::NonZero, time::Duration}; + /// # use chili::{Config, ThreadPool}; + /// ThreadPool::with_config(Config { + /// thread_count: Some(NonZero::new(1).unwrap()), + /// heartbeat_interval: Duration::from_micros(50), + /// }) + /// .set_global() + /// .unwrap(); + /// ``` + pub fn set_global(self) -> Result<(), Self> { + GLOBAL_THREAD_POOL.set(self) + } + + /// Returns the global thread pool. + /// + /// # Examples + /// + /// ``` + /// # use chili::ThreadPool; + /// let mut s = ThreadPool::global().scope(); + /// + /// let mut vals = [0; 2]; + /// let (left, right) = vals.split_at_mut(1); + /// + /// s.join(|_| left[0] = 1, |_| right[0] = 1); + /// + /// assert_eq!(vals, [1; 2]); + /// ``` + pub fn global() -> &'static ThreadPool { + GLOBAL_THREAD_POOL.get_or_init(ThreadPool::new) } /// Returns a `Scope`d object that you can run fork-join workloads on. @@ -577,13 +623,13 @@ impl ThreadPool { /// /// ``` /// # use chili::ThreadPool; - /// let mut tp = ThreadPool::new().unwrap(); + /// let mut tp = ThreadPool::new(); /// let mut s = tp.scope(); /// /// let mut vals = [0; 2]; /// let (left, right) = vals.split_at_mut(1); /// - /// s.join(|_|left[0] = 1, |_| right[0] = 1); + /// s.join(|_| left[0] = 1, |_| right[0] = 1); /// /// assert_eq!(vals, [1; 2]); /// ``` @@ -627,7 +673,7 @@ mod tests { #[test] fn join_basic() { - let threat_pool = ThreadPool::new().unwrap(); + let threat_pool = ThreadPool::new(); let mut scope = threat_pool.scope(); let mut a = 0; @@ -640,7 +686,7 @@ mod tests { #[test] fn join_long() { - let threat_pool = ThreadPool::new().unwrap(); + let threat_pool = ThreadPool::new(); fn increment(s: &mut Scope, slice: &mut [u32]) { match slice.len() { @@ -663,7 +709,7 @@ mod tests { #[test] fn join_very_long() { - let threat_pool = ThreadPool::new().unwrap(); + let threat_pool = ThreadPool::new(); fn increment(s: &mut Scope, slice: &mut [u32]) { match slice.len() { @@ -688,11 +734,10 @@ mod tests { #[test] fn join_wait() { let threat_pool = ThreadPool::with_config(Config { - thread_count: Some(2), + thread_count: Some(NonZero::new(2).unwrap()), heartbeat_interval: Duration::from_micros(1), ..Default::default() - }) - .unwrap(); + }); fn increment(s: &mut Scope, slice: &mut [u32]) { match slice.len() { @@ -723,10 +768,9 @@ mod tests { #[should_panic(expected = "panicked across threads")] fn join_panic() { let threat_pool = ThreadPool::with_config(Config { - thread_count: Some(2), + thread_count: Some(NonZero::new(2).unwrap()), heartbeat_interval: Duration::from_micros(1), - }) - .unwrap(); + }); if let Some(thread_count) = thread::available_parallelism().ok().map(NonZero::get) { if thread_count == 1 { @@ -782,10 +826,9 @@ mod tests { fn concurrent_scopes() { const NUM_THREADS: u8 = 128; let threat_pool = ThreadPool::with_config(Config { - thread_count: Some(4), + thread_count: Some(NonZero::new(4).unwrap()), ..Default::default() - }) - .unwrap(); + }); let a = AtomicU8::new(0); let b = AtomicU8::new(0);