From b9983b6bc6d202f94f029bed6c6fa685aa0818c8 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 22 Nov 2023 23:03:41 +0100 Subject: [PATCH] convert multer's Field headers to http 1.0 --- Cargo.toml | 3 --- axum-extra/src/extract/multipart.rs | 22 +++++++++++++++++++--- axum/src/extract/multipart.rs | 25 +++++++++++++++++++++++-- 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a6e3812efd..f9e9c8b59e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,3 @@ resolver = "2" [patch.crates-io] # for http 1.0. PR to update is merged but not published headers = { git = "https://github.com/hyperium/headers", rev = "4400aa90c47a7" } - -# for http 1.0. PR=https://github.com/rousan/multer-rs/pull/59 -multer = { git = "https://github.com/davidpdrsn/multer-rs", rev = "abe0e3a42a1fc" } diff --git a/axum-extra/src/extract/multipart.rs b/axum-extra/src/extract/multipart.rs index 8c78a77722..3f9e10bef2 100644 --- a/axum-extra/src/extract/multipart.rs +++ b/axum-extra/src/extract/multipart.rs @@ -12,7 +12,7 @@ use axum::{ use futures_util::stream::Stream; use http::{ header::{HeaderMap, CONTENT_TYPE}, - Request, StatusCode, + HeaderName, HeaderValue, Request, StatusCode, }; use std::{ error::Error, @@ -115,7 +115,22 @@ impl Multipart { .map_err(MultipartError::from_multer)?; if let Some(field) = field { - Ok(Some(Field { inner: field })) + // multer still uses http 0.2 which means we cannot directly expose + // `multer::Field::headers`. Instead we have to eagerly convert the headers into http + // 1.0 + // + // When the next major version of multer is published we can remove this. + let mut headers = HeaderMap::with_capacity(field.headers().len()); + headers.extend(field.headers().clone().into_iter().map(|(name, value)| { + let name = name.map(|name| HeaderName::from_bytes(name.as_ref()).unwrap()); + let value = HeaderValue::from_bytes(value.as_ref()).unwrap(); + (name, value) + })); + + Ok(Some(Field { + inner: field, + headers, + })) } else { Ok(None) } @@ -134,6 +149,7 @@ impl Multipart { #[derive(Debug)] pub struct Field { inner: multer::Field<'static>, + headers: HeaderMap, } impl Stream for Field { @@ -168,7 +184,7 @@ impl Field { /// Get a map of headers as [`HeaderMap`]. pub fn headers(&self) -> &HeaderMap { - self.inner.headers() + &self.headers } /// Get the full data of the field as [`Bytes`]. diff --git a/axum/src/extract/multipart.rs b/axum/src/extract/multipart.rs index 227e983a4b..2231f4eec5 100644 --- a/axum/src/extract/multipart.rs +++ b/axum/src/extract/multipart.rs @@ -11,6 +11,8 @@ use axum_core::response::{IntoResponse, Response}; use axum_core::RequestExt; use futures_util::stream::Stream; use http::header::{HeaderMap, CONTENT_TYPE}; +use http::HeaderName; +use http::HeaderValue; use http::StatusCode; use std::error::Error; use std::{ @@ -87,8 +89,21 @@ impl Multipart { .map_err(MultipartError::from_multer)?; if let Some(field) = field { + // multer still uses http 0.2 which means we cannot directly expose + // `multer::Field::headers`. Instead we have to eagerly convert the headers into http + // 1.0 + // + // When the next major version of multer is published we can remove this. + let mut headers = HeaderMap::with_capacity(field.headers().len()); + headers.extend(field.headers().clone().into_iter().map(|(name, value)| { + let name = name.map(|name| HeaderName::from_bytes(name.as_ref()).unwrap()); + let value = HeaderValue::from_bytes(value.as_ref()).unwrap(); + (name, value) + })); + Ok(Some(Field { inner: field, + headers, _multipart: self, })) } else { @@ -101,6 +116,7 @@ impl Multipart { #[derive(Debug)] pub struct Field<'a> { inner: multer::Field<'static>, + headers: HeaderMap, // multer requires there to only be one live `multer::Field` at any point. This enforces that // statically, which multer does not do, it returns an error instead. _multipart: &'a mut Multipart, @@ -138,7 +154,7 @@ impl<'a> Field<'a> { /// Get a map of headers as [`HeaderMap`]. pub fn headers(&self) -> &HeaderMap { - self.inner.headers() + &self.headers } /// Get the full data of the field as [`Bytes`]. @@ -320,6 +336,7 @@ mod tests { assert_eq!(field.file_name().unwrap(), FILE_NAME); assert_eq!(field.content_type().unwrap(), CONTENT_TYPE); + assert_eq!(field.headers()["foo"], "bar"); assert_eq!(field.bytes().await.unwrap(), BYTES); assert!(multipart.next_field().await.unwrap().is_none()); @@ -334,7 +351,11 @@ mod tests { reqwest::multipart::Part::bytes(BYTES) .file_name(FILE_NAME) .mime_str(CONTENT_TYPE) - .unwrap(), + .unwrap() + .headers(reqwest::header::HeaderMap::from_iter([( + reqwest::header::HeaderName::from_static("foo"), + reqwest::header::HeaderValue::from_static("bar"), + )])), ); client.post("/").multipart(form).send().await;