Skip to content

Commit

Permalink
s3s: error: add headers field
Browse files Browse the repository at this point in the history
  • Loading branch information
Nugine committed Mar 7, 2024
1 parent a4aa637 commit 6d8a5c3
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
17 changes: 17 additions & 0 deletions crates/s3s/src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::fmt;
use std::io::Write;
use std::str::FromStr;

use hyper::HeaderMap;
use hyper::StatusCode;

pub type StdError = Box<dyn std::error::Error + Send + Sync + 'static>;
Expand All @@ -26,6 +27,7 @@ struct Inner {
request_id: Option<String>,
status_code: Option<StatusCode>,
source: Option<StdError>,
headers: Option<HeaderMap>,
}

impl S3Error {
Expand All @@ -38,6 +40,7 @@ impl S3Error {
request_id: None,
status_code: None,
source: None,
headers: None,
}))
}

Expand Down Expand Up @@ -75,6 +78,10 @@ impl S3Error {
self.0.status_code = Some(val);
}

pub fn set_headers(&mut self, val: HeaderMap) {
self.0.headers = Some(val);
}

#[must_use]
pub fn code(&self) -> &S3ErrorCode {
&self.0.code
Expand All @@ -100,6 +107,16 @@ impl S3Error {
self.0.status_code.or_else(|| self.0.code.status_code())
}

#[must_use]
pub fn headers(&self) -> Option<&HeaderMap> {
self.0.headers.as_ref()
}

#[must_use]
pub(crate) fn take_headers(&mut self) -> Option<HeaderMap> {
self.0.headers.take()
}

#[must_use]
pub fn internal_error<E>(source: E) -> Self
where
Expand Down
11 changes: 7 additions & 4 deletions crates/s3s/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,14 @@ fn build_s3_request<T>(input: T, req: &mut Request) -> S3Request<T> {
}
}

fn serialize_error(x: S3Error) -> S3Result<Response> {
let status = x.status_code().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
fn serialize_error(mut e: S3Error) -> S3Result<Response> {
let status = e.status_code().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let mut res = Response::with_status(status);
http::set_xml_body(&mut res, &x)?;
drop(x);
http::set_xml_body(&mut res, &e)?;
if let Some(headers) = e.take_headers() {
res.headers = headers;
}
drop(e);
Ok(res)
}

Expand Down
29 changes: 29 additions & 0 deletions crates/s3s/src/ops/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,32 @@ fn track_future_size() {
assert_eq!(size, expected, "{name:?} size changed: prev {expected}, now {size}");
}
}

#[test]
fn error_custom_headers() {
fn redirect307(location: &str) -> S3Error {
let mut err = S3Error::new(S3ErrorCode::TemporaryRedirect);

err.set_headers({
let mut headers = HeaderMap::new();
headers.insert(crate::header::LOCATION, location.parse().unwrap());
headers
});

err
}

let res = serialize_error(redirect307("http://example.com")).unwrap();
assert_eq!(res.status, StatusCode::TEMPORARY_REDIRECT);
assert_eq!(res.headers.get("location").unwrap(), "http://example.com");

let body = res.body.bytes().unwrap();
let body = std::str::from_utf8(&body).unwrap();
assert_eq!(
body,
concat!(
"<?xml version=\"1.0\" encoding=\"UTF-8\"?>",
"<Error><Code>TemporaryRedirect</Code></Error>"
)
);
}

0 comments on commit 6d8a5c3

Please sign in to comment.