diff --git a/foyer-memory/src/cache.rs b/foyer-memory/src/cache.rs index bdb23b3a..79f3f7f3 100644 --- a/foyer-memory/src/cache.rs +++ b/foyer-memory/src/cache.rs @@ -33,7 +33,7 @@ use crate::{ s3fifo::{S3Fifo, S3FifoHandle}, sanity::SanityEviction, }, - generic::{GenericCache, GenericCacheConfig, GenericCacheEntry, GenericFetch, Weighter}, + generic::{FetchState, GenericCache, GenericCacheConfig, GenericCacheEntry, GenericFetch, Weighter}, indexer::{hash_table::HashTableIndexer, sanity::SanityIndexer}, FifoConfig, LfuConfig, LruConfig, S3FifoConfig, }; @@ -42,22 +42,22 @@ pub type FifoCache = GenericCache>, SanityIndexer>>, S>; pub type FifoCacheEntry = GenericCacheEntry>, SanityIndexer>>, S>; -pub type FifoFetch = - GenericFetch>, SanityIndexer>>, S, ER>; +pub type FifoFetch = + GenericFetch>, SanityIndexer>>, S, ER, T>; pub type LruCache = GenericCache>, SanityIndexer>>, S>; pub type LruCacheEntry = GenericCacheEntry>, SanityIndexer>>, S>; -pub type LruFetch = - GenericFetch>, SanityIndexer>>, S, ER>; +pub type LruFetch = + GenericFetch>, SanityIndexer>>, S, ER, T>; pub type LfuCache = GenericCache>, SanityIndexer>>, S>; pub type LfuCacheEntry = GenericCacheEntry>, SanityIndexer>>, S>; -pub type LfuFetch = - GenericFetch>, SanityIndexer>>, S, ER>; +pub type LfuFetch = + GenericFetch>, SanityIndexer>>, S, ER, T>; pub type S3FifoCache = GenericCache>, SanityIndexer>>, S>; @@ -68,8 +68,15 @@ pub type S3FifoCacheEntry = GenericCacheEntry< SanityIndexer>>, S, >; -pub type S3FifoFetch = - GenericFetch>, SanityIndexer>>, S, ER>; +pub type S3FifoFetch = GenericFetch< + K, + V, + SanityEviction>, + SanityIndexer>>, + S, + ER, + T, +>; /// A cached entry holder of the in-memory cache. #[derive(Debug)] @@ -657,72 +664,73 @@ where } /// A future that is used to get entry value from the remote storage for the in-memory cache. -pub enum Fetch +pub enum Fetch where K: Key, V: Value, S: HashBuilder, { /// A future that is used to get entry value from the remote storage for the in-memory FIFO cache. - Fifo(FifoFetch), + Fifo(FifoFetch), /// A future that is used to get entry value from the remote storage for the in-memory LRU cache. - Lru(LruFetch), + Lru(LruFetch), /// A future that is used to get entry value from the remote storage for the in-memory LFU cache. - Lfu(LfuFetch), + Lfu(LfuFetch), /// A future that is used to get entry value from the remote storage for the in-memory S3FIFO cache. - S3Fifo(S3FifoFetch), + S3Fifo(S3FifoFetch), } -impl From> for Fetch +impl From> for Fetch where K: Key, V: Value, S: HashBuilder, { - fn from(entry: FifoFetch) -> Self { + fn from(entry: FifoFetch) -> Self { Self::Fifo(entry) } } -impl From> for Fetch +impl From> for Fetch where K: Key, V: Value, S: HashBuilder, { - fn from(entry: LruFetch) -> Self { + fn from(entry: LruFetch) -> Self { Self::Lru(entry) } } -impl From> for Fetch +impl From> for Fetch where K: Key, V: Value, S: HashBuilder, { - fn from(entry: LfuFetch) -> Self { + fn from(entry: LfuFetch) -> Self { Self::Lfu(entry) } } -impl From> for Fetch +impl From> for Fetch where K: Key, V: Value, S: HashBuilder, { - fn from(entry: S3FifoFetch) -> Self { + fn from(entry: S3FifoFetch) -> Self { Self::S3Fifo(entry) } } -impl Future for Fetch +impl Future for Fetch where K: Key, V: Value, ER: From, S: HashBuilder, + T: Default, { type Output = std::result::Result, ER>; @@ -736,18 +744,7 @@ where } } -/// The state of [`Fetch`]. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum FetchState { - /// Cache hit. - Hit, - /// Cache miss, but wait in queue. - Wait, - /// Cache miss, and there is no other waiters at the moment. - Miss, -} - -impl Fetch +impl Fetch where K: Key, V: Value, @@ -756,18 +753,21 @@ where /// Get the fetch state. pub fn state(&self) -> FetchState { match self { - Fetch::Fifo(FifoFetch::Hit(_)) - | Fetch::Lru(LruFetch::Hit(_)) - | Fetch::Lfu(LfuFetch::Hit(_)) - | Fetch::S3Fifo(S3FifoFetch::Hit(_)) => FetchState::Hit, - Fetch::Fifo(FifoFetch::Wait(_)) - | Fetch::Lru(LruFetch::Wait(_)) - | Fetch::Lfu(LfuFetch::Wait(_)) - | Fetch::S3Fifo(S3FifoFetch::Wait(_)) => FetchState::Wait, - Fetch::Fifo(FifoFetch::Miss(_)) - | Fetch::Lru(LruFetch::Miss(_)) - | Fetch::Lfu(LfuFetch::Miss(_)) - | Fetch::S3Fifo(S3FifoFetch::Miss(_)) => FetchState::Miss, + Fetch::Fifo(fetch) => fetch.state(), + Fetch::Lru(fetch) => fetch.state(), + Fetch::Lfu(fetch) => fetch.state(), + Fetch::S3Fifo(fetch) => fetch.state(), + } + } + + /// Get the ext of the fetch. + #[doc(hidden)] + pub fn ext(&self) -> &T { + match self { + Fetch::Fifo(fetch) => fetch.ext(), + Fetch::Lru(fetch) => fetch.ext(), + Fetch::Lfu(fetch) => fetch.ext(), + Fetch::S3Fifo(fetch) => fetch.ext(), } } } @@ -826,17 +826,18 @@ where /// /// The concurrent fetch requests will be deduplicated. #[doc(hidden)] - pub fn fetch_inner( + pub fn fetch_inner( &self, key: K, context: CacheContext, fetch: F, runtime: &tokio::runtime::Handle, - ) -> Fetch + ) -> Fetch where F: FnOnce() -> FU, - FU: Future> + Send + 'static, + FU: Future> + Send + 'static, ER: Send + 'static + Debug, + T: Default + Send + Sync + 'static, { match self { Cache::Fifo(cache) => Fetch::from(cache.fetch_inner(key, context, fetch, runtime)), diff --git a/foyer-memory/src/generic.rs b/foyer-memory/src/generic.rs index 416bd2a9..47a94609 100644 --- a/foyer-memory/src/generic.rs +++ b/foyer-memory/src/generic.rs @@ -24,7 +24,7 @@ use std::{ atomic::{AtomicUsize, Ordering}, Arc, }, - task::{Context, Poll}, + task::{ready, Context, Poll}, }; use ahash::RandomState; @@ -35,6 +35,7 @@ use foyer_common::{ object_pool::ObjectPool, strict_assert, strict_assert_eq, }; +use futures::FutureExt; use hashbrown::hash_map::{Entry as HashMapEntry, HashMap}; use itertools::Itertools; use parking_lot::{lock_api::MutexGuard, Mutex, RawMutex}; @@ -396,10 +397,99 @@ where type GenericFetchHit = Option>; type GenericFetchWait = InSpan>>; -type GenericFetchMiss = JoinHandle, ER>>; +type GenericFetchMiss = + JoinHandle, T), ER>>; + +/// The state of [`Fetch`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FetchState { + /// Cache hit. + Hit, + /// Cache miss, but wait in queue. + Wait, + /// Cache miss, and there is no other waiters at the moment. + Miss, +} + +#[pin_project] +pub struct GenericFetch +where + K: Key, + V: Value, + E: Eviction, + E::Handle: KeyedHandle, + I: Indexer, + S: HashBuilder, +{ + #[pin] + inner: GenericFetchInner, + + ext: T, +} + +impl GenericFetch +where + K: Key, + V: Value, + E: Eviction, + E::Handle: KeyedHandle, + I: Indexer, + S: HashBuilder, +{ + fn new(inner: GenericFetchInner) -> Self + where + T: Default, + { + Self { + inner, + ext: T::default(), + } + } + + pub fn state(&self) -> FetchState { + match self.inner { + GenericFetchInner::Hit(_) => FetchState::Hit, + GenericFetchInner::Wait(_) => FetchState::Wait, + GenericFetchInner::Miss(_) => FetchState::Miss, + } + } + + pub fn ext(&self) -> &T { + &self.ext + } +} + +impl Future for GenericFetch +where + K: Key, + V: Value, + E: Eviction, + E::Handle: KeyedHandle, + I: Indexer, + S: HashBuilder, + ER: From, + T: Default, +{ + type Output = std::result::Result, ER>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let res = ready!(this.inner.poll(cx)); + + let res = match res { + Ok((entry, ext)) => { + *this.ext = ext; + Ok(entry) + } + Err(e) => Err(e), + }; + + Poll::Ready(res) + } +} #[pin_project(project = GenericFetchInnerProj)] -pub enum GenericFetch +enum GenericFetchInner where K: Key, V: Value, @@ -410,10 +500,10 @@ where { Hit(GenericFetchHit), Wait(#[pin] GenericFetchWait), - Miss(#[pin] GenericFetchMiss), + Miss(#[pin] GenericFetchMiss), } -impl Future for GenericFetch +impl Future for GenericFetchInner where K: Key, V: Value, @@ -422,13 +512,17 @@ where I: Indexer, S: HashBuilder, ER: From, + T: Default, { - type Output = std::result::Result, ER>; + type Output = std::result::Result<(GenericCacheEntry, T), ER>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.project() { - GenericFetchInnerProj::Hit(opt) => Poll::Ready(Ok(opt.take().unwrap())), - GenericFetchInnerProj::Wait(waiter) => waiter.poll(cx).map_err(|err| err.into()), + GenericFetchInnerProj::Hit(opt) => Poll::Ready(Ok((opt.take().unwrap(), T::default()))), + GenericFetchInnerProj::Wait(waiter) => waiter + .poll(cx) + .map(|res| res.map(|v| (v, T::default()))) + .map_err(|err| err.into()), GenericFetchInnerProj::Miss(handle) => handle.poll(cx).map(|join| join.unwrap()), } } @@ -711,7 +805,12 @@ where FU: Future> + Send + 'static, ER: Send + 'static + Debug, { - self.fetch_inner(key, CacheContext::default(), fetch, &tokio::runtime::Handle::current()) + self.fetch_inner( + key, + CacheContext::default(), + || fetch().map(|res| res.map(|v| (v, ()))), + &tokio::runtime::Handle::current(), + ) } pub fn fetch_with_context( @@ -725,20 +824,26 @@ where FU: Future> + Send + 'static, ER: Send + 'static + Debug, { - self.fetch_inner(key, context, fetch, &tokio::runtime::Handle::current()) + self.fetch_inner( + key, + context, + || fetch().map(|res| res.map(|v| (v, ()))), + &tokio::runtime::Handle::current(), + ) } - pub fn fetch_inner( + pub fn fetch_inner( self: &Arc, key: K, context: CacheContext, fetch: F, runtime: &tokio::runtime::Handle, - ) -> GenericFetch + ) -> GenericFetch where F: FnOnce() -> FU, - FU: Future> + Send + 'static, + FU: Future> + Send + 'static, ER: Send + 'static + Debug, + T: Send + Sync + 'static + Default, { let hash = self.hash_builder.hash_one(&key); @@ -746,19 +851,19 @@ where let mut shard = self.shard(hash as usize % self.shards.len()); if let Some(ptr) = unsafe { shard.get(hash, &key) } { - return GenericFetch::Hit(Some(GenericCacheEntry { + return GenericFetch::new(GenericFetchInner::Hit(Some(GenericCacheEntry { cache: self.clone(), ptr, - })); + }))); } match shard.waiters.entry(key.clone()) { HashMapEntry::Occupied(mut o) => { let (tx, rx) = oneshot::channel(); o.get_mut().push(tx); shard.state.metrics.memory_queue.increment(1); - return GenericFetch::Wait(rx.in_span(Span::enter_with_local_parent( + return GenericFetch::new(GenericFetchInner::Wait(rx.in_span(Span::enter_with_local_parent( "foyer::memory::generic::fetch_with_runtime::wait", - ))); + )))); } HashMapEntry::Vacant(v) => { v.insert(vec![]); @@ -771,13 +876,13 @@ where let future = fetch(); let join = runtime.spawn( async move { - let value = match future + let (value, ext) = match future .in_span(Span::enter_with_local_parent( "foyer::memory::generic::fetch_with_runtime::fn", )) .await { - Ok(value) => value, + Ok((value, ext)) => (value, ext), Err(e) => { let mut shard = cache.shard(hash as usize % cache.shards.len()); tracing::debug!("[fetch]: error raise while fetching, all waiter are dropped, err: {e:?}"); @@ -786,13 +891,13 @@ where } }; let entry = cache.insert_with_context(key, value, context); - Ok(entry) + Ok((entry, ext)) } .in_span(Span::enter_with_local_parent( "foyer::memory::generic::fetch_with_runtime::spawn", )), ); - GenericFetch::Miss(join) + GenericFetch::new(GenericFetchInner::Miss(join)) } } diff --git a/foyer-memory/src/prelude.rs b/foyer-memory/src/prelude.rs index 04e39342..3970a3a6 100644 --- a/foyer-memory/src/prelude.rs +++ b/foyer-memory/src/prelude.rs @@ -13,9 +13,9 @@ // limitations under the License. pub use crate::{ - cache::{Cache, CacheBuilder, CacheEntry, EvictionConfig, Fetch, FetchState}, + cache::{Cache, CacheBuilder, CacheEntry, EvictionConfig, Fetch}, context::CacheContext, eviction::{fifo::FifoConfig, lfu::LfuConfig, lru::LruConfig, s3fifo::S3FifoConfig}, - generic::Weighter, + generic::{FetchState, Weighter}, }; pub use ahash::RandomState; diff --git a/foyer/src/hybrid/cache.rs b/foyer/src/hybrid/cache.rs index 88e3ac2f..fe23c8d6 100644 --- a/foyer/src/hybrid/cache.rs +++ b/foyer/src/hybrid/cache.rs @@ -35,6 +35,7 @@ use foyer_common::{ }; use foyer_memory::{Cache, CacheContext, CacheEntry, Fetch, FetchState}; use foyer_storage::{DeviceStats, Storage, Store}; +use futures::FutureExt; use minitrace::prelude::*; use pin_project::pin_project; use tokio::sync::oneshot; @@ -402,9 +403,8 @@ where S: HashBuilder + Debug, { #[pin] - inner: Fetch, + inner: Fetch, - enqueue: Arc, storage: Store, } @@ -417,11 +417,11 @@ where type Output = anyhow::Result>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - let res = ready!(this.inner.poll(cx)); + let mut this = self.project(); + let res = ready!(this.inner.as_mut().poll(cx)); if let Ok(entry) = res.as_ref() { - if this.enqueue.load(Ordering::Acquire) { + if *this.inner.ext() { this.storage.enqueue(entry.clone(), false); } } @@ -436,7 +436,7 @@ where V: StorageValue, S: HashBuilder + Debug, { - type Target = Fetch; + type Target = Fetch; fn deref(&self) -> &Self::Target { &self.inner @@ -477,14 +477,14 @@ where let now = Instant::now(); let store = self.storage.clone(); - let enqueue = Arc::::default(); + let future = fetch(); let inner = self.memory.fetch_inner( key.clone(), context, || { let metrics = self.metrics.clone(); - let enqueue = enqueue.clone(); + async move { match store.load(&key).await.map_err(anyhow::Error::from)? { None => {} @@ -493,16 +493,15 @@ where metrics.hybrid_hit.increment(1); metrics.hybrid_hit_duration.record(now.elapsed()); - return Ok(v); + return Ok((v, false)); } } metrics.hybrid_miss.increment(1); metrics.hybrid_miss_duration.record(now.elapsed()); - enqueue.store(true, Ordering::Release); - future + .map(|res| res.map(|v| (v, true))) .in_span(Span::enter_with_local_parent("foyer::hybrid::fetch::fn")) .await .map_err(anyhow::Error::from) @@ -518,7 +517,6 @@ where let inner = HybridFetchInner { inner, - enqueue, storage: self.storage.clone(), };