Skip to content

Commit

Permalink
keep CompleteMultipartUpload alive (#141)
Browse files Browse the repository at this point in the history
* keep CompleteMultipartUpload alive

* use s3_resp.headers in CompleteMultipartUpload

* propagate error in keepalive body

* unify serialize_error

* set transfer-encoding and trailer header

* add keep_alive_body unit tests

---------

Co-authored-by: Nugine <[email protected]>
  • Loading branch information
lperlaki and Nugine authored May 31, 2024
1 parent 7b949d0 commit dda3c09
Show file tree
Hide file tree
Showing 8 changed files with 367 additions and 121 deletions.
56 changes: 40 additions & 16 deletions codegen/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,11 @@ fn codegen_op_http_ser(op: &Operation, rust_types: &RustTypes) {
}

if is_xml_output(ty) {
g!("http::set_xml_body(&mut res, &x)?;");
if op.name == "CompleteMultipartUpload" {
g!("http::set_xml_body_no_decl(&mut res, &x)?;");
} else {
g!("http::set_xml_body(&mut res, &x)?;");
}
} else if let Some(field) = ty.fields.iter().find(|x| x.position == "payload") {
match field.type_.as_str() {
"Policy" => {
Expand Down Expand Up @@ -652,25 +656,45 @@ fn codegen_op_http_call(op: &Operation) {
g!("let overrided_headers = super::get_object::extract_overrided_response_headers(&s3_req)?;");
}

g!("let result = s3.{method}(s3_req).await;");
if op.name == "CompleteMultipartUpload" {
g!("let s3 = s3.clone();");
g!("let fut = async move {{");
g!("let result = s3.{method}(s3_req).await;");
g!("match result {{");
glines![
"Ok(s3_resp) => {
let mut resp = Self::serialize_http(s3_resp.output)?;
resp.headers.extend(s3_resp.headers);
Ok(resp)
}"
];
g!("Err(err) => super::serialize_error(err, true).map_err(Into::into),");
g!("}}");
g!("}};");
g!("let mut resp = http::Response::with_status(http::StatusCode::OK);");
g!("http::set_keep_alive_xml_body(&mut resp, sync_wrapper::SyncFuture::new(fut), std::time::Duration::from_millis(100))?;");
g!("http::add_opt_header(&mut resp, \"trailer\", Some([X_AMZ_SERVER_SIDE_ENCRYPTION_BUCKET_KEY_ENABLED.as_str(), X_AMZ_EXPIRATION.as_str(), X_AMZ_REQUEST_CHARGED.as_str(), X_AMZ_SERVER_SIDE_ENCRYPTION_AWS_KMS_KEY_ID.as_str(), X_AMZ_SERVER_SIDE_ENCRYPTION.as_str(), X_AMZ_VERSION_ID.as_str()].join(\",\")))?;");
} else {
g!("let result = s3.{method}(s3_req).await;");

glines![
"let s3_resp = match result {"
" Ok(val) => val,"
" Err(err) => return super::serialize_error(err),"
"};"
];
glines![
"let s3_resp = match result {"
" Ok(val) => val,"
" Err(err) => return super::serialize_error(err, false),"
"};"
];

g!("let mut resp = Self::serialize_http(s3_resp.output)?;");
g!("let mut resp = Self::serialize_http(s3_resp.output)?;");

if op.name == "GetObject" {
g!("resp.headers.extend(overrided_headers);");
g!("super::get_object::merge_custom_headers(&mut resp, s3_resp.headers);");
} else {
g!("resp.headers.extend(s3_resp.headers);");
}
if op.name == "GetObject" {
g!("resp.headers.extend(overrided_headers);");
g!("super::get_object::merge_custom_headers(&mut resp, s3_resp.headers);");
} else {
g!("resp.headers.extend(s3_resp.headers);");
}

g!("resp.extensions.extend(s3_resp.extensions);");
g!("resp.extensions.extend(s3_resp.extensions);");
}
g!("Ok(resp)");

g!("}}");
Expand Down
3 changes: 3 additions & 0 deletions crates/s3s/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,8 @@ transform-stream = "0.3.0"
urlencoding = "2.1.3"
zeroize = "1.6.0"

sync_wrapper = { version = "1.0.0", default-features = false }
tokio = { version = "1.31.0", features = ["time"] }

[dev-dependencies]
tokio = { version = "1.31.0", features = ["full"] }
30 changes: 30 additions & 0 deletions crates/s3s/src/http/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ use crate::dto::SelectObjectContentEventStream;
use crate::dto::{Metadata, StreamingBlob, Timestamp, TimestampFormat};
use crate::error::{S3Error, S3Result};
use crate::http::{HeaderName, HeaderValue};
use crate::keep_alive_body::KeepAliveBody;
use crate::utils::format::fmt_timestamp;
use crate::xml;
use crate::StdError;

use std::convert::Infallible;
use std::fmt::Write as _;
Expand Down Expand Up @@ -105,6 +107,34 @@ pub fn set_xml_body<T: xml::Serialize>(res: &mut Response, val: &T) -> S3Result
Ok(())
}

#[allow(clippy::declare_interior_mutable_const)]
const TRANSFER_ENCODING_CHUNKED: HeaderValue = HeaderValue::from_static("chunked");

pub fn set_keep_alive_xml_body(
res: &mut Response,
fut: impl std::future::Future<Output = Result<Response, StdError>> + Send + Sync + 'static,
duration: std::time::Duration,
) -> S3Result {
let mut buf = Vec::with_capacity(40);
let mut ser = xml::Serializer::new(&mut buf);
ser.decl().map_err(S3Error::internal_error)?;

res.body = Body::http_body(KeepAliveBody::with_initial_body(fut, buf.into(), duration));
res.headers.insert(hyper::header::CONTENT_TYPE, APPLICATION_XML);
res.headers
.insert(hyper::header::TRANSFER_ENCODING, TRANSFER_ENCODING_CHUNKED);
Ok(())
}

pub fn set_xml_body_no_decl<T: xml::Serialize>(res: &mut Response, val: &T) -> S3Result {
let mut buf = Vec::with_capacity(256);
let mut ser = xml::Serializer::new(&mut buf);
val.serialize(&mut ser).map_err(S3Error::internal_error)?;
res.body = Body::from(buf);
res.headers.insert(hyper::header::CONTENT_TYPE, APPLICATION_XML);
Ok(())
}

pub fn set_stream_body(res: &mut Response, stream: StreamingBlob) {
res.body = Body::from(stream);
}
Expand Down
163 changes: 163 additions & 0 deletions crates/s3s/src/keep_alive_body.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
time::Duration,
};

use bytes::Bytes;
use http_body::{Body, Frame};
use tokio::time::Interval;

use crate::{http::Response, StdError};

// sends whitespace while the future is pending
pin_project_lite::pin_project! {

pub struct KeepAliveBody<F> {
#[pin]
inner: F,
initial_body: Option<Bytes>,
response: Option<Response>,
interval: Interval,
done: bool,
}
}
impl<F> KeepAliveBody<F> {
pub fn new(inner: F, interval: Duration) -> Self {
Self {
inner,
initial_body: None,
response: None,
interval: tokio::time::interval(interval),
done: false,
}
}

pub fn with_initial_body(inner: F, initial_body: Bytes, interval: Duration) -> Self {
Self {
inner,
initial_body: Some(initial_body),
response: None,
interval: tokio::time::interval(interval),
done: false,
}
}
}

impl<F> Body for KeepAliveBody<F>
where
F: Future<Output = Result<Response, StdError>>,
{
type Data = Bytes;

type Error = StdError;

fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
if self.done {
return Poll::Ready(None);
}
let mut this = self.project();
if let Some(initial_body) = this.initial_body.take() {
cx.waker().wake_by_ref();
return Poll::Ready(Some(Ok(Frame::data(initial_body))));
}
loop {
if let Some(response) = &mut this.response {
let frame = std::task::ready!(Pin::new(&mut response.body).poll_frame(cx)?);
if let Some(frame) = frame {
return Poll::Ready(Some(Ok(frame)));
}
*this.done = true;
return Poll::Ready(Some(Ok(Frame::trailers(std::mem::take(&mut response.headers)))));
}
match this.inner.as_mut().poll(cx) {
Poll::Ready(response) => match response {
Ok(response) => {
*this.response = Some(response);
}
Err(e) => {
*this.done = true;
return Poll::Ready(Some(Err(e)));
}
},
Poll::Pending => match this.interval.poll_tick(cx) {
Poll::Ready(_) => return Poll::Ready(Some(Ok(Frame::data(Bytes::from_static(b" "))))),
Poll::Pending => return Poll::Pending,
},
}
}
}

fn is_end_stream(&self) -> bool {
self.done
}
}

#[cfg(test)]
mod tests {
use http_body_util::BodyExt;
use hyper::{header::HeaderValue, StatusCode};

use super::*;

#[tokio::test]
async fn keep_alive_body() {
let body = KeepAliveBody::with_initial_body(
async {
let mut res = Response::with_status(StatusCode::OK);
res.body = Bytes::from_static(b" world").into();
res.headers.insert("key", HeaderValue::from_static("value"));
Ok(res)
},
Bytes::from_static(b"hello"),
Duration::from_secs(1),
);

let aggregated = body.collect().await.unwrap();

assert_eq!(aggregated.trailers().unwrap().get("key").unwrap(), "value");

let buf = aggregated.to_bytes();

assert_eq!(buf, b"hello world".as_slice());
}

#[tokio::test]
async fn keep_alive_body_no_initial() {
let body = KeepAliveBody::new(
async {
let mut res = Response::with_status(StatusCode::OK);
res.body = Bytes::from_static(b"hello world").into();
Ok(res)
},
Duration::from_secs(1),
);

let aggregated = body.collect().await.unwrap();

let buf = aggregated.to_bytes();

assert_eq!(buf, b"hello world".as_slice());
}

#[tokio::test]
async fn keep_alive_body_fill_withespace() {
let body = KeepAliveBody::new(
async {
tokio::time::sleep(Duration::from_millis(50)).await;

let mut res = Response::with_status(StatusCode::OK);
res.body = Bytes::from_static(b"hello world").into();
Ok(res)
},
Duration::from_millis(10),
);

let aggregated = body.collect().await.unwrap();

let buf = aggregated.to_bytes();

assert_eq!(buf, b" hello world".as_slice());
}
}
1 change: 1 addition & 0 deletions crates/s3s/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub mod service;
pub mod stream;

pub mod checksum;
pub mod keep_alive_body;

pub use self::error::*;
pub use self::http::Body;
Expand Down
Loading

0 comments on commit dda3c09

Please sign in to comment.