Skip to content

Commit

Permalink
Implement thread::park and Thread::unpark (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbornholt authored Aug 23, 2022
1 parent 6a369f9 commit bae3598
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 5 deletions.
37 changes: 37 additions & 0 deletions src/runtime/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub(crate) struct Task {
pub(super) id: TaskId,
pub(super) state: TaskState,
pub(super) detached: bool,
park_state: ParkState,

pub(super) continuation: Rc<RefCell<PooledContinuation>>,

Expand Down Expand Up @@ -79,6 +80,7 @@ impl Task {
waker,
woken: false,
detached: false,
park_state: ParkState::Unavailable,
name,
local_storage: StorageMap::new(),
}
Expand Down Expand Up @@ -245,6 +247,30 @@ impl Task {
pub(crate) fn pop_local(&mut self) -> Option<Box<dyn Any>> {
self.local_storage.pop()
}

/// Park the task if its park token is unavailable. Returns true if the token was unavailable.
pub(crate) fn park(&mut self) -> bool {
match self.park_state {
ParkState::Unparked => {
self.park_state = ParkState::Unavailable;
false
}
ParkState::Unavailable => {
self.park_state = ParkState::Parked;
self.block();
true
}
ParkState::Parked => unreachable!("cannot park a task that's already parked"),
}
}

/// Make the task's park token available, and unblock the task if it was parked.
pub(crate) fn unpark(&mut self) {
if std::mem::replace(&mut self.park_state, ParkState::Unparked) == ParkState::Parked {
assert!(self.blocked());
self.unblock();
}
}
}

#[derive(PartialEq, Eq, Clone, Copy, Debug)]
Expand All @@ -259,6 +285,17 @@ pub(crate) enum TaskState {
Finished,
}

#[derive(PartialEq, Eq, Clone, Copy, Debug)]
pub(crate) enum ParkState {
/// The task has parked itself and not yet been unparked, so the park token is unavailable.
/// Invariant: if ParkState is Parked, the task is Blocked
Parked,
/// Another task has unparked this one, so the park token is available.
Unparked,
/// The park token is not available. The task should enter Parked state on the next `park` call.
Unavailable,
}

/// A `TaskId` is a unique identifier for a task. `TaskId`s are never reused within a single
/// execution.
#[derive(PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord, Debug)]
Expand Down
32 changes: 31 additions & 1 deletion src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ impl Thread {
pub fn id(&self) -> ThreadId {
self.id
}

/// Atomically makes the handle's token available if it is not already.
pub fn unpark(&self) {
ExecutionState::with(|s| {
s.get_mut(self.id.task_id).unpark();
});

// Making the token available is a yield point
thread::switch();
}
}

/// Spawn a new thread, returning a JoinHandle for it.
Expand Down Expand Up @@ -170,7 +180,27 @@ pub fn current() -> Thread {
}
}

// TODO: Implement park(), unpark()
/// Blocks unless or until the current thread's token is made available.
pub fn park() {
let switch = ExecutionState::with(|s| s.current_mut().park());

// We only need to context switch if the park token was unavailable. If it was available, then
// any execution reachable by context switching here would also be reachable by having not
// chosen this thread at the last context switch, because the park state of a thread is only
// observable by the thread itself.
if switch {
thread::switch();
}
}

/// Blocks unless or until the current thread's token is made available or the specified duration
/// has been reached (may wake spuriously).
///
/// Note that Shuttle does not module time, so this behaves identically to `park`. It cannot
/// spuriously wake.
pub fn park_timeout(_dur: Duration) {
park();
}

/// Thread factory, which can be used in order to configure the properties of a new thread.
#[derive(Debug, Default)]
Expand Down
3 changes: 0 additions & 3 deletions tests/asynch/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
// It's convenient to not specify eval order for `await`s in our tests
#![allow(clippy::mixed_read_write_in_expression)]

mod basic;
mod channel;
mod countdown_timer;
Expand Down
119 changes: 118 additions & 1 deletion tests/basic/thread.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use shuttle::sync::Mutex;
use shuttle::sync::{Condvar, Mutex};
use shuttle::{check_dfs, check_random, thread};
use std::collections::HashSet;
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
use std::sync::Arc;
use test_log::test;
Expand Down Expand Up @@ -376,3 +377,119 @@ mod thread_local {
)
}
}

#[test]
fn thread_park() {
check_dfs(
|| {
let flag = Arc::new(AtomicBool::new(false));
let thd = {
let flag = Arc::clone(&flag);
thread::spawn(move || {
thread::park();
assert!(flag.load(Ordering::SeqCst));
})
};

flag.store(true, Ordering::SeqCst);
thd.thread().unpark();
thd.join().unwrap();
},
None,
)
}

#[test]
#[should_panic(expected = "deadlock")]
fn thread_park_deadlock() {
check_dfs(
|| {
thread::park();
},
None,
)
}

// From the docs: "Because the token is initially absent, `unpark` followed by `park` will result in
// the second call returning immediately"
#[test]
fn thread_unpark_park() {
check_dfs(
|| {
thread::current().unpark();
thread::park();
},
None,
)
}

// Unparking a thread should not unconditionally unblock it (e.g., if it's blocked waiting on a lock
// rather than parked)
#[test]
fn thread_unpark_unblock() {
check_dfs(
|| {
let lock = Arc::new(Mutex::new(false));
let condvar = Arc::new(Condvar::new());

let reader = {
let lock = Arc::clone(&lock);
let condvar = Arc::clone(&condvar);
thread::spawn(move || {
let mut guard = lock.lock().unwrap();
while !*guard {
guard = condvar.wait(guard).unwrap();
}
})
};

let _writer = {
let lock = Arc::clone(&lock);
let condvar = Arc::clone(&condvar);
thread::spawn(move || {
let mut guard = lock.lock().unwrap();
*guard = true;
condvar.notify_one();
})
};

reader.thread().unpark();
},
None,
)
}

// Calling `unpark` on a thread that has already been unparked should be a no-op
#[test]
fn thread_double_unpark() {
let seen_unparks = Arc::new(std::sync::Mutex::new(HashSet::new()));
let seen_unparks_clone = Arc::clone(&seen_unparks);

check_dfs(
move || {
let unpark_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let parkee = {
let seen_unparks = Arc::clone(&seen_unparks);
let unpark_count = Arc::clone(&unpark_count);
thread::spawn(move || {
thread::park();
let unpark_count = unpark_count.load(Ordering::SeqCst);
seen_unparks.lock().unwrap().insert(unpark_count);
// If this is 1 we know `unpark` will be uncalled again, so this won't deadlock
if unpark_count == 1 {
thread::park();
}
})
};

unpark_count.fetch_add(1, Ordering::SeqCst);
parkee.thread().unpark();
unpark_count.fetch_add(1, Ordering::SeqCst);
parkee.thread().unpark();
},
None,
);

let seen_unparks = Arc::try_unwrap(seen_unparks_clone).unwrap().into_inner().unwrap();
assert_eq!(seen_unparks, HashSet::from([1, 2]));
}

0 comments on commit bae3598

Please sign in to comment.