diff --git a/Cargo.toml b/Cargo.toml index d39083a..772d243 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,8 @@ axum = "0.7" backoff = { version = "0.4", features = ["tokio"] } base64 = "0.22" bytes = "1.6" +clap = { version = "4.5", features = ["derive", "string", "env"] } +clap_derive = "4.5" chacha20poly1305 = "0.10" cloudflare = { git = "https://github.com/cloudflare/cloudflare-rs.git", rev = "f14720e42184ee176a97676e85ef2d2d85bc3aae", default-features = false, features = [ "rustls-tls", @@ -33,6 +35,7 @@ hickory-resolver = { version = "0.24", features = [ http = "1.1" http-body = "1.0" http-body-util = "0.1" +humantime = "2.1" hyper = "1.4" hyper-util = { version = "0.1", features = ["full"] } instant-acme = { version = "0.7.1", default-features = false, features = [ @@ -71,13 +74,14 @@ strum_macros = "0.26" sync_wrapper = "1.0" systemstat = "0.2.3" thiserror = "1.0" -tokio = { version = "1.40", features = ["full"] } +tokio = { version = "1.41", features = ["full"] } tokio-util = { version = "0.7", features = ["full"] } tokio-rustls = { version = "0.26.0", default-features = false, features = [ "tls12", "logging", "ring", ] } +tokio-io-timeout = "1.2" tower = { version = "0.5", features = ["util"] } tower-service = "0.3" tracing = "0.1" diff --git a/src/http/body.rs b/src/http/body.rs index 54f9209..dcf9d2b 100644 --- a/src/http/body.rs +++ b/src/http/body.rs @@ -1,5 +1,6 @@ use std::{ pin::{pin, Pin}, + sync::atomic::{AtomicBool, Ordering}, task::{Context, Poll}, time::Duration, }; @@ -7,10 +8,14 @@ use std::{ use axum::body::Body; use bytes::{Buf, Bytes}; use futures::Stream; +use futures_util::ready; use http_body::{Body as HttpBody, Frame, SizeHint}; use http_body_util::{BodyExt, LengthLimitError, Limited}; use sync_wrapper::SyncWrapper; -use tokio::sync::oneshot::{self, Receiver, Sender}; +use tokio::sync::{ + mpsc, + oneshot::{self, Receiver, Sender}, +}; use super::{calc_headers_size, Error}; @@ -79,6 +84,74 @@ impl Stream for SyncBodyDataStream { } } +/// Body that notifies that it has finished by sending a value over the provided channel. +/// Use AtomicBool flag to make sure we notify only once. +pub struct NotifyingBody { + inner: Pin + Send + 'static>>, + tx: mpsc::Sender, + sig: S, + sent: AtomicBool, +} + +impl NotifyingBody { + pub fn new(inner: B, tx: mpsc::Sender, sig: S) -> Self + where + B: HttpBody + Send + 'static, + D: Buf, + { + Self { + inner: Box::pin(inner), + tx, + sig, + sent: AtomicBool::new(false), + } + } + + fn notify(&self) { + if self + .sent + .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) + == Ok(false) + { + let _ = self.tx.try_send(self.sig.clone()).is_ok(); + } + } +} + +impl HttpBody for NotifyingBody +where + D: Buf, + E: std::string::ToString, +{ + type Data = D; + type Error = E; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let poll = ready!(pin!(&mut self.inner).poll_frame(cx)); + if poll.is_none() { + self.notify(); + } + + Poll::Ready(poll) + } + + fn size_hint(&self) -> SizeHint { + self.inner.size_hint() + } + + fn is_end_stream(&self) -> bool { + let end = self.inner.is_end_stream(); + if end { + self.notify(); + } + + end + } +} + // Body that counts the bytes streamed pub struct CountingBody { inner: Pin + Send + 'static>>, @@ -131,11 +204,11 @@ where mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>>> { - let poll = pin!(&mut self.inner).poll_frame(cx); + let poll = ready!(pin!(&mut self.inner).poll_frame(cx)); match &poll { // There is still some data available - Poll::Ready(Some(v)) => match v { + Some(v) => match v { Ok(buf) => { // Normal data frame if buf.is_data() { @@ -160,17 +233,14 @@ where }, // Nothing left - Poll::Ready(None) => { + None => { // Make borrow checker happy let x = self.bytes_sent; self.finish(Ok(x)); } - - // Do nothing - Poll::Pending => {} } - poll + Poll::Ready(poll) } fn size_hint(&self) -> SizeHint { @@ -184,7 +254,7 @@ mod test { use http_body_util::BodyExt; #[tokio::test] - async fn test_body_stream() { + async fn test_counting_body_stream() { let data = b"foobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarbl\ ahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahbla\ hfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoob\ @@ -206,7 +276,7 @@ mod test { } #[tokio::test] - async fn test_body_full() { + async fn test_counting_body_full() { let data = vec![0; 512]; let buf = bytes::Bytes::from_iter(data.clone()); let body = http_body_util::Full::new(buf); @@ -221,4 +291,27 @@ mod test { let size = rx.await.unwrap().unwrap(); assert_eq!(size, data.len() as u64); } + + #[tokio::test] + async fn test_notifying_body() { + let data = b"foobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarbl\ + ahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahbla\ + hfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoob\ + arblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarbla\ + blahfoobarblahblah"; + + let stream = tokio_util::io::ReaderStream::new(&data[..]); + let body = axum::body::Body::from_stream(stream); + + let sig = 357; + let (tx, mut rx) = mpsc::channel(10); + let body = NotifyingBody::new(body, tx, sig); + + // Check that the body streams the same data back + let body = body.collect().await.unwrap().to_bytes().to_vec(); + assert_eq!(body, data); + + // Make sure we're notified + assert_eq!(sig, rx.recv().await.unwrap()); + } } diff --git a/src/http/server.rs b/src/http/server.rs index 1feca58..7c9dc72 100644 --- a/src/http/server.rs +++ b/src/http/server.rs @@ -5,7 +5,7 @@ use std::{ os::unix::fs::PermissionsExt, path::PathBuf, sync::{ - atomic::{AtomicU64, Ordering}, + atomic::{AtomicU32, AtomicU64, Ordering}, Arc, }, time::{Duration, Instant}, @@ -14,6 +14,7 @@ use std::{ use anyhow::{anyhow, Context}; use async_trait::async_trait; use axum::{extract::Request, Router}; +use http::Response; use hyper::body::Incoming; use hyper_util::{ rt::{TokioExecutor, TokioIo, TokioTimer}, @@ -27,22 +28,26 @@ use rustls::{server::ServerConnection, CipherSuite, ProtocolVersion}; use tokio::{ io::{AsyncRead, AsyncWrite, AsyncWriteExt}, net::{TcpListener, TcpSocket, UnixListener, UnixSocket}, - select, - time::sleep, + pin, select, + sync::mpsc::channel, + time::{sleep, timeout}, }; +use tokio_io_timeout::TimeoutStream; use tokio_rustls::TlsAcceptor; use tokio_util::{sync::CancellationToken, task::TaskTracker}; use tower_service::Service; use tracing::{debug, info, warn}; use uuid::Uuid; -use super::{AsyncCounter, Error, Stats, ALPN_ACME}; +use super::{body::NotifyingBody, AsyncCounter, Error, Stats, ALPN_ACME}; use crate::tasks::Run; const HANDSHAKE_DURATION_BUCKETS: &[f64] = &[0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.4, 0.8, 1.6]; const CONN_DURATION_BUCKETS: &[f64] = &[1.0, 8.0, 32.0, 64.0, 256.0, 512.0, 1024.0]; const CONN_REQUESTS: &[f64] = &[1.0, 4.0, 8.0, 16.0, 32.0, 64.0, 256.0]; +const YEAR: Duration = Duration::from_secs(86400 * 365); + // Blanket async read+write trait for streams Box-ing trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {} impl AsyncReadWrite for T {} @@ -144,6 +149,10 @@ impl Metrics { #[derive(Clone, Copy)] pub struct Options { pub backlog: u32, + pub tls_handshake_timeout: Duration, + pub read_timeout: Option, + pub write_timeout: Option, + pub idle_timeout: Duration, pub http1_header_read_timeout: Duration, pub http2_max_streams: u32, pub http2_keepalive_interval: Duration, @@ -305,21 +314,28 @@ impl Display for Addr { } } +#[derive(Clone)] +enum RequestState { + Start, + End, +} + struct Conn { addr: Addr, remote_addr: Addr, router: Router, builder: Builder, - token: CancellationToken, - token_close: CancellationToken, + token_graceful: CancellationToken, + token_forceful: CancellationToken, options: Options, metrics: Metrics, + requests: AtomicU32, tls_acceptor: Option, } impl Display for Conn { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Server {}: {}", self.addr, self.remote_addr,) + write!(f, "Server {}: {}", self.addr, self.remote_addr) } } @@ -335,10 +351,10 @@ impl Conn { let stream = self .tls_acceptor .as_ref() - .unwrap() + .unwrap() // Caller makes sure it's Some() .accept(stream) .await - .context("unable to accept TLS")?; + .context("TLS accept failed")?; let duration = start.elapsed(); let conn = stream.get_ref().1; @@ -361,7 +377,7 @@ impl Conn { async fn handle(&self, stream: Box) -> Result<(), Error> { let accepted_at = Instant::now(); - debug!("{}: got a new connection", self); + debug!("{self}: got a new connection"); // Prepare metric labels let addr = self.addr.to_string(); @@ -383,25 +399,31 @@ impl Conn { remote_addr: self.remote_addr.clone(), traffic: stats.clone(), req_count: AtomicU64::new(0), - close: self.token_close.clone(), + close: self.token_forceful.clone(), }); // Perform TLS handshake if we're in TLS mode let (stream, tls_info): (Box, _) = if self.tls_acceptor.is_some() { - let (mut stream, tls_info) = self.tls_handshake(stream).await?; + let (mut stream, tls_info) = timeout( + self.options.tls_handshake_timeout, + self.tls_handshake(stream), + ) + .await + .context("TLS handshake timed out")? + .context("TLS handshake failed")?; - // Close the connection if agreed ALPN is ACME - the handshake is enough for challenge + // Close the connection if agreed ALPN is ACME - the handshake is enough for the challenge if tls_info .alpn .as_ref() .is_some_and(|x| x.as_bytes() == ALPN_ACME) { - debug!("{}: ACME ALPN - closing connection", self); + debug!("{self}: ACME ALPN - closing connection"); - stream - .shutdown() + timeout(Duration::from_secs(5), stream.shutdown()) .await - .context("unable to shutdown stream")?; + .context("socket shutdown timed out")? + .context("socket shutdown failed")?; return Ok(()); } @@ -436,11 +458,11 @@ impl Conn { let reqs = conn_info.req_count.load(Ordering::SeqCst); // force-closed - if self.token_close.is_cancelled() { + if self.token_forceful.is_cancelled() { labels[4] = "yes"; } // recycled - if self.token.is_cancelled() { + if self.token_graceful.is_cancelled() { labels[5] = "yes"; } @@ -469,8 +491,8 @@ impl Conn { debug!( "{self}: connection closed (rcvd: {rcvd}, sent: {sent}, reqs: {reqs}, duration: {dur}, graceful: {}, forced close: {})", - self.token.is_cancelled(), - self.token_close.is_cancelled(), + self.token_graceful.is_cancelled(), + self.token_forceful.is_cancelled(), ); result @@ -482,13 +504,26 @@ impl Conn { conn_info: Arc, tls_info: Option>, ) -> Result<(), Error> { + // Create a timer for idle connection tracking + let mut idle_timer = Box::pin(sleep(self.options.idle_timeout)); + + // Create channel to notify about request start/stop. + // Use bounded but big enough so that it's larger than our concurrency. + let (state_tx, mut state_rx) = channel(65536); + + // Apply timeouts on read/write calls + let mut stream = TimeoutStream::new(stream); + stream.set_read_timeout(self.options.read_timeout); + stream.set_write_timeout(self.options.write_timeout); + // Convert stream from Tokio to Hyper let stream = TokioIo::new(stream); - let max_requests_per_conn = self.options.max_requests_per_conn; // Convert router to Hyper service + let max_requests_per_conn = self.options.max_requests_per_conn; let service = hyper::service::service_fn(move |mut request: Request| { - let conn_count = conn_info.req_count.fetch_add(1, Ordering::SeqCst); + // Notify that we have started processing the request + let _ = state_tx.try_send(RequestState::Start); // Inject connection information request.extensions_mut().insert(conn_info.clone()); @@ -496,17 +531,26 @@ impl Conn { request.extensions_mut().insert(v.clone()); } - // Serve the request + // Clone the stuff needed in the async block let mut router = self.router.clone(); - let token = self.token.clone(); + let token = self.token_graceful.clone(); + let conn_info = conn_info.clone(); + let state_tx = state_tx.clone(); + // Return the future async move { - // Get the result - let result = router.call(request).await; + // Execute the request + let result = router.call(request).await.map(|x| { + // Wrap the response body into a notifying one + let (parts, body) = x.into_parts(); + let body = NotifyingBody::new(body, state_tx, RequestState::End); + Response::from_parts(parts, body) + }); // Check if we need to gracefully shutdown this connection if let Some(v) = max_requests_per_conn { - if conn_count + 1 >= v { + let req_count = conn_info.req_count.fetch_add(1, Ordering::SeqCst); + if req_count + 1 >= v { token.cancel(); } } @@ -516,38 +560,78 @@ impl Conn { }); // Serve the connection - let conn = self.builder.serve_connection(stream, service); - // Using mutable future reference requires pinning, otherwise .await consumes it - tokio::pin!(conn); + let conn = self.builder.serve_connection(Box::pin(stream), service); + // Using mutable future reference requires pinning + pin!(conn); - select! { - biased; // Poll top-down + loop { + select! { + biased; // Poll top-down - // Immediately close the connection if was requested - () = self.token_close.cancelled() => { - return Ok(()); - } + // Immediately close the connection if was requested + () = self.token_forceful.cancelled() => { + break; + } - () = self.token.cancelled() => { // Start graceful shutdown of the connection - // For H2: sends GOAWAY frames to the client - // For H1: disables keepalives - conn.as_mut().graceful_shutdown(); - - // Wait for the grace period to finish or connection to complete. - // Connection must still be polled for the shutdown to proceed. - select! { - biased; - () = sleep(self.options.grace_period) => return Ok(()), - _ = conn.as_mut() => {}, - } - } + () = self.token_graceful.cancelled() => { + // For H2: sends GOAWAY frames to the client + // For H1: disables keepalives + conn.as_mut().graceful_shutdown(); + + // Wait for the grace period to finish or connection to complete. + // Connection must still be polled for the shutdown to proceed. + // We don't really care for the result. + let _ = timeout(self.options.grace_period, conn.as_mut()).await; + break; + }, - v = conn.as_mut() => { - if let Err(e) = v { - return Err(anyhow!("unable to serve connection: {e:#}").into()); - } - }, + // Get request state change notifications + Some(v) = state_rx.recv() => { + match v { + RequestState::Start => { + let reqs = self.requests.fetch_add(1, Ordering::SeqCst) + 1; + debug!("{self}: Request started, stopping idle timer (now: {reqs})"); + + // Effectively disable the timer by setting it to 1 year into the future. + // TODO improve? + idle_timer.as_mut().reset(tokio::time::Instant::now() + YEAR); + }, + + RequestState::End => { + let reqs = self.requests.fetch_sub(1, Ordering::SeqCst) - 1; + debug!("{self}: Request finished (now: {reqs})"); + + // Check if the number of outstanding requests is now zero + if reqs == 0 { + debug!("{self}: No outstanding requests, starting timer"); + // Enable the idle timer + idle_timer.as_mut().reset(tokio::time::Instant::now() + self.options.idle_timeout); + } + } + } + }, + + // See if the idle timeout has kicked in + () = idle_timer.as_mut() => { + debug!("{self}: Idle timeout triggered, closing"); + + // Signal that we're closing + conn.as_mut().graceful_shutdown(); + // Give the client some time to shut down + let _ = timeout(Duration::from_secs(3), conn.as_mut()).await; + break; + }, + + // Drive the connection by polling it + v = conn.as_mut() => { + if let Err(e) = v { + return Err(anyhow!("unable to serve connection: {e:#}").into()); + } + + break; + }, + } } Ok(()) @@ -561,6 +645,7 @@ pub struct Server { tracker: TaskTracker, options: Options, metrics: Metrics, + builder: Builder, tls_acceptor: Option, } @@ -572,12 +657,28 @@ impl Server { metrics: Metrics, rustls_cfg: Option, ) -> Self { + // Prepare Hyper connection builder + // It automatically figures out whether to do HTTP1 or HTTP2 + let mut builder = Builder::new(TokioExecutor::new()); + builder + .http1() + .timer(TokioTimer::new()) // Needed for the keepalives below + .header_read_timeout(Some(options.http1_header_read_timeout)) + .keep_alive(true) + .http2() + .adaptive_window(true) + .max_concurrent_streams(Some(options.http2_max_streams)) + .timer(TokioTimer::new()) // Needed for the keepalives below + .keep_alive_interval(Some(options.http2_keepalive_interval)) + .keep_alive_timeout(options.http2_keepalive_timeout); + Self { addr, router, options, metrics, tracker: TaskTracker::new(), + builder, tls_acceptor: rustls_cfg.map(|x| TlsAcceptor::from(Arc::new(x))), } } @@ -597,26 +698,46 @@ impl Server { self.serve_with_listener(listener, token).await } + fn spawn_connection( + &self, + stream: Box, + remote_addr: Addr, + token: CancellationToken, + ) { + // Create a new connection + // Router & TlsAcceptor are both Arc<> inside so it's cheap to clone + // Builder is a bit more complex, but cloning is better than to create it again + let conn = Conn { + addr: self.addr.clone(), + remote_addr: remote_addr.clone(), + router: self.router.clone(), + builder: self.builder.clone(), + token_graceful: token, + token_forceful: CancellationToken::new(), + options: self.options, + metrics: self.metrics.clone(), // All metrics have Arc inside + requests: AtomicU32::new(0), + tls_acceptor: self.tls_acceptor.clone(), + }; + + // Spawn a task to handle connection & track it + self.tracker.spawn(async move { + if let Err(e) = conn.handle(stream).await { + info!( + "Server {}: {}: failed to handle connection: {e:#}", + conn.addr, remote_addr + ); + } + + debug!("Server {}: {}: connection finished", conn.addr, remote_addr); + }); + } + pub async fn serve_with_listener( &self, listener: Listener, token: CancellationToken, ) -> Result<(), Error> { - // Prepare Hyper connection builder - // It automatically figures out whether to do HTTP1 or HTTP2 - let mut builder = Builder::new(TokioExecutor::new()); - builder - .http1() - .timer(TokioTimer::new()) - .header_read_timeout(Some(self.options.http1_header_read_timeout)) - .keep_alive(true) - .http2() - .adaptive_window(true) - .max_concurrent_streams(Some(self.options.http2_max_streams)) - .timer(TokioTimer::new()) // Needed for the keepalives below - .keep_alive_interval(Some(self.options.http2_keepalive_interval)) - .keep_alive_timeout(self.options.http2_keepalive_timeout); - warn!( "Server {}: running (TLS: {})", self.addr, @@ -664,32 +785,7 @@ impl Server { } }; - // Create a new connection - // Router & TlsAcceptor are both Arc<> inside so it's cheap to clone - // Builder is a bit more complex, but cloning is better than to create it again - let conn = Conn { - addr: self.addr.clone(), - remote_addr: remote_addr.clone(), - router: self.router.clone(), - builder: builder.clone(), - token: token.child_token(), - token_close: CancellationToken::new(), - options: self.options, - metrics: self.metrics.clone(), // All metrics have Arc inside - tls_acceptor: self.tls_acceptor.clone(), - }; - - // Spawn a task to handle connection & track it - self.tracker.spawn(async move { - if let Err(e) = conn.handle(stream).await { - info!("Server {}: {}: failed to handle connection: {e:#}", conn.addr, remote_addr); - } - - debug!( - "Server {}: {}: connection finished", - conn.addr, remote_addr - ); - }); + self.spawn_connection(stream, remote_addr, token.child_token()); } } } diff --git a/src/http/shed/cli.rs b/src/http/shed/cli.rs new file mode 100644 index 0000000..5840abf --- /dev/null +++ b/src/http/shed/cli.rs @@ -0,0 +1,118 @@ +use std::str::FromStr; + +use anyhow::{anyhow, Context}; +use clap::Args; +use humantime::parse_duration; + +use super::{sharded::TypeLatency, system::SystemOptions}; +use crate::Error; + +/// Generic parser for TypeLatency in "foo:" format. +/// Supports anything that implements FromStr. +impl FromStr for TypeLatency +where + T::Err: std::error::Error + Send + Sync + Sized + 'static, +{ + type Err = Error; + + fn from_str(s: &str) -> Result { + let (rtype, lat) = s + .split_once(":") + .ok_or_else(|| anyhow!("incorrect format"))?; + + let rtype = T::from_str(rtype).context("unknown request type")?; + let lat = parse_duration(lat).context("unable to parse latency")?; + + Ok(Self(rtype, lat)) + } +} + +#[derive(Debug, Clone, Args)] +pub struct ShedSystem { + /// EWMA alpha coefficient in [0.0, 1.0] range. + /// It represents the weight of the more recent measurements relative to the older ones. + #[clap(env, long, default_value = "0.8")] + pub shed_system_ewma: f64, + + /// CPU load where to start shedding, range [0.0, 1.0] + #[clap(env, long)] + pub shed_system_cpu: Option, + + /// Memory usage where to start shedding, range [0.0, 1.0] + #[clap(env, long)] + pub shed_system_memory: Option, + + /// 1-minute load average where to start shedding, range [0.0, inf) + #[clap(env, long)] + pub shed_system_load_avg_1: Option, + + /// 5-minute load average where to start shedding, range [0.0, inf) + #[clap(env, long)] + pub shed_system_load_avg_5: Option, + + /// 15-minute load average where to start shedding, range [0.0, inf) + #[clap(env, long)] + pub shed_system_load_avg_15: Option, +} + +impl From for SystemOptions { + fn from(v: ShedSystem) -> Self { + Self { + cpu: v.shed_system_cpu, + memory: v.shed_system_memory, + loadavg_1: v.shed_system_load_avg_1, + loadavg_5: v.shed_system_load_avg_5, + loadavg_15: v.shed_system_load_avg_15, + } + } +} + +#[derive(Debug, Clone, Args)] +pub struct ShedSharded +where + T::Err: std::error::Error + Send + Sync + 'static, +{ + /// EWMA alpha coefficient in [0.0, 1.0] range. + /// It represents the weight of the more recent measurements relative to the older ones. + #[clap(env, long, default_value = "0.8")] + pub shed_sharded_ewma: f64, + + /// Number of initial requests to allow through without shedding. + /// This allows for a gradual load buildup avoiding false positives. + #[clap(env, long, default_value = "1000")] + pub shed_sharded_passthrough: u64, + + /// Request types and their target latency, colon separated e.g. "query:100ms". + /// This specifies target latency for Little's load-shedding algorithm for a given request type. + /// Can be specified several times. + /// Important: if the request type is not specified in the list then it's not shedded at all. + #[clap(env, long, value_delimiter = ',')] + pub shed_sharded_latency: Vec>, +} + +#[cfg(test)] +mod test { + use std::time::Duration; + + use super::*; + use crate::types::RequestType; + + #[test] + fn test_type_latency() { + assert!(TypeLatency::::from_str("foo").is_err()); + assert!(TypeLatency::::from_str(":").is_err()); + assert!(TypeLatency::::from_str("foo:100ms").is_err()); + assert!(TypeLatency::::from_str("query:").is_err()); + assert!(TypeLatency::::from_str("query:1gigasecond").is_err()); + + assert_eq!( + TypeLatency::::from_str("query:100ms").unwrap(), + TypeLatency::(RequestType::Query, Duration::from_millis(100)) + ); + + assert_eq!( + TypeLatency::::from_str("sync_call:1s").unwrap(), + TypeLatency::(RequestType::SyncCall, Duration::from_millis(1000)) + ); + } +} diff --git a/src/http/shed/ewma.rs b/src/http/shed/ewma.rs index 9c8f8b3..1f395ad 100644 --- a/src/http/shed/ewma.rs +++ b/src/http/shed/ewma.rs @@ -1,5 +1,6 @@ /// Implementation of Exponential Weighted Moving Average. /// https://en.wikipedia.org/wiki/Exponential_smoothing#Basic_(simple)_exponential_smoothing +#[derive(Debug)] pub struct EWMA { alpha: f64, new: bool, @@ -39,6 +40,7 @@ impl EWMA { /// Implementation of Double Exponential Weighted Moving Average. /// https://en.wikipedia.org/wiki/Exponential_smoothing#Double_exponential_smoothing_(Holt_linear) +#[derive(Debug)] pub struct DEWMA { alpha: f64, beta: f64, diff --git a/src/http/shed/mod.rs b/src/http/shed/mod.rs index 8115292..dfaf95a 100644 --- a/src/http/shed/mod.rs +++ b/src/http/shed/mod.rs @@ -1,43 +1,12 @@ +pub mod cli; pub mod ewma; pub mod little; +pub mod sharded; +pub mod system; -use std::{ - collections::BTreeMap, - fmt::Debug, - future::Future, - pin::Pin, - sync::{Arc, RwLock}, - task::{Context, Poll}, - time::Duration, -}; +use std::{fmt::Debug, future::Future, pin::Pin}; -use anyhow::{anyhow, Context as _}; -use async_trait::async_trait; -use ewma::EWMA; -use little::{LoadShed, LoadShedResponse}; -use systemstat::{Platform, System}; -use tokio_util::sync::CancellationToken; -use tower::{Layer, Service, ServiceExt}; -use tracing::{error, warn}; - -use crate::{tasks::Run, Error}; - -#[async_trait] -pub trait GetsSystemInfo: Send + Sync + Clone { - async fn cpu_usage(&self) -> Result; - fn memory_usage(&self) -> Result; - fn load_avg(&self) -> Result<(f64, f64, f64), Error>; -} - -/// Trait to extract the shedding key from the given request -pub trait TypeExtractor: Clone + Debug + Send + Sync + 'static { - /// The type of the request. - type Type: Clone + Debug + Send + Sync + Ord + 'static; - type Request: Send + Sync; - - /// Extraction method, should return None response when the extraction failed - fn extract(&self, req: &Self::Request) -> Option; -} +pub type BoxFuture = Pin + Send>>; /// Reason for shedding #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -56,559 +25,3 @@ pub enum ShedResponse { /// The request was shed due to overload. Overload(ShedReason), } - -#[derive(Clone)] -pub struct SystemInfo(Arc); - -impl SystemInfo { - pub fn new() -> Self { - Self(Arc::new(System::new())) - } -} - -impl Default for SystemInfo { - fn default() -> Self { - Self::new() - } -} - -#[async_trait] -impl GetsSystemInfo for SystemInfo { - async fn cpu_usage(&self) -> Result { - let cpu = self - .0 - .cpu_load_aggregate() - .context("unable to measure CPU load")?; - tokio::time::sleep(Duration::from_secs(1)).await; - let cpu = cpu.done().context("unable to measure CPU load")?; - - Ok(1.0 - cpu.idle as f64) - } - - fn memory_usage(&self) -> Result { - let mem = self.0.memory().context("unable to measure memory usage")?; - if mem.total.as_u64() == 0 { - return Err(anyhow!("total memory is zero").into()); - } - - Ok(1.0 - mem.free.as_u64() as f64 / mem.total.as_u64() as f64) - } - - fn load_avg(&self) -> Result<(f64, f64, f64), Error> { - let la = self - .0 - .load_average() - .context("unable to measure load average")?; - - Ok((la.one as f64, la.five as f64, la.fifteen as f64)) - } -} - -struct Averages { - cpu: EWMA, - memory: EWMA, - load_avg: (EWMA, EWMA, EWMA), -} - -impl Averages { - fn new(alpha: f64) -> Self { - Self { - cpu: EWMA::new(alpha), - memory: EWMA::new(alpha), - load_avg: (EWMA::new(alpha), EWMA::new(alpha), EWMA::new(alpha)), - } - } -} - -#[derive(Debug, Clone, Copy)] -pub struct SystemOptions { - pub ewma_alpha: f64, - pub cpu: Option, - pub memory: Option, - pub loadavg_1: Option, - pub loadavg_5: Option, - pub loadavg_15: Option, -} - -/// Load shedder that sheds requests when the system load is over the defined thresholds -pub struct SystemLoadShedder { - sys_info: S, - avg: RwLock, - opts: SystemOptions, - inner: I, -} - -impl SystemLoadShedder { - pub fn new(inner: I, opts: SystemOptions, sys_info: S) -> Self { - Self { - sys_info, - avg: RwLock::new(Averages::new(opts.ewma_alpha)), - opts, - inner, - } - } - - async fn measure(&self) -> Result<(), Error> { - let cpu = self.sys_info.cpu_usage().await?; - let mem = self.sys_info.memory_usage()?; - let (l1, l5, l15) = self.sys_info.load_avg()?; - - let mut avg = self.avg.write().unwrap(); - avg.cpu.add(cpu); - avg.memory.add(mem); - avg.load_avg.0.add(l1); - avg.load_avg.1.add(l5); - avg.load_avg.2.add(l15); - drop(avg); // clippy - - Ok(()) - } - - fn evaluate(&self) -> Option { - let avg = self.avg.read().unwrap(); - - if self - .opts - .cpu - .map(|x| avg.cpu.get().unwrap_or(0.0) > x) - .unwrap_or(false) - { - return Some(ShedReason::CPU); - } - - if self - .opts - .memory - .map(|x| avg.memory.get().unwrap_or(0.0) > x) - .unwrap_or(false) - { - return Some(ShedReason::Memory); - } - - if self - .opts - .loadavg_1 - .map(|x| avg.load_avg.0.get().unwrap_or(0.0) > x) - .unwrap_or(false) - { - return Some(ShedReason::LoadAvg); - } - - if self - .opts - .loadavg_5 - .map(|x| avg.load_avg.1.get().unwrap_or(0.0) > x) - .unwrap_or(false) - { - return Some(ShedReason::LoadAvg); - } - - if self - .opts - .loadavg_15 - .map(|x| avg.load_avg.2.get().unwrap_or(0.0) > x) - .unwrap_or(false) - { - return Some(ShedReason::LoadAvg); - } - - None - } -} - -#[async_trait] -impl Run for SystemLoadShedder { - async fn run(&self, token: CancellationToken) -> Result<(), anyhow::Error> { - let mut interval = tokio::time::interval(Duration::from_secs(2)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - - loop { - tokio::select! { - biased; - - () = token.cancelled() => { - warn!("SystemLoadShedder: exiting"); - return Ok(()); - } - - _ = interval.tick() => { - if let Err(e) = self.measure().await { - error!("SystemLoadShedder: error: {e:#}"); - } - }, - } - } - } -} - -type BoxFuture = Pin + Send>>; - -// Implement tower service -impl Service for SystemLoadShedder -where - R: Send + 'static, - I: Service + Clone + Send + Sync + 'static, - I::Future: Send, -{ - type Response = ShedResponse; - type Error = I::Error; - type Future = BoxFuture>; - - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: R) -> Self::Future { - // Check if we need to shed the load - if let Some(v) = self.evaluate() { - return Box::pin(async move { Ok(ShedResponse::Overload(v)) }); - } - - let inner = self.inner.clone(); - Box::pin(async move { - let response = inner.oneshot(req).await; - Ok(ShedResponse::Inner(response?)) - }) - } -} - -#[derive(Debug, Clone)] -pub struct SystemLoadShedderLayer(SystemOptions, S); - -impl SystemLoadShedderLayer { - pub const fn new(opts: SystemOptions, sys_info: S) -> Self { - Self(opts, sys_info) - } -} - -impl Layer for SystemLoadShedderLayer { - type Service = SystemLoadShedder; - - fn layer(&self, inner: I) -> Self::Service { - SystemLoadShedder::new(inner, self.0, self.1.clone()) - } -} - -#[derive(Debug, Clone)] -pub struct ShardedOptions { - pub extractor: T, - pub ewma_alpha: f64, - pub passthrough_count: u64, - pub latencies: Vec<(T::Type, Duration)>, -} - -#[derive(Debug, Clone)] -pub struct ShardedLittleLoadShedder { - extractor: T, - inner: I, - shards: BTreeMap>, -} - -impl ShardedLittleLoadShedder { - pub fn new(inner: I, opts: ShardedOptions) -> Self { - // Generate the shedding shards, one per provided request type - let shards = BTreeMap::from_iter(opts.latencies.into_iter().map(|x| { - ( - x.0, - LoadShed::new(inner.clone(), opts.ewma_alpha, x.1, opts.passthrough_count), - ) - })); - - Self { - extractor: opts.extractor, - inner, - shards, - } - } - - // Tries to find a shard corresponding to the given request - fn get_shard(&self, req: &T::Request) -> Option> { - let req_type = self.extractor.extract(req)?; - self.shards.get(&req_type).cloned() - } -} - -// Implement tower service -impl Service for ShardedLittleLoadShedder -where - I: Service + Clone + Send + Sync + 'static, - I::Future: Send, -{ - type Response = ShedResponse; - type Error = I::Error; - type Future = BoxFuture>; - - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: T::Request) -> Self::Future { - // Try to find if we have a shard - let Some(shard) = self.get_shard(&req) else { - // If we don't - just pass the request to the inner service - let inner = self.inner.clone(); - return Box::pin(async move { Ok(ShedResponse::Inner(inner.oneshot(req).await?)) }); - }; - - Box::pin(async move { - // Map response to our - shard.oneshot(req).await.map(|x| match x { - LoadShedResponse::Overload => ShedResponse::Overload(ShedReason::Latency), - LoadShedResponse::Inner(i) => ShedResponse::Inner(i), - }) - }) - } -} - -#[derive(Debug, Clone)] -pub struct ShardedLittleLoadShedderLayer(ShardedOptions); - -impl ShardedLittleLoadShedderLayer { - pub const fn new(opts: ShardedOptions) -> Self { - Self(opts) - } -} - -impl Layer for ShardedLittleLoadShedderLayer { - type Service = ShardedLittleLoadShedder; - - fn layer(&self, inner: I) -> Self::Service { - ShardedLittleLoadShedder::new(inner, self.0.clone()) - } -} - -#[cfg(test)] -mod test { - use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Mutex, - }; - - use tokio_util::task::TaskTracker; - - use super::*; - - #[derive(Clone, Debug)] - struct StubSystemInfoVal { - cpu: f64, - memory: f64, - l1: f64, - l5: f64, - l15: f64, - } - - #[derive(Clone, Debug)] - struct StubSystemInfo { - v: Arc>, - } - - #[async_trait] - impl GetsSystemInfo for StubSystemInfo { - async fn cpu_usage(&self) -> Result { - Ok(self.v.lock().unwrap().cpu) - } - - fn memory_usage(&self) -> Result { - Ok(self.v.lock().unwrap().memory) - } - - fn load_avg(&self) -> Result<(f64, f64, f64), Error> { - let v = self.v.lock().unwrap(); - Ok((v.l1, v.l5, v.l15)) - } - } - - #[derive(Debug, Clone)] - struct StubService; - - impl Service for StubService { - type Response = (); - type Error = Error; - type Future = BoxFuture>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Duration) -> Self::Future { - let fut = async move { - tokio::time::sleep(req).await; - Ok(()) - }; - - Box::pin(fut) - } - } - - #[derive(Debug, Clone)] - struct StubExtractor(u8); - - impl TypeExtractor for StubExtractor { - type Request = Duration; - type Type = u8; - - fn extract(&self, _req: &Self::Request) -> Option { - return Some(self.0); - } - } - - #[tokio::test] - async fn test_system_shedder() { - let inner = StubService; - let opts = SystemOptions { - ewma_alpha: 0.8, - cpu: Some(0.5), - memory: Some(0.5), - loadavg_1: Some(0.5), - loadavg_5: Some(0.5), - loadavg_15: Some(0.5), - }; - let sys_info = StubSystemInfo { - v: Arc::new(Mutex::new(StubSystemInfoVal { - cpu: 0.0, - memory: 0.0, - l1: 0.0, - l5: 0.0, - l15: 0.0, - })), - }; - - let mut shedder = SystemLoadShedder::new(inner, opts, sys_info.clone()); - let _ = shedder.measure().await; - let resp = shedder.call(Duration::ZERO).await.unwrap(); - assert!(matches!(resp, ShedResponse::Inner(_))); - - sys_info.v.lock().unwrap().cpu = 1.0; - let _ = shedder.measure().await; - let resp = shedder.call(Duration::ZERO).await.unwrap(); - assert_eq!(resp, ShedResponse::Overload(ShedReason::CPU)); - sys_info.v.lock().unwrap().cpu = 0.0; - - sys_info.v.lock().unwrap().memory = 1.0; - let _ = shedder.measure().await; - let resp = shedder.call(Duration::ZERO).await.unwrap(); - assert_eq!(resp, ShedResponse::Overload(ShedReason::Memory)); - sys_info.v.lock().unwrap().memory = 0.0; - - sys_info.v.lock().unwrap().l1 = 1.0; - let _ = shedder.measure().await; - let resp = shedder.call(Duration::ZERO).await.unwrap(); - assert_eq!(resp, ShedResponse::Overload(ShedReason::LoadAvg)); - sys_info.v.lock().unwrap().l1 = 0.0; - - sys_info.v.lock().unwrap().l5 = 1.0; - let _ = shedder.measure().await; - let resp = shedder.call(Duration::ZERO).await.unwrap(); - assert_eq!(resp, ShedResponse::Overload(ShedReason::LoadAvg)); - sys_info.v.lock().unwrap().l5 = 0.0; - - sys_info.v.lock().unwrap().l15 = 1.0; - let _ = shedder.measure().await; - let resp = shedder.call(Duration::ZERO).await.unwrap(); - assert_eq!(resp, ShedResponse::Overload(ShedReason::LoadAvg)); - sys_info.v.lock().unwrap().l15 = 0.0; - - let _ = shedder.measure().await; - let resp = shedder.call(Duration::ZERO).await.unwrap(); - assert!(matches!(resp, ShedResponse::Inner(_))); - } - - #[tokio::test] - async fn test_sharded_shedder() { - let opts = ShardedOptions { - extractor: StubExtractor(0), - passthrough_count: 100, - ewma_alpha: 0.9, - latencies: vec![(0, Duration::from_millis(1))], - }; - let inner = StubService; - - let mut shedder = ShardedLittleLoadShedder::new(inner, opts); - - // Make sure sequential requests are not shedded no matter the latency - for _ in 0..10 { - let resp = shedder.call(Duration::from_millis(10)).await.unwrap(); - assert_eq!(resp, ShedResponse::Inner(())); - } - - // Now try 90 of concurrent requests with high latency - // They shouldn't be shedded due to passthrough_requests - let shedded = Arc::new(AtomicUsize::new(0)); - let tracker = TaskTracker::new(); - for _ in 0..90 { - let shedder = shedder.clone(); - let shedded = shedded.clone(); - - tracker.spawn(async move { - let resp = shedder.oneshot(Duration::from_millis(10)).await.unwrap(); - if matches!(resp, ShedResponse::Overload(ShedReason::Latency)) { - shedded.fetch_add(1, Ordering::SeqCst); - } - }); - } - - tracker.close(); - tracker.wait().await; - assert_eq!(shedded.load(Ordering::SeqCst), 0); - - // Now try 10 of concurrent requests with high latency - // 8 of them should be shedded - let shedded = Arc::new(AtomicUsize::new(0)); - let tracker = TaskTracker::new(); - for _ in 0..10 { - let shedder = shedder.clone(); - let shedded = shedded.clone(); - - tracker.spawn(async move { - let resp = shedder.oneshot(Duration::from_millis(10)).await.unwrap(); - if matches!(resp, ShedResponse::Overload(ShedReason::Latency)) { - shedded.fetch_add(1, Ordering::SeqCst); - } - }); - } - - tracker.close(); - tracker.wait().await; - assert_eq!(shedded.load(Ordering::SeqCst), 8); - - // Now try requests with low latency and limited concurrency - let shedded = Arc::new(AtomicUsize::new(0)); - let tracker = TaskTracker::new(); - let sem = Arc::new(tokio::sync::Semaphore::new(2)); - - for _ in 0..10 { - let shedder = shedder.clone(); - let shedded = shedded.clone(); - let sem = sem.clone(); - - tracker.spawn(async move { - let _permit = sem.acquire().await.unwrap(); - - let resp = shedder.oneshot(Duration::from_millis(1)).await.unwrap(); - if matches!(resp, ShedResponse::Overload(ShedReason::Latency)) { - shedded.fetch_add(1, Ordering::SeqCst); - } - }); - } - - tracker.close(); - tracker.wait().await; - assert_eq!(shedded.load(Ordering::SeqCst), 0); - - // Finally it shouldn't shed - let resp = shedder.oneshot(Duration::from_millis(10)).await.unwrap(); - assert_eq!(resp, ShedResponse::Inner(())); - - // Check that non-existant type still works (extractor returns 1 but we configure only 0) - let opts = ShardedOptions { - extractor: StubExtractor(1), - ewma_alpha: 0.9, - passthrough_count: 0, - latencies: vec![(0, Duration::from_millis(1))], - }; - let inner = StubService; - let mut shedder = ShardedLittleLoadShedder::new(inner, opts); - let resp = shedder.call(Duration::from_millis(50)).await.unwrap(); - assert_eq!(resp, ShedResponse::Inner(())); - } -} diff --git a/src/http/shed/sharded.rs b/src/http/shed/sharded.rs new file mode 100644 index 0000000..3b79f5a --- /dev/null +++ b/src/http/shed/sharded.rs @@ -0,0 +1,261 @@ +use std::{ + collections::BTreeMap, + fmt::Debug, + sync::Arc, + task::{Context, Poll}, + time::Duration, +}; + +use tower::{Layer, Service, ServiceExt}; + +use super::{ + little::{LoadShed, LoadShedResponse}, + BoxFuture, ShedReason, ShedResponse, +}; + +/// Trait to extract the shedding key from the given request +pub trait TypeExtractor: Clone + Debug + Send + Sync + 'static { + /// The type of the request. + type Type: Clone + Debug + Send + Sync + Ord + 'static; + type Request: Send; + + /// Extraction method, should return None response when the extraction failed + fn extract(&self, req: &Self::Request) -> Option; +} + +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct TypeLatency(pub T, pub Duration); + +#[derive(Debug, Clone)] +pub struct ShardedOptions { + pub extractor: T, + pub ewma_alpha: f64, + pub passthrough_count: u64, + pub latencies: Vec>, +} + +#[derive(Debug, Clone)] +pub struct ShardedLittleLoadShedder { + extractor: T, + inner: I, + shards: Arc>>, +} + +impl ShardedLittleLoadShedder { + pub fn new(inner: I, opts: ShardedOptions) -> Self { + // Generate the shedding shards, one per provided request type + let shards = Arc::new(BTreeMap::from_iter(opts.latencies.into_iter().map(|x| { + ( + x.0, + LoadShed::new(inner.clone(), opts.ewma_alpha, x.1, opts.passthrough_count), + ) + }))); + + Self { + extractor: opts.extractor, + inner, + shards, + } + } + + // Tries to find a shard corresponding to the given request + fn get_shard(&self, req: &T::Request) -> Option> { + let req_type = self.extractor.extract(req)?; + self.shards.get(&req_type).cloned() + } +} + +// Implement tower service +impl Service for ShardedLittleLoadShedder +where + I: Service + Clone + Send + Sync + 'static, + I::Future: Send, +{ + type Response = ShedResponse; + type Error = I::Error; + type Future = BoxFuture>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: T::Request) -> Self::Future { + // Try to find if we have a shard + let Some(shard) = self.get_shard(&req) else { + // If we don't - just pass the request to the inner service + let inner = self.inner.clone(); + return Box::pin(async move { Ok(ShedResponse::Inner(inner.oneshot(req).await?)) }); + }; + + Box::pin(async move { + // Map response to our + shard.oneshot(req).await.map(|x| match x { + LoadShedResponse::Overload => ShedResponse::Overload(ShedReason::Latency), + LoadShedResponse::Inner(i) => ShedResponse::Inner(i), + }) + }) + } +} + +#[derive(Debug, Clone)] +pub struct ShardedLittleLoadShedderLayer(ShardedOptions); + +impl ShardedLittleLoadShedderLayer { + pub const fn new(opts: ShardedOptions) -> Self { + Self(opts) + } +} + +impl Layer for ShardedLittleLoadShedderLayer { + type Service = ShardedLittleLoadShedder; + + fn layer(&self, inner: I) -> Self::Service { + ShardedLittleLoadShedder::new(inner, self.0.clone()) + } +} + +#[cfg(test)] +mod test { + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }; + + use tokio_util::task::TaskTracker; + + use super::*; + use crate::Error; + + #[derive(Debug, Clone)] + struct StubService; + + impl Service for StubService { + type Response = (); + type Error = Error; + type Future = BoxFuture>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Duration) -> Self::Future { + let fut = async move { + tokio::time::sleep(req).await; + Ok(()) + }; + + Box::pin(fut) + } + } + + #[derive(Debug, Clone)] + struct StubExtractor(u8); + + impl TypeExtractor for StubExtractor { + type Request = Duration; + type Type = u8; + + fn extract(&self, _req: &Self::Request) -> Option { + return Some(self.0); + } + } + + #[tokio::test] + async fn test_sharded_shedder() { + let opts = ShardedOptions { + extractor: StubExtractor(0), + passthrough_count: 100, + ewma_alpha: 0.9, + latencies: vec![TypeLatency(0, Duration::from_millis(1))], + }; + let inner = StubService; + + let mut shedder = ShardedLittleLoadShedder::new(inner, opts); + + // Make sure sequential requests are not shedded no matter the latency + for _ in 0..10 { + let resp = shedder.call(Duration::from_millis(10)).await.unwrap(); + assert_eq!(resp, ShedResponse::Inner(())); + } + + // Now try 90 of concurrent requests with high latency + // They shouldn't be shedded due to passthrough_requests + let shedded = Arc::new(AtomicUsize::new(0)); + let tracker = TaskTracker::new(); + for _ in 0..90 { + let shedder = shedder.clone(); + let shedded = shedded.clone(); + + tracker.spawn(async move { + let resp = shedder.oneshot(Duration::from_millis(10)).await.unwrap(); + if matches!(resp, ShedResponse::Overload(ShedReason::Latency)) { + shedded.fetch_add(1, Ordering::SeqCst); + } + }); + } + + tracker.close(); + tracker.wait().await; + assert_eq!(shedded.load(Ordering::SeqCst), 0); + + // Now try 10 of concurrent requests with high latency + // 8 of them should be shedded + let shedded = Arc::new(AtomicUsize::new(0)); + let tracker = TaskTracker::new(); + for _ in 0..10 { + let shedder = shedder.clone(); + let shedded = shedded.clone(); + + tracker.spawn(async move { + let resp = shedder.oneshot(Duration::from_millis(10)).await.unwrap(); + if matches!(resp, ShedResponse::Overload(ShedReason::Latency)) { + shedded.fetch_add(1, Ordering::SeqCst); + } + }); + } + + tracker.close(); + tracker.wait().await; + assert_eq!(shedded.load(Ordering::SeqCst), 8); + + // Now try requests with low latency and limited concurrency + let shedded = Arc::new(AtomicUsize::new(0)); + let tracker = TaskTracker::new(); + let sem = Arc::new(tokio::sync::Semaphore::new(2)); + + for _ in 0..10 { + let shedder = shedder.clone(); + let shedded = shedded.clone(); + let sem = sem.clone(); + + tracker.spawn(async move { + let _permit = sem.acquire().await.unwrap(); + + let resp = shedder.oneshot(Duration::from_millis(1)).await.unwrap(); + if matches!(resp, ShedResponse::Overload(ShedReason::Latency)) { + shedded.fetch_add(1, Ordering::SeqCst); + } + }); + } + + tracker.close(); + tracker.wait().await; + assert_eq!(shedded.load(Ordering::SeqCst), 0); + + // Finally it shouldn't shed + let resp = shedder.oneshot(Duration::from_millis(10)).await.unwrap(); + assert_eq!(resp, ShedResponse::Inner(())); + + // Check that non-existent type still works (extractor returns 1 but we configure only 0) + let opts = ShardedOptions { + extractor: StubExtractor(1), + ewma_alpha: 0.9, + passthrough_count: 0, + latencies: vec![TypeLatency(0, Duration::from_millis(1))], + }; + let inner = StubService; + let mut shedder = ShardedLittleLoadShedder::new(inner, opts); + let resp = shedder.call(Duration::from_millis(50)).await.unwrap(); + assert_eq!(resp, ShedResponse::Inner(())); + } +} diff --git a/src/http/shed/system.rs b/src/http/shed/system.rs new file mode 100644 index 0000000..a6fea8c --- /dev/null +++ b/src/http/shed/system.rs @@ -0,0 +1,390 @@ +use std::{ + fmt::Debug, + sync::{Arc, RwLock, RwLockWriteGuard}, + task::{Context, Poll}, + time::Duration, +}; + +use anyhow::{anyhow, Context as _}; +use async_trait::async_trait; +use systemstat::{Platform, System}; +use tower::{Layer, Service, ServiceExt}; +use tracing::{debug, error}; + +use super::{ewma::EWMA, BoxFuture, ShedReason, ShedResponse}; +use crate::Error; + +#[async_trait] +pub trait GetsSystemInfo: Send + Sync + Clone + 'static { + async fn cpu_usage(&self) -> Result; + fn memory_usage(&self) -> Result; + fn load_avg(&self) -> Result<(f64, f64, f64), Error>; +} + +#[derive(Clone)] +pub struct SystemInfo(Arc); + +impl SystemInfo { + pub fn new() -> Self { + Self(Arc::new(System::new())) + } +} + +impl Default for SystemInfo { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl GetsSystemInfo for SystemInfo { + async fn cpu_usage(&self) -> Result { + let cpu = self + .0 + .cpu_load_aggregate() + .context("unable to measure CPU load")?; + tokio::time::sleep(Duration::from_millis(900)).await; + let cpu = cpu.done().context("unable to measure CPU load")?; + + Ok(1.0 - cpu.idle as f64) + } + + fn memory_usage(&self) -> Result { + let mem = self.0.memory().context("unable to measure memory usage")?; + if mem.total.as_u64() == 0 { + return Err(anyhow!("total memory is zero").into()); + } + + Ok(1.0 - mem.free.as_u64() as f64 / mem.total.as_u64() as f64) + } + + fn load_avg(&self) -> Result<(f64, f64, f64), Error> { + let la = self + .0 + .load_average() + .context("unable to measure load average")?; + + Ok((la.one as f64, la.five as f64, la.fifteen as f64)) + } +} + +#[derive(Debug)] +struct StateInner { + cpu: EWMA, + memory: EWMA, + load_avg: (EWMA, EWMA, EWMA), + shed_reason: Option, +} + +impl StateInner { + fn new(alpha: f64) -> Self { + Self { + cpu: EWMA::new(alpha), + memory: EWMA::new(alpha), + load_avg: (EWMA::new(alpha), EWMA::new(alpha), EWMA::new(alpha)), + shed_reason: None, + } + } +} + +#[derive(Debug)] +pub struct State { + opts: SystemOptions, + sys_info: S, + inner: RwLock, +} + +impl State { + pub fn new(alpha: f64, opts: SystemOptions, sys_info: S) -> Self { + Self { + opts, + sys_info, + inner: RwLock::new(StateInner::new(alpha)), + } + } + + // Perform system info measurement + async fn measure(&self) -> Result<(), Error> { + let cpu = self.sys_info.cpu_usage().await?; + let mem = self.sys_info.memory_usage()?; + let (l1, l5, l15) = self.sys_info.load_avg()?; + + let mut inner = self.inner.write().unwrap(); + inner.cpu.add(cpu); + inner.memory.add(mem); + inner.load_avg.0.add(l1); + inner.load_avg.1.add(l5); + inner.load_avg.2.add(l15); + + // Check if we're overloaded + inner.shed_reason = self.evaluate(&inner); + debug!("System load: CPU {cpu}, MEM {mem}, LAVG1: {l1}, LAVG5: {l5}, LAVG15: {l15}, Overload: {:?}", inner.shed_reason); + + drop(inner); // clippy + Ok(()) + } + + fn evaluate(&self, state: &RwLockWriteGuard<'_, StateInner>) -> Option { + if self + .opts + .cpu + .map(|x| state.cpu.get().unwrap_or(0.0) > x) + .unwrap_or(false) + { + return Some(ShedReason::CPU); + } + + if self + .opts + .memory + .map(|x| state.memory.get().unwrap_or(0.0) > x) + .unwrap_or(false) + { + return Some(ShedReason::Memory); + } + + if self + .opts + .loadavg_1 + .map(|x| state.load_avg.0.get().unwrap_or(0.0) > x) + .unwrap_or(false) + { + return Some(ShedReason::LoadAvg); + } + + if self + .opts + .loadavg_5 + .map(|x| state.load_avg.1.get().unwrap_or(0.0) > x) + .unwrap_or(false) + { + return Some(ShedReason::LoadAvg); + } + + if self + .opts + .loadavg_15 + .map(|x| state.load_avg.2.get().unwrap_or(0.0) > x) + .unwrap_or(false) + { + return Some(ShedReason::LoadAvg); + } + + None + } + + fn is_overloaded(&self) -> Option { + self.inner.read().unwrap().shed_reason + } + + // Periodically run the measurements + async fn run(&self) { + // CPU usage measurement takes 900ms so we run every second + let mut interval = tokio::time::interval(Duration::from_secs(1)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + loop { + interval.tick().await; + + if let Err(e) = self.measure().await { + error!("SystemLoadShedder: error: {e:#}"); + } + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct SystemOptions { + pub cpu: Option, + pub memory: Option, + pub loadavg_1: Option, + pub loadavg_5: Option, + pub loadavg_15: Option, +} + +/// Load shedder that sheds requests when the system load is over the defined thresholds +#[derive(Debug, Clone)] +pub struct SystemLoadShedder { + state: Arc>, + inner: I, +} + +impl SystemLoadShedder { + pub const fn new(inner: I, state: Arc>) -> Self { + Self { state, inner } + } +} + +// Implement tower service +impl Service for SystemLoadShedder +where + R: Send + 'static, + I: Service + Clone + Send + Sync + 'static, + I::Future: Send, +{ + type Response = ShedResponse; + type Error = I::Error; + type Future = BoxFuture>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: R) -> Self::Future { + // Check if we need to shed the load + let shed_reason = self.state.is_overloaded(); + if let Some(v) = shed_reason { + return Box::pin(async move { Ok(ShedResponse::Overload(v)) }); + } + + let inner = self.inner.clone(); + Box::pin(async move { + let response = inner.oneshot(req).await; + Ok(ShedResponse::Inner(response?)) + }) + } +} + +#[derive(Debug, Clone)] +pub struct SystemLoadShedderLayer(Arc>); + +impl SystemLoadShedderLayer { + pub fn new(ewma_alpha: f64, opts: SystemOptions, sys_info: S) -> Self { + // Create a state that will be shared among all the shedder instances + let state = Arc::new(State::new(ewma_alpha, opts, sys_info)); + + // Spawn the background task to perform the system measurements + let state_bg = state.clone(); + tokio::spawn(async move { state_bg.run().await }); + + Self(state) + } +} + +impl Layer for SystemLoadShedderLayer { + type Service = SystemLoadShedder; + + fn layer(&self, inner: I) -> Self::Service { + SystemLoadShedder::new(inner, self.0.clone()) + } +} + +#[cfg(test)] +mod test { + use std::sync::Mutex; + + use super::*; + + #[derive(Clone, Debug)] + struct StubSystemInfoVal { + cpu: f64, + memory: f64, + l1: f64, + l5: f64, + l15: f64, + } + + #[derive(Clone, Debug)] + struct StubSystemInfo { + v: Arc>, + } + + #[async_trait] + impl GetsSystemInfo for StubSystemInfo { + async fn cpu_usage(&self) -> Result { + Ok(self.v.lock().unwrap().cpu) + } + + fn memory_usage(&self) -> Result { + Ok(self.v.lock().unwrap().memory) + } + + fn load_avg(&self) -> Result<(f64, f64, f64), Error> { + let v = self.v.lock().unwrap(); + Ok((v.l1, v.l5, v.l15)) + } + } + + #[derive(Debug, Clone)] + struct StubService; + + impl Service for StubService { + type Response = (); + type Error = Error; + type Future = BoxFuture>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Duration) -> Self::Future { + let fut = async move { + tokio::time::sleep(req).await; + Ok(()) + }; + + Box::pin(fut) + } + } + + #[tokio::test] + async fn test_system_shedder() { + let inner = StubService; + let opts = SystemOptions { + cpu: Some(0.5), + memory: Some(0.5), + loadavg_1: Some(0.5), + loadavg_5: Some(0.5), + loadavg_15: Some(0.5), + }; + let sys_info = StubSystemInfo { + v: Arc::new(Mutex::new(StubSystemInfoVal { + cpu: 0.0, + memory: 0.0, + l1: 0.0, + l5: 0.0, + l15: 0.0, + })), + }; + + let state = Arc::new(State::new(0.8, opts, sys_info.clone())); + let mut shedder = SystemLoadShedder::new(inner, state.clone()); + let _ = state.measure().await; + let resp = shedder.call(Duration::ZERO).await.unwrap(); + assert!(matches!(resp, ShedResponse::Inner(_))); + + sys_info.v.lock().unwrap().cpu = 1.0; + let _ = state.measure().await; + let resp = shedder.call(Duration::ZERO).await.unwrap(); + assert_eq!(resp, ShedResponse::Overload(ShedReason::CPU)); + sys_info.v.lock().unwrap().cpu = 0.0; + + sys_info.v.lock().unwrap().memory = 1.0; + let _ = state.measure().await; + let resp = shedder.call(Duration::ZERO).await.unwrap(); + assert_eq!(resp, ShedResponse::Overload(ShedReason::Memory)); + sys_info.v.lock().unwrap().memory = 0.0; + + sys_info.v.lock().unwrap().l1 = 1.0; + let _ = state.measure().await; + let resp = shedder.call(Duration::ZERO).await.unwrap(); + assert_eq!(resp, ShedResponse::Overload(ShedReason::LoadAvg)); + sys_info.v.lock().unwrap().l1 = 0.0; + + sys_info.v.lock().unwrap().l5 = 1.0; + let _ = state.measure().await; + let resp = shedder.call(Duration::ZERO).await.unwrap(); + assert_eq!(resp, ShedResponse::Overload(ShedReason::LoadAvg)); + sys_info.v.lock().unwrap().l5 = 0.0; + + sys_info.v.lock().unwrap().l15 = 1.0; + let _ = state.measure().await; + let resp = shedder.call(Duration::ZERO).await.unwrap(); + assert_eq!(resp, ShedResponse::Overload(ShedReason::LoadAvg)); + sys_info.v.lock().unwrap().l15 = 0.0; + + let _ = state.measure().await; + let resp = shedder.call(Duration::ZERO).await.unwrap(); + assert!(matches!(resp, ShedResponse::Inner(_))); + } +} diff --git a/src/lib.rs b/src/lib.rs index 4ccebcd..cf61859 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ pub mod http; pub mod tasks; pub mod tls; +pub mod types; pub mod vector; /// Generic error diff --git a/src/types.rs b/src/types.rs new file mode 100644 index 0000000..d1c0a6f --- /dev/null +++ b/src/types.rs @@ -0,0 +1,29 @@ +use std::fmt::Debug; + +use strum::{Display, EnumString, IntoStaticStr}; + +/// Type of IC API request +#[derive( + Debug, + Default, + Clone, + Copy, + Display, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + IntoStaticStr, + EnumString, +)] +#[strum(serialize_all = "snake_case")] +pub enum RequestType { + #[default] + Status, + Query, + Call, + SyncCall, + ReadState, + ReadStateSubnet, +}