Skip to content

Commit

Permalink
refactor(core): use BodyDataStream
Browse files Browse the repository at this point in the history
  • Loading branch information
fundon committed Sep 30, 2024
1 parent 3983c73 commit f6f514d
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 60 deletions.
42 changes: 1 addition & 41 deletions viz-core/src/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
};

use bytes::Bytes;
use futures_util::{Stream, TryStream, TryStreamExt};
use futures_util::{TryStream, TryStreamExt};
use http_body_util::{combinators::UnsyncBoxBody, BodyExt, BodyStream, Full, StreamBody};
use hyper::body::{Frame, Incoming, SizeHint};
use sync_wrapper::SyncWrapper;
Expand Down Expand Up @@ -53,7 +53,6 @@ impl Body {
B::Data: Into<Bytes>,
B::Error: Into<BoxError>,
{
// Copied from Axum, thanks.
let mut body = Some(body);
<dyn std::any::Any>::downcast_mut::<Option<UnsyncBoxBody<Bytes, Error>>>(&mut body)
.and_then(Option::take)
Expand Down Expand Up @@ -133,45 +132,6 @@ impl HttpBody for Body {
}
}

impl Stream for Body {
type Item = Result<Bytes, std::io::Error>;

#[inline]
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match match self.get_mut() {
Self::Empty => return Poll::Ready(None),
Self::Full(inner) => Pin::new(inner)
.poll_frame(cx)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?,
Self::Boxed(inner) => Pin::new(inner)
.get_pin_mut()
.poll_frame(cx)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?,
Self::Incoming(inner) => Pin::new(inner)
.poll_frame(cx)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?,
} {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(frame)) => Poll::Ready(frame.into_data().map(Ok).ok()),
}
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let sh = match self {
Self::Empty => return (0, Some(0)),
Self::Full(inner) => inner.size_hint(),
Self::Boxed(_) => return (0, None),
Self::Incoming(inner) => inner.size_hint(),
};
(
usize::try_from(sh.lower()).unwrap_or(usize::MAX),
sh.upper().map(|v| usize::try_from(v).unwrap_or(usize::MAX)),
)
}
}

impl From<()> for Body {
fn from((): ()) -> Self {
Self::Empty
Expand Down
7 changes: 6 additions & 1 deletion viz-core/src/middleware/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
use std::str::FromStr;

use async_compression::tokio::bufread;
use futures_util::TryStreamExt;
use http_body_util::BodyExt;
use tokio_util::io::{ReaderStream, StreamReader};

use crate::{
Expand Down Expand Up @@ -78,7 +80,10 @@ impl<T: IntoResponse> IntoResponse for Compress<T> {
match self.algo {
ContentCoding::Gzip | ContentCoding::Deflate | ContentCoding::Brotli => {
res = res.map(|body| {
let body = StreamReader::new(body);
let body = StreamReader::new(
body.into_data_stream()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)),
);
if self.algo == ContentCoding::Gzip {
Body::from_stream(ReaderStream::new(bufread::GzipEncoder::new(body)))
} else if self.algo == ContentCoding::Deflate {
Expand Down
7 changes: 5 additions & 2 deletions viz-core/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,10 @@ impl RequestExt for Request {
.ok_or(PayloadError::MissingBoundary)?
.as_str();

Ok(Multipart::new(self.incoming()?, boundary))
Ok(Multipart::new(
self.incoming()?.into_data_stream(),
boundary,
))
}

#[cfg(feature = "state")]
Expand Down Expand Up @@ -495,7 +498,7 @@ impl RequestLimitsExt for Request {
.ok_or(PayloadError::MissingBoundary)?
.as_str();
Ok(Multipart::with_limits(
self.incoming()?,
self.incoming()?.into_data_stream(),
boundary,
self.extensions()
.get::<std::sync::Arc<crate::types::MultipartLimits>>()
Expand Down
3 changes: 2 additions & 1 deletion viz-core/src/types/multipart.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Represents a Multipart extractor.
use form_data::FormData;
use http_body_util::BodyDataStream;

use crate::{Body, Error, FromRequest, IntoResponse, Request, RequestExt, Response, StatusCode};

Expand All @@ -9,7 +10,7 @@ use super::{Payload, PayloadError};
pub use form_data::{Error as MultipartError, Limits as MultipartLimits};

/// Extracts the data from the multipart body of a request.
pub type Multipart<T = Body> = FormData<T>;
pub type Multipart<T = BodyDataStream<Body>> = FormData<T>;

impl<T> Payload for Multipart<T> {
const NAME: &'static str = "multipart";
Expand Down
4 changes: 2 additions & 2 deletions viz-core/tests/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use test::Bencher;

use futures_util::{stream, Stream, StreamExt};
use headers::{ContentDisposition, ContentType, HeaderMapExt};
use http_body_util::{BodyExt, Full};
use http_body_util::{BodyDataStream, BodyExt, Full};
use serde::{Deserialize, Serialize};
use viz_core::{
header::{CONTENT_DISPOSITION, CONTENT_LOCATION, LOCATION},
Expand Down Expand Up @@ -78,7 +78,7 @@ async fn response_ext() -> Result<()> {

let resp = Response::stream(stream::repeat("viz").take(2).map(Result::<_, Error>::Ok));
assert!(resp.ok());
let body: Body = resp.into_body();
let body: BodyDataStream<_> = resp.into_body().into_data_stream();
assert_eq!(Stream::size_hint(&body), (0, None));
let (item, stream) = body.into_future().await;
assert_eq!(item.unwrap().unwrap().to_vec(), b"viz");
Expand Down
33 changes: 21 additions & 12 deletions viz-test/tests/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,26 @@ async fn incoming_stream() -> Result<()> {
use viz::Router;
use viz_test::TestServer;

let empty = Body::Empty;
let empty = Body::Empty
.into_data_stream()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
assert_eq!(Stream::size_hint(&empty), (0, Some(0)));
let mut reader =
TryStreamExt::map_err(empty, |e| std::io::Error::new(std::io::ErrorKind::Other, e))
.into_async_read();
let mut reader = empty.into_async_read();
let mut buf = Vec::new();
reader.read_to_end(&mut buf).await?;
assert!(buf.is_empty());

let router = Router::new()
.post("/login-empty", |mut req: Request| async move {
let mut body = req.incoming()?;
let mut body = req.incoming()?.into_data_stream();
let size_hint = Stream::size_hint(&body);
assert_eq!(size_hint.0, 0);
assert_eq!(size_hint.1, Some(0));
assert!(body.next().await.is_none());
Ok(())
})
.post("/login", |mut req: Request| async move {
let mut body = req.incoming()?;
let mut body = req.incoming()?.into_data_stream();
let size_hint = Stream::size_hint(&body);
assert_eq!(size_hint.0, 12);
assert_eq!(size_hint.1, Some(12));
Expand Down Expand Up @@ -195,35 +195,44 @@ async fn outgoing_body() -> Result<()> {
async fn outgoing_stream() -> Result<()> {
use futures_util::{AsyncReadExt, Stream, StreamExt, TryStreamExt};

let empty = Body::<Bytes>::Empty;
let empty = Body::<Bytes>::Empty
.into_data_stream()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
assert_eq!(Stream::size_hint(&empty), (0, Some(0)));
let mut reader = empty.into_async_read();
let mut buf = Vec::new();
reader.read_to_end(&mut buf).await?;
assert!(buf.is_empty());

let full_none = Body::from(Full::new(Bytes::new()));
let full_none = Body::from(Full::new(Bytes::new()))
.into_data_stream()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
assert_eq!(Stream::size_hint(&full_none), (0, Some(0)));
let mut reader = full_none.into_async_read();
let mut buf = Vec::new();
reader.read_to_end(&mut buf).await?;
assert!(buf.is_empty());

let mut full_some: Body = Full::new(Bytes::from(vec![1, 0, 2, 4])).into();
let mut full_some = Full::new(Bytes::from(vec![1, 0, 2, 4]))
.into_data_stream()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
assert_eq!(Stream::size_hint(&full_some), (4, Some(4)));
assert_eq!(full_some.next().await.unwrap().unwrap(), vec![1, 0, 2, 4]);
assert_eq!(Stream::size_hint(&full_some), (0, Some(0)));
assert!(full_some.next().await.is_none());

let boxed: Body = UnsyncBoxBody::new(Full::new(Bytes::new()).map_err(Into::into)).into();
let boxed = UnsyncBoxBody::new(Full::new(Bytes::new()))
.into_data_stream()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
assert_eq!(Stream::size_hint(&boxed), (0, None));
let mut reader = boxed.into_async_read();
let mut buf = Vec::new();
reader.read_to_end(&mut buf).await?;
assert!(buf.is_empty());

let mut boxed: Body =
UnsyncBoxBody::new(Full::new(Bytes::from(vec![2, 0, 4, 8])).map_err(Into::into)).into();
let mut boxed = UnsyncBoxBody::new(Full::new(Bytes::from(vec![2, 0, 4, 8])))
.into_data_stream()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
assert_eq!(Stream::size_hint(&boxed), (0, None));
assert_eq!(boxed.next().await.unwrap().unwrap(), vec![2, 0, 4, 8]);
assert_eq!(Stream::size_hint(&boxed), (0, None));
Expand Down
2 changes: 1 addition & 1 deletion viz-test/tests/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async fn response_ext() -> Result<()> {

let resp = Response::stream(stream::repeat("viz").take(2).map(Result::<_, Error>::Ok));
assert!(resp.ok());
let body: Body = resp.into_body();
let body = resp.into_body().into_data_stream();
assert_eq!(Stream::size_hint(&body), (0, None));
let (item, stream) = body.into_future().await;
assert_eq!(item.unwrap().unwrap().to_vec(), b"viz");
Expand Down

0 comments on commit f6f514d

Please sign in to comment.