From 3895c8a8746f293deaf10fe2748e94114850790c Mon Sep 17 00:00:00 2001 From: Hiram Chirino Date: Mon, 13 May 2024 10:32:16 -0400 Subject: [PATCH 1/2] [distributed store] use a single map Vec -> Counters map By keying using Vec we reduce how often we need to encode/decode the the counter keys. Signed-off-by: Hiram Chirino --- limitador/src/storage/distributed/grpc/mod.rs | 46 +-- limitador/src/storage/distributed/mod.rs | 318 +++++++----------- 2 files changed, 131 insertions(+), 233 deletions(-) diff --git a/limitador/src/storage/distributed/grpc/mod.rs b/limitador/src/storage/distributed/grpc/mod.rs index 2339e1cf..79d3d985 100644 --- a/limitador/src/storage/distributed/grpc/mod.rs +++ b/limitador/src/storage/distributed/grpc/mod.rs @@ -1,11 +1,10 @@ -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::ops::Add; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::{error::Error, io::ErrorKind, pin::Pin}; -use moka::sync::Cache; use tokio::sync::mpsc::Sender; use tokio::sync::{broadcast, mpsc, RwLock}; use tokio::time::sleep; @@ -14,15 +13,12 @@ use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt}; use tonic::{Code, Request, Response, Status, Streaming}; use tracing::debug; -use crate::counter::Counter; -use crate::storage::distributed::cr_counter_value::CrCounterValue; use crate::storage::distributed::grpc::v1::packet::Message; use crate::storage::distributed::grpc::v1::replication_client::ReplicationClient; use crate::storage::distributed::grpc::v1::replication_server::{Replication, ReplicationServer}; use crate::storage::distributed::grpc::v1::{ CounterUpdate, Hello, MembershipUpdate, Packet, Peer, Pong, }; -use crate::storage::distributed::CounterKey; // clippy will barf on protobuff generated code for enum variants in // v3::socket_option::SocketState, so allow this lint @@ -187,34 +183,7 @@ impl Session { } Some(Message::CounterUpdate(update)) => { debug!("peer: '{}': CounterUpdate", self.peer_id); - - let counter_key = postcard::from_bytes::(update.key.as_slice()) - .map_err(|err| { - Status::internal(format!("failed to decode counter key: {:?}", err)) - })?; - - let values = BTreeMap::from_iter( - update - .values - .iter() - .map(|(k, v)| (k.to_owned(), v.to_owned())), - ); - - let counter = >::into(counter_key); - if counter.is_qualified() { - if let Some(counter) = self.broker_state.qualified_counters.get(&counter) { - counter.merge( - (UNIX_EPOCH + Duration::from_secs(update.expires_at), values).into(), - ); - } - } else { - let counters = self.broker_state.limits_for_namespace.read().unwrap(); - let limits = counters.get(counter.namespace()).unwrap(); - let value = limits.get(counter.limit()).unwrap(); - value.merge( - (UNIX_EPOCH + Duration::from_secs(update.expires_at), values).into(), - ); - }; + (self.broker_state.on_counter_update)(update); } _ => { debug!("peer: '{}': unsupported packet: {:?}", self.peer_id, packet); @@ -348,12 +317,13 @@ impl MessageSender { } } +type CounterUpdateFn = Pin>; + #[derive(Clone)] struct BrokerState { id: String, - limits_for_namespace: Arc>, - qualified_counters: Arc>>>, publisher: broadcast::Sender, + on_counter_update: Arc, } #[derive(Clone)] @@ -369,8 +339,7 @@ impl Broker { id: String, listen_address: SocketAddr, peer_urls: Vec, - limits_for_namespace: Arc>, - qualified_counters: Arc>>>, + on_counter_update: CounterUpdateFn, ) -> Broker { let (tx, _) = broadcast::channel(16); let publisher: broadcast::Sender = tx; @@ -381,8 +350,7 @@ impl Broker { broker_state: BrokerState { id, publisher, - limits_for_namespace, - qualified_counters, + on_counter_update: Arc::new(on_counter_update), }, replication_state: Arc::new(RwLock::new(ReplicationState { discovered_urls: HashSet::new(), diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs index 7a9e7761..65c1394c 100644 --- a/limitador/src/storage/distributed/mod.rs +++ b/limitador/src/storage/distributed/mod.rs @@ -1,11 +1,9 @@ use std::collections::hash_map::Entry; -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::net::ToSocketAddrs; -use std::ops::Deref; use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use moka::sync::Cache; use serde::{Deserialize, Serialize}; use crate::counter::Counter; @@ -18,89 +16,57 @@ use crate::storage::{Authorization, CounterStorage, StorageErr}; mod cr_counter_value; mod grpc; -type NamespacedLimitCounters = HashMap>; +pub type LimitsMap = HashMap, CrCounterValue>; pub struct CrInMemoryStorage { identifier: String, - limits_for_namespace: Arc>>>, - qualified_counters: Arc>>>, + limits: Arc>, broker: Broker, } impl CounterStorage for CrInMemoryStorage { #[tracing::instrument(skip_all)] fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result { - let limits_by_namespace = self.limits_for_namespace.read().unwrap(); + let limits = self.limits.read().unwrap(); let mut value = 0; - - if counter.is_qualified() { - if let Some(counter) = self.qualified_counters.get(counter) { - value = counter.read(); - } - } else if let Some(limits) = limits_by_namespace.get(counter.limit().namespace()) { - if let Some(counter) = limits.get(counter.limit()) { - value = counter.read(); - } + let key = encode_counter_to_key(counter); + if let Some(counter_value) = limits.get(&key) { + value = counter_value.read() } - Ok(counter.max_value() >= value + delta) } #[tracing::instrument(skip_all)] fn add_counter(&self, limit: &Limit) -> Result<(), StorageErr> { if limit.variables().is_empty() { - let mut limits_by_namespace = self.limits_for_namespace.write().unwrap(); - limits_by_namespace - .entry(limit.namespace().clone()) - .or_default() - .entry(limit.clone()) - .or_insert(CrCounterValue::new( - self.identifier.clone(), - Duration::from_secs(limit.seconds()), - )); + let mut limits = self.limits.write().unwrap(); + let key = encode_limit_to_key(limit); + limits.entry(key).or_insert(CrCounterValue::new( + self.identifier.clone(), + Duration::from_secs(limit.seconds()), + )); } Ok(()) } #[tracing::instrument(skip_all)] fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr> { - let mut limits_by_namespace = self.limits_for_namespace.write().unwrap(); + let mut limits = self.limits.write().unwrap(); let now = SystemTime::now(); - if counter.is_qualified() { - let value = match self.qualified_counters.get(counter) { - None => self.qualified_counters.get_with(counter.clone(), || { - Arc::new(CrCounterValue::new( - self.identifier.clone(), - counter.window(), - )) - }), - Some(counter) => counter, - }; - self.increment_counter(counter.clone(), &value, delta, now); - } else { - match limits_by_namespace.entry(counter.limit().namespace().clone()) { - Entry::Vacant(v) => { - let mut limits = HashMap::new(); - let counter_val = - CrCounterValue::new(self.identifier.clone(), counter.window()); - self.increment_counter(counter.clone(), &counter_val, delta, now); - limits.insert(counter.limit().clone(), counter_val); - v.insert(limits); - } - Entry::Occupied(mut o) => match o.get_mut().entry(counter.limit().clone()) { - Entry::Vacant(v) => { - let counter_value = - CrCounterValue::new(self.identifier.clone(), counter.window()); - self.increment_counter(counter.clone(), &counter_value, delta, now); - v.insert(counter_value); - } - Entry::Occupied(o) => { - self.increment_counter(counter.clone(), o.get(), delta, now); - } - }, + + let key = encode_counter_to_key(counter); + match limits.entry(key.clone()) { + Entry::Vacant(entry) => { + let duration = counter.window(); + let store_value = CrCounterValue::new(self.identifier.clone(), duration); + self.increment_counter(counter, key, &store_value, delta, now); + entry.insert(store_value); } - } + Entry::Occupied(entry) => { + self.increment_counter(counter, key, entry.get(), delta, now); + } + }; Ok(()) } @@ -111,11 +77,8 @@ impl CounterStorage for CrInMemoryStorage { delta: u64, load_counters: bool, ) -> Result { - let limits_by_namespace = self.limits_for_namespace.write().unwrap(); let mut first_limited = None; - let mut counter_values_to_update: Vec<(&CrCounterValue, Counter)> = Vec::new(); - let mut qualified_counter_values_to_updated: Vec<(Arc>, Counter)> = - Vec::new(); + let mut counter_values_to_update: Vec<(Counter, Vec)> = Vec::new(); let now = SystemTime::now(); let mut process_counter = @@ -138,39 +101,44 @@ impl CounterStorage for CrInMemoryStorage { }; // Process simple counters - for counter in counters.iter_mut().filter(|c| !c.is_qualified()) { - let atomic_expiring_value: &CrCounterValue = limits_by_namespace - .get(counter.limit().namespace()) - .and_then(|limits| limits.get(counter.limit())) - .unwrap(); - - if let Some(limited) = process_counter(counter, atomic_expiring_value.read(), delta) { - if !load_counters { - return Ok(limited); + for counter in counters.iter_mut() { + let key = encode_counter_to_key(counter); + + // most of the time the counter should exist, so first try with a read only lock + // since that will allow us to have higher concurrency + let counter_existed = { + let key = key.clone(); + let limits = self.limits.read().unwrap(); + match limits.get(&key) { + None => false, + Some(store_value) => { + if let Some(limited) = process_counter(counter, store_value.read(), delta) { + if !load_counters { + return Ok(limited); + } + } + counter_values_to_update.push((counter.clone(), key)); + true + } } - } - counter_values_to_update.push((atomic_expiring_value, counter.clone())); - } - - // Process qualified counters - for counter in counters.iter_mut().filter(|c| c.is_qualified()) { - let value = match self.qualified_counters.get(counter) { - None => self.qualified_counters.get_with(counter.clone(), || { - Arc::new(CrCounterValue::new( - self.identifier.clone(), - counter.window(), - )) - }), - Some(counter) => counter, }; - if let Some(limited) = process_counter(counter, value.read(), delta) { - if !load_counters { - return Ok(limited); + // we need to take the slow path since we need to mutate the limits map. + if !counter_existed { + // try again with a write lock to create the counter if it's still missing. + let mut limits = self.limits.write().unwrap(); + let store_value = limits.entry(key.clone()).or_insert(CrCounterValue::new( + self.identifier.clone(), + counter.window(), + )); + + if let Some(limited) = process_counter(counter, store_value.read(), delta) { + if !load_counters { + return Ok(limited); + } } + counter_values_to_update.push((counter.clone(), key)); } - - qualified_counter_values_to_updated.push((value, counter.clone())); } if let Some(limited) = first_limited { @@ -178,15 +146,12 @@ impl CounterStorage for CrInMemoryStorage { } // Update counters + let limits = self.limits.read().unwrap(); counter_values_to_update .into_iter() - .for_each(|(v, counter)| { - self.increment_counter(counter, v, delta, now); - }); - qualified_counter_values_to_updated - .into_iter() - .for_each(|(v, counter)| { - self.increment_counter(counter, v.deref(), delta, now); + .for_each(|(counter, key)| { + let store_value = limits.get(&key).unwrap(); + self.increment_counter(&counter, key, store_value, delta, now); }); Ok(Authorization::Ok) @@ -196,40 +161,17 @@ impl CounterStorage for CrInMemoryStorage { fn get_counters(&self, limits: &HashSet) -> Result, StorageErr> { let mut res = HashSet::new(); - let namespaces: HashSet<&Namespace> = limits.iter().map(Limit::namespace).collect(); - let limits_by_namespace = self.limits_for_namespace.read().unwrap(); - - for namespace in namespaces { - if let Some(limits) = limits_by_namespace.get(namespace) { - for limit in limits.keys() { - if limits.contains_key(limit) { - for (counter, expiring_value) in self.counters_in_namespace(namespace) { - let mut counter_with_val = counter.clone(); - counter_with_val.set_remaining( - counter_with_val.max_value() - expiring_value.read(), - ); - counter_with_val.set_expires_in(expiring_value.ttl()); - if counter_with_val.expires_in().unwrap() > Duration::ZERO { - res.insert(counter_with_val); - } - } - } - } - } - } - - for (counter, expiring_value) in self.qualified_counters.iter() { + let limits_map = self.limits.read().unwrap(); + for (key, counter_value) in limits_map.iter() { + let mut counter: Counter = decode_counter_key(key).unwrap().into(); if limits.contains(counter.limit()) { - let mut counter_with_val = counter.deref().clone(); - counter_with_val - .set_remaining(counter_with_val.max_value() - expiring_value.read()); - counter_with_val.set_expires_in(expiring_value.ttl()); - if counter_with_val.expires_in().unwrap() > Duration::ZERO { - res.insert(counter_with_val); + counter.set_remaining(counter.max_value() - counter_value.read()); + counter.set_expires_in(counter_value.ttl()); + if counter.expires_in().unwrap() > Duration::ZERO { + res.insert(counter); } } } - Ok(res) } @@ -243,35 +185,38 @@ impl CounterStorage for CrInMemoryStorage { #[tracing::instrument(skip_all)] fn clear(&self) -> Result<(), StorageErr> { - self.limits_for_namespace.write().unwrap().clear(); + self.limits.write().unwrap().clear(); Ok(()) } } -pub type LimitsMap = HashMap>>; - impl CrInMemoryStorage { pub fn new( identifier: String, - cache_size: u64, + _cache_size: u64, listen_address: String, peer_urls: Vec, ) -> Self { - // let (sender, mut rx) = mpsc::channel(1000); - let listen_address = listen_address.to_socket_addrs().unwrap().next().unwrap(); let peer_urls = peer_urls.clone(); + let limits = Arc::new(RwLock::new(LimitsMap::new())); - let limits_for_namespace = Arc::new(RwLock::new(LimitsMap::new())); - let qualified_counters: Arc>>> = - Arc::new(Cache::new(cache_size)); - + let limits_clone = limits.clone(); let broker = grpc::Broker::new( identifier.clone(), listen_address, peer_urls, - limits_for_namespace.clone(), - qualified_counters.clone(), + Box::pin(move |update: CounterUpdate| { + let values = BTreeMap::from_iter( + update + .values + .iter() + .map(|(k, v)| (k.to_owned(), v.to_owned())), + ); + let limits = limits_clone.read().unwrap(); + let value = limits.get(&update.key).unwrap(); + value.merge((UNIX_EPOCH + Duration::from_secs(update.expires_at), values).into()); + }), ); { @@ -283,45 +228,14 @@ impl CrInMemoryStorage { Self { identifier, - limits_for_namespace, - qualified_counters, + limits, broker, } } - fn counters_in_namespace( - &self, - namespace: &Namespace, - ) -> HashMap> { - let mut res: HashMap> = HashMap::new(); - - if let Some(counters_by_limit) = self.limits_for_namespace.read().unwrap().get(namespace) { - for (limit, value) in counters_by_limit { - res.insert( - Counter::new(limit.clone(), HashMap::default()), - value.clone(), - ); - } - } - - for (counter, value) in self.qualified_counters.iter() { - if counter.namespace() == namespace { - res.insert(counter.deref().clone(), value.deref().clone()); - } - } - - res - } - fn delete_counters_of_limit(&self, limit: &Limit) { - if let Some(counters_by_limit) = self - .limits_for_namespace - .write() - .unwrap() - .get_mut(limit.namespace()) - { - counters_by_limit.remove(limit); - } + let key = encode_limit_to_key(limit); + self.limits.write().unwrap().remove(&key); } fn counter_is_within_limits(counter: &Counter, current_val: Option<&u64>, delta: u64) -> bool { @@ -333,19 +247,17 @@ impl CrInMemoryStorage { fn increment_counter( &self, - key: Counter, - counter: &CrCounterValue, + counter: &Counter, + store_key: Vec, + store_value: &CrCounterValue, delta: u64, when: SystemTime, ) { - counter.inc_at(delta, key.window(), when); - let counter = counter.clone(); - let (expiry, values) = counter.into_inner(); - let key: CounterKey = key.into(); - let key = postcard::to_stdvec(&key).unwrap(); + store_value.inc_at(delta, counter.window(), when); + let (expiry, values) = store_value.clone().into_inner(); self.broker.publish(CounterUpdate { - key, + key: store_key, values: values.into_iter().collect(), expires_at: expiry.duration_since(UNIX_EPOCH).unwrap().as_secs(), }) @@ -356,21 +268,39 @@ impl CrInMemoryStorage { struct CounterKey { namespace: Namespace, seconds: u64, + max_value: u64, conditions: HashSet, variables: HashSet, vars: HashMap, } -impl From for CounterKey { - fn from(value: Counter) -> Self { - Self { - namespace: value.namespace().clone(), - seconds: value.window().as_secs(), - variables: value.limit().variables(), - conditions: value.limit().conditions(), - vars: value.set_variables().clone(), - } - } +fn encode_counter_to_key(counter: &Counter) -> Vec { + let limit = counter.limit(); + let key = CounterKey { + namespace: limit.namespace().clone(), + max_value: limit.max_value(), + seconds: limit.seconds(), + variables: limit.variables().clone(), + conditions: limit.conditions().clone(), + vars: counter.set_variables().clone(), + }; + postcard::to_stdvec(&key).unwrap() +} + +fn encode_limit_to_key(limit: &Limit) -> Vec { + let key = CounterKey { + namespace: limit.namespace().clone(), + max_value: limit.max_value(), + seconds: limit.seconds(), + variables: limit.variables().clone(), + conditions: limit.conditions().clone(), + vars: HashMap::default(), + }; + postcard::to_stdvec(&key).unwrap() +} + +fn decode_counter_key(key: &Vec) -> postcard::Result { + postcard::from_bytes(key.as_slice()) } impl From for Counter { @@ -378,7 +308,7 @@ impl From for Counter { Self::new( Limit::new( value.namespace, - 0, + value.max_value, value.seconds, value.conditions, value.vars.keys(), From 15718198df98b622178bc79515b2ffc69b5b042c Mon Sep 17 00:00:00 2001 From: Hiram Chirino Date: Thu, 23 May 2024 11:08:25 -0400 Subject: [PATCH 2/2] Move the max_value field from the CounterKey to CrCounterValue. Signed-off-by: Hiram Chirino --- .../storage/distributed/cr_counter_value.rs | 53 ++++++----- limitador/src/storage/distributed/mod.rs | 94 +++++++++++-------- 2 files changed, 89 insertions(+), 58 deletions(-) diff --git a/limitador/src/storage/distributed/cr_counter_value.rs b/limitador/src/storage/distributed/cr_counter_value.rs index e76b4b05..38220e5f 100644 --- a/limitador/src/storage/distributed/cr_counter_value.rs +++ b/limitador/src/storage/distributed/cr_counter_value.rs @@ -1,13 +1,15 @@ -use crate::storage::atomic_expiring_value::AtomicExpiryTime; use std::collections::btree_map::Entry; use std::collections::BTreeMap; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::RwLock; use std::time::{Duration, SystemTime}; +use crate::storage::atomic_expiring_value::AtomicExpiryTime; + #[derive(Debug)] pub struct CrCounterValue { ourselves: A, + max_value: u64, value: AtomicU64, others: RwLock>, expiry: AtomicExpiryTime, @@ -15,15 +17,20 @@ pub struct CrCounterValue { #[allow(dead_code)] impl CrCounterValue { - pub fn new(actor: A, time_window: Duration) -> Self { + pub fn new(actor: A, max_value: u64, time_window: Duration) -> Self { Self { ourselves: actor, + max_value, value: Default::default(), others: RwLock::default(), expiry: AtomicExpiryTime::new(SystemTime::now() + time_window), } } + pub fn max_value(&self) -> u64 { + self.max_value + } + pub fn read(&self) -> u64 { self.read_at(SystemTime::now()) } @@ -116,6 +123,7 @@ impl CrCounterValue { pub fn into_inner(self) -> (SystemTime, BTreeMap) { let Self { ourselves, + max_value: _, value, others, expiry, @@ -137,6 +145,7 @@ impl Clone for CrCounterValue { fn clone(&self) -> Self { Self { ourselves: self.ourselves.clone(), + max_value: self.max_value, value: AtomicU64::new(self.value.load(Ordering::SeqCst)), others: RwLock::new(self.others.read().unwrap().clone()), expiry: self.expiry.clone(), @@ -148,6 +157,7 @@ impl From<(SystemTime, BTreeMap)> for CrCounte fn from(value: (SystemTime, BTreeMap)) -> Self { Self { ourselves: A::default(), + max_value: 0, value: Default::default(), others: RwLock::new(value.1), expiry: value.0.into(), @@ -157,13 +167,14 @@ impl From<(SystemTime, BTreeMap)> for CrCounte #[cfg(test)] mod tests { - use crate::storage::distributed::cr_counter_value::CrCounterValue; use std::time::{Duration, SystemTime}; + use crate::storage::distributed::cr_counter_value::CrCounterValue; + #[test] fn local_increments_are_readable() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', window); + let a = CrCounterValue::new('A', u64::MAX, window); a.inc(3, window); assert_eq!(3, a.read()); a.inc(2, window); @@ -173,7 +184,7 @@ mod tests { #[test] fn local_increments_expire() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', window); + let a = CrCounterValue::new('A', u64::MAX, window); let now = SystemTime::now(); a.inc_at(3, window, now); assert_eq!(3, a.read()); @@ -184,7 +195,7 @@ mod tests { #[test] fn other_increments_are_readable() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', window); + let a = CrCounterValue::new('A', u64::MAX, window); a.inc_actor('B', 3, window); assert_eq!(3, a.read()); a.inc_actor('B', 2, window); @@ -194,7 +205,7 @@ mod tests { #[test] fn other_increments_expire() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', window); + let a = CrCounterValue::new('A', u64::MAX, window); let now = SystemTime::now(); a.inc_actor_at('B', 3, window, now); assert_eq!(3, a.read()); @@ -205,8 +216,8 @@ mod tests { #[test] fn merges() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', window); - let b = CrCounterValue::new('B', window); + let a = CrCounterValue::new('A', u64::MAX, window); + let b = CrCounterValue::new('B', u64::MAX, window); a.inc(3, window); b.inc(2, window); a.merge(b); @@ -216,8 +227,8 @@ mod tests { #[test] fn merges_symetric() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', window); - let b = CrCounterValue::new('B', window); + let a = CrCounterValue::new('A', u64::MAX, window); + let b = CrCounterValue::new('B', u64::MAX, window); a.inc(3, window); b.inc(2, window); b.merge(a); @@ -227,8 +238,8 @@ mod tests { #[test] fn merges_overrides_with_larger_value() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', window); - let b = CrCounterValue::new('B', window); + let a = CrCounterValue::new('A', u64::MAX, window); + let b = CrCounterValue::new('B', u64::MAX, window); a.inc(3, window); b.inc(2, window); b.inc_actor('A', 2, window); // older value! @@ -239,8 +250,8 @@ mod tests { #[test] fn merges_ignore_lesser_values() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', window); - let b = CrCounterValue::new('B', window); + let a = CrCounterValue::new('A', u64::MAX, window); + let b = CrCounterValue::new('B', u64::MAX, window); a.inc(3, window); b.inc(2, window); b.inc_actor('A', 5, window); // newer value! @@ -251,9 +262,9 @@ mod tests { #[test] fn merge_ignores_expired_sets() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', Duration::ZERO); + let a = CrCounterValue::new('A', u64::MAX, Duration::ZERO); a.inc(3, Duration::ZERO); - let b = CrCounterValue::new('B', window); + let b = CrCounterValue::new('B', u64::MAX, window); b.inc(2, window); b.merge(a); assert_eq!(b.read(), 2); @@ -262,9 +273,9 @@ mod tests { #[test] fn merge_ignores_expired_sets_symmetric() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', Duration::ZERO); + let a = CrCounterValue::new('A', u64::MAX, Duration::ZERO); a.inc(3, Duration::ZERO); - let b = CrCounterValue::new('B', window); + let b = CrCounterValue::new('B', u64::MAX, window); b.inc(2, window); a.merge(b); assert_eq!(a.read(), 2); @@ -273,9 +284,9 @@ mod tests { #[test] fn merge_uses_earliest_expiry() { let later = Duration::from_secs(1); - let a = CrCounterValue::new('A', later); + let a = CrCounterValue::new('A', u64::MAX, later); let sooner = Duration::from_millis(200); - let b = CrCounterValue::new('B', sooner); + let b = CrCounterValue::new('B', u64::MAX, sooner); a.inc(3, later); b.inc(2, later); a.merge(b); diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs index 65c1394c..452b3aa2 100644 --- a/limitador/src/storage/distributed/mod.rs +++ b/limitador/src/storage/distributed/mod.rs @@ -44,6 +44,7 @@ impl CounterStorage for CrInMemoryStorage { let key = encode_limit_to_key(limit); limits.entry(key).or_insert(CrCounterValue::new( self.identifier.clone(), + limit.max_value(), Duration::from_secs(limit.seconds()), )); } @@ -59,7 +60,8 @@ impl CounterStorage for CrInMemoryStorage { match limits.entry(key.clone()) { Entry::Vacant(entry) => { let duration = counter.window(); - let store_value = CrCounterValue::new(self.identifier.clone(), duration); + let store_value = + CrCounterValue::new(self.identifier.clone(), counter.max_value(), duration); self.increment_counter(counter, key, &store_value, delta, now); entry.insert(store_value); } @@ -129,6 +131,7 @@ impl CounterStorage for CrInMemoryStorage { let mut limits = self.limits.write().unwrap(); let store_value = limits.entry(key.clone()).or_insert(CrCounterValue::new( self.identifier.clone(), + counter.max_value(), counter.window(), )); @@ -161,10 +164,22 @@ impl CounterStorage for CrInMemoryStorage { fn get_counters(&self, limits: &HashSet) -> Result, StorageErr> { let mut res = HashSet::new(); + let limits: HashSet<_> = limits.iter().map(encode_limit_to_key).collect(); + let limits_map = self.limits.read().unwrap(); for (key, counter_value) in limits_map.iter() { - let mut counter: Counter = decode_counter_key(key).unwrap().into(); - if limits.contains(counter.limit()) { + let counter_key = decode_counter_key(key).unwrap(); + let limit_key = if !counter_key.vars.is_empty() { + let mut cloned = counter_key.clone(); + cloned.vars = HashMap::default(); + cloned.encode() + } else { + key.clone() + }; + + if limits.contains(&limit_key) { + let counter = (&counter_key, counter_value); + let mut counter: Counter = counter.into(); counter.set_remaining(counter.max_value() - counter_value.read()); counter.set_expires_in(counter_value.ttl()); if counter.expires_in().unwrap() > Duration::ZERO { @@ -264,56 +279,61 @@ impl CrInMemoryStorage { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] struct CounterKey { namespace: Namespace, seconds: u64, - max_value: u64, conditions: HashSet, variables: HashSet, vars: HashMap, } +impl CounterKey { + fn new(limit: &Limit, vars: HashMap) -> Self { + CounterKey { + namespace: limit.namespace().clone(), + seconds: limit.seconds(), + variables: limit.variables().clone(), + conditions: limit.conditions().clone(), + vars, + } + } + + fn encode(&self) -> Vec { + postcard::to_stdvec(self).unwrap() + } +} + +impl From<(&CounterKey, &CrCounterValue)> for Counter { + fn from(value: (&CounterKey, &CrCounterValue)) -> Self { + let (counter_key, store_value) = value; + let max_value = store_value.max_value(); + let mut counter = Self::new( + Limit::new( + counter_key.namespace.clone(), + max_value, + counter_key.seconds, + counter_key.conditions.clone(), + counter_key.vars.keys(), + ), + counter_key.vars.clone(), + ); + counter.set_remaining(max_value - store_value.read()); + counter.set_expires_in(store_value.ttl()); + counter + } +} + fn encode_counter_to_key(counter: &Counter) -> Vec { - let limit = counter.limit(); - let key = CounterKey { - namespace: limit.namespace().clone(), - max_value: limit.max_value(), - seconds: limit.seconds(), - variables: limit.variables().clone(), - conditions: limit.conditions().clone(), - vars: counter.set_variables().clone(), - }; + let key = CounterKey::new(counter.limit(), counter.set_variables().clone()); postcard::to_stdvec(&key).unwrap() } fn encode_limit_to_key(limit: &Limit) -> Vec { - let key = CounterKey { - namespace: limit.namespace().clone(), - max_value: limit.max_value(), - seconds: limit.seconds(), - variables: limit.variables().clone(), - conditions: limit.conditions().clone(), - vars: HashMap::default(), - }; + let key = CounterKey::new(limit, HashMap::default()); postcard::to_stdvec(&key).unwrap() } fn decode_counter_key(key: &Vec) -> postcard::Result { postcard::from_bytes(key.as_slice()) } - -impl From for Counter { - fn from(value: CounterKey) -> Self { - Self::new( - Limit::new( - value.namespace, - value.max_value, - value.seconds, - value.conditions, - value.vars.keys(), - ), - value.vars, - ) - } -}