From fe998ee736802c113be7c26a9aead6d3a38058b1 Mon Sep 17 00:00:00 2001 From: Hiram Chirino <hiram@hiramchirino.com> Date: Tue, 30 Apr 2024 10:40:41 -0400 Subject: [PATCH] fixes #295: Use a semaphore to protect the Batcher from unbounded memory growth. Signed-off-by: Hiram Chirino <hiram@hiramchirino.com> --- limitador/src/storage/redis/counters_cache.rs | 30 ++++++++++++------- limitador/src/storage/redis/redis_cached.rs | 21 +++++++------ 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 3754a46e..bd562986 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -14,6 +14,7 @@ use std::sync::Arc; use std::time::{Duration, SystemTime}; use tokio::select; use tokio::sync::Notify; +use tokio::sync::Semaphore; pub struct CachedCounterValue { value: AtomicExpiringValue, @@ -132,19 +133,21 @@ pub struct Batcher { notifier: Notify, interval: Duration, priority_flush: AtomicBool, + limiter: Semaphore, } impl Batcher { - fn new(period: Duration) -> Self { + fn new(period: Duration, max_cached_counters: usize) -> Self { Self { updates: Default::default(), notifier: Default::default(), interval: period, priority_flush: AtomicBool::new(false), + limiter: Semaphore::new(max_cached_counters), } } - pub fn add(&self, counter: Counter, value: Arc<CachedCounterValue>) { + pub async fn add(&self, counter: Counter, value: Arc<CachedCounterValue>) { let priority = value.requires_fast_flush(&self.interval); match self.updates.entry(counter.clone()) { Entry::Occupied(needs_merge) => { @@ -154,6 +157,7 @@ impl Batcher { } } Entry::Vacant(miss) => { + self.limiter.acquire().await.unwrap().forget(); miss.insert_entry(value); } }; @@ -189,8 +193,12 @@ impl Batcher { } let result = consumer(result).await; batch.iter().for_each(|counter| { - self.updates - .remove_if(counter, |_, v| v.no_pending_writes()); + if let Some(_) = self + .updates + .remove_if(counter, |_, v| v.no_pending_writes()) + { + self.limiter.add_permits(1); + } }); return result; } else { @@ -217,7 +225,7 @@ impl Batcher { impl Default for Batcher { fn default() -> Self { - Self::new(Duration::from_millis(100)) + Self::new(Duration::from_millis(100), 100) } } @@ -285,7 +293,7 @@ impl CountersCache { )) } - pub fn increase_by(&self, counter: &Counter, delta: i64) { + pub async fn increase_by(&self, counter: &Counter, delta: i64) { let val = self.cache.get_with_by_ref(counter, || { if let Some(entry) = self.batcher.updates.get(counter) { entry.value().clone() @@ -294,7 +302,7 @@ impl CountersCache { } }); val.delta(counter, delta); - self.batcher.add(counter.clone(), val.clone()); + self.batcher.add(counter.clone(), val.clone()).await; } fn ttl_from_redis_ttl( @@ -377,7 +385,7 @@ impl CountersCacheBuilder { max_ttl_cached_counters: self.max_ttl_cached_counters, ttl_ratio_cached_counters: self.ttl_ratio_cached_counters, cache: Cache::new(self.max_cached_counters as u64), - batcher: Batcher::new(period), + batcher: Batcher::new(period, self.max_cached_counters), } } } @@ -495,8 +503,8 @@ mod tests { assert_eq!(cache.get(&counter).map(|e| e.hits(&counter)).unwrap(), 0); } - #[test] - fn increase_by() { + #[tokio::test] + async fn increase_by() { let current_val = 10; let increase_by = 8; let mut values = HashMap::new(); @@ -520,7 +528,7 @@ mod tests { Duration::from_secs(0), SystemTime::now(), ); - cache.increase_by(&counter, increase_by); + cache.increase_by(&counter, increase_by).await; assert_eq!( cache.get(&counter).map(|e| e.hits(&counter)).unwrap(), diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 0f5c36ff..d5e8b0d1 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -121,7 +121,7 @@ impl AsyncCounterStorage for CachedRedisStorage { // Update cached values for counter in counters.iter() { - self.cached_counters.increase_by(counter, delta); + self.cached_counters.increase_by(counter, delta).await; } Ok(Authorization::Ok) @@ -480,14 +480,17 @@ mod tests { )]); let cache = CountersCacheBuilder::new().build(Duration::from_millis(1)); - cache.batcher().add( - counter.clone(), - Arc::new(CachedCounterValue::from_authority( - &counter, - 2, - Duration::from_secs(60), - )), - ); + cache + .batcher() + .add( + counter.clone(), + Arc::new(CachedCounterValue::from_authority( + &counter, + 2, + Duration::from_secs(60), + )), + ) + .await; cache.insert( counter.clone(), Some(1),