Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(BOUN-1215 BOUN-1282): Finish load shedding, add various idle timers #4

Merged
merged 6 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ axum = "0.7"
backoff = { version = "0.4", features = ["tokio"] }
base64 = "0.22"
bytes = "1.6"
clap = { version = "4.5", features = ["derive", "string", "env"] }
clap_derive = "4.5"
chacha20poly1305 = "0.10"
cloudflare = { git = "https://github.com/cloudflare/cloudflare-rs.git", rev = "f14720e42184ee176a97676e85ef2d2d85bc3aae", default-features = false, features = [
"rustls-tls",
Expand All @@ -33,6 +35,7 @@ hickory-resolver = { version = "0.24", features = [
http = "1.1"
http-body = "1.0"
http-body-util = "0.1"
humantime = "2.1"
hyper = "1.4"
hyper-util = { version = "0.1", features = ["full"] }
instant-acme = { version = "0.7.1", default-features = false, features = [
Expand Down Expand Up @@ -71,13 +74,14 @@ strum_macros = "0.26"
sync_wrapper = "1.0"
systemstat = "0.2.3"
thiserror = "1.0"
tokio = { version = "1.40", features = ["full"] }
tokio = { version = "1.41", features = ["full"] }
tokio-util = { version = "0.7", features = ["full"] }
tokio-rustls = { version = "0.26.0", default-features = false, features = [
"tls12",
"logging",
"ring",
] }
tokio-io-timeout = "1.2"
tower = { version = "0.5", features = ["util"] }
tower-service = "0.3"
tracing = "0.1"
Expand Down
113 changes: 103 additions & 10 deletions src/http/body.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
use std::{
pin::{pin, Pin},
sync::atomic::{AtomicBool, Ordering},
task::{Context, Poll},
time::Duration,
};

use axum::body::Body;
use bytes::{Buf, Bytes};
use futures::Stream;
use futures_util::ready;
use http_body::{Body as HttpBody, Frame, SizeHint};
use http_body_util::{BodyExt, LengthLimitError, Limited};
use sync_wrapper::SyncWrapper;
use tokio::sync::oneshot::{self, Receiver, Sender};
use tokio::sync::{
mpsc,
oneshot::{self, Receiver, Sender},
};

use super::{calc_headers_size, Error};

Expand Down Expand Up @@ -79,6 +84,74 @@ impl Stream for SyncBodyDataStream {
}
}

/// Body that notifies that it has finished by sending a value over the provided channel.
/// Use AtomicBool flag to make sure we notify only once.
pub struct NotifyingBody<D, E, S: Clone + Unpin> {
inner: Pin<Box<dyn HttpBody<Data = D, Error = E> + Send + 'static>>,
tx: mpsc::Sender<S>,
sig: S,
sent: AtomicBool,
}

impl<D, E, S: Clone + Unpin> NotifyingBody<D, E, S> {
pub fn new<B>(inner: B, tx: mpsc::Sender<S>, sig: S) -> Self
where
B: HttpBody<Data = D, Error = E> + Send + 'static,
D: Buf,
{
Self {
inner: Box::pin(inner),
tx,
sig,
sent: AtomicBool::new(false),
}
}

fn notify(&self) {
if self
.sent
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
== Ok(false)
{
let _ = self.tx.try_send(self.sig.clone()).is_ok();
}
}
}

impl<D, E, S: Clone + Unpin> HttpBody for NotifyingBody<D, E, S>
where
D: Buf,
E: std::string::ToString,
{
type Data = D;
type Error = E;

fn poll_frame(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let poll = ready!(pin!(&mut self.inner).poll_frame(cx));
if poll.is_none() {
self.notify();
}

Poll::Ready(poll)
}

fn size_hint(&self) -> SizeHint {
self.inner.size_hint()
}

fn is_end_stream(&self) -> bool {
let end = self.inner.is_end_stream();
if end {
self.notify();
}

end
}
}

// Body that counts the bytes streamed
pub struct CountingBody<D, E> {
inner: Pin<Box<dyn HttpBody<Data = D, Error = E> + Send + 'static>>,
Expand Down Expand Up @@ -131,11 +204,11 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let poll = pin!(&mut self.inner).poll_frame(cx);
let poll = ready!(pin!(&mut self.inner).poll_frame(cx));

match &poll {
// There is still some data available
Poll::Ready(Some(v)) => match v {
Some(v) => match v {
Ok(buf) => {
// Normal data frame
if buf.is_data() {
Expand All @@ -160,17 +233,14 @@ where
},

// Nothing left
Poll::Ready(None) => {
None => {
// Make borrow checker happy
let x = self.bytes_sent;
self.finish(Ok(x));
}

// Do nothing
Poll::Pending => {}
}

poll
Poll::Ready(poll)
}

fn size_hint(&self) -> SizeHint {
Expand All @@ -184,7 +254,7 @@ mod test {
use http_body_util::BodyExt;

#[tokio::test]
async fn test_body_stream() {
async fn test_counting_body_stream() {
let data = b"foobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarbl\
ahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahbla\
hfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoob\
Expand All @@ -206,7 +276,7 @@ mod test {
}

#[tokio::test]
async fn test_body_full() {
async fn test_counting_body_full() {
let data = vec![0; 512];
let buf = bytes::Bytes::from_iter(data.clone());
let body = http_body_util::Full::new(buf);
Expand All @@ -221,4 +291,27 @@ mod test {
let size = rx.await.unwrap().unwrap();
assert_eq!(size, data.len() as u64);
}

#[tokio::test]
async fn test_notifying_body() {
let data = b"foobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarbl\
ahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahbla\
hfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoob\
arblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarbla\
blahfoobarblahblah";

let stream = tokio_util::io::ReaderStream::new(&data[..]);
let body = axum::body::Body::from_stream(stream);

let sig = 357;
let (tx, mut rx) = mpsc::channel(10);
let body = NotifyingBody::new(body, tx, sig);

// Check that the body streams the same data back
let body = body.collect().await.unwrap().to_bytes().to_vec();
assert_eq!(body, data);

// Make sure we're notified
assert_eq!(sig, rx.recv().await.unwrap());
}
}
Loading