diff --git a/caching/src/local/caching_local.rs b/caching/src/local/caching_local.rs index b0fdcf0..19f1606 100644 --- a/caching/src/local/caching_local.rs +++ b/caching/src/local/caching_local.rs @@ -7,7 +7,7 @@ use std::time::Duration; use async_trait::async_trait; use serde::de::DeserializeOwned; use serde::Serialize; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, RwLock}; use tokio::task; use novax::caching::{CachingDurationStrategy, CachingStrategy}; @@ -15,30 +15,34 @@ use novax::errors::CachingError; use novax::errors::NovaXError; use crate::date::get_current_timestamp::{get_current_timestamp, GetDuration}; -use crate::utils::lock::MutexLike; +use crate::utils::lock::{Locker, MutexLike}; -pub type CachingLocal = BaseCachingLocal>>, Mutex>, Mutex, Mutex>; +pub type CachingLocal = BaseCachingLocal>, RwLock>>>, RwLock, RwLock>>, Mutex, Mutex>; -pub struct BaseCachingLocal +pub struct BaseCachingLocal where - MutexValue: MutexLike>>, - MutexExpiration: MutexLike>, - MutexCleanupInterval: MutexLike, - MutexIsCleanupProcessStarted: MutexLike + LockerValue: Locker> + Debug, + LockerValueHashMap: Locker> + Debug, + LockerExpiration: Locker + Debug, + LockerExpirationHashMap: Locker> + Debug, + MutexCleanupInterval: MutexLike + Debug, + MutexIsCleanupProcessStarted: MutexLike + Debug { duration_strategy: CachingDurationStrategy, - value_map: Arc, - expiration_timestamp_map: Arc, + value_map: Arc, + expiration_timestamp_map: Arc, cleanup_interval: Arc, is_cleanup_process_started: Arc, } -impl Clone for BaseCachingLocal +impl Clone for BaseCachingLocal where - MutexValue: MutexLike>>, - MutexExpiration: MutexLike>, - MutexCleanupInterval: MutexLike, - MutexIsCleanupProcessStarted: MutexLike + LockerValue: Locker> + Debug, + LockerValueHashMap: Locker> + Debug, + LockerExpiration: Locker + Debug, + LockerExpirationHashMap: Locker> + Debug, + MutexCleanupInterval: MutexLike + Debug, + MutexIsCleanupProcessStarted: MutexLike + Debug { fn clone(&self) -> Self { Self { @@ -51,12 +55,14 @@ where } } -impl Debug for BaseCachingLocal +impl Debug for BaseCachingLocal where - MutexValue: MutexLike>>, - MutexExpiration: MutexLike>, - MutexCleanupInterval: MutexLike, - MutexIsCleanupProcessStarted: MutexLike + LockerValue: Locker> + Debug, + LockerValueHashMap: Locker> + Debug, + LockerExpiration: Locker + Debug, + LockerExpirationHashMap: Locker> + Debug, + MutexCleanupInterval: MutexLike + Debug, + MutexIsCleanupProcessStarted: MutexLike + Debug { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("BaseCachingLocal") @@ -69,34 +75,74 @@ where } } -impl BaseCachingLocal +impl BaseCachingLocal where - MutexValue: MutexLike>>, - MutexExpiration: MutexLike>, - MutexCleanupInterval: MutexLike, - MutexIsCleanupProcessStarted: MutexLike + LockerValue: Locker> + Debug, + LockerValueHashMap: Locker> + Debug, + LockerExpiration: Locker + Debug, + LockerExpirationHashMap: Locker> + Debug, + MutexCleanupInterval: MutexLike + Debug, + MutexIsCleanupProcessStarted: MutexLike + Debug { - pub fn empty(duration_strategy: CachingDurationStrategy) -> CachingLocal { - CachingLocal { + pub fn empty(duration_strategy: CachingDurationStrategy) -> Self { + BaseCachingLocal { duration_strategy, - value_map: Arc::new(Mutex::new(HashMap::new())), - expiration_timestamp_map: Arc::new(Mutex::new(HashMap::new())), - cleanup_interval: Arc::new(Mutex::new(Duration::from_secs(0))), - is_cleanup_process_started: Arc::new(Mutex::new(false)), + value_map: Arc::new(LockerValueHashMap::new(HashMap::new())), + expiration_timestamp_map: Arc::new(LockerExpirationHashMap::new(HashMap::new())), + cleanup_interval: Arc::new(MutexCleanupInterval::new(Duration::from_secs(0))), + is_cleanup_process_started: Arc::new(MutexIsCleanupProcessStarted::new(false)), } } async fn remove_key(&self, key: u64) { - let _ = self.expiration_timestamp_map.lock().await.remove(&key); - let _ = self.value_map.lock().await.remove(&key); + let contains_key = { + let expiration_timestamp_read_guard = self.expiration_timestamp_map.read().await; + expiration_timestamp_read_guard.contains_key(&key) + }; + + if contains_key { + let mut expiration_write_guard = self.expiration_timestamp_map.write().await; + let mut value_map_write_guard = self.value_map.write().await; + + expiration_write_guard.remove(&key); + value_map_write_guard.remove(&key); + } } async fn set_value(&self, key: u64, value: &T) -> Result<(), NovaXError> { + let contains_key = { + let expiration_timestamp_read_guard = self.expiration_timestamp_map.read().await; + expiration_timestamp_read_guard.contains_key(&key) + }; + let expiration_timestamp = self.duration_strategy.get_duration_timestamp(&get_current_timestamp()?)?; - self.expiration_timestamp_map.lock().await.insert(key, expiration_timestamp); - let Ok(serialized) = rmp_serde::to_vec(value) else { return Err(CachingError::UnableToSerialize.into())}; - self.value_map.lock().await.insert(key, serialized); + + if contains_key { + let expiration_timestamp_map_read_guard = self.expiration_timestamp_map.read().await; + // Important: the key might have been removed since the contains_key assignment. + // If so, we won't set the cache here, but go to the "!contains_key" scope. + // We could lock the whole map but this has a terrible performance impact by creating a bottleneck. + if let Some(expiration_timestamp_locker) = expiration_timestamp_map_read_guard.get(&key) { + let mut expiration_timestamp_write = expiration_timestamp_locker.write().await; + + // Let's do the same for the value + let value_map_read_guard = self.value_map.read().await; + if let Some(value_locker) = value_map_read_guard.get(&key) { + let mut value_write = value_locker.write().await; + *expiration_timestamp_write = expiration_timestamp; + *value_write = serialized; + + return Ok(()); + }; + }; + } + + // The key is not found, we have to lock everything. + let mut expiration_map_write_guard = self.expiration_timestamp_map.write().await; + let mut value_map_write_guard = self.value_map.write().await; + expiration_map_write_guard.insert(key, LockerExpiration::new(expiration_timestamp)); + value_map_write_guard.insert(key, LockerValue::new(serialized)); Ok(()) } @@ -165,14 +211,30 @@ impl CachingLocal } async fn perform_cleanup(&self) -> Result<(), NovaXError> { + // Can create a bottleneck, be sure to not run this function too frequently. let current_timestamp = get_current_timestamp()?; - let mut value_map_locked = self.value_map.lock().await; - let mut expiration_map_locked = self.expiration_timestamp_map.lock().await; + let mut expiration_map_write_guard = self.expiration_timestamp_map.write().await; + let mut value_map_write_guard = self.value_map.write().await; + + let keys: Vec = expiration_map_write_guard + .keys() + .copied() + .collect(); + + for key in keys { + let should_remove = { + let Some(duration_locker) = expiration_map_write_guard.get(&key) else { + continue; + }; + + let duration_read = duration_locker.read().await; + + current_timestamp > *duration_read + }; - for (key, duration) in expiration_map_locked.clone().into_iter() { - if current_timestamp > duration { - value_map_locked.remove(&key); - expiration_map_locked.remove(&key); + if should_remove { + value_map_write_guard.remove(&key); + expiration_map_write_guard.remove(&key); } } @@ -181,27 +243,45 @@ impl CachingLocal } #[async_trait] -impl CachingStrategy for BaseCachingLocal +impl CachingStrategy for BaseCachingLocal where - MutexValue: MutexLike>>, - MutexExpiration: MutexLike>, - MutexCleanupInterval: MutexLike, - MutexIsCleanupProcessStarted: MutexLike + LockerValue: Locker> + Debug, + LockerValueHashMap: Locker> + Debug, + LockerExpiration: Locker + Debug, + LockerExpirationHashMap: Locker> + Debug, + MutexCleanupInterval: MutexLike + Debug, + MutexIsCleanupProcessStarted: MutexLike + Debug { async fn get_cache(&self, key: u64) -> Result, NovaXError> { - let Some(expiration_timestamp) = self.expiration_timestamp_map.lock().await.get(&key).cloned() else { return Ok(None) }; + { + let expiration_timestamp = { + let read_guard = self.expiration_timestamp_map.read().await; + let Some(expiration_timestamp_locker) = read_guard.get(&key) else { + return Ok(None); + }; - if get_current_timestamp()? >= expiration_timestamp { - self.remove_key(key).await; - Ok(None) - } else { - let Some(encoded_value) = self.value_map.lock().await.get(&key).cloned() else { return Ok(None) }; - let Ok(value) = rmp_serde::from_slice(&encoded_value) else { - return Err(CachingError::UnableToDeserialize.into()) + let expiration_timestamp_read = expiration_timestamp_locker.read().await; + *expiration_timestamp_read }; - Ok(Some(value)) - } + if get_current_timestamp()? >= expiration_timestamp { + self.remove_key(key).await; + return Ok(None) + } + }; + + let value_map_read_guard = self.value_map.read().await; + let Some(encoded_value_locked) = value_map_read_guard.get(&key) else { + return Ok(None); + }; + + let encoded_value = encoded_value_locked.read().await; + + let Ok(value) = rmp_serde::from_slice(&encoded_value) else { + return Err(CachingError::UnableToDeserialize.into()) + }; + + Ok(Some(value)) } async fn set_cache(&self, key: u64, value: &T) -> Result<(), NovaXError> { @@ -224,8 +304,11 @@ where } async fn clear(&self) -> Result<(), NovaXError> { - self.expiration_timestamp_map.lock().await.clear(); - self.value_map.lock().await.clear(); + let mut expiration_map_write_guard = self.expiration_timestamp_map.write().await; + let mut value_map_write_guard = self.value_map.write().await; + + expiration_map_write_guard.clear(); + value_map_write_guard.clear(); Ok(()) } @@ -443,8 +526,8 @@ mod test { caching.set_cache(2, &"test2".to_string()).await?; caching.clear().await?; - assert!(caching.value_map.lock().await.is_empty()); - assert!(caching.expiration_timestamp_map.lock().await.is_empty()); + assert!(caching.value_map.write().await.is_empty()); + assert!(caching.expiration_timestamp_map.write().await.is_empty()); Ok(()) } @@ -461,8 +544,8 @@ mod test { caching.perform_cleanup().await?; - let value_map_locked = caching.value_map.lock().await; - let expiration_timestamp_locked = caching.expiration_timestamp_map.lock().await; + let value_map_locked = caching.value_map.write().await; + let expiration_timestamp_locked = caching.expiration_timestamp_map.write().await; assert_eq!(value_map_locked.len(), 1); assert_eq!(expiration_timestamp_locked.len(), 1); @@ -482,8 +565,8 @@ mod test { caching.perform_cleanup().await?; - let value_map_locked = caching.value_map.lock().await; - let expiration_timestamp_locked = caching.expiration_timestamp_map.lock().await; + let value_map_locked = caching.value_map.write().await; + let expiration_timestamp_locked = caching.expiration_timestamp_map.write().await; assert!(value_map_locked.is_empty()); assert!(expiration_timestamp_locked.is_empty()); @@ -513,8 +596,8 @@ mod test { set_mock_time(Duration::from_secs(11)); { - let value_map_locked = caching.value_map.lock().await; - let expiration_timestamp_locked = caching.expiration_timestamp_map.lock().await; + let value_map_locked = caching.value_map.write().await; + let expiration_timestamp_locked = caching.expiration_timestamp_map.write().await; assert_eq!(value_map_locked.len(), 2); assert_eq!(expiration_timestamp_locked.len(), 2); @@ -523,8 +606,8 @@ mod test { caching.perform_cleanup().await?; { - let value_map_locked = caching.value_map.lock().await; - let expiration_timestamp_locked = caching.expiration_timestamp_map.lock().await; + let value_map_locked = caching.value_map.write().await; + let expiration_timestamp_locked = caching.expiration_timestamp_map.write().await; assert_eq!(value_map_locked.len(), 1); assert_eq!(expiration_timestamp_locked.len(), 1); diff --git a/caching/src/locked/caching.rs b/caching/src/locked/caching.rs index 8d9d968..772e166 100644 --- a/caching/src/locked/caching.rs +++ b/caching/src/locked/caching.rs @@ -18,7 +18,7 @@ pub type CachingLocked = BaseCachingLocked, Mu pub struct BaseCachingLocked where C: CachingStrategy, - L: Locker, + L: Locker, M: MutexLike>>, { pub caching: C, @@ -28,7 +28,7 @@ where impl Clone for BaseCachingLocked where C: CachingStrategy, - L: Locker, + L: Locker, M: MutexLike>>, { fn clone(&self) -> Self { @@ -42,7 +42,7 @@ where impl Debug for BaseCachingLocked where C: CachingStrategy, - L: Locker, + L: Locker, M: MutexLike>>, { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { @@ -56,7 +56,7 @@ where impl BaseCachingLocked where C: CachingStrategy, - L: Locker, + L: Locker, M: MutexLike>>, { pub fn new(caching: C) -> BaseCachingLocked { @@ -70,7 +70,7 @@ where impl BaseCachingLocked where C: CachingStrategy, - L: Locker, + L: Locker, M: MutexLike>>, { async fn get_locker(&self, key: u64) -> Result, NovaXError> { @@ -78,7 +78,7 @@ where let locker = if let Some(locker) = lockers_map.get(&key) { locker.clone() } else { - let locker = Arc::new(L::new()); + let locker = Arc::new(L::new(())); lockers_map.insert(key, locker.clone()); locker }; @@ -91,7 +91,7 @@ where impl CachingStrategy for BaseCachingLocked where C: CachingStrategy, - L: Locker, + L: Locker, M: MutexLike>>, { async fn get_cache(&self, key: u64) -> Result, NovaXError> { diff --git a/caching/src/utils/lock.rs b/caching/src/utils/lock.rs index a4e6291..62f0c7f 100644 --- a/caching/src/utils/lock.rs +++ b/caching/src/utils/lock.rs @@ -25,22 +25,25 @@ impl MutexLike for Mutex { #[async_trait] pub trait Locker: Send + Sync { - fn new() -> Self; - async fn read(&self) -> RwLockReadGuard<'_, ()>; - async fn write(&self) -> RwLockWriteGuard<'_, ()>; + type T; + fn new(value: Self::T) -> Self; + async fn read(&self) -> RwLockReadGuard<'_, Self::T>; + async fn write(&self) -> RwLockWriteGuard<'_, Self::T>; } #[async_trait] -impl Locker for RwLock<()> { - fn new() -> Self { - Self::new(()) +impl Locker for RwLock { + type T = T; + + fn new(value: T) -> Self { + Self::new(value) } - async fn read(&self) -> RwLockReadGuard<'_, ()> { + async fn read(&self) -> RwLockReadGuard<'_, T> { RwLock::read(self).await } - async fn write(&self) -> RwLockWriteGuard<'_, ()> { + async fn write(&self) -> RwLockWriteGuard<'_, T> { RwLock::write(self).await } } \ No newline at end of file