From 0cfc09ff92767ed967320e5dcbc3de45f9e8ff1b Mon Sep 17 00:00:00 2001 From: MrCroxx Date: Mon, 28 Oct 2024 06:49:55 +0000 Subject: [PATCH] fix: fix atomic reference count and reclamation with CAS Signed-off-by: MrCroxx --- foyer-memory-v2/src/cache.rs | 2 +- foyer-memory-v2/src/raw.rs | 54 +++++++------ foyer-memory-v2/src/record.rs | 147 ++++++++++++++++++++++++++++++++-- 3 files changed, 171 insertions(+), 32 deletions(-) diff --git a/foyer-memory-v2/src/cache.rs b/foyer-memory-v2/src/cache.rs index cae672b5..e578260f 100644 --- a/foyer-memory-v2/src/cache.rs +++ b/foyer-memory-v2/src/cache.rs @@ -211,7 +211,7 @@ where } /// External reference count of the cached entry. - pub fn refs(&self) -> usize { + pub fn refs(&self) -> isize { match self { CacheEntry::Fifo(entry) => entry.refs(), CacheEntry::Lru(entry) => entry.refs(), diff --git a/foyer-memory-v2/src/raw.rs b/foyer-memory-v2/src/raw.rs index 54c4edeb..b78340a2 100644 --- a/foyer-memory-v2/src/raw.rs +++ b/foyer-memory-v2/src/raw.rs @@ -69,7 +69,7 @@ use std::{ ops::Deref, pin::Pin, ptr::NonNull, - sync::{atomic::Ordering, Arc}, + sync::Arc, task::{Context, Poll}, }; @@ -85,7 +85,7 @@ use foyer_common::{ metrics::Metrics, runtime::SingletonHandle, scope::Scope, - strict_assert, strict_assert_eq, + strict_assert, }; use itertools::Itertools; use parking_lot::Mutex; @@ -177,8 +177,10 @@ where self.metrics.memory_evict.increment(1); strict_assert!(unsafe { evicted.as_ref().is_in_indexer() }); strict_assert!(unsafe { !evicted.as_ref().is_in_eviction() }); - if unsafe { evicted.as_ref().refs().load(Ordering::SeqCst) } == 0 { - if let Some(garbage) = self.release(evicted, false) { + + // Try to free the record if this thread get the permission. + if unsafe { evicted.as_ref() }.need_reclaim() { + if let Some(garbage) = self.reclaim(evicted, false) { garbages.push(garbage); } } @@ -194,8 +196,9 @@ where } strict_assert!(!unsafe { old.as_ref() }.is_in_eviction()); // Because the `old` handle is removed from the indexer, it will not be reinserted again. - if unsafe { old.as_ref().refs().load(Ordering::SeqCst) } == 0 { - if let Some(garbage) = self.release(old, false) { + // Try to free the record if this thread get the permission. + if unsafe { old.as_ref() }.need_reclaim() { + if let Some(garbage) = self.reclaim(old, false) { garbages.push(garbage); } } @@ -213,18 +216,20 @@ where self.usage += weight; self.metrics.memory_usage.increment(weight as f64); // Increase the reference count within the lock section. - unsafe { ptr.as_ref().refs().fetch_add(waiters.len() + 1, Ordering::SeqCst) }; + // The reference count of the new record must be at the moment. + let refs = waiters.len() as isize + 1; + let inc = unsafe { ptr.as_ref() }.inc_refs(refs); + assert_eq!(refs, inc); ptr } #[fastrace::trace(name = "foyer::memory::raw::shard::release")] - fn release(&mut self, mut ptr: NonNull>, reinsert: bool) -> Option> { + fn reclaim(&mut self, mut ptr: NonNull>, reinsert: bool) -> Option> { let record = unsafe { ptr.as_mut() }; - if record.refs().load(Ordering::SeqCst) > 0 { - return None; - } + // Assert the record is in the reclamation phase. + assert_eq!(record.refs(), -1); if record.is_in_indexer() && record.is_ephemeral() { // The entry is ephemeral, remove it from indexer. Ignore reinsertion. @@ -263,7 +268,6 @@ where // Here the handle is neither in the indexer nor in the eviction container. strict_assert!(!record.is_in_indexer()); strict_assert!(!record.is_in_eviction()); - strict_assert_eq!(record.refs().load(Ordering::SeqCst), 0); self.metrics.memory_release.increment(1); self.usage -= record.weight(); @@ -290,13 +294,13 @@ where if record.is_in_eviction() { self.eviction.remove(ptr); } - record.refs().fetch_add(1, Ordering::SeqCst); - strict_assert!(!record.is_in_indexer()); strict_assert!(!record.is_in_eviction()); self.metrics.memory_remove.increment(1); + record.inc_refs_cas(1)?; + Some(ptr) } @@ -337,7 +341,8 @@ where strict_assert!(record.is_in_indexer()); record.set_ephemeral(false); - record.refs().fetch_add(1, Ordering::SeqCst); + + record.inc_refs_cas(1)?; Some(ptr) } @@ -363,8 +368,8 @@ where count += 1; strict_assert!(unsafe { !ptr.as_ref().is_in_indexer() }); strict_assert!(unsafe { !ptr.as_ref().is_in_eviction() }); - if unsafe { ptr.as_ref().refs().load(Ordering::SeqCst) } == 0 { - if let Some(garbage) = self.release(ptr, false) { + if unsafe { ptr.as_ref() }.need_reclaim() { + if let Some(garbage) = self.reclaim(ptr, false) { garbages.push(garbage); } } @@ -737,12 +742,14 @@ where I: Indexer, { fn drop(&mut self) { - if unsafe { self.ptr.as_ref() }.refs().fetch_sub(1, Ordering::SeqCst) == 1 { - let hash = unsafe { self.ptr.as_ref() }.hash(); + let record = unsafe { self.ptr.as_ref() }; + + if record.dec_ref_cas() == -1 { + let hash = record.hash(); let shard = hash as usize % self.inner.shards.len(); let garbage = self.inner.shards[shard] .write() - .with(|mut shard| shard.release(self.ptr, true)); + .with(|mut shard| shard.reclaim(self.ptr, true)); // Do not deallocate data within the lock section. if let Some(listener) = self.inner.event_listener.as_ref() { if let Some(Data { key, value, .. }) = garbage { @@ -760,7 +767,8 @@ where I: Indexer, { fn clone(&self) -> Self { - unsafe { self.ptr.as_ref() }.refs().fetch_add(1, Ordering::SeqCst); + let old = unsafe { self.ptr.as_ref() }.inc_refs(1); + assert!(old > 0); Self { inner: self.inner.clone(), ptr: self.ptr, @@ -823,8 +831,8 @@ where unsafe { self.ptr.as_ref() }.weight() } - pub fn refs(&self) -> usize { - unsafe { self.ptr.as_ref() }.refs().load(Ordering::SeqCst) + pub fn refs(&self) -> isize { + unsafe { self.ptr.as_ref() }.refs() } pub fn is_outdated(&self) -> bool { diff --git a/foyer-memory-v2/src/record.rs b/foyer-memory-v2/src/record.rs index c53882d9..ae0abafb 100644 --- a/foyer-memory-v2/src/record.rs +++ b/foyer-memory-v2/src/record.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicIsize, AtomicU64, Ordering}; use bitflags::bitflags; @@ -69,7 +69,7 @@ where pub(crate) state: E::State, hash: u64, weight: usize, - refs: AtomicUsize, + refs: AtomicIsize, flags: AtomicU64, token: Option, } @@ -87,7 +87,7 @@ where state: data.state, hash: data.hash, weight: data.weight, - refs: AtomicUsize::new(0), + refs: AtomicIsize::new(0), flags: AtomicU64::new(0), // Temporarily set to None, update after inserted into slab. token: None, @@ -160,11 +160,6 @@ where self.weight } - /// Get the record atomic refs. - pub fn refs(&self) -> &AtomicUsize { - &self.refs - } - /// Set in eviction flag with relaxed memory order. pub fn set_in_eviction(&self, val: bool) { self.set_flags(Flags::IN_EVICTION, val, Ordering::Release); @@ -207,4 +202,140 @@ where pub fn get_flags(&self, flags: Flags, order: Ordering) -> bool { self.flags.load(order) & flags.bits() == flags.bits() } + + /// Get the atomic reference count. + /// + /// Return a non-negative value when the record is alive, + /// otherwise, return -1 that implies the record is in the reclamation phase. + pub fn refs(&self) -> isize { + self.refs.load(Ordering::Acquire) + } + + /// Increase the atomic reference count. + /// + /// This function returns the new reference count after the op. + pub fn inc_refs(&self, val: isize) -> isize { + let old = self.refs.fetch_add(val, Ordering::SeqCst); + tracing::trace!( + "[record]: inc record (hash: {}) refs: {} => {}", + self.hash, + old, + old + val + ); + old + val + } + + // /// Decrease the atomic reference count. + // /// + // /// This function returns the new reference count after the op. + // pub fn dec_refs(&self, val: isize) -> isize { + // let old = self.refs.fetch_sub(val, Ordering::SeqCst); + // tracing::trace!( + // "[record]: dec record (hash: {}) refs: {} => {}", + // self.hash, + // old, + // old - val + // ); + // old - val + // } + + /// Increase the atomic reference count with a cas operation, + /// to prevent from increasing the record in the reclamation phase. + /// + /// This function returns the new reference count after the op if the record is not in the reclamation phase. + pub fn inc_refs_cas(&self, val: isize) -> Option { + let mut current = self.refs.load(Ordering::Relaxed); + loop { + if current == -1 { + tracing::trace!( + "[record]: inc record (hash: {}) refs (cas) skipped for it is in reclamation phase", + self.hash + ); + return None; + } + match self + .refs + .compare_exchange(current, current + val, Ordering::SeqCst, Ordering::Acquire) + { + Err(cur) => current = cur, + Ok(_) => { + tracing::trace!( + "[record]: inc record (hash: {}) refs (cas): {} => {}", + self.hash, + current, + current + val + ); + return Some(current + val); + } + } + } + } + + /// Decrease the atomic reference count by 1 with a cas operation. + /// + /// If the refs hits 0 after decreasing, get the permission to reclaim the record. + /// + /// This function returns the new reference count after the op if the record is not in the reclamation phase. + pub fn dec_ref_cas(&self) -> isize { + let mut current = self.refs.load(Ordering::Relaxed); + loop { + match current { + 1 => match self.refs.compare_exchange(1, -1, Ordering::SeqCst, Ordering::Acquire) { + Ok(_) => { + tracing::trace!( + "[record]: dec record (hash: {}) refs from 1 and got reclamation permission", + self.hash + ); + return -1; + } + Err(cur) => current = cur, + }, + c => match self + .refs + .compare_exchange(c, c - 1, Ordering::SeqCst, Ordering::Acquire) + { + Ok(_) => { + tracing::trace!("[record]: dec record (hash: {}) refs: {} => {}", self.hash, c, c - 1); + return c - 1; + } + Err(cur) => current = cur, + }, + } + } + } + + /// Try to acquire the permission to reclaim the record. + /// + /// If `true` is returned, the caller MUST reclaim the record. + pub fn need_reclaim(&self) -> bool { + let current = self.refs.load(Ordering::Acquire); + if current != 0 { + tracing::trace!( + "[record]: check if record (hash: {}) needs reclamation: {} with refs {}", + self.hash, + false, + current + ); + return false; + } + self.refs + .compare_exchange(0, -1, Ordering::SeqCst, Ordering::Acquire) + .inspect(|refs| { + tracing::trace!( + "[record]: check if record (hash: {}) needs reclamation: {} with refs {}", + self.hash, + true, + refs + ) + }) + .inspect_err(|refs| { + tracing::trace!( + "[record]: check if record (hash: {}) needs reclamation: {} with refs {}", + self.hash, + false, + refs + ) + }) + .is_ok() + } }