diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 9195b0ff76..62890aad9d 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -66,6 +66,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** Implement `IntoResponse` for `(R,) where R: IntoResponse` ([#2143]) - **changed:** For SSE, add space between field and value for compatibility ([#2149]) - **added:** Add `NestedPath` extractor ([#1924]) +- **added:** Add `handle_error` function to existing `ServiceExt` trait ([#2235]) - **breaking:** `impl IntoResponse(Parts) for Extension` now requires `T: Clone`, as that is required by the http crate ([#1882]) - **added:** Add `axum::Json::from_bytes` ([#2244]) @@ -92,6 +93,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#2140]: https://github.com/tokio-rs/axum/pull/2140 [#2143]: https://github.com/tokio-rs/axum/pull/2143 [#2149]: https://github.com/tokio-rs/axum/pull/2149 +[#2235]: https://github.com/tokio-rs/axum/pull/2235 [#2244]: https://github.com/tokio-rs/axum/pull/2244 [#2328]: https://github.com/tokio-rs/axum/pull/2328 diff --git a/axum/src/routing/tests/handle_error.rs b/axum/src/routing/tests/handle_error.rs index a1af97af70..9b81a20f1d 100644 --- a/axum/src/routing/tests/handle_error.rs +++ b/axum/src/routing/tests/handle_error.rs @@ -95,3 +95,17 @@ async fn handler_multiple_methods_last() { let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); } + +#[crate::test] +async fn handler_service_ext() { + let fallible_service = tower::service_fn(|_| async { Err::<(), ()>(()) }); + let handle_error_service = + fallible_service.handle_error(|_| async { StatusCode::INTERNAL_SERVER_ERROR }); + + let app = Router::new().route("/", get_service(handle_error_service)); + + let client = TestClient::new(app); + + let res = client.get("/").send().await; + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); +} diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index f53e1baa4b..6c00a6de67 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -12,7 +12,7 @@ use crate::{ tracing_helpers::{capture_tracing, TracingEvent}, *, }, - BoxError, Extension, Json, Router, + BoxError, Extension, Json, Router, ServiceExt, }; use axum_core::extract::Request; use futures_util::stream::StreamExt; @@ -30,7 +30,7 @@ use std::{ task::{Context, Poll}, time::Duration, }; -use tower::{service_fn, util::MapResponseLayer, ServiceExt}; +use tower::{service_fn, util::MapResponseLayer, ServiceExt as TowerServiceExt}; use tower_http::{ limit::RequestBodyLimitLayer, timeout::TimeoutLayer, validate_request::ValidateRequestHeaderLayer, diff --git a/axum/src/service_ext.rs b/axum/src/service_ext.rs index e603d65f16..1b49f244b6 100644 --- a/axum/src/service_ext.rs +++ b/axum/src/service_ext.rs @@ -1,3 +1,4 @@ +use crate::error_handling::HandleError; #[cfg(feature = "tokio")] use crate::extract::connect_info::IntoMakeServiceWithConnectInfo; use crate::routing::IntoMakeService; @@ -30,6 +31,17 @@ pub trait ServiceExt: Service + Sized { /// [`ConnectInfo`]: crate::extract::connect_info::ConnectInfo #[cfg(feature = "tokio")] fn into_make_service_with_connect_info(self) -> IntoMakeServiceWithConnectInfo; + + /// Convert this service into a [`HandleError`], that will handle errors + /// by converting them into responses. + /// + /// See ["error handling model"] for more details. + /// + /// [`HandleError`]: crate::error_handling::HandleError + /// ["error handling model"]: crate::error_handling#axums-error-handling-model + fn handle_error(self, f: F) -> HandleError { + HandleError::new(self, f) + } } impl ServiceExt for S