Skip to content

Commit

Permalink
Added global thread pool and scope.
Browse files Browse the repository at this point in the history
  • Loading branch information
dragostis committed Sep 21, 2024
1 parent fc27f2b commit 6eaa3d5
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 50 deletions.
8 changes: 3 additions & 5 deletions benches/overhead.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use chili::{Scope, ThreadPool};
use chili::Scope;
use divan::Bencher;

struct Node {
Expand Down Expand Up @@ -46,8 +46,7 @@ fn no_overhead(bencher: Bencher, nodes: (usize, usize)) {
}

let tree = Node::tree(nodes.0);
let thread_pool = ThreadPool::new().unwrap();
let mut scope = thread_pool.scope();
let mut scope = Scope::global();

bencher.bench_local(move || {
assert_eq!(sum(&tree, &mut scope), nodes.1 as u64);
Expand All @@ -66,8 +65,7 @@ fn chili_overhead(bencher: Bencher, nodes: (usize, usize)) {
}

let tree = Node::tree(nodes.0);
let thread_pool = ThreadPool::new().unwrap();
let mut scope = thread_pool.scope();
let mut scope = Scope::global();

bencher.bench_local(move || {
assert_eq!(sum(&tree, &mut scope), nodes.1 as u64);
Expand Down
139 changes: 94 additions & 45 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
//! # Examples
//!
//! ```
//! # use chili::{Scope, ThreadPool};
//! # use chili::Scope;
//! struct Node {
//! val: u64,
//! left: Option<Box<Node>>,
Expand Down Expand Up @@ -44,10 +44,7 @@
//!
//! let tree = Node::tree(10);
//!
//! let mut thread_pool = ThreadPool::new().unwrap();
//! let mut scope = thread_pool.scope();
//!
//! assert_eq!(sum(&tree, &mut scope), 1023);
//! assert_eq!(sum(&tree, &mut Scope::global()), 1023);
//! ```
use std::{
Expand All @@ -58,7 +55,7 @@ use std::{
panic,
sync::{
atomic::{AtomicBool, Ordering},
Arc, Barrier, Condvar, Mutex, Weak,
Arc, Barrier, Condvar, Mutex, OnceLock, Weak,
},
thread::{self, JoinHandle},
time::{Duration, Instant},
Expand Down Expand Up @@ -226,13 +223,13 @@ impl DerefMut for ThreadJobQueue<'_> {
///
/// ```
/// # use chili::ThreadPool;
/// let mut tp = ThreadPool::new().unwrap();
/// let mut tp = ThreadPool::new();
/// let mut s = tp.scope();
///
/// let mut vals = [0; 2];
/// let (left, right) = vals.split_at_mut(1);
///
/// s.join(|_|left[0] = 1, |_| right[0] = 1);
/// s.join(|_| left[0] = 1, |_| right[0] = 1);
///
/// assert_eq!(vals, [1; 2]);
/// ```
Expand All @@ -245,6 +242,18 @@ pub struct Scope<'s> {
}

impl<'s> Scope<'s> {
/// Returns the global scope.
///
/// # Examples
///
/// ```
/// # use chili::Scope;
/// let _s = Scope::global();
/// ```
pub fn global() -> Scope<'static> {
ThreadPool::global().scope()
}

fn new_from_thread_pool(thread_pool: &'s ThreadPool) -> Self {
let heartbeat = thread_pool.context.lock.lock().unwrap().new_heartbeat();
thread_pool
Expand Down Expand Up @@ -421,14 +430,11 @@ impl<'s> Scope<'s> {
/// # Examples
///
/// ```
/// # use chili::ThreadPool;
/// let mut tp = ThreadPool::new().unwrap();
/// let mut s = tp.scope();
///
/// # use chili::Scope;
/// let mut vals = [0; 2];
/// let (left, right) = vals.split_at_mut(1);
///
/// s.join(|_|left[0] = 1, |_| right[0] = 1);
/// Scope::global().join(|_| left[0] = 1, |_| right[0] = 1);
///
/// assert_eq!(vals, [1; 2]);
/// ```
Expand All @@ -451,15 +457,13 @@ impl<'s> Scope<'s> {
/// # Examples
///
/// ```
/// # use chili::ThreadPool;
/// let mut tp = ThreadPool::new().unwrap();
/// let mut s = tp.scope();
/// # use chili::Scope;
///
/// let mut vals = [0; 2];
/// let (left, right) = vals.split_at_mut(1);
///
/// // Skip checking 7/8 calls to join_with_heartbeat_every.
/// s.join_with_heartbeat_every::<8, _, _, _, _>(|_|left[0] = 1, |_| right[0] = 1);
/// Scope::global().join_with_heartbeat_every::<8, _, _, _, _>(|_| left[0] = 1, |_| right[0] = 1);
///
/// assert_eq!(vals, [1; 2]);
/// ```
Expand Down Expand Up @@ -489,7 +493,7 @@ impl<'s> Scope<'s> {
pub struct Config {
/// The number of threads or `None` to use
/// `std::thread::available_parallelism`.
pub thread_count: Option<usize>,
pub thread_count: Option<NonZero<usize>>,
/// The interval between heartbeats on any particular thread.
pub heartbeat_interval: Duration,
}
Expand All @@ -503,6 +507,8 @@ impl Default for Config {
}
}

static GLOBAL_THREAD_POOL: OnceLock<ThreadPool> = OnceLock::new();

/// A thread pool for running fork-join workloads.
#[derive(Debug)]
pub struct ThreadPool {
Expand All @@ -518,9 +524,9 @@ impl ThreadPool {
///
/// ```
/// # use chili::ThreadPool;
/// let _tp = ThreadPool::new().unwrap();
/// let _tp = ThreadPool::new();
/// ```
pub fn new() -> Option<Self> {
pub fn new() -> Self {
Self::with_config(Config::default())
}

Expand All @@ -529,19 +535,19 @@ impl ThreadPool {
/// # Examples
///
/// ```
/// # use std::time::Duration;
/// # use std::{num::NonZero, time::Duration};
/// # use chili::{Config, ThreadPool};
/// let _tp = ThreadPool::with_config(Config {
/// thread_count: Some(1),
/// thread_count: Some(NonZero::new(1).unwrap()),
/// heartbeat_interval: Duration::from_micros(50),
/// }).unwrap();
/// });
/// ```
pub fn with_config(config: Config) -> Option<Self> {
let thread_count = config.thread_count.or_else(|| {
thread::available_parallelism()
.ok()
.map(NonZero::<usize>::get)
})? - 1;
pub fn with_config(config: Config) -> Self {
let thread_count = config
.thread_count
.or_else(|| thread::available_parallelism().ok())
.map(|thread_count| thread_count.get() - 1)
.unwrap_or_default();
let worker_barrier = Arc::new(Barrier::new(thread_count + 1));

let context = Arc::new(Context {
Expand All @@ -562,13 +568,53 @@ impl ThreadPool {

worker_barrier.wait();

Some(Self {
Self {
context: context.clone(),
worker_handles,
heartbeat_handle: Some(thread::spawn(move || {
execute_heartbeat(context, config.heartbeat_interval, thread_count);
})),
})
}
}

/// Sets the global thread pool to this one.
///
/// The global thread pool can only be set once. Any subsequent call will
/// return the thread pool back.
///
/// # Examples
///
/// ```
/// # use std::{num::NonZero, time::Duration};
/// # use chili::{Config, ThreadPool};
/// ThreadPool::with_config(Config {
/// thread_count: Some(NonZero::new(1).unwrap()),
/// heartbeat_interval: Duration::from_micros(50),
/// })
/// .set_global()
/// .unwrap();
/// ```
pub fn set_global(self) -> Result<(), Self> {
GLOBAL_THREAD_POOL.set(self)
}

/// Returns the global thread pool.
///
/// # Examples
///
/// ```
/// # use chili::ThreadPool;
/// let mut s = ThreadPool::global().scope();
///
/// let mut vals = [0; 2];
/// let (left, right) = vals.split_at_mut(1);
///
/// s.join(|_| left[0] = 1, |_| right[0] = 1);
///
/// assert_eq!(vals, [1; 2]);
/// ```
pub fn global() -> &'static ThreadPool {
GLOBAL_THREAD_POOL.get_or_init(ThreadPool::new)
}

/// Returns a `Scope`d object that you can run fork-join workloads on.
Expand All @@ -577,13 +623,13 @@ impl ThreadPool {
///
/// ```
/// # use chili::ThreadPool;
/// let mut tp = ThreadPool::new().unwrap();
/// let mut tp = ThreadPool::new();
/// let mut s = tp.scope();
///
/// let mut vals = [0; 2];
/// let (left, right) = vals.split_at_mut(1);
///
/// s.join(|_|left[0] = 1, |_| right[0] = 1);
/// s.join(|_| left[0] = 1, |_| right[0] = 1);
///
/// assert_eq!(vals, [1; 2]);
/// ```
Expand All @@ -592,6 +638,12 @@ impl ThreadPool {
}
}

impl Default for ThreadPool {
fn default() -> Self {
Self::new()
}
}

impl Drop for ThreadPool {
fn drop(&mut self) {
self.context
Expand Down Expand Up @@ -627,7 +679,7 @@ mod tests {

#[test]
fn join_basic() {
let threat_pool = ThreadPool::new().unwrap();
let threat_pool = ThreadPool::new();
let mut scope = threat_pool.scope();

let mut a = 0;
Expand All @@ -640,7 +692,7 @@ mod tests {

#[test]
fn join_long() {
let threat_pool = ThreadPool::new().unwrap();
let threat_pool = ThreadPool::new();

fn increment(s: &mut Scope, slice: &mut [u32]) {
match slice.len() {
Expand All @@ -663,7 +715,7 @@ mod tests {

#[test]
fn join_very_long() {
let threat_pool = ThreadPool::new().unwrap();
let threat_pool = ThreadPool::new();

fn increment(s: &mut Scope, slice: &mut [u32]) {
match slice.len() {
Expand All @@ -688,11 +740,10 @@ mod tests {
#[test]
fn join_wait() {
let threat_pool = ThreadPool::with_config(Config {
thread_count: Some(2),
thread_count: Some(NonZero::new(2).unwrap()),
heartbeat_interval: Duration::from_micros(1),
..Default::default()
})
.unwrap();
});

fn increment(s: &mut Scope, slice: &mut [u32]) {
match slice.len() {
Expand Down Expand Up @@ -723,10 +774,9 @@ mod tests {
#[should_panic(expected = "panicked across threads")]
fn join_panic() {
let threat_pool = ThreadPool::with_config(Config {
thread_count: Some(2),
thread_count: Some(NonZero::new(2).unwrap()),
heartbeat_interval: Duration::from_micros(1),
})
.unwrap();
});

if let Some(thread_count) = thread::available_parallelism().ok().map(NonZero::get) {
if thread_count == 1 {
Expand Down Expand Up @@ -782,10 +832,9 @@ mod tests {
fn concurrent_scopes() {
const NUM_THREADS: u8 = 128;
let threat_pool = ThreadPool::with_config(Config {
thread_count: Some(4),
thread_count: Some(NonZero::new(4).unwrap()),
..Default::default()
})
.unwrap();
});

let a = AtomicU8::new(0);
let b = AtomicU8::new(0);
Expand Down

0 comments on commit 6eaa3d5

Please sign in to comment.