Skip to content

Commit

Permalink
fixes Kuadrant#295: Use a semaphore to protect the Batcher from unbou…
Browse files Browse the repository at this point in the history
…nded memory growth.

Signed-off-by: Hiram Chirino <[email protected]>
  • Loading branch information
chirino committed Apr 30, 2024
1 parent 31ee24c commit 3488c3a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 19 deletions.
28 changes: 18 additions & 10 deletions limitador/src/storage/redis/counters_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::select;
use tokio::sync::Notify;
use tokio::sync::Semaphore;

pub struct CachedCounterValue {
value: AtomicExpiringValue,
Expand Down Expand Up @@ -132,19 +133,21 @@ pub struct Batcher {
notifier: Notify,
interval: Duration,
priority_flush: AtomicBool,
limiter: Semaphore,
}

impl Batcher {
fn new(period: Duration) -> Self {
fn new(period: Duration, max_cached_counters: usize) -> Self {
Self {
updates: Default::default(),
notifier: Default::default(),
interval: period,
priority_flush: AtomicBool::new(false),
limiter: Semaphore::new(max_cached_counters),
}
}

pub fn add(&self, counter: Counter, value: Arc<CachedCounterValue>) {
pub async fn add(&self, counter: Counter, value: Arc<CachedCounterValue>) {
let priority = value.requires_fast_flush(&self.interval);
match self.updates.entry(counter.clone()) {
Entry::Occupied(needs_merge) => {
Expand All @@ -154,6 +157,7 @@ impl Batcher {
}
}
Entry::Vacant(miss) => {
self.limiter.acquire().await.unwrap().forget();
miss.insert_entry(value);
}
};
Expand Down Expand Up @@ -189,8 +193,12 @@ impl Batcher {
}
let result = consumer(result).await;
batch.iter().for_each(|counter| {
self.updates
.remove_if(counter, |_, v| v.no_pending_writes());
if let Some(_) = self
.updates
.remove_if(counter, |_, v| v.no_pending_writes())
{
self.limiter.add_permits(1);
}
});
return result;
} else {
Expand Down Expand Up @@ -285,7 +293,7 @@ impl CountersCache {
))
}

pub fn increase_by(&self, counter: &Counter, delta: i64) {
pub async fn increase_by(&self, counter: &Counter, delta: i64) {
let val = self.cache.get_with_by_ref(counter, || {
if let Some(entry) = self.batcher.updates.get(counter) {
entry.value().clone()
Expand All @@ -294,7 +302,7 @@ impl CountersCache {
}
});
val.delta(counter, delta);
self.batcher.add(counter.clone(), val.clone());
self.batcher.add(counter.clone(), val.clone()).await;
}

fn ttl_from_redis_ttl(
Expand Down Expand Up @@ -377,7 +385,7 @@ impl CountersCacheBuilder {
max_ttl_cached_counters: self.max_ttl_cached_counters,
ttl_ratio_cached_counters: self.ttl_ratio_cached_counters,
cache: Cache::new(self.max_cached_counters as u64),
batcher: Batcher::new(period),
batcher: Batcher::new(period, self.max_cached_counters),
}
}
}
Expand Down Expand Up @@ -495,8 +503,8 @@ mod tests {
assert_eq!(cache.get(&counter).map(|e| e.hits(&counter)).unwrap(), 0);
}

#[test]
fn increase_by() {
#[tokio::test]
async fn increase_by() {
let current_val = 10;
let increase_by = 8;
let mut values = HashMap::new();
Expand All @@ -520,7 +528,7 @@ mod tests {
Duration::from_secs(0),
SystemTime::now(),
);
cache.increase_by(&counter, increase_by);
cache.increase_by(&counter, increase_by).await;

assert_eq!(
cache.get(&counter).map(|e| e.hits(&counter)).unwrap(),
Expand Down
21 changes: 12 additions & 9 deletions limitador/src/storage/redis/redis_cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl AsyncCounterStorage for CachedRedisStorage {

// Update cached values
for counter in counters.iter() {
self.cached_counters.increase_by(counter, delta);
self.cached_counters.increase_by(counter, delta).await;
}

Ok(Authorization::Ok)
Expand Down Expand Up @@ -480,14 +480,17 @@ mod tests {
)]);

let cache = CountersCacheBuilder::new().build(Duration::from_millis(1));
cache.batcher().add(
counter.clone(),
Arc::new(CachedCounterValue::from_authority(
&counter,
2,
Duration::from_secs(60),
)),
);
cache
.batcher()
.add(
counter.clone(),
Arc::new(CachedCounterValue::from_authority(
&counter,
2,
Duration::from_secs(60),
)),
)
.await;
cache.insert(
counter.clone(),
Some(1),
Expand Down

0 comments on commit 3488c3a

Please sign in to comment.