From 4ecdf0de44c5a3c59c48b06fae945a16a8ed25e3 Mon Sep 17 00:00:00 2001 From: Hiram Chirino Date: Mon, 13 May 2024 10:32:16 -0400 Subject: [PATCH] [distributed store] use a single map Vec -> Counters map By keying using Vec we reduce how often we need to encode/decode the the counter keys. Signed-off-by: Hiram Chirino --- limitador/src/storage/distributed/grpc/mod.rs | 46 +-- limitador/src/storage/distributed/mod.rs | 321 +++++++----------- 2 files changed, 133 insertions(+), 234 deletions(-) diff --git a/limitador/src/storage/distributed/grpc/mod.rs b/limitador/src/storage/distributed/grpc/mod.rs index 2339e1cf..79d3d985 100644 --- a/limitador/src/storage/distributed/grpc/mod.rs +++ b/limitador/src/storage/distributed/grpc/mod.rs @@ -1,11 +1,10 @@ -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::ops::Add; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::{error::Error, io::ErrorKind, pin::Pin}; -use moka::sync::Cache; use tokio::sync::mpsc::Sender; use tokio::sync::{broadcast, mpsc, RwLock}; use tokio::time::sleep; @@ -14,15 +13,12 @@ use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt}; use tonic::{Code, Request, Response, Status, Streaming}; use tracing::debug; -use crate::counter::Counter; -use crate::storage::distributed::cr_counter_value::CrCounterValue; use crate::storage::distributed::grpc::v1::packet::Message; use crate::storage::distributed::grpc::v1::replication_client::ReplicationClient; use crate::storage::distributed::grpc::v1::replication_server::{Replication, ReplicationServer}; use crate::storage::distributed::grpc::v1::{ CounterUpdate, Hello, MembershipUpdate, Packet, Peer, Pong, }; -use crate::storage::distributed::CounterKey; // clippy will barf on protobuff generated code for enum variants in // v3::socket_option::SocketState, so allow this lint @@ -187,34 +183,7 @@ impl Session { } Some(Message::CounterUpdate(update)) => { debug!("peer: '{}': CounterUpdate", self.peer_id); - - let counter_key = postcard::from_bytes::(update.key.as_slice()) - .map_err(|err| { - Status::internal(format!("failed to decode counter key: {:?}", err)) - })?; - - let values = BTreeMap::from_iter( - update - .values - .iter() - .map(|(k, v)| (k.to_owned(), v.to_owned())), - ); - - let counter = >::into(counter_key); - if counter.is_qualified() { - if let Some(counter) = self.broker_state.qualified_counters.get(&counter) { - counter.merge( - (UNIX_EPOCH + Duration::from_secs(update.expires_at), values).into(), - ); - } - } else { - let counters = self.broker_state.limits_for_namespace.read().unwrap(); - let limits = counters.get(counter.namespace()).unwrap(); - let value = limits.get(counter.limit()).unwrap(); - value.merge( - (UNIX_EPOCH + Duration::from_secs(update.expires_at), values).into(), - ); - }; + (self.broker_state.on_counter_update)(update); } _ => { debug!("peer: '{}': unsupported packet: {:?}", self.peer_id, packet); @@ -348,12 +317,13 @@ impl MessageSender { } } +type CounterUpdateFn = Pin>; + #[derive(Clone)] struct BrokerState { id: String, - limits_for_namespace: Arc>, - qualified_counters: Arc>>>, publisher: broadcast::Sender, + on_counter_update: Arc, } #[derive(Clone)] @@ -369,8 +339,7 @@ impl Broker { id: String, listen_address: SocketAddr, peer_urls: Vec, - limits_for_namespace: Arc>, - qualified_counters: Arc>>>, + on_counter_update: CounterUpdateFn, ) -> Broker { let (tx, _) = broadcast::channel(16); let publisher: broadcast::Sender = tx; @@ -381,8 +350,7 @@ impl Broker { broker_state: BrokerState { id, publisher, - limits_for_namespace, - qualified_counters, + on_counter_update: Arc::new(on_counter_update), }, replication_state: Arc::new(RwLock::new(ReplicationState { discovered_urls: HashSet::new(), diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs index 7a9e7761..00c7b06e 100644 --- a/limitador/src/storage/distributed/mod.rs +++ b/limitador/src/storage/distributed/mod.rs @@ -1,11 +1,9 @@ use std::collections::hash_map::Entry; -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::net::ToSocketAddrs; -use std::ops::Deref; use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use moka::sync::Cache; use serde::{Deserialize, Serialize}; use crate::counter::Counter; @@ -18,89 +16,57 @@ use crate::storage::{Authorization, CounterStorage, StorageErr}; mod cr_counter_value; mod grpc; -type NamespacedLimitCounters = HashMap>; +pub type LimitsMap = HashMap, CrCounterValue>; pub struct CrInMemoryStorage { identifier: String, - limits_for_namespace: Arc>>>, - qualified_counters: Arc>>>, + limits: Arc>, broker: Broker, } impl CounterStorage for CrInMemoryStorage { #[tracing::instrument(skip_all)] fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result { - let limits_by_namespace = self.limits_for_namespace.read().unwrap(); + let limits = self.limits.read().unwrap(); let mut value = 0; - - if counter.is_qualified() { - if let Some(counter) = self.qualified_counters.get(counter) { - value = counter.read(); - } - } else if let Some(limits) = limits_by_namespace.get(counter.limit().namespace()) { - if let Some(counter) = limits.get(counter.limit()) { - value = counter.read(); - } + let key = encode_counter_to_key(counter); + if let Some(counter_value) = limits.get(&key) { + value = counter_value.read() } - Ok(counter.max_value() >= value + delta) } #[tracing::instrument(skip_all)] fn add_counter(&self, limit: &Limit) -> Result<(), StorageErr> { if limit.variables().is_empty() { - let mut limits_by_namespace = self.limits_for_namespace.write().unwrap(); - limits_by_namespace - .entry(limit.namespace().clone()) - .or_default() - .entry(limit.clone()) - .or_insert(CrCounterValue::new( - self.identifier.clone(), - Duration::from_secs(limit.seconds()), - )); + let mut limits = self.limits.write().unwrap(); + let key = encode_limit_to_key(&limit); + limits.entry(key).or_insert(CrCounterValue::new( + self.identifier.clone(), + Duration::from_secs(limit.seconds()), + )); } Ok(()) } #[tracing::instrument(skip_all)] fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr> { - let mut limits_by_namespace = self.limits_for_namespace.write().unwrap(); + let mut limits = self.limits.write().unwrap(); let now = SystemTime::now(); - if counter.is_qualified() { - let value = match self.qualified_counters.get(counter) { - None => self.qualified_counters.get_with(counter.clone(), || { - Arc::new(CrCounterValue::new( - self.identifier.clone(), - counter.window(), - )) - }), - Some(counter) => counter, - }; - self.increment_counter(counter.clone(), &value, delta, now); - } else { - match limits_by_namespace.entry(counter.limit().namespace().clone()) { - Entry::Vacant(v) => { - let mut limits = HashMap::new(); - let counter_val = - CrCounterValue::new(self.identifier.clone(), counter.window()); - self.increment_counter(counter.clone(), &counter_val, delta, now); - limits.insert(counter.limit().clone(), counter_val); - v.insert(limits); - } - Entry::Occupied(mut o) => match o.get_mut().entry(counter.limit().clone()) { - Entry::Vacant(v) => { - let counter_value = - CrCounterValue::new(self.identifier.clone(), counter.window()); - self.increment_counter(counter.clone(), &counter_value, delta, now); - v.insert(counter_value); - } - Entry::Occupied(o) => { - self.increment_counter(counter.clone(), o.get(), delta, now); - } - }, + + let key = encode_counter_to_key(counter); + match limits.entry(key.clone()) { + Entry::Vacant(entry) => { + let duration = counter.window(); + let store_value = CrCounterValue::new(self.identifier.clone(), duration); + self.increment_counter(&counter, key, &store_value, delta, now); + entry.insert(store_value); } - } + Entry::Occupied(entry) => { + self.increment_counter(&counter, key, entry.get(), delta, now); + } + }; Ok(()) } @@ -111,11 +77,8 @@ impl CounterStorage for CrInMemoryStorage { delta: u64, load_counters: bool, ) -> Result { - let limits_by_namespace = self.limits_for_namespace.write().unwrap(); let mut first_limited = None; - let mut counter_values_to_update: Vec<(&CrCounterValue, Counter)> = Vec::new(); - let mut qualified_counter_values_to_updated: Vec<(Arc>, Counter)> = - Vec::new(); + let mut counter_values_to_update: Vec<(Counter, Vec)> = Vec::new(); let now = SystemTime::now(); let mut process_counter = @@ -138,39 +101,44 @@ impl CounterStorage for CrInMemoryStorage { }; // Process simple counters - for counter in counters.iter_mut().filter(|c| !c.is_qualified()) { - let atomic_expiring_value: &CrCounterValue = limits_by_namespace - .get(counter.limit().namespace()) - .and_then(|limits| limits.get(counter.limit())) - .unwrap(); - - if let Some(limited) = process_counter(counter, atomic_expiring_value.read(), delta) { - if !load_counters { - return Ok(limited); + for counter in counters.iter_mut() { + let key = encode_counter_to_key(counter); + + // most of the time the counter should exist, so first try with a read only lock + // since that will allow us to have higher concurrency + let counter_existed = { + let key = key.clone(); + let limits = self.limits.read().unwrap(); + match limits.get(&key) { + None => false, + Some(store_value) => { + if let Some(limited) = process_counter(counter, store_value.read(), delta) { + if !load_counters { + return Ok(limited); + } + } + counter_values_to_update.push((counter.clone(), key)); + true + } } - } - counter_values_to_update.push((atomic_expiring_value, counter.clone())); - } - - // Process qualified counters - for counter in counters.iter_mut().filter(|c| c.is_qualified()) { - let value = match self.qualified_counters.get(counter) { - None => self.qualified_counters.get_with(counter.clone(), || { - Arc::new(CrCounterValue::new( - self.identifier.clone(), - counter.window(), - )) - }), - Some(counter) => counter, }; - if let Some(limited) = process_counter(counter, value.read(), delta) { - if !load_counters { - return Ok(limited); + // we need to take the slow path since we need to mutate the limits map. + if !counter_existed { + // try again with a write lock to create the counter if it's still missing. + let mut limits = self.limits.write().unwrap(); + let store_value = limits.entry(key.clone()).or_insert(CrCounterValue::new( + self.identifier.clone(), + counter.window(), + )); + + if let Some(limited) = process_counter(counter, store_value.read(), delta) { + if !load_counters { + return Ok(limited); + } } + counter_values_to_update.push((counter.clone(), key)); } - - qualified_counter_values_to_updated.push((value, counter.clone())); } if let Some(limited) = first_limited { @@ -178,15 +146,12 @@ impl CounterStorage for CrInMemoryStorage { } // Update counters + let limits = self.limits.read().unwrap(); counter_values_to_update .into_iter() - .for_each(|(v, counter)| { - self.increment_counter(counter, v, delta, now); - }); - qualified_counter_values_to_updated - .into_iter() - .for_each(|(v, counter)| { - self.increment_counter(counter, v.deref(), delta, now); + .for_each(|(counter, key)| { + let store_value = limits.get(&key).unwrap(); + self.increment_counter(&counter, key, store_value, delta, now); }); Ok(Authorization::Ok) @@ -196,40 +161,17 @@ impl CounterStorage for CrInMemoryStorage { fn get_counters(&self, limits: &HashSet) -> Result, StorageErr> { let mut res = HashSet::new(); - let namespaces: HashSet<&Namespace> = limits.iter().map(Limit::namespace).collect(); - let limits_by_namespace = self.limits_for_namespace.read().unwrap(); - - for namespace in namespaces { - if let Some(limits) = limits_by_namespace.get(namespace) { - for limit in limits.keys() { - if limits.contains_key(limit) { - for (counter, expiring_value) in self.counters_in_namespace(namespace) { - let mut counter_with_val = counter.clone(); - counter_with_val.set_remaining( - counter_with_val.max_value() - expiring_value.read(), - ); - counter_with_val.set_expires_in(expiring_value.ttl()); - if counter_with_val.expires_in().unwrap() > Duration::ZERO { - res.insert(counter_with_val); - } - } - } + let limits_map = self.limits.read().unwrap(); + for (key, counter_value) in limits_map.iter() { + let mut counter: Counter = decode_counter_key(key).unwrap().into(); + if limits.contains(&counter.limit()) { + counter.set_remaining(counter.max_value() - counter_value.read()); + counter.set_expires_in(counter_value.ttl()); + if counter.expires_in().unwrap() > Duration::ZERO { + res.insert(counter); } } } - - for (counter, expiring_value) in self.qualified_counters.iter() { - if limits.contains(counter.limit()) { - let mut counter_with_val = counter.deref().clone(); - counter_with_val - .set_remaining(counter_with_val.max_value() - expiring_value.read()); - counter_with_val.set_expires_in(expiring_value.ttl()); - if counter_with_val.expires_in().unwrap() > Duration::ZERO { - res.insert(counter_with_val); - } - } - } - Ok(res) } @@ -243,35 +185,39 @@ impl CounterStorage for CrInMemoryStorage { #[tracing::instrument(skip_all)] fn clear(&self) -> Result<(), StorageErr> { - self.limits_for_namespace.write().unwrap().clear(); + self.limits.write().unwrap().clear(); Ok(()) } } -pub type LimitsMap = HashMap>>; - impl CrInMemoryStorage { pub fn new( identifier: String, - cache_size: u64, + _cache_size: u64, listen_address: String, peer_urls: Vec, ) -> Self { - // let (sender, mut rx) = mpsc::channel(1000); - let listen_address = listen_address.to_socket_addrs().unwrap().next().unwrap(); let peer_urls = peer_urls.clone(); + let limits = Arc::new(RwLock::new(LimitsMap::new())); - let limits_for_namespace = Arc::new(RwLock::new(LimitsMap::new())); - let qualified_counters: Arc>>> = - Arc::new(Cache::new(cache_size)); - + let limits_clone = limits.clone(); let broker = grpc::Broker::new( identifier.clone(), listen_address, peer_urls, - limits_for_namespace.clone(), - qualified_counters.clone(), + Box::pin(move |update: CounterUpdate| { + let values = BTreeMap::from_iter( + update + .values + .iter() + .map(|(k, v)| (k.to_owned(), v.to_owned())), + ); + let limits = limits_clone.read().unwrap(); + let value = limits.get(&update.key).unwrap(); + value.merge((UNIX_EPOCH + Duration::from_secs(update.expires_at), values).into()); + return; + }), ); { @@ -283,45 +229,14 @@ impl CrInMemoryStorage { Self { identifier, - limits_for_namespace, - qualified_counters, + limits, broker, } } - fn counters_in_namespace( - &self, - namespace: &Namespace, - ) -> HashMap> { - let mut res: HashMap> = HashMap::new(); - - if let Some(counters_by_limit) = self.limits_for_namespace.read().unwrap().get(namespace) { - for (limit, value) in counters_by_limit { - res.insert( - Counter::new(limit.clone(), HashMap::default()), - value.clone(), - ); - } - } - - for (counter, value) in self.qualified_counters.iter() { - if counter.namespace() == namespace { - res.insert(counter.deref().clone(), value.deref().clone()); - } - } - - res - } - fn delete_counters_of_limit(&self, limit: &Limit) { - if let Some(counters_by_limit) = self - .limits_for_namespace - .write() - .unwrap() - .get_mut(limit.namespace()) - { - counters_by_limit.remove(limit); - } + let key = encode_limit_to_key(limit); + self.limits.write().unwrap().remove(&key); } fn counter_is_within_limits(counter: &Counter, current_val: Option<&u64>, delta: u64) -> bool { @@ -333,19 +248,17 @@ impl CrInMemoryStorage { fn increment_counter( &self, - key: Counter, - counter: &CrCounterValue, + counter: &Counter, + store_key: Vec, + store_value: &CrCounterValue, delta: u64, when: SystemTime, ) { - counter.inc_at(delta, key.window(), when); - let counter = counter.clone(); - let (expiry, values) = counter.into_inner(); - let key: CounterKey = key.into(); - let key = postcard::to_stdvec(&key).unwrap(); + store_value.inc_at(delta, counter.window(), when); + let (expiry, values) = store_value.clone().into_inner(); self.broker.publish(CounterUpdate { - key, + key: store_key, values: values.into_iter().collect(), expires_at: expiry.duration_since(UNIX_EPOCH).unwrap().as_secs(), }) @@ -356,21 +269,39 @@ impl CrInMemoryStorage { struct CounterKey { namespace: Namespace, seconds: u64, + max_value: u64, conditions: HashSet, variables: HashSet, vars: HashMap, } -impl From for CounterKey { - fn from(value: Counter) -> Self { - Self { - namespace: value.namespace().clone(), - seconds: value.window().as_secs(), - variables: value.limit().variables(), - conditions: value.limit().conditions(), - vars: value.set_variables().clone(), - } - } +fn encode_counter_to_key(counter: &Counter) -> Vec { + let limit = counter.limit(); + let key = CounterKey { + namespace: limit.namespace().clone(), + max_value: limit.max_value(), + seconds: limit.seconds().clone(), + variables: limit.variables().clone(), + conditions: limit.conditions().clone(), + vars: counter.set_variables().clone(), + }; + postcard::to_stdvec(&key).unwrap() +} + +fn encode_limit_to_key(limit: &Limit) -> Vec { + let key = CounterKey { + namespace: limit.namespace().clone(), + max_value: limit.max_value(), + seconds: limit.seconds(), + variables: limit.variables().clone(), + conditions: limit.conditions().clone(), + vars: HashMap::default(), + }; + postcard::to_stdvec(&key).unwrap() +} + +fn decode_counter_key(key: &Vec) -> postcard::Result { + postcard::from_bytes(key.as_slice()) } impl From for Counter { @@ -378,7 +309,7 @@ impl From for Counter { Self::new( Limit::new( value.namespace, - 0, + value.max_value, value.seconds, value.conditions, value.vars.keys(),