Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[distributed store] Batch up updates per session. #345

Merged
merged 1 commit into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 96 additions & 8 deletions limitador/src/storage/distributed/grpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH};
use std::{error::Error, io::ErrorKind, pin::Pin};

use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::mpsc::Sender;
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio::sync::mpsc::{Permit, Sender};
use tokio::sync::{broadcast, mpsc, Notify, RwLock};
use tokio::time::sleep;
use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt};
use tonic::{Code, Request, Response, Status, Streaming};
use tracing::debug;

use crate::storage::distributed::cr_counter_value::CrCounterValue;
use crate::storage::distributed::grpc::v1::packet::Message;
use crate::storage::distributed::grpc::v1::replication_client::ReplicationClient;
use crate::storage::distributed::grpc::v1::replication_server::{Replication, ReplicationServer};
Expand Down Expand Up @@ -145,12 +146,47 @@ impl Session {
TrySendError::Closed(_) => Status::unavailable("re-sync channel closed"),
})?;

let mut updates = self.broker_state.publisher.subscribe();
let mut udpates_to_send = self.broker_state.publisher.subscribe();
let mut tx_updates_by_key = HashMap::new();
let mut tx_updates_order = vec![];
let notifier = Notify::default();

loop {
tokio::select! {
update = updates.recv() => {
update = udpates_to_send.recv() => {
let update = update.map_err(|_| Status::unknown("broadcast error"))?;
self.send(Message::CounterUpdate(update)).await?;
// Multiple updates collapse into a single update for the same key
if !tx_updates_by_key.contains_key(&update.key) {
tx_updates_by_key.insert(update.key.clone(), update.value);
tx_updates_order.push(update.key);
notifier.notify_one();
}
}
_ = notifier.notified() => {
// while we have pending updates to send...
while !tx_updates_order.is_empty() {
// and we have space on the transmission channel to send the update...
match self.out_stream.clone().try_reserve() {
Err(_) => {
break
},
Ok(permit) => {

let key = tx_updates_order.remove(0);
let cr_counter_value = tx_updates_by_key.remove(&key).unwrap().clone();
let (expiry, values) = (*cr_counter_value).clone().into_inner();

// only send the update if it has not expired.
if expiry > SystemTime::now() {
permit.send(Ok(Message::CounterUpdate(CounterUpdate {
key,
values: values.into_iter().collect(),
expires_at: expiry.duration_since(UNIX_EPOCH).unwrap().as_secs(),
})))?;
}
}
}
}
}
result = in_stream.next() => {
match result {
Expand Down Expand Up @@ -354,14 +390,66 @@ impl MessageSender {
},
}
}
fn try_reserve(&self) -> Result<MessagePermit<'_>, Status> {
match self {
MessageSender::Client(sender) => {
let permit = sender
.try_reserve()
.map_err(|_| Status::unknown("send error"))?;
Ok(MessagePermit::Client(permit))
}
MessageSender::Server(sender) => {
let permit = sender
.try_reserve()
.map_err(|_| Status::unknown("send error"))?;
Ok(MessagePermit::Server(permit))
}
}
}
}

enum MessagePermit<'a> {
Server(Permit<'a, Result<Packet, Status>>),
Client(Permit<'a, Packet>),
}
impl<'a> MessagePermit<'a> {
fn send(self, message: Result<Message, Status>) -> Result<(), Status> {
match self {
MessagePermit::Server(sender) => {
let value = message.map(|x| Packet { message: Some(x) });
sender.send(value);
Ok(())
}
MessagePermit::Client(sender) => match message {
Ok(message) => {
sender.send(Packet {
message: Some(message),
});
Ok(())
}
Err(err) => Err(err),
},
}
}
}

type CounterUpdateFn = Pin<Box<dyn Fn(CounterUpdate) + Sync + Send>>;
#[derive(Clone, Debug)]
pub struct CounterEntry {
pub key: Vec<u8>,
pub value: Arc<CrCounterValue<String>>,
}

impl CounterEntry {
pub fn new(key: Vec<u8>, value: Arc<CrCounterValue<String>>) -> Self {
Self { key, value }
}
}

#[derive(Clone)]
struct BrokerState {
id: String,
publisher: broadcast::Sender<CounterUpdate>,
publisher: broadcast::Sender<CounterEntry>,
on_counter_update: Arc<CounterUpdateFn>,
on_re_sync: Arc<Sender<Sender<Option<CounterUpdate>>>>,
}
Expand All @@ -383,7 +471,7 @@ impl Broker {
on_re_sync: Sender<Sender<Option<CounterUpdate>>>,
) -> Broker {
let (tx, _) = broadcast::channel(16);
let publisher: broadcast::Sender<CounterUpdate> = tx;
let publisher: broadcast::Sender<CounterEntry> = tx;

Broker {
listen_address,
Expand All @@ -401,7 +489,7 @@ impl Broker {
}
}

pub fn publish(&self, counter_update: CounterUpdate) {
pub fn publish(&self, counter_update: CounterEntry) {
// ignore the send error, it just means there are no active subscribers
_ = self.broker_state.publisher.send(counter_update);
}
Expand Down
52 changes: 25 additions & 27 deletions limitador/src/storage/distributed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ use crate::counter::Counter;
use crate::limit::{Limit, Namespace};
use crate::storage::distributed::cr_counter_value::CrCounterValue;
use crate::storage::distributed::grpc::v1::CounterUpdate;
use crate::storage::distributed::grpc::Broker;
use crate::storage::distributed::grpc::{Broker, CounterEntry};
use crate::storage::{Authorization, CounterStorage, StorageErr};

mod cr_counter_value;
mod grpc;

pub type LimitsMap = HashMap<Vec<u8>, CrCounterValue<String>>;
pub type LimitsMap = HashMap<Vec<u8>, Arc<CrCounterValue<String>>>;

pub struct CrInMemoryStorage {
identifier: String,
Expand All @@ -45,11 +45,11 @@ impl CounterStorage for CrInMemoryStorage {
if limit.variables().is_empty() {
let mut limits = self.limits.write().unwrap();
let key = encode_limit_to_key(limit);
limits.entry(key).or_insert(CrCounterValue::new(
limits.entry(key).or_insert(Arc::new(CrCounterValue::new(
self.identifier.clone(),
limit.max_value(),
Duration::from_secs(limit.seconds()),
));
)));
}
Ok(())
}
Expand All @@ -63,13 +63,16 @@ impl CounterStorage for CrInMemoryStorage {
match limits.entry(key.clone()) {
Entry::Vacant(entry) => {
let duration = counter.window();
let store_value =
CrCounterValue::new(self.identifier.clone(), counter.max_value(), duration);
self.increment_counter(counter, key, &store_value, delta, now);
let store_value = Arc::new(CrCounterValue::new(
self.identifier.clone(),
counter.max_value(),
duration,
));
self.increment_counter(counter, key, store_value.clone(), delta, now);
entry.insert(store_value);
}
Entry::Occupied(entry) => {
self.increment_counter(counter, key, entry.get(), delta, now);
self.increment_counter(counter, key, entry.get().clone(), delta, now);
}
};
Ok(())
Expand Down Expand Up @@ -132,11 +135,14 @@ impl CounterStorage for CrInMemoryStorage {
if !counter_existed {
// try again with a write lock to create the counter if it's still missing.
let mut limits = self.limits.write().unwrap();
let store_value = limits.entry(key.clone()).or_insert(CrCounterValue::new(
self.identifier.clone(),
counter.max_value(),
counter.window(),
));
let store_value =
limits
.entry(key.clone())
.or_insert(Arc::new(CrCounterValue::new(
self.identifier.clone(),
counter.max_value(),
counter.window(),
)));

if let Some(limited) = process_counter(counter, store_value.read(), delta) {
if !load_counters {
Expand All @@ -157,7 +163,7 @@ impl CounterStorage for CrInMemoryStorage {
.into_iter()
.for_each(|(counter, key)| {
let store_value = limits.get(&key).unwrap();
self.increment_counter(&counter, key, store_value, delta, now);
self.increment_counter(&counter, key, store_value.clone(), delta, now);
});

Ok(Authorization::Ok)
Expand All @@ -181,7 +187,7 @@ impl CounterStorage for CrInMemoryStorage {
};

if limits.contains(&limit_key) {
let counter = (&counter_key, counter_value);
let counter = (&counter_key, &*counter_value.clone());
let mut counter: Counter = counter.into();
counter.set_remaining(counter.max_value() - counter_value.read());
counter.set_expires_in(counter_value.ttl());
Expand Down Expand Up @@ -280,25 +286,17 @@ impl CrInMemoryStorage {
&self,
counter: &Counter,
store_key: Vec<u8>,
store_value: &CrCounterValue<String>,
store_value: Arc<CrCounterValue<String>>,
delta: u64,
when: SystemTime,
) {
store_value.inc_at(delta, counter.window(), when);

let (expiry, values) = store_value.clone().into_inner();
self.broker.publish(CounterUpdate {
key: store_key,
values: values.into_iter().collect(),
expires_at: expiry.duration_since(UNIX_EPOCH).unwrap().as_secs(),
})
self.broker
.publish(CounterEntry::new(store_key, store_value))
}
}

async fn process_re_sync(
limits: &Arc<RwLock<HashMap<Vec<u8>, CrCounterValue<String>>>>,
sender: Sender<Option<CounterUpdate>>,
) {
async fn process_re_sync(limits: &Arc<RwLock<LimitsMap>>, sender: Sender<Option<CounterUpdate>>) {
// sending all the counters to the peer might take a while, so we don't want to lock
// the limits map for too long, lets figure first get the list of keys that needs to be sent.
let keys: Vec<_> = {
Expand Down
Loading