Skip to content

Commit

Permalink
feat: Add async RwLock
Browse files Browse the repository at this point in the history
  • Loading branch information
scottbot95 committed Mar 18, 2024
1 parent 99f9ac1 commit 8c65e1d
Show file tree
Hide file tree
Showing 4 changed files with 322 additions and 4 deletions.
6 changes: 3 additions & 3 deletions screeps-async/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
//! screeps-async = "0.2.0"
//! ```
//!
//! # The [`#[screeps_async::main]`](screeps_async::main) macro
//! # The [`#[screeps_async::main]`](main) macro
//! ```
//! #[screeps_async::main]
//! pub fn game_loop() {
Expand All @@ -29,13 +29,13 @@
//! println!("Hello!");
//! });
//!
//! screeps_async::run();
//! screeps_async::run().unwrap();
//! }
//! ```
pub mod macros;

pub use macros::*;

use std::cell::RefCell;
pub mod error;
pub mod job;
Expand Down
3 changes: 3 additions & 0 deletions screeps-async/src/sync/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
mod mutex;
pub use mutex::*;

mod rwlock;
pub use rwlock::*;
2 changes: 1 addition & 1 deletion screeps-async/src/sync/mutex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::task::{Context, Poll, Waker};
pub struct Mutex<T> {
/// Whether the mutex is currently locked.
///
/// Use [Cell<bool>] instead of [AtomicBool] since we don't really need atomics
/// Use [`Cell<bool>`] instead of [AtomicBool] since we don't really need atomics
/// and [Cell] is more general
state: Cell<bool>,
/// Wrapped value
Expand Down
315 changes: 315 additions & 0 deletions screeps-async/src/sync/rwlock.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
use std::cell::{Ref, RefCell, RefMut, UnsafeCell};
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll, Waker};

/// An async RwLock
///
/// Locks will be acquired in the order they are requested. When any task is waiting
/// on a [write](RwLock::write) lock, no new [read](RwLock::read) locks can be acquired
pub struct RwLock<T> {
/// Inner RwLock
inner: RefCell<T>,
/// Queue of futures to wake when a write lock is released
read_wakers: UnsafeCell<Vec<Waker>>,
/// Queue of futures to wake when a read lock is released
write_wakers: UnsafeCell<Vec<Waker>>,
}

impl<T> RwLock<T> {
/// Construct a new [RwLock] wrapping `val`
pub fn new(val: T) -> Self {
Self {
inner: RefCell::new(val),
read_wakers: UnsafeCell::new(Vec::new()),
write_wakers: UnsafeCell::new(Vec::new()),
}
}

/// Block until the wrapped value can be immutably borrowed
pub fn read(&self) -> RwLockFuture<'_, T, RwLockReadGuard<'_, T>> {
RwLockFuture {
lock: self,
borrow: Self::try_read,
is_writer: false,
}
}

/// Attempt to immutably borrow the wrapped value.
///
/// Returns [None] if the value is currently mutably borrowed or
/// a task is waiting on a mutable reference.
pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
unsafe { RwLockReadGuard::new(self) }
}

/// Block until the wrapped value can be mutably borrowed
pub fn write(&self) -> RwLockFuture<'_, T, RwLockWriteGuard<'_, T>> {
RwLockFuture {
lock: self,
borrow: Self::try_write,
is_writer: true,
}
}

/// Attempt to mutably borrow the wrapped value.
///
/// Returns [None] if the value is already borrowed (mutably or immutably)
pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
RwLockWriteGuard::new(self)
}

/// Consumes this [RwLock] and returns ownership of the wrapped value
pub fn into_inner(self) -> T {
self.inner.into_inner()
}

/// Convenience method to consume [`Rc<RwLock<T>>`] and return the wrapped value
///
/// # Panics
/// This method panics if the Rc has more than one strong reference
pub fn into_inner_rc(self: Rc<Self>) -> T {
Rc::into_inner(self).unwrap().into_inner()
}
}

impl<T> RwLock<T> {
unsafe fn unlock(&self) {
let wakers = &mut *self.write_wakers.get();
wakers.drain(..).for_each(Waker::wake);

let wakers = &mut *self.read_wakers.get();
wakers.drain(..).for_each(Waker::wake);
}
}

/// An RAII guard that releases a read lock when dropped
pub struct RwLockReadGuard<'a, T> {
inner: &'a RwLock<T>,
data: Ref<'a, T>,
}

impl<'a, T> RwLockReadGuard<'a, T> {
unsafe fn new(lock: &'a RwLock<T>) -> Option<Self> {
if !(*lock.write_wakers.get()).is_empty() {
return None; // Cannot take new reads if a writer is waiting
}

let data = lock.inner.try_borrow().ok()?;

Some(RwLockReadGuard { data, inner: lock })
}
}

impl<T> Drop for RwLockReadGuard<'_, T> {
fn drop(&mut self) {
unsafe { self.inner.unlock() }
}
}

impl<T> Deref for RwLockReadGuard<'_, T> {
type Target = T;

fn deref(&self) -> &Self::Target {
&self.data
}
}

/// An RAII guard that releases the write lock when dropped
pub struct RwLockWriteGuard<'a, T> {
inner: &'a RwLock<T>,
data: RefMut<'a, T>,
}

impl<'a, T> RwLockWriteGuard<'a, T> {
fn new(lock: &'a RwLock<T>) -> Option<Self> {
let data = lock.inner.try_borrow_mut().ok()?;

Some(Self { inner: lock, data })
}

/// Immediately drop the guard and release the write lock
///
/// Equivalent to [drop(self)], but is more self-documenting
pub fn unlock(self) {
drop(self);
}

/// Release the write lock and immediately yield control back to the async runtime
///
/// This essentially just calls [Self::unlock] then [yield_now()](crate::time::yield_now)
pub async fn unlock_fair(self) {
self.unlock();
crate::time::yield_now().await;
}
}

impl<T> Drop for RwLockWriteGuard<'_, T> {
fn drop(&mut self) {
unsafe { self.inner.unlock() }
}
}

impl<T> Deref for RwLockWriteGuard<'_, T> {
type Target = T;

fn deref(&self) -> &Self::Target {
&self.data
}
}

impl<T> DerefMut for RwLockWriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.data
}
}

/// A [Future] that blocks until the [RwLock] can be acquired.
pub struct RwLockFuture<'a, T, G> {
lock: &'a RwLock<T>,
borrow: fn(&'a RwLock<T>) -> Option<G>,
is_writer: bool,
}

impl<T, G> Future for RwLockFuture<'_, T, G> {
type Output = G;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(guard) = (self.borrow)(self.lock) {
return Poll::Ready(guard);
}

let wakers = if self.is_writer {
self.lock.write_wakers.get()
} else {
self.lock.read_wakers.get()
};
let wakers = unsafe { &mut *wakers };

wakers.push(cx.waker().clone());

Poll::Pending
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::time::delay_ticks;

#[test]
fn can_read_multiple_times() {
crate::tests::init_test();

let lock = Rc::new(RwLock::new(()));
const N: usize = 10;
for _ in 0..N {
let lock = lock.clone();
crate::spawn(async move {
let _guard = lock.read().await;
// Lock should acquire first tick
assert_eq!(0, crate::tests::game_time());
// don't release till next tick to check if we can hold multiple read locks at once
delay_ticks(1).await;
})
.detach();
}

for _ in 0..=N {
crate::tests::tick().unwrap();
}
}

#[test]
fn cannot_write_multiple_times() {
crate::tests::init_test();

let lock = Rc::new(RwLock::new(0));
{
let lock = lock.clone();
crate::spawn(async move {
let mut guard = lock.write().await;
assert_eq!(0, crate::tests::game_time());
delay_ticks(1).await;
*guard += 1;
})
.detach();
}
{
let lock = lock.clone();
crate::spawn(async move {
let mut guard = lock.write().await;
assert_eq!(1, crate::tests::game_time());
delay_ticks(1).await;
*guard += 1;
})
.detach();
}

crate::tests::tick().unwrap();
crate::tests::tick().unwrap();
crate::tests::tick().unwrap();

assert_eq!(2, lock.into_inner_rc());
}

#[test]
fn cannot_read_while_writer_waiting() {
crate::tests::init_test();

let lock = Rc::new(RwLock::new(0));
{
let lock = lock.clone();
crate::spawn(async move {
let mut guard = lock.write().await;
println!("write 1 acquired");
assert_eq!(0, crate::tests::game_time());
delay_ticks(1).await;
*guard += 1;
})
.detach();
}
{
let lock = lock.clone();
crate::spawn(async move {
let guard = lock.read().await;
println!("read 1 acquired");
// this should happen after second write
assert_eq!(2, crate::tests::game_time());
delay_ticks(1).await;
assert_eq!(2, *guard);
})
.detach();
}
{
let lock = lock.clone();
crate::spawn(async move {
let mut guard = lock.write().await;
println!("write 2 acquired");
assert_eq!(1, crate::tests::game_time());
delay_ticks(1).await;
*guard += 1;
})
.detach();
}
{
let lock = lock.clone();
crate::spawn(async move {
let guard = lock.read().await;
println!("read 2 acquired");
assert_eq!(2, crate::tests::game_time());
assert_eq!(2, *guard);
})
.detach();
}

crate::tests::tick().unwrap();
crate::tests::tick().unwrap();
crate::tests::tick().unwrap();
crate::tests::tick().unwrap();

assert_eq!(2, lock.into_inner_rc());
}
}

0 comments on commit 8c65e1d

Please sign in to comment.