diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 3949d0c6..800b3efe 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -11,7 +11,7 @@ use std::sync::atomic::{AtomicBool, AtomicI64, Ordering}; use std::sync::Arc; use std::time::{Duration, SystemTime}; use tokio::select; -use tokio::sync::Notify; +use tokio::sync::{Notify, Semaphore}; #[derive(Debug)] pub struct CachedCounterValue { @@ -129,19 +129,21 @@ pub struct Batcher { notifier: Notify, interval: Duration, priority_flush: AtomicBool, + limiter: Semaphore, } impl Batcher { - fn new(period: Duration) -> Self { + fn new(period: Duration, max_cached_counters: usize) -> Self { Self { updates: Default::default(), notifier: Default::default(), interval: period, priority_flush: AtomicBool::new(false), + limiter: Semaphore::new(max_cached_counters), } } - pub fn add(&self, counter: Counter, value: Arc) { + pub async fn add(&self, counter: Counter, value: Arc) { let priority = value.requires_fast_flush(&self.interval); match self.updates.entry(counter.clone()) { Entry::Occupied(needs_merge) => { @@ -151,6 +153,7 @@ impl Batcher { } } Entry::Vacant(miss) => { + self.limiter.acquire().await.unwrap().forget(); miss.insert_entry(value); } }; @@ -186,8 +189,12 @@ impl Batcher { } let result = consumer(result).await; batch.iter().for_each(|counter| { - self.updates + let prev = self + .updates .remove_if(counter, |_, v| v.no_pending_writes()); + if prev.is_some() { + self.limiter.add_permits(1); + } }); return result; } else { @@ -214,7 +221,7 @@ impl Batcher { impl Default for Batcher { fn default() -> Self { - Self::new(Duration::from_millis(100)) + Self::new(Duration::from_millis(100), DEFAULT_MAX_CACHED_COUNTERS) } } @@ -272,7 +279,7 @@ impl CountersCache { )) } - pub fn increase_by(&self, counter: &Counter, delta: i64) { + pub async fn increase_by(&self, counter: &Counter, delta: i64) { let val = self.cache.get_with_by_ref(counter, || { if let Some(entry) = self.batcher.updates.get(counter) { entry.value().clone() @@ -281,7 +288,7 @@ impl CountersCache { } }); val.delta(counter, delta); - self.batcher.add(counter.clone(), val.clone()); + self.batcher.add(counter.clone(), val.clone()).await; } } @@ -304,25 +311,28 @@ impl CountersCacheBuilder { pub fn build(&self, period: Duration) -> CountersCache { CountersCache { cache: Cache::new(self.max_cached_counters as u64), - batcher: Batcher::new(period), + batcher: Batcher::new(period, self.max_cached_counters), } } } #[cfg(test)] mod tests { - use super::*; - use crate::limit::Limit; use std::collections::HashMap; use std::ops::Add; use std::time::UNIX_EPOCH; + use crate::limit::Limit; + + use super::*; + mod cached_counter_value { - use crate::storage::redis::counters_cache::tests::test_counter; - use crate::storage::redis::counters_cache::CachedCounterValue; use std::ops::{Add, Not}; use std::time::{Duration, SystemTime}; + use crate::storage::redis::counters_cache::tests::test_counter; + use crate::storage::redis::counters_cache::CachedCounterValue; + #[test] fn records_pending_writes() { let counter = test_counter(10, None); @@ -401,15 +411,17 @@ mod tests { } mod batcher { - use crate::storage::redis::counters_cache::tests::test_counter; - use crate::storage::redis::counters_cache::{Batcher, CachedCounterValue}; use std::sync::Arc; use std::time::{Duration, SystemTime}; + use crate::storage::redis::counters_cache::tests::test_counter; + use crate::storage::redis::counters_cache::{Batcher, CachedCounterValue}; + use crate::storage::redis::DEFAULT_MAX_CACHED_COUNTERS; + #[tokio::test] async fn consume_waits_when_empty() { let duration = Duration::from_millis(100); - let batcher = Batcher::new(duration); + let batcher = Batcher::new(duration, DEFAULT_MAX_CACHED_COUNTERS); let start = SystemTime::now(); batcher .consume(2, |items| { @@ -423,7 +435,7 @@ mod tests { #[tokio::test] async fn consume_waits_when_batch_not_filled() { let duration = Duration::from_millis(100); - let batcher = Arc::new(Batcher::new(duration)); + let batcher = Arc::new(Batcher::new(duration, DEFAULT_MAX_CACHED_COUNTERS)); let start = SystemTime::now(); { let batcher = Arc::clone(&batcher); @@ -431,7 +443,7 @@ mod tests { tokio::time::sleep(Duration::from_millis(40)).await; let counter = test_counter(6, None); let arc = Arc::new(CachedCounterValue::from_authority(&counter, 0)); - batcher.add(counter, arc); + batcher.add(counter, arc).await; }); } batcher @@ -449,7 +461,7 @@ mod tests { #[tokio::test] async fn consume_waits_until_batch_is_filled() { let duration = Duration::from_millis(100); - let batcher = Arc::new(Batcher::new(duration)); + let batcher = Arc::new(Batcher::new(duration, DEFAULT_MAX_CACHED_COUNTERS)); let start = SystemTime::now(); { let batcher = Arc::clone(&batcher); @@ -457,7 +469,7 @@ mod tests { tokio::time::sleep(Duration::from_millis(40)).await; let counter = test_counter(6, None); let arc = Arc::new(CachedCounterValue::from_authority(&counter, 0)); - batcher.add(counter, arc); + batcher.add(counter, arc).await; }); } batcher @@ -474,12 +486,12 @@ mod tests { #[tokio::test] async fn consume_immediately_when_batch_is_filled() { let duration = Duration::from_millis(100); - let batcher = Arc::new(Batcher::new(duration)); + let batcher = Arc::new(Batcher::new(duration, DEFAULT_MAX_CACHED_COUNTERS)); let start = SystemTime::now(); { let counter = test_counter(6, None); let arc = Arc::new(CachedCounterValue::from_authority(&counter, 0)); - batcher.add(counter, arc); + batcher.add(counter, arc).await; } batcher .consume(1, |items| { @@ -495,7 +507,7 @@ mod tests { #[tokio::test] async fn consume_triggers_on_fast_flush() { let duration = Duration::from_millis(100); - let batcher = Arc::new(Batcher::new(duration)); + let batcher = Arc::new(Batcher::new(duration, DEFAULT_MAX_CACHED_COUNTERS)); let start = SystemTime::now(); { let batcher = Arc::clone(&batcher); @@ -503,7 +515,7 @@ mod tests { tokio::time::sleep(Duration::from_millis(40)).await; let counter = test_counter(6, None); let arc = Arc::new(CachedCounterValue::load_from_authority_asap(&counter, 0)); - batcher.add(counter, arc); + batcher.add(counter, arc).await; }); } batcher @@ -570,8 +582,8 @@ mod tests { ); } - #[test] - fn increase_by() { + #[tokio::test] + async fn increase_by() { let current_val = 10; let increase_by = 8; let counter = test_counter(current_val, None); @@ -587,7 +599,7 @@ mod tests { .unwrap() .as_micros() as i64, ); - cache.increase_by(&counter, increase_by); + cache.increase_by(&counter, increase_by).await; assert_eq!( cache.get(&counter).map(|e| e.hits(&counter)).unwrap(), diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 698d8151..d3f094e6 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -121,7 +121,7 @@ impl AsyncCounterStorage for CachedRedisStorage { // Update cached values for counter in counters.iter() { - self.cached_counters.increase_by(counter, delta); + self.cached_counters.increase_by(counter, delta).await; } Ok(Authorization::Ok) @@ -489,10 +489,13 @@ mod tests { )]); let cache = CountersCacheBuilder::new().build(Duration::from_millis(10)); - cache.batcher().add( - counter.clone(), - Arc::new(CachedCounterValue::load_from_authority_asap(&counter, 2)), - ); + cache + .batcher() + .add( + counter.clone(), + Arc::new(CachedCounterValue::load_from_authority_asap(&counter, 2)), + ) + .await; let cached_counters: Arc = Arc::new(cache); let partitioned = Arc::new(AtomicBool::new(false));