Skip to content

Commit

Permalink
Add a separate trait for optional extractors (#2475)
Browse files Browse the repository at this point in the history
  • Loading branch information
jplatte authored Dec 10, 2024
1 parent fd11d8e commit ec75ee3
Show file tree
Hide file tree
Showing 17 changed files with 306 additions and 84 deletions.
8 changes: 8 additions & 0 deletions axum-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

# Unreleased

- **breaking:**: `Option<T>` as an extractor now requires `T` to implement the
new trait `OptionalFromRequest` (if used as the last extractor) or
`OptionalFromRequestParts` (other extractors) ([#2475])

[#2475]: https://github.com/tokio-rs/axum/pull/2475

# 0.5.0

## alpha.1
Expand Down
34 changes: 6 additions & 28 deletions axum-core/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@ pub mod rejection;

mod default_body_limit;
mod from_ref;
mod option;
mod request_parts;
mod tuple;

pub(crate) use self::default_body_limit::DefaultBodyLimitKind;
pub use self::{default_body_limit::DefaultBodyLimit, from_ref::FromRef};
pub use self::{
default_body_limit::DefaultBodyLimit,
from_ref::FromRef,
option::{OptionalFromRequest, OptionalFromRequestParts},
};

/// Type alias for [`http::Request`] whose body type defaults to [`Body`], the most common body
/// type used with axum.
Expand Down Expand Up @@ -102,33 +107,6 @@ where
}
}

impl<S, T> FromRequestParts<S> for Option<T>
where
T: FromRequestParts<S>,
S: Send + Sync,
{
type Rejection = Infallible;

async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> Result<Option<T>, Self::Rejection> {
Ok(T::from_request_parts(parts, state).await.ok())
}
}

impl<S, T> FromRequest<S> for Option<T>
where
T: FromRequest<S>,
S: Send + Sync,
{
type Rejection = Infallible;

async fn from_request(req: Request, state: &S) -> Result<Option<T>, Self::Rejection> {
Ok(T::from_request(req, state).await.ok())
}
}

impl<S, T> FromRequestParts<S> for Result<T, T::Rejection>
where
T: FromRequestParts<S>,
Expand Down
63 changes: 63 additions & 0 deletions axum-core/src/extract/option.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use std::future::Future;

use http::request::Parts;

use crate::response::IntoResponse;

use super::{private, FromRequest, FromRequestParts, Request};

/// Customize the behavior of `Option<Self>` as a [`FromRequestParts`]
/// extractor.
pub trait OptionalFromRequestParts<S>: Sized {
/// If the extractor fails, it will use this "rejection" type.
///
/// A rejection is a kind of error that can be converted into a response.
type Rejection: IntoResponse;

/// Perform the extraction.
fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> impl Future<Output = Result<Option<Self>, Self::Rejection>> + Send;
}

/// Customize the behavior of `Option<Self>` as a [`FromRequest`] extractor.
pub trait OptionalFromRequest<S, M = private::ViaRequest>: Sized {
/// If the extractor fails, it will use this "rejection" type.
///
/// A rejection is a kind of error that can be converted into a response.
type Rejection: IntoResponse;

/// Perform the extraction.
fn from_request(
req: Request,
state: &S,
) -> impl Future<Output = Result<Option<Self>, Self::Rejection>> + Send;
}

impl<S, T> FromRequestParts<S> for Option<T>
where
T: OptionalFromRequestParts<S>,
S: Send + Sync,
{
type Rejection = T::Rejection;

fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> impl Future<Output = Result<Option<T>, Self::Rejection>> {
T::from_request_parts(parts, state)
}
}

impl<S, T> FromRequest<S> for Option<T>
where
T: OptionalFromRequest<S>,
S: Send + Sync,
{
type Rejection = T::Rejection;

async fn from_request(req: Request, state: &S) -> Result<Option<T>, Self::Rejection> {
T::from_request(req, state).await
}
}
4 changes: 4 additions & 0 deletions axum-extra/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning].

# Unreleased

- **breaking:** `Option<Query<T>>` no longer swallows all error conditions, instead rejecting the
request in many cases; see its documentation for details ([#2475])
- **changed:** Deprecated `OptionalPath<T>` and `OptionalQuery<T>` ([#2475])
- **fixed:** `Host` extractor includes port number when parsing authority ([#2242])
- **changed:** The `multipart` feature is no longer on by default ([#3058])
- **added:** Add `RouterExt::typed_connect` ([#2961])
Expand All @@ -16,6 +19,7 @@ and this project adheres to [Semantic Versioning].
- **added:** Add `FileStream` for easy construction of file stream responses ([#3047])

[#2242]: https://github.com/tokio-rs/axum/pull/2242
[#2475]: https://github.com/tokio-rs/axum/pull/2475
[#3058]: https://github.com/tokio-rs/axum/pull/3058
[#2961]: https://github.com/tokio-rs/axum/pull/2961
[#2962]: https://github.com/tokio-rs/axum/pull/2962
Expand Down
11 changes: 7 additions & 4 deletions axum-extra/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ pub mod multipart;
#[cfg(feature = "scheme")]
mod scheme;

pub use self::{
cached::Cached, host::Host, optional_path::OptionalPath, with_rejection::WithRejection,
};
#[allow(deprecated)]
pub use self::optional_path::OptionalPath;
pub use self::{cached::Cached, host::Host, with_rejection::WithRejection};

#[cfg(feature = "cookie")]
pub use self::cookie::CookieJar;
Expand All @@ -41,7 +41,10 @@ pub use self::cookie::SignedCookieJar;
pub use self::form::{Form, FormRejection};

#[cfg(feature = "query")]
pub use self::query::{OptionalQuery, OptionalQueryRejection, Query, QueryRejection};
#[allow(deprecated)]
pub use self::query::OptionalQuery;
#[cfg(feature = "query")]
pub use self::query::{OptionalQueryRejection, Query, QueryRejection};

#[cfg(feature = "multipart")]
pub use self::multipart::Multipart;
Expand Down
18 changes: 8 additions & 10 deletions axum-extra/src/extract/optional_path.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use axum::{
extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts, Path},
extract::{rejection::PathRejection, FromRequestParts, Path},
RequestPartsExt,
};
use serde::de::DeserializeOwned;
Expand Down Expand Up @@ -31,9 +31,11 @@ use serde::de::DeserializeOwned;
/// .route("/blog/{page}", get(render_blog));
/// # let app: Router = app;
/// ```
#[deprecated = "Use Option<Path<_>> instead"]
#[derive(Debug)]
pub struct OptionalPath<T>(pub Option<T>);

#[allow(deprecated)]
impl<T, S> FromRequestParts<S> for OptionalPath<T>
where
T: DeserializeOwned + Send + 'static,
Expand All @@ -45,19 +47,15 @@ where
parts: &mut http::request::Parts,
_: &S,
) -> Result<Self, Self::Rejection> {
match parts.extract::<Path<T>>().await {
Ok(Path(params)) => Ok(Self(Some(params))),
Err(PathRejection::FailedToDeserializePathParams(e))
if matches!(e.kind(), ErrorKind::WrongNumberOfParameters { got: 0, .. }) =>
{
Ok(Self(None))
}
Err(e) => Err(e),
}
parts
.extract::<Option<Path<T>>>()
.await
.map(|opt| Self(opt.map(|Path(x)| x)))
}
}

#[cfg(test)]
#[allow(deprecated)]
mod tests {
use std::num::NonZeroU32;

Expand Down
41 changes: 40 additions & 1 deletion axum-extra/src/extract/query.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use axum::{
extract::FromRequestParts,
extract::{FromRequestParts, OptionalFromRequestParts},
response::{IntoResponse, Response},
Error,
};
Expand All @@ -18,6 +18,19 @@ use std::fmt;
/// with the `multiple` attribute. Those values can be collected into a `Vec` or other sequential
/// container.
///
/// # `Option<Query<T>>` behavior
///
/// If `Query<T>` itself is used as an extractor and there is no query string in
/// the request URL, `T`'s `Deserialize` implementation is called on an empty
/// string instead.
///
/// You can avoid this by using `Option<Query<T>>`, which gives you `None` in
/// the case that there is no query string in the request URL.
///
/// Note that an empty query string is not the same as no query string, that is
/// `https://example.org/` and `https://example.org/?` are not treated the same
/// in this case.
///
/// # Example
///
/// ```rust,no_run
Expand Down Expand Up @@ -96,6 +109,27 @@ where
}
}

impl<T, S> OptionalFromRequestParts<S> for Query<T>
where
T: DeserializeOwned,
S: Send + Sync,
{
type Rejection = QueryRejection;

async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
if let Some(query) = parts.uri.query() {
let value = serde_html_form::from_str(query)
.map_err(|err| QueryRejection::FailedToDeserializeQueryString(Error::new(err)))?;
Ok(Some(Self(value)))
} else {
Ok(None)
}
}
}

axum_core::__impl_deref!(Query);

/// Rejection used for [`Query`].
Expand Down Expand Up @@ -182,9 +216,11 @@ impl std::error::Error for QueryRejection {
///
/// [example]: https://github.com/tokio-rs/axum/blob/main/examples/query-params-with-empty-strings/src/main.rs
#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
#[deprecated = "Use Option<Query<_>> instead"]
#[derive(Debug, Clone, Copy, Default)]
pub struct OptionalQuery<T>(pub Option<T>);

#[allow(deprecated)]
impl<T, S> FromRequestParts<S> for OptionalQuery<T>
where
T: DeserializeOwned,
Expand All @@ -204,6 +240,7 @@ where
}
}

#[allow(deprecated)]
impl<T> std::ops::Deref for OptionalQuery<T> {
type Target = Option<T>;

Expand All @@ -213,6 +250,7 @@ impl<T> std::ops::Deref for OptionalQuery<T> {
}
}

#[allow(deprecated)]
impl<T> std::ops::DerefMut for OptionalQuery<T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
Expand Down Expand Up @@ -260,6 +298,7 @@ impl std::error::Error for OptionalQueryRejection {
}

#[cfg(test)]
#[allow(deprecated)]
mod tests {
use super::*;
use crate::test_helpers::*;
Expand Down
26 changes: 25 additions & 1 deletion axum-extra/src/typed_header.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Extractor and response for typed headers.
use axum::{
extract::FromRequestParts,
extract::{FromRequestParts, OptionalFromRequestParts},
response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
};
use headers::{Header, HeaderMapExt};
Expand Down Expand Up @@ -78,6 +78,30 @@ where
}
}

impl<T, S> OptionalFromRequestParts<S> for TypedHeader<T>
where
T: Header,
S: Send + Sync,
{
type Rejection = TypedHeaderRejection;

async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
let mut values = parts.headers.get_all(T::name()).iter();
let is_missing = values.size_hint() == (0, Some(0));
match T::decode(&mut values) {
Ok(res) => Ok(Some(Self(res))),
Err(_) if is_missing => Ok(None),
Err(err) => Err(TypedHeaderRejection {
name: T::name(),
reason: TypedHeaderRejectionReason::Error(err),
}),
}
}
}

axum_core::__impl_deref!(TypedHeader);

impl<T> IntoResponseParts for TypedHeader<T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ struct UsersShow {
id: String,
}

async fn option_handler(_: Option<UsersShow>) {}

async fn result_handler(_: Result<UsersShow, PathRejection>) {}

#[derive(TypedPath, Deserialize)]
Expand All @@ -20,7 +18,6 @@ async fn result_handler_unit_struct(_: Result<UsersIndex, StatusCode>) {}

fn main() {
_ = axum::Router::<()>::new()
.typed_get(option_handler)
.typed_post(result_handler)
.typed_post(result_handler_unit_struct);
}
3 changes: 3 additions & 0 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
This allows middleware to add bodies to requests without needing to manually set `content-length` ([#2897])
- **breaking:** Remove `WebSocket::close`.
Users should explicitly send close messages themselves. ([#2974])
- **breaking:** `Option<Path<T>>` and `Option<Query<T>>` no longer swallow all error conditions,
instead rejecting the request in many cases; see their documentation for details ([#2475])
- **added:** Extend `FailedToDeserializePathParams::kind` enum with (`ErrorKind::DeserializeError`)
This new variant captures both `key`, `value`, and `message` from named path parameters parse errors,
instead of only deserialization error message in `ErrorKind::Message`. ([#2720])
- **breaking:** Make `serve` generic over the listener and IO types ([#2941])

[#2475]: https://github.com/tokio-rs/axum/pull/2475
[#2897]: https://github.com/tokio-rs/axum/pull/2897
[#2903]: https://github.com/tokio-rs/axum/pull/2903
[#2894]: https://github.com/tokio-rs/axum/pull/2894
Expand Down
1 change: 1 addition & 0 deletions axum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ features = [

[dev-dependencies]
anyhow = "1.0"
axum-extra = { path = "../axum-extra", features = ["typed-header"] }
axum-macros = { path = "../axum-macros", features = ["__private"] }
hyper = { version = "1.1.0", features = ["client"] }
quickcheck = "1.0"
Expand Down
Loading

0 comments on commit ec75ee3

Please sign in to comment.