diff --git a/caching/src/locked/caching.rs b/caching/src/locked/caching.rs index b817165..54ef254 100644 --- a/caching/src/locked/caching.rs +++ b/caching/src/locked/caching.rs @@ -1,37 +1,63 @@ use std::collections::HashMap; +use std::fmt::Debug; use std::future::Future; use std::sync::Arc; use async_trait::async_trait; use serde::de::DeserializeOwned; use serde::Serialize; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; use novax::caching::{CachingDurationStrategy, CachingStrategy}; use novax::errors::NovaXError; +#[allow(type_alias_bounds)] +pub type CachingLocked = BaseCachingLocked>>; + +#[async_trait] +pub trait Locker: Send + Sync + Clone + Debug { + fn new() -> Self; + async fn read(&self) -> RwLockReadGuard<'_, ()>; + async fn write(&self) -> RwLockWriteGuard<'_, ()>; +} + +#[async_trait] +impl Locker for Arc> { + fn new() -> Self { + Self::new(RwLock::new(())) + } + + async fn read(&self) -> RwLockReadGuard<'_, ()> { + self.read().await + } + + async fn write(&self) -> RwLockWriteGuard<'_, ()> { + self.write().await + } +} + #[derive(Clone, Debug)] -pub struct CachingLocked { +pub struct BaseCachingLocked { pub caching: C, - _lockers_map: Arc>>>> + _lockers_map: Arc>> } -impl CachingLocked { - pub fn new(caching: C) -> CachingLocked { - CachingLocked { +impl BaseCachingLocked { + pub fn new(caching: C) -> BaseCachingLocked { + BaseCachingLocked { caching, _lockers_map: Arc::new(Mutex::new(HashMap::new())) } } } -impl CachingLocked { - async fn get_locker(&self, key: u64) -> Result>, NovaXError> { +impl BaseCachingLocked { + async fn get_locker(&self, key: u64) -> Result { let mut lockers_map = self._lockers_map.lock().await; let locker = if let Some(locker) = lockers_map.get(&key) { locker.clone() } else { - let locker = Arc::new(RwLock::new(())); + let locker = L::new(); lockers_map.insert(key, locker.clone()); locker }; @@ -41,7 +67,7 @@ impl CachingLocked { } #[async_trait] -impl CachingStrategy for CachingLocked { +impl CachingStrategy for BaseCachingLocked { async fn get_cache(&self, key: u64) -> Result, NovaXError> { let locker = self.get_locker(key).await?; let lock_value = locker.read().await; @@ -83,7 +109,7 @@ impl CachingStrategy for CachingLocked { } fn with_duration_strategy(&self, strategy: CachingDurationStrategy) -> Self { - CachingLocked::new(self.caching.with_duration_strategy(strategy)) + Self::new(self.caching.with_duration_strategy(strategy)) } }