Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes #295: Use a semaphore to protect the Batcher from unbounded memory growth. #304

Merged
merged 1 commit into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 38 additions & 26 deletions limitador/src/storage/redis/counters_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::select;
use tokio::sync::Notify;
use tokio::sync::{Notify, Semaphore};

#[derive(Debug)]
pub struct CachedCounterValue {
Expand Down Expand Up @@ -129,19 +129,21 @@ pub struct Batcher {
notifier: Notify,
interval: Duration,
priority_flush: AtomicBool,
limiter: Semaphore,
}

impl Batcher {
fn new(period: Duration) -> Self {
fn new(period: Duration, max_cached_counters: usize) -> Self {
Self {
updates: Default::default(),
notifier: Default::default(),
interval: period,
priority_flush: AtomicBool::new(false),
limiter: Semaphore::new(max_cached_counters),
}
}

pub fn add(&self, counter: Counter, value: Arc<CachedCounterValue>) {
pub async fn add(&self, counter: Counter, value: Arc<CachedCounterValue>) {
let priority = value.requires_fast_flush(&self.interval);
match self.updates.entry(counter.clone()) {
Entry::Occupied(needs_merge) => {
Expand All @@ -151,6 +153,7 @@ impl Batcher {
}
}
Entry::Vacant(miss) => {
self.limiter.acquire().await.unwrap().forget();
miss.insert_entry(value);
}
};
Expand Down Expand Up @@ -186,8 +189,12 @@ impl Batcher {
}
let result = consumer(result).await;
batch.iter().for_each(|counter| {
self.updates
let prev = self
.updates
.remove_if(counter, |_, v| v.no_pending_writes());
if prev.is_some() {
self.limiter.add_permits(1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the example in the doc is somewhat oversimplified...
Reading the docs, .add_permits doesn't account for the initial permits provided (i.e. consider it a max), see here
So we need to do all the math ourselves, as this wouldn't account for re-adding permits back when counters expire for instance. I'd expect this would run out of permit after a while.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually... thinking about it, it might be good enough. As we don't compact the queue anywhere 🤔
So that, if the entry is passed to the consumer, even if there are no pending writes for this entry e.g. because it expired, then this would also remove it from self.updates and increment the tickets back...

}
});
return result;
} else {
Expand All @@ -214,7 +221,7 @@ impl Batcher {

impl Default for Batcher {
fn default() -> Self {
Self::new(Duration::from_millis(100))
Self::new(Duration::from_millis(100), DEFAULT_MAX_CACHED_COUNTERS)
}
}

Expand Down Expand Up @@ -272,7 +279,7 @@ impl CountersCache {
))
}

pub fn increase_by(&self, counter: &Counter, delta: i64) {
pub async fn increase_by(&self, counter: &Counter, delta: i64) {
let val = self.cache.get_with_by_ref(counter, || {
if let Some(entry) = self.batcher.updates.get(counter) {
entry.value().clone()
Expand All @@ -281,7 +288,7 @@ impl CountersCache {
}
});
val.delta(counter, delta);
self.batcher.add(counter.clone(), val.clone());
self.batcher.add(counter.clone(), val.clone()).await;
}
}

Expand All @@ -304,25 +311,28 @@ impl CountersCacheBuilder {
pub fn build(&self, period: Duration) -> CountersCache {
CountersCache {
cache: Cache::new(self.max_cached_counters as u64),
batcher: Batcher::new(period),
batcher: Batcher::new(period, self.max_cached_counters),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::limit::Limit;
use std::collections::HashMap;
use std::ops::Add;
use std::time::UNIX_EPOCH;

use crate::limit::Limit;

use super::*;

mod cached_counter_value {
use crate::storage::redis::counters_cache::tests::test_counter;
use crate::storage::redis::counters_cache::CachedCounterValue;
use std::ops::{Add, Not};
use std::time::{Duration, SystemTime};

use crate::storage::redis::counters_cache::tests::test_counter;
use crate::storage::redis::counters_cache::CachedCounterValue;

#[test]
fn records_pending_writes() {
let counter = test_counter(10, None);
Expand Down Expand Up @@ -401,15 +411,17 @@ mod tests {
}

mod batcher {
use crate::storage::redis::counters_cache::tests::test_counter;
use crate::storage::redis::counters_cache::{Batcher, CachedCounterValue};
use std::sync::Arc;
use std::time::{Duration, SystemTime};

use crate::storage::redis::counters_cache::tests::test_counter;
use crate::storage::redis::counters_cache::{Batcher, CachedCounterValue};
use crate::storage::redis::DEFAULT_MAX_CACHED_COUNTERS;

#[tokio::test]
async fn consume_waits_when_empty() {
let duration = Duration::from_millis(100);
let batcher = Batcher::new(duration);
let batcher = Batcher::new(duration, DEFAULT_MAX_CACHED_COUNTERS);
let start = SystemTime::now();
batcher
.consume(2, |items| {
Expand All @@ -423,15 +435,15 @@ mod tests {
#[tokio::test]
async fn consume_waits_when_batch_not_filled() {
let duration = Duration::from_millis(100);
let batcher = Arc::new(Batcher::new(duration));
let batcher = Arc::new(Batcher::new(duration, DEFAULT_MAX_CACHED_COUNTERS));
let start = SystemTime::now();
{
let batcher = Arc::clone(&batcher);
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(40)).await;
let counter = test_counter(6, None);
let arc = Arc::new(CachedCounterValue::from_authority(&counter, 0));
batcher.add(counter, arc);
batcher.add(counter, arc).await;
});
}
batcher
Expand All @@ -449,15 +461,15 @@ mod tests {
#[tokio::test]
async fn consume_waits_until_batch_is_filled() {
let duration = Duration::from_millis(100);
let batcher = Arc::new(Batcher::new(duration));
let batcher = Arc::new(Batcher::new(duration, DEFAULT_MAX_CACHED_COUNTERS));
let start = SystemTime::now();
{
let batcher = Arc::clone(&batcher);
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(40)).await;
let counter = test_counter(6, None);
let arc = Arc::new(CachedCounterValue::from_authority(&counter, 0));
batcher.add(counter, arc);
batcher.add(counter, arc).await;
});
}
batcher
Expand All @@ -474,12 +486,12 @@ mod tests {
#[tokio::test]
async fn consume_immediately_when_batch_is_filled() {
let duration = Duration::from_millis(100);
let batcher = Arc::new(Batcher::new(duration));
let batcher = Arc::new(Batcher::new(duration, DEFAULT_MAX_CACHED_COUNTERS));
let start = SystemTime::now();
{
let counter = test_counter(6, None);
let arc = Arc::new(CachedCounterValue::from_authority(&counter, 0));
batcher.add(counter, arc);
batcher.add(counter, arc).await;
}
batcher
.consume(1, |items| {
Expand All @@ -495,15 +507,15 @@ mod tests {
#[tokio::test]
async fn consume_triggers_on_fast_flush() {
let duration = Duration::from_millis(100);
let batcher = Arc::new(Batcher::new(duration));
let batcher = Arc::new(Batcher::new(duration, DEFAULT_MAX_CACHED_COUNTERS));
let start = SystemTime::now();
{
let batcher = Arc::clone(&batcher);
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(40)).await;
let counter = test_counter(6, None);
let arc = Arc::new(CachedCounterValue::load_from_authority_asap(&counter, 0));
batcher.add(counter, arc);
batcher.add(counter, arc).await;
});
}
batcher
Expand Down Expand Up @@ -570,8 +582,8 @@ mod tests {
);
}

#[test]
fn increase_by() {
#[tokio::test]
async fn increase_by() {
let current_val = 10;
let increase_by = 8;
let counter = test_counter(current_val, None);
Expand All @@ -587,7 +599,7 @@ mod tests {
.unwrap()
.as_micros() as i64,
);
cache.increase_by(&counter, increase_by);
cache.increase_by(&counter, increase_by).await;

assert_eq!(
cache.get(&counter).map(|e| e.hits(&counter)).unwrap(),
Expand Down
13 changes: 8 additions & 5 deletions limitador/src/storage/redis/redis_cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl AsyncCounterStorage for CachedRedisStorage {

// Update cached values
for counter in counters.iter() {
self.cached_counters.increase_by(counter, delta);
self.cached_counters.increase_by(counter, delta).await;
}

Ok(Authorization::Ok)
Expand Down Expand Up @@ -489,10 +489,13 @@ mod tests {
)]);

let cache = CountersCacheBuilder::new().build(Duration::from_millis(10));
cache.batcher().add(
counter.clone(),
Arc::new(CachedCounterValue::load_from_authority_asap(&counter, 2)),
);
cache
.batcher()
.add(
counter.clone(),
Arc::new(CachedCounterValue::load_from_authority_asap(&counter, 2)),
)
.await;

let cached_counters: Arc<CountersCache> = Arc::new(cache);
let partitioned = Arc::new(AtomicBool::new(false));
Expand Down
Loading