Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

keep CompleteMultipartUpload alive #141

Merged
merged 9 commits into from
May 31, 2024
Merged
56 changes: 40 additions & 16 deletions codegen/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,11 @@ fn codegen_op_http_ser(op: &Operation, rust_types: &RustTypes) {
}

if is_xml_output(ty) {
g!("http::set_xml_body(&mut res, &x)?;");
if op.name == "CompleteMultipartUpload" {
g!("http::set_xml_body_no_decl(&mut res, &x)?;");
} else {
g!("http::set_xml_body(&mut res, &x)?;");
}
} else if let Some(field) = ty.fields.iter().find(|x| x.position == "payload") {
match field.type_.as_str() {
"Policy" => {
Expand Down Expand Up @@ -652,25 +656,45 @@ fn codegen_op_http_call(op: &Operation) {
g!("let overrided_headers = super::get_object::extract_overrided_response_headers(&s3_req)?;");
}

g!("let result = s3.{method}(s3_req).await;");
if op.name == "CompleteMultipartUpload" {
g!("let s3 = s3.clone();");
g!("let fut = async move {{");
g!("let result = s3.{method}(s3_req).await;");
g!("match result {{");
glines![
"Ok(s3_resp) => {
let mut resp = Self::serialize_http(s3_resp.output)?;
resp.headers.extend(s3_resp.headers);
Ok(resp)
}"
];
g!("Err(err) => super::serialize_error(err, true).map_err(Into::into),");
g!("}}");
g!("}};");
g!("let mut resp = http::Response::with_status(http::StatusCode::OK);");
g!("http::set_keep_alive_xml_body(&mut resp, sync_wrapper::SyncFuture::new(fut), std::time::Duration::from_millis(100))?;");
g!("http::add_opt_header(&mut resp, \"trailer\", Some([X_AMZ_SERVER_SIDE_ENCRYPTION_BUCKET_KEY_ENABLED.as_str(), X_AMZ_EXPIRATION.as_str(), X_AMZ_REQUEST_CHARGED.as_str(), X_AMZ_SERVER_SIDE_ENCRYPTION_AWS_KMS_KEY_ID.as_str(), X_AMZ_SERVER_SIDE_ENCRYPTION.as_str(), X_AMZ_VERSION_ID.as_str()].join(\",\")))?;");
} else {
g!("let result = s3.{method}(s3_req).await;");

glines![
"let s3_resp = match result {"
" Ok(val) => val,"
" Err(err) => return super::serialize_error(err),"
"};"
];
glines![
"let s3_resp = match result {"
" Ok(val) => val,"
" Err(err) => return super::serialize_error(err, false),"
"};"
];

g!("let mut resp = Self::serialize_http(s3_resp.output)?;");
g!("let mut resp = Self::serialize_http(s3_resp.output)?;");

if op.name == "GetObject" {
g!("resp.headers.extend(overrided_headers);");
g!("super::get_object::merge_custom_headers(&mut resp, s3_resp.headers);");
} else {
g!("resp.headers.extend(s3_resp.headers);");
}
if op.name == "GetObject" {
g!("resp.headers.extend(overrided_headers);");
g!("super::get_object::merge_custom_headers(&mut resp, s3_resp.headers);");
} else {
g!("resp.headers.extend(s3_resp.headers);");
}

g!("resp.extensions.extend(s3_resp.extensions);");
g!("resp.extensions.extend(s3_resp.extensions);");
}
g!("Ok(resp)");

g!("}}");
Expand Down
3 changes: 3 additions & 0 deletions crates/s3s/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,8 @@ transform-stream = "0.3.0"
urlencoding = "2.1.3"
zeroize = "1.6.0"

sync_wrapper = { version = "1.0.0", default-features = false }
tokio = { version = "1.31.0", features = ["time"] }

[dev-dependencies]
tokio = { version = "1.31.0", features = ["full"] }
30 changes: 30 additions & 0 deletions crates/s3s/src/http/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ use crate::dto::SelectObjectContentEventStream;
use crate::dto::{Metadata, StreamingBlob, Timestamp, TimestampFormat};
use crate::error::{S3Error, S3Result};
use crate::http::{HeaderName, HeaderValue};
use crate::keep_alive_body::KeepAliveBody;
use crate::utils::format::fmt_timestamp;
use crate::xml;
use crate::StdError;

use std::convert::Infallible;
use std::fmt::Write as _;
Expand Down Expand Up @@ -105,6 +107,34 @@ pub fn set_xml_body<T: xml::Serialize>(res: &mut Response, val: &T) -> S3Result
Ok(())
}

#[allow(clippy::declare_interior_mutable_const)]
const TRANSFER_ENCODING_CHUNKED: HeaderValue = HeaderValue::from_static("chunked");

pub fn set_keep_alive_xml_body(
res: &mut Response,
fut: impl std::future::Future<Output = Result<Response, StdError>> + Send + Sync + 'static,
duration: std::time::Duration,
) -> S3Result {
let mut buf = Vec::with_capacity(40);
let mut ser = xml::Serializer::new(&mut buf);
ser.decl().map_err(S3Error::internal_error)?;

res.body = Body::http_body(KeepAliveBody::with_initial_body(fut, buf.into(), duration));
res.headers.insert(hyper::header::CONTENT_TYPE, APPLICATION_XML);
res.headers
.insert(hyper::header::TRANSFER_ENCODING, TRANSFER_ENCODING_CHUNKED);
Ok(())
}

pub fn set_xml_body_no_decl<T: xml::Serialize>(res: &mut Response, val: &T) -> S3Result {
let mut buf = Vec::with_capacity(256);
let mut ser = xml::Serializer::new(&mut buf);
val.serialize(&mut ser).map_err(S3Error::internal_error)?;
res.body = Body::from(buf);
res.headers.insert(hyper::header::CONTENT_TYPE, APPLICATION_XML);
Ok(())
}

pub fn set_stream_body(res: &mut Response, stream: StreamingBlob) {
res.body = Body::from(stream);
}
Expand Down
163 changes: 163 additions & 0 deletions crates/s3s/src/keep_alive_body.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
time::Duration,
};

use bytes::Bytes;
use http_body::{Body, Frame};
use tokio::time::Interval;

use crate::{http::Response, StdError};

// sends whitespace while the future is pending
pin_project_lite::pin_project! {

pub struct KeepAliveBody<F> {
#[pin]
inner: F,
initial_body: Option<Bytes>,
response: Option<Response>,
interval: Interval,
done: bool,
}
}
impl<F> KeepAliveBody<F> {
pub fn new(inner: F, interval: Duration) -> Self {
Self {
inner,
initial_body: None,
response: None,
interval: tokio::time::interval(interval),
done: false,
}
}

pub fn with_initial_body(inner: F, initial_body: Bytes, interval: Duration) -> Self {
Self {
inner,
initial_body: Some(initial_body),
response: None,
interval: tokio::time::interval(interval),
done: false,
}
}
}

impl<F> Body for KeepAliveBody<F>
where
F: Future<Output = Result<Response, StdError>>,
{
type Data = Bytes;

type Error = StdError;

fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
if self.done {
return Poll::Ready(None);
}
let mut this = self.project();
if let Some(initial_body) = this.initial_body.take() {
cx.waker().wake_by_ref();
return Poll::Ready(Some(Ok(Frame::data(initial_body))));
}
loop {
if let Some(response) = &mut this.response {
let frame = std::task::ready!(Pin::new(&mut response.body).poll_frame(cx)?);
if let Some(frame) = frame {
return Poll::Ready(Some(Ok(frame)));
}
*this.done = true;
return Poll::Ready(Some(Ok(Frame::trailers(std::mem::take(&mut response.headers)))));
}
match this.inner.as_mut().poll(cx) {
Poll::Ready(response) => match response {
Ok(response) => {
*this.response = Some(response);
}
Err(e) => {
*this.done = true;
return Poll::Ready(Some(Err(e)));
}
},
Poll::Pending => match this.interval.poll_tick(cx) {
Poll::Ready(_) => return Poll::Ready(Some(Ok(Frame::data(Bytes::from_static(b" "))))),
Poll::Pending => return Poll::Pending,
},
}
}
}

fn is_end_stream(&self) -> bool {
self.done
}
}

#[cfg(test)]
mod tests {
use http_body_util::BodyExt;
use hyper::{header::HeaderValue, StatusCode};

use super::*;

#[tokio::test]
async fn keep_alive_body() {
let body = KeepAliveBody::with_initial_body(
async {
let mut res = Response::with_status(StatusCode::OK);
res.body = Bytes::from_static(b" world").into();
res.headers.insert("key", HeaderValue::from_static("value"));
Ok(res)
},
Bytes::from_static(b"hello"),
Duration::from_secs(1),
);

let aggregated = body.collect().await.unwrap();

assert_eq!(aggregated.trailers().unwrap().get("key").unwrap(), "value");

let buf = aggregated.to_bytes();

assert_eq!(buf, b"hello world".as_slice());
}

#[tokio::test]
async fn keep_alive_body_no_initial() {
let body = KeepAliveBody::new(
async {
let mut res = Response::with_status(StatusCode::OK);
res.body = Bytes::from_static(b"hello world").into();
Ok(res)
},
Duration::from_secs(1),
);

let aggregated = body.collect().await.unwrap();

let buf = aggregated.to_bytes();

assert_eq!(buf, b"hello world".as_slice());
}

#[tokio::test]
async fn keep_alive_body_fill_withespace() {
let body = KeepAliveBody::new(
async {
tokio::time::sleep(Duration::from_millis(50)).await;

let mut res = Response::with_status(StatusCode::OK);
res.body = Bytes::from_static(b"hello world").into();
Ok(res)
},
Duration::from_millis(10),
);

let aggregated = body.collect().await.unwrap();

let buf = aggregated.to_bytes();

assert_eq!(buf, b" hello world".as_slice());
}
}
1 change: 1 addition & 0 deletions crates/s3s/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub mod service;
pub mod stream;

pub mod checksum;
pub mod keep_alive_body;

pub use self::error::*;
pub use self::http::Body;
Expand Down
Loading