Skip to content

Commit

Permalink
Rollup merge of #133406 - EFanZh:lock-value-accessors, r=Noratrieb
Browse files Browse the repository at this point in the history
Add value accessor methods to `Mutex` and `RwLock`

- ACP: rust-lang/libs-team#485.
- Tracking issue: #133407.

This PR adds `get`, `set` and `replace` methods to the `Mutex` and `RwLock` types for quick access to their contained values.

One possible optimization would be to check for poisoning first and return an error immediately, without attempting to acquire the lock. I didn’t implement this because I consider poisoning to be relatively rare, adding this extra check could slow down common use cases.
  • Loading branch information
Zalathar authored Dec 15, 2024
2 parents acdcd3a + 242c6c3 commit 6667908
Show file tree
Hide file tree
Showing 5 changed files with 517 additions and 86 deletions.
110 changes: 104 additions & 6 deletions library/std/src/sync/mutex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ mod tests;
use crate::cell::UnsafeCell;
use crate::fmt;
use crate::marker::PhantomData;
use crate::mem::ManuallyDrop;
use crate::mem::{self, ManuallyDrop};
use crate::ops::{Deref, DerefMut};
use crate::ptr::NonNull;
use crate::sync::{LockResult, TryLockError, TryLockResult, poison};
use crate::sync::{LockResult, PoisonError, TryLockError, TryLockResult, poison};
use crate::sys::sync as sys;

/// A mutual exclusion primitive useful for protecting shared data
Expand Down Expand Up @@ -273,6 +273,100 @@ impl<T> Mutex<T> {
pub const fn new(t: T) -> Mutex<T> {
Mutex { inner: sys::Mutex::new(), poison: poison::Flag::new(), data: UnsafeCell::new(t) }
}

/// Returns the contained value by cloning it.
///
/// # Errors
///
/// If another user of this mutex panicked while holding the mutex, then
/// this call will return an error instead.
///
/// # Examples
///
/// ```
/// #![feature(lock_value_accessors)]
///
/// use std::sync::Mutex;
///
/// let mut mutex = Mutex::new(7);
///
/// assert_eq!(mutex.get_cloned().unwrap(), 7);
/// ```
#[unstable(feature = "lock_value_accessors", issue = "133407")]
pub fn get_cloned(&self) -> Result<T, PoisonError<()>>
where
T: Clone,
{
match self.lock() {
Ok(guard) => Ok((*guard).clone()),
Err(_) => Err(PoisonError::new(())),
}
}

/// Sets the contained value.
///
/// # Errors
///
/// If another user of this mutex panicked while holding the mutex, then
/// this call will return an error containing the provided `value` instead.
///
/// # Examples
///
/// ```
/// #![feature(lock_value_accessors)]
///
/// use std::sync::Mutex;
///
/// let mut mutex = Mutex::new(7);
///
/// assert_eq!(mutex.get_cloned().unwrap(), 7);
/// mutex.set(11).unwrap();
/// assert_eq!(mutex.get_cloned().unwrap(), 11);
/// ```
#[unstable(feature = "lock_value_accessors", issue = "133407")]
pub fn set(&self, value: T) -> Result<(), PoisonError<T>> {
if mem::needs_drop::<T>() {
// If the contained value has non-trivial destructor, we
// call that destructor after the lock being released.
self.replace(value).map(drop)
} else {
match self.lock() {
Ok(mut guard) => {
*guard = value;

Ok(())
}
Err(_) => Err(PoisonError::new(value)),
}
}
}

/// Replaces the contained value with `value`, and returns the old contained value.
///
/// # Errors
///
/// If another user of this mutex panicked while holding the mutex, then
/// this call will return an error containing the provided `value` instead.
///
/// # Examples
///
/// ```
/// #![feature(lock_value_accessors)]
///
/// use std::sync::Mutex;
///
/// let mut mutex = Mutex::new(7);
///
/// assert_eq!(mutex.replace(11).unwrap(), 7);
/// assert_eq!(mutex.get_cloned().unwrap(), 11);
/// ```
#[unstable(feature = "lock_value_accessors", issue = "133407")]
pub fn replace(&self, value: T) -> LockResult<T> {
match self.lock() {
Ok(mut guard) => Ok(mem::replace(&mut *guard, value)),
Err(_) => Err(PoisonError::new(value)),
}
}
}

impl<T: ?Sized> Mutex<T> {
Expand All @@ -290,7 +384,8 @@ impl<T: ?Sized> Mutex<T> {
/// # Errors
///
/// If another user of this mutex panicked while holding the mutex, then
/// this call will return an error once the mutex is acquired.
/// this call will return an error once the mutex is acquired. The acquired
/// mutex guard will be contained in the returned error.
///
/// # Panics
///
Expand Down Expand Up @@ -331,7 +426,8 @@ impl<T: ?Sized> Mutex<T> {
///
/// If another user of this mutex panicked while holding the mutex, then
/// this call will return the [`Poisoned`] error if the mutex would
/// otherwise be acquired.
/// otherwise be acquired. An acquired lock guard will be contained
/// in the returned error.
///
/// If the mutex could not be acquired because it is already locked, then
/// this call will return the [`WouldBlock`] error.
Expand Down Expand Up @@ -438,7 +534,8 @@ impl<T: ?Sized> Mutex<T> {
/// # Errors
///
/// If another user of this mutex panicked while holding the mutex, then
/// this call will return an error instead.
/// this call will return an error containing the the underlying data
/// instead.
///
/// # Examples
///
Expand All @@ -465,7 +562,8 @@ impl<T: ?Sized> Mutex<T> {
/// # Errors
///
/// If another user of this mutex panicked while holding the mutex, then
/// this call will return an error instead.
/// this call will return an error containing a mutable reference to the
/// underlying data instead.
///
/// # Examples
///
Expand Down
161 changes: 138 additions & 23 deletions library/std/src/sync/mutex/tests.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,34 @@
use crate::fmt::Debug;
use crate::ops::FnMut;
use crate::panic::{self, AssertUnwindSafe};
use crate::sync::atomic::{AtomicUsize, Ordering};
use crate::sync::mpsc::channel;
use crate::sync::{Arc, Condvar, MappedMutexGuard, Mutex, MutexGuard, TryLockError};
use crate::thread;
use crate::{hint, mem, thread};

struct Packet<T>(Arc<(Mutex<T>, Condvar)>);

#[derive(Eq, PartialEq, Debug)]
struct NonCopy(i32);

#[derive(Eq, PartialEq, Debug)]
struct NonCopyNeedsDrop(i32);

impl Drop for NonCopyNeedsDrop {
fn drop(&mut self) {
hint::black_box(());
}
}

#[test]
fn test_needs_drop() {
assert!(!mem::needs_drop::<NonCopy>());
assert!(mem::needs_drop::<NonCopyNeedsDrop>());
}

#[derive(Clone, Eq, PartialEq, Debug)]
struct Cloneable(i32);

#[test]
fn smoke() {
let m = Mutex::new(());
Expand Down Expand Up @@ -57,6 +78,21 @@ fn try_lock() {
*m.try_lock().unwrap() = ();
}

fn new_poisoned_mutex<T>(value: T) -> Mutex<T> {
let mutex = Mutex::new(value);

let catch_unwind_result = panic::catch_unwind(AssertUnwindSafe(|| {
let _guard = mutex.lock().unwrap();

panic!("test panic to poison mutex");
}));

assert!(catch_unwind_result.is_err());
assert!(mutex.is_poisoned());

mutex
}

#[test]
fn test_into_inner() {
let m = Mutex::new(NonCopy(10));
Expand All @@ -83,21 +119,31 @@ fn test_into_inner_drop() {

#[test]
fn test_into_inner_poison() {
let m = Arc::new(Mutex::new(NonCopy(10)));
let m2 = m.clone();
let _ = thread::spawn(move || {
let _lock = m2.lock().unwrap();
panic!("test panic in inner thread to poison mutex");
})
.join();
let m = new_poisoned_mutex(NonCopy(10));

assert!(m.is_poisoned());
match Arc::try_unwrap(m).unwrap().into_inner() {
match m.into_inner() {
Err(e) => assert_eq!(e.into_inner(), NonCopy(10)),
Ok(x) => panic!("into_inner of poisoned Mutex is Ok: {x:?}"),
}
}

#[test]
fn test_get_cloned() {
let m = Mutex::new(Cloneable(10));

assert_eq!(m.get_cloned().unwrap(), Cloneable(10));
}

#[test]
fn test_get_cloned_poison() {
let m = new_poisoned_mutex(Cloneable(10));

match m.get_cloned() {
Err(e) => assert_eq!(e.into_inner(), ()),
Ok(x) => panic!("get of poisoned Mutex is Ok: {x:?}"),
}
}

#[test]
fn test_get_mut() {
let mut m = Mutex::new(NonCopy(10));
Expand All @@ -107,21 +153,90 @@ fn test_get_mut() {

#[test]
fn test_get_mut_poison() {
let m = Arc::new(Mutex::new(NonCopy(10)));
let m2 = m.clone();
let _ = thread::spawn(move || {
let _lock = m2.lock().unwrap();
panic!("test panic in inner thread to poison mutex");
})
.join();
let mut m = new_poisoned_mutex(NonCopy(10));

assert!(m.is_poisoned());
match Arc::try_unwrap(m).unwrap().get_mut() {
match m.get_mut() {
Err(e) => assert_eq!(*e.into_inner(), NonCopy(10)),
Ok(x) => panic!("get_mut of poisoned Mutex is Ok: {x:?}"),
}
}

#[test]
fn test_set() {
fn inner<T>(mut init: impl FnMut() -> T, mut value: impl FnMut() -> T)
where
T: Debug + Eq,
{
let m = Mutex::new(init());

assert_eq!(*m.lock().unwrap(), init());
m.set(value()).unwrap();
assert_eq!(*m.lock().unwrap(), value());
}

inner(|| NonCopy(10), || NonCopy(20));
inner(|| NonCopyNeedsDrop(10), || NonCopyNeedsDrop(20));
}

#[test]
fn test_set_poison() {
fn inner<T>(mut init: impl FnMut() -> T, mut value: impl FnMut() -> T)
where
T: Debug + Eq,
{
let m = new_poisoned_mutex(init());

match m.set(value()) {
Err(e) => {
assert_eq!(e.into_inner(), value());
assert_eq!(m.into_inner().unwrap_err().into_inner(), init());
}
Ok(x) => panic!("set of poisoned Mutex is Ok: {x:?}"),
}
}

inner(|| NonCopy(10), || NonCopy(20));
inner(|| NonCopyNeedsDrop(10), || NonCopyNeedsDrop(20));
}

#[test]
fn test_replace() {
fn inner<T>(mut init: impl FnMut() -> T, mut value: impl FnMut() -> T)
where
T: Debug + Eq,
{
let m = Mutex::new(init());

assert_eq!(*m.lock().unwrap(), init());
assert_eq!(m.replace(value()).unwrap(), init());
assert_eq!(*m.lock().unwrap(), value());
}

inner(|| NonCopy(10), || NonCopy(20));
inner(|| NonCopyNeedsDrop(10), || NonCopyNeedsDrop(20));
}

#[test]
fn test_replace_poison() {
fn inner<T>(mut init: impl FnMut() -> T, mut value: impl FnMut() -> T)
where
T: Debug + Eq,
{
let m = new_poisoned_mutex(init());

match m.replace(value()) {
Err(e) => {
assert_eq!(e.into_inner(), value());
assert_eq!(m.into_inner().unwrap_err().into_inner(), init());
}
Ok(x) => panic!("replace of poisoned Mutex is Ok: {x:?}"),
}
}

inner(|| NonCopy(10), || NonCopy(20));
inner(|| NonCopyNeedsDrop(10), || NonCopyNeedsDrop(20));
}

#[test]
fn test_mutex_arc_condvar() {
let packet = Packet(Arc::new((Mutex::new(false), Condvar::new())));
Expand Down Expand Up @@ -269,7 +384,7 @@ fn test_mapping_mapped_guard() {
fn panic_while_mapping_unlocked_poison() {
let lock = Mutex::new(());

let _ = crate::panic::catch_unwind(|| {
let _ = panic::catch_unwind(|| {
let guard = lock.lock().unwrap();
let _guard = MutexGuard::map::<(), _>(guard, |_| panic!());
});
Expand All @@ -282,7 +397,7 @@ fn panic_while_mapping_unlocked_poison() {
Err(TryLockError::Poisoned(_)) => {}
}

let _ = crate::panic::catch_unwind(|| {
let _ = panic::catch_unwind(|| {
let guard = lock.lock().unwrap();
let _guard = MutexGuard::try_map::<(), _>(guard, |_| panic!());
});
Expand All @@ -295,7 +410,7 @@ fn panic_while_mapping_unlocked_poison() {
Err(TryLockError::Poisoned(_)) => {}
}

let _ = crate::panic::catch_unwind(|| {
let _ = panic::catch_unwind(|| {
let guard = lock.lock().unwrap();
let guard = MutexGuard::map::<(), _>(guard, |val| val);
let _guard = MappedMutexGuard::map::<(), _>(guard, |_| panic!());
Expand All @@ -309,7 +424,7 @@ fn panic_while_mapping_unlocked_poison() {
Err(TryLockError::Poisoned(_)) => {}
}

let _ = crate::panic::catch_unwind(|| {
let _ = panic::catch_unwind(|| {
let guard = lock.lock().unwrap();
let guard = MutexGuard::map::<(), _>(guard, |val| val);
let _guard = MappedMutexGuard::try_map::<(), _>(guard, |_| panic!());
Expand Down
Loading

0 comments on commit 6667908

Please sign in to comment.