From 7121fbd4c883e8132e14299623c568d5061a0e57 Mon Sep 17 00:00:00 2001 From: Jacob Greenfield Date: Fri, 20 Dec 2024 07:16:47 -0500 Subject: [PATCH] Return `Some` for only one thread when calling `remove()` (#1143) --- crossbeam-skiplist/src/base.rs | 6 +++- crossbeam-skiplist/tests/base.rs | 53 ++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/crossbeam-skiplist/src/base.rs b/crossbeam-skiplist/src/base.rs index 2b85a2bb8..1b4495e35 100644 --- a/crossbeam-skiplist/src/base.rs +++ b/crossbeam-skiplist/src/base.rs @@ -1158,8 +1158,12 @@ where break; } } + return Some(entry); + } else { + // The node has already been marked. + n.decrement(guard); + return None; } - return Some(entry); } } } diff --git a/crossbeam-skiplist/tests/base.rs b/crossbeam-skiplist/tests/base.rs index 3f717a6b9..d12b7fde7 100644 --- a/crossbeam-skiplist/tests/base.rs +++ b/crossbeam-skiplist/tests/base.rs @@ -970,3 +970,56 @@ fn comparable_get() { assert_eq!(ent.key().b, 2); assert_eq!(*ent.value(), 12); } + +// https://github.com/crossbeam-rs/crossbeam/pull/1143 +#[cfg(target_has_atomic = "32")] +#[test] +fn remove_race() { + use std::sync::atomic::{AtomicU32, Ordering}; + + let nthreads = 16; + #[cfg(miri)] + let key_range = 100u32; + #[cfg(not(miri))] + let key_range = 100_000u32; + + let guard = &epoch::pin(); + let s = SkipList::new(epoch::default_collector().clone()); + + for x in 0..key_range { + s.insert(x, (), guard).release(guard); + } + + let barrier1 = AtomicU32::new(nthreads); + let barrier2 = AtomicU32::new(nthreads); + let mut total_removed = AtomicU32::new(0); + + std::thread::scope(|scope| { + for _ in 0..nthreads { + scope.spawn(|| { + let guard = &epoch::pin(); + let mut removed_entries = Vec::with_capacity(key_range as usize); + + barrier1.fetch_sub(1, Ordering::Relaxed); + while barrier1.load(Ordering::Acquire) != 0 {} + + for x in 0..key_range { + if let Some(entry) = s.remove(&x, guard) { + removed_entries.push(entry); + } + } + + barrier2.fetch_sub(1, Ordering::Relaxed); + while barrier2.load(Ordering::Acquire) != 0 {} + + total_removed.fetch_add(removed_entries.len() as u32, Ordering::Relaxed); + + for entry in removed_entries.drain(..) { + entry.release(guard); + } + }); + } + }); + + assert_eq!(*total_removed.get_mut(), key_range); +}