From fe998ee736802c113be7c26a9aead6d3a38058b1 Mon Sep 17 00:00:00 2001
From: Hiram Chirino <hiram@hiramchirino.com>
Date: Tue, 30 Apr 2024 10:40:41 -0400
Subject: [PATCH] fixes #295: Use a semaphore to protect the Batcher from
 unbounded memory growth.

Signed-off-by: Hiram Chirino <hiram@hiramchirino.com>
---
 limitador/src/storage/redis/counters_cache.rs | 30 ++++++++++++-------
 limitador/src/storage/redis/redis_cached.rs   | 21 +++++++------
 2 files changed, 31 insertions(+), 20 deletions(-)

diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs
index 3754a46e..bd562986 100644
--- a/limitador/src/storage/redis/counters_cache.rs
+++ b/limitador/src/storage/redis/counters_cache.rs
@@ -14,6 +14,7 @@ use std::sync::Arc;
 use std::time::{Duration, SystemTime};
 use tokio::select;
 use tokio::sync::Notify;
+use tokio::sync::Semaphore;
 
 pub struct CachedCounterValue {
     value: AtomicExpiringValue,
@@ -132,19 +133,21 @@ pub struct Batcher {
     notifier: Notify,
     interval: Duration,
     priority_flush: AtomicBool,
+    limiter: Semaphore,
 }
 
 impl Batcher {
-    fn new(period: Duration) -> Self {
+    fn new(period: Duration, max_cached_counters: usize) -> Self {
         Self {
             updates: Default::default(),
             notifier: Default::default(),
             interval: period,
             priority_flush: AtomicBool::new(false),
+            limiter: Semaphore::new(max_cached_counters),
         }
     }
 
-    pub fn add(&self, counter: Counter, value: Arc<CachedCounterValue>) {
+    pub async fn add(&self, counter: Counter, value: Arc<CachedCounterValue>) {
         let priority = value.requires_fast_flush(&self.interval);
         match self.updates.entry(counter.clone()) {
             Entry::Occupied(needs_merge) => {
@@ -154,6 +157,7 @@ impl Batcher {
                 }
             }
             Entry::Vacant(miss) => {
+                self.limiter.acquire().await.unwrap().forget();
                 miss.insert_entry(value);
             }
         };
@@ -189,8 +193,12 @@ impl Batcher {
                 }
                 let result = consumer(result).await;
                 batch.iter().for_each(|counter| {
-                    self.updates
-                        .remove_if(counter, |_, v| v.no_pending_writes());
+                    if let Some(_) = self
+                        .updates
+                        .remove_if(counter, |_, v| v.no_pending_writes())
+                    {
+                        self.limiter.add_permits(1);
+                    }
                 });
                 return result;
             } else {
@@ -217,7 +225,7 @@ impl Batcher {
 
 impl Default for Batcher {
     fn default() -> Self {
-        Self::new(Duration::from_millis(100))
+        Self::new(Duration::from_millis(100), 100)
     }
 }
 
@@ -285,7 +293,7 @@ impl CountersCache {
         ))
     }
 
-    pub fn increase_by(&self, counter: &Counter, delta: i64) {
+    pub async fn increase_by(&self, counter: &Counter, delta: i64) {
         let val = self.cache.get_with_by_ref(counter, || {
             if let Some(entry) = self.batcher.updates.get(counter) {
                 entry.value().clone()
@@ -294,7 +302,7 @@ impl CountersCache {
             }
         });
         val.delta(counter, delta);
-        self.batcher.add(counter.clone(), val.clone());
+        self.batcher.add(counter.clone(), val.clone()).await;
     }
 
     fn ttl_from_redis_ttl(
@@ -377,7 +385,7 @@ impl CountersCacheBuilder {
             max_ttl_cached_counters: self.max_ttl_cached_counters,
             ttl_ratio_cached_counters: self.ttl_ratio_cached_counters,
             cache: Cache::new(self.max_cached_counters as u64),
-            batcher: Batcher::new(period),
+            batcher: Batcher::new(period, self.max_cached_counters),
         }
     }
 }
@@ -495,8 +503,8 @@ mod tests {
         assert_eq!(cache.get(&counter).map(|e| e.hits(&counter)).unwrap(), 0);
     }
 
-    #[test]
-    fn increase_by() {
+    #[tokio::test]
+    async fn increase_by() {
         let current_val = 10;
         let increase_by = 8;
         let mut values = HashMap::new();
@@ -520,7 +528,7 @@ mod tests {
             Duration::from_secs(0),
             SystemTime::now(),
         );
-        cache.increase_by(&counter, increase_by);
+        cache.increase_by(&counter, increase_by).await;
 
         assert_eq!(
             cache.get(&counter).map(|e| e.hits(&counter)).unwrap(),
diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs
index 0f5c36ff..d5e8b0d1 100644
--- a/limitador/src/storage/redis/redis_cached.rs
+++ b/limitador/src/storage/redis/redis_cached.rs
@@ -121,7 +121,7 @@ impl AsyncCounterStorage for CachedRedisStorage {
 
         // Update cached values
         for counter in counters.iter() {
-            self.cached_counters.increase_by(counter, delta);
+            self.cached_counters.increase_by(counter, delta).await;
         }
 
         Ok(Authorization::Ok)
@@ -480,14 +480,17 @@ mod tests {
         )]);
 
         let cache = CountersCacheBuilder::new().build(Duration::from_millis(1));
-        cache.batcher().add(
-            counter.clone(),
-            Arc::new(CachedCounterValue::from_authority(
-                &counter,
-                2,
-                Duration::from_secs(60),
-            )),
-        );
+        cache
+            .batcher()
+            .add(
+                counter.clone(),
+                Arc::new(CachedCounterValue::from_authority(
+                    &counter,
+                    2,
+                    Duration::from_secs(60),
+                )),
+            )
+            .await;
         cache.insert(
             counter.clone(),
             Some(1),