Skip to content

Commit

Permalink
convert multer's Field headers to http 1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn committed Nov 22, 2023
1 parent f93a242 commit b9983b6
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 8 deletions.
3 changes: 0 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,3 @@ resolver = "2"
[patch.crates-io]
# for http 1.0. PR to update is merged but not published
headers = { git = "https://github.com/hyperium/headers", rev = "4400aa90c47a7" }

# for http 1.0. PR=https://github.com/rousan/multer-rs/pull/59
multer = { git = "https://github.com/davidpdrsn/multer-rs", rev = "abe0e3a42a1fc" }
22 changes: 19 additions & 3 deletions axum-extra/src/extract/multipart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use axum::{
use futures_util::stream::Stream;
use http::{
header::{HeaderMap, CONTENT_TYPE},
Request, StatusCode,
HeaderName, HeaderValue, Request, StatusCode,
};
use std::{
error::Error,
Expand Down Expand Up @@ -115,7 +115,22 @@ impl Multipart {
.map_err(MultipartError::from_multer)?;

if let Some(field) = field {
Ok(Some(Field { inner: field }))
// multer still uses http 0.2 which means we cannot directly expose
// `multer::Field::headers`. Instead we have to eagerly convert the headers into http
// 1.0
//
// When the next major version of multer is published we can remove this.
let mut headers = HeaderMap::with_capacity(field.headers().len());
headers.extend(field.headers().clone().into_iter().map(|(name, value)| {
let name = name.map(|name| HeaderName::from_bytes(name.as_ref()).unwrap());
let value = HeaderValue::from_bytes(value.as_ref()).unwrap();
(name, value)
}));

Ok(Some(Field {
inner: field,
headers,
}))
} else {
Ok(None)
}
Expand All @@ -134,6 +149,7 @@ impl Multipart {
#[derive(Debug)]
pub struct Field {
inner: multer::Field<'static>,
headers: HeaderMap,
}

impl Stream for Field {
Expand Down Expand Up @@ -168,7 +184,7 @@ impl Field {

/// Get a map of headers as [`HeaderMap`].
pub fn headers(&self) -> &HeaderMap {
self.inner.headers()
&self.headers
}

/// Get the full data of the field as [`Bytes`].
Expand Down
25 changes: 23 additions & 2 deletions axum/src/extract/multipart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use axum_core::response::{IntoResponse, Response};
use axum_core::RequestExt;
use futures_util::stream::Stream;
use http::header::{HeaderMap, CONTENT_TYPE};
use http::HeaderName;
use http::HeaderValue;
use http::StatusCode;
use std::error::Error;
use std::{
Expand Down Expand Up @@ -87,8 +89,21 @@ impl Multipart {
.map_err(MultipartError::from_multer)?;

if let Some(field) = field {
// multer still uses http 0.2 which means we cannot directly expose
// `multer::Field::headers`. Instead we have to eagerly convert the headers into http
// 1.0
//
// When the next major version of multer is published we can remove this.
let mut headers = HeaderMap::with_capacity(field.headers().len());
headers.extend(field.headers().clone().into_iter().map(|(name, value)| {
let name = name.map(|name| HeaderName::from_bytes(name.as_ref()).unwrap());
let value = HeaderValue::from_bytes(value.as_ref()).unwrap();
(name, value)
}));

Ok(Some(Field {
inner: field,
headers,
_multipart: self,
}))
} else {
Expand All @@ -101,6 +116,7 @@ impl Multipart {
#[derive(Debug)]
pub struct Field<'a> {
inner: multer::Field<'static>,
headers: HeaderMap,
// multer requires there to only be one live `multer::Field` at any point. This enforces that
// statically, which multer does not do, it returns an error instead.
_multipart: &'a mut Multipart,
Expand Down Expand Up @@ -138,7 +154,7 @@ impl<'a> Field<'a> {

/// Get a map of headers as [`HeaderMap`].
pub fn headers(&self) -> &HeaderMap {
self.inner.headers()
&self.headers
}

/// Get the full data of the field as [`Bytes`].
Expand Down Expand Up @@ -320,6 +336,7 @@ mod tests {

assert_eq!(field.file_name().unwrap(), FILE_NAME);
assert_eq!(field.content_type().unwrap(), CONTENT_TYPE);
assert_eq!(field.headers()["foo"], "bar");
assert_eq!(field.bytes().await.unwrap(), BYTES);

assert!(multipart.next_field().await.unwrap().is_none());
Expand All @@ -334,7 +351,11 @@ mod tests {
reqwest::multipart::Part::bytes(BYTES)
.file_name(FILE_NAME)
.mime_str(CONTENT_TYPE)
.unwrap(),
.unwrap()
.headers(reqwest::header::HeaderMap::from_iter([(
reqwest::header::HeaderName::from_static("foo"),
reqwest::header::HeaderValue::from_static("bar"),
)])),
);

client.post("/").multipart(form).send().await;
Expand Down

0 comments on commit b9983b6

Please sign in to comment.