diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 3754a46e..8982e5c6 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -14,6 +14,7 @@ use std::sync::Arc; use std::time::{Duration, SystemTime}; use tokio::select; use tokio::sync::Notify; +use tokio::sync::Semaphore; pub struct CachedCounterValue { value: AtomicExpiringValue, @@ -132,6 +133,7 @@ pub struct Batcher { notifier: Notify, interval: Duration, priority_flush: AtomicBool, + limiter: Semaphore, } impl Batcher { @@ -141,10 +143,11 @@ impl Batcher { notifier: Default::default(), interval: period, priority_flush: AtomicBool::new(false), + limiter: Semaphore::new(1000), } } - pub fn add(&self, counter: Counter, value: Arc) { + pub async fn add(&self, counter: Counter, value: Arc) { let priority = value.requires_fast_flush(&self.interval); match self.updates.entry(counter.clone()) { Entry::Occupied(needs_merge) => { @@ -154,6 +157,7 @@ impl Batcher { } } Entry::Vacant(miss) => { + self.limiter.acquire().await.unwrap().forget(); miss.insert_entry(value); } }; @@ -189,8 +193,12 @@ impl Batcher { } let result = consumer(result).await; batch.iter().for_each(|counter| { - self.updates - .remove_if(counter, |_, v| v.no_pending_writes()); + if let Some(_) = self + .updates + .remove_if(counter, |_, v| v.no_pending_writes()) + { + self.limiter.add_permits(1); + } }); return result; } else { @@ -285,7 +293,7 @@ impl CountersCache { )) } - pub fn increase_by(&self, counter: &Counter, delta: i64) { + pub async fn increase_by(&self, counter: &Counter, delta: i64) { let val = self.cache.get_with_by_ref(counter, || { if let Some(entry) = self.batcher.updates.get(counter) { entry.value().clone() @@ -294,7 +302,7 @@ impl CountersCache { } }); val.delta(counter, delta); - self.batcher.add(counter.clone(), val.clone()); + self.batcher.add(counter.clone(), val.clone()).await; } fn ttl_from_redis_ttl( @@ -495,8 +503,8 @@ mod tests { assert_eq!(cache.get(&counter).map(|e| e.hits(&counter)).unwrap(), 0); } - #[test] - fn increase_by() { + #[tokio::test] + async fn increase_by() { let current_val = 10; let increase_by = 8; let mut values = HashMap::new(); @@ -520,7 +528,7 @@ mod tests { Duration::from_secs(0), SystemTime::now(), ); - cache.increase_by(&counter, increase_by); + cache.increase_by(&counter, increase_by).await; assert_eq!( cache.get(&counter).map(|e| e.hits(&counter)).unwrap(), diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 0f5c36ff..d5e8b0d1 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -121,7 +121,7 @@ impl AsyncCounterStorage for CachedRedisStorage { // Update cached values for counter in counters.iter() { - self.cached_counters.increase_by(counter, delta); + self.cached_counters.increase_by(counter, delta).await; } Ok(Authorization::Ok) @@ -480,14 +480,17 @@ mod tests { )]); let cache = CountersCacheBuilder::new().build(Duration::from_millis(1)); - cache.batcher().add( - counter.clone(), - Arc::new(CachedCounterValue::from_authority( - &counter, - 2, - Duration::from_secs(60), - )), - ); + cache + .batcher() + .add( + counter.clone(), + Arc::new(CachedCounterValue::from_authority( + &counter, + 2, + Duration::from_secs(60), + )), + ) + .await; cache.insert( counter.clone(), Some(1),