From c263a6ceea427bd74132ceef0c14d308634d73f1 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Sun, 8 Sep 2024 01:45:41 -0500 Subject: [PATCH] Update Fairings types to fix issues --- contrib/dyn_templates/src/fairing.rs | 5 +- core/codegen/src/attribute/catch/mod.rs | 6 +- core/codegen/src/derive/typed_error.rs | 8 ++- core/codegen/tests/typed_error.rs | 69 +++++++++++++++++++++ core/codegen/tests/ui-fail/typed_error.rs | 32 ++++++++++ core/lib/src/erased.rs | 7 ++- core/lib/src/fairing/ad_hoc.rs | 8 +-- core/lib/src/fairing/mod.rs | 4 +- core/lib/src/lifecycle.rs | 5 +- core/lib/src/local/asynchronous/response.rs | 16 +++-- core/lib/src/server.rs | 4 +- examples/todo/src/main.rs | 10 ++- 12 files changed, 149 insertions(+), 25 deletions(-) create mode 100644 core/codegen/tests/typed_error.rs create mode 100644 core/codegen/tests/ui-fail/typed_error.rs diff --git a/contrib/dyn_templates/src/fairing.rs b/contrib/dyn_templates/src/fairing.rs index c17a53cc3e..6cef441315 100644 --- a/contrib/dyn_templates/src/fairing.rs +++ b/contrib/dyn_templates/src/fairing.rs @@ -66,13 +66,10 @@ impl Fairing for TemplateFairing { } #[cfg(debug_assertions)] - async fn on_request<'r>(&self, req: &'r mut rocket::Request<'_>, _data: &mut rocket::Data<'_>) - -> Result<(), Box + 'r>> - { + async fn on_request(&self, req: &mut rocket::Request<'_>, _data: &mut rocket::Data<'_>) { let cm = req.rocket().state::() .expect("Template ContextManager registered in on_ignite"); cm.reload_if_needed(&self.callback); - Ok(()) } } diff --git a/core/codegen/src/attribute/catch/mod.rs b/core/codegen/src/attribute/catch/mod.rs index e1c76ea127..bd8022130c 100644 --- a/core/codegen/src/attribute/catch/mod.rs +++ b/core/codegen/src/attribute/catch/mod.rs @@ -31,7 +31,11 @@ fn error_guard_decl(guard: &ErrorGuard) -> TokenStream { fn request_guard_decl(guard: &Guard) -> TokenStream { let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty); quote_spanned! { ty.span() => - let #ident: #ty = match <#ty as #FromError>::from_error(#__status, #__req, __error_init).await { + let #ident: #ty = match <#ty as #FromError>::from_error( + #__status, + #__req, + __error_init + ).await { #_Result::Ok(__v) => __v, #_Result::Err(__e) => { ::rocket::trace::info!( diff --git a/core/codegen/src/derive/typed_error.rs b/core/codegen/src/derive/typed_error.rs index 5df7094c1c..2b95ef856b 100644 --- a/core/codegen/src/derive/typed_error.rs +++ b/core/codegen/src/derive/typed_error.rs @@ -37,7 +37,9 @@ pub fn derive_typed_error(input: proc_macro::TokenStream) -> TokenStream { .inner_mapper(MapperBuild::new() .with_output(|_, output| quote! { #[allow(unused_variables)] - fn respond_to(&self, request: &'r #Request<'_>) -> #_Result<#Response<'r>, #_Status> { + fn respond_to(&self, request: &'r #Request<'_>) + -> #_Result<#Response<'r>, #_Status> + { #output } }) @@ -76,7 +78,9 @@ pub fn derive_typed_error(input: proc_macro::TokenStream) -> TokenStream { None => Member::Unnamed(Index { index: field.index as u32, span }) }; - source = Some(quote_spanned!(span => #_Some(&self.#member as &dyn #TypedError<'r>))); + source = Some(quote_spanned!( + span => #_Some(&self.#member as &dyn #TypedError<'r> + ))); } } } diff --git a/core/codegen/tests/typed_error.rs b/core/codegen/tests/typed_error.rs new file mode 100644 index 0000000000..b407811ff4 --- /dev/null +++ b/core/codegen/tests/typed_error.rs @@ -0,0 +1,69 @@ +#[macro_use] extern crate rocket; +use rocket::catcher::TypedError; +use rocket::http::Status; + +fn boxed_error<'r>(_val: Box + 'r>) {} + +#[derive(TypedError)] +pub enum Foo<'r> { + First(String), + Second(Vec), + Third { + #[error(source)] + responder: std::io::Error, + }, + #[error(status = 400)] + Fourth { + string: &'r str, + }, +} + +#[test] +fn validate_foo() { + let first = Foo::First("".into()); + assert_eq!(first.status(), Status::InternalServerError); + assert!(first.source().is_none()); + boxed_error(Box::new(first)); + let second = Foo::Second(vec![]); + assert_eq!(second.status(), Status::InternalServerError); + assert!(second.source().is_none()); + boxed_error(Box::new(second)); + let third = Foo::Third { + responder: std::io::Error::new(std::io::ErrorKind::NotFound, ""), + }; + assert_eq!(third.status(), Status::InternalServerError); + assert!(std::ptr::eq( + third.source().unwrap(), + if let Foo::Third { responder } = &third { responder } else { panic!() } + )); + boxed_error(Box::new(third)); + let fourth = Foo::Fourth { string: "" }; + assert_eq!(fourth.status(), Status::BadRequest); + assert!(fourth.source().is_none()); + boxed_error(Box::new(fourth)); +} + +#[derive(TypedError)] +pub struct InfallibleError { + #[error(source)] + _inner: std::convert::Infallible, +} + +#[derive(TypedError)] +pub struct StaticError { + #[error(source)] + inner: std::string::FromUtf8Error, +} + +#[test] +fn validate_static() { + let val = StaticError { + inner: String::from_utf8(vec![0xFF]).unwrap_err(), + }; + assert_eq!(val.status(), Status::InternalServerError); + assert!(std::ptr::eq( + val.source().unwrap(), + &val.inner, + )); + boxed_error(Box::new(val)); +} diff --git a/core/codegen/tests/ui-fail/typed_error.rs b/core/codegen/tests/ui-fail/typed_error.rs new file mode 100644 index 0000000000..9ede6a1153 --- /dev/null +++ b/core/codegen/tests/ui-fail/typed_error.rs @@ -0,0 +1,32 @@ +#[macro_use] extern crate rocket; + +#[derive(TypedError)] +struct InnerError; +struct InnerNonError; + +#[derive(TypedError)] +struct Thing1<'a, 'b> { + a: &'a str, + b: &'b str, +} + +#[derive(TypedError)] +struct Thing2 { + #[error(source)] + inner: InnerNonError, +} + +#[derive(TypedError)] +enum Thing3<'a, 'b> { + A(&'a str), + B(&'b str), +} + +#[derive(TypedError)] +enum Thing4 { + A(#[error(source)] InnerNonError), + B(#[error(source)] InnerError), +} + +#[derive(TypedError)] +enum EmptyEnum { } diff --git a/core/lib/src/erased.rs b/core/lib/src/erased.rs index 11718fa780..3d7dc99b80 100644 --- a/core/lib/src/erased.rs +++ b/core/lib/src/erased.rs @@ -41,7 +41,7 @@ pub struct ErasedError<'r> { impl<'r> ErasedError<'r> { pub fn new() -> Self { - Self { error: None } + Self { error: None } } pub fn write(&mut self, error: Option + 'r>>) { @@ -142,11 +142,12 @@ impl ErasedRequest { // SAFETY: At this point, ErasedRequest contains a request, which is permitted // to borrow from `Rocket` and `Parts`. They both have stable addresses (due to // `Arc` and `Box`), and the Request will be dropped first (due to drop order). - // SAFETY: Here, we place the `ErasedRequest` (i.e. the `Request`) behind an `Arc` (TODO: Why not Box?) + // SAFETY: Here, we place the `ErasedRequest` (i.e. the `Request`) behind an `Arc` // to ensure it has a stable address, and we again use drop order to ensure the `Request` // is dropped before the values that can borrow from it. let mut parent = Arc::new(self); - // SAFETY: This error is permitted to borrow from the `Request` (as well as `Rocket` and `Parts`). + // SAFETY: This error is permitted to borrow from the `Request` (as well as `Rocket` and + // `Parts`). let mut error = ErasedError { error: None }; let token: T = { let parent: &mut ErasedRequest = Arc::get_mut(&mut parent).unwrap(); diff --git a/core/lib/src/fairing/ad_hoc.rs b/core/lib/src/fairing/ad_hoc.rs index c7bd5dc21d..5e1c9226f1 100644 --- a/core/lib/src/fairing/ad_hoc.rs +++ b/core/lib/src/fairing/ad_hoc.rs @@ -61,7 +61,7 @@ enum AdHocKind { Liftoff(Once FnOnce(&'a Rocket) -> BoxFuture<'a, ()> + Send + 'static>), /// An ad-hoc **request** fairing. Called when a request is received. - Request(Box Fn(&'a mut Request<'_>, &'b mut Data<'_>) + Request(Box Fn(&'a mut Request<'_>, &'a mut Data<'_>) -> BoxFuture<'a, ()> + Send + Sync + 'static>), /// An ad-hoc **request_filter** fairing. Called when a request is received. @@ -159,7 +159,7 @@ impl AdHoc { /// }); /// ``` pub fn on_request(name: &'static str, f: F) -> AdHoc - where F: for<'a, 'b> Fn(&'a mut Request<'_>, &'b mut Data<'_>) + where F: for<'a> Fn(&'a mut Request<'_>, &'a mut Data<'_>) -> BoxFuture<'a, ()> { AdHoc { name, kind: AdHocKind::Request(Box::new(f)) } @@ -407,7 +407,7 @@ impl AdHoc { let _ = self.routes(rocket); } - async fn on_request<'r>(&self, req: &'r mut Request<'_>, _: &mut Data<'_>) { + async fn on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>) { // If the URI has no trailing slash, it routes as before. if req.uri().is_normalized_nontrailing() { return; @@ -458,7 +458,7 @@ impl Fairing for AdHoc { } } - async fn on_request<'r>(&self, req: &'r mut Request<'_>, data: &mut Data<'_>) { + async fn on_request(&self, req: &mut Request<'_>, data: &mut Data<'_>) { if let AdHocKind::Request(ref f) = self.kind { f(req, data).await } diff --git a/core/lib/src/fairing/mod.rs b/core/lib/src/fairing/mod.rs index 88bede2026..181fce7554 100644 --- a/core/lib/src/fairing/mod.rs +++ b/core/lib/src/fairing/mod.rs @@ -511,7 +511,7 @@ pub trait Fairing: Send + Sync + Any + 'static { /// ## Default Implementation /// /// The default implementation of this method does nothing. - async fn on_request<'r>(&self, _req: &'r mut Request<'_>, _data: &mut Data<'_>) { } + async fn on_request(&self, _req: &mut Request<'_>, _data: &mut Data<'_>) { } /// The request filter callback. /// @@ -582,7 +582,7 @@ impl Fairing for std::sync::Arc { } #[inline] - async fn on_request<'r>(&self, req: &'r mut Request<'_>, data: &mut Data<'_>) { + async fn on_request(&self, req: &mut Request<'_>, data: &mut Data<'_>) { (self as &T).on_request(req, data).await } diff --git a/core/lib/src/lifecycle.rs b/core/lib/src/lifecycle.rs index c9d20c733f..3ad5cb0149 100644 --- a/core/lib/src/lifecycle.rs +++ b/core/lib/src/lifecycle.rs @@ -108,8 +108,9 @@ impl Rocket { let mut response = if error_ptr.is_some() { // error_ptr is always some here, we just checked. self.dispatch_error(error_ptr.get().unwrap().status(), request, error_ptr.get()).await - // We MUST wait until we are inside this block to call `get`, since we HAVE to borrow it for `'r`. - // (And it's invariant, so we can't downcast the borrow to a shorter lifetime) + // We MUST wait until we are inside this block to call `get`, since we HAVE to borrow + // it for `'r`. (And it's invariant, so we can't downcast the borrow to a shorter + // lifetime) } else { match self.route(request, data).await { Outcome::Success(response) => response, diff --git a/core/lib/src/local/asynchronous/response.rs b/core/lib/src/local/asynchronous/response.rs index 47204baf15..0faaa690c3 100644 --- a/core/lib/src/local/asynchronous/response.rs +++ b/core/lib/src/local/asynchronous/response.rs @@ -67,10 +67,13 @@ impl Drop for LocalResponse<'_> { } impl<'c> LocalResponse<'c> { - pub(crate) fn new(req: Request<'c>, mut data: Data<'c>, preprocess: P, f: F) -> impl Future> - where P: FnOnce(&'c mut Request<'c>, &'c mut Data<'c>, &'c mut ErasedError<'c>) -> PO + Send, + pub(crate) fn new(req: Request<'c>, mut data: Data<'c>, preprocess: P, f: F) + -> impl Future> + where P: FnOnce(&'c mut Request<'c>, &'c mut Data<'c>, &'c mut ErasedError<'c>) + -> PO + Send, PO: Future + Send + 'c, - F: FnOnce(RequestToken, &'c Request<'c>, Data<'c>, &'c mut ErasedError<'c>) -> O + Send, + F: FnOnce(RequestToken, &'c Request<'c>, Data<'c>, &'c mut ErasedError<'c>) + -> O + Send, O: Future> + Send + 'c { // `LocalResponse` is a self-referential structure. In particular, @@ -121,7 +124,12 @@ impl<'c> LocalResponse<'c> { // the value is used to set cookie defaults. // SAFETY: The type of `preprocess` ensures that all of these types have the correct // lifetime ('c). - let response: Response<'c> = f(token, request, data, unsafe { transmute(&mut error) }).await; + let response: Response<'c> = f( + token, + request, + data, + unsafe { transmute(&mut error) } + ).await; let mut cookies = CookieJar::new(None, request.rocket()); for cookie in response.cookies() { cookies.add_original(cookie.into_owned()); diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index 9c551dcc71..5137e0f908 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -42,7 +42,9 @@ impl Rocket { span_debug!("request headers" => request.inner().headers().iter().trace_all_debug()); let mut response = request.into_response( stream, - |rocket, request, data, error_ptr| Box::pin(rocket.preprocess(request, data, error_ptr)), + |rocket, request, data, error_ptr| { + Box::pin(rocket.preprocess(request, data, error_ptr)) + }, |token, rocket, request, data, error_ptr| Box::pin(async move { if !request.errors.is_empty() { error_ptr.write(Some(Box::new(RequestErrors::new(&request.errors)))); diff --git a/examples/todo/src/main.rs b/examples/todo/src/main.rs index 9655142e64..d3086493d3 100644 --- a/examples/todo/src/main.rs +++ b/examples/todo/src/main.rs @@ -70,7 +70,10 @@ async fn toggle(id: i32, conn: DbConn) -> Either { Ok(_) => Either::Left(Redirect::to("/")), Err(e) => { error!("DB toggle({id}) error: {e}"); - Either::Right(Template::render("index", Context::err(&conn, "Failed to toggle task.").await)) + Either::Right(Template::render( + "index", + Context::err(&conn, "Failed to toggle task.").await + )) } } } @@ -81,7 +84,10 @@ async fn delete(id: i32, conn: DbConn) -> Either, Template> { Ok(_) => Either::Left(Flash::success(Redirect::to("/"), "Todo was deleted.")), Err(e) => { error!("DB deletion({id}) error: {e}"); - Either::Right(Template::render("index", Context::err(&conn, "Failed to delete task.").await)) + Either::Right(Template::render( + "index", + Context::err(&conn, "Failed to delete task.").await + )) } } }