Skip to content

Commit

Permalink
[distributed store] Batch up updates per session.
Browse files Browse the repository at this point in the history
This allows connecting slow peers without affecting how quickly 
We can replicate to fast peers.

Signed-off-by: Hiram Chirino <[email protected]>
  • Loading branch information
chirino committed May 24, 2024
1 parent d3e5110 commit 2ffd28e
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 30 deletions.
100 changes: 93 additions & 7 deletions limitador/src/storage/distributed/grpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use std::{error::Error, io::ErrorKind, pin::Pin};

use tokio::sync::mpsc::Sender;
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio::sync::mpsc::{Permit, Sender};
use tokio::sync::{broadcast, mpsc, Notify, RwLock};
use tokio::time::sleep;

use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt};
use tonic::{Code, Request, Response, Status, Streaming};
use tracing::debug;

use crate::storage::distributed::cr_counter_value::CrCounterValue;
use crate::storage::distributed::grpc::v1::packet::Message;
use crate::storage::distributed::grpc::v1::replication_client::ReplicationClient;
use crate::storage::distributed::grpc::v1::replication_server::{Replication, ReplicationServer};
Expand Down Expand Up @@ -106,12 +106,46 @@ impl Session {
.await?;

let mut udpates_to_send = self.broker_state.publisher.subscribe();
let mut tx_updates_by_key = HashMap::new();
let mut tx_updates_order = vec![];
let notifier = Notify::default();

loop {
tokio::select! {
update = udpates_to_send.recv() => {
let update = update.map_err(|_| Status::unknown("broadcast error"))?;
self.send(Message::CounterUpdate(update)).await?;
// Multiple updates collapse into a single update for the same key
if !tx_updates_by_key.contains_key(&update.key) {
tx_updates_by_key.insert(update.key.clone(), update.value);
tx_updates_order.push(update.key);
notifier.notify_one();
}
}
_ = notifier.notified() => {
// while we have pending updates to send...
while !tx_updates_order.is_empty() {
// and we have space on the transmission channel to send the update...
match self.out_stream.clone().try_reserve() {
Err(_) => {
break
},
Ok(permit) => {

let key = tx_updates_order.remove(0);
let cr_counter_value = tx_updates_by_key.remove(&key).unwrap().clone();
let (expiry, values) = (*cr_counter_value).clone().into_inner();

// only send the update if it has not expired.
if expiry > SystemTime::now() {
permit.send(Ok(Message::CounterUpdate(CounterUpdate {
key,
values: values.into_iter().collect(),
expires_at: expiry.duration_since(UNIX_EPOCH).unwrap().as_secs(),
})))?;
}
}
}
}
}
result = in_stream.next() => {
match result {
Expand Down Expand Up @@ -315,14 +349,66 @@ impl MessageSender {
},
}
}
fn try_reserve(&self) -> Result<MessagePermit<'_>, Status> {
match self {
MessageSender::Client(sender) => {
let permit = sender
.try_reserve()
.map_err(|_| Status::unknown("send error"))?;
Ok(MessagePermit::Client(permit))
}
MessageSender::Server(sender) => {
let permit = sender
.try_reserve()
.map_err(|_| Status::unknown("send error"))?;
Ok(MessagePermit::Server(permit))
}
}
}
}

enum MessagePermit<'a> {
Server(Permit<'a, Result<Packet, Status>>),
Client(Permit<'a, Packet>),
}
impl<'a> MessagePermit<'a> {
fn send(self, message: Result<Message, Status>) -> Result<(), Status> {
match self {
MessagePermit::Server(sender) => {
let value = message.map(|x| Packet { message: Some(x) });
sender.send(value);
Ok(())
}
MessagePermit::Client(sender) => match message {
Ok(message) => {
sender.send(Packet {
message: Some(message),
});
Ok(())
}
Err(err) => Err(err),
},
}
}
}

type CounterUpdateFn = Pin<Box<dyn Fn(CounterUpdate) + Sync + Send>>;
#[derive(Clone, Debug)]
pub struct CounterEntry {
pub key: Vec<u8>,
pub value: Arc<CrCounterValue<String>>,
}

impl CounterEntry {
pub fn new(key: Vec<u8>, value: Arc<CrCounterValue<String>>) -> Self {
Self { key, value }
}
}

#[derive(Clone)]
struct BrokerState {
id: String,
publisher: broadcast::Sender<CounterUpdate>,
publisher: broadcast::Sender<CounterEntry>,
on_counter_update: Arc<CounterUpdateFn>,
}

Expand All @@ -342,7 +428,7 @@ impl Broker {
on_counter_update: CounterUpdateFn,
) -> Broker {
let (tx, _) = broadcast::channel(16);
let publisher: broadcast::Sender<CounterUpdate> = tx;
let publisher: broadcast::Sender<CounterEntry> = tx;

Broker {
listen_address,
Expand All @@ -359,7 +445,7 @@ impl Broker {
}
}

pub fn publish(&self, counter_update: CounterUpdate) {
pub fn publish(&self, counter_update: CounterEntry) {
// ignore the send error, it just means there are no active subscribers
_ = self.broker_state.publisher.send(counter_update);
}
Expand Down
47 changes: 24 additions & 23 deletions limitador/src/storage/distributed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ use crate::counter::Counter;
use crate::limit::{Limit, Namespace};
use crate::storage::distributed::cr_counter_value::CrCounterValue;
use crate::storage::distributed::grpc::v1::CounterUpdate;
use crate::storage::distributed::grpc::Broker;
use crate::storage::distributed::grpc::{Broker, CounterEntry};
use crate::storage::{Authorization, CounterStorage, StorageErr};

mod cr_counter_value;
mod grpc;

pub type LimitsMap = HashMap<Vec<u8>, CrCounterValue<String>>;
pub type LimitsMap = HashMap<Vec<u8>, Arc<CrCounterValue<String>>>;

pub struct CrInMemoryStorage {
identifier: String,
Expand All @@ -42,11 +42,11 @@ impl CounterStorage for CrInMemoryStorage {
if limit.variables().is_empty() {
let mut limits = self.limits.write().unwrap();
let key = encode_limit_to_key(limit);
limits.entry(key).or_insert(CrCounterValue::new(
limits.entry(key).or_insert(Arc::new(CrCounterValue::new(
self.identifier.clone(),
limit.max_value(),
Duration::from_secs(limit.seconds()),
));
)));
}
Ok(())
}
Expand All @@ -60,13 +60,16 @@ impl CounterStorage for CrInMemoryStorage {
match limits.entry(key.clone()) {
Entry::Vacant(entry) => {
let duration = counter.window();
let store_value =
CrCounterValue::new(self.identifier.clone(), counter.max_value(), duration);
self.increment_counter(counter, key, &store_value, delta, now);
let store_value = Arc::new(CrCounterValue::new(
self.identifier.clone(),
counter.max_value(),
duration,
));
self.increment_counter(counter, key, store_value.clone(), delta, now);
entry.insert(store_value);
}
Entry::Occupied(entry) => {
self.increment_counter(counter, key, entry.get(), delta, now);
self.increment_counter(counter, key, entry.get().clone(), delta, now);
}
};
Ok(())
Expand Down Expand Up @@ -129,11 +132,14 @@ impl CounterStorage for CrInMemoryStorage {
if !counter_existed {
// try again with a write lock to create the counter if it's still missing.
let mut limits = self.limits.write().unwrap();
let store_value = limits.entry(key.clone()).or_insert(CrCounterValue::new(
self.identifier.clone(),
counter.max_value(),
counter.window(),
));
let store_value =
limits
.entry(key.clone())
.or_insert(Arc::new(CrCounterValue::new(
self.identifier.clone(),
counter.max_value(),
counter.window(),
)));

if let Some(limited) = process_counter(counter, store_value.read(), delta) {
if !load_counters {
Expand All @@ -154,7 +160,7 @@ impl CounterStorage for CrInMemoryStorage {
.into_iter()
.for_each(|(counter, key)| {
let store_value = limits.get(&key).unwrap();
self.increment_counter(&counter, key, store_value, delta, now);
self.increment_counter(&counter, key, store_value.clone(), delta, now);
});

Ok(Authorization::Ok)
Expand All @@ -178,7 +184,7 @@ impl CounterStorage for CrInMemoryStorage {
};

if limits.contains(&limit_key) {
let counter = (&counter_key, counter_value);
let counter = (&counter_key, &*counter_value.clone());
let mut counter: Counter = counter.into();
counter.set_remaining(counter.max_value() - counter_value.read());
counter.set_expires_in(counter_value.ttl());
Expand Down Expand Up @@ -264,18 +270,13 @@ impl CrInMemoryStorage {
&self,
counter: &Counter,
store_key: Vec<u8>,
store_value: &CrCounterValue<String>,
store_value: Arc<CrCounterValue<String>>,
delta: u64,
when: SystemTime,
) {
store_value.inc_at(delta, counter.window(), when);

let (expiry, values) = store_value.clone().into_inner();
self.broker.publish(CounterUpdate {
key: store_key,
values: values.into_iter().collect(),
expires_at: expiry.duration_since(UNIX_EPOCH).unwrap().as_secs(),
})
self.broker
.publish(CounterEntry::new(store_key, store_value))
}
}

Expand Down

0 comments on commit 2ffd28e

Please sign in to comment.