Skip to content

Commit

Permalink
Replace async_trait with AFIT / RPITIT (#2308)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonas Platte <[email protected]>
  • Loading branch information
lz1998 and jplatte authored Sep 28, 2024
1 parent dda5a27 commit 19101f6
Show file tree
Hide file tree
Showing 69 changed files with 115 additions and 266 deletions.
1 change: 0 additions & 1 deletion axum-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ tracing = ["dep:tracing"]
__private_docs = ["dep:tower-http"]

[dependencies]
async-trait = "0.1.67"
bytes = "1.2"
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
http = "1.0.0"
Expand Down
3 changes: 0 additions & 3 deletions axum-core/src/ext_traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@ mod tests {
use std::convert::Infallible;

use crate::extract::{FromRef, FromRequestParts};
use async_trait::async_trait;
use http::request::Parts;

#[derive(Debug, Default, Clone, Copy)]
pub(crate) struct State<S>(pub(crate) S);

#[async_trait]
impl<OuterState, InnerState> FromRequestParts<OuterState> for State<InnerState>
where
InnerState: FromRef<OuterState>,
Expand All @@ -33,7 +31,6 @@ mod tests {
#[allow(dead_code)]
pub(crate) struct RequiresState(pub(crate) String);

#[async_trait]
impl<S> FromRequestParts<S> for RequiresState
where
S: Send + Sync,
Expand Down
54 changes: 23 additions & 31 deletions axum-core/src/ext_traits/request.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::body::Body;
use crate::extract::{DefaultBodyLimitKind, FromRequest, FromRequestParts, Request};
use futures_util::future::BoxFuture;
use std::future::Future;

mod sealed {
pub trait Sealed {}
Expand All @@ -20,7 +20,6 @@ pub trait RequestExt: sealed::Sealed + Sized {
///
/// ```
/// use axum::{
/// async_trait,
/// extract::{Request, FromRequest},
/// body::Body,
/// http::{header::CONTENT_TYPE, StatusCode},
Expand All @@ -30,7 +29,6 @@ pub trait RequestExt: sealed::Sealed + Sized {
///
/// struct FormOrJson<T>(T);
///
/// #[async_trait]
/// impl<S, T> FromRequest<S> for FormOrJson<T>
/// where
/// Json<T>: FromRequest<()>,
Expand Down Expand Up @@ -67,7 +65,7 @@ pub trait RequestExt: sealed::Sealed + Sized {
/// }
/// }
/// ```
fn extract<E, M>(self) -> BoxFuture<'static, Result<E, E::Rejection>>
fn extract<E, M>(self) -> impl Future<Output = Result<E, E::Rejection>> + Send
where
E: FromRequest<(), M> + 'static,
M: 'static;
Expand All @@ -83,7 +81,6 @@ pub trait RequestExt: sealed::Sealed + Sized {
///
/// ```
/// use axum::{
/// async_trait,
/// body::Body,
/// extract::{Request, FromRef, FromRequest},
/// RequestExt,
Expand All @@ -93,7 +90,6 @@ pub trait RequestExt: sealed::Sealed + Sized {
/// requires_state: RequiresState,
/// }
///
/// #[async_trait]
/// impl<S> FromRequest<S> for MyExtractor
/// where
/// String: FromRef<S>,
Expand All @@ -111,7 +107,6 @@ pub trait RequestExt: sealed::Sealed + Sized {
/// // some extractor that consumes the request body and requires state
/// struct RequiresState { /* ... */ }
///
/// #[async_trait]
/// impl<S> FromRequest<S> for RequiresState
/// where
/// String: FromRef<S>,
Expand All @@ -124,7 +119,10 @@ pub trait RequestExt: sealed::Sealed + Sized {
/// # }
/// }
/// ```
fn extract_with_state<E, S, M>(self, state: &S) -> BoxFuture<'_, Result<E, E::Rejection>>
fn extract_with_state<E, S, M>(
self,
state: &S,
) -> impl Future<Output = Result<E, E::Rejection>> + Send
where
E: FromRequest<S, M> + 'static,
S: Send + Sync;
Expand All @@ -137,7 +135,6 @@ pub trait RequestExt: sealed::Sealed + Sized {
///
/// ```
/// use axum::{
/// async_trait,
/// extract::{Path, Request, FromRequest},
/// response::{IntoResponse, Response},
/// body::Body,
Expand All @@ -154,7 +151,6 @@ pub trait RequestExt: sealed::Sealed + Sized {
/// payload: T,
/// }
///
/// #[async_trait]
/// impl<S, T> FromRequest<S> for MyExtractor<T>
/// where
/// S: Send + Sync,
Expand All @@ -179,7 +175,7 @@ pub trait RequestExt: sealed::Sealed + Sized {
/// }
/// }
/// ```
fn extract_parts<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
fn extract_parts<E>(&mut self) -> impl Future<Output = Result<E, E::Rejection>> + Send
where
E: FromRequestParts<()> + 'static;

Expand All @@ -191,7 +187,6 @@ pub trait RequestExt: sealed::Sealed + Sized {
///
/// ```
/// use axum::{
/// async_trait,
/// extract::{Request, FromRef, FromRequest, FromRequestParts},
/// http::request::Parts,
/// response::{IntoResponse, Response},
Expand All @@ -204,7 +199,6 @@ pub trait RequestExt: sealed::Sealed + Sized {
/// payload: T,
/// }
///
/// #[async_trait]
/// impl<S, T> FromRequest<S> for MyExtractor<T>
/// where
/// String: FromRef<S>,
Expand Down Expand Up @@ -234,7 +228,6 @@ pub trait RequestExt: sealed::Sealed + Sized {
///
/// struct RequiresState {}
///
/// #[async_trait]
/// impl<S> FromRequestParts<S> for RequiresState
/// where
/// String: FromRef<S>,
Expand All @@ -250,7 +243,7 @@ pub trait RequestExt: sealed::Sealed + Sized {
fn extract_parts_with_state<'a, E, S>(
&'a mut self,
state: &'a S,
) -> BoxFuture<'a, Result<E, E::Rejection>>
) -> impl Future<Output = Result<E, E::Rejection>> + Send + 'a
where
E: FromRequestParts<S> + 'static,
S: Send + Sync;
Expand All @@ -267,33 +260,36 @@ pub trait RequestExt: sealed::Sealed + Sized {
}

impl RequestExt for Request {
fn extract<E, M>(self) -> BoxFuture<'static, Result<E, E::Rejection>>
fn extract<E, M>(self) -> impl Future<Output = Result<E, E::Rejection>> + Send
where
E: FromRequest<(), M> + 'static,
M: 'static,
{
self.extract_with_state(&())
}

fn extract_with_state<E, S, M>(self, state: &S) -> BoxFuture<'_, Result<E, E::Rejection>>
fn extract_with_state<E, S, M>(
self,
state: &S,
) -> impl Future<Output = Result<E, E::Rejection>> + Send
where
E: FromRequest<S, M> + 'static,
S: Send + Sync,
{
E::from_request(self, state)
}

fn extract_parts<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
fn extract_parts<E>(&mut self) -> impl Future<Output = Result<E, E::Rejection>> + Send
where
E: FromRequestParts<()> + 'static,
{
self.extract_parts_with_state(&())
}

fn extract_parts_with_state<'a, E, S>(
async fn extract_parts_with_state<'a, E, S>(
&'a mut self,
state: &'a S,
) -> BoxFuture<'a, Result<E, E::Rejection>>
) -> Result<E, E::Rejection>
where
E: FromRequestParts<S> + 'static,
S: Send + Sync,
Expand All @@ -306,17 +302,15 @@ impl RequestExt for Request {
*req.extensions_mut() = std::mem::take(self.extensions_mut());
let (mut parts, ()) = req.into_parts();

Box::pin(async move {
let result = E::from_request_parts(&mut parts, state).await;
let result = E::from_request_parts(&mut parts, state).await;

*self.version_mut() = parts.version;
*self.method_mut() = parts.method.clone();
*self.uri_mut() = parts.uri.clone();
*self.headers_mut() = std::mem::take(&mut parts.headers);
*self.extensions_mut() = std::mem::take(&mut parts.extensions);
*self.version_mut() = parts.version;
*self.method_mut() = parts.method.clone();
*self.uri_mut() = parts.uri.clone();
*self.headers_mut() = std::mem::take(&mut parts.headers);
*self.extensions_mut() = std::mem::take(&mut parts.extensions);

result
})
result
}

fn with_limited_body(self) -> Request {
Expand Down Expand Up @@ -345,7 +339,6 @@ mod tests {
ext_traits::tests::{RequiresState, State},
extract::FromRef,
};
use async_trait::async_trait;
use http::Method;

#[tokio::test]
Expand Down Expand Up @@ -414,7 +407,6 @@ mod tests {
body: String,
}

#[async_trait]
impl<S> FromRequest<S> for WorksForCustomExtractor
where
S: Send + Sync,
Expand Down
17 changes: 5 additions & 12 deletions axum-core/src/ext_traits/request_parts.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::extract::FromRequestParts;
use futures_util::future::BoxFuture;
use http::request::Parts;
use std::future::Future;

mod sealed {
pub trait Sealed {}
Expand All @@ -21,7 +21,6 @@ pub trait RequestPartsExt: sealed::Sealed + Sized {
/// response::{Response, IntoResponse},
/// http::request::Parts,
/// RequestPartsExt,
/// async_trait,
/// };
/// use std::collections::HashMap;
///
Expand All @@ -30,7 +29,6 @@ pub trait RequestPartsExt: sealed::Sealed + Sized {
/// query_params: HashMap<String, String>,
/// }
///
/// #[async_trait]
/// impl<S> FromRequestParts<S> for MyExtractor
/// where
/// S: Send + Sync,
Expand All @@ -54,7 +52,7 @@ pub trait RequestPartsExt: sealed::Sealed + Sized {
/// }
/// }
/// ```
fn extract<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
fn extract<E>(&mut self) -> impl Future<Output = Result<E, E::Rejection>> + Send
where
E: FromRequestParts<()> + 'static;

Expand All @@ -70,14 +68,12 @@ pub trait RequestPartsExt: sealed::Sealed + Sized {
/// response::{Response, IntoResponse},
/// http::request::Parts,
/// RequestPartsExt,
/// async_trait,
/// };
///
/// struct MyExtractor {
/// requires_state: RequiresState,
/// }
///
/// #[async_trait]
/// impl<S> FromRequestParts<S> for MyExtractor
/// where
/// String: FromRef<S>,
Expand All @@ -97,7 +93,6 @@ pub trait RequestPartsExt: sealed::Sealed + Sized {
/// struct RequiresState { /* ... */ }
///
/// // some extractor that requires a `String` in the state
/// #[async_trait]
/// impl<S> FromRequestParts<S> for RequiresState
/// where
/// String: FromRef<S>,
Expand All @@ -113,14 +108,14 @@ pub trait RequestPartsExt: sealed::Sealed + Sized {
fn extract_with_state<'a, E, S>(
&'a mut self,
state: &'a S,
) -> BoxFuture<'a, Result<E, E::Rejection>>
) -> impl Future<Output = Result<E, E::Rejection>> + Send + 'a
where
E: FromRequestParts<S> + 'static,
S: Send + Sync;
}

impl RequestPartsExt for Parts {
fn extract<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
fn extract<E>(&mut self) -> impl Future<Output = Result<E, E::Rejection>> + Send
where
E: FromRequestParts<()> + 'static,
{
Expand All @@ -130,7 +125,7 @@ impl RequestPartsExt for Parts {
fn extract_with_state<'a, E, S>(
&'a mut self,
state: &'a S,
) -> BoxFuture<'a, Result<E, E::Rejection>>
) -> impl Future<Output = Result<E, E::Rejection>> + Send + 'a
where
E: FromRequestParts<S> + 'static,
S: Send + Sync,
Expand All @@ -148,7 +143,6 @@ mod tests {
ext_traits::tests::{RequiresState, State},
extract::FromRef,
};
use async_trait::async_trait;
use http::{Method, Request};

#[tokio::test]
Expand Down Expand Up @@ -181,7 +175,6 @@ mod tests {
from_state: String,
}

#[async_trait]
impl<S> FromRequestParts<S> for WorksForCustomExtractor
where
S: Send + Sync,
Expand Down
Loading

0 comments on commit 19101f6

Please sign in to comment.