diff --git a/ipa-core/src/helpers/transport/stream/buffered.rs b/ipa-core/src/helpers/transport/stream/buffered.rs new file mode 100644 index 000000000..7efc12112 --- /dev/null +++ b/ipa-core/src/helpers/transport/stream/buffered.rs @@ -0,0 +1,311 @@ +use std::{ + mem, + num::NonZeroUsize, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures::Stream; +use pin_project::pin_project; + +use crate::helpers::BytesStream; + +/// An adaptor to buffer items coming from the upstream +/// [`BytesStream`](BytesStream) until the buffer is full, or the upstream is +/// done. This may need to be used when writing into HTTP streams as Hyper +/// does not provide any buffering functionality and we turn NODELAY on +#[pin_project] +pub struct BufferedBytesStream { + /// Inner stream to poll + #[pin] + inner: S, + /// Buffer of bytes pending release + buffer: Vec, + /// Number of bytes released per single poll. + /// All items except the last one are guaranteed to have + /// exactly this number of bytes written to them. + sz: usize, +} + +impl BufferedBytesStream { + fn new(inner: S, buf_size: NonZeroUsize) -> Self { + Self { + inner, + buffer: Vec::with_capacity(buf_size.get()), + sz: buf_size.get(), + } + } +} + +impl Stream for BufferedBytesStream { + type Item = S::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn take_next(buf: &mut Vec) -> Vec { + mem::replace(buf, Vec::with_capacity(buf.len())) + } + + let mut this = self.as_mut().project(); + loop { + // If we are at capacity, return what we have + if this.buffer.len() >= *this.sz { + // if we have more than we need in the buffer, split it + // otherwise, return the whole buffer to the reader + let next = if this.buffer.len() > *this.sz { + this.buffer.drain(..*this.sz).collect() + } else { + take_next(this.buffer) + }; + break Poll::Ready(Some(Ok(Bytes::from(next)))); + } + + match this.inner.as_mut().poll_next(cx) { + Poll::Ready(Some(item)) => { + // Received next portion of data, buffer it + match item { + Ok(bytes) => { + this.buffer.extend(bytes); + } + Err(e) => { + break Poll::Ready(Some(Err(e))); + } + } + } + Poll::Ready(None) => { + // yield what we have because the upstream is done + let next = if this.buffer.is_empty() { + None + } else { + Some(Ok(Bytes::from(take_next(this.buffer)))) + }; + + break Poll::Ready(next); + } + Poll::Pending => { + // we don't have enough data in the buffer (otherwise we wouldn't be here) + break Poll::Pending; + } + } + } + } +} + +#[cfg(all(test, unit_test))] +mod tests { + use std::{ + cmp::min, + mem, + num::NonZeroUsize, + pin::Pin, + sync::{Arc, Mutex}, + task, + task::Poll, + }; + + use bytes::Bytes; + use futures::{stream::TryStreamExt, FutureExt, Stream, StreamExt}; + use pin_project::pin_project; + use proptest::{ + prop_compose, proptest, + strategy::{Just, Strategy}, + }; + use task::Context; + + use crate::{ + error::BoxError, helpers::transport::stream::buffered::BufferedBytesStream, + test_executor::run, + }; + + #[test] + fn success() { + run(|| async move { + verify_success(infallible_stream(11, 2), 3).await; + verify_success(infallible_stream(12, 3), 3).await; + verify_success(infallible_stream(12, 5), 12).await; + verify_success(infallible_stream(12, 12), 12).await; + verify_success(infallible_stream(24, 12), 12).await; + verify_success(infallible_stream(24, 12), 1).await; + }); + } + + #[test] + fn fails_on_first_error() { + run(|| async move { + let stream = fallible_stream(12, 3, 5); + let mut buffered = BufferedBytesStream::new(stream, NonZeroUsize::try_from(2).unwrap()); + let mut buf = Vec::new(); + while let Some(next) = buffered.next().await { + match next { + Ok(bytes) => { + assert_eq!(2, bytes.len()); + buf.extend(bytes); + } + Err(_) => { + break; + } + } + } + + // we could only receive 2 bytes from the stream and here is why. + // first read puts 3 bytes into the buffer and we take 2 bytes off it. + // second read does not have sufficient bytes in the buffer, and we need + // to read from the stream again. Next read results in an error and we + // return it immediately + assert_eq!(2, buf.len()); + }); + } + + #[test] + fn pending() { + let status = Arc::new(Mutex::new(vec![1, 2])); + let stream = futures::stream::poll_fn({ + let status = Arc::clone(&status); + move |_cx| { + let mut vec = status.lock().unwrap(); + if vec.is_empty() { + Poll::Pending + } else { + Poll::Ready(Some(Ok(Bytes::from(mem::take(&mut *vec))))) + } + } + }); + + let mut buffered = BufferedBytesStream::new(stream, NonZeroUsize::try_from(4).unwrap()); + let mut fut = std::pin::pin!(buffered.next()); + assert!(fut.as_mut().now_or_never().is_none()); + + status.lock().unwrap().extend([3, 4]); + let actual = fut.now_or_never().flatten().unwrap().unwrap(); + assert_eq!(Bytes::from(vec![1, 2, 3, 4]), actual); + } + + async fn verify_success(input: TestStream, chunk_size: usize) { + let total_size = input.total_size; + assert!(total_size >= chunk_size); + let expected = input.clone(); + let mut buffered = BufferedBytesStream::new(input, chunk_size.try_into().unwrap()); + + let mut last_chunk_size = None; + let mut actual = Vec::new(); + while let Ok(Some(bytes)) = buffered.try_next().await { + assert!(bytes.len() <= chunk_size); + // All chunks except the last one must be exactly of `chunk_size` size. + if let Some(last) = last_chunk_size { + assert_eq!(last, chunk_size); + } + last_chunk_size = Some(bytes.len()); + actual.extend(bytes); + } + + // compare with what the original stream returned + assert_eq!(actual.len(), total_size); + let expected = expected + .try_collect::>() + .await + .unwrap() + .into_iter() + .flatten() + .collect::>(); + assert_eq!(expected, actual); + } + + #[derive(Debug, Clone)] + struct TestStream { + total_size: usize, + remaining: usize, + chunk: usize, + } + + #[pin_project] + struct FallibleTestStream { + #[pin] + inner: TestStream, + error_after: usize, + } + + fn infallible_stream(total_size: usize, chunk: usize) -> TestStream { + TestStream { + total_size, + remaining: total_size, + chunk, + } + } + + fn fallible_stream(total_size: usize, chunk: usize, error_after: usize) -> FallibleTestStream { + FallibleTestStream { + inner: TestStream { + total_size, + remaining: total_size, + chunk, + }, + error_after, + } + } + + impl Stream for TestStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + if self.remaining == 0 { + return Poll::Ready(None); + } + let next_chunk_size = min(self.remaining, self.chunk); + let next_chunk = (0..next_chunk_size) + .map(|v| u8::try_from(v % 256).unwrap()) + .collect::>(); + + self.remaining -= next_chunk_size; + Poll::Ready(Some(Ok(Bytes::from(next_chunk)))) + } + } + + impl Stream for FallibleTestStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + match this.inner.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(bytes))) => { + if this.inner.total_size - this.inner.remaining >= *this.error_after { + Poll::Ready(Some(Err("error".into()))) + } else { + Poll::Ready(Some(Ok(bytes))) + } + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } + } + + prop_compose! { + fn arb_infallible_stream(max_size: u16) + (total_size in 1..max_size) + (total_size in Just(total_size), chunk in 1..total_size) + -> TestStream { + TestStream { + total_size: total_size as usize, + remaining: total_size as usize, + chunk: chunk as usize, + } + } + } + + fn stream_and_chunk() -> impl Strategy { + arb_infallible_stream(24231).prop_flat_map(|stream| { + let len = stream.total_size; + (Just(stream), 1..len) + }) + } + + proptest! { + #[test] + fn proptest_success((stream, chunk) in stream_and_chunk()) { + run(move || async move { + verify_success(stream, chunk).await; + }); + } + } +} diff --git a/ipa-core/src/helpers/transport/stream/mod.rs b/ipa-core/src/helpers/transport/stream/mod.rs index 5925b62f5..ac39fbb22 100644 --- a/ipa-core/src/helpers/transport/stream/mod.rs +++ b/ipa-core/src/helpers/transport/stream/mod.rs @@ -1,6 +1,8 @@ #[cfg(feature = "web-app")] mod axum_body; mod box_body; +#[allow(dead_code)] +mod buffered; mod collection; mod input;