Skip to content

Commit

Permalink
feat: Add Mutex
Browse files Browse the repository at this point in the history
  • Loading branch information
scottbot95 committed Mar 14, 2024
1 parent 30aca0d commit f46095f
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 2 deletions.
9 changes: 9 additions & 0 deletions screeps-async/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub mod error;
pub mod job;
pub mod runtime;
pub mod time;
pub mod sync;

use crate::error::RuntimeError;
use crate::job::JobHandle;
Expand Down Expand Up @@ -142,6 +143,7 @@ mod utils {

#[cfg(test)]
mod tests {
use crate::error::RuntimeError;
use crate::runtime::Builder;
use std::cell::RefCell;

Expand All @@ -164,4 +166,11 @@ mod tests {

Builder::new().apply()
}

/// Calls [crate::run] and increments [GAME_TIME] if [crate::run] succeeded
pub(crate) fn tick() -> Result<(), RuntimeError> {
crate::run()?;
GAME_TIME.with_borrow_mut(|t| *t += 1);
Ok(())
}
}
5 changes: 5 additions & 0 deletions screeps-async/src/sync/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
//! Synchronization primitives for async contexts
mod mutex;

pub use mutex::*;
202 changes: 202 additions & 0 deletions screeps-async/src/sync/mutex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
use std::cell::UnsafeCell;
use std::collections::VecDeque;
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::rc::Rc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Context, Poll, Waker};

/// An async mutex
///
/// Locks will be acquired in the order they are requested
///
/// # Examples
/// ```
/// # use std::rc::Rc;
/// # use screeps_async::sync::Mutex;
/// # screeps_async::initialize();
/// let mutex = Rc::new(Mutex::new(0));
/// screeps_async::spawn(async move {
/// let mut val = mutex.lock().await;
/// *val = 1;
/// }).detach();
/// ```
pub struct Mutex<T> {
/// Whether the mutex is currently locked.
state: AtomicBool,
/// Wrapped value
data: UnsafeCell<T>,
/// Queue of futures to wake when a lock is released
wakers: UnsafeCell<VecDeque<Rc<UnsafeCell<Waker>>>>,
}

impl<T> Mutex<T> {
/// Construct a new [Mutex] in the unlocked state wrapping the given value
pub fn new(val: T) -> Self {
Self {
state: AtomicBool::new(false),
data: UnsafeCell::new(val),
wakers: UnsafeCell::new(VecDeque::new()),
}
}

/// Acquire the mutex.
///
/// Returns a guard that release the mutex when dropped
pub async fn lock(&self) -> MutexGuard<'_, T> {
MutexLockFuture::new(self).await
}

/// Try to acquire the mutex.
///
/// If the mutex could not be acquired at this time return [`None`], otherwise
/// returns a guard that will release the mutex when dropped.
pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
self.state
.compare_exchange(false, true, Ordering::Acquire, Ordering::Acquire)
.ok()?;
Some(MutexGuard::new(self))
}

/// Consumes the mutex, returning the underlying data
pub fn into_inner(self) -> T {
self.data.into_inner()
}

fn unlock(&self) {
self.state.swap(false, Ordering::Release);

unsafe {
if let Some(waker) = (*self.wakers.get()).pop_front() {
(*waker.get()).wake_by_ref();
}
}
}
}

pub struct MutexGuard<'a, T> {
lock: &'a Mutex<T>,
}

impl<'a, T> MutexGuard<'a, T> {
fn new(lock: &'a Mutex<T>) -> Self {
Self { lock }
}
}

impl<T> Deref for MutexGuard<'_, T> {
type Target = T;

fn deref(&self) -> &Self::Target {
unsafe { &*self.lock.data.get() }
}
}

impl<T> DerefMut for MutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.lock.data.get() }
}
}

impl<T> Drop for MutexGuard<'_, T> {
fn drop(&mut self) {
self.lock.unlock();
}
}

pub struct MutexLockFuture<'a, T> {
mutex: &'a Mutex<T>,
wake: Option<Rc<UnsafeCell<Waker>>>,
}

impl<'a, T> MutexLockFuture<'a, T> {
fn new(mutex: &'a Mutex<T>) -> Self {
Self { mutex, wake: None }
}
}

impl<'a, T> Future for MutexLockFuture<'a, T> {
type Output = MutexGuard<'a, T>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(val) = self.mutex.try_lock() {
return Poll::Ready(val);
}

if let Some(waker) = &self.wake {
unsafe {
(*waker.get()).clone_from(cx.waker());
}
} else {
let waker = Rc::new(UnsafeCell::new(cx.waker().clone()));
self.wake = Some(waker.clone());
unsafe {
(*self.mutex.wakers.get()).push_back(waker);
}
}

Poll::Pending
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::time::delay_ticks;

#[test]
fn single_lock() {
crate::tests::init_test();

let mutex = Rc::new(Mutex::new(vec![]));
{
let mutex = mutex.clone();
crate::spawn(async move {
let mut vec = mutex.lock().await;
vec.push(0);
})
.detach();
}

crate::run().unwrap();

let expected = vec![0];
let actual = Rc::into_inner(mutex).unwrap().into_inner();
assert_eq!(expected, actual);
}

#[test]
fn cannot_lock_twice() {
let mutex = Mutex::new(());
let _guard = mutex.try_lock().unwrap();

assert!(mutex.try_lock().is_none());
}

#[test]
fn await_multiple_locks() {
crate::tests::init_test();

let mutex = Rc::new(Mutex::new(vec![]));
const N: u32 = 10;
for i in 0..N {
let mutex = mutex.clone();
crate::spawn(async move {
let mut vec = mutex.lock().await;
// Release the lock next tick to guarantee blocked tasks
delay_ticks(1).await;
vec.push(i);
})
.detach();
}

for _ in 0..=N {
crate::tests::tick().unwrap();
}

let expected = (0..10).collect::<Vec<_>>();
let actual = Rc::into_inner(mutex).unwrap().into_inner();
assert_eq!(expected, actual);
}
}
3 changes: 1 addition & 2 deletions screeps-async/src/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ mod tests {

// Should complete within `dur` ticks (since we have infinite cpu time in this test)
while game_time() <= dur {
crate::run().unwrap();
crate::tests::GAME_TIME.with_borrow_mut(|t| *t += 1);
crate::tests::tick().unwrap()
}

// Future has been run
Expand Down

0 comments on commit f46095f

Please sign in to comment.