From f46095fa3a49f8caef628379101710b35bce6c13 Mon Sep 17 00:00:00 2001 From: Scott Techau Date: Tue, 12 Mar 2024 23:48:19 -0700 Subject: [PATCH] feat: Add `Mutex` --- screeps-async/src/lib.rs | 9 ++ screeps-async/src/sync/mod.rs | 5 + screeps-async/src/sync/mutex.rs | 202 ++++++++++++++++++++++++++++++++ screeps-async/src/time.rs | 3 +- 4 files changed, 217 insertions(+), 2 deletions(-) create mode 100644 screeps-async/src/sync/mod.rs create mode 100644 screeps-async/src/sync/mutex.rs diff --git a/screeps-async/src/lib.rs b/screeps-async/src/lib.rs index c6d78bb..59e861d 100644 --- a/screeps-async/src/lib.rs +++ b/screeps-async/src/lib.rs @@ -41,6 +41,7 @@ pub mod error; pub mod job; pub mod runtime; pub mod time; +pub mod sync; use crate::error::RuntimeError; use crate::job::JobHandle; @@ -142,6 +143,7 @@ mod utils { #[cfg(test)] mod tests { + use crate::error::RuntimeError; use crate::runtime::Builder; use std::cell::RefCell; @@ -164,4 +166,11 @@ mod tests { Builder::new().apply() } + + /// Calls [crate::run] and increments [GAME_TIME] if [crate::run] succeeded + pub(crate) fn tick() -> Result<(), RuntimeError> { + crate::run()?; + GAME_TIME.with_borrow_mut(|t| *t += 1); + Ok(()) + } } diff --git a/screeps-async/src/sync/mod.rs b/screeps-async/src/sync/mod.rs new file mode 100644 index 0000000..57ea2f2 --- /dev/null +++ b/screeps-async/src/sync/mod.rs @@ -0,0 +1,5 @@ +//! Synchronization primitives for async contexts + +mod mutex; + +pub use mutex::*; \ No newline at end of file diff --git a/screeps-async/src/sync/mutex.rs b/screeps-async/src/sync/mutex.rs new file mode 100644 index 0000000..174fc8c --- /dev/null +++ b/screeps-async/src/sync/mutex.rs @@ -0,0 +1,202 @@ +use std::cell::UnsafeCell; +use std::collections::VecDeque; +use std::future::Future; +use std::ops::{Deref, DerefMut}; +use std::pin::Pin; +use std::rc::Rc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::task::{Context, Poll, Waker}; + +/// An async mutex +/// +/// Locks will be acquired in the order they are requested +/// +/// # Examples +/// ``` +/// # use std::rc::Rc; +/// # use screeps_async::sync::Mutex; +/// # screeps_async::initialize(); +/// let mutex = Rc::new(Mutex::new(0)); +/// screeps_async::spawn(async move { +/// let mut val = mutex.lock().await; +/// *val = 1; +/// }).detach(); +/// ``` +pub struct Mutex { + /// Whether the mutex is currently locked. + state: AtomicBool, + /// Wrapped value + data: UnsafeCell, + /// Queue of futures to wake when a lock is released + wakers: UnsafeCell>>>, +} + +impl Mutex { + /// Construct a new [Mutex] in the unlocked state wrapping the given value + pub fn new(val: T) -> Self { + Self { + state: AtomicBool::new(false), + data: UnsafeCell::new(val), + wakers: UnsafeCell::new(VecDeque::new()), + } + } + + /// Acquire the mutex. + /// + /// Returns a guard that release the mutex when dropped + pub async fn lock(&self) -> MutexGuard<'_, T> { + MutexLockFuture::new(self).await + } + + /// Try to acquire the mutex. + /// + /// If the mutex could not be acquired at this time return [`None`], otherwise + /// returns a guard that will release the mutex when dropped. + pub fn try_lock(&self) -> Option> { + self.state + .compare_exchange(false, true, Ordering::Acquire, Ordering::Acquire) + .ok()?; + Some(MutexGuard::new(self)) + } + + /// Consumes the mutex, returning the underlying data + pub fn into_inner(self) -> T { + self.data.into_inner() + } + + fn unlock(&self) { + self.state.swap(false, Ordering::Release); + + unsafe { + if let Some(waker) = (*self.wakers.get()).pop_front() { + (*waker.get()).wake_by_ref(); + } + } + } +} + +pub struct MutexGuard<'a, T> { + lock: &'a Mutex, +} + +impl<'a, T> MutexGuard<'a, T> { + fn new(lock: &'a Mutex) -> Self { + Self { lock } + } +} + +impl Deref for MutexGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { &*self.lock.data.get() } + } +} + +impl DerefMut for MutexGuard<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.lock.data.get() } + } +} + +impl Drop for MutexGuard<'_, T> { + fn drop(&mut self) { + self.lock.unlock(); + } +} + +pub struct MutexLockFuture<'a, T> { + mutex: &'a Mutex, + wake: Option>>, +} + +impl<'a, T> MutexLockFuture<'a, T> { + fn new(mutex: &'a Mutex) -> Self { + Self { mutex, wake: None } + } +} + +impl<'a, T> Future for MutexLockFuture<'a, T> { + type Output = MutexGuard<'a, T>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if let Some(val) = self.mutex.try_lock() { + return Poll::Ready(val); + } + + if let Some(waker) = &self.wake { + unsafe { + (*waker.get()).clone_from(cx.waker()); + } + } else { + let waker = Rc::new(UnsafeCell::new(cx.waker().clone())); + self.wake = Some(waker.clone()); + unsafe { + (*self.mutex.wakers.get()).push_back(waker); + } + } + + Poll::Pending + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::time::delay_ticks; + + #[test] + fn single_lock() { + crate::tests::init_test(); + + let mutex = Rc::new(Mutex::new(vec![])); + { + let mutex = mutex.clone(); + crate::spawn(async move { + let mut vec = mutex.lock().await; + vec.push(0); + }) + .detach(); + } + + crate::run().unwrap(); + + let expected = vec![0]; + let actual = Rc::into_inner(mutex).unwrap().into_inner(); + assert_eq!(expected, actual); + } + + #[test] + fn cannot_lock_twice() { + let mutex = Mutex::new(()); + let _guard = mutex.try_lock().unwrap(); + + assert!(mutex.try_lock().is_none()); + } + + #[test] + fn await_multiple_locks() { + crate::tests::init_test(); + + let mutex = Rc::new(Mutex::new(vec![])); + const N: u32 = 10; + for i in 0..N { + let mutex = mutex.clone(); + crate::spawn(async move { + let mut vec = mutex.lock().await; + // Release the lock next tick to guarantee blocked tasks + delay_ticks(1).await; + vec.push(i); + }) + .detach(); + } + + for _ in 0..=N { + crate::tests::tick().unwrap(); + } + + let expected = (0..10).collect::>(); + let actual = Rc::into_inner(mutex).unwrap().into_inner(); + assert_eq!(expected, actual); + } +} diff --git a/screeps-async/src/time.rs b/screeps-async/src/time.rs index c5eb185..e34dbab 100644 --- a/screeps-async/src/time.rs +++ b/screeps-async/src/time.rs @@ -139,8 +139,7 @@ mod tests { // Should complete within `dur` ticks (since we have infinite cpu time in this test) while game_time() <= dur { - crate::run().unwrap(); - crate::tests::GAME_TIME.with_borrow_mut(|t| *t += 1); + crate::tests::tick().unwrap() } // Future has been run