From 265c87043164fb961f7a72ccb64457844353ea1e Mon Sep 17 00:00:00 2001 From: Adam Cattermole Date: Fri, 5 Apr 2024 15:57:29 +0100 Subject: [PATCH] Refactor to use RateLimitData in http server --- limitador-server/src/http_api/server.rs | 69 ++++++++++++++-------- limitador-server/src/main.rs | 2 - limitador-server/src/prometheus_metrics.rs | 12 +--- 3 files changed, 47 insertions(+), 36 deletions(-) diff --git a/limitador-server/src/http_api/server.rs b/limitador-server/src/http_api/server.rs index 91514a91..d5063d3e 100644 --- a/limitador-server/src/http_api/server.rs +++ b/limitador-server/src/http_api/server.rs @@ -14,6 +14,24 @@ use paperclip::actix::{ use std::fmt; use std::sync::Arc; +struct RateLimitData { + limiter: Arc, + metrics: Arc, +} + +impl RateLimitData { + fn new(limiter: Arc, metrics: Arc) -> Self { + Self { limiter, metrics } + } + fn limiter(&self) -> &Limiter { + self.limiter.as_ref() + } + + fn metrics(&self) -> &PrometheusMetrics { + self.metrics.as_ref() + } +} + #[api_v2_errors(429, 500)] #[derive(Debug)] enum ErrorResponse { @@ -47,20 +65,18 @@ async fn status() -> web::Json<()> { #[tracing::instrument(skip(data))] #[api_v2_operation] -async fn metrics(data: web::Data<(Arc, Arc)>) -> String { - let (_, metrics) = data.get_ref(); - metrics.gather_metrics() +async fn metrics(data: web::Data) -> String { + data.get_ref().metrics().gather_metrics() } #[api_v2_operation] #[tracing::instrument(skip(data))] async fn get_limits( - data: web::Data<(Arc, Arc)>, + data: web::Data, namespace: web::Path, ) -> Result>, ErrorResponse> { let namespace = &namespace.into_inner().into(); - let (limiter, _) = data.get_ref(); - let limits = match limiter.as_ref() { + let limits = match data.get_ref().limiter() { Limiter::Blocking(limiter) => limiter.get_limits(namespace), Limiter::Async(limiter) => limiter.get_limits(namespace), }; @@ -71,12 +87,11 @@ async fn get_limits( #[tracing::instrument(skip(data))] #[api_v2_operation] async fn get_counters( - data: web::Data<(Arc, Arc)>, + data: web::Data, namespace: web::Path, ) -> Result>, ErrorResponse> { let namespace = namespace.into_inner().into(); - let (limiter, _) = data.get_ref(); - let get_counters_result = match limiter.as_ref() { + let get_counters_result = match data.get_ref().limiter() { Limiter::Blocking(limiter) => limiter.get_counters(&namespace), Limiter::Async(limiter) => limiter.get_counters(&namespace).await, }; @@ -96,7 +111,7 @@ async fn get_counters( #[tracing::instrument(skip(state))] #[api_v2_operation] async fn check( - state: web::Data<(Arc, Arc)>, + state: web::Data, request: web::Json, ) -> Result, ErrorResponse> { let CheckAndReportInfo { @@ -105,8 +120,7 @@ async fn check( delta, } = request.into_inner(); let namespace = namespace.into(); - let (limiter, _) = state.get_ref(); - let is_rate_limited_result = match limiter.as_ref() { + let is_rate_limited_result = match state.get_ref().limiter() { Limiter::Blocking(limiter) => limiter.is_rate_limited(&namespace, &values, delta), Limiter::Async(limiter) => limiter.is_rate_limited(&namespace, &values, delta).await, }; @@ -126,7 +140,7 @@ async fn check( #[tracing::instrument(skip(data))] #[api_v2_operation] async fn report( - data: web::Data<(Arc, Arc)>, + data: web::Data, request: web::Json, ) -> Result, ErrorResponse> { let CheckAndReportInfo { @@ -135,8 +149,7 @@ async fn report( delta, } = request.into_inner(); let namespace = namespace.into(); - let (limiter, _) = data.get_ref(); - let update_counters_result = match limiter.as_ref() { + let update_counters_result = match data.get_ref().limiter() { Limiter::Blocking(limiter) => limiter.update_counters(&namespace, &values, delta), Limiter::Async(limiter) => limiter.update_counters(&namespace, &values, delta).await, }; @@ -150,7 +163,7 @@ async fn report( #[tracing::instrument(skip(data))] #[api_v2_operation] async fn check_and_report( - data: web::Data<(Arc, Arc)>, + data: web::Data, request: web::Json, ) -> Result, ErrorResponse> { let CheckAndReportInfo { @@ -159,8 +172,8 @@ async fn check_and_report( delta, } = request.into_inner(); let namespace = namespace.into(); - let (limiter, metrics) = data.get_ref(); - let rate_limited_and_update_result = match limiter.as_ref() { + let rate_limit_data = data.get_ref(); + let rate_limited_and_update_result = match rate_limit_data.limiter() { Limiter::Blocking(limiter) => { limiter.check_rate_limited_and_update(&namespace, &values, delta, false) } @@ -174,10 +187,12 @@ async fn check_and_report( match rate_limited_and_update_result { Ok(is_rate_limited) => { if is_rate_limited.limited { - metrics.incr_limited_calls(&namespace, is_rate_limited.limit_name.as_deref()); + rate_limit_data + .metrics() + .incr_limited_calls(&namespace, is_rate_limited.limit_name.as_deref()); Err(ErrorResponse::TooManyRequests) } else { - metrics.incr_authorized_calls(&namespace); + rate_limit_data.metrics().incr_authorized_calls(&namespace); Ok(Json(())) } } @@ -190,7 +205,7 @@ pub async fn run_http_server( rate_limiter: Arc, prometheus_metrics: Arc, ) -> std::io::Result<()> { - let data = web::Data::new((rate_limiter, prometheus_metrics)); + let data = web::Data::new(RateLimitData::new(rate_limiter, prometheus_metrics)); // This uses the paperclip crate to generate an OpenAPI spec. // Ref: https://paperclip.waffles.space/actix-plugin.html @@ -243,7 +258,8 @@ mod tests { async fn test_metrics() { let rate_limiter: Arc = Arc::new(Limiter::new(Configuration::default()).await.unwrap()); - let data = web::Data::new(rate_limiter); + let prometheus_metrics: Arc = Arc::new(PrometheusMetrics::default()); + let data = web::Data::new(RateLimitData::new(rate_limiter, prometheus_metrics)); let app = test::init_service( App::new() .app_data(data.clone()) @@ -267,7 +283,8 @@ mod tests { let limit = create_test_limit(&limiter, namespace, 10).await; let rate_limiter: Arc = Arc::new(limiter); - let data = web::Data::new(rate_limiter); + let prometheus_metrics: Arc = Arc::new(PrometheusMetrics::default()); + let data = web::Data::new(RateLimitData::new(rate_limiter, prometheus_metrics)); let app = test::init_service( App::new() .app_data(data.clone()) @@ -293,7 +310,8 @@ mod tests { let namespace = "test_namespace"; let _limit = create_test_limit(&limiter, namespace, 1).await; let rate_limiter: Arc = Arc::new(limiter); - let data = web::Data::new(rate_limiter); + let prometheus_metrics: Arc = Arc::new(PrometheusMetrics::default()); + let data = web::Data::new(RateLimitData::new(rate_limiter, prometheus_metrics)); let app = test::init_service( App::new() .app_data(data.clone()) @@ -337,7 +355,8 @@ mod tests { let _limit = create_test_limit(&limiter, namespace, 1).await; let rate_limiter: Arc = Arc::new(limiter); - let data = web::Data::new(rate_limiter); + let prometheus_metrics: Arc = Arc::new(PrometheusMetrics::default()); + let data = web::Data::new(RateLimitData::new(rate_limiter, prometheus_metrics)); let app = test::init_service( App::new() .app_data(data.clone()) diff --git a/limitador-server/src/main.rs b/limitador-server/src/main.rs index 62633bae..664f1a32 100644 --- a/limitador-server/src/main.rs +++ b/limitador-server/src/main.rs @@ -310,8 +310,6 @@ async fn main() -> Result<(), Box> { .init(); }; - prometheus_metrics.set_use_limit_name_in_label(limit_name_in_metrics); - info!("Version: {}", version); info!("Using config: {:?}", config); (config, prometheus_metrics) diff --git a/limitador-server/src/prometheus_metrics.rs b/limitador-server/src/prometheus_metrics.rs index 286c45fb..3fbe807a 100644 --- a/limitador-server/src/prometheus_metrics.rs +++ b/limitador-server/src/prometheus_metrics.rs @@ -2,7 +2,6 @@ use limitador::limit::Namespace; use prometheus::{ Encoder, Histogram, HistogramOpts, IntCounterVec, IntGauge, Opts, Registry, TextEncoder, }; -use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; const NAMESPACE_LABEL: &str = "limitador_namespace"; @@ -13,7 +12,7 @@ pub struct PrometheusMetrics { authorized_calls: IntCounterVec, limited_calls: IntCounterVec, counter_latency: Histogram, - use_limit_name_label: AtomicBool, + use_limit_name_label: bool, } impl Default for PrometheusMetrics { @@ -35,11 +34,6 @@ impl PrometheusMetrics { Self::new_with_options(true) } - pub fn set_use_limit_name_in_label(&self, use_limit_name_in_label: bool) { - self.use_limit_name_label - .store(use_limit_name_in_label, Ordering::SeqCst) - } - pub fn incr_authorized_calls(&self, namespace: &Namespace) { self.authorized_calls .with_label_values(&[namespace.as_ref()]) @@ -52,7 +46,7 @@ impl PrometheusMetrics { { let mut labels = vec![namespace.as_ref()]; - if self.use_limit_name_label.load(Ordering::Relaxed) { + if self.use_limit_name_label { // If we have configured the metric to accept 2 labels we need to // set values for them. labels.push(limit_name.into().unwrap_or("")); @@ -106,7 +100,7 @@ impl PrometheusMetrics { authorized_calls: authorized_calls_counter, limited_calls: limited_calls_counter, counter_latency, - use_limit_name_label: AtomicBool::new(use_limit_name_label), + use_limit_name_label, } }