Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First step: Limit storage #352

Merged
merged 4 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 17 additions & 58 deletions limitador-server/src/envoy_rls/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@ use opentelemetry::propagation::Extractor;
use std::collections::HashMap;
use std::sync::Arc;

use limitador::CheckResult;
use tonic::codegen::http::HeaderMap;
use tonic::{transport, transport::Server, Request, Response, Status};
use tracing::Span;
use tracing_opentelemetry::OpenTelemetrySpanExt;

use limitador::counter::Counter;

use crate::envoy_rls::server::envoy::config::core::v3::HeaderValue;
use crate::envoy_rls::server::envoy::service::ratelimit::v3::rate_limit_response::Code;
use crate::envoy_rls::server::envoy::service::ratelimit::v3::rate_limit_service_server::{
Expand All @@ -29,6 +28,21 @@ pub enum RateLimitHeaders {
DraftVersion03,
}

impl RateLimitHeaders {
pub fn headers(&self, response: &mut CheckResult) -> Vec<HeaderValue> {
let mut headers = match self {
RateLimitHeaders::None => Vec::default(),
RateLimitHeaders::DraftVersion03 => response
.response_header()
.into_iter()
.map(|(key, value)| HeaderValue { key, value })
.collect(),
};
headers.sort_by(|a, b| a.key.cmp(&b.key));
headers
}
}

pub struct MyRateLimiter {
limiter: Arc<Limiter>,
rate_limit_headers: RateLimitHeaders,
Expand Down Expand Up @@ -142,10 +156,7 @@ impl RateLimitService for MyRateLimiter {
overall_code: resp_code.into(),
statuses: vec![],
request_headers_to_add: vec![],
response_headers_to_add: to_response_header(
&self.rate_limit_headers,
&mut rate_limited_resp.counters,
),
response_headers_to_add: self.rate_limit_headers.headers(&mut rate_limited_resp),
raw_body: vec![],
dynamic_metadata: None,
quota: None,
Expand All @@ -155,58 +166,6 @@ impl RateLimitService for MyRateLimiter {
}
}

pub fn to_response_header(
rate_limit_headers: &RateLimitHeaders,
counters: &mut [Counter],
) -> Vec<HeaderValue> {
let mut headers = Vec::new();
match rate_limit_headers {
RateLimitHeaders::None => {}

// creates response headers per https://datatracker.ietf.org/doc/id/draft-polli-ratelimit-headers-03.html
RateLimitHeaders::DraftVersion03 => {
// sort by the limit remaining..
counters.sort_by(|a, b| {
let a_remaining = a.remaining().unwrap_or(a.max_value());
let b_remaining = b.remaining().unwrap_or(b.max_value());
a_remaining.cmp(&b_remaining)
});

let mut all_limits_text = String::with_capacity(20 * counters.len());
counters.iter_mut().for_each(|counter| {
all_limits_text.push_str(
format!(", {};w={}", counter.max_value(), counter.window().as_secs()).as_str(),
);
if let Some(name) = counter.limit().name() {
all_limits_text
.push_str(format!(";name=\"{}\"", name.replace('"', "'")).as_str());
}
});

if let Some(counter) = counters.first() {
headers.push(HeaderValue {
key: "X-RateLimit-Limit".to_string(),
value: format!("{}{}", counter.max_value(), all_limits_text),
});

let remaining = counter.remaining().unwrap_or(counter.max_value());
headers.push(HeaderValue {
key: "X-RateLimit-Remaining".to_string(),
value: format!("{}", remaining),
});

if let Some(duration) = counter.expires_in() {
headers.push(HeaderValue {
key: "X-RateLimit-Reset".to_string(),
value: format!("{}", duration.as_secs()),
});
}
}
}
};
headers
}

struct RateLimitRequestHeaders {
inner: HeaderMap,
}
Expand Down
54 changes: 14 additions & 40 deletions limitador-server/src/http_api/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::prometheus_metrics::PrometheusMetrics;
use crate::Limiter;
use actix_web::{http::StatusCode, HttpResponse, HttpResponseBuilder, ResponseError};
use actix_web::{App, HttpServer};
use limitador::CheckResult;
use paperclip::actix::{
api_v2_errors,
api_v2_operation,
Expand Down Expand Up @@ -209,7 +210,7 @@ async fn check_and_report(
add_response_header(
&mut resp,
response_headers.as_str(),
&mut is_rate_limited.counters,
&mut is_rate_limited,
);
resp.json(())
}
Expand All @@ -224,7 +225,7 @@ async fn check_and_report(
add_response_header(
&mut resp,
response_headers.as_str(),
&mut is_rate_limited.counters,
&mut is_rate_limited,
);
resp.json(())
}
Expand All @@ -238,48 +239,21 @@ async fn check_and_report(
pub fn add_response_header(
resp: &mut HttpResponseBuilder,
rate_limit_headers: &str,
counters: &mut [limitador::counter::Counter],
result: &mut CheckResult,
) {
match rate_limit_headers {
if rate_limit_headers == "DraftVersion03" {
// creates response headers per https://datatracker.ietf.org/doc/id/draft-polli-ratelimit-headers-03.html
"DraftVersion03" => {
// sort by the limit remaining..
counters.sort_by(|a, b| {
let a_remaining = a.remaining().unwrap_or(a.max_value());
let b_remaining = b.remaining().unwrap_or(b.max_value());
a_remaining.cmp(&b_remaining)
});

let mut all_limits_text = String::with_capacity(20 * counters.len());
counters.iter_mut().for_each(|counter| {
all_limits_text.push_str(
format!(", {};w={}", counter.max_value(), counter.window().as_secs()).as_str(),
);
if let Some(name) = counter.limit().name() {
all_limits_text
.push_str(format!(";name=\"{}\"", name.replace('"', "'")).as_str());
}
});

if let Some(counter) = counters.first() {
resp.insert_header((
"X-RateLimit-Limit",
format!("{}{}", counter.max_value(), all_limits_text),
));

let remaining = counter.remaining().unwrap_or(counter.max_value());
resp.insert_header((
"X-RateLimit-Remaining".to_string(),
format!("{}", remaining),
));

if let Some(duration) = counter.expires_in() {
resp.insert_header(("X-RateLimit-Reset", format!("{}", duration.as_secs())));
}
let headers = result.response_header();
if let Some(limit) = headers.get("X-RateLimit-Limit") {
resp.insert_header(("X-RateLimit-Limit", limit.clone()));
}
if let Some(remaining) = headers.get("X-RateLimit-Remaining") {
resp.insert_header(("X-RateLimit-Remaining".to_string(), remaining.clone()));
if let Some(duration) = headers.get("X-RateLimit-Reset") {
resp.insert_header(("X-RateLimit-Reset", duration.clone()));
}
}
_default => {}
};
}
}

pub async fn run_http_server(
Expand Down
17 changes: 8 additions & 9 deletions limitador/src/counter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use crate::limit::{Limit, Namespace};
use serde::{Deserialize, Serialize, Serializer};
use std::collections::{BTreeMap, HashMap};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::time::Duration;

#[derive(Eq, Clone, Debug, Serialize, Deserialize)]
pub struct Counter {
limit: Limit,
limit: Arc<Limit>,

// Need to sort to generate the same object when using the JSON as a key or
// value in Redis.
Expand All @@ -26,9 +27,10 @@ where
}

impl Counter {
pub fn new(limit: Limit, set_variables: HashMap<String, String>) -> Self {
pub fn new<L: Into<Arc<Limit>>>(limit: L, set_variables: HashMap<String, String>) -> Self {
// TODO: check that all the variables defined in the limit are set.

let limit = limit.into();
let mut vars = set_variables;
vars.retain(|var, _| limit.has_variable(var));

Expand All @@ -43,7 +45,7 @@ impl Counter {
#[cfg(any(feature = "redis_storage", feature = "disk_storage"))]
pub(crate) fn key(&self) -> Self {
Self {
limit: self.limit.clone(),
limit: Arc::clone(&self.limit),
set_variables: self.set_variables.clone(),
remaining: None,
expires_in: None,
Expand All @@ -58,12 +60,9 @@ impl Counter {
self.limit.max_value()
}

pub fn update_to_limit(&mut self, limit: &Limit) -> bool {
if limit == &self.limit {
self.limit.set_max_value(limit.max_value());
if let Some(name) = limit.name() {
self.limit.set_name(name.to_string());
}
pub fn update_to_limit(&mut self, limit: Arc<Limit>) -> bool {
if limit == self.limit {
self.limit = limit;
return true;
}
false
Expand Down
64 changes: 58 additions & 6 deletions limitador/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@
#![allow(clippy::multiple_crate_versions)]

use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use crate::counter::Counter;
use crate::errors::LimitadorError;
Expand Down Expand Up @@ -226,6 +227,49 @@ pub struct CheckResult {
pub limit_name: Option<String>,
}

impl CheckResult {
pub fn response_header(&mut self) -> HashMap<String, String> {
Copy link
Member Author

@alexsnaps alexsnaps Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking this should return something other than String as keys to our Map here, but not a big deal (I think) for now as there is only one supported impl.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for refactoring this here.

let mut headers = HashMap::new();
// sort by the limit remaining..
self.counters.sort_by(|a, b| {
let a_remaining = a.remaining().unwrap_or(a.max_value());
let b_remaining = b.remaining().unwrap_or(b.max_value());
a_remaining.cmp(&b_remaining)
});

let mut all_limits_text = String::with_capacity(20 * self.counters.len());
self.counters.iter_mut().for_each(|counter| {
all_limits_text.push_str(
format!(", {};w={}", counter.max_value(), counter.window().as_secs()).as_str(),
);
if let Some(name) = counter.limit().name() {
all_limits_text.push_str(format!(";name=\"{}\"", name.replace('"', "'")).as_str());
}
});

if let Some(counter) = self.counters.first() {
headers.insert(
"X-RateLimit-Limit".to_string(),
format!("{}{}", counter.max_value(), all_limits_text),
);

let remaining = counter.remaining().unwrap_or(counter.max_value());
headers.insert(
"X-RateLimit-Remaining".to_string(),
format!("{}", remaining),
);

if let Some(duration) = counter.expires_in() {
headers.insert(
"X-RateLimit-Reset".to_string(),
format!("{}", duration.as_secs()),
);
}
}
headers
}
}

impl From<CheckResult> for bool {
fn from(value: CheckResult) -> Self {
value.limited
Expand Down Expand Up @@ -298,7 +342,11 @@ impl RateLimiter {
}

pub fn get_limits(&self, namespace: &Namespace) -> HashSet<Limit> {
self.storage.get_limits(namespace)
self.storage
.get_limits(namespace)
.iter()
.map(|l| (**l).clone())
.collect()
}

pub fn delete_limits(&self, namespace: &Namespace) -> Result<(), LimitadorError> {
Expand Down Expand Up @@ -432,12 +480,12 @@ impl RateLimiter {
namespace: &Namespace,
values: &HashMap<String, String>,
) -> Result<Vec<Counter>, LimitadorError> {
let limits = self.get_limits(namespace);
let limits = self.storage.get_limits(namespace);

let counters = limits
.iter()
.filter(|lim| lim.applies(values))
.map(|lim| Counter::new(lim.clone(), values.clone()))
.map(|lim| Counter::new(Arc::clone(lim), values.clone()))
.collect();

Ok(counters)
Expand Down Expand Up @@ -470,7 +518,11 @@ impl AsyncRateLimiter {
}

pub fn get_limits(&self, namespace: &Namespace) -> HashSet<Limit> {
self.storage.get_limits(namespace)
self.storage
.get_limits(namespace)
.iter()
.map(|l| (**l).clone())
.collect()
}

pub async fn delete_limits(&self, namespace: &Namespace) -> Result<(), LimitadorError> {
Expand Down Expand Up @@ -610,12 +662,12 @@ impl AsyncRateLimiter {
namespace: &Namespace,
values: &HashMap<String, String>,
) -> Result<Vec<Counter>, LimitadorError> {
let limits = self.get_limits(namespace);
let limits = self.storage.get_limits(namespace);

let counters = limits
.iter()
.filter(|lim| lim.applies(values))
.map(|lim| Counter::new(lim.clone(), values.clone()))
.map(|lim| Counter::new(Arc::clone(lim), values.clone()))
.collect();

Ok(counters)
Expand Down
12 changes: 7 additions & 5 deletions limitador/src/storage/disk/rocksdb_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use rocksdb::{
DB,
};
use std::collections::{BTreeSet, HashSet};
use std::ops::Deref;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tracing::trace_span;

Expand Down Expand Up @@ -91,7 +93,7 @@ impl CounterStorage for RocksDbStorage {
}

#[tracing::instrument(skip_all)]
fn get_counters(&self, limits: &HashSet<Limit>) -> Result<HashSet<Counter>, StorageErr> {
fn get_counters(&self, limits: &HashSet<Arc<Limit>>) -> Result<HashSet<Counter>, StorageErr> {
let mut counters = HashSet::default();
let namepaces: BTreeSet<&str> = limits.iter().map(|l| l.namespace().as_ref()).collect();
for ns in namepaces {
Expand All @@ -113,8 +115,8 @@ impl CounterStorage for RocksDbStorage {
}
let value: ExpiringValue = value.as_ref().try_into()?;
for limit in limits {
if limit == counter.limit() {
counter.update_to_limit(limit);
if limit.deref() == counter.limit() {
counter.update_to_limit(Arc::clone(limit));
let ttl = value.ttl();
counter.set_expires_in(ttl);
counter.set_remaining(limit.max_value() - value.value());
Expand All @@ -133,8 +135,8 @@ impl CounterStorage for RocksDbStorage {
}

#[tracing::instrument(skip_all)]
fn delete_counters(&self, limits: HashSet<Limit>) -> Result<(), StorageErr> {
let counters = self.get_counters(&limits)?;
fn delete_counters(&self, limits: &HashSet<Arc<Limit>>) -> Result<(), StorageErr> {
let counters = self.get_counters(limits)?;
for counter in &counters {
let span = trace_span!("datastore");
let _entered = span.enter();
Expand Down
Loading
Loading