diff --git a/limitador/src/storage/distributed/grpc/mod.rs b/limitador/src/storage/distributed/grpc/mod.rs index 1e04fef0..20548850 100644 --- a/limitador/src/storage/distributed/grpc/mod.rs +++ b/limitador/src/storage/distributed/grpc/mod.rs @@ -6,13 +6,14 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::{error::Error, io::ErrorKind, pin::Pin}; use tokio::sync::mpsc::error::TrySendError; -use tokio::sync::mpsc::Sender; -use tokio::sync::{broadcast, mpsc, RwLock}; +use tokio::sync::mpsc::{Permit, Sender}; +use tokio::sync::{broadcast, mpsc, Notify, RwLock}; use tokio::time::sleep; use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt}; use tonic::{Code, Request, Response, Status, Streaming}; use tracing::debug; +use crate::storage::distributed::cr_counter_value::CrCounterValue; use crate::storage::distributed::grpc::v1::packet::Message; use crate::storage::distributed::grpc::v1::replication_client::ReplicationClient; use crate::storage::distributed::grpc::v1::replication_server::{Replication, ReplicationServer}; @@ -145,12 +146,47 @@ impl Session { TrySendError::Closed(_) => Status::unavailable("re-sync channel closed"), })?; - let mut updates = self.broker_state.publisher.subscribe(); + let mut udpates_to_send = self.broker_state.publisher.subscribe(); + let mut tx_updates_by_key = HashMap::new(); + let mut tx_updates_order = vec![]; + let notifier = Notify::default(); + loop { tokio::select! { - update = updates.recv() => { + update = udpates_to_send.recv() => { let update = update.map_err(|_| Status::unknown("broadcast error"))?; - self.send(Message::CounterUpdate(update)).await?; + // Multiple updates collapse into a single update for the same key + if !tx_updates_by_key.contains_key(&update.key) { + tx_updates_by_key.insert(update.key.clone(), update.value); + tx_updates_order.push(update.key); + notifier.notify_one(); + } + } + _ = notifier.notified() => { + // while we have pending updates to send... + while !tx_updates_order.is_empty() { + // and we have space on the transmission channel to send the update... + match self.out_stream.clone().try_reserve() { + Err(_) => { + break + }, + Ok(permit) => { + + let key = tx_updates_order.remove(0); + let cr_counter_value = tx_updates_by_key.remove(&key).unwrap().clone(); + let (expiry, values) = (*cr_counter_value).clone().into_inner(); + + // only send the update if it has not expired. + if expiry > SystemTime::now() { + permit.send(Ok(Message::CounterUpdate(CounterUpdate { + key, + values: values.into_iter().collect(), + expires_at: expiry.duration_since(UNIX_EPOCH).unwrap().as_secs(), + })))?; + } + } + } + } } result = in_stream.next() => { match result { @@ -354,14 +390,66 @@ impl MessageSender { }, } } + fn try_reserve(&self) -> Result, Status> { + match self { + MessageSender::Client(sender) => { + let permit = sender + .try_reserve() + .map_err(|_| Status::unknown("send error"))?; + Ok(MessagePermit::Client(permit)) + } + MessageSender::Server(sender) => { + let permit = sender + .try_reserve() + .map_err(|_| Status::unknown("send error"))?; + Ok(MessagePermit::Server(permit)) + } + } + } +} + +enum MessagePermit<'a> { + Server(Permit<'a, Result>), + Client(Permit<'a, Packet>), +} +impl<'a> MessagePermit<'a> { + fn send(self, message: Result) -> Result<(), Status> { + match self { + MessagePermit::Server(sender) => { + let value = message.map(|x| Packet { message: Some(x) }); + sender.send(value); + Ok(()) + } + MessagePermit::Client(sender) => match message { + Ok(message) => { + sender.send(Packet { + message: Some(message), + }); + Ok(()) + } + Err(err) => Err(err), + }, + } + } } type CounterUpdateFn = Pin>; +#[derive(Clone, Debug)] +pub struct CounterEntry { + pub key: Vec, + pub value: Arc>, +} + +impl CounterEntry { + pub fn new(key: Vec, value: Arc>) -> Self { + Self { key, value } + } +} #[derive(Clone)] struct BrokerState { id: String, - publisher: broadcast::Sender, + publisher: broadcast::Sender, on_counter_update: Arc, on_re_sync: Arc>>>, } @@ -383,7 +471,7 @@ impl Broker { on_re_sync: Sender>>, ) -> Broker { let (tx, _) = broadcast::channel(16); - let publisher: broadcast::Sender = tx; + let publisher: broadcast::Sender = tx; Broker { listen_address, @@ -401,7 +489,7 @@ impl Broker { } } - pub fn publish(&self, counter_update: CounterUpdate) { + pub fn publish(&self, counter_update: CounterEntry) { // ignore the send error, it just means there are no active subscribers _ = self.broker_state.publisher.send(counter_update); } diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs index 00506b24..428f5178 100644 --- a/limitador/src/storage/distributed/mod.rs +++ b/limitador/src/storage/distributed/mod.rs @@ -13,13 +13,13 @@ use crate::counter::Counter; use crate::limit::{Limit, Namespace}; use crate::storage::distributed::cr_counter_value::CrCounterValue; use crate::storage::distributed::grpc::v1::CounterUpdate; -use crate::storage::distributed::grpc::Broker; +use crate::storage::distributed::grpc::{Broker, CounterEntry}; use crate::storage::{Authorization, CounterStorage, StorageErr}; mod cr_counter_value; mod grpc; -pub type LimitsMap = HashMap, CrCounterValue>; +pub type LimitsMap = HashMap, Arc>>; pub struct CrInMemoryStorage { identifier: String, @@ -45,11 +45,11 @@ impl CounterStorage for CrInMemoryStorage { if limit.variables().is_empty() { let mut limits = self.limits.write().unwrap(); let key = encode_limit_to_key(limit); - limits.entry(key).or_insert(CrCounterValue::new( + limits.entry(key).or_insert(Arc::new(CrCounterValue::new( self.identifier.clone(), limit.max_value(), Duration::from_secs(limit.seconds()), - )); + ))); } Ok(()) } @@ -63,13 +63,16 @@ impl CounterStorage for CrInMemoryStorage { match limits.entry(key.clone()) { Entry::Vacant(entry) => { let duration = counter.window(); - let store_value = - CrCounterValue::new(self.identifier.clone(), counter.max_value(), duration); - self.increment_counter(counter, key, &store_value, delta, now); + let store_value = Arc::new(CrCounterValue::new( + self.identifier.clone(), + counter.max_value(), + duration, + )); + self.increment_counter(counter, key, store_value.clone(), delta, now); entry.insert(store_value); } Entry::Occupied(entry) => { - self.increment_counter(counter, key, entry.get(), delta, now); + self.increment_counter(counter, key, entry.get().clone(), delta, now); } }; Ok(()) @@ -132,11 +135,14 @@ impl CounterStorage for CrInMemoryStorage { if !counter_existed { // try again with a write lock to create the counter if it's still missing. let mut limits = self.limits.write().unwrap(); - let store_value = limits.entry(key.clone()).or_insert(CrCounterValue::new( - self.identifier.clone(), - counter.max_value(), - counter.window(), - )); + let store_value = + limits + .entry(key.clone()) + .or_insert(Arc::new(CrCounterValue::new( + self.identifier.clone(), + counter.max_value(), + counter.window(), + ))); if let Some(limited) = process_counter(counter, store_value.read(), delta) { if !load_counters { @@ -157,7 +163,7 @@ impl CounterStorage for CrInMemoryStorage { .into_iter() .for_each(|(counter, key)| { let store_value = limits.get(&key).unwrap(); - self.increment_counter(&counter, key, store_value, delta, now); + self.increment_counter(&counter, key, store_value.clone(), delta, now); }); Ok(Authorization::Ok) @@ -181,7 +187,7 @@ impl CounterStorage for CrInMemoryStorage { }; if limits.contains(&limit_key) { - let counter = (&counter_key, counter_value); + let counter = (&counter_key, &*counter_value.clone()); let mut counter: Counter = counter.into(); counter.set_remaining(counter.max_value() - counter_value.read()); counter.set_expires_in(counter_value.ttl()); @@ -280,25 +286,17 @@ impl CrInMemoryStorage { &self, counter: &Counter, store_key: Vec, - store_value: &CrCounterValue, + store_value: Arc>, delta: u64, when: SystemTime, ) { store_value.inc_at(delta, counter.window(), when); - - let (expiry, values) = store_value.clone().into_inner(); - self.broker.publish(CounterUpdate { - key: store_key, - values: values.into_iter().collect(), - expires_at: expiry.duration_since(UNIX_EPOCH).unwrap().as_secs(), - }) + self.broker + .publish(CounterEntry::new(store_key, store_value)) } } -async fn process_re_sync( - limits: &Arc, CrCounterValue>>>, - sender: Sender>, -) { +async fn process_re_sync(limits: &Arc>, sender: Sender>) { // sending all the counters to the peer might take a while, so we don't want to lock // the limits map for too long, lets figure first get the list of keys that needs to be sent. let keys: Vec<_> = {