diff --git a/src/middlewares/methods/cache.rs b/src/middlewares/methods/cache.rs index 9f9b45a..2774b54 100644 --- a/src/middlewares/methods/cache.rs +++ b/src/middlewares/methods/cache.rs @@ -75,7 +75,7 @@ impl Middleware> for CacheMiddl let result = self .cache - .get_or_insert_with(&key, || next(request, context).boxed()) + .get_or_insert_with(key.clone(), || next(request, context).boxed()) .await; if let Ok(ref value) = result { diff --git a/src/utils/cache.rs b/src/utils/cache.rs index 8f7e1fa..fa8aba0 100644 --- a/src/utils/cache.rs +++ b/src/utils/cache.rs @@ -96,15 +96,25 @@ impl Cache { pub async fn get_or_insert_with( &self, - key: &CacheKey, + key: CacheKey, f: F, ) -> Result where F: FnOnce() -> BoxFuture<'static, Result>, { - match self.cache.get(key) { + match self.cache.get(&key) { Some(CacheValue::Value(value)) => Ok(value), - Some(CacheValue::Pending(rx)) => rx.borrow().clone().unwrap(), + Some(CacheValue::Pending(mut rx)) => { + { + let value = rx.borrow(); + if value.is_some() { + return value.clone().unwrap(); + } + } + let _ = rx.changed().await; + let value = rx.borrow(); + value.clone().expect("Cache: should always be Some") + } None => { let (tx, rx) = watch::channel(None); self.cache @@ -119,7 +129,7 @@ impl Cache { .await; } Err(_) => { - self.cache.remove(key).await; + self.cache.remove(&key).await; } }; value @@ -137,3 +147,78 @@ impl Cache { self.cache.sync(); } } + +#[cfg(test)] +mod tests { + use super::*; + use futures::FutureExt as _; + use serde_json::json; + + #[tokio::test] + async fn get_insert_remove() { + let cache = Cache::::new(NonZeroUsize::new(1).unwrap(), None); + + let key = CacheKey::::new(&"key".to_string(), &[]); + + assert_eq!(cache.get(&key).await, None); + + cache.insert(key.clone(), json!("value")).await; + + assert_eq!(cache.get(&key).await, Some(json!("value"))); + + cache.remove(&key).await; + + assert_eq!(cache.get(&key).await, None); + } + + #[tokio::test] + async fn get_or_insert_with_basic() { + let cache = Cache::::new(NonZeroUsize::new(1).unwrap(), None); + + let key = CacheKey::::new(&"key".to_string(), &[]); + + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + + let cache2 = cache.clone(); + let key2 = key.clone(); + let h1 = tokio::spawn(async move { + let value = cache2 + .get_or_insert_with(key2.clone(), || { + async move { + let _ = rx.await; + Ok(json!("value")) + } + .boxed() + }) + .await; + assert_eq!(value, Ok(json!("value"))); + }); + + tokio::task::yield_now().await; + + let cache2 = cache.clone(); + let key2 = key.clone(); + let h2 = tokio::spawn(async move { + println!("5"); + + let value = cache2 + .get_or_insert_with(key2, || { + async { + panic!(); + } + .boxed() + }) + .await; + assert_eq!(value, Ok(json!("value"))); + }); + + tokio::task::yield_now().await; + + tx.send(()).unwrap(); + + h1.await.unwrap(); + h2.await.unwrap(); + + assert_eq!(cache.get(&key).await, Some(json!("value"))); + } +}