From 2b86062df8575e50b5629199d437e1cf5acdd930 Mon Sep 17 00:00:00 2001 From: scottbot95 Date: Sun, 17 Mar 2024 21:35:48 -0700 Subject: [PATCH] feat: Add `Mutex` (#4) --- screeps-async/src/lib.rs | 9 ++ screeps-async/src/sync/mod.rs | 4 + screeps-async/src/sync/mutex.rs | 244 ++++++++++++++++++++++++++++++++ screeps-async/src/time.rs | 3 +- 4 files changed, 258 insertions(+), 2 deletions(-) create mode 100644 screeps-async/src/sync/mod.rs create mode 100644 screeps-async/src/sync/mutex.rs diff --git a/screeps-async/src/lib.rs b/screeps-async/src/lib.rs index c6d78bb..4eea581 100644 --- a/screeps-async/src/lib.rs +++ b/screeps-async/src/lib.rs @@ -40,6 +40,7 @@ use std::cell::RefCell; pub mod error; pub mod job; pub mod runtime; +pub mod sync; pub mod time; use crate::error::RuntimeError; @@ -142,6 +143,7 @@ mod utils { #[cfg(test)] mod tests { + use crate::error::RuntimeError; use crate::runtime::Builder; use std::cell::RefCell; @@ -164,4 +166,11 @@ mod tests { Builder::new().apply() } + + /// Calls [crate::run] and increments [GAME_TIME] if [crate::run] succeeded + pub(crate) fn tick() -> Result<(), RuntimeError> { + crate::run()?; + GAME_TIME.with_borrow_mut(|t| *t += 1); + Ok(()) + } } diff --git a/screeps-async/src/sync/mod.rs b/screeps-async/src/sync/mod.rs new file mode 100644 index 0000000..3b27afb --- /dev/null +++ b/screeps-async/src/sync/mod.rs @@ -0,0 +1,4 @@ +//! Synchronization primitives for async contexts + +mod mutex; +pub use mutex::*; diff --git a/screeps-async/src/sync/mutex.rs b/screeps-async/src/sync/mutex.rs new file mode 100644 index 0000000..8ffe070 --- /dev/null +++ b/screeps-async/src/sync/mutex.rs @@ -0,0 +1,244 @@ +use std::cell::{Cell, UnsafeCell}; +use std::future::Future; +use std::ops::{Deref, DerefMut}; +use std::pin::Pin; +use std::task::{Context, Poll, Waker}; + +/// An async mutex +/// +/// Locks will be acquired in the order they are requested +/// +/// # Examples +/// ``` +/// # use std::rc::Rc; +/// # use screeps_async::sync::Mutex; +/// # screeps_async::initialize(); +/// let mutex = Rc::new(Mutex::new(0)); +/// screeps_async::spawn(async move { +/// let mut val = mutex.lock().await; +/// *val = 1; +/// }).detach(); +/// ``` +pub struct Mutex { + /// Whether the mutex is currently locked. + /// + /// Use [Cell] instead of [AtomicBool] since we don't really need atomics + /// and [Cell] is more general + state: Cell, + /// Wrapped value + data: UnsafeCell, + /// Queue of futures to wake when a lock is released + wakers: UnsafeCell>, +} + +impl Mutex { + /// Construct a new [Mutex] in the unlocked state wrapping the given value + pub fn new(val: T) -> Self { + Self { + state: Cell::new(false), + data: UnsafeCell::new(val), + wakers: UnsafeCell::new(Vec::new()), + } + } + + /// Acquire the mutex. + /// + /// Returns a guard that release the mutex when dropped + pub fn lock(&self) -> MutexLockFuture<'_, T> { + MutexLockFuture::new(self) + } + + /// Try to acquire the mutex. + /// + /// If the mutex could not be acquired at this time return [`None`], otherwise + /// returns a guard that will release the mutex when dropped. + pub fn try_lock(&self) -> Option> { + (!self.state.replace(true)).then(|| MutexGuard::new(self)) + } + + /// Consumes the mutex, returning the underlying data + pub fn into_inner(self) -> T { + self.data.into_inner() + } + + fn unlock(&self) { + self.state.set(false); + let wakers = unsafe { &mut *self.wakers.get() }; + wakers.drain(..).for_each(Waker::wake); + } +} + +/// An RAII guard that releases the mutex when dropped +pub struct MutexGuard<'a, T> { + lock: &'a Mutex, +} + +impl<'a, T> MutexGuard<'a, T> { + fn new(lock: &'a Mutex) -> Self { + Self { lock } + } + + /// Immediately drops the guard, and consequently unlocks the mutex. + /// + /// This function is equivalent to calling [`drop`] on the guard but is more self-documenting. + pub fn unlock(self) { + drop(self); + } + + /// Release the lock and immediately yield control back to the async runtime + /// + /// This essentially just calls [Self::unlock] then [yield_now()](crate::time::yield_now) + pub async fn unlock_fair(self) { + self.unlock(); + crate::time::yield_now().await; + } +} + +impl Deref for MutexGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { &*self.lock.data.get() } + } +} + +impl DerefMut for MutexGuard<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.lock.data.get() } + } +} + +impl Drop for MutexGuard<'_, T> { + fn drop(&mut self) { + self.lock.unlock(); + } +} + +/// A [Future] that blocks until the [Mutex] can be locked, then returns the [MutexGuard] +pub struct MutexLockFuture<'a, T> { + mutex: &'a Mutex, +} + +impl<'a, T> MutexLockFuture<'a, T> { + fn new(mutex: &'a Mutex) -> Self { + Self { mutex } + } +} + +impl<'a, T> Future for MutexLockFuture<'a, T> { + type Output = MutexGuard<'a, T>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if let Some(val) = self.mutex.try_lock() { + return Poll::Ready(val); + } + + unsafe { + (*self.mutex.wakers.get()).push(cx.waker().clone()); + } + + Poll::Pending + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::time::delay_ticks; + use std::rc::Rc; + + #[test] + fn single_lock() { + crate::tests::init_test(); + + let mutex = Rc::new(Mutex::new(vec![])); + { + let mutex = mutex.clone(); + crate::spawn(async move { + let mut vec = mutex.lock().await; + vec.push(0); + }) + .detach(); + } + + crate::run().unwrap(); + + let expected = vec![0]; + let actual = Rc::into_inner(mutex).unwrap().into_inner(); + assert_eq!(expected, actual); + } + + #[test] + fn cannot_lock_twice() { + let mutex = Mutex::new(()); + let _guard = mutex.try_lock().unwrap(); + + assert!(mutex.try_lock().is_none()); + } + + #[test] + fn await_multiple_locks() { + crate::tests::init_test(); + + let mutex = Rc::new(Mutex::new(vec![])); + const N: u32 = 10; + for i in 0..N { + let mutex = mutex.clone(); + crate::spawn(async move { + let mut vec = mutex.lock().await; + // Release the lock next tick to guarantee blocked tasks + delay_ticks(1).await; + vec.push(i); + }) + .detach(); + } + + for _ in 0..=N { + crate::tests::tick().unwrap(); + } + + let expected = (0..10).collect::>(); + let actual = Rc::into_inner(mutex).unwrap().into_inner(); + assert_eq!(expected, actual); + } + + #[test] + fn handles_dropped_futures() { + crate::tests::init_test(); + + let mutex = Rc::new(Mutex::new(vec![])); + { + let mutex = mutex.clone(); + crate::spawn(async move { + let mut _guard = mutex.lock().await; + delay_ticks(1).await; + _guard.push(0); + }) + .detach(); + } + let to_drop = { + let mutex = mutex.clone(); + crate::spawn(async move { + let mut _guard = mutex.lock().await; + _guard.push(1); + }) + }; + { + let mutex = mutex.clone(); + crate::spawn(async move { + let mut _guard = mutex.lock().await; + _guard.push(2); + }) + .detach(); + } + + crate::tests::tick().unwrap(); + drop(to_drop); + crate::tests::tick().unwrap(); + + let expected = vec![0, 2]; + let actual = Rc::into_inner(mutex).unwrap().into_inner(); + + assert_eq!(expected, actual); + } +} diff --git a/screeps-async/src/time.rs b/screeps-async/src/time.rs index c5eb185..e34dbab 100644 --- a/screeps-async/src/time.rs +++ b/screeps-async/src/time.rs @@ -139,8 +139,7 @@ mod tests { // Should complete within `dur` ticks (since we have infinite cpu time in this test) while game_time() <= dur { - crate::run().unwrap(); - crate::tests::GAME_TIME.with_borrow_mut(|t| *t += 1); + crate::tests::tick().unwrap() } // Future has been run