Skip to content

Commit

Permalink
Update Fairings types to fix issues
Browse files Browse the repository at this point in the history
  • Loading branch information
the10thWiz committed Sep 8, 2024
1 parent 3fbf1b4 commit c263a6c
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 25 deletions.
5 changes: 1 addition & 4 deletions contrib/dyn_templates/src/fairing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,10 @@ impl Fairing for TemplateFairing {
}

#[cfg(debug_assertions)]
async fn on_request<'r>(&self, req: &'r mut rocket::Request<'_>, _data: &mut rocket::Data<'_>)
-> Result<(), Box<dyn TypedError<'r> + 'r>>
{
async fn on_request(&self, req: &mut rocket::Request<'_>, _data: &mut rocket::Data<'_>) {
let cm = req.rocket().state::<ContextManager>()
.expect("Template ContextManager registered in on_ignite");

cm.reload_if_needed(&self.callback);
Ok(())
}
}
6 changes: 5 additions & 1 deletion core/codegen/src/attribute/catch/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ fn error_guard_decl(guard: &ErrorGuard) -> TokenStream {
fn request_guard_decl(guard: &Guard) -> TokenStream {
let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty);
quote_spanned! { ty.span() =>
let #ident: #ty = match <#ty as #FromError>::from_error(#__status, #__req, __error_init).await {
let #ident: #ty = match <#ty as #FromError>::from_error(
#__status,
#__req,
__error_init
).await {
#_Result::Ok(__v) => __v,
#_Result::Err(__e) => {
::rocket::trace::info!(
Expand Down
8 changes: 6 additions & 2 deletions core/codegen/src/derive/typed_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ pub fn derive_typed_error(input: proc_macro::TokenStream) -> TokenStream {
.inner_mapper(MapperBuild::new()
.with_output(|_, output| quote! {
#[allow(unused_variables)]
fn respond_to(&self, request: &'r #Request<'_>) -> #_Result<#Response<'r>, #_Status> {
fn respond_to(&self, request: &'r #Request<'_>)
-> #_Result<#Response<'r>, #_Status>
{
#output
}
})
Expand Down Expand Up @@ -76,7 +78,9 @@ pub fn derive_typed_error(input: proc_macro::TokenStream) -> TokenStream {
None => Member::Unnamed(Index { index: field.index as u32, span })
};

source = Some(quote_spanned!(span => #_Some(&self.#member as &dyn #TypedError<'r>)));
source = Some(quote_spanned!(
span => #_Some(&self.#member as &dyn #TypedError<'r>
)));
}
}
}
Expand Down
69 changes: 69 additions & 0 deletions core/codegen/tests/typed_error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#[macro_use] extern crate rocket;
use rocket::catcher::TypedError;
use rocket::http::Status;

fn boxed_error<'r>(_val: Box<dyn TypedError<'r> + 'r>) {}

#[derive(TypedError)]
pub enum Foo<'r> {
First(String),
Second(Vec<u8>),
Third {
#[error(source)]
responder: std::io::Error,
},
#[error(status = 400)]
Fourth {
string: &'r str,
},
}

#[test]
fn validate_foo() {
let first = Foo::First("".into());
assert_eq!(first.status(), Status::InternalServerError);
assert!(first.source().is_none());
boxed_error(Box::new(first));
let second = Foo::Second(vec![]);
assert_eq!(second.status(), Status::InternalServerError);
assert!(second.source().is_none());
boxed_error(Box::new(second));
let third = Foo::Third {
responder: std::io::Error::new(std::io::ErrorKind::NotFound, ""),
};
assert_eq!(third.status(), Status::InternalServerError);
assert!(std::ptr::eq(
third.source().unwrap(),
if let Foo::Third { responder } = &third { responder } else { panic!() }
));
boxed_error(Box::new(third));
let fourth = Foo::Fourth { string: "" };
assert_eq!(fourth.status(), Status::BadRequest);
assert!(fourth.source().is_none());
boxed_error(Box::new(fourth));
}

#[derive(TypedError)]
pub struct InfallibleError {
#[error(source)]
_inner: std::convert::Infallible,
}

#[derive(TypedError)]
pub struct StaticError {
#[error(source)]
inner: std::string::FromUtf8Error,
}

#[test]
fn validate_static() {
let val = StaticError {
inner: String::from_utf8(vec![0xFF]).unwrap_err(),
};
assert_eq!(val.status(), Status::InternalServerError);
assert!(std::ptr::eq(
val.source().unwrap(),
&val.inner,
));
boxed_error(Box::new(val));
}
32 changes: 32 additions & 0 deletions core/codegen/tests/ui-fail/typed_error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#[macro_use] extern crate rocket;

#[derive(TypedError)]
struct InnerError;
struct InnerNonError;

#[derive(TypedError)]
struct Thing1<'a, 'b> {
a: &'a str,
b: &'b str,
}

#[derive(TypedError)]
struct Thing2 {
#[error(source)]
inner: InnerNonError,
}

#[derive(TypedError)]
enum Thing3<'a, 'b> {
A(&'a str),
B(&'b str),
}

#[derive(TypedError)]
enum Thing4 {
A(#[error(source)] InnerNonError),
B(#[error(source)] InnerError),
}

#[derive(TypedError)]
enum EmptyEnum { }
7 changes: 4 additions & 3 deletions core/lib/src/erased.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub struct ErasedError<'r> {

impl<'r> ErasedError<'r> {
pub fn new() -> Self {
Self { error: None }
Self { error: None }
}

pub fn write(&mut self, error: Option<Box<dyn TypedError<'r> + 'r>>) {
Expand Down Expand Up @@ -142,11 +142,12 @@ impl ErasedRequest {
// SAFETY: At this point, ErasedRequest contains a request, which is permitted
// to borrow from `Rocket` and `Parts`. They both have stable addresses (due to
// `Arc` and `Box`), and the Request will be dropped first (due to drop order).
// SAFETY: Here, we place the `ErasedRequest` (i.e. the `Request`) behind an `Arc` (TODO: Why not Box?)
// SAFETY: Here, we place the `ErasedRequest` (i.e. the `Request`) behind an `Arc`
// to ensure it has a stable address, and we again use drop order to ensure the `Request`
// is dropped before the values that can borrow from it.
let mut parent = Arc::new(self);
// SAFETY: This error is permitted to borrow from the `Request` (as well as `Rocket` and `Parts`).
// SAFETY: This error is permitted to borrow from the `Request` (as well as `Rocket` and
// `Parts`).
let mut error = ErasedError { error: None };
let token: T = {
let parent: &mut ErasedRequest = Arc::get_mut(&mut parent).unwrap();
Expand Down
8 changes: 4 additions & 4 deletions core/lib/src/fairing/ad_hoc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ enum AdHocKind {
Liftoff(Once<dyn for<'a> FnOnce(&'a Rocket<Orbit>) -> BoxFuture<'a, ()> + Send + 'static>),

/// An ad-hoc **request** fairing. Called when a request is received.
Request(Box<dyn for<'a, 'b> Fn(&'a mut Request<'_>, &'b mut Data<'_>)
Request(Box<dyn for<'a> Fn(&'a mut Request<'_>, &'a mut Data<'_>)
-> BoxFuture<'a, ()> + Send + Sync + 'static>),

/// An ad-hoc **request_filter** fairing. Called when a request is received.
Expand Down Expand Up @@ -159,7 +159,7 @@ impl AdHoc {
/// });
/// ```
pub fn on_request<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
where F: for<'a, 'b> Fn(&'a mut Request<'_>, &'b mut Data<'_>)
where F: for<'a> Fn(&'a mut Request<'_>, &'a mut Data<'_>)
-> BoxFuture<'a, ()>
{
AdHoc { name, kind: AdHocKind::Request(Box::new(f)) }
Expand Down Expand Up @@ -407,7 +407,7 @@ impl AdHoc {
let _ = self.routes(rocket);
}

async fn on_request<'r>(&self, req: &'r mut Request<'_>, _: &mut Data<'_>) {
async fn on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>) {
// If the URI has no trailing slash, it routes as before.
if req.uri().is_normalized_nontrailing() {
return;
Expand Down Expand Up @@ -458,7 +458,7 @@ impl Fairing for AdHoc {
}
}

async fn on_request<'r>(&self, req: &'r mut Request<'_>, data: &mut Data<'_>) {
async fn on_request(&self, req: &mut Request<'_>, data: &mut Data<'_>) {
if let AdHocKind::Request(ref f) = self.kind {
f(req, data).await
}
Expand Down
4 changes: 2 additions & 2 deletions core/lib/src/fairing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ pub trait Fairing: Send + Sync + Any + 'static {
/// ## Default Implementation
///
/// The default implementation of this method does nothing.
async fn on_request<'r>(&self, _req: &'r mut Request<'_>, _data: &mut Data<'_>) { }
async fn on_request(&self, _req: &mut Request<'_>, _data: &mut Data<'_>) { }

/// The request filter callback.
///
Expand Down Expand Up @@ -582,7 +582,7 @@ impl<T: Fairing + ?Sized> Fairing for std::sync::Arc<T> {
}

#[inline]
async fn on_request<'r>(&self, req: &'r mut Request<'_>, data: &mut Data<'_>) {
async fn on_request(&self, req: &mut Request<'_>, data: &mut Data<'_>) {
(self as &T).on_request(req, data).await
}

Expand Down
5 changes: 3 additions & 2 deletions core/lib/src/lifecycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ impl Rocket<Orbit> {
let mut response = if error_ptr.is_some() {
// error_ptr is always some here, we just checked.
self.dispatch_error(error_ptr.get().unwrap().status(), request, error_ptr.get()).await
// We MUST wait until we are inside this block to call `get`, since we HAVE to borrow it for `'r`.
// (And it's invariant, so we can't downcast the borrow to a shorter lifetime)
// We MUST wait until we are inside this block to call `get`, since we HAVE to borrow
// it for `'r`. (And it's invariant, so we can't downcast the borrow to a shorter
// lifetime)
} else {
match self.route(request, data).await {
Outcome::Success(response) => response,
Expand Down
16 changes: 12 additions & 4 deletions core/lib/src/local/asynchronous/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,13 @@ impl Drop for LocalResponse<'_> {
}

impl<'c> LocalResponse<'c> {
pub(crate) fn new<P, PO, F, O>(req: Request<'c>, mut data: Data<'c>, preprocess: P, f: F) -> impl Future<Output = LocalResponse<'c>>
where P: FnOnce(&'c mut Request<'c>, &'c mut Data<'c>, &'c mut ErasedError<'c>) -> PO + Send,
pub(crate) fn new<P, PO, F, O>(req: Request<'c>, mut data: Data<'c>, preprocess: P, f: F)
-> impl Future<Output = LocalResponse<'c>>
where P: FnOnce(&'c mut Request<'c>, &'c mut Data<'c>, &'c mut ErasedError<'c>)
-> PO + Send,
PO: Future<Output = RequestToken> + Send + 'c,
F: FnOnce(RequestToken, &'c Request<'c>, Data<'c>, &'c mut ErasedError<'c>) -> O + Send,
F: FnOnce(RequestToken, &'c Request<'c>, Data<'c>, &'c mut ErasedError<'c>)
-> O + Send,
O: Future<Output = Response<'c>> + Send + 'c
{
// `LocalResponse` is a self-referential structure. In particular,
Expand Down Expand Up @@ -121,7 +124,12 @@ impl<'c> LocalResponse<'c> {
// the value is used to set cookie defaults.
// SAFETY: The type of `preprocess` ensures that all of these types have the correct
// lifetime ('c).
let response: Response<'c> = f(token, request, data, unsafe { transmute(&mut error) }).await;
let response: Response<'c> = f(
token,
request,
data,
unsafe { transmute(&mut error) }
).await;
let mut cookies = CookieJar::new(None, request.rocket());
for cookie in response.cookies() {
cookies.add_original(cookie.into_owned());
Expand Down
4 changes: 3 additions & 1 deletion core/lib/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ impl Rocket<Orbit> {
span_debug!("request headers" => request.inner().headers().iter().trace_all_debug());
let mut response = request.into_response(
stream,
|rocket, request, data, error_ptr| Box::pin(rocket.preprocess(request, data, error_ptr)),
|rocket, request, data, error_ptr| {
Box::pin(rocket.preprocess(request, data, error_ptr))
},
|token, rocket, request, data, error_ptr| Box::pin(async move {
if !request.errors.is_empty() {
error_ptr.write(Some(Box::new(RequestErrors::new(&request.errors))));
Expand Down
10 changes: 8 additions & 2 deletions examples/todo/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ async fn toggle(id: i32, conn: DbConn) -> Either<Redirect, Template> {
Ok(_) => Either::Left(Redirect::to("/")),
Err(e) => {
error!("DB toggle({id}) error: {e}");
Either::Right(Template::render("index", Context::err(&conn, "Failed to toggle task.").await))
Either::Right(Template::render(
"index",
Context::err(&conn, "Failed to toggle task.").await
))
}
}
}
Expand All @@ -81,7 +84,10 @@ async fn delete(id: i32, conn: DbConn) -> Either<Flash<Redirect>, Template> {
Ok(_) => Either::Left(Flash::success(Redirect::to("/"), "Todo was deleted.")),
Err(e) => {
error!("DB deletion({id}) error: {e}");
Either::Right(Template::render("index", Context::err(&conn, "Failed to delete task.").await))
Either::Right(Template::render(
"index",
Context::err(&conn, "Failed to delete task.").await
))
}
}
}
Expand Down

0 comments on commit c263a6c

Please sign in to comment.