Skip to content

Commit

Permalink
Refactor to use RateLimitData in http server
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-cattermole committed Apr 5, 2024
1 parent 602bfcc commit 265c870
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 36 deletions.
69 changes: 44 additions & 25 deletions limitador-server/src/http_api/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,24 @@ use paperclip::actix::{
use std::fmt;
use std::sync::Arc;

struct RateLimitData {
limiter: Arc<Limiter>,
metrics: Arc<PrometheusMetrics>,
}

impl RateLimitData {
fn new(limiter: Arc<Limiter>, metrics: Arc<PrometheusMetrics>) -> Self {
Self { limiter, metrics }
}
fn limiter(&self) -> &Limiter {
self.limiter.as_ref()
}

fn metrics(&self) -> &PrometheusMetrics {
self.metrics.as_ref()
}
}

#[api_v2_errors(429, 500)]
#[derive(Debug)]
enum ErrorResponse {
Expand Down Expand Up @@ -47,20 +65,18 @@ async fn status() -> web::Json<()> {

#[tracing::instrument(skip(data))]
#[api_v2_operation]
async fn metrics(data: web::Data<(Arc<Limiter>, Arc<PrometheusMetrics>)>) -> String {
let (_, metrics) = data.get_ref();
metrics.gather_metrics()
async fn metrics(data: web::Data<RateLimitData>) -> String {
data.get_ref().metrics().gather_metrics()
}

#[api_v2_operation]
#[tracing::instrument(skip(data))]
async fn get_limits(
data: web::Data<(Arc<Limiter>, Arc<PrometheusMetrics>)>,
data: web::Data<RateLimitData>,
namespace: web::Path<String>,
) -> Result<web::Json<Vec<Limit>>, ErrorResponse> {
let namespace = &namespace.into_inner().into();
let (limiter, _) = data.get_ref();
let limits = match limiter.as_ref() {
let limits = match data.get_ref().limiter() {
Limiter::Blocking(limiter) => limiter.get_limits(namespace),
Limiter::Async(limiter) => limiter.get_limits(namespace),
};
Expand All @@ -71,12 +87,11 @@ async fn get_limits(
#[tracing::instrument(skip(data))]
#[api_v2_operation]
async fn get_counters(
data: web::Data<(Arc<Limiter>, Arc<PrometheusMetrics>)>,
data: web::Data<RateLimitData>,
namespace: web::Path<String>,
) -> Result<web::Json<Vec<Counter>>, ErrorResponse> {
let namespace = namespace.into_inner().into();
let (limiter, _) = data.get_ref();
let get_counters_result = match limiter.as_ref() {
let get_counters_result = match data.get_ref().limiter() {
Limiter::Blocking(limiter) => limiter.get_counters(&namespace),
Limiter::Async(limiter) => limiter.get_counters(&namespace).await,
};
Expand All @@ -96,7 +111,7 @@ async fn get_counters(
#[tracing::instrument(skip(state))]
#[api_v2_operation]
async fn check(
state: web::Data<(Arc<Limiter>, Arc<PrometheusMetrics>)>,
state: web::Data<RateLimitData>,
request: web::Json<CheckAndReportInfo>,
) -> Result<web::Json<()>, ErrorResponse> {
let CheckAndReportInfo {
Expand All @@ -105,8 +120,7 @@ async fn check(
delta,
} = request.into_inner();
let namespace = namespace.into();
let (limiter, _) = state.get_ref();
let is_rate_limited_result = match limiter.as_ref() {
let is_rate_limited_result = match state.get_ref().limiter() {
Limiter::Blocking(limiter) => limiter.is_rate_limited(&namespace, &values, delta),
Limiter::Async(limiter) => limiter.is_rate_limited(&namespace, &values, delta).await,
};
Expand All @@ -126,7 +140,7 @@ async fn check(
#[tracing::instrument(skip(data))]
#[api_v2_operation]
async fn report(
data: web::Data<(Arc<Limiter>, Arc<PrometheusMetrics>)>,
data: web::Data<RateLimitData>,
request: web::Json<CheckAndReportInfo>,
) -> Result<web::Json<()>, ErrorResponse> {
let CheckAndReportInfo {
Expand All @@ -135,8 +149,7 @@ async fn report(
delta,
} = request.into_inner();
let namespace = namespace.into();
let (limiter, _) = data.get_ref();
let update_counters_result = match limiter.as_ref() {
let update_counters_result = match data.get_ref().limiter() {
Limiter::Blocking(limiter) => limiter.update_counters(&namespace, &values, delta),
Limiter::Async(limiter) => limiter.update_counters(&namespace, &values, delta).await,
};
Expand All @@ -150,7 +163,7 @@ async fn report(
#[tracing::instrument(skip(data))]
#[api_v2_operation]
async fn check_and_report(
data: web::Data<(Arc<Limiter>, Arc<PrometheusMetrics>)>,
data: web::Data<RateLimitData>,
request: web::Json<CheckAndReportInfo>,
) -> Result<web::Json<()>, ErrorResponse> {
let CheckAndReportInfo {
Expand All @@ -159,8 +172,8 @@ async fn check_and_report(
delta,
} = request.into_inner();
let namespace = namespace.into();
let (limiter, metrics) = data.get_ref();
let rate_limited_and_update_result = match limiter.as_ref() {
let rate_limit_data = data.get_ref();
let rate_limited_and_update_result = match rate_limit_data.limiter() {
Limiter::Blocking(limiter) => {
limiter.check_rate_limited_and_update(&namespace, &values, delta, false)
}
Expand All @@ -174,10 +187,12 @@ async fn check_and_report(
match rate_limited_and_update_result {
Ok(is_rate_limited) => {
if is_rate_limited.limited {
metrics.incr_limited_calls(&namespace, is_rate_limited.limit_name.as_deref());
rate_limit_data
.metrics()
.incr_limited_calls(&namespace, is_rate_limited.limit_name.as_deref());
Err(ErrorResponse::TooManyRequests)
} else {
metrics.incr_authorized_calls(&namespace);
rate_limit_data.metrics().incr_authorized_calls(&namespace);
Ok(Json(()))
}
}
Expand All @@ -190,7 +205,7 @@ pub async fn run_http_server(
rate_limiter: Arc<Limiter>,
prometheus_metrics: Arc<PrometheusMetrics>,
) -> std::io::Result<()> {
let data = web::Data::new((rate_limiter, prometheus_metrics));
let data = web::Data::new(RateLimitData::new(rate_limiter, prometheus_metrics));

// This uses the paperclip crate to generate an OpenAPI spec.
// Ref: https://paperclip.waffles.space/actix-plugin.html
Expand Down Expand Up @@ -243,7 +258,8 @@ mod tests {
async fn test_metrics() {
let rate_limiter: Arc<Limiter> =
Arc::new(Limiter::new(Configuration::default()).await.unwrap());
let data = web::Data::new(rate_limiter);
let prometheus_metrics: Arc<PrometheusMetrics> = Arc::new(PrometheusMetrics::default());
let data = web::Data::new(RateLimitData::new(rate_limiter, prometheus_metrics));
let app = test::init_service(
App::new()
.app_data(data.clone())
Expand All @@ -267,7 +283,8 @@ mod tests {

let limit = create_test_limit(&limiter, namespace, 10).await;
let rate_limiter: Arc<Limiter> = Arc::new(limiter);
let data = web::Data::new(rate_limiter);
let prometheus_metrics: Arc<PrometheusMetrics> = Arc::new(PrometheusMetrics::default());
let data = web::Data::new(RateLimitData::new(rate_limiter, prometheus_metrics));
let app = test::init_service(
App::new()
.app_data(data.clone())
Expand All @@ -293,7 +310,8 @@ mod tests {
let namespace = "test_namespace";
let _limit = create_test_limit(&limiter, namespace, 1).await;
let rate_limiter: Arc<Limiter> = Arc::new(limiter);
let data = web::Data::new(rate_limiter);
let prometheus_metrics: Arc<PrometheusMetrics> = Arc::new(PrometheusMetrics::default());
let data = web::Data::new(RateLimitData::new(rate_limiter, prometheus_metrics));
let app = test::init_service(
App::new()
.app_data(data.clone())
Expand Down Expand Up @@ -337,7 +355,8 @@ mod tests {
let _limit = create_test_limit(&limiter, namespace, 1).await;

let rate_limiter: Arc<Limiter> = Arc::new(limiter);
let data = web::Data::new(rate_limiter);
let prometheus_metrics: Arc<PrometheusMetrics> = Arc::new(PrometheusMetrics::default());
let data = web::Data::new(RateLimitData::new(rate_limiter, prometheus_metrics));
let app = test::init_service(
App::new()
.app_data(data.clone())
Expand Down
2 changes: 0 additions & 2 deletions limitador-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.init();
};

prometheus_metrics.set_use_limit_name_in_label(limit_name_in_metrics);

info!("Version: {}", version);
info!("Using config: {:?}", config);
(config, prometheus_metrics)
Expand Down
12 changes: 3 additions & 9 deletions limitador-server/src/prometheus_metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use limitador::limit::Namespace;
use prometheus::{
Encoder, Histogram, HistogramOpts, IntCounterVec, IntGauge, Opts, Registry, TextEncoder,
};
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;

const NAMESPACE_LABEL: &str = "limitador_namespace";
Expand All @@ -13,7 +12,7 @@ pub struct PrometheusMetrics {
authorized_calls: IntCounterVec,
limited_calls: IntCounterVec,
counter_latency: Histogram,
use_limit_name_label: AtomicBool,
use_limit_name_label: bool,
}

impl Default for PrometheusMetrics {
Expand All @@ -35,11 +34,6 @@ impl PrometheusMetrics {
Self::new_with_options(true)
}

pub fn set_use_limit_name_in_label(&self, use_limit_name_in_label: bool) {
self.use_limit_name_label
.store(use_limit_name_in_label, Ordering::SeqCst)
}

pub fn incr_authorized_calls(&self, namespace: &Namespace) {
self.authorized_calls
.with_label_values(&[namespace.as_ref()])
Expand All @@ -52,7 +46,7 @@ impl PrometheusMetrics {
{
let mut labels = vec![namespace.as_ref()];

if self.use_limit_name_label.load(Ordering::Relaxed) {
if self.use_limit_name_label {
// If we have configured the metric to accept 2 labels we need to
// set values for them.
labels.push(limit_name.into().unwrap_or(""));
Expand Down Expand Up @@ -106,7 +100,7 @@ impl PrometheusMetrics {
authorized_calls: authorized_calls_counter,
limited_calls: limited_calls_counter,
counter_latency,
use_limit_name_label: AtomicBool::new(use_limit_name_label),
use_limit_name_label,
}
}

Expand Down

0 comments on commit 265c870

Please sign in to comment.